# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-06-02 23:49
from typing import Optional, Callable, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from hanlp_common.configurable import AutoConfigurable
from hanlp.common.transform import VocabDict, ToChar
from hanlp.common.vocab import Vocab
from hanlp.layers.embeddings.embedding import Embedding, EmbeddingDim
[docs]class CharRNN(nn.Module, EmbeddingDim):
def __init__(self,
field,
vocab_size,
embed: Union[int, nn.Embedding],
hidden_size):
"""Character level RNN embedding module.
Args:
field: The field in samples this encoder will work on.
vocab_size: The size of character vocab.
embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
hidden_size: The hidden size of RNNs.
"""
super(CharRNN, self).__init__()
self.field = field
# the embedding layer
if isinstance(embed, int):
self.embed = nn.Embedding(num_embeddings=vocab_size,
embedding_dim=embed)
elif isinstance(embed, nn.Module):
self.embed = embed
embed = embed.embedding_dim
else:
raise ValueError(f'Unrecognized type for {embed}')
# the lstm layer
self.lstm = nn.LSTM(input_size=embed,
hidden_size=hidden_size,
batch_first=True,
bidirectional=True)
[docs] def forward(self, batch, mask, **kwargs):
x = batch[f'{self.field}_char_id']
# [batch_size, seq_len, fix_len]
mask = x.ne(0)
# [batch_size, seq_len]
lens = mask.sum(-1)
char_mask = lens.gt(0)
# [n, fix_len, n_embed]
x = self.embed(batch) if isinstance(self.embed, EmbeddingDim) else self.embed(x[char_mask])
x = pack_padded_sequence(x[char_mask], lens[char_mask].cpu(), True, False)
x, (h, _) = self.lstm(x)
# [n, fix_len, n_out]
h = torch.cat(torch.unbind(h), -1)
# [batch_size, seq_len, n_out]
embed = h.new_zeros(*lens.shape, h.size(-1))
embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h)
return embed
@property
def embedding_dim(self) -> int:
return self.lstm.hidden_size * 2
[docs]class CharRNNEmbedding(Embedding, AutoConfigurable):
def __init__(self,
field,
embed,
hidden_size,
max_word_length=None) -> None:
"""Character level RNN embedding module builder.
Args:
field: The field in samples this encoder will work on.
embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
hidden_size: The hidden size of RNNs.
max_word_length: Character sequence longer than ``max_word_length`` will be truncated.
"""
super().__init__()
self.field = field
self.hidden_size = hidden_size
self.embed = embed
self.max_word_length = max_word_length
@property
def vocab_name(self):
vocab_name = f'{self.field}_char'
return vocab_name
[docs] def module(self, vocabs: VocabDict, **kwargs) -> Optional[nn.Module]:
embed = self.embed
if isinstance(self.embed, Embedding):
embed = self.embed.module(vocabs=vocabs)
return CharRNN(self.field, len(vocabs[self.vocab_name]), embed, self.hidden_size)