torch_component

class hanlp.common.torch_component.TorchComponent(**kwargs)[source]

The base class for all components using PyTorch as backend. It provides common workflows of building vocabs, datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced protocols, which means subclass has the freedom to override or completely skip some steps.

Parameters

**kwargs – Addtional arguments to be stored in the config property.

abstract build_criterion(**kwargs)[source]

Implement this method to build criterion (loss function).

Parameters

**kwargs – The subclass decides the method signature.

abstract build_dataloader(data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) → torch.utils.data.dataloader.DataLoader[source]

Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.

Parameters
  • data – Data representing samples, which can be a path or a list of samples.

  • batch_size – Number of samples per batch.

  • shuffle – Whether to shuffle this dataloader.

  • device – Device tensors should be loaded onto.

  • logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.

  • **kwargs – Arguments from **self.config.

build_logger(name, save_dir)[source]

Build a logging.Logger.

Parameters
  • name – The name of this logger.

  • save_dir – The directory this logger should save logs into.

Returns

A logger.

Return type

logging.Logger

abstract build_metric(**kwargs)[source]

Implement this to build metric(s).

Parameters

**kwargs – The subclass decides the method signature.

abstract build_model(training=True, **kwargs) → torch.nn.modules.module.Module[source]

Build model.

Parameters
  • trainingTrue if called during training.

  • **kwargs**self.config.

abstract build_optimizer(**kwargs)[source]

Implement this method to build an optimizer.

Parameters

**kwargs – The subclass decides the method signature.

build_vocabs(trn: torch.utils.data.dataset.Dataset, logger: logging.Logger)[source]

Override this method to build vocabs.

Parameters
  • trn – Training set.

  • logger – Logger for reporting progress.

property device

The first device this component lives on.

property devices

The devices this component lives on.

evaluate(tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs)[source]

Evaluate test set.

Parameters
  • tst_data – Test set, which is usually a file path.

  • save_dir – The directory to save evaluation scores or predictions.

  • logger – Logger for reporting progress.

  • batch_size – Batch size for test dataloader.

  • output – Whether to save outputs into some file.

  • **kwargs – Not used.

Returns

(metric, outputs) where outputs are the return values of evaluate_dataloader.

abstract evaluate_dataloader(data: torch.utils.data.dataloader.DataLoader, criterion: Callable, metric=None, output=False, **kwargs)[source]

Evaluate on a dataloader.

Parameters
  • data – Dataloader which can build from any data source.

  • criterion – Loss function.

  • metric – Metric(s).

  • output – Whether to save outputs into some file.

  • **kwargs – Not used.

abstract execute_training_loop(trn: torch.utils.data.dataloader.DataLoader, dev: torch.utils.data.dataloader.DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, **kwargs)[source]

Implement this to run training loop.

Parameters
  • trn – Training set.

  • dev – Development set.

  • epochs – Number of epochs.

  • criterion – Loss function.

  • optimizer – Optimizer(s).

  • metric – Metric(s)

  • save_dir – The directory to save this component.

  • logger – Logger for reporting progress.

  • devices – Devices this component and dataloader will live on.

  • ratio_width – The width of dataset size measured in number of characters. Used for logger to align messages.

  • **kwargs – Other hyper-parameters passed from sub-class.

fit(trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune: Union[bool, str] = False, eval_trn=True, _device_placeholder=False, **kwargs)[source]

Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.

Parameters
  • trn_data – Training set.

  • dev_data – Development set.

  • save_dir – The directory to save trained component.

  • batch_size – The number of samples in a batch.

  • epochs – Number of epochs.

  • devices – Devices this component will live on.

  • logger – Any logging.Logger instance.

  • seed – Random seed to reproduce this training.

  • finetuneTrue to load from save_dir instead of creating a randomly initialized component. str to specify a different save_dir to load from.

  • eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.

  • _device_placeholderTrue to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.

  • **kwargs – Hyperparameters used by sub-classes.

Returns

Any results sub-classes would like to return. Usually the best metrics on training set.

abstract fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs)[source]

Fit onto a dataloader.

Parameters
  • trn – Training set.

  • criterion – Loss function.

  • optimizer – Optimizer.

  • metric – Metric(s).

  • logger – Logger for reporting progress.

  • **kwargs – Other hyper-parameters passed from sub-class.

load(save_dir: str, devices=None, verbose=False, **kwargs)[source]

Load from a local/remote component.

Parameters
  • save_dir – An identifier which can be a local path or a remote URL or a pre-defined string.

  • devices – The devices this component will be moved onto.

  • verboseTrue to log loading progress.

  • **kwargs – To override some configs.

load_config(save_dir, filename='config.json', **kwargs)[source]

Load config from a directory.

Parameters
  • save_dir – The directory to load config.

  • filename – A file name for config.

  • **kwargs – K-V pairs to override config.

load_vocabs(save_dir, filename='vocabs.json')[source]

Load vocabularies from a directory.

Parameters
  • save_dir – The directory to load vocabularies.

  • filename – The name for vocabularies.

load_weights(save_dir, filename='model.pt', **kwargs)[source]

Load weights from a directory.

Parameters
  • save_dir – The directory to load weights from.

  • filename – A file name for weights.

  • **kwargs – Not used.

property model_

The actual model when it’s wrapped by a DataParallel

Returns: The “real” model

on_config_ready(**kwargs)[source]

Called when config is ready, either during fit ot load. Subclass can perform extra initialization tasks in this callback.

Parameters

**kwargs – Not used.

abstract predict(data: Union[str, List[str]], batch_size: int = None, **kwargs)[source]

Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with torch.no_grad and will introduces unnecessary gradient computation. Use __call__ instead.

Parameters
  • data – Sentences or tokens.

  • batch_size – Decoding batch size.

  • **kwargs – Used in sub-classes.

save(save_dir: str, **kwargs)[source]

Save this component to a directory.

Parameters
  • save_dir – The directory to save this component.

  • **kwargs – Not used.

save_config(save_dir, filename='config.json')[source]

Save config into a directory.

Parameters
  • save_dir – The directory to save config.

  • filename – A file name for config.

save_vocabs(save_dir, filename='vocabs.json')[source]

Save vocabularies to a directory.

Parameters
  • save_dir – The directory to save vocabularies.

  • filename – The name for vocabularies.

save_weights(save_dir, filename='model.pt', trainable_only=True, **kwargs)[source]

Save model weights to a directory.

Parameters
  • save_dir – The directory to save weights into.

  • filename – A file name for weights.

  • trainable_onlyTrue to only save trainable weights. Useful when the model contains lots of static embeddings.

  • **kwargs – Not used for now.

to(devices=typing.Union[int, float, typing.List[int], typing.Dict[str, typing.Union[int, torch.device]]], logger: logging.Logger = None, verbose=False)[source]

Move this component to devices.

Parameters
  • devices – Target devices.

  • logger – Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds.

  • verboseTrue to print progress when logger is None.