Skip to content

Commit

Permalink
Add region to linalg_ext.attention (iree-org#18728)
Browse files Browse the repository at this point in the history
More attention variants leads to more `linalg_ext.attention` features.
Including a region we can support post Q@K linear operators, the main
example is `tanh` soft cap operations.
  • Loading branch information
rsuderman authored Oct 16, 2024
1 parent 8568efa commit a488d38
Show file tree
Hide file tree
Showing 25 changed files with 410 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ struct AttentionOpConversion
loc, result.getType(), query, key, value, scale, result,
rewriter.getAffineMapArrayAttr(indexingMaps), optionalMask);

{
auto *block = rewriter.createBlock(&attention.getRegion());
OpBuilder::InsertionGuard g(rewriter);
block->addArgument(rewriter.getF32Type(), loc);
rewriter.setInsertionPoint(block, block->begin());

rewriter.create<IREE::LinalgExt::YieldOp>(loc, block->getArgument(0));
}

rewriter.replaceOp(op, attention.getResult(0));
return success();
}
Expand Down
20 changes: 16 additions & 4 deletions compiler/plugins/input/Torch/InputConversion/test/attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar
// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<5x2x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32>

// -----
Expand All @@ -36,7 +39,10 @@ func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar
// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<5x2x8x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32>

// -----
Expand All @@ -56,7 +62,10 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2:
// CHECK: %[[ARG3:.*]]: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<1x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<1x3x4xf32>

// -----
Expand All @@ -80,5 +89,8 @@ func.func @attention_dyn(%arg0: tensor<?x?x4xf32>, %arg1: tensor<?x?x4xf32>, %ar
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<?x?x4xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x4xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ func.func @online_attention(%query: tensor<192x1024x64xf16>,
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR],
lowering_config = #config }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
^bb0(%score : f32):
iree_linalg_ext.yield %score : f32
}
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>

return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,10 @@ func.func @attention() attributes {hal.executable.target = #executable_target_em
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ func.func @attention_20x4096x64x4096x64() {
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
return
}
Expand Down Expand Up @@ -386,7 +389,10 @@ func.func @attention_large_head_dim_shared_mem() {
affine_map<(d1, d2, d3, d4) -> (d3, d4)>,
affine_map<(d1, d2, d3, d4) -> ()>,
affine_map<(d1, d2, d3, d4) -> (d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) -> tensor<1024x512xf16>
ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<1024x512xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<1024x512xf16>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,10 @@ hal.executable private @attention_20x4096x64x4096x64 {
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
lowering_config = #config}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
return
}
Expand Down Expand Up @@ -699,7 +702,10 @@ hal.executable private @attention_multiple_m_transpose {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
Expand Down Expand Up @@ -754,7 +760,10 @@ hal.executable private @attention_mfma_32x32x8 {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,28 @@ static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap,
return genericOp.getResult(0);
}

static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc,
Region &region, Value value) {
auto rank = cast<RankedTensorType>(value.getType()).getRank();
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(rank, builder.getContext());
SmallVector<AffineMap> indexingMaps{identityMap};
SmallVector<utils::IteratorType> iteratorTypes(rank,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, value.getType(), ValueRange{}, value, indexingMaps, iteratorTypes);
auto &dstRegion = genericOp.getRegion();
builder.cloneRegionBefore(region, dstRegion, dstRegion.end());
{
OpBuilder::InsertionGuard withinRegion(builder);
builder.setInsertionPoint(dstRegion.back().getTerminator());
builder.create<linalg::YieldOp>(
loc, dstRegion.back().getTerminator()->getOperands());
dstRegion.back().getTerminator()->erase();
}
return genericOp.getResult(0);
}

static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
AffineMap maskMap, Value qk, Value mask) {

Expand Down Expand Up @@ -339,11 +361,15 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);

s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);

// TODO: We shouldn't be relying on such attributes. We need a better
// mechanism to identify attention matmuls.
s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr());

s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);

if (qETy.getIntOrFloatBitWidth() <= 8) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
// losing numerical precision due to the low dynamic range of fp8 types when
Expand Down
33 changes: 33 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,20 @@ LogicalResult AttentionOp::verify() {
return failure();
}

auto &block = getRegion().front();
auto blockTys = block.getArgumentTypes();
if (!isa<FloatType>(blockTys[0]))
return attnOp->emitOpError("block argument 0 should be float");

auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
}

if (yieldOp->getNumOperands() != 1) {
return emitOpError("expected only one return");
}

return success();
}

Expand Down Expand Up @@ -1462,6 +1476,25 @@ LogicalResult OnlineAttentionOp::verify() {
return failure();
}

Block &block = attnOp.getRegion().front();
auto blockTys = block.getArgumentTypes();
if (blockTys.size() != 1) {
return attnOp->emitOpError("expects single block argument for score");
}

if (!isa<FloatType>(blockTys[0])) {
return attnOp->emitOpError("block argument 0 should be float");
}

auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
}

if (yieldOp->getNumOperands() != 1) {
return emitOpError("expected only one return");
}

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[

def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
Expand Down Expand Up @@ -502,6 +503,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
AnyShaped:$output,
AffineMapArrayAttr:$indexing_maps
);
let regions = (region SizedRegion<1>:$region);

let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
Expand All @@ -510,6 +512,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
`outs` `(` $output `:` type($output) `)`
$region
(`->` type($results)^)?
}];

Expand Down Expand Up @@ -564,6 +567,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",

def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
Expand Down Expand Up @@ -610,6 +614,7 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps
);
let regions = (region SizedRegion<1>:$region);

let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
Expand All @@ -618,6 +623,7 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
$region
(`->` type($results)^)?
}];

Expand Down
Loading

0 comments on commit a488d38

Please sign in to comment.