From fad23cb09a79f2221bebafda8901ef7d21b4d51b Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Thu, 1 Aug 2024 13:56:29 +0000 Subject: [PATCH] chore: Update train_cie.py to improve CIE training --- relik/reader/trainer/train_cie.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/relik/reader/trainer/train_cie.py b/relik/reader/trainer/train_cie.py index 77d065d..6801384 100644 --- a/relik/reader/trainer/train_cie.py +++ b/relik/reader/trainer/train_cie.py @@ -55,13 +55,15 @@ def train(cfg: DictConfig) -> None: # model declaration model = RelikReaderREPLModule( cfg=OmegaConf.to_container(cfg), - transformer_model=cfg.model.model.transformer_model, + # transformer_model=cfg.model.model.transformer_model, additional_special_symbols=len(special_symbols), additional_special_symbols_types=len(special_symbols_types), entity_type_loss=True, add_entity_embedding=True, training=True, + **cfg.model.model, ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer # optimizer declaration opt_conf = cfg.model.optimizer