Skip to content

Commit

Permalink
chore: Update train_cie.py to improve CIE training
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Aug 1, 2024
1 parent a164550 commit fad23cb
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion relik/reader/trainer/train_cie.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def train(cfg: DictConfig) -> None:
# model declaration
model = RelikReaderREPLModule(
cfg=OmegaConf.to_container(cfg),
transformer_model=cfg.model.model.transformer_model,
# transformer_model=cfg.model.model.transformer_model,
additional_special_symbols=len(special_symbols),
additional_special_symbols_types=len(special_symbols_types),
entity_type_loss=True,
add_entity_embedding=True,
training=True,
**cfg.model.model,
)

model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
# optimizer declaration
opt_conf = cfg.model.optimizer
Expand Down

0 comments on commit fad23cb

Please sign in to comment.