diff --git a/docs/resources.rst b/docs/resources.rst index a3dfa480ed..7bb3c99df4 100644 --- a/docs/resources.rst +++ b/docs/resources.rst @@ -24,6 +24,7 @@ Frameworks and tooling - `🐍 snekmate – Vyper smart contract building blocks `_ - `Serpentor – A set of smart contracts tools for governance `_ - `Smart contract development frameworks and tools for Vyper on Ethreum.org `_ +- `Vyper Online Compiler - an online platform for compiling and deploying Vyper smart contracts `_ Security -------- diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 04f908f6d0..597aed57c7 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets: # Calculate refund for sender refund: uint256 = 0 - for i in range(MAX_BIDS): + for i: int128 in range(MAX_BIDS): # Note that loop may break sooner than 128 iterations if i >= _numBids if (i >= _numBids): break diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 30057582e8..e105a79133 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25 assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch" batchBalances: DynArray[uint256, BATCH_SIZE] = [] j: uint256 = 0 - for i in ids: + for i: uint256 in ids: batchBalances.append(self.balanceOf[accounts[j]][i]) j += 1 return batchBalances @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[receiver][ids[i]] += amounts[i] @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[msg.sender][ids[i]] -= amounts[i] @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID" assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break id: uint256 = ids[i] diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 0b568784a9..107716accf 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool: def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 - for i in range(2): + for i: int128 in range(2): self.proposals[i] = Proposal({ name: _proposalNames[i], voteCount: 0 @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address): assert self.voters[delegate_with_weight_to_forward].weight > 0 target: address = self.voters[delegate_with_weight_to_forward].delegate - for i in range(4): + for i: int128 in range(4): if self._delegated(target): target = self.voters[target].delegate # The following effectively detects cycles of length <= 5, @@ -157,7 +157,7 @@ def vote(proposal: int128): def _winningProposal() -> int128: winning_vote_count: int128 = 0 winning_proposal: int128 = 0 - for i in range(2): + for i: int128 in range(2): if self.proposals[i].voteCount > winning_vote_count: winning_vote_count = self.proposals[i].voteCount winning_proposal = i diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index e2515d9e62..231f538ecf 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -14,7 +14,7 @@ seq: public(int128) @external def __init__(_owners: address[5], _threshold: int128): - for i in range(5): + for i: uint256 in range(5): if _owners[i] != empty(address): self.owners[i] = _owners[i] self.threshold = _threshold @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda assert self.seq == _seq # # Iterates through all the owners and verifies that there signatures, # # given as the sigdata argument are correct - for i in range(5): + for i: uint256 in range(5): if sigdata[i][0] != 0: # If an invalid signature is given for an owner then the contract throws assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i] diff --git a/tests/conftest.py b/tests/conftest.py index 51b4b4459a..e673f17b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,14 @@ def pytest_addoption(parser): parser.addoption("--enable-compiler-debug-mode", action="store_true") +@pytest.fixture(scope="module") +def output_formats(): + output_formats = compiler.OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] + return output_formats + + @pytest.fixture(scope="module") def optimize(pytestconfig): flag = pytestconfig.getoption("optimize") @@ -281,7 +289,14 @@ def ir_compiler(ir, *args, **kwargs): def _get_contract( - w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs + w3, + source_code, + optimize, + output_formats, + *args, + override_opt_level=None, + input_bundle=None, + **kwargs, ): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) @@ -289,7 +304,7 @@ def _get_contract( out = compiler.compile_code( source_code, # test that all output formats can get generated - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing @@ -309,17 +324,17 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize): +def get_contract(w3, optimize, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize): +def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, *args, **kwargs) + contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -329,15 +344,15 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize): +def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize): +def get_contract_module(optimize, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -350,18 +365,18 @@ def get_contract_module(optimize): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_module -def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): +def _deploy_blueprint_for(w3, source_code, optimize, output_formats, initcode_prefix=b"", **kwargs): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) settings.optimize = optimize out = compiler.compile_code( source_code, - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) @@ -394,9 +409,9 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize): +def deploy_blueprint_for(w3, optimize, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) + return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) return deploy_blueprint_for diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index c3627785dc..896c845da2 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @internal def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation): @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @external diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index ba82ebd5b8..31de1d9f22 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation): @external def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 - for i in range(256): + for i: uint256 in range(256): o = uint256_mulmod(o, o, modulus) if exponent & shift(1, 255 - i) != 0: o = uint256_mulmod(o, base, modulus) diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index a15a3eeb35..80936bbf82 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation): @external def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: inp: Bytes[50] = inp1 - for i in range(1, 11): + for i: uint256 in range(1, 11): inp = slice(inp, 1, 30 - i * 2) return inp """ diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 68131678fa..c954380def 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -2,8 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType @@ -19,9 +18,9 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"abs({a})") + vyper_ast = parse_and_fold(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a) == new_node.value == abs(a) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 1d789f1655..e6a9fc193f 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -2,8 +2,7 @@ from hypothesis import assume, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) @@ -22,8 +21,8 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") + vyper_ast = parse_and_fold(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 53a6d333a0..c1ff7674bb 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -29,7 +29,7 @@ def foo(a: uint256, b: uint256) -> uint256: contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") + vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -48,10 +48,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case @@ -78,10 +77,9 @@ def foo(a: int256, b: uint256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. @@ -105,7 +103,7 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"~{value}") + vyper_ast = parse_and_fold(f"~{value}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 4f5e9434ec..7bc2afe757 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("typ_name", ["decimal"]) @@ -13,8 +12,8 @@ def foo() -> {typ_name}: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") + vyper_ast = parse_and_fold(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 04921e504e..9e63c7b099 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -4,8 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_decimals = st.decimals( min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 @@ -28,8 +27,8 @@ def foo(a: decimal) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 4287615bab..01af646a16 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.builtins import functions as vy_fn from vyper.utils import SizeLimits @@ -30,9 +30,9 @@ def foo(a: decimal) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -49,8 +49,8 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index 8da420538f..3b5f99891f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~' # NOQA: E501 @@ -20,9 +19,9 @@ def foo(a: String[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") + vyper_ast = parse_and_fold(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -39,9 +38,9 @@ def foo(a: Bytes[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -60,8 +59,8 @@ def foo(a: Bytes[100]) -> bytes32: value = f"0x{value.hex()}" - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index 967f906555..6d59751748 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("length", [0, 1, 32, 33, 64, 65, 1024]) @@ -15,9 +14,9 @@ def foo(a: String[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") + vyper_ast = parse_and_fold(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -33,9 +32,9 @@ def foo(a: Bytes[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") + vyper_ast = parse_and_fold(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value.encode()) == new_node.value @@ -51,8 +50,8 @@ def foo(a: Bytes[1024]) -> uint256: value = f"0x{'00' * length}" - vyper_ast = vy_ast.parse_to_ast(f"len({value})") + vyper_ast = parse_and_fold(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 36a611fa1b..752b64eb04 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.utils import SizeLimits st_decimals = st.decimals( @@ -29,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -48,9 +47,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -67,8 +66,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index a3c2567f58..ad1197e8e3 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256) @@ -19,8 +18,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") + vyper_ast = parse_and_fold(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/codegen/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py index 8a08a11cc2..4abde9c617 100644 --- a/tests/functional/codegen/features/iteration/test_break.py +++ b/tests/functional/codegen/features/iteration/test_break.py @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(400): + for i: int128 in range(400): c = c / 1.2589 if c < 1.0: output = i @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c = c / 10.0 - for i in range(10): + for i: int128 in range(10): c = c / 1.2589 if c < 1.0: output = output + i @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation): def foo(n: int128) -> int128: c: decimal = convert(n, decimal) output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c /= 10.0 - for i in range(10): + for i: int128 in range(10): c /= 1.2589 if c < 1.0: output = output + i @@ -108,7 +108,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: break diff --git a/tests/functional/codegen/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py index 5f4f82a2de..1b2fcab460 100644 --- a/tests/functional/codegen/features/iteration/test_continue.py +++ b/tests/functional/codegen/features/iteration/test_continue.py @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation): code = """ @external def foo() -> bool: - for i in range(2): + for i: uint256 in range(2): continue return False return True @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += 1 continue x -= 1 @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += i continue return x @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): if i % 2 == 0: continue x += 1 @@ -83,7 +83,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: continue diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..7f5658e485 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -11,7 +11,9 @@ NamespaceCollision, StateAccessViolation, StructureException, + SyntaxException, TypeMismatch, + UnknownType, ) BASIC_FOR_LOOP_CODE = [ @@ -21,7 +23,7 @@ @external def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -33,7 +35,7 @@ def data() -> int128: @external def data() -> int128: s: DynArray[int128, 10] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -53,8 +55,8 @@ def data() -> int128: [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] ] ret: int128 = 0 - for ss in sss: - for s in ss: + for ss: DynArray[S, 10] in sss: + for s: S in ss: ret += s.x + s.y return ret""", sum(range(1, 11)), @@ -64,7 +66,7 @@ def data() -> int128: """ @external def data() -> int128: - for i in [3, 5, 7, 9]: + for i: int128 in [3, 5, 7, 9]: if i > 5: return i return -1""", @@ -76,7 +78,7 @@ def data() -> int128: @external def data() -> String[33]: xs: DynArray[String[33], 3] = ["hello", ",", "world"] - for x in xs: + for x: String[33] in xs: if x == ",": return x return "" @@ -88,7 +90,7 @@ def data() -> String[33]: """ @external def data() -> String[33]: - for x in ["hello", ",", "world"]: + for x: String[33] in ["hello", ",", "world"]: if x == ",": return x return "" @@ -100,7 +102,7 @@ def data() -> String[33]: """ @external def data() -> DynArray[String[33], 2]: - for x in [["hello", "world"], ["goodbye", "world!"]]: + for x: DynArray[String[33], 2] in [["hello", "world"], ["goodbye", "world!"]]: if x[1] == "world": return x return [] @@ -114,8 +116,8 @@ def data() -> DynArray[String[33], 2]: def data() -> int128: ret: int128 = 0 xss: int128[3][3] = [[1,2,3],[4,5,6],[7,8,9]] - for xs in xss: - for x in xs: + for xs: int128[3] in xss: + for x: int128 in xs: ret += x return ret""", sum(range(1, 10)), @@ -130,8 +132,8 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss in [[S({x:1, y:2})]]: - for s in ss: + for ss: S[1] in [[S({x:1, y:2})]]: + for s: S in ss: ret += s.x + s.y return ret""", 1 + 2, @@ -147,7 +149,7 @@ def data() -> address: 0xDCEceAF3fc5C0a63d195d69b1A90011B7B19650D ] count: int128 = 0 - for i in addresses: + for i: address in addresses: count += 1 if count == 2: return i @@ -174,7 +176,7 @@ def set(): @external def data() -> int128: - for i in self.x: + for i: int128 in self.x: if i > 5: return i return -1 @@ -198,7 +200,7 @@ def set(xs: DynArray[int128, 4]): @external def data() -> int128: t: int128 = 0 - for i in self.x: + for i: int128 in self.x: t += i return t """ @@ -227,7 +229,7 @@ def ret(i: int128) -> address: @external def iterate_return_second() -> address: count: int128 = 0 - for i in self.addresses: + for i: address in self.addresses: count += 1 if count == 2: return i @@ -258,7 +260,7 @@ def ret(i: int128) -> decimal: @external def i_return(break_count: int128) -> decimal: count: int128 = 0 - for i in self.readings: + for i: decimal in self.readings: if count == break_count: return i count += 1 @@ -284,7 +286,7 @@ def func(amounts: uint256[3]) -> uint256: total: uint256 = as_wei_value(0, "wei") # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -303,7 +305,7 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: total: uint256 = 0 # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -321,42 +323,42 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in range(4): + for i: int128 in range(4): p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass - for i in range(20): + for i: uint256 in range(20): pass """, # using index variable after loop """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass i: int128 = 100 # create new variable i i = 200 # look up the variable i and check whether it is in forvars @@ -372,25 +374,25 @@ def test_good_code(code, get_contract): RANGE_CONSTANT_CODE = [ ( """ -TREE_FIDDY: constant(int128) = 350 +TREE_FIDDY: constant(uint256) = 350 @external def a() -> uint256: x: uint256 = 0 - for i in range(TREE_FIDDY): + for i: uint256 in range(TREE_FIDDY): x += 1 return x""", 350, ), ( """ -ONE_HUNDRED: constant(int128) = 100 +ONE_HUNDRED: constant(uint256) = 100 @external def a() -> uint256: x: uint256 = 0 - for i in range(1, 1 + ONE_HUNDRED): + for i: uint256 in range(1, 1 + ONE_HUNDRED): x += 1 return x""", 100, @@ -401,9 +403,9 @@ def a() -> uint256: END: constant(int128) = 199 @external -def a() -> uint256: - x: uint256 = 0 - for i in range(START, END): +def a() -> int128: + x: int128 = 0 + for i: int128 in range(START, END): x += 1 return x""", 99, @@ -413,11 +415,23 @@ def a() -> uint256: @external def a() -> int128: x: int128 = 0 - for i in range(-5, -1): + for i: int128 in range(-5, -1): x += i return x""", -14, ), + ( + """ +@external +def a() -> uint256: + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + x: uint256 = 0 + for i: uint256 in a[2]: + x += i + return x + """, + 9, + ), ] @@ -436,7 +450,7 @@ def test_range_constant(get_contract, code, result): def data() -> int128: s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -451,7 +465,7 @@ def data() -> int128: def foo(): s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] += 1 """, ImmutableViolation, @@ -468,7 +482,7 @@ def set(): @external def data() -> int128: count: int128 = 0 - for i in self.s: + for i: int128 in self.s: self.s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -493,7 +507,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff(i) i += 1 """, @@ -519,7 +533,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.bar.foo: + for item: uint256 in self.my_array2.bar.foo: self.doStuff(i) i += 1 """, @@ -545,7 +559,7 @@ def doStuff(): @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff() i += 1 """, @@ -556,8 +570,8 @@ def _helper(): """ @external def foo(x: int128): - for i in range(4): - for i in range(5): + for i: int128 in range(4): + for i: int128 in range(5): pass """, NamespaceCollision, @@ -566,8 +580,8 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: - for i in [1,2]: + for i: int128 in [1,2]: + for i: int128 in [1,2]: pass """, NamespaceCollision, @@ -577,7 +591,7 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i = 2 """, ImmutableViolation, @@ -588,7 +602,7 @@ def foo(x: int128): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.pop() """, ImmutableViolation, @@ -599,7 +613,7 @@ def foo(): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.append(x) """, ImmutableViolation, @@ -610,7 +624,7 @@ def foo(): @external def foo(): xs: DynArray[DynArray[uint256, 5], 5] = [[1,2,3]] - for x in xs: + for x: DynArray[uint256, 5] in xs: x.pop() """, ImmutableViolation, @@ -629,7 +643,7 @@ def b(): @external def foo(): - for x in self.array: + for x: uint256 in self.array: self.a() """, ImmutableViolation, @@ -638,7 +652,7 @@ def foo(): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i += 2 """, ImmutableViolation, @@ -648,7 +662,7 @@ def foo(x: int128): """ @external def foo(): - for i in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -656,13 +670,13 @@ def foo(): """ @external def foo(): - for i in range(0): + for i: uint256 in range(0): pass """, """ @external def foo(): - for i in []: + for i: uint256 in []: pass """, """ @@ -670,14 +684,14 @@ def foo(): @external def foo(): - for i in FOO: + for i: uint256 in FOO: pass """, ( """ @external def foo(): - for i in range(5,3): + for i: uint256 in range(5,3): pass """, StructureException, @@ -686,7 +700,7 @@ def foo(): """ @external def foo(): - for i in range(5,3,-1): + for i: int128 in range(5,3,-1): pass """, ArgumentException, @@ -696,7 +710,7 @@ def foo(): @external def foo(): a: uint256 = 2 - for i in range(a): + for i: uint256 in range(a): pass """, StateAccessViolation, @@ -706,7 +720,7 @@ def foo(): @external def foo(): a: int128 = 6 - for i in range(a,a-3): + for i: int128 in range(a,a-3): pass """, StateAccessViolation, @@ -716,7 +730,7 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, ArgumentException, @@ -725,7 +739,7 @@ def foo(): """ @external def foo(): - for i in range(0,1,2): + for i: uint256 in range(0,1,2): pass """, ArgumentException, @@ -735,7 +749,7 @@ def foo(): """ @external def foo(): - for i in b"asdf": + for i: Bytes[1] in b"asdf": pass """, InvalidType, @@ -744,7 +758,7 @@ def foo(): """ @external def foo(): - for i in 31337: + for i: uint256 in 31337: pass """, InvalidType, @@ -753,7 +767,7 @@ def foo(): """ @external def foo(): - for i in bar(): + for i: uint256 in bar(): pass """, IteratorException, @@ -762,7 +776,7 @@ def foo(): """ @external def foo(): - for i in self.bar(): + for i: uint256 in self.bar(): pass """, IteratorException, @@ -772,11 +786,11 @@ def foo(): @external def test_for() -> int128: a: int128 = 0 - for i in range(max_value(int128), max_value(int128)+2): + for i: int128 in range(max_value(int128), max_value(int128)+2): a = i return a """, - TypeMismatch, + InvalidType, ), ( """ @@ -784,13 +798,40 @@ def test_for() -> int128: def test_for() -> int128: a: int128 = 0 b: uint256 = 0 - for i in range(5): + for i: int128 in range(5): a = i b = i return a """, TypeMismatch, ), + ( + """ +@external +def foo(): + for i in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: $$$ in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: uint9 in [1, 2, 3]: + pass + """, + UnknownType, + ), ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e946447285..c661c46553 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -6,7 +6,7 @@ def test_basic_repeater(get_contract_with_gas_estimation): @external def repeat(z: int128) -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): x = x + z return(x) """ @@ -19,7 +19,7 @@ def test_range_bound(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, bound=6): + for i: uint256 in range(n, bound=6): x += i + 1 return x """ @@ -37,7 +37,7 @@ def test_range_bound_constant_end(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, 7, bound=6): + for i: uint256 in range(n, 7, bound=6): x += i + 1 return x """ @@ -58,7 +58,7 @@ def test_range_bound_two_args(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(1, n, bound=6): + for i: uint256 in range(1, n, bound=6): x += i + 1 return x """ @@ -80,7 +80,7 @@ def test_range_bound_two_runtime_args(get_contract, tx_failed): @external def repeat(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x += i return x """ @@ -109,7 +109,7 @@ def test_range_overflow(get_contract, tx_failed): @external def get_last(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x = i return x """ @@ -134,11 +134,11 @@ def test_digit_reverser(get_contract_with_gas_estimation): def reverse_digits(x: int128) -> int128: dig: int128[6] = [0, 0, 0, 0, 0, 0] z: int128 = x - for i in range(6): + for i: uint256 in range(6): dig[i] = z % 10 z = z / 10 o: int128 = 0 - for i in range(6): + for i: uint256 in range(6): o = o * 10 + dig[i] return o @@ -153,9 +153,9 @@ def test_more_complex_repeater(get_contract_with_gas_estimation): @external def repeat() -> int128: out: int128 = 0 - for i in range(6): + for i: uint256 in range(6): out = out * 10 - for j in range(4): + for j: int128 in range(4): out = out + j return(out) """ @@ -170,7 +170,7 @@ def test_offset_repeater(get_contract_with_gas_estimation, typ): @external def sum() -> {typ}: out: {typ} = 0 - for i in range(80, 121): + for i: {typ} in range(80, 121): out = out + i return out """ @@ -185,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101, bound=101): + for i: {typ} in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -205,7 +205,7 @@ def _bar() -> bool: @external def foo() -> bool: - for i in range(3): + for i: uint256 in range(3): self._bar() return True """ @@ -219,8 +219,8 @@ def test_return_inside_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for j in range(10): + for i: {typ} in range(10): + for j: {typ} in range(10): if j > 5: if i > a: return i @@ -254,14 +254,14 @@ def test_for_range_edge(get_contract, typ): def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x - 1, x, bound=1): + for i: {typ} in range(x - 1, x, bound=1): if i + 1 == max_value({typ}): found = True assert found found = False x = max_value({typ}) - 1 - for i in range(x - 1, x + 1, bound=2): + for i: {typ} in range(x - 1, x + 1, bound=2): if i + 1 == max_value({typ}): found = True assert found @@ -276,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x + 2, bound=2): + for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) @@ -289,8 +289,8 @@ def test_return_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -318,8 +318,8 @@ def test_return_void_nested_repeater(get_contract, typ, val): result: {typ} @internal def _final(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -347,8 +347,8 @@ def test_external_nested_repeater(get_contract, typ, val): code = f""" @external def foo(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -368,8 +368,8 @@ def test_external_void_nested_repeater(get_contract, typ, val): result: public({typ}) @external def foo(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -388,8 +388,8 @@ def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if a < 2: break return 6 diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index 244f820537..de9dd17ef6 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -151,7 +151,7 @@ def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5 return True """ @@ -171,7 +171,7 @@ def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5, "because reasons" return True """ diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index f10d22ec99..422f53fdeb 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -152,7 +152,7 @@ def _increment(): @external def returnten() -> int128: - for i in range(10): + for i: uint256 in range(10): self._increment() return self.counter """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 671d424d60..891ed5aebe 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py new file mode 100644 index 0000000000..084ad26e6b --- /dev/null +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -0,0 +1,31 @@ +def test_import_interface_types(make_input_bundle, get_contract): + ifaces = """ +interface IFoo: + def foo() -> uint256: nonpayable + """ + + foo_impl = """ +import ifaces + +implements: ifaces.IFoo + +@external +def foo() -> uint256: + return block.number + """ + + contract = """ +import ifaces + +@external +def test_foo(s: ifaces.IFoo) -> bool: + assert s.foo() == block.number + return True + """ + + input_bundle = make_input_bundle({"ifaces.vy": ifaces}) + + foo = get_contract(foo_impl, input_bundle=input_bundle) + c = get_contract(contract, input_bundle=input_bundle) + + assert c.test_foo(foo.address) is True diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py new file mode 100644 index 0000000000..aafbb69252 --- /dev/null +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -0,0 +1,78 @@ +def test_module_constant(make_input_bundle, get_contract): + mod1 = """ +X: constant(uint256) = 12345 + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 12345 + + +def test_nested_module_constant(make_input_bundle, get_contract): + # test nested module constants + # test at least 3 modules deep to test the `path.reverse()` gizmo + # in ConstantFolder.visit_Attribute() + mod1 = """ +X: constant(uint256) = 12345 + """ + mod2 = """ +import mod1 +X: constant(uint256) = 54321 + """ + mod3 = """ +import mod2 +X: constant(uint256) = 98765 + """ + + contract = """ +import mod1 +import mod2 +import mod3 + +@external +def test_foo() -> bool: + assert mod1.X == 12345 + assert mod2.X == 54321 + assert mod3.X == 98765 + assert mod2.mod1.X == mod1.X + assert mod3.mod2.mod1.X == mod1.X + assert mod3.mod2.X == mod2.X + return True + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1, "mod2.vy": mod2, "mod3.vy": mod3}) + + c = get_contract(contract, input_bundle=input_bundle) + assert c.test_foo() is True + + +def test_import_constant_array(make_input_bundle, get_contract, tx_failed): + mod1 = """ +X: constant(uint256[3]) = [1,2,3] + """ + contract = """ +import mod1 + +@external +def foo(ix: uint256) -> uint256: + return mod1.X[ix] + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo(0) == 1 + assert c.foo(1) == 2 + assert c.foo(2) == 3 + with tx_failed(): + c.foo(3) diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/modules/test_stateless_functions.py similarity index 100% rename from tests/functional/codegen/test_stateless_modules.py rename to tests/functional/codegen/modules/test_stateless_functions.py diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index fcf71f12f0..72171dd4b5 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -125,7 +125,7 @@ def test_harder_decimal_test(get_contract_with_gas_estimation): @external def phooey(inp: decimal) -> decimal: x: decimal = 10000.0 - for i in range(4): + for i: uint256 in range(4): x = x * inp return x diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 1ee9b8d835..882629de65 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -268,7 +268,7 @@ def test_zero_padding_with_private(get_contract): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py index f9fcf37b25..6597facd1b 100644 --- a/tests/functional/codegen/types/test_bytes_zero_padding.py +++ b/tests/functional/codegen/types/test_bytes_zero_padding.py @@ -10,7 +10,7 @@ def little_endian_contract(get_contract_module): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 70a68e3206..e47eda6042 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -969,7 +969,7 @@ def foo() -> (uint256, uint256, uint256, uint256, uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -981,7 +981,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: some_var: uint256 @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.some_var = x # test that typechecker for append args works self.my_array.append(self.some_var) @@ -994,9 +994,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) - for x in xs: + for x: uint256 in xs: self.my_array.pop() return self.my_array """, @@ -1008,7 +1008,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array, self.my_array.pop() """, @@ -1020,7 +1020,7 @@ def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array.pop(), self.my_array """, @@ -1033,7 +1033,7 @@ def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] i: uint256 = 0 - for x in xs: + for x: uint256 in xs: if i >= len(xs) - 1: break ys.append(x) @@ -1049,7 +1049,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -1061,9 +1061,9 @@ def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() return ys """, @@ -1075,9 +1075,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() ys.pop() # fail return ys @@ -1328,7 +1328,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: DynArray[Foo, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): e: Foobar = _baz[i].z f: uint256 = convert(e, uint256) sum += _baz[i].x * _baz[i].y + f @@ -1397,7 +1397,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: DynArray[Bar, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index b5b9538c20..ee287064e8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -566,7 +566,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: Foo[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _baz[i].x * _baz[i].y return sum """ @@ -608,7 +608,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: Bar[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 7dd8c35929..351793b28e 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -4,7 +4,7 @@ import hypothesis import hypothesis.strategies as st import pytest -from hypothesis import assume, given +from hypothesis import HealthCheck, assume, given from hypothesis.extra.lark import LarkStrategy from vyper.ast import Module, parse_to_ast @@ -103,9 +103,9 @@ def has_no_docstrings(c): @pytest.mark.fuzzing @given(code=from_grammar().filter(lambda c: utf8_encodable(c))) -@hypothesis.settings(max_examples=500) +@hypothesis.settings(max_examples=500, suppress_health_check=[HealthCheck.too_slow]) def test_grammar_bruteforce(code): if utf8_encodable(code): - _, _, reformatted_code = pre_parse(code + "\n") + _, _, _, reformatted_code = pre_parse(code + "\n") tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index 0b7ec21bdb..4240aec8d2 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -80,13 +80,13 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, """ @external def foo(): - for i in range(1, 2, 3, 4): + for i: uint256 in range(1, 2, 3, 4): pass """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4bd0b4fcb9..7adf9538c7 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -57,7 +57,7 @@ def foo() -> int128: return 5 @external def bar(): - for i in range(self.foo(), self.foo() + 1): + for i: int128 in range(self.foo(), self.foo() + 1): pass""", """ glob: int128 @@ -67,13 +67,13 @@ def foo() -> int128: return 5 @external def bar(): - for i in [1,2,3,4,self.foo()]: + for i: int128 in [1,2,3,4,self.foo()]: pass""", """ @external def foo(): x: int128 = 5 - for i in range(x): + for i: int128 in range(x): pass""", """ f:int128 diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 9ab9b6c677..53a9550a7d 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -86,6 +86,18 @@ def f(a:uint256,/): # test posonlyargs blocked def g(): self.f() """, + """ +@external +def foo(): + for i in range(0, 10): + pass + """, + """ +@external +def foo(): + for i: $$$ in range(0, 10): + pass + """, ] diff --git a/tests/functional/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py index 942aa3fa68..466b5509ca 100644 --- a/tests/functional/syntax/test_blockscope.py +++ b/tests/functional/syntax/test_blockscope.py @@ -33,7 +33,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a = 1 """, @@ -41,7 +41,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a += 1 """, diff --git a/tests/functional/syntax/test_bool.py b/tests/functional/syntax/test_bool.py index 48ed37321a..5388a92b95 100644 --- a/tests/functional/syntax/test_bool.py +++ b/tests/functional/syntax/test_bool.py @@ -37,7 +37,7 @@ def foo(): def foo() -> bool: return (1 == 2) <= (1 == 1) """, - TypeMismatch, + InvalidOperation, ), """ @external diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index ffd2f1faa0..7089dee3bb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i in range(CONST / 4): + for i: uint256 in range(CONST / 4): pass """, """ diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a9c3ad5cab..a486d11738 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -8,6 +8,7 @@ StateAccessViolation, StructureException, TypeMismatch, + UnknownType, ) fail_list = [ @@ -15,7 +16,7 @@ """ @external def foo(): - for a[1] in range(10): + for a[1]: uint256 in range(10): pass """, StructureException, @@ -26,7 +27,7 @@ def foo(): """ @external def bar(): - for i in range(1,2,bound=0): + for i: uint256 in range(1,2,bound=0): pass """, StructureException, @@ -38,7 +39,7 @@ def bar(): @external def foo(): x: uint256 = 100 - for _ in range(10, bound=x): + for _: uint256 in range(10, bound=x): pass """, StateAccessViolation, @@ -49,7 +50,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=5): + for _: uint256 in range(10, 20, bound=5): pass """, StructureException, @@ -60,7 +61,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=0): + for _: uint256 in range(10, 20, bound=0): pass """, StructureException, @@ -72,7 +73,7 @@ def foo(): @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2,extra=3): + for i: uint256 in range(x,x+1,bound=2,extra=3): pass """, ArgumentException, @@ -83,7 +84,7 @@ def bar(): """ @external def bar(): - for i in range(0): + for i: uint256 in range(0): pass """, StructureException, @@ -95,7 +96,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x): + for i: uint256 in range(x): pass """, StateAccessViolation, @@ -107,7 +108,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(0, x): + for i: uint256 in range(0, x): pass """, StateAccessViolation, @@ -118,7 +119,7 @@ def bar(): """ @external def repeat(n: uint256) -> uint256: - for i in range(0, n * 10): + for i: uint256 in range(0, n * 10): pass return n """, @@ -131,7 +132,7 @@ def repeat(n: uint256) -> uint256: @external def bar(): x:uint256 = 1 - for i in range(0, x + 1): + for i: uint256 in range(0, x + 1): pass """, StateAccessViolation, @@ -142,7 +143,7 @@ def bar(): """ @external def bar(): - for i in range(2, 1): + for i: uint256 in range(2, 1): pass """, StructureException, @@ -154,7 +155,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x, x): + for i: uint256 in range(x, x): pass """, StateAccessViolation, @@ -166,7 +167,7 @@ def bar(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i: int128 in range(x, x + 10): pass """, StateAccessViolation, @@ -177,7 +178,7 @@ def foo(): """ @external def repeat(n: uint256) -> uint256: - for i in range(n, 6): + for i: uint256 in range(n, 6): pass return x """, @@ -190,7 +191,7 @@ def repeat(n: uint256) -> uint256: @external def foo(x: int128): y: int128 = 7 - for i in range(x, x + y): + for i: int128 in range(x, x + y): pass """, StateAccessViolation, @@ -201,7 +202,7 @@ def foo(x: int128): """ @external def bar(x: uint256): - for i in range(3, x): + for i: uint256 in range(3, x): pass """, StateAccessViolation, @@ -215,12 +216,12 @@ def bar(x: uint256): @external def foo(): - for i in range(FOO, BAR): + for i: uint256 in range(FOO, BAR): pass """, TypeMismatch, - "Iterator values are of different types", - "range(FOO, BAR)", + "Given reference has type int128, expected uint256", + "FOO", ), ( """ @@ -228,16 +229,61 @@ def foo(): @external def foo(): - for i in range(10, bound=FOO): + for i: int128 in range(10, bound=FOO): pass """, StructureException, "Bound must be at least 1", - "-1", + "FOO", + ), + ( + """ +@external +def foo(): + for i: DynArra[uint256, 3] in [1, 2, 3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "DynArra", + ), + ( + # test for loop target broken into multiple lines + """ +@external +def foo(): + for i: \\ + \\ + \\ + \\ + \\ + \\ + uint9 in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", + ), + ( + # test an even more deranged example + """ +@external +def foo(): + for i: \\ + \\ + DynArray[\\ + uint9, 3\\ + ] in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", ), ] -for_code_regex = re.compile(r"for .+ in (.*):") +for_code_regex = re.compile(r"for .+ in (.*):", re.DOTALL) fail_test_names = [ ( f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] @@ -252,41 +298,41 @@ def test_range_fail(bad_code, error_type, message, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) assert message == exc_info.value.message - assert source_code == exc_info.value.args[1].node_source_code + assert source_code == exc_info.value.args[1].get_original_node().node_source_code valid_list = [ """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass """, """ @external def foo(): - for i in range(10, 20): + for i: uint256 in range(10, 20): pass """, """ @external def foo(): x: int128 = 5 - for i in range(1, x, bound=4): + for i: int128 in range(1, x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(x, bound=4): + for i: int128 in range(x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(0, x, bound=4): + for i: int128 in range(0, x, bound=4): pass """, """ @@ -295,7 +341,7 @@ def kick(): nonpayable foos: Foo[3] @external def kick_foos(): - for foo in self.foos: + for foo: Foo in self.foos: foo.kick() """, ] diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index a672ed7b88..ca96adca91 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -90,7 +90,7 @@ def foo(): nonpayable """ implements: self.x """, - StructureException, + InvalidType, ), ( """ diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index db41de5526..3936f8c220 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -306,7 +306,7 @@ def foo(): @external def foo(): x: DynArray[uint256, 3] = [1, 2, 3] - for i in [[], []]: + for i: DynArray[uint256, 3] in [[], []]: x = i """, ] diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index e426a11de9..a75d114f88 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -4,7 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException st_decimals = st.decimals( @@ -28,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -45,11 +45,8 @@ def foo(a: decimal, b: decimal) -> decimal: def test_binop_pow(): # raises because Vyper does not support decimal exponentiation - vyper_ast = vy_ast.parse_to_ast("3.1337 ** 4.2") - old_node = vyper_ast.body[0].value - with pytest.raises(TypeMismatch): - old_node.get_folded_value() + _ = parse_and_fold("3.1337 ** 4.2") @pytest.mark.fuzzing @@ -72,8 +69,8 @@ def foo({input_value}) -> decimal: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = -(2**127) <= expected < 2**127 diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index 904b36c167..d9340927fe 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import ZeroDivisionException st_int32 = st.integers(min_value=-(2**32), max_value=2**32) @@ -24,9 +24,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -54,9 +54,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: @@ -83,7 +83,7 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") + vyper_ast = parse_and_fold(f"{left} ** {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -112,9 +112,8 @@ def foo({input_value}) -> int128: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) - try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = True diff --git a/tests/unit/ast/nodes/test_fold_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py index 3c42da0d26..082e6f35c3 100644 --- a/tests/unit/ast/nodes/test_fold_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold variables = "abcdefghij" @@ -24,7 +24,7 @@ def foo({input_value}) -> bool: literal_op = f" {comparator} ".join(str(i) for i in values) - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -52,7 +52,7 @@ def foo({input_value}) -> bool: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, comparators)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index 2b7c0f09d7..aab8ac0b2d 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import UnfoldableNode @@ -19,7 +19,7 @@ def foo(a: int128, b: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -39,7 +39,7 @@ def foo(a: uint128, b: uint128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -63,7 +63,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") + vyper_ast = parse_and_fold(f"{left} in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -92,7 +92,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") + vyper_ast = parse_and_fold(f"{left} not in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -106,7 +106,7 @@ def bar(a: int128) -> bool: @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_type_mismatch(op): - vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") + vyper_ast = parse_and_fold(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py index 1884abf73b..3ed26d07b7 100644 --- a/tests/unit/ast/nodes/test_fold_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.fuzzing @@ -19,7 +19,7 @@ def foo(array: int128[10], idx: uint256) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") + vyper_ast = parse_and_fold(f"{array}[{idx}]") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py index ff48adfe71..af72f5f8b0 100644 --- a/tests/unit/ast/nodes/test_fold_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -1,6 +1,6 @@ import pytest -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.parametrize("bool_cond", [True, False]) @@ -12,7 +12,7 @@ def foo(a: bool) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") + vyper_ast = parse_and_fold(f"not {bool_cond}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -30,7 +30,7 @@ def foo(a: bool) -> bool: contract = get_contract(source) literal_op = f"{'not ' * count}{bool_cond}" - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index d413340083..a6bc3147e6 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -24,7 +24,7 @@ def foo(): """ @external def foo(): - for i in [0x6b175474e89094c44da98b954eedeac495271d0F]: + for i: address in [0x6b175474e89094c44da98b954eedeac495271d0F]: pass """, """ diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 16ce6fe631..b202f6d8a3 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, reformatted_code = pre_parse(source_code) + _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types) + annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) return py_ast, reformatted_code diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 682c13ca84..020e83627c 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -173,7 +173,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _ = pre_parse(code) + settings, _, _, _ = pre_parse(code) assert settings == pre_parse_settings diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index a50946ba21..c805e2b5b1 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,18 +113,23 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") + # remove bb and bb_runtime from output formats + # because they require venom (experimental) + output_formats = OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] foo = compile_from_file_input( - foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + foo_input, output_formats=output_formats, input_bundle=input_bundle ) library_input = input_bundle.load_file("contracts/library.vy") library = compile_from_file_input( - library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + library_input, output_formats=output_formats, input_bundle=input_bundle ) bar_input = input_bundle.load_file("contracts/bar.vy") bar = compile_from_file_input( - bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + bar_input, output_formats=output_formats, input_bundle=input_bundle ) compile_code_results = { diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 44b823757c..b2851e908a 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -58,7 +58,7 @@ def ctor_only(): @internal def runtime_only(): - for i in range(10): + for i: uint256 in range(10): self.s += 1 @external diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index c9a152b09c..5b478dd2aa 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -6,7 +6,7 @@ @internal def _baz(a: int128) -> int128: b: int128 = a - for i in range(2, 5): + for i: int128 in range(2, 5): b *= i if b > 31337: break diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index a51992df67..b96c7f3351 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -18,10 +18,10 @@ def test_duplicate_operands(): ctx = IRFunction() bb = ctx.get_basic_block() op = bb.append_instruction("store", 10) - sum = bb.append_instruction("add", op, op) - bb.append_instruction("mul", sum, op) + sum_ = bb.append_instruction("add", op, op) + bb.append_instruction("mul", sum_, op) bb.append_instruction("stop") - asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.CODESIZE) assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 104697432b..6d8b074994 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -41,7 +41,7 @@ def test_multi_entry_block_1(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" @@ -93,7 +93,7 @@ def test_multi_entry_block_2(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" @@ -134,5 +134,5 @@ def test_multi_entry_block_with_dynamic_jump(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e2c0f555af..607587cc28 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -22,7 +22,7 @@ def foo(): @internal def bar(): self.foo() - for i in self.a: + for i: uint256 in self.a: pass """ vyper_module = parse_to_ast(code) @@ -42,7 +42,7 @@ def foo(a: uint256[3]) -> uint256[3]: @internal def bar(): a: uint256[3] = [1,2,3] - for i in a: + for i: uint256 in a: self.foo(a) """ vyper_module = parse_to_ast(code) @@ -56,7 +56,7 @@ def test_modify_iterator(dummy_input_bundle): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.a[0] = 1 """ vyper_module = parse_to_ast(code) @@ -70,7 +70,7 @@ def test_bad_keywords(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, boundddd=10): + for i: uint256 in range(n, boundddd=10): x += i """ vyper_module = parse_to_ast(code) @@ -84,7 +84,7 @@ def test_bad_bound(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, bound=n): + for i: uint256 in range(n, bound=n): x += i """ vyper_module = parse_to_ast(code) @@ -103,7 +103,7 @@ def foo(): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.foo() """ vyper_module = parse_to_ast(code) @@ -126,7 +126,7 @@ def bar(): @internal def baz(): - for i in self.a: + for i: uint256 in self.a: self.bar() """ vyper_module = parse_to_ast(code) @@ -138,32 +138,32 @@ def baz(): """ @external def main(): - for j in range(3): + for j: uint256 in range(3): x: uint256 = j y: uint16 = j """, # GH issue 3212 """ @external def foo(): - for i in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1]: - for j in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + for j: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1,2,3]: - for j in [1,2,3]: - b:uint256 = j + i - c:uint16 = i + for i: uint256 in [1,2,3]: + for j: uint256 in [1,2,3]: + b: uint256 = j + i + c: uint16 = i """, # GH issue 3374 ] diff --git a/tests/utils.py b/tests/utils.py index 0c89c39ff3..25dad818ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ import contextlib import os +from vyper import ast as vy_ast +from vyper.semantics.analysis.constant_folding import constant_fold + @contextlib.contextmanager def working_directory(directory): @@ -10,3 +13,9 @@ def working_directory(directory): yield finally: os.chdir(tmp) + + +def parse_and_fold(source_code): + ast = vy_ast.parse_to_ast(source_code) + constant_fold(ast) + return ast diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 7889473b19..234e96e552 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -178,8 +178,7 @@ body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] -// TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` -loop_variable: NAME [":" NAME] +loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 82afcb9217..de15fb9075 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,15 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata") +NODE_BASE_ATTRIBUTES = ( + "_children", + "_depth", + "_parent", + "ast_type", + "node_id", + "_metadata", + "_original_node", +) NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", @@ -255,6 +263,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): self.set_parent(parent) self._children: set = set() self._metadata: NodeMetadata = NodeMetadata() + self._original_node = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset @@ -392,46 +401,31 @@ def has_folded_value(self): """ return "folded_value" in self._metadata - def get_folded_value(self) -> "VyperNode": + def get_folded_value(self) -> "ExprNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. - - - The returned value is cached on `_metadata["folded_value"]`. - - For constant/literal nodes, the node should be directly returned - without caching to the metadata. """ - if self.is_literal_value: - return self - - if "folded_value" not in self._metadata: - res = self._try_fold() # possibly throws UnfoldableNode - self._set_folded_value(res) - - return self._metadata["folded_value"] + try: + return self._metadata["folded_value"] + except KeyError: + raise UnfoldableNode("not foldable", self) def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata - # set the folded node's parent so that get_ancestor works - # this is mainly important for error messages. - node._parent = self._parent + # set the "original node" so that exceptions can point to the original + # node and not the folded node + cls = node.__class__ + # make a fresh copy so that the node metadata is fresh. + node = cls(**{i: getattr(node, i) for i in node.get_fields() if hasattr(node, i)}) + node._original_node = self self._metadata["folded_value"] = node - def _try_fold(self) -> "VyperNode": - """ - Attempt to constant-fold the content of a node, returning the result of - constant-folding if possible. - - If a node cannot be folded, it should raise `UnfoldableNode`. This - base implementation acts as a catch-all to raise on any inherited - classes that do not implement the method. - """ - raise UnfoldableNode(f"{type(self)} cannot be folded") + def get_original_node(self) -> "VyperNode": + return self._original_node or self def validate(self) -> None: """ @@ -918,10 +912,6 @@ class List(ExprNode): def is_literal_value(self): return all(e.is_literal_value for e in self.elements) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class Tuple(ExprNode): __slots__ = ("elements",) @@ -935,10 +925,6 @@ def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class NameConstant(Constant): __slots__ = () @@ -959,10 +945,6 @@ class Dict(ExprNode): def is_literal_value(self): return all(v.is_literal_value for v in self.values) - def _try_fold(self) -> ExprNode: - values = [v.get_folded_value() for v in self.values] - return type(self).from_node(self, values=values) - class Name(ExprNode): __slots__ = ("id",) @@ -971,27 +953,6 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the unary operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - operand = self.operand.get_folded_value() - - if isinstance(self.op, Not) and not isinstance(operand, NameConstant): - raise UnfoldableNode("not a boolean!", self.operand) - if isinstance(self.op, USub) and not isinstance(operand, Num): - raise UnfoldableNode("not a number!", self.operand) - if isinstance(self.op, Invert) and not isinstance(operand, Int): - raise UnfoldableNode("not an int!", self.operand) - - value = self.op._op(operand.value) - return type(operand).from_node(self, value=value) - class Operator(VyperNode): pass @@ -1020,30 +981,6 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the arithmetic operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if type(left) is not type(right): - raise UnfoldableNode("invalid operation", self) - if not isinstance(left, Num): - raise UnfoldableNode("not a number!", self.left) - - # this validation is performed to prevent the compiler from hanging - # on very large shifts and improve the error message for negative - # values. - if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) - - value = self.op._op(left.value, right.value) - return type(left).from_node(self, value=value) - class Add(Operator): __slots__ = () @@ -1169,24 +1106,6 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the boolean operation. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - values = [v.get_folded_value() for v in self.values] - - if any(not isinstance(v, NameConstant) for v in values): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - values = [v.value for v in values] - value = self.op._op(values) - return NameConstant.from_node(self, value=value) - class And(Operator): __slots__ = () @@ -1224,40 +1143,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the comparison. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if not isinstance(left, Constant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - # CMC 2022-08-04 we could probably remove these evaluation rules as they - # are taken care of in the IR optimizer now. - if isinstance(self.op, (In, NotIn)): - if not isinstance(right, List): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if next((i for i in right.elements if not isinstance(i, Constant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if len(set([type(i) for i in right.elements])) > 1: - raise UnfoldableNode("List contains multiple literal types") - value = self.op._op(left.value, [i.value for i in right.elements]) - return NameConstant.from_node(self, value=value) - - if not isinstance(left, type(right)): - raise UnfoldableNode("Cannot compare different literal types") - - if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)): - raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self) - - value = self.op._op(left.value, right.value) - return NameConstant.from_node(self, value=value) - class Eq(Operator): __slots__ = () @@ -1319,27 +1204,15 @@ def is_terminus(self): # cursed import cycle! from vyper.builtins.functions import get_builtin_functions - func_name = self.func.get("id") - if not func_name: - return False - - builtin_t = get_builtin_functions().get(func_name) - return getattr(builtin_t, "_is_terminus", False) - - # try checking if this is a builtin, which is foldable - def _try_fold(self): if not isinstance(self.func, Name): - raise UnfoldableNode("not a builtin", self) - - # cursed import cycle! - from vyper.builtins.functions import DISPATCH_TABLE + return False - func_name = self.func.id - if func_name not in DISPATCH_TABLE: - raise UnfoldableNode("not a builtin", self) + funcname = self.func.id + builtin_t = get_builtin_functions().get(funcname) + if builtin_t is None: + return False - builtin_t = DISPATCH_TABLE[func_name] - return builtin_t._try_fold(self) + return builtin_t._is_terminus class keyword(VyperNode): @@ -1353,37 +1226,6 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the subscript. - - This method reduces an indexed reference to a literal array into the value - within the array, e.g. `["foo", "bar"][1]` becomes `"bar"` - - Returns - ------- - ExprNode - Node representing the result of the evaluation. - """ - slice_ = self.slice.value.get_folded_value() - value = self.value.get_folded_value() - - if not isinstance(value, List): - raise UnfoldableNode("Subscript object is not a literal list") - - elements = value.elements - if len(set([type(i) for i in elements])) > 1: - raise UnfoldableNode("List contains multiple node types") - - if not isinstance(slice_, Int): - raise UnfoldableNode("invalid index type", slice_) - - idx = slice_.value - if idx < 0 or idx >= len(elements): - raise UnfoldableNode("invalid index value") - - return elements[idx] - class Index(VyperNode): __slots__ = ("value",) @@ -1560,8 +1402,8 @@ class ImplementsDecl(Stmt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not isinstance(self.annotation, Name): - raise StructureException("not an identifier", self.annotation) + if not isinstance(self.annotation, (Name, Attribute)): + raise StructureException("invalid implements", self.annotation) class If(Stmt): @@ -1573,7 +1415,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 8bc4a4eb57..7f8c902d45 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -30,9 +30,8 @@ class VyperNode: def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def get_folded_value(self) -> VyperNode: ... - def _try_fold(self) -> VyperNode: ... - def _set_folded_value(self, node: VyperNode) -> None: ... + def get_folded_value(self) -> ExprNode: ... + def _set_folded_value(self, node: ExprNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 38a9d31695..cc0a47824c 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,7 +54,7 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) except SyntaxError as e: @@ -73,11 +73,15 @@ def parse_to_ast_with_settings( py_ast, source_code, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -110,14 +114,60 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') +def annotate_python_ast( + parsed_ast: python_ast.AST, + source_code: str, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, +) -> python_ast.AST: + """ + Annotate and optimize a Python AST in preparation conversion to a Vyper AST. + + Parameters + ---------- + parsed_ast : AST + The AST to be annotated and optimized. + source_code : str + The originating source code of the AST. + loop_var_annotations: dict + A mapping of line numbers of `For` nodes to the tokens of the type + annotation of the iterator extracted during pre-parsing. + modification_offsets : dict + A mapping of class names to their original class types. + + Returns + ------- + The annotated and optimized AST. + """ + + tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + for_loop_annotations, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + visitor.visit(parsed_ast) + + return parsed_ast + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets + _loop_var_annotations: dict[int, dict[str, Any]] def __init__( self, source_code: str, - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -127,11 +177,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -144,7 +194,9 @@ def generic_visit(self, node): self.counter += 1 # Decorate every node with source end offsets - start = node.first_token.start if hasattr(node, "first_token") else (None, None) + start = (None, None) + if hasattr(node, "first_token"): + start = node.first_token.start end = (None, None) if hasattr(node, "last_token"): end = node.last_token.end @@ -161,6 +213,7 @@ def generic_visit(self, node): if hasattr(node, "last_token"): start_pos = node.first_token.startpos end_pos = node.last_token.endpos + if node.last_token.type == 4: # ignore trailing newline once more end_pos -= 1 @@ -213,6 +266,62 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not annotation_tokens: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + # some kind of black magic. untokenize preserves the line and column + # offsets, giving us something like `\ + # \ + # \ + # uint8` + # that's not a valid python Expr because it is indented. + # but it's good because the code is indented to exactly the same + # offset as it did in the original source! + # (to best understand this, print out annotation_str and + # self._source_code and compare them side-by-side). + # + # what we do here is add in a dummy target which we will remove + # in a bit, but for now lets us keep the line/col offset, and + # *also* gives us a valid AST. it doesn't matter what the dummy + # target name is, since it gets removed in a few lines. + annotation_str = tokenize.untokenize(annotation_tokens) + annotation_str = "dummy_target:" + annotation_str + + try: + fake_node = python_ast.parse(annotation_str).body[0] + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + # fill in with asttokens info. note we can use `self._tokens` because + # it is indented to exactly the same position where it appeared + # in the original source! + self._tokens.mark_tokens(fake_node) + + # replace the dummy target name with the real target name. + fake_node.target = node.target + # replace the For node target with the new ann_assign + node.target = fake_node + + return self.generic_visit(node) + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -350,42 +459,3 @@ def visit_UnaryOp(self, node): return node.operand else: return node - - -def annotate_python_ast( - parsed_ast: python_ast.AST, - source_code: str, - modification_offsets: Optional[ModificationOffsets] = None, - source_id: int = 0, - module_path: Optional[str] = None, - resolved_path: Optional[str] = None, -) -> python_ast.AST: - """ - Annotate and optimize a Python AST in preparation conversion to a Vyper AST. - - Parameters - ---------- - parsed_ast : AST - The AST to be annotated and optimized. - source_code : str - The originating source code of the AST. - modification_offsets : dict, optional - A mapping of class names to their original class types. - - Returns - ------- - The annotated and optimized AST. - """ - - tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor( - source_code, - modification_offsets, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, - ) - visitor.visit(parsed_ast) - - return parsed_ast diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b949a242bb..159dfc0ace 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,3 +1,4 @@ +import enum import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize @@ -43,6 +44,64 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: ) +class ForParserState(enum.Enum): + NOT_RUNNING = enum.auto() + START_SOON = enum.auto() + RUNNING = enum.auto() + + +# a simple state machine which allows us to handle loop variable annotations +# (which are rejected by the python parser due to pep-526, so we scoop up the +# tokens between `:` and `in` and parse them and add them back in later). +class ForParser: + def __init__(self, code): + self._code = code + self.annotations = {} + self._current_annotation = None + + self._state = ForParserState.NOT_RUNNING + self._current_for_loop = None + + def consume(self, token): + # state machine: we can start slurping tokens soon + if token.type == NAME and token.string == "for": + # note: self._state should be NOT_RUNNING here, but we don't sanity + # check here as that should be an error the parser will handle. + self._state = ForParserState.START_SOON + self._current_for_loop = token.start + + if self._state == ForParserState.NOT_RUNNING: + return False + + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. + + # state machine: end slurping tokens + if token.type == NAME and token.string == "in": + self._state = ForParserState.NOT_RUNNING + self.annotations[self._current_for_loop] = self._current_annotation or [] + self._current_annotation = None + return False + + if self._state != ForParserState.RUNNING: + return False + + # slurp the token + self._current_annotation.append(token) + return True + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} @@ -51,7 +110,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -60,9 +119,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"). + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. Parameters ---------- @@ -71,21 +132,25 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Returns ------- - dict - Mapping of offsets where source was modified. + Settings + Compilation settings based on the directives in the source code + ModificationOffsets + A mapping of class names to their original class types. + dict[tuple[int, int], list[TokenInfo]] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) - for i in range(len(token_list)): - token = token_list[i] + for token in token_list: toks = [token] typ = token.type @@ -146,8 +211,15 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - result.extend(toks) + + if not for_parser.consume(token): + result.extend(toks) + except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return settings, modification_offsets, untokenize(result).decode("utf-8") + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + for_loop_annotations[k] = v.copy() + + return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index aac008ad1e..1a488f39e0 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _is_terminus = False # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index c896fc7ef6..4f8101dfbe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -90,6 +90,7 @@ ceil32, fourbytes_to_int, keccak256, + method_id, method_id_int, vyper_warn, ) @@ -723,12 +724,12 @@ def _try_fold(self, node): raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) return_type = self.infer_kwarg_types(node)["output_type"].typedef - value = method_id_int(value.value) + value = method_id(value.value) if return_type.compare_type(BYTES4_T): - return vy_ast.Hex.from_node(node, value=hex(value)) + return vy_ast.Hex.from_node(node, value="0x" + value.hex()) else: - return vy_ast.Bytes.from_node(node, value=value.to_bytes(4, "big")) + return vy_ast.Bytes.from_node(node, value=value) def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) @@ -2157,7 +2158,7 @@ def build_IR(self, expr, args, kwargs, context): z = x / 2.0 + 0.5 y: decimal = x - for i in range(256): + for i: uint256 in range(256): if z == y: break y = z diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 475ffe3cfc..7d4938f287 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -32,7 +32,7 @@ ) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.semantics.types.shortcuts import UINT256_T class Stmt: @@ -230,19 +230,17 @@ def parse_For(self): return self._parse_For_list() def _parse_For_range(self): - # TODO make sure type always gets annotated - if "type" in self.stmt.target._metadata: - iter_typ = self.stmt.target._metadata["type"] - else: - iter_typ = INT256_T + assert "type" in self.stmt.target.target._metadata + target_type = self.stmt.target.target._metadata["type"] # Get arg0 - for_iter: vy_ast.Call = self.stmt.iter - args_len = len(for_iter.args) + range_call: vy_ast.Call = self.stmt.iter + assert isinstance(range_call, vy_ast.Call) + args_len = len(range_call.args) if args_len == 1: - arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + arg0, arg1 = (IRnode.from_list(0, typ=target_type), range_call.args[0]) elif args_len == 2: - arg0, arg1 = for_iter.args + arg0, arg1 = range_call.args else: # pragma: nocover raise TypeCheckFailure("unreachable: bad # of arguments to range()") @@ -250,7 +248,7 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) end = Expr.parse_value_expr(arg1, self.context) kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + s.arg: Expr.parse_value_expr(s.value, self.context) for s in range_call.keywords } if "bound" in kwargs: @@ -269,9 +267,9 @@ def _parse_For_range(self): if rounds_bound < 1: # pragma: nocover raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) - iptr = self.context.new_variable(varname, iter_typ) + varname = self.stmt.target.target.id + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) self.context.forvars[varname] = True @@ -296,11 +294,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0f7d7a8014..9297f9e3c3 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -23,6 +23,8 @@ # requires ir_node "external_interface": output.build_external_interface_output, "interface": output.build_interface_output, + "bb": output.build_bb_output, + "bb_runtime": output.build_bb_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, @@ -84,6 +86,8 @@ def compile_from_file_input( two arguments - the name of the contract, and the exception that was raised no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool + Use experimental codegen. Defaults to False Returns ------- diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 8ccf6abee1..5e11a20139 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -84,6 +84,14 @@ def build_interface_output(compiler_data: CompilerData) -> str: return out +def build_bb_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[0] + + +def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[1] + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 8cbcfb1da9..ba6ccbda20 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -174,13 +174,7 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - nodes = generate_ir_nodes( - self.global_ctx, self.settings.optimize, self.settings.experimental_codegen - ) - if self.settings.experimental_codegen: - return [generate_ir(nodes[0]), generate_ir(nodes[1])] - else: - return nodes + return generate_ir_nodes(self.global_ctx, self.settings.optimize) @property def ir_nodes(self) -> IRnode: @@ -201,11 +195,17 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: fs = self.annotated_vyper_module.get_children(vy_ast.FunctionDef) return {f.name: f._metadata["func_type"] for f in fs} + @cached_property + def venom_functions(self): + return generate_ir(self.ir_nodes, self.settings.optimize) + @cached_property def assembly(self) -> list: if self.settings.experimental_codegen: + deploy_code, runtime_code = self.venom_functions + assert self.settings.optimize is not None # mypy hint return generate_assembly_experimental( - self.ir_nodes, self.settings.optimize # type: ignore + runtime_code, deploy_code=deploy_code, optimize=self.settings.optimize ) else: return generate_assembly(self.ir_nodes, self.settings.optimize) @@ -213,9 +213,9 @@ def assembly(self) -> list: @cached_property def assembly_runtime(self) -> list: if self.settings.experimental_codegen: - return generate_assembly_experimental( - self.ir_runtime, self.settings.optimize # type: ignore - ) + _, runtime_code = self.venom_functions + assert self.settings.optimize is not None # mypy hint + return generate_assembly_experimental(runtime_code, optimize=self.settings.optimize) else: return generate_assembly(self.ir_runtime, self.settings.optimize) @@ -270,9 +270,7 @@ def generate_annotated_ast( return vyper_module, symbol_tables -def generate_ir_nodes( - global_ctx: ModuleT, optimize: OptimizationLevel, experimental_codegen: bool -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..04667aaa59 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -92,6 +92,10 @@ def __str__(self): node = value[1] if isinstance(value, tuple) else value node_msg = "" + if isinstance(node, vy_ast.VyperNode): + # folded AST nodes contain pointers to the original source + node = node.get_original_node() + try: source_annotation = annotate_source_code( # add trailing space because EOF exceptions point one char beyond the length @@ -369,7 +373,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py new file mode 100644 index 0000000000..b165a6dae9 --- /dev/null +++ b/vyper/semantics/analysis/constant_folding.py @@ -0,0 +1,237 @@ +from vyper import ast as vy_ast +from vyper.exceptions import InvalidLiteral, UnfoldableNode, VyperException +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.namespace import get_namespace + + +def constant_fold(module_ast: vy_ast.Module): + ConstantFolder(module_ast).run() + + +class ConstantFolder(VyperNodeVisitorBase): + def __init__(self, module_ast): + self._constants = {} + self._module_ast = module_ast + + def run(self): + self._get_constants() + self.visit(self._module_ast) + + def _get_constants(self): + module = self._module_ast + const_var_decls = module.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + # visit the entire constant node in case its type annotation + # has unfolded constants in it. + self.visit(c) + + assert c.value is not None # guaranteed by VariableDecl.validate() + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + self._constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + def visit(self, node): + if node.has_folded_value: + return node.get_folded_value() + + for c in node.get_children(): + try: + self.visit(c) + except UnfoldableNode: + # ignore bubbled up exceptions + pass + + try: + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + folded_value = visitor_fn(node) + node._set_folded_value(folded_value) + return folded_value + except UnfoldableNode: + # ignore bubbled up exceptions + pass + + return node + + def visit_Constant(self, node) -> vy_ast.ExprNode: + return node + + def visit_Name(self, node) -> vy_ast.ExprNode: + try: + return self._constants[node.id] + except KeyError: + raise UnfoldableNode("unknown name", node) + + def visit_Attribute(self, node) -> vy_ast.ExprNode: + namespace = get_namespace() + path = [] + value = node.value + while isinstance(value, vy_ast.Attribute): + path.append(value.attr) + value = value.value + + path.reverse() + + if not isinstance(value, vy_ast.Name): + raise UnfoldableNode("not a module", value) + + # not super type-safe but we don't care. just catch AttributeErrors + # and move on + try: + module_t = namespace[value.id].module_t + + for module_name in path: + module_t = module_t.members[module_name].module_t + + varinfo = module_t.get_member(node.attr, node) + + return varinfo.decl_node.value.get_folded_value() + except (VyperException, AttributeError): + raise UnfoldableNode("not a module") + + def visit_UnaryOp(self, node): + operand = node.operand.get_folded_value() + + if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant): + raise UnfoldableNode("not a boolean!", node.operand) + if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num): + raise UnfoldableNode("not a number!", node.operand) + if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int): + raise UnfoldableNode("not an int!", node.operand) + + value = node.op._op(operand.value) + return type(operand).from_node(node, value=value) + + def visit_BinOp(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if type(left) is not type(right): + raise UnfoldableNode("invalid operation", node) + if not isinstance(left, vy_ast.Num): + raise UnfoldableNode("not a number!", node.left) + + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", node.right) + + value = node.op._op(left.value, right.value) + return type(left).from_node(node, value=value) + + def visit_BoolOp(self, node): + values = [v.get_folded_value() for v in node.values] + + if any(not isinstance(v, vy_ast.NameConstant) for v in values): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + values = [v.value for v in values] + value = node.op._op(values) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_Compare(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if not isinstance(left, vy_ast.Constant): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + # CMC 2022-08-04 we could probably remove these evaluation rules as they + # are taken care of in the IR optimizer now. + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + if not isinstance(right, vy_ast.List): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if len(set([type(i) for i in right.elements])) > 1: + raise UnfoldableNode("List contains multiple literal types") + value = node.op._op(left.value, [i.value for i in right.elements]) + return vy_ast.NameConstant.from_node(node, value=value) + + if not isinstance(left, type(right)): + raise UnfoldableNode("Cannot compare different literal types") + + # this is maybe just handled in the type checker. + if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num): + raise UnfoldableNode( + f"Invalid literal types for {node.op.description} comparison", node + ) + + value = node.op._op(left.value, right.value) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_List(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Tuple(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Dict(self, node) -> vy_ast.ExprNode: + values = [v.get_folded_value() for v in node.values] + return type(node).from_node(node, values=values) + + def visit_Call(self, node) -> vy_ast.ExprNode: + if not isinstance(node.func, vy_ast.Name): + raise UnfoldableNode("not a builtin", node) + + namespace = get_namespace() + + func_name = node.func.id + if func_name not in namespace: + raise UnfoldableNode("unknown", node) + + varinfo = namespace[func_name] + if not isinstance(varinfo, VarInfo): + raise UnfoldableNode("unfoldable", node) + + typ = varinfo.typ + # TODO: rename to vyper_type.try_fold_call_expr + if not hasattr(typ, "_try_fold"): + raise UnfoldableNode("unfoldable", node) + return typ._try_fold(node) # type: ignore + + def visit_Subscript(self, node) -> vy_ast.ExprNode: + slice_ = node.slice.value.get_folded_value() + value = node.value.get_folded_value() + + if not isinstance(value, vy_ast.List): + raise UnfoldableNode("Subscript object is not a literal list") + + elements = value.elements + if len(set([type(i) for i in elements])) > 1: + raise UnfoldableNode("List contains multiple node types") + + if not isinstance(slice_, vy_ast.Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") + + return elements[idx] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 00804bfec9..c4af5b1e3a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,13 +1,11 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -40,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -351,8 +348,10 @@ def visit_Expr(self, node): self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): - if isinstance(node.iter, vy_ast.Subscript): - raise StructureException("Cannot iterate over a nested list", node.iter) + if not isinstance(node.target.target, vy_ast.Name): + raise StructureException("Invalid syntax for loop iterator", node.target.target) + + target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -360,7 +359,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _validate_range_call(node.iter) else: # iteration over a variable or literal list @@ -368,14 +367,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -419,65 +414,28 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) + target_name = node.target.target.id + with self.namespace.enter_scope(): + self.namespace[target_name] = VarInfo( + target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) - for_loop_exceptions = [] - iter_name = node.target.id - for possible_target_type in type_list: - # type check the for loop body using each possible type for iterator value + for stmt in node.body: + self.visit(stmt) - with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo( - possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT - ) + self.expr_visitor.visit(node.target.target, target_type) - try: - with NodeMetadata.enter_typechecker_speculation(): - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) + elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + args = node.iter.args + kwargs = [s.value for s in node.iter.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + else: + iter_type = get_exact_type_from_node(node.iter) + self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -556,8 +514,7 @@ def visit(self, node, typ): # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() - validate_expected_type(folded_node, typ) - folded_node._metadata["type"] = typ + self.visit(folded_node, typ) def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) @@ -754,25 +711,18 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _validate_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range() :return: None """ + assert node.func.get("id") == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] - all_args = (start, end, *kwargs.values()) - for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) - if "bound" in kwargs: bound = kwargs["bound"] if bound.has_folded_value: @@ -791,5 +741,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8e435f870f..100819526b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -23,14 +23,10 @@ ) from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.pre_typecheck import pre_typecheck -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT @@ -55,8 +51,6 @@ def validate_semantics_r( """ validate_literal_nodes(module_ast) - pre_typecheck(module_ast) - # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -144,6 +138,9 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + # we can resolve constants after imports are handled. + constant_fold(self.ast) + # keep trying to process all the nodes until we finish or can # no longer progress. this makes it so we don't need to # calculate a dependency tree between top-level items. @@ -315,12 +312,11 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) + ExprVisitor().visit(node.value, type_) # performs validate_expected_type if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) - validate_expected_type(node.value, type_) _validate_self_namespace() return _finalize() @@ -388,8 +384,9 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_InterfaceDef(node) - self.namespace[node.name] = obj + interface_t = InterfaceT.from_InterfaceDef(node) + node._metadata["interface_type"] = interface_t + self.namespace[node.name] = interface_t def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py deleted file mode 100644 index a1302ce9c9..0000000000 --- a/vyper/semantics/analysis/pre_typecheck.py +++ /dev/null @@ -1,94 +0,0 @@ -from vyper import ast as vy_ast -from vyper.exceptions import UnfoldableNode - - -# try to fold a node, swallowing exceptions. this function is very similar to -# `VyperNode.get_folded_value()` but additionally checks in the constants -# table if the node is a `Name` node. -# -# CMC 2023-12-30 a potential refactor would be to move this function into -# `Name._try_fold` (which would require modifying the signature of _try_fold to -# take an optional constants table as parameter). this would remove the -# need to use this function in conjunction with `get_descendants` since -# `VyperNode._try_fold()` already recurses. it would also remove the need -# for `VyperNode._set_folded_value()`. -def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): - if node.has_folded_value: - return - - if isinstance(node, vy_ast.Name): - # check if it's in constants table - var_name = node.id - - if var_name not in constants: - return - - res = constants[var_name] - node._set_folded_value(res) - return - - try: - # call get_folded_value for its side effects - node.get_folded_value() - except UnfoldableNode: - pass - - -def _get_constants(node: vy_ast.Module) -> dict: - constants: dict[str, vy_ast.VyperNode] = {} - const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) - - while True: - n_processed = 0 - - for c in const_var_decls.copy(): - assert c.value is not None # guaranteed by VariableDecl.validate() - - for n in c.get_descendants(reverse=True): - _fold_with_constants(n, constants) - - try: - val = c.value.get_folded_value() - except UnfoldableNode: - # not foldable, maybe it depends on other constants - # so try again later - continue - - # note that if a constant is redefined, its value will be - # overwritten, but it is okay because the error is handled - # downstream - name = c.target.id - constants[name] = val - - n_processed += 1 - const_var_decls.remove(c) - - if n_processed == 0: - # this condition means that there are some constant vardecls - # whose values are not foldable. this can happen for struct - # and interface constants for instance. these are valid constant - # declarations, but we just can't fold them at this stage. - break - - return constants - - -# perform constant folding on a module AST -def pre_typecheck(node: vy_ast.Module) -> None: - """ - Perform pre-typechecking steps on a Module AST node. - At this point, this is limited to performing constant folding. - """ - constants = _get_constants(node) - - # note: use reverse to get descendants in leaf-first order - for n in node.get_descendants(reverse=True): - # try folding every single node. note this should be done before - # type checking because the typechecker requires literals or - # foldable nodes in type signatures and some other places (e.g. - # certain builtin kwargs). - # - # note we could limit to only folding nodes which are required - # during type checking, but it's easier to just fold everything - # and be done with it! - _fold_with_constants(n, constants) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ba1b02b8d6..359b51b71e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -650,6 +650,8 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) + + # builtins call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) return call_type_modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 429ba807e1..14949f693f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -19,7 +19,7 @@ # type of type `type_` class _GenericTypeAcceptor: def __repr__(self): - return repr(self.type_) + return f"GenericTypeAcceptor({self.type_})" def __init__(self, type_): self.type_ = type_ diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..f2c3d74525 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -4,7 +4,12 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args -from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.exceptions import ( + InterfaceViolation, + NamespaceCollision, + StructureException, + UnfoldableNode, +) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.namespace import get_namespace @@ -34,6 +39,7 @@ def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> No self._helper = VyperType(events | structs) self._id = _id + self._helper._id = _id self.functions = functions self.events = events self.structs = structs @@ -53,6 +59,15 @@ def abi_type(self) -> ABIType: def __repr__(self): return f"interface {self._id}" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + arg = node.args[0].get_folded_value() + if not isinstance(arg, vy_ast.Hex): + raise UnfoldableNode("not an address", arg) + + return node + # when using the type itself (not an instance) in the call position def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": self._ctor_arg_types(node) @@ -253,6 +268,8 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -262,7 +279,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # compute the interface, note this has the side effect of checking # for function collisions - self._helper = self.interface + _ = self.interface + + self._helper = VyperType() + self._helper._id = self._id for f in self.function_defs: # note: this checks for collisions @@ -275,6 +295,12 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for s in self.struct_defs: # add the type of the struct so it can be used in call position self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + self._helper.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for i in self.interface_defs: + # add the type of the interface so it can be used in call position + self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore + self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -308,6 +334,10 @@ def event_defs(self): def struct_defs(self): return self._module.get_children(vy_ast.StructDef) + @property + def interface_defs(self): + return self._module.get_children(vy_ast.InterfaceDef) + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a4e782349d..8ef9aa8d4a 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -10,6 +10,7 @@ InvalidAttribute, NamespaceCollision, StructureException, + UnfoldableNode, UnknownAttribute, VariableDeclarationException, ) @@ -357,6 +358,16 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": def __repr__(self): return f"{self._id} declaration object" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + args = [arg.get_folded_value() for arg in node.args] + if not isinstance(args[0], vy_ast.Dict): + raise UnfoldableNode("not a dict") + + # it can't be reduced, but this lets upstream code know it's constant + return node + @property def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types.values()) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index eb96375404..c82eb73afc 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -127,14 +127,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface + module_or_interface = module_or_interface.module_t - if not interface._attribute_in_annotation: + if not isinstance(module_or_interface, VyperType): raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + if not module_or_interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = module_or_interface.get_type_member(node.attr, node) # type: ignore assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 5a09f8378e..570aba771a 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -1,7 +1,7 @@ # maybe rename this `main.py` or `venom.py` # (can have an `__init__.py` which exposes the API). -from typing import Optional +from typing import Any, Optional from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel @@ -17,19 +17,26 @@ from vyper.venom.passes.dft import DFTPass from vyper.venom.venom_to_assembly import VenomCompiler +DEFAULT_OPT_LEVEL = OptimizationLevel.default() + def generate_assembly_experimental( - ctx: IRFunction, optimize: Optional[OptimizationLevel] = None + runtime_code: IRFunction, + deploy_code: Optional[IRFunction] = None, + optimize: OptimizationLevel = DEFAULT_OPT_LEVEL, ) -> list[str]: - compiler = VenomCompiler(ctx) - return compiler.generate_evm(optimize is OptimizationLevel.NONE) + # note: VenomCompiler is sensitive to the order of these! + if deploy_code is not None: + functions = [deploy_code, runtime_code] + else: + functions = [runtime_code] + compiler = VenomCompiler(functions) + return compiler.generate_evm(optimize == OptimizationLevel.NONE) -def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: - # Convert "old" IR to "new" IR - ctx = convert_ir_basicblock(ir) - # Run passes on "new" IR +def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: + # Run passes on Venom IR # TODO: Add support for optimization levels while True: changes = 0 @@ -53,4 +60,12 @@ def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRF if changes == 0: break - return ctx + +def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> tuple[IRFunction, IRFunction]: + # Convert "old" IR to "new" IR + ctx, ctx_runtime = convert_ir_basicblock(ir) + + _run_passes(ctx, optimize) + _run_passes(ctx_runtime, optimize) + + return ctx, ctx_runtime diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 6dfc3c3d7c..eed579463e 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -19,27 +19,6 @@ def calculate_cfg(ctx: IRFunction) -> None: bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() - # TODO: This is a hack to support the old IR format where `deploy` is - # an instruction. in the future we should have two entry points, one - # for the initcode and one for the runtime code. - deploy_bb = None - after_deploy_bb = None - for i, bb in enumerate(ctx.basic_blocks): - if bb.instructions[0].opcode == "deploy": - deploy_bb = bb - after_deploy_bb = ctx.basic_blocks[i + 1] - break - - if deploy_bb is not None: - assert after_deploy_bb is not None, "No block after deploy block" - entry_block = after_deploy_bb - has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" - if has_constructor: - deploy_bb.add_cfg_in(ctx.basic_blocks[0]) - entry_block.add_cfg_in(deploy_bb) - else: - entry_block = ctx.basic_blocks[0] - for bb in ctx.basic_blocks: assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 9afaa5e6fd..598b8af7d5 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -4,7 +4,7 @@ from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -33,7 +33,6 @@ NO_OUTPUT_INSTRUCTIONS = frozenset( [ - "deploy", "mstore", "sstore", "dstore", @@ -56,9 +55,7 @@ ] ) -CFG_ALTERING_INSTRUCTIONS = frozenset( - ["jmp", "djmp", "jnz", "call", "staticcall", "invoke", "deploy"] -) +CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "djmp", "jnz", "call", "staticcall", "invoke"]) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -273,7 +270,7 @@ def _ir_operand_from_value(val: Any) -> IROperand: if isinstance(val, IROperand): return val - assert isinstance(val, int) + assert isinstance(val, int), val return IRLiteral(val) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 665fa0c6c2..9f26fa8ec0 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -9,7 +9,7 @@ MemType, ) -GLOBAL_LABEL = IRLabel("global") +GLOBAL_LABEL = IRLabel("__global") class IRFunction: @@ -18,7 +18,10 @@ class IRFunction: """ name: IRLabel # symbol name + entry_points: list[IRLabel] # entry points args: list + ctor_mem_size: Optional[int] + immutables_len: Optional[int] basic_blocks: list[IRBasicBlock] data_segment: list[IRInstruction] last_label: int @@ -28,14 +31,30 @@ def __init__(self, name: IRLabel = None) -> None: if name is None: name = GLOBAL_LABEL self.name = name + self.entry_points = [] self.args = [] + self.ctor_mem_size = None + self.immutables_len = None self.basic_blocks = [] self.data_segment = [] self.last_label = 0 self.last_variable = 0 + self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) + def add_entry_point(self, label: IRLabel) -> None: + """ + Add entry point. + """ + self.entry_points.append(label) + + def remove_entry_point(self, label: IRLabel) -> None: + """ + Remove entry point. + """ + self.entry_points.remove(label) + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: """ Append basic block to function. @@ -91,7 +110,7 @@ def remove_unreachable_blocks(self) -> int: removed = 0 new_basic_blocks = [] for bb in self.basic_blocks: - if not bb.is_reachable and bb.label.value != "global": + if not bb.is_reachable and bb.label not in self.entry_points: removed += 1 else: new_basic_blocks.append(bb) @@ -119,16 +138,10 @@ def normalized(self) -> bool: if len(bb.cfg_in) <= 1: continue - # Check if there is a conditional jump at the end + # Check if there is a branching jump at the end # of one of the predecessors - # - # TODO: this check could be: - # `if len(in_bb.cfg_out) > 1: return False` - # but the cfg is currently not calculated "correctly" for - # the special deploy instruction. for in_bb in bb.cfg_in: - jump_inst = in_bb.instructions[-1] - if jump_inst.opcode in ("jnz", "djmp"): + if len(in_bb.cfg_out) > 1: return False # The function is normalized diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 9f5c23df0b..c86d3a3d67 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -87,19 +87,35 @@ def _get_symbols_common(a: dict, b: dict) -> dict: return ret -def convert_ir_basicblock(ir: IRnode) -> IRFunction: - global_function = IRFunction() - _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) +def _findIRnode(ir: IRnode, value: str) -> Optional[IRnode]: + if ir.value == value: + return ir + for arg in ir.args: + if isinstance(arg, IRnode): + ret = _findIRnode(arg, value) + if ret is not None: + return ret + return None + + +def convert_ir_basicblock(ir: IRnode) -> tuple[IRFunction, IRFunction]: + deploy_ir = _findIRnode(ir, "deploy") + assert deploy_ir is not None + + deploy_venom = IRFunction() + _convert_ir_basicblock(deploy_venom, ir, {}, OrderedSet(), {}) + deploy_venom.get_basic_block().append_instruction("stop") - for i, bb in enumerate(global_function.basic_blocks): - if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: - bb.append_instruction("jmp", global_function.basic_blocks[i + 1].label) + runtime_ir = deploy_ir.args[1] + runtime_venom = IRFunction() + _convert_ir_basicblock(runtime_venom, runtime_ir, {}, OrderedSet(), {}) - revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) - revert_bb = global_function.append_basic_block(revert_bb) - revert_bb.append_instruction("revert", 0, 0) + # Connect unterminated blocks to the next with a jump + for i, bb in enumerate(runtime_venom.basic_blocks): + if not bb.is_terminated and i < len(runtime_venom.basic_blocks) - 1: + bb.append_instruction("jmp", runtime_venom.basic_blocks[i + 1].label) - return global_function + return deploy_venom, runtime_venom def _convert_binary_op( @@ -279,20 +295,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value in ["pass", "stop", "return"]: pass elif ir.value == "deploy": - memsize = ir.args[0].value - ir_runtime = ir.args[1] - padding = ir.args[2].value - assert isinstance(memsize, int), "non-int memsize" - assert isinstance(padding, int), "non-int padding" - - runtimeLabel = ctx.get_next_label() - - ctx.get_basic_block().append_instruction("deploy", memsize, runtimeLabel, padding) - - bb = IRBasicBlock(runtimeLabel, ctx) - ctx.append_basic_block(bb) - - _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + ctx.ctor_mem_size = ir.args[0].value + ctx.immutables_len = ir.args[2].value + return None elif ir.value == "seq": func_t = ir.passthrough_metadata.get("func_t", None) if ir.is_self_call: diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 43e8d47235..26699099b2 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -14,13 +14,11 @@ class NormalizationPass(IRPass): changes = 0 def _split_basic_block(self, bb: IRBasicBlock) -> None: - # Iterate over the predecessors of the basic block + # Iterate over the predecessors to this basic block for in_bb in list(bb.cfg_in): - jump_inst = in_bb.instructions[-1] assert bb in in_bb.cfg_out - - # Handle branching - if jump_inst.opcode in ("jnz", "djmp"): + # Handle branching in the predecessor bb + if len(in_bb.cfg_out) > 1: self._insert_split_basicblock(bb, in_bb) self.changes += 1 break diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 0c32c3b816..926f8df8a3 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -65,6 +65,8 @@ ] ) +_REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. @@ -75,13 +77,13 @@ # with the assembler. My suggestion is to let this be for now, and we can # refactor it later when we are finished phasing out the old IR. class VenomCompiler: - ctx: IRFunction + ctxs: list[IRFunction] label_counter = 0 visited_instructions: OrderedSet # {IRInstruction} visited_basicblocks: OrderedSet # {IRBasicBlock} - def __init__(self, ctx: IRFunction): - self.ctx = ctx + def __init__(self, ctxs: list[IRFunction]): + self.ctxs = ctxs self.label_counter = 0 self.visited_instructions = OrderedSet() self.visited_basicblocks = OrderedSet() @@ -91,8 +93,8 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: self.visited_basicblocks = OrderedSet() self.label_counter = 0 - stack = StackModel() - asm: list[str] = [] + asm: list[Any] = [] + top_asm = asm # Before emitting the assembly, we need to make sure that the # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) @@ -101,41 +103,49 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: # assembly generation. # This is a side-effect of how dynamic jumps are temporarily being used # to support the O(1) dispatcher. -> look into calculate_cfg() - calculate_cfg(self.ctx) - NormalizationPass.run_pass(self.ctx) - calculate_liveness(self.ctx) - - assert self.ctx.normalized, "Non-normalized CFG!" - - self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) - - # Append postambles - revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] - runtime = None - if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): - runtime = asm.pop() - - asm.extend(revert_postamble) - if runtime: - runtime.extend(revert_postamble) - asm.append(runtime) + for ctx in self.ctxs: + calculate_cfg(ctx) + NormalizationPass.run_pass(ctx) + calculate_liveness(ctx) + + assert ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, ctx.basic_blocks[0], StackModel()) + + # TODO make this property on IRFunction + if ctx.immutables_len is not None and ctx.ctor_mem_size is not None: + while asm[-1] != "JUMPDEST": + asm.pop() + asm.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + asm.extend(["_OFST", "_sym_subcode_size", ctx.immutables_len]) # stack: len + asm.extend(["_mem_deploy_start"]) # stack: len mem_ofst + asm.extend(["RETURN"]) + asm.extend(_REVERT_POSTAMBLE) + runtime_asm = [ + RuntimeHeader("_sym_runtime_begin", ctx.ctor_mem_size, ctx.immutables_len) + ] + asm.append(runtime_asm) + asm = runtime_asm + else: + asm.extend(_REVERT_POSTAMBLE) - # Append data segment - data_segments: dict[Any, list[Any]] = dict() - for inst in self.ctx.data_segment: - if inst.opcode == "dbname": - label = inst.operands[0].value - data_segments[label] = [DataHeader(f"_sym_{label}")] - elif inst.opcode == "db": - data_segments[label].append(f"_sym_{inst.operands[0].value}") + # Append data segment + data_segments: dict = dict() + for inst in ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") - extent_point = asm if not isinstance(asm[-1], list) else asm[-1] - extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + asm.extend(list(data_segments.values())) if no_optimize is False: - optimize_assembly(asm) + optimize_assembly(top_asm) - return asm + return top_asm def _stack_reorder( self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] @@ -397,20 +407,6 @@ def _generate_evm_for_instruction( assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) elif opcode == "assert": assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) - elif opcode == "deploy": - memsize = inst.operands[0].value - padding = inst.operands[2].value - # TODO: fix this by removing deploy opcode altogether me move emition to ir translation - while assembly[-1] != "JUMPDEST": - assembly.pop() - assembly.extend( - ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] - ) - assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len - assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst - assembly.extend(["RETURN"]) - assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore - assembly = assembly[-1] elif opcode == "iload": loc = inst.operands[0].value assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"])