From 3f013ecef227dbac2383c1dfefc56de5e2ba8a4a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 13 Jan 2024 11:48:30 -0500 Subject: [PATCH] feat: add support for constants in imported modules (#3726) - add a case to ConstantFolder, and move the constant folder slightly down in the pipeline to after imports have been resolved. - rename `pre_typecheck` to `constant_fold` (since that is all it does, and it doesn't strictly happen before typechecking anymore). --- .../codegen/modules/test_module_constants.py | 78 +++++++++++++++++++ tests/utils.py | 4 +- vyper/ast/nodes.py | 2 +- vyper/ast/nodes.pyi | 4 +- .../{pre_typecheck.py => constant_folding.py} | 31 +++++++- vyper/semantics/analysis/module.py | 7 +- 6 files changed, 116 insertions(+), 10 deletions(-) create mode 100644 tests/functional/codegen/modules/test_module_constants.py rename vyper/semantics/analysis/{pre_typecheck.py => constant_folding.py} (89%) diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py new file mode 100644 index 0000000000..aafbb69252 --- /dev/null +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -0,0 +1,78 @@ +def test_module_constant(make_input_bundle, get_contract): + mod1 = """ +X: constant(uint256) = 12345 + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 12345 + + +def test_nested_module_constant(make_input_bundle, get_contract): + # test nested module constants + # test at least 3 modules deep to test the `path.reverse()` gizmo + # in ConstantFolder.visit_Attribute() + mod1 = """ +X: constant(uint256) = 12345 + """ + mod2 = """ +import mod1 +X: constant(uint256) = 54321 + """ + mod3 = """ +import mod2 +X: constant(uint256) = 98765 + """ + + contract = """ +import mod1 +import mod2 +import mod3 + +@external +def test_foo() -> bool: + assert mod1.X == 12345 + assert mod2.X == 54321 + assert mod3.X == 98765 + assert mod2.mod1.X == mod1.X + assert mod3.mod2.mod1.X == mod1.X + assert mod3.mod2.X == mod2.X + return True + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1, "mod2.vy": mod2, "mod3.vy": mod3}) + + c = get_contract(contract, input_bundle=input_bundle) + assert c.test_foo() is True + + +def test_import_constant_array(make_input_bundle, get_contract, tx_failed): + mod1 = """ +X: constant(uint256[3]) = [1,2,3] + """ + contract = """ +import mod1 + +@external +def foo(ix: uint256) -> uint256: + return mod1.X[ix] + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo(0) == 1 + assert c.foo(1) == 2 + assert c.foo(2) == 3 + with tx_failed(): + c.foo(3) diff --git a/tests/utils.py b/tests/utils.py index b8a6b493d8..25dad818ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ import os from vyper import ast as vy_ast -from vyper.semantics.analysis.pre_typecheck import pre_typecheck +from vyper.semantics.analysis.constant_folding import constant_fold @contextlib.contextmanager @@ -17,5 +17,5 @@ def working_directory(directory): def parse_and_fold(source_code): ast = vy_ast.parse_to_ast(source_code) - pre_typecheck(ast) + constant_fold(ast) return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index fa1fb63673..df419daa25 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -396,7 +396,7 @@ def has_folded_value(self): """ return "folded_value" in self._metadata - def get_folded_value(self) -> "VyperNode": + def get_folded_value(self) -> "ExprNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 4a5bc0d001..7f8c902d45 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -30,8 +30,8 @@ class VyperNode: def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def get_folded_value(self) -> VyperNode: ... - def _set_folded_value(self, node: VyperNode) -> None: ... + def get_folded_value(self) -> ExprNode: ... + def _set_folded_value(self, node: ExprNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/constant_folding.py similarity index 89% rename from vyper/semantics/analysis/pre_typecheck.py rename to vyper/semantics/analysis/constant_folding.py index 1c2a5392c3..b165a6dae9 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -1,11 +1,11 @@ from vyper import ast as vy_ast -from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.exceptions import InvalidLiteral, UnfoldableNode, VyperException from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.namespace import get_namespace -def pre_typecheck(module_ast: vy_ast.Module): +def constant_fold(module_ast: vy_ast.Module): ConstantFolder(module_ast).run() @@ -89,6 +89,33 @@ def visit_Name(self, node) -> vy_ast.ExprNode: except KeyError: raise UnfoldableNode("unknown name", node) + def visit_Attribute(self, node) -> vy_ast.ExprNode: + namespace = get_namespace() + path = [] + value = node.value + while isinstance(value, vy_ast.Attribute): + path.append(value.attr) + value = value.value + + path.reverse() + + if not isinstance(value, vy_ast.Name): + raise UnfoldableNode("not a module", value) + + # not super type-safe but we don't care. just catch AttributeErrors + # and move on + try: + module_t = namespace[value.id].module_t + + for module_name in path: + module_t = module_t.members[module_name].module_t + + varinfo = module_t.get_member(node.attr, node) + + return varinfo.decl_node.value.get_folded_value() + except (VyperException, AttributeError): + raise UnfoldableNode("not a module") + def visit_UnaryOp(self, node): operand = node.operand.get_folded_value() diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 2972ed2917..100819526b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -23,9 +23,9 @@ ) from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace @@ -51,8 +51,6 @@ def validate_semantics_r( """ validate_literal_nodes(module_ast) - pre_typecheck(module_ast) - # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -140,6 +138,9 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + # we can resolve constants after imports are handled. + constant_fold(self.ast) + # keep trying to process all the nodes until we finish or can # no longer progress. this makes it so we don't need to # calculate a dependency tree between top-level items.