Source code for hanlp.components.taggers.transformers.transformer_tagger
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-15 20:55
import logging
from typing import Union, List
import torch
from torch import nn
from torch.utils.data import DataLoader
from hanlp.common.dataset import PadSequenceDataLoader, SamplerBuilder, TransformableDataset
from hanlp.common.structure import History
from hanlp.common.transform import FieldLength, TransformList
from hanlp.common.vocab import Vocab
from hanlp.components.classifiers.transformer_classifier import TransformerComponent
from hanlp.components.taggers.tagger import Tagger
from hanlp.datasets.ner.loaders.tsv import TSVTaggingDataset
from hanlp.layers.crf.crf import CRF
from hanlp.layers.embeddings.embedding import EmbeddingDim, Embedding
from hanlp.layers.transformers.encoder import TransformerEncoder
from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
from hanlp.utils.time_util import CountdownTimer
from hanlp.utils.torch_util import clip_grad_norm, lengths_to_mask, filter_state_dict_safely
from hanlp_common.util import merge_locals_kwargs
# noinspection PyAbstractClass
class TransformerTaggingModel(nn.Module):
def __init__(self,
encoder: TransformerEncoder,
num_labels,
crf=False,
secondary_encoder=None,
extra_embeddings: EmbeddingDim = None) -> None:
"""
A shallow tagging model use transformer as decoder.
Args:
encoder: A pretrained transformer.
num_labels: Size of tagset.
crf: True to enable CRF.
extra_embeddings: Extra embeddings which will be concatenated to the encoder outputs.
"""
super().__init__()
self.encoder = encoder
self.secondary_encoder = secondary_encoder
self.extra_embeddings = extra_embeddings
# noinspection PyUnresolvedReferences
feature_size = encoder.transformer.config.hidden_size
if extra_embeddings:
feature_size += extra_embeddings.get_output_dim()
self.classifier = nn.Linear(feature_size, num_labels)
self.crf = CRF(num_labels) if crf else None
def forward(self, lens: torch.LongTensor, input_ids, token_span, token_type_ids=None, batch=None):
mask = lengths_to_mask(lens)
x = self.encoder(input_ids, token_span=token_span, token_type_ids=token_type_ids)
if self.secondary_encoder:
x = self.secondary_encoder(x, mask=mask)
if self.extra_embeddings:
# noinspection PyCallingNonCallable
embed = self.extra_embeddings(batch, mask=mask)
x = torch.cat([x, embed], dim=-1)
x = self.classifier(x)
return x, mask
[docs]class TransformerTagger(TransformerComponent, Tagger):
def __init__(self, **kwargs) -> None:
"""A simple tagger using a linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer for
any tagging tasks including PoS tagging and many others.
Args:
**kwargs: Not used.
"""
super().__init__(**kwargs)
self._tokenizer_transform = None
self.model: TransformerTaggingModel = None
# noinspection PyMethodOverriding
[docs] def fit_dataloader(self,
trn: DataLoader,
criterion,
optimizer,
metric,
logger: logging.Logger,
history: History,
gradient_accumulation=1,
grad_norm=None,
transformer_grad_norm=None,
teacher: Tagger = None,
kd_criterion=None,
temperature_scheduler=None,
ratio_width=None,
eval_trn=True,
**kwargs):
optimizer, scheduler = optimizer
if teacher:
scheduler, lambda_scheduler = scheduler
else:
lambda_scheduler = None
self.model.train()
timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
total_loss = 0
for idx, batch in enumerate(trn):
out, mask = self.feed_batch(batch)
y = batch['tag_id']
loss = self.compute_loss(criterion, out, y, mask)
if gradient_accumulation and gradient_accumulation > 1:
loss /= gradient_accumulation
if teacher:
with torch.no_grad():
out_T, _ = teacher.feed_batch(batch)
# noinspection PyNoneFunctionAssignment
kd_loss = self.compute_distill_loss(kd_criterion, out, out_T, mask, temperature_scheduler)
_lambda = float(lambda_scheduler)
loss = _lambda * loss + (1 - _lambda) * kd_loss
loss.backward()
total_loss += loss.item()
if eval_trn:
prediction = self.decode_output(out, mask, batch)
self.update_metrics(metric, out, y, mask, batch, prediction)
if history.step(gradient_accumulation):
self._step(optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler)
report = f'loss: {total_loss / (idx + 1):.4f} {metric if eval_trn else ""}'
timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
del loss
del out
del mask
def _step(self, optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler):
clip_grad_norm(self.model, grad_norm, self.model.encoder.transformer, transformer_grad_norm)
optimizer.step()
scheduler.step()
if lambda_scheduler:
lambda_scheduler.step()
optimizer.zero_grad()
def compute_distill_loss(self, kd_criterion, out_S, out_T, mask, temperature_scheduler):
logits_S = out_S[mask]
logits_T = out_T[mask]
temperature = temperature_scheduler(logits_S, logits_T)
return kd_criterion(logits_S, logits_T, temperature)
[docs] def build_model(self, training=True, extra_embeddings: Embedding = None, finetune=False, logger=None,
**kwargs) -> torch.nn.Module:
model = TransformerTaggingModel(
self.build_transformer(training=training),
len(self.vocabs.tag),
self.config.crf,
self.config.get('secondary_encoder', None),
extra_embeddings=extra_embeddings.module(self.vocabs) if extra_embeddings else None,
)
if finetune and self.model:
model_state = model.state_dict()
load_state = self.model.state_dict()
safe_state = filter_state_dict_safely(model_state, load_state)
missing_params = model_state.keys() - safe_state.keys()
if missing_params:
logger.info(f'The following parameters were missing from the checkpoint: '
f'{", ".join(sorted(missing_params))}.')
model.load_state_dict(safe_state, strict=False)
n = self.model.classifier.bias.size(0)
if model.classifier.bias.size(0) != n:
model.classifier.weight.data[:n, :] = self.model.classifier.weight.data[:n, :]
model.classifier.bias.data[:n] = self.model.classifier.bias.data[:n]
return model
# noinspection PyMethodOverriding
[docs] def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None,
sampler_builder: SamplerBuilder = None, gradient_accumulation=1,
extra_embeddings: Embedding = None, transform=None, max_seq_len=None, **kwargs) -> DataLoader:
if isinstance(data, TransformableDataset):
dataset = data
else:
args = dict((k, self.config.get(k, None)) for k in
['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'])
dataset = self.build_dataset(data, **args)
if self.config.token_key is None:
self.config.token_key = next(iter(dataset[0]))
logger.info(
f'Guess [bold][blue]token_key={self.config.token_key}[/blue][/bold] according to the '
f'training dataset: [blue]{dataset}[/blue]')
if transform:
dataset.append_transform(transform)
if extra_embeddings:
dataset.append_transform(extra_embeddings.transform(self.vocabs))
dataset.append_transform(self.tokenizer_transform)
dataset.append_transform(self.last_transform())
if not isinstance(data, list):
dataset.purge_cache()
if self.vocabs.mutable:
self.build_vocabs(dataset, logger)
if isinstance(data, str) and max_seq_len:
token_key = self.config.token_key
dataset.prune(lambda x: len(x[token_key]) > max_seq_len, logger)
if sampler_builder is not None:
sampler = sampler_builder.build([len(x[f'{self.config.token_key}_input_ids']) for x in dataset], shuffle,
gradient_accumulation=gradient_accumulation if shuffle else 1)
else:
sampler = None
return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
def build_dataset(self, data, transform=None, **kwargs):
return TSVTaggingDataset(data, transform=transform, **kwargs)
def last_transform(self):
transforms = TransformList(self.vocabs, FieldLength(self.config.token_key))
return transforms
@property
def tokenizer_transform(self) -> TransformerSequenceTokenizer:
if not self._tokenizer_transform:
self._tokenizer_transform = TransformerSequenceTokenizer(self.transformer_tokenizer,
self.config.token_key,
ret_token_span=True)
return self._tokenizer_transform
[docs] def build_vocabs(self, trn, logger, **kwargs):
if 'tag' not in self.vocabs:
self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
timer = CountdownTimer(len(trn))
max_seq_len = 0
token_key = self.config.token_key
for each in trn:
max_seq_len = max(max_seq_len, len(each[token_key]))
timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
self.vocabs.tag.set_unk_as_safe_unk()
self.vocabs.lock()
self.vocabs.summary(logger)
# noinspection PyMethodOverriding
[docs] def fit(self,
trn_data,
dev_data,
save_dir,
transformer,
average_subwords=False,
word_dropout: float = 0.2,
hidden_dropout=None,
layer_dropout=0,
scalar_mix=None,
mix_embedding: int = 0,
grad_norm=5.0,
transformer_grad_norm=None,
lr=5e-5,
transformer_lr=None,
transformer_layers=None,
gradient_accumulation=1,
adam_epsilon=1e-6,
weight_decay=0,
warmup_steps=0.1,
secondary_encoder=None,
extra_embeddings: Embedding = None,
crf=False,
reduction='sum',
batch_size=32,
sampler_builder: SamplerBuilder = None,
epochs=3,
patience=5,
token_key=None,
max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False,
transform=None,
logger=None,
devices: Union[float, int, List[int]] = None,
**kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
def feed_batch(self, batch: dict):
features = [batch[k] for k in self.tokenizer_transform.output_key]
if len(features) == 2:
input_ids, token_span = features
else:
input_ids, token_span = features[0], None
lens = batch[f'{self.config.token_key}_length']
x, mask = self.model(lens, input_ids, token_span, batch.get(f'{self.config.token_key}_token_type_ids'),
batch=batch)
return x, mask
# noinspection PyMethodOverriding
def distill(self,
teacher: str,
trn_data,
dev_data,
save_dir,
transformer: str,
batch_size=None,
temperature_scheduler='flsw',
epochs=None,
devices=None,
logger=None,
seed=None,
**kwargs):
return super().distill(**merge_locals_kwargs(locals(), kwargs))