transformer
transformer¶
- class hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbedding(field: str, transformer: str, average_subwords=False, scalar_mix: Optional[Union[hanlp.layers.scalar_mix.ScalarMixWithDropoutBuilder, int]] = None, word_dropout: Optional[Union[float, Tuple[float, str]]] = None, max_sequence_length=None, truncate_long_sequences=False, cls_is_bos=False, sep_is_eos=False, ret_token_span=True, ret_subtokens=False, ret_subtokens_group=False, ret_prefix_mask=False, ret_raw_hidden_states=False, transformer_args: Optional[Dict[str, Any]] = None, use_fast=True, do_basic_tokenize=True, trainable=True)[source]¶
A contextual word embedding builder which builds a
ContextualWordEmbeddingModule
and aTransformerSequenceTokenizer
.- Parameters
field – The field to work on. Usually some token fields.
transformer – An identifier of a
PreTrainedModel
.average_subwords –
True
to average subword representations.scalar_mix – Layer attention.
word_dropout – Dropout rate of randomly replacing a subword with MASK.
max_sequence_length – The maximum sequence length. Sequence longer than this will be handled by sliding window.
truncate_long_sequences –
True
to return hidden states of each layer.cls_is_bos –
True
means the first token of input is treated as [CLS] no matter what its surface form is.False
(default) means the first token is not [CLS], it will have its own embedding other than the embedding of [CLS].sep_is_eos –
True
means the last token of input is [SEP].False
means it’s not but [SEP] will be appended,None
means it dependents on input[-1] == [EOS].ret_token_span –
True
to return span of each token measured by subtoken offsets.ret_subtokens –
True
to return list of subtokens belonging to each token.ret_subtokens_group –
True
to return list of offsets of subtokens belonging to each token.ret_prefix_mask –
True
to generate a mask where each non-zero element corresponds to a prefix of a token.ret_raw_hidden_states –
True
to return hidden states of each layer.transformer_args – Extra arguments passed to the transformer.
use_fast – Whether or not to try to load the fast version of the tokenizer.
do_basic_tokenize – Whether to do basic tokenization before wordpiece.
trainable –
False
to use static embeddings.
- module(training=True, **kwargs) Optional[torch.nn.modules.module.Module] [source]¶
Build a module for this embedding.
- Parameters
**kwargs – Containing vocabs, training etc. Not finalized for now.
- Returns
A module.
- transform(**kwargs) hanlp.transform.transformer_tokenizer.TransformerSequenceTokenizer [source]¶
Build a transform function for this embedding.
- Parameters
**kwargs – Containing vocabs, training etc. Not finalized for now.
- Returns
A transform function.
- class hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbeddingModule(field: str, transformer: str, transformer_tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, average_subwords=False, scalar_mix: Optional[Union[hanlp.layers.scalar_mix.ScalarMixWithDropoutBuilder, int]] = None, word_dropout=None, max_sequence_length=None, ret_raw_hidden_states=False, transformer_args: Optional[Dict[str, Any]] = None, trainable=True, training=True)[source]¶
A contextualized word embedding module.
- Parameters
field – The field to work on. Usually some token fields.
transformer – An identifier of a
PreTrainedModel
.transformer_tokenizer –
average_subwords –
True
to average subword representations.scalar_mix – Layer attention.
word_dropout – Dropout rate of randomly replacing a subword with MASK.
max_sequence_length – The maximum sequence length. Sequence longer than this will be handled by sliding window.
ret_raw_hidden_states –
True
to return hidden states of each layer.transformer_args – Extra arguments passed to the transformer.
trainable –
False
to use static embeddings.training –
False
to skip loading weights from pre-trained transformers.
- forward(batch: dict, mask=None, **kwargs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.