From 1511fc1d7fc9c72b635c6433f2de6c0e2785bf74 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 6 Sep 2024 16:54:12 -0700 Subject: [PATCH] [ExecuTorch] Allow setting dtype to bf16 in export_llama Differential Revision: D61981363 Pull Request resolved: https://github.com/pytorch/executorch/pull/4985 --- examples/models/llama2/export_llama_lib.py | 4 ++-- extension/llm/export/builder.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index c19ddd58a2..e56d2fe848 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -256,9 +256,9 @@ def build_args_parser() -> argparse.ArgumentParser: "--dtype-override", default="fp32", type=str, - choices=["fp32", "fp16"], + choices=["fp32", "fp16", "bf16"], help="Override the dtype of the model (default is the checkpoint dtype)." - "Options: fp32, fp16. Please be aware that only some backends support fp16.", + "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.", ) parser.add_argument( diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 2c2e52c744..6eecebb946 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -46,6 +46,7 @@ def to_torch_dtype(self) -> torch.dtype: mapping = { DType.fp32: torch.float32, DType.fp16: torch.float16, + DType.bf16: torch.bfloat16, } if self not in mapping: raise ValueError(f"Unsupported dtype {self}")