diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b6399b3ae9..fd78e2fff8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,11 +78,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [["3.10", "310"], ["3.11", "311"]] + python-version: [["3.11", "311"]] # run in modes: --optimize [gas, none, codesize] - flag: ["core", "no-opt", "codesize"] + opt-mode: ["gas", "none", "codesize"] + debug: [true, false] + # run across other python versions.# we don't really need to run all + # modes across all python versions - one is enough + include: + - python-version: ["3.10", "310"] + opt-mode: gas + debug: false - name: py${{ matrix.python-version[1] }}-${{ matrix.flag }} + name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }} steps: - uses: actions/checkout@v1 @@ -97,7 +104,7 @@ jobs: run: pip install tox - name: Run Tox - run: TOXENV=py${{ matrix.python-version[1] }}-${{ matrix.flag }} tox -r -- --reruns 10 --reruns-delay 1 -r aR tests/ + run: TOXENV=py${{ matrix.python-version[1] }} tox -r -- --optimize ${{ matrix.opt-mode }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} --reruns 10 --reruns-delay 1 -r aR tests/ - name: Upload Coverage uses: codecov/codecov-action@v1 diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 208771a5a9..6d1cdf98d7 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -113,6 +113,23 @@ Remix IDE While the Vyper version of the Remix IDE compiler is updated on a regular basis, it might be a bit behind the latest version found in the master branch of the repository. Make sure the byte code matches the output from your local compiler. +.. _optimization-mode: + +Compiler Optimization Modes +=========================== + +The vyper CLI tool accepts an optimization mode ``"none"``, ``"codesize"``, or ``"gas"`` (default). It can be set using the ``--optimize`` flag. For example, invoking ``vyper --optimize codesize MyContract.vy`` will compile the contract, optimizing for code size. As a rough summary of the differences between gas and codesize mode, in gas optimized mode, the compiler will try to generate bytecode which minimizes gas (up to a point), including: + +* using a sparse selector table which optimizes for gas over codesize +* inlining some constants, and +* trying to unroll some loops, especially for data copies. + +In codesize optimized mode, the compiler will try hard to minimize codesize by + +* using a dense selector table +* out-lining code, and +* using more loops for data copies. + .. _evm-version: diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index c7abb3e645..f58ab3b067 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -37,13 +37,13 @@ In the above examples, the contract will only compile with Vyper versions ``0.3. Optimization Mode ----------------- -The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For instance, the following contract will be compiled in a way which tries to minimize codesize: +The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For example, adding the following line to a contract will cause it to try to optimize for codesize: .. code-block:: python #pragma optimize codesize -The optimization mode can also be set as a compiler option. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue. +The optimization mode can also be set as a compiler option, which is documented in :ref:`optimization-mode`. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue. EVM Version ----------------- diff --git a/setup.py b/setup.py index 36a138aacd..bbf6e60f55 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import re import subprocess -from setuptools import find_packages, setup +from setuptools import setup extras_require = { "test": [ @@ -88,7 +88,7 @@ def _global_version(version): license="Apache License 2.0", keywords="ethereum evm smart contract language", include_package_data=True, - packages=find_packages(exclude=("tests", "docs")), + packages=["vyper"], python_requires=">=3.10,<4", py_modules=["vyper"], install_requires=[ diff --git a/tests/base_conftest.py b/tests/base_conftest.py index a78562e982..81e8dedc36 100644 --- a/tests/base_conftest.py +++ b/tests/base_conftest.py @@ -112,10 +112,10 @@ def w3(tester): return w3 -def _get_contract(w3, source_code, optimize, *args, **kwargs): +def _get_contract(w3, source_code, optimize, *args, override_opt_level=None, **kwargs): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) - settings.optimize = optimize + settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, # test that metadata gets generated diff --git a/tests/cli/vyper_json/test_parse_args_vyperjson.py b/tests/cli/vyper_json/test_parse_args_vyperjson.py index 08da5f1888..11e527843a 100644 --- a/tests/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/cli/vyper_json/test_parse_args_vyperjson.py @@ -57,7 +57,7 @@ def test_to_stdout(tmp_path, capfd): _parse_args([path.absolute().as_posix()]) out, _ = capfd.readouterr() output_json = json.loads(out) - assert _no_errors(output_json) + assert _no_errors(output_json), (INPUT_JSON, output_json) assert "contracts/foo.vy" in output_json["sources"] assert "contracts/bar.vy" in output_json["sources"] @@ -71,7 +71,7 @@ def test_to_file(tmp_path): assert output_path.exists() with output_path.open() as fp: output_json = json.load(fp) - assert _no_errors(output_json) + assert _no_errors(output_json), (INPUT_JSON, output_json) assert "contracts/foo.vy" in output_json["sources"] assert "contracts/bar.vy" in output_json["sources"] diff --git a/tests/compiler/__init__.py b/tests/compiler/__init__.py index e69de29bb2..35a11f851b 100644 --- a/tests/compiler/__init__.py +++ b/tests/compiler/__init__.py @@ -0,0 +1,2 @@ +# prevent module name collision between tests/compiler/test_pre_parser.py +# and tests/ast/test_pre_parser.py diff --git a/tests/compiler/test_default_settings.py b/tests/compiler/test_default_settings.py new file mode 100644 index 0000000000..ca05170b61 --- /dev/null +++ b/tests/compiler/test_default_settings.py @@ -0,0 +1,27 @@ +from vyper.codegen import core +from vyper.compiler.phases import CompilerData +from vyper.compiler.settings import OptimizationLevel, _is_debug_mode + + +def test_default_settings(): + source_code = "" + compiler_data = CompilerData(source_code) + _ = compiler_data.vyper_module # force settings to be computed + + assert compiler_data.settings.optimize == OptimizationLevel.GAS + + +def test_default_opt_level(): + assert OptimizationLevel.default() == OptimizationLevel.GAS + + +def test_codegen_opt_level(): + assert core._opt_level == OptimizationLevel.GAS + assert core._opt_gas() is True + assert core._opt_none() is False + assert core._opt_codesize() is False + + +def test_debug_mode(pytestconfig): + debug_mode = pytestconfig.getoption("enable_compiler_debug_mode") + assert _is_debug_mode() == debug_mode diff --git a/tests/conftest.py b/tests/conftest.py index 9c9c4191b9..d519ca3100 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from vyper import compiler from vyper.codegen.ir_node import IRnode -from vyper.compiler.settings import OptimizationLevel +from vyper.compiler.settings import OptimizationLevel, _set_debug_mode from vyper.ir import compile_ir, optimizer from .base_conftest import VyperContract, _get_contract, zero_gas_price_strategy @@ -43,6 +43,7 @@ def pytest_addoption(parser): default="gas", help="change optimization mode", ) + parser.addoption("--enable-compiler-debug-mode", action="store_true") @pytest.fixture(scope="module") @@ -51,6 +52,13 @@ def optimize(pytestconfig): return OptimizationLevel.from_string(flag) +@pytest.fixture(scope="session", autouse=True) +def debug(pytestconfig): + debug = pytestconfig.getoption("enable_compiler_debug_mode") + assert isinstance(debug, bool) + _set_debug_mode(debug) + + @pytest.fixture def keccak(): return Web3.keccak diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py index 8707b4c326..0d61a8f8f8 100644 --- a/tests/functional/semantics/analysis/test_for_loop.py +++ b/tests/functional/semantics/analysis/test_for_loop.py @@ -1,7 +1,12 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ImmutableViolation, TypeMismatch +from vyper.exceptions import ( + ArgumentException, + ImmutableViolation, + StateAccessViolation, + TypeMismatch, +) from vyper.semantics.analysis import validate_semantics @@ -59,6 +64,34 @@ def bar(): validate_semantics(vyper_module, {}) +def test_bad_keywords(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, boundddd=10): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ArgumentException): + validate_semantics(vyper_module, {}) + + +def test_bad_bound(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, bound=n): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(StateAccessViolation): + validate_semantics(vyper_module, {}) + + def test_modify_iterator_function_call(namespace): code = """ diff --git a/tests/parser/exceptions/test_invalid_reference.py b/tests/parser/exceptions/test_invalid_reference.py index 3aec6028e4..fe315e5cbf 100644 --- a/tests/parser/exceptions/test_invalid_reference.py +++ b/tests/parser/exceptions/test_invalid_reference.py @@ -37,6 +37,24 @@ def foo(): def foo(): int128 = 5 """, + """ +a: public(constant(uint256)) = 1 + +@external +def foo(): + b: uint256 = self.a + """, + """ +a: public(immutable(uint256)) + +@external +def __init__(): + a = 123 + +@external +def foo(): + b: uint256 = self.a + """, ] diff --git a/tests/parser/features/iteration/test_for_range.py b/tests/parser/features/iteration/test_for_range.py index 30f4bb87e3..395dd28231 100644 --- a/tests/parser/features/iteration/test_for_range.py +++ b/tests/parser/features/iteration/test_for_range.py @@ -14,6 +14,23 @@ def repeat(z: int128) -> int128: assert c.repeat(9) == 54 +def test_range_bound(get_contract, assert_tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(7): + assert c.repeat(n) == sum(range(n)) + + # check codegen inserts assertion for n greater than bound + assert_tx_failed(lambda: c.repeat(7)) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external diff --git a/tests/parser/functions/test_ecrecover.py b/tests/parser/functions/test_ecrecover.py index 77e9655b3e..40c9a6a936 100644 --- a/tests/parser/functions/test_ecrecover.py +++ b/tests/parser/functions/test_ecrecover.py @@ -40,3 +40,21 @@ def test_ecrecover_uints2() -> address: assert c.test_ecrecover_uints2() == local_account.address print("Passed ecrecover test") + + +def test_invalid_signature(get_contract): + code = """ +dummies: HashMap[address, HashMap[address, uint256]] + +@external +def test_ecrecover(hash: bytes32, v: uint8, r: uint256) -> address: + # read from hashmap to put garbage in 0 memory location + s: uint256 = self.dummies[msg.sender][msg.sender] + return ecrecover(hash, v, r, s) + """ + c = get_contract(code) + hash_ = bytes(i for i in range(32)) + v = 0 # invalid v! ecrecover precompile will not write to output buffer + r = 0 + # note web3.py decoding of 0x000..00 address is None. + assert c.test_ecrecover(hash_, v, r) is None diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3064ee308e..6229b47921 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -2,6 +2,7 @@ import pytest from hypothesis import given, settings +from vyper.compiler.settings import OptimizationLevel from vyper.exceptions import ArgumentException, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -33,12 +34,15 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: @pytest.mark.parametrize("literal_start", (True, False)) @pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=25, deadline=None) +@settings(max_examples=100, deadline=None) +@pytest.mark.fuzzing def test_slice_immutable( get_contract, assert_compile_failed, assert_tx_failed, + opt_level, bytesdata, start, literal_start, @@ -64,7 +68,7 @@ def do_splice() -> Bytes[{length_bound}]: """ def _get_contract(): - return get_contract(code, bytesdata, start, length) + return get_contract(code, bytesdata, start, length, override_opt_level=opt_level) if ( (start + length > length_bound and literal_start and literal_length) @@ -84,12 +88,15 @@ def _get_contract(): @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) @pytest.mark.parametrize("literal_start", (True, False)) @pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=25, deadline=None) +@settings(max_examples=100, deadline=None) +@pytest.mark.fuzzing def test_slice_bytes( get_contract, assert_compile_failed, assert_tx_failed, + opt_level, location, bytesdata, start, @@ -133,7 +140,7 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt """ def _get_contract(): - return get_contract(code, bytesdata) + return get_contract(code, bytesdata, override_opt_level=opt_level) data_length = len(bytesdata) if location == "literal" else length_bound if ( diff --git a/tests/parser/globals/test_getters.py b/tests/parser/globals/test_getters.py index 59c91cbeef..5eac074ef6 100644 --- a/tests/parser/globals/test_getters.py +++ b/tests/parser/globals/test_getters.py @@ -35,6 +35,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants): c: public(constant(uint256)) = 1 d: public(immutable(uint256)) e: public(immutable(uint256[2])) +f: public(constant(uint256[2])) = [3, 7] @external def __init__(): @@ -68,6 +69,7 @@ def __init__(): assert c.c() == 1 assert c.d() == 1729 assert c.e(0) == 2 + assert [c.f(i) for i in range(2)] == [3, 7] def test_getter_mutability(get_contract): diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py new file mode 100644 index 0000000000..01a83698b7 --- /dev/null +++ b/tests/parser/test_selector_table.py @@ -0,0 +1,198 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +import vyper.utils as utils +from vyper.codegen.jumptable_utils import ( + generate_dense_jumptable_info, + generate_sparse_jumptable_buckets, +) +from vyper.compiler.settings import OptimizationLevel + + +@given( + n_methods=st.integers(min_value=1, max_value=100), + seed=st.integers(min_value=0, max_value=2**64 - 1), +) +@pytest.mark.fuzzing +@settings(max_examples=10, deadline=None) +def test_sparse_jumptable_probe_depth(n_methods, seed): + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + _, buckets = generate_sparse_jumptable_buckets(sigs) + bucket_sizes = [len(bucket) for bucket in buckets.values()] + + # generally bucket sizes should be bounded at around 4, but + # just test that they don't get really out of hand + assert max(bucket_sizes) <= 8 + + # generally mean bucket size should be around 1.6, here just + # test they don't get really out of hand + assert sum(bucket_sizes) / len(bucket_sizes) <= 4 + + +@given( + n_methods=st.integers(min_value=4, max_value=100), + seed=st.integers(min_value=0, max_value=2**64 - 1), +) +@pytest.mark.fuzzing +@settings(max_examples=10, deadline=None) +def test_dense_jumptable_bucket_size(n_methods, seed): + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + n = len(sigs) + buckets = generate_dense_jumptable_info(sigs) + n_buckets = len(buckets) + + # generally should be around 14 buckets per 100 methods, here + # we test they don't get really out of hand + assert n_buckets / n < 0.4 or n < 10 + + +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) +# dense selector table packing boundaries at 256 and 65336 +@pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) +@settings(max_examples=5, deadline=None) +@given( + seed=st.integers(min_value=0, max_value=2**64 - 1), + max_default_args=st.integers(min_value=0, max_value=4), + default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]), +) +@pytest.mark.fuzzing +def test_selector_table_fuzz( + max_calldata_bytes, + seed, + max_default_args, + opt_level, + default_fn_mutability, + w3, + get_contract, + assert_tx_failed, + get_logs, +): + def abi_sig(calldata_words, i, n_default_args): + args = [] if not calldata_words else [f"uint256[{calldata_words}]"] + args.extend(["uint256"] * n_default_args) + argstr = ",".join(args) + return f"foo{seed + i}({argstr})" + + def generate_func_def(mutability, calldata_words, i, n_default_args): + arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"] + for j in range(n_default_args): + arglist.append(f"x{j}: uint256 = 0") + args = ", ".join(arglist) + _log_return = f"log _Return({i})" if mutability == "@payable" else "" + + return f""" +@external +{mutability} +def foo{seed + i}({args}) -> uint256: + {_log_return} + return {i} + """ + + @given( + methods=st.lists( + st.tuples( + st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), + st.integers(min_value=0, max_value=max_calldata_bytes // 32), + # n bytes to strip from calldata + st.integers(min_value=1, max_value=4), + # n default args + st.integers(min_value=0, max_value=max_default_args), + ), + min_size=1, + max_size=100, + ) + ) + @settings(max_examples=25) + def _test(methods): + func_defs = "\n".join( + generate_func_def(m, s, i, d) for i, (m, s, _, d) in enumerate(methods) + ) + + if default_fn_mutability == "": + default_fn_code = "" + elif default_fn_mutability in ("@nonpayable", "@payable"): + default_fn_code = f""" +@external +{default_fn_mutability} +def __default__(): + log CalledDefault() + """ + else: + # can't log from pure/view functions, just test that it returns + default_fn_code = """ +@external +def __default__(): + pass + """ + + code = f""" +event CalledDefault: + pass + +event _Return: + val: uint256 + +{func_defs} + +{default_fn_code} + """ + + c = get_contract(code, override_opt_level=opt_level) + + for i, (mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate(methods): + funcname = f"foo{seed + i}" + func = getattr(c, funcname) + + for j in range(n_default_args + 1): + args = [[1] * n_calldata_words] if n_calldata_words else [] + args.extend([1] * j) + + # check the function returns as expected + assert func(*args) == i + + method_id = utils.method_id(abi_sig(n_calldata_words, i, j)) + + argsdata = b"\x00" * (n_calldata_words * 32 + j * 32) + + # do payable check + if mutability == "@payable": + tx = func(*args, transact={"value": 1}) + (event,) = get_logs(tx, c, "_Return") + assert event.args.val == i + else: + hexstr = (method_id + argsdata).hex() + txdata = {"to": c.address, "data": hexstr, "value": 1} + assert_tx_failed(lambda: w3.eth.send_transaction(txdata)) + + # now do calldatasize check + # strip some bytes + calldata = (method_id + argsdata)[:-n_strip_bytes] + hexstr = calldata.hex() + tx_params = {"to": c.address, "data": hexstr} + if n_calldata_words == 0 and j == 0: + # no args, hit default function + if default_fn_mutability == "": + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + elif default_fn_mutability == "@payable": + # we should be able to send eth to it + tx_params["value"] = 1 + tx = w3.eth.send_transaction(tx_params) + logs = get_logs(tx, c, "CalledDefault") + assert len(logs) == 1 + else: + tx = w3.eth.send_transaction(tx_params) + + # note: can't emit logs from view/pure functions, + # so the logging is not tested. + if default_fn_mutability == "@nonpayable": + logs = get_logs(tx, c, "CalledDefault") + assert len(logs) == 1 + + # check default function reverts + tx_params["value"] = 1 + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + else: + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + + _test() diff --git a/tox.ini b/tox.ini index 9b63630f58..c949354dfe 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py{310,311}-{core,no-opt} + py{310,311} lint mypy docs @@ -8,9 +8,7 @@ envlist = [testenv] usedevelop = True commands = - core: pytest -m "not fuzzing" --showlocals {posargs:tests/} - no-opt: pytest -m "not fuzzing" --showlocals --optimize none {posargs:tests/} - codesize: pytest -m "not fuzzing" --showlocals --optimize codesize {posargs:tests/} + pytest -m "not fuzzing" --showlocals {posargs:tests/} basepython = py310: python3.10 py311: python3.11 diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 753f2687cd..5471b971a4 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -49,7 +49,6 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: # the base return statement is an `Attribute` node, e.g. `self.` # for each input type we wrap it in a `Subscript` to access a specific member return_stmt = vy_ast.Attribute(value=vy_ast.Name(id="self"), attr=func_type.name) - return_stmt._metadata["type"] = node._metadata["type"] for i, type_ in enumerate(input_types): if not isinstance(annotation, vy_ast.Subscript): diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 77806d734c..ca9979b2a3 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -72,8 +72,8 @@ function_def: [decorators] function_sig ":" body _EVENT_DECL: "event" event_member: NAME ":" type indexed_event_arg: NAME ":" "indexed" "(" type ")" -event_body: _NEWLINE _INDENT ((event_member | indexed_event_arg) _NEWLINE)+ _DEDENT // Events which use no args use a pass statement instead +event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT event_def: _EVENT_DECL NAME ":" ( event_body | _PASS ) // Enums diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 7c907b4d08..2497928035 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -339,7 +339,7 @@ def __hash__(self): def __eq__(self, other): if not isinstance(other, type(self)): return False - if other.node_id != self.node_id: + if getattr(other, "node_id", None) != getattr(self, "node_id", None): return False for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__): if getattr(self, field_name, None) != getattr(other, field_name, None): diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 35153af9d5..7e677b3b92 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -119,12 +119,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version - if string.startswith("#pragma "): - pragma = string.removeprefix("#pragma").strip() + if contents.startswith("pragma "): + pragma = contents.removeprefix("pragma ").strip() if pragma.startswith("version "): if settings.compiler_version is not None: raise StructureException("pragma version specified twice!", start) - compiler_version = pragma.removeprefix("version ".strip()) + compiler_version = pragma.removeprefix("version ").strip() validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 105baa47d6..783c45ca1f 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -770,29 +770,19 @@ def infer_arg_types(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + input_buf = context.new_internal_variable(get_type_for_exact_size(128)) + output_buf = MemoryPositions.FREE_VAR_SPACE return IRnode.from_list( [ "seq", - ["mstore", placeholder_node, args[0]], - ["mstore", ["add", placeholder_node, 32], args[1]], - ["mstore", ["add", placeholder_node, 64], args[2]], - ["mstore", ["add", placeholder_node, 96], args[3]], - [ - "pop", - [ - "staticcall", - ["gas"], - 1, - placeholder_node, - 128, - MemoryPositions.FREE_VAR_SPACE, - 32, - ], - ], - ["mload", MemoryPositions.FREE_VAR_SPACE], + # clear output memory first, ecrecover can return 0 bytes + ["mstore", output_buf, 0], + ["mstore", input_buf, args[0]], + ["mstore", input_buf + 32, args[1]], + ["mstore", input_buf + 64, args[2]], + ["mstore", input_buf + 96, args[3]], + ["staticcall", "gas", 1, input_buf, 128, output_buf, 32], + ["mload", output_buf], ], typ=AddressT(), ) @@ -1640,7 +1630,9 @@ def _create_ir(value, buf, length, salt=None, checked=True): if not checked: return ret - return clamp_nonzero(ret) + ret = clamp_nonzero(ret) + ret.set_error_msg(f"{create_op} failed") + return ret # calculate the gas used by create for a given number of bytes @@ -1836,7 +1828,10 @@ def _build_create_IR(self, expr, args, context, value, salt): ir = ["seq"] # make sure there is actually code at the target - ir.append(["assert", codesize]) + check_codesize = ["assert", codesize] + ir.append( + IRnode.from_list(check_codesize, error_msg="empty target (create_copy_of)") + ) # store the preamble at msize + 22 (zero padding) preamble, preamble_len = _create_preamble(codesize) @@ -1926,7 +1921,12 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar # (code_ofst == (extcodesize target) would be empty # initcode, which we disallow for hygiene reasons - # same as `create_copy_of` on an empty target). - ir.append(["assert", ["sgt", codesize, 0]]) + check_codesize = ["assert", ["sgt", codesize, 0]] + ir.append( + IRnode.from_list( + check_codesize, error_msg="empty target (create_from_blueprint)" + ) + ) # copy the target code into memory. # layout starting from mem_ofst: diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 55e0fc82b2..9c96d55040 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -11,7 +11,12 @@ import vyper.codegen.ir_node as ir_node from vyper.cli import vyper_json from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path -from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT, OptimizationLevel, Settings +from vyper.compiler.settings import ( + VYPER_TRACEBACK_LIMIT, + OptimizationLevel, + Settings, + _set_debug_mode, +) from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS from vyper.typing import ContractCodes, ContractPath, OutputFormats @@ -105,7 +110,12 @@ def _parse_args(argv): dest="evm_version", ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") - parser.add_argument("--optimize", help="Optimization flag", choices=["gas", "codesize", "none"]) + parser.add_argument( + "--optimize", + help="Optimization flag (defaults to 'gas')", + choices=["gas", "codesize", "none"], + ) + parser.add_argument("--debug", help="Compile in debug mode", action="store_true") parser.add_argument( "--no-bytecode-metadata", help="Do not add metadata to bytecode", action="store_true" ) @@ -151,6 +161,9 @@ def _parse_args(argv): output_formats = tuple(uniq(args.format.split(","))) + if args.debug: + _set_debug_mode(True) + if args.no_optimize and args.optimize: raise ValueError("Cannot use `--no-optimize` and `--optimize` at the same time!") @@ -165,7 +178,7 @@ def _parse_args(argv): settings.evm_version = args.evm_version if args.verbose: - print(f"using `{settings}`", file=sys.stderr) + print(f"cli specified: `{settings}`", file=sys.stderr) compiled = compile_files( args.input_files, diff --git a/vyper/cli/vyper_ir.py b/vyper/cli/vyper_ir.py index 6831f39473..1f90badcaa 100755 --- a/vyper/cli/vyper_ir.py +++ b/vyper/cli/vyper_ir.py @@ -55,7 +55,7 @@ def compile_to_ir(input_file, output_formats, show_gas_estimates=False): compiler_data["asm"] = asm if "bytecode" in output_formats: - (bytecode, _srcmap) = compile_ir.assembly_to_evm(asm) + bytecode, _ = compile_ir.assembly_to_evm(asm) compiler_data["bytecode"] = "0x" + bytecode.hex() return compiler_data diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index f47f88ac85..e1d3ea12b4 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -539,6 +539,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check): # an array index, and the clamp will throw an error. # NOTE: there are optimization rules for this when ix or bound is literal ix = clamp("lt", ix, bound) + ix.set_error_msg(f"{parent.typ} bounds check") if parent.encoding == Encoding.ABI: if parent.location == STORAGE: @@ -1032,7 +1033,6 @@ def eval_seq(ir_node): return None -# TODO move return checks to vyper/semantics/validation def is_return_from_function(node): if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in ( "raw_revert", @@ -1044,6 +1044,8 @@ def is_return_from_function(node): return False +# TODO this is almost certainly duplicated with check_terminus_node +# in vyper/semantics/analysis/local.py def check_single_exit(fn_node): _check_return_body(fn_node, fn_node.body) for node in fn_node.get_descendants(vy_ast.If): diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 08bebbb4a5..94617bef35 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1 @@ -from .common import generate_ir_for_function # noqa +from .common import FuncIR, generate_ir_for_function # noqa diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index fd65b12265..3fd5ce0b29 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -4,7 +4,7 @@ import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.core import check_single_exit, getpos +from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.global_context import GlobalContext @@ -63,12 +63,32 @@ def internal_function_label(self, is_ctor_context: bool = False) -> str: return self.ir_identifier + suffix +class FuncIR: + pass + + +@dataclass +class EntryPointInfo: + func_t: ContractFunctionT + min_calldatasize: int # the min calldata required for this entry point + ir_node: IRnode # the ir for this entry point + + +@dataclass +class ExternalFuncIR(FuncIR): + entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points + common_ir: IRnode # the "common" code for the function + + +@dataclass +class InternalFuncIR(FuncIR): + func_ir: IRnode # the code for the function + + +# TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, - global_ctx: GlobalContext, - skip_nonpayable_check: bool, - is_ctor_context: bool = False, -) -> IRnode: + code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False +) -> FuncIR: """ Parse a function and produce IR code for the function, includes: - Signature method if statement @@ -82,6 +102,7 @@ def generate_ir_for_function( func_t._ir_info = _FuncIRInfo(func_t) # Validate return statements. + # XXX: This should really be in semantics pass. check_single_exit(code) callees = func_t.called_functions @@ -106,19 +127,23 @@ def generate_ir_for_function( ) if func_t.is_internal: - assert skip_nonpayable_check is False - o = generate_ir_for_internal_function(code, func_t, context) + ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) + func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: - if func_t.is_payable: - assert skip_nonpayable_check is False # nonsense - o = generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check) - - o.source_pos = getpos(code) + kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) + entry_points = { + k: EntryPointInfo(func_t, mincalldatasize, ir_node) + for k, (mincalldatasize, ir_node) in kwarg_handlers.items() + } + ret = ExternalFuncIR(entry_points, common) + # note: this ignores the cost of traversing selector table + func_t._ir_info.gas_estimate = ret.common_ir.gas frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY frame_info = FrameInfo(allocate_start, frame_size, context.vars) + # XXX: when can this happen? if func_t._ir_info.frame_info is None: func_t._ir_info.set_frame_info(frame_info) else: @@ -128,9 +153,7 @@ def generate_ir_for_function( # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since - # it is already accounted for by the caller.) - o.add_gas_estimate += calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - - func_t._ir_info.gas_estimate = o.gas + mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore + ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - return o + return ret diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 207356860b..32236e9aad 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -1,6 +1,3 @@ -from typing import Any, List - -import vyper.utils as util from vyper.codegen.abi_encoder import abi_encoding_matches_vyper from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp @@ -15,7 +12,7 @@ # register function args with the local calling context. # also allocate the ones that live in memory (i.e. kwargs) -def _register_function_args(func_t: ContractFunctionT, context: Context) -> List[IRnode]: +def _register_function_args(func_t: ContractFunctionT, context: Context) -> list[IRnode]: ret = [] # the type of the calldata base_args_t = TupleT(tuple(arg.typ for arg in func_t.positional_args)) @@ -52,13 +49,9 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> List return ret -def _annotated_method_id(abi_sig): - method_id = util.method_id_int(abi_sig) - annotation = f"{hex(method_id)}: {abi_sig}" - return IRnode(method_id, annotation=annotation) - - -def _generate_kwarg_handlers(func_t: ContractFunctionT, context: Context) -> List[Any]: +def _generate_kwarg_handlers( + func_t: ContractFunctionT, context: Context +) -> dict[str, tuple[int, IRnode]]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -75,7 +68,6 @@ def handler_for(calldata_kwargs, default_kwargs): calldata_args_t = TupleT(list(arg.typ for arg in calldata_args)) abi_sig = func_t.abi_signature_for_kwargs(calldata_kwargs) - method_id = _annotated_method_id(abi_sig) calldata_kwargs_ofst = IRnode( 4, location=CALLDATA, typ=calldata_args_t, encoding=Encoding.ABI @@ -88,11 +80,6 @@ def handler_for(calldata_kwargs, default_kwargs): args_abi_t = calldata_args_t.abi_type calldata_min_size = args_abi_t.min_size() + 4 - # note we don't need the check if calldata_min_size == 4, - # because the global calldatasize check ensures that already. - if calldata_min_size > 4: - ret.append(["assert", ["ge", "calldatasize", calldata_min_size]]) - # TODO optimize make_setter by using # TupleT(list(arg.typ for arg in calldata_kwargs + default_kwargs)) # (must ensure memory area is contiguous) @@ -123,11 +110,10 @@ def handler_for(calldata_kwargs, default_kwargs): ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) - method_id_check = ["eq", "_calldata_method_id", method_id] - ret = ["if", method_id_check, ret] - return ret + # return something we can turn into ExternalFuncIR + return abi_sig, calldata_min_size, ret - ret = ["seq"] + ret = {} keyword_args = func_t.keyword_args @@ -139,9 +125,12 @@ def handler_for(calldata_kwargs, default_kwargs): calldata_kwargs = keyword_args[:i] default_kwargs = keyword_args[i:] - ret.append(handler_for(calldata_kwargs, default_kwargs)) + sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) + ret[sig] = calldata_min_size, ir_node + + sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret.append(handler_for(keyword_args, [])) + ret[sig] = calldata_min_size, ir_node return ret @@ -149,7 +138,7 @@ def handler_for(calldata_kwargs, default_kwargs): # TODO it would be nice if this returned a data structure which were # amenable to generating a jump table instead of the linear search for # method_id we have now. -def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check): +def generate_ir_for_external_function(code, func_t, context): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, @@ -174,14 +163,6 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che # generate the main body of the function body += handle_base_args - if not func_t.is_payable and not skip_nonpayable_check: - # if the contract contains payable functions, but this is not one of them - # add an assertion that the value of the call is zero - nonpayable_check = IRnode.from_list( - ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" - ) - body.append(nonpayable_check) - body += nonreentrant_pre body += [parse_body(code.body, context, ensure_terminated=True)] @@ -201,22 +182,10 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che if context.return_type is not None: exit_sequence_args += ["ret_ofst", "ret_len"] # wrap the exit in a labeled block - exit = ["label", func_t._ir_info.exit_sequence_label, exit_sequence_args, exit_sequence] + exit_ = ["label", func_t._ir_info.exit_sequence_label, exit_sequence_args, exit_sequence] # the ir which comprises the main body of the function, # besides any kwarg handling - func_common_ir = ["seq", body, exit] - - if func_t.is_fallback or func_t.is_constructor: - ret = ["seq"] - # add a goto to make the function entry look like other functions - # (for zksync interpreter) - ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) - ret.append(func_common_ir) - else: - ret = kwarg_handlers - # sneak the base code into the kwarg handler - # TODO rethink this / make it clearer - ret[-1][-1].append(func_common_ir) + func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return IRnode.from_list(ret, source_pos=getpos(code)) + return kwarg_handlers, func_common_ir diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 0895e5f02d..6cb0a07281 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -148,6 +148,13 @@ def _check(condition, err): self.valency = 1 self._gas = 5 + elif isinstance(self.value, bytes): + # a literal bytes value, probably inside a "data" node. + _check(len(self.args) == 0, "bytes can't have arguments") + + self.valency = 0 + self._gas = 0 + elif isinstance(self.value, str): # Opcodes and pseudo-opcodes (e.g. clamp) if self.value.upper() in get_ir_opcodes(): @@ -264,8 +271,11 @@ def _check(condition, err): self.valency = 0 self._gas = sum([arg.gas for arg in self.args]) elif self.value == "label": - if not self.args[1].value == "var_list": - raise CodegenPanic(f"2nd argument to label must be var_list, {self}") + _check( + self.args[1].value == "var_list", + f"2nd argument to label must be var_list, {self}", + ) + _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}") self.valency = 0 self._gas = 1 + sum(t.gas for t in self.args) elif self.value == "unique_symbol": @@ -330,6 +340,14 @@ def is_complex_ir(self): and self.value.lower() not in do_not_cache ) + # set an error message and push down into all children. + # useful for overriding an error message generated by a helper + # function with a more specific error message. + def set_error_msg(self, error_msg: str) -> None: + self.error_msg = error_msg + for arg in self.args: + arg.set_error_msg(error_msg) + # get the unique symbols contained in this node, which provides # sanity check invariants for the optimizer. # cache because it's a perf hotspot. note that this (and other cached diff --git a/vyper/codegen/jumptable_utils.py b/vyper/codegen/jumptable_utils.py new file mode 100644 index 0000000000..6987ce90bd --- /dev/null +++ b/vyper/codegen/jumptable_utils.py @@ -0,0 +1,195 @@ +# helper module which implements jumptable for function selection +import math +from dataclasses import dataclass + +from vyper.utils import method_id_int + + +@dataclass +class Signature: + method_id: int + payable: bool + + +# bucket for dense function +@dataclass +class Bucket: + bucket_id: int + magic: int + method_ids: list[int] + + @property + def image(self): + return _image_of([s for s in self.method_ids], self.magic) + + @property + # return method ids, sorted by by their image + def method_ids_image_order(self): + return [x[1] for x in sorted(zip(self.image, self.method_ids))] + + @property + def bucket_size(self): + return len(self.method_ids) + + +BITS_MAGIC = 24 # a constant which produced good results, see _bench_dense() + + +def _image_of(xs, magic): + bits_shift = BITS_MAGIC + + # take the upper bits from the multiplication for more entropy + # can we do better using primes of some sort? + return [((x * magic) >> bits_shift) % len(xs) for x in xs] + + +class _Failure(Exception): + pass + + +def find_magic_for(xs): + for m in range(2**16): + test = _image_of(xs, m) + if len(test) == len(set(test)): + return m + + raise _Failure(f"Could not find hash for {xs}") + + +def _mk_buckets(method_ids, n_buckets): + buckets = {} + for x in method_ids: + t = x % n_buckets + buckets.setdefault(t, []) + buckets[t].append(x) + return buckets + + +# two layer method for generating perfect hash +# first get "reasonably good" distribution by using +# method_id % len(method_ids) +# second, get the magic for the bucket. +def _dense_jumptable_info(method_ids, n_buckets): + buckets = _mk_buckets(method_ids, n_buckets) + + ret = {} + for bucket_id, method_ids in buckets.items(): + magic = find_magic_for(method_ids) + ret[bucket_id] = Bucket(bucket_id, magic, method_ids) + + return ret + + +START_BUCKET_SIZE = 5 + + +# this is expensive! for 80 methods, costs about 350ms and probably +# linear in # of methods. +# see _bench_perfect() +# note the buckets are NOT in order! +def generate_dense_jumptable_info(signatures): + method_ids = [method_id_int(sig) for sig in signatures] + n = len(signatures) + # start at bucket size of 5 and try to improve (generally + # speaking we want as few buckets as possible) + n_buckets = (n // START_BUCKET_SIZE) + 1 + ret = None + tried_exhaustive = False + while n_buckets > 0: + try: + # print(f"trying {n_buckets} (bucket size {n // n_buckets})") + ret = _dense_jumptable_info(method_ids, n_buckets) + except _Failure: + if ret is not None: + break + + # we have not tried exhaustive search. try really hard + # to find a valid jumptable at the cost of performance + if not tried_exhaustive: + # print("failed with guess! trying exhaustive search.") + n_buckets = n + tried_exhaustive = True + continue + else: + raise RuntimeError(f"Could not generate jumptable! {signatures}") + n_buckets -= 1 + + return ret + + +# note the buckets are NOT in order! +def generate_sparse_jumptable_buckets(signatures): + method_ids = [method_id_int(sig) for sig in signatures] + n = len(signatures) + + # search a range of buckets to try to minimize bucket size + # (doing the range search improves worst worst bucket size from 9 to 4, + # see _bench_sparse) + lo = max(1, math.floor(n * 0.85)) + hi = max(1, math.ceil(n * 1.15)) + stats = {} + for i in range(lo, hi + 1): + buckets = _mk_buckets(method_ids, i) + + stats[i] = buckets + + min_max_bucket_size = hi + 1 # smallest max_bucket_size + # find the smallest i which gives us the smallest max_bucket_size + for i, buckets in stats.items(): + max_bucket_size = max(len(bucket) for bucket in buckets.values()) + if max_bucket_size < min_max_bucket_size: + min_max_bucket_size = max_bucket_size + ret = i, buckets + + assert ret is not None + return ret + + +# benchmark for quality of buckets +def _bench_dense(N=1_000, n_methods=100): + import random + + stats = [] + for i in range(N): + seed = random.randint(0, 2**64 - 1) + # "large" contracts in prod hit about ~50 methods, test with + # double the limit + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + + xs = generate_dense_jumptable_info(sigs) + print(f"found. n buckets {len(xs)}") + stats.append(xs) + + def mean(xs): + return sum(xs) / len(xs) + + avg_n_buckets = mean([len(jt) for jt in stats]) + # usually around ~14 buckets per 100 sigs + # N=10, time=3.6s + print(f"average N buckets: {avg_n_buckets}") + + +def _bench_sparse(N=10_000, n_methods=80): + import random + + stats = [] + for _ in range(N): + seed = random.randint(0, 2**64 - 1) + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + _, buckets = generate_sparse_jumptable_buckets(sigs) + + bucket_sizes = [len(bucket) for bucket in buckets.values()] + worst_bucket_size = max(bucket_sizes) + mean_bucket_size = sum(bucket_sizes) / len(bucket_sizes) + stats.append((worst_bucket_size, mean_bucket_size)) + + # N=10_000, time=9s + # range 0.85*n - 1.15*n + # worst worst bucket size: 4 + # avg worst bucket size: 3.0018 + # worst mean bucket size: 2.0 + # avg mean bucket size: 1.579112583664968 + print("worst worst bucket size:", max(x[0] for x in stats)) + print("avg worst bucket size:", sum(x[0] for x in stats) / len(stats)) + print("worst mean bucket size:", max(x[1] for x in stats)) + print("avg mean bucket size:", sum(x[1] for x in stats) / len(stats)) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index b98e4d0f86..ebe7f92cf2 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -1,12 +1,15 @@ -# a contract.vy -- all functions and constructor +# a compilation unit -- all functions and constructor from typing import Any, List +from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic +from vyper.utils import method_id_int def _topsort_helper(functions, lookup): @@ -47,92 +50,349 @@ def _is_payable(func_ast): return func_ast._metadata["type"].is_payable -# codegen for all runtime functions + callvalue/calldata checks + method selector routines -def _runtime_ir(runtime_functions, global_ctx): - # categorize the runtime functions because we will organize the runtime - # code into the following sections: - # payable functions, nonpayable functions, fallback function, internal_functions - internal_functions = [f for f in runtime_functions if _is_internal(f)] +def _annotated_method_id(abi_sig): + method_id = method_id_int(abi_sig) + annotation = f"{hex(method_id)}: {abi_sig}" + return IRnode(method_id, annotation=annotation) - external_functions = [f for f in runtime_functions if not _is_internal(f)] - default_function = next((f for f in external_functions if _is_fallback(f)), None) - # functions that need to go exposed in the selector section - regular_functions = [f for f in external_functions if not _is_fallback(f)] - payables = [f for f in regular_functions if _is_payable(f)] - nonpayables = [f for f in regular_functions if not _is_payable(f)] +def label_for_entry_point(abi_sig, entry_point): + method_id = method_id_int(abi_sig) + return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}" - # create a map of the IR functions since they might live in both - # runtime and deploy code (if init function calls them) - internal_functions_ir: list[IRnode] = [] - for func_ast in internal_functions: - func_ir = generate_ir_for_function(func_ast, global_ctx, False) - internal_functions_ir.append(func_ir) +# adapt whatever generate_ir_for_function gives us into an IR node +def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): + func_t = func_ast._metadata["type"] + assert func_t.is_fallback or func_t.is_constructor + + ret = ["seq"] + if not func_t.is_payable: + callvalue_check = ["assert", ["iszero", "callvalue"]] + ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) + + func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + assert len(func_ir.entry_points) == 1 + + # add a goto to make the function entry look like other functions + # (for zksync interpreter) + ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) + ret.append(func_ir.common_ir) + + return IRnode.from_list(ret) + + +def _ir_for_internal_function(func_ast, *args, **kwargs): + return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + + +def _generate_external_entry_points(external_functions, global_ctx): + entry_points = {} # map from ABI sigs to ir code + sig_of = {} # reverse map from method ids to abi sig + + for code in external_functions: + func_ir = generate_ir_for_function(code, global_ctx) + for abi_sig, entry_point in func_ir.entry_points.items(): + assert abi_sig not in entry_points + entry_points[abi_sig] = entry_point + sig_of[method_id_int(abi_sig)] = abi_sig + + # stick function common body into final entry point to save a jump + ir_node = IRnode.from_list(["seq", entry_point.ir_node, func_ir.common_ir]) + entry_point.ir_node = ir_node + + return entry_points, sig_of + + +# codegen for all runtime functions + callvalue/calldata checks, +# with O(1) jumptable for selector table. +# uses two level strategy: uses `method_id % n_buckets` to descend +# into a bucket (of about 8-10 items), and then uses perfect hash +# to select the final function. +# costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) +def _selector_section_dense(external_functions, global_ctx): + function_irs = [] - # for some reason, somebody may want to deploy a contract with no - # external functions, or more likely, a "pure data" contract which - # contains immutables if len(external_functions) == 0: - # TODO: prune internal functions in this case? dead code eliminator - # might not eliminate them, since internal function jumpdest is at the - # first instruction in the contract. - runtime = ["seq"] + internal_functions_ir - return runtime + return IRnode.from_list(["seq"]) - # note: if the user does not provide one, the default fallback function - # reverts anyway. so it does not hurt to batch the payable check. - default_is_nonpayable = default_function is None or not _is_payable(default_function) + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) - # when a contract has a nonpayable default function, - # we can do a single check for all nonpayable functions - batch_payable_check = len(nonpayables) > 0 and default_is_nonpayable - skip_nonpayable_check = batch_payable_check + # generate the label so the jumptable works + for abi_sig, entry_point in entry_points.items(): + label = label_for_entry_point(abi_sig, entry_point) + ir_node = ["label", label, ["var_list"], entry_point.ir_node] + function_irs.append(IRnode.from_list(ir_node)) - selector_section = ["seq"] + jumptable_info = jumptable_utils.generate_dense_jumptable_info(entry_points.keys()) + n_buckets = len(jumptable_info) + + # bucket magic <2 bytes> | bucket location <2 bytes> | bucket size <1 byte> + # TODO: can make it smaller if the largest bucket magic <= 255 + SZ_BUCKET_HEADER = 5 - for func_ast in payables: - func_ir = generate_ir_for_function(func_ast, global_ctx, False) - selector_section.append(func_ir) + selector_section = ["seq"] - if batch_payable_check: - nonpayable_check = IRnode.from_list( - ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" + bucket_id = ["mod", "_calldata_method_id", n_buckets] + bucket_hdr_location = [ + "add", + ["symbol", "BUCKET_HEADERS"], + ["mul", bucket_id, SZ_BUCKET_HEADER], + ] + # get bucket header + dst = 32 - SZ_BUCKET_HEADER + assert dst >= 0 + + if _is_debug_mode(): + selector_section.append(["assert", ["eq", "msize", 0]]) + + selector_section.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) + + # figure out the minimum number of bytes we can use to encode + # min_calldatasize in function info + largest_mincalldatasize = max(f.min_calldatasize for f in entry_points.values()) + FN_METADATA_BYTES = (largest_mincalldatasize.bit_length() + 7) // 8 + + func_info_size = 4 + 2 + FN_METADATA_BYTES + # grab function info. + # method id <4 bytes> | label <2 bytes> | func info <1-3 bytes> + # func info (1-3 bytes, packed) for: expected calldatasize, is_nonpayable bit + # NOTE: might be able to improve codesize if we use variable # of bytes + # per bucket + + hdr_info = IRnode.from_list(["mload", 0]) + with hdr_info.cache_when_complex("hdr_info") as (b1, hdr_info): + bucket_location = ["and", 0xFFFF, shr(8, hdr_info)] + bucket_magic = shr(24, hdr_info) + bucket_size = ["and", 0xFF, hdr_info] + # ((method_id * bucket_magic) >> BITS_MAGIC) % bucket_size + func_id = [ + "mod", + shr(jumptable_utils.BITS_MAGIC, ["mul", bucket_magic, "_calldata_method_id"]), + bucket_size, + ] + func_info_location = ["add", bucket_location, ["mul", func_id, func_info_size]] + dst = 32 - func_info_size + assert func_info_size >= SZ_BUCKET_HEADER # otherwise mload will have dirty bytes + assert dst >= 0 + selector_section.append(b1.resolve(["codecopy", dst, func_info_location, func_info_size])) + + func_info = IRnode.from_list(["mload", 0]) + fn_metadata_mask = 2 ** (FN_METADATA_BYTES * 8) - 1 + calldatasize_mask = fn_metadata_mask - 1 # ex. 0xFFFE + with func_info.cache_when_complex("func_info") as (b1, func_info): + x = ["seq"] + + # expected calldatasize always satisfies (x - 4) % 32 == 0 + # the lower 5 bits are always 0b00100, so we can use those + # bits for other purposes. + is_nonpayable = ["and", 1, func_info] + expected_calldatasize = ["and", calldatasize_mask, func_info] + + label_bits_ofst = FN_METADATA_BYTES * 8 + function_label = ["and", 0xFFFF, shr(label_bits_ofst, func_info)] + method_id_bits_ofst = (FN_METADATA_BYTES + 2) * 8 + function_method_id = shr(method_id_bits_ofst, func_info) + + # check method id is right, if not then fallback. + # need to check calldatasize >= 4 in case there are + # trailing 0s in the method id. + calldatasize_valid = ["gt", "calldatasize", 3] + method_id_correct = ["eq", function_method_id, "_calldata_method_id"] + should_fallback = ["iszero", ["and", calldatasize_valid, method_id_correct]] + x.append(["if", should_fallback, ["goto", "fallback"]]) + + # assert callvalue == 0 if nonpayable + bad_callvalue = ["mul", is_nonpayable, "callvalue"] + # assert calldatasize at least minimum for the abi type + bad_calldatasize = ["lt", "calldatasize", expected_calldatasize] + failed_entry_conditions = ["or", bad_callvalue, bad_calldatasize] + check_entry_conditions = IRnode.from_list( + ["assert", ["iszero", failed_entry_conditions]], + error_msg="bad calldatasize or callvalue", ) - selector_section.append(nonpayable_check) + x.append(check_entry_conditions) + x.append(["jump", function_label]) + selector_section.append(b1.resolve(x)) + + bucket_headers = ["data", "BUCKET_HEADERS"] + + for bucket_id, bucket in sorted(jumptable_info.items()): + bucket_headers.append(bucket.magic.to_bytes(2, "big")) + bucket_headers.append(["symbol", f"bucket_{bucket_id}"]) + # note: buckets are usually ~10 items. to_bytes would + # fail if the int is too big. + bucket_headers.append(bucket.bucket_size.to_bytes(1, "big")) + + selector_section.append(bucket_headers) + + for bucket_id, bucket in jumptable_info.items(): + function_infos = ["data", f"bucket_{bucket_id}"] + # sort function infos by their image. + for method_id in bucket.method_ids_image_order: + abi_sig = sig_of[method_id] + entry_point = entry_points[abi_sig] + + method_id_bytes = method_id.to_bytes(4, "big") + symbol = ["symbol", label_for_entry_point(abi_sig, entry_point)] + func_metadata_int = entry_point.min_calldatasize | int( + not entry_point.func_t.is_payable + ) + func_metadata = func_metadata_int.to_bytes(FN_METADATA_BYTES, "big") - for func_ast in nonpayables: - func_ir = generate_ir_for_function(func_ast, global_ctx, skip_nonpayable_check) - selector_section.append(func_ir) + function_infos.extend([method_id_bytes, symbol, func_metadata]) - if default_function: - fallback_ir = generate_ir_for_function( - default_function, global_ctx, skip_nonpayable_check=False - ) - else: - fallback_ir = IRnode.from_list( - ["revert", 0, 0], annotation="Default function", error_msg="fallback function" - ) + selector_section.append(function_infos) - # ensure the external jumptable section gets closed out - # (for basic block hygiene and also for zksync interpreter) - # NOTE: this jump gets optimized out in assembly since the - # fallback label is the immediate next instruction, - close_selector_section = ["goto", "fallback"] + ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section]] - global_calldatasize_check = ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]] + ret.extend(function_irs) - runtime = [ - "seq", - global_calldatasize_check, - ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section], - close_selector_section, - ["label", "fallback", ["var_list"], fallback_ir], - ] + return ret - runtime.extend(internal_functions_ir) - return runtime +# codegen for all runtime functions + callvalue/calldata checks, +# with O(1) jumptable for selector table. +# uses two level strategy: uses `method_id % n_methods` to calculate +# a bucket, and then descends into linear search from there. +# costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) +# function and 24 bytes of code (+ ~23 bytes of global overhead) +def _selector_section_sparse(external_functions, global_ctx): + ret = ["seq"] + + if len(external_functions) == 0: + return ret + + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + + n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) + + # 2 bytes for bucket location + SZ_BUCKET_HEADER = 2 + + if n_buckets > 1: + bucket_id = ["mod", "_calldata_method_id", n_buckets] + bucket_hdr_location = [ + "add", + ["symbol", "selector_buckets"], + ["mul", bucket_id, SZ_BUCKET_HEADER], + ] + # get bucket header + dst = 32 - SZ_BUCKET_HEADER + assert dst >= 0 + + if _is_debug_mode(): + ret.append(["assert", ["eq", "msize", 0]]) + + ret.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) + + jumpdest = IRnode.from_list(["mload", 0]) + # don't particularly like using `jump` here since it can cause + # issues for other backends, consider changing `goto` to allow + # dynamic jumps, or adding some kind of jumptable instruction + ret.append(["jump", jumpdest]) + + jumptable_data = ["data", "selector_buckets"] + for i in range(n_buckets): + if i in buckets: + bucket_label = f"selector_bucket_{i}" + jumptable_data.append(["symbol", bucket_label]) + else: + # empty bucket + jumptable_data.append(["symbol", "fallback"]) + + ret.append(jumptable_data) + + for bucket_id, bucket in buckets.items(): + bucket_label = f"selector_bucket_{bucket_id}" + ret.append(["label", bucket_label, ["var_list"], ["seq"]]) + + handle_bucket = ["seq"] + + for method_id in bucket: + sig = sig_of[method_id] + entry_point = entry_points[sig] + func_t = entry_point.func_t + expected_calldatasize = entry_point.min_calldatasize + + dispatch = ["seq"] # code to dispatch into the function + skip_callvalue_check = func_t.is_payable + skip_calldatasize_check = expected_calldatasize == 4 + bad_callvalue = [0] if skip_callvalue_check else ["callvalue"] + bad_calldatasize = ( + [0] if skip_calldatasize_check else ["lt", "calldatasize", expected_calldatasize] + ) + + dispatch.append( + IRnode.from_list( + ["assert", ["iszero", ["or", bad_callvalue, bad_calldatasize]]], + error_msg="bad calldatasize or callvalue", + ) + ) + # we could skip a jumpdest per method if we out-lined the entry point + # so the dispatcher looks just like - + # ```(if (eq method_id) + # (goto entry_point_label))``` + # it would another optimization for patterns like + # `if ... (goto)` though. + dispatch.append(entry_point.ir_node) + + method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)] + has_trailing_zeroes = method_id.to_bytes(4, "big").endswith(b"\x00") + if has_trailing_zeroes: + # if the method id check has trailing 0s, we need to include + # a calldatasize check to distinguish from when not enough + # bytes are provided for the method id in calldata. + method_id_check = ["and", ["ge", "calldatasize", 4], method_id_check] + handle_bucket.append(["if", method_id_check, dispatch]) + + # close out the bucket with a goto fallback so we don't keep searching + handle_bucket.append(["goto", "fallback"]) + + ret.append(handle_bucket) + + ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), ret]] + + return ret + + +# codegen for all runtime functions + callvalue/calldata checks, +# O(n) linear search for the method id +# mainly keep this in for backends which cannot handle the indirect jump +# in selector_section_dense and selector_section_sparse +def _selector_section_linear(external_functions, global_ctx): + ret = ["seq"] + if len(external_functions) == 0: + return ret + + ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) + + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + + dispatcher = ["seq"] + + for sig, entry_point in entry_points.items(): + func_t = entry_point.func_t + expected_calldatasize = entry_point.min_calldatasize + + dispatch = ["seq"] # code to dispatch into the function + + if not func_t.is_payable: + callvalue_check = ["assert", ["iszero", "callvalue"]] + dispatch.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) + + good_calldatasize = ["ge", "calldatasize", expected_calldatasize] + calldatasize_check = ["assert", good_calldatasize] + dispatch.append(IRnode.from_list(calldatasize_check, error_msg="calldatasize check")) + + dispatch.append(entry_point.ir_node) + + method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)] + dispatcher.append(["if", method_id_check, dispatch]) + + ret.append(["with", "_calldata_method_id", shr(224, ["calldataload", 0]), dispatcher]) + + return ret # take a GlobalContext, and generate the runtime and deploy IR @@ -143,15 +403,47 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) - runtime = _runtime_ir(runtime_functions, global_ctx) + internal_functions = [f for f in runtime_functions if _is_internal(f)] + + external_functions = [ + f for f in runtime_functions if not _is_internal(f) and not _is_fallback(f) + ] + default_function = next((f for f in runtime_functions if _is_fallback(f)), None) + + internal_functions_ir: list[IRnode] = [] + + # compile internal functions first so we have the function info + for func_ast in internal_functions: + func_ir = _ir_for_internal_function(func_ast, global_ctx, False) + internal_functions_ir.append(IRnode.from_list(func_ir)) + + if core._opt_none(): + selector_section = _selector_section_linear(external_functions, global_ctx) + # dense vs sparse global overhead is amortized after about 4 methods. + # (--debug will force dense selector table anyway if _opt_codesize is selected.) + elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): + selector_section = _selector_section_dense(external_functions, global_ctx) + else: + selector_section = _selector_section_sparse(external_functions, global_ctx) + + if default_function: + fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + else: + fallback_ir = IRnode.from_list( + ["revert", 0, 0], annotation="Default function", error_msg="fallback function" + ) + + runtime = ["seq", selector_section] + runtime.append(["goto", "fallback"]) + runtime.append(["label", "fallback", ["var_list"], fallback_ir]) + + runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] immutables_len = global_ctx.immutable_section_bytes if init_function: # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = generate_ir_for_function( - init_function, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables @@ -184,12 +476,10 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: for f in internal_functions: init_func_t = init_function._metadata["type"] if f.name not in init_func_t.recursive_calls: - # unreachable + # unreachable code, delete it continue - func_ir = generate_ir_for_function( - f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) deploy_code.append(func_ir) else: diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 91d45f4916..86ea1813ea 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -258,11 +258,17 @@ def _parse_For_range(self): arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) + for s in self.stmt.iter.keywords or [] + } + # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: - arg0_val = self._get_range_const_value(arg0) + n = Expr.parse_value_expr(arg0, self.context) start = IRnode.from_list(0, typ=iter_typ) - rounds = arg0_val + rounds = n + rounds_bound = kwargs.get("bound", rounds) # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: @@ -270,6 +276,7 @@ def _parse_For_range(self): arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = IRnode.from_list(arg0_val, typ=iter_typ) rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) + rounds_bound = rounds # Type 3 for, e.g. for i in range(x, x + 10): ... else: @@ -278,9 +285,10 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) _, hi = start.typ.int_bounds start = clamp("le", start, hi + 1 - rounds) + rounds_bound = rounds - r = rounds if isinstance(rounds, int) else rounds.value - if r < 1: + bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value + if bound < 1: return varname = self.stmt.target.id @@ -294,7 +302,10 @@ def _parse_For_range(self): loop_body.append(["mstore", iptr, i]) loop_body.append(parse_body(self.stmt.body, self.context)) - ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds, loop_body]) + # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound. + # if we ever want to remove that, we need to manually add the assertion + # where it makes sense. + ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds_bound, loop_body]) del self.context.forvars[varname] return ir_node diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 63d92d9a47..69fcbf1f1f 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -300,9 +300,13 @@ def _build_opcodes(bytecode: bytes) -> str: while bytecode_sequence: op = bytecode_sequence.popleft() - opcode_output.append(opcode_map[op]) + opcode_output.append(opcode_map.get(op, f"VERBATIM_{hex(op)}")) if "PUSH" in opcode_output[-1] and opcode_output[-1] != "PUSH0": push_len = int(opcode_map[op][4:]) + # we can have push_len > len(bytecode_sequence) when there is data + # (instead of code) at end of contract + # CMC 2023-07-13 maybe just strip known data segments? + push_len = min(push_len, len(bytecode_sequence)) push_values = [hex(bytecode_sequence.popleft())[2:] for i in range(push_len)] opcode_output.append(f"0x{''.join(push_values).upper()}") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e1bd9e6c3..526d2f3253 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -263,7 +263,6 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) validate_semantics(vyper_module_folded, interface_codes) - vy_ast.expansion.expand_annotated_ast(vyper_module_folded) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) return vyper_module_folded, symbol_tables diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index bb5e9cdc25..d2c88a8592 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -42,3 +42,16 @@ class Settings: compiler_version: Optional[str] = None optimize: Optional[OptimizationLevel] = None evm_version: Optional[str] = None + + +_DEBUG = False + + +def _is_debug_mode(): + global _DEBUG + return _DEBUG + + +def _set_debug_mode(dbg: bool = False) -> None: + global _DEBUG + _DEBUG = dbg diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index a9064a44fa..bba3b34515 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -158,11 +158,20 @@ def _add_postambles(asm_ops): to_append.extend(_revert_string) if len(to_append) > 0: + # insert the postambles *before* runtime code + # so the data section of the runtime code can't bork the postambles. + runtime = None + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + runtime = asm_ops.pop() + # for some reason there might not be a STOP at the end of asm_ops. # (generally vyper programs will have it but raw IR might not). asm_ops.append("STOP") asm_ops.extend(to_append) + if runtime: + asm_ops.append(runtime) + # need to do this recursively since every sublist is basically # treated as its own program (there are no global labels.) for t in asm_ops: @@ -213,6 +222,9 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): res = _compile_to_assembly(code) _add_postambles(res) + + _relocate_segments(res) + if optimize != OptimizationLevel.NONE: _optimize_assembly(res) return res @@ -401,9 +413,8 @@ def _height_of(witharg): ) # stack: i, rounds, rounds_bound # assert rounds <= rounds_bound - # TODO this runtime assertion should never fail for + # TODO this runtime assertion shouldn't fail for # internally generated repeats. - # maybe drop it or jump to 0xFE o.extend(["DUP2", "GT"] + _assert_false()) # stack: i, rounds @@ -500,14 +511,14 @@ def _height_of(witharg): assert isinstance(memsize, int), "non-int memsize" assert isinstance(padding, int), "non-int padding" - begincode = mksymbol("runtime_begin") + runtime_begin = mksymbol("runtime_begin") subcode = _compile_to_assembly(ir) o = [] # COPY the code to memory for deploy - o.extend(["_sym_subcode_size", begincode, "_mem_deploy_start", "CODECOPY"]) + o.extend(["_sym_subcode_size", runtime_begin, "_mem_deploy_start", "CODECOPY"]) # calculate the len of runtime code o.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len @@ -517,10 +528,9 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [f"_DEPLOY_MEM_OFST_{memsize}"] + subcode + subcode = [_RuntimeHeader(runtime_begin, memsize)] + subcode # append the runtime code after the ctor code - o.extend([begincode, "BLANK"]) # `append(...)` call here is intentional. # each sublist is essentially its own program with its # own symbols. @@ -661,16 +671,36 @@ def _height_of(witharg): height, ) + elif code.value == "data": + data_node = [_DataHeader("_sym_" + code.args[0].value)] + + for c in code.args[1:]: + if isinstance(c.value, int): + assert 0 <= c < 256, f"invalid data byte {c}" + data_node.append(c.value) + elif isinstance(c.value, bytes): + data_node.append(c.value) + elif isinstance(c, IRnode): + assert c.value == "symbol" + data_node.extend( + _compile_to_assembly(c, withargs, existing_labels, break_dest, height) + ) + else: + raise ValueError(f"Invalid data: {type(c)} {c}") + + # intentionally return a sublist. + return [data_node] + # jump to a symbol, and push variable # of arguments onto stack elif code.value == "goto": o = [] for i, c in enumerate(reversed(code.args[1:])): o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) - o.extend(["_sym_" + str(code.args[0]), "JUMP"]) + o.extend(["_sym_" + code.args[0].value, "JUMP"]) return o # push a literal symbol elif code.value == "symbol": - return ["_sym_" + str(code.args[0])] + return ["_sym_" + code.args[0].value] # set a symbol as a location. elif code.value == "label": label_name = code.args[0].value @@ -728,8 +758,8 @@ def _height_of(witharg): # inject debug opcode. elif code.value == "pc_debugger": return mkdebug(pc_debugger=True, source_pos=code.source_pos) - else: - raise Exception("Weird code element: " + repr(code)) + else: # pragma: no cover + raise ValueError(f"Weird code element: {type(code)} {code}") def note_line_num(line_number_map, item, pos): @@ -764,11 +794,8 @@ def note_breakpoint(line_number_map, item, pos): def _prune_unreachable_code(assembly): - # In converting IR to assembly we sometimes end up with unreachable - # instructions - POPing to clear the stack or STOPing execution at the - # end of a function that has already returned or reverted. This should - # be addressed in the IR, but for now we do a final sanity check here - # to avoid unnecessary bytecode bloat. + # delete code between terminal ops and JUMPDESTS as those are + # unreachable changed = False i = 0 while i < len(assembly) - 2: @@ -777,7 +804,7 @@ def _prune_unreachable_code(assembly): instr = assembly[i][-1] if assembly[i] in _TERMINAL_OPS and not ( - is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK") + is_symbol(assembly[i + 1]) or isinstance(assembly[i + 1], list) ): changed = True del assembly[i + 1] @@ -889,6 +916,14 @@ def _merge_iszero(assembly): return changed +# a symbol _sym_x in assembly can either mean to push _sym_x to the stack, +# or it can precede a location in code which we want to add to symbol map. +# this helper function tells us if we want to add the previous instruction +# to the symbol map. +def is_symbol_map_indicator(asm_node): + return asm_node == "JUMPDEST" + + def _prune_unused_jumpdests(assembly): changed = False @@ -896,9 +931,17 @@ def _prune_unused_jumpdests(assembly): # find all used jumpdests for i in range(len(assembly) - 1): - if is_symbol(assembly[i]) and assembly[i + 1] != "JUMPDEST": + if is_symbol(assembly[i]) and not is_symbol_map_indicator(assembly[i + 1]): used_jumpdests.add(assembly[i]) + for item in assembly: + if isinstance(item, list) and isinstance(item[0], _DataHeader): + # add symbols used in data sections as they are likely + # used for a jumptable. + for t in item: + if is_symbol(t): + used_jumpdests.add(t) + # delete jumpdests that aren't used i = 0 while i < len(assembly) - 2: @@ -937,7 +980,7 @@ def _stack_peephole_opts(assembly): # optimize assembly, in place def _optimize_assembly(assembly): for x in assembly: - if isinstance(x, list): + if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): _optimize_assembly(x) for _ in range(1024): @@ -970,7 +1013,93 @@ def adjust_pc_maps(pc_maps, ofst): return ret +SYMBOL_SIZE = 2 # size of a PUSH instruction for a code symbol + + +def _data_to_evm(assembly, symbol_map): + ret = bytearray() + assert isinstance(assembly[0], _DataHeader) + for item in assembly[1:]: + if is_symbol(item): + symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") + ret.extend(symbol) + elif isinstance(item, int): + ret.append(item) + elif isinstance(item, bytes): + ret.extend(item) + else: + raise ValueError(f"invalid data {type(item)} {item}") + + return ret + + +# predict what length of an assembly [data] node will be in bytecode +def _length_of_data(assembly): + ret = 0 + assert isinstance(assembly[0], _DataHeader) + for item in assembly[1:]: + if is_symbol(item): + ret += SYMBOL_SIZE + elif isinstance(item, int): + assert 0 <= item < 256, f"invalid data byte {item}" + ret += 1 + elif isinstance(item, bytes): + ret += len(item) + else: + raise ValueError(f"invalid data {type(item)} {item}") + + return ret + + +class _RuntimeHeader: + def __init__(self, label, ctor_mem_size): + self.label = label + self.ctor_mem_size = ctor_mem_size + + def __repr__(self): + return f"" + + +class _DataHeader: + def __init__(self, label): + self.label = label + + def __repr__(self): + return f"DATA {self.label}" + + +def _relocate_segments(assembly): + # relocate all data segments to the end, otherwise data could be + # interpreted as PUSH instructions and mangle otherwise valid jumpdests + # relocate all runtime segments to the end as well + data_segments = [] + non_data_segments = [] + code_segments = [] + for t in assembly: + if isinstance(t, list): + if isinstance(t[0], _DataHeader): + data_segments.append(t) + else: + _relocate_segments(t) # recurse + assert isinstance(t[0], _RuntimeHeader) + code_segments.append(t) + else: + non_data_segments.append(t) + assembly.clear() + assembly.extend(non_data_segments) + assembly.extend(code_segments) + assembly.extend(data_segments) + + +# TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): + bytecode, source_maps, _ = assembly_to_evm_with_symbol_map( + assembly, pc_ofst=pc_ofst, insert_vyper_signature=insert_vyper_signature + ) + return bytecode, source_maps + + +def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature=False): """ Assembles assembly into EVM @@ -999,8 +1128,6 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): bytecode_suffix += b"\xa1\x65vyper\x83" + bytes(list(version_tuple)) bytecode_suffix += len(bytecode_suffix).to_bytes(2, "big") - CODE_OFST_SIZE = 2 # size of a PUSH instruction for a code symbol - # to optimize the size of deploy code - we want to use the smallest # PUSH instruction possible which can support all memory symbols # (and also works with linear pass symbol resolution) @@ -1009,13 +1136,13 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list): + if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): assert runtime_code is None, "Multiple subcodes" - runtime_code, runtime_map = assembly_to_evm(item) - assert item[0].startswith("_DEPLOY_MEM_OFST_") assert ctor_mem_size is None - ctor_mem_size = int(item[0][len("_DEPLOY_MEM_OFST_") :]) + ctor_mem_size = item[0].ctor_mem_size + + runtime_code, runtime_map = assembly_to_evm(item[1:]) runtime_code_start, runtime_code_end = _runtime_code_offsets( ctor_mem_size, len(runtime_code) @@ -1053,14 +1180,14 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): # update pc if is_symbol(item): - if assembly[i + 1] == "JUMPDEST" or assembly[i + 1] == "BLANK": + if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc else: - pc += CODE_OFST_SIZE + 1 # PUSH2 highbits lowbits + pc += SYMBOL_SIZE + 1 # PUSH2 highbits lowbits elif is_mem_sym(item): # PUSH item pc += mem_ofst_size + 1 @@ -1070,19 +1197,16 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif item == "BLANK": - pc += 0 - elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"): - # _DEPLOY_MEM_OFST is assembly magic which will - # get removed during final assembly-to-bytecode - pc += 0 - elif isinstance(item, list): + elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + symbol_map[item[0].label] = pc # add source map for all items in the runtime map t = adjust_pc_maps(runtime_map, pc) for key in line_number_map: line_number_map[key].update(t[key]) pc += len(runtime_code) - + elif isinstance(item, list) and isinstance(item[0], _DataHeader): + symbol_map[item[0].label] = pc + pc += _length_of_data(item) else: pc += 1 @@ -1094,13 +1218,9 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): if runtime_code is not None: symbol_map["_sym_subcode_size"] = len(runtime_code) - # (NOTE CMC 2022-06-17 this way of generating bytecode did not - # seem to be a perf hotspot. if it is, may want to use bytearray() - # instead). - - # TODO refactor into two functions, create posmap and assemble + # TODO refactor into two functions, create symbol_map and assemble - o = b"" + ret = bytearray() # now that all symbols have been resolved, generate bytecode # using the symbol map @@ -1110,47 +1230,47 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): to_skip -= 1 continue - if item in ("DEBUG", "BLANK"): + if item in ("DEBUG",): continue # skippable opcodes - elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"): - continue - elif is_symbol(item): - if assembly[i + 1] != "JUMPDEST" and assembly[i + 1] != "BLANK": - bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=CODE_OFST_SIZE)) - o += bytecode + # push a symbol to stack + if not is_symbol_map_indicator(assembly[i + 1]): + bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=SYMBOL_SIZE)) + ret.extend(bytecode) elif is_mem_sym(item): bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=mem_ofst_size)) - o += bytecode + ret.extend(bytecode) elif is_ofst(item): # _OFST _sym_foo 32 ofst = symbol_map[assembly[i + 1]] + assembly[i + 2] - n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else CODE_OFST_SIZE + n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else SYMBOL_SIZE bytecode, _ = assembly_to_evm(PUSH_N(ofst, n)) - o += bytecode + ret.extend(bytecode) to_skip = 2 elif isinstance(item, int): - o += bytes([item]) + ret.append(item) elif isinstance(item, str) and item.upper() in get_opcodes(): - o += bytes([get_opcodes()[item.upper()][0]]) + ret.append(get_opcodes()[item.upper()][0]) elif item[:4] == "PUSH": - o += bytes([PUSH_OFFSET + int(item[4:])]) + ret.append(PUSH_OFFSET + int(item[4:])) elif item[:3] == "DUP": - o += bytes([DUP_OFFSET + int(item[3:])]) + ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": - o += bytes([SWAP_OFFSET + int(item[4:])]) - elif isinstance(item, list): - o += runtime_code - else: - # Should never reach because, assembly is create in _compile_to_assembly. - raise Exception("Weird symbol in assembly: " + str(item)) # pragma: no cover + ret.append(SWAP_OFFSET + int(item[4:])) + elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + ret.extend(runtime_code) + elif isinstance(item, list) and isinstance(item[0], _DataHeader): + ret.extend(_data_to_evm(item, symbol_map)) + else: # pragma: no cover + # unreachable + raise ValueError(f"Weird symbol in assembly: {type(item)} {item}") - o += bytecode_suffix + ret.extend(bytecode_suffix) line_number_map["breakpoints"] = list(line_number_map["breakpoints"]) line_number_map["pc_breakpoints"] = list(line_number_map["pc_breakpoints"]) - return o, line_number_map + return bytes(ret), line_number_map, symbol_map diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 40e02e79c7..08c2168381 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -473,6 +473,8 @@ def finalize(val, args): if value == "seq": changed |= _merge_memzero(argz) changed |= _merge_calldataload(argz) + changed |= _merge_dload(argz) + changed |= _rewrite_mstore_dload(argz) changed |= _merge_mload(argz) changed |= _remove_empty_seqs(argz) @@ -645,6 +647,18 @@ def _merge_dload(argz): return _merge_load(argz, "dload", "dloadbytes") +def _rewrite_mstore_dload(argz): + changed = False + for i, arg in enumerate(argz): + if arg.value == "mstore" and arg.args[1].value == "dload": + dst = arg.args[0] + src = arg.args[1].args[0] + len_ = 32 + argz[i] = IRnode.from_list(["dloadbytes", dst, src, len_], source_pos=arg.source_pos) + changed = True + return changed + + def _merge_mload(argz): if not version_check(begin="cancun"): return False diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 5977a87812..9e987d1cd0 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,3 +1,5 @@ +import vyper.ast as vy_ast + from .. import types # break a dependency cycle. from ..namespace import get_namespace from .local import validate_functions @@ -11,4 +13,5 @@ def validate_semantics(vyper_ast, interface_codes): with namespace.enter_scope(): add_module_namespace(vyper_ast, interface_codes) + vy_ast.expansion.expand_annotated_ast(vyper_ast) validate_functions(vyper_ast) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c58c65e8a0..bb32b81df2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -371,16 +371,30 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - validate_call_args(node.iter, (1, 2)) + validate_call_args(node.iter, (1, 2), kwargs=["bound"]) args = node.iter.args + kwargs = {s.arg: s.value for s in node.iter.keywords or []} if len(args) == 1: # range(CONSTANT) - if not isinstance(args[0], vy_ast.Num): - raise StateAccessViolation("Value must be a literal", node) - if args[0].value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - type_list = get_possible_types_from_node(args[0]) + n = args[0] + bound = kwargs.pop("bound", None) + #validate_expected_type(n, IntegerT.any()) + + if bound is None: + if not isinstance(n, vy_ast.Num): + raise StateAccessViolation("Value must be a literal", n) + if n.value <= 0: + raise StructureException("For loop must have at least 1 iteration", args[0]) + type_list = get_possible_types_from_node(n) + + else: + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("bound must be at least 1", args[0]) + type_list = get_common_types(n, bound) + else: type_list = get_common_types(*args) if not isinstance(args[0], vy_ast.Constant): @@ -498,6 +512,10 @@ def visit_For(self, node): for a in node.iter.args: self.expr_visitor.visit(a, typ) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, typ) + # success -- do not enter error handling section return diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f252c84373..afa6b56838 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -180,24 +180,30 @@ def _find_fn(self, node): raise StructureException("Cannot determine type of this object", node) def types_from_Attribute(self, node): + is_self_reference = node.get("value.id") == "self" # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr + + def _raise_invalid_reference(name, node): + raise InvalidReference( + f"'{name}' is not a storage variable, it should not be prepended with self", node + ) + try: s = t.get_member(name, node) if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] + if is_self_reference and (s.is_constant or s.is_immutable): + _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] except UnknownAttribute: - if node.get("value.id") != "self": + if not is_self_reference: raise if name in self.namespace: - raise InvalidReference( - f"'{name}' is not a storage variable, it should not be prepended with self", - node, - ) from None + _raise_invalid_reference(name, node) suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4) raise UndeclaredDefinition(