2 回答
TA貢獻1829條經驗 獲得超7個贊
也許你可以試試這一行而不是這兩行:
label = tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
你會得到類似的東西[0, 1, 0](標簽的索引在CLASS_NAMES)。
功能和可重現的例子:
import tensorflow as tf
import numpy as np
from string import ascii_lowercase as letters
CLASS_NAMES = [b'class_1', b'class_2', b'class_3']
files = ['\\'.join([np.random.choice(CLASS_NAMES).decode(),
''.join(np.random.choice(list(letters), 5)) + '.jpg'])
for i in range(10)]
ds = tf.data.Dataset.from_tensor_slices(files)
這是我生成的假文件:
['class_3\\jrxog.jpg',
'class_1\\slfiq.jpg',
'class_2\\svldd.jpg',
'class_2\\avrgt.jpg',
'class_3\\wqwuv.jpg']
現在實現這個:
def get_label(file_path):
parts = tf.strings.split(file_path, '\\')
return file_path, tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
ds = ds.map(get_label)
next(iter(ds))
(<tf.Tensor: shape=(), dtype=string, numpy=b'class_1\\bbqrx.jpg'>,
<tf.Tensor: shape=(), dtype=int64, numpy=0>)
添加回答
舉報
