From 3e33e8258db88ab7d358b821193bcbbebf350e83 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 17 Sep 2024 15:42:27 -0700 Subject: [PATCH] [not for land yet] hack max and abs out of ops eligible for AC Summary: For now, this is not for land and just saving work and starting a discussion. We need to calculate max(abs(tensor)) for each float8 gemm input when using per-tensor scaling. I realized that today this does not work efficiently with AC, because max(abs(tensor)) is usually recomputed. Since the output size is 1, it's more efficient to save it and never recompute. For now, just hack these ops into the do-not-recompute list to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs. I verified in the pre-post traces that the redundant triton kernels to calculate max(abs(activation)) and max(abs(weight)) are gone with this hack. Heading to PTC but we should get a measurement on a larger model, and figure out a better way to land this. Test Plan: https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75 Reviewers: Subscribers: Tasks: Tags: --- torchtitan/parallelisms/parallelize_llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703d..68865ddb 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -206,6 +206,11 @@ def apply_tp( torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + # Not for land in the current state, need to align on best way to expose this + # for various AC options. For now just hack it in here to get a clean + # measurement. + torch.ops.aten.abs.default, + torch.ops.aten.max.default, }