Source code for hanlp.common.vocab

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-06-13 22:42
from collections import Counter
from typing import List, Dict, Union, Iterable

from hanlp_common.constant import UNK, PAD
from hanlp_common.structure import Serializable
from hanlp_common.reflection import classpath_of


[docs]class Vocab(Serializable): def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD, unk_token=UNK) -> None: """Vocabulary base class which converts tokens to indices and vice versa. Args: idx_to_token: id to token mapping. token_to_idx: token to id mapping. mutable: ``True`` to allow adding new tokens, ``False`` to map OOV to ``unk``. pad_token: The token representing padding. unk_token: The token representing OOV. """ super().__init__() if idx_to_token: t2i = dict((token, idx) for idx, token in enumerate(idx_to_token)) if token_to_idx: t2i.update(token_to_idx) token_to_idx = t2i if token_to_idx is None: token_to_idx = {} if pad_token is not None: token_to_idx[pad_token] = len(token_to_idx) if unk_token is not None: token_to_idx[unk_token] = token_to_idx.get(unk_token, len(token_to_idx)) self.token_to_idx = token_to_idx self.idx_to_token: List[str] = None self.mutable = mutable self.pad_token = pad_token self.unk_token = unk_token def __setitem__(self, token: str, idx: int): assert self.mutable, 'Update an immutable Vocab object is not allowed' self.token_to_idx[token] = idx
[docs] def __getitem__(self, key: Union[str, int, List]) -> Union[int, str, List]: """ Get the index/indices associated with a token or a list of tokens or vice versa. Args: key: ``str`` for token(s) and ``int`` for index/indices. Returns: Associated indices or tokens. """ if isinstance(key, str): return self.get_idx(key) elif isinstance(key, int): return self.get_token(key) elif isinstance(key, list): if len(key) == 0: return [] elif isinstance(key[0], str): return [self.get_idx(x) for x in key] elif isinstance(key[0], int): return [self.get_token(x) for x in key]
def __contains__(self, key: Union[str, int]): if isinstance(key, str): return key in self.token_to_idx elif isinstance(key, int): return 0 <= key < len(self.idx_to_token) else: return False
[docs] def add(self, token: str) -> int: """ Tries to add a token into a vocab and returns its id. If it has already been there, its id will be returned and the vocab won't be updated. If the vocab is locked, an assertion failure will occur. Args: token: A new or existing token. Returns: Its associated id. """ assert self.mutable, 'It is not allowed to call add on an immutable Vocab' assert isinstance(token, str), f'Token type must be str but got {type(token)} from {token}' assert token is not None, 'Token must not be None' idx = self.token_to_idx.get(token, None) if idx is None: idx = len(self.token_to_idx) self.token_to_idx[token] = idx return idx
[docs] def update(self, tokens: Iterable[str]) -> None: """Update the vocab with these tokens by adding them to vocab one by one. Args: tokens (Iterable[str]): A list of tokens. """ assert self.mutable, 'It is not allowed to update an immutable Vocab' for token in tokens: self.add(token)
[docs] def get_idx(self, token: str) -> int: """Get the idx of a token. If it's not there, it will be added to the vocab when the vocab is locked otherwise the id of UNK will be returned. Args: token: A token. Returns: The id of that token. """ assert isinstance(token, str), 'token has to be `str`' idx = self.token_to_idx.get(token, None) if idx is None: if self.mutable: idx = len(self.token_to_idx) self.token_to_idx[token] = idx else: idx = self.token_to_idx.get(self.unk_token, None) return idx
def get_idx_without_add(self, token: str) -> int: idx = self.token_to_idx.get(token, None) if idx is None: idx = self.token_to_idx.get(self.safe_unk_token, None) return idx
[docs] def get_token(self, idx: int) -> str: """Get the token using its index. Args: idx: The index to a token. Returns: """ if self.idx_to_token: return self.idx_to_token[idx] if self.mutable: for token in self.token_to_idx: if self.token_to_idx[token] == idx: return token
def has_key(self, token): return token in self.token_to_idx def __len__(self): return len(self.token_to_idx)
[docs] def lock(self): """Lock this vocab up so that it won't accept new tokens. Returns: Itself. """ if self.locked: return self self.mutable = False self.build_idx_to_token() return self
def build_idx_to_token(self): max_idx = max(self.token_to_idx.values()) self.idx_to_token = [None] * (max_idx + 1) for token, idx in self.token_to_idx.items(): self.idx_to_token[idx] = token
[docs] def unlock(self): """Unlock this vocab so that new tokens can be added in. Returns: Itself. """ if not self.locked: return self.mutable = True self.idx_to_token = None return self
@property def locked(self): """ ``True`` indicates this vocab is locked. """ return not self.mutable @property def unk_idx(self): """ The index of ``UNK`` token. """ if self.unk_token is None: return None else: return self.token_to_idx.get(self.unk_token, None) @property def pad_idx(self): """ The index of ``PAD`` token. """ if self.pad_token is None: return None else: return self.token_to_idx.get(self.pad_token, None) @property def tokens(self): """ A set of all tokens in this vocab. """ return self.token_to_idx.keys() def __str__(self) -> str: return self.token_to_idx.__str__()
[docs] def summary(self, verbose=True) -> str: """Get or print a summary of this vocab. Args: verbose: ``True`` to print the summary to stdout. Returns: Summary in text form. """ # report = 'Length: {}\n'.format(len(self)) # report += 'Samples: {}\n'.format(str(list(self.token_to_idx.keys())[:min(50, len(self))])) # report += 'Mutable: {}'.format(self.mutable) # report = report.strip() report = '[{}] = '.format(len(self)) report += str(list(self.token_to_idx.keys())[:min(50, len(self))]) if verbose: print(report) return report
def __call__(self, some_token: Union[str, Iterable[str]]) -> Union[int, List[int]]: if isinstance(some_token, (list, tuple, set)): indices = [] if len(some_token) and isinstance(some_token[0], (list, tuple, set)): for sent in some_token: inside = [] for token in sent: inside.append(self.get_idx(token)) indices.append(inside) return indices for token in some_token: indices.append(self.get_idx(token)) return indices else: return self.get_idx(some_token)
[docs] def to_dict(self) -> dict: """Convert this vocab to a dict so that it can be json serialized. Returns: A dict. """ idx_to_token = self.idx_to_token pad_token = self.pad_token unk_token = self.unk_token mutable = self.mutable items = locals().copy() items.pop('self') return items
[docs] def copy_from(self, item: dict): """Copy properties from a dict so that it can json de-serialized. Args: item: A dict holding ``token_to_idx`` Returns: Itself. """ for key, value in item.items(): setattr(self, key, value) self.token_to_idx = {k: v for v, k in enumerate(self.idx_to_token)} return self
[docs] def lower(self): """Convert all tokens to lower case. Returns: Itself. """ self.unlock() token_to_idx = self.token_to_idx self.token_to_idx = {} for token in token_to_idx.keys(): self.add(token.lower()) return self
@property def first_token(self): """The first token in this vocab. """ if self.idx_to_token: return self.idx_to_token[0] if self.token_to_idx: return next(iter(self.token_to_idx)) return None
[docs] def merge(self, other): """Merge this with another vocab inplace. Args: other (Vocab): Another vocab. """ for word, idx in other.token_to_idx.items(): self.get_idx(word)
@property def safe_pad_token(self) -> str: """Get the pad token safely. It always returns a pad token, which is the pad token or the first token if pad does not present in the vocab. """ if self.pad_token: return self.pad_token if self.first_token: return self.first_token return PAD @property def safe_pad_token_idx(self) -> int: """Get the idx to the pad token safely. It always returns an index, which corresponds to the pad token or the first token if pad does not present in the vocab. """ return self.token_to_idx.get(self.safe_pad_token, 0) @property def safe_unk_token(self) -> str: """Get the unk token safely. It always returns a unk token, which is the unk token or the first token if unk does not presented in the vocab. """ if self.unk_token: return self.unk_token if self.first_token: return self.first_token return UNK def __repr__(self) -> str: if self.idx_to_token is not None: return self.idx_to_token.__repr__() return self.token_to_idx.__repr__() def extend(self, tokens: Iterable[str]): self.unlock() self(tokens) def reload_idx_to_token(self, idx_to_token: List[str], pad_idx=0, unk_idx=1): self.idx_to_token = idx_to_token self.token_to_idx = dict((s, i) for i, s in enumerate(idx_to_token)) if pad_idx is not None: self.pad_token = idx_to_token[pad_idx] if unk_idx is not None: self.unk_token = idx_to_token[unk_idx]
[docs] def set_unk_as_safe_unk(self): """Set ``self.unk_token = self.safe_unk_token``. It's useful when the dev/test set contains OOV labels. """ self.unk_token = self.safe_unk_token
def clear(self): self.unlock() self.token_to_idx.clear()
class CustomVocab(Vocab): def to_dict(self) -> dict: d = super().to_dict() d['type'] = classpath_of(self) return d class LowercaseVocab(CustomVocab): def get_idx(self, token: str) -> int: idx = self.token_to_idx.get(token, None) if idx is None: idx = self.token_to_idx.get(token.lower(), None) if idx is None: if self.mutable: idx = len(self.token_to_idx) self.token_to_idx[token] = idx else: idx = self.token_to_idx.get(self.unk_token, None) return idx class VocabWithNone(CustomVocab): def get_idx(self, token: str) -> int: if token is None: return -1 return super().get_idx(token) class VocabWithFrequency(CustomVocab): def __init__(self, counter: Counter = None, min_occur_cnt=0, pad_token=PAD, unk_token=UNK, specials=None) -> None: super().__init__(None, None, True, pad_token, unk_token) if specials: for each in specials: counter.pop(each, None) self.add(each) self.frequencies = [1] * len(self) if counter: for token, freq in counter.most_common(): if freq >= min_occur_cnt: self.add(token) self.frequencies.append(freq) self.lock() def to_dict(self) -> dict: d = super().to_dict() d['frequencies'] = self.frequencies return d def copy_from(self, item: dict): super().copy_from(item) self.frequencies = item['frequencies'] def get_frequency(self, token): idx = self.get_idx(token) if idx is not None: return self.frequencies[idx] return 0 class VocabCounter(CustomVocab): def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD, unk_token=UNK) -> None: super().__init__(idx_to_token, token_to_idx, mutable, pad_token, unk_token) self.counter = Counter() def get_idx(self, token: str) -> int: if self.mutable: self.counter[token] += 1 return super().get_idx(token) def trim(self, min_frequency): assert self.mutable specials = {self.unk_token, self.pad_token} survivors = list((token, freq) for token, freq in self.counter.most_common() if freq >= min_frequency and token not in specials) survivors = [(x, -1) for x in specials if x] + survivors self.counter = Counter(dict(survivors)) self.token_to_idx = dict() self.idx_to_token = None for token, freq in survivors: idx = len(self.token_to_idx) self.token_to_idx[token] = idx def copy_from(self, item: dict): super().copy_from(item) self.counter = Counter(item['counter'].items()) if 'counter' in item else Counter() def to_dict(self) -> dict: d = super().to_dict() d['counter'] = dict(self.counter.items()) return d class Vocab3D(CustomVocab): def __call__(self, some_token: Union[str, Iterable[str], Iterable[Iterable[str]]]) \ -> Union[int, List[int], List[List[int]]]: """It supports 3D arrays of tokens. Args: some_token: Tokens of 1D to 3D Returns: A list of indices. """ if isinstance(some_token, (list, tuple, set)): indices = [] if len(some_token) and isinstance(some_token[0], (list, tuple, set)): for sent in some_token: inside = [] for token in sent: inside.append(self.get_idx(token)) indices.append(inside) return indices for token in some_token: if isinstance(token, str): indices.append(self.get_idx(token)) else: indices.append([self.get_idx(x) for x in token]) return indices else: return self.get_idx(some_token) def create_label_vocab() -> Vocab: return Vocab(pad_token=None, unk_token=None)