Source code for hanlp.layers.transformers.encoder
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-22 21:06
import warnings
from typing import Union, Dict, Any, Sequence, Tuple, Optional
import torch
from torch import nn
from hanlp.layers.dropout import WordDropout
from hanlp.layers.scalar_mix import ScalarMixWithDropout, ScalarMixWithDropoutBuilder
from hanlp.layers.transformers.resource import get_tokenizer_mirror
from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModel_, \
BertTokenizer, AutoTokenizer_
from hanlp.layers.transformers.utils import transformer_encode
# noinspection PyAbstractClass
[docs]class TransformerEncoder(nn.Module):
def __init__(self,
transformer: Union[PreTrainedModel, str],
transformer_tokenizer: PreTrainedTokenizer,
average_subwords=False,
scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
word_dropout=None,
max_sequence_length=None,
ret_raw_hidden_states=False,
transformer_args: Dict[str, Any] = None,
trainable=Union[bool, Optional[Tuple[int, int]]],
training=True) -> None:
"""A pre-trained transformer encoder.
Args:
transformer: A ``PreTrainedModel`` or an identifier of a ``PreTrainedModel``.
transformer_tokenizer: A ``PreTrainedTokenizer``.
average_subwords: ``True`` to average subword representations.
scalar_mix: Layer attention.
word_dropout: Dropout rate of randomly replacing a subword with MASK.
max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
window. If ``None``, then the ``max_position_embeddings`` of the transformer will be used.
ret_raw_hidden_states: ``True`` to return hidden states of each layer.
transformer_args: Extra arguments passed to the transformer.
trainable: ``False`` to use static embeddings.
training: ``False`` to skip loading weights from pre-trained transformers.
"""
super().__init__()
self.ret_raw_hidden_states = ret_raw_hidden_states
self.average_subwords = average_subwords
if word_dropout:
oov = transformer_tokenizer.mask_token_id
if isinstance(word_dropout, Sequence):
word_dropout, replacement = word_dropout
if replacement == 'unk':
# Electra English has to use unk
oov = transformer_tokenizer.unk_token_id
elif replacement == 'mask':
# UDify uses [MASK]
oov = transformer_tokenizer.mask_token_id
else:
oov = replacement
pad = transformer_tokenizer.pad_token_id
cls = transformer_tokenizer.cls_token_id
sep = transformer_tokenizer.sep_token_id
excludes = [pad, cls, sep]
self.word_dropout = WordDropout(p=word_dropout, oov_token=oov, exclude_tokens=excludes)
else:
self.word_dropout = None
if isinstance(transformer, str):
output_hidden_states = scalar_mix is not None
if transformer_args is None:
transformer_args = dict()
transformer_args['output_hidden_states'] = output_hidden_states
transformer = AutoModel_.from_pretrained(transformer, training=training or not trainable,
**transformer_args)
if max_sequence_length is None:
max_sequence_length = transformer.config.max_position_embeddings
self.max_sequence_length = max_sequence_length
if hasattr(transformer, 'encoder') and hasattr(transformer, 'decoder'):
# For seq2seq model, use its encoder
transformer = transformer.encoder
self.transformer = transformer
if not trainable:
transformer.requires_grad_(False)
elif isinstance(trainable, tuple):
layers = []
if hasattr(transformer, 'embeddings'):
layers.append(transformer.embeddings)
layers.extend(transformer.encoder.layer)
for i, layer in enumerate(layers):
if i < trainable[0] or i >= trainable[1]:
layer.requires_grad_(False)
if isinstance(scalar_mix, ScalarMixWithDropoutBuilder):
self.scalar_mix: ScalarMixWithDropout = scalar_mix.build()
else:
self.scalar_mix = None
[docs] def forward(self, input_ids: torch.LongTensor, attention_mask=None, token_type_ids=None, token_span=None, **kwargs):
if self.word_dropout:
input_ids = self.word_dropout(input_ids)
x = transformer_encode(self.transformer,
input_ids,
attention_mask,
token_type_ids,
token_span,
layer_range=self.scalar_mix.mixture_range if self.scalar_mix else 0,
max_sequence_length=self.max_sequence_length,
average_subwords=self.average_subwords,
ret_raw_hidden_states=self.ret_raw_hidden_states)
if self.ret_raw_hidden_states:
x, raw_hidden_states = x
if self.scalar_mix:
x = self.scalar_mix(x)
if self.ret_raw_hidden_states:
# noinspection PyUnboundLocalVariable
return x, raw_hidden_states
return x
@staticmethod
def build_transformer(config, training=True) -> PreTrainedModel:
kwargs = {}
if config.scalar_mix and config.scalar_mix > 0:
kwargs['output_hidden_states'] = True
transformer = AutoModel_.from_pretrained(config.transformer, training=training, **kwargs)
return transformer
@staticmethod
def build_transformer_tokenizer(config_or_str, use_fast=True, do_basic_tokenize=True) -> PreTrainedTokenizer:
return AutoTokenizer_.from_pretrained(config_or_str, use_fast, do_basic_tokenize)