# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-20 19:55
import functools
import itertools
import logging
import os
from collections import defaultdict
from copy import copy
from itertools import chain
from typing import Union, List, Callable, Dict, Optional, Any, Iterable, Tuple
import numpy as np
import torch
from hanlp_common.constant import IDX, BOS, EOS
from hanlp_common.document import Document
from hanlp_common.util import merge_locals_kwargs, topological_sort, reorder, prefix_match
from hanlp_common.visualization import markdown_table
from toposort import toposort
from torch.utils.data import DataLoader
import hanlp.utils.torch_util
from hanlp.common.dataset import PadSequenceDataLoader, PrefetchDataLoader, CachedDataLoader
from hanlp.common.structure import History
from hanlp.common.torch_component import TorchComponent
from hanlp.common.transform import FieldLength, TransformList
from hanlp.components.mtl.tasks import Task
from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding, ContextualWordEmbeddingModule
from hanlp.layers.embeddings.embedding import Embedding
from hanlp.layers.transformers.utils import pick_tensor_for_each_token
from hanlp.metrics.metric import Metric
from hanlp.metrics.mtl import MetricDict
from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
from hanlp.utils.time_util import CountdownTimer
from hanlp.utils.torch_util import clip_grad_norm
class MultiTaskModel(torch.nn.Module):
def __init__(self,
encoder: torch.nn.Module,
scalar_mixes: torch.nn.ModuleDict,
decoders: torch.nn.ModuleDict,
use_raw_hidden_states: dict) -> None:
super().__init__()
self.use_raw_hidden_states = use_raw_hidden_states
self.encoder: ContextualWordEmbeddingModule = encoder
self.scalar_mixes = scalar_mixes
self.decoders = decoders
class MultiTaskDataLoader(DataLoader):
def __init__(self, training=True, tau: float = 0.8, **dataloaders) -> None:
# noinspection PyTypeChecker
super().__init__(None)
self.tau = tau
self.training = training
self.dataloaders: Dict[str, DataLoader] = dataloaders if dataloaders else {}
# self.iterators = dict((k, iter(v)) for k, v in dataloaders.items())
def __len__(self) -> int:
if self.dataloaders:
return sum(len(x) for x in self.dataloaders.values())
return 0
def __iter__(self):
if self.training:
sampling_weights, total_size = self.sampling_weights
task_names = list(self.dataloaders.keys())
iterators = dict((k, itertools.cycle(v)) for k, v in self.dataloaders.items())
for i in range(total_size):
task_name = np.random.choice(task_names, p=sampling_weights)
yield task_name, next(iterators[task_name])
else:
for task_name, dataloader in self.dataloaders.items():
for batch in dataloader:
yield task_name, batch
@property
def sampling_weights(self):
sampling_weights = self.sizes
total_size = sum(sampling_weights)
Z = sum(pow(v, self.tau) for v in sampling_weights)
sampling_weights = [pow(v, self.tau) / Z for v in sampling_weights]
return sampling_weights, total_size
@property
def sizes(self):
return [len(v) for v in self.dataloaders.values()]
[docs]class MultiTaskLearning(TorchComponent):
def __init__(self, **kwargs) -> None:
""" A multi-task learning (MTL) framework. It shares the same encoder across multiple decoders. These decoders
can have dependencies on each other which will be properly handled during decoding. To integrate a component
into this MTL framework, a component needs to implement the :class:`~hanlp.components.mtl.tasks.Task` interface.
This framework mostly follows the architecture of :cite:`clark-etal-2019-bam` and :cite:`he-choi-2021-stem`, with additional scalar mix
tricks (:cite:`kondratyuk-straka-2019-75`) allowing each task to attend to any subset of layers. We also
experimented with knowledge distillation on single tasks, the performance gain was nonsignificant on a large
dataset. In the near future, we have no plan to invest more efforts in distillation, since most datasets HanLP
uses are relatively large, and our hardware is relatively powerful.
Args:
**kwargs: Arguments passed to config.
"""
super().__init__(**kwargs)
self.model: Optional[MultiTaskModel] = None
self.tasks: Dict[str, Task] = None
self.vocabs = None
[docs] def build_dataloader(self,
data,
batch_size,
shuffle=False,
device=None,
logger: logging.Logger = None,
gradient_accumulation=1,
tau: float = 0.8,
prune=None,
prefetch=None,
tasks_need_custom_eval=None,
cache=False,
debug=False,
**kwargs) -> DataLoader:
# This method is only called during training or evaluation but not prediction
dataloader = MultiTaskDataLoader(training=shuffle, tau=tau)
for i, (task_name, task) in enumerate(self.tasks.items()):
encoder_transform, transform = self.build_transform(task)
training = None
if data == 'trn':
if debug:
_data = task.dev
else:
_data = task.trn
training = True
elif data == 'dev':
_data = task.dev
training = False
elif data == 'tst':
_data = task.tst
training = False
else:
_data = data
if isinstance(data, str):
logger.info(f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for '
f'[cyan]{task_name}[/cyan] ...')
# Adjust Tokenizer according to task config
config = copy(task.config)
config.pop('transform', None)
task_dataloader: DataLoader = task.build_dataloader(_data, transform, training, device, logger,
tokenizer=encoder_transform.tokenizer,
gradient_accumulation=gradient_accumulation,
cache=isinstance(data, str), **config)
# if prune:
# # noinspection PyTypeChecker
# task_dataset: TransformDataset = task_dataloader.dataset
# size_before = len(task_dataset)
# task_dataset.prune(prune)
# size_after = len(task_dataset)
# num_pruned = size_before - size_after
# logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
# f'samples out of {size_before}.')
if cache and data in ('trn', 'dev'):
task_dataloader: CachedDataLoader = CachedDataLoader(
task_dataloader,
f'{cache}/{os.getpid()}-{data}-{task_name.replace("/", "-")}-cache.pt' if isinstance(cache,
str) else None
)
dataloader.dataloaders[task_name] = task_dataloader
if data == 'trn':
sampling_weights, total_size = dataloader.sampling_weights
headings = ['task', '#batches', '%batches', '#scaled', '%scaled', '#epoch']
matrix = []
min_epochs = []
for (task_name, dataset), weight in zip(dataloader.dataloaders.items(), sampling_weights):
epochs = len(dataset) / weight / total_size
matrix.append(
[f'{task_name}', len(dataset), f'{len(dataset) / total_size:.2%}', int(total_size * weight),
f'{weight:.2%}', f'{epochs:.2f}'])
min_epochs.append(epochs)
longest = int(torch.argmax(torch.tensor(min_epochs)))
table = markdown_table(headings, matrix)
rows = table.splitlines()
cells = rows[longest + 2].split('|')
cells[-2] = cells[-2].replace(f'{min_epochs[longest]:.2f}',
f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]')
rows[longest + 2] = '|'.join(cells)
logger.info(f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]')
logger.info('\n'.join(rows))
if prefetch and (data == 'trn' or not tasks_need_custom_eval):
dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch)
return dataloader
def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]:
encoder: ContextualWordEmbedding = self.config.encoder
encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform())
length_transform = FieldLength('token', 'token_length')
transform = TransformList(encoder_transform, length_transform)
extra_transform = self.config.get('transform', None)
if extra_transform:
transform.insert(0, extra_transform)
return encoder_transform, transform
[docs] def build_optimizer(self,
trn,
epochs,
adam_epsilon,
weight_decay,
warmup_steps,
lr,
encoder_lr,
**kwargs):
model = self.model_
encoder = model.encoder
num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
encoder_parameters = list(encoder.parameters())
parameter_groups: List[Dict[str, Any]] = []
decoders = model.decoders
decoder_optimizers = dict()
for k, task in self.tasks.items():
decoder: torch.nn.Module = decoders[k]
decoder_parameters = list(decoder.parameters())
if task.separate_optimizer:
decoder_optimizers[k] = task.build_optimizer(decoder=decoder, **kwargs)
else:
task_lr = task.lr or lr
parameter_groups.append({"params": decoder_parameters, 'lr': task_lr})
parameter_groups.append({"params": encoder_parameters, 'lr': encoder_lr})
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay_parameters = set()
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
no_decay_parameters.add(p)
no_decay_by_lr = defaultdict(list)
for group in parameter_groups:
_lr = group['lr']
ps = group['params']
group['params'] = decay_parameters = []
group['weight_decay'] = weight_decay
for p in ps:
if p in no_decay_parameters:
no_decay_by_lr[_lr].append(p)
else:
decay_parameters.append(p)
for _lr, ps in no_decay_by_lr.items():
parameter_groups.append({"params": ps, 'lr': _lr, 'weight_decay': 0.0})
# noinspection PyTypeChecker
from transformers import optimization
encoder_optimizer = optimization.AdamW(
parameter_groups,
lr=lr,
weight_decay=weight_decay,
eps=adam_epsilon,
)
encoder_scheduler = optimization.get_linear_schedule_with_warmup(encoder_optimizer,
num_training_steps * warmup_steps,
num_training_steps)
return encoder_optimizer, encoder_scheduler, decoder_optimizers
[docs] def build_criterion(self, **kwargs):
return dict((k, v.build_criterion(decoder=self.model_.decoders[k], **kwargs)) for k, v in self.tasks.items())
[docs] def build_metric(self, **kwargs):
metrics = MetricDict()
for key, task in self.tasks.items():
metric = task.build_metric(**kwargs)
assert metric, f'Please implement `build_metric` of {type(task)} to return a metric.'
metrics[key] = metric
return metrics
[docs] def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
logger: logging.Logger, devices, patience=0.5, **kwargs):
if isinstance(patience, float):
patience = int(patience * epochs)
best_epoch, best_metric = 0, -1
timer = CountdownTimer(epochs)
ratio_width = len(f'{len(trn)}/{len(trn)}')
epoch = 0
history = History()
for epoch in range(1, epochs + 1):
logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
self.fit_dataloader(trn, criterion, optimizer, metric, logger, history, ratio_width=ratio_width,
**self.config)
if dev:
self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width, input='dev')
report = f'{timer.elapsed_human}/{timer.total_time_human}'
dev_score = metric.score
if dev_score > best_metric:
self.save_weights(save_dir)
best_metric = dev_score
best_epoch = epoch
report += ' [red]saved[/red]'
else:
report += f' ({epoch - best_epoch})'
if epoch - best_epoch >= patience:
report += ' early stop'
break
timer.log(report, ratio_percentage=False, newline=True, ratio=False)
for d in [trn, dev]:
self._close_dataloader(d)
if best_epoch != epoch:
logger.info(f'Restoring best model saved [red]{epoch - best_epoch}[/red] epochs ago')
self.load_weights(save_dir)
return best_metric
def _close_dataloader(self, d):
if isinstance(d, PrefetchDataLoader):
d.close()
if hasattr(d.dataset, 'close'):
self._close_dataloader(d.dataset)
elif isinstance(d, CachedDataLoader):
d.close()
elif isinstance(d, MultiTaskDataLoader):
for d in d.dataloaders.values():
self._close_dataloader(d)
# noinspection PyMethodOverriding
[docs] def fit_dataloader(self,
trn: DataLoader,
criterion,
optimizer,
metric,
logger: logging.Logger,
history: History,
ratio_width=None,
gradient_accumulation=1,
encoder_grad_norm=None,
decoder_grad_norm=None,
patience=0.5,
eval_trn=False,
**kwargs):
self.model.train()
encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
timer = CountdownTimer(len(trn))
total_loss = 0
self.reset_metrics(metric)
model = self.model_
encoder_parameters = model.encoder.parameters()
decoder_parameters = model.decoders.parameters()
for idx, (task_name, batch) in enumerate(trn):
decoder_optimizer = decoder_optimizers.get(task_name, None)
output_dict, _ = self.feed_batch(batch, task_name)
loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
self.tasks[task_name])
if gradient_accumulation and gradient_accumulation > 1:
loss /= gradient_accumulation
loss.backward()
total_loss += float(loss.item())
if history.step(gradient_accumulation):
if self.config.get('grad_norm', None):
clip_grad_norm(model, self.config.grad_norm)
if encoder_grad_norm:
torch.nn.utils.clip_grad_norm_(encoder_parameters, encoder_grad_norm)
if decoder_grad_norm:
torch.nn.utils.clip_grad_norm_(decoder_parameters, decoder_grad_norm)
encoder_optimizer.step()
encoder_optimizer.zero_grad()
encoder_scheduler.step()
if decoder_optimizer:
if isinstance(decoder_optimizer, tuple):
decoder_optimizer, decoder_scheduler = decoder_optimizer
else:
decoder_scheduler = None
decoder_optimizer.step()
decoder_optimizer.zero_grad()
if decoder_scheduler:
decoder_scheduler.step()
if eval_trn:
self.decode_output(output_dict, batch, task_name)
self.update_metrics(batch, output_dict, metric, task_name)
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric if eval_trn else None),
ratio_percentage=None,
ratio_width=ratio_width,
logger=logger)
del loss
del output_dict
return total_loss / timer.total
def report_metrics(self, loss, metrics: MetricDict):
return f'loss: {loss:.4f} {metrics.cstr()}' if metrics else f'loss: {loss:.4f}'
# noinspection PyMethodOverriding
[docs] @torch.no_grad()
def evaluate_dataloader(self,
data: MultiTaskDataLoader,
criterion,
metric: MetricDict,
logger,
ratio_width=None,
input: str = None,
**kwargs):
self.model.eval()
self.reset_metrics(metric)
tasks_need_custom_eval = self.config.get('tasks_need_custom_eval', None)
tasks_need_custom_eval = tasks_need_custom_eval or {}
tasks_need_custom_eval = dict((k, None) for k in tasks_need_custom_eval)
for each in tasks_need_custom_eval:
tasks_need_custom_eval[each] = data.dataloaders.pop(each)
timer = CountdownTimer(len(data) + len(tasks_need_custom_eval))
total_loss = 0
for idx, (task_name, batch) in enumerate(data):
output_dict, _ = self.feed_batch(batch, task_name)
loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
self.tasks[task_name])
total_loss += loss.item()
self.decode_output(output_dict, batch, task_name)
self.update_metrics(batch, output_dict, metric, task_name)
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger,
ratio_width=ratio_width)
del loss
del output_dict
for task_name, dataset in tasks_need_custom_eval.items():
task = self.tasks[task_name]
decoder = self.model_.decoders[task_name]
task.evaluate_dataloader(
dataset, task.build_criterion(decoder=decoder),
metric=metric[task_name],
input=task.dev if input == 'dev' else task.tst,
split=input,
decoder=decoder,
h=functools.partial(self._encode, task_name=task_name,
cls_is_bos=task.cls_is_bos, sep_is_eos=task.sep_is_eos)
)
data.dataloaders[task_name] = dataset
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger,
ratio_width=ratio_width)
return total_loss / timer.total, metric, data
[docs] def build_model(self, training=False, **kwargs) -> torch.nn.Module:
tasks = self.tasks
encoder: ContextualWordEmbedding = self.config.encoder
transformer_module = encoder.module(training=training)
encoder_size = transformer_module.get_output_dim()
scalar_mixes = torch.nn.ModuleDict()
decoders = torch.nn.ModuleDict()
use_raw_hidden_states = dict()
for task_name, task in tasks.items():
decoder = task.build_model(encoder_size, training=training, **task.config)
assert decoder, f'Please implement `build_model` of {type(task)} to return a decoder.'
decoders[task_name] = decoder
if task.scalar_mix:
scalar_mix = task.scalar_mix.build()
scalar_mixes[task_name] = scalar_mix
# Activate scalar mix starting from 0-th layer
encoder.scalar_mix = 0
use_raw_hidden_states[task_name] = task.use_raw_hidden_states
encoder.ret_raw_hidden_states = any(use_raw_hidden_states.values())
return MultiTaskModel(transformer_module, scalar_mixes, decoders, use_raw_hidden_states)
[docs] def predict(self,
data: Union[str, List[str]],
tasks: Optional[Union[str, List[str]]] = None,
skip_tasks: Optional[Union[str, List[str]]] = None,
resolved_tasks=None,
**kwargs) -> Document:
"""Predict on data.
Args:
data: A sentence or a list of sentences.
tasks: The tasks to predict.
skip_tasks: The tasks to skip.
resolved_tasks: The resolved tasks to override ``tasks`` and ``skip_tasks``.
**kwargs: Not used.
Returns:
A :class:`~hanlp_common.document.Document`.
"""
doc = Document()
target_tasks = resolved_tasks or self.resolve_tasks(tasks, skip_tasks)
if data == []:
for group in target_tasks:
for task_name in group:
doc[task_name] = []
return doc
flatten_target_tasks = [self.tasks[t] for group in target_tasks for t in group]
cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks])
sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks])
# Now build the dataloaders and execute tasks
first_task_name: str = list(target_tasks[0])[0]
first_task: Task = self.tasks[first_task_name]
encoder_transform, transform = self.build_transform(first_task)
# Override the tokenizer config of the 1st task
encoder_transform.sep_is_eos = sep_is_eos
encoder_transform.cls_is_bos = cls_is_bos
average_subwords = self.model.encoder.average_subwords
flat = first_task.input_is_flat(data)
if flat:
data = [data]
device = self.device
samples = first_task.build_samples(data, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
dataloader = first_task.build_dataloader(samples, transform=transform, device=device)
results = defaultdict(list)
order = []
for batch in dataloader:
order.extend(batch[IDX])
# Run the first task, let it make the initial batch for the successors
output_dict = self.predict_task(first_task, first_task_name, batch, results, run_transform=True,
cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
# Run each task group in order
for group_id, group in enumerate(target_tasks):
# We could parallelize this in the future
for task_name in group:
if task_name == first_task_name:
continue
output_dict = self.predict_task(self.tasks[task_name], task_name, batch, results, output_dict,
run_transform=True, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
if group_id == 0:
# We are kind of hard coding here. If the first task is a tokenizer,
# we need to convert the hidden and mask to token level
if first_task_name.startswith('tok'):
spans = []
tokens = []
output_spans = first_task.config.get('output_spans', None)
for span_per_sent, token_per_sent in zip(output_dict[first_task_name]['prediction'],
results[first_task_name][-len(batch[IDX]):]):
if output_spans:
token_per_sent = [x[0] for x in token_per_sent]
if cls_is_bos:
span_per_sent = [(-1, 0)] + span_per_sent
token_per_sent = [BOS] + token_per_sent
if sep_is_eos:
span_per_sent = span_per_sent + [(span_per_sent[-1][0] + 1, span_per_sent[-1][1] + 1)]
token_per_sent = token_per_sent + [EOS]
# The offsets start with 0 while [CLS] is zero
if average_subwords:
span_per_sent = [list(range(x[0] + 1, x[1] + 1)) for x in span_per_sent]
else:
span_per_sent = [x[0] + 1 for x in span_per_sent]
spans.append(span_per_sent)
tokens.append(token_per_sent)
spans = PadSequenceDataLoader.pad_data(spans, 0, torch.long, device=device)
output_dict['hidden'] = pick_tensor_for_each_token(output_dict['hidden'], spans,
average_subwords)
batch['token_token_span'] = spans
batch['token'] = tokens
# noinspection PyTypeChecker
batch['token_length'] = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=device)
batch.pop('mask', None)
# Put results into doc in the order of tasks
for k in self.config.task_names:
v = results.get(k, None)
if v is None:
continue
doc[k] = reorder(v, order)
# Allow task to perform finalization on document
for group in target_tasks:
for task_name in group:
task = self.tasks[task_name]
task.finalize_document(doc, task_name)
# If no tok in doc, use raw input as tok
if not any(k.startswith('tok') for k in doc):
doc['tok'] = data
if flat:
for k, v in list(doc.items()):
doc[k] = v[0]
# If there is only one field, don't bother to wrap it
# if len(doc) == 1:
# return list(doc.values())[0]
return doc
def resolve_tasks(self, tasks, skip_tasks) -> List[Iterable[str]]:
# Now we decide which tasks to perform and their orders
tasks_in_topological_order = self._tasks_in_topological_order
task_topological_order = self._task_topological_order
computation_graph = self._computation_graph
target_tasks = self._resolve_task_name(tasks)
if not target_tasks:
target_tasks = tasks_in_topological_order
else:
target_topological_order = defaultdict(set)
for task_name in target_tasks:
for dependency in topological_sort(computation_graph, task_name):
target_topological_order[task_topological_order[dependency]].add(dependency)
target_tasks = [item[1] for item in sorted(target_topological_order.items())]
if skip_tasks:
skip_tasks = self._resolve_task_name(skip_tasks)
target_tasks = [x - skip_tasks for x in target_tasks]
target_tasks = [x for x in target_tasks if x]
assert target_tasks, f'No task to perform due to `tasks = {tasks}`.'
# Sort target tasks within the same group in a defined order
target_tasks = [sorted(x, key=lambda _x: self.config.task_names.index(_x)) for x in target_tasks]
return target_tasks
def predict_task(self, task: Task, output_key, batch, results, output_dict=None, run_transform=True,
cls_is_bos=True, sep_is_eos=True):
output_dict, batch = self.feed_batch(batch, output_key, output_dict, run_transform, cls_is_bos, sep_is_eos,
results)
self.decode_output(output_dict, batch, output_key)
results[output_key].extend(task.prediction_to_result(output_dict[output_key]['prediction'], batch))
return output_dict
def _resolve_task_name(self, dependencies):
resolved_dependencies = set()
if isinstance(dependencies, str):
if dependencies in self.tasks:
resolved_dependencies.add(dependencies)
elif dependencies.endswith('*'):
resolved_dependencies.update(x for x in self.tasks if x.startswith(dependencies[:-1]))
else:
prefix_matched = prefix_match(dependencies, self.config.task_names)
assert prefix_matched, f'No prefix matching for {dependencies}. ' \
f'Check your dependencies definition: {list(self.tasks.values())}'
resolved_dependencies.add(prefix_matched)
elif isinstance(dependencies, Iterable):
resolved_dependencies.update(set(chain.from_iterable(self._resolve_task_name(x) for x in dependencies)))
return resolved_dependencies
[docs] def fit(self,
encoder: Embedding,
tasks: Dict[str, Task],
save_dir,
epochs,
patience=0.5,
lr=1e-3,
encoder_lr=5e-5,
adam_epsilon=1e-8,
weight_decay=0.0,
warmup_steps=0.1,
gradient_accumulation=1,
grad_norm=5.0,
encoder_grad_norm=None,
decoder_grad_norm=None,
tau: float = 0.8,
transform=None,
# prune: Callable = None,
eval_trn=True,
prefetch=None,
tasks_need_custom_eval=None,
_device_placeholder=False,
cache=False,
devices=None,
logger=None,
seed=None,
**kwargs):
trn_data, dev_data, batch_size = 'trn', 'dev', None
task_names = list(tasks.keys())
return super().fit(**merge_locals_kwargs(locals(), kwargs, excludes=('self', 'kwargs', '__class__', 'tasks')),
**tasks)
# noinspection PyAttributeOutsideInit
[docs] def on_config_ready(self, **kwargs):
self.tasks = dict((key, task) for key, task in self.config.items() if isinstance(task, Task))
computation_graph = dict()
for task_name, task in self.tasks.items():
dependencies = task.dependencies
resolved_dependencies = self._resolve_task_name(dependencies)
computation_graph[task_name] = resolved_dependencies
# We can cache this order
tasks_in_topological_order = list(toposort(computation_graph))
task_topological_order = dict()
for i, group in enumerate(tasks_in_topological_order):
for task_name in group:
task_topological_order[task_name] = i
self._tasks_in_topological_order = tasks_in_topological_order
self._task_topological_order = task_topological_order
self._computation_graph = computation_graph
@staticmethod
def reset_metrics(metrics: Dict[str, Metric]):
for metric in metrics.values():
metric.reset()
def feed_batch(self,
batch: Dict[str, Any],
task_name,
output_dict=None,
run_transform=False,
cls_is_bos=False,
sep_is_eos=False,
results=None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
h, output_dict = self._encode(batch, task_name, output_dict, cls_is_bos, sep_is_eos)
task = self.tasks[task_name]
if run_transform:
batch = task.transform_batch(batch, results=results, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
batch['mask'] = mask = hanlp.utils.torch_util.lengths_to_mask(batch['token_length'])
output_dict[task_name] = {
'output': task.feed_batch(h,
batch=batch,
mask=mask,
decoder=self.model.decoders[task_name]),
'mask': mask
}
return output_dict, batch
def _encode(self, batch, task_name, output_dict=None, cls_is_bos=False, sep_is_eos=False):
model = self.model
if output_dict:
hidden, raw_hidden = output_dict['hidden'], output_dict['raw_hidden']
else:
hidden = model.encoder(batch)
if isinstance(hidden, tuple):
hidden, raw_hidden = hidden
else:
raw_hidden = None
output_dict = {'hidden': hidden, 'raw_hidden': raw_hidden}
hidden_states = raw_hidden if model.use_raw_hidden_states[task_name] else hidden
if task_name in model.scalar_mixes:
scalar_mix = model.scalar_mixes[task_name]
h = scalar_mix(hidden_states)
else:
if model.scalar_mixes: # If any task enables scalar_mix, hidden_states will be a 4d tensor
hidden_states = hidden_states[-1, :, :, :]
h = hidden_states
# If the task doesn't need cls while h has cls, remove cls
task = self.tasks[task_name]
if cls_is_bos and not task.cls_is_bos:
h = h[:, 1:, :]
if sep_is_eos and not task.sep_is_eos:
h = h[:, :-1, :]
return h, output_dict
def decode_output(self, output_dict, batch, task_name=None):
if not task_name:
for task_name, task in self.tasks.items():
output_per_task = output_dict.get(task_name, None)
if output_per_task is not None:
output_per_task['prediction'] = task.decode_output(
output_per_task['output'],
output_per_task['mask'],
batch, self.model.decoders[task_name])
else:
output_per_task = output_dict[task_name]
output_per_task['prediction'] = self.tasks[task_name].decode_output(
output_per_task['output'],
output_per_task['mask'],
batch,
self.model.decoders[task_name])
def update_metrics(self, batch: Dict[str, Any], output_dict: Dict[str, Any], metrics: MetricDict, task_name):
task = self.tasks[task_name]
output_per_task = output_dict.get(task_name, None)
if output_per_task:
output = output_per_task['output']
prediction = output_per_task['prediction']
metric = metrics.get(task_name, None)
task.update_metrics(batch, output, prediction, metric)
def compute_loss(self,
batch: Dict[str, Any],
output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
criterion: Callable,
task: Task) -> torch.FloatTensor:
return task.compute_loss(batch, output, criterion)
[docs] def evaluate(self, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs):
rets = super().evaluate('tst', save_dir, logger, batch_size, output, **kwargs)
tst = rets[-1]
self._close_dataloader(tst)
return rets
[docs] def save_vocabs(self, save_dir, filename='vocabs.json'):
for task_name, task in self.tasks.items():
task.save_vocabs(save_dir, f'{task_name}_{filename}')
[docs] def load_vocabs(self, save_dir, filename='vocabs.json'):
for task_name, task in self.tasks.items():
task.load_vocabs(save_dir, f'{task_name}_{filename}')
def parallelize(self, devices: List[Union[int, torch.device]]):
raise NotImplementedError('Parallelization is not implemented yet.')
[docs] def __call__(self, data, **kwargs) -> Document:
return super().__call__(data, **kwargs)
def __getitem__(self, task_name: str) -> Task:
return self.tasks[task_name]
[docs] def __delitem__(self, task_name: str):
"""Delete a task (and every resource it owns) from this component.
Args:
task_name: The name of the task to be deleted.
Examples:
>>> del mtl['dep'] # Delete dep from MTL
"""
del self.config[task_name]
self.config.task_names.remove(task_name)
del self.tasks[task_name]
del self.model.decoders[task_name]
del self._computation_graph[task_name]
self._task_topological_order.pop(task_name)
for group in self._tasks_in_topological_order:
group: set = group
group.discard(task_name)
def __repr__(self):
return repr(self.config)
def items(self):
yield from self.tasks.items()
[docs] def __setattr__(self, key: str, value):
if key and key.startswith('dict') and not hasattr(self, key):
please_read_the_doc_ok = f'This MTL component has no {key}.'
matched_children = []
for name in self.config.task_names:
if hasattr(self[name], key):
matched_children.append(name)
if matched_children:
please_read_the_doc_ok += f' Maybe you are looking for one of its tasks: {matched_children}. ' \
f'For example, HanLP["{matched_children[0]}"].{key} = ...'
raise TypeError(please_read_the_doc_ok)
object.__setattr__(self, key, value)