Skip to content

Commit

Permalink
fix tests; relax type propagation in folding builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 21, 2023
1 parent 0f2302b commit 4256dca
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
68 changes: 34 additions & 34 deletions tests/unit/ast/test_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vyper.semantics import validate_semantics


def test_integration():
def test_integration(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -22,13 +22,13 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.fold(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)


def test_replace_binop_simple():
def test_replace_binop_simple(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -44,13 +44,13 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.replace_literal_ops(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)


def test_replace_binop_nested():
def test_replace_binop_nested(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -65,57 +65,57 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.replace_literal_ops(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)


def test_replace_binop_nested_intermediate_overflow():
def test_replace_binop_nested_intermediate_overflow(dummy_input_bundle):
test = """
@external
def foo():
a: uint256 = 2**255 * 2 / 10
"""
test_ast = vy_ast.parse_to_ast(test)
with pytest.raises(OverflowException):
validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)


def test_replace_binop_nested_intermediate_underflow():
def test_replace_binop_nested_intermediate_underflow(dummy_input_bundle):
test = """
@external
def foo():
a: int256 = -2**255 * 2 - 10 + 100
"""
test_ast = vy_ast.parse_to_ast(test)
with pytest.raises(InvalidType):
validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)


def test_replace_decimal_nested_intermediate_overflow():
def test_replace_decimal_nested_intermediate_overflow(dummy_input_bundle):
test = """
@external
def foo():
a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10
"""
test_ast = vy_ast.parse_to_ast(test)
with pytest.raises(OverflowException):
validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)


def test_replace_decimal_nested_intermediate_underflow():
def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle):
test = """
@external
def foo():
a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10
"""
test_ast = vy_ast.parse_to_ast(test)
with pytest.raises(OverflowException):
validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)


def test_replace_literal_ops():
def test_replace_literal_ops(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -130,13 +130,13 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.replace_literal_ops(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)


def test_replace_subscripts_simple():
def test_replace_subscripts_simple(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -151,13 +151,13 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.replace_subscripts(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)


def test_replace_subscripts_nested():
def test_replace_subscripts_nested(dummy_input_bundle):
test = """
@external
def foo():
Expand All @@ -172,7 +172,7 @@ def foo():
test_ast = vy_ast.parse_to_ast(test)
expected_ast = vy_ast.parse_to_ast(expected)

validate_semantics(test_ast, {})
validate_semantics(test_ast, dummy_input_bundle)
folding.replace_subscripts(test_ast)

assert vy_ast.compare_nodes(test_ast, expected_ast)
Expand Down Expand Up @@ -240,11 +240,11 @@ def bar(x: DynArray[uint256, 4]):


@pytest.mark.parametrize("source", constants_modified)
def test_replace_constant(source):
def test_replace_constant(dummy_input_bundle, source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

validate_semantics(folded_ast, {})
validate_semantics(folded_ast, dummy_input_bundle)
folding.replace_user_defined_constants(folded_ast)

assert not vy_ast.compare_nodes(unmodified_ast, folded_ast)
Expand Down Expand Up @@ -321,11 +321,11 @@ def foo():


@pytest.mark.parametrize("source", constants_unmodified)
def test_replace_constant_no(source):
def test_replace_constant_no(dummy_input_bundle, source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

validate_semantics(folded_ast, {})
validate_semantics(folded_ast, dummy_input_bundle)
folding.replace_user_defined_constants(folded_ast)

assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
Expand Down Expand Up @@ -367,13 +367,13 @@ def foo() -> int128:


@pytest.mark.parametrize("source", userdefined_modified)
def test_replace_userdefined_constant(source):
def test_replace_userdefined_constant(dummy_input_bundle, source):
source = f"FOO: constant(int128) = 42\n{source}"

unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

validate_semantics(folded_ast, {})
validate_semantics(folded_ast, dummy_input_bundle)
folding.replace_user_defined_constants(folded_ast)

assert not vy_ast.compare_nodes(unmodified_ast, folded_ast)
Expand All @@ -397,13 +397,13 @@ def foo():


@pytest.mark.parametrize("source", userdefined_attributes)
def test_replace_userdefined_attribute(source):
def test_replace_userdefined_attribute(dummy_input_bundle, source):
preamble = f"ADDR: constant(address) = {dummy_address}"
l_source = f"{preamble}\n{source[0]}"
r_source = f"{preamble}\n{source[1]}"

l_ast = vy_ast.parse_to_ast(l_source)
validate_semantics(l_ast, {})
validate_semantics(l_ast, dummy_input_bundle)
folding.replace_user_defined_constants(l_ast)

r_ast = vy_ast.parse_to_ast(r_source)
Expand All @@ -428,7 +428,7 @@ def foo():


@pytest.mark.parametrize("source", userdefined_struct)
def test_replace_userdefined_struct(source):
def test_replace_userdefined_struct(dummy_input_bundle, source):
preamble = """
struct Foo:
a: uint256
Expand All @@ -440,7 +440,7 @@ def test_replace_userdefined_struct(source):
r_source = f"{preamble}\n{source[1]}"

l_ast = vy_ast.parse_to_ast(l_source)
validate_semantics(l_ast, {})
validate_semantics(l_ast, dummy_input_bundle)
folding.replace_user_defined_constants(l_ast)

r_ast = vy_ast.parse_to_ast(r_source)
Expand All @@ -465,7 +465,7 @@ def foo():


@pytest.mark.parametrize("source", userdefined_nested_struct)
def test_replace_userdefined_nested_struct(source):
def test_replace_userdefined_nested_struct(dummy_input_bundle, source):
preamble = """
struct Bar:
b1: uint256
Expand All @@ -481,7 +481,7 @@ def test_replace_userdefined_nested_struct(source):
r_source = f"{preamble}\n{source[1]}"

l_ast = vy_ast.parse_to_ast(l_source)
validate_semantics(l_ast, {})
validate_semantics(l_ast, dummy_input_bundle)
folding.replace_user_defined_constants(l_ast)

r_ast = vy_ast.parse_to_ast(r_source)
Expand Down Expand Up @@ -515,11 +515,11 @@ def foo(bar: int256 = {}):

@pytest.mark.parametrize("source", builtin_folding_sources)
@pytest.mark.parametrize("original,result", builtin_folding_functions)
def test_replace_builtins(source, original, result):
def test_replace_builtins(dummy_input_bundle, source, original, result):
original_ast = vy_ast.parse_to_ast(source.format(original))
target_ast = vy_ast.parse_to_ast(source.format(result))

validate_semantics(original_ast, {})
validate_semantics(original_ast, dummy_input_bundle)
folding.replace_builtin_functions(original_ast)

assert vy_ast.compare_nodes(original_ast, target_ast)
3 changes: 2 additions & 1 deletion vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
except UnfoldableNode:
continue

new_node._metadata["type"] = node._metadata["type"]
if "type" in node._metadata:
new_node._metadata["type"] = node._metadata["type"]

changed_nodes += 1
vyper_module.replace_in_tree(node, new_node)
Expand Down

0 comments on commit 4256dca

Please sign in to comment.