diff --git a/model/model_training/utils/utils.py b/model/model_training/utils/utils.py index 677da36a22..0397822d3c 100644 --- a/model/model_training/utils/utils.py +++ b/model/model_training/utils/utils.py @@ -330,15 +330,16 @@ def get_model(conf, tokenizer, pad_vocab_size_to_multiple_of=16, check_freeze_la ) n_embs = model.get_input_embeddings().num_embeddings - if len(tokenizer) != n_embs and check_freeze_layer: - assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen." - if len(tokenizer) != n_embs or pad_vocab_size_to_multiple_of: p = pad_vocab_size_to_multiple_of target_size = len(tokenizer) if not p else math.ceil(len(tokenizer) / p) * p print("Resizing embeddings to", target_size) model.resize_token_embeddings(target_size) + new_n_embs = model.get_input_embeddings().num_embeddings + if new_n_embs != n_embs and check_freeze_layer: + assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen." + if conf.freeze_layer: model = freeze_top_n_layers(model, conf.freeze_layer)