定义参数
class CommonConfig:
bert_dir = "hfl/chinese-macbert-base"
output_dir = "./checkpoint/"
data_dir = "./data"
class NerConfig:
def __init__(self):
cf = CommonConfig()
self.bert_dir = cf.bert_dir
self.output_dir = cf.output_dir
# self.output_dir = os.path.join(self.output_dir)
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
self.data_dir = cf.data_dir
labels_list = ['O', 'B-HCCX', 'I-HCCX', 'B-HPPX', 'I-HPPX', 'B-XH', 'I-XH', 'B-MISC', 'I-MISC']
self.num_labels = len(labels_list)
self.label2id = {l: i for i, l in enumerate(labels_list)}
self.id2label = {i: l for i, l in enumerate(labels_list)}
self.max_seq_len = 512
self.epochs = 3
self.train_batch_size = 64
self.dev_batch_size = 64
self.bert_learning_rate = 3e-5
self.crf_learning_rate = 3e-3
self.adam_epsilon = 1e-8
self.weight_decay = 0.01
self.warmup_proportion = 0.01
self.save_step = 500
加载参数
args = NerConfig()
with open(os.path.join(args.output_dir, "ner_args.json"), "w") as fp:
json.dump(vars(args), fp, ensure_ascii=False, indent=2)
定义DataLoader
class NerDataset(Dataset):
def __init__(self, data, args, tokenizer):
self.data = data
self.args = args
self.tokenizer = tokenizer
self.label2id = args.label2id
self.max_seq_len = args.max_seq_len
def __len__(self):
return len(self.data)
def __getitem__(self, item):
examples=self.data[item]
tokenized_exmaples = self.tokenizer(examples["tokens"], max_length=self.max_seq_len,truncation=True,padding="max_length",
is_split_into_words=True)
word_ids = tokenized_exmaples.word_ids(batch_index=0)
label_ids = []
for word_id in word_ids:
if word_id is None:
label_ids.append(0)
else:
label_ids.append(examples["label"][word_id])
tokenized_exmaples["labels"] = label_ids
for k,v in tokenized_exmaples.items():
tokenized_exmaples[k]=torch.tensor(np.array(v))
return tokenized_exmaples
定义读取数据的方法
# 读取数据方法
def read_data(file,label2id):
lists = []
with open(file, 'r') as f:
lines = f.readlines()
id = 0
tokens = []
ner_tags = []
ner_labels = []
for line in tqdm(lines):
line = line.replace("\n", "")
if len(line) == 0:
lists.append({
"id": str(id),
"tokens": tokens,
"ner_tags": ner_tags,
"label": ner_labels
})
tokens = []
ner_tags = []
ner_labels = []
id += 1
continue
texts = line.split(" ")
tokens.append(texts[0])
ner_tags.append(texts[1])
ner_labels.append(label2id[texts[1]])
return lists
处理数据
tokenizer = AutoTokenizer.from_pretrained(args.bert_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = NerDataset(read_data("../origin/train.txt",args.label2id),args, tokenizer=tokenizer)
dev_dataset=NerDataset(read_data("../origin/dev.txt",args.label2id),args,tokenizer=tokenizer)
# test_dataset = NerDataset(read_data("../origin/test.txt"), tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=2)
dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=args.dev_batch_size, num_workers=2)
定义模型
class BertNer(nn.Module):
def __init__(self, args):
super(BertNer, self).__init__()
self.bert = BertModel.from_pretrained(args.bert_dir)
self.bert_config = BertConfig.from_pretrained(args.bert_dir)
self.config = self.bert_config
hidden_size = self.bert_config.hidden_size
self.lstm_hiden = 128
self.max_seq_len = args.max_seq_len
self.bilstm = nn.LSTM(hidden_size, self.lstm_hiden, 1, bidirectional=True, batch_first=True,
dropout=0.1)
self.linear = nn.Linear(self.lstm_hiden * 2, args.num_labels)
self.crf = CRF(args.num_labels, batch_first=True)
def forward(self, input_ids, attention_mask, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
seq_out = bert_output[0] # [batchsize, max_len, 768]
batch_size = seq_out.size(0)
seq_out, _ = self.bilstm(seq_out)
seq_out = seq_out.contiguous().view(-1, self.lstm_hiden * 2)
seq_out = seq_out.contiguous().view(batch_size, self.max_seq_len, -1)
seq_out = self.linear(seq_out)
logits = self.crf.decode(seq_out, mask=attention_mask.bool())
loss = None
if labels is not None:
loss = -self.crf(seq_out, labels, mask=attention_mask.bool(), reduction='token_mean')
return (logits, labels, loss)
加载模型
model = BertNer(args)
model.to(device)
定义Trainer
这里只介绍主要的训练和测试的方法,其他代码见后文代码地址
class Trainer:
def train(self):
global_step = 1
epoch_loss = 0
for epoch in range(1, self.epochs + 1):
for step, batch_data in enumerate(tqdm(self.train_loader)):
self.model.train()
for key, value in batch_data.items():
batch_data[key] = value.to(self.device)
input_ids = batch_data["input_ids"]
attention_mask = batch_data["attention_mask"]
labels = batch_data["labels"]
logits, _labels, loss = self.model(input_ids, attention_mask, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.schedule.step()
global_step += 1
epoch_loss = loss
if global_step % self.save_step == 0:
torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))
print(f"【train】{epoch}/{self.epochs} {global_step}/{self.total_step} loss:{epoch_loss.item()}")
torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))
def test(self):
self.model.load_state_dict(torch.load(os.path.join(self.output_dir, "pytorch_model_ner.bin")))
self.model.eval()
preds = []
trues = []
for step, batch_data in enumerate(tqdm(self.test_loader)):
for key, value in batch_data.items():
batch_data[key] = value.to(self.device)
input_ids = batch_data["input_ids"]
attention_mask = batch_data["attention_mask"]
labels = batch_data["labels"]
logits, _labels, loss = self.model(input_ids, attention_mask, labels)
attention_mask = attention_mask.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
batch_size = input_ids.size(0)
for i in range(batch_size):
length = sum(attention_mask[i])
logit = logits[i][1:length]
logit = [self.id2label[i] for i in logit]
label = labels[i][1:length]
label = [self.id2label[i] for i in label]
preds.append(logit)
trues.append(label)
report = classification_report(trues, preds)
return report
使用Trainer传入参数
train = Trainer(
output_dir=args.output_dir,
model=model,
train_loader=train_loader,
dev_loader=dev_loader,
test_loader=dev_loader,
epochs=args.epochs,
device=device,
id2label=args.id2label,
t_total_num=len(train_loader) * args.epochs,
optimizer_args=args
)
训练
train.train()
【train】1/3 95/282 loss:0.32739126682281494
【train】2/3 189/282 loss:0.2356387972831726
【train】3/3 283/282 loss:0.21381179988384247
评估
report = train.test()
print(report)
precision | recall | f1-score | support |
---|---|---|---|
HCCX | 0.83 | 0.84 | 0.83 |
HPPX | 0.78 | 0.79 | 0.78 |
MISC | 0.76 | 0.81 | 0.79 |
XH | 0.71 | 0.79 | 0.75 |
micro avg | 0.81 | 0.83 | 0.82 |
macro avg | 0.77 | 0.81 | 0.79 |
weighted avg | 0.81 | 0.83 | 0.82 |
推理预测
预测函数见全部代码Predictor类
from predict import Predictor
predictor=Predictor()
text="川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购"
ner_result = predictor.ner_predict(text)
print("文本>>>>>:", text)
print("实体>>>>>:", ner_result)
文本>>>>>: 川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购
实体>>>>>: {'HPPX': [('川珍', 0, 1)], 'HCCX': [('香菇', 5, 6), ('干货', 7, 8), ('干香菇', 13, 15)], 'MISC': [('500g', 16, 19)]}
代码地址
更多内容
- 小鱼吃猫博客——Transformers教程
- 微信公众号 codeCraft编程工艺
评论区