# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-07-28 15:12
import functools
from collections import Counter
from typing import Union, List
import torch
from torch import nn
from hanlp_common.constant import UNK
from hanlp.common.transform import TransformList
from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
from hanlp_common.conll import CoNLLUWord, CoNLLSentence
from hanlp.datasets.parsing.semeval15 import unpack_deps_to_head_deprel, append_bos_to_form_pos
from hanlp.metrics.parsing.labeled_f1 import LabeledF1
from hanlp_common.util import merge_locals_kwargs
[docs]class BiaffineSemanticDependencyParser(BiaffineDependencyParser):
def __init__(self) -> None:
r"""Implementation of "Stanford's graph-based neural dependency parser at
the conll 2017 shared task" (:cite:`dozat2017stanford`) and "Establishing Strong Baselines for the New Decade"
(:cite:`he-choi-2019`).
"""
super().__init__()
def get_pad_dict(self):
return {'arc': False}
[docs] def build_metric(self, **kwargs):
return LabeledF1()
# noinspection PyMethodOverriding
def build_dataset(self, data, transform=None):
transforms = TransformList(functools.partial(append_bos_to_form_pos, pos_key='UPOS'),
functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel))
if transform:
transforms.append(transform)
return super(BiaffineSemanticDependencyParser, self).build_dataset(data, transforms)
[docs] def build_criterion(self, **kwargs):
return nn.BCEWithLogitsLoss(), nn.CrossEntropyLoss()
def feed_batch(self, batch):
arc_scores, rel_scores, mask, puncts = super().feed_batch(batch)
mask = self.convert_to_3d_mask(arc_scores, mask)
puncts = self.convert_to_3d_puncts(puncts, mask)
return arc_scores, rel_scores, mask, puncts
@staticmethod
def convert_to_3d_puncts(puncts, mask):
if puncts is not None:
puncts = puncts.unsqueeze(-1).expand_as(mask)
return puncts
@staticmethod
def convert_to_3d_mask(arc_scores, mask):
# 3d masks
mask = mask.unsqueeze(-1).expand_as(arc_scores).clone()
mask[:, :, 1:] = mask[:, :, 1:] & mask.transpose(1, 2)[:, :, 1:] # Keep the 1st colum because it predicts root
return mask
def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask: torch.BoolTensor, criterion, batch=None):
bce, ce = criterion
arc_scores, arcs = arc_scores[mask], arcs[mask]
rel_scores, rels = rel_scores[mask], rels[mask]
rel_scores, rels = rel_scores[arcs], rels[arcs]
arc_loss = bce(arc_scores, arcs.to(torch.float))
arc_loss_interpolation = self.config.get('arc_loss_interpolation', None)
loss = arc_loss * arc_loss_interpolation if arc_loss_interpolation else arc_loss
if len(rels):
rel_loss = ce(rel_scores, rels)
loss += (rel_loss * (1 - arc_loss_interpolation)) if arc_loss_interpolation else rel_loss
if arc_loss_interpolation:
loss *= 2
return loss
def cache_dataset(self, dataset, timer, training=False, logger=None):
if not self.config.apply_constraint:
return super(BiaffineSemanticDependencyParser, self).cache_dataset(dataset, timer, training)
num_roots = Counter()
no_zero_head = True
root_rels = Counter()
for each in dataset:
if training:
num_roots[sum([x[0] for x in each['arc']])] += 1
no_zero_head &= all([x != '_' for x in each['DEPS']])
head_is_root = [i for i in range(len(each['arc'])) if each['arc'][i][0]]
if head_is_root:
for i in head_is_root:
root_rels[each['rel'][i][0]] += 1
timer.log('Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]')
if training:
if self.config.single_root is None:
self.config.single_root = len(num_roots) == 1 and num_roots.most_common()[0][0] == 1
if self.config.no_zero_head is None:
self.config.no_zero_head = no_zero_head
root_rel = root_rels.most_common()[0][0]
self.config.root_rel_id = self.vocabs['rel'].get_idx(root_rel)
if logger:
logger.info(f'Training set properties: [blue]single_root = {self.config.single_root}[/blue], '
f'[blue]no_zero_head = {no_zero_head}[/blue], '
f'[blue]root_rel = {root_rel}[/blue]')
def decode(self, arc_scores, rel_scores, mask, batch=None):
eye = torch.arange(0, arc_scores.size(1), device=arc_scores.device).view(1, 1, -1).expand(
arc_scores.size(0), -1, -1)
inf = float('inf')
arc_scores.scatter_(dim=1, index=eye, value=-inf)
if self.config.apply_constraint:
if self.config.get('single_root', False):
arc_scores[~mask] = -inf # the biaffine decoder doesn't apply 3d mask for now
root_mask = arc_scores[:, :, 0].argmax(dim=-1).unsqueeze_(-1).expand_as(arc_scores[:, :, 0])
arc_scores[:, :, 0] = -inf
arc_scores[:, :, 0].scatter_(dim=-1, index=root_mask, value=inf)
root_rel_id = self.config.root_rel_id
rel_scores[:, :, 0, root_rel_id] = inf
rel_scores[:, :, 1:, root_rel_id] = -inf
arc_scores_T = arc_scores.transpose(-1, -2)
arc = ((arc_scores > 0) & (arc_scores_T < arc_scores))
if self.config.get('no_zero_head', False):
arc_scores_T[arc] = -inf # avoid cycle between a pair of nodes
arc_scores_fix = arc_scores_T.argmax(dim=-2).unsqueeze_(-1).expand_as(arc_scores)
arc.scatter_(dim=-1, index=arc_scores_fix, value=True)
else:
arc = arc_scores > 0
rel = rel_scores.argmax(dim=-1)
return arc, rel
def collect_outputs_extend(self, predictions, arc_preds, rel_preds, lens, mask):
predictions.extend(zip(arc_preds.tolist(), rel_preds.tolist(), mask.tolist()))
# all_arcs.extend(seq.tolist() for seq in arc_preds[mask].split([x * x for x in lens]))
# all_rels.extend(seq.tolist() for seq in rel_preds[mask].split([x * x for x in lens]))
def predictions_to_human(self, predictions, outputs, data, use_pos, conll=True):
for d, (arcs, rels, masks) in zip(data, predictions):
sent = CoNLLSentence()
for idx, (cell, a, r) in enumerate(zip(d, arcs[1:], rels[1:])):
if use_pos:
token, pos = cell
else:
token, pos = cell, None
heads = [i for i in range(len(d) + 1) if a[i]]
deprels = [self.vocabs['rel'][r[i]] for i in range(len(d) + 1) if a[i]]
sent.append(
CoNLLUWord(idx + 1, token, upos=pos, head=None, deprel=None, deps=list(zip(heads, deprels))))
outputs.append(sent)
[docs] def fit(self, trn_data, dev_data, save_dir,
feat=None,
n_embed=100,
pretrained_embed=None,
transformer=None,
average_subwords=False,
word_dropout: float = 0.2,
transformer_hidden_dropout=None,
layer_dropout=0,
mix_embedding: int = None,
embed_dropout=.33,
n_lstm_hidden=400,
n_lstm_layers=3,
hidden_dropout=.33,
n_mlp_arc=500,
n_mlp_rel=100,
mlp_dropout=.33,
arc_dropout=None,
rel_dropout=None,
arc_loss_interpolation=0.4,
lr=2e-3,
transformer_lr=5e-5,
mu=.9,
nu=.9,
epsilon=1e-12,
clip=5.0,
decay=.75,
decay_steps=5000,
weight_decay=0,
warmup_steps=0.1,
separate_optimizer=True,
patience=100,
batch_size=None,
sampler_builder=None,
lowercase=False,
epochs=50000,
apply_constraint=False,
single_root=None,
no_zero_head=None,
punct=False,
min_freq=2,
logger=None,
verbose=True,
unk=UNK,
pad_rel=None,
max_sequence_length=512,
gradient_accumulation=1,
devices: Union[float, int, List[int]] = None,
transform=None,
**kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))