# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-07-09 18:13
import logging
from bisect import bisect
from typing import Union, List, Callable, Tuple, Dict, Any
from hanlp_common.constant import IDX
from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
import torch
from torch.utils.data import DataLoader
from hanlp.common.dataset import PadSequenceDataLoader, SortingSampler
from hanlp.common.torch_component import TorchComponent
from hanlp.common.transform import FieldLength
from hanlp.common.vocab import Vocab
from hanlp.components.srl.span_rank.inference_utils import srl_decode
from hanlp.components.srl.span_rank.span_ranking_srl_model import SpanRankingSRLModel
from hanlp.components.srl.span_rank.srl_eval_utils import compute_srl_f1
from hanlp.datasets.srl.loaders.conll2012 import CoNLL2012SRLDataset, filter_v_args, unpack_srl, \
group_pa_by_p
from hanlp.layers.embeddings.embedding import Embedding
from hanlp.metrics.f1 import F1
from hanlp_common.visualization import markdown_table
from hanlp.utils.time_util import CountdownTimer
from hanlp_common.util import merge_locals_kwargs, reorder
[docs]class SpanRankingSemanticRoleLabeler(TorchComponent):
def __init__(self, **kwargs) -> None:
"""An implementation of "Jointly Predicting Predicates and Arguments in Neural Semantic Role Labeling"
(:cite:`he-etal-2018-jointly`). It generates candidates triples of (predicate, arg_start, arg_end) and rank them.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
self.model: SpanRankingSRLModel = None
[docs] def build_optimizer(self,
trn,
epochs,
lr,
adam_epsilon,
weight_decay,
warmup_steps,
transformer_lr,
**kwargs):
# noinspection PyProtectedMember
transformer = self._get_transformer()
if transformer:
num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model,
transformer,
lr, transformer_lr,
num_training_steps, warmup_steps,
weight_decay, adam_epsilon)
else:
optimizer = torch.optim.Adam(self.model.parameters(), self.config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer,
mode='max',
factor=0.5,
patience=2,
verbose=True,
)
return optimizer, scheduler
def _get_transformer(self):
return getattr(self.model_.embed, 'transformer', None)
[docs] def build_criterion(self, **kwargs):
pass
# noinspection PyProtectedMember
[docs] def build_metric(self, **kwargs) -> Tuple[F1, F1]:
predicate_f1 = F1()
end_to_end_f1 = F1()
return predicate_f1, end_to_end_f1
[docs] def execute_training_loop(self,
trn: DataLoader,
dev: DataLoader,
epochs,
criterion,
optimizer,
metric,
save_dir,
logger: logging.Logger,
devices,
**kwargs):
best_epoch, best_metric = 0, -1
predicate, end_to_end = metric
optimizer, scheduler = optimizer
timer = CountdownTimer(epochs)
ratio_width = len(f'{len(trn)}/{len(trn)}')
for epoch in range(1, epochs + 1):
logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
self.fit_dataloader(trn, criterion, optimizer, metric, logger,
linear_scheduler=scheduler if self._get_transformer() else None)
if dev:
self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
report = f'{timer.elapsed_human}/{timer.total_time_human}'
dev_score = end_to_end.score
if not self._get_transformer():
scheduler.step(dev_score)
if dev_score > best_metric:
self.save_weights(save_dir)
best_metric = dev_score
report += ' [red]saved[/red]'
timer.log(report, ratio_percentage=False, newline=True, ratio=False)
[docs] def fit_dataloader(self,
trn: DataLoader,
criterion,
optimizer,
metric,
logger: logging.Logger,
linear_scheduler=None,
gradient_accumulation=1,
**kwargs):
self.model.train()
timer = CountdownTimer(len(trn) // gradient_accumulation)
total_loss = 0
self.reset_metrics(metric)
for idx, batch in enumerate(trn):
output_dict = self.feed_batch(batch)
self.update_metrics(batch, output_dict, metric)
loss = output_dict['loss']
loss = loss.sum() # For data parallel
if torch.isnan(loss): # w/ gold pred, some batches do not have PAs at all, resulting in empty scores
loss = torch.zeros((1,), device=loss.device)
else:
loss.backward()
if gradient_accumulation and gradient_accumulation > 1:
loss /= gradient_accumulation
if self.config.grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
if (idx + 1) % gradient_accumulation == 0:
self._step(optimizer, linear_scheduler)
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger)
total_loss += loss.item()
del loss
if len(trn) % gradient_accumulation:
self._step(optimizer, linear_scheduler)
return total_loss / timer.total
def _step(self, optimizer, linear_scheduler):
optimizer.step()
optimizer.zero_grad()
if linear_scheduler:
linear_scheduler.step()
# noinspection PyMethodOverriding
[docs] @torch.no_grad()
def evaluate_dataloader(self,
data: DataLoader,
criterion: Callable,
metric,
logger,
ratio_width=None,
output=False,
official=False,
confusion_matrix=False,
**kwargs):
self.model.eval()
self.reset_metrics(metric)
timer = CountdownTimer(len(data))
total_loss = 0
if official:
sentences = []
gold = []
pred = []
for batch in data:
output_dict = self.feed_batch(batch)
if official:
sentences += batch['token']
gold += batch['srl']
pred += output_dict['prediction']
self.update_metrics(batch, output_dict, metric)
loss = output_dict['loss']
total_loss += loss.item()
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger,
ratio_width=ratio_width)
del loss
if official:
scores = compute_srl_f1(sentences, gold, pred)
if logger:
if confusion_matrix:
labels = sorted(set(y for x in scores.label_confusions.keys() for y in x))
headings = ['GOLD↓PRED→'] + labels
matrix = []
for i, gold in enumerate(labels):
row = [gold]
matrix.append(row)
for j, pred in enumerate(labels):
row.append(scores.label_confusions.get((gold, pred), 0))
matrix = markdown_table(headings, matrix)
logger.info(f'{"Confusion Matrix": ^{len(matrix.splitlines()[0])}}')
logger.info(matrix)
headings = ['Settings', 'Precision', 'Recall', 'F1']
data = []
for h, (p, r, f) in zip(['Unlabeled', 'Labeled', 'Official'], [
[scores.unlabeled_precision, scores.unlabeled_recall, scores.unlabeled_f1],
[scores.precision, scores.recall, scores.f1],
[scores.conll_precision, scores.conll_recall, scores.conll_f1],
]):
data.append([h] + [f'{x:.2%}' for x in [p, r, f]])
table = markdown_table(headings, data)
logger.info(f'{"Scores": ^{len(table.splitlines()[0])}}')
logger.info(table)
else:
scores = metric
return total_loss / timer.total, scores
[docs] def build_model(self,
training=True,
**kwargs) -> torch.nn.Module:
# noinspection PyTypeChecker
# embed: torch.nn.Embedding = self.config.embed.module(vocabs=self.vocabs)[0].embed
model = SpanRankingSRLModel(self.config,
self.config.embed.module(vocabs=self.vocabs, training=training),
self.config.context_layer,
len(self.vocabs.srl_label))
return model
# noinspection PyMethodOverriding
[docs] def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger,
generate_idx=False, transform=None, **kwargs) -> DataLoader:
batch_max_tokens = self.config.batch_max_tokens
gradient_accumulation = self.config.get('gradient_accumulation', 1)
if batch_size:
batch_size //= gradient_accumulation
if batch_max_tokens:
batch_max_tokens //= gradient_accumulation
dataset = self.build_dataset(data, generate_idx, logger, transform)
sampler = SortingSampler([x['token_length'] for x in dataset],
batch_size=batch_size,
batch_max_tokens=batch_max_tokens,
shuffle=shuffle)
return PadSequenceDataLoader(batch_sampler=sampler,
device=device,
dataset=dataset)
def build_dataset(self, data, generate_idx, logger, transform=None):
dataset = CoNLL2012SRLDataset(data, transform=[filter_v_args, unpack_srl, group_pa_by_p],
doc_level_offset=self.config.doc_level_offset, generate_idx=generate_idx)
if transform:
dataset.append_transform(transform)
if isinstance(self.config.get('embed', None), Embedding):
transform = self.config.embed.transform(vocabs=self.vocabs)
if transform:
dataset.append_transform(transform)
dataset.append_transform(self.vocabs)
dataset.append_transform(FieldLength('token'))
if isinstance(data, str):
dataset.purge_cache() # Enable cache
if self.vocabs.mutable:
self.build_vocabs(dataset, logger)
return dataset
[docs] def predict(self, data: Union[str, List[str]], batch_size: int = None, fmt='dict', **kwargs):
if not data:
return []
flat = self.input_is_flat(data)
if flat:
data = [data]
samples = []
for token in data:
sample = dict()
sample['token'] = token
samples.append(sample)
batch_size = batch_size or self.config.batch_size
dataloader = self.build_dataloader(samples, batch_size, False, self.device, None, generate_idx=True)
outputs = []
order = []
for batch in dataloader:
output_dict = self.feed_batch(batch)
outputs.extend(output_dict['prediction'])
order.extend(batch[IDX])
outputs = reorder(outputs, order)
if fmt == 'list':
outputs = self.format_dict_to_results(data, outputs)
if flat:
return outputs[0]
return outputs
@staticmethod
def format_dict_to_results(data, outputs, exclusive_offset=False, with_predicate=False, with_argument=False,
label_first=False):
results = []
for i in range(len(outputs)):
tokens = data[i]
output = []
for p, a in outputs[i].items():
# a: [(0, 0, 'ARG0')]
if with_predicate:
a.insert(bisect([x[0] for x in a], p), (p, p, 'PRED'))
if with_argument is not False:
a = [x + (tokens[x[0]:x[1] + 1],) for x in a]
if isinstance(with_argument, str):
a = [x[:-1] + (with_argument.join(x[-1]),) for x in a]
if exclusive_offset:
a = [(x[0], x[1] + 1) + x[2:] for x in a]
if label_first:
a = [tuple(reversed(x[2:])) + x[:2] for x in a]
output.append(a)
results.append(output)
return results
def input_is_flat(self, data):
return isinstance(data[0], str)
# noinspection PyMethodOverriding
[docs] def fit(self,
trn_data,
dev_data,
save_dir,
embed,
context_layer,
batch_size=40,
batch_max_tokens=700,
lexical_dropout=0.5,
dropout=0.2,
span_width_feature_size=20,
ffnn_size=150,
ffnn_depth=2,
argument_ratio=0.8,
predicate_ratio=0.4,
max_arg_width=30,
mlp_label_size=100,
enforce_srl_constraint=False,
use_gold_predicates=False,
doc_level_offset=True,
use_biaffine=False,
lr=1e-3,
transformer_lr=1e-5,
adam_epsilon=1e-6,
weight_decay=0.01,
warmup_steps=0.1,
grad_norm=5.0,
gradient_accumulation=1,
loss_reduction='sum',
transform=None,
devices=None,
logger=None,
seed=None,
**kwargs
):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
[docs] def build_vocabs(self, dataset, logger, **kwargs):
self.vocabs.srl_label = Vocab(pad_token=None, unk_token=None)
# Use null to indicate no relationship
self.vocabs.srl_label.add('<null>')
timer = CountdownTimer(len(dataset))
max_seq_len = 0
for each in dataset:
max_seq_len = max(max_seq_len, len(each['token_input_ids']))
timer.log(f'Building vocabs (max sequence length {max_seq_len}) [blink][yellow]...[/yellow][/blink]')
pass
timer.stop()
timer.erase()
self.vocabs['srl_label'].set_unk_as_safe_unk()
self.vocabs.lock()
self.vocabs.summary(logger)
def reset_metrics(self, metrics):
for each in metrics:
each.reset()
def report_metrics(self, loss, metrics):
predicate, end_to_end = metrics
return f'loss: {loss:.4f} predicate: {predicate.score:.2%} end_to_end: {end_to_end.score:.2%}'
def feed_batch(self, batch) -> Dict[str, Any]:
output_dict = self.model(batch)
prediction = self.decode_output(output_dict, batch, self.model.training)
output_dict['prediction'] = prediction
return output_dict
def decode_output(self, output_dict, batch, training=False):
idx_to_label = self.vocabs['srl_label'].idx_to_token
if training:
# Use fast decoding during training,
prediction = []
top_predicate_indices = output_dict['predicates'].tolist()
top_spans = torch.stack([output_dict['arg_starts'], output_dict['arg_ends']], dim=-1).tolist()
srl_mask = output_dict['srl_mask'].tolist()
srl_scores = output_dict['srl_scores']
pal_list = srl_scores.argmax(-1).tolist() if srl_scores.numel() else []
for n, (pal, predicate_indices, argument_spans) in enumerate(
zip(pal_list, top_predicate_indices, top_spans)):
srl_per_sentence = {}
for p, (al, predicate_index) in enumerate(zip(pal, predicate_indices)):
for a, (l, argument_span) in enumerate(zip(al, argument_spans)):
if l and srl_mask[n][p][a]:
args = srl_per_sentence.get(p, None)
if args is None:
args = srl_per_sentence[p] = []
args.append((*argument_span, idx_to_label[l]))
prediction.append(srl_per_sentence)
else:
prediction = srl_decode(batch['token_length'], output_dict, idx_to_label, self.config)
return prediction
def update_metrics(self, batch: dict, output_dict: dict, metrics):
def unpack(y: dict):
return set((p, bel) for p, a in y.items() for bel in a)
predicate, end_to_end = metrics
for pred, gold in zip(output_dict['prediction'], batch['srl']):
predicate(pred.keys(), gold.keys())
end_to_end(unpack(pred), unpack(gold))