Source code for hanlp.datasets.eos.eos

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-07-26 18:12
import itertools
from collections import Counter
from typing import Union, List, Callable

from hanlp.common.dataset import TransformableDataset
from hanlp.utils.io_util import TimingFileIterator
from hanlp.utils.log_util import cprint
from hanlp.utils.string_util import ispunct


[docs]class SentenceBoundaryDetectionDataset(TransformableDataset): def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, append_after_sentence=None, eos_chars=None, eos_char_min_freq=200, eos_char_is_punct=True, window_size=5, **kwargs, ) -> None: """Dataset for sentence boundary detection (eos). Args: data: The local or remote path to a dataset, or a list of samples where each sample is a dict. transform: Predefined transform(s). cache: ``True`` to enable caching, so that transforms won't be called twice. append_after_sentence: A :class:`str` to insert at the tail of each sentence. For example, English always have a space between sentences. eos_chars: Punctuations at the tail of sentences. If ``None``, then it will built from training samples. eos_char_min_freq: Minimal frequency to keep a eos char. eos_char_is_punct: Limit eos chars to punctuations. window_size: Window size to extract ngram features. kwargs: Not used. """ self.eos_char_is_punct = eos_char_is_punct self.append_after_sentence = append_after_sentence self.window_size = window_size self.eos_chars = eos_chars self.eos_char_min_freq = eos_char_min_freq super().__init__(data, transform, cache)
[docs] def load_file(self, filepath: str): """Load eos corpus. Args: filepath: Path to the corpus. .. highlight:: bash .. code-block:: bash $ head -n 2 ctb8.txt 中国经济简讯 新华社北京十月二十九日电中国经济简讯 """ f = TimingFileIterator(filepath) sents = [] eos_offsets = [] offset = 0 for line in f: if not line.strip(): continue line = line.rstrip('\n') eos_offsets.append(offset + len(line.rstrip()) - 1) offset += len(line) if self.append_after_sentence: line += self.append_after_sentence offset += len(self.append_after_sentence) f.log(line) sents.append(line) f.erase() corpus = list(itertools.chain.from_iterable(sents)) if self.eos_chars: if not isinstance(self.eos_chars, set): self.eos_chars = set(self.eos_chars) else: eos_chars = Counter() for i in eos_offsets: eos_chars[corpus[i]] += 1 self.eos_chars = set(k for (k, v) in eos_chars.most_common() if v >= self.eos_char_min_freq and (not self.eos_char_is_punct or ispunct(k))) cprint(f'eos_chars = [yellow]{self.eos_chars}[/yellow]') eos_index = 0 eos_offsets = [i for i in eos_offsets if corpus[i] in self.eos_chars] window_size = self.window_size for i, c in enumerate(corpus): if c in self.eos_chars: window = corpus[i - window_size: i + window_size + 1] label_id = 1. if eos_offsets[eos_index] == i else 0. if label_id > 0: eos_index += 1 yield {'char': window, 'label_id': label_id} assert eos_index == len(eos_offsets), f'{eos_index} != {len(eos_offsets)}'