# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-31 04:16
import json
import re
import warnings
from typing import List, Union
from phrasetree.tree import Tree
from hanlp_common.conll import CoNLLUWord, CoNLLSentence, CoNLLSentenceList
from hanlp_common.constant import PRED, IPYTHON
from hanlp_common.util import collapse_json, prefix_match
from hanlp_common.visualization import tree_to_list, list_to_tree, render_labeled_span, make_table
[docs]class Document(dict):
def __init__(self, *args, **kwargs) -> None:
r"""A dict structure holding parsed annotations. A document is a subclass of ``dict`` and it supports every
interface of ``dict``\. Additionally, it supports interfaces to deal with various linguistic structures. Its
``str`` and ``dict`` representations are made to be compatible with JSON serialization.
Args:
*args: An iterator of key-value pairs.
**kwargs: Arguments from ``**`` operator.
Examples::
# Create a document
doc = Document(
tok=[["晓美焰", "来到", "北京", "立方庭", "参观", "自然", "语义", "科技", "公司"]],
pos=[["NR", "VV", "NR", "NR", "VV", "NN", "NN", "NN", "NN"]],
ner=[[["晓美焰", "PERSON", 0, 1], ["北京立方庭", "LOCATION", 2, 4],
["自然语义科技公司", "ORGANIZATION", 5, 9]]],
dep=[[[2, "nsubj"], [0, "root"], [4, "name"], [2, "dobj"], [2, "conj"],
[9, "compound"], [9, "compound"], [9, "compound"], [5, "dobj"]]]
)
# print(doc) or str(doc) to get its JSON representation
print(doc)
# Access an annotation by its task name
print(doc['tok'])
# Get number of sentences
print(f'It has {doc.count_sentences()} sentence(s)')
# Access the n-th sentence
print(doc.squeeze(0)['tok'])
# Pretty print it right in your console or notebook
doc.pretty_print()
# To save the pretty prints in a str
pretty_text: str = '\n\n'.join(doc.to_pretty())
"""
super().__init__(*args, **kwargs)
for k, v in list(self.items()):
if not v:
continue
if k == 'con':
if isinstance(v, Tree) or isinstance(v[0], Tree):
continue
flat = isinstance(v[0], str)
if flat:
v = [v]
ls = []
for each in v:
if not isinstance(each, Tree):
ls.append(list_to_tree(each))
if flat:
ls = ls[0]
self[k] = ls
elif k == 'amr':
from hanlp_common.amr import AMRGraph
import penman
if isinstance(v, AMRGraph) or isinstance(v[0], AMRGraph):
continue
flat = isinstance(v[0][0], str)
if flat:
v = [v]
graphs = [AMRGraph(penman.Graph(triples)) for triples in v]
if flat:
graphs = graphs[0]
self[k] = graphs
[docs] def to_json(self, ensure_ascii=False, indent=2) -> str:
"""Convert to json string.
Args:
ensure_ascii: ``False`` to allow for non-ascii text.
indent: Indent per nested structure.
Returns:
A text representation in ``str``.
"""
d = self.to_dict()
text = json.dumps(d, ensure_ascii=ensure_ascii, indent=indent, default=lambda o: repr(o))
text = collapse_json(text, 4)
return text
[docs] def to_dict(self):
"""Convert to a json compatible dict.
Returns:
A dict representation.
"""
d = dict(self)
for k, v in self.items():
if v == [] or v is None:
continue
if k == 'con':
if not isinstance(v, Tree) and not isinstance(v[0], Tree):
continue
flat = isinstance(v, Tree)
if flat:
v = [v]
ls = []
for each in v:
if isinstance(each, Tree):
ls.append(tree_to_list(each))
if flat:
ls = ls[0]
d[k] = ls
return d
def __str__(self) -> str:
return self.to_json()
[docs] def to_conll(self, tok='tok', lem='lem', pos='pos', xpos='pos/xpos', fea='fea', dep='dep', sdp='sdp') -> Union[
CoNLLSentence, List[CoNLLSentence]]:
"""
Convert to :class:`~hanlp_common.conll.CoNLLSentence`.
Args:
tok (str): Field name for tok.
lem (str): Field name for lem.
pos (str): Field name for upos.
xpos (str): Field name for xpos.
fea (str): Field name for feats.
dep (str): Field name for dependency parsing.
sdp (str): Field name for semantic dependency parsing.
Returns:
A :class:`~hanlp_common.conll.CoNLLSentence` representation.
"""
tok = prefix_match(tok, self)
lem = prefix_match(lem, self)
pos = prefix_match(pos, self)
xpos = prefix_match(xpos, self)
fea = prefix_match(fea, self)
dep = prefix_match(dep, self)
sdp = prefix_match(sdp, self)
results = CoNLLSentenceList()
if not tok or not self[tok]:
return results
self = self._to_doc_without_spans(tok)
flat = isinstance(self[tok][0], str)
if flat:
d = Document((k, [v]) for k, v in self.items())
else:
d = self
for sample in [dict(zip(d, t)) for t in zip(*d.values())]:
def get(_k, _i):
_v = sample.get(_k, None)
if not _v:
return None
return _v[_i]
sent = CoNLLSentence()
for i, _tok in enumerate(sample[tok]):
_dep = get(dep, i)
if not _dep:
_dep = (None, None)
sent.append(
CoNLLUWord(i + 1, form=_tok, lemma=get(lem, i), upos=get(pos, i), xpos=get(xpos, i),
feats=get(fea, i), head=_dep[0], deprel=_dep[1],
deps=None if not get(sdp, i) else '|'.join(f'{x[0]}:{x[1]}' for x in get(sdp, i))))
results.append(sent)
if flat:
return results[0]
return results
[docs] def to_pretty(self, tok='tok', lem='lem', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl', con='con',
show_header=True, html=False) -> Union[str, List[str]]:
"""
Convert to a pretty text representation which can be printed to visualize linguistic structures.
Args:
tok: Token key.
lem: Lemma key.
pos: Part-of-speech key.
dep: Dependency parse tree key.
sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
ner: Named entity key.
srl: Semantic role labeling key.
con: Constituency parsing key.
show_header: ``True`` to include a header which indicates each field with its name.
html: ``True`` to output HTML format so that non-ASCII characters can align correctly.
Returns:
A pretty string.
"""
results = []
tok = prefix_match(tok, self)
pos = prefix_match(pos, self)
ner = prefix_match(ner, self)
conlls = self.to_conll(tok=tok, lem=lem, pos=pos, dep=dep, sdp=sdp)
flat = isinstance(conlls, CoNLLSentence)
if flat:
conlls: List[CoNLLSentence] = [conlls]
def condense(block_, extras_=None):
text_ = make_table(block_, insert_header=False)
text_ = [x.split('\t', 1) for x in text_.split('\n')]
text_ = [[x[0], x[1].replace('\t', '')] for x in text_]
if extras_:
for r, s in zip(extras_, text_):
r.extend(s)
return text_
for i, conll in enumerate(conlls):
conll: CoNLLSentence = conll
tokens = [x.form for x in conll]
length = len(conll)
extras = [[] for j in range(length + 1)]
if ner in self:
ner_samples = self[ner]
if flat:
ner_samples = [ner_samples]
ner_per_sample = ner_samples[i]
# For nested NER, use the longest span
start_offsets = [None for i in range(length)]
for ent, label, b, e in ner_per_sample:
if not start_offsets[b] or e > start_offsets[b][-1]:
start_offsets[b] = (ent, label, b, e)
ner_per_sample = [y for y in start_offsets if y]
header = ['Token', 'NER', 'Type']
block = [[] for _ in range(length + 1)]
_ner = []
_type = []
offset = 0
for ent, label, b, e in ner_per_sample:
render_labeled_span(b, e, _ner, _type, label, offset)
offset = e
if offset != length:
_ner.extend([''] * (length - offset))
_type.extend([''] * (length - offset))
if any(_type):
block[0].extend(header)
for j, (_s, _t) in enumerate(zip(_ner, _type)):
block[j + 1].extend((tokens[j], _s, _t))
text = condense(block, extras)
if srl in self:
srl_samples = self[srl]
if flat:
srl_samples = [srl_samples]
srl_per_sample = srl_samples[i]
for k, pas in enumerate(srl_per_sample):
if not pas:
continue
block = [[] for _ in range(length + 1)]
header = ['Token', 'SRL', f'PA{k + 1}']
_srl = []
_type = []
offset = 0
p_index = None
for _, label, b, e in pas:
render_labeled_span(b, e, _srl, _type, label, offset)
offset = e
if label == PRED:
p_index = b
if len(_srl) != length:
_srl.extend([''] * (length - offset))
_type.extend([''] * (length - offset))
if p_index is not None:
_srl[p_index] = '╟──►'
# _type[j] = 'V'
if len(block) != len(_srl) + 1:
# warnings.warn(f'Unable to visualize overlapped spans: {pas}')
continue
block[0].extend(header)
while len(_srl) < length:
_srl.append('')
while len(_type) < length:
_type.append('')
for j, (_s, _t) in enumerate(zip(_srl, _type)):
block[j + 1].extend((tokens[j], _s, _t))
text = condense(block, extras)
if con in self:
con_samples: Tree = self[con]
if flat:
con_samples: List[Tree] = [con_samples]
tree = con_samples[i]
block = [[] for _ in range(length + 1)]
block[0].extend(('Token', 'PoS'))
for j, t in enumerate(tree.pos()):
block[j + 1].extend(t)
for height in range(2, tree.height() + (0 if len(tree) == 1 else 1)):
offset = 0
spans = []
labels = []
for k, subtree in enumerate(tree.subtrees(lambda x: x.height() == height)):
subtree: Tree = subtree
b, e = offset, offset + len(subtree.leaves())
if height >= 3:
b, e = subtree[0].center, subtree[-1].center + 1
subtree.center = b + (e - b) // 2
render_labeled_span(b, e, spans, labels, subtree.label(), offset, unidirectional=True)
offset = e
if len(spans) != length:
spans.extend([''] * (length - len(spans)))
if len(labels) != length:
labels.extend([''] * (length - len(labels)))
if height < 3:
continue
block[0].extend(['', f'{height}'])
for j, (_s, _t) in enumerate(zip(spans, labels)):
block[j + 1].extend((_s, _t))
# check short arrows and increase their length
for j, arrow in enumerate(spans):
if not arrow:
# -1 current tag ; -2 arrow to current tag ; -3 = prev tag ; -4 = arrow to prev tag
if block[j + 1][-3] or block[j + 1][-4] == '───►':
if height > 3:
if block[j + 1][-3]:
block[j + 1][-1] = block[j + 1][-3]
block[j + 1][-2] = '───►'
else:
block[j + 1][-1] = '────'
block[j + 1][-2] = '────'
block[j + 1][-3] = '────'
if block[j + 1][-4] == '───►':
block[j + 1][-4] = '────'
else:
block[j + 1][-1] = '────'
if block[j + 1][-1] == '────':
block[j + 1][-2] = '────'
if not block[j + 1][-4]:
block[j + 1][-4] = '────'
# If the root label is shorter than the level number, extend it to the same length
level_len = len(block[0][-1])
for row in block[1:]:
if row[-1] and len(row[-1]) < level_len:
row[-1] = row[-1] + ' ' * (level_len - len(row[-1]))
text = condense(block)
# Cosmetic issues
for row in text[1:]:
while ' ─' in row[1]:
row[1] = row[1].replace(' ─', ' ──')
row[1] = row[1].replace('─ ─', '───')
row[1] = re.sub(r'([►─])([\w-]*)(\s+)([│├])', lambda
m: f'{m.group(1)}{m.group(2)}{"─" * len(m.group(3))}{"┤" if m.group(4) == "│" else "┼"}',
row[1])
row[1] = re.sub(r'►(─+)►', r'─\1►', row[1])
for r, s in zip(extras, text):
r.extend(s)
# warnings.warn('Unable to visualize non-projective trees.')
if dep in self and conll.projective:
text = conll.to_tree(extras, main_pos=True)
if not show_header:
text = text.split('\n')
text = '\n'.join(text[2:])
results.append(text)
elif any(extras):
results.append(make_table(extras, insert_header=True))
else:
results.append(' '.join(['/'.join(str(f) for f in x.nonempty_fields) for x in conll]))
if html:
def to_html(pretty_text: str) -> str:
lines = [x for x in pretty_text.split('\n') if x]
cells = []
for line in lines:
cells.append(line.split('\t'))
num_cols = len(cells[0])
cols = []
for i in range(num_cols):
cols.append([])
for row in cells:
cols[-1].append(row[i])
html = '<div style="display: table; padding-bottom: 1rem;">'
for i, each in enumerate(cols):
html += '<pre style="display: table-cell; font-family: SFMono-Regular,Menlo,Monaco,Consolas,' \
'Liberation Mono,Courier New,monospace; white-space: nowrap; line-height: 128%; padding: 0;">'
if i != len(cols) - 1:
each = [x + ' ' for x in each]
html += '<br>'.join([x.replace(' ', ' ') for x in each])
html += '</pre>'
html += '</div>'
return html
results = [to_html(x) for x in results]
if flat:
return results[0]
return results
[docs] def pretty_print(self, tok='tok', lem='lem', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl', con='con',
show_header=True, html=IPYTHON):
"""
Print a pretty text representation which visualizes linguistic structures.
Args:
tok: Token key.
lem: Lemma key.
pos: Part-of-speech key.
dep: Dependency parse tree key.
sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
ner: Named entity key.
srl: Semantic role labeling key.
con: Constituency parsing key.
show_header: ``True`` to print a header which indicates each field with its name.
html: ``True`` to output HTML format so that non-ASCII characters can align correctly.
"""
results = self.to_pretty(tok, lem, pos, dep, sdp, ner, srl, con, show_header, html=html)
if isinstance(results, str):
results = [results]
if html and IPYTHON:
from IPython.core.display import display, HTML
display(HTML('<br>'.join(results)))
else:
sent_new_line = '\n\n' if any('\n' in x for x in results) else '\n'
print(sent_new_line.join(results))
[docs] def translate(self, lang, tok='tok', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl'):
"""
Translate tags for each annotation. This is an inplace operation.
.. Attention:: Note that the translated document might not print well in terminal due to non-ASCII characters.
Args:
lang: Target language to be translated to.
tok: Token key.
pos: Part-of-speech key.
dep: Dependency parse tree key.
sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
ner: Named entity key.
srl: Semantic role labeling key.
Returns:
The translated document.
"""
if lang == 'zh':
from hanlp.utils.lang.zh import localization
else:
raise NotImplementedError(f'No translation for {lang}. '
f'Please contribute to our translation at https://github.com/hankcs/HanLP')
flat = isinstance(self[tok][0], str)
for task, name in zip(['pos', 'ner', 'dep', 'sdp', 'srl'], [pos, ner, dep, sdp, srl]):
annotations = self.get(name, None)
if not annotations:
continue
if flat:
annotations = [annotations]
translate: dict = getattr(localization, name, None)
if not translate:
continue
for anno_per_sent in annotations:
for i, v in enumerate(anno_per_sent):
if task == 'ner' or task == 'dep':
v[1] = translate.get(v[1], v[1])
else:
anno_per_sent[i] = translate.get(v, v)
return self
[docs] def squeeze(self, i=0):
r"""
Squeeze the dimension of each field into one. It's intended to convert a nested document like ``[[sent_i]]``
to ``[sent_i]``. When there are multiple sentences, only the ``i-th`` one will be returned. Note this is not an
inplace operation.
Args:
i: Keep the element at ``index`` for all ``list``\s.
Returns:
A squeezed document with only one sentence.
"""
sq = Document()
for k, v in self.items():
sq[k] = v[i] if isinstance(v, list) else v
return sq
def _to_doc_without_spans(self, tok: str):
"""
Remove the spans attached to tokens and return a new document.
Args:
tok: The key to tokens.
Returns:
A new document or itself.
"""
tokens: Union[List[str], List[List[str]], List[str, int, int],
List[List[str, int, int]]] = self[tok]
if isinstance(tokens[0], str):
return self
elif isinstance(tokens[0][-1], int):
tokens = [x[0] for x in tokens]
elif isinstance(tokens[0][-1], str):
return self
else:
tokens = [[t[0] for t in x] for x in tokens]
d = Document(**self)
d[tok] = tokens
return d
[docs] def get_by_prefix(self, prefix: str):
"""
Get value by the prefix of a key.
Args:
prefix: The prefix of a key. If multiple keys are matched, only the first one will be used.
Returns:
The value assigned with the matched key.
"""
key = prefix_match(prefix, self)
if not key:
return None
return self[key]
[docs] def count_sentences(self) -> int:
"""
Count number of sentences in this document.
Returns:
Number of sentences.
"""
tok = self.get_by_prefix('tok')
if isinstance(tok[0], str):
return 1
return len(tok)