Skip to content

Commit

Permalink
feat: add support for constants in imported modules (vyperlang#3726)
Browse files Browse the repository at this point in the history
- add a case to ConstantFolder, and move the constant folder slightly
  down in the pipeline to after imports have been resolved.
- rename `pre_typecheck` to `constant_fold` (since that is all it does,
  and it doesn't strictly happen before typechecking anymore).
  • Loading branch information
charles-cooper authored Jan 13, 2024
1 parent 785f09d commit 3f013ec
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 10 deletions.
78 changes: 78 additions & 0 deletions tests/functional/codegen/modules/test_module_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def test_module_constant(make_input_bundle, get_contract):
mod1 = """
X: constant(uint256) = 12345
"""
contract = """
import mod1
@external
def foo() -> uint256:
return mod1.X
"""

input_bundle = make_input_bundle({"mod1.vy": mod1})

c = get_contract(contract, input_bundle=input_bundle)

assert c.foo() == 12345


def test_nested_module_constant(make_input_bundle, get_contract):
# test nested module constants
# test at least 3 modules deep to test the `path.reverse()` gizmo
# in ConstantFolder.visit_Attribute()
mod1 = """
X: constant(uint256) = 12345
"""
mod2 = """
import mod1
X: constant(uint256) = 54321
"""
mod3 = """
import mod2
X: constant(uint256) = 98765
"""

contract = """
import mod1
import mod2
import mod3
@external
def test_foo() -> bool:
assert mod1.X == 12345
assert mod2.X == 54321
assert mod3.X == 98765
assert mod2.mod1.X == mod1.X
assert mod3.mod2.mod1.X == mod1.X
assert mod3.mod2.X == mod2.X
return True
"""

input_bundle = make_input_bundle({"mod1.vy": mod1, "mod2.vy": mod2, "mod3.vy": mod3})

c = get_contract(contract, input_bundle=input_bundle)
assert c.test_foo() is True


def test_import_constant_array(make_input_bundle, get_contract, tx_failed):
mod1 = """
X: constant(uint256[3]) = [1,2,3]
"""
contract = """
import mod1
@external
def foo(ix: uint256) -> uint256:
return mod1.X[ix]
"""

input_bundle = make_input_bundle({"mod1.vy": mod1})

c = get_contract(contract, input_bundle=input_bundle)

assert c.foo(0) == 1
assert c.foo(1) == 2
assert c.foo(2) == 3
with tx_failed():
c.foo(3)
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from vyper import ast as vy_ast
from vyper.semantics.analysis.pre_typecheck import pre_typecheck
from vyper.semantics.analysis.constant_folding import constant_fold


@contextlib.contextmanager
Expand All @@ -17,5 +17,5 @@ def working_directory(directory):

def parse_and_fold(source_code):
ast = vy_ast.parse_to_ast(source_code)
pre_typecheck(ast)
constant_fold(ast)
return ast
2 changes: 1 addition & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def has_folded_value(self):
"""
return "folded_value" in self._metadata

def get_folded_value(self) -> "VyperNode":
def get_folded_value(self) -> "ExprNode":
"""
Attempt to get the folded value, bubbling up UnfoldableNode if the node
is not foldable.
Expand Down
4 changes: 2 additions & 2 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class VyperNode:
def has_folded_value(self): ...
@classmethod
def get_fields(cls: Any) -> set: ...
def get_folded_value(self) -> VyperNode: ...
def _set_folded_value(self, node: VyperNode) -> None: ...
def get_folded_value(self) -> ExprNode: ...
def _set_folded_value(self, node: ExprNode) -> None: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
def to_dict(self) -> dict: ...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from vyper import ast as vy_ast
from vyper.exceptions import InvalidLiteral, UnfoldableNode
from vyper.exceptions import InvalidLiteral, UnfoldableNode, VyperException
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.namespace import get_namespace


def pre_typecheck(module_ast: vy_ast.Module):
def constant_fold(module_ast: vy_ast.Module):
ConstantFolder(module_ast).run()


Expand Down Expand Up @@ -89,6 +89,33 @@ def visit_Name(self, node) -> vy_ast.ExprNode:
except KeyError:
raise UnfoldableNode("unknown name", node)

def visit_Attribute(self, node) -> vy_ast.ExprNode:
namespace = get_namespace()
path = []
value = node.value
while isinstance(value, vy_ast.Attribute):
path.append(value.attr)
value = value.value

path.reverse()

if not isinstance(value, vy_ast.Name):
raise UnfoldableNode("not a module", value)

# not super type-safe but we don't care. just catch AttributeErrors
# and move on
try:
module_t = namespace[value.id].module_t

for module_name in path:
module_t = module_t.members[module_name].module_t

varinfo = module_t.get_member(node.attr, node)

return varinfo.decl_node.value.get_folded_value()
except (VyperException, AttributeError):
raise UnfoldableNode("not a module")

def visit_UnaryOp(self, node):
operand = node.operand.get_folded_value()

Expand Down
7 changes: 4 additions & 3 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
)
from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.constant_folding import constant_fold
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, validate_functions
from vyper.semantics.analysis.pre_typecheck import pre_typecheck
from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
Expand All @@ -51,8 +51,6 @@ def validate_semantics_r(
"""
validate_literal_nodes(module_ast)

pre_typecheck(module_ast)

# validate semantics and annotate AST with type/semantics information
namespace = get_namespace()

Expand Down Expand Up @@ -140,6 +138,9 @@ def analyze(self) -> ModuleT:
self.visit(node)
to_visit.remove(node)

# we can resolve constants after imports are handled.
constant_fold(self.ast)

# keep trying to process all the nodes until we finish or can
# no longer progress. this makes it so we don't need to
# calculate a dependency tree between top-level items.
Expand Down

0 comments on commit 3f013ec

Please sign in to comment.