Skip to content

Commit

Permalink
[ExecuTorch] Allow setting dtype to bf16 in export_llama
Browse files Browse the repository at this point in the history
Differential Revision: D61981363

Pull Request resolved: #4985
  • Loading branch information
swolchok authored Sep 6, 2024
1 parent 1d420c9 commit 1511fc1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def build_args_parser() -> argparse.ArgumentParser:
"--dtype-override",
default="fp32",
type=str,
choices=["fp32", "fp16"],
choices=["fp32", "fp16", "bf16"],
help="Override the dtype of the model (default is the checkpoint dtype)."
"Options: fp32, fp16. Please be aware that only some backends support fp16.",
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
)

parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def to_torch_dtype(self) -> torch.dtype:
mapping = {
DType.fp32: torch.float32,
DType.fp16: torch.float16,
DType.bf16: torch.bfloat16,
}
if self not in mapping:
raise ValueError(f"Unsupported dtype {self}")
Expand Down

0 comments on commit 1511fc1

Please sign in to comment.