亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

訓練深度學習模型時出錯

訓練深度學習模型時出錯

HUX布斯 2023-03-22 10:54:01
所以我設計了一個 CNN 并使用以下參數進行編譯,training_file_loc = "8-SignLanguageMNIST/sign_mnist_train.csv"testing_file_loc = "8-SignLanguageMNIST/sign_mnist_test.csv"def getData(filename):    images = []    labels = []    with open(filename) as csv_file:        file = csv.reader(csv_file, delimiter = ",")        next(file, None)                for row in file:            label = row[0]            data = row[1:]            img = np.array(data).reshape(28,28)                        images.append(img)            labels.append(label)                images = np.array(images).astype("float64")        labels = np.array(labels).astype("float64")            return images, labelstraining_images, training_labels = getData(training_file_loc)testing_images, testing_labels = getData(testing_file_loc)print(training_images.shape, training_labels.shape)print(testing_images.shape, testing_labels.shape)training_images = np.expand_dims(training_images, axis = 3)testing_images = np.expand_dims(testing_images, axis = 3)training_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")training_generator = training_datagen.flow(    training_images,    training_labels,    batch_size = 64,)validation_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")validation_generator = training_datagen.flow(    testing_images,    testing_labels,    batch_size = 64,])但是,當我運行 model.fit() 時,出現以下錯誤,ValueError: Shapes (None, 1) and (None, 24) are incompatible將損失函數更改為 后sparse_categorical_crossentropy,程序運行良好。我不明白為什么會這樣。誰能解釋這一點以及這些損失函數之間的區別?
查看完整描述

2 回答

?
largeQ

TA貢獻2039條經驗 獲得超8個贊

問題是,categorical_crossentropy期望單熱編碼標簽,這意味著,對于每個樣本,它期望一個長度張量,num_classes其中label第 th 個元素設置為 1,其他所有元素都為 0。


另一方面,sparse_categorical_crossentropy直接使用整數標簽(因為這里的用例是大量的類,所以單熱編碼標簽會浪費大量零的內存)。我相信,但我無法證實這一點,它categorical_crossentropy比它的稀疏對應物運行得更快。


對于您的情況,對于 26 個類,我建議使用非稀疏版本并將您的標簽轉換為單熱編碼,如下所示:


def getData(filename):

    images = []

    labels = []

    with open(filename) as csv_file:

        file = csv.reader(csv_file, delimiter = ",")

        next(file, None)

        

        for row in file:

            label = row[0]

            data = row[1:]

            img = np.array(data).reshape(28,28)

            

            images.append(img)

            labels.append(label)

        

        images = np.array(images).astype("float64")

        labels = np.array(labels).astype("float64")

        

    return images, tf.keras.utils.to_categorical(labels, num_classes=26) # you can omit num_classes to have it computed from the data

旁注:除非你有理由使用float64圖像,否則我會切換到float32(它將數據集所需的內存減半,并且模型可能會將它們轉換為float32第一個操作)


查看完整回答
反對 回復 2023-03-22
?
BIG陽

TA貢獻1859條經驗 獲得超6個贊

很簡單,對于輸出類為整數的分類問題,使用 sparse_categorical_crosentropy,對于標簽在一個熱編碼標簽中轉換的問題,我們使用 categorical_crosentropy。



查看完整回答
反對 回復 2023-03-22
  • 2 回答
  • 0 關注
  • 173 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號