-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR][CUDA] Add native FP8 support to codegen #16548
[TIR][CUDA] Add native FP8 support to codegen #16548
Conversation
e4a60ad
to
38693aa
Compare
please checkin on the ci issues, likely we need requires_cuda tag? |
ef2be51
to
4593359
Compare
Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
conversion between float and half vector types with equal lanes
* test_half_misaligned_vector_load * test_half_broadcast
* test_e4m3_packing * test_half4_vector_add
…egal memory access
from fp16 weights to fp16 dequantized weights for comparison. * WIP: Use small test size to narrow down numerics issue beyond first vector store
…ersion for fp8 support on cuda target
54fa150
to
21566a4
Compare
@tvm-bot rerun |
legalize bf16 before down stream passes. This can be fixed to support conditional target dependent lowering for BF16, but that is outside the scope of this change, so prefer to revert API changes to BF16 and keep them the same as before, even though the API will be different for FP8.
src/target/llvm/codegen_llvm.cc
Outdated
@@ -594,6 +596,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { | |||
return llvm::VectorType::get(etype, dtype.lanes()); | |||
#endif | |||
} else { | |||
ICHECK(etype != nullptr) << "No suitable llvm type found for dtype: " << dtype; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this breaks CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @vinx13. I was trying to leave the code better than I found it, but I guess some downstream handling is relying on this behavior. I have removed the check.
…types (lanes > 4).
* Add TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly) * Move fp8 tests to codegen tests
* [TIR][CUDA] Add native FP8 support to codegen Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops. * Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls. * Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3) * Add test for e4m3 <-> half conversion which lowers to ptx intrins. * Introduce half4 and support native fp8 vector types (1, 2, 4), and conversion between float and half vector types with equal lanes * Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4). * Test e4m3 x4 vector quant/dequant --------- Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
* [TIR][CUDA] Add native FP8 support to codegen Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops. * Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls. * Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3) * Add test for e4m3 <-> half conversion which lowers to ptx intrins. * Introduce half4 and support native fp8 vector types (1, 2, 4), and conversion between float and half vector types with equal lanes * Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4). * Test e4m3 x4 vector quant/dequant --------- Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.
* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin
wmma
calls.* Implement support for
float16x4 (half4)
for use withe4m3_float8x4 (__nv_fp8x4_e4m3)
e.g.