rnn_tagger
rnn_tagger¶
RNN based tagger.
- class hanlp.components.taggers.rnn_tagger.RNNTagger(**kwargs)[source]¶
An old-school tagger using non-contextualized embeddings and RNNs as context layer.
- Parameters
**kwargs – Predefined config.
- build_dataloader(data, batch_size, shuffle, device, logger=None, **kwargs) torch.utils.data.dataloader.DataLoader [source]¶
Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.
- Parameters
data – Data representing samples, which can be a path or a list of samples.
batch_size – Number of samples per batch.
shuffle – Whether to shuffle this dataloader.
device – Device tensors should be loaded onto.
logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
**kwargs – Arguments from
**self.config
.
- build_model(rnn_input, rnn_hidden, drop, crf, **kwargs) torch.nn.modules.module.Module [source]¶
Build model.
- Parameters
training –
True
if called during training.**kwargs –
**self.config
.
- build_vocabs(dataset, logger)[source]¶
Override this method to build vocabs.
- Parameters
trn – Training set.
logger – Logger for reporting progress.
- execute_training_loop(trn: torch.utils.data.dataloader.DataLoader, dev: torch.utils.data.dataloader.DataLoader, epochs, criterion, optimizer, metric, save_dir, logger, patience, **kwargs)[source]¶
Implement this to run training loop.
- Parameters
trn – Training set.
dev – Development set.
epochs – Number of epochs.
criterion – Loss function.
optimizer – Optimizer(s).
metric – Metric(s)
save_dir – The directory to save this component.
logger – Logger for reporting progress.
devices – Devices this component and dataloader will live on.
ratio_width – The width of dataset size measured in number of characters. Used for logger to align messages.
**kwargs – Other hyper-parameters passed from sub-class.
- fit(trn_data, dev_data, save_dir, batch_size=50, epochs=100, embed=100, rnn_input=None, rnn_hidden=256, drop=0.5, lr=0.001, patience=10, crf=True, optimizer='adam', token_key='token', tagging_scheme=None, anneal_factor: float = 0.5, anneal_patience=2, devices=None, logger=None, verbose=True, **kwargs)[source]¶
Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.
- Parameters
trn_data – Training set.
dev_data – Development set.
save_dir – The directory to save trained component.
batch_size – The number of samples in a batch.
epochs – Number of epochs.
devices – Devices this component will live on.
logger – Any
logging.Logger
instance.seed – Random seed to reproduce this training.
finetune –
True
to load fromsave_dir
instead of creating a randomly initialized component.str
to specify a differentsave_dir
to load from.eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.
_device_placeholder –
True
to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.**kwargs – Hyperparameters used by sub-classes.
- Returns
Any results sub-classes would like to return. Usually the best metrics on training set.
- fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None, **kwargs)[source]¶
Fit onto a dataloader.
- Parameters
trn – Training set.
criterion – Loss function.
optimizer – Optimizer.
metric – Metric(s).
logger – Logger for reporting progress.
**kwargs – Other hyper-parameters passed from sub-class.