Hugging face 模型微调 - 文本分类任务

任务目标

1、投入目标任务的文本数据集重新训练哈工大已完成MLM任务预训练的roberta模型

2、使其能够完成下游文本分类任务

载入模型

模型下载地址,只需下载模型相关文件即可,config.json、pytorch_model.bin、vocab.txt

1
2
3
4
5
6
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

model_path = 'D:/Models/chinese-roberta-wwm-ext'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=15) # 分类个数

构建训练数据

数据集下载地址,内部数据结构如下,具体信息可见下载链接中readme文件描述

1
2
3
4
5
6
6551700932705387022_!_101_!_news_culture_!_京城最值得你来场文化之旅的博物馆_!_保利集团,马未都,中国科学技术馆,博物馆,新中国
6552368441838272771_!_101_!_news_culture_!_发酵床的垫料种类有哪些?哪种更好?_!_
6552310157706002702_!_102_!_news_entertainment_!_成龙改口决定不裸捐了,20亿财产给儿子一半,你怎么看?_!_
6552309039697494532_!_103_!_news_sports_!_亚洲杯夺冠赔率:日本、伊朗领衔 中国竟与泰国并列_!_土库曼斯坦,乌兹别克斯坦,亚洲杯,赔率,小组赛
6552477789642031623_!_103_!_news_sports_!_9轮4球本土射手仅次武磊 黄紫昌要抢最强U23头衔_!_黄紫昌,武磊,卡佩罗,惠家康,韦世豪
6552495859798376712_!_103_!_news_sports_!_如果今年勇士夺冠,下赛季詹姆斯何去何从?_!_
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import re
from sklearn.utils import shuffle
import pandas as pd

label_dic = {
'news_story':0,
'news_culture':1,
'news_entertainment':2,
'news_sports':3,
'news_finance':4,
'news_house':5,
'news_car':6,
'news_edu':7,
'news_tech':8,
'news_military':9,
'news_travel':10,
'news_world':11,
'stock':12,
'news_agriculture':13,
'news_game':14
}

def get_train_data(file_path, col_num):
content = []
label = []
with open(file_path, "r", encoding="utf-8") as f:
num = 0
for i in f.readlines():
if num > col_num:
break
lines = i.split("_!_")
content.append(re.sub('[^\u4e00-\u9fa5]',"",lines[3])) # 去除非中文
label.append(label_dic.get(lines[2]))
num += 1
return content,label

content,label = get_train_data("./file/toutiao_cat_data.txt", 8000)
data = pd.DataFrame({"content":content,"label":label})
data = shuffle(data)

train_data = tokenizer(data.content.to_list(), padding = "max_length", max_length = 100, truncation=True ,return_tensors = "pt")
train_label = data.label.to_list()

完成预处理的data变量中的训练样本数据格式如下:

index content label
4383 以色列警告称如果战机被击落将会轰炸俄军事基地你怎么看 9
5244 月份北京楼市各区成交排名昌平丰台密云三区热度高 5
5608 市值与业绩倒挂华大基因是第二个乐视网吗 12

定义优化器和学习率

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

batch_size = 16
train = TensorDataset(train_data["input_ids"], train_data["attention_mask"], torch.tensor(train_label))
train_sampler = RandomSampler(train)
train_dataloader = DataLoader(train, sampler=train_sampler, batch_size=batch_size)

# 定义优化器
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-4)
# 定义学习率和训练轮数
num_epochs = 1
from transformers import get_scheduler
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

模型训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from tqdm import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
for epoch in range(num_epochs):
total_loss = 0
model.train()
with tqdm(list(enumerate(train_dataloader)),ncols=100) as _tqdm:
for step, batch in _tqdm:
_tqdm.set_description('epoch {}/{}'.format(epoch+1, num_epochs))
if not step == 0:
cur_loss = total_loss/(step*batch_size)
avg_train_loss = total_loss / len(train_dataloader)
_tqdm.set_postfix(loss=cur_loss, avg_loss=avg_train_loss)
_tqdm.update(1)
else:
_tqdm.set_postfix(loss=0.00000)
b_input_ids = batch[0].to(device)
b_input_mask = batch[1].to(device)
b_labels = batch[2].to(device)
model.zero_grad()
outputs = model(b_input_ids,
token_type_ids=None,
attention_mask=b_input_mask,
labels=b_labels)

loss = outputs.loss
total_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()

模型预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

test = tokenizer("通过考研大纲谈农学复习之化学部分年农学将第二次实行全国统一考试这使农学考生的复习备考十分迷茫",return_tensors="pt",padding="max_length",max_length=100)
test.to(device)

model.eval()
with torch.no_grad():
outputs = model(test["input_ids"],
token_type_ids=None,
attention_mask=test["attention_mask"])

logits = outputs["logits"].cpu()
pred_flat = np.argmax(logits,axis=1).numpy().squeeze()
print(pred_flat.tolist())
print(list(label_dic.keys())[list(label_dic.values()).index(pred_flat.tolist())])

注:预训练模型的finetune基本就是这个套路,模型部分基本没太大变动,一般只需要根据数据集进行预处理,处理成模型适用的输入格式