From 6c220b02b4d5e5e84d98508cb7979453a4384e7b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:09:12 +0800 Subject: [PATCH] stricter replace rules in folding --- tests/ast/test_folding.py | 5 +++-- vyper/ast/folding.py | 21 ++++++++------------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/tests/ast/test_folding.py b/tests/ast/test_folding.py index aab0a7fae3..28e2e00fa6 100644 --- a/tests/ast/test_folding.py +++ b/tests/ast/test_folding.py @@ -4,6 +4,7 @@ from vyper.ast import folding from vyper.exceptions import OverflowException from vyper.semantics import validate_semantics +from vyper.semantics.types.shortcuts import UINT256_T def test_integration(): @@ -175,7 +176,7 @@ def test_replace_constant(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) + folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True) assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) @@ -198,7 +199,7 @@ def test_replace_constant_no(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) + folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True) assert vy_ast.compare_nodes(unmodified_ast, folded_ast) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 4f128aa465..aca96a4595 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Union from vyper.ast import nodes as vy_ast from vyper.builtins.functions import DISPATCH_TABLE @@ -161,9 +161,7 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: continue type_ = node._metadata["type"] - changed_nodes += replace_constant( - vyper_module, node.target.id, node.value, False, type_=type_ - ) + changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False) return changed_nodes @@ -171,18 +169,16 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: # TODO constant folding on log events -def _replace(old_node, new_node, type_=None): +def _replace(old_node, new_node, type_): if isinstance(new_node, vy_ast.Constant): new_node = new_node.from_node(old_node, value=new_node.value) - if type_ is not None: - new_node._metadata["type"] = type_ + new_node._metadata["type"] = type_ return new_node elif isinstance(new_node, vy_ast.List): base_type = type_.value_type if type_ else None list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] new_node = new_node.from_node(old_node, elements=list_values) - if type_ is not None: - new_node._metadata["type"] = type_ + new_node._metadata["type"] = type_ return new_node elif isinstance(new_node, vy_ast.Call): # Replace `Name` node with `Call` node @@ -194,8 +190,7 @@ def _replace(old_node, new_node, type_=None): new_node = new_node.from_node( old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords ) - if type_ is not None: - new_node._metadata["type"] = type_ + new_node._metadata["type"] = type_ return new_node else: raise UnfoldableNode @@ -205,8 +200,8 @@ def replace_constant( vyper_module: vy_ast.Module, id_: str, replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], + type_: VyperType, raise_on_error: bool, - type_: Optional[VyperType] = None, ) -> int: """ Replace references to a variable name with a literal value. @@ -259,7 +254,7 @@ def replace_constant( try: # note: _replace creates a copy of the replacement_node - new_node = _replace(node, replacement_node, type_=type_) + new_node = _replace(node, replacement_node, type_) except UnfoldableNode: if raise_on_error: raise