我們在使用pytorch訓(xùn)練的時候一般會把數(shù)據(jù)集放到dataloader里。但在訓(xùn)練前我們也需要看一下訓(xùn)練數(shù)據(jù)長啥樣(檢驗數(shù)據(jù)集是否有問題),這就需要訓(xùn)練數(shù)據(jù)集可視化了。在訓(xùn)練數(shù)據(jù)集中的圖像一般都是帶batch的tensor類型的圖像,那么pytorch中帶batch的tensor類型圖像如何顯示呢?看完這篇文章你將得到答案。
顯示圖像
繪圖最常用的庫就是matplotlib:
pip install matplotlib
顯示圖像會用到matplotlib.pyplot.imshow方法。查閱官方文檔可知,該方法接收的圖像的通道數(shù)要放到后面:
數(shù)據(jù)加載器中數(shù)據(jù)的維度是[B, C, H, W],我們每次只拿一個數(shù)據(jù)出來就是[C, H, W],而matplotlib.pyplot.imshow要求的輸入維度是[H, W, C],所以我們需要交換一下數(shù)據(jù)維度,把通道數(shù)放到最后面,這里用到pytorch里面的permute方法(transpose方法也行,不過要交換兩次,沒這個方便,numpy中的transpose方法倒是可以一次交換完成)
用法示例如下:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])
代碼示例
#%% 導(dǎo)入模塊
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下載數(shù)據(jù)集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作數(shù)據(jù)加載器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 訓(xùn)練數(shù)據(jù)可視化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.title(labels[i].item())
plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()
這里以mnist數(shù)據(jù)集為例,演示一下顯示效果。我這個代碼其實還有一點小問題。數(shù)據(jù)增強的時候我不是進(jìn)行標(biāo)準(zhǔn)化了嘛,就是在第7行代碼:Normalize((0.1307,), (0.3081,))。
所以,如果你想查看訓(xùn)練集的原始圖像,還得反標(biāo)準(zhǔn)化。
標(biāo)準(zhǔn)化:image = (image-mean)/std
反標(biāo)準(zhǔn)化:image = image*std+mean
我拿imagenet中的一個螞蟻和蜜蜂的子集做了一下實驗,標(biāo)準(zhǔn)化前后的區(qū)別還是很明顯的:
最終效果
補充:PIL,plt顯示tensor類型的圖像
該方法針對顯示Dataloader讀取的圖像
PIL 與plt中對應(yīng)操作不同,但原理是一樣的,我試過用下方代碼Image的方法在plt上show失敗了,原因暫且不知。
# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一維
img = transforms.ToPILImage(image[0])
img.show()
# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一維
img = img.numpy() # FloatTensor轉(zhuǎn)為ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一維放到最后
# 顯示圖片
plt.imshow(img)
plt.show()
cnt += 1
以上就是pytorch中帶batch的tensor類型圖像如何顯示的方法了,希望能給大家一個參考,也希望大家多多支持W3Cschool。