1 回答

TA貢獻1772條經驗 獲得超6個贊
class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
self.batch_size = batch_size
self.seq_length = seq_length
self.encoding = encoding
#第一次運行程序時只有input.txt一個文件,剩下兩個文件是運行之后產生的
input_file = os.path.join(data_dir, "input.txt")
vocab_file = os.path.join(data_dir, "vocab.pkl")
tensor_file = os.path.join(data_dir, "data.npy")
#如果是第一次執行則調用preprocess函數,否則調用load_preprocessed函數。
if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
print("reading text file")
self.preprocess(input_file, vocab_file, tensor_file)
else:
print("loading preprocessed files")
self.load_preprocessed(vocab_file, tensor_file)
self.create_batches()
self.reset_batch_pointer()
def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
#使用Counter函數對輸入數據進行統計。counter保存data中每個字符出現的次數
counter = collections.Counter(data)
#對counter進行排序,出現次數最多的排在前面
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
#將data中出現的所有字符保存,這里有65個,所以voacb_size=65
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
#按照字符出現次數多少順序將chars保存,vocab中存儲的是char和順序,這樣方便將data轉化為索引
self.vocab = dict(zip(self.chars, range(len(self.chars))))
with open(vocab_file, 'wb') as f:
#保存chars
cPickle.dump(self.chars, f)
#將data中每個字符轉化為索引下標。
self.tensor = np.array(list(map(self.vocab.get, data)))
np.save(tensor_file, self.tensor)
def load_preprocessed(self, vocab_file, tensor_file):
#如果是第二次運行,則可以直接讀取之前保存的chars和tensor
with open(vocab_file, 'rb') as f:
self.chars = cPickle.load(f)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.tensor = np.load(tensor_file)
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
def create_batches(self):
#首先將數據按batch_size切割,然后每個batch_size在按照seq_length進行切割
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
if self.num_batches == 0:
assert False, "Not enough data. Make seq_length and batch_size small."
self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
xdata = self.tensor
#構造target,這里使用上一個詞預測下一個詞,所以直接將x向后一個字符即可
ydata = np.copy(self.tensor)
ydata[:-1] = xdata[1:]
ydata[-1] = xdata[0]
#將數據進行切分,這里我們假設數據總長度為10000,batch_size為100, seq_length為10.
# 所以num_batches=10,所以,xdata在reshape之后變成[100, 100],然后在第二個維度上切成10份,
# 所以最終得到[100, 10, 10]的數據
self.x_batches = np.split(xdata.reshape(self.batch_size, -1),
self.num_batches, 1)
self.y_batches = np.split(ydata.reshape(self.batch_size, -1),
self.num_batches, 1)
def next_batch(self):
x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
self.pointer += 1
return x, y
def reset_batch_pointer(self):
self.pointer = 0
- 1 回答
- 0 關注
- 977 瀏覽
添加回答
舉報