Source code for hanlp.components.tokenizers.multi_criteria_cws_transformer
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-10-21 19:55
from typing import List, Union
from hanlp.common.dataset import SamplerBuilder
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer
from hanlp.datasets.tokenization.loaders.multi_criteria_cws.mcws_dataset import MultiCriteriaTextTokenizingDataset, append_criteria_token
import functools
from hanlp.metrics.f1 import F1
from hanlp.metrics.mtl import MetricDict
from hanlp_common.util import merge_locals_kwargs
[docs]class MultiCriteriaTransformerTaggingTokenizer(TransformerTaggingTokenizer):
def __init__(self, **kwargs) -> None:
r"""Transformer based implementation of "Effective Neural Solution for Multi-Criteria Word Segmentation"
(:cite:`he2019effective`). It uses an artificial token ``[unused_i]`` instead of ``[SEP]`` in the input_ids to
mark the i-th segmentation criteria.
Args:
**kwargs: Not used.
"""
super().__init__(**kwargs)
def build_dataset(self, data, **kwargs):
return MultiCriteriaTextTokenizingDataset(data, **kwargs)
[docs] def on_config_ready(self, **kwargs):
super().on_config_ready(**kwargs)
# noinspection PyAttributeOutsideInit
if 'criteria_token_map' not in self.config:
unused_tokens = [f'[unused{i}]' for i in range(1, 100)]
ids = self.transformer_tokenizer.convert_tokens_to_ids(unused_tokens)
self.config.unused_tokens = dict((x, ids[i]) for i, x in enumerate(unused_tokens) if
ids[i] != self.transformer_tokenizer.unk_token_id)
self.config.criteria_token_map = dict()
def last_transform(self):
transforms = super().last_transform()
transforms.append(functools.partial(append_criteria_token,
criteria_tokens=self.config.unused_tokens,
criteria_token_map=self.config.criteria_token_map))
return transforms
[docs] def build_vocabs(self, trn, logger, **kwargs):
super().build_vocabs(trn, logger, **kwargs)
logger.info(f'criteria[{len(self.config.criteria_token_map)}] = {list(self.config.criteria_token_map)}')
def feed_batch(self, batch: dict):
x, mask = TransformerTagger.feed_batch(self, batch)
# strip [CLS], [SEP] and [unused_i]
return x[:, 1:-2, :], mask
def build_samples(self, data: List[str], criteria=None, **kwargs):
if not criteria:
criteria = next(iter(self.config.criteria_token_map.keys()))
else:
assert criteria in self.config.criteria_token_map, \
f'Unsupported criteria {criteria}. Choose one from {list(self.config.criteria_token_map.keys())}'
samples = super().build_samples(data, **kwargs)
for sample in samples:
sample['criteria'] = criteria
return samples
[docs] def build_metric(self, **kwargs):
metrics = MetricDict()
for criteria in self.config.criteria_token_map:
metrics[criteria] = F1()
return metrics
def update_metrics(self, metric, logits, y, mask, batch, prediction):
for p, g, c in zip(prediction, self.tag_to_span(batch['tag']), batch['criteria']):
pred = set(p)
gold = set(g)
metric[c](pred, gold)
[docs] def fit(self, trn_data, dev_data, save_dir, transformer, average_subwords=False, word_dropout: float = 0.2,
hidden_dropout=None, layer_dropout=0, scalar_mix=None, mix_embedding: int = 0, grad_norm=5.0,
transformer_grad_norm=None, lr=5e-5,
transformer_lr=None, transformer_layers=None, gradient_accumulation=1,
adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, crf=False, reduction='sum',
batch_size=32, sampler_builder: SamplerBuilder = None, epochs=30, patience=5, token_key=None,
tagging_scheme='BMES', delimiter=None,
max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None,
devices: Union[float, int, List[int]] = None, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))