Skip to content

Commit

Permalink
Revert "[TorchToLinalg] perform rank0 elementwise computations outsid…
Browse files Browse the repository at this point in the history
…e linalg generic ops (#3762)" (#3767)

Reverted due to downstream model changes. Will reland with fixes post
integration.

This reverts commit 6e8c7be.
  • Loading branch information
rsuderman authored Oct 4, 2024
1 parent e9ed4af commit 53f7532
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 24 deletions.
19 changes: 0 additions & 19 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,25 +1627,6 @@ class ConvertElementwiseOp : public ConversionPattern {
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool isScalarOp = resultType.getShape().size() == 0;
if (isScalarOp) {
// for elementwise ops that are actually rank0 scalar computations,
// perform the payload outside a linalg generic op.
SmallVector<Value> payloadArgs;
for (auto t : tensorOperands) {
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
}
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
if (!scalarResult)
return rewriter.notifyMatchFailure(
op, "Failed to create payload for scalar elementwise op");
Value rank0Result =
createInitTensor(rewriter, loc, ValueRange{},
resultType.getElementType(), scalarResult);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
return success();
}
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
Expand Down
12 changes: 7 additions & 5 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
// CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
// CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: }
Expand Down

0 comments on commit 53f7532

Please sign in to comment.