# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2021-05-20 17:03
import logging
from typing import Union, List
import torch
from torch.utils.data import DataLoader
from hanlp.common.structure import History
from hanlp.layers.transformers.pt_imports import AutoConfig_, AutoTokenizer_
from transformers import AutoModelForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from hanlp.common.dataset import SortingSamplerBuilder, PadSequenceDataLoader
from hanlp.common.torch_component import TorchComponent
from hanlp.datasets.sts.stsb import SemanticTextualSimilarityDataset
from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
from hanlp.metrics.spearman_correlation import SpearmanCorrelation
from hanlp.transform.transformer_tokenizer import TransformerTextTokenizer
from hanlp.utils.time_util import CountdownTimer
from hanlp_common.util import merge_locals_kwargs, reorder
from hanlp_common.constant import IDX
[docs]class TransformerSemanticTextualSimilarity(TorchComponent):
def __init__(self, **kwargs) -> None:
"""
A simple Semantic Textual Similarity (STS) baseline which fine-tunes a transformer with a regression layer on
top of it.
Args:
**kwargs: Predefined config.
"""
super().__init__(**kwargs)
self._tokenizer = None
# noinspection PyMethodOverriding
[docs] def build_dataloader(self, data, batch_size, sent_a_col=None,
sent_b_col=None,
similarity_col=None,
delimiter='auto',
gradient_accumulation=1,
sampler_builder=None,
shuffle=False, device=None, logger: logging.Logger = None,
split=None,
**kwargs) -> DataLoader:
dataset = SemanticTextualSimilarityDataset(data,
sent_a_col,
sent_b_col,
similarity_col,
delimiter=delimiter,
transform=self._tokenizer,
cache=isinstance(data, str))
if split == 'trn':
scores = [x['similarity'] for x in dataset]
self.config.max_score = max(scores)
self.config.min_score = min(scores)
if not sampler_builder:
sampler_builder = SortingSamplerBuilder(batch_size=batch_size)
lens = [len(x['input_ids']) for x in dataset]
return PadSequenceDataLoader(dataset, batch_sampler=sampler_builder.build(lens, shuffle, gradient_accumulation),
device=device,
pad={'similarity': 0.0, 'input_ids': self._tokenizer.tokenizer.pad_token_id})
[docs] def build_optimizer(self, trn, epochs, gradient_accumulation=1, lr=1e-3, transformer_lr=5e-5, adam_epsilon=1e-8,
weight_decay=0.0, warmup_steps=0.1, **kwargs):
num_training_steps = len(trn) * epochs // gradient_accumulation
optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model,
self.model.base_model,
lr, transformer_lr,
num_training_steps, warmup_steps,
weight_decay, adam_epsilon)
return optimizer, scheduler
[docs] def build_criterion(self, **kwargs):
pass
[docs] def build_metric(self, **kwargs):
return SpearmanCorrelation()
[docs] def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
logger: logging.Logger, devices, ratio_width=None, gradient_accumulation=1, **kwargs):
best_epoch, best_metric = 0, -1
timer = CountdownTimer(epochs)
history = History()
for epoch in range(1, epochs + 1):
logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width,
gradient_accumulation=gradient_accumulation, history=history, save_dir=save_dir)
report = f'{timer.elapsed_human}/{timer.total_time_human}'
self.evaluate_dataloader(dev, logger, ratio_width=ratio_width, save_dir=save_dir, metric=metric)
if metric > best_metric:
self.save_weights(save_dir)
best_metric = float(metric)
best_epoch = epoch
report += ' [red]saved[/red]'
timer.log(report, ratio_percentage=False, newline=True, ratio=False)
if best_epoch and best_epoch != epochs:
logger.info(f'Restored the best model with {best_metric} saved {epochs - best_epoch} epochs ago')
self.load_weights(save_dir)
[docs] def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric: SpearmanCorrelation, logger: logging.Logger,
history=None, gradient_accumulation=1, **kwargs):
self.model.train()
optimizer, scheduler = optimizer
timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
total_loss = 0
metric.reset()
for batch in trn:
output = self.feed_batch(batch)
prediction = self.decode(output)
metric(prediction, batch['similarity'])
loss = output['loss']
if gradient_accumulation and gradient_accumulation > 1:
loss /= gradient_accumulation
loss.backward()
total_loss += loss.item()
if history.step(gradient_accumulation):
if self.config.grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger)
del loss
return total_loss / timer.total
[docs] @torch.no_grad()
def evaluate_dataloader(self, data: DataLoader, logger: logging.Logger, metric=None, output=False, **kwargs):
self.model.eval()
timer = CountdownTimer(len(data))
total_loss = 0
metric.reset()
if output:
predictions = []
orders = []
samples = []
for batch in data:
output_dict = self.feed_batch(batch)
prediction = self.decode(output_dict)
metric(prediction, batch['similarity'])
if output:
predictions.extend(prediction.tolist())
orders.extend(batch[IDX])
samples.extend(list(zip(batch['sent_a'], batch['sent_b'])))
loss = output_dict['loss']
total_loss += loss.item()
timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
logger=logger)
del loss
if output:
predictions = reorder(predictions, orders)
samples = reorder(samples, orders)
with open(output, 'w') as out:
for s, p in zip(samples, predictions):
out.write('\t'.join(s + (str(p),)))
out.write('\n')
return total_loss / timer.total
# noinspection PyMethodOverriding
[docs] def build_model(self, transformer, training=True, **kwargs) -> torch.nn.Module:
config = AutoConfig_.from_pretrained(transformer, num_labels=1)
if training:
model = AutoModelForSequenceClassification.from_pretrained(transformer, config=config)
else:
model = AutoModelForSequenceClassification.from_config(config)
return model
[docs] def predict(self, data: Union[List[str], List[List[str]]], batch_size: int = None, **kwargs) -> Union[
float, List[float]]:
""" Predict the similarity between sentence pairs.
Args:
data: Sentence pairs.
batch_size: The number of samples in a batch.
**kwargs: Not used.
Returns:
Similarities between sentences.
"""
if not data:
return []
flat = isinstance(data[0], str)
if flat:
data = [data]
dataloader = self.build_dataloader([{'sent_a': x[0], 'sent_b': x[1]} for x in data],
batch_size=batch_size or self.config.batch_size,
device=self.device)
orders = []
predictions = []
for batch in dataloader:
output_dict = self.feed_batch(batch)
prediction = self.decode(output_dict)
predictions.extend(prediction.tolist())
orders.extend(batch[IDX])
predictions = reorder(predictions, orders)
if flat:
return predictions[0]
return predictions
# noinspection PyMethodOverriding
[docs] def fit(self, trn_data, dev_data, save_dir,
transformer,
sent_a_col,
sent_b_col,
similarity_col,
delimiter='auto',
batch_size=32,
max_seq_len=128,
epochs=3,
lr=1e-3,
transformer_lr=5e-5,
adam_epsilon=1e-8,
weight_decay=0.0,
warmup_steps=0.1,
gradient_accumulation=1,
grad_norm=1.0,
sampler_builder=None,
devices=None,
logger=None,
seed=None,
finetune: Union[bool, str] = False, eval_trn=True, _device_placeholder=False, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
[docs] def on_config_ready(self, transformer, max_seq_len, **kwargs):
super().on_config_ready(**kwargs)
self._tokenizer = TransformerTextTokenizer(AutoTokenizer_.from_pretrained(transformer),
text_a_key='sent_a',
text_b_key='sent_b',
output_key='',
max_seq_length=max_seq_len)
def feed_batch(self, batch) -> SequenceClassifierOutput:
return self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'],
token_type_ids=batch['token_type_ids'], labels=batch.get('similarity', None))
def decode(self, output: SequenceClassifierOutput):
return output.logits.squeeze(-1).detach().clip(self.config.min_score, self.config.max_score)
def report_metrics(self, loss, metric):
return f'loss: {loss:.4f} {metric}'