diff --git a/tests/unit/ast/test_folding.py b/tests/unit/ast/test_folding.py index 6564da1c3b..8347fa90dd 100644 --- a/tests/unit/ast/test_folding.py +++ b/tests/unit/ast/test_folding.py @@ -6,7 +6,7 @@ from vyper.semantics import validate_semantics -def test_integration(): +def test_integration(dummy_input_bundle): test = """ @external def foo(): @@ -22,13 +22,13 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.fold(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_simple(): +def test_replace_binop_simple(dummy_input_bundle): test = """ @external def foo(): @@ -44,13 +44,13 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_nested(): +def test_replace_binop_nested(dummy_input_bundle): test = """ @external def foo(): @@ -65,13 +65,13 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_binop_nested_intermediate_overflow(): +def test_replace_binop_nested_intermediate_overflow(dummy_input_bundle): test = """ @external def foo(): @@ -79,10 +79,10 @@ def foo(): """ test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_binop_nested_intermediate_underflow(): +def test_replace_binop_nested_intermediate_underflow(dummy_input_bundle): test = """ @external def foo(): @@ -90,10 +90,10 @@ def foo(): """ test_ast = vy_ast.parse_to_ast(test) with pytest.raises(InvalidType): - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_decimal_nested_intermediate_overflow(): +def test_replace_decimal_nested_intermediate_overflow(dummy_input_bundle): test = """ @external def foo(): @@ -101,10 +101,10 @@ def foo(): """ test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_decimal_nested_intermediate_underflow(): +def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): test = """ @external def foo(): @@ -112,10 +112,10 @@ def foo(): """ test_ast = vy_ast.parse_to_ast(test) with pytest.raises(OverflowException): - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) -def test_replace_literal_ops(): +def test_replace_literal_ops(dummy_input_bundle): test = """ @external def foo(): @@ -130,13 +130,13 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_subscripts_simple(): +def test_replace_subscripts_simple(dummy_input_bundle): test = """ @external def foo(): @@ -151,13 +151,13 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) -def test_replace_subscripts_nested(): +def test_replace_subscripts_nested(dummy_input_bundle): test = """ @external def foo(): @@ -172,7 +172,7 @@ def foo(): test_ast = vy_ast.parse_to_ast(test) expected_ast = vy_ast.parse_to_ast(expected) - validate_semantics(test_ast, {}) + validate_semantics(test_ast, dummy_input_bundle) folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast) @@ -240,11 +240,11 @@ def bar(x: DynArray[uint256, 4]): @pytest.mark.parametrize("source", constants_modified) -def test_replace_constant(source): +def test_replace_constant(dummy_input_bundle, source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - validate_semantics(folded_ast, {}) + validate_semantics(folded_ast, dummy_input_bundle) folding.replace_user_defined_constants(folded_ast) assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) @@ -321,11 +321,11 @@ def foo(): @pytest.mark.parametrize("source", constants_unmodified) -def test_replace_constant_no(source): +def test_replace_constant_no(dummy_input_bundle, source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - validate_semantics(folded_ast, {}) + validate_semantics(folded_ast, dummy_input_bundle) folding.replace_user_defined_constants(folded_ast) assert vy_ast.compare_nodes(unmodified_ast, folded_ast) @@ -367,13 +367,13 @@ def foo() -> int128: @pytest.mark.parametrize("source", userdefined_modified) -def test_replace_userdefined_constant(source): +def test_replace_userdefined_constant(dummy_input_bundle, source): source = f"FOO: constant(int128) = 42\n{source}" unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - validate_semantics(folded_ast, {}) + validate_semantics(folded_ast, dummy_input_bundle) folding.replace_user_defined_constants(folded_ast) assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) @@ -397,13 +397,13 @@ def foo(): @pytest.mark.parametrize("source", userdefined_attributes) -def test_replace_userdefined_attribute(source): +def test_replace_userdefined_attribute(dummy_input_bundle, source): preamble = f"ADDR: constant(address) = {dummy_address}" l_source = f"{preamble}\n{source[0]}" r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, {}) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -428,7 +428,7 @@ def foo(): @pytest.mark.parametrize("source", userdefined_struct) -def test_replace_userdefined_struct(source): +def test_replace_userdefined_struct(dummy_input_bundle, source): preamble = """ struct Foo: a: uint256 @@ -440,7 +440,7 @@ def test_replace_userdefined_struct(source): r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, {}) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -465,7 +465,7 @@ def foo(): @pytest.mark.parametrize("source", userdefined_nested_struct) -def test_replace_userdefined_nested_struct(source): +def test_replace_userdefined_nested_struct(dummy_input_bundle, source): preamble = """ struct Bar: b1: uint256 @@ -481,7 +481,7 @@ def test_replace_userdefined_nested_struct(source): r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, {}) + validate_semantics(l_ast, dummy_input_bundle) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) @@ -515,11 +515,11 @@ def foo(bar: int256 = {}): @pytest.mark.parametrize("source", builtin_folding_sources) @pytest.mark.parametrize("original,result", builtin_folding_functions) -def test_replace_builtins(source, original, result): +def test_replace_builtins(dummy_input_bundle, source, original, result): original_ast = vy_ast.parse_to_ast(source.format(original)) target_ast = vy_ast.parse_to_ast(source.format(result)) - validate_semantics(original_ast, {}) + validate_semantics(original_ast, dummy_input_bundle) folding.replace_builtin_functions(original_ast) assert vy_ast.compare_nodes(original_ast, target_ast) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 0863dba174..51a58f0bb8 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -121,7 +121,8 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue - new_node._metadata["type"] = node._metadata["type"] + if "type" in node._metadata: + new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 vyper_module.replace_in_tree(node, new_node)