1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import re from sklearn.utils import shuffle import pandas as pd
label_dic = { 'news_story':0, 'news_culture':1, 'news_entertainment':2, 'news_sports':3, 'news_finance':4, 'news_house':5, 'news_car':6, 'news_edu':7, 'news_tech':8, 'news_military':9, 'news_travel':10, 'news_world':11, 'stock':12, 'news_agriculture':13, 'news_game':14 }
def get_train_data(file_path, col_num): content = [] label = [] with open(file_path, "r", encoding="utf-8") as f: num = 0 for i in f.readlines(): if num > col_num: break lines = i.split("_!_") content.append(re.sub('[^\u4e00-\u9fa5]',"",lines[3])) label.append(label_dic.get(lines[2])) num += 1 return content,label content,label = get_train_data("./file/toutiao_cat_data.txt", 8000) data = pd.DataFrame({"content":content,"label":label}) data = shuffle(data)
train_data = tokenizer(data.content.to_list(), padding = "max_length", max_length = 100, truncation=True ,return_tensors = "pt") train_label = data.label.to_list()
|