Skip to content

Commit

Permalink
[CodeGen] Fix the argument replacements in scf.forall op lowering. (i…
Browse files Browse the repository at this point in the history
…ree-org#18613)

They should be scaled by tile sizes. Otherwise, we always access the
same memory chunk.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Sep 26, 2024
1 parent 66d0c31 commit 76c3e61
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
Block *parentBlock = forallOp->getBlock();
Block *remainingBlock =
rewriter.splitBlock(parentBlock, Block::iterator(forallOp));
for (auto [id, step] : llvm::zip_equal(procId, mixedStep)) {
rewriter.setInsertionPointToEnd(parentBlock);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr expr = s1 * s0;
id = affine::makeComposedFoldedAffineApply(rewriter, forallOp.getLoc(),
expr, {id, step});
}
auto argReplacements =
getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId);
Block *loopBody = forallOp.getBody();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ hal.executable private @scf_forall_2D {
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: hal.executable.export public @scf_forall_2D layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
Expand All @@ -203,7 +205,9 @@ hal.executable private @scf_forall_2D {
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-NOT: scf.forall
// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[WG_ID_X]]]
// CHECK: "use"(%[[I]], %[[J]])

// -----

Expand Down Expand Up @@ -236,6 +240,7 @@ hal.executable private @scf_forall_2D_dynamic_tile_size {
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
// CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
Expand All @@ -246,10 +251,14 @@ hal.executable private @scf_forall_2D_dynamic_tile_size {
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
// CHECK: func @scf_forall_2D_dynamic_tile_size()
// CHECK-DAG: %[[STEP_Y:.+]] = hal.interface.constant.load {{.+}} ordinal(2)
// CHECK-DAG: %[[STEP_X:.+]] = hal.interface.constant.load {{.+}} ordinal(3)
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-NOT: scf.forall
// CHECK: "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
// CHECK: %[[I:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]], %[[STEP_Y]]]
// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]], %[[STEP_X]]]
// CHECK: "use"(%[[I]], %[[J]])

// -----

Expand Down Expand Up @@ -305,6 +314,7 @@ hal.executable private @scf_forall_4D {
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
// CHECK: hal.executable.export public @scf_forall_4D layout
// CHECK-NEXT: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
Expand All @@ -329,14 +339,20 @@ hal.executable private @scf_forall_4D {
// CHECK-DAG: %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5)
// CHECK-DAG: %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8)
// CHECK-DAG: %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9)
// CHECK-DAG: %[[STEP2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(10)
// CHECK-DAG: %[[STEP3:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(11)
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
// CHECK-DAG: %[[NITERS0:.+]] = affine.apply #[[MAP0]]()[%[[LB0]], %[[UB0]], %[[STEP0]]]
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
// CHECK-NOT: scf.forall
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]])
// CHECK: "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]])
// CHECK: %[[I:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[STEP0]]]
// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#1, %[[STEP1]]]
// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]], %[[STEP2]]]
// CHECK: %[[L:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_X]], %[[STEP3]]]
// CHECK: "use"(%[[I]], %[[J]], %[[K]], %[[L]])

// -----

Expand Down Expand Up @@ -364,6 +380,10 @@ hal.executable private @scf_forall_4D_static_interchange {
}
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 5)>
// CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout
// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
Expand All @@ -378,7 +398,11 @@ hal.executable private @scf_forall_4D_static_interchange {
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
// CHECK-NOT: scf.forall
// CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
// CHECK: "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1)
// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[DELINEARIZE]]#0]
// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
// CHECK: %[[L:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#1]
// CHECK: "use"(%[[DELINEARIZE]]#2, %[[I]], %[[J]], %[[K]], %[[L]])

// -----

Expand Down

0 comments on commit 76c3e61

Please sign in to comment.