diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 609f0f06c4..dfbbedf77d 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -48,10 +48,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): """ def __init__( - self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], gen_kwargs: Optional[Dict[str, Any]], **kwargs ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args + self.gen_kwargs = gen_kwargs if processor is not None: self.add_callback(SaveProcessorCallback(processor)) @@ -102,7 +103,7 @@ def prediction_step( inputs["labels"] = inputs["labels"][:, :prompt_len] loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) - model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **self.gen_kwargs ) if generated_tokens is not None and self.args.predict_with_generate: generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 43a9aef16f..2e566e03fe 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -73,10 +73,17 @@ def run_sft( metric_module["compute_metrics"] = ComputeAccuracy() metric_module["preprocess_logits_for_metrics"] = eval_logit_processor + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + # Initialize our Trainer trainer = CustomSeq2SeqTrainer( model=model, args=training_args, + gen_kwargs=gen_kwargs, finetuning_args=finetuning_args, data_collator=data_collator, callbacks=callbacks, @@ -85,12 +92,6 @@ def run_sft( **metric_module, ) - # Keyword arguments for `model.generate` - gen_kwargs = generating_args.to_dict() - gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids - gen_kwargs["pad_token_id"] = tokenizer.pad_token_id - gen_kwargs["logits_processor"] = get_logits_processor() - # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)