Source code for hanlp.layers.transformers.encoder
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-22 21:06
import warnings
from typing import Union, Dict, Any, Sequence, Tuple, Optional
import torch
from torch import nn
from hanlp.layers.dropout import WordDropout
from hanlp.layers.scalar_mix import ScalarMixWithDropout, ScalarMixWithDropoutBuilder
from hanlp.layers.transformers.resource import get_tokenizer_mirror
from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModel_, \
    BertTokenizer, AutoTokenizer_
from hanlp.layers.transformers.utils import transformer_encode
# noinspection PyAbstractClass
[docs]class TransformerEncoder(nn.Module):
    def __init__(self,
                 transformer: Union[PreTrainedModel, str],
                 transformer_tokenizer: PreTrainedTokenizer,
                 average_subwords=False,
                 scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
                 word_dropout=None,
                 max_sequence_length=None,
                 ret_raw_hidden_states=False,
                 transformer_args: Dict[str, Any] = None,
                 trainable=Union[bool, Optional[Tuple[int, int]]],
                 training=True) -> None:
        """A pre-trained transformer encoder.
        Args:
            transformer: A ``PreTrainedModel`` or an identifier of a ``PreTrainedModel``.
            transformer_tokenizer: A ``PreTrainedTokenizer``.
            average_subwords: ``True`` to average subword representations.
            scalar_mix: Layer attention.
            word_dropout: Dropout rate of randomly replacing a subword with MASK.
            max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
                window. If ``None``, then the ``max_position_embeddings`` of the transformer will be used.
            ret_raw_hidden_states: ``True`` to return hidden states of each layer.
            transformer_args: Extra arguments passed to the transformer.
            trainable: ``False`` to use static embeddings.
            training: ``False`` to skip loading weights from pre-trained transformers.
        """
        super().__init__()
        self.ret_raw_hidden_states = ret_raw_hidden_states
        self.average_subwords = average_subwords
        if word_dropout:
            oov = transformer_tokenizer.mask_token_id
            if isinstance(word_dropout, Sequence):
                word_dropout, replacement = word_dropout
                if replacement == 'unk':
                    # Electra English has to use unk
                    oov = transformer_tokenizer.unk_token_id
                elif replacement == 'mask':
                    # UDify uses [MASK]
                    oov = transformer_tokenizer.mask_token_id
                else:
                    oov = replacement
            pad = transformer_tokenizer.pad_token_id
            cls = transformer_tokenizer.cls_token_id
            sep = transformer_tokenizer.sep_token_id
            excludes = [pad, cls, sep]
            self.word_dropout = WordDropout(p=word_dropout, oov_token=oov, exclude_tokens=excludes)
        else:
            self.word_dropout = None
        if isinstance(transformer, str):
            output_hidden_states = scalar_mix is not None
            if transformer_args is None:
                transformer_args = dict()
            transformer_args['output_hidden_states'] = output_hidden_states
            transformer = AutoModel_.from_pretrained(transformer, training=training or not trainable,
                                                     **transformer_args)
            if max_sequence_length is None:
                max_sequence_length = transformer.config.max_position_embeddings
        self.max_sequence_length = max_sequence_length
        if hasattr(transformer, 'encoder') and hasattr(transformer, 'decoder'):
            # For seq2seq model, use its encoder
            transformer = transformer.encoder
        self.transformer = transformer
        if not trainable:
            transformer.requires_grad_(False)
        elif isinstance(trainable, tuple):
            layers = []
            if hasattr(transformer, 'embeddings'):
                layers.append(transformer.embeddings)
            layers.extend(transformer.encoder.layer)
            for i, layer in enumerate(layers):
                if i < trainable[0] or i >= trainable[1]:
                    layer.requires_grad_(False)
        if isinstance(scalar_mix, ScalarMixWithDropoutBuilder):
            self.scalar_mix: ScalarMixWithDropout = scalar_mix.build()
        else:
            self.scalar_mix = None
[docs]    def forward(self, input_ids: torch.LongTensor, attention_mask=None, token_type_ids=None, token_span=None, **kwargs):
        if self.word_dropout:
            input_ids = self.word_dropout(input_ids)
        x = transformer_encode(self.transformer,
                               input_ids,
                               attention_mask,
                               token_type_ids,
                               token_span,
                               layer_range=self.scalar_mix.mixture_range if self.scalar_mix else 0,
                               max_sequence_length=self.max_sequence_length,
                               average_subwords=self.average_subwords,
                               ret_raw_hidden_states=self.ret_raw_hidden_states)
        if self.ret_raw_hidden_states:
            x, raw_hidden_states = x
        if self.scalar_mix:
            x = self.scalar_mix(x)
        if self.ret_raw_hidden_states:
            # noinspection PyUnboundLocalVariable
            return x, raw_hidden_states
        return x
    @staticmethod
    def build_transformer(config, training=True) -> PreTrainedModel:
        kwargs = {}
        if config.scalar_mix and config.scalar_mix > 0:
            kwargs['output_hidden_states'] = True
        transformer = AutoModel_.from_pretrained(config.transformer, training=training, **kwargs)
        return transformer
    @staticmethod
    def build_transformer_tokenizer(config_or_str, use_fast=True, do_basic_tokenize=True) -> PreTrainedTokenizer:
        return AutoTokenizer_.from_pretrained(config_or_str, use_fast, do_basic_tokenize)
