From 509f4a7029058454bcb143f940e4883ca9ece9ae Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 11 Jun 2024 00:25:36 +0000 Subject: [PATCH] Address feedback: Use walk for traversal --- stablehlo/reference/Api.cpp | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/stablehlo/reference/Api.cpp b/stablehlo/reference/Api.cpp index a23d91b38bf..3c24908f84e 100644 --- a/stablehlo/reference/Api.cpp +++ b/stablehlo/reference/Api.cpp @@ -68,7 +68,7 @@ FailureOr getMainFunction(ModuleOp module, StringRef mainName) { class DefaultInterpreterFallback : public InterpreterFallback { public: DefaultInterpreterFallback(const InterpreterConfiguration &config) - : config(config) {}; + : config(config){}; virtual llvm::Error operator()(Operation &op, Scope &scope, Process *process) final { @@ -170,7 +170,7 @@ LogicalResult removeDynamism(ModuleOp module, func::FuncOp func, return success(); } -bool isQuantizedTypes(TypeRange types) { +bool isAnyQuantizedTypes(TypeRange types) { return llvm::any_of(types, [](Type type) { return isa(getElementTypeOrSelf(type)); }); @@ -185,20 +185,19 @@ bool isQuantizedTypes(TypeRange types) { // Returns: // True if the operation or any nested operation uses quantized types, // false otherwise. -bool operationUsesQuantType(Operation &op) { - for (Region ®ion : op.getRegions()) { - for (Block &block : region) { - for (Operation &op : block) { - TypeRange operandTypes = op.getOperandTypes(); - TypeRange resultTypes = op.getResultTypes(); - if (isQuantizedTypes(operandTypes) || isQuantizedTypes(resultTypes) || - operationUsesQuantType(op)) { - return true; - } - } +bool funcUsesQuantType(func::FuncOp func_op) { + bool usesQuantizedType = false; + + func_op.walk([&](Operation *op) { + if (isAnyQuantizedTypes(op->getOperandTypes()) || + isAnyQuantizedTypes(op->getResultTypes())) { + usesQuantizedType = true; + return WalkResult::interrupt(); } - } - return false; + return WalkResult::advance(); + }); + + return usesQuantizedType; } // Lowers quantization-related operations and types within a function to @@ -217,9 +216,9 @@ bool operationUsesQuantType(Operation &op) { // Returns: // A `LogicalResult` indicating success or failure of the lowering process. LogicalResult lowerQuantization(ModuleOp module, func::FuncOp func) { - if (!(isQuantizedTypes(func.getFunctionType().getInputs()) || - operationUsesQuantType(*func.getOperation()) || - isQuantizedTypes(func.getFunctionType().getResults()))) { + if (!(isAnyQuantizedTypes(func.getFunctionType().getInputs()) || + funcUsesQuantType(func) || + isAnyQuantizedTypes(func.getFunctionType().getResults()))) { return success(); }