Source code for hanlp.datasets.tokenization.loaders.multi_criteria_cws.mcws_dataset

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-10-21 19:11
import os
from typing import Union, List, Callable, Dict, Iterable

from hanlp.datasets.tokenization.loaders.txt import TextTokenizingDataset
from hanlp.utils.io_util import get_resource


[docs]class MultiCriteriaTextTokenizingDataset(TextTokenizingDataset): def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, generate_idx=None, delimiter=None, max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False) -> None: super().__init__(data, transform, cache, generate_idx, delimiter, max_seq_len, sent_delimiter, char_level, hard_constraint)
[docs] def should_load_file(self, data) -> bool: return isinstance(data, (tuple, dict))
[docs] def load_file(self, filepath: Union[Iterable[str], Dict[str, str]]): """Load multi-criteria corpora specified in filepath. Args: filepath: A list of files where filename is its criterion. Or a dict of filename-criterion pairs. .. highlight:: bash .. code-block:: bash $ tree -L 2 . . ├── cnc │   ├── dev.txt │   ├── test.txt │   ├── train-all.txt │   └── train.txt ├── ctb │   ├── dev.txt │   ├── test.txt │   ├── train-all.txt │   └── train.txt ├── sxu │   ├── dev.txt │   ├── test.txt │   ├── train-all.txt │   └── train.txt ├── udc │   ├── dev.txt │   ├── test.txt │   ├── train-all.txt │   └── train.txt ├── wtb │   ├── dev.txt │   ├── test.txt │   ├── train-all.txt │   └── train.txt └── zx ├── dev.txt ├── test.txt ├── train-all.txt └── train.txt $ head -n 2 ctb/dev.txt 上海 浦东 开发 与 法制 建设 同步 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) """ for eachpath in (filepath.items() if isinstance(filepath, dict) else filepath): if isinstance(eachpath, tuple): criteria, eachpath = eachpath eachpath = get_resource(eachpath) else: eachpath = get_resource(eachpath) criteria = os.path.basename(os.path.dirname(eachpath)) for sample in super().load_file(eachpath): sample['criteria'] = criteria yield sample
def append_criteria_token(sample: dict, criteria_tokens: Dict[str, int], criteria_token_map: dict) -> dict: criteria = sample['criteria'] token = criteria_token_map.get(criteria, None) if not token: unused_tokens = list(criteria_tokens.keys()) size = len(criteria_token_map) assert size + 1 < len(unused_tokens), f'No unused token available for criteria {criteria}. ' \ f'Current criteria_token_map = {criteria_token_map}' token = criteria_token_map[criteria] = unused_tokens[size] sample['token_token_type_ids'] = [0] * len(sample['token_input_ids']) + [1] sample['token_input_ids'] = sample['token_input_ids'] + [criteria_tokens[token]] return sample