transformer_tagger

transformer_tagger

Transformer based tagger.

class hanlp.components.taggers.transformers.transformer_tagger.TransformerTagger(**kwargs)[source]

A simple tagger using a linear layer with an optional CRF (Lafferty et al. 2001) layer for any tagging tasks including PoS tagging and many others.

Parameters

**kwargs – Not used.

build_dataloader(data, batch_size, shuffle, device, logger: Optional[logging.Logger] = None, sampler_builder: Optional[hanlp.common.dataset.SamplerBuilder] = None, gradient_accumulation=1, extra_embeddings: Optional[hanlp.layers.embeddings.embedding.Embedding] = None, transform=None, max_seq_len=None, **kwargs) torch.utils.data.dataloader.DataLoader[source]

Build dataloader for training, dev and test sets. It’s suggested to build vocabs in this method if they are not built yet.

Parameters
  • data – Data representing samples, which can be a path or a list of samples.

  • batch_size – Number of samples per batch.

  • shuffle – Whether to shuffle this dataloader.

  • device – Device tensors should be loaded onto.

  • logger – Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.

  • **kwargs – Arguments from **self.config.

build_model(training=True, extra_embeddings: Optional[hanlp.layers.embeddings.embedding.Embedding] = None, finetune=False, logger=None, **kwargs) torch.nn.modules.module.Module[source]

Build model.

Parameters
  • trainingTrue if called during training.

  • **kwargs**self.config.

build_vocabs(trn, logger, **kwargs)[source]

Override this method to build vocabs.

Parameters
  • trn – Training set.

  • logger – Logger for reporting progress.

fit(trn_data, dev_data, save_dir, transformer, average_subwords=False, word_dropout: float = 0.2, hidden_dropout=None, layer_dropout=0, scalar_mix=None, mix_embedding: int = 0, grad_norm=5.0, transformer_grad_norm=None, lr=5e-05, transformer_lr=None, transformer_layers=None, gradient_accumulation=1, adam_epsilon=1e-06, weight_decay=0, warmup_steps=0.1, secondary_encoder=None, extra_embeddings: Optional[hanlp.layers.embeddings.embedding.Embedding] = None, crf=False, reduction='sum', batch_size=32, sampler_builder: Optional[hanlp.common.dataset.SamplerBuilder] = None, epochs=3, patience=5, token_key=None, max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None, devices: Optional[Union[float, int, List[int]]] = None, **kwargs)[source]

Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files.

Parameters
  • trn_data – Training set.

  • dev_data – Development set.

  • save_dir – The directory to save trained component.

  • batch_size – The number of samples in a batch.

  • epochs – Number of epochs.

  • devices – Devices this component will live on.

  • logger – Any logging.Logger instance.

  • seed – Random seed to reproduce this training.

  • finetuneTrue to load from save_dir instead of creating a randomly initialized component. str to specify a different save_dir to load from.

  • eval_trn – Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging.

  • _device_placeholderTrue to create a placeholder tensor which triggers PyTorch to occupy devices so other components won’t take these devices as first choices.

  • **kwargs – Hyperparameters used by sub-classes.

Returns

Any results sub-classes would like to return. Usually the best metrics on training set.

fit_dataloader(trn: torch.utils.data.dataloader.DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: hanlp.common.structure.History, gradient_accumulation=1, grad_norm=None, transformer_grad_norm=None, teacher: Optional[hanlp.components.taggers.tagger.Tagger] = None, kd_criterion=None, temperature_scheduler=None, ratio_width=None, eval_trn=True, **kwargs)[source]

Fit onto a dataloader.

Parameters
  • trn – Training set.

  • criterion – Loss function.

  • optimizer – Optimizer.

  • metric – Metric(s).

  • logger – Logger for reporting progress.

  • **kwargs – Other hyper-parameters passed from sub-class.