torch_component
torch_component¶
- class hanlp.common.torch_component.TorchComponent(**kwargs)[source]¶
The base class for all components using PyTorch as backend. It provides common workflows of building vocabs, datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced protocols, which means subclass has the freedom to override or completely skip some steps.
- Parameters
**kwargs – Addtional arguments to be stored in the
config
property.
- abstract build_criterion(**kwargs)[source]¶
Implement this method to build criterion (loss function).
- Parameters
**kwargs – The subclass decides the method signature.
- abstract build_dataloader(data, batch_size, shuffle=False, device=None, logger: Optional[logging.Logger] = None, **kwargs) torch.utils.data.dataloader.DataLoader [source]¶
Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.
- Parameters
data – Data representing samples, which can be a path or a list of samples.
batch_size – Number of samples per batch.
shuffle – Whether to shuffle this dataloader.
device – Device tensors should be loaded onto.
logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
**kwargs – Arguments from
**self.config
.
- build_logger(name, save_dir)[source]¶
Build a
logging.Logger
.- Parameters
name – The name of this logger.
save_dir – The directory this logger should save logs into.
- Returns
A logger.
- Return type
- abstract build_metric(**kwargs)[source]¶
Implement this to build metric(s).
- Parameters
**kwargs – The subclass decides the method signature.
- abstract build_model(training=True, **kwargs) torch.nn.modules.module.Module [source]¶
Build model.
- Parameters
training –
True
if called during training.**kwargs –
**self.config
.
- abstract build_optimizer(**kwargs)[source]¶
Implement this method to build an optimizer.
- Parameters
**kwargs – The subclass decides the method signature.
- build_vocabs(trn: torch.utils.data.dataset.Dataset, logger: logging.Logger)[source]¶
Override this method to build vocabs.
- Parameters
trn – Training set.
logger – Logger for reporting progress.
- property device¶
The first device this component lives on.
- property devices¶
The devices this component lives on.
- evaluate(tst_data, save_dir=None, logger: Optional[logging.Logger] = None, batch_size=None, output=False, **kwargs)[source]¶
Evaluate test set.
- Parameters
tst_data – Test set, which is usually a file path.
save_dir – The directory to save evaluation scores or predictions.
logger – Logger for reporting progress.
batch_size – Batch size for test dataloader.
output – Whether to save outputs into some file.
**kwargs – Not used.
- Returns
(metric, outputs) where outputs are the return values of
evaluate_dataloader
.
- abstract evaluate_dataloader(data: torch.utils.data.dataloader.DataLoader, criterion: Callable, metric=None, output=False, **kwargs)[source]¶
Evaluate on a dataloader.
- Parameters
data – Dataloader which can build from any data source.
criterion – Loss function.
metric – Metric(s).
output – Whether to save outputs into some file.
**kwargs – Not used.
- abstract execute_training_loop(trn: torch.utils.data.dataloader.DataLoader, dev: torch.utils.data.dataloader.DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, **kwargs)[source]¶
Implement this to run training loop.
- Parameters
trn – Training set.
dev – Development set.
epochs – Number of epochs.
criterion – Loss function.
optimizer – Optimizer(s).
metric – Metric(s)
save_dir – The directory to save this component.
logger – Logger for reporting progress.
devices – Devices this component and dataloader will live on.
ratio_width – The width of dataset size measured in number of characters. Used for logger to align messages.
**kwargs – Other hyper-parameters passed from sub-class.
- fit(trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune: Union[bool, str] = False, eval_trn=True, _device_placeholder=False, **kwargs)[source]¶
Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.
- Parameters
trn_data – Training set.
dev_data – Development set.
save_dir – The directory to save trained component.
batch_size – The number of samples in a batch.
epochs – Number of epochs.
devices – Devices this component will live on.
logger – Any
logging.Logger
instance.seed – Random seed to reproduce this training.
finetune –
True
to load fromsave_dir
instead of creating a randomly initialized component.str
to specify a differentsave_dir
to load from.eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.
_device_placeholder –
True
to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.**kwargs – Hyperparameters used by sub-classes.
- Returns
Any results sub-classes would like to return. Usually the best metrics on training set.
- abstract fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs)[source]¶
Fit onto a dataloader.
- Parameters
trn – Training set.
criterion – Loss function.
optimizer – Optimizer.
metric – Metric(s).
logger – Logger for reporting progress.
**kwargs – Other hyper-parameters passed from sub-class.
- load(save_dir: str, devices=None, verbose=False, **kwargs)[source]¶
Load from a local/remote component.
- Parameters
save_dir – An identifier which can be a local path or a remote URL or a pre-defined string.
devices – The devices this component will be moved onto.
verbose –
True
to log loading progress.**kwargs – To override some configs.
- load_config(save_dir, filename='config.json', **kwargs)[source]¶
Load config from a directory.
- Parameters
save_dir – The directory to load config.
filename – A file name for config.
**kwargs – K-V pairs to override config.
- load_vocabs(save_dir, filename='vocabs.json')[source]¶
Load vocabularies from a directory.
- Parameters
save_dir – The directory to load vocabularies.
filename – The name for vocabularies.
- load_weights(save_dir, filename='model.pt', **kwargs)[source]¶
Load weights from a directory.
- Parameters
save_dir – The directory to load weights from.
filename – A file name for weights.
**kwargs – Not used.
- property model_: torch.nn.modules.module.Module¶
The actual model when it’s wrapped by a DataParallel
Returns: The “real” model
- on_config_ready(**kwargs)[source]¶
Called when config is ready, either during
fit
orload
. Subclass can perform extra initialization tasks in this callback.- Parameters
**kwargs – Not used.
- abstract predict(*args, **kwargs)[source]¶
Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with
torch.no_grad
and will introduces unnecessary gradient computation. Use__call__
instead.- Parameters
*args – Sentences or tokens.
**kwargs – Used in sub-classes.
- save(save_dir: str, **kwargs)[source]¶
Save this component to a directory.
- Parameters
save_dir – The directory to save this component.
**kwargs – Not used.
- save_config(save_dir, filename='config.json')[source]¶
Save config into a directory.
- Parameters
save_dir – The directory to save config.
filename – A file name for config.
- save_vocabs(save_dir, filename='vocabs.json')[source]¶
Save vocabularies to a directory.
- Parameters
save_dir – The directory to save vocabularies.
filename – The name for vocabularies.
- save_weights(save_dir, filename='model.pt', trainable_only=True, **kwargs)[source]¶
Save model weights to a directory.
- Parameters
save_dir – The directory to save weights into.
filename – A file name for weights.
trainable_only –
True
to only save trainable weights. Useful when the model contains lots of static embeddings.**kwargs – Not used for now.
- to(devices: Optional[Union[int, float, List[int], Dict[str, Union[int, torch.device]]]] = None, logger: Optional[logging.Logger] = None, verbose=False)[source]¶
Move this component to devices.
- Parameters
devices – Target devices.
logger – Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds.
verbose –
True
to print progress when logger is None.