侧边栏壁纸
博主头像
小鱼吃猫博客博主等级

你所热爱的,便是你的生活。

  • 累计撰写 115 篇文章
  • 累计创建 47 个标签
  • 累计收到 14 条评论

目 录CONTENT

文章目录

基于BERT+LSTM+CRF的命名实体识别

小鱼吃猫
2023-08-18 / 0 评论 / 4 点赞 / 89 阅读 / 9734 字

定义参数

class CommonConfig:
    bert_dir = "hfl/chinese-macbert-base"
    output_dir = "./checkpoint/"
    data_dir = "./data"

class NerConfig:
    def __init__(self):
        cf = CommonConfig()
        self.bert_dir = cf.bert_dir
        self.output_dir = cf.output_dir
        # self.output_dir = os.path.join(self.output_dir)
        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
        self.data_dir = cf.data_dir

        labels_list = ['O', 'B-HCCX', 'I-HCCX', 'B-HPPX', 'I-HPPX', 'B-XH', 'I-XH', 'B-MISC', 'I-MISC']
        self.num_labels = len(labels_list)
        self.label2id = {l: i for i, l in enumerate(labels_list)}
        self.id2label = {i: l for i, l in enumerate(labels_list)}

        self.max_seq_len = 512
        self.epochs = 3
        self.train_batch_size = 64
        self.dev_batch_size = 64
        self.bert_learning_rate = 3e-5
        self.crf_learning_rate = 3e-3
        self.adam_epsilon = 1e-8
        self.weight_decay = 0.01
        self.warmup_proportion = 0.01
        self.save_step = 500

加载参数

args = NerConfig()

with open(os.path.join(args.output_dir, "ner_args.json"), "w") as fp:
    json.dump(vars(args), fp, ensure_ascii=False, indent=2)

定义DataLoader

class NerDataset(Dataset):  
	def __init__(self, data, args, tokenizer):  
		self.data = data  
		self.args = args  
		self.tokenizer = tokenizer  
		self.label2id = args.label2id  
		self.max_seq_len = args.max_seq_len  
  
def __len__(self):  
	return len(self.data)  
  
def __getitem__(self, item):  
	examples=self.data[item]  
	tokenized_exmaples = self.tokenizer(examples["tokens"], max_length=self.max_seq_len,truncation=True,padding="max_length",  
	is_split_into_words=True)  
	word_ids = tokenized_exmaples.word_ids(batch_index=0)  
	label_ids = []  
	for word_id in word_ids:  
		if word_id is None:  
		label_ids.append(0)  
		else:  
		label_ids.append(examples["label"][word_id])  
		tokenized_exmaples["labels"] = label_ids  
		  
	for k,v in tokenized_exmaples.items():  
		tokenized_exmaples[k]=torch.tensor(np.array(v))  
  
return tokenized_exmaples

定义读取数据的方法

# 读取数据方法
def read_data(file,label2id):
    lists = []
    with open(file, 'r') as f:
        lines = f.readlines()
        id = 0
        tokens = []
        ner_tags = []
        ner_labels = []
        for line in tqdm(lines):
            line = line.replace("\n", "")
            if len(line) == 0:
                lists.append({
                    "id": str(id),
                    "tokens": tokens,
                    "ner_tags": ner_tags,
                    "label": ner_labels
                })
                tokens = []
                ner_tags = []
                ner_labels = []
                id += 1
                continue
            texts = line.split(" ")
            tokens.append(texts[0])
            ner_tags.append(texts[1])
            ner_labels.append(label2id[texts[1]])
    return lists

处理数据

tokenizer = AutoTokenizer.from_pretrained(args.bert_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = NerDataset(read_data("../origin/train.txt",args.label2id),args, tokenizer=tokenizer)
dev_dataset=NerDataset(read_data("../origin/dev.txt",args.label2id),args,tokenizer=tokenizer)
# test_dataset = NerDataset(read_data("../origin/test.txt"), tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=2)
dev_loader = DataLoader(dev_dataset, shuffle=False, batch_size=args.dev_batch_size, num_workers=2)

定义模型

 class BertNer(nn.Module):  
	def __init__(self, args):  
		super(BertNer, self).__init__()  
		self.bert = BertModel.from_pretrained(args.bert_dir)  
		self.bert_config = BertConfig.from_pretrained(args.bert_dir)  
		self.config = self.bert_config  
		hidden_size = self.bert_config.hidden_size  
		self.lstm_hiden = 128  
		self.max_seq_len = args.max_seq_len  
		self.bilstm = nn.LSTM(hidden_size, self.lstm_hiden, 1, bidirectional=True, batch_first=True,  
		dropout=0.1)  
		self.linear = nn.Linear(self.lstm_hiden * 2, args.num_labels)  
		self.crf = CRF(args.num_labels, batch_first=True)  
	  
	def forward(self, input_ids, attention_mask, labels=None):  
		bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)  
		seq_out = bert_output[0] # [batchsize, max_len, 768]  
		batch_size = seq_out.size(0)  
		seq_out, _ = self.bilstm(seq_out)  
		seq_out = seq_out.contiguous().view(-1, self.lstm_hiden * 2)  
		seq_out = seq_out.contiguous().view(batch_size, self.max_seq_len, -1)  
		seq_out = self.linear(seq_out)  
		logits = self.crf.decode(seq_out, mask=attention_mask.bool())  
		loss = None  
		if labels is not None:  
		loss = -self.crf(seq_out, labels, mask=attention_mask.bool(), reduction='token_mean')  
		return (logits, labels, loss)

