# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-05-08 21:20
import logging
import os
import re
import time
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Union, Callable
import torch
from torch import nn
from torch.utils.data import DataLoader
import hanlp
from hanlp.common.component import Component
from hanlp.common.dataset import TransformableDataset
from hanlp.common.transform import VocabDict
from hanlp.utils.io_util import get_resource, basename_no_ext
from hanlp.utils.log_util import init_logger, flash
from hanlp.utils.torch_util import cuda_devices, set_seed
from hanlp_common.configurable import Configurable
from hanlp_common.constant import IDX, HANLP_VERBOSE
from hanlp_common.reflection import classpath_of
from hanlp_common.structure import SerializableDict
from hanlp_common.util import merge_dict, isdebugging
[docs]class TorchComponent(Component, ABC):
def __init__(self, **kwargs) -> None:
"""The base class for all components using PyTorch as backend. It provides common workflows of building vocabs,
datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced
protocols, which means subclass has the freedom to override or completely skip some steps.
Args:
**kwargs: Addtional arguments to be stored in the ``config`` property.
"""
super().__init__()
self.model: Optional[torch.nn.Module] = None
self.config = SerializableDict(**kwargs)
self.vocabs = VocabDict()
def _capture_config(self, locals_: Dict,
exclude=(
'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
'dev_batch_size', '__class__', 'devices', 'eval_trn')):
"""Save arguments to config
Args:
locals_: Dict:
exclude: (Default value = ('trn_data')
'dev_data':
'save_dir':
'kwargs':
'self':
'logger':
'verbose':
'dev_batch_size':
'__class__':
'devices'):
Returns:
"""
if 'kwargs' in locals_:
locals_.update(locals_['kwargs'])
locals_ = dict((k, v) for k, v in locals_.items() if k not in exclude and not k.startswith('_'))
self.config.update(locals_)
return self.config
[docs] def save_weights(self, save_dir, filename='model.pt', trainable_only=True, **kwargs):
"""Save model weights to a directory.
Args:
save_dir: The directory to save weights into.
filename: A file name for weights.
trainable_only: ``True`` to only save trainable weights. Useful when the model contains lots of static
embeddings.
**kwargs: Not used for now.
"""
model = self.model_
state_dict = model.state_dict()
if trainable_only:
trainable_names = set(n for n, p in model.named_parameters() if p.requires_grad)
state_dict = dict((n, p) for n, p in state_dict.items() if n in trainable_names)
torch.save(state_dict, os.path.join(save_dir, filename))
[docs] def load_weights(self, save_dir, filename='model.pt', **kwargs):
"""Load weights from a directory.
Args:
save_dir: The directory to load weights from.
filename: A file name for weights.
**kwargs: Not used.
"""
save_dir = get_resource(save_dir)
filename = os.path.join(save_dir, filename)
# flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
try:
self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False)
except TypeError:
self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
# flash('')
[docs] def save_config(self, save_dir, filename='config.json'):
"""Save config into a directory.
Args:
save_dir: The directory to save config.
filename: A file name for config.
"""
self._savable_config.save_json(os.path.join(save_dir, filename))
[docs] def load_config(self, save_dir, filename='config.json', **kwargs):
"""Load config from a directory.
Args:
save_dir: The directory to load config.
filename: A file name for config.
**kwargs: K-V pairs to override config.
"""
save_dir = get_resource(save_dir)
self.config.load_json(os.path.join(save_dir, filename))
self.config.update(kwargs) # overwrite config loaded from disk
for k, v in self.config.items():
if isinstance(v, dict) and 'classpath' in v:
self.config[k] = Configurable.from_config(v)
self.on_config_ready(**self.config, save_dir=save_dir)
[docs] def save_vocabs(self, save_dir, filename='vocabs.json'):
"""Save vocabularies to a directory.
Args:
save_dir: The directory to save vocabularies.
filename: The name for vocabularies.
"""
if hasattr(self, 'vocabs'):
self.vocabs.save_vocabs(save_dir, filename)
[docs] def load_vocabs(self, save_dir, filename='vocabs.json'):
"""Load vocabularies from a directory.
Args:
save_dir: The directory to load vocabularies.
filename: The name for vocabularies.
"""
if hasattr(self, 'vocabs'):
self.vocabs = VocabDict()
self.vocabs.load_vocabs(save_dir, filename)
[docs] def save(self, save_dir: str, **kwargs):
"""Save this component to a directory.
Args:
save_dir: The directory to save this component.
**kwargs: Not used.
"""
self.save_config(save_dir)
self.save_vocabs(save_dir)
self.save_weights(save_dir)
[docs] def load(self, save_dir: str, devices=None, verbose=HANLP_VERBOSE, **kwargs):
"""Load from a local/remote component.
Args:
save_dir: An identifier which can be a local path or a remote URL or a pre-defined string.
devices: The devices this component will be moved onto.
verbose: ``True`` to log loading progress.
**kwargs: To override some configs.
"""
save_dir = get_resource(save_dir)
# flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
if devices is None and self.model:
devices = self.devices
self.load_config(save_dir, **kwargs)
self.load_vocabs(save_dir)
if verbose:
flash('Building model [blink][yellow]...[/yellow][/blink]')
self.config.pop('training', None) # Some legacy versions accidentally put training into config file
self.model = self.build_model(
**merge_dict(self.config, **kwargs, overwrite=True, inplace=True), training=False, save_dir=save_dir)
if verbose:
flash('')
self.load_weights(save_dir, **kwargs)
self.to(devices, verbose=verbose)
self.model.eval()
[docs] def fit(self,
trn_data,
dev_data,
save_dir,
batch_size,
epochs,
devices=None,
logger=None,
seed=None,
finetune: Union[bool, str] = False,
eval_trn=True,
_device_placeholder=False,
**kwargs):
"""Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote
files.
Args:
trn_data: Training set.
dev_data: Development set.
save_dir: The directory to save trained component.
batch_size: The number of samples in a batch.
epochs: Number of epochs.
devices: Devices this component will live on.
logger: Any :class:`logging.Logger` instance.
seed: Random seed to reproduce this training.
finetune: ``True`` to load from ``save_dir`` instead of creating a randomly initialized component. ``str``
to specify a different ``save_dir`` to load from.
eval_trn: Evaluate training set after each update. This can slow down the training but provides a quick
diagnostic for debugging.
_device_placeholder: ``True`` to create a placeholder tensor which triggers PyTorch to occupy devices so
other components won't take these devices as first choices.
**kwargs: Hyperparameters used by sub-classes.
Returns:
Any results sub-classes would like to return. Usually the best metrics on training set.
"""
# Common initialization steps
config = self._capture_config(locals())
if not logger:
logger = self.build_logger('train', save_dir)
if seed is None:
self.config.seed = 233 if isdebugging() else int(time.time())
set_seed(self.config.seed)
logger.info(self._savable_config.to_json(sort=True))
if isinstance(devices, list) or devices is None or isinstance(devices, float):
flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
devices = -1 if isdebugging() else cuda_devices(devices)
flash('')
# flash(f'Available GPUs: {devices}')
if isinstance(devices, list):
first_device = (devices[0] if devices else -1)
elif isinstance(devices, dict):
first_device = next(iter(devices.values()))
elif isinstance(devices, int):
first_device = devices
else:
first_device = -1
if _device_placeholder and first_device >= 0:
_dummy_placeholder = self._create_dummy_placeholder_on(first_device)
if finetune:
if isinstance(finetune, str):
self.load(finetune, devices=devices)
else:
self.load(save_dir, devices=devices)
self.config.finetune = finetune
self.vocabs.unlock() # For extending vocabs
logger.info(
f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
self.on_config_ready(**self.config, save_dir=save_dir)
trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True,
training=True, device=first_device, logger=logger, vocabs=self.vocabs,
overwrite=True))
dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False,
training=None, device=first_device, logger=logger, vocabs=self.vocabs,
overwrite=True)) if dev_data else None
flash('[yellow]Building model [blink]...[/blink][/yellow]')
self.model = self.build_model(**merge_dict(config, training=True), logger=logger)
flash('')
logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
assert self.model, 'build_model is not properly implemented.'
_description = repr(self.model)
if len(_description.split('\n')) < 10:
logger.info(_description)
self.save_config(save_dir)
self.save_vocabs(save_dir)
self.to(devices, logger)
if _device_placeholder and first_device >= 0:
del _dummy_placeholder
criterion = self.build_criterion(**merge_dict(config, trn=trn))
optimizer = self.build_optimizer(**merge_dict(config, trn=trn, criterion=criterion))
metric = self.build_metric(**self.config)
if hasattr(trn, 'dataset') and dev and hasattr(dev, 'dataset'):
if trn.dataset and dev.dataset:
logger.info(f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.')
if hasattr(trn, '__len__') and dev and hasattr(dev, '__len__'):
trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
ratio_width = len(f'{trn_size}/{trn_size}')
else:
ratio_width = None
return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion,
optimizer=optimizer, metric=metric, logger=logger,
save_dir=save_dir,
devices=devices,
ratio_width=ratio_width,
trn_data=trn_data,
dev_data=dev_data,
eval_trn=eval_trn,
overwrite=True))
[docs] def build_logger(self, name, save_dir):
"""Build a :class:`logging.Logger`.
Args:
name: The name of this logger.
save_dir: The directory this logger should save logs into.
Returns:
logging.Logger: A logger.
"""
logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO, fmt="%(message)s")
return logger
[docs] @abstractmethod
def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None,
**kwargs) -> DataLoader:
"""Build dataloader for training, dev and test sets. It's suggested to build vocabs in this method if they are
not built yet.
Args:
data: Data representing samples, which can be a path or a list of samples.
batch_size: Number of samples per batch.
shuffle: Whether to shuffle this dataloader.
device: Device tensors should be loaded onto.
logger: Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
**kwargs: Arguments from ``**self.config``.
"""
pass
[docs] def build_vocabs(self, trn: torch.utils.data.Dataset, logger: logging.Logger):
"""Override this method to build vocabs.
Args:
trn: Training set.
logger: Logger for reporting progress.
"""
pass
@property
def _savable_config(self):
def convert(k, v):
if not isinstance(v, SerializableDict) and hasattr(v, 'config'):
v = v.config
elif isinstance(v, (set, tuple)):
v = list(v)
if isinstance(v, dict):
v = dict(convert(_k, _v) for _k, _v in v.items())
return k, v
config = SerializableDict(
convert(k, v) for k, v in sorted(self.config.items()))
config.update({
# 'create_time': now_datetime(),
'classpath': classpath_of(self),
'hanlp_version': hanlp.__version__,
})
return config
[docs] @abstractmethod
def build_optimizer(self, **kwargs):
"""Implement this method to build an optimizer.
Args:
**kwargs: The subclass decides the method signature.
"""
pass
[docs] @abstractmethod
def build_criterion(self, **kwargs):
"""Implement this method to build criterion (loss function).
Args:
**kwargs: The subclass decides the method signature.
"""
pass
[docs] @abstractmethod
def build_metric(self, **kwargs):
"""Implement this to build metric(s).
Args:
**kwargs: The subclass decides the method signature.
"""
pass
[docs] @abstractmethod
def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
logger: logging.Logger, devices, ratio_width=None,
**kwargs):
"""Implement this to run training loop.
Args:
trn: Training set.
dev: Development set.
epochs: Number of epochs.
criterion: Loss function.
optimizer: Optimizer(s).
metric: Metric(s)
save_dir: The directory to save this component.
logger: Logger for reporting progress.
devices: Devices this component and dataloader will live on.
ratio_width: The width of dataset size measured in number of characters. Used for logger to align messages.
**kwargs: Other hyper-parameters passed from sub-class.
"""
pass
[docs] @abstractmethod
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs):
"""Fit onto a dataloader.
Args:
trn: Training set.
criterion: Loss function.
optimizer: Optimizer.
metric: Metric(s).
logger: Logger for reporting progress.
**kwargs: Other hyper-parameters passed from sub-class.
"""
pass
[docs] @abstractmethod
def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, **kwargs):
"""Evaluate on a dataloader.
Args:
data: Dataloader which can build from any data source.
criterion: Loss function.
metric: Metric(s).
output: Whether to save outputs into some file.
**kwargs: Not used.
"""
pass
[docs] @abstractmethod
def build_model(self, training=True, **kwargs) -> torch.nn.Module:
"""Build model.
Args:
training: ``True`` if called during training.
**kwargs: ``**self.config``.
"""
raise NotImplementedError
[docs] def evaluate(self, tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs):
"""Evaluate test set.
Args:
tst_data: Test set, which is usually a file path.
save_dir: The directory to save evaluation scores or predictions.
logger: Logger for reporting progress.
batch_size: Batch size for test dataloader.
output: Whether to save outputs into some file.
**kwargs: Not used.
Returns:
(metric, outputs) where outputs are the return values of ``evaluate_dataloader``.
"""
if not self.model:
raise RuntimeError('Call fit or load before evaluate.')
if isinstance(tst_data, str):
tst_data = get_resource(tst_data)
filename = os.path.basename(tst_data)
else:
filename = None
if output is True:
output = self.generate_prediction_filename(tst_data if isinstance(tst_data, str) else 'test.txt', save_dir)
if logger is None:
_logger_name = basename_no_ext(filename) if filename else None
logger = self.build_logger(_logger_name, save_dir)
if not batch_size:
batch_size = self.config.get('batch_size', 32)
data = self.build_dataloader(**merge_dict(self.config, data=tst_data, batch_size=batch_size, shuffle=False,
device=self.devices[0], logger=logger, overwrite=True))
dataset = data
while dataset and hasattr(dataset, 'dataset'):
dataset = dataset.dataset
num_samples = len(dataset) if dataset else None
if output and isinstance(dataset, TransformableDataset):
def add_idx(samples):
for idx, sample in enumerate(samples):
if sample:
sample[IDX] = idx
add_idx(dataset.data)
if dataset.cache:
add_idx(dataset.cache)
criterion = self.build_criterion(**self.config)
metric = self.build_metric(**self.config)
start = time.time()
outputs = self.evaluate_dataloader(data, criterion=criterion, filename=filename, output=output, input=tst_data,
save_dir=save_dir,
test=True,
num_samples=num_samples,
**merge_dict(self.config, batch_size=batch_size, metric=metric,
logger=logger, **kwargs))
elapsed = time.time() - start
if logger:
if num_samples:
logger.info(f'speed: {num_samples / elapsed:.0f} samples/second')
else:
logger.info(f'speed: {len(data) / elapsed:.0f} batches/second')
return metric, outputs
def generate_prediction_filename(self, tst_data, save_dir):
assert isinstance(tst_data,
str), 'tst_data has be a str in order to infer the output name'
output = os.path.splitext(os.path.basename(tst_data))
output = os.path.join(save_dir, output[0] + '.pred' + output[1])
return output
[docs] def to(self,
devices: Union[int, float, List[int], Dict[str, Union[int, torch.device]]] = None,
logger: logging.Logger = None, verbose=HANLP_VERBOSE):
"""Move this component to devices.
Args:
devices: Target devices.
logger: Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds.
verbose: ``True`` to print progress when logger is None.
"""
if devices is None:
# if getattr(torch, 'has_mps', None): # mac M1 chips
# devices = torch.device('mps:0')
# else:
devices = cuda_devices(devices)
elif devices == -1 or devices == [-1]:
devices = []
elif isinstance(devices, (int, float)):
devices = cuda_devices(devices)
if devices:
if logger:
logger.info(f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]')
if isinstance(devices, list):
if verbose:
flash(f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]')
self.model = self.model.to(devices[0])
if len(devices) > 1 and not isdebugging() and not isinstance(self.model, nn.DataParallel):
self.model = self.parallelize(devices)
elif isinstance(devices, dict):
for name, module in self.model.named_modules():
for regex, device in devices.items():
try:
on_device: torch.device = next(module.parameters()).device
except StopIteration:
continue
if on_device == device:
continue
if isinstance(device, int):
if on_device.index == device:
continue
if re.match(regex, name):
if not name:
name = '*'
flash(f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n')
module.to(device)
elif isinstance(devices, torch.device):
if verbose:
flash(f'Moving model to {devices} [blink][yellow]...[/yellow][/blink]')
self.model = self.model.to(devices)
else:
raise ValueError(f'Unrecognized devices {devices}')
if verbose:
flash('')
else:
if logger:
logger.info('Using [red]CPU[/red]')
def parallelize(self, devices: List[Union[int, torch.device]]):
return nn.DataParallel(self.model, device_ids=devices)
@property
def devices(self):
"""The devices this component lives on.
"""
if self.model is None:
return None
# next(parser.model.parameters()).device
if hasattr(self.model, 'device_ids'):
return self.model.device_ids
device: torch.device = next(self.model.parameters()).device
return [device]
@property
def device(self):
"""The first device this component lives on.
"""
devices = self.devices
if not devices:
return None
return devices[0]
[docs] def on_config_ready(self, **kwargs):
"""Called when config is ready, either during ``fit`` or ``load``. Subclass can perform extra initialization
tasks in this callback.
Args:
**kwargs: Not used.
"""
pass
@property
def model_(self) -> nn.Module:
"""
The actual model when it's wrapped by a `DataParallel`
Returns: The "real" model
"""
if isinstance(self.model, nn.DataParallel):
return self.model.module
return self.model
# noinspection PyMethodOverriding
[docs] @abstractmethod
def predict(self, *args, **kwargs):
"""Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with
``torch.no_grad`` and will introduces unnecessary gradient computation. Use ``__call__`` instead.
Args:
*args: Sentences or tokens.
**kwargs: Used in sub-classes.
"""
pass
@staticmethod
def _create_dummy_placeholder_on(device):
if device < 0:
device = 'cpu:0'
return torch.zeros(16, 16, device=device)
@torch.no_grad()
def __call__(self, *args, **kwargs):
"""Predict on data fed by user. This method calls :meth:`~hanlp.common.torch_component.predict` but decorates
it with ``torch.no_grad``.
Args:
*args: Sentences or tokens.
**kwargs: Used in sub-classes.
"""
return super().__call__(*args, **merge_dict(self.config, overwrite=True, **kwargs))