diff --git a/main.py b/main.py index 1a67f26fa..503dc181f 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ import torch import transformers from accelerate import Accelerator -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, + AutoTokenizer, HfArgumentParser) from lm_eval.arguments import EvalArguments from lm_eval.evaluator import Evaluator @@ -29,6 +30,9 @@ def __iter__(self): yield choice +MODEL_CLASSES = {"CausalLM": AutoModelForCausalLM, "Seq2SeqLM": AutoModelForSeq2SeqLM} + + def parse_args(): parser = HfArgumentParser(EvalArguments) @@ -37,6 +41,12 @@ def parse_args(): default="codeparrot/codeparrot-small", help="Model to evaluate, provide a repo name in Hugging Face hub or a local path", ) + parser.add_argument( + "--model_class", + default="CausalLM", + choices=["CausalLM", "Seq2SeqLM"], + help="The model will be loaded using transformer's AutoModelFor.", + ) parser.add_argument( "--revision", default=None, @@ -187,11 +197,14 @@ def main(): raise ValueError( f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" ) + model_class = MODEL_CLASSES[args.model_class] + if accelerator.is_main_process: + print(f"The model will be loaded using the class: AutoModelFor{model_class.__name__}") if args.load_in_8bit: print("Loading model in 8bit") current_device = accelerator.process_index # the model needs to fit in one GPU - model = AutoModelForCausalLM.from_pretrained( + model = model_class.from_pretrained( args.model, revision=args.revision, load_in_8bit=args.load_in_8bit, @@ -203,7 +216,7 @@ def main(): print("Loading model in 4bit") current_device = accelerator.process_index # the model needs to fit in one GPU - model = AutoModelForCausalLM.from_pretrained( + model = model_class.from_pretrained( args.model, revision=args.revision, load_in_4bit=args.load_in_4bit, @@ -213,7 +226,7 @@ def main(): ) else: print(f"Loading model in {args.precision}") - model = AutoModelForCausalLM.from_pretrained( + model = model_class.from_pretrained( args.model, revision=args.revision, torch_dtype=dict_precisions[args.precision],