MultiTaskLearning

MultiTaskLearning

class hanlp.components.mtl.multi_task_learning.MultiTaskLearning(**kwargs)[source]

A multi-task learning (MTL) framework. It shares the same encoder across multiple decoders. These decoders can have dependencies on each other which will be properly handled during decoding. To integrate a component into this MTL framework, a component needs to implement the Task interface.

This framework mostly follows the architecture of Clark et al. (2019) and He & Choi (2021), with additional scalar mix tricks (Kondratyuk & Straka 2019) allowing each task to attend to any subset of layers. We also experimented with knowledge distillation on single tasks, the performance gain was nonsignificant on a large dataset. In the near future, we have no plan to invest more efforts in distillation, since most datasets HanLP uses are relatively large, and our hardware is relatively powerful.

Parameters

**kwargs – Arguments passed to config.

__call__(data, **kwargs) hanlp_common.document.Document[source]

Predict on data fed by user. This method calls predict() but decorates it with torch.no_grad.

Parameters
  • *args – Sentences or tokens.

  • **kwargs – Used in sub-classes.

__delitem__(task_name: str)[source]

Delete a task (and every resource it owns) from this component.

Parameters

task_name – The name of the task to be deleted.

Examples

>>> del mtl['dep']  # Delete dep from MTL
__setattr__(key: str, value)[source]

Implement setattr(self, name, value).

build_criterion(**kwargs)[source]

Implement this method to build criterion (loss function).

Parameters

**kwargs – The subclass decides the method signature.

build_dataloader(data, batch_size, shuffle=False, device=None, logger: Optional[logging.Logger] = None, gradient_accumulation=1, tau: float = 0.8, prune=None, prefetch=None, tasks_need_custom_eval=None, cache=False, debug=False, **kwargs) torch.utils.data.dataloader.DataLoader[source]

Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.

Parameters
  • data – Data representing samples, which can be a path or a list of samples.

  • batch_size – Number of samples per batch.

  • shuffle – Whether to shuffle this dataloader.

  • device – Device tensors should be loaded onto.

  • logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.

  • **kwargs – Arguments from **self.config.

build_metric(**kwargs)[source]

Implement this to build metric(s).

Parameters

**kwargs – The subclass decides the method signature.

build_model(training=False, **kwargs) torch.nn.modules.module.Module[source]

Build model.

Parameters
  • trainingTrue if called during training.

  • **kwargs**self.config.

build_optimizer(trn, epochs, adam_epsilon, weight_decay, warmup_steps, lr, encoder_lr, **kwargs)[source]

Implement this method to build an optimizer.

Parameters

**kwargs – The subclass decides the method signature.

evaluate(save_dir=None, logger: Optional[logging.Logger] = None, batch_size=None, output=False, **kwargs)[source]

Evaluate test set.

Parameters
  • tst_data – Test set, which is usually a file path.

  • save_dir – The directory to save evaluation scores or predictions.

  • logger – Logger for reporting progress.

  • batch_size – Batch size for test dataloader.

  • output – Whether to save outputs into some file.

  • **kwargs – Not used.

Returns

(metric, outputs) where outputs are the return values of evaluate_dataloader.

evaluate_dataloader(data: hanlp.components.mtl.multi_task_learning.MultiTaskDataLoader, criterion, metric: hanlp.metrics.mtl.MetricDict, logger, ratio_width=None, input: str = None, **kwargs)[source]

Evaluate on a dataloader.

Parameters
  • data – Dataloader which can build from any data source.

  • criterion – Loss function.

  • metric – Metric(s).

  • output – Whether to save outputs into some file.

  • **kwargs – Not used.

execute_training_loop(trn: torch.utils.data.dataloader.DataLoader, dev: torch.utils.data.dataloader.DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, patience=0.5, **kwargs)[source]

Implement this to run training loop.

Parameters
  • trn – Training set.

  • dev – Development set.

  • epochs – Number of epochs.

  • criterion – Loss function.

  • optimizer – Optimizer(s).

  • metric – Metric(s)

  • save_dir – The directory to save this component.

  • logger – Logger for reporting progress.

  • devices – Devices this component and dataloader will live on.

  • ratio_width – The width of dataset size measured in number of characters. Used for logger to align messages.

  • **kwargs – Other hyper-parameters passed from sub-class.

fit(encoder: hanlp.layers.embeddings.embedding.Embedding, tasks: Dict[str, hanlp.components.mtl.tasks.Task], save_dir, epochs, patience=0.5, lr=0.001, encoder_lr=5e-05, adam_epsilon=1e-08, weight_decay=0.0, warmup_steps=0.1, gradient_accumulation=1, grad_norm=5.0, encoder_grad_norm=None, decoder_grad_norm=None, tau: float = 0.8, transform=None, eval_trn=True, prefetch=None, tasks_need_custom_eval=None, _device_placeholder=False, cache=False, devices=None, logger=None, seed=None, **kwargs)[source]

Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.

Parameters
  • trn_data – Training set.

  • dev_data – Development set.

  • save_dir – The directory to save trained component.

  • batch_size – The number of samples in a batch.

  • epochs – Number of epochs.

  • devices – Devices this component will live on.

  • logger – Any logging.Logger instance.

  • seed – Random seed to reproduce this training.

  • finetuneTrue to load from save_dir instead of creating a randomly initialized component. str to specify a different save_dir to load from.

  • eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.

  • _device_placeholderTrue to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.

  • **kwargs – Hyperparameters used by sub-classes.

Returns

Any results sub-classes would like to return. Usually the best metrics on training set.

fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: hanlp.common.structure.History, ratio_width=None, gradient_accumulation=1, encoder_grad_norm=None, decoder_grad_norm=None, patience=0.5, eval_trn=False, **kwargs)[source]

Fit onto a dataloader.

Parameters
  • trn – Training set.

  • criterion – Loss function.

  • optimizer – Optimizer.

  • metric – Metric(s).

  • logger – Logger for reporting progress.

  • **kwargs – Other hyper-parameters passed from sub-class.

load_vocabs(save_dir, filename='vocabs.json')[source]

Load vocabularies from a directory.

Parameters
  • save_dir – The directory to load vocabularies.

  • filename – The name for vocabularies.

on_config_ready(**kwargs)[source]

Called when config is ready, either during fit or load. Subclass can perform extra initialization tasks in this callback.

Parameters

**kwargs – Not used.

predict(data: Union[str, List[str]], tasks: Optional[Union[str, List[str]]] = None, skip_tasks: Optional[Union[str, List[str]]] = None, resolved_tasks=None, **kwargs) hanlp_common.document.Document[source]

Predict on data.

Parameters
  • data – A sentence or a list of sentences.

  • tasks – The tasks to predict.

  • skip_tasks – The tasks to skip.

  • resolved_tasks – The resolved tasks to override tasks and skip_tasks.

  • **kwargs – Not used.

Returns

A Document.

save_vocabs(save_dir, filename='vocabs.json')[source]

Save vocabularies to a directory.

Parameters
  • save_dir – The directory to save vocabularies.

  • filename – The name for vocabularies.