MultiTaskLearning
MultiTaskLearning¶
- class hanlp.components.mtl.multi_task_learning.MultiTaskLearning(**kwargs)[source]¶
A multi-task learning (MTL) framework. It shares the same encoder across multiple decoders. These decoders can have dependencies on each other which will be properly handled during decoding. To integrate a component into this MTL framework, a component needs to implement the
Task
interface.This framework mostly follows the architecture of Clark et al. (2019) and He & Choi (2021), with additional scalar mix tricks (Kondratyuk & Straka 2019) allowing each task to attend to any subset of layers. We also experimented with knowledge distillation on single tasks, the performance gain was nonsignificant on a large dataset. In the near future, we have no plan to invest more efforts in distillation, since most datasets HanLP uses are relatively large, and our hardware is relatively powerful.
- Parameters
**kwargs – Arguments passed to config.
- __call__(data, **kwargs) hanlp_common.document.Document [source]¶
Predict on data fed by user. This method calls
predict()
but decorates it withtorch.no_grad
.- Parameters
*args – Sentences or tokens.
**kwargs – Used in sub-classes.
- __delitem__(task_name: str)[source]¶
Delete a task (and every resource it owns) from this component.
- Parameters
task_name – The name of the task to be deleted.
Examples
>>> del mtl['dep'] # Delete dep from MTL
- build_criterion(**kwargs)[source]¶
Implement this method to build criterion (loss function).
- Parameters
**kwargs – The subclass decides the method signature.
- build_dataloader(data, batch_size, shuffle=False, device=None, logger: Optional[logging.Logger] = None, gradient_accumulation=1, tau: float = 0.8, prune=None, prefetch=None, tasks_need_custom_eval=None, cache=False, debug=False, **kwargs) torch.utils.data.dataloader.DataLoader [source]¶
Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.
- Parameters
data – Data representing samples, which can be a path or a list of samples.
batch_size – Number of samples per batch.
shuffle – Whether to shuffle this dataloader.
device – Device tensors should be loaded onto.
logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
**kwargs – Arguments from
**self.config
.
- build_metric(**kwargs)[source]¶
Implement this to build metric(s).
- Parameters
**kwargs – The subclass decides the method signature.
- build_model(training=False, **kwargs) torch.nn.modules.module.Module [source]¶
Build model.
- Parameters
training –
True
if called during training.**kwargs –
**self.config
.
- build_optimizer(trn, epochs, adam_epsilon, weight_decay, warmup_steps, lr, encoder_lr, **kwargs)[source]¶
Implement this method to build an optimizer.
- Parameters
**kwargs – The subclass decides the method signature.
- evaluate(save_dir=None, logger: Optional[logging.Logger] = None, batch_size=None, output=False, **kwargs)[source]¶
Evaluate test set.
- Parameters
tst_data – Test set, which is usually a file path.
save_dir – The directory to save evaluation scores or predictions.
logger – Logger for reporting progress.
batch_size – Batch size for test dataloader.
output – Whether to save outputs into some file.
**kwargs – Not used.
- Returns
(metric, outputs) where outputs are the return values of
evaluate_dataloader
.
- evaluate_dataloader(data: hanlp.components.mtl.multi_task_learning.MultiTaskDataLoader, criterion, metric: hanlp.metrics.mtl.MetricDict, logger, ratio_width=None, input: str = None, **kwargs)[source]¶
Evaluate on a dataloader.
- Parameters
data – Dataloader which can build from any data source.
criterion – Loss function.
metric – Metric(s).
output – Whether to save outputs into some file.
**kwargs – Not used.
- execute_training_loop(trn: torch.utils.data.dataloader.DataLoader, dev: torch.utils.data.dataloader.DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, patience=0.5, **kwargs)[source]¶
Implement this to run training loop.
- Parameters
trn – Training set.
dev – Development set.
epochs – Number of epochs.
criterion – Loss function.
optimizer – Optimizer(s).
metric – Metric(s)
save_dir – The directory to save this component.
logger – Logger for reporting progress.
devices – Devices this component and dataloader will live on.
ratio_width – The width of dataset size measured in number of characters. Used for logger to align messages.
**kwargs – Other hyper-parameters passed from sub-class.
- fit(encoder: hanlp.layers.embeddings.embedding.Embedding, tasks: Dict[str, hanlp.components.mtl.tasks.Task], save_dir, epochs, patience=0.5, lr=0.001, encoder_lr=5e-05, adam_epsilon=1e-08, weight_decay=0.0, warmup_steps=0.1, gradient_accumulation=1, grad_norm=5.0, encoder_grad_norm=None, decoder_grad_norm=None, tau: float = 0.8, transform=None, eval_trn=True, prefetch=None, tasks_need_custom_eval=None, _device_placeholder=False, cache=False, devices=None, logger=None, seed=None, **kwargs)[source]¶
Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.
- Parameters
trn_data – Training set.
dev_data – Development set.
save_dir – The directory to save trained component.
batch_size – The number of samples in a batch.
epochs – Number of epochs.
devices – Devices this component will live on.
logger – Any
logging.Logger
instance.seed – Random seed to reproduce this training.
finetune –
True
to load fromsave_dir
instead of creating a randomly initialized component.str
to specify a differentsave_dir
to load from.eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.
_device_placeholder –
True
to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.**kwargs – Hyperparameters used by sub-classes.
- Returns
Any results sub-classes would like to return. Usually the best metrics on training set.
- fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: hanlp.common.structure.History, ratio_width=None, gradient_accumulation=1, encoder_grad_norm=None, decoder_grad_norm=None, patience=0.5, eval_trn=False, **kwargs)[source]¶
Fit onto a dataloader.
- Parameters
trn – Training set.
criterion – Loss function.
optimizer – Optimizer.
metric – Metric(s).
logger – Logger for reporting progress.
**kwargs – Other hyper-parameters passed from sub-class.
- load_vocabs(save_dir, filename='vocabs.json')[source]¶
Load vocabularies from a directory.
- Parameters
save_dir – The directory to load vocabularies.
filename – The name for vocabularies.
- on_config_ready(**kwargs)[source]¶
Called when config is ready, either during
fit
orload
. Subclass can perform extra initialization tasks in this callback.- Parameters
**kwargs – Not used.
- predict(data: Union[str, List[str]], tasks: Optional[Union[str, List[str]]] = None, skip_tasks: Optional[Union[str, List[str]]] = None, resolved_tasks=None, **kwargs) hanlp_common.document.Document [source]¶
Predict on data.
- Parameters
data – A sentence or a list of sentences.
tasks – The tasks to predict.
skip_tasks – The tasks to skip.
resolved_tasks – The resolved tasks to override
tasks
andskip_tasks
.**kwargs – Not used.
- Returns
A
Document
.