亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

GPT2 on Hugging Face(pytorch 變壓器)運行時錯誤:

GPT2 on Hugging Face(pytorch 變壓器)運行時錯誤:

慕哥9229398 2023-08-08 18:03:12
我正在嘗試使用我的自定義數據集微調 gpt2。我使用擁抱面變壓器的文檔創建了一個基本示例。我收到上述錯誤。我知道這意味著什么:(基本上它是在非標量張量上向后調用)但由于我幾乎只使用 API 調用,所以我不知道如何解決這個問題。有什么建議么?from pathlib import Pathfrom absl import flags, appimport IPythonimport torchfrom transformers import GPT2LMHeadModel, Trainer,  TrainingArgumentsfrom data_reader import GetDataAsPython# this is my custom data, but i get the same error for the basic case below# data = GetDataAsPython('data.json')# data = [data_point.GetText2Text() for data_point in data]# print("Number of data samples is", len(data))data = ["this is a trial text", "this is another trial text"]train_texts = datafrom transformers import GPT2Tokenizertokenizer = GPT2Tokenizer.from_pretrained('gpt2')special_tokens_dict = {'pad_token': '<PAD>'}num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)train_encodigs = tokenizer(train_texts, truncation=True, padding=True)class BugFixDataset(torch.utils.data.Dataset):    def __init__(self, encodings):        self.encodings = encodings        def __getitem__(self, index):        item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}        return item    def __len__(self):        return len(self.encodings['input_ids'])train_dataset = BugFixDataset(train_encodigs)training_args = TrainingArguments(    output_dir='./results',              num_train_epochs=3,                  per_device_train_batch_size=1,      per_device_eval_batch_size=1,       warmup_steps=500,                    weight_decay=0.01,                   logging_dir='./logs',    logging_steps=10,)model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)model.resize_token_embeddings(len(tokenizer))trainer = Trainer(    model=model,    args=training_args,    train_dataset=train_dataset,)trainer.train()
查看完整描述

1 回答

?
海綿寶寶撒

TA貢獻1809條經驗 獲得超8個贊

我終于弄明白了。問題在于數據樣本不包含目標輸出。即使很難的 gpt 也是自我監督的,這必須明確地告訴模型。

你必須添加以下行:

item['labels'] = torch.tensor(self.encodings['input_ids'][index])

到Dataset類的getitem函數,然后就可以正常運行了!


查看完整回答
反對 回復 2023-08-08
  • 1 回答
  • 0 關注
  • 187 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號