Source code for hanlp.components.tokenizers.transformer
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-08-11 02:48
import functools
from typing import TextIO, Union, List, Dict, Any, Set
import torch
from hanlp.common.dataset import SamplerBuilder
from hanlp.common.transform import TransformList
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.datasets.tokenization.loaders.txt import TextTokenizingDataset, generate_tags_for_subtokens
from hanlp.metrics.f1 import F1
from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
from hanlp.utils.span_util import bmes_to_spans
from hanlp.utils.string_util import possible_tokenization
from hanlp_common.util import merge_locals_kwargs
from hanlp_trie import DictInterface, TrieDict
from hanlp_trie.dictionary import TupleTrieDict
[docs]class TransformerTaggingTokenizer(TransformerTagger):
def __init__(self, **kwargs) -> None:
""" A tokenizer using transformer tagger for span prediction. It features with 2 high performance dictionaries
to handle edge cases in real application.
- ``dict_force``: High priority dictionary performs longest-prefix-matching on input text which takes higher
priority over model predictions.
- ``dict_combine``: Low priority dictionary performs longest-prefix-matching on model predictions then
combines them.
.. Note:: For algorithm beginners, longest-prefix-matching is the prerequisite to understand what dictionary can
do and what it can't do. The tutorial in `this book <http://nlp.hankcs.com/book.php>`_ can be very helpful.
It also supports outputting the span of each token by setting ``config.output_spans = True``.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
@property
def dict_force(self) -> DictInterface:
r""" The high priority dictionary which perform longest-prefix-matching on inputs to split them into two subsets:
1. spans containing no keywords, which are then fed into tokenizer for further tokenization.
2. keywords, which will be outputed without furthur tokenization.
.. Caution::
Longest-prefix-matching **NEVER** guarantee the presence of any keywords. Abuse of
``dict_force`` can lead to low quality results. For more details, refer to
`this book <http://nlp.hankcs.com/book.php>`_.
Examples:
>>> tok.dict_force = {'和服', '服务行业'} # Force '和服' and '服务行业' by longest-prefix-matching
>>> tok("商品和服务行业")
['商品', '和服', '务行业']
>>> tok.dict_force = {'和服务': ['和', '服务']} # Force '和服务' to be tokenized as ['和', '服务']
>>> tok("商品和服务行业")
['商品', '和', '服务', '行业']
"""
return self.config.get('dict_force', None)
@dict_force.setter
def dict_force(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
if dictionary is not None and not isinstance(dictionary, DictInterface):
dictionary = TrieDict(dictionary)
self.config.dict_force = dictionary
self.tokenizer_transform.dict = dictionary
@property
def dict_combine(self) -> DictInterface:
""" The low priority dictionary which perform longest-prefix-matching on model predictions and combing them.
Examples:
>>> tok.dict_combine = {'和服', '服务行业'}
>>> tok("商品和服务行业") # '和服' is not in the original results ['商品', '和', '服务']. '服务', '行业' are combined to '服务行业'
['商品', '和', '服务行业']
"""
return self.config.get('dict_combine', None)
@dict_combine.setter
def dict_combine(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
if dictionary is not None and not isinstance(dictionary, DictInterface):
if all(isinstance(k, str) for k in dictionary):
dictionary = TrieDict(dictionary)
else:
_d = set()
for k in dictionary:
if isinstance(k, str):
_d.update(possible_tokenization(k))
else:
_d.add(k)
dictionary = TupleTrieDict(_d)
self.config.dict_combine = dictionary
# noinspection PyMethodOverriding
def update_metrics(self, metric, logits, y, mask, batch, prediction):
for p, g in zip(prediction, self.tag_to_span(batch['tag'], batch)):
pred = set(p)
gold = set(g)
metric(pred, gold)
def decode_output(self, logits, mask, batch, model=None):
output = super().decode_output(logits, mask, batch, model)
if isinstance(output, torch.Tensor):
output = output.tolist()
prediction = self.id_to_tags(output, [len(x) for x in batch['token']])
return self.tag_to_span(prediction, batch)
def tag_to_span(self, batch_tags, batch: dict):
spans = []
if 'custom_words' in batch:
if self.config.tagging_scheme == 'BMES':
S = 'S'
M = 'M'
E = 'E'
else:
S = 'B'
M = 'I'
E = 'I'
for tags, custom_words in zip(batch_tags, batch['custom_words']):
# [batch['raw_token'][0][x[0]:x[1]] for x in subwords]
if custom_words:
for start, end, label in custom_words:
if end - start == 1:
tags[start] = S
else:
tags[start] = 'B'
tags[end - 1] = E
for i in range(start + 1, end - 1):
tags[i] = M
if end < len(tags):
tags[end] = 'B'
if 'token_subtoken_offsets_group' not in batch: # only check prediction on raw text for now
# Check cases that a single char gets split into multiple subtokens, e.g., ‥ -> . + .
for tags, subtoken_offsets in zip(batch_tags, batch['token_subtoken_offsets']):
offset = -1 # BERT produces 'ᄒ', '##ᅡ', '##ᆫ' for '한' and they share the same span
prev_tag = None
for i, (tag, (b, e)) in enumerate(zip(tags, subtoken_offsets)):
if b < offset:
if prev_tag == 'S':
tags[i - 1] = 'B'
elif prev_tag == 'E':
tags[i - 1] = 'M'
tags[i] = tag = 'M'
offset = e
prev_tag = tag
for tags in batch_tags:
spans.append(bmes_to_spans(tags))
return spans
def write_prediction(self, prediction, batch, output: TextIO):
batch_tokens = self.spans_to_tokens(prediction, batch)
for tokens in batch_tokens:
output.write(' '.join(tokens))
output.write('\n')
@property
def tokenizer_transform(self):
if not self._tokenizer_transform:
self._tokenizer_transform = TransformerSequenceTokenizer(self.transformer_tokenizer,
self.config.token_key,
ret_subtokens=True,
ret_subtokens_group=True,
ret_token_span=False,
dict_force=self.dict_force)
return self._tokenizer_transform
def spans_to_tokens(self, spans, batch, rebuild_span=False):
batch_tokens = []
dict_combine = self.dict_combine
raw_text = batch.get('token_', None) # Use raw text to rebuild the token according to its offset
for b, (spans_per_sent, sub_tokens) in enumerate(zip(spans, batch[self.config.token_key])):
if raw_text: # This will restore iPhone X as a whole
text = raw_text[b]
offsets = batch['token_subtoken_offsets'][b]
tokens = [text[offsets[b][0]:offsets[e - 1][-1]] for b, e in spans_per_sent]
else: # This will merge iPhone X into iPhoneX
tokens = [''.join(sub_tokens[span[0]:span[1]]) for span in spans_per_sent]
if dict_combine:
buffer = []
offset = 0
delta = 0
for start, end, label in dict_combine.tokenize(tokens):
if offset < start:
buffer.extend(tokens[offset:start])
if raw_text:
# noinspection PyUnboundLocalVariable
combined = text[offsets[spans_per_sent[start - delta][0]][0]:
offsets[spans_per_sent[end - delta - 1][1] - 1][1]]
else:
combined = ''.join(tokens[start:end])
buffer.append(combined)
offset = end
if rebuild_span:
start -= delta
end -= delta
combined_span = (spans_per_sent[start][0], spans_per_sent[end - 1][1])
del spans_per_sent[start:end]
delta += end - start - 1
spans_per_sent.insert(start, combined_span)
if offset < len(tokens):
buffer.extend(tokens[offset:])
tokens = buffer
batch_tokens.append(tokens)
return batch_tokens
def generate_prediction_filename(self, tst_data, save_dir):
return super().generate_prediction_filename(tst_data.replace('.tsv', '.txt'), save_dir)
def prediction_to_human(self, pred, vocab, batch, rebuild_span=False):
output_spans = self.config.get('output_spans', None)
tokens = self.spans_to_tokens(pred, batch, rebuild_span or output_spans)
if output_spans:
subtoken_spans = batch['token_subtoken_offsets']
results = []
for toks, offs, subs in zip(tokens, pred, subtoken_spans):
r = []
results.append(r)
for t, (b, e) in zip(toks, offs):
r.append([t, subs[b][0], subs[e - 1][-1]])
return results
return tokens
def input_is_flat(self, tokens):
return isinstance(tokens, str)
def build_dataset(self, data, **kwargs):
return TextTokenizingDataset(data, **kwargs)
def last_transform(self):
return TransformList(functools.partial(generate_tags_for_subtokens, tagging_scheme=self.config.tagging_scheme),
super().last_transform())
[docs] def fit(self, trn_data, dev_data, save_dir, transformer, average_subwords=False, word_dropout: float = 0.2,
hidden_dropout=None, layer_dropout=0, scalar_mix=None, grad_norm=5.0,
transformer_grad_norm=None, lr=5e-5, eval_trn=True,
transformer_lr=None, transformer_layers=None, gradient_accumulation=1,
adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, crf=False, reduction='sum',
batch_size=32, sampler_builder: SamplerBuilder = None, epochs=30, patience=5, token_key=None,
tagging_scheme='BMES', delimiter=None,
max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None,
devices: Union[float, int, List[int]] = None, **kwargs):
"""
Args:
trn_data: Training set.
dev_data: Development set.
save_dir: The directory to save trained component.
transformer: An identifier of a pre-trained transformer.
average_subwords: ``True`` to average subword representations.
word_dropout: Dropout rate to randomly replace a subword with MASK.
hidden_dropout: Dropout rate applied to hidden states.
layer_dropout: Randomly zero out hidden states of a transformer layer.
scalar_mix: Layer attention.
grad_norm: Gradient norm for clipping.
transformer_grad_norm: Gradient norm for clipping transformer gradient.
lr: Learning rate for decoder.
transformer_lr: Learning for encoder.
transformer_layers: The number of bottom layers to use.
gradient_accumulation: Number of batches per update.
adam_epsilon: The epsilon to use in Adam.
weight_decay: The weight decay to use.
warmup_steps: The number of warmup steps.
crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
reduction: The loss reduction used in aggregating losses.
batch_size: The number of samples in a batch.
sampler_builder: The builder to build sampler, which will override batch_size.
epochs: The number of epochs to train.
patience: The number of patience epochs before early stopping.
token_key: The key to tokens in dataset.
tagging_scheme: Either ``BMES`` or ``BI``.
delimiter: Delimiter between tokens used to split a line in the corpus.
max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
be split here.
char_level: Whether the sequence length is measured at char level.
hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
in a sentence, it will be split at a token anyway.
transform: An optional transform to be applied to samples. Usually a character normalization transform is
passed in.
devices: Devices this component will live on.
logger: Any :class:`logging.Logger` instance.
seed: Random seed to reproduce this training.
**kwargs: Not used.
Returns:
Best metrics on dev set.
"""
return super().fit(**merge_locals_kwargs(locals(), kwargs))
def feed_batch(self, batch: dict):
x, mask = super().feed_batch(batch)
return x[:, 1:-1, :], mask