Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][CUDA] Add native FP8 support to codegen #16548

Merged
merged 17 commits into from
Mar 15, 2024

Conversation

csullivan
Copy link
Contributor

@csullivan csullivan commented Feb 8, 2024

Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.

* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls.

* Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3)

e.g.

#include <cuda_fp8.h>
using fp8_e4_t = __nv_fp8_e4m3;
using fp8_e4_2_t = __nv_fp8x2_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3;
using fp8_e5_t = __nv_fp8_e5m2;
using fp8_e5_2_t = __nv_fp8x2_e5m2;
using fp8_e5_4_t = __nv_fp8x4_e5m2;

struct __align__(4) half4 {
  __half x, y, z, w;
  __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
  __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
  __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4);
  __host__ __device__ explicit operator __nv_fp8x4_e4m3() const;
};
#endif


extern "C" __global__ void __launch_bounds__(32) add_kernel(fp8_e4_4_t* __restrict__ A, fp8_e4_4_t* __restrict__ B, fp8_e4_4_t* __restrict__ C) {
  half4 __1;
    half4 v_ = (half4)(A[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))]);
    half4 v__1 = (half4)(B[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))]);
    __1.x = (v_.x+v__1.x);
    __1.y = (v_.y+v__1.y);
    __1.z = (v_.z+v__1.z);
    __1.w = (v_.w+v__1.w);
  C[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = (fp8_e4_4_t)(__1);
}
}

@csullivan csullivan force-pushed the feature/2024-02-08/cuda_native_fp8 branch 3 times, most recently from e4a60ad to 38693aa Compare February 12, 2024 17:12
@tqchen
Copy link
Member

tqchen commented Feb 15, 2024

please checkin on the ci issues, likely we need requires_cuda tag?

@csullivan csullivan force-pushed the feature/2024-02-08/cuda_native_fp8 branch from ef2be51 to 4593359 Compare February 21, 2024 20:42
csullivan and others added 13 commits February 21, 2024 22:04
Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
conversion between float and half vector types with equal lanes
* test_half_misaligned_vector_load
* test_half_broadcast
* test_e4m3_packing
* test_half4_vector_add
from fp16 weights to fp16 dequantized weights for comparison.

* WIP: Use small test size to narrow down numerics
  issue beyond first vector store
@csullivan csullivan force-pushed the feature/2024-02-08/cuda_native_fp8 branch from 54fa150 to 21566a4 Compare February 21, 2024 22:16
@tqchen
Copy link
Member

tqchen commented Mar 11, 2024

@tvm-bot rerun

legalize bf16 before down stream passes. This can be fixed to
support conditional target dependent lowering for BF16, but that
is outside the scope of this change, so prefer to revert API changes
to BF16 and keep them the same as before, even though the API will
be different for FP8.
@@ -594,6 +596,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
return llvm::VectorType::get(etype, dtype.lanes());
#endif
} else {
ICHECK(etype != nullptr) << "No suitable llvm type found for dtype: " << dtype;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @vinx13. I was trying to leave the code better than I found it, but I guess some downstream handling is relying on this behavior. I have removed the check.

* Add TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly)
* Move fp8 tests to codegen tests
@csullivan csullivan merged commit feb1043 into apache:main Mar 15, 2024
18 checks passed
JosephTheOctonaut added a commit to JosephTheOctonaut/tvm that referenced this pull request Mar 15, 2024
* [TIR][CUDA] Add native FP8 support to codegen

Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.

* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls.

* Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3)

* Add test for e4m3 <-> half conversion which lowers to ptx intrins.

* Introduce half4 and support native fp8 vector types (1, 2, 4), and
conversion between float and half vector types with equal lanes

* Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4).

* Test e4m3 x4 vector quant/dequant

---------

Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
* [TIR][CUDA] Add native FP8 support to codegen

Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.

* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls.

* Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3)

* Add test for e4m3 <-> half conversion which lowers to ptx intrins.

* Introduce half4 and support native fp8 vector types (1, 2, 4), and
conversion between float and half vector types with equal lanes

* Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4).

* Test e4m3 x4 vector quant/dequant

---------

Co-authored-by: Joseph McMahan <jmcmahan@octoml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants