Source code for hanlp.components.lemmatizer
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-12-08 18:35
from typing import List
from hanlp.common.transform import TransformList
from hanlp.components.parsers.ud.lemma_edit import gen_lemma_rule, apply_lemma_rule
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
def add_lemma_rules_to_sample(sample: dict):
if 'tag' in sample and 'lemma' not in sample:
lemma_rules = [gen_lemma_rule(word, lemma)
if lemma != "_" else "_"
for word, lemma in zip(sample['token'], sample['tag'])]
sample['lemma'] = sample['tag'] = lemma_rules
return sample
[docs]class TransformerLemmatizer(TransformerTagger):
def __init__(self, **kwargs) -> None:
"""A transition based lemmatizer using transformer as encoder.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
def build_dataset(self, data, transform=None, **kwargs):
if not isinstance(transform, list):
transform = TransformList()
transform.append(add_lemma_rules_to_sample)
return super().build_dataset(data, transform, **kwargs)
def prediction_to_human(self, pred, vocab: List[str], batch, token=None):
if token is None:
token = batch['token']
rules = super().prediction_to_human(pred, vocab, batch)
for token_per_sent, rule_per_sent in zip(token, rules):
lemma_per_sent = [apply_lemma_rule(t, r) for t, r in zip(token_per_sent, rule_per_sent)]
yield lemma_per_sent