由于我下载的imagenet2012验证集所有图片都在一个文件夹,所有标签数据都在一个txt里面,因此我使用了自定义的DataSet和DataLoader进行读取。
import os from torch.utils import data from PIL import Image import torch.nn as nn from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt
transform=transforms.Compose([ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), ])
class MyDataSet(data.Dataset):
def __init__(self,root,target_transform=None):
fh = open('imagenet/caffe_ilsvrc12/val.txt', 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
words[0]=os.path.join(root, words[0])
print('img path:',words[0],'label:',words[1])
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transforms = transform
self.target_transform = target_transform
def __getitem__(self, index):
#print('index:',index)
img_path,label = self.imgs[index]
pil_img = Image.open(img_path).convert('L')
if self.transforms:
data = self.transforms(pil_img)
else:
pil_img = np.asarray(pil_img)
data = torch.from_numpy(pil_img)
return data,label
def __len__(self):
return len(self.imgs)自定义的MyDataSet类继承于torch.utils.data.DataSet类。由于图片本身一部分是三通道的,一部分却是单通道的。因此如果不在读取的时候统一读入灰度图,就会报一个错误:
RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1, 224, 224] at entry 25
我原本是想读入彩色图,以便在下面直接进行展示。这个错误我不知道如何解决,因此统一读入时使用
pil_img = Image.open(img_path).convert('L')读取单通道图片,在后面的展示中显示的也就是灰度图片。
train_dataset = MyDataSet('imagenet/val')
print(len(train_dataset))
valid_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)for image in valid_loader:
valid_image, valid_label = image[0], image[1]
print('valid_label:', valid_label)
print('valid_image shape', valid_image.shape)
print(valid_image[0].shape)
plt.imshow(valid_image[0].squeeze(), cmap='gray')
plt.show()
breakvalid_label: tensor([658, 283, 202, 619, 32, 758, 646, 690, 100, 546, 942, 728, 343, 969, 80, 530, 296, 412, 163, 128, 858, 702, 507, 500, 303, 478, 342, 10, 524, 703, 277, 777, 600, 806, 768, 353, 718, 981, 598, 519, 413, 817, 774, 302, 263, 366, 31, 600, 48, 986, 98, 602, 409, 39, 894, 747, 200, 384, 140, 386, 191, 952, 128, 990]) valid_image shape torch.Size([64, 1, 224, 224]) torch.Size([1, 224, 224])
點擊查看更多內容
為 TA 點贊
評論
評論
共同學習,寫下你的評論
評論加載中...
作者其他優質文章
正在加載中
感謝您的支持,我會繼續努力的~
掃碼打賞,你說多少就多少
贊賞金額會直接到老師賬戶
支付方式
打開微信掃一掃,即可進行掃碼打賞哦