加载模型

model = BertNer(args)
model.to(device)

定义Trainer

这里只介绍主要的训练和测试的方法,其他代码见后文代码地址

class Trainer:   
	def train(self):  
		global_step = 1  
		epoch_loss = 0  
		for epoch in range(1, self.epochs + 1):  
			for step, batch_data in enumerate(tqdm(self.train_loader)):  
				self.model.train()  
				for key, value in batch_data.items():  
					batch_data[key] = value.to(self.device)  
					input_ids = batch_data["input_ids"]  
					attention_mask = batch_data["attention_mask"]  
					labels = batch_data["labels"]  
					logits, _labels, loss = self.model(input_ids, attention_mask, labels)  
					self.optimizer.zero_grad()  
					loss.backward()  
					self.optimizer.step()  
					self.schedule.step()  
					global_step += 1  
					epoch_loss = loss  
		if global_step % self.save_step == 0:  
			torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))  
			print(f"【train】{epoch}/{self.epochs} {global_step}/{self.total_step} loss:{epoch_loss.item()}")  
		  
		torch.save(self.model.state_dict(), os.path.join(self.output_dir, "pytorch_model_ner.bin"))  
	  
	def test(self):  
		self.model.load_state_dict(torch.load(os.path.join(self.output_dir, "pytorch_model_ner.bin")))  
		self.model.eval()  
		preds = []  
		trues = []  
		for step, batch_data in enumerate(tqdm(self.test_loader)):  
			for key, value in batch_data.items():  
				batch_data[key] = value.to(self.device)  
				input_ids = batch_data["input_ids"]  
				attention_mask = batch_data["attention_mask"]  
				labels = batch_data["labels"]  
				logits, _labels, loss = self.model(input_ids, attention_mask, labels)  
				attention_mask = attention_mask.detach().cpu().numpy()  
				labels = labels.detach().cpu().numpy()  
		  
		batch_size = input_ids.size(0)  
		for i in range(batch_size):  
			length = sum(attention_mask[i])  
			logit = logits[i][1:length]  
			logit = [self.id2label[i] for i in logit]  
			label = labels[i][1:length]  
			label = [self.id2label[i] for i in label]  
			preds.append(logit)  
			trues.append(label)  
		  
		report = classification_report(trues, preds)  
		return report

使用Trainer传入参数

train = Trainer(
    output_dir=args.output_dir,
    model=model,
    train_loader=train_loader,
    dev_loader=dev_loader,
    test_loader=dev_loader,
    epochs=args.epochs,
    device=device,
    id2label=args.id2label,
    t_total_num=len(train_loader) * args.epochs,
    optimizer_args=args
)

训练

train.train()
【train】1/3 95/282 loss:0.32739126682281494
【train】2/3 189/282 loss:0.2356387972831726
【train】3/3 283/282 loss:0.21381179988384247

评估

report = train.test()
print(report)
precisionrecallf1-scoresupport
HCCX0.830.840.83
HPPX0.780.790.78
MISC0.760.810.79
XH0.710.790.75
micro avg0.810.830.82
macro avg0.770.810.79
weighted avg0.810.830.82

推理预测

预测函数见全部代码Predictor类

from predict import Predictor

predictor=Predictor()
text="川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购"
ner_result = predictor.ner_predict(text)
print("文本>>>>>:", text)
print("实体>>>>>:", ner_result)
文本>>>>>: 川珍浓香型香菇干货去根肉厚干香菇500g热销品质抢购
实体>>>>>: {'HPPX': [('川珍', 0, 1)], 'HCCX': [('香菇', 5, 6), ('干货', 7, 8), ('干香菇', 13, 15)], 'MISC': [('500g', 16, 19)]}

代码地址

基于BERT+LSTM+CRF的命名实体识别

更多内容

4

评论区