# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-12 18:00
from typing import Any
import torch
from hanlp_common.util import merge_locals_kwargs
import hanlp.utils.span_util
from hanlp.components.taggers.rnn_tagger import RNNTagger
from hanlp.metrics.chunking.conlleval import SpanF1
[docs]class RNNNamedEntityRecognizer(RNNTagger):
def __init__(self, **kwargs) -> None:
"""An old-school RNN tagger using word2vec or fasttext embeddings.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
[docs] def build_metric(self, **kwargs):
return SpanF1(self.tagging_scheme)
[docs] def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, **kwargs):
loss, metric = super().evaluate_dataloader(data, criterion, logger, ratio_width, **kwargs)
if logger:
logger.info(metric.result(True, False)[-1])
return loss, metric
[docs] def fit(self, trn_data, dev_data, save_dir, batch_size=50, epochs=100, embed=100, rnn_input=None, rnn_hidden=256,
drop=0.5, lr=0.001, patience=10, crf=True, optimizer='adam', token_key='token', tagging_scheme=None,
anneal_factor: float = 0.5, delimiter=None, anneal_patience=2, devices=None,
token_delimiter=None,
logger=None,
verbose=True, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
def update_metrics(self, metric, logits, y, mask, batch, prediction):
logits = self.decode_output(logits, mask, batch)
if isinstance(logits, torch.Tensor):
logits = logits.tolist()
metric(self._id_to_tags(logits), batch['tag'])
[docs] def predict(self, tokens: Any, batch_size: int = None, **kwargs):
return super().predict(tokens, batch_size, **kwargs)
def predict_data(self, data, batch_size, **kwargs):
outputs = super().predict_data(data, batch_size)
tagging_scheme = self.tagging_scheme
if tagging_scheme == 'IOBES':
entities = [hanlp.utils.span_util.iobes_tags_to_spans(y) for y in outputs]
elif tagging_scheme == 'BIO':
entities = [hanlp.utils.span_util.bio_tags_to_spans(y) for y in outputs]
elif tagging_scheme == 'BIOUL':
entities = [hanlp.utils.span_util.bioul_tags_to_spans(y) for y in outputs]
else:
raise ValueError(f'Unrecognized tag scheme {tagging_scheme}')
for i, (tokens, es) in enumerate(zip(data, entities)):
outputs[i] = [(self.config.token_delimiter.join(tokens[b:e + 1]), t, b, e + 1) for t, (b, e) in es]
return outputs
[docs] def save_config(self, save_dir, filename='config.json'):
if self.config.token_delimiter is None:
self.config.token_delimiter = '' if all(
[len(x) == 1 for x in self.vocabs[self.config.token_key].idx_to_token[-100:]]) else ' '
super().save_config(save_dir, filename)