我按照以下方式構建了我的數據集:dataset/train/0/456.jpgdataset/train/1/456456.jpgdataset/train/2/456.jpgdataset/train/...dataset/val/0/878.jpgdataset/val/1/234.jpgdataset/val/2/34554.jpgdataset/val/...所以我曾經torchvision.datasets.ImageFolder將我的數據集導入 PyTorch。然而,它似乎沒有給正確的圖像貼上正確的標簽。我在下面添加了我的代碼:data_transforms = { 'train': transforms.Compose( [transforms.Resize((176,176)), transforms.RandomRotation((0,360)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.CenterCrop(128), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]), 'val': transforms.Compose( [transforms.Resize((128,128)), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]),}data_dir = 'dataset'image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")我發現標簽是錯誤的,使用以下函數:def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()dataiter = iter(dataloaders['val'])images, labels = dataiter.next()imshow(torchvision.utils.make_grid(images))print(labels)使用顯示的圖像和標簽,我手動檢查它們是否正確。不幸的是,標簽與圖像不對應。有人能告訴我我做錯了什么嗎?
2 回答

當年話下
TA貢獻1890條經驗 獲得超9個贊
有人幫我解決了這個問題。ImageFolder 創建自己的內部標簽。通過打印,image_datasets['train'].class_to_idx
您可以看到哪個標簽與哪個內部標簽配對。使用這本詞典,您可以追溯原始標簽。
添加回答
舉報
0/150
提交
取消