亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

訓練期間每類驗證的準確性

訓練期間每類驗證的準確性

躍然一笑 2023-05-23 10:15:08
Keras 在訓練時給出了整體training和validation準確率。有沒有辦法在培訓期間獲得a per-class validation accuracy?更新:來自 Pycharm 的錯誤日志File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 82, in <module>shuffle=True, callbacks=callbacks)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapperreturn method(self, *args, **kwargs)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py", line 876, in fitcallbacks.on_epoch_end(epoch, epoch_logs)File "C:\Users\wj96hq\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\callbacks.py", line 365, in on_epoch_endcallback.on_epoch_end(epoch, logs)File "C:/Users/wj96hq/PycharmProjects/PedestrianClassification/Awareness.py", line 36, in on_epoch_endx_test, y_test = self.validation_data[0], self.validation_data[1]TypeError: 'NoneType' object is not subscriptable
查看完整描述

3 回答

?
慕斯709654

TA貢獻1840條經驗 獲得超5個贊

使用它來獲得每類準確性:



model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])



class Metrics(keras.callbacks.Callback):

    def on_train_begin(self, logs={}):

        self._data = []


    def on_epoch_end(self, batch, logs={}):

        x_test, y_test = self.validation_data[0], self.validation_data[1]

        y_predict = np.asarray(model.predict(x_test))


        true = np.argmax(y_test, axis=1)

        pred = np.argmax(y_predict, axis=1)

        

        cm = confusion_matrix(true, pred)

        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        self._data.append({

            'classLevelaccuracy':cm.diagonal() ,

        })

        return


    def get_data(self):

        return self._data


metrics = Metrics()

history = model.fit(x_train, y_train, epochs=100, validation_data=(x_test, y_test), callbacks=[metrics])

metrics.get_data()

您可以在指標類中更改代碼。隨心所欲..并且這個工作。你只是用來metrics.get_data()獲取所有信息..


查看完整回答
反對 回復 2023-05-23
?
猛跑小豬

TA貢獻1858條經驗 獲得超8個贊

好吧,準確性是一個global指標,沒有per-class accuracy.?也許你的意思是,這就是orproportion of the class correctly identified的確切定義。TPRrecall


查看完整回答
反對 回復 2023-05-23
?
倚天杖

TA貢獻1828條經驗 獲得超3個贊

如果您想獲得某個類別或一組特定類別的準確性,掩碼可能是一個很好的解決方案。看這段代碼:


def cus_accuracy(real, pred):


    score = accuracy(real, pred)

    mask = tf.math.greater_equal(real, 5)

    mask = tf.cast(mask, dtype=real.dtype)

    score *= mask


    mask2 = tf.math.less_equal(real, 10)

    mask2 = tf.cast(mask2, dtype=real.dtype)

    score *= mask2


return tf.reduce_mean(score)

這個指標給出了 5 到 10 類的準確度。我用它來測量 seq2seq 模型中某些單詞的準確度。


查看完整回答
反對 回復 2023-05-23
  • 3 回答
  • 0 關注
  • 174 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號