[docs]classRNNTagger(Tagger):def__init__(self,**kwargs)->None:"""An old-school tagger using non-contextualized embeddings and RNNs as context layer. Args: **kwargs: Predefined config. """super().__init__(**kwargs)self.model:RNNTaggingModel=None# noinspection PyMethodOverriding
[docs]defexecute_training_loop(self,trn:DataLoader,dev:DataLoader,epochs,criterion,optimizer,metric,save_dir,logger,patience,**kwargs):max_e,max_metric=0,-1criterion=self.build_criterion()timer=CountdownTimer(epochs)ratio_width=len(f'{len(trn)}/{len(trn)}')scheduler=self.build_scheduler(**merge_dict(self.config,optimizer=optimizer,overwrite=True))ifnotpatience:patience=epochsforepochinrange(1,epochs+1):logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")self.fit_dataloader(trn,criterion,optimizer,metric,logger,ratio_width=ratio_width)loss,dev_metric=self.evaluate_dataloader(dev,criterion,logger)ifscheduler:ifisinstance(scheduler,ReduceLROnPlateau):scheduler.step(dev_metric.score)else:scheduler.step(epoch)report_patience=f'Patience: {epoch-max_e}/{patience}'# save the model if it is the best so farifdev_metric>max_metric:self.save_weights(save_dir)max_e,max_metric=epoch,dev_metricreport_patience='[red]Saved[/red] 'stop=epoch-max_e>=patienceifstop:timer.stop()timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}',ratio_percentage=False,newline=True,ratio=False)ifstop:breaktimer.stop()ifmax_e!=epoch:self.load_weights(save_dir)logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}")logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
[docs]defbuild_dataloader(self,data,batch_size,shuffle,device,logger=None,**kwargs)->DataLoader:vocabs=self.vocabstoken_embed=self._convert_embed()dataset=dataifisinstance(data,TransformableDataset)elseself.build_dataset(data,transform=[vocabs])ifvocabs.mutable:# Before building vocabs, let embeddings submit their vocabs, some embeddings will possibly opt out as their# transforms are not relevant to vocabsifisinstance(token_embed,Embedding):transform=token_embed.transform(vocabs=vocabs)iftransform:dataset.transform.insert(-1,transform)self.build_vocabs(dataset,logger)ifisinstance(token_embed,Embedding):# Vocabs built, now add all transforms to the pipeline. Be careful about redundant ones.transform=token_embed.transform(vocabs=vocabs)iftransformandtransformnotindataset.transform:dataset.transform.insert(-1,transform)sampler=SortingSampler([len(sample[self.config.token_key])forsampleindataset],batch_size,shuffle=shuffle)returnPadSequenceDataLoader(dataset,device=device,batch_sampler=sampler,vocabs=vocabs)