diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 25dc1f1a1e..fcf71f12f0 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -3,7 +3,13 @@ import pytest -from vyper.exceptions import DecimalOverrideException, InvalidOperation, TypeMismatch +from vyper import compile_code +from vyper.exceptions import ( + DecimalOverrideException, + InvalidOperation, + OverflowException, + TypeMismatch, +) from vyper.utils import DECIMAL_EPSILON, SizeLimits @@ -24,23 +30,25 @@ def test_decimal_override(): @pytest.mark.parametrize("op", ["**", "&", "|", "^"]) -def test_invalid_ops(get_contract, assert_compile_failed, op): +def test_invalid_ops(op): code = f""" @external def foo(x: decimal, y: decimal) -> decimal: return x {op} y """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, op): +def test_invalid_unary_ops(op): code = f""" @external def foo(x: decimal) -> decimal: return {op} x """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) def quantize(x: Decimal) -> Decimal: @@ -263,11 +271,32 @@ def bar(num: decimal) -> decimal: assert c.bar(Decimal("1e37")) == Decimal("-9e37") # Math lines up -def test_exponents(assert_compile_failed, get_contract): +def test_exponents(): code = """ @external def foo() -> decimal: return 2.2 ** 2.0 """ - assert_compile_failed(lambda: get_contract(code), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code) + + +def test_decimal_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) + + +def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): + code = """ +@external +def foo(): + a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 52de5b649f..46fdfcafac 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -4,6 +4,7 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT from vyper.utils import evm_div, evm_mod @@ -206,17 +207,16 @@ def _num_min() -> {typ}: @pytest.mark.parametrize("typ", types) -def test_overflow_out_of_range(get_contract, assert_compile_failed, typ): +def test_overflow_out_of_range(get_contract, typ): code = f""" @external def num_sub() -> {typ}: return 1-2**{typ.bits} """ - if typ.bits == 256: - assert_compile_failed(lambda: get_contract(code), OverflowException) - else: - assert_compile_failed(lambda: get_contract(code), InvalidType) + exc = OverflowException if typ.bits == 256 else InvalidType + with pytest.raises(exc): + compile_code(code) ARITHMETIC_OPS = { @@ -231,7 +231,7 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -318,10 +318,12 @@ def foo() -> {typ}: elif div_by_zero: with tx_failed(): c.foo(x, y) - assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_2) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_4) else: with tx_failed(): c.foo(x, y) @@ -329,9 +331,8 @@ def foo() -> {typ}: get_contract(code_2).foo(x) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed( - lambda code=code_4: get_contract(code), (InvalidType, OverflowException) - ) + with pytest.raises((InvalidType, OverflowException)): + compile_code(code_4) COMPARISON_OPS = { @@ -413,10 +414,21 @@ def foo(a: {typ}) -> {typ}: @pytest.mark.parametrize("typ", types) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): +def test_invalid_unary_ops(typ, op): code = f""" @external def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_underflow(): + code = """ +@external +def foo(): + a: int256 = -2**255 * 2 - 10 + 100 + """ + with pytest.raises(InvalidType): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 8982065b5d..ed489c634d 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -4,6 +4,7 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT from vyper.utils import evm_div, evm_mod @@ -85,7 +86,7 @@ def foo(x: {typ}) -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -204,7 +205,7 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_uint_literal(get_contract, assert_compile_failed, typ): +def test_uint_literal(get_contract, typ): lo, hi = typ.ast_bounds good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] @@ -221,7 +222,8 @@ def test() -> {typ}: assert c.test() == val for val in bad_cases: - assert_compile_failed(lambda v=val: get_contract(code_template.format(typ=typ, val=v))) + with pytest.raises(): + compile_code(code_template.format(typ=typ, val=val)) @pytest.mark.parametrize("typ", types) @@ -232,4 +234,15 @@ def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: uint256 = 2**255 * 2 / 10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index 0a586e1704..b2a6d6be54 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -74,8 +74,8 @@ def foo({input_value}) -> decimal: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.fold() + expected = new_node.value is_valid = -(2**127) <= expected < 2**127 except (OverflowException, ZeroDivisionException): # for overflow or division/modulus by 0, expect the contract call to revert diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index c603daee46..407da7a720 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -115,8 +115,8 @@ def foo({input_value}) -> int128: vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.fold() + expected = new_node.value is_valid = True except ZeroDivisionException: is_valid = False diff --git a/tests/unit/ast/nodes/test_fold_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py index 5de4b60bda..a496f84b18 100644 --- a/tests/unit/ast/nodes/test_fold_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -53,7 +53,7 @@ def foo({input_value}) -> bool: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.fold() + expected = new_node.value assert contract.foo(*values) == expected diff --git a/tests/unit/ast/nodes/test_fold_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py index dc447955ed..e40496e3d2 100644 --- a/tests/unit/ast/nodes/test_fold_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -31,7 +31,7 @@ def foo(a: bool) -> bool: literal_op = f"{'not ' * count}{bool_cond}" vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.fold() + expected = new_node.value assert contract.foo(bool_cond) == expected diff --git a/tests/unit/ast/test_folding.py b/tests/unit/ast/test_folding.py deleted file mode 100644 index 8347fa90dd..0000000000 --- a/tests/unit/ast/test_folding.py +++ /dev/null @@ -1,525 +0,0 @@ -import pytest - -from vyper import ast as vy_ast -from vyper.ast import folding -from vyper.exceptions import InvalidType, OverflowException -from vyper.semantics import validate_semantics - - -def test_integration(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = [1+2, 6+7][8-8] - """ - - expected = """ -@external -def foo(): - a: uint256 = 3 - """ - - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.fold(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_simple(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = 1 + 2 - """ - - expected = """ -@external -def foo(): - a: uint256 = 3 - """ - - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = ((6 + (2**4)) * 4) / 2 - """ - - expected = """ -@external -def foo(): - a: uint256 = 44 - """ - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested_intermediate_overflow(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = 2**255 * 2 / 10 - """ - test_ast = vy_ast.parse_to_ast(test) - with pytest.raises(OverflowException): - validate_semantics(test_ast, dummy_input_bundle) - - -def test_replace_binop_nested_intermediate_underflow(dummy_input_bundle): - test = """ -@external -def foo(): - a: int256 = -2**255 * 2 - 10 + 100 - """ - test_ast = vy_ast.parse_to_ast(test) - with pytest.raises(InvalidType): - validate_semantics(test_ast, dummy_input_bundle) - - -def test_replace_decimal_nested_intermediate_overflow(dummy_input_bundle): - test = """ -@external -def foo(): - a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10 - """ - test_ast = vy_ast.parse_to_ast(test) - with pytest.raises(OverflowException): - validate_semantics(test_ast, dummy_input_bundle) - - -def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): - test = """ -@external -def foo(): - a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10 - """ - test_ast = vy_ast.parse_to_ast(test) - with pytest.raises(OverflowException): - validate_semantics(test_ast, dummy_input_bundle) - - -def test_replace_literal_ops(dummy_input_bundle): - test = """ -@external -def foo(): - a: bool[3] = [not True, True and False, True or False] - """ - - expected = """ -@external -def foo(): - a: bool[3] = [False, False, True] - """ - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_simple(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = [1, 2, 3][1] - """ - - expected = """ -@external -def foo(): - a: uint256 = 2 - """ - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_nested(dummy_input_bundle): - test = """ -@external -def foo(): - a: uint256 = [[0, 1], [2, 3], [4, 5]][2][1] - """ - - expected = """ -@external -def foo(): - a: uint256 = 5 - """ - test_ast = vy_ast.parse_to_ast(test) - expected_ast = vy_ast.parse_to_ast(expected) - - validate_semantics(test_ast, dummy_input_bundle) - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -constants_modified = [ - """ -FOO: constant(uint256) = 4 - -@external -def foo(): - bar: uint256 = 1 - bar = FOO - """, - """ -FOO: constant(uint256) = 4 -bar: int128[FOO] - """, - """ -FOO: constant(uint256) = 4 - -@external -def foo(): - a: uint256[3] = [1, 2, FOO] - """, - """ -FOO: constant(uint256) = 4 -@external -def bar(a: uint256 = FOO): - pass - """, - """ -FOO: constant(uint256) = 4 - -event bar: - a: uint256 - -@external -def foo(): - log bar(FOO) - """, - """ -FOO: constant(uint256) = 4 - -@external -def foo(): - a: uint256 = FOO + 1 - """, - """ -FOO: constant(uint256) = 4 - -@external -def foo(): - a: int128[FOO / 2] = [1, 2] - """, - """ -FOO: constant(uint256) = 4 - -@external -def bar(x: DynArray[uint256, 4]): - a: DynArray[uint256, 4] = x - a[FOO - 1] = 44 - """, -] - - -@pytest.mark.parametrize("source", constants_modified) -def test_replace_constant(dummy_input_bundle, source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - validate_semantics(folded_ast, dummy_input_bundle) - folding.replace_user_defined_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -constants_unmodified = [ - """ -FOO: immutable(uint256) - -@external -def __init__(): - FOO = 42 - """, - """ -FOO: uint256 - -@external -def foo(): - self.FOO = 42 - """, - """ -bar: uint256 - -@internal -def FOO() -> uint256: - return 123 - -@external -def foo(): - bar: uint256 = 456 - bar = self.FOO() - """, - """ -@internal -def FOO(): - pass - -@external -def foo(): - self.FOO() - """, - """ -FOO: uint256 - -@external -def foo(): - bar: uint256 = 1 - bar = self.FOO - """, - """ -event FOO: - a: uint256 - -@external -def foo(bar: uint256): - log FOO(bar) - """, - """ -@internal -def FOO() -> uint256: - return 3 - -@external -def foo(): - a: uint256[3] = [1, 2, self.FOO()] - """, - """ -@external -def foo(): - FOO: DynArray[uint256, 5] = [1, 2, 3, 4, 5] - FOO[4] = 2 - """, -] - - -@pytest.mark.parametrize("source", constants_unmodified) -def test_replace_constant_no(dummy_input_bundle, source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - validate_semantics(folded_ast, dummy_input_bundle) - folding.replace_user_defined_constants(folded_ast) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -userdefined_modified = [ - """ -@external -def foo(): - foo: int128 = FOO - """, - """ -@external -def foo(): - foo: DynArray[int128, FOO] = [] - """, - """ -@external -def foo(): - foo: int128[1] = [FOO] - """, - """ -@external -def foo(): - foo: int128 = 3 - foo += FOO - """, - """ -@external -def foo(bar: int128 = FOO): - pass - """, - """ -@external -def foo() -> int128: - return FOO - """, -] - - -@pytest.mark.parametrize("source", userdefined_modified) -def test_replace_userdefined_constant(dummy_input_bundle, source): - source = f"FOO: constant(int128) = 42\n{source}" - - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - validate_semantics(folded_ast, dummy_input_bundle) - folding.replace_user_defined_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -dummy_address = "0x000000000000000000000000000000000000dEaD" -userdefined_attributes = [ - ( - """ -@external -def foo(): - b: uint256 = ADDR.balance - """, - f""" -@external -def foo(): - b: uint256 = {dummy_address}.balance - """, - ) -] - - -@pytest.mark.parametrize("source", userdefined_attributes) -def test_replace_userdefined_attribute(dummy_input_bundle, source): - preamble = f"ADDR: constant(address) = {dummy_address}" - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, dummy_input_bundle) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_struct = [ - ( - """ -@external -def foo(): - b: Foo = FOO - """, - """ -@external -def foo(): - b: Foo = Foo({a: 123, b: 456}) - """, - ) -] - - -@pytest.mark.parametrize("source", userdefined_struct) -def test_replace_userdefined_struct(dummy_input_bundle, source): - preamble = """ -struct Foo: - a: uint256 - b: uint256 - -FOO: constant(Foo) = Foo({a: 123, b: 456}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, dummy_input_bundle) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_nested_struct = [ - ( - """ -@external -def foo(): - b: Foo = FOO - """, - """ -@external -def foo(): - b: Foo = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) - """, - ) -] - - -@pytest.mark.parametrize("source", userdefined_nested_struct) -def test_replace_userdefined_nested_struct(dummy_input_bundle, source): - preamble = """ -struct Bar: - b1: uint256 - b2: uint256 - -struct Foo: - f1: Bar - f2: uint256 - -FOO: constant(Foo) = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - validate_semantics(l_ast, dummy_input_bundle) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -builtin_folding_functions = [("ceil(4.2)", "5"), ("floor(4.2)", "4")] - -builtin_folding_sources = [ - """ -@external -def foo(): - foo: int256 = {} - """, - """ -foo: constant(int256[2]) = [{0}, {0}] - """, - """ -@external -def foo() -> int256: - return {} - """, - """ -@external -def foo(bar: int256 = {}): - pass - """, -] - - -@pytest.mark.parametrize("source", builtin_folding_sources) -@pytest.mark.parametrize("original,result", builtin_folding_functions) -def test_replace_builtins(dummy_input_bundle, source, original, result): - original_ast = vy_ast.parse_to_ast(source.format(original)) - target_ast = vy_ast.parse_to_ast(source.format(result)) - - validate_semantics(original_ast, dummy_input_bundle) - folding.replace_builtin_functions(original_ast) - - assert vy_ast.compare_nodes(original_ast, target_ast) diff --git a/tests/unit/ast/test_natspec.py b/tests/unit/ast/test_natspec.py index c2133468aa..5207a0ce76 100644 --- a/tests/unit/ast/test_natspec.py +++ b/tests/unit/ast/test_natspec.py @@ -60,7 +60,7 @@ def doesEat(food: String[30], qty: uint256) -> bool: def parse_natspec(code): - vyper_ast = CompilerData(code).vyper_module_folded + vyper_ast = CompilerData(code).vyper_module_annotated return vy_ast.parse_natspec(vyper_ast) diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index 4b46801153..bc08626b59 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -17,4 +17,4 @@ # required to avoid circular dependency -from . import expansion, folding # noqa: E402 +from . import expansion # noqa: E402 diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index eac8ffdef5..5581e82fe2 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -1,7 +1,7 @@ import ast as python_ast from typing import Any, Optional, Union -from . import expansion, folding, nodes, validation +from . import expansion, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * from .parse import parse_to_ast as parse_to_ast diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py deleted file mode 100644 index 3eb3e163b1..0000000000 --- a/vyper/ast/folding.py +++ /dev/null @@ -1,260 +0,0 @@ -from typing import Union - -from vyper.ast import nodes as vy_ast -from vyper.builtins.functions import DISPATCH_TABLE -from vyper.exceptions import UnfoldableNode -from vyper.semantics.types.base import VyperType - - -def fold(vyper_module: vy_ast.Module) -> None: - """ - Perform literal folding operations on a Vyper AST. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - changed_nodes = 1 - while changed_nodes: - changed_nodes = 0 - changed_nodes += replace_user_defined_constants(vyper_module) - changed_nodes += replace_literal_ops(vyper_module) - changed_nodes += replace_subscripts(vyper_module) - changed_nodes += replace_builtin_functions(vyper_module) - - -def replace_literal_ops(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate operation and comparison nodes within the Vyper AST, - replacing them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare) - for node in vyper_module.get_descendants(node_types, reverse=True): - try: - new_node = node.fold() - except UnfoldableNode: - continue - - # type may not be available if it is within a type's annotation - # e.g. DynArray[uint256, 2 ** 8] - typ = node._metadata.get("type") - if typ: - new_node._metadata["type"] = node._metadata["type"] - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_subscripts(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate Subscript nodes within the Vyper AST, replacing them with - Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Subscript, reverse=True): - try: - new_node = node.fold() - except UnfoldableNode: - continue - - new_node._metadata["type"] = node._metadata["type"] - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate builtin function calls within the Vyper AST, replacing - them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Call, reverse=True): - if not isinstance(node.func, vy_ast.Name): - continue - - name = node.func.id - func = DISPATCH_TABLE.get(name) - if func is None or not hasattr(func, "fold"): - continue - try: - new_node = func.fold(node) # type: ignore - except UnfoldableNode: - continue - - if "type" in node._metadata: - new_node._metadata["type"] = node._metadata["type"] - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: - """ - Find user-defined constant assignments, and replace references - to the constants with their literal values. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_children(vy_ast.VariableDecl): - if not node.is_constant: - # annotation is not wrapped in `constant(...)` - continue - - # Extract type definition from propagated annotation - type_ = node._metadata["type"] - - changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False) - - return changed_nodes - - -# TODO constant folding on log events - - -def _replace(old_node, new_node, type_): - if isinstance(new_node, vy_ast.Constant): - new_node = new_node.from_node(old_node, value=new_node.value) - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.List): - base_type = type_.value_type - list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] - new_node = new_node.from_node(old_node, elements=list_values) - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.Call): - # Replace `Name` node with `Call` node - keyword = keywords = None - if hasattr(new_node, "keyword"): - keyword = new_node.keyword - if hasattr(new_node, "keywords"): - keywords = new_node.keywords - new_node = new_node.from_node( - old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords - ) - new_node._metadata["type"] = type_ - return new_node - else: - raise UnfoldableNode - - -def replace_constant( - vyper_module: vy_ast.Module, - id_: str, - replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], - type_: VyperType, - raise_on_error: bool, -) -> int: - """ - Replace references to a variable name with a literal value. - - Arguments - --------- - vyper_module : Module - Module-level ast node to perform replacement in. - id_ : str - String representing the `.id` attribute of the node(s) to be replaced. - replacement_node : Constant | List | Call - Vyper ast node representing the literal value to be substituted in. - `Call` nodes are for struct constants. - raise_on_error: bool - Boolean indicating if `UnfoldableNode` exception should be raised or ignored. - type_ : VyperType, optional - Type definition to be propagated to type checker. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Name, {"id": id_}, reverse=True): - parent = node.get_ancestor() - - if isinstance(parent, vy_ast.Call) and node == parent.func: - # do not replace calls because splicing a constant into a callable site is - # never valid and it worsens the error message - continue - - # do not replace dictionary keys - if isinstance(parent, vy_ast.Dict) and node in parent.keys: - continue - - if not node.get_ancestor(vy_ast.Index): - # do not replace left-hand side of assignments - assign = node.get_ancestor( - (vy_ast.Assign, vy_ast.AnnAssign, vy_ast.AugAssign, vy_ast.VariableDecl) - ) - - if assign and node in assign.target.get_descendants(include_self=True): - continue - - # do not replace enum members - if node.get_ancestor(vy_ast.FlagDef): - continue - - try: - # note: _replace creates a copy of the replacement_node - new_node = _replace(node, replacement_node, type_=type_) - except UnfoldableNode: - if raise_on_error: - raise - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes