diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 599cc3b..7fe7caf 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -881,7 +881,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten BLOCK_SIZE_M = BLOCK_SIZE_M, use_bf16 = use_bf16, propagation_alg_id = propagation_alg_id, - **propagation_alg_kwargs + **propagation_alg_kwargs, + num_stages = 1 ) elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8: @@ -904,7 +905,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten BLOCK_SIZE_M = BLOCK_SIZE_M, use_bf16 = use_bf16, propagation_alg_id = propagation_alg_id, - **propagation_alg_kwargs + **propagation_alg_kwargs, + num_stages = 1 ) else: @@ -927,7 +929,8 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten BLOCK_SIZE_M = BLOCK_SIZE_M, use_bf16 = use_bf16, propagation_alg_id = propagation_alg_id, - **propagation_alg_kwargs + **propagation_alg_kwargs, + num_stages = 1 ) return None @@ -1955,7 +1958,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_M = BLOCK_SIZE_M, BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, - num_warps = 2, # TODO: test for different devices num_stages = 1, propagation_alg_id = propagation_alg_id, accumulate_ch_flows = accumulate_ch_flows, @@ -1987,7 +1989,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_M = BLOCK_SIZE_M, BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, - num_warps = 2, # TODO: test for different devices num_stages = 1, propagation_alg_id = propagation_alg_id, accumulate_ch_flows = accumulate_ch_flows, @@ -2320,7 +2321,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor propagation_alg_id = propagation_alg_id, negate_pflows = negate_pflows, allow_neg_flows = allow_neg_flows, - **propagation_alg_kwargs + **propagation_alg_kwargs, + num_stages = 1 ) else: @@ -2347,7 +2349,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor propagation_alg_id = propagation_alg_id, negate_pflows = negate_pflows, allow_neg_flows = allow_neg_flows, - **propagation_alg_kwargs + **propagation_alg_kwargs, + num_stages = 1 ) return None