From 45a225c438918f21611d4c91de887052f020b002 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 3 Apr 2024 08:57:59 -0400 Subject: [PATCH] fix[lang]: allow type expressions inside pure functions (#3906) 20432c505c706ed introduced a regression where type expressions like the following would raise a compiler error instead of successfully compiling: ``` @pure def f(): convert(..., uint256) # raises `not a variable or literal: 'uint256'` ``` the reason is because `get_expr_info` is called on `uint256`, which is not a regular expr. this commit introduces a fastpath return to address the issue. longer-term, we should generalize the rules in `vyper/semantics/analysis/local.py` so that AST traversal does not progress into type expressions. --- .../codegen/features/decorators/test_pure.py | 13 +++++++++++++ vyper/semantics/analysis/local.py | 9 ++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/functional/codegen/features/decorators/test_pure.py b/tests/functional/codegen/features/decorators/test_pure.py index 7c49c2091b..fb081f62f8 100644 --- a/tests/functional/codegen/features/decorators/test_pure.py +++ b/tests/functional/codegen/features/decorators/test_pure.py @@ -175,6 +175,19 @@ def foo() -> uint256: compile_code(code) +def test_type_in_pure(get_contract): + code = """ +@pure +@external +def _convert(x: bytes32) -> uint256: + return convert(x, uint256) + """ + c = get_contract(code) + x = 123456 + bs = x.to_bytes(32, "big") + assert x == c._convert(bs) + + def test_invalid_conflicting_decorators(): code = """ @pure diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 1b2e3252c8..b0a6e38d10 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -165,7 +165,10 @@ def _validate_msg_value_access(node: vy_ast.Attribute) -> None: raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) -def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name) -> None: +def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + return + info = get_expr_info(node) env_vars = CONSTANT_ENVIRONMENT_VARS @@ -705,7 +708,7 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_value_access(node) if self.func and self.func.mutability == StateMutability.PURE: - _validate_pure_access(node) + _validate_pure_access(node, typ) value_type = get_exact_type_from_node(node.value) @@ -886,7 +889,7 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: if self.func and self.func.mutability == StateMutability.PURE: - _validate_pure_access(node) + _validate_pure_access(node, typ) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T):