Skip to content

Commit

Permalink
Address feedback: Use walk for traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jun 11, 2024
1 parent 762f8b0 commit 509f4a7
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ FailureOr<func::FuncOp> getMainFunction(ModuleOp module, StringRef mainName) {
class DefaultInterpreterFallback : public InterpreterFallback {
public:
DefaultInterpreterFallback(const InterpreterConfiguration &config)
: config(config) {};
: config(config){};

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
Expand Down Expand Up @@ -170,7 +170,7 @@ LogicalResult removeDynamism(ModuleOp module, func::FuncOp func,
return success();
}

bool isQuantizedTypes(TypeRange types) {
bool isAnyQuantizedTypes(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return isa<quant::QuantizedType>(getElementTypeOrSelf(type));
});
Expand All @@ -185,20 +185,19 @@ bool isQuantizedTypes(TypeRange types) {
// Returns:
// True if the operation or any nested operation uses quantized types,
// false otherwise.
bool operationUsesQuantType(Operation &op) {
for (Region &region : op.getRegions()) {
for (Block &block : region) {
for (Operation &op : block) {
TypeRange operandTypes = op.getOperandTypes();
TypeRange resultTypes = op.getResultTypes();
if (isQuantizedTypes(operandTypes) || isQuantizedTypes(resultTypes) ||
operationUsesQuantType(op)) {
return true;
}
}
bool funcUsesQuantType(func::FuncOp func_op) {
bool usesQuantizedType = false;

func_op.walk([&](Operation *op) {
if (isAnyQuantizedTypes(op->getOperandTypes()) ||
isAnyQuantizedTypes(op->getResultTypes())) {
usesQuantizedType = true;
return WalkResult::interrupt();
}
}
return false;
return WalkResult::advance();
});

return usesQuantizedType;
}

// Lowers quantization-related operations and types within a function to
Expand All @@ -217,9 +216,9 @@ bool operationUsesQuantType(Operation &op) {
// Returns:
// A `LogicalResult` indicating success or failure of the lowering process.
LogicalResult lowerQuantization(ModuleOp module, func::FuncOp func) {
if (!(isQuantizedTypes(func.getFunctionType().getInputs()) ||
operationUsesQuantType(*func.getOperation()) ||
isQuantizedTypes(func.getFunctionType().getResults()))) {
if (!(isAnyQuantizedTypes(func.getFunctionType().getInputs()) ||
funcUsesQuantType(func) ||
isAnyQuantizedTypes(func.getFunctionType().getResults()))) {
return success();
}

Expand Down

0 comments on commit 509f4a7

Please sign in to comment.