Source code for hanlp.components.parsers.constituency.crf_constituency_parser

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-11-28 21:24
import logging
from typing import Union, List

import torch
from phrasetree.tree import Tree
from torch.utils.data import DataLoader

from hanlp_common.constant import BOS, EOS, IDX
from hanlp.common.dataset import TransformableDataset, SamplerBuilder, PadSequenceDataLoader
from hanlp.common.structure import History
from hanlp.common.torch_component import TorchComponent
from hanlp.common.transform import FieldLength, TransformList
from hanlp.common.vocab import VocabWithNone
from hanlp.components.classifiers.transformer_classifier import TransformerComponent
from hanlp.datasets.parsing.loaders.constituency_dataset import ConstituencyDataset, unpack_tree_to_features, \
    build_tree, factorize, remove_subcategory
from hanlp.components.parsers.constituency.crf_constituency_model import CRFConstituencyDecoder, CRFConstituencyModel
from hanlp.metrics.parsing.span import SpanMetric
from hanlp.utils.time_util import CountdownTimer
from hanlp.utils.torch_util import clip_grad_norm
from hanlp_common.util import merge_locals_kwargs, merge_dict, reorder


[docs]class CRFConstituencyParser(TorchComponent): def __init__(self, **kwargs) -> None: """Two-stage CRF Parsing (:cite:`ijcai2020-560`). Args: **kwargs: Predefined config. """ super().__init__(**kwargs) self.model: CRFConstituencyModel = self.model
[docs] def build_optimizer(self, trn, **kwargs): # noinspection PyCallByClass,PyTypeChecker return TransformerComponent.build_optimizer(self, trn, **kwargs)
[docs] def build_criterion(self, decoder=None, **kwargs): return decoder
[docs] def build_metric(self, **kwargs): return SpanMetric()
[docs] def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, patience=0.5, eval_trn=True, **kwargs): if isinstance(patience, float): patience = int(patience * epochs) best_epoch, best_metric = 0, -1 timer = CountdownTimer(epochs) history = History() for epoch in range(1, epochs + 1): logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]") self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width, eval_trn=eval_trn, **self.config) loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger=logger, ratio_width=ratio_width) timer.update() report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}" if dev_metric > best_metric: best_epoch, best_metric = epoch, dev_metric self.save_weights(save_dir) report += ' [red](saved)[/red]' else: report += f' ({epoch - best_epoch})' if epoch - best_epoch >= patience: report += ' early stop' logger.info(report) if epoch - best_epoch >= patience: break if not best_epoch: self.save_weights(save_dir) elif best_epoch != epoch: self.load_weights(save_dir) logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}") logger.info(f"Average time of each epoch is {timer.elapsed_average_human}") logger.info(f"{timer.elapsed_human} elapsed")
# noinspection PyMethodOverriding
[docs] def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric: SpanMetric, logger: logging.Logger, history: History, gradient_accumulation=1, grad_norm=None, ratio_width=None, eval_trn=True, **kwargs): optimizer, scheduler = optimizer metric.reset() self.model.train() timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation)) total_loss = 0 for idx, batch in enumerate(trn): out, mask = self.feed_batch(batch) y = batch['chart_id'] loss, span_probs = self.compute_loss(out, y, mask) if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += loss.item() if eval_trn: prediction = self.decode_output(out, mask, batch, span_probs) self.update_metrics(metric, batch, prediction) if history.step(gradient_accumulation): self._step(optimizer, scheduler, grad_norm) report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \ else f'loss: {total_loss / (idx + 1):.4f}' timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width) del loss del out del mask
def decode_output(self, out, mask, batch, span_probs=None, decoder=None, tokens=None): s_span, s_label = out if not decoder: decoder = self.model.decoder if mask.any().item(): if span_probs is None: if self.config.mbr: s_span = decoder.crf(s_span, mask, mbr=True) else: s_span = span_probs chart_preds = decoder.decode(s_span, s_label, mask) else: chart_preds = [[]] * len(tokens) idx_to_token = self.vocabs.chart.idx_to_token if tokens is None: tokens = batch.get('token_', None) # Use the original tokens if any if tokens is None: tokens = batch['token'] tokens = [x[1:-1] for x in tokens] trees = [build_tree(token, [(i, j, idx_to_token[label]) for i, j, label in chart]) for token, chart in zip(tokens, chart_preds)] # probs = [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span.unbind())] return trees def update_metrics(self, metric, batch, prediction): # Add pre-terminals (pos tags) back to prediction for safe factorization (deletion based on pos) for pred, gold in zip(prediction, batch['constituency']): pred: Tree = pred gold: Tree = gold for p, g in zip(pred.subtrees(lambda t: t.height() == 2), gold.pos()): token, pos = g p: Tree = p assert p.label() == '_' p.set_label(pos) metric([factorize(tree, self.config.delete, self.config.equal) for tree in prediction], [factorize(tree, self.config.delete, self.config.equal) for tree in batch['constituency']]) return metric def feed_batch(self, batch: dict): mask = self.compute_mask(batch) s_span, s_label = self.model(batch) return (s_span, s_label), mask def compute_mask(self, batch, offset=1): lens = batch['token_length'] - offset seq_len = lens.max() mask = lens.new_tensor(range(seq_len)) < lens.view(-1, 1, 1) mask = mask & mask.new_ones(seq_len, seq_len).triu_(1) return mask def compute_loss(self, out, y, mask, crf_decoder=None): if not crf_decoder: crf_decoder = self.model.decoder loss, span_probs = crf_decoder.loss(out[0], out[1], y, mask, self.config.mbr) if loss < 0: # wired negative loss loss *= 0 return loss, span_probs def _step(self, optimizer, scheduler, grad_norm): clip_grad_norm(self.model, grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad()
[docs] @torch.no_grad() def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, metric=None, output=None, **kwargs): self.model.eval() total_loss = 0 if not metric: metric = self.build_metric() else: metric.reset() timer = CountdownTimer(len(data)) for idx, batch in enumerate(data): out, mask = self.feed_batch(batch) y = batch['chart_id'] loss, span_probs = self.compute_loss(out, y, mask) total_loss += loss.item() prediction = self.decode_output(out, mask, batch, span_probs) self.update_metrics(metric, batch, prediction) timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger, ratio_width=ratio_width) total_loss /= len(data) if output: output.close() return total_loss, metric
# noinspection PyMethodOverriding
[docs] def build_model(self, encoder, training=True, **kwargs) -> torch.nn.Module: decoder = CRFConstituencyDecoder(n_labels=len(self.vocabs.chart), n_hidden=encoder.get_output_dim(), **kwargs) encoder = encoder.module(vocabs=self.vocabs, training=training) return CRFConstituencyModel(encoder, decoder)
[docs] def build_dataloader(self, data, batch_size, sampler_builder: SamplerBuilder = None, gradient_accumulation=1, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) -> DataLoader: if isinstance(data, TransformableDataset): dataset = data else: transform = self.config.encoder.transform() if self.config.get('transform', None): transform = TransformList(self.config.transform, transform) dataset = self.build_dataset(data, transform, logger) if self.vocabs.mutable: # noinspection PyTypeChecker self.build_vocabs(dataset, logger) lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
[docs] def predict(self, data: Union[str, List[str]], **kwargs): if not data: return [] flat = self.input_is_flat(data) if flat: data = [data] samples = self.build_samples(data) dataloader = self.build_dataloader(samples, device=self.device, **kwargs) outputs = [] orders = [] for idx, batch in enumerate(dataloader): out, mask = self.feed_batch(batch) prediction = self.decode_output(out, mask, batch, span_probs=None) # prediction = [x[0] for x in prediction] outputs.extend(prediction) orders.extend(batch[IDX]) outputs = reorder(outputs, orders) if flat: return outputs[0] return outputs
def input_is_flat(self, data): return isinstance(data[0], str) def build_samples(self, data): return [{'token': [BOS] + token + [EOS]} for token in data] # noinspection PyMethodOverriding
[docs] def fit(self, trn_data, dev_data, save_dir, encoder, lr=5e-5, transformer_lr=None, adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, grad_norm=1.0, n_mlp_span=500, n_mlp_label=100, mlp_dropout=.33, batch_size=None, batch_max_tokens=5000, gradient_accumulation=1, epochs=30, patience=0.5, mbr=True, sampler_builder=None, delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',', 'S1'), equal=(('ADVP', 'PRT'),), no_subcategory=True, eval_trn=True, transform=None, devices=None, logger=None, seed=None, **kwargs): if isinstance(equal, tuple): equal = dict(equal) return super().fit(**merge_locals_kwargs(locals(), kwargs))
def build_dataset(self, data, transform, logger=None): _transform = [ unpack_tree_to_features, self.vocabs, FieldLength('token'), transform ] if self.config.get('no_subcategory', True): _transform.insert(0, remove_subcategory) dataset = ConstituencyDataset(data, transform=_transform, cache=isinstance(data, str)) return dataset
[docs] def build_vocabs(self, trn, logger, **kwargs): self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None) timer = CountdownTimer(len(trn)) max_seq_len = 0 for each in trn: max_seq_len = max(max_seq_len, len(each['token_input_ids'])) timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})') self.vocabs.chart.set_unk_as_safe_unk() self.vocabs.lock() self.vocabs.summary(logger)