Skip to content

Commit

Permalink
tune num_stages of sum kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Aug 28, 2024
1 parent 78bd3d7 commit ac92216
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_bf16 = use_bf16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
**propagation_alg_kwargs,
num_stages = 1
)

elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8:
Expand All @@ -904,7 +905,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_bf16 = use_bf16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
**propagation_alg_kwargs,
num_stages = 1
)

else:
Expand All @@ -927,7 +929,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_bf16 = use_bf16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
**propagation_alg_kwargs,
num_stages = 1
)

return None
Expand Down Expand Up @@ -1955,7 +1958,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_K = BLOCK_SIZE_K,
TL_DOT = TL_DOT,
num_warps = 2, # TODO: test for different devices
num_stages = 1,
propagation_alg_id = propagation_alg_id,
accumulate_ch_flows = accumulate_ch_flows,
Expand Down Expand Up @@ -1987,7 +1989,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo
BLOCK_SIZE_M = BLOCK_SIZE_M,
BLOCK_SIZE_K = BLOCK_SIZE_K,
TL_DOT = TL_DOT,
num_warps = 2, # TODO: test for different devices
num_stages = 1,
propagation_alg_id = propagation_alg_id,
accumulate_ch_flows = accumulate_ch_flows,
Expand Down Expand Up @@ -2320,7 +2321,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor
propagation_alg_id = propagation_alg_id,
negate_pflows = negate_pflows,
allow_neg_flows = allow_neg_flows,
**propagation_alg_kwargs
**propagation_alg_kwargs,
num_stages = 1
)

else:
Expand All @@ -2347,7 +2349,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor
propagation_alg_id = propagation_alg_id,
negate_pflows = negate_pflows,
allow_neg_flows = allow_neg_flows,
**propagation_alg_kwargs
**propagation_alg_kwargs,
num_stages = 1
)

return None
Expand Down

0 comments on commit ac92216

Please sign in to comment.