Skip to content

Commit

Permalink
[BF16] The BF16ComputeLegalizer expects to be run in lowering to
Browse files Browse the repository at this point in the history
legalize bf16 before down stream passes. This can be fixed to
support conditional target dependent lowering for BF16, but that
is outside the scope of this change, so prefer to revert API changes
to BF16 and keep them the same as before, even though the API will
be different for FP8.
  • Loading branch information
csullivan committed Mar 11, 2024
1 parent 21566a4 commit b97081d
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 14 deletions.
6 changes: 2 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize(Target target);
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
Expand All @@ -414,14 +414,12 @@ TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float

/*!
* \brief Legalize bf16 storage types to u16.
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16StorageLegalize(Target target);
TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Legalize fp8 storage types to u8.
* \param target The target used for checking native fp8 support
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize(Target target);
Expand Down
4 changes: 2 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::TransformMmaBufferLayout());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());

Expand Down Expand Up @@ -569,7 +570,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
mixed_pass_list.push_back(tir::transform::BF16ComputeLegalize(target));

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
Expand Down Expand Up @@ -621,7 +621,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

Expand Down
7 changes: 2 additions & 5 deletions src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ namespace meta_schedule {
class DisallowAsyncStridedMemCopyNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {
this->target_ = context->target.value();
}
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
Expand All @@ -140,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16ComputeLegalize(this->target_));
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectVirtualThread());
Expand Down Expand Up @@ -168,7 +166,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
return Postproc(n);
}

Target target_{nullptr};
static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16ComputeLegalize(this->target_));
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
// Phase 2
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ bool CheckDataTypeSupport(const Target& target, const std::string& support_func_
return has_native_support;
}

Pass BF16ComputeLegalize(Target target) {
Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports bf16
return BF16ComputeLegalizer().Legalize(f);
Expand All @@ -717,7 +717,7 @@ Pass BF16ComputeLegalize(Target target) {

TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize);

Pass BF16StorageLegalize(Target target) {
Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports bf16
return BF16StorageLegalizer().Legalize(f);
Expand Down

0 comments on commit b97081d

Please sign in to comment.