Source code for hanlp.components.ner.transformer_ner

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-10-07 11:08
import functools
from typing import Union, List, Dict, Any, Set

from hanlp_trie import DictInterface, TrieDict

from hanlp.common.dataset import SamplerBuilder
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.metrics.chunking.sequence_labeling import get_entities
from hanlp.metrics.f1 import F1
from hanlp.datasets.ner.loaders.json_ner import prune_ner_tagset
from hanlp.utils.string_util import guess_delimiter
from hanlp_common.util import merge_locals_kwargs


[docs]class TransformerNamedEntityRecognizer(TransformerTagger): def __init__(self, **kwargs) -> None: r"""A simple tagger using transformers and a linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer for NER task. It can utilize whitelist gazetteers which is dict mapping from entity name to entity type. During decoding, it performs longest-prefix-matching of these words to override the prediction from underlying statistical model. It also uses a blacklist to mask out mis-predicted entities. .. Note:: For algorithm beginners, longest-prefix-matching is the prerequisite to understand what dictionary can do and what it can't do. The tutorial in `this book <http://nlp.hankcs.com/book.php>`_ can be very helpful. Args: **kwargs: Not used. """ super().__init__(**kwargs)
[docs] def build_metric(self, **kwargs): return F1()
# noinspection PyMethodOverriding def update_metrics(self, metric, logits, y, mask, batch, prediction): for p, g in zip(prediction, self.tag_to_span(batch['tag'], batch)): pred = set(p) gold = set(g) metric(pred, gold) # noinspection PyMethodOverriding def decode_output(self, logits, mask, batch, model=None): output = super().decode_output(logits, mask, batch, model) prediction = super().prediction_to_human(output, self.vocabs['tag'].idx_to_token, batch) return self.tag_to_span(prediction, batch) def tag_to_span(self, batch_tags, batch): spans = [] sents = batch[self.config.token_key] dict_whitelist = self.dict_whitelist dict_blacklist = self.dict_blacklist merge_types = self.config.get('merge_types', None) for tags, tokens in zip(batch_tags, sents): entities = get_entities(tags) if dict_whitelist: matches = dict_whitelist.tokenize(tokens) if matches: # Fix O E-LOC O like predictions entities = get_entities(tags) for label, start, end in entities: if end - start == 1: tags[start] = 'S-' + label else: tags[start] = 'B-' + label for i in range(start + 1, end - 1): tags[i] = 'I-' + label tags[end - 1] = 'E-' + label for start, end, label in matches: if (not tags[start][0] in 'ME') and (not tags[end - 1][0] in 'BM'): if end - start == 1: tags[start] = 'S-' + label else: tags[start] = 'B-' + label for i in range(start + 1, end - 1): tags[i] = 'I-' + label tags[end - 1] = 'E-' + label entities = get_entities(tags) if merge_types and len(entities) > 1: merged_entities = [] begin = 0 for i in range(1, len(entities)): if entities[begin][0] != entities[i][0] or entities[i - 1][2] != entities[i][1] \ or entities[i][0] not in merge_types: merged_entities.append((entities[begin][0], entities[begin][1], entities[i - 1][2])) begin = i merged_entities.append((entities[begin][0], entities[begin][1], entities[-1][2])) entities = merged_entities if dict_blacklist: pruned = [] delimiter_in_entity = self.config.get('delimiter_in_entity', ' ') for label, start, end in entities: entity = delimiter_in_entity.join(tokens[start:end]) if entity not in dict_blacklist: pruned.append((label, start, end)) entities = pruned spans.append(entities) return spans def decorate_spans(self, spans, batch): batch_ner = [] delimiter_in_entity = self.config.get('delimiter_in_entity', ' ') for spans_per_sent, tokens in zip(spans, batch.get(f'{self.config.token_key}_', batch[self.config.token_key])): ner_per_sent = [] for label, start, end in spans_per_sent: ner_per_sent.append((delimiter_in_entity.join(tokens[start:end]), label, start, end)) batch_ner.append(ner_per_sent) return batch_ner def generate_prediction_filename(self, tst_data, save_dir): return super().generate_prediction_filename(tst_data.replace('.tsv', '.txt'), save_dir) def prediction_to_human(self, pred, vocab, batch): return self.decorate_spans(pred, batch) def input_is_flat(self, tokens): return tokens and isinstance(tokens, list) and isinstance(tokens[0], str)
[docs] def fit(self, trn_data, dev_data, save_dir, transformer, delimiter_in_entity=None, merge_types: List[str] = None, average_subwords=False, word_dropout: float = 0.2, hidden_dropout=None, layer_dropout=0, scalar_mix=None, grad_norm=5.0, lr=5e-5, transformer_lr=None, adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, crf=False, secondary_encoder=None, reduction='sum', batch_size=32, sampler_builder: SamplerBuilder = None, epochs=3, tagset=None, token_key='token', max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None, seed=None, devices: Union[float, int, List[int]] = None, **kwargs): """Fit component to training set. Args: trn_data: Training set. dev_data: Development set. save_dir: The directory to save trained component. transformer: An identifier of a pre-trained transformer. delimiter_in_entity: The delimiter between tokens in entity, which is used to rebuild entity by joining tokens during decoding. merge_types: The types of consecutive entities to be merged. average_subwords: ``True`` to average subword representations. word_dropout: Dropout rate to randomly replace a subword with MASK. hidden_dropout: Dropout rate applied to hidden states. layer_dropout: Randomly zero out hidden states of a transformer layer. scalar_mix: Layer attention. grad_norm: Gradient norm for clipping. lr: Learning rate for decoder. transformer_lr: Learning for encoder. adam_epsilon: The epsilon to use in Adam. weight_decay: The weight decay to use. warmup_steps: The number of warmup steps. crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`). secondary_encoder: An optional secondary encoder to provide enhanced representation by taking the hidden states from the main encoder as input. reduction: The loss reduction used in aggregating losses. batch_size: The number of samples in a batch. sampler_builder: The builder to build sampler, which will override batch_size. epochs: The number of epochs to train. tagset: Optional tagset to prune entities outside of this tagset from datasets. token_key: The key to tokens in dataset. max_seq_len: The maximum sequence length. Sequence longer than this will be handled by sliding window. sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can be split here. char_level: Whether the sequence length is measured at char level, which is never the case for lemmatization. hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter`` in a sentence, it will be split at a token anyway. transform: An optional transform to be applied to samples. Usually a character normalization transform is passed in. devices: Devices this component will live on. logger: Any :class:`logging.Logger` instance. seed: Random seed to reproduce this training. **kwargs: Not used. Returns: The best metrics on training set. """ return super().fit(**merge_locals_kwargs(locals(), kwargs))
[docs] def build_vocabs(self, trn, logger, **kwargs): super().build_vocabs(trn, logger, **kwargs) if self.config.get('delimiter_in_entity', None) is None: # Check the first sample to guess the delimiter between tokens in a NE tokens = trn[0][self.config.token_key] delimiter_in_entity = guess_delimiter(tokens) logger.info(f'Guess the delimiter between tokens in named entity could be [blue]"{delimiter_in_entity}' f'"[/blue]. If not, specify `delimiter_in_entity` in `fit()`') self.config.delimiter_in_entity = delimiter_in_entity
def build_dataset(self, data, transform=None, **kwargs): dataset = super().build_dataset(data, transform, **kwargs) if isinstance(data, str): tagset = self.config.get('tagset', None) if tagset: dataset.append_transform(functools.partial(prune_ner_tagset, tagset=tagset)) return dataset @property def dict_whitelist(self) -> DictInterface: return self.config.get('dict_whitelist', None) @dict_whitelist.setter def dict_whitelist(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]): if dictionary is not None and not isinstance(dictionary, DictInterface): dictionary = TrieDict(dictionary) self.config.dict_whitelist = dictionary @property def dict_blacklist(self) -> DictInterface: return self.config.get('dict_blacklist', None) @dict_blacklist.setter def dict_blacklist(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]): if dictionary is not None and not isinstance(dictionary, DictInterface): dictionary = TrieDict(dictionary) self.config.dict_blacklist = dictionary