Source code for hanlp.layers.embeddings.contextual_word_embedding

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-07-05 13:50
from typing import Optional, Union, List, Any, Dict, Tuple

import torch
from torch import nn

from hanlp_common.configurable import AutoConfigurable
from hanlp.layers.embeddings.embedding import Embedding
from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
from hanlp.layers.transformers.encoder import TransformerEncoder
from hanlp.layers.transformers.pt_imports import PreTrainedTokenizer, AutoConfig_, AutoTokenizer_
from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer


[docs]class ContextualWordEmbeddingModule(TransformerEncoder): def __init__(self, field: str, transformer: str, transformer_tokenizer: PreTrainedTokenizer, average_subwords=False, scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None, word_dropout=None, max_sequence_length=None, ret_raw_hidden_states=False, transformer_args: Dict[str, Any] = None, trainable=True, training=True) -> None: """A contextualized word embedding module. Args: field: The field to work on. Usually some token fields. transformer: An identifier of a ``PreTrainedModel``. transformer_tokenizer: average_subwords: ``True`` to average subword representations. scalar_mix: Layer attention. word_dropout: Dropout rate of randomly replacing a subword with MASK. max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding window. ret_raw_hidden_states: ``True`` to return hidden states of each layer. transformer_args: Extra arguments passed to the transformer. trainable: ``False`` to use static embeddings. training: ``False`` to skip loading weights from pre-trained transformers. """ super().__init__(transformer, transformer_tokenizer, average_subwords, scalar_mix, word_dropout, max_sequence_length, ret_raw_hidden_states, transformer_args, trainable, training) self.field = field # noinspection PyMethodOverriding # noinspection PyTypeChecker
[docs] def forward(self, batch: dict, mask=None, **kwargs): input_ids: torch.LongTensor = batch[f'{self.field}_input_ids'] token_span: torch.LongTensor = batch.get(f'{self.field}_token_span', None) # input_device = input_ids.device # this_device = self.get_device() # if input_device != this_device: # input_ids = input_ids.to(this_device) # token_span = token_span.to(this_device) # We might want to apply mask here output: Union[torch.Tensor, List[torch.Tensor]] = super().forward(input_ids, token_span=token_span, **kwargs) # if input_device != this_device: # if isinstance(output, torch.Tensor): # output = output.to(input_device) # else: # output = [x.to(input_device) for x in output] return output
def get_output_dim(self): return self.transformer.config.hidden_size def get_device(self): device: torch.device = next(self.parameters()).device return device
[docs]class ContextualWordEmbedding(Embedding, AutoConfigurable): def __init__(self, field: str, transformer: str, average_subwords=False, scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None, word_dropout: Optional[Union[float, Tuple[float, str]]] = None, max_sequence_length=None, truncate_long_sequences=False, cls_is_bos=False, sep_is_eos=False, ret_token_span=True, ret_subtokens=False, ret_subtokens_group=False, ret_prefix_mask=False, ret_raw_hidden_states=False, transformer_args: Dict[str, Any] = None, use_fast=True, do_basic_tokenize=True, trainable=True) -> None: """A contextual word embedding builder which builds a :class:`~hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbeddingModule` and a :class:`~hanlp.transform.transformer_tokenizer.TransformerSequenceTokenizer`. Args: field: The field to work on. Usually some token fields. transformer: An identifier of a ``PreTrainedModel``. average_subwords: ``True`` to average subword representations. scalar_mix: Layer attention. word_dropout: Dropout rate of randomly replacing a subword with MASK. max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding window. truncate_long_sequences: ``True`` to return hidden states of each layer. cls_is_bos: ``True`` means the first token of input is treated as [CLS] no matter what its surface form is. ``False`` (default) means the first token is not [CLS], it will have its own embedding other than the embedding of [CLS]. sep_is_eos: ``True`` means the last token of input is [SEP]. ``False`` means it's not but [SEP] will be appended, ``None`` means it dependents on `input[-1] == [EOS]`. ret_token_span: ``True`` to return span of each token measured by subtoken offsets. ret_subtokens: ``True`` to return list of subtokens belonging to each token. ret_subtokens_group: ``True`` to return list of offsets of subtokens belonging to each token. ret_prefix_mask: ``True`` to generate a mask where each non-zero element corresponds to a prefix of a token. ret_raw_hidden_states: ``True`` to return hidden states of each layer. transformer_args: Extra arguments passed to the transformer. use_fast: Whether or not to try to load the fast version of the tokenizer. do_basic_tokenize: Whether to do basic tokenization before wordpiece. trainable: ``False`` to use static embeddings. """ super().__init__() self.truncate_long_sequences = truncate_long_sequences self.transformer_args = transformer_args self.trainable = trainable self.ret_subtokens_group = ret_subtokens_group self.ret_subtokens = ret_subtokens self.ret_raw_hidden_states = ret_raw_hidden_states self.sep_is_eos = sep_is_eos self.cls_is_bos = cls_is_bos self.max_sequence_length = max_sequence_length self.word_dropout = word_dropout self.scalar_mix = scalar_mix self.average_subwords = average_subwords self.transformer = transformer self.field = field self._transformer_tokenizer = AutoTokenizer_.from_pretrained(self.transformer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize) self._tokenizer_transform = TransformerSequenceTokenizer(self._transformer_tokenizer, field, truncate_long_sequences=truncate_long_sequences, ret_prefix_mask=ret_prefix_mask, ret_token_span=ret_token_span, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos, ret_subtokens=ret_subtokens, ret_subtokens_group=ret_subtokens_group, max_seq_length=self.max_sequence_length )
[docs] def transform(self, **kwargs) -> TransformerSequenceTokenizer: return self._tokenizer_transform
[docs] def module(self, training=True, **kwargs) -> Optional[nn.Module]: return ContextualWordEmbeddingModule(self.field, self.transformer, self._transformer_tokenizer, self.average_subwords, self.scalar_mix, self.word_dropout, self.max_sequence_length, self.ret_raw_hidden_states, self.transformer_args, self.trainable, training=training)
def get_output_dim(self): config = AutoConfig_.from_pretrained(self.transformer) return config.hidden_size def get_tokenizer(self): return self._transformer_tokenizer
def find_transformer(embed: nn.Module): if isinstance(embed, ContextualWordEmbeddingModule): return embed if isinstance(embed, nn.ModuleList): for child in embed: found = find_transformer(child) if found: return found