Skip to content

Commit

Permalink
[NFC][LinalgExt] Rename op functions from outdated naming conventions (
Browse files Browse the repository at this point in the history
…iree-org#17333)

Most ops in LinalgExt used outdated naming conventions. This PR updates
the names of op member functions to match the `getX()` convention.
  • Loading branch information
Max191 authored May 10, 2024
1 parent 7baef75 commit c81496c
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 172 deletions.
45 changes: 20 additions & 25 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,12 @@ LogicalResult ScanOp::verify() {
if (getNumDpsInits() != 2) {
return op->emitOpError("expected two output operands");
}
if (!isa<ShapedType>(input().getType())) {
if (!isa<ShapedType>(getInput().getType())) {
return op->emitOpError("expected first input element type to be shaped");
}
auto accumulatorType = cast<ShapedType>(accumulator().getType());
auto inputType = cast<ShapedType>(input().getType());
auto outputType = cast<ShapedType>(output().getType());
auto accumulatorType = cast<ShapedType>(getAccumulator().getType());
auto inputType = cast<ShapedType>(getInput().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
Expand Down Expand Up @@ -435,8 +435,8 @@ LogicalResult ReverseOp::verify() {
if (getNumDpsInits() != 1) {
return op->emitOpError("expected exactly one output");
}
auto inputType = cast<ShapedType>(input().getType());
auto outputType = cast<ShapedType>(output().getType());
auto inputType = cast<ShapedType>(getInput().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
if (inputType.getElementType() != outputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
Expand All @@ -457,7 +457,7 @@ LogicalResult ReverseOp::verify() {

int64_t rank = getOperandRank();
llvm::SmallSetVector<int64_t, 4> s;
for (auto dim : dims()) {
for (auto dim : getDimensionsArray()) {
if (dim < 0 || dim >= rank) {
return op->emitOpError("all the dimensions must be within [0, ")
<< rank << ")";
Expand Down Expand Up @@ -494,14 +494,14 @@ LogicalResult TopkOp::verify() {
return op->emitOpError("dimension exceeds rank");
}
// Ensure input/output element types match
auto inputValuesType = cast<ShapedType>(values().getType());
auto inputValuesType = cast<ShapedType>(getValues().getType());
auto outputValuesType = cast<ShapedType>(outputValues().getType());
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
return op->emitOpError("expected input/output value types to be identical");
}
// Indices must be int if provided
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
if (auto inputIndices = indices()) {
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (!inputIndicesType.getElementType().isInteger(32) ||
!outputIndicesType.getElementType().isInteger(32)) {
Expand All @@ -513,14 +513,14 @@ LogicalResult TopkOp::verify() {
if (inputValuesType.getRank() != outputValuesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
if (auto inputIndices = indices()) {
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
}
// Input indicies and values must have the same shape.
if (auto inputIndices = indices()) {
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) {
return op->emitOpError("input indices/values shape must match");
Expand Down Expand Up @@ -996,8 +996,7 @@ LogicalResult WinogradInputTransformOp::verify() {
return op->emitOpError(
"expected output rank to be equal to input rank + 2");
}
const SmallVector<int64_t> imageDims = imageDimensions();
const size_t numImageDims = imageDims.size();
ArrayRef<int64_t> imageDims = getImageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
Expand All @@ -1007,14 +1006,11 @@ LogicalResult WinogradInputTransformOp::verify() {
return op->emitOpError(
"expect image dimensions to be either [1, 2] or [2, 3]");
}
const int64_t outputTileSize = getOutputTileSize();
const int64_t kernelSize = getKernelSize();
const int64_t inputTileSize = getInputTileSize();
SmallVector<int64_t> expectedOutputShape(getOutputRank(), inputTileSize);
SmallVector<int64_t> expectedOutputShape(getOutputRank(), getInputTileSize());
int outputIndex;
ArrayRef<int64_t> inputShape = inputType.getShape();
for (int i = 0; i < inputShape.size(); i++) {
outputIndex = i + numImageDims;
outputIndex = i + imageDims.size();
if (ShapedType::isDynamic(inputShape[i])) {
expectedOutputShape[outputIndex] = inputShape[i];
continue;
Expand All @@ -1023,7 +1019,8 @@ LogicalResult WinogradInputTransformOp::verify() {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] =
std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize);
std::ceil(static_cast<float>(inputShape[i] - getKernelSize() + 1) /
getOutputTileSize());
}
}
if (isNchw()) {
Expand Down Expand Up @@ -1189,8 +1186,7 @@ LogicalResult WinogradOutputTransformOp::verify() {
return op->emitOpError(
"expected output rank to be equal to input rank - 2");
}
const SmallVector<int64_t> imageDims = imageDimensions();
const size_t numImageDims = imageDims.size();
ArrayRef<int64_t> imageDims = getImageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
Expand All @@ -1204,19 +1200,18 @@ LogicalResult WinogradOutputTransformOp::verify() {
if (isNchw()) {
permute<Permutation::TTNHWC_TO_TTNCHW>(inputShape);
}
const int64_t outputTileSize = getOutputTileSize();
SmallVector<int64_t> expectedOutputShape(getOutputRank(), 1);
int outputIndex;
for (int i = numImageDims; i < inputShape.size(); i++) {
outputIndex = i - numImageDims;
for (int i = imageDims.size(); i < inputShape.size(); i++) {
outputIndex = i - imageDims.size();
if (ShapedType::isDynamic(inputShape[i])) {
expectedOutputShape[outputIndex] = inputShape[i];
continue;
}
if (!imageDimsSet.contains(outputIndex)) {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] = outputTileSize * inputShape[i];
expectedOutputShape[outputIndex] = getOutputTileSize() * inputShape[i];
}
}
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
Expand Down
Loading

0 comments on commit c81496c

Please sign in to comment.