Skip to content

Commit

Permalink
stricter replace rules in folding
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Oct 23, 2023
1 parent d508d3f commit 6c220b0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
5 changes: 3 additions & 2 deletions tests/ast/test_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vyper.ast import folding
from vyper.exceptions import OverflowException
from vyper.semantics import validate_semantics
from vyper.semantics.types.shortcuts import UINT256_T


def test_integration():
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_replace_constant(source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True)
folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True)

assert not vy_ast.compare_nodes(unmodified_ast, folded_ast)

Expand All @@ -198,7 +199,7 @@ def test_replace_constant_no(source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True)
folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True)

assert vy_ast.compare_nodes(unmodified_ast, folded_ast)

Expand Down
21 changes: 8 additions & 13 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union

from vyper.ast import nodes as vy_ast
from vyper.builtins.functions import DISPATCH_TABLE
Expand Down Expand Up @@ -161,28 +161,24 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int:
continue

type_ = node._metadata["type"]
changed_nodes += replace_constant(
vyper_module, node.target.id, node.value, False, type_=type_
)
changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False)

return changed_nodes


# TODO constant folding on log events


def _replace(old_node, new_node, type_=None):
def _replace(old_node, new_node, type_):
if isinstance(new_node, vy_ast.Constant):
new_node = new_node.from_node(old_node, value=new_node.value)
if type_ is not None:
new_node._metadata["type"] = type_
new_node._metadata["type"] = type_
return new_node
elif isinstance(new_node, vy_ast.List):
base_type = type_.value_type if type_ else None
list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements]
new_node = new_node.from_node(old_node, elements=list_values)
if type_ is not None:
new_node._metadata["type"] = type_
new_node._metadata["type"] = type_
return new_node
elif isinstance(new_node, vy_ast.Call):
# Replace `Name` node with `Call` node
Expand All @@ -194,8 +190,7 @@ def _replace(old_node, new_node, type_=None):
new_node = new_node.from_node(
old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords
)
if type_ is not None:
new_node._metadata["type"] = type_
new_node._metadata["type"] = type_
return new_node
else:
raise UnfoldableNode
Expand All @@ -205,8 +200,8 @@ def replace_constant(
vyper_module: vy_ast.Module,
id_: str,
replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call],
type_: VyperType,
raise_on_error: bool,
type_: Optional[VyperType] = None,
) -> int:
"""
Replace references to a variable name with a literal value.
Expand Down Expand Up @@ -259,7 +254,7 @@ def replace_constant(

try:
# note: _replace creates a copy of the replacement_node
new_node = _replace(node, replacement_node, type_=type_)
new_node = _replace(node, replacement_node, type_)
except UnfoldableNode:
if raise_on_error:
raise
Expand Down

0 comments on commit 6c220b0

Please sign in to comment.