# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-05-20 13:12
import logging
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from hanlp.common.dataset import PadSequenceDataLoader, SortingSampler, TransformableDataset
from hanlp_common.configurable import Configurable
from hanlp.common.transform import EmbeddingNamedTransform
from hanlp.common.vocab import Vocab
from hanlp.components.taggers.rnn.rnntaggingmodel import RNNTaggingModel
from hanlp.components.taggers.tagger import Tagger
from hanlp.datasets.ner.loaders.tsv import TSVTaggingDataset
from hanlp.layers.embeddings.embedding import Embedding
from hanlp.layers.embeddings.util import build_word2vec_with_vocab
from hanlp.utils.time_util import CountdownTimer
from hanlp_common.util import merge_locals_kwargs, merge_dict
[docs]class RNNTagger(Tagger):
def __init__(self, **kwargs) -> None:
"""An old-school tagger using non-contextualized embeddings and RNNs as context layer.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
self.model: RNNTaggingModel = None
# noinspection PyMethodOverriding
[docs] def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion,
optimizer,
metric,
save_dir,
logger,
patience,
**kwargs):
max_e, max_metric = 0, -1
criterion = self.build_criterion()
timer = CountdownTimer(epochs)
ratio_width = len(f'{len(trn)}/{len(trn)}')
scheduler = self.build_scheduler(**merge_dict(self.config, optimizer=optimizer, overwrite=True))
if not patience:
patience = epochs
for epoch in range(1, epochs + 1):
logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width)
loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger)
if scheduler:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(dev_metric.score)
else:
scheduler.step(epoch)
report_patience = f'Patience: {epoch - max_e}/{patience}'
# save the model if it is the best so far
if dev_metric > max_metric:
self.save_weights(save_dir)
max_e, max_metric = epoch, dev_metric
report_patience = '[red]Saved[/red] '
stop = epoch - max_e >= patience
if stop:
timer.stop()
timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}',
ratio_percentage=False, newline=True, ratio=False)
if stop:
break
timer.stop()
if max_e != epoch:
self.load_weights(save_dir)
logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}")
logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
def build_scheduler(self, optimizer, anneal_factor, anneal_patience, **kwargs):
scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer,
factor=anneal_factor,
patience=anneal_patience,
mode='max') if anneal_factor and anneal_patience else None
return scheduler
[docs] def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None,
**kwargs):
self.model.train()
timer = CountdownTimer(len(trn))
total_loss = 0
for idx, batch in enumerate(trn):
optimizer.zero_grad()
out, mask = self.feed_batch(batch)
y = batch['tag_id']
loss = self.compute_loss(criterion, out, y, mask)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
total_loss += loss.item()
prediction = self.decode_output(out, mask, batch)
self.update_metrics(metric, out, y, mask, batch, prediction)
timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
ratio_width=ratio_width)
del loss
del out
del mask
def feed_batch(self, batch):
x = batch[f'{self.config.token_key}_id']
out, mask = self.model(x, **batch, batch=batch)
return out, mask
# noinspection PyMethodOverriding
[docs] def build_model(self, rnn_input, rnn_hidden, drop, crf, **kwargs) -> torch.nn.Module:
vocabs = self.vocabs
token_embed = self._convert_embed()
if isinstance(token_embed, EmbeddingNamedTransform):
token_embed = token_embed.output_dim
elif isinstance(token_embed, Embedding):
token_embed = token_embed.module(vocabs=vocabs)
else:
token_embed = build_word2vec_with_vocab(token_embed, vocabs[self.config.token_key])
model = RNNTaggingModel(token_embed, rnn_input, rnn_hidden, len(vocabs['tag']), drop, crf)
return model
def _convert_embed(self):
embed = self.config['embed']
if isinstance(embed, dict):
self.config['embed'] = embed = Configurable.from_config(embed)
return embed
[docs] def build_dataloader(self, data, batch_size, shuffle, device, logger=None, **kwargs) -> DataLoader:
vocabs = self.vocabs
token_embed = self._convert_embed()
dataset = data if isinstance(data, TransformableDataset) else self.build_dataset(data, transform=[vocabs])
if vocabs.mutable:
# Before building vocabs, let embeddings submit their vocabs, some embeddings will possibly opt out as their
# transforms are not relevant to vocabs
if isinstance(token_embed, Embedding):
transform = token_embed.transform(vocabs=vocabs)
if transform:
dataset.transform.insert(-1, transform)
self.build_vocabs(dataset, logger)
if isinstance(token_embed, Embedding):
# Vocabs built, now add all transforms to the pipeline. Be careful about redundant ones.
transform = token_embed.transform(vocabs=vocabs)
if transform and transform not in dataset.transform:
dataset.transform.insert(-1, transform)
sampler = SortingSampler([len(sample[self.config.token_key]) for sample in dataset], batch_size,
shuffle=shuffle)
return PadSequenceDataLoader(dataset,
device=device,
batch_sampler=sampler,
vocabs=vocabs)
def build_dataset(self, data, transform):
return TSVTaggingDataset(data, transform)
[docs] def build_vocabs(self, dataset, logger):
self.vocabs.tag = Vocab(unk_token=None, pad_token=None)
self.vocabs[self.config.token_key] = Vocab()
for each in dataset:
pass
self.vocabs.lock()
self.vocabs.summary(logger)
[docs] def fit(self, trn_data, dev_data, save_dir,
batch_size=50,
epochs=100,
embed=100,
rnn_input=None,
rnn_hidden=256,
drop=0.5,
lr=0.001,
patience=10,
crf=True,
optimizer='adam',
token_key='token',
tagging_scheme=None,
anneal_factor: float = 0.5,
anneal_patience=2,
devices=None, logger=None, verbose=True, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
def _id_to_tags(self, ids):
batch = []
vocab = self.vocabs['tag'].idx_to_token
for b in ids:
batch.append([])
for i in b:
batch[-1].append(vocab[i])
return batch
def write_output(self, yhat, y, mask, batch, prediction, output):
pass