Source code for hanlp.transform.transformer_tokenizer
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-05-03 16:23
import warnings
from typing import Union, Optional
from hanlp_common.constant import BOS, EOS
from hanlp_common.structure import SerializableDict
from hanlp.layers.transformers.pt_imports import PreTrainedTokenizer, PretrainedConfig, AutoTokenizer_
from hanlp_trie import DictInterface
class TransformerTokenizer(object):
def __init__(self, max_seq_length=512, truncate_long_sequences=True) -> None:
self.truncate_long_sequences = truncate_long_sequences
self.max_seq_length = max_seq_length
def sliding_window(self, flat_wordpiece_ids, same_tail=True):
if same_tail:
start_piece_ids, flat_wordpiece_ids, end_piece_ids = flat_wordpiece_ids[:1], \
flat_wordpiece_ids[1:-1], flat_wordpiece_ids[-1:]
else:
start_piece_ids, flat_wordpiece_ids, end_piece_ids = flat_wordpiece_ids[:1], \
flat_wordpiece_ids[1:], []
window_length = self.max_seq_length - len(start_piece_ids) - len(end_piece_ids)
stride = window_length // 2
wordpiece_windows = [start_piece_ids + flat_wordpiece_ids[i:i + window_length] + end_piece_ids
for i in range(0, len(flat_wordpiece_ids), stride)]
# Check for overlap in the last window. Throw it away if it is redundant.
last_window = wordpiece_windows[-1][1:]
penultimate_window = wordpiece_windows[-2]
if last_window == penultimate_window[-len(last_window):]:
wordpiece_windows = wordpiece_windows[:-1]
wordpiece_ids = [wordpiece for sequence in wordpiece_windows for wordpiece in sequence]
return wordpiece_ids
class TransformerTextTokenizer(TransformerTokenizer):
_KEY = ['input_ids', 'attention_mask', 'token_type_ids']
def __init__(self,
tokenizer: Union[PreTrainedTokenizer, str],
text_a_key: str,
text_b_key: str = None,
output_key=None,
max_seq_length=512, truncate_long_sequences=True) -> None:
super().__init__(max_seq_length, truncate_long_sequences)
self.text_b = text_b_key
self.text_a = text_a_key
if output_key is None:
output_key = self.text_a
if text_b_key:
output_key += '_' + text_b_key
if output_key == '':
output_key = self._KEY
else:
output_key = [f'{output_key}_{key}' for key in self._KEY]
self.output_key = output_key
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer_.from_pretrained(tokenizer)
self.tokenizer = tokenizer
def __call__(self, sample: dict):
text_a = sample[self.text_a]
text_b = sample[self.text_b] if self.text_b else None
max_seq_length = self.max_seq_length if self.truncate_long_sequences else None
encoding = self.tokenizer.encode_plus(text_a, text_b, max_length=max_seq_length)
results = dict((k, encoding.data.get(k, None)) for k in self._KEY)
if not self.truncate_long_sequences and len(results['input_ids']) > self.max_seq_length:
# TODO: other fields should be properly handled too
results['input_ids'] = self.sliding_window(results['input_ids'])
if not results['token_type_ids']:
results['token_type_ids'] = encoding[0].type_ids
for k, v in zip(self.output_key, [results[_] for _ in self._KEY]):
sample[k] = v
return sample
[docs]class TransformerSequenceTokenizer(TransformerTokenizer):
def __init__(self,
tokenizer: Union[PreTrainedTokenizer, str],
input_key,
output_key=None,
max_seq_length=512,
truncate_long_sequences=False,
config: PretrainedConfig = None,
cls_token_at_end=False,
cls_token_segment_id=0,
pad_token_segment_id=0,
pad_on_left=False,
do_padding=False,
sep_token_extra=False,
ret_mask_and_type=False,
ret_prefix_mask=False,
ret_token_span=True,
ret_subtokens=False,
ret_subtokens_group=False,
cls_is_bos=False,
sep_is_eos=False,
do_basic_tokenize=True,
use_fast=True,
dict_force=None,
strip_cls_sep=True,
check_space_before=None,
) -> None:
"""A transformer tokenizer for token-level tasks. It honors the boundary of tokens and tokenize each token into
several subtokens then merge them. The information about each subtoken belongs to which token are kept and
returned as a new field in the sample. It also provides out-of-box sliding window trick on long sequences.
Args:
tokenizer: The identifier of a pre-trained tokenizer or a ``PreTrainedTokenizer``.
input_key: The token key in samples.
output_key: The output keys to store results.
max_seq_length: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
truncate_long_sequences: ``True`` to truncate exceeded parts of long sequences. ``False`` to enable
sliding window.
config: The ``PretrainedConfig`` to determine the model structure of the transformer, so that special
tokenization can be applied.
cls_token_at_end: ``True`` to put ``[CLS]`` at the end of input tokens.
cls_token_segment_id: The id of ``[CLS]``.
pad_token_segment_id: The id of ``[SEP]``.
pad_on_left: ``True`` to put ``[PAD]`` at the left side of input tokens.
do_padding: ``True`` to pad sequence to the left.
sep_token_extra: ``True`` to have two ``[SEP]``.
ret_mask_and_type: ``True`` to return masks and type ids.
ret_prefix_mask: ``True`` to generate a mask where each non-zero element corresponds to a prefix of a token.
ret_token_span: ``True`` to return span of each token measured by subtoken offsets.
ret_subtokens: ``True`` to return list of subtokens belonging to each token for tokenization purpose.
When enabled, the prefix mask for each subtoken is set to True as each subtoken is a token unit in
tokenization task. Similarity, the token span for each token will be a continuous integer sequence.
ret_subtokens_group: ``True`` to return list of offsets of subtokens belonging to each token.
cls_is_bos: ``True`` means the first token of input is treated as [CLS] no matter what its surface form is.
``False`` (default) means the first token is not [CLS], it will have its own embedding other than
the embedding of [CLS].
sep_is_eos: ``True`` means the last token of input is [SEP].
``False`` means it's not but [SEP] will be appended,
``None`` means it dependents on `input[-1] == [EOS]`.
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
use_fast: Whether or not to try to load the fast version of the tokenizer.
dict_force: A dictionary doing longest-prefix-match on input text so that the head and tail of each keyword
won't be concatenated to other tokens by transformer tokenizers.
strip_cls_sep: ``True`` to strip [CLS] and [SEP] off the input tokens.
check_space_before: ``True`` to detect the space before each token to handle underline in sentence piece
tokenization.
Examples:
.. highlight:: python
.. code-block:: python
transform = TransformerSequenceTokenizer('bert-base-uncased', 'token')
sample = {'token': 'HanLP good'.split()}
print(transform(sample))
"""
super().__init__(max_seq_length, truncate_long_sequences)
tokenizer_name = tokenizer if isinstance(tokenizer, str) else tokenizer.name_or_path
if check_space_before is None:
# These tokenizer is BPE-based which appends a space before each token and tokenizes loving into
# ['▁lo', 'ving'], tokenize 商品 into ['▁', '商品']. For the later case, the prefix '▁' has to be removed
# as there is no space between some languages like Chinese
check_space_before = tokenizer_name in (
'xlm-roberta-base', 'xlm-roberta-large', 'google/mt5-small', 'google/mt5-base',
'xlm-roberta-base-no-space', 'mMiniLMv2L6-no-space', 'mMiniLMv2L12-no-space',
"answerdotai/ModernBERT-base", "answerdotai/ModernBERT-large"
)
self.check_space_before = check_space_before
self.ret_subtokens_group = ret_subtokens_group
self.ret_subtokens = ret_subtokens
self.sep_is_eos = sep_is_eos
self.ret_prefix_mask = ret_prefix_mask
self.ret_mask_and_type = ret_mask_and_type
self.cls_is_bos = cls_is_bos
self.ret_token_span = ret_token_span
if not output_key or isinstance(output_key, str):
suffixes = ['input_ids']
if ret_mask_and_type:
suffixes += 'attention_mask', 'token_type_ids'
if ret_prefix_mask:
suffixes += ['prefix_mask']
if ret_token_span:
suffixes.append('token_span')
if output_key is None:
output_key = [f'{input_key}_{key}' for key in suffixes]
elif output_key == '':
output_key = suffixes
else:
output_key = [f'{output_key}_{key}' for key in suffixes]
self.input_key = input_key
self.output_key = output_key
if config:
xlnet = config_is(config, 'xlnet')
pad_token_segment_id = 4 if xlnet else 0
cls_token_segment_id = 2 if xlnet else 0
cls_token_at_end = xlnet
pad_on_left = xlnet
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer_.from_pretrained(tokenizer, use_fast=use_fast,
do_basic_tokenize=do_basic_tokenize)
if use_fast:
# Dirty fix upstream bug: https://github.com/hankcs/HanLP/issues/1602
if hasattr(tokenizer, '_tokenizer') and hasattr(tokenizer._tokenizer, 'no_truncation'):
_t = tokenizer._tokenizer
_t.no_truncation()
_t.no_padding()
_t.no_truncation = _t.no_padding = lambda: None
pad_token = tokenizer.pad_token
self.pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
self.pad_token_segment_id = pad_token_segment_id
if tokenizer_name in ('google/mt5-small', 'google/mt5-base'):
# mt5 doesn't have cls or sep, but we can use something similar
self.has_cls = False
self.cls_token = '▁'
self.cls_token_id = tokenizer.convert_tokens_to_ids(self.cls_token)
self.sep_token = tokenizer.eos_token
self.sep_token_id = tokenizer.eos_token_id
else:
self.has_cls = True
self.cls_token = tokenizer.cls_token
self.sep_token = tokenizer.sep_token
self.cls_token_segment_id = cls_token_segment_id
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.sep_token_extra = sep_token_extra
self.cls_token_at_end = cls_token_at_end
self.tokenizer = tokenizer
self.pad_on_left = pad_on_left
self.do_padding = do_padding
if self.ret_token_span or not self.truncate_long_sequences:
assert not self.cls_token_at_end
assert not self.pad_on_left
# if self.ret_subtokens:
# if not use_fast:
# raise NotImplementedError(
# 'ret_subtokens is not available when using Python tokenizers. '
# 'To use this feature, set use_fast = True.')
self.dict: Optional[DictInterface] = dict_force # For tokenization of raw text
self.strip_cls_sep = strip_cls_sep
def __call__(self, sample: dict):
input_tokens = sample[self.input_key]
input_is_str = isinstance(input_tokens, str)
tokenizer = self.tokenizer
ret_token_span = self.ret_token_span
if input_is_str: # This happens in a tokenizer component where the raw sentence is fed.
# noinspection PyShadowingNames
def tokenize_str(input_str, add_special_tokens=True):
if tokenizer.is_fast:
encoding = tokenizer.encode_plus(input_str,
return_offsets_mapping=True,
add_special_tokens=add_special_tokens).encodings[0]
subtoken_offsets = encoding.offsets
input_tokens = encoding.tokens
input_ids = encoding.ids
# Fill up missing non-blank characters swallowed by HF tokenizer
offset = 0
fixed_offsets = []
fixed_tokens = []
fixed_ids = []
for token, id, (b, e) in zip(input_tokens, input_ids, subtoken_offsets):
if b > offset:
missing_token = input_str[offset: b]
if not missing_token.isspace(): # In the future, we may want space back
fixed_tokens.append(missing_token)
fixed_ids.append(tokenizer.unk_token_id)
fixed_offsets.append((offset, b))
if e == offset: # LI™ -> LIT + M
if fixed_offsets and fixed_offsets[-1][0] < b:
fixed_offsets[-1] = (fixed_offsets[-1][0], b)
fixed_tokens.append(token)
fixed_ids.append(id)
fixed_offsets.append((b, e))
offset = e
subtoken_offsets = fixed_offsets
input_tokens = fixed_tokens
input_ids = fixed_ids
if add_special_tokens:
subtoken_offsets = subtoken_offsets[1 if self.has_cls else 0:-1]
# Edge case that the input_str is swallowed in whole
if input_str and not subtoken_offsets and not input_str.isspace():
__index = 1 if add_special_tokens and self.has_cls else 0
input_tokens.insert(__index, input_str)
input_ids.insert(__index, tokenizer.unk_token_id)
subtoken_offsets.append((0, len(input_str)))
if not self.has_cls:
input_tokens = [self.cls_token] + input_tokens
input_ids = [self.cls_token_id] + input_ids
else:
input_tokens = tokenizer.tokenize(input_str)
subtoken_offsets = []
_o = 0
for each in input_tokens:
subtoken_offsets.append((_o, _o + len(each)))
_o += len(each)
if add_special_tokens:
input_tokens = [self.cls_token] + input_tokens + [self.sep_token]
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
if self.check_space_before:
non_blank_offsets = [i for i in range(len(input_tokens)) if input_tokens[i] != '▁']
if add_special_tokens and not self.has_cls:
non_blank_offsets.insert(0, 0)
input_tokens = [input_tokens[i] for i in non_blank_offsets]
input_ids = [input_ids[i] for i in non_blank_offsets]
if add_special_tokens:
non_blank_offsets = non_blank_offsets[1:-1]
subtoken_offsets = [subtoken_offsets[i - 1] for i in non_blank_offsets]
else:
subtoken_offsets = [subtoken_offsets[i] for i in non_blank_offsets]
# MT5 generates tokens like ▁of, which is bad for the tokenizer. So we want to remove the prefix.
for i, token in enumerate(input_tokens[1:-1] if add_special_tokens else input_tokens):
if input_str[subtoken_offsets[i][0]] == ' ':
subtoken_offsets[i] = (subtoken_offsets[i][0] + 1, subtoken_offsets[i][1])
# The following block will tokenize each empty string (space) into an unk token
# if add_special_tokens:
# if len(input_tokens) == 2: # bos and eos, meaning that the text contains only some spaces
# input_tokens.insert(1, input_str)
# input_ids.insert(1, tokenizer.unk_token_id)
# subtoken_offsets.append((0, len(input_str)))
# else:
# if not input_ids: # This chunk might be some control chars getting removed by tokenizer
# input_tokens = [input_str]
# input_ids = [tokenizer.unk_token_id]
# subtoken_offsets = [(0, len(input_str))]
return input_tokens, input_ids, subtoken_offsets
if self.dict:
chunks = self.dict.split(sample.get(f'{self.input_key}_', input_tokens)) # Match original text directly
_input_tokens, _input_ids, _subtoken_offsets = [self.cls_token], [self.cls_token_id], []
_offset = 0
custom_words = sample['custom_words'] = []
char_offset = 0
for chunk in chunks:
if isinstance(chunk, str): # Use transformed text as it's what models are trained on
chunk = input_tokens[char_offset:char_offset + len(chunk)]
tokens, ids, offsets = tokenize_str(chunk, add_special_tokens=False)
char_offset += len(chunk)
else:
begin, end, label = chunk
_offset = begin
# chunk offset is on char level, at this moment, there is no concept of tokens, just subtokens
if isinstance(label, list):
tokens, ids, offsets, delta = [], [], [], 0
for token in label:
_tokens, _ids, _offsets = tokenize_str(token, add_special_tokens=False)
tokens.extend(_tokens)
# track the subword offset of this chunk, -1 for [CLS]
custom_words.append(
(len(_input_ids) + len(ids) - 1, len(_input_ids) + len(ids) - 1 + len(_ids), token))
ids.extend(_ids)
offsets.extend((x[0] + delta, x[1] + delta) for x in _offsets)
delta = offsets[-1][-1]
else:
tokens, ids, offsets = tokenize_str(input_tokens[begin:end], add_special_tokens=False)
# offsets = [(offsets[0][0], offsets[-1][-1])]
custom_words.append((len(_input_ids) - 1, len(_input_ids) + len(ids) - 1, label))
char_offset = end
_input_tokens.extend(tokens)
_input_ids.extend(ids)
_subtoken_offsets.extend((x[0] + _offset, x[1] + _offset) for x in offsets)
_offset = _subtoken_offsets[-1][-1]
subtoken_offsets = _subtoken_offsets
input_tokens = _input_tokens + [self.sep_token]
input_ids = _input_ids + [self.sep_token_id]
else:
input_tokens, input_ids, subtoken_offsets = tokenize_str(input_tokens, add_special_tokens=True)
if self.ret_subtokens:
sample[f'{self.input_key}_subtoken_offsets'] = subtoken_offsets
cls_is_bos = self.cls_is_bos
if cls_is_bos is None:
cls_is_bos = input_tokens[0] == BOS
sep_is_eos = self.sep_is_eos
if sep_is_eos is None:
sep_is_eos = input_tokens[-1] == EOS
if self.strip_cls_sep:
if cls_is_bos:
input_tokens = input_tokens[1:]
if sep_is_eos:
input_tokens = input_tokens[:-1]
if not self.ret_mask_and_type: # only need input_ids and token_span, use a light version
if input_is_str:
prefix_mask = self._init_prefix_mask(input_ids)
else:
if input_tokens:
return_offsets_mapping = tokenizer.is_fast and self.ret_subtokens
encodings = tokenizer.batch_encode_plus(
input_tokens,
return_offsets_mapping=return_offsets_mapping, # Many tokenizers do not offer fast version
add_special_tokens=False
)
subtoken_ids_per_token = encodings.data['input_ids']
if return_offsets_mapping:
offsets_mapping = [encoding.offsets for encoding in encodings.encodings]
else:
offsets_mapping = []
for token, subtoken_ids in zip(input_tokens, subtoken_ids_per_token):
if len(subtoken_ids) > len(token): # … --> ...
del subtoken_ids[len(token):]
if not subtoken_ids:
subtoken_ids = [tokenizer.unk_token_id]
# Since non-fast tok generates no mapping, we have to guess
char_per_subtoken = max(len(token) // len(subtoken_ids), 1)
bes = [(b, b + char_per_subtoken) for b in range(0, len(token), char_per_subtoken)]
if not bes: # the token is an empty string
bes = [(0, 0)]
if len(bes) != len(subtoken_ids):
bes[len(subtoken_ids) - 1] = (bes[len(subtoken_ids) - 1][0], len(token))
del bes[len(subtoken_ids):]
offsets_mapping.append(bes)
else:
encodings = SerializableDict()
subtoken_ids_per_token = []
encodings.data = {'input_ids': subtoken_ids_per_token}
if self.check_space_before:
# noinspection PyUnboundLocalVariable
for token, subtokens, mapping, encoding in zip(input_tokens, subtoken_ids_per_token,
offsets_mapping, encodings.encodings):
# Remove ▁ generated by spm for 2 reasons:
# 1. During decoding, mostly no ▁ will be created unless blanks are placed between tokens (which
# is true for English but in English it will likely be concatenated to the token following it)
# 2. For T5, '▁' is used as CLS
if len(subtokens) > 1 and encoding.tokens[0] == '▁':
subtokens.pop(0)
if mapping:
mapping.pop(0)
# Some tokens get stripped out
subtoken_ids_per_token = [ids if ids else [tokenizer.unk_token_id] for ids in subtoken_ids_per_token]
input_ids = sum(subtoken_ids_per_token, [self.cls_token_id])
if self.sep_is_eos is None:
# None means to check whether sep is at the tail or between tokens
if sep_is_eos:
input_ids += [self.sep_token_id]
elif self.sep_token_id not in input_ids:
input_ids += [self.sep_token_id]
else:
input_ids += [self.sep_token_id]
# else self.sep_is_eos == False means sep is between tokens and don't bother to check
if self.ret_subtokens:
prefix_mask = self._init_prefix_mask(input_ids)
# if self.check_space_before:
# if offsets_mapping[0] and not input_tokens[0].startswith(' '):
# prefix_mask[1] = False
else:
prefix_mask = [False] * len(input_ids)
offset = 1
for _subtokens in subtoken_ids_per_token:
prefix_mask[offset] = True
offset += len(_subtokens)
if self.ret_subtokens:
subtoken_offsets = []
for token, offsets in zip(input_tokens, offsets_mapping):
if offsets:
subtoken_offsets.append(offsets)
else:
subtoken_offsets.append([(0, len(token))])
if self.ret_subtokens_group:
sample[f'{self.input_key}_subtoken_offsets_group'] = subtoken_offsets
else:
sample[f'{self.input_key}_subtoken_offsets'] = sum(subtoken_offsets, [])
else:
input_ids, attention_mask, token_type_ids, prefix_mask = \
convert_examples_to_features(input_tokens,
None,
tokenizer,
cls_token_at_end=self.cls_token_at_end,
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=self.cls_token_segment_id,
sep_token=self.sep_token,
sep_token_extra=self.sep_token_extra,
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=self.pad_on_left,
# pad on the left for xlnet
pad_token_id=self.pad_token_id,
pad_token_segment_id=self.pad_token_segment_id,
pad_token_label_id=0,
do_padding=self.do_padding)
if len(input_ids) > self.max_seq_length:
if self.truncate_long_sequences:
# raise SequenceTooLong(
# f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
# f'For sequence tasks, truncate_long_sequences = True is not supported.'
# f'You are recommended to split your long text into several sentences within '
# f'{self.max_seq_length - 2} tokens beforehand. '
# f'Or simply set truncate_long_sequences = False to enable sliding window.')
input_ids = input_ids[:self.max_seq_length]
prefix_mask = prefix_mask[:self.max_seq_length]
warnings.warn(
f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
f'The exceeded part will be truncated and ignored. '
f'You are recommended to split your long text into several sentences within '
f'{self.max_seq_length - 2} tokens beforehand.'
f'Or simply set truncate_long_sequences = False to enable sliding window.'
)
else:
input_ids = self.sliding_window(input_ids, input_ids[-1] == self.sep_token_id)
if prefix_mask:
if cls_is_bos:
prefix_mask[0] = True
if sep_is_eos:
prefix_mask[-1] = True
outputs = [input_ids]
if self.ret_mask_and_type:
# noinspection PyUnboundLocalVariable
outputs += [attention_mask, token_type_ids]
if self.ret_prefix_mask:
outputs += [prefix_mask]
if ret_token_span and prefix_mask:
if cls_is_bos:
token_span = [[0]]
else:
token_span = []
offset = 1
span = []
for mask in prefix_mask[1:len(prefix_mask) if sep_is_eos is None else -1]: # skip [CLS] and [SEP]
if mask and span:
token_span.append(span)
span = []
span.append(offset)
offset += 1
if span:
token_span.append(span)
if sep_is_eos:
assert offset == len(prefix_mask) - 1
token_span.append([offset])
outputs.append(token_span)
for k, v in zip(self.output_key, outputs):
sample[k] = v
return sample
def _init_prefix_mask(self, input_ids):
prefix_mask = [True] * len(input_ids)
if not self.cls_is_bos:
prefix_mask[0] = False
if not self.sep_is_eos:
prefix_mask[-1] = False
return prefix_mask
def config_is(config, model='bert'):
return model in type(config).__name__.lower()
def convert_examples_to_features(
words,
max_seq_length: Optional[int],
tokenizer,
labels=None,
label_map=None,
cls_token_at_end=False,
cls_token="[CLS]",
cls_token_segment_id=1,
sep_token="[SEP]",
sep_token_extra=False,
pad_on_left=False,
pad_token_id=0,
pad_token_segment_id=0,
pad_token_label_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True,
unk_token='[UNK]',
do_padding=True
):
"""Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
Args:
words:
max_seq_length:
tokenizer:
labels: (Default value = None)
label_map: (Default value = None)
cls_token_at_end: (Default value = False)
cls_token: (Default value = "[CLS]")
cls_token_segment_id: (Default value = 1)
sep_token: (Default value = "[SEP]")
sep_token_extra: (Default value = False)
pad_on_left: (Default value = False)
pad_token_id: (Default value = 0)
pad_token_segment_id: (Default value = 0)
pad_token_label_id: (Default value = 0)
sequence_a_segment_id: (Default value = 0)
mask_padding_with_zero: (Default value = True)
unk_token: (Default value = '[UNK]')
do_padding: (Default value = True)
Returns:
"""
args = locals()
if not labels:
labels = words
pad_token_label_id = False
tokens = []
label_ids = []
for word, label in zip(words, labels):
word_tokens = tokenizer.tokenize(word)
if not word_tokens:
# some wired chars cause the tagger to return empty list
word_tokens = [unk_token] * len(word)
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
label_ids.extend([label_map[label] if label_map else True] + [pad_token_label_id] * (len(word_tokens) - 1))
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = 3 if sep_token_extra else 2
if max_seq_length and len(tokens) > max_seq_length - special_tokens_count:
warnings.warn(
f'Input tokens {words} exceed the max sequence length of {max_seq_length - special_tokens_count}. '
f'The exceeded part will be truncated and ignored. '
f'You are recommended to split your long text into several sentences within '
f'{max_seq_length - special_tokens_count} tokens beforehand.')
tokens = tokens[: (max_seq_length - special_tokens_count)]
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# token_type_ids: 0 0 0 0 0 0 0
#
# Where "token_type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens += [sep_token]
label_ids += [pad_token_label_id]
if sep_token_extra:
# roberta uses an extra separator b/w pairs of sentences
tokens += [sep_token]
label_ids += [pad_token_label_id]
segment_ids = [sequence_a_segment_id] * len(tokens)
if cls_token_at_end:
tokens += [cls_token]
label_ids += [pad_token_label_id]
segment_ids += [cls_token_segment_id]
else:
tokens = [cls_token] + tokens
label_ids = [pad_token_label_id] + label_ids
segment_ids = [cls_token_segment_id] + segment_ids
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
if do_padding:
# Zero-pad up to the sequence length.
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token_id] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
label_ids = ([pad_token_label_id] * padding_length) + label_ids
else:
input_ids += [pad_token_id] * padding_length
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
segment_ids += [pad_token_segment_id] * padding_length
label_ids += [pad_token_label_id] * padding_length
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length, f'failed for:\n {args}'
else:
assert len(set(len(x) for x in [input_ids, input_mask, segment_ids, label_ids])) == 1
return input_ids, input_mask, segment_ids, label_ids
def main():
transformer = 'bert-base-uncased'
tokenizer: PreTrainedTokenizer = AutoTokenizer_.from_pretrained(transformer)
# _test_text_transform(tokenizer)
_test_sequence_transform(tokenizer)
def _test_text_transform(tokenizer):
transform = TransformerTextTokenizer(tokenizer, 'text')
sample = {'text': 'HanLP good'}
print(transform(sample))
def _test_sequence_transform(tokenizer):
transform = TransformerSequenceTokenizer(tokenizer, 'token')
sample = {'token': 'HanLP good'.split()}
print(transform(sample))
if __name__ == '__main__':
main()