Source code for hanlp.components.lemmatizer
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-12-08 18:35
from typing import List, Dict, Any
from hanlp.common.transform import TransformList
from hanlp.components.mtl.tasks.pos import TransformerTagging
from hanlp.components.parsers.ud.lemma_edit import gen_lemma_rule, apply_lemma_rule
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp_common.document import Document
def add_lemma_rules_to_sample(sample: dict):
if 'tag' in sample and 'lemma' not in sample:
lemma_rules = [gen_lemma_rule(word, lemma)
if lemma != "_" else "_"
for word, lemma in zip(sample['token'], sample['tag'])]
sample['lemma'] = sample['tag'] = lemma_rules
return sample
[docs]class TransformerLemmatizer(TransformerTagger):
def __init__(self, **kwargs) -> None:
"""A transition based lemmatizer using transformer as encoder.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
def build_dataset(self, data, transform=None, **kwargs):
if not isinstance(transform, list):
transform = TransformList()
transform.append(add_lemma_rules_to_sample)
return super().build_dataset(data, transform, **kwargs)
def prediction_to_human(self, pred, vocab: List[str], batch, token=None):
if token is None:
token = batch['token']
rules = super().prediction_to_human(pred, vocab, batch)
for token_per_sent, rule_per_sent in zip(token, rules):
lemma_per_sent = [apply_lemma_rule(t, r) for t, r in zip(token_per_sent, rule_per_sent)]
for i, (t, l) in enumerate(zip(token_per_sent, lemma_per_sent)):
if t.isdigit():
lemma_per_sent[i] = t
yield lemma_per_sent
class TransformerTaggingLemmatizer(TransformerTagging):
def transform_batch(self, batch: Dict[str, Any], results: Dict[str, Any] = None, cls_is_bos=False,
sep_is_eos=False) -> Dict[str, Any]:
return batch
def finalize_document(self, doc: Document, task_name: str):
tok = doc.get_by_prefix('tok')
if tok:
for tokens, lemmas in zip(tok, doc.get_by_prefix('lem')):
if len(''.join(tokens)) == len(''.join(lemmas)):
# Map lemmas into tokens
mapped_lemmas = []
offset = 0
for token in tokens:
mapped_lemmas.append(''.join(lemmas[offset:offset + len(token)]))
offset += len(token)
lemmas.clear()
lemmas.extend(mapped_lemmas)
super().finalize_document(doc, task_name)