# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-11-29 16:52
import logging
from typing import Dict, Any, List, Union, Iterable, Callable
import torch
from phrasetree.tree import Tree
from hanlp_common.constant import BOS, EOS
from hanlp_common.document import Document
from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
from torch.utils.data import DataLoader
from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
from hanlp.common.transform import VocabDict
from hanlp.components.mtl.tasks import Task
from hanlp.components.parsers.constituency.crf_constituency_model import CRFConstituencyDecoder
from hanlp.components.parsers.constituency.crf_constituency_parser import CRFConstituencyParser
from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
from hanlp.metrics.metric import Metric
from hanlp.metrics.mtl import MetricDict
from hanlp.utils.time_util import CountdownTimer
from hanlp_common.util import merge_locals_kwargs, prefix_match
[docs]class CRFConstituencyParsing(Task, CRFConstituencyParser):
def __init__(self,
trn: str = None,
dev: str = None,
tst: str = None,
sampler_builder: SamplerBuilder = None,
dependencies: str = None,
scalar_mix: ScalarMixWithDropoutBuilder = None,
delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',', 'S1'),
equal=(('ADVP', 'PRT'),),
) -> None:
r"""Two-stage CRF Parsing (:cite:`ijcai2020-560`).
trn: Path to training set.
dev: Path to dev set.
tst: Path to test set.
sampler_builder: A builder which builds a sampler.
dependencies: Its dependencies on other tasks.
scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
lr: Learning rate for this task.
separate_optimizer: Use customized separate optimizer for this task.
cls_is_bos: ``True`` to treat the first token as ``BOS``.
sep_is_eos: ``True`` to treat the last token as ``EOS``.
delete: Constituencies to be deleted from training and evaluation.
equal: Constituencies that are regarded as equal during evaluation.
mbr: ``True`` to enable Minimum Bayes Risk (MBR) decoding (:cite:`smith-smith-2007-probabilistic`).
n_mlp_span: Number of features for span decoder.
n_mlp_label: Number of features for label decoder.
mlp_dropout: Dropout applied to MLPs.
no_subcategory: Strip out subcategories.
**kwargs: Not used.
if isinstance(equal, tuple):
equal = dict(equal)
super().__init__(**merge_locals_kwargs(locals(), kwargs))
self.vocabs = VocabDict()
# noinspection DuplicatedCode
[docs] def build_dataloader(self,
transform: Callable = None,
logger: logging.Logger = None,
**kwargs) -> DataLoader:
dataset = CRFConstituencyParsing.build_dataset(self, data, transform)
if self.vocabs.mutable:
CRFConstituencyParsing.build_vocabs(self, dataset, logger)
if isinstance(data, str):
timer = CountdownTimer(len(dataset))
# noinspection PyCallByClass
BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger)
return PadSequenceDataLoader(
batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
def feed_batch(self,
h: torch.FloatTensor,
batch: Dict[str, torch.Tensor],
mask: torch.BoolTensor,
decoder: torch.nn.Module):
return {
'output': decoder(h),
'mask': CRFConstituencyParser.compute_mask(
self, batch, offset=1 if 'constituency' in batch or batch['token'][0][-1] == EOS else -1)
def compute_loss(self,
batch: Dict[str, Any],
output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
out, mask = output['output'], output['mask']
loss, span_probs = CRFConstituencyParser.compute_loss(self, out, batch['chart_id'], mask, crf_decoder=criterion)
output['span_probs'] = span_probs
return loss
def decode_output(self,
output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
mask: torch.BoolTensor,
batch: Dict[str, Any],
decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]:
out, mask = output['output'], output['mask']
tokens = []
for sent in batch['token']:
if sent[0] == BOS:
sent = sent[1:]
if sent[-1] == EOS:
sent = sent[:-1]
return CRFConstituencyParser.decode_output(self, out, mask, batch, output.get('span_probs', None),
decoder=decoder, tokens=tokens)
def update_metrics(self,
batch: Dict[str, Any],
output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
return CRFConstituencyParser.update_metrics(self, metric, batch, prediction)
[docs] def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
return CRFConstituencyDecoder(n_labels=len(self.vocabs.chart), n_hidden=encoder_size)
[docs] def build_metric(self, **kwargs):
return CRFConstituencyParser.build_metric(self)
def prediction_to_result(self, prediction: List, batch: Dict[str, Any]) -> List:
return prediction
def finalize_document(self, doc: Document, task_name: str):
pos_key = prefix_match('pos', doc)
pos: List[List[str]] = doc.get(pos_key, None)
if pos:
for tree, pos_per_sent in zip(doc[task_name], pos):
tree: Tree = tree
offset = 0
for subtree in tree.subtrees(lambda t: t.height() == 2):
tag = subtree.label()
if tag == '_':
offset += 1
[docs] def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
return CRFConstituencyParser.build_samples(self, inputs)