From 1ac8362df0c9d501af241df63bc3798cab7f5f9c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 May 2023 09:49:12 -0400 Subject: [PATCH 001/201] feat: use new push0 opcode (#3361) per the shanghai fork - during codegen, use the new `push0` opcode instead of the `push1 0` sequence --- setup.py | 4 ++-- tests/compiler/ir/test_compile_ir.py | 2 +- tests/compiler/test_opcodes.py | 31 ++++++++++++++++------------ tests/compiler/test_sha3_32.py | 3 +++ vyper/compiler/output.py | 4 ++-- vyper/evm/opcodes.py | 4 +++- vyper/ir/compile_ir.py | 28 ++++++++++--------------- 7 files changed, 40 insertions(+), 36 deletions(-) diff --git a/setup.py b/setup.py index 0966a8e31a..05cb52259d 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,8 @@ "pytest-xdist>=2.5,<3.0", "pytest-split>=0.7.0,<1.0", "pytest-rerunfailures>=10.2,<11", - "eth-tester[py-evm]>=0.8.0b3,<0.9", - "py-evm>=0.6.1a2,<0.7", + "eth-tester[py-evm]>=0.9.0b1,<0.10", + "py-evm>=0.7.0a1,<0.8", "web3==6.0.0", "tox>=3.15,<4.0", "lark==1.1.2", diff --git a/tests/compiler/ir/test_compile_ir.py b/tests/compiler/ir/test_compile_ir.py index 91007da33a..706c31e0f2 100644 --- a/tests/compiler/ir/test_compile_ir.py +++ b/tests/compiler/ir/test_compile_ir.py @@ -68,4 +68,4 @@ def test_pc_debugger(): debugger_ir = ["seq", ["mstore", 0, 32], ["pc_debugger"]] ir_nodes = IRnode.from_list(debugger_ir) _, line_number_map = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir_nodes)) - assert line_number_map["pc_breakpoints"][0] == 5 + assert line_number_map["pc_breakpoints"][0] == 4 diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index 67ea10c311..f36fcfac6f 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -8,9 +8,11 @@ @pytest.fixture(params=list(opcodes.EVM_VERSIONS)) def evm_version(request): default = opcodes.active_evm_version - opcodes.active_evm_version = opcodes.EVM_VERSIONS[request.param] - yield request.param - opcodes.active_evm_version = default + try: + opcodes.active_evm_version = opcodes.EVM_VERSIONS[request.param] + yield request.param + finally: + opcodes.active_evm_version = default def test_opcodes(): @@ -42,17 +44,20 @@ def test_version_check(evm_version): def test_get_opcodes(evm_version): - op = opcodes.get_opcodes() - if evm_version in ("paris", "berlin"): - assert "CHAINID" in op - assert op["SLOAD"][-1] == 2100 + ops = opcodes.get_opcodes() + if evm_version in ("paris", "berlin", "shanghai"): + assert "CHAINID" in ops + assert ops["SLOAD"][-1] == 2100 + if evm_version in ("shanghai",): + assert "PUSH0" in ops elif evm_version == "istanbul": - assert "CHAINID" in op - assert op["SLOAD"][-1] == 800 + assert "CHAINID" in ops + assert ops["SLOAD"][-1] == 800 else: - assert "CHAINID" not in op - assert op["SLOAD"][-1] == 200 + assert "CHAINID" not in ops + assert ops["SLOAD"][-1] == 200 + if evm_version in ("byzantium", "atlantis"): - assert "CREATE2" not in op + assert "CREATE2" not in ops else: - assert op["CREATE2"][-1] == 32000 + assert ops["CREATE2"][-1] == 32000 diff --git a/tests/compiler/test_sha3_32.py b/tests/compiler/test_sha3_32.py index 9fbdf6f000..e1cbf9c843 100644 --- a/tests/compiler/test_sha3_32.py +++ b/tests/compiler/test_sha3_32.py @@ -1,9 +1,12 @@ from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import version_check from vyper.ir import compile_ir, optimizer def test_sha3_32(): ir = ["sha3_32", 0] evm = ["PUSH1", 0, "PUSH1", 0, "MSTORE", "PUSH1", 32, "PUSH1", 0, "SHA3"] + if version_check(begin="shanghai"): + evm = ["PUSH0", "PUSH0", "MSTORE", "PUSH1", 32, "PUSH0", "SHA3"] assert compile_ir.compile_to_assembly(IRnode.from_list(ir)) == evm assert compile_ir.compile_to_assembly(optimizer.optimize(IRnode.from_list(ir))) == evm diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index e30f021c6b..f061bd8e18 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -208,7 +208,7 @@ def _build_asm(asm_list): else: output_string += str(node) + " " - if isinstance(node, str) and node.startswith("PUSH"): + if isinstance(node, str) and node.startswith("PUSH") and node != "PUSH0": assert in_push == 0 in_push = int(node[4:]) output_string += "0x" @@ -303,7 +303,7 @@ def _build_opcodes(bytecode: bytes) -> str: while bytecode_sequence: op = bytecode_sequence.popleft() opcode_output.append(opcode_map[op]) - if "PUSH" in opcode_output[-1]: + if "PUSH" in opcode_output[-1] and opcode_output[-1] != "PUSH0": push_len = int(opcode_map[op][4:]) push_values = [hex(bytecode_sequence.popleft())[2:] for i in range(push_len)] opcode_output.append(f"0x{''.join(push_values).upper()}") diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index b9f1e77ca8..76529da14e 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -25,11 +25,12 @@ "istanbul": 2, "berlin": 3, "paris": 4, + "shanghai": 5, # ETC Forks "atlantis": 0, "agharta": 1, } -DEFAULT_EVM_VERSION: str = "paris" +DEFAULT_EVM_VERSION: str = "shanghai" # opcode as hex value @@ -102,6 +103,7 @@ "MSIZE": (0x59, 0, 1, 2), "GAS": (0x5A, 0, 1, 2), "JUMPDEST": (0x5B, 0, 0, 1), + "PUSH0": (0x5F, 0, 1, 2), "PUSH1": (0x60, 0, 1, 3), "PUSH2": (0x61, 0, 1, 3), "PUSH3": (0x62, 0, 1, 3), diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index c24b3a67a2..57ea4ca7e7 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -3,7 +3,7 @@ import math from vyper.codegen.ir_node import IRnode -from vyper.evm.opcodes import get_opcodes +from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic from vyper.utils import MemoryPositions from vyper.version import version_tuple @@ -23,7 +23,8 @@ def num_to_bytearray(x): def PUSH(x): bs = num_to_bytearray(x) - if len(bs) == 0: + # starting in shanghai, can do push0 directly with no immediates + if len(bs) == 0 and not version_check(begin="shanghai"): bs = [0] return [f"PUSH{len(bs)}"] + bs @@ -149,7 +150,7 @@ def _add_postambles(asm_ops): global _revert_label - _revert_string = [_revert_label, "JUMPDEST", "PUSH1", 0, "DUP1", "REVERT"] + _revert_string = [_revert_label, "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] if _revert_label in asm_ops: # shared failure block @@ -555,13 +556,10 @@ def _height_of(witharg): o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) o.extend( [ - "PUSH1", - MemoryPositions.FREE_VAR_SPACE, + *PUSH(MemoryPositions.FREE_VAR_SPACE), "MSTORE", - "PUSH1", - 32, - "PUSH1", - MemoryPositions.FREE_VAR_SPACE, + *PUSH(32), + *PUSH(MemoryPositions.FREE_VAR_SPACE), "SHA3", ] ) @@ -572,16 +570,12 @@ def _height_of(witharg): o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) o.extend( [ - "PUSH1", - MemoryPositions.FREE_VAR_SPACE2, + *PUSH(MemoryPositions.FREE_VAR_SPACE2), "MSTORE", - "PUSH1", - MemoryPositions.FREE_VAR_SPACE, + *PUSH(MemoryPositions.FREE_VAR_SPACE), "MSTORE", - "PUSH1", - 64, - "PUSH1", - MemoryPositions.FREE_VAR_SPACE, + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), "SHA3", ] ) From 8aaaa9dbf64d537f57ae73d0db7f48682cecc5c2 Mon Sep 17 00:00:00 2001 From: Kelvin Fan Date: Thu, 18 May 2023 07:08:00 -0700 Subject: [PATCH 002/201] feat: build for aarch64 (#2687) Use `universal2` as target arch in pyinstaller To also build official binary releases for aarch64 machines. --------- Co-authored-by: Charles Cooper --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index daa1c2bfc9..645b800e79 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ freeze: clean init echo Generating binary... export OS="$$(uname -s | tr A-Z a-z)" && \ export VERSION="$$(PYTHONPATH=. python vyper/cli/vyper_compile.py --version)" && \ - pyinstaller --clean --onefile vyper/cli/vyper_compile.py --name "vyper.$${VERSION}.$${OS}" --add-data vyper:vyper + pyinstaller --target-architecture=universal2 --clean --onefile vyper/cli/vyper_compile.py --name "vyper.$${VERSION}.$${OS}" --add-data vyper:vyper clean: clean-build clean-docs clean-pyc clean-test From a8382f53b70f185fc3035dbfea561ed0737d0463 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 May 2023 10:18:33 -0400 Subject: [PATCH 003/201] fix: uninitialized immutable values (#3409) immutable variables can be read before assignment in constructor code, and their memory location is accessed, but that memory might not yet be initialized. prior to this commit, its value is not necessarily `empty(type)` since memory could have been written to ephemerally. in particular, `create_copy_of` (and its sister, `create_from_blueprint`) use `msize` to determine a starting location for where to copy the target bytecode into memory, while the immutables section start is determined using the static memory allocator. in case that `msize` is still less than the immutables section end, `create_copy_of` can write to the immutables section, thereby resulting in reads from the immutables section to return garbage. this commit fixes the issue by issuing an `iload - 32` before executing any initcode, which forces `msize` to be initialized past the end of the immutables section (and therefore, accessing an immutable before it is initialized in the constructor will produce the "expected" `empty()` value for the immutable). note that a corresponding `mload` is not required for runtime code, because vyper requires all memory variables to be instantiated at the declaration site, so there is no way that msize can produce a pointer to an uninitialized memory variable. --- tests/parser/features/test_immutable.py | 50 +++++++++++++++++++++++++ vyper/codegen/module.py | 21 ++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/parser/features/test_immutable.py b/tests/parser/features/test_immutable.py index 488943f784..7300d0f2d9 100644 --- a/tests/parser/features/test_immutable.py +++ b/tests/parser/features/test_immutable.py @@ -241,6 +241,56 @@ def get_immutable() -> uint256: assert c.get_immutable() == n + 2 +# GH issue 3101 +def test_immutables_initialized(get_contract): + dummy_code = """ +@external +def foo() -> uint256: + return 1 + """ + dummy_contract = get_contract(dummy_code) + + code = """ +a: public(immutable(uint256)) +b: public(uint256) + +@payable +@external +def __init__(to_copy: address): + c: address = create_copy_of(to_copy) + self.b = a + a = 12 + """ + c = get_contract(code, dummy_contract.address) + + assert c.b() == 0 + + +# GH issue 3101, take 2 +def test_immutables_initialized2(get_contract, get_contract_from_ir): + dummy_contract = get_contract_from_ir( + ["deploy", 0, ["seq"] + ["invalid"] * 600, 0], no_optimize=True + ) + + # rekt because immutables section extends past allocated memory + code = """ +a0: immutable(uint256[10]) +a: public(immutable(uint256)) +b: public(uint256) + +@payable +@external +def __init__(to_copy: address): + c: address = create_copy_of(to_copy) + self.b = a + a = 12 + a0 = empty(uint256[10]) + """ + c = get_contract(code, dummy_contract.address) + + assert c.b() == 0 + + # GH issue 3292 def test_internal_functions_called_by_ctor_location(get_contract): code = """ diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 320cf43b55..5d05c27e0b 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -153,12 +153,31 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: init_func_ir = generate_ir_for_function( init_function, global_ctx, skip_nonpayable_check=False, is_ctor_context=True ) - deploy_code.append(init_func_ir) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables # note: (deploy mem_ofst, code, extra_padding) init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used + + # force msize to be initialized past the end of immutables section + # so that builtins which use `msize` for "dynamic" memory + # allocation do not clobber uninitialized immutables. + # cf. GH issue 3101. + # note mload/iload X touches bytes from X to X+32, and msize rounds up + # to the nearest 32, so `iload`ing `immutables_len - 32` guarantees + # that `msize` will refer to a memory location of at least + # ` + immutables_len` (where == + # `_mem_deploy_end` as defined in the assembler). + # note: + # mload 32 => msize == 64 + # mload 33 => msize == 96 + # assumption in general: (mload X) => msize == ceil32(X + 32) + # see py-evm extend_memory: after_size = ceil32(start_position + size) + if immutables_len > 0: + deploy_code.append(["iload", max(0, immutables_len - 32)]) + + deploy_code.append(init_func_ir) + deploy_code.append(["deploy", init_mem_used, runtime, immutables_len]) # internal functions come after everything else From 6ee74f5af58507029192732f828e13c431c273a3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 19 May 2023 01:05:49 +0800 Subject: [PATCH 004/201] fix: complex arguments to builtin functions (#3167) prior to this commit, some builtin functions including ceil, would panic if their arguments were function calls (or otherwise determined to be complex expressions by `is_complex_ir`). this commit fixes the relevant builtin functions by using `cache_when_complex` where appropriate. --------- Co-authored-by: Charles Cooper --- tests/conftest.py | 35 ++++ tests/parser/functions/test_addmod.py | 57 ++++++ tests/parser/functions/test_as_wei_value.py | 31 ++++ tests/parser/functions/test_ceil.py | 34 ++++ tests/parser/functions/test_ec.py | 62 +++++++ tests/parser/functions/test_floor.py | 34 ++++ tests/parser/functions/test_mulmod.py | 75 ++++++++ .../types/numbers/test_unsigned_ints.py | 43 ----- vyper/builtins/functions.py | 167 ++++++++++-------- 9 files changed, 422 insertions(+), 116 deletions(-) create mode 100644 tests/parser/functions/test_addmod.py create mode 100644 tests/parser/functions/test_as_wei_value.py create mode 100644 tests/parser/functions/test_mulmod.py diff --git a/tests/conftest.py b/tests/conftest.py index e1d0996767..1cc9e4e72e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,3 +192,38 @@ def _f(_addr, _salt, _initcode): return keccak(prefix + addr + salt + keccak(initcode))[12:] return _f + + +@pytest.fixture +def side_effects_contract(get_contract): + def generate(ret_type): + """ + Generates a Vyper contract with an external `foo()` function, which + returns the specified return value of the specified return type, for + testing side effects using the `assert_side_effects_invoked` fixture. + """ + code = f""" +counter: public(uint256) + +@external +def foo(s: {ret_type}) -> {ret_type}: + self.counter += 1 + return s + """ + contract = get_contract(code) + return contract + + return generate + + +@pytest.fixture +def assert_side_effects_invoked(): + def assert_side_effects_invoked(side_effects_contract, side_effects_trigger, n=1): + start_value = side_effects_contract.counter() + + side_effects_trigger() + + end_value = side_effects_contract.counter() + assert end_value == start_value + n + + return assert_side_effects_invoked diff --git a/tests/parser/functions/test_addmod.py b/tests/parser/functions/test_addmod.py new file mode 100644 index 0000000000..67a7e9b101 --- /dev/null +++ b/tests/parser/functions/test_addmod.py @@ -0,0 +1,57 @@ +def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation): + uint256_code = """ +@external +def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_addmod(x, y, z) + """ + + c = get_contract_with_gas_estimation(uint256_code) + + assert c._uint256_addmod(1, 2, 2) == 1 + assert c._uint256_addmod(32, 2, 32) == 2 + assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 + assert c._uint256_addmod(2**255, 2**255, 6) == 4 + assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) + + +def test_uint256_addmod_ext_call( + w3, side_effects_contract, assert_side_effects_invoked, get_contract +): + code = """ +@external +def foo(f: Foo) -> uint256: + return uint256_addmod(32, 2, f.foo(32)) + +interface Foo: + def foo(x: uint256) -> uint256: payable + """ + + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 2 + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_uint256_addmod_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return uint256_addmod(self.a(), self.b(), self.c()) + +@internal +def a() -> uint256: + return 32 + +@internal +def b() -> uint256: + return 2 + +@internal +def c() -> uint256: + return 32 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 2 diff --git a/tests/parser/functions/test_as_wei_value.py b/tests/parser/functions/test_as_wei_value.py new file mode 100644 index 0000000000..bab0aed616 --- /dev/null +++ b/tests/parser/functions/test_as_wei_value.py @@ -0,0 +1,31 @@ +def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> uint256: + return as_wei_value(a.foo(7), "ether") + +interface Foo: + def foo(x: uint8) -> uint8: nonpayable + """ + + c1 = side_effects_contract("uint8") + c2 = get_contract(code) + + assert c2.foo(c1.address) == w3.to_wei(7, "ether") + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_internal_call(w3, get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return as_wei_value(self.bar(), "ether") + +@internal +def bar() -> uint8: + return 7 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == w3.to_wei(7, "ether") diff --git a/tests/parser/functions/test_ceil.py b/tests/parser/functions/test_ceil.py index a9bcf62da2..daa9cb7c1b 100644 --- a/tests/parser/functions/test_ceil.py +++ b/tests/parser/functions/test_ceil.py @@ -104,3 +104,37 @@ def ceil_param(p: decimal) -> int256: assert c.fou() == -3 assert c.ceil_param(Decimal("-0.5")) == 0 assert c.ceil_param(Decimal("-7777777.7777777")) == -7777777 + + +def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> int256: + return ceil(a.foo(2.5)) + +interface Foo: + def foo(x: decimal) -> decimal: payable + """ + + c1 = side_effects_contract("decimal") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 3 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_ceil_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> int256: + return ceil(self.bar()) + +@internal +def bar() -> decimal: + return 2.5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 3 diff --git a/tests/parser/functions/test_ec.py b/tests/parser/functions/test_ec.py index be0f6f7ed2..9ce37d0721 100644 --- a/tests/parser/functions/test_ec.py +++ b/tests/parser/functions/test_ec.py @@ -45,6 +45,37 @@ def _ecadd3(x: uint256[2], y: uint256[2]) -> uint256[2]: assert c._ecadd3(G1, negative_G1) == [0, 0] +def test_ecadd_internal_call(get_contract_with_gas_estimation): + code = """ +@internal +def a() -> uint256[2]: + return [1, 2] + +@external +def foo() -> uint256[2]: + return ecadd([1, 2], self.a()) + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == G1_times_two + + +def test_ecadd_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +interface Foo: + def foo(x: uint256[2]) -> uint256[2]: payable + +@external +def foo(a: Foo) -> uint256[2]: + return ecadd([1, 2], a.foo([1, 2])) + """ + c1 = side_effects_contract("uint256[2]") + c2 = get_contract(code) + + assert c2.foo(c1.address) == G1_times_two + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + def test_ecmul(get_contract_with_gas_estimation): ecmuller = """ x3: uint256[2] @@ -74,3 +105,34 @@ def _ecmul3(x: uint256[2], y: uint256) -> uint256[2]: assert c._ecmul(G1, 3) == G1_times_three assert c._ecmul(G1, curve_order - 1) == negative_G1 assert c._ecmul(G1, curve_order) == [0, 0] + + +def test_ecmul_internal_call(get_contract_with_gas_estimation): + code = """ +@internal +def a() -> uint256: + return 3 + +@external +def foo() -> uint256[2]: + return ecmul([1, 2], self.a()) + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == G1_times_three + + +def test_ecmul_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +interface Foo: + def foo(x: uint256) -> uint256: payable + +@external +def foo(a: Foo) -> uint256[2]: + return ecmul([1, 2], a.foo(3)) + """ + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == G1_times_three + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) diff --git a/tests/parser/functions/test_floor.py b/tests/parser/functions/test_floor.py index dc53545ac3..d2fd993785 100644 --- a/tests/parser/functions/test_floor.py +++ b/tests/parser/functions/test_floor.py @@ -108,3 +108,37 @@ def floor_param(p: decimal) -> int256: assert c.fou() == -4 assert c.floor_param(Decimal("-5.6")) == -6 assert c.floor_param(Decimal("-0.0000000001")) == -1 + + +def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> int256: + return floor(a.foo(2.5)) + +interface Foo: + def foo(x: decimal) -> decimal: nonpayable + """ + + c1 = side_effects_contract("decimal") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 2 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_floor_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> int256: + return floor(self.bar()) + +@internal +def bar() -> decimal: + return 2.5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 2 diff --git a/tests/parser/functions/test_mulmod.py b/tests/parser/functions/test_mulmod.py new file mode 100644 index 0000000000..1ea7a3f8e8 --- /dev/null +++ b/tests/parser/functions/test_mulmod.py @@ -0,0 +1,75 @@ +def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation): + uint256_code = """ +@external +def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_mulmod(x, y, z) + """ + + c = get_contract_with_gas_estimation(uint256_code) + + assert c._uint256_mulmod(3, 1, 2) == 1 + assert c._uint256_mulmod(200, 3, 601) == 600 + assert c._uint256_mulmod(2**255, 1, 3) == 2 + assert c._uint256_mulmod(2**255, 2, 6) == 4 + assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) + + +def test_uint256_mulmod_complex(get_contract_with_gas_estimation): + modexper = """ +@external +def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: + o: uint256 = 1 + for i in range(256): + o = uint256_mulmod(o, o, modulus) + if exponent & shift(1, 255 - i) != 0: + o = uint256_mulmod(o, base, modulus) + return o + """ + + c = get_contract_with_gas_estimation(modexper) + assert c.exponential(3, 5, 100) == 43 + assert c.exponential(2, 997, 997) == 2 + + +def test_uint256_mulmod_ext_call( + w3, side_effects_contract, assert_side_effects_invoked, get_contract +): + code = """ +@external +def foo(f: Foo) -> uint256: + return uint256_mulmod(200, 3, f.foo(601)) + +interface Foo: + def foo(x: uint256) -> uint256: nonpayable + """ + + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 600 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_uint256_mulmod_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return uint256_mulmod(self.a(), self.b(), self.c()) + +@internal +def a() -> uint256: + return 200 + +@internal +def b() -> uint256: + return 3 + +@internal +def c() -> uint256: + return 601 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 600 diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/parser/types/numbers/test_unsigned_ints.py index 82c0f8484c..683684e6be 100644 --- a/tests/parser/types/numbers/test_unsigned_ints.py +++ b/tests/parser/types/numbers/test_unsigned_ints.py @@ -195,49 +195,6 @@ def foo(x: {typ}, y: {typ}) -> bool: assert c.foo(x, y) is expected -# TODO move to tests/parser/functions/test_mulmod.py and test_addmod.py -def test_uint256_mod(assert_tx_failed, get_contract_with_gas_estimation): - uint256_code = """ -@external -def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_addmod(x, y, z) - -@external -def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_mulmod(x, y, z) - """ - - c = get_contract_with_gas_estimation(uint256_code) - - assert c._uint256_addmod(1, 2, 2) == 1 - assert c._uint256_addmod(32, 2, 32) == 2 - assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 - assert c._uint256_addmod(2**255, 2**255, 6) == 4 - assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) - assert c._uint256_mulmod(3, 1, 2) == 1 - assert c._uint256_mulmod(200, 3, 601) == 600 - assert c._uint256_mulmod(2**255, 1, 3) == 2 - assert c._uint256_mulmod(2**255, 2, 6) == 4 - assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) - - -def test_uint256_modmul(get_contract_with_gas_estimation): - modexper = """ -@external -def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: - o: uint256 = 1 - for i in range(256): - o = uint256_mulmod(o, o, modulus) - if exponent & (1 << (255 - i)) != 0: - o = uint256_mulmod(o, base, modulus) - return o - """ - - c = get_contract_with_gas_estimation(modexper) - assert c.exponential(3, 5, 100) == 43 - assert c.exponential(2, 997, 997) == 2 - - @pytest.mark.parametrize("typ", types) def test_uint_literal(get_contract, assert_compile_failed, typ): lo, hi = typ.ast_bounds diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index bfe90bb669..915f10ede3 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -148,15 +148,18 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - [ - "if", - ["slt", args[0], 0], - ["sdiv", ["sub", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], - ["sdiv", args[0], DECIMAL_DIVISOR], - ], - typ=INT256_T, - ) + arg = args[0] + with arg.cache_when_complex("arg") as (b1, arg): + ret = IRnode.from_list( + [ + "if", + ["slt", arg, 0], + ["sdiv", ["sub", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], + ["sdiv", arg, DECIMAL_DIVISOR], + ], + typ=INT256_T, + ) + return b1.resolve(ret) class Ceil(BuiltinFunction): @@ -175,15 +178,18 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - [ - "if", - ["slt", args[0], 0], - ["sdiv", args[0], DECIMAL_DIVISOR], - ["sdiv", ["add", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], - ], - typ=INT256_T, - ) + arg = args[0] + with arg.cache_when_complex("arg") as (b1, arg): + ret = IRnode.from_list( + [ + "if", + ["slt", arg, 0], + ["sdiv", arg, DECIMAL_DIVISOR], + ["sdiv", ["add", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], + ], + typ=INT256_T, + ) + return b1.resolve(ret) class Convert(BuiltinFunction): @@ -800,20 +806,25 @@ def build_IR(self, expr, args, kwargs, context): placeholder_node = IRnode.from_list( context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY ) - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(args[0], 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)], - ["mstore", ["add", placeholder_node, 64], _getelem(args[1], 0)], - ["mstore", ["add", placeholder_node, 96], _getelem(args[1], 1)], - ["assert", ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64]], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return o + + with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): + o = IRnode.from_list( + [ + "seq", + ["mstore", placeholder_node, _getelem(a, 0)], + ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], + ["mstore", ["add", placeholder_node, 64], _getelem(b, 0)], + ["mstore", ["add", placeholder_node, 96], _getelem(b, 1)], + [ + "assert", + ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64], + ], + placeholder_node, + ], + typ=SArrayT(UINT256_T, 2), + location=MEMORY, + ) + return b2.resolve(b1.resolve(o)) class ECMul(BuiltinFunction): @@ -826,19 +837,24 @@ def build_IR(self, expr, args, kwargs, context): placeholder_node = IRnode.from_list( context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY ) - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(args[0], 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)], - ["mstore", ["add", placeholder_node, 64], args[1]], - ["assert", ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64]], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return o + + with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): + o = IRnode.from_list( + [ + "seq", + ["mstore", placeholder_node, _getelem(a, 0)], + ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], + ["mstore", ["add", placeholder_node, 64], b], + [ + "assert", + ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64], + ], + placeholder_node, + ], + typ=SArrayT(UINT256_T, 2), + location=MEMORY, + ) + return b2.resolve(b1.resolve(o)) def _generic_element_getter(op): @@ -1030,34 +1046,35 @@ def build_IR(self, expr, args, kwargs, context): value = args[0] denom_divisor = self.get_denomination(expr) - if value.typ in (UINT256_T, UINT8_T): - sub = [ - "with", - "ans", - ["mul", value, denom_divisor], - [ - "seq", + with value.cache_when_complex("value") as (b1, value): + if value.typ in (UINT256_T, UINT8_T): + sub = [ + "with", + "ans", + ["mul", value, denom_divisor], [ - "assert", - ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]], + "seq", + [ + "assert", + ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]], + ], + "ans", ], - "ans", - ], - ] - elif value.typ == INT128_T: - # signed types do not require bounds checks because the - # largest possible converted value will not overflow 2**256 - sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]] - elif value.typ == DecimalT(): - sub = [ - "seq", - ["assert", ["sgt", value, -1]], - ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR], - ] - else: - raise CompilerPanic(f"Unexpected type: {value.typ}") + ] + elif value.typ == INT128_T: + # signed types do not require bounds checks because the + # largest possible converted value will not overflow 2**256 + sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]] + elif value.typ == DecimalT(): + sub = [ + "seq", + ["assert", ["sgt", value, -1]], + ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR], + ] + else: + raise CompilerPanic(f"Unexpected type: {value.typ}") - return IRnode.from_list(sub, typ=UINT256_T) + return IRnode.from_list(b1.resolve(sub), typ=UINT256_T) zero_value = IRnode.from_list(0, typ=UINT256_T) @@ -1516,9 +1533,13 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - ["seq", ["assert", args[2]], [self._opcode, args[0], args[1], args[2]]], typ=UINT256_T - ) + c = args[2] + + with c.cache_when_complex("c") as (b1, c): + ret = IRnode.from_list( + ["seq", ["assert", c], [self._opcode, args[0], args[1], c]], typ=UINT256_T + ) + return b1.resolve(ret) class AddMod(_AddMulMod): From 27b6b893caad86c461c7980dd777abe42f716e19 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 May 2023 14:24:28 -0400 Subject: [PATCH 005/201] fix: calculate `active_evm_version` from `DEFAULT_EVM_VERSION` (#3427) per 1ac8362df0c, the `DEFAULT_EVM_VERSION` is updated to shanghai, while `active_evm_version` still points to paris. so entry points into the compiler which don't use the @evm_wrapper wrapper might continue using paris. this commit fixes the issue by calculating active_evm_version from `DEFAULT_EVM_VERSION` so that only one value needs to be updated going forward. --- vyper/evm/opcodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 76529da14e..7ff56df772 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -3,8 +3,6 @@ from vyper.exceptions import CompilerPanic from vyper.typing import OpcodeGasCost, OpcodeMap, OpcodeRulesetMap, OpcodeRulesetValue, OpcodeValue -active_evm_version: int = 4 - # EVM version rules work as follows: # 1. Fork rules go from oldest (lowest value) to newest (highest value). # 2. Fork versions aren't actually tied to anything. They are not a part of our @@ -17,7 +15,7 @@ # 6. Yes, this will probably have to be rethought if there's ever conflicting support # between multiple chains for a specific feature. Let's hope not. # 7. We support at a maximum 3 hard forks (for any given chain). -EVM_VERSIONS: Dict[str, int] = { +EVM_VERSIONS: dict[str, int] = { # ETH Forks "byzantium": 0, "constantinople": 1, @@ -31,6 +29,7 @@ "agharta": 1, } DEFAULT_EVM_VERSION: str = "shanghai" +active_evm_version: int = EVM_VERSIONS[DEFAULT_EVM_VERSION] # opcode as hex value From 5b9bca24e3ac82654ce78ba66d5932ab609b28a8 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 May 2023 18:07:40 -0400 Subject: [PATCH 006/201] chore: fix badges in README (#3428) microbadger appears defunct, switch to shields.io switch lgtm to codeql --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f17e693bf5..af987ffd4f 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,15 @@ -[![Build Status](https://github.com/vyperlang/vyper/workflows/Test/badge.svg)](https://github.com/vyperlang/vyper/actions) +[![Build Status](https://github.com/vyperlang/vyper/workflows/Test/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/test.yml) [![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://vyper.readthedocs.io/en/latest/?badge=latest "ReadTheDocs") [![Discord](https://img.shields.io/discord/969926564286459934.svg?label=%23vyper)](https://discord.gg/6tw7PTM7C2) [![PyPI](https://badge.fury.io/py/vyper.svg)](https://pypi.org/project/vyper "PyPI") -[![Docker](https://images.microbadger.com/badges/version/vyperlang/vyper.svg)](https://hub.docker.com/r/vyperlang/vyper "DockerHub") +[![Docker](https://img.shields.io/docker/cloud/build/vyperlang/vyper)](https://hub.docker.com/r/vyperlang/vyper "DockerHub") [![Coverage Status](https://codecov.io/gh/vyperlang/vyper/branch/master/graph/badge.svg)](https://codecov.io/gh/vyperlang/vyper "Codecov") -[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/vyperlang/vyper.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/vyperlang/vyper/context:python) +[![Language grade: Python](https://github.com/vyperlang/vyper/workflows/CodeQL/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/codeql.yml) # Getting Started See [Installing Vyper](http://vyper.readthedocs.io/en/latest/installing-vyper.html) to install vyper. From ed0a654aa2f1069874c2c6d21b8932737aee3f6f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 May 2023 18:27:38 -0400 Subject: [PATCH 007/201] transient storage keyword (#3373) experimentally add support for transient storage via a new `transient` keyword, which works like `immutable` or `constant`, ex.: ```vyper my_transient_variable: transient(uint256) ``` this feature is considered experimental until py-evm adds support (giving us the ability to actually test it). so this commit leaves the default evm version as "shanghai" for now. it blocks the feature on pre-cancun EVM versions, so users can't use it by accident - the only way to use it is to explicitly enable it via `--evm-version=cancun`. --- tests/compiler/test_opcodes.py | 7 ++- tests/parser/ast_utils/test_ast_dict.py | 1 + .../features/decorators/test_nonreentrant.py | 2 + tests/parser/features/test_transient.py | 61 +++++++++++++++++++ vyper/ast/nodes.py | 18 ++++-- vyper/cli/vyper_compile.py | 3 +- vyper/codegen/context.py | 1 + vyper/codegen/core.py | 8 +-- vyper/codegen/expr.py | 6 +- vyper/codegen/function_definitions/utils.py | 10 ++- vyper/evm/address_space.py | 1 + vyper/evm/opcodes.py | 3 + vyper/semantics/analysis/base.py | 1 + vyper/semantics/analysis/module.py | 9 +++ vyper/semantics/data_locations.py | 2 + vyper/semantics/namespace.py | 1 + 16 files changed, 118 insertions(+), 16 deletions(-) create mode 100644 tests/parser/features/test_transient.py diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index f36fcfac6f..3c595dee44 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -45,11 +45,14 @@ def test_version_check(evm_version): def test_get_opcodes(evm_version): ops = opcodes.get_opcodes() - if evm_version in ("paris", "berlin", "shanghai"): + if evm_version in ("paris", "berlin", "shanghai", "cancun"): assert "CHAINID" in ops assert ops["SLOAD"][-1] == 2100 - if evm_version in ("shanghai",): + if evm_version in ("shanghai", "cancun"): assert "PUSH0" in ops + if evm_version in ("cancun",): + assert "TLOAD" in ops + assert "TSTORE" in ops elif evm_version == "istanbul": assert "CHAINID" in ops assert ops["SLOAD"][-1] == 800 diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/parser/ast_utils/test_ast_dict.py index 214af50f9f..f483d0cbe8 100644 --- a/tests/parser/ast_utils/test_ast_dict.py +++ b/tests/parser/ast_utils/test_ast_dict.py @@ -73,6 +73,7 @@ def test_basic_ast(): "is_constant": False, "is_immutable": False, "is_public": False, + "is_transient": False, } diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/parser/features/decorators/test_nonreentrant.py index 0577313b88..ac73b35bec 100644 --- a/tests/parser/features/decorators/test_nonreentrant.py +++ b/tests/parser/features/decorators/test_nonreentrant.py @@ -3,6 +3,8 @@ from vyper.exceptions import FunctionDeclarationException +# TODO test functions in this module across all evm versions +# once we have cancun support. def test_nonreentrant_decorator(get_contract, assert_tx_failed): calling_contract_code = """ interface SpecialContract: diff --git a/tests/parser/features/test_transient.py b/tests/parser/features/test_transient.py new file mode 100644 index 0000000000..53354beca8 --- /dev/null +++ b/tests/parser/features/test_transient.py @@ -0,0 +1,61 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.evm.opcodes import EVM_VERSIONS +from vyper.exceptions import StructureException + +post_cancun = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + +@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys())) +def test_transient_blocked(evm_version): + # test transient is blocked on pre-cancun and compiles post-cancun + code = """ +my_map: transient(HashMap[address, uint256]) + """ + if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["cancun"]: + assert compile_code(code, evm_version=evm_version) is not None + else: + with pytest.raises(StructureException): + compile_code(code, evm_version=evm_version) + + +@pytest.mark.parametrize("evm_version", list(post_cancun.keys())) +def test_transient_compiles(evm_version): + # test transient keyword at least generates TLOAD/TSTORE opcodes + getter_code = """ +my_map: public(transient(HashMap[address, uint256])) + """ + t = compile_code(getter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" in t + assert "TSTORE" not in t + + setter_code = """ +my_map: transient(HashMap[address, uint256]) + +@external +def setter(k: address, v: uint256): + self.my_map[k] = v + """ + t = compile_code(setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" not in t + assert "TSTORE" in t + + getter_setter_code = """ +my_map: public(transient(HashMap[address, uint256])) + +@external +def setter(k: address, v: uint256): + self.my_map[k] = v + """ + t = compile_code( + getter_setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"] + ) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" in t + assert "TSTORE" in t diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 03f2d713c1..7c907b4d08 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1344,7 +1344,15 @@ class VariableDecl(VyperNode): If true, indicates that the variable is an immutable variable. """ - __slots__ = ("target", "annotation", "value", "is_constant", "is_public", "is_immutable") + __slots__ = ( + "target", + "annotation", + "value", + "is_constant", + "is_public", + "is_immutable", + "is_transient", + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1352,6 +1360,7 @@ def __init__(self, *args, **kwargs): self.is_constant = False self.is_public = False self.is_immutable = False + self.is_transient = False def _check_args(annotation, call_name): # do the same thing as `validate_call_args` @@ -1369,9 +1378,10 @@ def _check_args(annotation, call_name): # unwrap one layer self.annotation = self.annotation.args[0] - if self.annotation.get("func.id") in ("immutable", "constant"): - _check_args(self.annotation, self.annotation.func.id) - setattr(self, f"is_{self.annotation.func.id}", True) + func_id = self.annotation.get("func.id") + if func_id in ("immutable", "constant", "transient"): + _check_args(self.annotation, func_id) + setattr(self, f"is_{func_id}", True) # unwrap one layer self.annotation = self.annotation.args[0] diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 9ab884a6d0..4dfc87639a 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -101,7 +101,8 @@ def _parse_args(argv): ) parser.add_argument( "--evm-version", - help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION})", + help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). " + " note: cancun support is EXPERIMENTAL", choices=list(EVM_VERSIONS), default=DEFAULT_EVM_VERSION, dest="evm_version", diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 6e8d02c9b3..e4b41adbc0 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -28,6 +28,7 @@ class VariableRecord: defined_at: Any = None is_internal: bool = False is_immutable: bool = False + is_transient: bool = False data_offset: Optional[int] = None def __hash__(self): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index a9a91ec9d8..06140f3f52 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,6 +1,6 @@ from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE +from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( @@ -562,10 +562,10 @@ def _get_element_ptr_mapping(parent, key): key = unwrap_location(key) # TODO when is key None? - if key is None or parent.location != STORAGE: - raise TypeCheckFailure(f"bad dereference on mapping {parent}[{key}]") + if key is None or parent.location not in (STORAGE, TRANSIENT): + raise TypeCheckFailure("bad dereference on mapping {parent}[{key}]") - return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=STORAGE) + return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=parent.location) # Take a value representing a memory or storage location, and descend down to diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 4a18a16e1b..ac7290836b 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -23,7 +23,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE +from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import ( CompilerPanic, @@ -259,10 +259,12 @@ def parse_Attribute(self): # self.x: global attribute elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": varinfo = self.context.globals[self.expr.attr] + location = TRANSIENT if varinfo.is_transient else STORAGE + ret = IRnode.from_list( varinfo.position.position, typ=varinfo.typ, - location=STORAGE, + location=location, annotation="self." + self.expr.attr, ) ret._referenced_variables = {varinfo} diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py index 7129388c58..f524ec6e88 100644 --- a/vyper/codegen/function_definitions/utils.py +++ b/vyper/codegen/function_definitions/utils.py @@ -8,6 +8,10 @@ def get_nonreentrant_lock(func_type): nkey = func_type.reentrancy_key_position.position + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + if version_check(begin="berlin"): # any nonzero values would work here (see pricing as of net gas # metering); these values are chosen so that downgrading to the @@ -16,12 +20,12 @@ def get_nonreentrant_lock(func_type): else: final_value, temp_value = 0, 1 - check_notset = ["assert", ["ne", temp_value, ["sload", nkey]]] + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] if func_type.mutability == StateMutability.VIEW: return [check_notset], [["seq"]] else: - pre = ["seq", check_notset, ["sstore", nkey, temp_value]] - post = ["sstore", nkey, final_value] + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] return [pre], [post] diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 855e98b5c8..85a75c3c23 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -48,6 +48,7 @@ def byte_addressable(self) -> bool: MEMORY = AddrSpace("memory", 32, "mload", "mstore") STORAGE = AddrSpace("storage", 1, "sload", "sstore") +TRANSIENT = AddrSpace("transient", 1, "tload", "tstore") CALLDATA = AddrSpace("calldata", 32, "calldataload") # immutables address space: "immutables" section of memory # which is read-write in deploy code but then gets turned into diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 7ff56df772..c447fd863c 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -24,6 +24,7 @@ "berlin": 3, "paris": 4, "shanghai": 5, + "cancun": 6, # ETC Forks "atlantis": 0, "agharta": 1, @@ -184,6 +185,8 @@ "INVALID": (0xFE, 0, 0, 0), "DEBUG": (0xA5, 1, 0, 0), "BREAKPOINT": (0xA6, 0, 0, 0), + "TLOAD": (0xB3, 1, 1, 100), + "TSTORE": (0xB4, 2, 0, 100), } PSEUDO_OPCODES: OpcodeMap = { diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 5065131f29..449e6ca338 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -162,6 +162,7 @@ class VarInfo: is_constant: bool = False is_public: bool = False is_immutable: bool = False + is_transient: bool = False is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 188005e365..cb8e93ff28 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -4,6 +4,7 @@ import vyper.builtins.interfaces from vyper import ast as vy_ast +from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, CompilerPanic, @@ -189,10 +190,17 @@ def visit_VariableDecl(self, node): if node.is_immutable else DataLocation.UNSET if node.is_constant + # XXX: needed if we want separate transient allocator + # else DataLocation.TRANSIENT + # if node.is_transient else DataLocation.STORAGE ) type_ = type_from_annotation(node.annotation, data_loc) + + if node.is_transient and not version_check(begin="cancun"): + raise StructureException("`transient` is not available pre-cancun", node.annotation) + var_info = VarInfo( type_, decl_node=node, @@ -200,6 +208,7 @@ def visit_VariableDecl(self, node): is_constant=node.is_constant, is_public=node.is_public, is_immutable=node.is_immutable, + is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 0ec374e42f..2f259b1766 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -7,3 +7,5 @@ class DataLocation(enum.Enum): STORAGE = 2 CALLDATA = 3 CODE = 4 + # XXX: needed for separate transient storage allocator + # TRANSIENT = 5 diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index d760f66972..82a5d5cf3e 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -176,6 +176,7 @@ def validate_identifier(attr): "nonpayable", "constant", "immutable", + "transient", "internal", "payable", "nonreentrant", From 903727006c1e5ebef99fa9fd5d51d62bd33d72a9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 19 May 2023 10:17:11 -0400 Subject: [PATCH 008/201] Merge pull request from GHSA-vxmm-cwh2-q762 on <=0.3.7, the batch payable check was broken. this was fixed due to the removal of the global calldatasize check in 02339dfda0. this commit adds a test to prevent regression --- .../features/decorators/test_payable.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/parser/features/decorators/test_payable.py b/tests/parser/features/decorators/test_payable.py index 906ae330c0..55c60236f4 100644 --- a/tests/parser/features/decorators/test_payable.py +++ b/tests/parser/features/decorators/test_payable.py @@ -372,3 +372,24 @@ def __default__(): assert_tx_failed( lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) ) + + +def test_batch_nonpayable(get_contract, w3, assert_tx_failed): + code = """ +@external +def foo() -> bool: + return True + +@external +def __default__(): + pass + """ + + c = get_contract(code) + w3.eth.send_transaction({"to": c.address, "value": 0, "data": "0x12345678"}) + data = bytes([1, 2, 3, 4]) + for i in range(5): + calldata = "0x" + data[:i].hex() + assert_tx_failed( + lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": calldata}) + ) From 32c9a3d70e066d9b4c31adb0a11c33ec1ee640bd Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 19 May 2023 23:04:47 +0800 Subject: [PATCH 009/201] chore: fix a comment (#3431) fix comment on TYPE_T --------- Co-authored-by: Charles Cooper --- vyper/semantics/types/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 8a174566eb..af955f6071 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -314,7 +314,9 @@ def __init__(self, typ, default, require_literal=False): self.require_literal = require_literal -# A type type. Only used internally for builtins +# A type type. Used internally for types which can live in expression +# position, ex. constructors (events, interfaces and structs), and also +# certain builtins which take types as parameters class TYPE_T: def __init__(self, typedef): self.typedef = typedef From 11e1ae9f8547632c4ecbed8565dccc082f12fd8f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 19 May 2023 14:36:28 -0400 Subject: [PATCH 010/201] chore: cache pip in setup-python (#3436) installing dependencies takes about 1min per job. pip caching should speed it up. note that in theory, this caches "correctly" in that it doesn't cache the dependency graph or install directories, just the wheels. so if upstream packages are updated, they should get reinstalled. --- .github/workflows/build.yml | 2 ++ .github/workflows/era-tester.yml | 1 + .github/workflows/publish.yml | 2 +- .github/workflows/test.yml | 6 ++++++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a3e9a195f6..43586c262a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,6 +34,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Generate Binary run: >- @@ -61,6 +62,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Generate Binary run: >- diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml index 6c15e6af07..8c7e355d26 100644 --- a/.github/workflows/era-tester.yml +++ b/.github/workflows/era-tester.yml @@ -38,6 +38,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version[0] }} + cache: "pip" - name: Get cache id: get-cache diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5a8d989038..e6e5f2a6f9 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -18,7 +18,7 @@ jobs: - name: Python uses: actions/setup-python@v4 with: - python-version: '3.x' + python-version: "3.11" - name: Install dependencies run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f90ff706ec..94e8c7c8f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Install Dependencies run: pip install .[lint] @@ -46,6 +47,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Install Tox run: pip install tox @@ -63,6 +65,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Install Tox run: pip install tox @@ -88,6 +91,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version[0] }} + cache: "pip" - name: Install Tox run: pip install tox @@ -130,6 +134,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Install Tox run: pip install tox @@ -171,6 +176,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.11" + cache: "pip" - name: Install Tox run: pip install tox From 8a28372f6d9f9e63dfa4c7ffcbf7fad4f8169117 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 19 May 2023 16:04:48 -0400 Subject: [PATCH 011/201] chore: make `FuncIRInfo` generation private (#3437) this moves generation of `func_t._ir_info` to be closer to where it is used (and where FuncIRInfo is defined!). since FuncIRInfo is no longer imported anywhere, it can be changed to a private member of the function_definitions/common.py module. --- vyper/codegen/function_definitions/__init__.py | 2 +- vyper/codegen/function_definitions/common.py | 5 ++++- vyper/codegen/module.py | 7 +------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index b677a14579..08bebbb4a5 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1 @@ -from .common import FuncIRInfo, generate_ir_for_function # noqa +from .common import generate_ir_for_function # noqa diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 45b97831aa..fd65b12265 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -28,7 +28,7 @@ def mem_used(self): @dataclass -class FuncIRInfo: +class _FuncIRInfo: func_t: ContractFunctionT gas_estimate: Optional[int] = None frame_info: Optional[FrameInfo] = None @@ -78,6 +78,9 @@ def generate_ir_for_function( """ func_t = code._metadata["type"] + # generate _FuncIRInfo + func_t._ir_info = _FuncIRInfo(func_t) + # Validate return statements. check_single_exit(code) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 5d05c27e0b..9bc589d82f 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,7 @@ from vyper import ast as vy_ast from vyper.codegen.core import shr -from vyper.codegen.function_definitions import FuncIRInfo, generate_ir_for_function +from vyper.codegen.function_definitions import generate_ir_for_function from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic @@ -136,11 +136,6 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: init_function: Optional[vy_ast.FunctionDef] = None - # generate all FuncIRInfos - for f in function_defs: - func_t = f._metadata["type"] - func_t._ir_info = FuncIRInfo(func_t) - runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) From 870ad491c86be30c00c2404d85bae8368a7cf1d1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 19 May 2023 16:33:07 -0400 Subject: [PATCH 012/201] ci: deploy to ghcr on push (#3435) publish and tag docker images continuously to ghcr.io. adds custom tagging so we can retain every commit. it's technically possible to do this on docker hub, but in order to have custom tags, you need to set up a regular user and log in/push via that user. the authentication is much cleaner in github actions for ghcr. (note docker hub pulls for releases are still staying the same, this is just an alternative form of retention going forward). --- .github/workflows/ghcr.yml | 73 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 .github/workflows/ghcr.yml diff --git a/.github/workflows/ghcr.yml b/.github/workflows/ghcr.yml new file mode 100644 index 0000000000..d227d6caf0 --- /dev/null +++ b/.github/workflows/ghcr.yml @@ -0,0 +1,73 @@ +name: Deploy docker image to ghcr + +# Deploy docker image to ghcr on pushes to master and all releases/tags. +# Note releases to docker hub are managed separately in another process +# (github sends webhooks to docker hub which triggers the build there). +# This workflow is an alternative form of retention for docker images +# which also allows us to tag and retain every single commit to master. + +on: + push: + tags: + - '*' + branches: + - master + release: + types: [released] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + deploy-ghcr: + + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + # need to fetch unshallow so that setuptools_scm can infer the version + fetch-depth: 0 + + - uses: actions/setup-python@v4 + name: Install python + with: + python-version: "3.11" + cache: "pip" + + - name: Generate vyper/version.py + run: | + pip install . + echo "VYPER_VERSION=$(PYTHONPATH=. python vyper/cli/vyper_compile.py --version)" >> "$GITHUB_ENV" + + - name: Docker meta + id: meta + uses: docker/metadata-action@v4 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + flavor: | + latest=true + tags: | + type=ref,event=branch + type=ref,event=tag + type=raw,value=${{ env.VYPER_VERSION }} + + - name: Login to ghcr.io + uses: docker/login-action@v2 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v4 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} From 95bf73f493dc8458a1d6981493275379197a4bdf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 20 May 2023 13:55:48 -0400 Subject: [PATCH 013/201] chore: fix up ghcr tags (#3438) use `latest` for latest release, `dev` for continuous, and tag dev builds with `-dev`. also remove the 'master' tag since that's redundant with `dev`. --- .github/workflows/ghcr.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ghcr.yml b/.github/workflows/ghcr.yml index d227d6caf0..4bc2885bdb 100644 --- a/.github/workflows/ghcr.yml +++ b/.github/workflows/ghcr.yml @@ -45,17 +45,21 @@ jobs: pip install . echo "VYPER_VERSION=$(PYTHONPATH=. python vyper/cli/vyper_compile.py --version)" >> "$GITHUB_ENV" + - name: generate tag suffix + if: ${{ github.event_name != 'release' }} + run: echo "VERSION_SUFFIX=-dev" >> "$GITHUB_ENV" + - name: Docker meta id: meta uses: docker/metadata-action@v4 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - flavor: | - latest=true tags: | - type=ref,event=branch type=ref,event=tag - type=raw,value=${{ env.VYPER_VERSION }} + type=raw,value=${{ env.VYPER_VERSION }}${{ env.VERSION_SUFFIX }} + type=raw,value=dev,enable=${{ github.ref == 'refs/heads/master' }} + type=raw,value=latest,enable=${{ github.event_name == 'release' }} + - name: Login to ghcr.io uses: docker/login-action@v2 From 0ed5c23ca43cd299f4e62c262df31038806ac164 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 10:26:06 -0400 Subject: [PATCH 014/201] chore: add v0.3.8 release notes (#3439) --- docs/release-notes.rst | 107 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 89a528dc49..cf4d8d42f9 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -3,6 +3,113 @@ Release Notes ############# +.. + vim regexes: + first convert all single backticks to double backticks: + :'<,'>s/`/``/g + to convert links to nice rst links: + :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/pull\/)(\d+)/(`#\2 <\1\2>`_)/g + ex. in: https://github.com/vyperlang/vyper/pull/3373 + ex. out: (`#3373 `_) + for advisory links: + :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g + +v0.3.8 +****** + +Date released: 2023-05-23 + +Non-breaking changes and improvements: + +- ``transient`` storage keyword (`#3373 `_) +- ternary operators (`#3398 `_) +- ``raw_revert()`` builtin (`#3136 `_) +- shift operators (`#3019 `_) +- make ``send()`` gas stipend configurable (`#3158 `_) +- use new ``push0`` opcode (`#3361 `_) +- python 3.11 support (`#3129 `_) +- drop support for python 3.8 and 3.9 (`#3325 `_) +- build for ``aarch64`` (`#2687 `_) + +Major refactoring PRs: + +- refactor front-end type system (`#2974 `_) +- merge front-end and codegen type systems (`#3182 `_) +- simplify ``GlobalContext`` (`#3209 `_) +- remove ``FunctionSignature`` (`#3390 `_) + +Notable fixes: + +- assignment when rhs is complex type and references lhs (`#3410 `_) +- uninitialized immutable values (`#3409 `_) +- success value when mixing ``max_outsize=0`` and ``revert_on_failure=False`` (`GHSA-w9g2-3w7p-72g9 `_) +- block certain kinds of storage allocator overflows (`GHSA-mgv8-gggw-mrg6 `_) +- store-before-load when a dynarray appears on both sides of an assignment (`GHSA-3p37-3636-q8wv `_) +- bounds check for loops of the form ``for i in range(x, x+N)`` (`GHSA-6r8q-pfpv-7cgj `_) +- alignment of call-site posargs and kwargs for internal functions (`GHSA-ph9x-4vc9-m39g `_) +- batch nonpayable check for default functions calldatasize < 4 (`#3104 `_, `#3408 `_, cf. `GHSA-vxmm-cwh2-q762 `_) + +Other docs updates, chores and fixes: + +- call graph stability (`#3370 `_) +- fix ``vyper-serve`` output (`#3338 `_) +- add ``custom:`` natspec tags (`#3403 `_) +- add missing pc maps to ``vyper_json`` output (`#3333 `_) +- fix constructor context for internal functions (`#3388 `_) +- add deprecation warning for ``selfdestruct`` usage (`#3372 `_) +- add bytecode metadata option to vyper-json (`#3117 `_) +- fix compiler panic when a ``break`` is outside of a loop (`#3177 `_) +- fix complex arguments to builtin functions (`#3167 `_) +- add support for all types in ABI imports (`#3154 `_) +- disable uadd operator (`#3174 `_) +- block bitwise ops on decimals (`#3219 `_) +- raise ``UNREACHABLE`` (`#3194 `_) +- allow enum as mapping key (`#3256 `_) +- block boolean ``not`` operator on numeric types (`#3231 `_) +- enforce that loop's iterators are valid names (`#3242 `_) +- fix typechecker hotspot (`#3318 `_) +- rewrite typechecker journal to handle nested commits (`#3375 `_) +- fix missing pc map for empty functions (`#3202 `_) +- guard against iterating over empty list in for loop (`#3197 `_) +- skip enum members during constant folding (`#3235 `_) +- bitwise ``not`` constant folding (`#3222 `_) +- allow accessing members of constant address (`#3261 `_) +- guard against decorators in interface (`#3266 `_) +- fix bounds for decimals in some builtins (`#3283 `_) +- length of literal empty bytestrings (`#3276 `_) +- block ``empty()`` for HashMaps (`#3303 `_) +- fix type inference for empty lists (`#3377 `_) +- disallow logging from ``pure``, ``view`` functions (`#3424 `_) +- improve optimizer rules for comparison operators (`#3412 `_) +- deploy to ghcr on push (`#3435 `_) +- add note on return value bounds in interfaces (`#3205 `_) +- index ``id`` param in ``URI`` event of ``ERC1155ownable`` (`#3203 `_) +- add missing ``asset`` function to ``ERC4626`` built-in interface (`#3295 `_) +- clarify ``skip_contract_check=True`` can result in undefined behavior (`#3386 `_) +- add ``custom`` NatSpec tag to docs (`#3404 `_) +- fix ``uint256_addmod`` doc (`#3300 `_) +- document optional kwargs for external calls (`#3122 `_) +- remove ``slice()`` length documentation caveats (`#3152 `_) +- fix docs of ``blockhash`` to reflect revert behaviour (`#3168 `_) +- improvements to compiler error messages (`#3121 `_, `#3134 `_, `#3312 `_, `#3304 `_, `#3240 `_, `#3264 `_, `#3343 `_, `#3307 `_, `#3313 `_ and `#3215 `_) + +These are really just the highlights, as many other bugfixes, docs updates and refactoring (over 150 pull requests!) made it into this release! For the full list, please see the `changelog `_. Special thanks to contributions from @tserg, @trocher, @z80dev, @emc415 and @benber86 in this release! + +New Contributors: + +- @omahs made their first contribution in (`#3128 `_) +- @ObiajuluM made their first contribution in (`#3124 `_) +- @trocher made their first contribution in (`#3134 `_) +- @ozmium22 made their first contribution in (`#3149 `_) +- @ToonVanHove made their first contribution in (`#3168 `_) +- @emc415 made their first contribution in (`#3158 `_) +- @lgtm-com made their first contribution in (`#3147 `_) +- @tdurieux made their first contribution in (`#3224 `_) +- @victor-ego made their first contribution in (`#3263 `_) +- @miohtama made their first contribution in (`#3257 `_) +- @kelvinfan001 made their first contribution in (`#2687 `_) + + v0.3.7 ****** From 036f153683e0d55b890305eb4c77680a0872fcba Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 10:48:21 -0400 Subject: [PATCH 015/201] chore: clean up 0.3.7 release notes formatting (#3444) --- docs/release-notes.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index cf4d8d42f9..3e7bc02587 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -115,29 +115,29 @@ v0.3.7 Date released: 2022-09-26 -## Breaking changes: +Breaking changes: - chore: drop python 3.7 support (`#3071 `_) - fix: relax check for statically sized calldata (`#3090 `_) -## Non-breaking changes and improvements: +Non-breaking changes and improvements: -- fix: assert description in Crowdfund.finalize() (`#3058 `_) +- fix: assert description in ``Crowdfund.finalize()`` (`#3058 `_) - fix: change mutability of example ERC721 interface (`#3076 `_) - chore: improve error message for non-checksummed address literal (`#3065 `_) -- feat: isqrt built-in (`#3074 `_) (`#3069 `_) -- feat: add `block.prevrandao` as alias for `block.difficulty` (`#3085 `_) -- feat: epsilon builtin (`#3057 `_) +- feat: ``isqrt()`` builtin (`#3074 `_) (`#3069 `_) +- feat: add ``block.prevrandao`` as alias for ``block.difficulty`` (`#3085 `_) +- feat: ``epsilon()`` builtin (`#3057 `_) - feat: extend ecrecover signature to accept additional parameter types (`#3084 `_) - feat: allow constant and immutable variables to be declared public (`#3024 `_) - feat: optionally disable metadata in bytecode (`#3107 `_) -## Bugfixes: +Bugfixes: - fix: empty nested dynamic arrays (`#3061 `_) - fix: foldable builtin default args in imports (`#3079 `_) (`#3077 `_) -## Additional changes and improvements: +Additional changes and improvements: - doc: update broken links in SECURITY.md (`#3095 `_) - chore: update discord link in docs (`#3031 `_) @@ -147,7 +147,7 @@ Date released: 2022-09-26 - chore: migrate lark grammar (`#3082 `_) - chore: loosen and upgrade semantic version (`#3106 `_) -# New Contributors +New Contributors - @emilianobonassi made their first contribution in `#3107 `_ - @unparalleled-js made their first contribution in `#3106 `_ From 71c8e55b7ca6b5cef02411c006c8cdc3f0b0a8e1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 17:48:25 -0400 Subject: [PATCH 016/201] chore: build for old ubuntus (#3453) python3.11 uses a new libc which is not compatible with ubuntu 20.04. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43586c262a..f2b63e9967 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,7 +20,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-20.04, macos-latest] steps: - uses: actions/checkout@v2 From 510125e0fce389fcc2b9993691696eb0836345b6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 17:53:31 -0400 Subject: [PATCH 017/201] fix: initcode codesize regression (#3450) this commit fixes a regression in c202c4e3ec8. the commit message states that we rely on the dead code eliminator to prune unused internal functions in the initcode, but the dead code eliminator does not prune dead code in all cases, including nested internal functions and loops. this commit reintroduces the reachability analysis in `vyper/codegen/module.py` as a stopgap until the dead code eliminator is more robust. --- tests/compiler/asm/test_asm_optimizer.py | 83 +++++++++++++++---- .../parser/functions/test_create_functions.py | 24 ++++-- vyper/codegen/module.py | 7 +- vyper/ir/compile_ir.py | 13 ++- 4 files changed, 100 insertions(+), 27 deletions(-) diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py index b82d568ff8..f4a245e168 100644 --- a/tests/compiler/asm/test_asm_optimizer.py +++ b/tests/compiler/asm/test_asm_optimizer.py @@ -1,49 +1,102 @@ -from vyper.compiler.phases import CompilerData +import pytest +from vyper.compiler.phases import CompilerData -def test_dead_code_eliminator(): - code = """ +codes = [ + """ s: uint256 @internal -def foo(): +def ctor_only(): self.s = 1 @internal -def qux(): +def runtime_only(): self.s = 2 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, + # code with nested function in it + """ +s: uint256 + +@internal +def runtime_only(): + self.s = 1 + +@internal +def foo(): + self.runtime_only() + +@internal +def ctor_only(): + self.s += 1 + @external def bar(): self.foo() @external def __init__(): - self.qux() + self.ctor_only() + """, + # code with loop in it, these are harder for dead code eliminator """ +s: uint256 + +@internal +def ctor_only(): + self.s = 1 + +@internal +def runtime_only(): + for i in range(10): + self.s += 1 +@external +def bar(): + self.runtime_only() + +@external +def __init__(): + self.ctor_only() + """, +] + + +@pytest.mark.parametrize("code", codes) +def test_dead_code_eliminator(code): c = CompilerData(code, no_optimize=True) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - foo_label = "_sym_internal_foo___" - qux_label = "_sym_internal_qux___" + ctor_only_label = "_sym_internal_ctor_only___" + runtime_only_label = "_sym_internal_runtime_only___" + + # qux reachable from unoptimized initcode, foo not reachable. + assert ctor_only_label + "_deploy" in initcode_asm + assert runtime_only_label + "_deploy" not in initcode_asm - # all the labels should be in all the unoptimized asms - for s in (foo_label, qux_label): - assert s + "_deploy" in initcode_asm + # all labels should be in unoptimized runtime asm + for s in (ctor_only_label, runtime_only_label): assert s + "_runtime" in runtime_asm c = CompilerData(code, no_optimize=False) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime - # qux should not be in runtime code + # ctor only label should not be in runtime code for instr in runtime_asm: if isinstance(instr, str): - assert not instr.startswith(qux_label), instr + assert not instr.startswith(ctor_only_label), instr - # foo should not be in initcode asm + # runtime only label should not be in initcode asm for instr in initcode_asm: if isinstance(instr, str): - assert not instr.startswith(foo_label), instr + assert not instr.startswith(runtime_only_label), instr diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py index 857173df7e..64e0823146 100644 --- a/tests/parser/functions/test_create_functions.py +++ b/tests/parser/functions/test_create_functions.py @@ -3,6 +3,8 @@ from eth.codecs import abi from hexbytes import HexBytes +import vyper.ir.compile_ir as compile_ir +from vyper.codegen.ir_node import IRnode from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256 @@ -224,15 +226,23 @@ def test(code_ofst: uint256) -> address: return create_from_blueprint(BLUEPRINT, code_offset=code_ofst) """ - # use a bunch of JUMPDEST + STOP instructions as blueprint code - # (as any STOP instruction returns valid code, split up with - # jumpdests as optimization fence) initcode_len = 100 - f = get_contract_from_ir(["deploy", 0, ["seq"] + ["jumpdest", "stop"] * (initcode_len // 2), 0]) - blueprint_code = w3.eth.get_code(f.address) - print(blueprint_code) - d = get_contract(deployer_code, f.address) + # deploy a blueprint contract whose contained initcode contains only + # zeroes (so no matter which offset, create_from_blueprint will + # return empty code) + ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0]) + bytecode, _ = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir, no_optimize=True)) + # manually deploy the bytecode + c = w3.eth.contract(abi=[], bytecode=bytecode) + deploy_transaction = c.constructor() + tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} + tx_hash = deploy_transaction.transact(tx_info) + blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + blueprint_code = w3.eth.get_code(blueprint_address) + print("BLUEPRINT CODE:", blueprint_code) + + d = get_contract(deployer_code, blueprint_address) # deploy with code_ofst=0 fine d.test(0) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 9bc589d82f..2fece47a9e 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -123,7 +123,6 @@ def _runtime_ir(runtime_functions, global_ctx): ["label", "fallback", ["var_list"], fallback_ir], ] - # note: dead code eliminator will clean dead functions runtime.extend(internal_functions_ir) return runtime @@ -178,10 +177,14 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # internal functions come after everything else internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: + init_func_t = init_function._metadata["type"] + if f.name not in init_func_t.recursive_calls: + # unreachable + continue + func_ir = generate_ir_for_function( f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True ) - # note: we depend on dead code eliminator to clean dead function defs deploy_code.append(func_ir) else: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 57ea4ca7e7..b2a58fa8c9 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -758,6 +758,9 @@ def note_breakpoint(line_number_map, item, pos): line_number_map["breakpoints"].add(item.lineno + 1) +_TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID") + + def _prune_unreachable_code(assembly): # In converting IR to assembly we sometimes end up with unreachable # instructions - POPing to clear the stack or STOPing execution at the @@ -766,9 +769,13 @@ def _prune_unreachable_code(assembly): # to avoid unnecessary bytecode bloat. changed = False i = 0 - while i < len(assembly) - 1: - if assembly[i] in ("JUMP", "RETURN", "REVERT", "STOP") and not ( - is_symbol(assembly[i + 1]) or assembly[i + 1] == "JUMPDEST" + while i < len(assembly) - 2: + instr = assembly[i] + if isinstance(instr, list): + instr = assembly[i][-1] + + if assembly[i] in _TERMINAL_OPS and not ( + is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK") ): changed = True del assembly[i + 1] From b3c7efa38eef63d9d979ad663a9ecfd77e343987 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 23 May 2023 18:31:18 -0400 Subject: [PATCH 018/201] add v0.3.9 release notes (#3452) --- docs/release-notes.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 3e7bc02587..06bb29d839 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,6 +14,18 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g +v0.3.9 +****** + +Date released: 2023-05-23 + +This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. + +Fixes: + +- initcode codesize blowup (`#3450 `_) + + v0.3.8 ****** From 19073dadb3c0d661f4bddfa7de7852d85835484e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 24 May 2023 10:08:08 -0400 Subject: [PATCH 019/201] chore: update v0.3.9 release date (#3455) --- docs/release-notes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 06bb29d839..a8bd309bf1 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -17,7 +17,7 @@ Release Notes v0.3.9 ****** -Date released: 2023-05-23 +Date released: 2023-05-24 This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. From 070b0cfba0bf32da66b8640740d54d9e2a850833 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 25 May 2023 10:29:27 -0400 Subject: [PATCH 020/201] fix: add error message for send() builtin (#3459) --- vyper/builtins/functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 915f10ede3..af965afe0a 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1247,7 +1247,9 @@ def build_IR(self, expr, args, kwargs, context): to, value = args gas = kwargs["gas"] context.check_is_not_constant("send ether", expr) - return IRnode.from_list(["assert", ["call", gas, to, value, 0, 0, 0, 0]]) + return IRnode.from_list( + ["assert", ["call", gas, to, value, 0, 0, 0, 0]], error_msg="send failed" + ) class SelfDestruct(BuiltinFunction): From 056dfd84878bf3892a42eea33415d792ac0d29c1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 25 May 2023 22:37:35 -0400 Subject: [PATCH 021/201] fix: `source_map_full` output in `vyper-json` (#3464) source_map_full output fails with a KeyError when user requests evm.sourceMapFull but not evm.sourceMap. this commit fixes. --- vyper/cli/vyper_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index aa6cf1c2f5..2fbf58aec4 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -429,7 +429,7 @@ def format_to_output_dict(compiler_data: Dict) -> Dict: if "source_map" in data: evm["sourceMap"] = data["source_map"]["pc_pos_map_compressed"] if "source_map_full" in data: - evm["sourceMapFull"] = data["source_map"] + evm["sourceMapFull"] = data["source_map_full"] return output_dict From 64733b9d15935ecd2bfcfdfbb9606d5ab500d70c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 26 May 2023 11:36:30 -0400 Subject: [PATCH 022/201] fix: add error message for nonpayable check (#3466) --- vyper/codegen/function_definitions/external_function.py | 5 ++++- vyper/codegen/module.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 6104a86c16..312cb75cf8 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -200,7 +200,10 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che if not func_t.is_payable and not skip_nonpayable_check: # if the contract contains payable functions, but this is not one of them # add an assertion that the value of the call is zero - body += [["assert", ["iszero", "callvalue"]]] + nonpayable_check = IRnode.from_list( + ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" + ) + body.append(nonpayable_check) body += nonreentrant_pre diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 2fece47a9e..2d498460be 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -97,7 +97,10 @@ def _runtime_ir(runtime_functions, global_ctx): selector_section.append(func_ir) if batch_payable_check: - selector_section.append(["assert", ["iszero", "callvalue"]]) + nonpayable_check = IRnode.from_list( + ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" + ) + selector_section.append(nonpayable_check) for func_ast in nonpayables: func_ir = generate_ir_for_function(func_ast, global_ctx, skip_nonpayable_check) From 33c247151cfed13999289c08f3c35d70fa83394d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 28 May 2023 10:33:50 -0400 Subject: [PATCH 023/201] fix: add back global calldatasize check (#3463) prevent a performance regression for sending eth to contracts with a payable default function by (mostly) reverting the changes introduced in 9ecb97b4b6f and 02339dfda0. the `skip_nonpayable_check=False` for the default function is introduced to address GHSA-vxmm-cwh2-q762 (which 02339dfda0 inadvertently fixed, and a test for which was added in 903727006c). --- .../function_definitions/external_function.py | 25 +------------------ vyper/codegen/module.py | 7 +++++- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 312cb75cf8..207356860b 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -89,8 +89,7 @@ def handler_for(calldata_kwargs, default_kwargs): calldata_min_size = args_abi_t.min_size() + 4 # note we don't need the check if calldata_min_size == 4, - # because the selector checks later in this routine ensure - # that calldatasize >= 4. + # because the global calldatasize check ensures that already. if calldata_min_size > 4: ret.append(["assert", ["ge", "calldatasize", calldata_min_size]]) @@ -125,28 +124,6 @@ def handler_for(calldata_kwargs, default_kwargs): ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) method_id_check = ["eq", "_calldata_method_id", method_id] - - # if there is a function whose selector is 0 or has trailing 0s, it - # might not be distinguished from the case where insufficient calldata - # is supplied, b/c calldataload loads 0s past the end of physical - # calldata (cf. yellow paper). - # since the expected behavior of supplying insufficient calldata - # is to trigger the fallback fn, we add to the selector check that - # calldatasize >= 4, which distinguishes any selector with trailing - # 0 bytes from the fallback function "selector" (equiv. to "all - # selectors not in the selector table"). - # - # note that the inclusion of this check means that, we are always - # guaranteed that the calldata is at least 4 bytes - either we have - # the explicit `calldatasize >= 4` condition in the selector check, - # or there are no trailing zeroes in the selector, (so the selector - # is impossible to match without calldatasize being at least 4). - method_id_bytes = util.method_id(abi_sig) - assert len(method_id_bytes) == 4 - has_trailing_zeroes = method_id_bytes.endswith(b"\x00") - if has_trailing_zeroes: - method_id_check = ["and", ["ge", "calldatasize", 4], method_id_check] - ret = ["if", method_id_check, ret] return ret diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 2d498460be..64d5a8b70c 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -107,7 +107,9 @@ def _runtime_ir(runtime_functions, global_ctx): selector_section.append(func_ir) if default_function: - fallback_ir = generate_ir_for_function(default_function, global_ctx, skip_nonpayable_check) + fallback_ir = generate_ir_for_function( + default_function, global_ctx, skip_nonpayable_check=False + ) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -119,8 +121,11 @@ def _runtime_ir(runtime_functions, global_ctx): # fallback label is the immediate next instruction, close_selector_section = ["goto", "fallback"] + global_calldatasize_check = ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]] + runtime = [ "seq", + global_calldatasize_check, ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section], close_selector_section, ["label", "fallback", ["var_list"], fallback_ir], From 5c2892b2b4f6cdbc039b0f70ecd0e7058fed521c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 28 May 2023 10:36:16 -0400 Subject: [PATCH 024/201] chore: update v0.3.9 release notes (#3458) --- docs/release-notes.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index a8bd309bf1..6d1d35b1e2 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -17,13 +17,14 @@ Release Notes v0.3.9 ****** -Date released: 2023-05-24 +Date released: 2023-05-28 -This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. +This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. @bout3fiddy also discovered a runtime performance (gas) regression for default functions in v0.3.8 which is fixed in this release. Fixes: - initcode codesize blowup (`#3450 `_) +- add back global calldatasize check for contracts with default fn (`#3463 `_) v0.3.8 From 66b9670555cbb57a78b9113e1c2ad343111df1b3 Mon Sep 17 00:00:00 2001 From: ControlCplusControlV <44706811+ControlCplusControlV@users.noreply.github.com> Date: Mon, 29 May 2023 08:50:34 -0600 Subject: [PATCH 025/201] docs: add name to 0.3.9 release (#3468) --------- Co-authored-by: Charles Cooper --- docs/release-notes.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 6d1d35b1e2..22d89614db 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,10 +14,10 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g -v0.3.9 +v0.3.9 ("Common Adder") ****** -Date released: 2023-05-28 +Date released: 2023-05-29 This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. @bout3fiddy also discovered a runtime performance (gas) regression for default functions in v0.3.8 which is fixed in this release. From e97e9f2a2fd7c6c062574997b622a41539926930 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 30 May 2023 15:28:44 -0400 Subject: [PATCH 026/201] ci: auto-upload release assets on release event (#3469) note: used download-artifact to upload the artifacts since we don't want to learn how to write a windows script in the windows-build workflow also: * removed the ghcr build on tag events (since it's redundant with the release event and we end up with orphaned images, since two images get tagged with ex. v0.3.9) * rename some workflows for clarity --- .github/workflows/build.yml | 37 +++++++++++++++++++++++++++----- .github/workflows/era-tester.yml | 2 +- .github/workflows/ghcr.yml | 2 -- .github/workflows/publish.yml | 4 ++-- .github/workflows/test.yml | 2 +- 5 files changed, 36 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f2b63e9967..f891ff7e1d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -name: Artifacts +name: Build and release artifacts on: workflow_dispatch: @@ -6,10 +6,10 @@ on: tag: default: '' push: - tags: - - '*' branches: - master + release: + types: [released] defaults: run: @@ -23,7 +23,7 @@ jobs: os: [ubuntu-20.04, macos-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: # grab the commit passed in via `tag`, if any ref: ${{ github.event.inputs.tag }} @@ -51,7 +51,7 @@ jobs: runs-on: windows-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: # grab the commit passed in via `tag`, if any ref: ${{ github.event.inputs.tag }} @@ -74,3 +74,30 @@ jobs: uses: actions/upload-artifact@v3 with: path: dist/vyper.* + + publish-release-assets: + needs: [windows-build, unix-build] + if: ${{ github.event_name == 'release' }} + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/download-artifact@v3 + with: + path: artifacts/ + + - name: Upload assets + # fun - artifacts are downloaded into "artifact/". + working-directory: artifacts/artifact + run: | + set -Eeuxo pipefail + for BIN_NAME in $(ls) + do + curl -L \ + --no-progress-meter \ + -X POST \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}"\ + -H "Content-Type: application/octet-stream" \ + "https://uploads.github.com/repos/${{ github.repository }}/releases/${{ github.event.release.id }}/assets?name=${BIN_NAME}" \ + --data-binary "@${BIN_NAME}" + done diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml index 8c7e355d26..a693d2c97d 100644 --- a/.github/workflows/era-tester.yml +++ b/.github/workflows/era-tester.yml @@ -1,4 +1,4 @@ -name: era compiler tester +name: Era compiler tester # run the matter labs compiler test to integrate their test cases # this is intended as a diagnostic / spot check to check that we diff --git a/.github/workflows/ghcr.yml b/.github/workflows/ghcr.yml index 4bc2885bdb..a35a22e278 100644 --- a/.github/workflows/ghcr.yml +++ b/.github/workflows/ghcr.yml @@ -8,8 +8,6 @@ name: Deploy docker image to ghcr on: push: - tags: - - '*' branches: - master release: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e6e5f2a6f9..44c6978295 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,7 +1,7 @@ -# This workflows will upload a Python Package using Twine when a release is created +# This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries -name: Publish +name: Publish to PyPI on: release: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 94e8c7c8f6..42e0524b13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Test +name: Run test suite on: [push, pull_request] From 981fcdabd66c5823c5a2f00e02b57413a7dd4ad3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 1 Jun 2023 00:09:33 +0800 Subject: [PATCH 027/201] fix: type inference for ternary operator literals (#3460) --- tests/parser/syntax/test_ternary.py | 5 +++++ vyper/semantics/analysis/annotation.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/parser/syntax/test_ternary.py b/tests/parser/syntax/test_ternary.py index 11c06051d0..325be3e43b 100644 --- a/tests/parser/syntax/test_ternary.py +++ b/tests/parser/syntax/test_ternary.py @@ -10,6 +10,11 @@ def foo(a: uint256, b: uint256) -> uint256: return a if a > b else b """, + """ +@external +def foo(): + a: bool = (True if True else True) or True + """, # different locations: """ b: uint256 diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index e501be5fdb..3ea0319b54 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -271,7 +271,7 @@ def visit_UnaryOp(self, node, type_): def visit_IfExp(self, node, type_): if type_ is None: ts = get_common_types(node.body, node.orelse) - if len(type_) == 1: + if len(ts) == 1: type_ = ts.pop() node._metadata["type"] = type_ From 7f18aeee59abbf3f4657edc94a8b354731cce19b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 31 May 2023 13:00:13 -0400 Subject: [PATCH 028/201] docs: shanghai is default compilation target (#3474) in v0.3.9 release notes --- docs/release-notes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 22d89614db..dcdbcda74a 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -44,6 +44,8 @@ Non-breaking changes and improvements: - drop support for python 3.8 and 3.9 (`#3325 `_) - build for ``aarch64`` (`#2687 `_) +Note that with the addition of ``push0`` opcode, ``shanghai`` is now the default compilation target for vyper. When deploying to a chain which does not support ``shanghai``, it is recommended to set ``--evm-version`` to ``paris``, otherwise it could result in hard-to-debug errors. + Major refactoring PRs: - refactor front-end type system (`#2974 `_) From 07f3cb091a30d06d45e626f49e8189260153aaa6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 1 Jun 2023 15:46:37 -0400 Subject: [PATCH 029/201] chore: drop evm versions through istanbul (#3470) drop pre-istanbul versions. per VIP-3365, we could drop through paris, but since this is the first time starting to enforce this policy, we don't want to drop too many versions at once. --- docs/compiling-a-contract.rst | 29 +++++---- tests/cli/vyper_compile/test_compile_files.py | 26 -------- .../test_compile_from_input_dict.py | 9 --- tests/cli/vyper_json/test_get_settings.py | 14 ++++- tests/compiler/test_opcodes.py | 33 +++++----- tests/parser/functions/test_bitwise.py | 13 +--- tests/parser/syntax/test_chainid.py | 16 ++--- tests/parser/syntax/test_codehash.py | 7 --- vyper/cli/vyper_compile.py | 2 +- vyper/cli/vyper_json.py | 10 ++- vyper/codegen/arithmetic.py | 11 +--- vyper/codegen/core.py | 14 +---- vyper/codegen/expr.py | 5 -- vyper/codegen/external_call.py | 3 +- vyper/evm/opcodes.py | 63 +++++++------------ vyper/ir/optimizer.py | 9 +-- 16 files changed, 95 insertions(+), 169 deletions(-) diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 4a03347536..36d0c8cb74 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -140,24 +140,29 @@ Target Options The following is a list of supported EVM versions, and changes in the compiler introduced with each version. Backward compatibility is not guaranteed between each version. -.. py:attribute:: byzantium +.. py:attribute:: istanbul - - The oldest EVM version supported by Vyper. + - The ``CHAINID`` opcode is accessible via ``chain.id`` + - The ``SELFBALANCE`` opcode is used for calls to ``self.balance`` + - Gas estimates changed for ``SLOAD`` and ``BALANCE`` -.. py:attribute:: constantinople +.. py:attribute:: berlin + - Gas estimates changed for ``EXTCODESIZE``, ``EXTCODECOPY``, ``EXTCODEHASH``, ``SLOAD``, ``SSTORE``, ``CALL``, ``CALLCODE``, ``DELEGATECALL`` and ``STATICCALL`` + - Functions marked with ``@nonreentrant`` are protected with different values (3 and 2) than contracts targeting pre-berlin. + - ``BASEFEE`` is accessible via ``block.basefee`` - - The ``EXTCODEHASH`` opcode is accessible via ``address.codehash`` - - ``shift`` makes use of ``SHL``/``SHR`` opcodes. +.. py:attribute:: paris + - ``block.difficulty`` is deprecated in favor of its new alias, ``block.prevrandao``. -.. py:attribute:: petersburg +.. py:attribute:: shanghai + - The ``PUSH0`` opcode is automatically generated by the compiler instead of ``PUSH1 0`` - - The compiler behaves the same way as with constantinople. +.. py:attribute:: cancun (experimental) + + - The ``transient`` keyword allows declaration of variables which live in transient storage + - Functions marked with ``@nonreentrant`` are protected with TLOAD/TSTORE instead of SLOAD/SSTORE -.. py:attribute:: istanbul (default) - - The ``CHAINID`` opcode is accessible via ``chain.id`` - - The ``SELFBALANCE`` opcode is used for calls to ``self.balance`` - - Gas estimates changed for ``SLOAD`` and ``BALANCE`` Compiler Input and Output JSON Description @@ -204,7 +209,7 @@ The following example describes the expected input format of ``vyper-json``. Com }, // Optional "settings": { - "evmVersion": "istanbul", // EVM version to compile for. Can be byzantium, constantinople, petersburg or istanbul. + "evmVersion": "shanghai", // EVM version to compile for. Can be istanbul, berlin, paris, shanghai (default) or cancun (experimental!). // optional, whether or not optimizations are turned on // defaults to true "optimize": true, diff --git a/tests/cli/vyper_compile/test_compile_files.py b/tests/cli/vyper_compile/test_compile_files.py index 796976ae0e..31cf622658 100644 --- a/tests/cli/vyper_compile/test_compile_files.py +++ b/tests/cli/vyper_compile/test_compile_files.py @@ -28,29 +28,3 @@ def test_combined_json_keys(tmp_path): def test_invalid_root_path(): with pytest.raises(FileNotFoundError): compile_files([], [], root_folder="path/that/does/not/exist") - - -def test_evm_versions(tmp_path): - # should compile differently because of SELFBALANCE - code = """ -@external -def foo() -> uint256: - return self.balance -""" - - bar_path = tmp_path.joinpath("bar.vy") - with bar_path.open("w") as fp: - fp.write(code) - - byzantium_bytecode = compile_files( - [bar_path], output_formats=["bytecode"], evm_version="byzantium" - )[str(bar_path)]["bytecode"] - istanbul_bytecode = compile_files( - [bar_path], output_formats=["bytecode"], evm_version="istanbul" - )[str(bar_path)]["bytecode"] - - assert byzantium_bytecode != istanbul_bytecode - - # SELFBALANCE opcode is 0x47 - assert "47" not in byzantium_bytecode - assert "47" in istanbul_bytecode diff --git a/tests/cli/vyper_json/test_compile_from_input_dict.py b/tests/cli/vyper_json/test_compile_from_input_dict.py index a5a31a522b..a6d0a23100 100644 --- a/tests/cli/vyper_json/test_compile_from_input_dict.py +++ b/tests/cli/vyper_json/test_compile_from_input_dict.py @@ -130,12 +130,3 @@ def test_relative_import_paths(): input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} compile_from_input_dict(input_json) - - -def test_evm_version(): - # should compile differently because of SELFBALANCE - input_json = deepcopy(INPUT_JSON) - input_json["settings"]["evmVersion"] = "byzantium" - compiled = compile_from_input_dict(input_json) - input_json["settings"]["evmVersion"] = "istanbul" - assert compiled != compile_from_input_dict(input_json) diff --git a/tests/cli/vyper_json/test_get_settings.py b/tests/cli/vyper_json/test_get_settings.py index ca60d2cf5a..7530e85ef8 100644 --- a/tests/cli/vyper_json/test_get_settings.py +++ b/tests/cli/vyper_json/test_get_settings.py @@ -12,13 +12,23 @@ def test_unknown_evm(): get_evm_version({"settings": {"evmVersion": "foo"}}) -@pytest.mark.parametrize("evm_version", ["homestead", "tangerineWhistle", "spuriousDragon"]) +@pytest.mark.parametrize( + "evm_version", + [ + "homestead", + "tangerineWhistle", + "spuriousDragon", + "byzantium", + "constantinople", + "petersburg", + ], +) def test_early_evm(evm_version): with pytest.raises(JSONError): get_evm_version({"settings": {"evmVersion": evm_version}}) -@pytest.mark.parametrize("evm_version", ["byzantium", "constantinople", "petersburg"]) +@pytest.mark.parametrize("evm_version", ["istanbul", "berlin", "paris", "shanghai", "cancun"]) def test_valid_evm(evm_version): assert evm_version == get_evm_version({"settings": {"evmVersion": evm_version}}) diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index 3c595dee44..b9841b92f0 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -37,30 +37,27 @@ def test_version_check(evm_version): assert opcodes.version_check(begin=evm_version) assert opcodes.version_check(end=evm_version) assert opcodes.version_check(begin=evm_version, end=evm_version) - if evm_version not in ("byzantium", "atlantis"): - assert not opcodes.version_check(end="byzantium") + if evm_version not in ("istanbul"): + assert not opcodes.version_check(end="istanbul") istanbul_check = opcodes.version_check(begin="istanbul") assert istanbul_check == (opcodes.EVM_VERSIONS[evm_version] >= opcodes.EVM_VERSIONS["istanbul"]) def test_get_opcodes(evm_version): ops = opcodes.get_opcodes() - if evm_version in ("paris", "berlin", "shanghai", "cancun"): - assert "CHAINID" in ops + + assert "CHAINID" in ops + assert ops["CREATE2"][-1] == 32000 + + if evm_version in ("london", "berlin", "paris", "shanghai", "cancun"): assert ops["SLOAD"][-1] == 2100 - if evm_version in ("shanghai", "cancun"): - assert "PUSH0" in ops - if evm_version in ("cancun",): - assert "TLOAD" in ops - assert "TSTORE" in ops - elif evm_version == "istanbul": - assert "CHAINID" in ops - assert ops["SLOAD"][-1] == 800 else: - assert "CHAINID" not in ops - assert ops["SLOAD"][-1] == 200 + assert evm_version == "istanbul" + assert ops["SLOAD"][-1] == 800 - if evm_version in ("byzantium", "atlantis"): - assert "CREATE2" not in ops - else: - assert ops["CREATE2"][-1] == 32000 + if evm_version in ("shanghai", "cancun"): + assert "PUSH0" in ops + + if evm_version in ("cancun",): + assert "TLOAD" in ops + assert "TSTORE" in ops diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py index 800803907a..3e18bd292c 100644 --- a/tests/parser/functions/test_bitwise.py +++ b/tests/parser/functions/test_bitwise.py @@ -35,12 +35,8 @@ def _shr(x: uint256, y: uint256) -> uint256: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) def test_bitwise_opcodes(evm_version): opcodes = compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"] - if evm_version in ("byzantium", "atlantis"): - assert "SHL" not in opcodes - assert "SHR" not in opcodes - else: - assert "SHL" in opcodes - assert "SHR" in opcodes + assert "SHL" in opcodes + assert "SHR" in opcodes @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -59,10 +55,7 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version): assert c._shl(t, s) == (t << s) % (2**256) -POST_BYZANTIUM = [k for (k, v) in EVM_VERSIONS.items() if v > 0] - - -@pytest.mark.parametrize("evm_version", POST_BYZANTIUM) +@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys())) def test_signed_shift(get_contract_with_gas_estimation, evm_version): code = """ @external diff --git a/tests/parser/syntax/test_chainid.py b/tests/parser/syntax/test_chainid.py index eb2ed36325..be960f2fc5 100644 --- a/tests/parser/syntax/test_chainid.py +++ b/tests/parser/syntax/test_chainid.py @@ -2,7 +2,7 @@ from vyper import compiler from vyper.evm.opcodes import EVM_VERSIONS -from vyper.exceptions import EvmVersionException, InvalidType, TypeMismatch +from vyper.exceptions import InvalidType, TypeMismatch @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -13,11 +13,7 @@ def foo(): a: uint256 = chain.id """ - if EVM_VERSIONS[evm_version] < 2: - with pytest.raises(EvmVersionException): - compiler.compile_code(code, evm_version=evm_version) - else: - compiler.compile_code(code, evm_version=evm_version) + assert compiler.compile_code(code, evm_version=evm_version) is not None fail_list = [ @@ -71,10 +67,10 @@ def foo(inp: Bytes[10]) -> Bytes[3]: def test_chain_fail(bad_code): if isinstance(bad_code, tuple): with pytest.raises(bad_code[1]): - compiler.compile_code(bad_code[0], evm_version="istanbul") + compiler.compile_code(bad_code[0]) else: with pytest.raises(TypeMismatch): - compiler.compile_code(bad_code, evm_version="istanbul") + compiler.compile_code(bad_code) valid_list = [ @@ -95,7 +91,7 @@ def check_chain_id(c: uint256) -> bool: @pytest.mark.parametrize("good_code", valid_list) def test_chain_success(good_code): - assert compiler.compile_code(good_code, evm_version="istanbul") is not None + assert compiler.compile_code(good_code) is not None def test_chainid_operation(get_contract_with_gas_estimation): @@ -105,5 +101,5 @@ def test_chainid_operation(get_contract_with_gas_estimation): def get_chain_id() -> uint256: return chain.id """ - c = get_contract_with_gas_estimation(code, evm_version="istanbul") + c = get_contract_with_gas_estimation(code) assert c.get_chain_id() == 131277322940537 # Default value of py-evm diff --git a/tests/parser/syntax/test_codehash.py b/tests/parser/syntax/test_codehash.py index 8c1e430d32..e4b6d90d8d 100644 --- a/tests/parser/syntax/test_codehash.py +++ b/tests/parser/syntax/test_codehash.py @@ -2,7 +2,6 @@ from vyper.compiler import compile_code from vyper.evm.opcodes import EVM_VERSIONS -from vyper.exceptions import EvmVersionException from vyper.utils import keccak256 @@ -32,12 +31,6 @@ def foo3() -> bytes32: def foo4() -> bytes32: return self.a.codehash """ - - if evm_version in ("byzantium", "atlantis"): - with pytest.raises(EvmVersionException): - compile_code(code, evm_version=evm_version) - return - compiled = compile_code( code, ["bytecode_runtime"], evm_version=evm_version, no_optimize=no_optimize ) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 4dfc87639a..f5e113116d 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -102,7 +102,7 @@ def _parse_args(argv): parser.add_argument( "--evm-version", help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). " - " note: cancun support is EXPERIMENTAL", + "note: cancun support is EXPERIMENTAL", choices=list(EVM_VERSIONS), default=DEFAULT_EVM_VERSION, dest="evm_version", diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 2fbf58aec4..9778848aa2 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -149,8 +149,14 @@ def get_evm_version(input_dict: Dict) -> str: return DEFAULT_EVM_VERSION evm_version = input_dict["settings"].get("evmVersion", DEFAULT_EVM_VERSION) - if evm_version in ("homestead", "tangerineWhistle", "spuriousDragon"): - raise JSONError("Vyper does not support pre-byzantium EVM versions") + if evm_version in ( + "homestead", + "tangerineWhistle", + "spuriousDragon", + "byzantium", + "constantinople", + ): + raise JSONError("Vyper does not support pre-istanbul EVM versions") if evm_version not in EVM_VERSIONS: raise JSONError(f"Unknown EVM version - '{evm_version}'") diff --git a/vyper/codegen/arithmetic.py b/vyper/codegen/arithmetic.py index eb2df20922..f14069384a 100644 --- a/vyper/codegen/arithmetic.py +++ b/vyper/codegen/arithmetic.py @@ -10,7 +10,6 @@ is_numeric_type, ) from vyper.codegen.ir_node import IRnode -from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, TypeCheckFailure, UnimplementedException @@ -243,10 +242,7 @@ def safe_mul(x, y): # in the above sdiv check, if (r==-1 and l==-2**255), # -2**255 / -1 will return -2**255. # need to check: not (r == -1 and l == -2**255) - if version_check(begin="constantinople"): - upper_bound = ["shl", 255, 1] - else: - upper_bound = -(2**255) + upper_bound = ["shl", 255, 1] check_x = ["ne", x, upper_bound] check_y = ["ne", ["not", y], 0] @@ -301,10 +297,7 @@ def safe_div(x, y): with res.cache_when_complex("res") as (b1, res): # TODO: refactor this condition / push some things into the optimizer if typ.is_signed and typ.bits == 256: - if version_check(begin="constantinople"): - upper_bound = ["shl", 255, 1] - else: - upper_bound = -(2**255) + upper_bound = ["shl", 255, 1] if not x.is_literal and not y.is_literal: ok = ["or", ["ne", y, ["not", 0]], ["ne", x, upper_bound]] diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 06140f3f52..58d9db9889 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,7 +1,6 @@ from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT -from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( AddressT, @@ -997,23 +996,16 @@ def zero_pad(bytez_placeholder): # convenience rewrites for shr/sar/shl def shr(bits, x): - if version_check(begin="constantinople"): - return ["shr", bits, x] - return ["div", x, ["exp", 2, bits]] + return ["shr", bits, x] # convenience rewrites for shr/sar/shl def shl(bits, x): - if version_check(begin="constantinople"): - return ["shl", bits, x] - return ["mul", x, ["exp", 2, bits]] + return ["shl", bits, x] def sar(bits, x): - if version_check(begin="constantinople"): - return ["sar", bits, x] - - raise NotImplementedError("no SAR emulation for pre-constantinople EVM") + return ["sar", bits, x] def clamp_bytestring(ir_node): diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index ac7290836b..d637a454bc 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -242,10 +242,6 @@ def parse_Attribute(self): # x.codehash: keccak of address x elif self.expr.attr == "codehash": addr = Expr.parse_value_expr(self.expr.value, self.context) - if not version_check(begin="constantinople"): - raise EvmVersionException( - "address.codehash is unavailable prior to constantinople ruleset", self.expr - ) if addr.typ == AddressT(): return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T) # x.code: codecopy/extcodecopy of address x @@ -401,7 +397,6 @@ def parse_BinOp(self): # TODO implement me. promote_signed_int(op(right, left), bits) return op = shr if not left.typ.is_signed else sar - # note: sar NotImplementedError for pre-constantinople return IRnode.from_list(op(right, left), typ=new_typ) # enums can only do bit ops, not arithmetic. diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index c4eb182eb1..ba89f3cace 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -63,8 +63,7 @@ def _pack_arguments(fn_type, args, context): # 32 bytes | args # 0x..00 | args # the reason for the left padding is just so the alignment is easier. - # if we were only targeting constantinople, we could align - # to buf (and also keep code size small) by using + # XXX: we could align to buf (and also keep code size small) by using # (mstore buf (shl signature.method_id 224)) pack_args = ["seq"] pack_args.append(["mstore", buf, util.method_id_int(abi_signature)]) diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index c447fd863c..00e0986939 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -7,28 +7,13 @@ # 1. Fork rules go from oldest (lowest value) to newest (highest value). # 2. Fork versions aren't actually tied to anything. They are not a part of our # official API. *DO NOT USE THE VALUES FOR ANYTHING IMPORTANT* besides versioning. -# 3. When support for an older version is dropped, the numbers should *not* change to -# reflect it (i.e. dropping support for version 0 removes version 0 entirely). -# 4. There can be multiple aliases to the same version number (but not the reverse). -# 5. When supporting multiple chains, if a chain gets a fix first, it increments the -# number first. -# 6. Yes, this will probably have to be rethought if there's ever conflicting support -# between multiple chains for a specific feature. Let's hope not. -# 7. We support at a maximum 3 hard forks (for any given chain). -EVM_VERSIONS: dict[str, int] = { - # ETH Forks - "byzantium": 0, - "constantinople": 1, - "petersburg": 1, - "istanbul": 2, - "berlin": 3, - "paris": 4, - "shanghai": 5, - "cancun": 6, - # ETC Forks - "atlantis": 0, - "agharta": 1, -} +# 3. Per VIP-3365, we support mainnet fork choice rules up to 1 year old +# (and may optionally have forward support for experimental/unreleased +# fork choice rules) +_evm_versions = ("istanbul", "berlin", "london", "paris", "shanghai", "cancun") +EVM_VERSIONS: dict[str, int] = dict((v, i) for i, v in enumerate(_evm_versions)) + + DEFAULT_EVM_VERSION: str = "shanghai" active_evm_version: int = EVM_VERSIONS[DEFAULT_EVM_VERSION] @@ -36,7 +21,7 @@ # opcode as hex value # number of values removed from stack # number of values added to stack -# gas cost (byzantium, constantinople, istanbul, berlin) +# gas cost (istanbul, berlin, paris, shanghai, cancun) OPCODES: OpcodeMap = { "STOP": (0x00, 0, 0, 0), "ADD": (0x01, 2, 1, 3), @@ -61,12 +46,12 @@ "XOR": (0x18, 2, 1, 3), "NOT": (0x19, 1, 1, 3), "BYTE": (0x1A, 2, 1, 3), - "SHL": (0x1B, 2, 1, (None, 3)), - "SHR": (0x1C, 2, 1, (None, 3)), - "SAR": (0x1D, 2, 1, (None, 3)), + "SHL": (0x1B, 2, 1, 3), + "SHR": (0x1C, 2, 1, 3), + "SAR": (0x1D, 2, 1, 3), "SHA3": (0x20, 2, 1, 30), "ADDRESS": (0x30, 0, 1, 2), - "BALANCE": (0x31, 1, 1, (400, 400, 700)), + "BALANCE": (0x31, 1, 1, 700), "ORIGIN": (0x32, 0, 1, 2), "CALLER": (0x33, 0, 1, 2), "CALLVALUE": (0x34, 0, 1, 2), @@ -76,11 +61,11 @@ "CODESIZE": (0x38, 0, 1, 2), "CODECOPY": (0x39, 3, 0, 3), "GASPRICE": (0x3A, 0, 1, 2), - "EXTCODESIZE": (0x3B, 1, 1, (700, 700, 700, 2600)), - "EXTCODECOPY": (0x3C, 4, 0, (700, 700, 700, 2600)), + "EXTCODESIZE": (0x3B, 1, 1, (700, 2600)), + "EXTCODECOPY": (0x3C, 4, 0, (700, 2600)), "RETURNDATASIZE": (0x3D, 0, 1, 2), "RETURNDATACOPY": (0x3E, 3, 0, 3), - "EXTCODEHASH": (0x3F, 1, 1, (None, 400, 700, 2600)), + "EXTCODEHASH": (0x3F, 1, 1, (700, 2600)), "BLOCKHASH": (0x40, 1, 1, 20), "COINBASE": (0x41, 0, 1, 2), "TIMESTAMP": (0x42, 0, 1, 2), @@ -88,14 +73,14 @@ "DIFFICULTY": (0x44, 0, 1, 2), "PREVRANDAO": (0x44, 0, 1, 2), "GASLIMIT": (0x45, 0, 1, 2), - "CHAINID": (0x46, 0, 1, (None, None, 2)), - "SELFBALANCE": (0x47, 0, 1, (None, None, 5)), - "BASEFEE": (0x48, 0, 1, (None, None, None, 2)), + "CHAINID": (0x46, 0, 1, 2), + "SELFBALANCE": (0x47, 0, 1, 5), + "BASEFEE": (0x48, 0, 1, (None, 2)), "POP": (0x50, 1, 0, 2), "MLOAD": (0x51, 1, 1, 3), "MSTORE": (0x52, 2, 0, 3), "MSTORE8": (0x53, 2, 0, 3), - "SLOAD": (0x54, 1, 1, (200, 200, 800, 2100)), + "SLOAD": (0x54, 1, 1, (800, 2100)), "SSTORE": (0x55, 2, 0, 20000), "JUMP": (0x56, 1, 0, 8), "JUMPI": (0x57, 2, 0, 10), @@ -174,13 +159,13 @@ "LOG3": (0xA3, 5, 0, 1500), "LOG4": (0xA4, 6, 0, 1875), "CREATE": (0xF0, 3, 1, 32000), - "CALL": (0xF1, 7, 1, (700, 700, 700, 2100)), - "CALLCODE": (0xF2, 7, 1, (700, 700, 700, 2100)), + "CALL": (0xF1, 7, 1, (700, 2100)), + "CALLCODE": (0xF2, 7, 1, (700, 2100)), "RETURN": (0xF3, 2, 0, 0), - "DELEGATECALL": (0xF4, 6, 1, (700, 700, 700, 2100)), - "CREATE2": (0xF5, 4, 1, (None, 32000)), + "DELEGATECALL": (0xF4, 6, 1, (700, 2100)), + "CREATE2": (0xF5, 4, 1, 32000), "SELFDESTRUCT": (0xFF, 1, 0, 25000), - "STATICCALL": (0xFA, 6, 1, (700, 700, 700, 2100)), + "STATICCALL": (0xFA, 6, 1, (700, 2100)), "REVERT": (0xFD, 2, 0, 0), "INVALID": (0xFE, 0, 0, 0), "DEBUG": (0xA5, 1, 0, 0), diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index fb10b515cc..b13c6f79f8 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple, Union from vyper.codegen.ir_node import IRnode -from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StaticAssertionException from vyper.utils import ( ceil32, @@ -340,19 +339,17 @@ def _conservative_eq(x, y): if binop == "mod": return finalize("and", [args[0], _int(args[1]) - 1]) - if binop == "div" and version_check(begin="constantinople"): + if binop == "div": # x / 2**n == x >> n # recall shr/shl have unintuitive arg order return finalize("shr", [int_log2(_int(args[1])), args[0]]) # note: no rule for sdiv since it rounds differently from sar - if binop == "mul" and version_check(begin="constantinople"): + if binop == "mul": # x * 2**n == x << n return finalize("shl", [int_log2(_int(args[1])), args[0]]) - # reachable but only before constantinople - if version_check(begin="constantinople"): # pragma: no cover - raise CompilerPanic("unreachable") + raise CompilerPanic("unreachable") # pragma: no cover ## # COMPARISONS From 795d9f5a389b4c4e05869be06ae016e47a2f0173 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 1 Jun 2023 17:26:15 -0400 Subject: [PATCH 030/201] docs: fix borked formatting for rulesets (#3476) --- docs/compiling-a-contract.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 36d0c8cb74..6295226bca 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -147,14 +147,17 @@ The following is a list of supported EVM versions, and changes in the compiler i - Gas estimates changed for ``SLOAD`` and ``BALANCE`` .. py:attribute:: berlin + - Gas estimates changed for ``EXTCODESIZE``, ``EXTCODECOPY``, ``EXTCODEHASH``, ``SLOAD``, ``SSTORE``, ``CALL``, ``CALLCODE``, ``DELEGATECALL`` and ``STATICCALL`` - Functions marked with ``@nonreentrant`` are protected with different values (3 and 2) than contracts targeting pre-berlin. - ``BASEFEE`` is accessible via ``block.basefee`` .. py:attribute:: paris + - ``block.difficulty`` is deprecated in favor of its new alias, ``block.prevrandao``. -.. py:attribute:: shanghai +.. py:attribute:: shanghai (default) + - The ``PUSH0`` opcode is automatically generated by the compiler instead of ``PUSH1 0`` .. py:attribute:: cancun (experimental) From 0e201316e63284d52bbe1410fdfaafe2a378616a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 1 Jun 2023 23:55:42 -0400 Subject: [PATCH 031/201] chore: fix era compiler tester workflow (#3477) auto-detect vyper version only run full llvm matrix on push to master --- .github/workflows/era-tester.yml | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml index a693d2c97d..8a2a3e50ce 100644 --- a/.github/workflows/era-tester.yml +++ b/.github/workflows/era-tester.yml @@ -85,14 +85,27 @@ jobs: **/era-compiler-tester key: ${{ runner.os }}-${{ env.ERA_HASH }}-${{ env.ERA_VYPER_HASH }} - - name: Install Vyper + - name: Build Vyper run: | + set -Eeuxo pipefail pip install . + echo "VYPER_VERSION=$(vyper --version | cut -f1 -d'+')" >> $GITHUB_ENV + + - name: Install Vyper + run: | mkdir era-compiler-tester/vyper-bin - echo $(which vyper) - cp $(which vyper) era-compiler-tester/vyper-bin/vyper-0.3.8 + cp $(which vyper) era-compiler-tester/vyper-bin/vyper-${{ env.VYPER_VERSION }} + + - name: Run tester (fast) + # Run era tester with no LLVM optimizations + if: ${{ github.ref != 'refs/heads/master' }} + run: | + cd era-compiler-tester + cargo run --release --bin compiler-tester -- -v --path=tests/vyper/ --mode="M0B0 ${{ env.VYPER_VERSION }}" - - name: Run tester + - name: Run tester (slow) + # Run era tester across the LLVM optimization matrix + if: ${{ github.ref == 'refs/heads/master' }} run: | cd era-compiler-tester - cargo run --release --bin compiler-tester -- -v --path='tests/vyper/' --mode='M*B* 0.3.8' + cargo run --release --bin compiler-tester -- -v --path=tests/vyper/ --mode="M*B* ${{ env.VYPER_VERSION }}" From a71d604513c1cf711b188ca3826325ffb58e35a0 Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 2 Jun 2023 09:27:23 -0600 Subject: [PATCH 032/201] chore: rm redundant var declaration (#3478) it was apparently there as a type hint but it is no longer needed --- vyper/codegen/module.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 64d5a8b70c..b98e4d0f86 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -1,8 +1,7 @@ # a contract.vy -- all functions and constructor -from typing import Any, List, Optional +from typing import Any, List -from vyper import ast as vy_ast from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function from vyper.codegen.global_context import GlobalContext @@ -141,8 +140,6 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees function_defs = _topsort(global_ctx.functions) - init_function: Optional[vy_ast.FunctionDef] = None - runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) From f0f9377748a6089dc8a39db692c8ad4c51a11f40 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 8 Jun 2023 13:30:14 -0700 Subject: [PATCH 033/201] chore: update tload/tstore opcodes per latest 1153 (#3484) --- vyper/evm/opcodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 00e0986939..7550d047b5 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -170,8 +170,8 @@ "INVALID": (0xFE, 0, 0, 0), "DEBUG": (0xA5, 1, 0, 0), "BREAKPOINT": (0xA6, 0, 0, 0), - "TLOAD": (0xB3, 1, 1, 100), - "TSTORE": (0xB4, 2, 0, 100), + "TLOAD": (0x5C, 1, 1, 100), + "TSTORE": (0x5D, 2, 0, 100), } PSEUDO_OPCODES: OpcodeMap = { From c90ab2fb26d66ad2121bf967dabe738ee8eaf21e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 15 Jun 2023 00:57:27 +0800 Subject: [PATCH 034/201] feat: let params of internal functions be mutable (#3473) params to internal functions are never in calldata, so we don't need to have any write restrictions. --- tests/parser/features/test_assignment.py | 115 +++++++++++++++++- .../function_definitions/internal_function.py | 4 +- vyper/semantics/analysis/local.py | 9 +- 3 files changed, 121 insertions(+), 7 deletions(-) diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 0dd63a0d09..29ec820484 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -39,7 +39,118 @@ def augmod(x: int128, y: int128) -> int128: print("Passed aug-assignment test") -def test_invalid_assign(assert_compile_failed, get_contract_with_gas_estimation): +@pytest.mark.parametrize( + "typ,in_val,out_val", + [ + ("uint256", 77, 123), + ("uint256[3]", [1, 2, 3], [4, 5, 6]), + ("DynArray[uint256, 3]", [1, 2, 3], [4, 5, 6]), + ("Bytes[5]", b"vyper", b"conda"), + ], +) +def test_internal_assign(get_contract_with_gas_estimation, typ, in_val, out_val): + code = f""" +@internal +def foo(x: {typ}) -> {typ}: + x = {out_val} + return x + +@external +def bar(x: {typ}) -> {typ}: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar(in_val) == out_val + + +def test_internal_assign_struct(get_contract_with_gas_estimation): + code = """ +enum Bar: + BAD + BAK + BAZ + +struct Foo: + a: uint256 + b: DynArray[Bar, 3] + c: String[5] + +@internal +def foo(x: Foo) -> Foo: + x = Foo({a: 789, b: [Bar.BAZ, Bar.BAK, Bar.BAD], c: \"conda\"}) + return x + +@external +def bar(x: Foo) -> Foo: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar((123, [1, 2, 4], "vyper")) == (789, [4, 2, 1], "conda") + + +def test_internal_assign_struct_member(get_contract_with_gas_estimation): + code = """ +enum Bar: + BAD + BAK + BAZ + +struct Foo: + a: uint256 + b: DynArray[Bar, 3] + c: String[5] + +@internal +def foo(x: Foo) -> Foo: + x.a = 789 + x.b.pop() + return x + +@external +def bar(x: Foo) -> Foo: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar((123, [1, 2, 4], "vyper")) == (789, [1, 2], "vyper") + + +def test_internal_augassign(get_contract_with_gas_estimation): + code = """ +@internal +def foo(x: int128) -> int128: + x += 77 + return x + +@external +def bar(x: int128) -> int128: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar(123) == 200 + + +@pytest.mark.parametrize("typ", ["DynArray[uint256, 3]", "uint256[3]"]) +def test_internal_augassign_arrays(get_contract_with_gas_estimation, typ): + code = f""" +@internal +def foo(x: {typ}) -> {typ}: + x[1] += 77 + return x + +@external +def bar(x: {typ}) -> {typ}: + return self.foo(x) + """ + c = get_contract_with_gas_estimation(code) + + assert c.bar([1, 2, 3]) == [1, 79, 3] + + +def test_invalid_external_assign(assert_compile_failed, get_contract_with_gas_estimation): code = """ @external def foo(x: int128): @@ -48,7 +159,7 @@ def foo(x: int128): assert_compile_failed(lambda: get_contract_with_gas_estimation(code), ImmutableViolation) -def test_invalid_augassign(assert_compile_failed, get_contract_with_gas_estimation): +def test_invalid_external_augassign(assert_compile_failed, get_contract_with_gas_estimation): code = """ @external def foo(x: int128): diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 17479c4c07..228191e3ca 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -41,8 +41,8 @@ def generate_ir_for_internal_function( for arg in func_t.arguments: # allocate a variable for every arg, setting mutability - # to False to comply with vyper semantics, function arguments are immutable - context.new_variable(arg.name, arg.typ, is_mutable=False) + # to True to allow internal function arguments to be mutable + context.new_variable(arg.name, arg.typ, is_mutable=True) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 790cee52d6..c99b582ad3 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -173,10 +173,13 @@ def __init__( self.func = fn_node._metadata["type"] self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace) self.expr_visitor = _LocalExpressionVisitor() + + # allow internal function params to be mutable + location, is_immutable = ( + (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) + ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo( - arg.typ, location=DataLocation.CALLDATA, is_immutable=True - ) + namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) for node in fn_node.body: self.visit(node) From 2704ff0140c3f08372720c95c92e0b0071211726 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 29 Jun 2023 17:02:45 -0700 Subject: [PATCH 035/201] fix: pycryptodome on arm (#3485) --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f891ff7e1d..b4be1043c1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,7 @@ jobs: - name: Generate Binary run: >- - pip install . && + pip install --no-binary pycryptodome . && pip install pyinstaller && make freeze From c1f0bd5a87e2f7fa10dec41639e8e37f9692e62c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 29 Jun 2023 17:51:37 -0700 Subject: [PATCH 036/201] chore: polish some codegen issues (#3488) fix some issues with well-formedness of generated IR (which were getting hidden by the `rewrite_return_sequences` routine). this shouldn't affect correctness of current vyper programs, but may help programs which consume vyper IR directly. * fix push_label_to_stack - use dedicated symbol instruction * remove busted return_pc forwarding in external functions * remove weird `_sym_` prefix in `exit_to` --- vyper/codegen/ir_node.py | 5 ----- vyper/codegen/return_.py | 5 +++-- vyper/codegen/self_call.py | 4 ++-- vyper/ir/compile_ir.py | 8 ++++---- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index d36a18ec66..f7698fbabb 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -38,11 +38,6 @@ def __repr__(self) -> str: __mul__ = __add__ -def push_label_to_stack(labelname: str) -> str: - # items prefixed with `_sym_` are ignored until asm phase - return "_sym_" + labelname - - class Encoding(Enum): # vyper encoding, default for memory variables VYPER = auto() diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index b8468f3eb1..56bea2b8da 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -21,7 +21,7 @@ def make_return_stmt(ir_val: IRnode, stmt: Any, context: Context) -> Optional[IRnode]: func_t = context.func_t - jump_to_exit = ["exit_to", f"_sym_{func_t._ir_info.exit_sequence_label}"] + jump_to_exit = ["exit_to", func_t._ir_info.exit_sequence_label] if context.return_type is None: if stmt.value is not None: @@ -43,7 +43,8 @@ def finalize(fill_return_buffer): return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit]) if context.return_type is None: - jump_to_exit += ["return_pc"] + if context.is_internal: + jump_to_exit += ["return_pc"] return finalize(["seq"]) if context.is_internal: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 311576194b..c320e6889c 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -1,5 +1,5 @@ from vyper.codegen.core import _freshname, eval_once_check, make_setter -from vyper.codegen.ir_node import IRnode, push_label_to_stack +from vyper.codegen.ir_node import IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT @@ -104,7 +104,7 @@ def ir_for_self_call(stmt_expr, context): if return_buffer is not None: goto_op += [return_buffer] # pass return label to subroutine - goto_op += [push_label_to_stack(return_label)] + goto_op.append(["symbol", return_label]) call_sequence = ["seq"] call_sequence.append(eval_once_check(_freshname(stmt_expr.node_source_code))) diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index b2a58fa8c9..9d7ef4691f 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -118,14 +118,14 @@ def _rewrite_return_sequences(ir_node, label_params=None): args[0].value = "pass" else: # handle jump to cleanup - assert is_symbol(args[0].value) ir_node.value = "seq" _t = ["seq"] if "return_buffer" in label_params: _t.append(["pop", "pass"]) - dest = args[0].value[5:] # `_sym_foo` -> `foo` + dest = args[0].value + # works for both internal and external exit_to more_args = ["pass" if t.value == "return_pc" else t for t in args[1:]] _t.append(["goto", dest] + more_args) ir_node.args = IRnode.from_list(_t, source_pos=ir_node.source_pos).args @@ -667,8 +667,8 @@ def _height_of(witharg): o.extend(["_sym_" + str(code.args[0]), "JUMP"]) return o # push a literal symbol - elif isinstance(code.value, str) and is_symbol(code.value): - return [code.value] + elif code.value == "symbol": + return ["_sym_" + str(code.args[0])] # set a symbol as a location. elif code.value == "label": label_name = code.args[0].value From ae608368f6a3ea6fd7cb16e685d95018ad0efcd0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 29 Jun 2023 22:58:40 -0700 Subject: [PATCH 037/201] chore: add __new__ to Namespace (#3489) this makes it picklable, otherwise it fails with `_scope` not being available during `__setitem__` --- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/namespace.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index cb8e93ff28..d916dcf119 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -98,7 +98,7 @@ def __init__( _ns = Namespace() # note that we don't just copy the namespace because # there are constructor issues. - _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) + _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore module_node._metadata["namespace"] = _ns # check for collisions between 4byte function selectors diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 82a5d5cf3e..b88bc3d817 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -20,9 +20,13 @@ class Namespace(dict): List of sets containing the key names for each scope """ + def __new__(cls, *args, **kwargs): + self = super().__new__(cls, *args, **kwargs) + self._scopes = [] + return self + def __init__(self): super().__init__() - self._scopes = [] # NOTE cyclic imports! # TODO: break this cycle by providing an `init_vyper_namespace` in 3rd module from vyper.builtins.functions import get_builtin_functions From bc723d2645aeec94bec1d83cdcbb8b41f7f807d3 Mon Sep 17 00:00:00 2001 From: trocher Date: Fri, 30 Jun 2023 08:01:14 +0200 Subject: [PATCH 038/201] fix: improve error message for conflicting methods IDs (#3491) Before this commit: `Methods have conflicting IDs: ` Now: `Methods produce colliding method ID '0x2e1a7d4d': OwnerTransferV7b711143(uint256), withdraw(uint256)` for the following contract: ``` @external def OwnerTransferV7b711143(a : uint256) : pass @external def withdraw(a : uint256): pass ``` Co-authored-by: Tanguy Rocher --- tests/signatures/test_method_id_conflicts.py | 7 +++++++ vyper/semantics/analysis/utils.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/signatures/test_method_id_conflicts.py b/tests/signatures/test_method_id_conflicts.py index 262348c12a..35c10300b4 100644 --- a/tests/signatures/test_method_id_conflicts.py +++ b/tests/signatures/test_method_id_conflicts.py @@ -67,6 +67,13 @@ def gfah(): pass @view def eexo(): pass """, + """ +# check collision with ID = 0x00000000 +wycpnbqcyf:public(uint256) + +@external +def randallsRevenge_ilxaotc(): pass + """, ] diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 26f3fd1827..f16b0c8c33 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -24,7 +24,7 @@ from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT -from vyper.utils import checksum_encode +from vyper.utils import checksum_encode, int_to_fourbytes def _validate_op(node, types_list, validation_fn_name): @@ -593,8 +593,13 @@ def validate_unique_method_ids(functions: List) -> None: seen = set() for method_id in method_ids: if method_id in seen: - collision_str = ", ".join(i.name for i in functions if method_id in i.method_ids) - raise StructureException(f"Methods have conflicting IDs: {collision_str}") + collision_str = ", ".join( + x for i in functions for x in i.method_ids.keys() if i.method_ids[x] == method_id + ) + collision_hex = int_to_fourbytes(method_id).hex() + raise StructureException( + f"Methods produce colliding method ID `0x{collision_hex}`: {collision_str}" + ) seen.add(method_id) From 9e363e89fb9a67984b29590da3f821c767622c8c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 30 Jun 2023 15:31:40 -0700 Subject: [PATCH 039/201] chore: remove vyper signature from runtime (#3471) it's going at the end of initcode instead, which is cheaper but still possible to pick up the vyper version by looking at the create tx. --- vyper/compiler/output.py | 4 +--- vyper/compiler/phases.py | 17 +++++------------ vyper/ir/compile_ir.py | 12 +++--------- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index f061bd8e18..63d92d9a47 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -218,9 +218,7 @@ def _build_asm(asm_list): def build_source_map_output(compiler_data: CompilerData) -> OrderedDict: _, line_number_map = compile_ir.assembly_to_evm( - compiler_data.assembly_runtime, - insert_vyper_signature=True, - disable_bytecode_metadata=compiler_data.no_bytecode_metadata, + compiler_data.assembly_runtime, insert_vyper_signature=False ) # Sort line_number_map out = OrderedDict() diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5156aa1bbd..c759f6e272 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -150,15 +150,12 @@ def assembly_runtime(self) -> list: @cached_property def bytecode(self) -> bytes: - return generate_bytecode( - self.assembly, is_runtime=False, no_bytecode_metadata=self.no_bytecode_metadata - ) + insert_vyper_signature = not self.no_bytecode_metadata + return generate_bytecode(self.assembly, insert_vyper_signature=insert_vyper_signature) @cached_property def bytecode_runtime(self) -> bytes: - return generate_bytecode( - self.assembly_runtime, is_runtime=True, no_bytecode_metadata=self.no_bytecode_metadata - ) + return generate_bytecode(self.assembly_runtime, insert_vyper_signature=False) @cached_property def blueprint_bytecode(self) -> bytes: @@ -295,9 +292,7 @@ def _find_nested_opcode(assembly, key): return any(_find_nested_opcode(x, key) for x in sublists) -def generate_bytecode( - assembly: list, is_runtime: bool = False, no_bytecode_metadata: bool = False -) -> bytes: +def generate_bytecode(assembly: list, insert_vyper_signature: bool) -> bytes: """ Generate bytecode from assembly instructions. @@ -311,6 +306,4 @@ def generate_bytecode( bytes Final compiled bytecode. """ - return compile_ir.assembly_to_evm( - assembly, insert_vyper_signature=is_runtime, disable_bytecode_metadata=no_bytecode_metadata - )[0] + return compile_ir.assembly_to_evm(assembly, insert_vyper_signature=insert_vyper_signature)[0] diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 9d7ef4691f..5a35b8f932 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -968,9 +968,7 @@ def adjust_pc_maps(pc_maps, ofst): return ret -def assembly_to_evm( - assembly, pc_ofst=0, insert_vyper_signature=False, disable_bytecode_metadata=False -): +def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): """ Assembles assembly into EVM @@ -994,7 +992,7 @@ def assembly_to_evm( runtime_code, runtime_code_start, runtime_code_end = None, None, None bytecode_suffix = b"" - if (not disable_bytecode_metadata) and insert_vyper_signature: + if insert_vyper_signature: # CBOR encoded: {"vyper": [major,minor,patch]} bytecode_suffix += b"\xa1\x65vyper\x83" + bytes(list(version_tuple)) bytecode_suffix += len(bytecode_suffix).to_bytes(2, "big") @@ -1011,11 +1009,7 @@ def assembly_to_evm( for i, item in enumerate(assembly): if isinstance(item, list): assert runtime_code is None, "Multiple subcodes" - runtime_code, runtime_map = assembly_to_evm( - item, - insert_vyper_signature=True, - disable_bytecode_metadata=disable_bytecode_metadata, - ) + runtime_code, runtime_map = assembly_to_evm(item) assert item[0].startswith("_DEPLOY_MEM_OFST_") assert ctor_mem_size is None From 3d01e947276481d49390aaf3dce6d09a216ea004 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 30 Jun 2023 15:32:41 -0700 Subject: [PATCH 040/201] chore: add test for complex storage assignment (#3472) add a test for complex make_setter when location is storage to prevent future regressions --- tests/parser/features/test_assignment.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 29ec820484..e550f60541 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -367,17 +367,33 @@ def foo(): """ assert_compile_failed(lambda: get_contract_with_gas_estimation(code), InvalidType) - -def test_assign_rhs_lhs_overlap(get_contract): # GH issue 2418 - code = """ + + +overlap_codes = [ + """ @external def bug(xs: uint256[2]) -> uint256[2]: # Initial value ys: uint256[2] = xs ys = [ys[1], ys[0]] return ys + """, """ +foo: uint256[2] +@external +def bug(xs: uint256[2]) -> uint256[2]: + # Initial value + self.foo = xs + self.foo = [self.foo[1], self.foo[0]] + return self.foo + """, + # TODO add transient tests when it's available +] + + +@pytest.mark.parametrize("code", overlap_codes) +def test_assign_rhs_lhs_overlap(get_contract, code): c = get_contract(code) assert c.bug([1, 2]) == [2, 1] From 29b02dd5b88a951c3791affc2062cca81701745d Mon Sep 17 00:00:00 2001 From: trocher Date: Sat, 1 Jul 2023 22:53:00 +0200 Subject: [PATCH 041/201] fix: typechecking of folded builtins (#3490) some builtins would allow decimals during typechecking and then panic during codegen --------- Co-authored-by: Tanguy Rocher --- tests/parser/syntax/test_addmulmod.py | 27 +++++++++++++++++++++++++++ vyper/builtins/functions.py | 14 +++++++------- 2 files changed, 34 insertions(+), 7 deletions(-) create mode 100644 tests/parser/syntax/test_addmulmod.py diff --git a/tests/parser/syntax/test_addmulmod.py b/tests/parser/syntax/test_addmulmod.py new file mode 100644 index 0000000000..ddff4d3e01 --- /dev/null +++ b/tests/parser/syntax/test_addmulmod.py @@ -0,0 +1,27 @@ +import pytest + +from vyper.exceptions import InvalidType + +fail_list = [ + ( # bad AST nodes given as arguments + """ +@external +def foo() -> uint256: + return uint256_addmod(1.1, 1.2, 3.0) + """, + InvalidType, + ), + ( # bad AST nodes given as arguments + """ +@external +def foo() -> uint256: + return uint256_mulmod(1.1, 1.2, 3.0) + """, + InvalidType, + ), +] + + +@pytest.mark.parametrize("code,exc", fail_list) +def test_add_mod_fail(assert_compile_failed, get_contract, code, exc): + assert_compile_failed(lambda: get_contract(code), exc) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index af965afe0a..90214554b0 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1370,7 +1370,7 @@ def evaluate(self, node): validate_call_args(node, 2) for arg in node.args: - if not isinstance(arg, vy_ast.Num): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) @@ -1396,7 +1396,7 @@ def evaluate(self, node): validate_call_args(node, 2) for arg in node.args: - if not isinstance(arg, vy_ast.Num): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) @@ -1422,7 +1422,7 @@ def evaluate(self, node): validate_call_args(node, 2) for arg in node.args: - if not isinstance(arg, vy_ast.Num): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) @@ -1447,7 +1447,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Num): + if not isinstance(node.args[0], vy_ast.Int): raise UnfoldableNode value = node.args[0].value @@ -1474,7 +1474,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 2) - if [i for i in node.args if not isinstance(i, vy_ast.Num)]: + if [i for i in node.args if not isinstance(i, vy_ast.Int)]: raise UnfoldableNode value, shift = [i.value for i in node.args] if value < 0 or value >= 2**256: @@ -1522,10 +1522,10 @@ class _AddMulMod(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 3) - if isinstance(node.args[2], vy_ast.Num) and node.args[2].value == 0: + if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) for arg in node.args: - if not isinstance(arg, vy_ast.Num): + if not isinstance(arg, vy_ast.Int): raise UnfoldableNode if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) From 593c9b86cfea23f624655d5847ef36ae00d7ccdc Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 11 Jul 2023 08:35:39 -0400 Subject: [PATCH 042/201] feat: add optimization mode to vyper compiler (#3493) this commit adds the `--optimize` flag to the vyper cli, and as an option in vyper json. it is to be used separately from the `--no-optimize` flag. this commit does not actually change codegen, just adds the flag and threads it through the codebase so it is available once we want to start differentiating between the two modes, and sets up the test harness to test both modes. it also makes the `optimize` and `evm-version` available as source code pragmas, and adds an additional syntax for specifying the compiler version (`#pragma version X.Y.Z`). if the CLI / JSON options conflict with the source code pragmas, an exception is raised. this commit also: * bumps mypy - it was needed to bump to 0.940 to handle match/case, and discovered we could bump all the way to 0.98* without breaking anything * removes evm_version from bitwise op tests - it was probably important when we supported pre-constantinople targets, which we don't anymore --- .github/workflows/test.yml | 4 +- docs/compiling-a-contract.rst | 31 +++++-- docs/structure-of-a-contract.rst | 39 ++++++++- setup.py | 2 +- tests/ast/test_pre_parser.py | 85 +++++++++++++++++-- tests/base_conftest.py | 25 +++--- tests/cli/vyper_json/test_get_settings.py | 5 -- tests/compiler/asm/test_asm_optimizer.py | 5 +- tests/compiler/test_pre_parser.py | 61 ++++++++++++- tests/conftest.py | 31 ++++--- tests/examples/factory/test_factory.py | 5 +- tests/grammar/test_grammar.py | 3 +- tests/parser/features/test_immutable.py | 4 +- tests/parser/features/test_transient.py | 15 ++-- tests/parser/functions/test_bitwise.py | 21 ++--- .../parser/functions/test_create_functions.py | 5 +- .../test_annotate_and_optimize_ast.py | 2 +- tests/parser/syntax/test_address_code.py | 6 +- tests/parser/syntax/test_chainid.py | 4 +- tests/parser/syntax/test_codehash.py | 8 +- tests/parser/syntax/test_self_balance.py | 4 +- tests/parser/types/test_dynamic_array.py | 5 +- tox.ini | 3 +- vyper/ast/__init__.py | 2 +- vyper/ast/nodes.pyi | 1 + vyper/ast/pre_parser.py | 57 ++++++++++--- vyper/ast/utils.py | 17 ++-- vyper/cli/vyper_compile.py | 35 +++++--- vyper/cli/vyper_json.py | 34 ++++++-- vyper/compiler/__init__.py | 73 ++++++++-------- vyper/compiler/phases.py | 66 ++++++++++---- vyper/compiler/settings.py | 30 +++++++ vyper/evm/opcodes.py | 24 +++--- vyper/ir/compile_ir.py | 5 +- 34 files changed, 524 insertions(+), 193 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 42e0524b13..b6399b3ae9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,8 +79,8 @@ jobs: strategy: matrix: python-version: [["3.10", "310"], ["3.11", "311"]] - # run in default (optimized) and --no-optimize mode - flag: ["core", "no-opt"] + # run in modes: --optimize [gas, none, codesize] + flag: ["core", "no-opt", "codesize"] name: py${{ matrix.python-version[1] }}-${{ matrix.flag }} diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 6295226bca..208771a5a9 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -99,6 +99,11 @@ See :ref:`searching_for_imports` for more information on Vyper's import system. Online Compilers ================ +Try VyperLang! +----------------- + +`Try VyperLang! `_ is a JupterHub instance hosted by the Vyper team as a sandbox for developing and testing contracts in Vyper. It requires github for login, and supports deployment via the browser. + Remix IDE --------- @@ -109,22 +114,33 @@ Remix IDE While the Vyper version of the Remix IDE compiler is updated on a regular basis, it might be a bit behind the latest version found in the master branch of the repository. Make sure the byte code matches the output from your local compiler. +.. _evm-version: + Setting the Target EVM Version ============================== -When you compile your contract code, you can specify the Ethereum Virtual Machine version to compile for, to avoid particular features or behaviours. +When you compile your contract code, you can specify the target Ethereum Virtual Machine version to compile for, to access or avoid particular features. You can specify the version either with a source code pragma or as a compiler option. It is recommended to use the compiler option when you want flexibility (for instance, ease of deploying across different chains), and the source code pragma when you want bytecode reproducibility (for instance, when verifying code on a block explorer). + +.. note:: + If the evm version specified by the compiler options conflicts with the source code pragma, an exception will be raised and compilation will not continue. + +For instance, the adding the following pragma to a contract indicates that it should be compiled for the "shanghai" fork of the EVM. + +.. code-block:: python + + #pragma evm-version shanghai .. warning:: - Compiling for the wrong EVM version can result in wrong, strange and failing behaviour. Please ensure, especially if running a private chain, that you use matching EVM versions. + Compiling for the wrong EVM version can result in wrong, strange, or failing behavior. Please ensure, especially if running a private chain, that you use matching EVM versions. -When compiling via ``vyper``, include the ``--evm-version`` flag: +When compiling via the ``vyper`` CLI, you can specify the EVM version option using the ``--evm-version`` flag: :: $ vyper --evm-version [VERSION] -When using the JSON interface, include the ``"evmVersion"`` key within the ``"settings"`` field: +When using the JSON interface, you can include the ``"evmVersion"`` key within the ``"settings"`` field: .. code-block:: javascript @@ -213,9 +229,10 @@ The following example describes the expected input format of ``vyper-json``. Com // Optional "settings": { "evmVersion": "shanghai", // EVM version to compile for. Can be istanbul, berlin, paris, shanghai (default) or cancun (experimental!). - // optional, whether or not optimizations are turned on - // defaults to true - "optimize": true, + // optional, optimization mode + // defaults to "gas". can be one of "gas", "codesize", "none", + // false and true (the last two are for backwards compatibility). + "optimize": "gas", // optional, whether or not the bytecode should include Vyper's signature // defaults to true "bytecodeMetadata": true, diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index 8eb2c1da78..c7abb3e645 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -9,16 +9,47 @@ This section provides a quick overview of the types of data present within a con .. _structure-versions: -Version Pragma +Pragmas ============== -Vyper supports a version pragma to ensure that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. +Vyper supports several source code directives to control compiler modes and help with build reproducibility. + +Version Pragma +-------------- + +The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. + +As of 0.3.10, the recommended way to specify the version pragma is as follows: .. code-block:: python - # @version ^0.2.0 + #pragma version ^0.3.0 + +The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: + +.. code-block:: python + + # @version ^0.3.0 + + +In the above examples, the contract will only compile with Vyper versions ``0.3.x``. + +Optimization Mode +----------------- + +The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For instance, the following contract will be compiled in a way which tries to minimize codesize: + +.. code-block:: python + + #pragma optimize codesize + +The optimization mode can also be set as a compiler option. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue. + +EVM Version +----------------- + +The EVM version can be set with the ``evm-version`` pragma, which is documented in :ref:`evm-version`. -In the above example, the contract only compiles with Vyper versions ``0.2.x``. .. _structure-state-variables: diff --git a/setup.py b/setup.py index 05cb52259d..36a138aacd 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "flake8-bugbear==20.1.4", "flake8-use-fstring==1.1", "isort==5.9.3", - "mypy==0.910", + "mypy==0.982", ], "docs": ["recommonmark", "sphinx>=6.0,<7.0", "sphinx_rtd_theme>=1.2,<1.3"], "dev": ["ipython", "pre-commit", "pyinstaller", "twine"], diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py index 8501bb8749..150ee55edf 100644 --- a/tests/ast/test_pre_parser.py +++ b/tests/ast/test_pre_parser.py @@ -1,6 +1,7 @@ import pytest -from vyper.ast.pre_parser import validate_version_pragma +from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import VersionException SRC_LINE = (1, 0) # Dummy source line @@ -51,14 +52,14 @@ def set_version(version): @pytest.mark.parametrize("file_version", valid_versions) def test_valid_version_pragma(file_version, mock_version): mock_version(COMPILER_VERSION) - validate_version_pragma(f" @version {file_version}", (SRC_LINE)) + validate_version_pragma(f"{file_version}", (SRC_LINE)) @pytest.mark.parametrize("file_version", invalid_versions) def test_invalid_version_pragma(file_version, mock_version): mock_version(COMPILER_VERSION) with pytest.raises(VersionException): - validate_version_pragma(f" @version {file_version}", (SRC_LINE)) + validate_version_pragma(f"{file_version}", (SRC_LINE)) prerelease_valid_versions = [ @@ -98,11 +99,85 @@ def test_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("file_version", prerelease_valid_versions) def test_prerelease_valid_version_pragma(file_version, mock_version): mock_version(PRERELEASE_COMPILER_VERSION) - validate_version_pragma(f" @version {file_version}", (SRC_LINE)) + validate_version_pragma(file_version, (SRC_LINE)) @pytest.mark.parametrize("file_version", prerelease_invalid_versions) def test_prerelease_invalid_version_pragma(file_version, mock_version): mock_version(PRERELEASE_COMPILER_VERSION) with pytest.raises(VersionException): - validate_version_pragma(f" @version {file_version}", (SRC_LINE)) + validate_version_pragma(file_version, (SRC_LINE)) + + +pragma_examples = [ + ( + """ + """, + Settings(), + ), + ( + """ + #pragma optimize codesize + """, + Settings(optimize=OptimizationLevel.CODESIZE), + ), + ( + """ + #pragma optimize none + """, + Settings(optimize=OptimizationLevel.NONE), + ), + ( + """ + #pragma optimize gas + """, + Settings(optimize=OptimizationLevel.GAS), + ), + ( + """ + #pragma version 0.3.10 + """, + Settings(compiler_version="0.3.10"), + ), + ( + """ + #pragma evm-version shanghai + """, + Settings(evm_version="shanghai"), + ), + ( + """ + #pragma optimize codesize + #pragma evm-version shanghai + """, + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), + ), + ( + """ + #pragma version 0.3.10 + #pragma evm-version shanghai + """, + Settings(evm_version="shanghai", compiler_version="0.3.10"), + ), + ( + """ + #pragma version 0.3.10 + #pragma optimize gas + """, + Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS), + ), + ( + """ + #pragma version 0.3.10 + #pragma evm-version shanghai + #pragma optimize gas + """, + Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS, evm_version="shanghai"), + ), +] + + +@pytest.mark.parametrize("code, expected_pragmas", pragma_examples) +def parse_pragmas(code, expected_pragmas): + pragmas, _, _ = pre_parse(code) + assert pragmas == expected_pragmas diff --git a/tests/base_conftest.py b/tests/base_conftest.py index 29809a074d..a78562e982 100644 --- a/tests/base_conftest.py +++ b/tests/base_conftest.py @@ -12,6 +12,7 @@ from vyper import compiler from vyper.ast.grammar import parse_vyper_source +from vyper.compiler.settings import Settings class VyperMethod: @@ -111,14 +112,16 @@ def w3(tester): return w3 -def _get_contract(w3, source_code, no_optimize, *args, **kwargs): +def _get_contract(w3, source_code, optimize, *args, **kwargs): + settings = Settings() + settings.evm_version = kwargs.pop("evm_version", None) + settings.optimize = optimize out = compiler.compile_code( source_code, # test that metadata gets generated ["abi", "bytecode", "metadata"], + settings=settings, interface_codes=kwargs.pop("interface_codes", None), - no_optimize=no_optimize, - evm_version=kwargs.pop("evm_version", None), show_gas_estimates=True, # Enable gas estimates for testing ) parse_vyper_source(source_code) # Test grammar. @@ -135,13 +138,15 @@ def _get_contract(w3, source_code, no_optimize, *args, **kwargs): return w3.eth.contract(address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract) -def _deploy_blueprint_for(w3, source_code, no_optimize, initcode_prefix=b"", **kwargs): +def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): + settings = Settings() + settings.evm_version = kwargs.pop("evm_version", None) + settings.optimize = optimize out = compiler.compile_code( source_code, ["abi", "bytecode"], interface_codes=kwargs.pop("interface_codes", None), - no_optimize=no_optimize, - evm_version=kwargs.pop("evm_version", None), + settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) parse_vyper_source(source_code) # Test grammar. @@ -173,17 +178,17 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, no_optimize): +def deploy_blueprint_for(w3, optimize): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, no_optimize, *args, **kwargs) + return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) return deploy_blueprint_for @pytest.fixture(scope="module") -def get_contract(w3, no_optimize): +def get_contract(w3, optimize): def get_contract(source_code, *args, **kwargs): - return _get_contract(w3, source_code, no_optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, *args, **kwargs) return get_contract diff --git a/tests/cli/vyper_json/test_get_settings.py b/tests/cli/vyper_json/test_get_settings.py index 7530e85ef8..bbe5dab113 100644 --- a/tests/cli/vyper_json/test_get_settings.py +++ b/tests/cli/vyper_json/test_get_settings.py @@ -3,7 +3,6 @@ import pytest from vyper.cli.vyper_json import get_evm_version -from vyper.evm.opcodes import DEFAULT_EVM_VERSION from vyper.exceptions import JSONError @@ -31,7 +30,3 @@ def test_early_evm(evm_version): @pytest.mark.parametrize("evm_version", ["istanbul", "berlin", "paris", "shanghai", "cancun"]) def test_valid_evm(evm_version): assert evm_version == get_evm_version({"settings": {"evmVersion": evm_version}}) - - -def test_default_evm(): - assert get_evm_version({}) == DEFAULT_EVM_VERSION diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py index f4a245e168..47b70a8c70 100644 --- a/tests/compiler/asm/test_asm_optimizer.py +++ b/tests/compiler/asm/test_asm_optimizer.py @@ -1,6 +1,7 @@ import pytest from vyper.compiler.phases import CompilerData +from vyper.compiler.settings import OptimizationLevel, Settings codes = [ """ @@ -72,7 +73,7 @@ def __init__(): @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): - c = CompilerData(code, no_optimize=True) + c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime @@ -87,7 +88,7 @@ def test_dead_code_eliminator(code): for s in (ctor_only_label, runtime_only_label): assert s + "_runtime" in runtime_asm - c = CompilerData(code, no_optimize=False) + c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) initcode_asm = [i for i in c.assembly if not isinstance(i, list)] runtime_asm = c.assembly_runtime diff --git a/tests/compiler/test_pre_parser.py b/tests/compiler/test_pre_parser.py index 4b747bb7d1..1761e74bad 100644 --- a/tests/compiler/test_pre_parser.py +++ b/tests/compiler/test_pre_parser.py @@ -1,6 +1,8 @@ -from pytest import raises +import pytest -from vyper.exceptions import SyntaxException +from vyper.compiler import compile_code +from vyper.compiler.settings import OptimizationLevel, Settings +from vyper.exceptions import StructureException, SyntaxException def test_semicolon_prohibited(get_contract): @@ -10,7 +12,7 @@ def test() -> int128: return a + b """ - with raises(SyntaxException): + with pytest.raises(SyntaxException): get_contract(code) @@ -70,6 +72,57 @@ def test(): assert get_contract(code) +def test_version_pragma2(get_contract): + # new, `#pragma` way of doing things + from vyper import __version__ + + installed_version = ".".join(__version__.split(".")[:3]) + + code = f""" +#pragma version {installed_version} + +@external +def test(): + pass + """ + assert get_contract(code) + + +def test_evm_version_check(assert_compile_failed): + code = """ +#pragma evm-version berlin + """ + assert compile_code(code, settings=Settings(evm_version=None)) is not None + assert compile_code(code, settings=Settings(evm_version="berlin")) is not None + # should fail if compile options indicate different evm version + # from source pragma + with pytest.raises(StructureException): + compile_code(code, settings=Settings(evm_version="shanghai")) + + +def test_optimization_mode_check(): + code = """ +#pragma optimize codesize + """ + assert compile_code(code, settings=Settings(optimize=None)) + # should fail if compile options indicate different optimization mode + # from source pragma + with pytest.raises(StructureException): + compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS)) + with pytest.raises(StructureException): + compile_code(code, settings=Settings(optimize=OptimizationLevel.NONE)) + + +def test_optimization_mode_check_none(): + code = """ +#pragma optimize none + """ + assert compile_code(code, settings=Settings(optimize=None)) + # "none" conflicts with "gas" + with pytest.raises(StructureException): + compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS)) + + def test_version_empty_version(assert_compile_failed, get_contract): code = """ #@version @@ -110,5 +163,5 @@ def foo(): convert( """ - with raises(SyntaxException): + with pytest.raises(SyntaxException): get_contract(code) diff --git a/tests/conftest.py b/tests/conftest.py index 1cc9e4e72e..9c9c4191b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from vyper import compiler from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel from vyper.ir import compile_ir, optimizer from .base_conftest import VyperContract, _get_contract, zero_gas_price_strategy @@ -36,12 +37,18 @@ def set_evm_verbose_logging(): def pytest_addoption(parser): - parser.addoption("--no-optimize", action="store_true", help="disable asm and IR optimizations") + parser.addoption( + "--optimize", + choices=["codesize", "gas", "none"], + default="gas", + help="change optimization mode", + ) @pytest.fixture(scope="module") -def no_optimize(pytestconfig): - return pytestconfig.getoption("no_optimize") +def optimize(pytestconfig): + flag = pytestconfig.getoption("optimize") + return OptimizationLevel.from_string(flag) @pytest.fixture @@ -58,13 +65,13 @@ def bytes_helper(str, length): @pytest.fixture -def get_contract_from_ir(w3, no_optimize): +def get_contract_from_ir(w3, optimize): def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) - if not no_optimize: + if optimize != OptimizationLevel.NONE: ir = optimizer.optimize(ir) bytecode, _ = compile_ir.assembly_to_evm( - compile_ir.compile_to_assembly(ir, no_optimize=no_optimize) + compile_ir.compile_to_assembly(ir, optimize=optimize) ) abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) @@ -80,7 +87,7 @@ def ir_compiler(ir, *args, **kwargs): @pytest.fixture(scope="module") -def get_contract_module(no_optimize): +def get_contract_module(optimize): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -93,7 +100,7 @@ def get_contract_module(no_optimize): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, no_optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, *args, **kwargs) return get_contract_module @@ -138,9 +145,9 @@ def set_decorator_to_contract_function(w3, tester, contract, source_code, func): @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, no_optimize): +def get_contract_with_gas_estimation(tester, w3, optimize): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, no_optimize, *args, **kwargs) + contract = _get_contract(w3, source_code, optimize, *args, **kwargs) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -150,9 +157,9 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, no_optimize): +def get_contract_with_gas_estimation_for_constants(w3, optimize): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, no_optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, *args, **kwargs) return get_contract_with_gas_estimation_for_constants diff --git a/tests/examples/factory/test_factory.py b/tests/examples/factory/test_factory.py index 15becc05f1..0c5cf61b04 100644 --- a/tests/examples/factory/test_factory.py +++ b/tests/examples/factory/test_factory.py @@ -2,6 +2,7 @@ from eth_utils import keccak import vyper +from vyper.compiler.settings import Settings @pytest.fixture @@ -30,12 +31,12 @@ def create_exchange(token, factory): @pytest.fixture -def factory(get_contract, no_optimize): +def factory(get_contract, optimize): with open("examples/factory/Exchange.vy") as f: code = f.read() exchange_interface = vyper.compile_code( - code, output_formats=["bytecode_runtime"], no_optimize=no_optimize + code, output_formats=["bytecode_runtime"], settings=Settings(optimize=optimize) ) exchange_deployed_bytecode = exchange_interface["bytecode_runtime"] diff --git a/tests/grammar/test_grammar.py b/tests/grammar/test_grammar.py index 7e220b58ae..d665ca2544 100644 --- a/tests/grammar/test_grammar.py +++ b/tests/grammar/test_grammar.py @@ -106,5 +106,6 @@ def has_no_docstrings(c): @hypothesis.settings(deadline=400, max_examples=500, suppress_health_check=(HealthCheck.too_slow,)) def test_grammar_bruteforce(code): if utf8_encodable(code): - tree = parse_to_ast(pre_parse(code + "\n")[1]) + _, _, reformatted_code = pre_parse(code + "\n") + tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) diff --git a/tests/parser/features/test_immutable.py b/tests/parser/features/test_immutable.py index 7300d0f2d9..47f7fc748e 100644 --- a/tests/parser/features/test_immutable.py +++ b/tests/parser/features/test_immutable.py @@ -1,5 +1,7 @@ import pytest +from vyper.compiler.settings import OptimizationLevel + @pytest.mark.parametrize( "typ,value", @@ -269,7 +271,7 @@ def __init__(to_copy: address): # GH issue 3101, take 2 def test_immutables_initialized2(get_contract, get_contract_from_ir): dummy_contract = get_contract_from_ir( - ["deploy", 0, ["seq"] + ["invalid"] * 600, 0], no_optimize=True + ["deploy", 0, ["seq"] + ["invalid"] * 600, 0], optimize=OptimizationLevel.NONE ) # rekt because immutables section extends past allocated memory diff --git a/tests/parser/features/test_transient.py b/tests/parser/features/test_transient.py index 53354beca8..718f5ae314 100644 --- a/tests/parser/features/test_transient.py +++ b/tests/parser/features/test_transient.py @@ -1,6 +1,7 @@ import pytest from vyper.compiler import compile_code +from vyper.compiler.settings import Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import StructureException @@ -13,20 +14,22 @@ def test_transient_blocked(evm_version): code = """ my_map: transient(HashMap[address, uint256]) """ + settings = Settings(evm_version=evm_version) if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["cancun"]: - assert compile_code(code, evm_version=evm_version) is not None + assert compile_code(code, settings=settings) is not None else: with pytest.raises(StructureException): - compile_code(code, evm_version=evm_version) + compile_code(code, settings=settings) @pytest.mark.parametrize("evm_version", list(post_cancun.keys())) def test_transient_compiles(evm_version): # test transient keyword at least generates TLOAD/TSTORE opcodes + settings = Settings(evm_version=evm_version) getter_code = """ my_map: public(transient(HashMap[address, uint256])) """ - t = compile_code(getter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = compile_code(getter_code, settings=settings, output_formats=["opcodes_runtime"]) t = t["opcodes_runtime"].split(" ") assert "TLOAD" in t @@ -39,7 +42,7 @@ def test_transient_compiles(evm_version): def setter(k: address, v: uint256): self.my_map[k] = v """ - t = compile_code(setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = compile_code(setter_code, settings=settings, output_formats=["opcodes_runtime"]) t = t["opcodes_runtime"].split(" ") assert "TLOAD" not in t @@ -52,9 +55,7 @@ def setter(k: address, v: uint256): def setter(k: address, v: uint256): self.my_map[k] = v """ - t = compile_code( - getter_setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"] - ) + t = compile_code(getter_setter_code, settings=settings, output_formats=["opcodes_runtime"]) t = t["opcodes_runtime"].split(" ") assert "TLOAD" in t diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py index 3e18bd292c..3ba74034ac 100644 --- a/tests/parser/functions/test_bitwise.py +++ b/tests/parser/functions/test_bitwise.py @@ -1,7 +1,6 @@ import pytest from vyper.compiler import compile_code -from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch from vyper.utils import unsigned_to_signed @@ -32,16 +31,14 @@ def _shr(x: uint256, y: uint256) -> uint256: """ -@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) -def test_bitwise_opcodes(evm_version): - opcodes = compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"] +def test_bitwise_opcodes(): + opcodes = compile_code(code, ["opcodes"])["opcodes"] assert "SHL" in opcodes assert "SHR" in opcodes -@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) -def test_test_bitwise(get_contract_with_gas_estimation, evm_version): - c = get_contract_with_gas_estimation(code, evm_version=evm_version) +def test_test_bitwise(get_contract_with_gas_estimation): + c = get_contract_with_gas_estimation(code) x = 126416208461208640982146408124 y = 7128468721412412459 assert c._bitwise_and(x, y) == (x & y) @@ -55,8 +52,7 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version): assert c._shl(t, s) == (t << s) % (2**256) -@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys())) -def test_signed_shift(get_contract_with_gas_estimation, evm_version): +def test_signed_shift(get_contract_with_gas_estimation): code = """ @external def _sar(x: int256, y: uint256) -> int256: @@ -66,7 +62,7 @@ def _sar(x: int256, y: uint256) -> int256: def _shl(x: int256, y: uint256) -> int256: return x << y """ - c = get_contract_with_gas_estimation(code, evm_version=evm_version) + c = get_contract_with_gas_estimation(code) x = 126416208461208640982146408124 y = 7128468721412412459 cases = [x, y, -x, -y] @@ -97,8 +93,7 @@ def baz(a: uint256, b: uint256, c: uint256) -> (uint256, uint256): assert tuple(c.baz(1, 6, 14)) == (1 + 8 | ~6 & 14 * 2, (1 + 8 | ~6) & 14 * 2) == (25, 24) -@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) -def test_literals(get_contract, evm_version): +def test_literals(get_contract): code = """ @external def _shr(x: uint256) -> uint256: @@ -109,7 +104,7 @@ def _shl(x: uint256) -> uint256: return x << 3 """ - c = get_contract(code, evm_version=evm_version) + c = get_contract(code) assert c._shr(80) == 10 assert c._shl(80) == 640 diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py index 64e0823146..876d50b27d 100644 --- a/tests/parser/functions/test_create_functions.py +++ b/tests/parser/functions/test_create_functions.py @@ -5,6 +5,7 @@ import vyper.ir.compile_ir as compile_ir from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256 @@ -232,7 +233,9 @@ def test(code_ofst: uint256) -> address: # zeroes (so no matter which offset, create_from_blueprint will # return empty code) ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0]) - bytecode, _ = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir, no_optimize=True)) + bytecode, _ = compile_ir.assembly_to_evm( + compile_ir.compile_to_assembly(ir, optimize=OptimizationLevel.NONE) + ) # manually deploy the bytecode c = w3.eth.contract(abi=[], bytecode=bytecode) deploy_transaction = c.constructor() diff --git a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py index 6f2246c6c0..68a07178bb 100644 --- a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py +++ b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py @@ -29,7 +29,7 @@ def foo() -> int128: def get_contract_info(source_code): - class_types, reformatted_code = pre_parse(source_code) + _, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) annotate_python_ast(py_ast, reformatted_code, class_types) diff --git a/tests/parser/syntax/test_address_code.py b/tests/parser/syntax/test_address_code.py index 25fe1be0b4..70ba5cbbf7 100644 --- a/tests/parser/syntax/test_address_code.py +++ b/tests/parser/syntax/test_address_code.py @@ -5,6 +5,7 @@ from web3 import Web3 from vyper import compiler +from vyper.compiler.settings import Settings from vyper.exceptions import NamespaceCollision, StructureException, VyperException # For reproducibility, use precompiled data of `hello: public(uint256)` using vyper 0.3.1 @@ -161,7 +162,7 @@ def test_address_code_compile_success(code: str): compiler.compile_code(code) -def test_address_code_self_success(get_contract, no_optimize: bool): +def test_address_code_self_success(get_contract, optimize): code = """ code_deployment: public(Bytes[32]) @@ -174,8 +175,9 @@ def code_runtime() -> Bytes[32]: return slice(self.code, 0, 32) """ contract = get_contract(code) + settings = Settings(optimize=optimize) code_compiled = compiler.compile_code( - code, output_formats=["bytecode", "bytecode_runtime"], no_optimize=no_optimize + code, output_formats=["bytecode", "bytecode_runtime"], settings=settings ) assert contract.code_deployment() == bytes.fromhex(code_compiled["bytecode"][2:])[:32] assert contract.code_runtime() == bytes.fromhex(code_compiled["bytecode_runtime"][2:])[:32] diff --git a/tests/parser/syntax/test_chainid.py b/tests/parser/syntax/test_chainid.py index be960f2fc5..2b6e08cbc4 100644 --- a/tests/parser/syntax/test_chainid.py +++ b/tests/parser/syntax/test_chainid.py @@ -1,6 +1,7 @@ import pytest from vyper import compiler +from vyper.compiler.settings import Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import InvalidType, TypeMismatch @@ -12,8 +13,9 @@ def test_evm_version(evm_version): def foo(): a: uint256 = chain.id """ + settings = Settings(evm_version=evm_version) - assert compiler.compile_code(code, evm_version=evm_version) is not None + assert compiler.compile_code(code, settings=settings) is not None fail_list = [ diff --git a/tests/parser/syntax/test_codehash.py b/tests/parser/syntax/test_codehash.py index e4b6d90d8d..5074d14636 100644 --- a/tests/parser/syntax/test_codehash.py +++ b/tests/parser/syntax/test_codehash.py @@ -1,12 +1,13 @@ import pytest from vyper.compiler import compile_code +from vyper.compiler.settings import Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.utils import keccak256 @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) -def test_get_extcodehash(get_contract, evm_version, no_optimize): +def test_get_extcodehash(get_contract, evm_version, optimize): code = """ a: address @@ -31,9 +32,8 @@ def foo3() -> bytes32: def foo4() -> bytes32: return self.a.codehash """ - compiled = compile_code( - code, ["bytecode_runtime"], evm_version=evm_version, no_optimize=no_optimize - ) + settings = Settings(evm_version=evm_version, optimize=optimize) + compiled = compile_code(code, ["bytecode_runtime"], settings=settings) bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:]) hash_ = keccak256(bytecode) diff --git a/tests/parser/syntax/test_self_balance.py b/tests/parser/syntax/test_self_balance.py index 949cdde324..63db58e347 100644 --- a/tests/parser/syntax/test_self_balance.py +++ b/tests/parser/syntax/test_self_balance.py @@ -1,6 +1,7 @@ import pytest from vyper import compiler +from vyper.compiler.settings import Settings from vyper.evm.opcodes import EVM_VERSIONS @@ -18,7 +19,8 @@ def get_balance() -> uint256: def __default__(): pass """ - opcodes = compiler.compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"] + settings = Settings(evm_version=evm_version) + opcodes = compiler.compile_code(code, ["opcodes"], settings=settings)["opcodes"] if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["istanbul"]: assert "SELFBALANCE" in opcodes else: diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index cb55c42870..cbae183fe4 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -2,6 +2,7 @@ import pytest +from vyper.compiler.settings import OptimizationLevel from vyper.exceptions import ( ArgumentException, ArrayIndexException, @@ -1543,7 +1544,7 @@ def bar(x: int128) -> DynArray[int128, 3]: assert c.bar(7) == [7, 14] -def test_nested_struct_of_lists(get_contract, assert_compile_failed, no_optimize): +def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): code = """ struct nestedFoo: a1: DynArray[DynArray[DynArray[uint256, 2], 2], 2] @@ -1585,7 +1586,7 @@ def bar2() -> uint256: newFoo.b1[0][1][0].a1[0][0][0] """ - if no_optimize: + if optimize == OptimizationLevel.NONE: # fails at assembly stage with too many stack variables assert_compile_failed(lambda: get_contract(code), Exception) else: diff --git a/tox.ini b/tox.ini index 5ddd01d7d4..9b63630f58 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,8 @@ envlist = usedevelop = True commands = core: pytest -m "not fuzzing" --showlocals {posargs:tests/} - no-opt: pytest -m "not fuzzing" --showlocals --no-optimize {posargs:tests/} + no-opt: pytest -m "not fuzzing" --showlocals --optimize none {posargs:tests/} + codesize: pytest -m "not fuzzing" --showlocals --optimize codesize {posargs:tests/} basepython = py310: python3.10 py311: python3.11 diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index 5695ceab7c..e5b81f1e7f 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -6,7 +6,7 @@ from . import nodes, validation from .natspec import parse_natspec from .nodes import compare_nodes -from .utils import ast_to_dict, parse_to_ast +from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings # adds vyper.ast.nodes classes into the local namespace for name, obj in ( diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 3d83ae7506..0d59a2fa63 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -4,6 +4,7 @@ from typing import Any, Optional, Sequence, Type, Union from .natspec import parse_natspec as parse_natspec from .utils import ast_to_dict as ast_to_dict from .utils import parse_to_ast as parse_to_ast +from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings NODE_BASE_ATTRIBUTES: Any NODE_SRC_ATTRIBUTES: Any diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index f29150a5d3..35153af9d5 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,11 +1,15 @@ import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize -from typing import Tuple from semantic_version import NpmSpec, Version -from vyper.exceptions import SyntaxException, VersionException +from vyper.compiler.settings import OptimizationLevel, Settings + +# seems a bit early to be importing this but we want it to validate the +# evm-version pragma +from vyper.evm.opcodes import EVM_VERSIONS +from vyper.exceptions import StructureException, SyntaxException, VersionException from vyper.typing import ModificationOffsets, ParserPosition VERSION_ALPHA_RE = re.compile(r"(?<=\d)a(?=\d)") # 0.1.0a17 @@ -33,10 +37,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: # NOTE: should be `x.y.z.*` installed_version = ".".join(__version__.split(".")[:3]) - version_arr = version_str.split("@version") - - raw_file_version = version_arr[1].strip() - strict_file_version = _convert_version_str(raw_file_version) + strict_file_version = _convert_version_str(version_str) strict_compiler_version = Version(_convert_version_str(installed_version)) if len(strict_file_version) == 0: @@ -46,14 +47,14 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: npm_spec = NpmSpec(strict_file_version) except ValueError: raise VersionException( - f'Version specification "{raw_file_version}" is not a valid NPM semantic ' + f'Version specification "{version_str}" is not a valid NPM semantic ' f"version specification", start, ) if not npm_spec.match(strict_compiler_version): raise VersionException( - f'Version specification "{raw_file_version}" is not compatible ' + f'Version specification "{version_str}" is not compatible ' f'with compiler version "{installed_version}"', start, ) @@ -66,7 +67,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> Tuple[ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -93,6 +94,7 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]: """ result = [] modification_offsets: ModificationOffsets = {} + settings = Settings() try: code_bytes = code.encode("utf-8") @@ -108,8 +110,39 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]: end = token.end line = token.line - if typ == COMMENT and "@version" in string: - validate_version_pragma(string[1:], start) + if typ == COMMENT: + contents = string[1:].strip() + if contents.startswith("@version"): + if settings.compiler_version is not None: + raise StructureException("compiler version specified twice!", start) + compiler_version = contents.removeprefix("@version ").strip() + validate_version_pragma(compiler_version, start) + settings.compiler_version = compiler_version + + if string.startswith("#pragma "): + pragma = string.removeprefix("#pragma").strip() + if pragma.startswith("version "): + if settings.compiler_version is not None: + raise StructureException("pragma version specified twice!", start) + compiler_version = pragma.removeprefix("version ".strip()) + validate_version_pragma(compiler_version, start) + settings.compiler_version = compiler_version + + if pragma.startswith("optimize "): + if settings.optimize is not None: + raise StructureException("pragma optimize specified twice!", start) + try: + mode = pragma.removeprefix("optimize").strip() + settings.optimize = OptimizationLevel.from_string(mode) + except ValueError: + raise StructureException(f"Invalid optimization mode `{mode}`", start) + if pragma.startswith("evm-version "): + if settings.evm_version is not None: + raise StructureException("pragma evm-version specified twice!", start) + evm_version = pragma.removeprefix("evm-version").strip() + if evm_version not in EVM_VERSIONS: + raise StructureException("Invalid evm version: `{evm_version}`", start) + settings.evm_version = evm_version if typ == NAME and string in ("class", "yield"): raise SyntaxException( @@ -130,4 +163,4 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]: except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return modification_offsets, untokenize(result).decode("utf-8") + return settings, modification_offsets, untokenize(result).decode("utf-8") diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index fc8aad227c..4e669385ab 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -1,18 +1,23 @@ import ast as python_ast -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from vyper.ast import nodes as vy_ast from vyper.ast.annotation import annotate_python_ast from vyper.ast.pre_parser import pre_parse +from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic, ParserException, SyntaxException -def parse_to_ast( +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + return parse_to_ast_with_settings(*args, **kwargs)[1] + + +def parse_to_ast_with_settings( source_code: str, source_id: int = 0, contract_name: Optional[str] = None, add_fn_node: Optional[str] = None, -) -> vy_ast.Module: +) -> tuple[Settings, vy_ast.Module]: """ Parses a Vyper source string and generates basic Vyper AST nodes. @@ -34,7 +39,7 @@ def parse_to_ast( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - class_types, reformatted_code = pre_parse(source_code) + settings, class_types, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) except SyntaxError as e: @@ -51,7 +56,9 @@ def parse_to_ast( annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) # Convert to Vyper AST. - return vy_ast.get_node(py_ast) # type: ignore + module = vy_ast.get_node(py_ast) + assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index f5e113116d..71e78dd666 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -5,13 +5,13 @@ import warnings from collections import OrderedDict from pathlib import Path -from typing import Dict, Iterable, Iterator, Set, TypeVar +from typing import Dict, Iterable, Iterator, Optional, Set, TypeVar import vyper import vyper.codegen.ir_node as ir_node from vyper.cli import vyper_json from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path -from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT +from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT, OptimizationLevel, Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS from vyper.typing import ContractCodes, ContractPath, OutputFormats @@ -37,8 +37,6 @@ ir - Intermediate representation in list format ir_json - Intermediate representation in JSON format hex-ir - Output IR and assembly constants in hex instead of decimal -no-optimize - Do not optimize (don't use this for production code) -no-bytecode-metadata - Do not add metadata to bytecode """ combined_json_outputs = [ @@ -104,10 +102,10 @@ def _parse_args(argv): help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). " "note: cancun support is EXPERIMENTAL", choices=list(EVM_VERSIONS), - default=DEFAULT_EVM_VERSION, dest="evm_version", ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") + parser.add_argument("--optimize", help="Optimization flag", choices=["gas", "codesize"]) parser.add_argument( "--no-bytecode-metadata", help="Do not add metadata to bytecode", action="store_true" ) @@ -153,13 +151,28 @@ def _parse_args(argv): output_formats = tuple(uniq(args.format.split(","))) + if args.no_optimize and args.optimize: + raise ValueError("Cannot use `--no-optimize` and `--optimize` at the same time!") + + settings = Settings() + + if args.no_optimize: + settings.optimize = OptimizationLevel.NONE + elif args.optimize is not None: + settings.optimize = OptimizationLevel.from_string(args.optimize) + + if args.evm_version: + settings.evm_version = args.evm_version + + if args.verbose: + print(f"using `{settings}`", file=sys.stderr) + compiled = compile_files( args.input_files, output_formats, args.root_folder, args.show_gas_estimates, - args.evm_version, - args.no_optimize, + settings, args.storage_layout, args.no_bytecode_metadata, ) @@ -253,9 +266,8 @@ def compile_files( output_formats: OutputFormats, root_folder: str = ".", show_gas_estimates: bool = False, - evm_version: str = DEFAULT_EVM_VERSION, - no_optimize: bool = False, - storage_layout: Iterable[str] = None, + settings: Optional[Settings] = None, + storage_layout: Optional[Iterable[str]] = None, no_bytecode_metadata: bool = False, ) -> OrderedDict: root_path = Path(root_folder).resolve() @@ -296,8 +308,7 @@ def compile_files( final_formats, exc_handler=exc_handler, interface_codes=get_interface_codes(root_path, contract_sources), - evm_version=evm_version, - no_optimize=no_optimize, + settings=settings, storage_layouts=storage_layouts, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 9778848aa2..4a1c91550e 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -5,11 +5,12 @@ import sys import warnings from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union import vyper from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path -from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS +from vyper.compiler.settings import OptimizationLevel, Settings +from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError from vyper.typing import ContractCodes, ContractPath from vyper.utils import keccak256 @@ -144,11 +145,15 @@ def _standardize_path(path_str: str) -> str: return path.as_posix() -def get_evm_version(input_dict: Dict) -> str: +def get_evm_version(input_dict: Dict) -> Optional[str]: if "settings" not in input_dict: - return DEFAULT_EVM_VERSION + return None + + # TODO: move this validation somewhere it can be reused more easily + evm_version = input_dict["settings"].get("evmVersion") + if evm_version is None: + return None - evm_version = input_dict["settings"].get("evmVersion", DEFAULT_EVM_VERSION) if evm_version in ( "homestead", "tangerineWhistle", @@ -360,7 +365,21 @@ def compile_from_input_dict( raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.") evm_version = get_evm_version(input_dict) - no_optimize = not input_dict["settings"].get("optimize", True) + + optimize = input_dict["settings"].get("optimize") + if isinstance(optimize, bool): + # bool optimization level for backwards compatibility + warnings.warn( + "optimize: is deprecated! please use one of 'gas', 'codesize', 'none'." + ) + optimize = OptimizationLevel.default() if optimize else OptimizationLevel.NONE + elif isinstance(optimize, str): + optimize = OptimizationLevel.from_string(optimize) + else: + assert optimize is None + + settings = Settings(evm_version=evm_version, optimize=optimize) + no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) contract_sources: ContractCodes = get_input_dict_contracts(input_dict) @@ -383,8 +402,7 @@ def compile_from_input_dict( output_formats[contract_path], interface_codes=interface_codes, initial_id=id_, - no_optimize=no_optimize, - evm_version=evm_version, + settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) except Exception as exc: diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 7be45ce832..0b3c0d8191 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -5,7 +5,8 @@ import vyper.codegen.core as codegen import vyper.compiler.output as output from vyper.compiler.phases import CompilerData -from vyper.evm.opcodes import DEFAULT_EVM_VERSION, evm_wrapper +from vyper.compiler.settings import Settings +from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version from vyper.typing import ( ContractCodes, ContractPath, @@ -46,15 +47,14 @@ } -@evm_wrapper def compile_codes( contract_sources: ContractCodes, output_formats: Union[OutputDict, OutputFormats, None] = None, exc_handler: Union[Callable, None] = None, interface_codes: Union[InterfaceDict, InterfaceImports, None] = None, initial_id: int = 0, - no_optimize: bool = False, - storage_layouts: Dict[ContractPath, StorageLayout] = None, + settings: Settings = None, + storage_layouts: Optional[dict[ContractPath, Optional[StorageLayout]]] = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, ) -> OrderedDict: @@ -73,11 +73,8 @@ def compile_codes( two arguments - the name of the contract, and the exception that was raised initial_id: int, optional The lowest source ID value to be used when generating the source map. - evm_version: str, optional - The target EVM ruleset to compile for. If not given, defaults to the latest - implemented ruleset. - no_optimize: bool, optional - Turn off optimizations. Defaults to False + settings: Settings, optional + Compiler settings show_gas_estimates: bool, optional Show gas estimates for abi and ir output modes interface_codes: Dict, optional @@ -98,6 +95,7 @@ def compile_codes( Dict Compiler output as `{'contract name': {'output key': "output data"}}` """ + settings = settings or Settings() if output_formats is None: output_formats = ("bytecode",) @@ -121,27 +119,30 @@ def compile_codes( # make IR output the same between runs codegen.reset_names() - compiler_data = CompilerData( - source_code, - contract_name, - interfaces, - source_id, - no_optimize, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, - ) - for output_format in output_formats[contract_name]: - if output_format not in OUTPUT_FORMATS: - raise ValueError(f"Unsupported format type {repr(output_format)}") - try: - out.setdefault(contract_name, {}) - out[contract_name][output_format] = OUTPUT_FORMATS[output_format](compiler_data) - except Exception as exc: - if exc_handler is not None: - exc_handler(contract_name, exc) - else: - raise exc + + with anchor_evm_version(settings.evm_version): + compiler_data = CompilerData( + source_code, + contract_name, + interfaces, + source_id, + settings, + storage_layout_override, + show_gas_estimates, + no_bytecode_metadata, + ) + for output_format in output_formats[contract_name]: + if output_format not in OUTPUT_FORMATS: + raise ValueError(f"Unsupported format type {repr(output_format)}") + try: + out.setdefault(contract_name, {}) + formatter = OUTPUT_FORMATS[output_format] + out[contract_name][output_format] = formatter(compiler_data) + except Exception as exc: + if exc_handler is not None: + exc_handler(contract_name, exc) + else: + raise exc return out @@ -153,9 +154,8 @@ def compile_code( contract_source: str, output_formats: Optional[OutputFormats] = None, interface_codes: Optional[InterfaceImports] = None, - evm_version: str = DEFAULT_EVM_VERSION, - no_optimize: bool = False, - storage_layout_override: StorageLayout = None, + settings: Settings = None, + storage_layout_override: Optional[StorageLayout] = None, show_gas_estimates: bool = False, ) -> dict: """ @@ -171,8 +171,8 @@ def compile_code( evm_version: str, optional The target EVM ruleset to compile for. If not given, defaults to the latest implemented ruleset. - no_optimize: bool, optional - Turn off optimizations. Defaults to False + settings: Settings, optional + Compiler settings. show_gas_estimates: bool, optional Show gas estimates for abi and ir output modes interface_codes: Dict, optional @@ -194,8 +194,7 @@ def compile_code( contract_sources, output_formats, interface_codes=interface_codes, - evm_version=evm_version, - no_optimize=no_optimize, + settings=settings, storage_layouts=storage_layouts, show_gas_estimates=show_gas_estimates, )[UNKNOWN_CONTRACT_NAME] diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index c759f6e272..99465809bd 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -7,6 +7,8 @@ from vyper.codegen import module from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel, Settings +from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT @@ -49,7 +51,7 @@ def __init__( contract_name: str = "VyperContract", interface_codes: Optional[InterfaceImports] = None, source_id: int = 0, - no_optimize: bool = False, + settings: Settings = None, storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, @@ -69,8 +71,8 @@ def __init__( * JSON interfaces are given as lists, vyper interfaces as strings source_id : int, optional ID number used to identify this contract in the source map. - no_optimize: bool, optional - Turn off optimizations. Defaults to False + settings: Settings + Set optimization mode. show_gas_estimates: bool, optional Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional @@ -80,14 +82,45 @@ def __init__( self.source_code = source_code self.interface_codes = interface_codes self.source_id = source_id - self.no_optimize = no_optimize self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata + self.settings = settings or Settings() @cached_property - def vyper_module(self) -> vy_ast.Module: - return generate_ast(self.source_code, self.source_id, self.contract_name) + def _generate_ast(self): + settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + # validate the compiler settings + # XXX: this is a bit ugly, clean up later + if settings.evm_version is not None: + if ( + self.settings.evm_version is not None + and self.settings.evm_version != settings.evm_version + ): + raise StructureException( + f"compiler settings indicate evm version {self.settings.evm_version}, " + f"but source pragma indicates {settings.evm_version}." + ) + + self.settings.evm_version = settings.evm_version + + if settings.optimize is not None: + if self.settings.optimize is not None and self.settings.optimize != settings.optimize: + raise StructureException( + f"compiler options indicate optimization mode {self.settings.optimize}, " + f"but source pragma indicates {settings.optimize}." + ) + self.settings.optimize = settings.optimize + + # ensure defaults + if self.settings.optimize is None: + self.settings.optimize = OptimizationLevel.default() + + return ast + + @cached_property + def vyper_module(self): + return self._generate_ast @cached_property def vyper_module_unfolded(self) -> vy_ast.Module: @@ -119,7 +152,7 @@ def global_ctx(self) -> GlobalContext: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes(self.global_ctx, self.no_optimize) + return generate_ir_nodes(self.global_ctx, self.settings.optimize) @property def ir_nodes(self) -> IRnode: @@ -142,11 +175,11 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - return generate_assembly(self.ir_nodes, self.no_optimize) + return generate_assembly(self.ir_nodes, self.settings.optimize) @cached_property def assembly_runtime(self) -> list: - return generate_assembly(self.ir_runtime, self.no_optimize) + return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property def bytecode(self) -> bytes: @@ -169,7 +202,9 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_ast(source_code: str, source_id: int, contract_name: str) -> vy_ast.Module: +def generate_ast( + source_code: str, source_id: int, contract_name: str +) -> tuple[Settings, vy_ast.Module]: """ Generate a Vyper AST from source code. @@ -187,7 +222,7 @@ def generate_ast(source_code: str, source_id: int, contract_name: str) -> vy_ast vy_ast.Module Top-level Vyper AST node """ - return vy_ast.parse_to_ast(source_code, source_id, contract_name) + return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name) def generate_unfolded_ast( @@ -233,7 +268,7 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes(global_ctx: GlobalContext, no_optimize: bool) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: GlobalContext, optimize: bool) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -254,13 +289,13 @@ def generate_ir_nodes(global_ctx: GlobalContext, no_optimize: bool) -> tuple[IRn IR to generate runtime bytecode """ ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) - if not no_optimize: + if optimize != OptimizationLevel.NONE: ir_nodes = optimizer.optimize(ir_nodes) ir_runtime = optimizer.optimize(ir_runtime) return ir_nodes, ir_runtime -def generate_assembly(ir_nodes: IRnode, no_optimize: bool = False) -> list: +def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = None) -> list: """ Generate assembly instructions from IR. @@ -274,7 +309,8 @@ def generate_assembly(ir_nodes: IRnode, no_optimize: bool = False) -> list: list List of assembly instructions. """ - assembly = compile_ir.compile_to_assembly(ir_nodes, no_optimize=no_optimize) + optimize = optimize or OptimizationLevel.default() + assembly = compile_ir.compile_to_assembly(ir_nodes, optimize=optimize) if _find_nested_opcode(assembly, "DEBUG"): warnings.warn( diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index 09ced0dcb8..bb5e9cdc25 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -1,4 +1,6 @@ import os +from dataclasses import dataclass +from enum import Enum from typing import Optional VYPER_COLOR_OUTPUT = os.environ.get("VYPER_COLOR_OUTPUT", "0") == "1" @@ -12,3 +14,31 @@ VYPER_TRACEBACK_LIMIT = int(_tb_limit_str) else: VYPER_TRACEBACK_LIMIT = None + + +class OptimizationLevel(Enum): + NONE = 1 + GAS = 2 + CODESIZE = 3 + + @classmethod + def from_string(cls, val): + match val: + case "none": + return cls.NONE + case "gas": + return cls.GAS + case "codesize": + return cls.CODESIZE + raise ValueError(f"unrecognized optimization level: {val}") + + @classmethod + def default(cls): + return cls.GAS + + +@dataclass +class Settings: + compiler_version: Optional[str] = None + optimize: Optional[OptimizationLevel] = None + evm_version: Optional[str] = None diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 7550d047b5..4fec13e897 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional +import contextlib +from typing import Dict, Generator, Optional from vyper.exceptions import CompilerPanic from vyper.typing import OpcodeGasCost, OpcodeMap, OpcodeRulesetMap, OpcodeRulesetValue, OpcodeValue @@ -206,17 +207,16 @@ IR_OPCODES: OpcodeMap = {**OPCODES, **PSEUDO_OPCODES} -def evm_wrapper(fn, *args, **kwargs): - def _wrapper(*args, **kwargs): - global active_evm_version - evm_version = kwargs.pop("evm_version", None) or DEFAULT_EVM_VERSION - active_evm_version = EVM_VERSIONS[evm_version] - try: - return fn(*args, **kwargs) - finally: - active_evm_version = EVM_VERSIONS[DEFAULT_EVM_VERSION] - - return _wrapper +@contextlib.contextmanager +def anchor_evm_version(evm_version: Optional[str] = None) -> Generator: + global active_evm_version + anchor = active_evm_version + evm_version = evm_version or DEFAULT_EVM_VERSION + active_evm_version = EVM_VERSIONS[evm_version] + try: + yield + finally: + active_evm_version = anchor def _gas(value: OpcodeValue, idx: int) -> Optional[OpcodeRulesetValue]: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 5a35b8f932..15a68a5079 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -3,6 +3,7 @@ import math from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic from vyper.utils import MemoryPositions @@ -201,7 +202,7 @@ def apply_line_no_wrapper(*args, **kwargs): @apply_line_numbers -def compile_to_assembly(code, no_optimize=False): +def compile_to_assembly(code, optimize=OptimizationLevel.GAS): global _revert_label _revert_label = mksymbol("revert") @@ -212,7 +213,7 @@ def compile_to_assembly(code, no_optimize=False): res = _compile_to_assembly(code) _add_postambles(res) - if not no_optimize: + if optimize != OptimizationLevel.NONE: _optimize_assembly(res) return res From 5dc3ac7ec700d85886eda3d53a03abcf5c7efc9c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 15 Jul 2023 11:20:53 -0400 Subject: [PATCH 043/201] feat: improve batch copy performance (#3483) per cancun, eip-5656, this commit adds the use of mcopy for memory copies. it also - adds heuristics to use loops vs unrolled loops for batch copies. - adds helper functions `vyper.codegen.core._opt_[gas,codesize,none]()` to detect optimization mode during codegen - adds `--optimize none` to CLI options, with the intent of phasing out `--no-optimize` if the ergonomics are better. --- .github/workflows/era-tester.yml | 4 +- setup.cfg | 1 - tests/compiler/test_opcodes.py | 7 +- tests/parser/functions/test_slice.py | 89 ++++++++-------- tests/parser/types/test_dynamic_array.py | 12 +-- vyper/cli/vyper_compile.py | 2 +- vyper/codegen/core.py | 128 ++++++++++++++++++++--- vyper/codegen/ir_node.py | 16 +-- vyper/compiler/phases.py | 8 +- vyper/evm/opcodes.py | 5 +- vyper/ir/compile_ir.py | 1 + vyper/ir/optimizer.py | 44 +++++--- vyper/utils.py | 3 +- 13 files changed, 221 insertions(+), 99 deletions(-) diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml index 8a2a3e50ce..187b5c03a2 100644 --- a/.github/workflows/era-tester.yml +++ b/.github/workflows/era-tester.yml @@ -101,11 +101,11 @@ jobs: if: ${{ github.ref != 'refs/heads/master' }} run: | cd era-compiler-tester - cargo run --release --bin compiler-tester -- -v --path=tests/vyper/ --mode="M0B0 ${{ env.VYPER_VERSION }}" + cargo run --release --bin compiler-tester -- --path=tests/vyper/ --mode="M0B0 ${{ env.VYPER_VERSION }}" - name: Run tester (slow) # Run era tester across the LLVM optimization matrix if: ${{ github.ref == 'refs/heads/master' }} run: | cd era-compiler-tester - cargo run --release --bin compiler-tester -- -v --path=tests/vyper/ --mode="M*B* ${{ env.VYPER_VERSION }}" + cargo run --release --bin compiler-tester -- --path=tests/vyper/ --mode="M*B* ${{ env.VYPER_VERSION }}" diff --git a/setup.cfg b/setup.cfg index d18ffe2ac7..dd4a32a3ac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,6 @@ addopts = -n auto --cov-report html --cov-report xml --cov=vyper - --hypothesis-show-statistics python_files = test_*.py testpaths = tests markers = diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index b9841b92f0..20f45ced6b 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -59,5 +59,8 @@ def test_get_opcodes(evm_version): assert "PUSH0" in ops if evm_version in ("cancun",): - assert "TLOAD" in ops - assert "TSTORE" in ops + for op in ("TLOAD", "TSTORE", "MCOPY"): + assert op in ops + else: + for op in ("TLOAD", "TSTORE", "MCOPY"): + assert op not in ops diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 11d834bf42..f1b642b28d 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -1,4 +1,6 @@ +import hypothesis.strategies as st import pytest +from hypothesis import given, settings from vyper.exceptions import ArgumentException @@ -9,14 +11,6 @@ def _generate_bytes(length): return bytes(list(range(length))) -# good numbers to try -_fun_numbers = [0, 1, 5, 31, 32, 33, 64, 99, 100, 101] - - -# [b"", b"\x01", b"\x02"...] -_bytes_examples = [_generate_bytes(i) for i in _fun_numbers if i <= 100] - - def test_basic_slice(get_contract_with_gas_estimation): code = """ @external @@ -31,12 +25,16 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: assert x == b"klmnopqrst", x -@pytest.mark.parametrize("bytesdata", _bytes_examples) -@pytest.mark.parametrize("start", _fun_numbers) +# note: optimization boundaries at 32, 64 and 320 depending on mode +_draw_1024 = st.integers(min_value=0, max_value=1024) +_draw_1024_1 = st.integers(min_value=1, max_value=1024) +_bytes_1024 = st.binary(min_size=0, max_size=1024) + + @pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("length", _fun_numbers) @pytest.mark.parametrize("literal_length", (True, False)) -@pytest.mark.fuzzing +@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) +@settings(max_examples=25, deadline=None) def test_slice_immutable( get_contract, assert_compile_failed, @@ -46,47 +44,48 @@ def test_slice_immutable( literal_start, length, literal_length, + length_bound, ): _start = start if literal_start else "start" _length = length if literal_length else "length" code = f""" -IMMUTABLE_BYTES: immutable(Bytes[100]) -IMMUTABLE_SLICE: immutable(Bytes[100]) +IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) +IMMUTABLE_SLICE: immutable(Bytes[{length_bound}]) @external -def __init__(inp: Bytes[100], start: uint256, length: uint256): +def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256): IMMUTABLE_BYTES = inp IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length}) @external -def do_splice() -> Bytes[100]: +def do_splice() -> Bytes[{length_bound}]: return IMMUTABLE_SLICE """ + def _get_contract(): + return get_contract(code, bytesdata, start, length) + if ( - (start + length > 100 and literal_start and literal_length) - or (literal_length and length > 100) - or (literal_start and start > 100) + (start + length > length_bound and literal_start and literal_length) + or (literal_length and length > length_bound) + or (literal_start and start > length_bound) or (literal_length and length < 1) ): - assert_compile_failed( - lambda: get_contract(code, bytesdata, start, length), ArgumentException - ) - elif start + length > len(bytesdata): - assert_tx_failed(lambda: get_contract(code, bytesdata, start, length)) + assert_compile_failed(lambda: _get_contract(), ArgumentException) + elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): + # deploy fail + assert_tx_failed(lambda: _get_contract()) else: - c = get_contract(code, bytesdata, start, length) + c = _get_contract() assert c.do_splice() == bytesdata[start : start + length] @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) -@pytest.mark.parametrize("bytesdata", _bytes_examples) -@pytest.mark.parametrize("start", _fun_numbers) @pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("length", _fun_numbers) @pytest.mark.parametrize("literal_length", (True, False)) -@pytest.mark.fuzzing +@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) +@settings(max_examples=25, deadline=None) def test_slice_bytes( get_contract, assert_compile_failed, @@ -97,9 +96,10 @@ def test_slice_bytes( literal_start, length, literal_length, + length_bound, ): if location == "memory": - spliced_code = "foo: Bytes[100] = inp" + spliced_code = f"foo: Bytes[{length_bound}] = inp" foo = "foo" elif location == "storage": spliced_code = "self.foo = inp" @@ -120,31 +120,38 @@ def test_slice_bytes( _length = length if literal_length else "length" code = f""" -foo: Bytes[100] -IMMUTABLE_BYTES: immutable(Bytes[100]) +foo: Bytes[{length_bound}] +IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) @external -def __init__(foo: Bytes[100]): +def __init__(foo: Bytes[{length_bound}]): IMMUTABLE_BYTES = foo @external -def do_slice(inp: Bytes[100], start: uint256, length: uint256) -> Bytes[100]: +def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]: {spliced_code} return slice({foo}, {_start}, {_length}) """ - length_bound = len(bytesdata) if location == "literal" else 100 + def _get_contract(): + return get_contract(code, bytesdata) + + data_length = len(bytesdata) if location == "literal" else length_bound if ( - (start + length > length_bound and literal_start and literal_length) - or (literal_length and length > length_bound) - or (literal_start and start > length_bound) + (start + length > data_length and literal_start and literal_length) + or (literal_length and length > data_length) + or (location == "literal" and len(bytesdata) > length_bound) + or (literal_start and start > data_length) or (literal_length and length < 1) ): - assert_compile_failed(lambda: get_contract(code, bytesdata), ArgumentException) + assert_compile_failed(lambda: _get_contract(), ArgumentException) + elif len(bytesdata) > data_length: + # deploy fail + assert_tx_failed(lambda: _get_contract()) elif start + length > len(bytesdata): - c = get_contract(code, bytesdata) + c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: - c = get_contract(code, bytesdata) + c = _get_contract() assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py index cbae183fe4..9231d1979f 100644 --- a/tests/parser/types/test_dynamic_array.py +++ b/tests/parser/types/test_dynamic_array.py @@ -2,7 +2,6 @@ import pytest -from vyper.compiler.settings import OptimizationLevel from vyper.exceptions import ( ArgumentException, ArrayIndexException, @@ -1585,14 +1584,9 @@ def bar2() -> uint256: newFoo.b1[1][0][0].a1[0][1][1] + \\ newFoo.b1[0][1][0].a1[0][0][0] """ - - if optimize == OptimizationLevel.NONE: - # fails at assembly stage with too many stack variables - assert_compile_failed(lambda: get_contract(code), Exception) - else: - c = get_contract(code) - assert c.bar() == [[[3, 7], [7, 3]], [[7, 3], [0, 0]]] - assert c.bar2() == 0 + c = get_contract(code) + assert c.bar() == [[[3, 7], [7, 3]], [[7, 3], [0, 0]]] + assert c.bar2() == 0 def test_tuple_of_lists(get_contract): diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 71e78dd666..55e0fc82b2 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -105,7 +105,7 @@ def _parse_args(argv): dest="evm_version", ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") - parser.add_argument("--optimize", help="Optimization flag", choices=["gas", "codesize"]) + parser.add_argument("--optimize", help="Optimization flag", choices=["gas", "codesize", "none"]) parser.add_argument( "--no-bytecode-metadata", help="Do not add metadata to bytecode", action="store_true" ) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 58d9db9889..5b16938e99 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,6 +1,11 @@ +import contextlib +from typing import Generator + from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode +from vyper.compiler.settings import OptimizationLevel from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( AddressT, @@ -19,13 +24,7 @@ from vyper.semantics.types.shortcuts import BYTES32_T, INT256_T, UINT256_T from vyper.semantics.types.subscriptable import SArrayT from vyper.semantics.types.user import EnumT -from vyper.utils import ( - GAS_CALLDATACOPY_WORD, - GAS_CODECOPY_WORD, - GAS_IDENTITY, - GAS_IDENTITYWORD, - ceil32, -) +from vyper.utils import GAS_COPY_WORD, GAS_IDENTITY, GAS_IDENTITYWORD, ceil32 DYNAMIC_ARRAY_OVERHEAD = 1 @@ -90,12 +89,16 @@ def _identity_gas_bound(num_bytes): return GAS_IDENTITY + GAS_IDENTITYWORD * (ceil32(num_bytes) // 32) +def _mcopy_gas_bound(num_bytes): + return GAS_COPY_WORD * ceil32(num_bytes) // 32 + + def _calldatacopy_gas_bound(num_bytes): - return GAS_CALLDATACOPY_WORD * ceil32(num_bytes) // 32 + return GAS_COPY_WORD * ceil32(num_bytes) // 32 def _codecopy_gas_bound(num_bytes): - return GAS_CODECOPY_WORD * ceil32(num_bytes) // 32 + return GAS_COPY_WORD * ceil32(num_bytes) // 32 # Copy byte array word-for-word (including layout) @@ -258,7 +261,6 @@ def copy_bytes(dst, src, length, length_bound): assert src.is_pointer and dst.is_pointer # fast code for common case where num bytes is small - # TODO expand this for more cases where num words is less than ~8 if length_bound <= 32: copy_op = STORE(dst, LOAD(src)) ret = IRnode.from_list(copy_op, annotation=annotation) @@ -268,8 +270,12 @@ def copy_bytes(dst, src, length, length_bound): # special cases: batch copy to memory # TODO: iloadbytes if src.location == MEMORY: - copy_op = ["staticcall", "gas", 4, src, length, dst, length] - gas_bound = _identity_gas_bound(length_bound) + if version_check(begin="cancun"): + copy_op = ["mcopy", dst, src, length] + gas_bound = _mcopy_gas_bound(length_bound) + else: + copy_op = ["staticcall", "gas", 4, src, length, dst, length] + gas_bound = _identity_gas_bound(length_bound) elif src.location == CALLDATA: copy_op = ["calldatacopy", dst, src, length] gas_bound = _calldatacopy_gas_bound(length_bound) @@ -876,6 +882,38 @@ def make_setter(left, right): return _complex_make_setter(left, right) +_opt_level = OptimizationLevel.GAS + + +@contextlib.contextmanager +def anchor_opt_level(new_level: OptimizationLevel) -> Generator: + """ + Set the global optimization level variable for the duration of this + context manager. + """ + assert isinstance(new_level, OptimizationLevel) + + global _opt_level + try: + tmp = _opt_level + _opt_level = new_level + yield + finally: + _opt_level = tmp + + +def _opt_codesize(): + return _opt_level == OptimizationLevel.CODESIZE + + +def _opt_gas(): + return _opt_level == OptimizationLevel.GAS + + +def _opt_none(): + return _opt_level == OptimizationLevel.NONE + + def _complex_make_setter(left, right): if right.value == "~empty" and left.location == MEMORY: # optimized memzero @@ -891,11 +929,69 @@ def _complex_make_setter(left, right): assert is_tuple_like(left.typ) keys = left.typ.tuple_keys() - # if len(keyz) == 0: - # return IRnode.from_list(["pass"]) + if left.is_pointer and right.is_pointer and right.encoding == Encoding.VYPER: + # both left and right are pointers, see if we want to batch copy + # instead of unrolling the loop. + assert left.encoding == Encoding.VYPER + len_ = left.typ.memory_bytes_required + + has_storage = STORAGE in (left.location, right.location) + if has_storage: + if _opt_codesize(): + # assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes, + # sstore(add (dst ofst), (sload (add (src ofst)))) is 16 bytes, + # whereas loop overhead is 16-17 bytes. + base_cost = 3 + if left._optimized.is_literal: + # code size is smaller since add is performed at compile-time + base_cost += 1 + if right._optimized.is_literal: + base_cost += 1 + # the formula is a heuristic, but it works. + # (CMC 2023-07-14 could get more detailed for PUSH1 vs + # PUSH2 etc but not worried about that too much now, + # it's probably better to add a proper unroll rule in the + # optimizer.) + should_batch_copy = len_ >= 32 * base_cost + elif _opt_gas(): + # kind of arbitrary, but cut off when code used > ~160 bytes + should_batch_copy = len_ >= 32 * 10 + else: + assert _opt_none() + # don't care, just generate the most readable version + should_batch_copy = True + else: + # find a cutoff for memory copy where identity is cheaper + # than unrolled mloads/mstores + # if MCOPY is available, mcopy is *always* better (except in + # the 1 word case, but that is already handled by copy_bytes). + if right.location == MEMORY and _opt_gas() and not version_check(begin="cancun"): + # cost for 0th word - (mstore dst (mload src)) + base_unroll_cost = 12 + nth_word_cost = base_unroll_cost + if not left._optimized.is_literal: + # (mstore (add N dst) (mload src)) + nth_word_cost += 6 + if not right._optimized.is_literal: + # (mstore dst (mload (add N src))) + nth_word_cost += 6 + + identity_base_cost = 115 # staticcall 4 gas dst len src len + + n_words = ceil32(len_) // 32 + should_batch_copy = ( + base_unroll_cost + (nth_word_cost * (n_words - 1)) >= identity_base_cost + ) + + # calldata to memory, code to memory, cancun, or codesize - + # batch copy is always better. + else: + should_batch_copy = True + + if should_batch_copy: + return copy_bytes(left, right, len_, len_) - # general case - # TODO use copy_bytes when the generated code is above a certain size + # general case, unroll with left.cache_when_complex("_L") as (b1, left), right.cache_when_complex("_R") as (b2, right): for k in keys: l_i = get_element_ptr(left, k, array_bounds_check=False) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index f7698fbabb..0895e5f02d 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -49,10 +49,7 @@ class Encoding(Enum): # this creates a magical block which maps to IR `with` class _WithBuilder: def __init__(self, ir_node, name, should_inline=False): - # TODO figure out how to fix this circular import - from vyper.ir.optimizer import optimize - - if should_inline and optimize(ir_node).is_complex_ir: + if should_inline and ir_node._optimized.is_complex_ir: # this can only mean trouble raise CompilerPanic("trying to inline a complex IR node") @@ -366,6 +363,13 @@ def is_pointer(self): # eventually return self.location is not None + @property # probably could be cached_property but be paranoid + def _optimized(self): + # TODO figure out how to fix this circular import + from vyper.ir.optimizer import optimize + + return optimize(self) + # This function is slightly confusing but abstracts a common pattern: # when an IR value needs to be computed once and then cached as an # IR value (if it is expensive, or more importantly if its computation @@ -382,13 +386,11 @@ def is_pointer(self): # return builder.resolve(ret) # ``` def cache_when_complex(self, name): - from vyper.ir.optimizer import optimize - # for caching purposes, see if the ir_node will be optimized # because a non-literal expr could turn into a literal, # (e.g. `(add 1 2)`) # TODO this could really be moved into optimizer.py - should_inline = not optimize(self).is_complex_ir + should_inline = not self._optimized.is_complex_ir return _WithBuilder(self, name, should_inline) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 99465809bd..4e1bd9e6c3 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -5,6 +5,7 @@ from vyper import ast as vy_ast from vyper.codegen import module +from vyper.codegen.core import anchor_opt_level from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel, Settings @@ -268,7 +269,9 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes(global_ctx: GlobalContext, optimize: bool) -> tuple[IRnode, IRnode]: +def generate_ir_nodes( + global_ctx: GlobalContext, optimize: OptimizationLevel +) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -288,7 +291,8 @@ def generate_ir_nodes(global_ctx: GlobalContext, optimize: bool) -> tuple[IRnode IR to generate deployment bytecode IR to generate runtime bytecode """ - ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) + with anchor_opt_level(optimize): + ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) if optimize != OptimizationLevel.NONE: ir_nodes = optimizer.optimize(ir_nodes) ir_runtime = optimizer.optimize(ir_runtime) diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 4fec13e897..767d634c89 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -89,6 +89,7 @@ "MSIZE": (0x59, 0, 1, 2), "GAS": (0x5A, 0, 1, 2), "JUMPDEST": (0x5B, 0, 0, 1), + "MCOPY": (0x5E, 3, 0, (None, None, None, None, None, 3)), "PUSH0": (0x5F, 0, 1, 2), "PUSH1": (0x60, 0, 1, 3), "PUSH2": (0x61, 0, 1, 3), @@ -171,8 +172,8 @@ "INVALID": (0xFE, 0, 0, 0), "DEBUG": (0xA5, 1, 0, 0), "BREAKPOINT": (0xA6, 0, 0, 0), - "TLOAD": (0x5C, 1, 1, 100), - "TSTORE": (0x5D, 2, 0, 100), + "TLOAD": (0x5C, 1, 1, (None, None, None, None, None, 100)), + "TSTORE": (0x5D, 2, 0, (None, None, None, None, None, 100)), } PSEUDO_OPCODES: OpcodeMap = { diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 15a68a5079..a9064a44fa 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -297,6 +297,7 @@ def _height_of(witharg): return o # batch copy from data section of the currently executing code to memory + # (probably should have named this dcopy but oh well) elif code.value == "dloadbytes": dst = code.args[0] src = code.args[1] diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index b13c6f79f8..40e02e79c7 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -2,6 +2,7 @@ from typing import List, Optional, Tuple, Union from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StaticAssertionException from vyper.utils import ( ceil32, @@ -472,6 +473,7 @@ def finalize(val, args): if value == "seq": changed |= _merge_memzero(argz) changed |= _merge_calldataload(argz) + changed |= _merge_mload(argz) changed |= _remove_empty_seqs(argz) # (seq x) => (x) for cleanliness and @@ -636,12 +638,26 @@ def _remove_empty_seqs(argz): def _merge_calldataload(argz): - # look for sequential operations copying from calldata to memory - # and merge them into a single calldatacopy operation + return _merge_load(argz, "calldataload", "calldatacopy") + + +def _merge_dload(argz): + return _merge_load(argz, "dload", "dloadbytes") + + +def _merge_mload(argz): + if not version_check(begin="cancun"): + return False + return _merge_load(argz, "mload", "mcopy") + + +def _merge_load(argz, _LOAD, _COPY): + # look for sequential operations copying from X to Y + # and merge them into a single copy operation changed = False mstore_nodes: List = [] - initial_mem_offset = 0 - initial_calldata_offset = 0 + initial_dst_offset = 0 + initial_src_offset = 0 total_length = 0 idx = None for i, ir_node in enumerate(argz): @@ -649,19 +665,19 @@ def _merge_calldataload(argz): if ( ir_node.value == "mstore" and isinstance(ir_node.args[0].value, int) - and ir_node.args[1].value == "calldataload" + and ir_node.args[1].value == _LOAD and isinstance(ir_node.args[1].args[0].value, int) ): # mstore of a zero value - mem_offset = ir_node.args[0].value - calldata_offset = ir_node.args[1].args[0].value + dst_offset = ir_node.args[0].value + src_offset = ir_node.args[1].args[0].value if not mstore_nodes: - initial_mem_offset = mem_offset - initial_calldata_offset = calldata_offset + initial_dst_offset = dst_offset + initial_src_offset = src_offset idx = i if ( - initial_mem_offset + total_length == mem_offset - and initial_calldata_offset + total_length == calldata_offset + initial_dst_offset + total_length == dst_offset + and initial_src_offset + total_length == src_offset ): mstore_nodes.append(ir_node) total_length += 32 @@ -676,7 +692,7 @@ def _merge_calldataload(argz): if len(mstore_nodes) > 1: changed = True new_ir = IRnode.from_list( - ["calldatacopy", initial_mem_offset, initial_calldata_offset, total_length], + [_COPY, initial_dst_offset, initial_src_offset, total_length], source_pos=mstore_nodes[0].source_pos, ) # replace first copy operation with optimized node and remove the rest @@ -684,8 +700,8 @@ def _merge_calldataload(argz): # note: del xs[k:l] deletes l - k items del argz[idx + 1 : idx + len(mstore_nodes)] - initial_mem_offset = 0 - initial_calldata_offset = 0 + initial_dst_offset = 0 + initial_src_offset = 0 total_length = 0 mstore_nodes.clear() diff --git a/vyper/utils.py b/vyper/utils.py index 2440117d0c..3d9d9cb416 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -196,8 +196,7 @@ def calc_mem_gas(memsize): # Specific gas usage GAS_IDENTITY = 15 GAS_IDENTITYWORD = 3 -GAS_CODECOPY_WORD = 3 -GAS_CALLDATACOPY_WORD = 3 +GAS_COPY_WORD = 3 # i.e., W_copy from YP # A decimal value can store multiples of 1/DECIMAL_DIVISOR MAX_DECIMAL_PLACES = 10 From 91d6e240f770414e1fbfd8648a166e9d2dba1698 Mon Sep 17 00:00:00 2001 From: trocher Date: Sun, 16 Jul 2023 18:24:44 +0200 Subject: [PATCH 044/201] fix displaying of ArgumentException (#3500) The change is pretty self-explanatory. ```vyper @internal @view def bar(): pass @external def foo(): self.bar(12) ``` Was failing to compile with: `vyper.exceptions.ArgumentException: Invalid argument count for call to 'bar': expected 0 to 0, got 1` And now fail to compile with: `vyper.exceptions.ArgumentException: Invalid argument count for call to 'bar': expected 0, got 1` Co-authored-by: Tanguy Rocher --- vyper/ast/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/ast/validation.py b/vyper/ast/validation.py index 7742d60c01..36a6a0484c 100644 --- a/vyper/ast/validation.py +++ b/vyper/ast/validation.py @@ -48,7 +48,7 @@ def validate_call_args( arg_count = (arg_count[0], 2**64) if arg_count[0] == arg_count[1]: - arg_count == arg_count[0] + arg_count = arg_count[0] if isinstance(node.func, vy_ast.Attribute): msg = f" for call to '{node.func.attr}'" From 9e3b9a2b8ae55aa83b5450080f750be15f819de7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 16 Jul 2023 20:14:35 -0400 Subject: [PATCH 045/201] feat: optimize dynarray and bytearray copies (#3499) include the length word in the batch copy instead of issuing a separate store instruction. brings CurveStableSwapMetaNG.vy down by 315 bytes (~1.5%) and VaultV3.vy by 45 bytes (0.25%) in both `--optimize codesize` and `--optimize gas` modes. --- tests/parser/functions/test_slice.py | 4 +-- vyper/codegen/core.py | 46 ++++++++++++++++------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index f1b642b28d..3064ee308e 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -2,7 +2,7 @@ import pytest from hypothesis import given, settings -from vyper.exceptions import ArgumentException +from vyper.exceptions import ArgumentException, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -143,7 +143,7 @@ def _get_contract(): or (literal_start and start > data_length) or (literal_length and length < 1) ): - assert_compile_failed(lambda: _get_contract(), ArgumentException) + assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) elif len(bytesdata) > data_length: # deploy fail assert_tx_failed(lambda: _get_contract()) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 5b16938e99..f47f88ac85 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -110,25 +110,33 @@ def make_byte_array_copier(dst, src): _check_assign_bytes(dst, src) # TODO: remove this branch, copy_bytes and get_bytearray_length should handle - if src.value == "~empty": + if src.value == "~empty" or src.typ.maxlen == 0: # set length word to 0. return STORE(dst, 0) with src.cache_when_complex("src") as (b1, src): - with get_bytearray_length(src).cache_when_complex("len") as (b2, len_): - max_bytes = src.typ.maxlen + has_storage = STORAGE in (src.location, dst.location) + is_memory_copy = dst.location == src.location == MEMORY + batch_uses_identity = is_memory_copy and not version_check(begin="cancun") + if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity): + # it's cheaper to run two load/stores instead of copy_bytes ret = ["seq"] - - dst_ = bytes_data_ptr(dst) - src_ = bytes_data_ptr(src) - - ret.append(copy_bytes(dst_, src_, len_, max_bytes)) - - # store length + # store length word + len_ = get_bytearray_length(src) ret.append(STORE(dst, len_)) - return b1.resolve(b2.resolve(ret)) + # store the single data word. + dst_data_ptr = bytes_data_ptr(dst) + src_data_ptr = bytes_data_ptr(src) + ret.append(STORE(dst_data_ptr, LOAD(src_data_ptr))) + return b1.resolve(ret) + + # batch copy the bytearray (including length word) using copy_bytes + len_ = add_ofst(get_bytearray_length(src), 32) + max_bytes = src.typ.maxlen + 32 + ret = copy_bytes(dst, src, len_, max_bytes) + return b1.resolve(ret) def bytes_data_ptr(ptr): @@ -213,19 +221,17 @@ def _dynarray_make_setter(dst, src): loop_body.annotation = f"{dst}[i] = {src}[i]" ret.append(["repeat", i, 0, count, src.typ.count, loop_body]) + # write the length word after data is copied + ret.append(STORE(dst, count)) else: element_size = src.typ.value_type.memory_bytes_required - # number of elements * size of element in bytes - n_bytes = _mul(count, element_size) - max_bytes = src.typ.count * element_size - - src_ = dynarray_data_ptr(src) - dst_ = dynarray_data_ptr(dst) - ret.append(copy_bytes(dst_, src_, n_bytes, max_bytes)) + # number of elements * size of element in bytes + length word + n_bytes = add_ofst(_mul(count, element_size), 32) + max_bytes = 32 + src.typ.count * element_size - # write the length word after data is copied - ret.append(STORE(dst, count)) + # batch copy the entire dynarray, including length word + ret.append(copy_bytes(dst, src, n_bytes, max_bytes)) return b1.resolve(b2.resolve(ret)) From cfba51719e10923cc93e40f6bca9a9d1d0d4a328 Mon Sep 17 00:00:00 2001 From: antazoey Date: Wed, 19 Jul 2023 11:27:10 -0500 Subject: [PATCH 046/201] fix: `tests` being imported in editable mode (#3510) the `tests` package was being imported when vyper installed in editable mode. this commit fixes by restricting the packages being exported in `setup.py`. --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 36a138aacd..bbf6e60f55 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import re import subprocess -from setuptools import find_packages, setup +from setuptools import setup extras_require = { "test": [ @@ -88,7 +88,7 @@ def _global_version(version): license="Apache License 2.0", keywords="ethereum evm smart contract language", include_package_data=True, - packages=find_packages(exclude=("tests", "docs")), + packages=["vyper"], python_requires=">=3.10,<4", py_modules=["vyper"], install_requires=[ From 6bd81dea55d0d0ef71b192bf30331a48a409f1d4 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 19 Jul 2023 10:15:57 -0700 Subject: [PATCH 047/201] chore: relax pragma parsing (#3511) allow `# pragma ...` in addition to `#pragma ...` also fix a small bug in version parsing (it only affected the error message formatting, not the parsed version) --- vyper/ast/pre_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 35153af9d5..7e677b3b92 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -119,12 +119,12 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version - if string.startswith("#pragma "): - pragma = string.removeprefix("#pragma").strip() + if contents.startswith("pragma "): + pragma = contents.removeprefix("pragma ").strip() if pragma.startswith("version "): if settings.compiler_version is not None: raise StructureException("pragma version specified twice!", start) - compiler_version = pragma.removeprefix("version ".strip()) + compiler_version = pragma.removeprefix("version ").strip() validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version From f928a0ff64bd3355f6410e460f6a710000e5f9d7 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 21 Jul 2023 23:32:06 +0800 Subject: [PATCH 048/201] chore: improve error message for invalid references to constants and immutables (#3529) --- .../exceptions/test_invalid_reference.py | 18 ++++++++++++++++++ vyper/semantics/analysis/utils.py | 16 +++++++++++----- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/parser/exceptions/test_invalid_reference.py b/tests/parser/exceptions/test_invalid_reference.py index 3aec6028e4..fe315e5cbf 100644 --- a/tests/parser/exceptions/test_invalid_reference.py +++ b/tests/parser/exceptions/test_invalid_reference.py @@ -37,6 +37,24 @@ def foo(): def foo(): int128 = 5 """, + """ +a: public(constant(uint256)) = 1 + +@external +def foo(): + b: uint256 = self.a + """, + """ +a: public(immutable(uint256)) + +@external +def __init__(): + a = 123 + +@external +def foo(): + b: uint256 = self.a + """, ] diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f16b0c8c33..4f911764e0 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -180,24 +180,30 @@ def _find_fn(self, node): raise StructureException("Cannot determine type of this object", node) def types_from_Attribute(self, node): + is_self_reference = node.get("value.id") == "self" # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr + + def _raise_invalid_reference(name, node): + raise InvalidReference( + f"'{name}' is not a storage variable, it should not be prepended with self", node + ) + try: s = t.get_member(name, node) if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] + if is_self_reference and (s.is_constant or s.is_immutable): + _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] except UnknownAttribute: - if node.get("value.id") != "self": + if not is_self_reference: raise if name in self.namespace: - raise InvalidReference( - f"'{name}' is not a storage variable, it should not be prepended with self", - node, - ) from None + _raise_invalid_reference(name, node) suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4) raise UndeclaredDefinition( From 299352ef5ee3b9be2b334120091a7e80e10d2022 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 23 Jul 2023 15:11:29 -0700 Subject: [PATCH 049/201] feat: optimize dload/mstore sequences (#3525) merge_dload was defined in 5dc3ac7, but was not applied. this commit also adds a rewrite rule for single dload/mstore patterns, any `(mstore dst (dload src))` (which compiles to a `codecopy` followed by an `mload` and `mstore`) is rewritten to a `dloadbytes` (which compiles directly to a `codecopy`). this rule saves 4 bytes / ~10 gas per rewrite. for instance, it shaves 25 bytes off `VaultV3.vy`, 50 bytes off `CurveTricryptoOptimizedWETH.vy` and 75 bytes off `CurveStableSwapMetaNG.vy` (basically 0.1%-0.3%). --- vyper/ir/optimizer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 40e02e79c7..08c2168381 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -473,6 +473,8 @@ def finalize(val, args): if value == "seq": changed |= _merge_memzero(argz) changed |= _merge_calldataload(argz) + changed |= _merge_dload(argz) + changed |= _rewrite_mstore_dload(argz) changed |= _merge_mload(argz) changed |= _remove_empty_seqs(argz) @@ -645,6 +647,18 @@ def _merge_dload(argz): return _merge_load(argz, "dload", "dloadbytes") +def _rewrite_mstore_dload(argz): + changed = False + for i, arg in enumerate(argz): + if arg.value == "mstore" and arg.args[1].value == "dload": + dst = arg.args[0] + src = arg.args[1].args[0] + len_ = 32 + argz[i] = IRnode.from_list(["dloadbytes", dst, src, len_], source_pos=arg.source_pos) + changed = True + return changed + + def _merge_mload(argz): if not version_check(begin="cancun"): return False From 4ca1c81aaa4f3e950522a2115aa5fcd7d80c1b27 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 24 Jul 2023 09:23:04 -0700 Subject: [PATCH 050/201] chore: improve some error messages (#3524) fix array bounds check and `create_*` builtin error messages - array bounds checks, previously were something like `clamp lt [mload, 640 ]` - codesize check error message was missing for create builtins - create failure error message was also missing --- vyper/builtins/functions.py | 16 +++++++++++++--- vyper/codegen/core.py | 1 + vyper/codegen/ir_node.py | 8 ++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 90214554b0..e1dcee6b8d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1634,7 +1634,9 @@ def _create_ir(value, buf, length, salt=None, checked=True): if not checked: return ret - return clamp_nonzero(ret) + ret = clamp_nonzero(ret) + ret.set_error_msg(f"{create_op} failed") + return ret # calculate the gas used by create for a given number of bytes @@ -1830,7 +1832,10 @@ def _build_create_IR(self, expr, args, context, value, salt): ir = ["seq"] # make sure there is actually code at the target - ir.append(["assert", codesize]) + check_codesize = ["assert", codesize] + ir.append( + IRnode.from_list(check_codesize, error_msg="empty target (create_copy_of)") + ) # store the preamble at msize + 22 (zero padding) preamble, preamble_len = _create_preamble(codesize) @@ -1920,7 +1925,12 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar # (code_ofst == (extcodesize target) would be empty # initcode, which we disallow for hygiene reasons - # same as `create_copy_of` on an empty target). - ir.append(["assert", ["sgt", codesize, 0]]) + check_codesize = ["assert", ["sgt", codesize, 0]] + ir.append( + IRnode.from_list( + check_codesize, error_msg="empty target (create_from_blueprint)" + ) + ) # copy the target code into memory. # layout starting from mem_ofst: diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index f47f88ac85..47a2c8c8d0 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -539,6 +539,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check): # an array index, and the clamp will throw an error. # NOTE: there are optimization rules for this when ix or bound is literal ix = clamp("lt", ix, bound) + ix.set_error_msg(f"{parent.typ} bounds check") if parent.encoding == Encoding.ABI: if parent.location == STORAGE: diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 0895e5f02d..fa015b293e 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -330,6 +330,14 @@ def is_complex_ir(self): and self.value.lower() not in do_not_cache ) + # set an error message and push down into all children. + # useful for overriding an error message generated by a helper + # function with a more specific error message. + def set_error_msg(self, error_msg: str) -> None: + self.error_msg = error_msg + for arg in self.args: + arg.set_error_msg(error_msg) + # get the unique symbols contained in this node, which provides # sanity check invariants for the optimizer. # cache because it's a perf hotspot. note that this (and other cached From 408929fa31ae01dde4f7566bb7babbc7da5b6620 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 24 Jul 2023 18:55:13 -0700 Subject: [PATCH 051/201] feat: O(1) selector tables (#3496) this commit replaces the existing linear entry point search with an O(1) implementation. there are two methods depending on whether optimizing for code size or gas, hash table with probing and perfect hashing using a two-level technique. the first method divides the selectors into buckets, uses `method_id % n_buckets` as a "guess" to where to enter the selector table and then jumps there and performs the familiar linear search for the selector ("probing"). to avoid too large buckets, the jumptable generator searches a range from ~`n_buckets * 0.85` to `n_buckets * 1.15` to minimize worst-case probe depth; the average worst case for 80-100 methods is 3 items per bucket and the worst worst case is 4 items per bucket (presumably if you get really unlucky), see `_bench_sparse()` in `vyper/codegen/jumptable_utils.py`. the average bucket size is 1.6 methods. the second method uses a perfect hashing technique. finding a single magic which produces a perfect hash is infeasible for large `N` (exponential, and in practice seems to run off a cliff around 10 methods). to "get around" this, the methods are divided into buckets of roughly size 10, and a magic is computed per bucket. several `n_buckets` are tried, trying to minimize `n_buckets`. the code size overhead of each bucket is roughly 5 bytes per bucket, which works out to ~20% per method, see `_bench_dense()` in `vyper/codegen/jumptable_utils.py`. then, the function selector is looked up in two steps - it loads the magic for the bucket given by `method_id % n_buckets`, and then uses the magic to compute the location of the function selector (and associated metadata) in the data section. from there it loads the function metadata, performs the calldatasize, callvalue and method id checks and jumps into the function. there is a gas vs code size tradeoff between the two methods - roughly speaking, the sparse method requires ~69 gas in the best case (~109 gas in the "average" case) and 12-22 bytes of code per method, while the dense method requires ~212 gas across the board, and ~8 bytes of code per method. to accomplish this implementation-wise, the jumptable info is generated in a new helper module, `vyper/codegen/jumptable_utils.py`. some refactoring had to be additionally done to pull the calldatasize, callvalue and method id checks from external function generation out into a new selector section construction step in `vyper/codegen/module.py`. additionally, a new IR "data" directive was added, and an associated assembly directive. the data segments in assembly are moved to the end of the bytecode to ensure that data bytes which happen to look like `PUSH` instructions do not mangle valid bytecode which comes after the data section. --- .github/workflows/test.yml | 15 +- docs/compiling-a-contract.rst | 17 + docs/structure-of-a-contract.rst | 4 +- tests/base_conftest.py | 4 +- .../vyper_json/test_parse_args_vyperjson.py | 4 +- tests/compiler/__init__.py | 2 + tests/compiler/test_default_settings.py | 27 ++ tests/conftest.py | 10 +- tests/parser/functions/test_slice.py | 15 +- tests/parser/test_selector_table.py | 198 ++++++++ tox.ini | 6 +- vyper/ast/grammar.lark | 2 +- vyper/cli/vyper_compile.py | 19 +- vyper/cli/vyper_ir.py | 2 +- vyper/codegen/core.py | 3 +- .../codegen/function_definitions/__init__.py | 2 +- vyper/codegen/function_definitions/common.py | 59 ++- .../function_definitions/external_function.py | 63 +-- vyper/codegen/ir_node.py | 14 +- vyper/codegen/jumptable_utils.py | 195 ++++++++ vyper/codegen/module.py | 442 +++++++++++++++--- vyper/compiler/output.py | 6 +- vyper/compiler/settings.py | 13 + vyper/ir/compile_ir.py | 241 +++++++--- 24 files changed, 1133 insertions(+), 230 deletions(-) create mode 100644 tests/compiler/test_default_settings.py create mode 100644 tests/parser/test_selector_table.py create mode 100644 vyper/codegen/jumptable_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b6399b3ae9..fd78e2fff8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,11 +78,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [["3.10", "310"], ["3.11", "311"]] + python-version: [["3.11", "311"]] # run in modes: --optimize [gas, none, codesize] - flag: ["core", "no-opt", "codesize"] + opt-mode: ["gas", "none", "codesize"] + debug: [true, false] + # run across other python versions.# we don't really need to run all + # modes across all python versions - one is enough + include: + - python-version: ["3.10", "310"] + opt-mode: gas + debug: false - name: py${{ matrix.python-version[1] }}-${{ matrix.flag }} + name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }} steps: - uses: actions/checkout@v1 @@ -97,7 +104,7 @@ jobs: run: pip install tox - name: Run Tox - run: TOXENV=py${{ matrix.python-version[1] }}-${{ matrix.flag }} tox -r -- --reruns 10 --reruns-delay 1 -r aR tests/ + run: TOXENV=py${{ matrix.python-version[1] }} tox -r -- --optimize ${{ matrix.opt-mode }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} --reruns 10 --reruns-delay 1 -r aR tests/ - name: Upload Coverage uses: codecov/codecov-action@v1 diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 208771a5a9..6d1cdf98d7 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -113,6 +113,23 @@ Remix IDE While the Vyper version of the Remix IDE compiler is updated on a regular basis, it might be a bit behind the latest version found in the master branch of the repository. Make sure the byte code matches the output from your local compiler. +.. _optimization-mode: + +Compiler Optimization Modes +=========================== + +The vyper CLI tool accepts an optimization mode ``"none"``, ``"codesize"``, or ``"gas"`` (default). It can be set using the ``--optimize`` flag. For example, invoking ``vyper --optimize codesize MyContract.vy`` will compile the contract, optimizing for code size. As a rough summary of the differences between gas and codesize mode, in gas optimized mode, the compiler will try to generate bytecode which minimizes gas (up to a point), including: + +* using a sparse selector table which optimizes for gas over codesize +* inlining some constants, and +* trying to unroll some loops, especially for data copies. + +In codesize optimized mode, the compiler will try hard to minimize codesize by + +* using a dense selector table +* out-lining code, and +* using more loops for data copies. + .. _evm-version: diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index c7abb3e645..f58ab3b067 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -37,13 +37,13 @@ In the above examples, the contract will only compile with Vyper versions ``0.3. Optimization Mode ----------------- -The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For instance, the following contract will be compiled in a way which tries to minimize codesize: +The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For example, adding the following line to a contract will cause it to try to optimize for codesize: .. code-block:: python #pragma optimize codesize -The optimization mode can also be set as a compiler option. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue. +The optimization mode can also be set as a compiler option, which is documented in :ref:`optimization-mode`. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue. EVM Version ----------------- diff --git a/tests/base_conftest.py b/tests/base_conftest.py index a78562e982..81e8dedc36 100644 --- a/tests/base_conftest.py +++ b/tests/base_conftest.py @@ -112,10 +112,10 @@ def w3(tester): return w3 -def _get_contract(w3, source_code, optimize, *args, **kwargs): +def _get_contract(w3, source_code, optimize, *args, override_opt_level=None, **kwargs): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) - settings.optimize = optimize + settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, # test that metadata gets generated diff --git a/tests/cli/vyper_json/test_parse_args_vyperjson.py b/tests/cli/vyper_json/test_parse_args_vyperjson.py index 08da5f1888..11e527843a 100644 --- a/tests/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/cli/vyper_json/test_parse_args_vyperjson.py @@ -57,7 +57,7 @@ def test_to_stdout(tmp_path, capfd): _parse_args([path.absolute().as_posix()]) out, _ = capfd.readouterr() output_json = json.loads(out) - assert _no_errors(output_json) + assert _no_errors(output_json), (INPUT_JSON, output_json) assert "contracts/foo.vy" in output_json["sources"] assert "contracts/bar.vy" in output_json["sources"] @@ -71,7 +71,7 @@ def test_to_file(tmp_path): assert output_path.exists() with output_path.open() as fp: output_json = json.load(fp) - assert _no_errors(output_json) + assert _no_errors(output_json), (INPUT_JSON, output_json) assert "contracts/foo.vy" in output_json["sources"] assert "contracts/bar.vy" in output_json["sources"] diff --git a/tests/compiler/__init__.py b/tests/compiler/__init__.py index e69de29bb2..35a11f851b 100644 --- a/tests/compiler/__init__.py +++ b/tests/compiler/__init__.py @@ -0,0 +1,2 @@ +# prevent module name collision between tests/compiler/test_pre_parser.py +# and tests/ast/test_pre_parser.py diff --git a/tests/compiler/test_default_settings.py b/tests/compiler/test_default_settings.py new file mode 100644 index 0000000000..ca05170b61 --- /dev/null +++ b/tests/compiler/test_default_settings.py @@ -0,0 +1,27 @@ +from vyper.codegen import core +from vyper.compiler.phases import CompilerData +from vyper.compiler.settings import OptimizationLevel, _is_debug_mode + + +def test_default_settings(): + source_code = "" + compiler_data = CompilerData(source_code) + _ = compiler_data.vyper_module # force settings to be computed + + assert compiler_data.settings.optimize == OptimizationLevel.GAS + + +def test_default_opt_level(): + assert OptimizationLevel.default() == OptimizationLevel.GAS + + +def test_codegen_opt_level(): + assert core._opt_level == OptimizationLevel.GAS + assert core._opt_gas() is True + assert core._opt_none() is False + assert core._opt_codesize() is False + + +def test_debug_mode(pytestconfig): + debug_mode = pytestconfig.getoption("enable_compiler_debug_mode") + assert _is_debug_mode() == debug_mode diff --git a/tests/conftest.py b/tests/conftest.py index 9c9c4191b9..d519ca3100 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from vyper import compiler from vyper.codegen.ir_node import IRnode -from vyper.compiler.settings import OptimizationLevel +from vyper.compiler.settings import OptimizationLevel, _set_debug_mode from vyper.ir import compile_ir, optimizer from .base_conftest import VyperContract, _get_contract, zero_gas_price_strategy @@ -43,6 +43,7 @@ def pytest_addoption(parser): default="gas", help="change optimization mode", ) + parser.addoption("--enable-compiler-debug-mode", action="store_true") @pytest.fixture(scope="module") @@ -51,6 +52,13 @@ def optimize(pytestconfig): return OptimizationLevel.from_string(flag) +@pytest.fixture(scope="session", autouse=True) +def debug(pytestconfig): + debug = pytestconfig.getoption("enable_compiler_debug_mode") + assert isinstance(debug, bool) + _set_debug_mode(debug) + + @pytest.fixture def keccak(): return Web3.keccak diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3064ee308e..6229b47921 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -2,6 +2,7 @@ import pytest from hypothesis import given, settings +from vyper.compiler.settings import OptimizationLevel from vyper.exceptions import ArgumentException, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -33,12 +34,15 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: @pytest.mark.parametrize("literal_start", (True, False)) @pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=25, deadline=None) +@settings(max_examples=100, deadline=None) +@pytest.mark.fuzzing def test_slice_immutable( get_contract, assert_compile_failed, assert_tx_failed, + opt_level, bytesdata, start, literal_start, @@ -64,7 +68,7 @@ def do_splice() -> Bytes[{length_bound}]: """ def _get_contract(): - return get_contract(code, bytesdata, start, length) + return get_contract(code, bytesdata, start, length, override_opt_level=opt_level) if ( (start + length > length_bound and literal_start and literal_length) @@ -84,12 +88,15 @@ def _get_contract(): @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) @pytest.mark.parametrize("literal_start", (True, False)) @pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=25, deadline=None) +@settings(max_examples=100, deadline=None) +@pytest.mark.fuzzing def test_slice_bytes( get_contract, assert_compile_failed, assert_tx_failed, + opt_level, location, bytesdata, start, @@ -133,7 +140,7 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt """ def _get_contract(): - return get_contract(code, bytesdata) + return get_contract(code, bytesdata, override_opt_level=opt_level) data_length = len(bytesdata) if location == "literal" else length_bound if ( diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py new file mode 100644 index 0000000000..01a83698b7 --- /dev/null +++ b/tests/parser/test_selector_table.py @@ -0,0 +1,198 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +import vyper.utils as utils +from vyper.codegen.jumptable_utils import ( + generate_dense_jumptable_info, + generate_sparse_jumptable_buckets, +) +from vyper.compiler.settings import OptimizationLevel + + +@given( + n_methods=st.integers(min_value=1, max_value=100), + seed=st.integers(min_value=0, max_value=2**64 - 1), +) +@pytest.mark.fuzzing +@settings(max_examples=10, deadline=None) +def test_sparse_jumptable_probe_depth(n_methods, seed): + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + _, buckets = generate_sparse_jumptable_buckets(sigs) + bucket_sizes = [len(bucket) for bucket in buckets.values()] + + # generally bucket sizes should be bounded at around 4, but + # just test that they don't get really out of hand + assert max(bucket_sizes) <= 8 + + # generally mean bucket size should be around 1.6, here just + # test they don't get really out of hand + assert sum(bucket_sizes) / len(bucket_sizes) <= 4 + + +@given( + n_methods=st.integers(min_value=4, max_value=100), + seed=st.integers(min_value=0, max_value=2**64 - 1), +) +@pytest.mark.fuzzing +@settings(max_examples=10, deadline=None) +def test_dense_jumptable_bucket_size(n_methods, seed): + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + n = len(sigs) + buckets = generate_dense_jumptable_info(sigs) + n_buckets = len(buckets) + + # generally should be around 14 buckets per 100 methods, here + # we test they don't get really out of hand + assert n_buckets / n < 0.4 or n < 10 + + +@pytest.mark.parametrize("opt_level", list(OptimizationLevel)) +# dense selector table packing boundaries at 256 and 65336 +@pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) +@settings(max_examples=5, deadline=None) +@given( + seed=st.integers(min_value=0, max_value=2**64 - 1), + max_default_args=st.integers(min_value=0, max_value=4), + default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]), +) +@pytest.mark.fuzzing +def test_selector_table_fuzz( + max_calldata_bytes, + seed, + max_default_args, + opt_level, + default_fn_mutability, + w3, + get_contract, + assert_tx_failed, + get_logs, +): + def abi_sig(calldata_words, i, n_default_args): + args = [] if not calldata_words else [f"uint256[{calldata_words}]"] + args.extend(["uint256"] * n_default_args) + argstr = ",".join(args) + return f"foo{seed + i}({argstr})" + + def generate_func_def(mutability, calldata_words, i, n_default_args): + arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"] + for j in range(n_default_args): + arglist.append(f"x{j}: uint256 = 0") + args = ", ".join(arglist) + _log_return = f"log _Return({i})" if mutability == "@payable" else "" + + return f""" +@external +{mutability} +def foo{seed + i}({args}) -> uint256: + {_log_return} + return {i} + """ + + @given( + methods=st.lists( + st.tuples( + st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), + st.integers(min_value=0, max_value=max_calldata_bytes // 32), + # n bytes to strip from calldata + st.integers(min_value=1, max_value=4), + # n default args + st.integers(min_value=0, max_value=max_default_args), + ), + min_size=1, + max_size=100, + ) + ) + @settings(max_examples=25) + def _test(methods): + func_defs = "\n".join( + generate_func_def(m, s, i, d) for i, (m, s, _, d) in enumerate(methods) + ) + + if default_fn_mutability == "": + default_fn_code = "" + elif default_fn_mutability in ("@nonpayable", "@payable"): + default_fn_code = f""" +@external +{default_fn_mutability} +def __default__(): + log CalledDefault() + """ + else: + # can't log from pure/view functions, just test that it returns + default_fn_code = """ +@external +def __default__(): + pass + """ + + code = f""" +event CalledDefault: + pass + +event _Return: + val: uint256 + +{func_defs} + +{default_fn_code} + """ + + c = get_contract(code, override_opt_level=opt_level) + + for i, (mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate(methods): + funcname = f"foo{seed + i}" + func = getattr(c, funcname) + + for j in range(n_default_args + 1): + args = [[1] * n_calldata_words] if n_calldata_words else [] + args.extend([1] * j) + + # check the function returns as expected + assert func(*args) == i + + method_id = utils.method_id(abi_sig(n_calldata_words, i, j)) + + argsdata = b"\x00" * (n_calldata_words * 32 + j * 32) + + # do payable check + if mutability == "@payable": + tx = func(*args, transact={"value": 1}) + (event,) = get_logs(tx, c, "_Return") + assert event.args.val == i + else: + hexstr = (method_id + argsdata).hex() + txdata = {"to": c.address, "data": hexstr, "value": 1} + assert_tx_failed(lambda: w3.eth.send_transaction(txdata)) + + # now do calldatasize check + # strip some bytes + calldata = (method_id + argsdata)[:-n_strip_bytes] + hexstr = calldata.hex() + tx_params = {"to": c.address, "data": hexstr} + if n_calldata_words == 0 and j == 0: + # no args, hit default function + if default_fn_mutability == "": + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + elif default_fn_mutability == "@payable": + # we should be able to send eth to it + tx_params["value"] = 1 + tx = w3.eth.send_transaction(tx_params) + logs = get_logs(tx, c, "CalledDefault") + assert len(logs) == 1 + else: + tx = w3.eth.send_transaction(tx_params) + + # note: can't emit logs from view/pure functions, + # so the logging is not tested. + if default_fn_mutability == "@nonpayable": + logs = get_logs(tx, c, "CalledDefault") + assert len(logs) == 1 + + # check default function reverts + tx_params["value"] = 1 + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + else: + assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + + _test() diff --git a/tox.ini b/tox.ini index 9b63630f58..c949354dfe 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py{310,311}-{core,no-opt} + py{310,311} lint mypy docs @@ -8,9 +8,7 @@ envlist = [testenv] usedevelop = True commands = - core: pytest -m "not fuzzing" --showlocals {posargs:tests/} - no-opt: pytest -m "not fuzzing" --showlocals --optimize none {posargs:tests/} - codesize: pytest -m "not fuzzing" --showlocals --optimize codesize {posargs:tests/} + pytest -m "not fuzzing" --showlocals {posargs:tests/} basepython = py310: python3.10 py311: python3.11 diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 77806d734c..ca9979b2a3 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -72,8 +72,8 @@ function_def: [decorators] function_sig ":" body _EVENT_DECL: "event" event_member: NAME ":" type indexed_event_arg: NAME ":" "indexed" "(" type ")" -event_body: _NEWLINE _INDENT ((event_member | indexed_event_arg) _NEWLINE)+ _DEDENT // Events which use no args use a pass statement instead +event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT event_def: _EVENT_DECL NAME ":" ( event_body | _PASS ) // Enums diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 55e0fc82b2..9c96d55040 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -11,7 +11,12 @@ import vyper.codegen.ir_node as ir_node from vyper.cli import vyper_json from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path -from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT, OptimizationLevel, Settings +from vyper.compiler.settings import ( + VYPER_TRACEBACK_LIMIT, + OptimizationLevel, + Settings, + _set_debug_mode, +) from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS from vyper.typing import ContractCodes, ContractPath, OutputFormats @@ -105,7 +110,12 @@ def _parse_args(argv): dest="evm_version", ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") - parser.add_argument("--optimize", help="Optimization flag", choices=["gas", "codesize", "none"]) + parser.add_argument( + "--optimize", + help="Optimization flag (defaults to 'gas')", + choices=["gas", "codesize", "none"], + ) + parser.add_argument("--debug", help="Compile in debug mode", action="store_true") parser.add_argument( "--no-bytecode-metadata", help="Do not add metadata to bytecode", action="store_true" ) @@ -151,6 +161,9 @@ def _parse_args(argv): output_formats = tuple(uniq(args.format.split(","))) + if args.debug: + _set_debug_mode(True) + if args.no_optimize and args.optimize: raise ValueError("Cannot use `--no-optimize` and `--optimize` at the same time!") @@ -165,7 +178,7 @@ def _parse_args(argv): settings.evm_version = args.evm_version if args.verbose: - print(f"using `{settings}`", file=sys.stderr) + print(f"cli specified: `{settings}`", file=sys.stderr) compiled = compile_files( args.input_files, diff --git a/vyper/cli/vyper_ir.py b/vyper/cli/vyper_ir.py index 6831f39473..1f90badcaa 100755 --- a/vyper/cli/vyper_ir.py +++ b/vyper/cli/vyper_ir.py @@ -55,7 +55,7 @@ def compile_to_ir(input_file, output_formats, show_gas_estimates=False): compiler_data["asm"] = asm if "bytecode" in output_formats: - (bytecode, _srcmap) = compile_ir.assembly_to_evm(asm) + bytecode, _ = compile_ir.assembly_to_evm(asm) compiler_data["bytecode"] = "0x" + bytecode.hex() return compiler_data diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 47a2c8c8d0..e1d3ea12b4 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1033,7 +1033,6 @@ def eval_seq(ir_node): return None -# TODO move return checks to vyper/semantics/validation def is_return_from_function(node): if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in ( "raw_revert", @@ -1045,6 +1044,8 @@ def is_return_from_function(node): return False +# TODO this is almost certainly duplicated with check_terminus_node +# in vyper/semantics/analysis/local.py def check_single_exit(fn_node): _check_return_body(fn_node, fn_node.body) for node in fn_node.get_descendants(vy_ast.If): diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 08bebbb4a5..94617bef35 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1 @@ -from .common import generate_ir_for_function # noqa +from .common import FuncIR, generate_ir_for_function # noqa diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index fd65b12265..3fd5ce0b29 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -4,7 +4,7 @@ import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.core import check_single_exit, getpos +from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.global_context import GlobalContext @@ -63,12 +63,32 @@ def internal_function_label(self, is_ctor_context: bool = False) -> str: return self.ir_identifier + suffix +class FuncIR: + pass + + +@dataclass +class EntryPointInfo: + func_t: ContractFunctionT + min_calldatasize: int # the min calldata required for this entry point + ir_node: IRnode # the ir for this entry point + + +@dataclass +class ExternalFuncIR(FuncIR): + entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points + common_ir: IRnode # the "common" code for the function + + +@dataclass +class InternalFuncIR(FuncIR): + func_ir: IRnode # the code for the function + + +# TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, - global_ctx: GlobalContext, - skip_nonpayable_check: bool, - is_ctor_context: bool = False, -) -> IRnode: + code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False +) -> FuncIR: """ Parse a function and produce IR code for the function, includes: - Signature method if statement @@ -82,6 +102,7 @@ def generate_ir_for_function( func_t._ir_info = _FuncIRInfo(func_t) # Validate return statements. + # XXX: This should really be in semantics pass. check_single_exit(code) callees = func_t.called_functions @@ -106,19 +127,23 @@ def generate_ir_for_function( ) if func_t.is_internal: - assert skip_nonpayable_check is False - o = generate_ir_for_internal_function(code, func_t, context) + ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) + func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: - if func_t.is_payable: - assert skip_nonpayable_check is False # nonsense - o = generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check) - - o.source_pos = getpos(code) + kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) + entry_points = { + k: EntryPointInfo(func_t, mincalldatasize, ir_node) + for k, (mincalldatasize, ir_node) in kwarg_handlers.items() + } + ret = ExternalFuncIR(entry_points, common) + # note: this ignores the cost of traversing selector table + func_t._ir_info.gas_estimate = ret.common_ir.gas frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY frame_info = FrameInfo(allocate_start, frame_size, context.vars) + # XXX: when can this happen? if func_t._ir_info.frame_info is None: func_t._ir_info.set_frame_info(frame_info) else: @@ -128,9 +153,7 @@ def generate_ir_for_function( # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since - # it is already accounted for by the caller.) - o.add_gas_estimate += calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - - func_t._ir_info.gas_estimate = o.gas + mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore + ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - return o + return ret diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 207356860b..32236e9aad 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -1,6 +1,3 @@ -from typing import Any, List - -import vyper.utils as util from vyper.codegen.abi_encoder import abi_encoding_matches_vyper from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp @@ -15,7 +12,7 @@ # register function args with the local calling context. # also allocate the ones that live in memory (i.e. kwargs) -def _register_function_args(func_t: ContractFunctionT, context: Context) -> List[IRnode]: +def _register_function_args(func_t: ContractFunctionT, context: Context) -> list[IRnode]: ret = [] # the type of the calldata base_args_t = TupleT(tuple(arg.typ for arg in func_t.positional_args)) @@ -52,13 +49,9 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> List return ret -def _annotated_method_id(abi_sig): - method_id = util.method_id_int(abi_sig) - annotation = f"{hex(method_id)}: {abi_sig}" - return IRnode(method_id, annotation=annotation) - - -def _generate_kwarg_handlers(func_t: ContractFunctionT, context: Context) -> List[Any]: +def _generate_kwarg_handlers( + func_t: ContractFunctionT, context: Context +) -> dict[str, tuple[int, IRnode]]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -75,7 +68,6 @@ def handler_for(calldata_kwargs, default_kwargs): calldata_args_t = TupleT(list(arg.typ for arg in calldata_args)) abi_sig = func_t.abi_signature_for_kwargs(calldata_kwargs) - method_id = _annotated_method_id(abi_sig) calldata_kwargs_ofst = IRnode( 4, location=CALLDATA, typ=calldata_args_t, encoding=Encoding.ABI @@ -88,11 +80,6 @@ def handler_for(calldata_kwargs, default_kwargs): args_abi_t = calldata_args_t.abi_type calldata_min_size = args_abi_t.min_size() + 4 - # note we don't need the check if calldata_min_size == 4, - # because the global calldatasize check ensures that already. - if calldata_min_size > 4: - ret.append(["assert", ["ge", "calldatasize", calldata_min_size]]) - # TODO optimize make_setter by using # TupleT(list(arg.typ for arg in calldata_kwargs + default_kwargs)) # (must ensure memory area is contiguous) @@ -123,11 +110,10 @@ def handler_for(calldata_kwargs, default_kwargs): ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) - method_id_check = ["eq", "_calldata_method_id", method_id] - ret = ["if", method_id_check, ret] - return ret + # return something we can turn into ExternalFuncIR + return abi_sig, calldata_min_size, ret - ret = ["seq"] + ret = {} keyword_args = func_t.keyword_args @@ -139,9 +125,12 @@ def handler_for(calldata_kwargs, default_kwargs): calldata_kwargs = keyword_args[:i] default_kwargs = keyword_args[i:] - ret.append(handler_for(calldata_kwargs, default_kwargs)) + sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) + ret[sig] = calldata_min_size, ir_node + + sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret.append(handler_for(keyword_args, [])) + ret[sig] = calldata_min_size, ir_node return ret @@ -149,7 +138,7 @@ def handler_for(calldata_kwargs, default_kwargs): # TODO it would be nice if this returned a data structure which were # amenable to generating a jump table instead of the linear search for # method_id we have now. -def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_check): +def generate_ir_for_external_function(code, func_t, context): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, @@ -174,14 +163,6 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che # generate the main body of the function body += handle_base_args - if not func_t.is_payable and not skip_nonpayable_check: - # if the contract contains payable functions, but this is not one of them - # add an assertion that the value of the call is zero - nonpayable_check = IRnode.from_list( - ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" - ) - body.append(nonpayable_check) - body += nonreentrant_pre body += [parse_body(code.body, context, ensure_terminated=True)] @@ -201,22 +182,10 @@ def generate_ir_for_external_function(code, func_t, context, skip_nonpayable_che if context.return_type is not None: exit_sequence_args += ["ret_ofst", "ret_len"] # wrap the exit in a labeled block - exit = ["label", func_t._ir_info.exit_sequence_label, exit_sequence_args, exit_sequence] + exit_ = ["label", func_t._ir_info.exit_sequence_label, exit_sequence_args, exit_sequence] # the ir which comprises the main body of the function, # besides any kwarg handling - func_common_ir = ["seq", body, exit] - - if func_t.is_fallback or func_t.is_constructor: - ret = ["seq"] - # add a goto to make the function entry look like other functions - # (for zksync interpreter) - ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) - ret.append(func_common_ir) - else: - ret = kwarg_handlers - # sneak the base code into the kwarg handler - # TODO rethink this / make it clearer - ret[-1][-1].append(func_common_ir) + func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return IRnode.from_list(ret, source_pos=getpos(code)) + return kwarg_handlers, func_common_ir diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index fa015b293e..6cb0a07281 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -148,6 +148,13 @@ def _check(condition, err): self.valency = 1 self._gas = 5 + elif isinstance(self.value, bytes): + # a literal bytes value, probably inside a "data" node. + _check(len(self.args) == 0, "bytes can't have arguments") + + self.valency = 0 + self._gas = 0 + elif isinstance(self.value, str): # Opcodes and pseudo-opcodes (e.g. clamp) if self.value.upper() in get_ir_opcodes(): @@ -264,8 +271,11 @@ def _check(condition, err): self.valency = 0 self._gas = sum([arg.gas for arg in self.args]) elif self.value == "label": - if not self.args[1].value == "var_list": - raise CodegenPanic(f"2nd argument to label must be var_list, {self}") + _check( + self.args[1].value == "var_list", + f"2nd argument to label must be var_list, {self}", + ) + _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}") self.valency = 0 self._gas = 1 + sum(t.gas for t in self.args) elif self.value == "unique_symbol": diff --git a/vyper/codegen/jumptable_utils.py b/vyper/codegen/jumptable_utils.py new file mode 100644 index 0000000000..6987ce90bd --- /dev/null +++ b/vyper/codegen/jumptable_utils.py @@ -0,0 +1,195 @@ +# helper module which implements jumptable for function selection +import math +from dataclasses import dataclass + +from vyper.utils import method_id_int + + +@dataclass +class Signature: + method_id: int + payable: bool + + +# bucket for dense function +@dataclass +class Bucket: + bucket_id: int + magic: int + method_ids: list[int] + + @property + def image(self): + return _image_of([s for s in self.method_ids], self.magic) + + @property + # return method ids, sorted by by their image + def method_ids_image_order(self): + return [x[1] for x in sorted(zip(self.image, self.method_ids))] + + @property + def bucket_size(self): + return len(self.method_ids) + + +BITS_MAGIC = 24 # a constant which produced good results, see _bench_dense() + + +def _image_of(xs, magic): + bits_shift = BITS_MAGIC + + # take the upper bits from the multiplication for more entropy + # can we do better using primes of some sort? + return [((x * magic) >> bits_shift) % len(xs) for x in xs] + + +class _Failure(Exception): + pass + + +def find_magic_for(xs): + for m in range(2**16): + test = _image_of(xs, m) + if len(test) == len(set(test)): + return m + + raise _Failure(f"Could not find hash for {xs}") + + +def _mk_buckets(method_ids, n_buckets): + buckets = {} + for x in method_ids: + t = x % n_buckets + buckets.setdefault(t, []) + buckets[t].append(x) + return buckets + + +# two layer method for generating perfect hash +# first get "reasonably good" distribution by using +# method_id % len(method_ids) +# second, get the magic for the bucket. +def _dense_jumptable_info(method_ids, n_buckets): + buckets = _mk_buckets(method_ids, n_buckets) + + ret = {} + for bucket_id, method_ids in buckets.items(): + magic = find_magic_for(method_ids) + ret[bucket_id] = Bucket(bucket_id, magic, method_ids) + + return ret + + +START_BUCKET_SIZE = 5 + + +# this is expensive! for 80 methods, costs about 350ms and probably +# linear in # of methods. +# see _bench_perfect() +# note the buckets are NOT in order! +def generate_dense_jumptable_info(signatures): + method_ids = [method_id_int(sig) for sig in signatures] + n = len(signatures) + # start at bucket size of 5 and try to improve (generally + # speaking we want as few buckets as possible) + n_buckets = (n // START_BUCKET_SIZE) + 1 + ret = None + tried_exhaustive = False + while n_buckets > 0: + try: + # print(f"trying {n_buckets} (bucket size {n // n_buckets})") + ret = _dense_jumptable_info(method_ids, n_buckets) + except _Failure: + if ret is not None: + break + + # we have not tried exhaustive search. try really hard + # to find a valid jumptable at the cost of performance + if not tried_exhaustive: + # print("failed with guess! trying exhaustive search.") + n_buckets = n + tried_exhaustive = True + continue + else: + raise RuntimeError(f"Could not generate jumptable! {signatures}") + n_buckets -= 1 + + return ret + + +# note the buckets are NOT in order! +def generate_sparse_jumptable_buckets(signatures): + method_ids = [method_id_int(sig) for sig in signatures] + n = len(signatures) + + # search a range of buckets to try to minimize bucket size + # (doing the range search improves worst worst bucket size from 9 to 4, + # see _bench_sparse) + lo = max(1, math.floor(n * 0.85)) + hi = max(1, math.ceil(n * 1.15)) + stats = {} + for i in range(lo, hi + 1): + buckets = _mk_buckets(method_ids, i) + + stats[i] = buckets + + min_max_bucket_size = hi + 1 # smallest max_bucket_size + # find the smallest i which gives us the smallest max_bucket_size + for i, buckets in stats.items(): + max_bucket_size = max(len(bucket) for bucket in buckets.values()) + if max_bucket_size < min_max_bucket_size: + min_max_bucket_size = max_bucket_size + ret = i, buckets + + assert ret is not None + return ret + + +# benchmark for quality of buckets +def _bench_dense(N=1_000, n_methods=100): + import random + + stats = [] + for i in range(N): + seed = random.randint(0, 2**64 - 1) + # "large" contracts in prod hit about ~50 methods, test with + # double the limit + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + + xs = generate_dense_jumptable_info(sigs) + print(f"found. n buckets {len(xs)}") + stats.append(xs) + + def mean(xs): + return sum(xs) / len(xs) + + avg_n_buckets = mean([len(jt) for jt in stats]) + # usually around ~14 buckets per 100 sigs + # N=10, time=3.6s + print(f"average N buckets: {avg_n_buckets}") + + +def _bench_sparse(N=10_000, n_methods=80): + import random + + stats = [] + for _ in range(N): + seed = random.randint(0, 2**64 - 1) + sigs = [f"foo{i + seed}()" for i in range(n_methods)] + _, buckets = generate_sparse_jumptable_buckets(sigs) + + bucket_sizes = [len(bucket) for bucket in buckets.values()] + worst_bucket_size = max(bucket_sizes) + mean_bucket_size = sum(bucket_sizes) / len(bucket_sizes) + stats.append((worst_bucket_size, mean_bucket_size)) + + # N=10_000, time=9s + # range 0.85*n - 1.15*n + # worst worst bucket size: 4 + # avg worst bucket size: 3.0018 + # worst mean bucket size: 2.0 + # avg mean bucket size: 1.579112583664968 + print("worst worst bucket size:", max(x[0] for x in stats)) + print("avg worst bucket size:", sum(x[0] for x in stats) / len(stats)) + print("worst mean bucket size:", max(x[1] for x in stats)) + print("avg mean bucket size:", sum(x[1] for x in stats) / len(stats)) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index b98e4d0f86..ebe7f92cf2 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -1,12 +1,15 @@ -# a contract.vy -- all functions and constructor +# a compilation unit -- all functions and constructor from typing import Any, List +from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic +from vyper.utils import method_id_int def _topsort_helper(functions, lookup): @@ -47,92 +50,349 @@ def _is_payable(func_ast): return func_ast._metadata["type"].is_payable -# codegen for all runtime functions + callvalue/calldata checks + method selector routines -def _runtime_ir(runtime_functions, global_ctx): - # categorize the runtime functions because we will organize the runtime - # code into the following sections: - # payable functions, nonpayable functions, fallback function, internal_functions - internal_functions = [f for f in runtime_functions if _is_internal(f)] +def _annotated_method_id(abi_sig): + method_id = method_id_int(abi_sig) + annotation = f"{hex(method_id)}: {abi_sig}" + return IRnode(method_id, annotation=annotation) - external_functions = [f for f in runtime_functions if not _is_internal(f)] - default_function = next((f for f in external_functions if _is_fallback(f)), None) - # functions that need to go exposed in the selector section - regular_functions = [f for f in external_functions if not _is_fallback(f)] - payables = [f for f in regular_functions if _is_payable(f)] - nonpayables = [f for f in regular_functions if not _is_payable(f)] +def label_for_entry_point(abi_sig, entry_point): + method_id = method_id_int(abi_sig) + return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}" - # create a map of the IR functions since they might live in both - # runtime and deploy code (if init function calls them) - internal_functions_ir: list[IRnode] = [] - for func_ast in internal_functions: - func_ir = generate_ir_for_function(func_ast, global_ctx, False) - internal_functions_ir.append(func_ir) +# adapt whatever generate_ir_for_function gives us into an IR node +def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): + func_t = func_ast._metadata["type"] + assert func_t.is_fallback or func_t.is_constructor + + ret = ["seq"] + if not func_t.is_payable: + callvalue_check = ["assert", ["iszero", "callvalue"]] + ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) + + func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + assert len(func_ir.entry_points) == 1 + + # add a goto to make the function entry look like other functions + # (for zksync interpreter) + ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) + ret.append(func_ir.common_ir) + + return IRnode.from_list(ret) + + +def _ir_for_internal_function(func_ast, *args, **kwargs): + return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + + +def _generate_external_entry_points(external_functions, global_ctx): + entry_points = {} # map from ABI sigs to ir code + sig_of = {} # reverse map from method ids to abi sig + + for code in external_functions: + func_ir = generate_ir_for_function(code, global_ctx) + for abi_sig, entry_point in func_ir.entry_points.items(): + assert abi_sig not in entry_points + entry_points[abi_sig] = entry_point + sig_of[method_id_int(abi_sig)] = abi_sig + + # stick function common body into final entry point to save a jump + ir_node = IRnode.from_list(["seq", entry_point.ir_node, func_ir.common_ir]) + entry_point.ir_node = ir_node + + return entry_points, sig_of + + +# codegen for all runtime functions + callvalue/calldata checks, +# with O(1) jumptable for selector table. +# uses two level strategy: uses `method_id % n_buckets` to descend +# into a bucket (of about 8-10 items), and then uses perfect hash +# to select the final function. +# costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) +def _selector_section_dense(external_functions, global_ctx): + function_irs = [] - # for some reason, somebody may want to deploy a contract with no - # external functions, or more likely, a "pure data" contract which - # contains immutables if len(external_functions) == 0: - # TODO: prune internal functions in this case? dead code eliminator - # might not eliminate them, since internal function jumpdest is at the - # first instruction in the contract. - runtime = ["seq"] + internal_functions_ir - return runtime + return IRnode.from_list(["seq"]) - # note: if the user does not provide one, the default fallback function - # reverts anyway. so it does not hurt to batch the payable check. - default_is_nonpayable = default_function is None or not _is_payable(default_function) + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) - # when a contract has a nonpayable default function, - # we can do a single check for all nonpayable functions - batch_payable_check = len(nonpayables) > 0 and default_is_nonpayable - skip_nonpayable_check = batch_payable_check + # generate the label so the jumptable works + for abi_sig, entry_point in entry_points.items(): + label = label_for_entry_point(abi_sig, entry_point) + ir_node = ["label", label, ["var_list"], entry_point.ir_node] + function_irs.append(IRnode.from_list(ir_node)) - selector_section = ["seq"] + jumptable_info = jumptable_utils.generate_dense_jumptable_info(entry_points.keys()) + n_buckets = len(jumptable_info) + + # bucket magic <2 bytes> | bucket location <2 bytes> | bucket size <1 byte> + # TODO: can make it smaller if the largest bucket magic <= 255 + SZ_BUCKET_HEADER = 5 - for func_ast in payables: - func_ir = generate_ir_for_function(func_ast, global_ctx, False) - selector_section.append(func_ir) + selector_section = ["seq"] - if batch_payable_check: - nonpayable_check = IRnode.from_list( - ["assert", ["iszero", "callvalue"]], error_msg="nonpayable check" + bucket_id = ["mod", "_calldata_method_id", n_buckets] + bucket_hdr_location = [ + "add", + ["symbol", "BUCKET_HEADERS"], + ["mul", bucket_id, SZ_BUCKET_HEADER], + ] + # get bucket header + dst = 32 - SZ_BUCKET_HEADER + assert dst >= 0 + + if _is_debug_mode(): + selector_section.append(["assert", ["eq", "msize", 0]]) + + selector_section.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) + + # figure out the minimum number of bytes we can use to encode + # min_calldatasize in function info + largest_mincalldatasize = max(f.min_calldatasize for f in entry_points.values()) + FN_METADATA_BYTES = (largest_mincalldatasize.bit_length() + 7) // 8 + + func_info_size = 4 + 2 + FN_METADATA_BYTES + # grab function info. + # method id <4 bytes> | label <2 bytes> | func info <1-3 bytes> + # func info (1-3 bytes, packed) for: expected calldatasize, is_nonpayable bit + # NOTE: might be able to improve codesize if we use variable # of bytes + # per bucket + + hdr_info = IRnode.from_list(["mload", 0]) + with hdr_info.cache_when_complex("hdr_info") as (b1, hdr_info): + bucket_location = ["and", 0xFFFF, shr(8, hdr_info)] + bucket_magic = shr(24, hdr_info) + bucket_size = ["and", 0xFF, hdr_info] + # ((method_id * bucket_magic) >> BITS_MAGIC) % bucket_size + func_id = [ + "mod", + shr(jumptable_utils.BITS_MAGIC, ["mul", bucket_magic, "_calldata_method_id"]), + bucket_size, + ] + func_info_location = ["add", bucket_location, ["mul", func_id, func_info_size]] + dst = 32 - func_info_size + assert func_info_size >= SZ_BUCKET_HEADER # otherwise mload will have dirty bytes + assert dst >= 0 + selector_section.append(b1.resolve(["codecopy", dst, func_info_location, func_info_size])) + + func_info = IRnode.from_list(["mload", 0]) + fn_metadata_mask = 2 ** (FN_METADATA_BYTES * 8) - 1 + calldatasize_mask = fn_metadata_mask - 1 # ex. 0xFFFE + with func_info.cache_when_complex("func_info") as (b1, func_info): + x = ["seq"] + + # expected calldatasize always satisfies (x - 4) % 32 == 0 + # the lower 5 bits are always 0b00100, so we can use those + # bits for other purposes. + is_nonpayable = ["and", 1, func_info] + expected_calldatasize = ["and", calldatasize_mask, func_info] + + label_bits_ofst = FN_METADATA_BYTES * 8 + function_label = ["and", 0xFFFF, shr(label_bits_ofst, func_info)] + method_id_bits_ofst = (FN_METADATA_BYTES + 2) * 8 + function_method_id = shr(method_id_bits_ofst, func_info) + + # check method id is right, if not then fallback. + # need to check calldatasize >= 4 in case there are + # trailing 0s in the method id. + calldatasize_valid = ["gt", "calldatasize", 3] + method_id_correct = ["eq", function_method_id, "_calldata_method_id"] + should_fallback = ["iszero", ["and", calldatasize_valid, method_id_correct]] + x.append(["if", should_fallback, ["goto", "fallback"]]) + + # assert callvalue == 0 if nonpayable + bad_callvalue = ["mul", is_nonpayable, "callvalue"] + # assert calldatasize at least minimum for the abi type + bad_calldatasize = ["lt", "calldatasize", expected_calldatasize] + failed_entry_conditions = ["or", bad_callvalue, bad_calldatasize] + check_entry_conditions = IRnode.from_list( + ["assert", ["iszero", failed_entry_conditions]], + error_msg="bad calldatasize or callvalue", ) - selector_section.append(nonpayable_check) + x.append(check_entry_conditions) + x.append(["jump", function_label]) + selector_section.append(b1.resolve(x)) + + bucket_headers = ["data", "BUCKET_HEADERS"] + + for bucket_id, bucket in sorted(jumptable_info.items()): + bucket_headers.append(bucket.magic.to_bytes(2, "big")) + bucket_headers.append(["symbol", f"bucket_{bucket_id}"]) + # note: buckets are usually ~10 items. to_bytes would + # fail if the int is too big. + bucket_headers.append(bucket.bucket_size.to_bytes(1, "big")) + + selector_section.append(bucket_headers) + + for bucket_id, bucket in jumptable_info.items(): + function_infos = ["data", f"bucket_{bucket_id}"] + # sort function infos by their image. + for method_id in bucket.method_ids_image_order: + abi_sig = sig_of[method_id] + entry_point = entry_points[abi_sig] + + method_id_bytes = method_id.to_bytes(4, "big") + symbol = ["symbol", label_for_entry_point(abi_sig, entry_point)] + func_metadata_int = entry_point.min_calldatasize | int( + not entry_point.func_t.is_payable + ) + func_metadata = func_metadata_int.to_bytes(FN_METADATA_BYTES, "big") - for func_ast in nonpayables: - func_ir = generate_ir_for_function(func_ast, global_ctx, skip_nonpayable_check) - selector_section.append(func_ir) + function_infos.extend([method_id_bytes, symbol, func_metadata]) - if default_function: - fallback_ir = generate_ir_for_function( - default_function, global_ctx, skip_nonpayable_check=False - ) - else: - fallback_ir = IRnode.from_list( - ["revert", 0, 0], annotation="Default function", error_msg="fallback function" - ) + selector_section.append(function_infos) - # ensure the external jumptable section gets closed out - # (for basic block hygiene and also for zksync interpreter) - # NOTE: this jump gets optimized out in assembly since the - # fallback label is the immediate next instruction, - close_selector_section = ["goto", "fallback"] + ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section]] - global_calldatasize_check = ["if", ["lt", "calldatasize", 4], ["goto", "fallback"]] + ret.extend(function_irs) - runtime = [ - "seq", - global_calldatasize_check, - ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section], - close_selector_section, - ["label", "fallback", ["var_list"], fallback_ir], - ] + return ret - runtime.extend(internal_functions_ir) - return runtime +# codegen for all runtime functions + callvalue/calldata checks, +# with O(1) jumptable for selector table. +# uses two level strategy: uses `method_id % n_methods` to calculate +# a bucket, and then descends into linear search from there. +# costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) +# function and 24 bytes of code (+ ~23 bytes of global overhead) +def _selector_section_sparse(external_functions, global_ctx): + ret = ["seq"] + + if len(external_functions) == 0: + return ret + + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + + n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) + + # 2 bytes for bucket location + SZ_BUCKET_HEADER = 2 + + if n_buckets > 1: + bucket_id = ["mod", "_calldata_method_id", n_buckets] + bucket_hdr_location = [ + "add", + ["symbol", "selector_buckets"], + ["mul", bucket_id, SZ_BUCKET_HEADER], + ] + # get bucket header + dst = 32 - SZ_BUCKET_HEADER + assert dst >= 0 + + if _is_debug_mode(): + ret.append(["assert", ["eq", "msize", 0]]) + + ret.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) + + jumpdest = IRnode.from_list(["mload", 0]) + # don't particularly like using `jump` here since it can cause + # issues for other backends, consider changing `goto` to allow + # dynamic jumps, or adding some kind of jumptable instruction + ret.append(["jump", jumpdest]) + + jumptable_data = ["data", "selector_buckets"] + for i in range(n_buckets): + if i in buckets: + bucket_label = f"selector_bucket_{i}" + jumptable_data.append(["symbol", bucket_label]) + else: + # empty bucket + jumptable_data.append(["symbol", "fallback"]) + + ret.append(jumptable_data) + + for bucket_id, bucket in buckets.items(): + bucket_label = f"selector_bucket_{bucket_id}" + ret.append(["label", bucket_label, ["var_list"], ["seq"]]) + + handle_bucket = ["seq"] + + for method_id in bucket: + sig = sig_of[method_id] + entry_point = entry_points[sig] + func_t = entry_point.func_t + expected_calldatasize = entry_point.min_calldatasize + + dispatch = ["seq"] # code to dispatch into the function + skip_callvalue_check = func_t.is_payable + skip_calldatasize_check = expected_calldatasize == 4 + bad_callvalue = [0] if skip_callvalue_check else ["callvalue"] + bad_calldatasize = ( + [0] if skip_calldatasize_check else ["lt", "calldatasize", expected_calldatasize] + ) + + dispatch.append( + IRnode.from_list( + ["assert", ["iszero", ["or", bad_callvalue, bad_calldatasize]]], + error_msg="bad calldatasize or callvalue", + ) + ) + # we could skip a jumpdest per method if we out-lined the entry point + # so the dispatcher looks just like - + # ```(if (eq method_id) + # (goto entry_point_label))``` + # it would another optimization for patterns like + # `if ... (goto)` though. + dispatch.append(entry_point.ir_node) + + method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)] + has_trailing_zeroes = method_id.to_bytes(4, "big").endswith(b"\x00") + if has_trailing_zeroes: + # if the method id check has trailing 0s, we need to include + # a calldatasize check to distinguish from when not enough + # bytes are provided for the method id in calldata. + method_id_check = ["and", ["ge", "calldatasize", 4], method_id_check] + handle_bucket.append(["if", method_id_check, dispatch]) + + # close out the bucket with a goto fallback so we don't keep searching + handle_bucket.append(["goto", "fallback"]) + + ret.append(handle_bucket) + + ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), ret]] + + return ret + + +# codegen for all runtime functions + callvalue/calldata checks, +# O(n) linear search for the method id +# mainly keep this in for backends which cannot handle the indirect jump +# in selector_section_dense and selector_section_sparse +def _selector_section_linear(external_functions, global_ctx): + ret = ["seq"] + if len(external_functions) == 0: + return ret + + ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) + + entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + + dispatcher = ["seq"] + + for sig, entry_point in entry_points.items(): + func_t = entry_point.func_t + expected_calldatasize = entry_point.min_calldatasize + + dispatch = ["seq"] # code to dispatch into the function + + if not func_t.is_payable: + callvalue_check = ["assert", ["iszero", "callvalue"]] + dispatch.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) + + good_calldatasize = ["ge", "calldatasize", expected_calldatasize] + calldatasize_check = ["assert", good_calldatasize] + dispatch.append(IRnode.from_list(calldatasize_check, error_msg="calldatasize check")) + + dispatch.append(entry_point.ir_node) + + method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)] + dispatcher.append(["if", method_id_check, dispatch]) + + ret.append(["with", "_calldata_method_id", shr(224, ["calldataload", 0]), dispatcher]) + + return ret # take a GlobalContext, and generate the runtime and deploy IR @@ -143,15 +403,47 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) - runtime = _runtime_ir(runtime_functions, global_ctx) + internal_functions = [f for f in runtime_functions if _is_internal(f)] + + external_functions = [ + f for f in runtime_functions if not _is_internal(f) and not _is_fallback(f) + ] + default_function = next((f for f in runtime_functions if _is_fallback(f)), None) + + internal_functions_ir: list[IRnode] = [] + + # compile internal functions first so we have the function info + for func_ast in internal_functions: + func_ir = _ir_for_internal_function(func_ast, global_ctx, False) + internal_functions_ir.append(IRnode.from_list(func_ir)) + + if core._opt_none(): + selector_section = _selector_section_linear(external_functions, global_ctx) + # dense vs sparse global overhead is amortized after about 4 methods. + # (--debug will force dense selector table anyway if _opt_codesize is selected.) + elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): + selector_section = _selector_section_dense(external_functions, global_ctx) + else: + selector_section = _selector_section_sparse(external_functions, global_ctx) + + if default_function: + fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + else: + fallback_ir = IRnode.from_list( + ["revert", 0, 0], annotation="Default function", error_msg="fallback function" + ) + + runtime = ["seq", selector_section] + runtime.append(["goto", "fallback"]) + runtime.append(["label", "fallback", ["var_list"], fallback_ir]) + + runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] immutables_len = global_ctx.immutable_section_bytes if init_function: # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = generate_ir_for_function( - init_function, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables @@ -184,12 +476,10 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: for f in internal_functions: init_func_t = init_function._metadata["type"] if f.name not in init_func_t.recursive_calls: - # unreachable + # unreachable code, delete it continue - func_ir = generate_ir_for_function( - f, global_ctx, skip_nonpayable_check=False, is_ctor_context=True - ) + func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) deploy_code.append(func_ir) else: diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 63d92d9a47..69fcbf1f1f 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -300,9 +300,13 @@ def _build_opcodes(bytecode: bytes) -> str: while bytecode_sequence: op = bytecode_sequence.popleft() - opcode_output.append(opcode_map[op]) + opcode_output.append(opcode_map.get(op, f"VERBATIM_{hex(op)}")) if "PUSH" in opcode_output[-1] and opcode_output[-1] != "PUSH0": push_len = int(opcode_map[op][4:]) + # we can have push_len > len(bytecode_sequence) when there is data + # (instead of code) at end of contract + # CMC 2023-07-13 maybe just strip known data segments? + push_len = min(push_len, len(bytecode_sequence)) push_values = [hex(bytecode_sequence.popleft())[2:] for i in range(push_len)] opcode_output.append(f"0x{''.join(push_values).upper()}") diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index bb5e9cdc25..d2c88a8592 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -42,3 +42,16 @@ class Settings: compiler_version: Optional[str] = None optimize: Optional[OptimizationLevel] = None evm_version: Optional[str] = None + + +_DEBUG = False + + +def _is_debug_mode(): + global _DEBUG + return _DEBUG + + +def _set_debug_mode(dbg: bool = False) -> None: + global _DEBUG + _DEBUG = dbg diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index a9064a44fa..5e29bad0b5 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -158,11 +158,20 @@ def _add_postambles(asm_ops): to_append.extend(_revert_string) if len(to_append) > 0: + # insert the postambles *before* runtime code + # so the data section of the runtime code can't bork the postambles. + runtime = None + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + runtime = asm_ops.pop() + # for some reason there might not be a STOP at the end of asm_ops. # (generally vyper programs will have it but raw IR might not). asm_ops.append("STOP") asm_ops.extend(to_append) + if runtime: + asm_ops.append(runtime) + # need to do this recursively since every sublist is basically # treated as its own program (there are no global labels.) for t in asm_ops: @@ -213,6 +222,9 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): res = _compile_to_assembly(code) _add_postambles(res) + + _relocate_segments(res) + if optimize != OptimizationLevel.NONE: _optimize_assembly(res) return res @@ -500,14 +512,14 @@ def _height_of(witharg): assert isinstance(memsize, int), "non-int memsize" assert isinstance(padding, int), "non-int padding" - begincode = mksymbol("runtime_begin") + runtime_begin = mksymbol("runtime_begin") subcode = _compile_to_assembly(ir) o = [] # COPY the code to memory for deploy - o.extend(["_sym_subcode_size", begincode, "_mem_deploy_start", "CODECOPY"]) + o.extend(["_sym_subcode_size", runtime_begin, "_mem_deploy_start", "CODECOPY"]) # calculate the len of runtime code o.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len @@ -517,10 +529,9 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [f"_DEPLOY_MEM_OFST_{memsize}"] + subcode + subcode = [_RuntimeHeader(runtime_begin, memsize)] + subcode # append the runtime code after the ctor code - o.extend([begincode, "BLANK"]) # `append(...)` call here is intentional. # each sublist is essentially its own program with its # own symbols. @@ -661,16 +672,36 @@ def _height_of(witharg): height, ) + elif code.value == "data": + data_node = [_DataHeader("_sym_" + code.args[0].value)] + + for c in code.args[1:]: + if isinstance(c.value, int): + assert 0 <= c < 256, f"invalid data byte {c}" + data_node.append(c.value) + elif isinstance(c.value, bytes): + data_node.append(c.value) + elif isinstance(c, IRnode): + assert c.value == "symbol" + data_node.extend( + _compile_to_assembly(c, withargs, existing_labels, break_dest, height) + ) + else: + raise ValueError(f"Invalid data: {type(c)} {c}") + + # intentionally return a sublist. + return [data_node] + # jump to a symbol, and push variable # of arguments onto stack elif code.value == "goto": o = [] for i, c in enumerate(reversed(code.args[1:])): o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) - o.extend(["_sym_" + str(code.args[0]), "JUMP"]) + o.extend(["_sym_" + code.args[0].value, "JUMP"]) return o # push a literal symbol elif code.value == "symbol": - return ["_sym_" + str(code.args[0])] + return ["_sym_" + code.args[0].value] # set a symbol as a location. elif code.value == "label": label_name = code.args[0].value @@ -728,8 +759,8 @@ def _height_of(witharg): # inject debug opcode. elif code.value == "pc_debugger": return mkdebug(pc_debugger=True, source_pos=code.source_pos) - else: - raise Exception("Weird code element: " + repr(code)) + else: # pragma: no cover + raise ValueError(f"Weird code element: {type(code)} {code}") def note_line_num(line_number_map, item, pos): @@ -764,11 +795,8 @@ def note_breakpoint(line_number_map, item, pos): def _prune_unreachable_code(assembly): - # In converting IR to assembly we sometimes end up with unreachable - # instructions - POPing to clear the stack or STOPing execution at the - # end of a function that has already returned or reverted. This should - # be addressed in the IR, but for now we do a final sanity check here - # to avoid unnecessary bytecode bloat. + # delete code between terminal ops and JUMPDESTS as those are + # unreachable changed = False i = 0 while i < len(assembly) - 2: @@ -777,7 +805,7 @@ def _prune_unreachable_code(assembly): instr = assembly[i][-1] if assembly[i] in _TERMINAL_OPS and not ( - is_symbol(assembly[i + 1]) and assembly[i + 2] in ("JUMPDEST", "BLANK") + is_symbol(assembly[i + 1]) or isinstance(assembly[i + 1], list) ): changed = True del assembly[i + 1] @@ -889,6 +917,14 @@ def _merge_iszero(assembly): return changed +# a symbol _sym_x in assembly can either mean to push _sym_x to the stack, +# or it can precede a location in code which we want to add to symbol map. +# this helper function tells us if we want to add the previous instruction +# to the symbol map. +def is_symbol_map_indicator(asm_node): + return asm_node == "JUMPDEST" + + def _prune_unused_jumpdests(assembly): changed = False @@ -896,9 +932,17 @@ def _prune_unused_jumpdests(assembly): # find all used jumpdests for i in range(len(assembly) - 1): - if is_symbol(assembly[i]) and assembly[i + 1] != "JUMPDEST": + if is_symbol(assembly[i]) and not is_symbol_map_indicator(assembly[i + 1]): used_jumpdests.add(assembly[i]) + for item in assembly: + if isinstance(item, list) and isinstance(item[0], _DataHeader): + # add symbols used in data sections as they are likely + # used for a jumptable. + for t in item: + if is_symbol(t): + used_jumpdests.add(t) + # delete jumpdests that aren't used i = 0 while i < len(assembly) - 2: @@ -937,7 +981,7 @@ def _stack_peephole_opts(assembly): # optimize assembly, in place def _optimize_assembly(assembly): for x in assembly: - if isinstance(x, list): + if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): _optimize_assembly(x) for _ in range(1024): @@ -970,7 +1014,93 @@ def adjust_pc_maps(pc_maps, ofst): return ret +SYMBOL_SIZE = 2 # size of a PUSH instruction for a code symbol + + +def _data_to_evm(assembly, symbol_map): + ret = bytearray() + assert isinstance(assembly[0], _DataHeader) + for item in assembly[1:]: + if is_symbol(item): + symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") + ret.extend(symbol) + elif isinstance(item, int): + ret.append(item) + elif isinstance(item, bytes): + ret.extend(item) + else: + raise ValueError(f"invalid data {type(item)} {item}") + + return ret + + +# predict what length of an assembly [data] node will be in bytecode +def _length_of_data(assembly): + ret = 0 + assert isinstance(assembly[0], _DataHeader) + for item in assembly[1:]: + if is_symbol(item): + ret += SYMBOL_SIZE + elif isinstance(item, int): + assert 0 <= item < 256, f"invalid data byte {item}" + ret += 1 + elif isinstance(item, bytes): + ret += len(item) + else: + raise ValueError(f"invalid data {type(item)} {item}") + + return ret + + +class _RuntimeHeader: + def __init__(self, label, ctor_mem_size): + self.label = label + self.ctor_mem_size = ctor_mem_size + + def __repr__(self): + return f"" + + +class _DataHeader: + def __init__(self, label): + self.label = label + + def __repr__(self): + return f"DATA {self.label}" + + +def _relocate_segments(assembly): + # relocate all data segments to the end, otherwise data could be + # interpreted as PUSH instructions and mangle otherwise valid jumpdests + # relocate all runtime segments to the end as well + data_segments = [] + non_data_segments = [] + code_segments = [] + for t in assembly: + if isinstance(t, list): + if isinstance(t[0], _DataHeader): + data_segments.append(t) + else: + _relocate_segments(t) # recurse + assert isinstance(t[0], _RuntimeHeader) + code_segments.append(t) + else: + non_data_segments.append(t) + assembly.clear() + assembly.extend(non_data_segments) + assembly.extend(code_segments) + assembly.extend(data_segments) + + +# TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): + bytecode, source_maps, _ = assembly_to_evm_with_symbol_map( + assembly, pc_ofst=pc_ofst, insert_vyper_signature=insert_vyper_signature + ) + return bytecode, source_maps + + +def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature=False): """ Assembles assembly into EVM @@ -999,8 +1129,6 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): bytecode_suffix += b"\xa1\x65vyper\x83" + bytes(list(version_tuple)) bytecode_suffix += len(bytecode_suffix).to_bytes(2, "big") - CODE_OFST_SIZE = 2 # size of a PUSH instruction for a code symbol - # to optimize the size of deploy code - we want to use the smallest # PUSH instruction possible which can support all memory symbols # (and also works with linear pass symbol resolution) @@ -1009,13 +1137,13 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list): + if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): assert runtime_code is None, "Multiple subcodes" - runtime_code, runtime_map = assembly_to_evm(item) - assert item[0].startswith("_DEPLOY_MEM_OFST_") assert ctor_mem_size is None - ctor_mem_size = int(item[0][len("_DEPLOY_MEM_OFST_") :]) + ctor_mem_size = item[0].ctor_mem_size + + runtime_code, runtime_map = assembly_to_evm(item[1:]) runtime_code_start, runtime_code_end = _runtime_code_offsets( ctor_mem_size, len(runtime_code) @@ -1053,14 +1181,14 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): # update pc if is_symbol(item): - if assembly[i + 1] == "JUMPDEST" or assembly[i + 1] == "BLANK": + if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc else: - pc += CODE_OFST_SIZE + 1 # PUSH2 highbits lowbits + pc += SYMBOL_SIZE + 1 # PUSH2 highbits lowbits elif is_mem_sym(item): # PUSH item pc += mem_ofst_size + 1 @@ -1070,19 +1198,16 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif item == "BLANK": - pc += 0 - elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"): - # _DEPLOY_MEM_OFST is assembly magic which will - # get removed during final assembly-to-bytecode - pc += 0 - elif isinstance(item, list): + elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + symbol_map[item[0].label] = pc # add source map for all items in the runtime map t = adjust_pc_maps(runtime_map, pc) for key in line_number_map: line_number_map[key].update(t[key]) pc += len(runtime_code) - + elif isinstance(item, list) and isinstance(item[0], _DataHeader): + symbol_map[item[0].label] = pc + pc += _length_of_data(item) else: pc += 1 @@ -1094,13 +1219,9 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): if runtime_code is not None: symbol_map["_sym_subcode_size"] = len(runtime_code) - # (NOTE CMC 2022-06-17 this way of generating bytecode did not - # seem to be a perf hotspot. if it is, may want to use bytearray() - # instead). - - # TODO refactor into two functions, create posmap and assemble + # TODO refactor into two functions, create symbol_map and assemble - o = b"" + ret = bytearray() # now that all symbols have been resolved, generate bytecode # using the symbol map @@ -1110,47 +1231,47 @@ def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): to_skip -= 1 continue - if item in ("DEBUG", "BLANK"): + if item in ("DEBUG",): continue # skippable opcodes - elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"): - continue - elif is_symbol(item): - if assembly[i + 1] != "JUMPDEST" and assembly[i + 1] != "BLANK": - bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=CODE_OFST_SIZE)) - o += bytecode + # push a symbol to stack + if not is_symbol_map_indicator(assembly[i + 1]): + bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=SYMBOL_SIZE)) + ret.extend(bytecode) elif is_mem_sym(item): bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=mem_ofst_size)) - o += bytecode + ret.extend(bytecode) elif is_ofst(item): # _OFST _sym_foo 32 ofst = symbol_map[assembly[i + 1]] + assembly[i + 2] - n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else CODE_OFST_SIZE + n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else SYMBOL_SIZE bytecode, _ = assembly_to_evm(PUSH_N(ofst, n)) - o += bytecode + ret.extend(bytecode) to_skip = 2 elif isinstance(item, int): - o += bytes([item]) + ret.append(item) elif isinstance(item, str) and item.upper() in get_opcodes(): - o += bytes([get_opcodes()[item.upper()][0]]) + ret.append(get_opcodes()[item.upper()][0]) elif item[:4] == "PUSH": - o += bytes([PUSH_OFFSET + int(item[4:])]) + ret.append(PUSH_OFFSET + int(item[4:])) elif item[:3] == "DUP": - o += bytes([DUP_OFFSET + int(item[3:])]) + ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": - o += bytes([SWAP_OFFSET + int(item[4:])]) - elif isinstance(item, list): - o += runtime_code - else: - # Should never reach because, assembly is create in _compile_to_assembly. - raise Exception("Weird symbol in assembly: " + str(item)) # pragma: no cover + ret.append(SWAP_OFFSET + int(item[4:])) + elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + ret.extend(runtime_code) + elif isinstance(item, list) and isinstance(item[0], _DataHeader): + ret.extend(_data_to_evm(item, symbol_map)) + else: # pragma: no cover + # unreachable + raise ValueError(f"Weird symbol in assembly: {type(item)} {item}") - o += bytecode_suffix + ret.extend(bytecode_suffix) line_number_map["breakpoints"] = list(line_number_map["breakpoints"]) line_number_map["pc_breakpoints"] = list(line_number_map["pc_breakpoints"]) - return o, line_number_map + return bytes(ret), line_number_map, symbol_map From 019a37ab98ff53f04fecfadf602b6cd5ac748f7f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 25 Jul 2023 07:41:12 -0700 Subject: [PATCH 052/201] Merge pull request from GHSA-f5x6-7qgp-jhf3 --- tests/parser/functions/test_ecrecover.py | 18 ++++++++++++++ vyper/builtins/functions.py | 30 ++++++++---------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/tests/parser/functions/test_ecrecover.py b/tests/parser/functions/test_ecrecover.py index 77e9655b3e..40c9a6a936 100644 --- a/tests/parser/functions/test_ecrecover.py +++ b/tests/parser/functions/test_ecrecover.py @@ -40,3 +40,21 @@ def test_ecrecover_uints2() -> address: assert c.test_ecrecover_uints2() == local_account.address print("Passed ecrecover test") + + +def test_invalid_signature(get_contract): + code = """ +dummies: HashMap[address, HashMap[address, uint256]] + +@external +def test_ecrecover(hash: bytes32, v: uint8, r: uint256) -> address: + # read from hashmap to put garbage in 0 memory location + s: uint256 = self.dummies[msg.sender][msg.sender] + return ecrecover(hash, v, r, s) + """ + c = get_contract(code) + hash_ = bytes(i for i in range(32)) + v = 0 # invalid v! ecrecover precompile will not write to output buffer + r = 0 + # note web3.py decoding of 0x000..00 address is None. + assert c.test_ecrecover(hash_, v, r) is None diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index e1dcee6b8d..685d832c01 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -764,29 +764,19 @@ def infer_arg_types(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + input_buf = context.new_internal_variable(get_type_for_exact_size(128)) + output_buf = MemoryPositions.FREE_VAR_SPACE return IRnode.from_list( [ "seq", - ["mstore", placeholder_node, args[0]], - ["mstore", ["add", placeholder_node, 32], args[1]], - ["mstore", ["add", placeholder_node, 64], args[2]], - ["mstore", ["add", placeholder_node, 96], args[3]], - [ - "pop", - [ - "staticcall", - ["gas"], - 1, - placeholder_node, - 128, - MemoryPositions.FREE_VAR_SPACE, - 32, - ], - ], - ["mload", MemoryPositions.FREE_VAR_SPACE], + # clear output memory first, ecrecover can return 0 bytes + ["mstore", output_buf, 0], + ["mstore", input_buf, args[0]], + ["mstore", input_buf + 32, args[1]], + ["mstore", input_buf + 64, args[2]], + ["mstore", input_buf + 96, args[3]], + ["staticcall", "gas", 1, input_buf, 128, output_buf, 32], + ["mload", output_buf], ], typ=AddressT(), ) From d48438e10722db8fc9e74d8ed434745e3b0d31cf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 25 Jul 2023 12:56:31 -0700 Subject: [PATCH 053/201] feat: implement bound= in ranges (#3537) --- .../semantics/analysis/test_for_loop.py | 35 ++++++++++++++++++- .../features/iteration/test_for_range.py | 17 +++++++++ vyper/codegen/stmt.py | 21 ++++++++--- vyper/ir/compile_ir.py | 3 +- vyper/semantics/analysis/annotation.py | 3 ++ vyper/semantics/analysis/local.py | 27 ++++++++++---- 6 files changed, 91 insertions(+), 15 deletions(-) diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py index 8707b4c326..0d61a8f8f8 100644 --- a/tests/functional/semantics/analysis/test_for_loop.py +++ b/tests/functional/semantics/analysis/test_for_loop.py @@ -1,7 +1,12 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ImmutableViolation, TypeMismatch +from vyper.exceptions import ( + ArgumentException, + ImmutableViolation, + StateAccessViolation, + TypeMismatch, +) from vyper.semantics.analysis import validate_semantics @@ -59,6 +64,34 @@ def bar(): validate_semantics(vyper_module, {}) +def test_bad_keywords(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, boundddd=10): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ArgumentException): + validate_semantics(vyper_module, {}) + + +def test_bad_bound(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, bound=n): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(StateAccessViolation): + validate_semantics(vyper_module, {}) + + def test_modify_iterator_function_call(namespace): code = """ diff --git a/tests/parser/features/iteration/test_for_range.py b/tests/parser/features/iteration/test_for_range.py index 30f4bb87e3..395dd28231 100644 --- a/tests/parser/features/iteration/test_for_range.py +++ b/tests/parser/features/iteration/test_for_range.py @@ -14,6 +14,23 @@ def repeat(z: int128) -> int128: assert c.repeat(9) == 54 +def test_range_bound(get_contract, assert_tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(7): + assert c.repeat(n) == sum(range(n)) + + # check codegen inserts assertion for n greater than bound + assert_tx_failed(lambda: c.repeat(7)) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 91d45f4916..86ea1813ea 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -258,11 +258,17 @@ def _parse_For_range(self): arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) + for s in self.stmt.iter.keywords or [] + } + # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: - arg0_val = self._get_range_const_value(arg0) + n = Expr.parse_value_expr(arg0, self.context) start = IRnode.from_list(0, typ=iter_typ) - rounds = arg0_val + rounds = n + rounds_bound = kwargs.get("bound", rounds) # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: @@ -270,6 +276,7 @@ def _parse_For_range(self): arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = IRnode.from_list(arg0_val, typ=iter_typ) rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) + rounds_bound = rounds # Type 3 for, e.g. for i in range(x, x + 10): ... else: @@ -278,9 +285,10 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) _, hi = start.typ.int_bounds start = clamp("le", start, hi + 1 - rounds) + rounds_bound = rounds - r = rounds if isinstance(rounds, int) else rounds.value - if r < 1: + bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value + if bound < 1: return varname = self.stmt.target.id @@ -294,7 +302,10 @@ def _parse_For_range(self): loop_body.append(["mstore", iptr, i]) loop_body.append(parse_body(self.stmt.body, self.context)) - ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds, loop_body]) + # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound. + # if we ever want to remove that, we need to manually add the assertion + # where it makes sense. + ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds_bound, loop_body]) del self.context.forvars[varname] return ir_node diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 5e29bad0b5..bba3b34515 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -413,9 +413,8 @@ def _height_of(witharg): ) # stack: i, rounds, rounds_bound # assert rounds <= rounds_bound - # TODO this runtime assertion should never fail for + # TODO this runtime assertion shouldn't fail for # internally generated repeats. - # maybe drop it or jump to 0xFE o.extend(["DUP2", "GT"] + _assert_false()) # stack: i, rounds diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index 3ea0319b54..d309f102cd 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -95,6 +95,9 @@ def visit_For(self, node): iter_type = node.target._metadata["type"] for a in node.iter.args: self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) class ExpressionAnnotationVisitor(_AnnotationVisitorBase): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c99b582ad3..c0c05325f2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -346,17 +346,30 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - validate_call_args(node.iter, (1, 2)) + validate_call_args(node.iter, (1, 2), kwargs=["bound"]) args = node.iter.args + kwargs = {s.arg: s.value for s in node.iter.keywords or []} if len(args) == 1: # range(CONSTANT) - if not isinstance(args[0], vy_ast.Num): - raise StateAccessViolation("Value must be a literal", node) - if args[0].value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - validate_expected_type(args[0], IntegerT.any()) - type_list = get_possible_types_from_node(args[0]) + n = args[0] + bound = kwargs.pop("bound", None) + validate_expected_type(n, IntegerT.any()) + + if bound is None: + if not isinstance(n, vy_ast.Num): + raise StateAccessViolation("Value must be a literal", n) + if n.value <= 0: + raise StructureException("For loop must have at least 1 iteration", args[0]) + type_list = get_possible_types_from_node(n) + + else: + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("bound must be at least 1", args[0]) + type_list = get_common_types(n, bound) + else: validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) From 2f39e69d077fb8ab90bd6fe039372dd4fe5cadde Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 25 Jul 2023 16:33:41 -0700 Subject: [PATCH 054/201] fix: public constant arrays (#3536) public getters for arrays would panic at codegen because type information for the array members was not available. this is because type annotation would occur before getter expansion. this commit moves the type annotation phase to right before getter expansion, so that the generated ast nodes will get annotated. it also fixes a small bug when trying to deepcopy the nodes generated by ast expansion - the generated nodes have no node_id and raise an exception when deepcopy tries to perform `__eq__` between two of the generated FunctionDefs. --- tests/parser/globals/test_getters.py | 2 ++ vyper/ast/expansion.py | 1 - vyper/ast/nodes.py | 2 +- vyper/compiler/phases.py | 1 - vyper/semantics/analysis/__init__.py | 3 +++ 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/parser/globals/test_getters.py b/tests/parser/globals/test_getters.py index 59c91cbeef..5eac074ef6 100644 --- a/tests/parser/globals/test_getters.py +++ b/tests/parser/globals/test_getters.py @@ -35,6 +35,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants): c: public(constant(uint256)) = 1 d: public(immutable(uint256)) e: public(immutable(uint256[2])) +f: public(constant(uint256[2])) = [3, 7] @external def __init__(): @@ -68,6 +69,7 @@ def __init__(): assert c.c() == 1 assert c.d() == 1729 assert c.e(0) == 2 + assert [c.f(i) for i in range(2)] == [3, 7] def test_getter_mutability(get_contract): diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 753f2687cd..5471b971a4 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -49,7 +49,6 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: # the base return statement is an `Attribute` node, e.g. `self.` # for each input type we wrap it in a `Subscript` to access a specific member return_stmt = vy_ast.Attribute(value=vy_ast.Name(id="self"), attr=func_type.name) - return_stmt._metadata["type"] = node._metadata["type"] for i, type_ in enumerate(input_types): if not isinstance(annotation, vy_ast.Subscript): diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 7c907b4d08..2497928035 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -339,7 +339,7 @@ def __hash__(self): def __eq__(self, other): if not isinstance(other, type(self)): return False - if other.node_id != self.node_id: + if getattr(other, "node_id", None) != getattr(self, "node_id", None): return False for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__): if getattr(self, field_name, None) != getattr(other, field_name, None): diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e1bd9e6c3..526d2f3253 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -263,7 +263,6 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) validate_semantics(vyper_module_folded, interface_codes) - vy_ast.expansion.expand_annotated_ast(vyper_module_folded) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) return vyper_module_folded, symbol_tables diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 5977a87812..9e987d1cd0 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,3 +1,5 @@ +import vyper.ast as vy_ast + from .. import types # break a dependency cycle. from ..namespace import get_namespace from .local import validate_functions @@ -11,4 +13,5 @@ def validate_semantics(vyper_ast, interface_codes): with namespace.enter_scope(): add_module_namespace(vyper_ast, interface_codes) + vy_ast.expansion.expand_annotated_ast(vyper_ast) validate_functions(vyper_ast) From 3c3285c2f15a88c84574dcca1958a282d4910e5f Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 26 Jul 2023 23:25:47 +0900 Subject: [PATCH 055/201] docs: fix typo in release-notes.rst (#3538) unitialized -> uninitialized --- docs/release-notes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index dcdbcda74a..5b6880dfdc 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -774,7 +774,7 @@ The following VIPs were implemented for Beta 13: - Add ``vyper-json`` compilation mode (VIP `#1520 `_) - Environment variables and constants can now be used as default parameters (VIP `#1525 `_) -- Require unitialized memory be set on creation (VIP `#1493 `_) +- Require uninitialized memory be set on creation (VIP `#1493 `_) Some of the bug and stability fixes: From 76f1cc5a8b288696446ac08d9099bf643d132c73 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 27 Jul 2023 10:48:21 -0700 Subject: [PATCH 056/201] chore: add error message for repeat range check (#3542) since d48438e and 3de1415, loops can revert depending on user input. add it to the error map so it's easier for users to debug. --- vyper/codegen/stmt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 86ea1813ea..9dc75b46ba 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -305,7 +305,9 @@ def _parse_For_range(self): # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound. # if we ever want to remove that, we need to manually add the assertion # where it makes sense. - ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds_bound, loop_body]) + ir_node = IRnode.from_list( + ["repeat", i, start, rounds, rounds_bound, loop_body], error_msg="range() bounds check" + ) del self.context.forvars[varname] return ir_node From cfda16c734ecddc170079817cd96b14e4fe24586 Mon Sep 17 00:00:00 2001 From: Pascal Marco Caversaccio Date: Mon, 31 Jul 2023 16:24:11 +0200 Subject: [PATCH 057/201] Use `0.3.7` as example in `Installing Vyper` (#3543) --- docs/installing-vyper.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installing-vyper.rst b/docs/installing-vyper.rst index 2e2d51bd6e..249182a1c2 100644 --- a/docs/installing-vyper.rst +++ b/docs/installing-vyper.rst @@ -76,7 +76,7 @@ Each tagged version of vyper is uploaded to `pypi Date: Thu, 3 Aug 2023 09:19:35 +0800 Subject: [PATCH 058/201] docs: fix yanked version in release notes (#3545) --- docs/release-notes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 5b6880dfdc..5dc33a49c6 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -403,6 +403,7 @@ Fixes: v0.2.14 ******* +**THIS RELEASE HAS BEEN PULLED** Date released: 20-07-2021 @@ -414,7 +415,6 @@ Fixes: v0.2.13 ******* -**THIS RELEASE HAS BEEN PULLED** Date released: 06-07-2021 From 855f7349668d0907f968f0f5f41b64730f4dd13f Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Thu, 3 Aug 2023 16:21:42 +0200 Subject: [PATCH 059/201] docs: update release notes / yanked versions (#3547) --- docs/release-notes.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 5dc33a49c6..9a6384697b 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -187,6 +187,7 @@ Bugfixes: v0.3.5 ****** +**THIS RELEASE HAS BEEN PULLED** Date released: 2022-08-05 @@ -415,6 +416,7 @@ Fixes: v0.2.13 ******* +**THIS RELEASE HAS BEEN PULLED** Date released: 06-07-2021 @@ -521,6 +523,7 @@ Fixes: v0.2.6 ****** +**THIS RELEASE HAS BEEN PULLED** Date released: 10-10-2020 From b87889974b9a600624e10ab4c46adfd2c1f930ff Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 5 Aug 2023 23:47:12 +0800 Subject: [PATCH 060/201] docs: epsilon builtin (#3552) --- docs/built-in-functions.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index 74e8560498..84859d66c2 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -573,6 +573,24 @@ Math >>> ExampleContract.foo(3.1337) 4 +.. py:function:: epsilon(typename) -> Any + + Returns the smallest non-zero value for a decimal type. + + * ``typename``: Name of the decimal type (currently only ``decimal``) + + .. code-block:: python + + @external + @view + def foo() -> decimal: + return epsilon(decimal) + + .. code-block:: python + + >>> ExampleContract.foo() + Decimal('1E-10') + .. py:function:: floor(value: decimal) -> int256 Round a decimal down to the nearest integer. From cc2a5cd696f9720683a19a9490119ee7297a4192 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Sun, 6 Aug 2023 17:08:19 +0200 Subject: [PATCH 061/201] docs: note on security advisory in release notes for versions `0.2.15`, `0.2.16`, and `0.3.0` (#3553) * Add note on security advisory in release notes for `0.2.15`, `0.2.16`, and `0.3.0` * Add link to `0.3.1` release --- docs/release-notes.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 9a6384697b..f408c5c0ab 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -336,6 +336,7 @@ Special thanks to @skellet0r for some major features in this release! v0.3.0 ******* +⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_. Date released: 2021-10-04 @@ -368,6 +369,7 @@ Special thanks to contributions from @skellet0r and @benjyz for this release! v0.2.16 ******* +⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_. Date released: 2021-08-27 @@ -392,6 +394,7 @@ Special thanks to contributions from @skellet0r, @sambacha and @milancermak for v0.2.15 ******* +⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_. Date released: 23-07-2021 From 728a27677240fdd55a4144d04b31004f8330847c Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Mon, 7 Aug 2023 06:14:30 +0200 Subject: [PATCH 062/201] docs: add security advisory note for `ecrecover` (#3539) Co-authored-by: Charles Cooper --- docs/built-in-functions.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index 84859d66c2..bfaa8fdd5e 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -379,7 +379,11 @@ Cryptography * ``s``: second 32 bytes of signature * ``v``: final 1 byte of signature - Returns the associated address, or ``0`` on error. + Returns the associated address, or ``empty(address)`` on error. + + .. note:: + + Prior to Vyper ``0.3.10``, the ``ecrecover`` function could return an undefined (possibly nonzero) value for invalid inputs to ``ecrecover``. For more information, please see `GHSA-f5x6-7qgp-jhf3 `_. .. code-block:: python From 43c8d8519a67a2b5da664d85ab6207a789707c1f Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 8 Aug 2023 00:35:35 +0800 Subject: [PATCH 063/201] fix: guard against kwargs for `range` expressions with two arguments (#3551) and slight refactor -- extract `node.iter` expr to `range_` for clarity --------- Co-authored-by: Charles Cooper --- tests/parser/syntax/test_for_range.py | 21 ++++++++++++++++++++- vyper/semantics/analysis/local.py | 14 +++++++++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/parser/syntax/test_for_range.py b/tests/parser/syntax/test_for_range.py index b2a9491058..e6f35c1d2d 100644 --- a/tests/parser/syntax/test_for_range.py +++ b/tests/parser/syntax/test_for_range.py @@ -12,7 +12,26 @@ def foo(): pass """, StructureException, - ) + ), + ( + """ +@external +def bar(): + for i in range(1,2,bound=2): + pass + """, + StructureException, + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x,x+1,bound=2): + pass + """, + StructureException, + ), ] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c0c05325f2..c10df3b8fd 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -346,10 +346,11 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - validate_call_args(node.iter, (1, 2), kwargs=["bound"]) + range_ = node.iter + validate_call_args(range_, (1, 2), kwargs=["bound"]) - args = node.iter.args - kwargs = {s.arg: s.value for s in node.iter.keywords or []} + args = range_.args + kwargs = {s.arg: s.value for s in range_.keywords or []} if len(args) == 1: # range(CONSTANT) n = args[0] @@ -371,6 +372,13 @@ def visit_For(self, node): type_list = get_common_types(n, bound) else: + if range_.keywords: + raise StructureException( + "Keyword arguments are not supported for `range(N, M)` and" + "`range(x, x + N)` expressions", + range_.keywords[0], + ) + validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) if not isinstance(args[0], vy_ast.Constant): From f72ad784d9cbf85235ee61d29f6571e1dfc48229 Mon Sep 17 00:00:00 2001 From: mahdiRostami Date: Mon, 21 Aug 2023 18:18:10 +0330 Subject: [PATCH 064/201] docs: add `vyper --version` (#3558) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add “vyper --version” to installation instructions --- docs/installing-vyper.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/installing-vyper.rst b/docs/installing-vyper.rst index 249182a1c2..fb2849708d 100644 --- a/docs/installing-vyper.rst +++ b/docs/installing-vyper.rst @@ -78,7 +78,11 @@ To install a specific version use: pip install vyper==0.3.7 +You can check if Vyper is installed completely or not by typing the following in your terminal/cmd: +:: + + vyper --version nix *** From 158099b9c1a49b5472293c1fb7a4baf3cd015eb5 Mon Sep 17 00:00:00 2001 From: Shmuel Kroizer <69422117+shmuel44@users.noreply.github.com> Date: Sun, 27 Aug 2023 05:59:27 +0300 Subject: [PATCH 065/201] chore: update flake8 from gitlab to github (#3566) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 739e977c96..4b416a4414 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: black name: black -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 From c28f14f757e17a132cff0236ee0cadb61513aa90 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 12:59:02 -0400 Subject: [PATCH 066/201] chore: fix loop variable typing (#3571) there is an inconsistency between codegen and typechecking types when a loop iterates over a literal list. in this code, although it compiles, during typechecking, `i` is typed as a `uint8` while `[1,2,3]` is annotated with `int8[3]` ``` @external def foo(): for i in [1,2,3]: a: uint8 = i ``` since the iterator type is always correct, this commit is a chore since it fixes the discrepancy, but there is no known way to "abuse" the behavior to get a wrong codegen type. chainsec june 2023 review 5.15 --- vyper/codegen/stmt.py | 4 +--- vyper/semantics/analysis/annotation.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 9dc75b46ba..3ecb0afdc3 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -316,10 +316,8 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - # override with type inferred at typechecking time - # TODO investigate why stmt.target.type != stmt.iter.type.value_type target_type = self.stmt.target._metadata["type"] - iter_list.typ.value_type = target_type + assert target_type == iter_list.typ.value_type # user-supplied name for loop variable varname = self.stmt.target.id diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index d309f102cd..01ca51d7f4 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -85,14 +85,14 @@ def visit_Return(self, node): def visit_For(self, node): if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): self.expr_visitor.visit(node.iter) - # typecheck list literal as static array + + iter_type = node.target._metadata["type"] if isinstance(node.iter, vy_ast.List): - value_type = get_common_types(*node.iter.elements).pop() + # typecheck list literal as static array len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(value_type, len_)) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - iter_type = node.target._metadata["type"] for a in node.iter.args: self.expr_visitor.visit(a, iter_type) for a in node.iter.keywords: From 6ea56a6eb40a7225f42765d1bedc386bd2c6166d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 13:41:17 -0400 Subject: [PATCH 067/201] chore: fix dead parameter usages (#3575) `_is_function_implemented` did not use its parameter `fn_name`, it used the captured `name` variable, which happened to be the same as `fn_name`. chainsec june 2023 review 6.2 --- vyper/codegen/expr.py | 6 ++---- vyper/semantics/types/user.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index d637a454bc..fa3b8bb498 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -662,9 +662,7 @@ def parse_Call(self): elif isinstance(self.expr._metadata["type"], StructT): args = self.expr.args if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals( - args[0], function_name, self.context, self.expr._metadata["type"] - ) + return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) # Interface assignment. Bar(
). elif isinstance(self.expr._metadata["type"], InterfaceT): @@ -733,7 +731,7 @@ def parse_IfExp(self): return IRnode.from_list(["if", test, body, orelse], typ=typ, location=location) @staticmethod - def struct_literals(expr, name, context, typ): + def struct_literals(expr, context, typ): member_subs = {} member_typs = {} for key, value in zip(expr.keys, expr.values): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a603691410..a71f852dbf 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -313,11 +313,11 @@ def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: def _is_function_implemented(fn_name, fn_type): vyper_self = namespace["self"].typ - if name not in vyper_self.members: + if fn_name not in vyper_self.members: return False - s = vyper_self.members[name] + s = vyper_self.members[fn_name] if isinstance(s, ContractFunctionT): - to_compare = vyper_self.members[name] + to_compare = vyper_self.members[fn_name] # this is kludgy, rework order of passes in ModuleNodeVisitor elif isinstance(s, VarInfo) and s.is_public: to_compare = s.decl_node._metadata["func_type"] From 17e730a044c24d4ec99fa766331eccf7a2c1effa Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 14:01:01 -0400 Subject: [PATCH 068/201] chore: add sanity check in parse_BinOp (#3567) add sanity check in parse_BinOp, we can be stricter in the case where it's a shift binop. chainsec june 2023 review 5.4 --- vyper/codegen/expr.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index fa3b8bb498..dc0e98786f 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -368,13 +368,17 @@ def parse_BinOp(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) - if not isinstance(self.expr.op, (vy_ast.LShift, vy_ast.RShift)): + is_shift_op = isinstance(self.expr.op, (vy_ast.LShift, vy_ast.RShift)) + + if is_shift_op: + assert is_numeric_type(left.typ) + assert is_numeric_type(right.typ) + else: # Sanity check - ensure that we aren't dealing with different types # This should be unreachable due to the type check pass if left.typ != right.typ: raise TypeCheckFailure(f"unreachable, {left.typ} != {right.typ}", self.expr) - - assert is_numeric_type(left.typ) or is_enum_type(left.typ) + assert is_numeric_type(left.typ) or is_enum_type(left.typ) out_typ = left.typ From 6a819b1db8b3812b8e814de680b589aae6ddd203 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 16:20:21 -0400 Subject: [PATCH 069/201] chore: fix args passed to `validate_call_args` (#3568) `validate_call_args` takes kwargs, the list of valid keywords as an argument and makes sure that when a call is made, the given keywords are valid according to the passed kwargs. however, vyper does not allow kwargs when calling internal functions, so we should actually pass no kwargs to `validate_call_args`. note that this PR does not actually introduce observed changes in compiler behavior, as the later check in `fetch_call_return` correctly validates there are no call site kwargs for internal functions. chainsec june review 5.7 --- vyper/semantics/types/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 10711edc8e..506dae135c 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -488,8 +488,8 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL: raise CallViolation("Cannot call external functions via 'self'", node) + kwarg_keys = [] # for external calls, include gas and value as optional kwargs - kwarg_keys = [arg.name for arg in self.keyword_args] if not self.is_internal: kwarg_keys += list(self.call_site_kwargs.keys()) validate_call_args(node, (self.n_positional_args, self.n_total_args), kwarg_keys) From fa89ca2f6d09a42c0349a8b22eeba281039c85a1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 16:25:39 -0400 Subject: [PATCH 070/201] chore: note `Context.in_assertion` is dead (#3564) the `Context` class has an `in_assertion` flag which, when set, indicates that the context should be constant according to the definition of `is_constant()`. however, this flag is never set during code generation, specifically, it is possible to have a non-constant expression in an assert statement. for example, the following contract compiles: ```vyper x: uint256 @internal def bar() -> uint256: self.x = 1 return self.x @external def foo(): assert self.bar() == 1 ``` chainsec june 2023 review 5.5 --- vyper/codegen/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index e4b41adbc0..5b79f293bd 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -69,6 +69,7 @@ def __init__( self.constancy = constancy # Whether body is currently in an assert statement + # XXX: dead, never set to True self.in_assertion = False # Whether we are currently parsing a range expression From ef1c589f1e3488b26de9edd078a7340cac1298a4 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 31 Aug 2023 16:26:13 -0400 Subject: [PATCH 071/201] refactor: initcode generation (#3574) move internal function generation to after ctor generation. prior to this commit, the existing code relies on the fact that the code generation of runtime internal functions properly sets the frame information of the ctor's callees. if this precondition is not met in the future, the compiler could panic because the memory allocation info will not be available. chainsec june 2023 review 6.2 --- vyper/codegen/module.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index ebe7f92cf2..8caea9ee9b 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -442,6 +442,20 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: deploy_code: List[Any] = ["seq"] immutables_len = global_ctx.immutable_section_bytes if init_function: + # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` + ctor_internal_func_irs = [] + internal_functions = [f for f in runtime_functions if _is_internal(f)] + for f in internal_functions: + init_func_t = init_function._metadata["type"] + if f.name not in init_func_t.recursive_calls: + # unreachable code, delete it + continue + + func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) + ctor_internal_func_irs.append(func_ir) + + # generate init_func_ir after callees to ensure they have analyzed + # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) @@ -468,19 +482,9 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: deploy_code.append(["iload", max(0, immutables_len - 32)]) deploy_code.append(init_func_ir) - deploy_code.append(["deploy", init_mem_used, runtime, immutables_len]) - - # internal functions come after everything else - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - init_func_t = init_function._metadata["type"] - if f.name not in init_func_t.recursive_calls: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) - deploy_code.append(func_ir) + # internal functions come at end of initcode + deploy_code.extend(ctor_internal_func_irs) else: if immutables_len != 0: From a19cdeaf84e4c70aa6517a1535fbe442cd6059f3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 1 Sep 2023 13:25:12 -0400 Subject: [PATCH 072/201] feat: relax restrictions on internal function signatures (#3573) relax the restriction on unique "method ids" for internal methods. the check used to be there to avoid collisions between external method ids and internal "method ids" because the calling convention for internal functions used to involve the method id as part of the signature, but that is no longer the case. so we can safely allow collision between internal "method ids" and external method ids. cf. issue #1687 which was resolved in in 9e8c661494d84fbf. chainsec june 2023 review 5.22 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> Co-authored-by: trocher --- .../syntax/utils/test_function_names.py | 40 +++++++++++++++++++ tests/signatures/test_method_id_conflicts.py | 20 ---------- vyper/semantics/analysis/module.py | 11 +---- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/tests/parser/syntax/utils/test_function_names.py b/tests/parser/syntax/utils/test_function_names.py index 90e185558c..5489a4f6a0 100644 --- a/tests/parser/syntax/utils/test_function_names.py +++ b/tests/parser/syntax/utils/test_function_names.py @@ -23,6 +23,22 @@ def wei(i: int128) -> int128: temp_var : int128 = i return temp_var1 """, + # collision between getter and external function + """ +foo: public(uint256) + +@external +def foo(): + pass + """, + # collision between getter and external function, reverse order + """ +@external +def foo(): + pass + +foo: public(uint256) + """, ] @@ -77,6 +93,30 @@ def append(): def foo(): self.append() """, + # "method id" collisions between internal functions are allowed + """ +@internal +@view +def gfah(): + pass + +@internal +@view +def eexo(): + pass + """, + # "method id" collisions between internal+external functions are allowed + """ +@internal +@view +def gfah(): + pass + +@external +@view +def eexo(): + pass + """, ] diff --git a/tests/signatures/test_method_id_conflicts.py b/tests/signatures/test_method_id_conflicts.py index 35c10300b4..f3312efeab 100644 --- a/tests/signatures/test_method_id_conflicts.py +++ b/tests/signatures/test_method_id_conflicts.py @@ -48,26 +48,6 @@ def OwnerTransferV7b711143(a: uint256): pass """, """ -# check collision between private method IDs -@internal -@view -def gfah(): pass - -@internal -@view -def eexo(): pass - """, - """ -# check collision between private and public IDs -@internal -@view -def gfah(): pass - -@external -@view -def eexo(): pass - """, - """ # check collision with ID = 0x00000000 wycpnbqcyf:public(uint256) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index d916dcf119..02ae82faac 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -22,11 +22,7 @@ from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import ( - check_constant, - validate_expected_type, - validate_unique_method_ids, -) +from vyper.semantics.analysis.utils import check_constant, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT @@ -90,6 +86,7 @@ def __init__( err_list.raise_if_not_empty() # generate an `InterfaceT` from the top-level node - used for building the ABI + # note: also validates unique method ids interface = InterfaceT.from_ast(module_node) module_node._metadata["type"] = interface self.interface = interface # this is useful downstream @@ -102,11 +99,7 @@ def __init__( module_node._metadata["namespace"] = _ns # check for collisions between 4byte function selectors - # internal functions are intentionally included in this check, to prevent breaking - # changes in in case of a future change to their calling convention self_members = namespace["self"].typ.members - functions = [i for i in self_members.values() if isinstance(i, ContractFunctionT)] - validate_unique_method_ids(functions) # get list of internal function calls made by each function function_defs = self.ast.get_children(vy_ast.FunctionDef) From 572b38c839c65ef032aa58f656194205bf4ecce7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 1 Sep 2023 13:32:45 -0400 Subject: [PATCH 073/201] fix: raw_call type when max_outsize=0 is set (#3572) prior to this commit, when `raw_call` is used with `max_outsize` explicitly set to 0 (`max_outsize=0`) the compiler incorrectly infers that raw_call has no return type ```vyper @external @payable def foo(_target: address): # compiles a: bool = raw_call(_target, method_id("foo()"), revert_on_failure=False) # should have same behavior, but prior to this commit does not compile: b: bool = raw_call(_target, method_id("foo()"), max_outsize=0, revert_on_failure=False) ``` chainsec june 2023 review 5.16 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> --- tests/parser/functions/test_raw_call.py | 63 +++++++++++++++++++++++++ vyper/builtins/functions.py | 2 +- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/parser/functions/test_raw_call.py b/tests/parser/functions/test_raw_call.py index 95db070ffa..9c6fba79e7 100644 --- a/tests/parser/functions/test_raw_call.py +++ b/tests/parser/functions/test_raw_call.py @@ -1,6 +1,7 @@ import pytest from hexbytes import HexBytes +from vyper import compile_code from vyper.builtins.functions import eip1167_bytecode from vyper.exceptions import ArgumentException, InvalidType, StateAccessViolation @@ -260,6 +261,68 @@ def __default__(): w3.eth.send_transaction({"to": caller.address, "data": sig}) +# check max_outsize=0 does same thing as not setting max_outsize. +# compile to bytecode and compare bytecode directly. +def test_max_outsize_0(): + code1 = """ +@external +def test_raw_call(_target: address): + raw_call(_target, method_id("foo()")) + """ + code2 = """ +@external +def test_raw_call(_target: address): + raw_call(_target, method_id("foo()"), max_outsize=0) + """ + output1 = compile_code(code1, ["bytecode", "bytecode_runtime"]) + output2 = compile_code(code2, ["bytecode", "bytecode_runtime"]) + assert output1 == output2 + + +# check max_outsize=0 does same thing as not setting max_outsize, +# this time with revert_on_failure set to False +def test_max_outsize_0_no_revert_on_failure(): + code1 = """ +@external +def test_raw_call(_target: address) -> bool: + # compile raw_call both ways, with revert_on_failure + a: bool = raw_call(_target, method_id("foo()"), revert_on_failure=False) + return a + """ + # same code, but with max_outsize=0 + code2 = """ +@external +def test_raw_call(_target: address) -> bool: + a: bool = raw_call(_target, method_id("foo()"), max_outsize=0, revert_on_failure=False) + return a + """ + output1 = compile_code(code1, ["bytecode", "bytecode_runtime"]) + output2 = compile_code(code2, ["bytecode", "bytecode_runtime"]) + assert output1 == output2 + + +# test functionality of max_outsize=0 +def test_max_outsize_0_call(get_contract): + target_source = """ +@external +@payable +def bar() -> uint256: + return 123 + """ + + caller_source = """ +@external +@payable +def foo(_addr: address) -> bool: + success: bool = raw_call(_addr, method_id("bar()"), max_outsize=0, revert_on_failure=False) + return success + """ + + target = get_contract(target_source) + caller = get_contract(caller_source) + assert caller.foo(target.address) is True + + def test_static_call_fails_nonpayable(get_contract, assert_tx_failed): target_source = """ baz: int128 diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 685d832c01..e8e001306c 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1093,7 +1093,7 @@ def fetch_call_return(self, node): revert_on_failure = kwargz.get("revert_on_failure") revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True - if outsize is None: + if outsize is None or outsize.value == 0: if revert_on_failure: return None return BoolT() From 2c21eab442f6feeac6bc92b95347f2d3968b09a6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 1 Sep 2023 17:05:44 -0400 Subject: [PATCH 074/201] fix: implements check for indexed event arguments (#3570) prior to this commit, implementing an interface with wrong indexed arguments for an event would pass the implements check. this commit fixes the behavior. chainsec june 2023 review 5.12 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> --- tests/parser/functions/test_interfaces.py | 60 +++++++++++++++++- tests/parser/syntax/test_interfaces.py | 76 +++++++++++++++++++++++ vyper/builtins/interfaces/ERC721.py | 16 ++--- vyper/semantics/types/user.py | 23 +++++-- 4 files changed, 160 insertions(+), 15 deletions(-) diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py index e43c080d46..c16e188cfd 100644 --- a/tests/parser/functions/test_interfaces.py +++ b/tests/parser/functions/test_interfaces.py @@ -67,7 +67,6 @@ def test_basic_interface_implements(assert_compile_failed): implements: ERC20 - @external def test() -> bool: return True @@ -146,6 +145,7 @@ def bar() -> uint256: ) +# check that event types match def test_malformed_event(assert_compile_failed): interface_code = """ event Foo: @@ -173,6 +173,64 @@ def bar() -> uint256: ) +# check that event non-indexed arg needs to match interface +def test_malformed_events_indexed(assert_compile_failed): + interface_code = """ +event Foo: + a: uint256 + """ + + interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + + not_implemented_code = """ +import a as FooBarInterface + +implements: FooBarInterface + +# a should not be indexed +event Foo: + a: indexed(uint256) + +@external +def bar() -> uint256: + return 1 + """ + + assert_compile_failed( + lambda: compile_code(not_implemented_code, interface_codes=interface_codes), + InterfaceViolation, + ) + + +# check that event indexed arg needs to match interface +def test_malformed_events_indexed2(assert_compile_failed): + interface_code = """ +event Foo: + a: indexed(uint256) + """ + + interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + + not_implemented_code = """ +import a as FooBarInterface + +implements: FooBarInterface + +# a should be indexed +event Foo: + a: uint256 + +@external +def bar() -> uint256: + return 1 + """ + + assert_compile_failed( + lambda: compile_code(not_implemented_code, interface_codes=interface_codes), + InterfaceViolation, + ) + + VALID_IMPORT_CODE = [ # import statement, import path without suffix ("import a as Foo", "a"), diff --git a/tests/parser/syntax/test_interfaces.py b/tests/parser/syntax/test_interfaces.py index acadaff20d..5afb34e6bd 100644 --- a/tests/parser/syntax/test_interfaces.py +++ b/tests/parser/syntax/test_interfaces.py @@ -134,6 +134,82 @@ def f(a: uint256): # visibility is nonpayable instead of view """, InterfaceViolation, ), + ( + # `receiver` of `Transfer` event should be indexed + """ +from vyper.interfaces import ERC20 + +implements: ERC20 + +event Transfer: + sender: indexed(address) + receiver: address + value: uint256 + +event Approval: + owner: indexed(address) + spender: indexed(address) + value: uint256 + +name: public(String[32]) +symbol: public(String[32]) +decimals: public(uint8) +balanceOf: public(HashMap[address, uint256]) +allowance: public(HashMap[address, HashMap[address, uint256]]) +totalSupply: public(uint256) + +@external +def transfer(_to : address, _value : uint256) -> bool: + return True + +@external +def transferFrom(_from : address, _to : address, _value : uint256) -> bool: + return True + +@external +def approve(_spender : address, _value : uint256) -> bool: + return True + """, + InterfaceViolation, + ), + ( + # `value` of `Transfer` event should not be indexed + """ +from vyper.interfaces import ERC20 + +implements: ERC20 + +event Transfer: + sender: indexed(address) + receiver: indexed(address) + value: indexed(uint256) + +event Approval: + owner: indexed(address) + spender: indexed(address) + value: uint256 + +name: public(String[32]) +symbol: public(String[32]) +decimals: public(uint8) +balanceOf: public(HashMap[address, uint256]) +allowance: public(HashMap[address, HashMap[address, uint256]]) +totalSupply: public(uint256) + +@external +def transfer(_to : address, _value : uint256) -> bool: + return True + +@external +def transferFrom(_from : address, _to : address, _value : uint256) -> bool: + return True + +@external +def approve(_spender : address, _value : uint256) -> bool: + return True + """, + InterfaceViolation, + ), ] diff --git a/vyper/builtins/interfaces/ERC721.py b/vyper/builtins/interfaces/ERC721.py index 29ef5f4c26..8dea4e4976 100644 --- a/vyper/builtins/interfaces/ERC721.py +++ b/vyper/builtins/interfaces/ERC721.py @@ -2,18 +2,18 @@ # Events event Transfer: - _from: address - _to: address - _tokenId: uint256 + _from: indexed(address) + _to: indexed(address) + _tokenId: indexed(uint256) event Approval: - _owner: address - _approved: address - _tokenId: uint256 + _owner: indexed(address) + _approved: indexed(address) + _tokenId: indexed(uint256) event ApprovalForAll: - _owner: address - _operator: address + _owner: indexed(address) + _operator: indexed(address) _approved: bool # Functions diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a71f852dbf..ce82731c34 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -164,6 +164,7 @@ def __init__(self, name: str, arguments: dict, indexed: list) -> None: super().__init__(members=arguments) self.name = name self.indexed = indexed + assert len(self.indexed) == len(self.arguments) self.event_id = int(keccak256(self.signature.encode()).hex(), 16) # backward compatible @@ -172,8 +173,13 @@ def arguments(self): return self.members def __repr__(self): - arg_types = ",".join(repr(a) for a in self.arguments.values()) - return f"event {self.name}({arg_types})" + args = [] + for is_indexed, (_, argtype) in zip(self.indexed, self.arguments.items()): + argtype_str = repr(argtype) + if is_indexed: + argtype_str = f"indexed({argtype_str})" + args.append(f"{argtype_str}") + return f"event {self.name}({','.join(args)})" # TODO rename to abi_signature @property @@ -337,12 +343,17 @@ def _is_function_implemented(fn_name, fn_type): # check for missing events for name, event in self.events.items(): + if name not in namespace: + unimplemented.append(name) + continue + + if not isinstance(namespace[name], EventT): + unimplemented.append(f"{name} is not an event!") if ( - name not in namespace - or not isinstance(namespace[name], EventT) - or namespace[name].event_id != event.event_id + namespace[name].event_id != event.event_id + or namespace[name].indexed != event.indexed ): - unimplemented.append(name) + unimplemented.append(f"{name} is not implemented! (should be {event})") if len(unimplemented) > 0: # TODO: improve the error message for cases where the From 78fa8dd8f91ba0cb26277eeffb585c68c83e7daa Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 08:26:37 -0400 Subject: [PATCH 075/201] fix: order of evaluation for some builtins (#3583) ecadd, ecmul, addmod, mulmod in the case that the arguments have side effects, they could be evaluated out of order chainsec june 2023 review 5.1 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> Co-authored-by: trocher <43437004+trocher@users.noreply.github.com> --- tests/parser/functions/test_addmod.py | 32 ++++++++++ tests/parser/functions/test_ec.py | 40 ++++++++++++ tests/parser/functions/test_mulmod.py | 32 ++++++++++ vyper/builtins/functions.py | 92 ++++++++++++--------------- 4 files changed, 143 insertions(+), 53 deletions(-) diff --git a/tests/parser/functions/test_addmod.py b/tests/parser/functions/test_addmod.py index 67a7e9b101..b3135660bb 100644 --- a/tests/parser/functions/test_addmod.py +++ b/tests/parser/functions/test_addmod.py @@ -55,3 +55,35 @@ def c() -> uint256: c = get_contract_with_gas_estimation(code) assert c.foo() == 2 + + +def test_uint256_addmod_evaluation_order(get_contract_with_gas_estimation): + code = """ +a: uint256 + +@external +def foo1() -> uint256: + self.a = 0 + return uint256_addmod(self.a, 1, self.bar()) + +@external +def foo2() -> uint256: + self.a = 0 + return uint256_addmod(self.a, self.bar(), 3) + +@external +def foo3() -> uint256: + self.a = 0 + return uint256_addmod(1, self.a, self.bar()) + +@internal +def bar() -> uint256: + self.a = 1 + return 2 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo1() == 1 + assert c.foo2() == 2 + assert c.foo3() == 1 diff --git a/tests/parser/functions/test_ec.py b/tests/parser/functions/test_ec.py index 9ce37d0721..e1d9e3d2ee 100644 --- a/tests/parser/functions/test_ec.py +++ b/tests/parser/functions/test_ec.py @@ -76,6 +76,26 @@ def foo(a: Foo) -> uint256[2]: assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) +def test_ecadd_evaluation_order(get_contract_with_gas_estimation): + code = """ +x: uint256[2] + +@internal +def bar() -> uint256[2]: + self.x = ecadd([1, 2], [1, 2]) + return [1, 2] + +@external +def foo() -> bool: + self.x = [1, 2] + a: uint256[2] = ecadd([1, 2], [1, 2]) + b: uint256[2] = ecadd(self.x, self.bar()) + return a[0] == b[0] and a[1] == b[1] + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() is True + + def test_ecmul(get_contract_with_gas_estimation): ecmuller = """ x3: uint256[2] @@ -136,3 +156,23 @@ def foo(a: Foo) -> uint256[2]: assert c2.foo(c1.address) == G1_times_three assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_ecmul_evaluation_order(get_contract_with_gas_estimation): + code = """ +x: uint256[2] + +@internal +def bar() -> uint256: + self.x = ecmul([1, 2], 3) + return 3 + +@external +def foo() -> bool: + self.x = [1, 2] + a: uint256[2] = ecmul([1, 2], 3) + b: uint256[2] = ecmul(self.x, self.bar()) + return a[0] == b[0] and a[1] == b[1] + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() is True diff --git a/tests/parser/functions/test_mulmod.py b/tests/parser/functions/test_mulmod.py index 1ea7a3f8e8..96477897b9 100644 --- a/tests/parser/functions/test_mulmod.py +++ b/tests/parser/functions/test_mulmod.py @@ -73,3 +73,35 @@ def c() -> uint256: c = get_contract_with_gas_estimation(code) assert c.foo() == 600 + + +def test_uint256_mulmod_evaluation_order(get_contract_with_gas_estimation): + code = """ +a: uint256 + +@external +def foo1() -> uint256: + self.a = 1 + return uint256_mulmod(self.a, 2, self.bar()) + +@external +def foo2() -> uint256: + self.a = 1 + return uint256_mulmod(self.bar(), self.a, 2) + +@external +def foo3() -> uint256: + self.a = 1 + return uint256_mulmod(2, self.a, self.bar()) + +@internal +def bar() -> uint256: + self.a = 7 + return 5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo1() == 2 + assert c.foo2() == 1 + assert c.foo3() == 2 diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index e8e001306c..053ee512dc 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -25,9 +25,9 @@ eval_once_check, eval_seq, get_bytearray_length, - get_element_ptr, get_type_for_exact_size, ir_tuple_from_args, + make_setter, needs_external_call_wrap, promote_signed_int, sar, @@ -782,10 +782,6 @@ def build_IR(self, expr, args, kwargs, context): ) -def _getelem(arg, ind): - return unwrap_location(get_element_ptr(arg, IRnode.from_list(ind, typ=INT128_T))) - - class ECAdd(BuiltinFunction): _id = "ecadd" _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))] @@ -793,28 +789,22 @@ class ECAdd(BuiltinFunction): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + buf_t = get_type_for_exact_size(128) - with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(a, 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], - ["mstore", ["add", placeholder_node, 64], _getelem(b, 0)], - ["mstore", ["add", placeholder_node, 96], _getelem(b, 1)], - [ - "assert", - ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64], - ], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return b2.resolve(b1.resolve(o)) + buf = context.new_internal_variable(buf_t) + + ret = ["seq"] + + dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst0, args[0])) + + dst1 = IRnode.from_list(buf + 64, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst1, args[1])) + + ret.append(["assert", ["staticcall", ["gas"], 6, buf, 128, buf, 64]]) + ret.append(buf) + + return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) class ECMul(BuiltinFunction): @@ -824,27 +814,22 @@ class ECMul(BuiltinFunction): @process_inputs def build_IR(self, expr, args, kwargs, context): - placeholder_node = IRnode.from_list( - context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY - ) + buf_t = get_type_for_exact_size(96) - with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(a, 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], - ["mstore", ["add", placeholder_node, 64], b], - [ - "assert", - ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64], - ], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return b2.resolve(b1.resolve(o)) + buf = context.new_internal_variable(buf_t) + + ret = ["seq"] + + dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) + ret.append(make_setter(dst0, args[0])) + + dst1 = IRnode.from_list(buf + 64, typ=UINT256_T, location=MEMORY) + ret.append(make_setter(dst1, args[1])) + + ret.append(["assert", ["staticcall", ["gas"], 7, buf, 96, buf, 64]]) + ret.append(buf) + + return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) def _generic_element_getter(op): @@ -1525,13 +1510,14 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - c = args[2] - - with c.cache_when_complex("c") as (b1, c): - ret = IRnode.from_list( - ["seq", ["assert", c], [self._opcode, args[0], args[1], c]], typ=UINT256_T - ) - return b1.resolve(ret) + x, y, z = args + with x.cache_when_complex("x") as (b1, x): + with y.cache_when_complex("y") as (b2, y): + with z.cache_when_complex("z") as (b3, z): + ret = IRnode.from_list( + ["seq", ["assert", z], [self._opcode, x, y, z]], typ=UINT256_T + ) + return b1.resolve(b2.resolve(b3.resolve(ret))) class AddMod(_AddMulMod): From 8700e6da32189b300896567f30c185338d93fd8f Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Tue, 5 Sep 2023 19:12:34 +0100 Subject: [PATCH 076/201] chore: add `asm` option to cli help (#3585) * Add `asm` option to CLI help * Add missing `enum` in function docstring --- vyper/ast/pre_parser.py | 2 +- vyper/cli/vyper_compile.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 7e677b3b92..788c44ef19 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -72,7 +72,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Re-formats a vyper source string into a python source string and performs some validation. More specifically, - * Translates "interface", "struct" and "event" keywords into python "class" keyword + * Translates "interface", "struct", "enum, and "event" keywords into python "class" keyword * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 9c96d55040..9c97f8c667 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -41,6 +41,7 @@ opcodes_runtime - List of runtime opcodes as a string ir - Intermediate representation in list format ir_json - Intermediate representation in JSON format +asm - Output the EVM assembly of the deployable bytecode hex-ir - Output IR and assembly constants in hex instead of decimal """ From 3900ec0d2970aa8b4fbe64d4d698d4721ad09f21 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 17:44:05 -0400 Subject: [PATCH 077/201] refactor: `ecadd()` and `ecmul()` codegen (#3587) refactor `ecadd()` and `ecmul()` to use `make_setter` and share code, so we don't need to do pointer arithmetic --- vyper/builtins/functions.py | 66 +++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 053ee512dc..3933ab2de5 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -782,54 +782,50 @@ def build_IR(self, expr, args, kwargs, context): ) -class ECAdd(BuiltinFunction): - _id = "ecadd" - _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))] - _return_type = SArrayT(UINT256_T, 2) - +class _ECArith(BuiltinFunction): @process_inputs - def build_IR(self, expr, args, kwargs, context): - buf_t = get_type_for_exact_size(128) + def build_IR(self, expr, _args, kwargs, context): + args_tuple = ir_tuple_from_args(_args) - buf = context.new_internal_variable(buf_t) + args_t = args_tuple.typ + input_buf = IRnode.from_list( + context.new_internal_variable(args_t), typ=args_t, location=MEMORY + ) + ret_t = self._return_type ret = ["seq"] + ret.append(make_setter(input_buf, args_tuple)) - dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) - ret.append(make_setter(dst0, args[0])) + output_buf = context.new_internal_variable(ret_t) - dst1 = IRnode.from_list(buf + 64, typ=SArrayT(UINT256_T, 2), location=MEMORY) - ret.append(make_setter(dst1, args[1])) + args_ofst = input_buf + args_len = args_t.memory_bytes_required + out_ofst = output_buf + out_len = ret_t.memory_bytes_required - ret.append(["assert", ["staticcall", ["gas"], 6, buf, 128, buf, 64]]) - ret.append(buf) + ret.append( + [ + "assert", + ["staticcall", ["gas"], self._precompile, args_ofst, args_len, out_ofst, out_len], + ] + ) + ret.append(output_buf) - return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) + return IRnode.from_list(ret, typ=ret_t, location=MEMORY) -class ECMul(BuiltinFunction): - _id = "ecmul" - _inputs = [("point", SArrayT(UINT256_T, 2)), ("scalar", UINT256_T)] +class ECAdd(_ECArith): + _id = "ecadd" + _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))] _return_type = SArrayT(UINT256_T, 2) + _precompile = 0x6 - @process_inputs - def build_IR(self, expr, args, kwargs, context): - buf_t = get_type_for_exact_size(96) - - buf = context.new_internal_variable(buf_t) - - ret = ["seq"] - dst0 = IRnode.from_list(buf, typ=SArrayT(UINT256_T, 2), location=MEMORY) - ret.append(make_setter(dst0, args[0])) - - dst1 = IRnode.from_list(buf + 64, typ=UINT256_T, location=MEMORY) - ret.append(make_setter(dst1, args[1])) - - ret.append(["assert", ["staticcall", ["gas"], 7, buf, 96, buf, 64]]) - ret.append(buf) - - return IRnode.from_list(ret, typ=SArrayT(UINT256_T, 2), location=MEMORY) +class ECMul(_ECArith): + _id = "ecmul" + _inputs = [("point", SArrayT(UINT256_T, 2)), ("scalar", UINT256_T)] + _return_type = SArrayT(UINT256_T, 2) + _precompile = 0x7 def _generic_element_getter(op): From a854929b602f8e40728bdc028c2485ec6da4a3ef Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 18:21:32 -0400 Subject: [PATCH 078/201] fix: `ecrecover()` buffer edge case (#3586) this commit fixes an edge case in `ecrecover()` that was not covered by 019a37ab98ff5. in the case that one of the arguments to ecrecover writes to memory location 0, and the signature is invalid, `ecrecover()` could return the data written by the argument. this commit fixes the issue by allocating fresh memory for the output buffer (which won't be written to by evaluating any of the arguments unless the memory allocator is broken). --- tests/parser/functions/test_ecrecover.py | 27 ++++++++++++++++++++++++ vyper/builtins/functions.py | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/parser/functions/test_ecrecover.py b/tests/parser/functions/test_ecrecover.py index 40c9a6a936..8571948c3d 100644 --- a/tests/parser/functions/test_ecrecover.py +++ b/tests/parser/functions/test_ecrecover.py @@ -58,3 +58,30 @@ def test_ecrecover(hash: bytes32, v: uint8, r: uint256) -> address: r = 0 # note web3.py decoding of 0x000..00 address is None. assert c.test_ecrecover(hash_, v, r) is None + + +# slightly more subtle example: get_v() stomps memory location 0, +# so this tests that the output buffer stays clean during ecrecover() +# builtin execution. +def test_invalid_signature2(get_contract): + code = """ + +owner: immutable(address) + +@external +def __init__(): + owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf + +@internal +def get_v() -> uint256: + assert owner == owner # force a dload to write at index 0 of memory + return 21 + +@payable +@external +def test_ecrecover() -> bool: + assert ecrecover(empty(bytes32), self.get_v(), 0, 0) == empty(address) + return True + """ + c = get_contract(code) + assert c.test_ecrecover() is True diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 3933ab2de5..3ec8f69934 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -765,7 +765,7 @@ def infer_arg_types(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): input_buf = context.new_internal_variable(get_type_for_exact_size(128)) - output_buf = MemoryPositions.FREE_VAR_SPACE + output_buf = context.new_internal_variable(get_type_for_exact_size(32)) return IRnode.from_list( [ "seq", From 39a23137cd1babfb24222e3a9e785e047bba0c6d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 18:44:22 -0400 Subject: [PATCH 079/201] fix: metadata journal can rollback incorrectly (#3569) this commit fixes an issue where multiple writes inside of a checkpoint lead to journal corruption on rollback. it ensures a call to `register_update()` when the metadata dict has already been updated inside of a given checkpoint. note this does not change any observed functionality in the compiler because writes to the metadata journal inside for loops only ever happen to be written once, but it prevents a bug in case we ever add multiple writes inside of the same checkpoint. chainsec june review 5.3 --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> --- tests/ast/test_metadata_journal.py | 82 ++++++++++++++++++++++++++++++ vyper/ast/metadata.py | 5 +- 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/ast/test_metadata_journal.py diff --git a/tests/ast/test_metadata_journal.py b/tests/ast/test_metadata_journal.py new file mode 100644 index 0000000000..34830409fc --- /dev/null +++ b/tests/ast/test_metadata_journal.py @@ -0,0 +1,82 @@ +from vyper.ast.metadata import NodeMetadata +from vyper.exceptions import VyperException + + +def test_metadata_journal_basic(): + m = NodeMetadata() + + m["x"] = 1 + assert m["x"] == 1 + + +def test_metadata_journal_commit(): + m = NodeMetadata() + + with m.enter_typechecker_speculation(): + m["x"] = 1 + + assert m["x"] == 1 + + +def test_metadata_journal_exception(): + m = NodeMetadata() + + m["x"] = 1 + try: + with m.enter_typechecker_speculation(): + m["x"] = 2 + m["x"] = 3 + + assert m["x"] == 3 + raise VyperException("dummy exception") + + except VyperException: + pass + + # rollback upon exception + assert m["x"] == 1 + + +def test_metadata_journal_rollback_inner(): + m = NodeMetadata() + + m["x"] = 1 + with m.enter_typechecker_speculation(): + m["x"] = 2 + + try: + with m.enter_typechecker_speculation(): + m["x"] = 3 + m["x"] = 4 # test multiple writes + + assert m["x"] == 4 + raise VyperException("dummy exception") + + except VyperException: + pass + + assert m["x"] == 2 + + +def test_metadata_journal_rollback_outer(): + m = NodeMetadata() + + m["x"] = 1 + try: + with m.enter_typechecker_speculation(): + m["x"] = 2 + + with m.enter_typechecker_speculation(): + m["x"] = 3 + m["x"] = 4 # test multiple writes + + assert m["x"] == 4 + + m["x"] = 5 + + raise VyperException("dummy exception") + + except VyperException: + pass + + assert m["x"] == 1 diff --git a/vyper/ast/metadata.py b/vyper/ast/metadata.py index 30e06e0016..0a419c3732 100644 --- a/vyper/ast/metadata.py +++ b/vyper/ast/metadata.py @@ -17,8 +17,11 @@ def __init__(self): self._node_updates: list[dict[tuple[int, str, Any], NodeMetadata]] = [] def register_update(self, metadata, k): + KEY = (id(metadata), k) + if KEY in self._node_updates[-1]: + return prev = metadata.get(k, self._NOT_FOUND) - self._node_updates[-1][(id(metadata), k)] = (metadata, prev) + self._node_updates[-1][KEY] = (metadata, prev) @contextlib.contextmanager def enter(self): From 96d20425fa2fbebb9e9aeb0402399c745eb80cfe Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 5 Sep 2023 19:05:23 -0400 Subject: [PATCH 080/201] feat: add runtime code layout to initcode (#3584) this commit adds the runtime code layout to the initcode payload (as a suffix), so that the runtime code can be analyzed without source code. this is particularly important for disassemblers, which need demarcations for where the data section starts as distinct from the runtime code segment itself. the layout is: CBOR-encoded list: runtime code length [ for data section in runtime data sections] immutable section length {"vyper": (major, minor, patch)} length of CBOR-encoded list + 2, encoded as two big-endian bytes. note the specific format for the CBOR payload was chosen to avoid changing the last 13 bytes of the signature compared to previous versions of vyper. that is, the last 13 bytes still look like b"\xa1evyper\x83...", this is because, as the last item in a list, its encoding does not change compared to being the only dict in the payload. this commit also changes the meaning of the two footer bytes: they now indicate the length of the entire footer (including the two bytes indicating the footer length). the sole purpose of this is to be more intuitive as the two footer bytes indicate offset-from-the-end where the CBOR-encoded metadata starts, rather than the length of the CBOR payload (without the two length bytes). lastly, this commit renames the internal `insert_vyper_signature=` kwarg to `insert_compiler_metadata=` as the metadata includes more than just the vyper version now. --- setup.py | 1 + tests/compiler/test_bytecode_runtime.py | 133 ++++++++++++++++++++++-- vyper/compiler/output.py | 2 +- vyper/compiler/phases.py | 12 ++- vyper/ir/compile_ir.py | 64 ++++++++---- 5 files changed, 180 insertions(+), 32 deletions(-) diff --git a/setup.py b/setup.py index bbf6e60f55..c251071229 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ def _global_version(version): python_requires=">=3.10,<4", py_modules=["vyper"], install_requires=[ + "cbor2>=5.4.6,<6", "asttokens>=2.0.5,<3", "pycryptodome>=3.5.1,<4", "semantic-version>=2.10,<3", diff --git a/tests/compiler/test_bytecode_runtime.py b/tests/compiler/test_bytecode_runtime.py index 86eff70a50..9519b03772 100644 --- a/tests/compiler/test_bytecode_runtime.py +++ b/tests/compiler/test_bytecode_runtime.py @@ -1,14 +1,135 @@ -import vyper +import cbor2 +import pytest +import vyper +from vyper.compiler.settings import OptimizationLevel, Settings -def test_bytecode_runtime(): - code = """ +simple_contract_code = """ @external def a() -> bool: return True - """ +""" + +many_functions = """ +@external +def foo1(): + pass + +@external +def foo2(): + pass + +@external +def foo3(): + pass + +@external +def foo4(): + pass + +@external +def foo5(): + pass +""" + +has_immutables = """ +A_GOOD_PRIME: public(immutable(uint256)) + +@external +def __init__(): + A_GOOD_PRIME = 967 +""" + + +def _parse_cbor_metadata(initcode): + metadata_ofst = int.from_bytes(initcode[-2:], "big") + metadata = cbor2.loads(initcode[-metadata_ofst:-2]) + return metadata - out = vyper.compile_code(code, ["bytecode_runtime", "bytecode"]) + +def test_bytecode_runtime(): + out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"]) assert len(out["bytecode"]) > len(out["bytecode_runtime"]) - assert out["bytecode_runtime"][2:] in out["bytecode"][2:] + assert out["bytecode_runtime"].removeprefix("0x") in out["bytecode"].removeprefix("0x") + + +def test_bytecode_signature(): + out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"]) + + runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) + initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) + + metadata = _parse_cbor_metadata(initcode) + runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert runtime_len == len(runtime_code) + assert data_section_lengths == [] + assert immutables_len == 0 + assert compiler == {"vyper": list(vyper.version.version_tuple)} + + +def test_bytecode_signature_dense_jumptable(): + settings = Settings(optimize=OptimizationLevel.CODESIZE) + + out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings) + + runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) + initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) + + metadata = _parse_cbor_metadata(initcode) + runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert runtime_len == len(runtime_code) + assert data_section_lengths == [5, 35] + assert immutables_len == 0 + assert compiler == {"vyper": list(vyper.version.version_tuple)} + + +def test_bytecode_signature_sparse_jumptable(): + settings = Settings(optimize=OptimizationLevel.GAS) + + out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings) + + runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) + initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) + + metadata = _parse_cbor_metadata(initcode) + runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert runtime_len == len(runtime_code) + assert data_section_lengths == [8] + assert immutables_len == 0 + assert compiler == {"vyper": list(vyper.version.version_tuple)} + + +def test_bytecode_signature_immutables(): + out = vyper.compile_code(has_immutables, ["bytecode_runtime", "bytecode"]) + + runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) + initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) + + metadata = _parse_cbor_metadata(initcode) + runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert runtime_len == len(runtime_code) + assert data_section_lengths == [] + assert immutables_len == 32 + assert compiler == {"vyper": list(vyper.version.version_tuple)} + + +# check that deployed bytecode actually matches the cbor metadata +@pytest.mark.parametrize("code", [simple_contract_code, has_immutables, many_functions]) +def test_bytecode_signature_deployed(code, get_contract, w3): + c = get_contract(code) + deployed_code = w3.eth.get_code(c.address) + + initcode = c._classic_contract.bytecode + + metadata = _parse_cbor_metadata(initcode) + runtime_len, data_section_lengths, immutables_len, compiler = metadata + + assert compiler == {"vyper": list(vyper.version.version_tuple)} + + # runtime_len includes data sections but not immutables + assert len(deployed_code) == runtime_len + immutables_len diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 69fcbf1f1f..334c5ba613 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -218,7 +218,7 @@ def _build_asm(asm_list): def build_source_map_output(compiler_data: CompilerData) -> OrderedDict: _, line_number_map = compile_ir.assembly_to_evm( - compiler_data.assembly_runtime, insert_vyper_signature=False + compiler_data.assembly_runtime, insert_compiler_metadata=False ) # Sort line_number_map out = OrderedDict() diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 526d2f3253..a1c7342320 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -184,12 +184,12 @@ def assembly_runtime(self) -> list: @cached_property def bytecode(self) -> bytes: - insert_vyper_signature = not self.no_bytecode_metadata - return generate_bytecode(self.assembly, insert_vyper_signature=insert_vyper_signature) + insert_compiler_metadata = not self.no_bytecode_metadata + return generate_bytecode(self.assembly, insert_compiler_metadata=insert_compiler_metadata) @cached_property def bytecode_runtime(self) -> bytes: - return generate_bytecode(self.assembly_runtime, insert_vyper_signature=False) + return generate_bytecode(self.assembly_runtime, insert_compiler_metadata=False) @cached_property def blueprint_bytecode(self) -> bytes: @@ -331,7 +331,7 @@ def _find_nested_opcode(assembly, key): return any(_find_nested_opcode(x, key) for x in sublists) -def generate_bytecode(assembly: list, insert_vyper_signature: bool) -> bytes: +def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes: """ Generate bytecode from assembly instructions. @@ -345,4 +345,6 @@ def generate_bytecode(assembly: list, insert_vyper_signature: bool) -> bytes: bytes Final compiled bytecode. """ - return compile_ir.assembly_to_evm(assembly, insert_vyper_signature=insert_vyper_signature)[0] + return compile_ir.assembly_to_evm(assembly, insert_compiler_metadata=insert_compiler_metadata)[ + 0 + ] diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index bba3b34515..7a3e97155b 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -1,6 +1,9 @@ import copy import functools import math +from dataclasses import dataclass + +import cbor2 from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel @@ -507,9 +510,9 @@ def _height_of(witharg): elif code.value == "deploy": memsize = code.args[0].value # used later to calculate _mem_deploy_start ir = code.args[1] - padding = code.args[2].value + immutables_len = code.args[2].value assert isinstance(memsize, int), "non-int memsize" - assert isinstance(padding, int), "non-int padding" + assert isinstance(immutables_len, int), "non-int immutables_len" runtime_begin = mksymbol("runtime_begin") @@ -521,14 +524,14 @@ def _height_of(witharg): o.extend(["_sym_subcode_size", runtime_begin, "_mem_deploy_start", "CODECOPY"]) # calculate the len of runtime code - o.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len + o.extend(["_OFST", "_sym_subcode_size", immutables_len]) # stack: len o.extend(["_mem_deploy_start"]) # stack: len mem_ofst o.extend(["RETURN"]) # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [_RuntimeHeader(runtime_begin, memsize)] + subcode + subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode # append the runtime code after the ctor code # `append(...)` call here is intentional. @@ -1051,18 +1054,19 @@ def _length_of_data(assembly): return ret +@dataclass class _RuntimeHeader: - def __init__(self, label, ctor_mem_size): - self.label = label - self.ctor_mem_size = ctor_mem_size + label: str + ctor_mem_size: int + immutables_len: int def __repr__(self): - return f"" + return f"" +@dataclass class _DataHeader: - def __init__(self, label): - self.label = label + label: str def __repr__(self): return f"DATA {self.label}" @@ -1092,21 +1096,21 @@ def _relocate_segments(assembly): # TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps -def assembly_to_evm(assembly, pc_ofst=0, insert_vyper_signature=False): +def assembly_to_evm(assembly, pc_ofst=0, insert_compiler_metadata=False): bytecode, source_maps, _ = assembly_to_evm_with_symbol_map( - assembly, pc_ofst=pc_ofst, insert_vyper_signature=insert_vyper_signature + assembly, pc_ofst=pc_ofst, insert_compiler_metadata=insert_compiler_metadata ) return bytecode, source_maps -def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature=False): +def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadata=False): """ Assembles assembly into EVM assembly: list of asm instructions pc_ofst: when constructing the source map, the amount to offset all pcs by (no effect until we add deploy code source map) - insert_vyper_signature: whether to append vyper metadata to output + insert_compiler_metadata: whether to append vyper metadata to output (should be true for runtime code) """ line_number_map = { @@ -1122,12 +1126,6 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature= runtime_code, runtime_code_start, runtime_code_end = None, None, None - bytecode_suffix = b"" - if insert_vyper_signature: - # CBOR encoded: {"vyper": [major,minor,patch]} - bytecode_suffix += b"\xa1\x65vyper\x83" + bytes(list(version_tuple)) - bytecode_suffix += len(bytecode_suffix).to_bytes(2, "big") - # to optimize the size of deploy code - we want to use the smallest # PUSH instruction possible which can support all memory symbols # (and also works with linear pass symbol resolution) @@ -1155,6 +1153,9 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature= if runtime_code_end is not None: mem_ofst_size = calc_mem_ofst_size(runtime_code_end + max_mem_ofst) + data_section_lengths = [] + immutables_len = None + # go through the code, resolving symbolic locations # (i.e. JUMPDEST locations) to actual code locations for i, item in enumerate(assembly): @@ -1198,18 +1199,41 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_vyper_signature= # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + # we are in initcode symbol_map[item[0].label] = pc # add source map for all items in the runtime map t = adjust_pc_maps(runtime_map, pc) for key in line_number_map: line_number_map[key].update(t[key]) + immutables_len = item[0].immutables_len pc += len(runtime_code) + # grab lengths of data sections from the runtime + for t in item: + if isinstance(t, list) and isinstance(t[0], _DataHeader): + data_section_lengths.append(_length_of_data(t)) + elif isinstance(item, list) and isinstance(item[0], _DataHeader): symbol_map[item[0].label] = pc pc += _length_of_data(item) else: pc += 1 + bytecode_suffix = b"" + if insert_compiler_metadata: + # this will hold true when we are in initcode + assert immutables_len is not None + metadata = ( + len(runtime_code), + data_section_lengths, + immutables_len, + {"vyper": version_tuple}, + ) + bytecode_suffix += cbor2.dumps(metadata) + # append the length of the footer, *including* the length + # of the length bytes themselves. + suffix_len = len(bytecode_suffix) + 2 + bytecode_suffix += suffix_len.to_bytes(2, "big") + pc += len(bytecode_suffix) symbol_map["_sym_code_end"] = pc From 41c9a3a62f94678aa0a64b4db3676d6da963a1e7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 11:15:09 -0400 Subject: [PATCH 081/201] chore: `v0.3.10rc1` release notes (#3534) --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> --- docs/release-notes.rst | 58 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index f408c5c0ab..da86c5c0ce 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,8 +14,64 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g +.. + v0.3.10 ("Black Adder") + *********************** + +v0.3.10rc1 +********** + +Date released: 2023-09-06 +========================= + +v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). + +Breaking changes: +----------------- + +- add runtime code layout to initcode (`#3584 `_) +- drop evm versions through istanbul (`#3470 `_) +- remove vyper signature from runtime (`#3471 `_) + +Non-breaking changes and improvements: +-------------------------------------- + +- O(1) selector tables (`#3496 `_) +- implement bound= in ranges (`#3537 `_, `#3551 `_) +- add optimization mode to vyper compiler (`#3493 `_) +- improve batch copy performance (`#3483 `_, `#3499 `_, `#3525 `_) + +Notable fixes: +-------------- + +- fix ``ecrecover()`` behavior when signature is invalid (`GHSA-f5x6-7qgp-jhf3 `_, `#3586 `_) +- fix: order of evaluation for some builtins (`#3583 `_, `#3587 `_) +- fix: pycryptodome for arm builds (`#3485 `_) +- let params of internal functions be mutable (`#3473 `_) +- typechecking of folded builtins in (`#3490 `_) +- update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_) +- fix: raw_call type when max_outsize=0 is set (`#3572 `_) +- fix: implements check for indexed event arguments (`#3570 `_) + +Other docs updates, chores and fixes: +------------------------------------- + +- relax restrictions on internal function signatures (`#3573 `_) +- note on security advisory in release notes for versions ``0.2.15``, ``0.2.16``, and ``0.3.0`` (`#3553 `_) +- fix: yanked version in release notes (`#3545 `_) +- update release notes on yanked versions (`#3547 `_) +- improve error message for conflicting methods IDs (`#3491 `_) +- document epsilon builtin (`#3552 `_) +- relax version pragma parsing (`#3511 `_) +- fix: issue with finding installed packages in editable mode (`#3510 `_) +- add note on security advisory for ``ecrecover`` in docs (`#3539 `_) +- add ``asm`` option to cli help (`#3585 `_) +- add message to error map for repeat range check (`#3542 `_) +- fix: public constant arrays (`#3536 `_) + + v0.3.9 ("Common Adder") -****** +*********************** Date released: 2023-05-29 From 09f95c5d3921bb193f35f7fff8f653a1f0bb79b6 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Wed, 6 Sep 2023 20:14:48 +0100 Subject: [PATCH 082/201] chore: add `ir_runtime` option to cli help (#3592) --- vyper/cli/vyper_compile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 9c97f8c667..bdd01eebbe 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -41,6 +41,7 @@ opcodes_runtime - List of runtime opcodes as a string ir - Intermediate representation in list format ir_json - Intermediate representation in JSON format +ir_runtime - Intermediate representation of runtime bytecode in list format asm - Output the EVM assembly of the deployable bytecode hex-ir - Output IR and assembly constants in hex instead of decimal """ From 1ed445765d546437febb7a2d3347d29bea6d943d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 15:18:46 -0400 Subject: [PATCH 083/201] chore(ci): fix macos universal2 build (#3590) this was a build regression introduced by the inclusion of the `cbor2` package in 96d20425fa2fb. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b4be1043c1..684955bea1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,7 @@ jobs: - name: Generate Binary run: >- - pip install --no-binary pycryptodome . && + pip install --no-binary pycryptodome --no-binary cbor2 . && pip install pyinstaller && make freeze From 294d97c2b853fb67ec7ca5398dfd60808384d4fb Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 16:02:51 -0400 Subject: [PATCH 084/201] fix: version parsing for release candidates (#3593) the npm spec library is buggy and does not handle release candidates correctly. switch to the pypa packaging library which does pep440. note that we do a hack in order to support commonly used npm prefixes: no prefix, and `^` as prefix. going forward in v0.4.x, we will switch to pep440 entirely. --- docs/structure-of-a-contract.rst | 2 +- tests/ast/test_pre_parser.py | 32 ++++++------------------ vyper/ast/pre_parser.py | 43 ++++++++++---------------------- 3 files changed, 21 insertions(+), 56 deletions(-) diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index f58ab3b067..d2c5d48d96 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -17,7 +17,7 @@ Vyper supports several source code directives to control compiler modes and help Version Pragma -------------- -The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. +The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers _`. As of 0.3.10, the recommended way to specify the version pragma is as follows: diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py index 150ee55edf..5427532c16 100644 --- a/tests/ast/test_pre_parser.py +++ b/tests/ast/test_pre_parser.py @@ -21,16 +21,9 @@ def set_version(version): "0.1.1", ">0.0.1", "^0.1.0", - "<=1.0.0 >=0.1.0", - "0.1.0 - 1.0.0", - "~0.1.0", - "0.1", - "0", - "*", - "x", - "0.x", - "0.1.x", - "0.2.0 || 0.1.1", + "<=1.0.0,>=0.1.0", + # "0.1.0 - 1.0.0", + "~=0.1.0", ] invalid_versions = [ "0.1.0", @@ -44,7 +37,6 @@ def set_version(version): "1.x", "0.2.x", "0.2.0 || 0.1.3", - "==0.1.1", "abc", ] @@ -70,9 +62,10 @@ def test_invalid_version_pragma(file_version, mock_version): "<0.1.1-rc.1", ">0.1.1a1", ">0.1.1-alpha.1", - "0.1.1a9 - 0.1.1-rc.10", + ">=0.1.1a9,<=0.1.1-rc.10", "<0.1.1b8", "<0.1.1rc1", + "<0.2.0", ] prerelease_invalid_versions = [ ">0.1.1-beta.9", @@ -80,19 +73,8 @@ def test_invalid_version_pragma(file_version, mock_version): "0.1.1b8", "0.1.1rc2", "0.1.1-rc.9 - 0.1.1-rc.10", - "<0.2.0", - pytest.param( - "<0.1.1b1", - marks=pytest.mark.xfail( - reason="https://github.com/rbarrois/python-semanticversion/issues/100" - ), - ), - pytest.param( - "<0.1.1a9", - marks=pytest.mark.xfail( - reason="https://github.com/rbarrois/python-semanticversion/issues/100" - ), - ), + "<0.1.1b1", + "<0.1.1a9", ] diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 788c44ef19..0ead889787 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -2,7 +2,7 @@ import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize -from semantic_version import NpmSpec, Version +from packaging.specifiers import InvalidSpecifier, SpecifierSet from vyper.compiler.settings import OptimizationLevel, Settings @@ -12,21 +12,6 @@ from vyper.exceptions import StructureException, SyntaxException, VersionException from vyper.typing import ModificationOffsets, ParserPosition -VERSION_ALPHA_RE = re.compile(r"(?<=\d)a(?=\d)") # 0.1.0a17 -VERSION_BETA_RE = re.compile(r"(?<=\d)b(?=\d)") # 0.1.0b17 -VERSION_RC_RE = re.compile(r"(?<=\d)rc(?=\d)") # 0.1.0rc17 - - -def _convert_version_str(version_str: str) -> str: - """ - Convert loose version (0.1.0b17) to strict version (0.1.0-beta.17) - """ - version_str = re.sub(VERSION_ALPHA_RE, "-alpha.", version_str) # 0.1.0-alpha.17 - version_str = re.sub(VERSION_BETA_RE, "-beta.", version_str) # 0.1.0-beta.17 - version_str = re.sub(VERSION_RC_RE, "-rc.", version_str) # 0.1.0-rc.17 - - return version_str - def validate_version_pragma(version_str: str, start: ParserPosition) -> None: """ @@ -34,28 +19,26 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: """ from vyper import __version__ - # NOTE: should be `x.y.z.*` - installed_version = ".".join(__version__.split(".")[:3]) - - strict_file_version = _convert_version_str(version_str) - strict_compiler_version = Version(_convert_version_str(installed_version)) - - if len(strict_file_version) == 0: + if len(version_str) == 0: raise VersionException("Version specification cannot be empty", start) + # X.Y.Z or vX.Y.Z => ==X.Y.Z, ==vX.Y.Z + if re.match("[v0-9]", version_str): + version_str = "==" + version_str + # convert npm to pep440 + version_str = re.sub("^\\^", "~=", version_str) + try: - npm_spec = NpmSpec(strict_file_version) - except ValueError: + spec = SpecifierSet(version_str) + except InvalidSpecifier: raise VersionException( - f'Version specification "{version_str}" is not a valid NPM semantic ' - f"version specification", - start, + f'Version specification "{version_str}" is not a valid PEP440 specifier', start ) - if not npm_spec.match(strict_compiler_version): + if not spec.contains(__version__, prereleases=True): raise VersionException( f'Version specification "{version_str}" is not compatible ' - f'with compiler version "{installed_version}"', + f'with compiler version "{__version__}"', start, ) From aca2b4c5e54791943547342fae3c06552db1a3a7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 16:04:14 -0400 Subject: [PATCH 085/201] chore: CI for pre-release (release candidate) actions (#3589) --- .github/workflows/build.yml | 2 +- .github/workflows/publish.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 684955bea1..e81aa236d1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,7 +9,7 @@ on: branches: - master release: - types: [released] + types: [published] # releases and pre-releases (release candidates) defaults: run: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 44c6978295..f268942e7d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -5,7 +5,7 @@ name: Publish to PyPI on: release: - types: [released] + types: [published] # releases and pre-releases (release candidates) jobs: From bb6e69acc3158f0acf16f23637f053d63d226e5b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 16:35:03 -0400 Subject: [PATCH 086/201] chore(ci): build binaries on pull requests (#3591) build binaries on all pull requests, to have better oversight over binary build success --- .github/workflows/build.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e81aa236d1..7243a05408 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -name: Build and release artifacts +name: Build artifacts on: workflow_dispatch: @@ -8,6 +8,7 @@ on: push: branches: - master + pull_request: release: types: [published] # releases and pre-releases (release candidates) @@ -42,6 +43,7 @@ jobs: pip install pyinstaller && make freeze + - name: Upload Artifact uses: actions/upload-artifact@v3 with: @@ -101,3 +103,13 @@ jobs: "https://uploads.github.com/repos/${{ github.repository }}/releases/${{ github.event.release.id }}/assets?name=${BIN_NAME}" \ --data-binary "@${BIN_NAME}" done + + # check build success for pull requests + build-success: + if: always() + runs-on: ubuntu-latest + needs: [windows-build, unix-build] + steps: + - name: check that all builds succeeded + if: ${{ contains(needs.*.result, 'failure') }} + run: exit 1 From 0cb37e3ef96ce374dafec5b1fcb40849fe074c62 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 6 Sep 2023 17:51:33 -0400 Subject: [PATCH 087/201] fix: dependency specification for `packaging` (#3594) setup.py regression introduced in 294d97c2b853fb --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c251071229..c81b9bed4a 100644 --- a/setup.py +++ b/setup.py @@ -95,7 +95,7 @@ def _global_version(version): "cbor2>=5.4.6,<6", "asttokens>=2.0.5,<3", "pycryptodome>=3.5.1,<4", - "semantic-version>=2.10,<3", + "packaging>=23.1,<24", "importlib-metadata", "wheel", ], From 3b310d5292c4d1448e673d7b3adb223f9353260e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 7 Sep 2023 17:45:35 -0400 Subject: [PATCH 088/201] chore(ci): fix binary names in release asset upload (#3597) rename binary during asset upload to properly escape the filename for the github API call. (Github API states: > GitHub renames asset filenames that have special characters, non-alphanumeric characters, and leading or trailing periods. The "List release assets" endpoint lists the renamed filenames. For more information and help, contact GitHub Support. ) --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7243a05408..c8d7f7d6c4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -100,7 +100,7 @@ jobs: -X POST \ -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}"\ -H "Content-Type: application/octet-stream" \ - "https://uploads.github.com/repos/${{ github.repository }}/releases/${{ github.event.release.id }}/assets?name=${BIN_NAME}" \ + "https://uploads.github.com/repos/${{ github.repository }}/releases/${{ github.event.release.id }}/assets?name=${BIN_NAME/+/%2B}" \ --data-binary "@${BIN_NAME}" done From 344fd8f36c7f0cf1e34fd06ec30f34f6c487f340 Mon Sep 17 00:00:00 2001 From: Mikko Ohtamaa Date: Sun, 10 Sep 2023 17:30:01 +0200 Subject: [PATCH 089/201] docs: add README banner about Vyper audit competition (#3599) Add a temporary banner at the top of the README to advertise the audit competition --------- Co-authored-by: Charles Cooper --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index af987ffd4f..bad929956d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +**Vyper compiler security audit competition starts 14th September with $150k worth of bounties.** [See the competition on CodeHawks](https://www.codehawks.com/contests/cll5rujmw0001js08menkj7hc) and find [more details in this blog post](https://mirror.xyz/0xBA41A04A14aeaEec79e2D694B21ba5Ab610982f1/WTZ3l3MLhTz9P4avq6JqipN5d4HJNiUY-d8zT0pfmXg). From 0b740280c1e3c5528a20d47b29831948ddcc6d83 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 15 Sep 2023 18:01:03 -0400 Subject: [PATCH 090/201] fix: only allow valid identifiers to be nonreentrant keys (#3605) disallow invalid identifiers like `" "`, `"123abc"` from being keys for non-reentrant locks. this commit also refactors the `validate_identifiers` helper function to be in the `ast/` subdirectory, and slightly improves the VyperException constructor by allowing None (optional) annotations. --- .../exceptions/test_structure_exception.py | 23 +++- .../features/decorators/test_nonreentrant.py | 4 +- tests/parser/test_call_graph_stability.py | 2 +- tests/parser/types/test_identifier_naming.py | 2 +- vyper/ast/identifiers.py | 111 ++++++++++++++++ vyper/exceptions.py | 4 +- vyper/semantics/namespace.py | 119 +----------------- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/function.py | 6 +- 9 files changed, 147 insertions(+), 126 deletions(-) create mode 100644 vyper/ast/identifiers.py diff --git a/tests/parser/exceptions/test_structure_exception.py b/tests/parser/exceptions/test_structure_exception.py index 08794b75f2..97ac2b139d 100644 --- a/tests/parser/exceptions/test_structure_exception.py +++ b/tests/parser/exceptions/test_structure_exception.py @@ -56,9 +56,26 @@ def double_nonreentrant(): """, """ @external -@nonreentrant("B") -@nonreentrant("C") -def double_nonreentrant(): +@nonreentrant(" ") +def invalid_nonreentrant_key(): + pass + """, + """ +@external +@nonreentrant("") +def invalid_nonreentrant_key(): + pass + """, + """ +@external +@nonreentrant("123") +def invalid_nonreentrant_key(): + pass + """, + """ +@external +@nonreentrant("!123abcd") +def invalid_nonreentrant_key(): pass """, """ diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/parser/features/decorators/test_nonreentrant.py index ac73b35bec..9e74019250 100644 --- a/tests/parser/features/decorators/test_nonreentrant.py +++ b/tests/parser/features/decorators/test_nonreentrant.py @@ -142,7 +142,7 @@ def set_callback(c: address): @external @payable -@nonreentrant('default') +@nonreentrant("lock") def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val _amount: uint256 = msg.value @@ -166,7 +166,7 @@ def unprotected_function(val: String[100], do_callback: bool): @external @payable -@nonreentrant('default') +@nonreentrant("lock") def __default__(): pass """ diff --git a/tests/parser/test_call_graph_stability.py b/tests/parser/test_call_graph_stability.py index b651092d16..a6193610e2 100644 --- a/tests/parser/test_call_graph_stability.py +++ b/tests/parser/test_call_graph_stability.py @@ -6,8 +6,8 @@ from hypothesis import given, settings import vyper.ast as vy_ast +from vyper.ast.identifiers import RESERVED_KEYWORDS from vyper.compiler.phases import CompilerData -from vyper.semantics.namespace import RESERVED_KEYWORDS def _valid_identifier(attr): diff --git a/tests/parser/types/test_identifier_naming.py b/tests/parser/types/test_identifier_naming.py index f4f602f471..5cfc7e8ed7 100755 --- a/tests/parser/types/test_identifier_naming.py +++ b/tests/parser/types/test_identifier_naming.py @@ -1,10 +1,10 @@ import pytest from vyper.ast.folding import BUILTIN_CONSTANTS +from vyper.ast.identifiers import RESERVED_KEYWORDS from vyper.builtins.functions import BUILTIN_FUNCTIONS from vyper.codegen.expr import ENVIRONMENT_VARIABLES from vyper.exceptions import NamespaceCollision, StructureException, SyntaxException -from vyper.semantics.namespace import RESERVED_KEYWORDS from vyper.semantics.types.primitives import AddressT BUILTIN_CONSTANTS = set(BUILTIN_CONSTANTS.keys()) diff --git a/vyper/ast/identifiers.py b/vyper/ast/identifiers.py new file mode 100644 index 0000000000..985b04e5cd --- /dev/null +++ b/vyper/ast/identifiers.py @@ -0,0 +1,111 @@ +import re + +from vyper.exceptions import StructureException + + +def validate_identifier(attr, ast_node=None): + if not re.match("^[_a-zA-Z][a-zA-Z0-9_]*$", attr): + raise StructureException(f"'{attr}' contains invalid character(s)", ast_node) + if attr.lower() in RESERVED_KEYWORDS: + raise StructureException(f"'{attr}' is a reserved keyword", ast_node) + + +# https://docs.python.org/3/reference/lexical_analysis.html#keywords +# note we don't technically need to block all python reserved keywords, +# but do it for hygiene +_PYTHON_RESERVED_KEYWORDS = { + "False", + "None", + "True", + "and", + "as", + "assert", + "async", + "await", + "break", + "class", + "continue", + "def", + "del", + "elif", + "else", + "except", + "finally", + "for", + "from", + "global", + "if", + "import", + "in", + "is", + "lambda", + "nonlocal", + "not", + "or", + "pass", + "raise", + "return", + "try", + "while", + "with", + "yield", +} +_PYTHON_RESERVED_KEYWORDS = {s.lower() for s in _PYTHON_RESERVED_KEYWORDS} + +# Cannot be used for variable or member naming +RESERVED_KEYWORDS = _PYTHON_RESERVED_KEYWORDS | { + # decorators + "public", + "external", + "nonpayable", + "constant", + "immutable", + "transient", + "internal", + "payable", + "nonreentrant", + # "class" keywords + "interface", + "struct", + "event", + "enum", + # EVM operations + "unreachable", + # special functions (no name mangling) + "init", + "_init_", + "___init___", + "____init____", + "default", + "_default_", + "___default___", + "____default____", + # more control flow and special operations + "range", + # more special operations + "indexed", + # denominations + "ether", + "wei", + "finney", + "szabo", + "shannon", + "lovelace", + "ada", + "babbage", + "gwei", + "kwei", + "mwei", + "twei", + "pwei", + # sentinal constant values + # TODO remove when these are removed from the language + "zero_address", + "empty_bytes32", + "max_int128", + "min_int128", + "max_decimal", + "min_decimal", + "max_uint256", + "zero_wei", +} diff --git a/vyper/exceptions.py b/vyper/exceptions.py index aa23614e85..defca7cc53 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -54,7 +54,9 @@ def __init__(self, message="Error Message not found.", *items): # support older exceptions that don't annotate - remove this in the future! self.lineno, self.col_offset = items[0][:2] else: - self.annotations = items + # strip out None sources so that None can be passed as a valid + # annotation (in case it is only available optionally) + self.annotations = [k for k in items if k is not None] def with_annotation(self, *annotations): """ diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index b88bc3d817..613ac0c03b 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -1,12 +1,7 @@ import contextlib -import re - -from vyper.exceptions import ( - CompilerPanic, - NamespaceCollision, - StructureException, - UndeclaredDefinition, -) + +from vyper.ast.identifiers import validate_identifier +from vyper.exceptions import CompilerPanic, NamespaceCollision, UndeclaredDefinition from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions @@ -121,111 +116,3 @@ def override_global_namespace(ns): finally: # unclobber _namespace = tmp - - -def validate_identifier(attr): - if not re.match("^[_a-zA-Z][a-zA-Z0-9_]*$", attr): - raise StructureException(f"'{attr}' contains invalid character(s)") - if attr.lower() in RESERVED_KEYWORDS: - raise StructureException(f"'{attr}' is a reserved keyword") - - -# https://docs.python.org/3/reference/lexical_analysis.html#keywords -# note we don't technically need to block all python reserved keywords, -# but do it for hygiene -_PYTHON_RESERVED_KEYWORDS = { - "False", - "None", - "True", - "and", - "as", - "assert", - "async", - "await", - "break", - "class", - "continue", - "def", - "del", - "elif", - "else", - "except", - "finally", - "for", - "from", - "global", - "if", - "import", - "in", - "is", - "lambda", - "nonlocal", - "not", - "or", - "pass", - "raise", - "return", - "try", - "while", - "with", - "yield", -} -_PYTHON_RESERVED_KEYWORDS = {s.lower() for s in _PYTHON_RESERVED_KEYWORDS} - -# Cannot be used for variable or member naming -RESERVED_KEYWORDS = _PYTHON_RESERVED_KEYWORDS | { - # decorators - "public", - "external", - "nonpayable", - "constant", - "immutable", - "transient", - "internal", - "payable", - "nonreentrant", - # "class" keywords - "interface", - "struct", - "event", - "enum", - # EVM operations - "unreachable", - # special functions (no name mangling) - "init", - "_init_", - "___init___", - "____init____", - "default", - "_default_", - "___default___", - "____default____", - # more control flow and special operations - "range", - # more special operations - "indexed", - # denominations - "ether", - "wei", - "finney", - "szabo", - "shannon", - "lovelace", - "ada", - "babbage", - "gwei", - "kwei", - "mwei", - "twei", - "pwei", - # sentinal constant values - # TODO remove when these are removed from the language - "zero_address", - "empty_bytes32", - "max_int128", - "min_int128", - "max_decimal", - "min_decimal", - "max_uint256", - "zero_wei", -} diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index af955f6071..c5af5c2a39 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -3,6 +3,7 @@ from vyper import ast as vy_ast from vyper.abi_types import ABIType +from vyper.ast.identifiers import validate_identifier from vyper.exceptions import ( CompilerPanic, InvalidLiteral, @@ -12,7 +13,6 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.namespace import validate_identifier # Some fake type with an overridden `compare_type` which accepts any RHS diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 506dae135c..77b9efb13d 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast +from vyper.ast.identifiers import validate_identifier from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ArgumentException, @@ -220,7 +221,10 @@ def from_FunctionDef( msg = "Nonreentrant decorator disallowed on `__init__`" raise FunctionDeclarationException(msg, decorator) - kwargs["nonreentrant"] = decorator.args[0].value + nonreentrant_key = decorator.args[0].value + validate_identifier(nonreentrant_key, decorator.args[0]) + + kwargs["nonreentrant"] = nonreentrant_key elif isinstance(decorator, vy_ast.Name): if FunctionVisibility.is_valid_value(decorator.id): From 823675a8dc49e8148b7a8c79e86f01dea7115cd9 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 18 Sep 2023 08:16:51 -0700 Subject: [PATCH 091/201] fix: dense selector table when there are empty buckets (#3606) certain combinations of selectors can result in some buckets being empty. in this case, the header section is incomplete. this commit fixes the issue by bailing out of the mkbucket routine when there are empty buckets (thus treating the configurations with empty buckets as invalid) --------- Co-authored-by: Tanguy Rocher --- tests/parser/test_selector_table.py | 431 ++++++++++++++++++++++++++++ vyper/codegen/jumptable_utils.py | 25 +- vyper/codegen/module.py | 8 +- 3 files changed, 458 insertions(+), 6 deletions(-) diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py index 01a83698b7..3ac50707c2 100644 --- a/tests/parser/test_selector_table.py +++ b/tests/parser/test_selector_table.py @@ -10,6 +10,437 @@ from vyper.compiler.settings import OptimizationLevel +def test_dense_selector_table_empty_buckets(get_contract): + # some special combination of selectors which can result in + # some empty bucket being returned from _mk_buckets (that is, + # len(_mk_buckets(..., n_buckets)) != n_buckets + code = """ +@external +def aX61QLPWF()->uint256: + return 1 +@external +def aQHG0P2L1()->uint256: + return 2 +@external +def a2G8ME94W()->uint256: + return 3 +@external +def a0GNA21AY()->uint256: + return 4 +@external +def a4U1XA4T5()->uint256: + return 5 +@external +def aAYLMGOBZ()->uint256: + return 6 +@external +def a0KXRLHKE()->uint256: + return 7 +@external +def aDQS32HTR()->uint256: + return 8 +@external +def aP4K6SA3S()->uint256: + return 9 +@external +def aEB94ZP5S()->uint256: + return 10 +@external +def aTOIMN0IM()->uint256: + return 11 +@external +def aXV2N81OW()->uint256: + return 12 +@external +def a66PP6Y5X()->uint256: + return 13 +@external +def a5MWMTEWN()->uint256: + return 14 +@external +def a5ZFST4Z8()->uint256: + return 15 +@external +def aR13VXULX()->uint256: + return 16 +@external +def aWITH917Y()->uint256: + return 17 +@external +def a59NP6C5O()->uint256: + return 18 +@external +def aJ02590EX()->uint256: + return 19 +@external +def aUAXAAUQ8()->uint256: + return 20 +@external +def aWR1XNC6J()->uint256: + return 21 +@external +def aJABKZOKH()->uint256: + return 22 +@external +def aO1TT0RJT()->uint256: + return 23 +@external +def a41442IOK()->uint256: + return 24 +@external +def aMVXV9FHQ()->uint256: + return 25 +@external +def aNN0KJDZM()->uint256: + return 26 +@external +def aOX965047()->uint256: + return 27 +@external +def a575NX2J3()->uint256: + return 28 +@external +def a16EN8O7W()->uint256: + return 29 +@external +def aSZXLFF7O()->uint256: + return 30 +@external +def aQKQCIPH9()->uint256: + return 31 +@external +def aIP8021DL()->uint256: + return 32 +@external +def aQAV0HSHX()->uint256: + return 33 +@external +def aZVPAD745()->uint256: + return 34 +@external +def aJYBSNST4()->uint256: + return 35 +@external +def aQGWC4NYQ()->uint256: + return 36 +@external +def aFMBB9CXJ()->uint256: + return 37 +@external +def aYWM7ZUH1()->uint256: + return 38 +@external +def aJAZONIX1()->uint256: + return 39 +@external +def aQZ1HJK0H()->uint256: + return 40 +@external +def aKIH9LOUB()->uint256: + return 41 +@external +def aF4ZT80XL()->uint256: + return 42 +@external +def aYQD8UKR5()->uint256: + return 43 +@external +def aP6NCCAI4()->uint256: + return 44 +@external +def aY92U2EAZ()->uint256: + return 45 +@external +def aHMQ49D7P()->uint256: + return 46 +@external +def aMC6YX8VF()->uint256: + return 47 +@external +def a734X6YSI()->uint256: + return 48 +@external +def aRXXPNSMU()->uint256: + return 49 +@external +def aL5XKDTGT()->uint256: + return 50 +@external +def a86V1Y18A()->uint256: + return 51 +@external +def aAUM8PL5J()->uint256: + return 52 +@external +def aBAEC1ERZ()->uint256: + return 53 +@external +def a1U1VA3UE()->uint256: + return 54 +@external +def aC9FGVAHC()->uint256: + return 55 +@external +def aWN81WYJ3()->uint256: + return 56 +@external +def a3KK1Y07J()->uint256: + return 57 +@external +def aAZ6P6OSG()->uint256: + return 58 +@external +def aWP5HCIB3()->uint256: + return 59 +@external +def aVEK161C5()->uint256: + return 60 +@external +def aY0Q3O519()->uint256: + return 61 +@external +def aDHHHFIAE()->uint256: + return 62 +@external +def aGSJBCZKQ()->uint256: + return 63 +@external +def aZQQIUDHY()->uint256: + return 64 +@external +def a12O9QDH5()->uint256: + return 65 +@external +def aRQ1178XR()->uint256: + return 66 +@external +def aDT25C832()->uint256: + return 67 +@external +def aCSB01C4E()->uint256: + return 68 +@external +def aYGBPKZSD()->uint256: + return 69 +@external +def aP24N3EJ8()->uint256: + return 70 +@external +def a531Y9X3C()->uint256: + return 71 +@external +def a4727IKVS()->uint256: + return 72 +@external +def a2EX1L2BS()->uint256: + return 73 +@external +def a6145RN68()->uint256: + return 74 +@external +def aDO1ZNX97()->uint256: + return 75 +@external +def a3R28EU6M()->uint256: + return 76 +@external +def a9BFC867L()->uint256: + return 77 +@external +def aPL1MBGYC()->uint256: + return 78 +@external +def aI6H11O48()->uint256: + return 79 +@external +def aX0248DZY()->uint256: + return 80 +@external +def aE4JBUJN4()->uint256: + return 81 +@external +def aXBDB2ZBO()->uint256: + return 82 +@external +def a7O7MYYHL()->uint256: + return 83 +@external +def aERFF4PB6()->uint256: + return 84 +@external +def aJCUBG6TJ()->uint256: + return 85 +@external +def aQ5ELXM0F()->uint256: + return 86 +@external +def aWDT9UQVV()->uint256: + return 87 +@external +def a7UU40DJK()->uint256: + return 88 +@external +def aH01IT5VS()->uint256: + return 89 +@external +def aSKYTZ0FC()->uint256: + return 90 +@external +def aNX5LYRAW()->uint256: + return 91 +@external +def aUDKAOSGG()->uint256: + return 92 +@external +def aZ86YGAAO()->uint256: + return 93 +@external +def aIHWQGKLO()->uint256: + return 94 +@external +def aKIKFLAR9()->uint256: + return 95 +@external +def aCTPE0KRS()->uint256: + return 96 +@external +def aAD75X00P()->uint256: + return 97 +@external +def aDROUEF2F()->uint256: + return 98 +@external +def a8CDIF6YN()->uint256: + return 99 +@external +def aD2X7TM83()->uint256: + return 100 +@external +def a3W5UUB4L()->uint256: + return 101 +@external +def aG4MOBN4B()->uint256: + return 102 +@external +def aPRS0MSG7()->uint256: + return 103 +@external +def aKN3GHBUR()->uint256: + return 104 +@external +def aGE435RHQ()->uint256: + return 105 +@external +def a4E86BNFE()->uint256: + return 106 +@external +def aYDG928YW()->uint256: + return 107 +@external +def a2HFP5GQE()->uint256: + return 108 +@external +def a5DPMVXKA()->uint256: + return 109 +@external +def a3OFVC3DR()->uint256: + return 110 +@external +def aK8F62DAN()->uint256: + return 111 +@external +def aJS9EY3U6()->uint256: + return 112 +@external +def aWW789JQH()->uint256: + return 113 +@external +def a8AJJN3YR()->uint256: + return 114 +@external +def a4D0MUIDU()->uint256: + return 115 +@external +def a35W41JQR()->uint256: + return 116 +@external +def a07DQOI1E()->uint256: + return 117 +@external +def aFT43YNCT()->uint256: + return 118 +@external +def a0E75I8X3()->uint256: + return 119 +@external +def aT6NXIRO4()->uint256: + return 120 +@external +def aXB2UBAKQ()->uint256: + return 121 +@external +def aHWH55NW6()->uint256: + return 122 +@external +def a7TCFE6C2()->uint256: + return 123 +@external +def a8XYAM81I()->uint256: + return 124 +@external +def aHQTQ4YBY()->uint256: + return 125 +@external +def aGCZEHG6Y()->uint256: + return 126 +@external +def a6LJTKIW0()->uint256: + return 127 +@external +def aBDIXTD9S()->uint256: + return 128 +@external +def aCB83G21P()->uint256: + return 129 +@external +def aZC525N4K()->uint256: + return 130 +@external +def a40LC94U6()->uint256: + return 131 +@external +def a8X9TI93D()->uint256: + return 132 +@external +def aGUG9CD8Y()->uint256: + return 133 +@external +def a0LAERVAY()->uint256: + return 134 +@external +def aXQ0UEX19()->uint256: + return 135 +@external +def aKK9C7NE7()->uint256: + return 136 +@external +def aS2APW8UE()->uint256: + return 137 +@external +def a65NT07MM()->uint256: + return 138 +@external +def aGRMT6ZW5()->uint256: + return 139 +@external +def aILR4U1Z()->uint256: + return 140 + """ + c = get_contract(code) + + assert c.aX61QLPWF() == 1 # will revert if the header section is misaligned + + @given( n_methods=st.integers(min_value=1, max_value=100), seed=st.integers(min_value=0, max_value=2**64 - 1), diff --git a/vyper/codegen/jumptable_utils.py b/vyper/codegen/jumptable_utils.py index 6987ce90bd..6404b75532 100644 --- a/vyper/codegen/jumptable_utils.py +++ b/vyper/codegen/jumptable_utils.py @@ -43,7 +43,11 @@ def _image_of(xs, magic): return [((x * magic) >> bits_shift) % len(xs) for x in xs] -class _Failure(Exception): +class _FindMagicFailure(Exception): + pass + + +class _HasEmptyBuckets(Exception): pass @@ -53,7 +57,7 @@ def find_magic_for(xs): if len(test) == len(set(test)): return m - raise _Failure(f"Could not find hash for {xs}") + raise _FindMagicFailure(f"Could not find hash for {xs}") def _mk_buckets(method_ids, n_buckets): @@ -72,6 +76,11 @@ def _mk_buckets(method_ids, n_buckets): def _dense_jumptable_info(method_ids, n_buckets): buckets = _mk_buckets(method_ids, n_buckets) + # if there are somehow empty buckets, bail out as that can mess up + # the bucket header layout + if len(buckets) != n_buckets: + raise _HasEmptyBuckets() + ret = {} for bucket_id, method_ids in buckets.items(): magic = find_magic_for(method_ids) @@ -98,8 +107,16 @@ def generate_dense_jumptable_info(signatures): while n_buckets > 0: try: # print(f"trying {n_buckets} (bucket size {n // n_buckets})") - ret = _dense_jumptable_info(method_ids, n_buckets) - except _Failure: + solution = _dense_jumptable_info(method_ids, n_buckets) + assert len(solution) == n_buckets + ret = n_buckets, solution + + except _HasEmptyBuckets: + # found a solution which has empty buckets; skip it since + # it will break the bucket layout. + pass + + except _FindMagicFailure: if ret is not None: break diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 8caea9ee9b..6445a5e1e0 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -124,8 +124,12 @@ def _selector_section_dense(external_functions, global_ctx): ir_node = ["label", label, ["var_list"], entry_point.ir_node] function_irs.append(IRnode.from_list(ir_node)) - jumptable_info = jumptable_utils.generate_dense_jumptable_info(entry_points.keys()) - n_buckets = len(jumptable_info) + n_buckets, jumptable_info = jumptable_utils.generate_dense_jumptable_info(entry_points.keys()) + # note: we are guaranteed by jumptable_utils that there are no buckets + # which are empty. sanity check that the bucket ids are well-behaved: + assert n_buckets == len(jumptable_info) + for i, (bucket_id, _) in enumerate(sorted(jumptable_info.items())): + assert i == bucket_id # bucket magic <2 bytes> | bucket location <2 bytes> | bucket size <1 byte> # TODO: can make it smaller if the largest bucket magic <= 255 From ecf3050782ae15e40e27a338db3f29f296e94bfe Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 18 Sep 2023 14:48:29 -0700 Subject: [PATCH 092/201] chore: add tests for selector table stability (#3608) --- tests/parser/test_selector_table_stability.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/parser/test_selector_table_stability.py diff --git a/tests/parser/test_selector_table_stability.py b/tests/parser/test_selector_table_stability.py new file mode 100644 index 0000000000..abc2c17b8f --- /dev/null +++ b/tests/parser/test_selector_table_stability.py @@ -0,0 +1,53 @@ +from vyper.codegen.jumptable_utils import generate_sparse_jumptable_buckets +from vyper.compiler import compile_code +from vyper.compiler.settings import OptimizationLevel, Settings + + +def test_dense_jumptable_stability(): + function_names = [f"foo{i}" for i in range(30)] + + code = "\n".join(f"@external\ndef {name}():\n pass" for name in function_names) + + output = compile_code(code, ["asm"], settings=Settings(optimize=OptimizationLevel.CODESIZE)) + + # test that the selector table data is stable across different runs + # (tox should provide different PYTHONHASHSEEDs). + expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501 + assert expected_asm in output["asm"] + + +def test_sparse_jumptable_stability(): + function_names = [f"foo{i}()" for i in range(30)] + + # sparse jumptable is not as complicated in assembly. + # here just test the data structure is stable + + n_buckets, buckets = generate_sparse_jumptable_buckets(function_names) + assert n_buckets == 33 + + # the buckets sorted by id are what go into the IR, check equality against + # expected: + assert sorted(buckets.items()) == [ + (0, [4238827494, 1639717992]), + (1, [1852846975]), + (2, [1749874931]), + (3, [4007206371]), + (4, [2298916294]), + (7, [2036990380]), + (10, [3639517672, 73884139]), + (12, [3199534581]), + (13, [99737695]), + (14, [3548993000, 4196577773]), + (15, [454181688, 702207840]), + (16, [3533627641]), + (17, [108068714]), + (20, [1105681157]), + (21, [409552857, 3542787852]), + (22, [4058401765]), + (23, [2520029513, 2242756913]), + (24, [855026394, 183630207]), + (25, [3532216504, 653629177]), + (26, [1937893313]), + (28, [2433332347]), + (31, [3827674321]), + ] From 1711569f0852fa487d8677b0e9984b5692dfc4e6 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 20 Sep 2023 10:47:07 +0800 Subject: [PATCH 093/201] chore: always pass era-tester CI (#3415) This PR relaxes the check for era-tester CI so that it always succeeds as a non-blocking CI. --------- Co-authored-by: Charles Cooper --- .github/workflows/era-tester.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml index 187b5c03a2..3e0bb3e941 100644 --- a/.github/workflows/era-tester.yml +++ b/.github/workflows/era-tester.yml @@ -98,6 +98,7 @@ jobs: - name: Run tester (fast) # Run era tester with no LLVM optimizations + continue-on-error: true if: ${{ github.ref != 'refs/heads/master' }} run: | cd era-compiler-tester @@ -105,7 +106,12 @@ jobs: - name: Run tester (slow) # Run era tester across the LLVM optimization matrix + continue-on-error: true if: ${{ github.ref == 'refs/heads/master' }} run: | cd era-compiler-tester cargo run --release --bin compiler-tester -- --path=tests/vyper/ --mode="M*B* ${{ env.VYPER_VERSION }}" + + - name: Mark as success + run: | + exit 0 From f224d83a91d7ff5097dafaa715d53c6c1a88f502 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Sep 2023 06:12:09 -0700 Subject: [PATCH 094/201] chore: tighten bounds for setuptools_scm (#3613) there is a regression in 8.0.0 which results in the following invalid code being generated for `vyper/version.py`: ```python from __future__ import annotations __version__ : str = version : str = '0.3.11' __version_tuple__ : 'tuple[int | str, ...]' = \ version_tuple : 'tuple[int | str, ...]' = (0, 3, 11) ``` this commit also removes some bad fuzzer deadlines. --- setup.py | 2 +- tests/ast/nodes/test_evaluate_binop_decimal.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index c81b9bed4a..40efb436c5 100644 --- a/setup.py +++ b/setup.py @@ -99,7 +99,7 @@ def _global_version(version): "importlib-metadata", "wheel", ], - setup_requires=["pytest-runner", "setuptools_scm"], + setup_requires=["pytest-runner", "setuptools_scm>=7.1.0,<8.0.0"], tests_require=extras_require["test"], extras_require=extras_require, entry_points={ diff --git a/tests/ast/nodes/test_evaluate_binop_decimal.py b/tests/ast/nodes/test_evaluate_binop_decimal.py index c6c69626b8..3c8ba0888c 100644 --- a/tests/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/ast/nodes/test_evaluate_binop_decimal.py @@ -13,7 +13,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50, deadline=None) @given(left=st_decimals, right=st_decimals) @example(left=Decimal("0.9999999999"), right=Decimal("0.0000000001")) @example(left=Decimal("0.0000000001"), right=Decimal("0.9999999999")) @@ -52,7 +52,7 @@ def test_binop_pow(): @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50, deadline=None) @given( values=st.lists(st_decimals, min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), From 79303fc4fcba06994ee5c6a7baef57bdb185006c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Sep 2023 07:51:41 -0700 Subject: [PATCH 095/201] fix: memory allocation in certain builtins using `msize` (#3610) in certain builtins which use `msize` to allocate a buffer for their arguments (specifically, `raw_call()`, `create_copy_of()` and `create_from_blueprint()`), corruption of the buffer can occur when `msize` is not properly initialized. (this usually happens when there are no variables which are held in memory in the outer external function). what can happen is that some arguments can be evaluated after `msize` is evaluated, leading to overwriting the memory region for the argument buffer with other arguments. specifically, combined with the condition that `msize` is underinitialized, this can happen with: - the buffer for the initcode of `create_copy_of()` and `create_from_blueprint()` can be overwritten when the `salt=` or `value=` arguments write to memory - the buffer for the `data` argument (when `msg.data` is provided, prompting the use of `msize`) of `raw_call()` can be overwritten when the `to`, `gas=` or `value=` arguments write to memory this commit fixes the issue by using a variant of `cache_when_complex()` to ensure that the relevant arguments are evaluated before `msize` is evaluated. this is a patch for GHSA-c647-pxm2-c52w. summarized changelog: * fix raw_call * test: raw_call with msg.data buffer clean memory * force memory effects in some clean_mem tests * add tests for clean memory in create_* functions * add scope_multi abstraction * refactor raw_call to use scope_multi * add fixes for create_* memory cleanliness * update optimizer tests -- callvalue is now considered constant * move salt back into scope_multi * add a note on reads in cache_when_complex --------- Co-authored-by: Tanguy Rocher --- tests/compiler/ir/test_optimize_ir.py | 8 +- .../parser/functions/test_create_functions.py | 209 ++++++++++++++++++ tests/parser/functions/test_raw_call.py | 158 +++++++++++++ vyper/builtins/functions.py | 63 +++--- vyper/codegen/ir_node.py | 81 ++++++- 5 files changed, 487 insertions(+), 32 deletions(-) diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index b679e55453..1466166501 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -143,7 +143,9 @@ (["sub", "x", 0], ["x"]), (["sub", "x", "x"], [0]), (["sub", ["sload", 0], ["sload", 0]], None), - (["sub", ["callvalue"], ["callvalue"]], None), + (["sub", ["callvalue"], ["callvalue"]], [0]), + (["sub", ["msize"], ["msize"]], None), + (["sub", ["gas"], ["gas"]], None), (["sub", -1, ["sload", 0]], ["not", ["sload", 0]]), (["mul", "x", 1], ["x"]), (["div", "x", 1], ["x"]), @@ -210,7 +212,9 @@ (["eq", -1, ["add", -(2**255), 2**255 - 1]], [1]), # test compile-time wrapping (["eq", -2, ["add", 2**256 - 1, 2**256 - 1]], [1]), # test compile-time wrapping (["eq", "x", "x"], [1]), - (["eq", "callvalue", "callvalue"], None), + (["eq", "gas", "gas"], None), + (["eq", "msize", "msize"], None), + (["eq", "callvalue", "callvalue"], [1]), (["ne", "x", "x"], [0]), ] diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py index 876d50b27d..fa7729d98e 100644 --- a/tests/parser/functions/test_create_functions.py +++ b/tests/parser/functions/test_create_functions.py @@ -431,3 +431,212 @@ def test2(target: address, salt: bytes32) -> address: # test2 = c.test2(b"\x01", salt) # assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(b"\x01")) # assert_tx_failed(lambda: c.test2(bytecode, salt)) + + +# XXX: these various tests to check the msize allocator for +# create_copy_of and create_from_blueprint depend on calling convention +# and variables writing to memory. think of ways to make more robust to +# changes in calling convention and memory layout +@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) +def test_create_from_blueprint_complex_value( + get_contract, deploy_blueprint_for, w3, blueprint_prefix +): + # check msize allocator does not get trampled by value= kwarg + code = """ +var: uint256 + +@external +@payable +def __init__(x: uint256): + self.var = x + +@external +def foo()-> uint256: + return self.var + """ + + prefix_len = len(blueprint_prefix) + + some_constant = b"\00" * 31 + b"\x0c" + + deployer_code = f""" +created_address: public(address) +x: constant(Bytes[32]) = {some_constant} + +@internal +def foo() -> uint256: + g:uint256 = 42 + return 3 + +@external +@payable +def test(target: address): + self.created_address = create_from_blueprint( + target, + x, + code_offset={prefix_len}, + value=self.foo(), + raw_args=True + ) + """ + + foo_contract = get_contract(code, 12) + expected_runtime_code = w3.eth.get_code(foo_contract.address) + + f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix) + + d = get_contract(deployer_code) + + d.test(f.address, transact={"value": 3}) + + test = FooContract(d.created_address()) + assert w3.eth.get_code(test.address) == expected_runtime_code + assert test.foo() == 12 + + +@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) +def test_create_from_blueprint_complex_salt_raw_args( + get_contract, deploy_blueprint_for, w3, blueprint_prefix +): + # test msize allocator does not get trampled by salt= kwarg + code = """ +var: uint256 + +@external +@payable +def __init__(x: uint256): + self.var = x + +@external +def foo()-> uint256: + return self.var + """ + + some_constant = b"\00" * 31 + b"\x0c" + prefix_len = len(blueprint_prefix) + + deployer_code = f""" +created_address: public(address) + +x: constant(Bytes[32]) = {some_constant} +salt: constant(bytes32) = keccak256("kebab") + +@internal +def foo() -> bytes32: + g:uint256 = 42 + return salt + +@external +@payable +def test(target: address): + self.created_address = create_from_blueprint( + target, + x, + code_offset={prefix_len}, + salt=self.foo(), + raw_args= True + ) + """ + + foo_contract = get_contract(code, 12) + expected_runtime_code = w3.eth.get_code(foo_contract.address) + + f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix) + + d = get_contract(deployer_code) + + d.test(f.address, transact={}) + + test = FooContract(d.created_address()) + assert w3.eth.get_code(test.address) == expected_runtime_code + assert test.foo() == 12 + + +@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) +def test_create_from_blueprint_complex_salt_no_constructor_args( + get_contract, deploy_blueprint_for, w3, blueprint_prefix +): + # test msize allocator does not get trampled by salt= kwarg + code = """ +var: uint256 + +@external +@payable +def __init__(): + self.var = 12 + +@external +def foo()-> uint256: + return self.var + """ + + prefix_len = len(blueprint_prefix) + deployer_code = f""" +created_address: public(address) + +salt: constant(bytes32) = keccak256("kebab") + +@external +@payable +def test(target: address): + self.created_address = create_from_blueprint( + target, + code_offset={prefix_len}, + salt=keccak256(_abi_encode(target)) + ) + """ + + foo_contract = get_contract(code) + expected_runtime_code = w3.eth.get_code(foo_contract.address) + + f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix) + + d = get_contract(deployer_code) + + d.test(f.address, transact={}) + + test = FooContract(d.created_address()) + assert w3.eth.get_code(test.address) == expected_runtime_code + assert test.foo() == 12 + + +def test_create_copy_of_complex_kwargs(get_contract, w3): + # test msize allocator does not get trampled by salt= kwarg + complex_salt = """ +created_address: public(address) + +@external +def test(target: address) -> address: + self.created_address = create_copy_of( + target, + salt=keccak256(_abi_encode(target)) + ) + return self.created_address + + """ + + c = get_contract(complex_salt) + bytecode = w3.eth.get_code(c.address) + c.test(c.address, transact={}) + test1 = c.created_address() + assert w3.eth.get_code(test1) == bytecode + + # test msize allocator does not get trampled by value= kwarg + complex_value = """ +created_address: public(address) + +@external +@payable +def test(target: address) -> address: + value: uint256 = 2 + self.created_address = create_copy_of(target, value = [2,2,2][value]) + return self.created_address + + """ + + c = get_contract(complex_value) + bytecode = w3.eth.get_code(c.address) + + c.test(c.address, transact={"value": 2}) + test1 = c.created_address() + assert w3.eth.get_code(test1) == bytecode diff --git a/tests/parser/functions/test_raw_call.py b/tests/parser/functions/test_raw_call.py index 9c6fba79e7..81efe64a18 100644 --- a/tests/parser/functions/test_raw_call.py +++ b/tests/parser/functions/test_raw_call.py @@ -426,6 +426,164 @@ def baz(_addr: address, should_raise: bool) -> uint256: assert caller.baz(target.address, False) == 3 +# XXX: these test_raw_call_clean_mem* tests depend on variables and +# calling convention writing to memory. think of ways to make more +# robust to changes to calling convention and memory layout. + + +def test_raw_call_msg_data_clean_mem(get_contract): + # test msize uses clean memory and does not get overwritten by + # any raw_call() arguments + code = """ +identity: constant(address) = 0x0000000000000000000000000000000000000004 + +@external +def foo(): + pass + +@internal +@view +def get_address()->address: + a:uint256 = 121 # 0x79 + return identity +@external +def bar(f: uint256, u: uint256) -> Bytes[100]: + # embed an internal call in the calculation of address + a: Bytes[100] = raw_call(self.get_address(), msg.data, max_outsize=100) + return a + """ + + c = get_contract(code) + assert ( + c.bar(1, 2).hex() == "ae42e951" + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002" + ) + + +def test_raw_call_clean_mem2(get_contract): + # test msize uses clean memory and does not get overwritten by + # any raw_call() arguments, another way + code = """ +buf: Bytes[100] + +@external +def bar(f: uint256, g: uint256, h: uint256) -> Bytes[100]: + # embed a memory modifying expression in the calculation of address + self.buf = raw_call( + [0x0000000000000000000000000000000000000004,][f-1], + msg.data, + max_outsize=100 + ) + return self.buf + """ + c = get_contract(code) + + assert ( + c.bar(1, 2, 3).hex() == "9309b76e" + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002" + "0000000000000000000000000000000000000000000000000000000000000003" + ) + + +def test_raw_call_clean_mem3(get_contract): + # test msize uses clean memory and does not get overwritten by + # any raw_call() arguments, and also test order of evaluation for + # scope_multi + code = """ +buf: Bytes[100] +canary: String[32] + +@internal +def bar() -> address: + self.canary = "bar" + return 0x0000000000000000000000000000000000000004 + +@internal +def goo() -> uint256: + self.canary = "goo" + return 0 + +@external +def foo() -> String[32]: + self.buf = raw_call(self.bar(), msg.data, value = self.goo(), max_outsize=100) + return self.canary + """ + c = get_contract(code) + assert c.foo() == "goo" + + +def test_raw_call_clean_mem_kwargs_value(get_contract): + # test msize uses clean memory and does not get overwritten by + # any raw_call() kwargs + code = """ +buf: Bytes[100] + +# add a dummy function to trigger memory expansion in the selector table routine +@external +def foo(): + pass + +@internal +def _value() -> uint256: + x: uint256 = 1 + return x + +@external +def bar(f: uint256) -> Bytes[100]: + # embed a memory modifying expression in the calculation of address + self.buf = raw_call( + 0x0000000000000000000000000000000000000004, + msg.data, + max_outsize=100, + value=self._value() + ) + return self.buf + """ + c = get_contract(code, value=1) + + assert ( + c.bar(13).hex() == "0423a132" + "000000000000000000000000000000000000000000000000000000000000000d" + ) + + +def test_raw_call_clean_mem_kwargs_gas(get_contract): + # test msize uses clean memory and does not get overwritten by + # any raw_call() kwargs + code = """ +buf: Bytes[100] + +# add a dummy function to trigger memory expansion in the selector table routine +@external +def foo(): + pass + +@internal +def _gas() -> uint256: + x: uint256 = msg.gas + return x + +@external +def bar(f: uint256) -> Bytes[100]: + # embed a memory modifying expression in the calculation of address + self.buf = raw_call( + 0x0000000000000000000000000000000000000004, + msg.data, + max_outsize=100, + gas=self._gas() + ) + return self.buf + """ + c = get_contract(code, value=1) + + assert ( + c.bar(15).hex() == "0423a132" + "000000000000000000000000000000000000000000000000000000000000000f" + ) + + uncompilable_code = [ ( """ diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 3ec8f69934..95759372a6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -21,6 +21,7 @@ clamp_basetype, clamp_nonzero, copy_bytes, + dummy_node_for_type, ensure_in_memory, eval_once_check, eval_seq, @@ -36,7 +37,7 @@ unwrap_location, ) from vyper.codegen.expr import Expr -from vyper.codegen.ir_node import Encoding +from vyper.codegen.ir_node import Encoding, scope_multi from vyper.codegen.keccak256_helper import keccak256_helper from vyper.evm.address_space import MEMORY, STORAGE from vyper.exceptions import ( @@ -1155,14 +1156,17 @@ def build_IR(self, expr, args, kwargs, context): outsize, ] - if delegate_call: - call_op = ["delegatecall", gas, to, *common_call_args] - elif static_call: - call_op = ["staticcall", gas, to, *common_call_args] - else: - call_op = ["call", gas, to, value, *common_call_args] + gas, value = IRnode.from_list(gas), IRnode.from_list(value) + with scope_multi((to, value, gas), ("_to", "_value", "_gas")) as (b1, (to, value, gas)): + if delegate_call: + call_op = ["delegatecall", gas, to, *common_call_args] + elif static_call: + call_op = ["staticcall", gas, to, *common_call_args] + else: + call_op = ["call", gas, to, value, *common_call_args] - call_ir += [call_op] + call_ir += [call_op] + call_ir = b1.resolve(call_ir) # build sequence IR if outsize: @@ -1589,13 +1593,15 @@ def build_IR(self, expr, context): # CREATE* functions +CREATE2_SENTINEL = dummy_node_for_type(BYTES32_T) + # create helper functions # generates CREATE op sequence + zero check for result -def _create_ir(value, buf, length, salt=None, checked=True): +def _create_ir(value, buf, length, salt, checked=True): args = [value, buf, length] create_op = "create" - if salt is not None: + if salt is not CREATE2_SENTINEL: create_op = "create2" args.append(salt) @@ -1713,8 +1719,9 @@ def build_IR(self, expr, args, kwargs, context): context.check_is_not_constant("use {self._id}", expr) should_use_create2 = "salt" in [kwarg.arg for kwarg in expr.keywords] + if not should_use_create2: - kwargs["salt"] = None + kwargs["salt"] = CREATE2_SENTINEL ir_builder = self._build_create_IR(expr, args, context, **kwargs) @@ -1794,13 +1801,16 @@ def _add_gas_estimate(self, args, should_use_create2): def _build_create_IR(self, expr, args, context, value, salt): target = args[0] - with target.cache_when_complex("create_target") as (b1, target): + # something we can pass to scope_multi + with scope_multi( + (target, value, salt), ("create_target", "create_value", "create_salt") + ) as (b1, (target, value, salt)): codesize = IRnode.from_list(["extcodesize", target]) msize = IRnode.from_list(["msize"]) - with codesize.cache_when_complex("target_codesize") as ( + with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as ( b2, - codesize, - ), msize.cache_when_complex("mem_ofst") as (b3, mem_ofst): + (codesize, mem_ofst), + ): ir = ["seq"] # make sure there is actually code at the target @@ -1824,7 +1834,7 @@ def _build_create_IR(self, expr, args, context, value, salt): ir.append(_create_ir(value, buf, buf_len, salt)) - return b1.resolve(b2.resolve(b3.resolve(ir))) + return b1.resolve(b2.resolve(ir)) class CreateFromBlueprint(_CreateBase): @@ -1877,17 +1887,18 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar # (since the abi encoder could write to fresh memory). # it would be good to not require the memory copy, but need # to evaluate memory safety. - with target.cache_when_complex("create_target") as (b1, target), argslen.cache_when_complex( - "encoded_args_len" - ) as (b2, encoded_args_len), code_offset.cache_when_complex("code_ofst") as (b3, codeofst): - codesize = IRnode.from_list(["sub", ["extcodesize", target], codeofst]) + with scope_multi( + (target, value, salt, argslen, code_offset), + ("create_target", "create_value", "create_salt", "encoded_args_len", "code_offset"), + ) as (b1, (target, value, salt, encoded_args_len, code_offset)): + codesize = IRnode.from_list(["sub", ["extcodesize", target], code_offset]) # copy code to memory starting from msize. we are clobbering # unused memory so it's safe. msize = IRnode.from_list(["msize"], location=MEMORY) - with codesize.cache_when_complex("target_codesize") as ( - b4, - codesize, - ), msize.cache_when_complex("mem_ofst") as (b5, mem_ofst): + with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as ( + b2, + (codesize, mem_ofst), + ): ir = ["seq"] # make sure there is code at the target, and that @@ -1907,7 +1918,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar # copy the target code into memory. # layout starting from mem_ofst: # 00...00 (22 0's) | preamble | bytecode - ir.append(["extcodecopy", target, mem_ofst, codeofst, codesize]) + ir.append(["extcodecopy", target, mem_ofst, code_offset, codesize]) ir.append(copy_bytes(add_ofst(mem_ofst, codesize), argbuf, encoded_args_len, bufsz)) @@ -1922,7 +1933,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar ir.append(_create_ir(value, mem_ofst, length, salt)) - return b1.resolve(b2.resolve(b3.resolve(b4.resolve(b5.resolve(ir))))) + return b1.resolve(b2.resolve(ir)) class _UnsafeMath(BuiltinFunction): diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 6cb0a07281..ad4aa76437 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -1,3 +1,4 @@ +import contextlib import re from enum import Enum, auto from functools import cached_property @@ -46,6 +47,77 @@ class Encoding(Enum): # future: packed +# shortcut for chaining multiple cache_when_complex calls +# CMC 2023-08-10 remove this and scope_together _as soon as_ we have +# real variables in IR (that we can declare without explicit scoping - +# needs liveness analysis). +@contextlib.contextmanager +def scope_multi(ir_nodes, names): + assert len(ir_nodes) == len(names) + + builders = [] + scoped_ir_nodes = [] + + class _MultiBuilder: + def resolve(self, body): + # sanity check that it's initialized properly + assert len(builders) == len(ir_nodes) + ret = body + for b in reversed(builders): + ret = b.resolve(ret) + return ret + + mb = _MultiBuilder() + + with contextlib.ExitStack() as stack: + for arg, name in zip(ir_nodes, names): + b, ir_node = stack.enter_context(arg.cache_when_complex(name)) + + builders.append(b) + scoped_ir_nodes.append(ir_node) + + yield mb, scoped_ir_nodes + + +# create multiple with scopes if any of the items are complex, to force +# ordering of side effects. +@contextlib.contextmanager +def scope_together(ir_nodes, names): + assert len(ir_nodes) == len(names) + + should_scope = any(s._optimized.is_complex_ir for s in ir_nodes) + + class _Builder: + def resolve(self, body): + if not should_scope: + # uses of the variable have already been inlined + return body + + ret = body + # build with scopes from inside-out (hence reversed) + for arg, name in reversed(list(zip(ir_nodes, names))): + ret = ["with", name, arg, ret] + + if isinstance(body, IRnode): + return IRnode.from_list( + ret, typ=body.typ, location=body.location, encoding=body.encoding + ) + else: + return ret + + b = _Builder() + + if should_scope: + ir_vars = tuple( + IRnode.from_list(name, typ=arg.typ, location=arg.location, encoding=arg.encoding) + for (arg, name) in zip(ir_nodes, names) + ) + yield b, ir_vars + else: + # inline them + yield b, ir_nodes + + # this creates a magical block which maps to IR `with` class _WithBuilder: def __init__(self, ir_node, name, should_inline=False): @@ -326,14 +398,15 @@ def _check(condition, err): def gas(self): return self._gas + self.add_gas_estimate - # the IR should be cached. - # TODO make this private. turns out usages are all for the caching - # idiom that cache_when_complex addresses + # the IR should be cached and/or evaluated exactly once @property def is_complex_ir(self): # list of items not to cache. note can add other env variables # which do not change, e.g. calldatasize, coinbase, etc. - do_not_cache = {"~empty", "calldatasize"} + # reads (from memory or storage) should not be cached because + # they can have or be affected by side effects. + do_not_cache = {"~empty", "calldatasize", "callvalue"} + return ( isinstance(self.value, str) and (self.value.lower() in VALID_IR_MACROS or self.value.upper() in get_ir_opcodes()) From 1e9922d9f76bed2083ce2187e5be72912cfe2082 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 25 Sep 2023 09:59:56 -0700 Subject: [PATCH 096/201] chore: add notes to selector table implementation (#3618) and a couple sanity checks --- vyper/codegen/function_definitions/common.py | 7 +++++++ vyper/codegen/module.py | 5 ++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 3fd5ce0b29..1d24b6c6dd 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -73,6 +73,13 @@ class EntryPointInfo: min_calldatasize: int # the min calldata required for this entry point ir_node: IRnode # the ir for this entry point + def __post_init__(self): + # ABI v2 property guaranteed by the spec. + # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 + # > Note that for any X, len(enc(X)) is a multiple of 32. + assert self.min_calldatasize >= 4 + assert (self.min_calldatasize - 4) % 32 == 0 + @dataclass class ExternalFuncIR(FuncIR): diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 6445a5e1e0..bfdafa8ba9 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -93,9 +93,12 @@ def _generate_external_entry_points(external_functions, global_ctx): for code in external_functions: func_ir = generate_ir_for_function(code, global_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): + method_id = method_id_int(abi_sig) assert abi_sig not in entry_points + assert method_id not in sig_of + entry_points[abi_sig] = entry_point - sig_of[method_id_int(abi_sig)] = abi_sig + sig_of[method_id] = abi_sig # stick function common body into final entry point to save a jump ir_node = IRnode.from_list(["seq", entry_point.ir_node, func_ir.common_ir]) From 7b9d159b84c0e568378d53e753350e7f691c413a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 25 Sep 2023 20:05:33 -0700 Subject: [PATCH 097/201] docs: mcopy is enabled with cancun target (#3620) --- docs/compiling-a-contract.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 6d1cdf98d7..b529d1efb1 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -197,6 +197,7 @@ The following is a list of supported EVM versions, and changes in the compiler i - The ``transient`` keyword allows declaration of variables which live in transient storage - Functions marked with ``@nonreentrant`` are protected with TLOAD/TSTORE instead of SLOAD/SSTORE + - The ``MCOPY`` opcode will be generated automatically by the compiler for most memory operations. From e5c323afa4f61a2fc7a28bbfb824afd90ec86158 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Sep 2023 11:24:32 -0700 Subject: [PATCH 098/201] chore: fix some documentation inconsistencies (#3624) * raw_args kwarg to create_from_blueprint * fix create_from_blueprint example * clarify "all but one 64th" behavior when forwarding gas left= * remove dead comment * update internal documentation for generate_ir_for_external_function() * update memory layout in create_from_blueprint comments * fix warning for BitwiseNot * fix error message about msg.data * docs: extract32 can output any bytesM type * remove dead variable in comment --- docs/built-in-functions.rst | 13 +++++++++---- vyper/builtins/functions.py | 6 ++---- .../function_definitions/external_function.py | 11 ++++------- vyper/semantics/analysis/local.py | 2 +- vyper/semantics/analysis/module.py | 1 - 5 files changed, 16 insertions(+), 17 deletions(-) diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index bfaa8fdd5e..45cf9ec8c2 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -184,13 +184,14 @@ Vyper has three built-ins for contract creation; all three contract creation bui The implementation of ``create_copy_of`` assumes that the code at ``target`` is smaller than 16MB. While this is much larger than the EIP-170 constraint of 24KB, it is a conservative size limit intended to future-proof deployer contracts in case the EIP-170 constraint is lifted. If the code at ``target`` is larger than 16MB, the behavior of ``create_copy_of`` is undefined. -.. py:function:: create_from_blueprint(target: address, *args, value: uint256 = 0, code_offset=0, [, salt: bytes32]) -> address +.. py:function:: create_from_blueprint(target: address, *args, value: uint256 = 0, raw_args: bool = False, code_offset: int = 0, [, salt: bytes32]) -> address Copy the code of ``target`` into memory and execute it as initcode. In other words, this operation interprets the code at ``target`` not as regular runtime code, but directly as initcode. The ``*args`` are interpreted as constructor arguments, and are ABI-encoded and included when executing the initcode. * ``target``: Address of the blueprint to invoke * ``*args``: Constructor arguments to forward to the initcode. * ``value``: The wei value to send to the new contract address (Optional, default 0) + * ``raw_args``: If ``True``, ``*args`` must be a single ``Bytes[...]`` argument, which will be interpreted as a raw bytes buffer to forward to the create operation (which is useful for instance, if pre- ABI-encoded data is passed in from elsewhere). (Optional, default ``False``) * ``code_offset``: The offset to start the ``EXTCODECOPY`` from (Optional, default 0) * ``salt``: A ``bytes32`` value utilized by the deterministic ``CREATE2`` opcode (Optional, if not supplied, ``CREATE`` is used) @@ -201,7 +202,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui @external def foo(blueprint: address) -> address: arg1: uint256 = 18 - arg2: String = "some string" + arg2: String[32] = "some string" return create_from_blueprint(blueprint, arg1, arg2, code_offset=1) .. note:: @@ -226,7 +227,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``to``: Destination address to call to * ``data``: Data to send to the destination address * ``max_outsize``: Maximum length of the bytes array returned from the call. If the returned call data exceeds this length, only this number of bytes is returned. (Optional, default ``0``) - * ``gas``: The amount of gas to attach to the call. If not set, all remaining gas is forwarded. + * ``gas``: The amount of gas to attach to the call. (Optional, defaults to ``msg.gas``). * ``value``: The wei value to send to the address (Optional, default ``0``) * ``is_delegate_call``: If ``True``, the call will be sent as ``DELEGATECALL`` (Optional, default ``False``) * ``is_static_call``: If ``True``, the call will be sent as ``STATICCALL`` (Optional, default ``False``) @@ -264,6 +265,10 @@ Vyper has three built-ins for contract creation; all three contract creation bui assert success return response + .. note:: + + Regarding "forwarding all gas", note that, while Vyper will provide ``msg.gas`` to the call, in practice, there are some subtleties around forwarding all remaining gas on the EVM which are out of scope of this documentation and could be subject to change. For instance, see the language in EIP-150 around "all but one 64th". + .. py:function:: raw_log(topics: bytes32[4], data: Union[Bytes, bytes32]) -> None Provides low level access to the ``LOG`` opcodes, emitting a log without having to specify an ABI type. @@ -500,7 +505,7 @@ Data Manipulation * ``b``: ``Bytes`` list to extract from * ``start``: Start point to extract from - * ``output_type``: Type of output (``bytes32``, ``integer``, or ``address``). Defaults to ``bytes32``. + * ``output_type``: Type of output (``bytesM``, ``integer``, or ``address``). Defaults to ``bytes32``. Returns a value of the type specified by ``output_type``. diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 95759372a6..a0936712b2 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1418,7 +1418,7 @@ class BitwiseNot(BuiltinFunction): def evaluate(self, node): if not self.__class__._warned: - vyper_warn("`bitwise_not()` is deprecated! Please use the ^ operator instead.") + vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") self.__class__._warned = True validate_call_args(node, 1) @@ -1917,9 +1917,8 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar # copy the target code into memory. # layout starting from mem_ofst: - # 00...00 (22 0's) | preamble | bytecode + # | ir.append(["extcodecopy", target, mem_ofst, code_offset, codesize]) - ir.append(copy_bytes(add_ofst(mem_ofst, codesize), argbuf, encoded_args_len, bufsz)) # theoretically, dst = "msize", but just be safe. @@ -2586,7 +2585,6 @@ def evaluate(self, node): if isinstance(input_type, IntegerT): ret = vy_ast.Int.from_node(node, value=val) - # TODO: to change to known_type once #3213 is merged ret._metadata["type"] = input_type return ret diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 32236e9aad..65276469e7 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -135,20 +135,17 @@ def handler_for(calldata_kwargs, default_kwargs): return ret -# TODO it would be nice if this returned a data structure which were -# amenable to generating a jump table instead of the linear search for -# method_id we have now. def generate_ir_for_external_function(code, func_t, context): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, # func_t: ContractFunctionT, # context: Context, - # check_nonpayable: bool, # ) -> IRnode: - """Return the IR for an external function. Includes code to inspect the method_id, - enter the function (nonpayable and reentrancy checks), handle kwargs and exit - the function (clean up reentrancy storage variables) + """ + Return the IR for an external function. Returns IR for the body + of the function, handle kwargs and exit the function. Also returns + metadata required for `module.py` to construct the selector table. """ nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c10df3b8fd..b391b33953 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -150,7 +150,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: allowed_builtins = ("slice", "len", "raw_call") if not isinstance(parent, vy_ast.Call) or parent.get("func.id") not in allowed_builtins: raise StructureException( - "msg.data is only allowed inside of the slice or len functions", node + "msg.data is only allowed inside of the slice, len or raw_call functions", node ) if parent.get("func.id") == "slice": ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 02ae82faac..e59422294c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -98,7 +98,6 @@ def __init__( _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore module_node._metadata["namespace"] = _ns - # check for collisions between 4byte function selectors self_members = namespace["self"].typ.members # get list of internal function calls made by each function From d438d927bed3b850fe4768a490f3acde5f51b475 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Sep 2023 15:24:46 -0700 Subject: [PATCH 099/201] fix: `_abi_decode()` validation (#3626) `_abi_decode()` does not validate input when it is nested in certain expressions. the following example gets correctly validated (bounds checked): ```vyper x: uint8 = _abi_decode(slice(msg.data, 4, 32), uint8) ``` however, the following example is not bounds checked: ```vyper @external def abi_decode(x: uint256) -> uint256: a: uint256 = convert( _abi_decode( slice(msg.data, 4, 32), (uint8) ), uint256 ) return a # abi_decode(256) returns: 256 ``` the issue is caused because the `ABIDecode()` builtin tags its output with `encoding=Encoding.ABI`, but this does not result in validation until that itself is passed to `make_setter` (which is called for instance when generating an assignment or return statement). the issue can be triggered by constructing an example where the output of `ABIDecode()` is not internally passed to `make_setter` or other input validating routine. this commit fixes the issue by calling `make_setter` in `ABIDecode()` before returning the output buffer, which causes validation to be performed. note that this causes a performance regression in the common (and majority of) cases where `make_setter` is immediately called on the result of `ABIDecode()` because a redundant memory copy ends up being generated (like in the aforementioned examples: in a plain assignment or return statement). however, fixing this performance regression is left to future work in the optimizer. --- tests/parser/functions/test_abi_decode.py | 28 ++++++++++++ vyper/builtins/functions.py | 52 +++++++++++------------ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/tests/parser/functions/test_abi_decode.py b/tests/parser/functions/test_abi_decode.py index 2f9b93057d..2216a5bd76 100644 --- a/tests/parser/functions/test_abi_decode.py +++ b/tests/parser/functions/test_abi_decode.py @@ -344,6 +344,34 @@ def abi_decode(x: Bytes[96]) -> (uint256, uint256): assert_tx_failed(lambda: c.abi_decode(input_)) +def test_clamper_nested_uint8(get_contract, assert_tx_failed): + # check that _abi_decode clamps on word-types even when it is in a nested expression + # decode -> validate uint8 -> revert if input >= 256 -> cast back to uint256 + contract = """ +@external +def abi_decode(x: uint256) -> uint256: + a: uint256 = convert(_abi_decode(slice(msg.data, 4, 32), (uint8)), uint256) + return a + """ + c = get_contract(contract) + assert c.abi_decode(255) == 255 + assert_tx_failed(lambda: c.abi_decode(256)) + + +def test_clamper_nested_bytes(get_contract, assert_tx_failed): + # check that _abi_decode clamps dynamic even when it is in a nested expression + # decode -> validate Bytes[20] -> revert if len(input) > 20 -> convert back to -> add 1 + contract = """ +@external +def abi_decode(x: Bytes[96]) -> Bytes[21]: + a: Bytes[21] = concat(b"a", _abi_decode(x, Bytes[20])) + return a + """ + c = get_contract(contract) + assert c.abi_decode(abi.encode("(bytes)", (b"bc",))) == b"abc" + assert_tx_failed(lambda: c.abi_decode(abi.encode("(bytes)", (b"a" * 22,)))) + + @pytest.mark.parametrize( "output_typ,input_", [ diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index a0936712b2..8cdd2a4b8b 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -29,7 +29,6 @@ get_type_for_exact_size, ir_tuple_from_args, make_setter, - needs_external_call_wrap, promote_signed_int, sar, shl, @@ -2367,8 +2366,6 @@ def build_IR(self, expr, args, kwargs, context): class ABIEncode(BuiltinFunction): _id = "_abi_encode" # TODO prettier to rename this to abi.encode # signature: *, ensure_tuple= -> Bytes[] - # (check the signature manually since we have no utility methods - # to handle varargs.) # explanation of ensure_tuple: # default is to force even a single value into a tuple, # e.g. _abi_encode(bytes) -> _abi_encode((bytes,)) @@ -2529,24 +2526,11 @@ def build_IR(self, expr, args, kwargs, context): ) data = ensure_in_memory(data, context) + with data.cache_when_complex("to_decode") as (b1, data): data_ptr = bytes_data_ptr(data) data_len = get_bytearray_length(data) - # Normally, ABI-encoded data assumes the argument is a tuple - # (See comments for `wrap_value_for_external_return`) - # However, we do not want to use `wrap_value_for_external_return` - # technique as used in external call codegen because in order to be - # type-safe we would need an extra memory copy. To avoid a copy, - # we manually add the ABI-dynamic offset so that it is - # re-interpreted in-place. - if ( - unwrap_tuple is True - and needs_external_call_wrap(output_typ) - and output_typ.abi_type.is_dynamic() - ): - data_ptr = add_ofst(data_ptr, 32) - ret = ["seq"] if abi_min_size == abi_size_bound: @@ -2555,18 +2539,30 @@ def build_IR(self, expr, args, kwargs, context): # runtime assert: abi_min_size <= data_len <= abi_size_bound ret.append(clamp2(abi_min_size, data_len, abi_size_bound, signed=False)) - # return pointer to the buffer - ret.append(data_ptr) - - return b1.resolve( - IRnode.from_list( - ret, - typ=output_typ, - location=data.location, - encoding=Encoding.ABI, - annotation=f"abi_decode({output_typ})", - ) + to_decode = IRnode.from_list( + data_ptr, + typ=wrapped_typ, + location=data.location, + encoding=Encoding.ABI, + annotation=f"abi_decode({output_typ})", ) + to_decode.encoding = Encoding.ABI + + # TODO optimization: skip make_setter when we don't need + # input validation + + output_buf = context.new_internal_variable(wrapped_typ) + output = IRnode.from_list(output_buf, typ=wrapped_typ, location=MEMORY) + + # sanity check buffer size for wrapped output type will not buffer overflow + assert wrapped_typ.memory_bytes_required == output_typ.memory_bytes_required + ret.append(make_setter(output, to_decode)) + + ret.append(output) + # finalize. set the type and location for the return buffer. + # (note: unwraps the tuple type if necessary) + ret = IRnode.from_list(ret, typ=output_typ, location=MEMORY) + return b1.resolve(ret) class _MinMaxValue(TypenameFoldedFunction): From 950a97ea0d16db9884ec2f09bc71f1fc52c20bb5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 26 Sep 2023 16:19:34 -0700 Subject: [PATCH 100/201] fix: type check abi_decode arguments (#3623) currently, the following code will trigger a compiler panic: ```vyper @external def foo(j: uint256) -> bool: s: bool = _abi_decode(j, bool, unwrap_tuple= False) return s ``` the following code will compile, even though it should not typecheck: ```vyper @external def foo(s: String[32]) -> bool: t: bool = _abi_decode(s, bool, unwrap_tuple=False) return t ``` this commit fixes the issue by typechecking the input to `_abi_decode()`. it also adds syntax tests for `_abi_decode()`. --- tests/parser/syntax/test_abi_decode.py | 45 ++++++++++++++++++++++++++ vyper/builtins/functions.py | 2 ++ 2 files changed, 47 insertions(+) create mode 100644 tests/parser/syntax/test_abi_decode.py diff --git a/tests/parser/syntax/test_abi_decode.py b/tests/parser/syntax/test_abi_decode.py new file mode 100644 index 0000000000..f05ff429cd --- /dev/null +++ b/tests/parser/syntax/test_abi_decode.py @@ -0,0 +1,45 @@ +import pytest + +from vyper import compiler +from vyper.exceptions import TypeMismatch + +fail_list = [ + ( + """ +@external +def foo(j: uint256) -> bool: + s: bool = _abi_decode(j, bool, unwrap_tuple= False) + return s + """, + TypeMismatch, + ), + ( + """ +@external +def bar(j: String[32]) -> bool: + s: bool = _abi_decode(j, bool, unwrap_tuple= False) + return s + """, + TypeMismatch, + ), +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_abi_encode_fail(bad_code, exc): + with pytest.raises(exc): + compiler.compile_code(bad_code) + + +valid_list = [ + """ +@external +def foo(x: Bytes[32]) -> uint256: + return _abi_decode(x, uint256) + """ +] + + +@pytest.mark.parametrize("good_code", valid_list) +def test_abi_encode_success(good_code): + assert compiler.compile_code(good_code) is not None diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 8cdd2a4b8b..f07202831d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2490,6 +2490,8 @@ def fetch_call_return(self, node): return output_type.typedef def infer_arg_types(self, node): + self._validate_arg_types(node) + validate_call_args(node, 2, ["unwrap_tuple"]) data_type = get_exact_type_from_node(node.args[0]) From 2bdbd846b09c94f05739e1274e00825912404fe3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 27 Sep 2023 21:06:03 +0800 Subject: [PATCH 101/201] chore: add metadata to vyper-json (#3622) --- tests/cli/vyper_json/test_output_selection.py | 6 ++++++ vyper/cli/vyper_json.py | 2 +- vyper/compiler/output.py | 1 - 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/cli/vyper_json/test_output_selection.py b/tests/cli/vyper_json/test_output_selection.py index c72f06f5a7..3b12e2b54a 100644 --- a/tests/cli/vyper_json/test_output_selection.py +++ b/tests/cli/vyper_json/test_output_selection.py @@ -52,3 +52,9 @@ def test_solc_style(): input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} sources = {"foo.vy": ""} assert get_input_dict_output_formats(input_json, sources) == {"foo.vy": ["abi", "ir_dict"]} + + +def test_metadata(): + input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} + sources = {"foo.vy": ""} + assert get_input_dict_output_formats(input_json, sources) == {"foo.vy": ["metadata"]} diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 4a1c91550e..f6d82c3fe0 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -29,7 +29,7 @@ "interface": "interface", "ir": "ir_dict", "ir_runtime": "ir_runtime_dict", - # "metadata": "metadata", # don't include in "*" output for now + "metadata": "metadata", "layout": "layout", "userdoc": "userdoc", } diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 334c5ba613..9ef492c3e2 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -104,7 +104,6 @@ def build_ir_runtime_dict_output(compiler_data: CompilerData) -> dict: def build_metadata_output(compiler_data: CompilerData) -> dict: - warnings.warn("metadata output format is unstable!") sigs = compiler_data.function_signatures def _var_rec_dict(variable_record): From aecd911347af5912a22540ff3dc513273e51c72d Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 28 Sep 2023 14:42:12 -0700 Subject: [PATCH 102/201] fix: metadata output interaction with natspec (#3627) enabling the `-f metadata` output has an interaction with other outputs because the metadata output format mutates some internal data structures in-place. this is because `vars()` returns a reference to the object's `__dict__` as opposed to a copy of it. the behavior can be seen by trying to call the compiler with `-f metadata,devdoc,userdoc`. this issue was revealed in (but not introduced by) 2bdbd846b0, because that commit caused metadata and userdoc to be bundled together by default. this commit fixes the issue by constructing a copy of the object during metadata output formatting. it also modifies the test suite to include more output formats, to test the interactions between these different output formats. in doing so, it was also found that some examples have invalid natspec, which has also been fixed. --- examples/tokens/ERC1155ownable.vy | 2 -- tests/base_conftest.py | 4 ++-- vyper/compiler/output.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 8094225f18..f1070b8f89 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -214,7 +214,6 @@ def mint(receiver: address, id: uint256, amount:uint256): @param receiver the account that will receive the minted token @param id the ID of the token @param amount of tokens for this ID - @param data the data associated with this mint. Usually stays empty """ assert not self.paused, "The contract has been paused" assert self.owner == msg.sender, "Only the contract owner can mint" @@ -232,7 +231,6 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy @param receiver the account that will receive the minted token @param ids array of ids for the tokens @param amounts amounts of tokens for each ID in the ids array - @param data the data associated with this mint. Usually stays empty """ assert not self.paused, "The contract has been paused" assert self.owner == msg.sender, "Only the contract owner can mint" diff --git a/tests/base_conftest.py b/tests/base_conftest.py index 81e8dedc36..1c7c6f3aed 100644 --- a/tests/base_conftest.py +++ b/tests/base_conftest.py @@ -118,8 +118,8 @@ def _get_contract(w3, source_code, optimize, *args, override_opt_level=None, **k settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, - # test that metadata gets generated - ["abi", "bytecode", "metadata"], + # test that metadata and natspecs get generated + ["abi", "bytecode", "metadata", "userdoc", "devdoc"], settings=settings, interface_codes=kwargs.pop("interface_codes", None), show_gas_estimates=True, # Enable gas estimates for testing diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 9ef492c3e2..1c38fcff9b 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -107,7 +107,7 @@ def build_metadata_output(compiler_data: CompilerData) -> dict: sigs = compiler_data.function_signatures def _var_rec_dict(variable_record): - ret = vars(variable_record) + ret = vars(variable_record).copy() ret["typ"] = str(ret["typ"]) if ret["data_offset"] is None: del ret["data_offset"] @@ -117,7 +117,7 @@ def _var_rec_dict(variable_record): return ret def _to_dict(func_t): - ret = vars(func_t) + ret = vars(func_t).copy() ret["return_type"] = str(ret["return_type"]) ret["_ir_identifier"] = func_t._ir_info.ir_identifier @@ -133,7 +133,7 @@ def _to_dict(func_t): args = ret[attr] ret[attr] = {arg.name: str(arg.typ) for arg in args} - ret["frame_info"] = vars(func_t._ir_info.frame_info) + ret["frame_info"] = vars(func_t._ir_info.frame_info).copy() del ret["frame_info"]["frame_vars"] # frame_var.pos might be IR, cannot serialize keep_keys = { From 42817806cadaffefed7bf9c8edd64abf439be4de Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 28 Sep 2023 14:55:05 -0700 Subject: [PATCH 103/201] fix: improve test case inputs in selector table fuzz (#3625) this commit improves the fuzz examples for the selector table. the nested `@given` tests too many "dumb" examples (ex. 0, 1, max_value) when `max_examples` is not large enough. the nested `@given` strategy can find falsifying inputs, but it requires the inner `max_examples` to be set much higher, and the shrinking takes much longer. this setting of `max_examples=125` with a single `@given` using the `@composite` strategy in this commit finds the selector table bug (that was fixed in 823675a8dc) after an average of 3 runs. --- tests/parser/test_selector_table.py | 98 +++++++++++++++-------------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py index 3ac50707c2..180c0266bf 100644 --- a/tests/parser/test_selector_table.py +++ b/tests/parser/test_selector_table.py @@ -478,66 +478,72 @@ def test_dense_jumptable_bucket_size(n_methods, seed): assert n_buckets / n < 0.4 or n < 10 +@st.composite +def generate_methods(draw, max_calldata_bytes): + max_default_args = draw(st.integers(min_value=0, max_value=4)) + default_fn_mutability = draw(st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"])) + + return ( + max_default_args, + default_fn_mutability, + draw( + st.lists( + st.tuples( + # function id: + st.integers(min_value=0), + # mutability: + st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), + # n calldata words: + st.integers(min_value=0, max_value=max_calldata_bytes // 32), + # n bytes to strip from calldata + st.integers(min_value=1, max_value=4), + # n default args + st.integers(min_value=0, max_value=max_default_args), + ), + unique_by=lambda x: x[0], + min_size=1, + max_size=100, + ) + ), + ) + + @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) # dense selector table packing boundaries at 256 and 65336 @pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) -@settings(max_examples=5, deadline=None) -@given( - seed=st.integers(min_value=0, max_value=2**64 - 1), - max_default_args=st.integers(min_value=0, max_value=4), - default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]), -) @pytest.mark.fuzzing def test_selector_table_fuzz( - max_calldata_bytes, - seed, - max_default_args, - opt_level, - default_fn_mutability, - w3, - get_contract, - assert_tx_failed, - get_logs, + max_calldata_bytes, opt_level, w3, get_contract, assert_tx_failed, get_logs ): - def abi_sig(calldata_words, i, n_default_args): - args = [] if not calldata_words else [f"uint256[{calldata_words}]"] - args.extend(["uint256"] * n_default_args) - argstr = ",".join(args) - return f"foo{seed + i}({argstr})" + def abi_sig(func_id, calldata_words, n_default_args): + params = [] if not calldata_words else [f"uint256[{calldata_words}]"] + params.extend(["uint256"] * n_default_args) + paramstr = ",".join(params) + return f"foo{func_id}({paramstr})" - def generate_func_def(mutability, calldata_words, i, n_default_args): + def generate_func_def(func_id, mutability, calldata_words, n_default_args): arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"] for j in range(n_default_args): arglist.append(f"x{j}: uint256 = 0") args = ", ".join(arglist) - _log_return = f"log _Return({i})" if mutability == "@payable" else "" + _log_return = f"log _Return({func_id})" if mutability == "@payable" else "" return f""" @external {mutability} -def foo{seed + i}({args}) -> uint256: +def foo{func_id}({args}) -> uint256: {_log_return} - return {i} + return {func_id} """ - @given( - methods=st.lists( - st.tuples( - st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), - st.integers(min_value=0, max_value=max_calldata_bytes // 32), - # n bytes to strip from calldata - st.integers(min_value=1, max_value=4), - # n default args - st.integers(min_value=0, max_value=max_default_args), - ), - min_size=1, - max_size=100, - ) - ) - @settings(max_examples=25) - def _test(methods): + @given(_input=generate_methods(max_calldata_bytes)) + @settings(max_examples=125, deadline=None) + def _test(_input): + max_default_args, default_fn_mutability, methods = _input + func_defs = "\n".join( - generate_func_def(m, s, i, d) for i, (m, s, _, d) in enumerate(methods) + generate_func_def(func_id, mutability, calldata_words, n_default_args) + for (func_id, mutability, calldata_words, _, n_default_args) in (methods) ) if default_fn_mutability == "": @@ -571,8 +577,8 @@ def __default__(): c = get_contract(code, override_opt_level=opt_level) - for i, (mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate(methods): - funcname = f"foo{seed + i}" + for func_id, mutability, n_calldata_words, n_strip_bytes, n_default_args in methods: + funcname = f"foo{func_id}" func = getattr(c, funcname) for j in range(n_default_args + 1): @@ -580,9 +586,9 @@ def __default__(): args.extend([1] * j) # check the function returns as expected - assert func(*args) == i + assert func(*args) == func_id - method_id = utils.method_id(abi_sig(n_calldata_words, i, j)) + method_id = utils.method_id(abi_sig(func_id, n_calldata_words, j)) argsdata = b"\x00" * (n_calldata_words * 32 + j * 32) @@ -590,7 +596,7 @@ def __default__(): if mutability == "@payable": tx = func(*args, transact={"value": 1}) (event,) = get_logs(tx, c, "_Return") - assert event.args.val == i + assert event.args.val == func_id else: hexstr = (method_id + argsdata).hex() txdata = {"to": c.address, "data": hexstr, "value": 1} From 917959e3993ab0592d28bb5326e89a7a3ae0eb58 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 28 Sep 2023 18:25:39 -0700 Subject: [PATCH 104/201] docs: new for loop range syntax: `bound=` (#3540) `for i in range(..., bound=...)` --------- Co-authored-by: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com> --- docs/control-structures.rst | 13 +++++++++++-- tests/parser/features/iteration/test_for_range.py | 4 ++-- vyper/codegen/stmt.py | 4 +++- vyper/ir/compile_ir.py | 2 +- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index fc8a472ff6..873135709a 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -271,16 +271,25 @@ Ranges are created using the ``range`` function. The following examples are vali ``STOP`` is a literal integer greater than zero. ``i`` begins as zero and increments by one until it is equal to ``STOP``. +.. code-block:: python + + for i in range(stop, bound=N): + ... + +Here, ``stop`` can be a variable with integer type, greater than zero. ``N`` must be a compile-time constant. ``i`` begins as zero and increments by one until it is equal to ``stop``. If ``stop`` is larger than ``N``, execution will revert at runtime. In certain cases, you may not have a guarantee that ``stop`` is less than ``N``, but still want to avoid the possibility of runtime reversion. To accomplish this, use the ``bound=`` keyword in combination with ``min(stop, N)`` as the argument to ``range``, like ``range(min(stop, N), bound=N)``. This is helpful for use cases like chunking up operations on larger arrays across multiple transactions. + +Another use of range can be with ``START`` and ``STOP`` bounds. + .. code-block:: python for i in range(START, STOP): ... -``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``. +Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``. .. code-block:: python for i in range(a, a + N): ... -``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. +``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert. diff --git a/tests/parser/features/iteration/test_for_range.py b/tests/parser/features/iteration/test_for_range.py index 395dd28231..ed6235d992 100644 --- a/tests/parser/features/iteration/test_for_range.py +++ b/tests/parser/features/iteration/test_for_range.py @@ -20,12 +20,12 @@ def test_range_bound(get_contract, assert_tx_failed): def repeat(n: uint256) -> uint256: x: uint256 = 0 for i in range(n, bound=6): - x += i + x += i + 1 return x """ c = get_contract(code) for n in range(7): - assert c.repeat(n) == sum(range(n)) + assert c.repeat(n) == sum(i + 1 for i in range(n)) # check codegen inserts assertion for n greater than bound assert_tx_failed(lambda: c.repeat(7)) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 3ecb0afdc3..c2951986c8 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -302,7 +302,9 @@ def _parse_For_range(self): loop_body.append(["mstore", iptr, i]) loop_body.append(parse_body(self.stmt.body, self.context)) - # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound. + # NOTE: codegen for `repeat` inserts an assertion that + # (gt rounds_bound rounds). note this also covers the case where + # rounds < 0. # if we ever want to remove that, we need to manually add the assertion # where it makes sense. ir_node = IRnode.from_list( diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 7a3e97155b..1c4dc1ef7c 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -415,7 +415,7 @@ def _height_of(witharg): ) ) # stack: i, rounds, rounds_bound - # assert rounds <= rounds_bound + # assert 0 <= rounds <= rounds_bound (for rounds_bound < 2**255) # TODO this runtime assertion shouldn't fail for # internally generated repeats. o.extend(["DUP2", "GT"] + _assert_false()) From c913b2db0881a6f4e1c70b7929b713a6aab05c62 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 28 Sep 2023 18:32:04 -0700 Subject: [PATCH 105/201] chore: remove deadlines and reruns (#3630) for historical reasons, pytest ran with `--allow-reruns` when tests failed because tests would usually fail due to deadline errors. deadline errors have never indicated any meaningful issue with the compiler, they are just a somewhat unavoidable byproduct of the fact that we are running in a CI environment which has a lot of jitter. this commit changes the hypothesis deadline to `None` for the whole test suite, and removes the `--allow-reruns` parameter in the CI, which should make the test suite much more efficient when there are failures. --- .github/workflows/test.yml | 6 +++--- tests/ast/nodes/test_evaluate_binop_decimal.py | 4 ++-- tests/ast/nodes/test_evaluate_binop_int.py | 8 ++++---- tests/ast/nodes/test_evaluate_boolop.py | 4 ++-- tests/ast/nodes/test_evaluate_compare.py | 8 ++++---- tests/ast/nodes/test_evaluate_subscript.py | 2 +- tests/builtins/folding/test_abs.py | 4 ++-- tests/builtins/folding/test_addmod_mulmod.py | 2 +- tests/builtins/folding/test_bitwise.py | 8 ++++---- tests/builtins/folding/test_floor_ceil.py | 2 +- tests/builtins/folding/test_fold_as_wei_value.py | 4 ++-- tests/builtins/folding/test_keccak_sha.py | 6 +++--- tests/builtins/folding/test_min_max.py | 6 +++--- tests/builtins/folding/test_powmod.py | 2 +- tests/conftest.py | 6 ++++++ tests/fuzzing/test_exponents.py | 4 ++-- tests/grammar/test_grammar.py | 4 ++-- tests/parser/features/test_internal_call.py | 2 +- tests/parser/functions/test_slice.py | 4 ++-- tests/parser/test_call_graph_stability.py | 2 +- tests/parser/test_selector_table.py | 6 +++--- tests/parser/types/numbers/test_isqrt.py | 1 - tests/parser/types/numbers/test_sqrt.py | 2 -- tests/parser/types/test_bytes_zero_padding.py | 1 - 24 files changed, 50 insertions(+), 48 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fd78e2fff8..8d23368eb0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -104,7 +104,7 @@ jobs: run: pip install tox - name: Run Tox - run: TOXENV=py${{ matrix.python-version[1] }} tox -r -- --optimize ${{ matrix.opt-mode }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} --reruns 10 --reruns-delay 1 -r aR tests/ + run: TOXENV=py${{ matrix.python-version[1] }} tox -r -- --optimize ${{ matrix.opt-mode }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} -r aR tests/ - name: Upload Coverage uses: codecov/codecov-action@v1 @@ -148,12 +148,12 @@ jobs: # fetch test durations # NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo. - # `TOXENV=fuzzing tox -r -- --store-durations --reruns 10 --reruns-delay 1 -r aR tests/` + # `TOXENV=fuzzing tox -r -- --store-durations -r aR tests/` - name: Fetch test-durations run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/5982755ee8459f771f2e8622427c36494646e1dd/test_durations" -o .test_durations - name: Run Tox - run: TOXENV=fuzzing tox -r -- --splits 60 --group ${{ matrix.group }} --splitting-algorithm least_duration --reruns 10 --reruns-delay 1 -r aR tests/ + run: TOXENV=fuzzing tox -r -- --splits 60 --group ${{ matrix.group }} --splitting-algorithm least_duration -r aR tests/ - name: Upload Coverage uses: codecov/codecov-action@v1 diff --git a/tests/ast/nodes/test_evaluate_binop_decimal.py b/tests/ast/nodes/test_evaluate_binop_decimal.py index 3c8ba0888c..5c9956caba 100644 --- a/tests/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/ast/nodes/test_evaluate_binop_decimal.py @@ -13,7 +13,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=None) +@settings(max_examples=50) @given(left=st_decimals, right=st_decimals) @example(left=Decimal("0.9999999999"), right=Decimal("0.0000000001")) @example(left=Decimal("0.0000000001"), right=Decimal("0.9999999999")) @@ -52,7 +52,7 @@ def test_binop_pow(): @pytest.mark.fuzzing -@settings(max_examples=50, deadline=None) +@settings(max_examples=50) @given( values=st.lists(st_decimals, min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), diff --git a/tests/ast/nodes/test_evaluate_binop_int.py b/tests/ast/nodes/test_evaluate_binop_int.py index d632a95461..80c9381c0f 100644 --- a/tests/ast/nodes/test_evaluate_binop_int.py +++ b/tests/ast/nodes/test_evaluate_binop_int.py @@ -9,7 +9,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st_int32, right=st_int32) @example(left=1, right=1) @example(left=1, right=-1) @@ -42,7 +42,7 @@ def foo(a: int128, b: int128) -> int128: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st_uint64, right=st_uint64) @pytest.mark.parametrize("op", "+-*/%") def test_binop_uint256(get_contract, assert_tx_failed, op, left, right): @@ -69,7 +69,7 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.xfail(reason="need to implement safe exponentiation logic") @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st.integers(min_value=2, max_value=245), right=st.integers(min_value=0, max_value=16)) @example(left=0, right=0) @example(left=0, right=1) @@ -89,7 +89,7 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given( values=st.lists(st.integers(min_value=-256, max_value=256), min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), diff --git a/tests/ast/nodes/test_evaluate_boolop.py b/tests/ast/nodes/test_evaluate_boolop.py index 6bd9ecc6cb..8b70537c39 100644 --- a/tests/ast/nodes/test_evaluate_boolop.py +++ b/tests/ast/nodes/test_evaluate_boolop.py @@ -8,7 +8,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(values=st.lists(st.booleans(), min_size=2, max_size=10)) @pytest.mark.parametrize("comparator", ["and", "or"]) def test_boolop_simple(get_contract, values, comparator): @@ -32,7 +32,7 @@ def foo({input_value}) -> bool: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given( values=st.lists(st.booleans(), min_size=2, max_size=10), comparators=st.lists(st.sampled_from(["and", "or"]), min_size=11, max_size=11), diff --git a/tests/ast/nodes/test_evaluate_compare.py b/tests/ast/nodes/test_evaluate_compare.py index 9ff5cea338..07f8e70de6 100644 --- a/tests/ast/nodes/test_evaluate_compare.py +++ b/tests/ast/nodes/test_evaluate_compare.py @@ -8,7 +8,7 @@ # TODO expand to all signed types @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st.integers(), right=st.integers()) @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_eq_signed(get_contract, op, left, right): @@ -28,7 +28,7 @@ def foo(a: int128, b: int128) -> bool: # TODO expand to all unsigned types @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st.integers(min_value=0), right=st.integers(min_value=0)) @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_eq_unsigned(get_contract, op, left, right): @@ -47,7 +47,7 @@ def foo(a: uint128, b: uint128) -> bool: @pytest.mark.fuzzing -@settings(max_examples=20, deadline=1000) +@settings(max_examples=20) @given(left=st.integers(), right=st.lists(st.integers(), min_size=1, max_size=16)) def test_compare_in(left, right, get_contract): source = f""" @@ -76,7 +76,7 @@ def bar(a: int128) -> bool: @pytest.mark.fuzzing -@settings(max_examples=20, deadline=1000) +@settings(max_examples=20) @given(left=st.integers(), right=st.lists(st.integers(), min_size=1, max_size=16)) def test_compare_not_in(left, right, get_contract): source = f""" diff --git a/tests/ast/nodes/test_evaluate_subscript.py b/tests/ast/nodes/test_evaluate_subscript.py index 3c0fa5d16d..ca50a076a5 100644 --- a/tests/ast/nodes/test_evaluate_subscript.py +++ b/tests/ast/nodes/test_evaluate_subscript.py @@ -6,7 +6,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given( idx=st.integers(min_value=0, max_value=9), array=st.lists(st.integers(), min_size=10, max_size=10), diff --git a/tests/builtins/folding/test_abs.py b/tests/builtins/folding/test_abs.py index 58f861ed0c..1c919d7826 100644 --- a/tests/builtins/folding/test_abs.py +++ b/tests/builtins/folding/test_abs.py @@ -8,7 +8,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(a=st.integers(min_value=-(2**255) + 1, max_value=2**255 - 1)) @example(a=0) def test_abs(get_contract, a): @@ -27,7 +27,7 @@ def foo(a: int256) -> int256: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(a=st.integers(min_value=2**255, max_value=2**256 - 1)) def test_abs_upper_bound_folding(get_contract, a): source = f""" diff --git a/tests/builtins/folding/test_addmod_mulmod.py b/tests/builtins/folding/test_addmod_mulmod.py index 0514dea18a..33dcc62984 100644 --- a/tests/builtins/folding/test_addmod_mulmod.py +++ b/tests/builtins/folding/test_addmod_mulmod.py @@ -9,7 +9,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(a=st_uint256, b=st_uint256, c=st_uint256) @pytest.mark.parametrize("fn_name", ["uint256_addmod", "uint256_mulmod"]) def test_modmath(get_contract, a, b, c, fn_name): diff --git a/tests/builtins/folding/test_bitwise.py b/tests/builtins/folding/test_bitwise.py index d28e482589..63e733644f 100644 --- a/tests/builtins/folding/test_bitwise.py +++ b/tests/builtins/folding/test_bitwise.py @@ -14,7 +14,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("op", ["&", "|", "^"]) @given(a=st_uint256, b=st_uint256) def test_bitwise_ops(get_contract, a, b, op): @@ -34,7 +34,7 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("op", ["<<", ">>"]) @given(a=st_uint256, b=st.integers(min_value=0, max_value=256)) def test_bitwise_shift_unsigned(get_contract, a, b, op): @@ -64,7 +64,7 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("op", ["<<", ">>"]) @given(a=st_sint256, b=st.integers(min_value=0, max_value=256)) def test_bitwise_shift_signed(get_contract, a, b, op): @@ -92,7 +92,7 @@ def foo(a: int256, b: uint256) -> int256: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(value=st_uint256) def test_bitwise_not(get_contract, value): source = """ diff --git a/tests/builtins/folding/test_floor_ceil.py b/tests/builtins/folding/test_floor_ceil.py index 763f8fec63..87db23889a 100644 --- a/tests/builtins/folding/test_floor_ceil.py +++ b/tests/builtins/folding/test_floor_ceil.py @@ -13,7 +13,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(value=st_decimals) @example(value=Decimal("0.9999999999")) @example(value=Decimal("0.0000000001")) diff --git a/tests/builtins/folding/test_fold_as_wei_value.py b/tests/builtins/folding/test_fold_as_wei_value.py index 11d23bd3bf..210ab51f0d 100644 --- a/tests/builtins/folding/test_fold_as_wei_value.py +++ b/tests/builtins/folding/test_fold_as_wei_value.py @@ -19,7 +19,7 @@ @pytest.mark.fuzzing -@settings(max_examples=10, deadline=1000) +@settings(max_examples=10) @given(value=st_decimals) @pytest.mark.parametrize("denom", denoms) def test_decimal(get_contract, value, denom): @@ -38,7 +38,7 @@ def foo(a: decimal) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=10, deadline=1000) +@settings(max_examples=10) @given(value=st.integers(min_value=0, max_value=2**128)) @pytest.mark.parametrize("denom", denoms) def test_integer(get_contract, value, denom): diff --git a/tests/builtins/folding/test_keccak_sha.py b/tests/builtins/folding/test_keccak_sha.py index 8e283566de..a2fe460dd1 100644 --- a/tests/builtins/folding/test_keccak_sha.py +++ b/tests/builtins/folding/test_keccak_sha.py @@ -10,7 +10,7 @@ @pytest.mark.fuzzing @given(value=st.text(alphabet=alphabet, min_size=0, max_size=100)) -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("fn_name", ["keccak256", "sha256"]) def test_string(get_contract, value, fn_name): source = f""" @@ -29,7 +29,7 @@ def foo(a: String[100]) -> bytes32: @pytest.mark.fuzzing @given(value=st.binary(min_size=0, max_size=100)) -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("fn_name", ["keccak256", "sha256"]) def test_bytes(get_contract, value, fn_name): source = f""" @@ -48,7 +48,7 @@ def foo(a: Bytes[100]) -> bytes32: @pytest.mark.fuzzing @given(value=st.binary(min_size=1, max_size=100)) -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @pytest.mark.parametrize("fn_name", ["keccak256", "sha256"]) def test_hex(get_contract, value, fn_name): source = f""" diff --git a/tests/builtins/folding/test_min_max.py b/tests/builtins/folding/test_min_max.py index e2d33237ca..309f7519c0 100644 --- a/tests/builtins/folding/test_min_max.py +++ b/tests/builtins/folding/test_min_max.py @@ -18,7 +18,7 @@ @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st_decimals, right=st_decimals) @pytest.mark.parametrize("fn_name", ["min", "max"]) def test_decimal(get_contract, left, right, fn_name): @@ -37,7 +37,7 @@ def foo(a: decimal, b: decimal) -> decimal: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st_int128, right=st_int128) @pytest.mark.parametrize("fn_name", ["min", "max"]) def test_int128(get_contract, left, right, fn_name): @@ -56,7 +56,7 @@ def foo(a: int128, b: int128) -> int128: @pytest.mark.fuzzing -@settings(max_examples=50, deadline=1000) +@settings(max_examples=50) @given(left=st_uint256, right=st_uint256) @pytest.mark.parametrize("fn_name", ["min", "max"]) def test_min_uint256(get_contract, left, right, fn_name): diff --git a/tests/builtins/folding/test_powmod.py b/tests/builtins/folding/test_powmod.py index fdc0e300ab..8667ec93fd 100644 --- a/tests/builtins/folding/test_powmod.py +++ b/tests/builtins/folding/test_powmod.py @@ -9,7 +9,7 @@ @pytest.mark.fuzzing -@settings(max_examples=100, deadline=1000) +@settings(max_examples=100) @given(a=st_uint256, b=st_uint256) def test_powmod_uint256(get_contract, a, b): source = """ diff --git a/tests/conftest.py b/tests/conftest.py index d519ca3100..c9d3f794a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import logging from functools import wraps +import hypothesis import pytest from eth_tester import EthereumTester, PyEVMBackend from eth_utils import setup_DEBUG2_logging @@ -23,6 +24,11 @@ ############ +# disable hypothesis deadline globally +hypothesis.settings.register_profile("ci", deadline=None) +hypothesis.settings.load_profile("ci") + + def set_evm_verbose_logging(): logger = logging.getLogger("eth.vm.computation.Computation") setup_DEBUG2_logging() diff --git a/tests/fuzzing/test_exponents.py b/tests/fuzzing/test_exponents.py index 29c1f198ed..5726e4c1ca 100644 --- a/tests/fuzzing/test_exponents.py +++ b/tests/fuzzing/test_exponents.py @@ -92,7 +92,7 @@ def foo(a: int16) -> int16: @example(a=2**127 - 1) # 256 bits @example(a=2**256 - 1) -@settings(max_examples=200, deadline=1000) +@settings(max_examples=200) def test_max_exp(get_contract, assert_tx_failed, a): code = f""" @external @@ -127,7 +127,7 @@ def foo(b: uint256) -> uint256: @example(a=2**63 - 1) # 128 bits @example(a=2**127 - 1) -@settings(max_examples=200, deadline=1000) +@settings(max_examples=200) def test_max_exp_int128(get_contract, assert_tx_failed, a): code = f""" @external diff --git a/tests/grammar/test_grammar.py b/tests/grammar/test_grammar.py index d665ca2544..aa0286cfa5 100644 --- a/tests/grammar/test_grammar.py +++ b/tests/grammar/test_grammar.py @@ -4,7 +4,7 @@ import hypothesis import hypothesis.strategies as st import pytest -from hypothesis import HealthCheck, assume, given +from hypothesis import assume, given from hypothesis.extra.lark import LarkStrategy from vyper.ast import Module, parse_to_ast @@ -103,7 +103,7 @@ def has_no_docstrings(c): @pytest.mark.fuzzing @given(code=from_grammar().filter(lambda c: utf8_encodable(c))) -@hypothesis.settings(deadline=400, max_examples=500, suppress_health_check=(HealthCheck.too_slow,)) +@hypothesis.settings(max_examples=500) def test_grammar_bruteforce(code): if utf8_encodable(code): _, _, reformatted_code = pre_parse(code + "\n") diff --git a/tests/parser/features/test_internal_call.py b/tests/parser/features/test_internal_call.py index d7a41acbc0..f10d22ec99 100644 --- a/tests/parser/features/test_internal_call.py +++ b/tests/parser/features/test_internal_call.py @@ -669,7 +669,7 @@ def test_internal_call_kwargs(get_contract, typ1, strategy1, typ2, strategy2): # GHSA-ph9x-4vc9-m39g @given(kwarg1=strategy1, default1=strategy1, kwarg2=strategy2, default2=strategy2) - @settings(deadline=None, max_examples=5) # len(cases) * len(cases) * 5 * 5 + @settings(max_examples=5) # len(cases) * len(cases) * 5 * 5 def fuzz(kwarg1, kwarg2, default1, default2): code = f""" @internal diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 6229b47921..3090dafda0 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -36,7 +36,7 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: @pytest.mark.parametrize("literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=100, deadline=None) +@settings(max_examples=100) @pytest.mark.fuzzing def test_slice_immutable( get_contract, @@ -90,7 +90,7 @@ def _get_contract(): @pytest.mark.parametrize("literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) -@settings(max_examples=100, deadline=None) +@settings(max_examples=100) @pytest.mark.fuzzing def test_slice_bytes( get_contract, diff --git a/tests/parser/test_call_graph_stability.py b/tests/parser/test_call_graph_stability.py index a6193610e2..4c85c330f3 100644 --- a/tests/parser/test_call_graph_stability.py +++ b/tests/parser/test_call_graph_stability.py @@ -15,7 +15,7 @@ def _valid_identifier(attr): # random names for functions -@settings(max_examples=20, deadline=None) +@settings(max_examples=20) @given( st.lists( st.tuples( diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py index 180c0266bf..161cd480fd 100644 --- a/tests/parser/test_selector_table.py +++ b/tests/parser/test_selector_table.py @@ -446,7 +446,7 @@ def aILR4U1Z()->uint256: seed=st.integers(min_value=0, max_value=2**64 - 1), ) @pytest.mark.fuzzing -@settings(max_examples=10, deadline=None) +@settings(max_examples=10) def test_sparse_jumptable_probe_depth(n_methods, seed): sigs = [f"foo{i + seed}()" for i in range(n_methods)] _, buckets = generate_sparse_jumptable_buckets(sigs) @@ -466,7 +466,7 @@ def test_sparse_jumptable_probe_depth(n_methods, seed): seed=st.integers(min_value=0, max_value=2**64 - 1), ) @pytest.mark.fuzzing -@settings(max_examples=10, deadline=None) +@settings(max_examples=10) def test_dense_jumptable_bucket_size(n_methods, seed): sigs = [f"foo{i + seed}()" for i in range(n_methods)] n = len(sigs) @@ -537,7 +537,7 @@ def foo{func_id}({args}) -> uint256: """ @given(_input=generate_methods(max_calldata_bytes)) - @settings(max_examples=125, deadline=None) + @settings(max_examples=125) def _test(_input): max_default_args, default_fn_mutability, methods = _input diff --git a/tests/parser/types/numbers/test_isqrt.py b/tests/parser/types/numbers/test_isqrt.py index ce26d24d06..b734323a6e 100644 --- a/tests/parser/types/numbers/test_isqrt.py +++ b/tests/parser/types/numbers/test_isqrt.py @@ -119,7 +119,6 @@ def test(a: uint256) -> (uint256, uint256, uint256, uint256, uint256, String[100 @hypothesis.example(2704) @hypothesis.example(110889) @hypothesis.example(32239684) -@hypothesis.settings(deadline=1000) def test_isqrt_valid_range(isqrt_contract, value): vyper_isqrt = isqrt_contract.test(value) actual_isqrt = math.isqrt(value) diff --git a/tests/parser/types/numbers/test_sqrt.py b/tests/parser/types/numbers/test_sqrt.py index df1ed0539c..020a79e7ef 100644 --- a/tests/parser/types/numbers/test_sqrt.py +++ b/tests/parser/types/numbers/test_sqrt.py @@ -145,7 +145,6 @@ def test_sqrt_bounds(sqrt_contract, value): ) @hypothesis.example(value=Decimal(SizeLimits.MAX_INT128)) @hypothesis.example(value=Decimal(0)) -@hypothesis.settings(deadline=1000) def test_sqrt_valid_range(sqrt_contract, value): vyper_sqrt = sqrt_contract.test(value) actual_sqrt = decimal_sqrt(value) @@ -158,7 +157,6 @@ def test_sqrt_valid_range(sqrt_contract, value): min_value=Decimal(SizeLimits.MIN_INT128), max_value=Decimal("-1E10"), places=DECIMAL_PLACES ) ) -@hypothesis.settings(deadline=400) @hypothesis.example(value=Decimal(SizeLimits.MIN_INT128)) @hypothesis.example(value=Decimal("-1E10")) def test_sqrt_invalid_range(sqrt_contract, value): diff --git a/tests/parser/types/test_bytes_zero_padding.py b/tests/parser/types/test_bytes_zero_padding.py index ee938fdffb..f9fcf37b25 100644 --- a/tests/parser/types/test_bytes_zero_padding.py +++ b/tests/parser/types/test_bytes_zero_padding.py @@ -26,7 +26,6 @@ def get_count(counter: uint256) -> Bytes[24]: @pytest.mark.fuzzing @hypothesis.given(value=hypothesis.strategies.integers(min_value=0, max_value=2**64)) -@hypothesis.settings(deadline=400) def test_zero_pad_range(little_endian_contract, value): actual_bytes = value.to_bytes(8, byteorder="little") contract_bytes = little_endian_contract.get_count(value) From 8aae7cd6b86c15978bdfa16d5a6e3ca273121107 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 2 Oct 2023 19:20:19 -0700 Subject: [PATCH 106/201] feat: disallow invalid pragmas (#3634) this commit prevents typos (ex. `# pragma evm-versionn ...`) from getting unexpectedly ignored. it also fixes an issue with the `evm_version` pragma getting properly propagated into `CompilerData.settings`. it also fixes a misnamed test function, and adds some unit tests to check that `evm_version` gets properly propagated into `CompilerData.settings`. commit summary: * fix `evm_version` passing into compiler data * fix `compile_code()` * add tests for `CompilerData.settings` * add tests for invalid pragmas --- tests/ast/test_pre_parser.py | 70 ++++++++++++++++++++++++++++++++---- vyper/ast/pre_parser.py | 7 ++-- vyper/compiler/README.md | 6 ++-- vyper/compiler/__init__.py | 22 ++++++------ vyper/compiler/phases.py | 5 +++ 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py index 5427532c16..3d072674f6 100644 --- a/tests/ast/test_pre_parser.py +++ b/tests/ast/test_pre_parser.py @@ -1,8 +1,9 @@ import pytest from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings -from vyper.exceptions import VersionException +from vyper.exceptions import StructureException, VersionException SRC_LINE = (1, 0) # Dummy source line COMPILER_VERSION = "0.1.1" @@ -96,43 +97,50 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): """ """, Settings(), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize """, Settings(optimize=OptimizationLevel.CODESIZE), + None, ), ( """ #pragma optimize none """, Settings(optimize=OptimizationLevel.NONE), + None, ), ( """ #pragma optimize gas """, Settings(optimize=OptimizationLevel.GAS), + None, ), ( """ #pragma version 0.3.10 """, Settings(compiler_version="0.3.10"), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma evm-version shanghai """, Settings(evm_version="shanghai"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize #pragma evm-version shanghai """, - Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), + Settings(evm_version="shanghai", optimize=OptimizationLevel.CODESIZE), + None, ), ( """ @@ -140,6 +148,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma evm-version shanghai """, Settings(evm_version="shanghai", compiler_version="0.3.10"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ @@ -147,6 +156,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS), + Settings(optimize=OptimizationLevel.GAS), ), ( """ @@ -155,11 +165,59 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS, evm_version="shanghai"), + Settings(optimize=OptimizationLevel.GAS, evm_version="shanghai"), ), ] -@pytest.mark.parametrize("code, expected_pragmas", pragma_examples) -def parse_pragmas(code, expected_pragmas): - pragmas, _, _ = pre_parse(code) - assert pragmas == expected_pragmas +@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) +def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): + mock_version("0.3.10") + settings, _, _ = pre_parse(code) + + assert settings == pre_parse_settings + + compiler_data = CompilerData(code) + + # check what happens after CompilerData constructor + if compiler_data_settings is None: + # None is sentinel here meaning that nothing changed + compiler_data_settings = pre_parse_settings + + assert compiler_data.settings == compiler_data_settings + + +invalid_pragmas = [ + # evm-versionnn + """ +# pragma evm-versionnn cancun + """, + # bad fork name + """ +# pragma evm-version cancunn + """, + # oppptimize + """ +# pragma oppptimize codesize + """, + # ggas + """ +# pragma optimize ggas + """, + # double specified + """ +# pragma optimize gas +# pragma optimize codesize + """, + # double specified + """ +# pragma evm-version cancun +# pragma evm-version shanghai + """, +] + + +@pytest.mark.parametrize("code", invalid_pragmas) +def test_invalid_pragma(code): + with pytest.raises(StructureException): + pre_parse(code) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 0ead889787..9d96efea5e 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -111,7 +111,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version - if pragma.startswith("optimize "): + elif pragma.startswith("optimize "): if settings.optimize is not None: raise StructureException("pragma optimize specified twice!", start) try: @@ -119,7 +119,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: settings.optimize = OptimizationLevel.from_string(mode) except ValueError: raise StructureException(f"Invalid optimization mode `{mode}`", start) - if pragma.startswith("evm-version "): + elif pragma.startswith("evm-version "): if settings.evm_version is not None: raise StructureException("pragma evm-version specified twice!", start) evm_version = pragma.removeprefix("evm-version").strip() @@ -127,6 +127,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + else: + raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") + if typ == NAME and string in ("class", "yield"): raise SyntaxException( f"The `{string}` keyword is not allowed. ", code, start[0], start[1] diff --git a/vyper/compiler/README.md b/vyper/compiler/README.md index d6b55fdd82..eb70750a2b 100644 --- a/vyper/compiler/README.md +++ b/vyper/compiler/README.md @@ -51,11 +51,9 @@ for specific implementation details. [`vyper.compiler.compile_codes`](__init__.py) is the main user-facing function for generating compiler output from Vyper source. The process is as follows: -1. The `@evm_wrapper` decorator sets the target EVM version in -[`opcodes.py`](../evm/opcodes.py). -2. A [`CompilerData`](phases.py) object is created for each contract to be compiled. +1. A [`CompilerData`](phases.py) object is created for each contract to be compiled. This object uses `@property` methods to trigger phases of the compiler as required. -3. Functions in [`output.py`](output.py) generate the requested outputs from the +2. Functions in [`output.py`](output.py) generate the requested outputs from the compiler data. ## Design diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0b3c0d8191..b1c4201361 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -120,17 +120,17 @@ def compile_codes( # make IR output the same between runs codegen.reset_names() - with anchor_evm_version(settings.evm_version): - compiler_data = CompilerData( - source_code, - contract_name, - interfaces, - source_id, - settings, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, - ) + compiler_data = CompilerData( + source_code, + contract_name, + interfaces, + source_id, + settings, + storage_layout_override, + show_gas_estimates, + no_bytecode_metadata, + ) + with anchor_evm_version(compiler_data.settings.evm_version): for output_format in output_formats[contract_name]: if output_format not in OUTPUT_FORMATS: raise ValueError(f"Unsupported format type {repr(output_format)}") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index a1c7342320..5ddf071caf 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -88,9 +88,12 @@ def __init__( self.no_bytecode_metadata = no_bytecode_metadata self.settings = settings or Settings() + _ = self._generate_ast # force settings to be calculated + @cached_property def _generate_ast(self): settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + # validate the compiler settings # XXX: this is a bit ugly, clean up later if settings.evm_version is not None: @@ -117,6 +120,8 @@ def _generate_ast(self): if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() + # note self.settings.compiler_version is erased here as it is + # not used after pre-parsing return ast @cached_property From e9c16e40dd11ba21ba817ff0da78eaeab744fd39 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 3 Oct 2023 16:36:58 -0700 Subject: [PATCH 107/201] fix: block `mload` merging when src and dst overlap (#3635) this commit fixes an optimization bug when the target architecture has the `mcopy` instruction (i.e. `cancun` or later). the bug was introduced in 5dc3ac7. specifically, the `merge_mload` step can incorrectly merge `mload`/`mstore` sequences (into `mcopy`) when the source and destination buffers overlap, and the destination buffer is "ahead of" (i.e. greater than) the source buffer. this commit fixes the issue by blocking the optimization in these cases, and adds unit and functional tests demonstrating the correct behavior. --------- Co-authored-by: Robert Chen --- tests/compiler/ir/test_optimize_ir.py | 107 +++++++++++++++++++++++ tests/parser/features/test_assignment.py | 60 +++++++++++++ vyper/ir/optimizer.py | 9 +- 3 files changed, 174 insertions(+), 2 deletions(-) diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index 1466166501..cb46ba238d 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -1,9 +1,13 @@ import pytest from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version from vyper.exceptions import StaticAssertionException from vyper.ir import optimizer +POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + optimize_list = [ (["eq", 1, 2], [0]), (["lt", 1, 2], [1]), @@ -272,3 +276,106 @@ def test_operator_set_values(): assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"} assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"} assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"} + + +mload_merge_list = [ + # copy "backward" with no overlap between src and dst buffers, + # OK to become mcopy + ( + ["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]], + ["mcopy", 32, 128, 64], + ), + # copy with overlap "backwards", OK to become mcopy + (["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]), + # "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy + (["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]), + # copy "forward" with no overlap, OK to become mcopy + (["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]), + # copy "forwards" with overlap by one word, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None), + # check "forward" overlap by one byte, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None), + # check "forward" overlap by one byte again, must NOT become mcopy + (["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None), + # copy 3 words with partial overlap "forwards", partially becomes mcopy + # (2 words are mcopied and 1 word is mload/mstored + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]], + ), + # copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ["mstore", 192, ["mload", 128]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]], + ), + # copy 4 words with 1 byte of overlap, must NOT become mcopy + ( + [ + "seq", + ["mstore", 96, ["mload", 33]], + ["mstore", 128, ["mload", 65]], + ["mstore", 160, ["mload", 97]], + ["mstore", 192, ["mload", 129]], + ], + None, + ), + # Ensure only sequential mstore + mload sequences are optimized + ( + [ + "seq", + ["mstore", 0, ["mload", 32]], + ["sstore", 0, ["calldataload", 4]], + ["mstore", 32, ["mload", 64]], + ], + None, + ), + # not-word aligned optimizations (not overlap) + (["seq", ["mstore", 0, ["mload", 1]], ["mstore", 32, ["mload", 33]]], ["mcopy", 0, 1, 64]), + # not-word aligned optimizations (overlap) + (["seq", ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], None), + # not-word aligned optimizations (overlap and not-overlap) + ( + [ + "seq", + ["mstore", 0, ["mload", 1]], + ["mstore", 32, ["mload", 33]], + ["mstore", 1, ["mload", 0]], + ["mstore", 33, ["mload", 32]], + ], + ["seq", ["mcopy", 0, 1, 64], ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], + ), + # overflow test + ( + [ + "seq", + ["mstore", 2**256 - 1 - 31 - 32, ["mload", 0]], + ["mstore", 2**256 - 1 - 31, ["mload", 32]], + ], + ["mcopy", 2**256 - 1 - 31 - 32, 0, 64], + ), +] + + +@pytest.mark.parametrize("ir", mload_merge_list) +@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys())) +def test_mload_merge(ir, evm_version): + with anchor_evm_version(evm_version): + optimized = optimizer.optimize(IRnode.from_list(ir[0])) + if ir[1] is None: + # no-op, assert optimizer does nothing + expected = IRnode.from_list(ir[0]) + else: + expected = IRnode.from_list(ir[1]) + + assert optimized == expected diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index e550f60541..35b008a8ba 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -442,3 +442,63 @@ def bug(p: Point) -> Point: """ c = get_contract(code) assert c.bug((1, 2)) == (2, 1) + + +mload_merge_codes = [ + ( + """ +@external +def foo() -> uint256[4]: + # copy "backwards" + xs: uint256[4] = [1, 2, 3, 4] + +# dst < src + xs[0] = xs[1] + xs[1] = xs[2] + xs[2] = xs[3] + + return xs + """, + [2, 3, 4, 4], + ), + ( + """ +@external +def foo() -> uint256[4]: + # copy "forwards" + xs: uint256[4] = [1, 2, 3, 4] + +# src < dst + xs[1] = xs[0] + xs[2] = xs[1] + xs[3] = xs[2] + + return xs + """, + [1, 1, 1, 1], + ), + ( + """ +@external +def foo() -> uint256[5]: + # partial "forward" copy + xs: uint256[5] = [1, 2, 3, 4, 5] + +# src < dst + xs[2] = xs[0] + xs[3] = xs[1] + xs[4] = xs[2] + + return xs + """, + [1, 2, 1, 2, 1], + ), +] + + +# functional test that mload merging does not occur when source and dest +# buffers overlap. (note: mload merging only applies after cancun) +@pytest.mark.parametrize("code,expected_result", mload_merge_codes) +def test_mcopy_overlap(get_contract, code, expected_result): + c = get_contract(code) + assert c.foo() == expected_result diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 08c2168381..8df4bbac2d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz): def _merge_mload(argz): if not version_check(begin="cancun"): return False - return _merge_load(argz, "mload", "mcopy") + return _merge_load(argz, "mload", "mcopy", allow_overlap=False) -def _merge_load(argz, _LOAD, _COPY): +def _merge_load(argz, _LOAD, _COPY, allow_overlap=True): # look for sequential operations copying from X to Y # and merge them into a single copy operation changed = False @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY): initial_dst_offset = dst_offset initial_src_offset = src_offset idx = i + + # dst and src overlap, discontinue the optimization + has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32 + if ( initial_dst_offset + total_length == dst_offset and initial_src_offset + total_length == src_offset + and (allow_overlap or not has_overlap) ): mstore_nodes.append(ir_node) total_length += 32 From 9136169468f317a53b4e7448389aa315f90b95ba Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 4 Oct 2023 06:16:44 -0700 Subject: [PATCH 108/201] docs: v0.3.10 release (#3629) --- docs/release-notes.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index da86c5c0ce..64199bc860 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,14 +14,10 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g -.. - v0.3.10 ("Black Adder") - *********************** - -v0.3.10rc1 -********** +v0.3.10 ("Black Adder") +*********************** -Date released: 2023-09-06 +Date released: 2023-10-04 ========================= v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). @@ -32,6 +28,7 @@ Breaking changes: - add runtime code layout to initcode (`#3584 `_) - drop evm versions through istanbul (`#3470 `_) - remove vyper signature from runtime (`#3471 `_) +- only allow valid identifiers to be nonreentrant keys (`#3605 `_) Non-breaking changes and improvements: -------------------------------------- @@ -46,12 +43,15 @@ Notable fixes: - fix ``ecrecover()`` behavior when signature is invalid (`GHSA-f5x6-7qgp-jhf3 `_, `#3586 `_) - fix: order of evaluation for some builtins (`#3583 `_, `#3587 `_) +- fix: memory allocation in certain builtins using ``msize`` (`#3610 `_) +- fix: ``_abi_decode()`` input validation in certain complex expressions (`#3626 `_) - fix: pycryptodome for arm builds (`#3485 `_) - let params of internal functions be mutable (`#3473 `_) - typechecking of folded builtins in (`#3490 `_) - update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_) - fix: raw_call type when max_outsize=0 is set (`#3572 `_) - fix: implements check for indexed event arguments (`#3570 `_) +- fix: type-checking for ``_abi_decode()`` arguments (`#3626 `_) Other docs updates, chores and fixes: ------------------------------------- From b8b4610a46379558367ba60f94ac2813eec056d4 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Wed, 4 Oct 2023 20:25:09 +0200 Subject: [PATCH 109/201] chore: update `FUNDING.yml` to point directly at wallet (#3636) Add link to the Vyper multisig address. --- FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FUNDING.yml b/FUNDING.yml index 81e82160d0..efb9eb01b7 100644 --- a/FUNDING.yml +++ b/FUNDING.yml @@ -1 +1 @@ -custom: https://gitcoin.co/grants/200/vyper-smart-contract-language-2 +custom: https://etherscan.io/address/0x70CCBE10F980d80b7eBaab7D2E3A73e87D67B775 From 5ca0cbfa4064a5577347bca8377bc9bea1cebca5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 4 Oct 2023 12:17:36 -0700 Subject: [PATCH 110/201] docs: fix nit in v0.3.10 release notes (#3638) --- docs/release-notes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 64199bc860..3db11dc451 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -20,7 +20,7 @@ v0.3.10 ("Black Adder") Date released: 2023-10-04 ========================= -v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). +v0.3.10 is a performance focused release that additionally ships numerous bugfixes. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). Breaking changes: ----------------- From 435754dd1db3e1b0b4608569e003f060e6e9eb40 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Thu, 5 Oct 2023 16:54:29 +0200 Subject: [PATCH 111/201] docs: add note on `pragma` parsing (#3640) * Fix broken link in `structure-of-a-contract.rst` * Add note on `pragma` version parsing --- docs/structure-of-a-contract.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index d2c5d48d96..3861bf4380 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -17,7 +17,7 @@ Vyper supports several source code directives to control compiler modes and help Version Pragma -------------- -The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers _`. +The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers `_. As of 0.3.10, the recommended way to specify the version pragma is as follows: @@ -25,6 +25,10 @@ As of 0.3.10, the recommended way to specify the version pragma is as follows: #pragma version ^0.3.0 +.. note:: + + Both pragma directive versions ``#pragma`` and ``# pragma`` are supported. + The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: .. code-block:: python From 68da04b2e9e010c2e4da288a80eeeb9c8e076025 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 5 Oct 2023 09:04:18 -0700 Subject: [PATCH 112/201] fix: block memory allocation overflow (#3639) this fixes potential overflow bugs in pointer calculation by blocking memory allocation above a certain size. the size limit is set at `2**64`, which is the size of addressable memory on physical machines. practically, for EVM use cases, we could limit at a much smaller number (like `2**24`), but we want to allow for "exotic" targets which may allow much more addressable memory. --- vyper/codegen/memory_allocator.py | 12 +++++++++++- vyper/exceptions.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/vyper/codegen/memory_allocator.py b/vyper/codegen/memory_allocator.py index 582d4b9c54..b5e1212917 100644 --- a/vyper/codegen/memory_allocator.py +++ b/vyper/codegen/memory_allocator.py @@ -1,6 +1,6 @@ from typing import List -from vyper.exceptions import CompilerPanic +from vyper.exceptions import CompilerPanic, MemoryAllocationException from vyper.utils import MemoryPositions @@ -46,6 +46,8 @@ class MemoryAllocator: next_mem: int + _ALLOCATION_LIMIT: int = 2**64 + def __init__(self, start_position: int = MemoryPositions.RESERVED_MEMORY): """ Initializer. @@ -110,6 +112,14 @@ def _expand_memory(self, size: int) -> int: before_value = self.next_mem self.next_mem += size self.size_of_mem = max(self.size_of_mem, self.next_mem) + + if self.size_of_mem >= self._ALLOCATION_LIMIT: + # this should not be caught + raise MemoryAllocationException( + f"Tried to allocate {self.size_of_mem} bytes! " + f"(limit is {self._ALLOCATION_LIMIT} (2**64) bytes)" + ) + return before_value def deallocate_memory(self, pos: int, size: int) -> None: diff --git a/vyper/exceptions.py b/vyper/exceptions.py index defca7cc53..8b2020285a 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -269,6 +269,10 @@ class StorageLayoutException(VyperException): """Invalid slot for the storage layout overrides""" +class MemoryAllocationException(VyperException): + """Tried to allocate too much memory""" + + class JSONError(Exception): """Invalid compiler input JSON.""" From 74a8e0254461119af9a5d504f326877d7aed4134 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 17 Oct 2023 09:23:40 -0700 Subject: [PATCH 113/201] chore: reorder compilation of branches in stmt.py (#3603) the if and else branches were being allocated out-of-source-order. this commit switches the order of compilation of the if and else branches to be in source order. this is a hygienic fix, right now the only thing that should be affected is the memory allocator (but if more side effects are ever introduced in codegen, the existing code might compile the side effects out of order). --- vyper/codegen/stmt.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index c2951986c8..254cad32e6 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -91,17 +91,15 @@ def parse_Assign(self): return IRnode.from_list(ret) def parse_If(self): + with self.context.block_scope(): + test_expr = Expr.parse_value_expr(self.stmt.test, self.context) + body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + if self.stmt.orelse: with self.context.block_scope(): - add_on = [parse_body(self.stmt.orelse, self.context)] - else: - add_on = [] + body.extend([parse_body(self.stmt.orelse, self.context)]) - with self.context.block_scope(): - test_expr = Expr.parse_value_expr(self.stmt.test, self.context) - body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + add_on - ir_node = IRnode.from_list(body) - return ir_node + return IRnode.from_list(body) def parse_Log(self): event = self.stmt._metadata["type"] From 5482bbcbed22b856bec6e57c06aeb7e0bee9a1ab Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 18 Oct 2023 00:26:56 +0800 Subject: [PATCH 114/201] feat: remove builtin constants (#3350) Builtin constants like `MAX_INT128`, `MIN_INT128`, `MIN_DECIMAL`, `MAX_DECIMAL` and `MAX_UINT256` have been deprecated since v0.3.4 with the introduction of the `min_value` and `max_value` builtins, and further back for `ZERO_ADDRESS` and `EMPTY_BYTES32` with the `empty` builtin. This PR removes them from the language entirely, and will be a breaking change. --- tests/ast/test_folding.py | 43 ------------------- .../features/decorators/test_private.py | 2 +- .../test_external_contract_calls.py | 14 +++--- .../features/iteration/test_for_in_list.py | 2 +- tests/parser/features/test_assignment.py | 4 +- tests/parser/features/test_memory_dealloc.py | 2 +- tests/parser/functions/test_abi_decode.py | 4 +- tests/parser/functions/test_convert.py | 18 -------- .../parser/functions/test_default_function.py | 2 +- tests/parser/functions/test_empty.py | 20 ++++----- tests/parser/functions/test_return_tuple.py | 4 +- tests/parser/integration/test_escrow.py | 4 +- tests/parser/syntax/test_bool.py | 4 +- tests/parser/syntax/test_interfaces.py | 4 +- tests/parser/syntax/test_no_none.py | 12 +++--- tests/parser/syntax/test_tuple_assign.py | 2 +- tests/parser/syntax/test_unbalanced_return.py | 4 +- tests/parser/types/numbers/test_constants.py | 8 ++-- tests/parser/types/test_identifier_naming.py | 8 +--- tests/parser/types/test_string.py | 2 +- vyper/ast/folding.py | 33 -------------- vyper/compiler/phases.py | 1 - 22 files changed, 49 insertions(+), 148 deletions(-) diff --git a/tests/ast/test_folding.py b/tests/ast/test_folding.py index 22d5f58222..62a7140e97 100644 --- a/tests/ast/test_folding.py +++ b/tests/ast/test_folding.py @@ -132,49 +132,6 @@ def test_replace_constant_no(source): assert vy_ast.compare_nodes(unmodified_ast, folded_ast) -builtins_modified = [ - "ZERO_ADDRESS", - "foo = ZERO_ADDRESS", - "foo: int128[ZERO_ADDRESS] = 42", - "foo = [ZERO_ADDRESS]", - "def foo(bar: address = ZERO_ADDRESS): pass", - "def foo(): bar = ZERO_ADDRESS", - "def foo(): return ZERO_ADDRESS", - "log foo(ZERO_ADDRESS)", - "log foo(42, ZERO_ADDRESS)", -] - - -@pytest.mark.parametrize("source", builtins_modified) -def test_replace_builtin_constant(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_builtin_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -builtins_unmodified = [ - "ZERO_ADDRESS = 2", - "ZERO_ADDRESS()", - "def foo(ZERO_ADDRESS: int128 = 42): pass", - "def foo(): ZERO_ADDRESS = 42", - "def ZERO_ADDRESS(): pass", - "log ZERO_ADDRESS(42)", -] - - -@pytest.mark.parametrize("source", builtins_unmodified) -def test_replace_builtin_constant_no(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_builtin_constants(folded_ast) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - userdefined_modified = [ "FOO", "foo = FOO", diff --git a/tests/parser/features/decorators/test_private.py b/tests/parser/features/decorators/test_private.py index 7c92f72af9..51e6d90ee1 100644 --- a/tests/parser/features/decorators/test_private.py +++ b/tests/parser/features/decorators/test_private.py @@ -304,7 +304,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): b: uint256 = 1 c: int128 = 1 d: int128 = 123 - f: bytes32 = EMPTY_BYTES32 + f: bytes32 = empty(bytes32) f, b, c = self._test(a) assert d == 123 return f, b, c diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index b3cc6f5576..12fcde2f4f 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -775,9 +775,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -808,9 +808,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -841,9 +841,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -1538,7 +1538,7 @@ def out_literals() -> (int128, address, Bytes[10]) : view @external def test(addr: address) -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: Bytes[10] = b"" (a, b, c) = Test(addr).out_literals() return a, b,c diff --git a/tests/parser/features/iteration/test_for_in_list.py b/tests/parser/features/iteration/test_for_in_list.py index bfd960a787..fb01cc98eb 100644 --- a/tests/parser/features/iteration/test_for_in_list.py +++ b/tests/parser/features/iteration/test_for_in_list.py @@ -230,7 +230,7 @@ def iterate_return_second() -> address: count += 1 if count == 2: return i - return ZERO_ADDRESS + return empty(address) """ c = get_contract_with_gas_estimation(code) diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 35b008a8ba..cd26659a5c 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -331,7 +331,7 @@ def foo(): @external def foo(): y: int128 = 1 - z: bytes32 = EMPTY_BYTES32 + z: bytes32 = empty(bytes32) z = y """, """ @@ -344,7 +344,7 @@ def foo(): @external def foo(): y: uint256 = 1 - z: bytes32 = EMPTY_BYTES32 + z: bytes32 = empty(bytes32) z = y """, ], diff --git a/tests/parser/features/test_memory_dealloc.py b/tests/parser/features/test_memory_dealloc.py index de82f03296..814bf0d3bb 100644 --- a/tests/parser/features/test_memory_dealloc.py +++ b/tests/parser/features/test_memory_dealloc.py @@ -9,7 +9,7 @@ def sendit(): nonpayable @external def foo(target: address) -> uint256[2]: - log Shimmy(ZERO_ADDRESS, 3) + log Shimmy(empty(address), 3) amount: uint256 = 1 flargen: uint256 = 42 Other(target).sendit() diff --git a/tests/parser/functions/test_abi_decode.py b/tests/parser/functions/test_abi_decode.py index 2216a5bd76..242841e1cf 100644 --- a/tests/parser/functions/test_abi_decode.py +++ b/tests/parser/functions/test_abi_decode.py @@ -25,7 +25,7 @@ def test_abi_decode_complex(get_contract): @external def abi_decode(x: Bytes[160]) -> (address, int128, bool, decimal, bytes32): - a: address = ZERO_ADDRESS + a: address = empty(address) b: int128 = 0 c: bool = False d: decimal = 0.0 @@ -39,7 +39,7 @@ def abi_decode_struct(x: Bytes[544]) -> Human: name: "", pet: Animal({ name: "", - address_: ZERO_ADDRESS, + address_: empty(address), id_: 0, is_furry: False, price: 0.0, diff --git a/tests/parser/functions/test_convert.py b/tests/parser/functions/test_convert.py index eb8449447c..b5ce613235 100644 --- a/tests/parser/functions/test_convert.py +++ b/tests/parser/functions/test_convert.py @@ -534,24 +534,6 @@ def foo(a: {typ}) -> Status: assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch) -# TODO CMC 2022-04-06 I think this test is somewhat unnecessary. -@pytest.mark.parametrize( - "builtin_constant,out_type,out_value", - [("ZERO_ADDRESS", "bool", False), ("msg.sender", "bool", True)], -) -def test_convert_builtin_constant( - get_contract_with_gas_estimation, builtin_constant, out_type, out_value -): - contract = f""" -@external -def convert_builtin_constant() -> {out_type}: - return convert({builtin_constant}, {out_type}) - """ - - c = get_contract_with_gas_estimation(contract) - assert c.convert_builtin_constant() == out_value - - # uint256 conversion is currently valid due to type inference on literals # not quite working yet same_type_conversion_blocked = sorted(TEST_TYPES - {UINT256_T}) diff --git a/tests/parser/functions/test_default_function.py b/tests/parser/functions/test_default_function.py index 4aa0b04a77..4ad68697ac 100644 --- a/tests/parser/functions/test_default_function.py +++ b/tests/parser/functions/test_default_function.py @@ -41,7 +41,7 @@ def test_basic_default_default_param_function(w3, get_logs, get_contract_with_ga @external @payable def fooBar(a: int128 = 12345) -> int128: - log Sent(ZERO_ADDRESS) + log Sent(empty(address)) return a @external diff --git a/tests/parser/functions/test_empty.py b/tests/parser/functions/test_empty.py index c10d03550a..c3627785dc 100644 --- a/tests/parser/functions/test_empty.py +++ b/tests/parser/functions/test_empty.py @@ -87,8 +87,8 @@ def foo(): self.foobar = empty(address) bar = empty(address) - assert self.foobar == ZERO_ADDRESS - assert bar == ZERO_ADDRESS + assert self.foobar == empty(address) + assert bar == empty(address) """, """ @external @@ -214,12 +214,12 @@ def foo(): self.foobar = empty(address[3]) bar = empty(address[3]) - assert self.foobar[0] == ZERO_ADDRESS - assert self.foobar[1] == ZERO_ADDRESS - assert self.foobar[2] == ZERO_ADDRESS - assert bar[0] == ZERO_ADDRESS - assert bar[1] == ZERO_ADDRESS - assert bar[2] == ZERO_ADDRESS + assert self.foobar[0] == empty(address) + assert self.foobar[1] == empty(address) + assert self.foobar[2] == empty(address) + assert bar[0] == empty(address) + assert bar[1] == empty(address) + assert bar[2] == empty(address) """, ], ) @@ -376,14 +376,14 @@ def foo(): assert self.foobar.c == False assert self.foobar.d == 0.0 assert self.foobar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 - assert self.foobar.f == ZERO_ADDRESS + assert self.foobar.f == empty(address) assert bar.a == 0 assert bar.b == 0 assert bar.c == False assert bar.d == 0.0 assert bar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 - assert bar.f == ZERO_ADDRESS + assert bar.f == empty(address) """ c = get_contract_with_gas_estimation(code) diff --git a/tests/parser/functions/test_return_tuple.py b/tests/parser/functions/test_return_tuple.py index 87b7cdcde3..b375839147 100644 --- a/tests/parser/functions/test_return_tuple.py +++ b/tests/parser/functions/test_return_tuple.py @@ -99,7 +99,7 @@ def out_literals() -> (int128, address, Bytes[10]): @external def test() -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: Bytes[10] = b"" (a, b, c) = self._out_literals() return a, b, c @@ -138,7 +138,7 @@ def test2() -> (int128, address): @external def test3() -> (address, int128): - x: address = ZERO_ADDRESS + x: address = empty(address) self.a, self.c, x, self.d = self._out_literals() return x, self.a """ diff --git a/tests/parser/integration/test_escrow.py b/tests/parser/integration/test_escrow.py index 2982ff9eae..1578f5a418 100644 --- a/tests/parser/integration/test_escrow.py +++ b/tests/parser/integration/test_escrow.py @@ -9,7 +9,7 @@ def test_arbitration_code(w3, get_contract_with_gas_estimation, assert_tx_failed @external def setup(_seller: address, _arbitrator: address): - if self.buyer == ZERO_ADDRESS: + if self.buyer == empty(address): self.buyer = msg.sender self.seller = _seller self.arbitrator = _arbitrator @@ -43,7 +43,7 @@ def test_arbitration_code_with_init(w3, assert_tx_failed, get_contract_with_gas_ @external @payable def __init__(_seller: address, _arbitrator: address): - if self.buyer == ZERO_ADDRESS: + if self.buyer == empty(address): self.buyer = msg.sender self.seller = _seller self.arbitrator = _arbitrator diff --git a/tests/parser/syntax/test_bool.py b/tests/parser/syntax/test_bool.py index 09f799d91c..48ed37321a 100644 --- a/tests/parser/syntax/test_bool.py +++ b/tests/parser/syntax/test_bool.py @@ -52,7 +52,7 @@ def foo() -> bool: """ @external def foo() -> bool: - a: address = ZERO_ADDRESS + a: address = empty(address) return a == 1 """, ( @@ -137,7 +137,7 @@ def foo() -> bool: """ @external def foo2(a: address) -> bool: - return a != ZERO_ADDRESS + return a != empty(address) """, ] diff --git a/tests/parser/syntax/test_interfaces.py b/tests/parser/syntax/test_interfaces.py index 5afb34e6bd..498f1363d8 100644 --- a/tests/parser/syntax/test_interfaces.py +++ b/tests/parser/syntax/test_interfaces.py @@ -47,7 +47,7 @@ def test(): @external def test(): - a: address(ERC20) = ZERO_ADDRESS + a: address(ERC20) = empty(address) """, InvalidType, ), @@ -306,7 +306,7 @@ def some_func(): nonpayable @external def __init__(): - self.my_interface[self.idx] = MyInterface(ZERO_ADDRESS) + self.my_interface[self.idx] = MyInterface(empty(address)) """, """ interface MyInterface: diff --git a/tests/parser/syntax/test_no_none.py b/tests/parser/syntax/test_no_none.py index 7030a56b18..24c32a46a4 100644 --- a/tests/parser/syntax/test_no_none.py +++ b/tests/parser/syntax/test_no_none.py @@ -30,13 +30,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) bar = None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) bar = None """, """ @@ -104,13 +104,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) assert bar is None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) assert bar is None """, ] @@ -148,13 +148,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) assert bar == None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) assert bar == None """, ] diff --git a/tests/parser/syntax/test_tuple_assign.py b/tests/parser/syntax/test_tuple_assign.py index 115499ce8b..49b63ee614 100644 --- a/tests/parser/syntax/test_tuple_assign.py +++ b/tests/parser/syntax/test_tuple_assign.py @@ -41,7 +41,7 @@ def out_literals() -> (int128, int128, Bytes[10]): @external def test() -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) a, b = self.out_literals() # tuple count mismatch return """, diff --git a/tests/parser/syntax/test_unbalanced_return.py b/tests/parser/syntax/test_unbalanced_return.py index 5337b4b677..d1d9732777 100644 --- a/tests/parser/syntax/test_unbalanced_return.py +++ b/tests/parser/syntax/test_unbalanced_return.py @@ -56,7 +56,7 @@ def valid_address(sender: address) -> bool: """ @internal def valid_address(sender: address) -> bool: - if sender == ZERO_ADDRESS: + if sender == empty(address): selfdestruct(sender) _sender: address = sender else: @@ -144,7 +144,7 @@ def test() -> int128: """ @external def test() -> int128: - x: bytes32 = EMPTY_BYTES32 + x: bytes32 = empty(bytes32) if False: if False: return 0 diff --git a/tests/parser/types/numbers/test_constants.py b/tests/parser/types/numbers/test_constants.py index 0d5e386dad..652c8e8bd9 100644 --- a/tests/parser/types/numbers/test_constants.py +++ b/tests/parser/types/numbers/test_constants.py @@ -12,12 +12,12 @@ def test_builtin_constants(get_contract_with_gas_estimation): code = """ @external def test_zaddress(a: address) -> bool: - return a == ZERO_ADDRESS + return a == empty(address) @external def test_empty_bytes32(a: bytes32) -> bool: - return a == EMPTY_BYTES32 + return a == empty(bytes32) @external @@ -81,12 +81,12 @@ def goo() -> int128: @external def hoo() -> bytes32: - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) return bar @external def joo() -> address: - bar: address = ZERO_ADDRESS + bar: address = empty(address) return bar @external diff --git a/tests/parser/types/test_identifier_naming.py b/tests/parser/types/test_identifier_naming.py index 5cfc7e8ed7..0a93329848 100755 --- a/tests/parser/types/test_identifier_naming.py +++ b/tests/parser/types/test_identifier_naming.py @@ -1,16 +1,12 @@ import pytest -from vyper.ast.folding import BUILTIN_CONSTANTS from vyper.ast.identifiers import RESERVED_KEYWORDS from vyper.builtins.functions import BUILTIN_FUNCTIONS from vyper.codegen.expr import ENVIRONMENT_VARIABLES from vyper.exceptions import NamespaceCollision, StructureException, SyntaxException from vyper.semantics.types.primitives import AddressT -BUILTIN_CONSTANTS = set(BUILTIN_CONSTANTS.keys()) -ALL_RESERVED_KEYWORDS = ( - BUILTIN_CONSTANTS | BUILTIN_FUNCTIONS | RESERVED_KEYWORDS | ENVIRONMENT_VARIABLES -) +ALL_RESERVED_KEYWORDS = BUILTIN_FUNCTIONS | RESERVED_KEYWORDS | ENVIRONMENT_VARIABLES @pytest.mark.parametrize("constant", sorted(ALL_RESERVED_KEYWORDS)) @@ -46,7 +42,7 @@ def test({constant}: int128): SELF_NAMESPACE_MEMBERS = set(AddressT._type_members.keys()) -DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | RESERVED_KEYWORDS | BUILTIN_CONSTANTS +DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | RESERVED_KEYWORDS ALLOWED_FN_NAMES = ALL_RESERVED_KEYWORDS - DISALLOWED_FN_NAMES diff --git a/tests/parser/types/test_string.py b/tests/parser/types/test_string.py index a5eef66dae..7f1fa71329 100644 --- a/tests/parser/types/test_string.py +++ b/tests/parser/types/test_string.py @@ -139,7 +139,7 @@ def out_literals() -> (int128, address, String[10]) : view @external def test(addr: address) -> (int128, address, String[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: String[10] = "" (a, b, c) = Test(addr).out_literals() return a, b,c diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index fbd1dfc2f4..38d58f6fd0 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,4 +1,3 @@ -import warnings from typing import Optional, Union from vyper.ast import nodes as vy_ast @@ -6,21 +5,6 @@ from vyper.exceptions import UnfoldableNode, UnknownType from vyper.semantics.types.base import VyperType from vyper.semantics.types.utils import type_from_annotation -from vyper.utils import SizeLimits - -BUILTIN_CONSTANTS = { - "EMPTY_BYTES32": ( - vy_ast.Hex, - "0x0000000000000000000000000000000000000000000000000000000000000000", - "empty(bytes32)", - ), # NOQA: E501 - "ZERO_ADDRESS": (vy_ast.Hex, "0x0000000000000000000000000000000000000000", "empty(address)"), - "MAX_INT128": (vy_ast.Int, 2**127 - 1, "max_value(int128)"), - "MIN_INT128": (vy_ast.Int, -(2**127), "min_value(int128)"), - "MAX_DECIMAL": (vy_ast.Decimal, SizeLimits.MAX_AST_DECIMAL, "max_value(decimal)"), - "MIN_DECIMAL": (vy_ast.Decimal, SizeLimits.MIN_AST_DECIMAL, "min_value(decimal)"), - "MAX_UINT256": (vy_ast.Int, 2**256 - 1, "max_value(uint256)"), -} def fold(vyper_module: vy_ast.Module) -> None: @@ -32,8 +16,6 @@ def fold(vyper_module: vy_ast.Module) -> None: vyper_module : Module Top-level Vyper AST node. """ - replace_builtin_constants(vyper_module) - changed_nodes = 1 while changed_nodes: changed_nodes = 0 @@ -138,21 +120,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: return changed_nodes -def replace_builtin_constants(vyper_module: vy_ast.Module) -> None: - """ - Replace references to builtin constants with their literal values. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - for name, (node, value, replacement) in BUILTIN_CONSTANTS.items(): - found = replace_constant(vyper_module, name, node(value=value), True) - if found > 0: - warnings.warn(f"{name} is deprecated. Please use `{replacement}` instead.") - - def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: """ Find user-defined constant assignments, and replace references diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5ddf071caf..72be4396e4 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -235,7 +235,6 @@ def generate_unfolded_ast( vyper_module: vy_ast.Module, interface_codes: Optional[InterfaceImports] ) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) - vy_ast.folding.replace_builtin_constants(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, interface_codes) From 3ba14124602b673d45b86bae7ff90a01d782acb5 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:26:40 +0800 Subject: [PATCH 115/201] refactor: merge `annotation.py` and `local.py` (#3456) This commit merges two phases of analysis: typechecking (`vyper/semantics/analysis/local.py`) and annotation of ast nodes with types (`vyper/semantics/analysis/annotation.py`). This is both for consistency with how it is done for `analysis/module.py`, and also because it increases internal consistency, as some typechecking was being done in `annotation.py`. It is also easier to maintain this way, since bugfixes or modifications to typechecking need to only be done in one place, rather than two passes. Lastly, it also probably improves performance, because it collapses two passes into one (and calls `get_*_type_from_node` less often). This commit also fixes a bug with accessing the iterator when the iterator is an empty list literal. --- tests/parser/syntax/test_list.py | 3 +- vyper/ast/nodes.pyi | 5 + vyper/builtins/_signatures.py | 2 +- vyper/builtins/functions.py | 6 + vyper/semantics/analysis/annotation.py | 283 ------------ vyper/semantics/analysis/common.py | 19 +- vyper/semantics/analysis/local.py | 607 ++++++++++++++++--------- vyper/semantics/analysis/utils.py | 13 +- 8 files changed, 420 insertions(+), 518 deletions(-) delete mode 100644 vyper/semantics/analysis/annotation.py diff --git a/tests/parser/syntax/test_list.py b/tests/parser/syntax/test_list.py index 3f81b911c8..db41de5526 100644 --- a/tests/parser/syntax/test_list.py +++ b/tests/parser/syntax/test_list.py @@ -305,8 +305,9 @@ def foo(): """ @external def foo(): + x: DynArray[uint256, 3] = [1, 2, 3] for i in [[], []]: - pass + x = i """, ] diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 0d59a2fa63..47c9af8526 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -142,6 +142,7 @@ class Expr(VyperNode): class UnaryOp(ExprNode): op: VyperNode = ... + operand: VyperNode = ... class USub(VyperNode): ... class Not(VyperNode): ... @@ -165,12 +166,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... + values: list[VyperNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... + left: VyperNode = ... + right: VyperNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -179,6 +183,7 @@ class LtE(VyperNode): ... class Gt(VyperNode): ... class GtE(VyperNode): ... class In(VyperNode): ... +class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d39a4a085f..2802421129 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -74,7 +74,7 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction: +class BuiltinFunction(VyperType): _has_varargs = False _kwargs: Dict[str, KwargSettings] = {} diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f07202831d..001939638b 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -475,6 +475,12 @@ def evaluate(self, node): return vy_ast.Int.from_node(node, value=length) + def infer_arg_types(self, node): + self._validate_arg_types(node) + # return a concrete type + typ = get_possible_types_from_node(node.args[0]).pop() + return [typ] + def build_IR(self, node, context): arg = Expr(node.args[0], context).ir_node if arg.value == "~calldata": diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py deleted file mode 100644 index 01ca51d7f4..0000000000 --- a/vyper/semantics/analysis/annotation.py +++ /dev/null @@ -1,283 +0,0 @@ -from vyper import ast as vy_ast -from vyper.exceptions import StructureException, TypeCheckFailure -from vyper.semantics.analysis.utils import ( - get_common_types, - get_exact_type_from_node, - get_possible_types_from_node, -) -from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t -from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT - - -class _AnnotationVisitorBase: - - """ - Annotation visitor base class. - - Annotation visitors apply metadata (such as type information) to vyper AST objects. - Immediately after type checking a statement-level node, that node is passed to - `StatementAnnotationVisitor`. Some expression nodes are then passed onward to - `ExpressionAnnotationVisitor` for additional annotation. - """ - - def visit(self, node, *args): - if isinstance(node, self.ignored_types): - return - # iterate over the MRO until we find a matching visitor function - # this lets us use a single function to broadly target several - # node types with a shared parent - for class_ in node.__class__.mro(): - ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - visitor_fn(node, *args) - return - raise StructureException(f"Cannot annotate: {node.ast_type}", node) - - -class StatementAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = (vy_ast.Break, vy_ast.Continue, vy_ast.Pass, vy_ast.Raise) - - def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None: - self.func = fn_node._metadata["type"] - self.namespace = namespace - self.expr_visitor = ExpressionAnnotationVisitor(self.func) - - assert self.func.n_keyword_args == len(fn_node.args.defaults) - for kwarg in self.func.keyword_args: - self.expr_visitor.visit(kwarg.default_value, kwarg.typ) - - def visit(self, node): - super().visit(node) - - def visit_AnnAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Assert(self, node): - self.expr_visitor.visit(node.test) - - def visit_Assign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_AugAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Expr(self, node): - self.expr_visitor.visit(node.value) - - def visit_If(self, node): - self.expr_visitor.visit(node.test) - - def visit_Log(self, node): - node._metadata["type"] = self.namespace[node.value.func.id] - self.expr_visitor.visit(node.value) - - def visit_Return(self, node): - if node.value is not None: - self.expr_visitor.visit(node.value, self.func.return_type) - - def visit_For(self, node): - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - self.expr_visitor.visit(node.iter) - - iter_type = node.target._metadata["type"] - if isinstance(node.iter, vy_ast.List): - # typecheck list literal as static array - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) - - -class ExpressionAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = () - - def __init__(self, fn_node: ContractFunctionT): - self.func = fn_node - - def visit(self, node, type_=None): - # the statement visitor sometimes passes type information about expressions - super().visit(node, type_) - - def visit_Attribute(self, node, type_): - type_ = get_exact_type_from_node(node) - node._metadata["type"] = type_ - self.visit(node.value, None) - - def visit_BinOp(self, node, type_): - if type_ is None: - type_ = get_common_types(node.left, node.right) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_BoolOp(self, node, type_): - for value in node.values: - self.visit(value) - - def visit_Call(self, node, type_): - call_type = get_exact_type_from_node(node.func) - node_type = type_ or call_type.fetch_call_return(node) - node._metadata["type"] = node_type - self.visit(node.func) - - if isinstance(call_type, ContractFunctionT): - # function calls - if call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): - self.visit(arg, typ) - for kwarg in node.keywords: - # We should only see special kwargs - self.visit(kwarg.value, call_type.call_site_kwargs[kwarg.arg].typ) - - elif is_type_t(call_type, EventT): - # events have no kwargs - for arg, typ in zip(node.args, list(call_type.typedef.arguments.values())): - self.visit(arg, typ) - elif is_type_t(call_type, StructT): - # struct ctors - # ctors have no kwargs - for value, arg_type in zip( - node.args[0].values, list(call_type.typedef.members.values()) - ): - self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): - self.visit(arg, arg_type) - else: - # builtin functions - arg_types = call_type.infer_arg_types(node) - for arg, arg_type in zip(node.args, arg_types): - self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) - for kwarg in node.keywords: - self.visit(kwarg.value, kwarg_types[kwarg.arg]) - - def visit_Compare(self, node, type_): - if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): - if isinstance(node.right, vy_ast.List): - type_ = get_common_types(node.left, *node.right.elements).pop() - self.visit(node.left, type_) - rlen = len(node.right.elements) - self.visit(node.right, SArrayT(type_, rlen)) - else: - type_ = get_exact_type_from_node(node.right) - self.visit(node.right, type_) - if isinstance(type_, EnumT): - self.visit(node.left, type_) - else: - # array membership - self.visit(node.left, type_.value_type) - else: - type_ = get_common_types(node.left, node.right).pop() - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_Constant(self, node, type_): - if type_ is None: - possible_types = get_possible_types_from_node(node) - if len(possible_types) == 1: - type_ = possible_types.pop() - node._metadata["type"] = type_ - - def visit_Dict(self, node, type_): - node._metadata["type"] = type_ - - def visit_Index(self, node, type_): - self.visit(node.value, type_) - - def visit_List(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node) - # CMC 2022-04-14 this seems sus. try to only annotate - # if get_possible_types only returns 1 type - if len(type_) >= 1: - type_ = type_.pop() - node._metadata["type"] = type_ - for element in node.elements: - self.visit(element, type_.value_type) - - def visit_Name(self, node, type_): - if isinstance(type_, TYPE_T): - node._metadata["type"] = type_ - else: - node._metadata["type"] = get_exact_type_from_node(node) - - def visit_Subscript(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - if isinstance(node.value, vy_ast.List): - possible_base_types = get_possible_types_from_node(node.value) - - if len(possible_base_types) == 1: - base_type = possible_base_types.pop() - - elif type_ is not None and len(possible_base_types) > 1: - for possible_type in possible_base_types: - if type_.compare_type(possible_type.value_type): - base_type = possible_type - break - else: - # this should have been caught in - # `get_possible_types_from_node` but wasn't. - raise TypeCheckFailure(f"Expected {type_} but it is not a possible type", node) - - else: - base_type = get_exact_type_from_node(node.value) - - # get the correct type for the index, it might - # not be base_type.key_type - index_types = get_possible_types_from_node(node.slice.value) - index_type = index_types.pop() - - self.visit(node.slice, index_type) - self.visit(node.value, base_type) - - def visit_Tuple(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - for element, subtype in zip(node.elements, type_.member_types): - self.visit(element, subtype) - - def visit_UnaryOp(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node.operand) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - self.visit(node.operand, type_) - - def visit_IfExp(self, node, type_): - if type_ is None: - ts = get_common_types(node.body, node.orelse) - if len(ts) == 1: - type_ = ts.pop() - - node._metadata["type"] = type_ - self.visit(node.test, BoolT()) - self.visit(node.body, type_) - self.visit(node.orelse, type_) diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 193d1892e1..507eb0a570 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -10,10 +10,17 @@ class VyperNodeVisitorBase: def visit(self, node, *args): if isinstance(node, self.ignored_types): return + + # iterate over the MRO until we find a matching visitor function + # this lets us use a single function to broadly target several + # node types with a shared parent + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) + node_type = type(node).__name__ - visitor_fn = getattr(self, f"visit_{node_type}", None) - if visitor_fn is None: - raise StructureException( - f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node - ) - visitor_fn(node, *args) + raise StructureException( + f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node + ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b391b33953..647f01c299 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -14,11 +14,11 @@ NonPayableViolation, StateAccessViolation, StructureException, + TypeCheckFailure, TypeMismatch, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.annotation import StatementAnnotationVisitor from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( @@ -34,9 +34,11 @@ from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( + TYPE_T, AddressT, BoolT, DArrayT, + EnumT, EventT, HashMapT, IntegerT, @@ -44,6 +46,8 @@ StringT, StructT, TupleT, + VyperType, + _BytestringT, is_type_t, ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability @@ -117,20 +121,8 @@ def _check_iterator_modification( return None -def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None: - if msg_node: - if isinstance(msg_node, vy_ast.Str): - if not msg_node.value.strip(): - raise StructureException("Reason string cannot be empty", msg_node) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - validate_expected_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - - -def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: - value_type = get_exact_type_from_node(node.value) +# helpers +def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": # Validate `slice(
.code, start, length)` where `length` is constant parent = node.get_ancestor() @@ -139,6 +131,7 @@ def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) if ok_func and ok_args: return + raise StructureException( "(address).code is only allowed inside of a slice function with a constant length", node ) @@ -160,8 +153,30 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: ) +def _validate_msg_value_access(node: vy_ast.Attribute) -> None: + if isinstance(node.value, vy_ast.Name) and node.attr == "value" and node.value.id == "msg": + raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) + + +def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: + env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) + if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: + if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: + return + + raise StateAccessViolation( + "not allowed to query contract or environment variables in pure functions", node + ) + + +def _validate_self_reference(node: vy_ast.Name) -> None: + # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through + if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): + raise StateAccessViolation("not allowed to query self in pure functions", node) + + class FunctionNodeVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Pass) + ignored_types = (vy_ast.Pass,) scope_name = "function" def __init__( @@ -171,8 +186,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["type"] - self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace) - self.expr_visitor = _LocalExpressionVisitor() + self.expr_visitor = _ExprVisitor(self.func) # allow internal function params to be mutable location, is_immutable = ( @@ -189,44 +203,13 @@ def __init__( f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node ) - if self.func.mutability == StateMutability.PURE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, - { - "value.id": set(CONSTANT_ENVIRONMENT_VARS.keys()).union( - set(MUTABLE_ENVIRONMENT_VARS.keys()) - ) - }, - ) - - # Add references to `self` as standalone address - self_references = fn_node.get_descendants(vy_ast.Name, {"id": "self"}) - standalone_self = [ - n for n in self_references if not isinstance(n.get_ancestor(), vy_ast.Attribute) - ] - node_list.extend(standalone_self) # type: ignore - - for node in node_list: - t = node._metadata.get("type") - if isinstance(t, ContractFunctionT) and t.mutability == StateMutability.PURE: - # allowed - continue - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", - node_list[0], - ) - if self.func.mutability is not StateMutability.PAYABLE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "msg", "attr": "value"} - ) - if node_list: - raise NonPayableViolation( - "msg.value is not allowed in non-payable functions", node_list[0] - ) + # visit default args + assert self.func.n_keyword_args == len(fn_node.args.defaults) + for kwarg in self.func.keyword_args: + self.expr_visitor.visit(kwarg.default_value, kwarg.typ) def visit(self, node): super().visit(node) - self.annotation_visitor.visit(node) def visit_AnnAssign(self, node): name = node.get("target.id") @@ -238,16 +221,42 @@ def visit_AnnAssign(self, node): "Memory variables must be declared with an initial value", node ) - type_ = type_from_annotation(node.annotation, DataLocation.MEMORY) - validate_expected_type(node.value, type_) + typ = type_from_annotation(node.annotation, DataLocation.MEMORY) + validate_expected_type(node.value, typ) try: - self.namespace[name] = VarInfo(type_, location=DataLocation.MEMORY) + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) except VyperException as exc: raise exc.with_annotation(node) from None - self.expr_visitor.visit(node.value) - def visit_Assign(self, node): + self.expr_visitor.visit(node.target, typ) + self.expr_visitor.visit(node.value, typ) + + def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Str): + if not msg_node.value.strip(): + raise StructureException("Reason string cannot be empty", msg_node) + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): + try: + validate_expected_type(msg_node, StringT(1024)) + except TypeMismatch as e: + raise InvalidType("revert reason must fit within String[1024]") from e + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + + def visit_Assert(self, node): + if node.msg: + self._validate_revert_reason(node.msg) + + try: + validate_expected_type(node.test, BoolT()) + except InvalidType: + raise InvalidType("Assertion test value must be a boolean", node.test) + self.expr_visitor.visit(node.test, BoolT()) + + # repeated code for Assign and AugAssign + def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) @@ -260,81 +269,71 @@ def visit_Assign(self, node): validate_expected_type(node.value, target.typ) target.validate_modification(node, self.func.mutability) - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) + self.expr_visitor.visit(node.value, target.typ) + self.expr_visitor.visit(node.target, target.typ) - def visit_AugAssign(self, node): - if isinstance(node.value, vy_ast.Tuple): - raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - - lhs_info = get_expr_info(node.target) - - validate_expected_type(node.value, lhs_info.typ) - lhs_info.validate_modification(node, self.func.mutability) - - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) - - def visit_Raise(self, node): - if node.exc: - _validate_revert_reason(node.exc) - self.expr_visitor.visit(node.exc) + def visit_Assign(self, node): + self._assign_helper(node) - def visit_Assert(self, node): - if node.msg: - _validate_revert_reason(node.msg) - self.expr_visitor.visit(node.msg) + def visit_AugAssign(self, node): + self._assign_helper(node) - try: - validate_expected_type(node.test, BoolT()) - except InvalidType: - raise InvalidType("Assertion test value must be a boolean", node.test) - self.expr_visitor.visit(node.test) + def visit_Break(self, node): + for_node = node.get_ancestor(vy_ast.For) + if for_node is None: + raise StructureException("`break` must be enclosed in a `for` loop", node) def visit_Continue(self, node): + # TODO: use context/state instead of ast search for_node = node.get_ancestor(vy_ast.For) if for_node is None: raise StructureException("`continue` must be enclosed in a `for` loop", node) - def visit_Break(self, node): - for_node = node.get_ancestor(vy_ast.For) - if for_node is None: - raise StructureException("`break` must be enclosed in a `for` loop", node) + def visit_Expr(self, node): + if not isinstance(node.value, vy_ast.Call): + raise StructureException("Expressions without assignment are disallowed", node) - def visit_Return(self, node): - values = node.value - if values is None: - if self.func.return_type: - raise FunctionDeclarationException("Return statement is missing a value", node) - return - elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): + raise StructureException("To call an event you must use the `log` statement", node) - if isinstance(values, vy_ast.Tuple): - values = values.elements - if not isinstance(self.func.return_type, TupleT): - raise FunctionDeclarationException("Function only returns a single value", node) - if self.func.return_type.length != len(values): - raise FunctionDeclarationException( - f"Incorrect number of return values: " - f"expected {self.func.return_type.length}, got {len(values)}", + if is_type_t(fn_type, StructT): + raise StructureException("Struct creation without assignment is disallowed", node) + + if isinstance(fn_type, ContractFunctionT): + if ( + fn_type.mutability > StateMutability.VIEW + and self.func.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot call a mutating function from a {self.func.mutability.value} function", node, ) - for given, expected in zip(values, self.func.return_type.member_types): - validate_expected_type(given, expected) - else: - validate_expected_type(values, self.func.return_type) - self.expr_visitor.visit(node.value) - def visit_If(self, node): - validate_expected_type(node.test, BoolT()) - self.expr_visitor.visit(node.test) - with self.namespace.enter_scope(): - for n in node.body: - self.visit(n) - with self.namespace.enter_scope(): - for n in node.orelse: - self.visit(n) + if ( + self.func.mutability == StateMutability.PURE + and fn_type.mutability != StateMutability.PURE + ): + raise StateAccessViolation( + "Cannot call non-pure function from a pure function", node + ) + + if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: + # it's a dotted function call like dynarray.pop() + expr_info = get_expr_info(node.value.func.value) + expr_info.validate_modification(node, self.func.mutability) + + # NOTE: fetch_call_return validates call args. + return_value = fn_type.fetch_call_return(node.value) + if ( + return_value + and not isinstance(fn_type, MemberFunctionT) + and not isinstance(fn_type, ContractFunctionT) + ): + raise StructureException( + f"Function '{fn_type._id}' cannot be called without assigning the result", node + ) + self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): @@ -463,19 +462,18 @@ def visit_For(self, node): f"which potentially modifies iterated storage variable '{iter_name}'", call_node, ) - self.expr_visitor.visit(node.iter) if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) for_loop_exceptions = [] iter_name = node.target.id - for type_ in type_list: + for possible_target_type in type_list: # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): try: - self.namespace[iter_name] = VarInfo(type_, is_constant=True) + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) except VyperException as exc: raise exc.with_annotation(node) from None @@ -486,17 +484,27 @@ def visit_For(self, node): except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - # type information is applied directly here because the - # scope is closed prior to the call to - # `StatementAnnotationVisitor` - node.target._metadata["type"] = type_ - - # success -- bail out instead of error handling. + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + + # success -- do not enter error handling section return - # if we have gotten here, there was an error for - # every type tried for the iterator - + # failed to find a good type. bail out if len(set(str(i) for i in for_loop_exceptions)) == 1: # if every attempt at type checking raised the same exception raise for_loop_exceptions[0] @@ -510,56 +518,20 @@ def visit_For(self, node): "but type checking fails with all possible types:", node, *( - (f"Casting '{iter_name}' as {type_}: {exc.message}", exc.annotations[0]) - for type_, exc in zip(type_list, for_loop_exceptions) + (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) + for typ, exc in zip(type_list, for_loop_exceptions) ), ) - def visit_Expr(self, node): - if not isinstance(node.value, vy_ast.Call): - raise StructureException("Expressions without assignment are disallowed", node) - - fn_type = get_exact_type_from_node(node.value.func) - if is_type_t(fn_type, EventT): - raise StructureException("To call an event you must use the `log` statement", node) - - if is_type_t(fn_type, StructT): - raise StructureException("Struct creation without assignment is disallowed", node) - - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - - # NOTE: fetch_call_return validates call args. - return_value = fn_type.fetch_call_return(node.value) - if ( - return_value - and not isinstance(fn_type, MemberFunctionT) - and not isinstance(fn_type, ContractFunctionT) - ): - raise StructureException( - f"Function '{fn_type._id}' cannot be called without assigning the result", node - ) - self.expr_visitor.visit(node.value) + def visit_If(self, node): + validate_expected_type(node.test, BoolT()) + self.expr_visitor.visit(node.test, BoolT()) + with self.namespace.enter_scope(): + for n in node.body: + self.visit(n) + with self.namespace.enter_scope(): + for n in node.orelse: + self.visit(n) def visit_Log(self, node): if not isinstance(node.value, vy_ast.Call): @@ -572,62 +544,249 @@ def visit_Log(self, node): f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node ) f.fetch_call_return(node.value) - self.expr_visitor.visit(node.value) + node._metadata["type"] = f.typedef + self.expr_visitor.visit(node.value, f.typedef) + + def visit_Raise(self, node): + if node.exc: + self._validate_revert_reason(node.exc) + def visit_Return(self, node): + values = node.value + if values is None: + if self.func.return_type: + raise FunctionDeclarationException("Return statement is missing a value", node) + return + elif self.func.return_type is None: + raise FunctionDeclarationException("Function does not return any values", node) -class _LocalExpressionVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Name) + if isinstance(values, vy_ast.Tuple): + values = values.elements + if not isinstance(self.func.return_type, TupleT): + raise FunctionDeclarationException("Function only returns a single value", node) + if self.func.return_type.length != len(values): + raise FunctionDeclarationException( + f"Incorrect number of return values: " + f"expected {self.func.return_type.length}, got {len(values)}", + node, + ) + for given, expected in zip(values, self.func.return_type.member_types): + validate_expected_type(given, expected) + else: + validate_expected_type(values, self.func.return_type) + self.expr_visitor.visit(node.value, self.func.return_type) + + +class _ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def visit_Attribute(self, node: vy_ast.Attribute) -> None: - self.visit(node.value) + def __init__(self, fn_node: ContractFunctionT): + self.func = fn_node + + def visit(self, node, typ): + # recurse and typecheck in case we are being fed the wrong type for + # some reason. note that `validate_expected_type` is unnecessary + # for nodes that already call `get_exact_type_from_node` and + # `get_possible_types_from_node` because `validate_expected_type` + # would be calling the same function again. + # CMC 2023-06-27 would be cleanest to call validate_expected_type() + # before recursing but maybe needs some refactoring before that + # can happen. + super().visit(node, typ) + + # annotate + node._metadata["type"] = typ + + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) - _validate_address_code_attribute(node) - - def visit_BinOp(self, node: vy_ast.BinOp) -> None: - self.visit(node.left) - self.visit(node.right) - - def visit_BoolOp(self, node: vy_ast.BoolOp) -> None: - for value in node.values: # type: ignore[attr-defined] - self.visit(value) - - def visit_Call(self, node: vy_ast.Call) -> None: - self.visit(node.func) - for arg in node.args: - self.visit(arg) - for kwarg in node.keywords: - self.visit(kwarg.value) - - def visit_Compare(self, node: vy_ast.Compare) -> None: - self.visit(node.left) # type: ignore[attr-defined] - self.visit(node.right) # type: ignore[attr-defined] - - def visit_Dict(self, node: vy_ast.Dict) -> None: - for key in node.keys: - self.visit(key) + + # CMC 2023-10-19 TODO generalize this to mutability check on every node. + # something like, + # if self.func.mutability < expr_info.mutability: + # raise ... + + if self.func.mutability != StateMutability.PAYABLE: + _validate_msg_value_access(node) + + if self.func.mutability == StateMutability.PURE: + _validate_pure_access(node, typ) + + value_type = get_exact_type_from_node(node.value) + _validate_address_code(node, value_type) + + self.visit(node.value, value_type) + + def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None: + validate_expected_type(node.left, typ) + self.visit(node.left, typ) + + rtyp = typ + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): + rtyp = get_possible_types_from_node(node.right).pop() + + validate_expected_type(node.right, rtyp) + + self.visit(node.right, rtyp) + + def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: + assert typ == BoolT() # sanity check for value in node.values: - self.visit(value) + validate_expected_type(value, BoolT()) + self.visit(value, BoolT()) + + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: + call_type = get_exact_type_from_node(node.func) + # except for builtin functions, `get_exact_type_from_node` + # already calls `validate_expected_type` on the call args + # and kwargs via `call_type.fetch_call_return` + self.visit(node.func, call_type) + + if isinstance(call_type, ContractFunctionT): + # function calls + if call_type.is_internal: + self.func.called_functions.add(call_type) + for arg, typ in zip(node.args, call_type.argument_types): + self.visit(arg, typ) + for kwarg in node.keywords: + # We should only see special kwargs + typ = call_type.call_site_kwargs[kwarg.arg].typ + self.visit(kwarg.value, typ) + + elif is_type_t(call_type, EventT): + # events have no kwargs + expected_types = call_type.typedef.arguments.values() + for arg, typ in zip(node.args, expected_types): + self.visit(arg, typ) + elif is_type_t(call_type, StructT): + # struct ctors + # ctors have no kwargs + expected_types = call_type.typedef.members.values() + for value, arg_type in zip(node.args[0].values, expected_types): + self.visit(value, arg_type) + elif isinstance(call_type, MemberFunctionT): + assert len(node.args) == len(call_type.arg_types) + for arg, arg_type in zip(node.args, call_type.arg_types): + self.visit(arg, arg_type) + else: + # builtin functions + arg_types = call_type.infer_arg_types(node) + # `infer_arg_types` already calls `validate_expected_type` + for arg, arg_type in zip(node.args, arg_types): + self.visit(arg, arg_type) + kwarg_types = call_type.infer_kwarg_types(node) + for kwarg in node.keywords: + self.visit(kwarg.value, kwarg_types[kwarg.arg]) + + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + # membership in list literal - `x in [a, b, c]` + # needle: ltyp, haystack: rtyp + if isinstance(node.right, vy_ast.List): + ltyp = get_common_types(node.left, *node.right.elements).pop() + + rlen = len(node.right.elements) + rtyp = SArrayT(ltyp, rlen) + validate_expected_type(node.right, rtyp) + else: + rtyp = get_exact_type_from_node(node.right) + if isinstance(rtyp, EnumT): + # enum membership - `some_enum in other_enum` + ltyp = rtyp + else: + # array membership - `x in my_list_variable` + assert isinstance(rtyp, (SArrayT, DArrayT)) + ltyp = rtyp.value_type - def visit_Index(self, node: vy_ast.Index) -> None: - self.visit(node.value) + validate_expected_type(node.left, ltyp) - def visit_List(self, node: vy_ast.List) -> None: - for element in node.elements: - self.visit(element) + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + else: + # ex. a < b + cmp_typ = get_common_types(node.left, node.right).pop() + if isinstance(cmp_typ, _BytestringT): + # for bytestrings, get_common_types automatically downcasts + # to the smaller common type - that will annotate with the + # wrong type, instead use get_exact_type_from_node (which + # resolves to the right type for bytestrings anyways). + ltyp = get_exact_type_from_node(node.left) + rtyp = get_exact_type_from_node(node.right) + else: + ltyp = rtyp = cmp_typ + validate_expected_type(node.left, ltyp) + validate_expected_type(node.right, rtyp) + + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None: + validate_expected_type(node, typ) - def visit_Subscript(self, node: vy_ast.Subscript) -> None: - self.visit(node.value) - self.visit(node.slice) + def visit_Index(self, node: vy_ast.Index, typ: VyperType) -> None: + validate_expected_type(node.value, typ) + self.visit(node.value, typ) - def visit_Tuple(self, node: vy_ast.Tuple) -> None: + def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: + assert isinstance(typ, (SArrayT, DArrayT)) for element in node.elements: - self.visit(element) + validate_expected_type(element, typ.value_type) + self.visit(element, typ.value_type) - def visit_UnaryOp(self, node: vy_ast.UnaryOp) -> None: - self.visit(node.operand) # type: ignore[attr-defined] + def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) + + if not isinstance(typ, TYPE_T): + validate_expected_type(node, typ) + + def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + if isinstance(node.value, vy_ast.List): + possible_base_types = get_possible_types_from_node(node.value) + + for possible_type in possible_base_types: + if typ.compare_type(possible_type.value_type): + base_type = possible_type + break + else: + # this should have been caught in + # `get_possible_types_from_node` but wasn't. + raise TypeCheckFailure(f"Expected {typ} but it is not a possible type", node) + + else: + base_type = get_exact_type_from_node(node.value) + + # get the correct type for the index, it might + # not be exactly base_type.key_type + # note: index_type is validated in types_from_Subscript + index_types = get_possible_types_from_node(node.slice.value) + index_type = index_types.pop() + + self.visit(node.slice, index_type) + self.visit(node.value, base_type) + + def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + assert isinstance(typ, TupleT) + for element, subtype in zip(node.elements, typ.member_types): + validate_expected_type(element, subtype) + self.visit(element, subtype) - def visit_IfExp(self, node: vy_ast.IfExp) -> None: - self.visit(node.test) - self.visit(node.body) - self.visit(node.orelse) + def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: VyperType) -> None: + validate_expected_type(node.operand, typ) + self.visit(node.operand, typ) + + def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: + validate_expected_type(node.test, BoolT()) + self.visit(node.test, BoolT()) + validate_expected_type(node.body, typ) + self.visit(node.body, typ) + validate_expected_type(node.orelse, typ) + self.visit(node.orelse, typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4f911764e0..afa6b56838 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -312,10 +312,17 @@ def types_from_Constant(self, node): def types_from_List(self, node): # literal array if _is_empty_list(node): - # empty list literal `[]` ret = [] - # subtype can be anything - for t in types.PRIMITIVE_TYPES.values(): + + if len(node.elements) > 0: + # empty nested list literals `[[], []]` + subtypes = self.get_possible_types_from_node(node.elements[0]) + else: + # empty list literal `[]` + # subtype can be anything + subtypes = types.PRIMITIVE_TYPES.values() + + for t in subtypes: # 1 is minimum possible length for dynarray, # can be assigned to anything if isinstance(t, VyperType): From ed0b1e0ac8ddb47019efcff4b692ff6470fc6a04 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 20 Oct 2023 07:52:04 -0700 Subject: [PATCH 116/201] fix: add missing test for memory allocation overflow (#3650) should have been added in 68da04b2e9e0 but the file was not committed --- tests/parser/features/test_memory_alloc.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/parser/features/test_memory_alloc.py diff --git a/tests/parser/features/test_memory_alloc.py b/tests/parser/features/test_memory_alloc.py new file mode 100644 index 0000000000..ee6d15c67c --- /dev/null +++ b/tests/parser/features/test_memory_alloc.py @@ -0,0 +1,16 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import MemoryAllocationException + + +def test_memory_overflow(): + code = """ +@external +def zzz(x: DynArray[uint256, 2**59]): # 2**64 / 32 bytes per word == 2**59 + y: uint256[7] = [0,0,0,0,0,0,0] + + y[6] = y[5] + """ + with pytest.raises(MemoryAllocationException): + compile_code(code) From b01cd686aa567b32498fefd76bd96b0597c6f099 Mon Sep 17 00:00:00 2001 From: engn33r Date: Mon, 23 Oct 2023 02:00:08 +0000 Subject: [PATCH 117/201] docs: fix link to style guide (#3658) --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 94eb5ec04c..baa8decacc 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ ### Commit message -Commit message for the final, squashed PR. (Optional, but reviewers will appreciate it! Please see [our commit message style guide](../../blob/master/docs/style-guide.rst#best-practices-1) for what we would ideally like to see in a commit message.) +Commit message for the final, squashed PR. (Optional, but reviewers will appreciate it! Please see [our commit message style guide](../../master/docs/style-guide.rst#best-practices-1) for what we would ideally like to see in a commit message.) ### Description for the changelog From 52dc413c684532d5c4d6cdd91e3b058957cfcba0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 1 Nov 2023 15:03:51 -0400 Subject: [PATCH 118/201] docs: retire security@vyperlang.org (#3660) now that private vulnerability reporting is available on github, the security inbox is no longer required (or regularly monitored) cf. https://docs.github.com/en/code-security/security-advisories/working-with-repository-security-advisories/configuring-private-vulnerability-reporting-for-a-repository --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index c7bdad4ee7..0a054b2c93 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -48,7 +48,7 @@ https://github.com/vyperlang/vyper/security/advisories If you think you have found a security vulnerability with a project that has used Vyper, please report the vulnerability to the relevant project's security disclosure program prior -to reporting to us. If one is not available, please email your vulnerability to security@vyperlang.org. +to reporting to us. If one is not available, submit it at https://github.com/vyperlang/vyper/security/advisories. **Please Do Not Log An Issue** mentioning the vulnerability. From 9ce56e7d8b0196a5d51d706a8d2376b98d3e8ad7 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:33:20 +0800 Subject: [PATCH 119/201] chore: fix test for `slice` (#3633) fix some test cases for `slice` and simplify the test logic --------- Co-authored-by: Charles Cooper --- tests/parser/functions/test_slice.py | 88 +++++++++++++++++----------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3090dafda0..53e092019f 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -32,8 +32,8 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: _bytes_1024 = st.binary(min_size=0, max_size=1024) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @@ -45,13 +45,13 @@ def test_slice_immutable( opt_level, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) @@ -71,10 +71,10 @@ def _get_contract(): return get_contract(code, bytesdata, start, length, override_opt_level=opt_level) if ( - (start + length > length_bound and literal_start and literal_length) - or (literal_length and length > length_bound) - or (literal_start and start > length_bound) - or (literal_length and length < 1) + (start + length > length_bound and use_literal_start and use_literal_length) + or (use_literal_length and length > length_bound) + or (use_literal_start and start > length_bound) + or (use_literal_length and length == 0) ): assert_compile_failed(lambda: _get_contract(), ArgumentException) elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): @@ -86,13 +86,13 @@ def _get_contract(): @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @pytest.mark.fuzzing -def test_slice_bytes( +def test_slice_bytes_fuzz( get_contract, assert_compile_failed, assert_tx_failed, @@ -100,18 +100,28 @@ def test_slice_bytes( location, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): + preamble = "" if location == "memory": spliced_code = f"foo: Bytes[{length_bound}] = inp" foo = "foo" elif location == "storage": + preamble = f""" +foo: Bytes[{length_bound}] + """ spliced_code = "self.foo = inp" foo = "self.foo" elif location == "code": + preamble = f""" +IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) +@external +def __init__(foo: Bytes[{length_bound}]): + IMMUTABLE_BYTES = foo + """ spliced_code = "" foo = "IMMUTABLE_BYTES" elif location == "literal": @@ -123,15 +133,11 @@ def test_slice_bytes( else: raise Exception("unreachable") - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" -foo: Bytes[{length_bound}] -IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external -def __init__(foo: Bytes[{length_bound}]): - IMMUTABLE_BYTES = foo +{preamble} @external def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]: @@ -142,24 +148,40 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt def _get_contract(): return get_contract(code, bytesdata, override_opt_level=opt_level) - data_length = len(bytesdata) if location == "literal" else length_bound - if ( - (start + length > data_length and literal_start and literal_length) - or (literal_length and length > data_length) - or (location == "literal" and len(bytesdata) > length_bound) - or (literal_start and start > data_length) - or (literal_length and length < 1) - ): + # length bound is the container size; input_bound is the bound on the input + # (which can be different, if the input is a literal) + input_bound = length_bound + slice_output_too_large = False + + if location == "literal": + input_bound = len(bytesdata) + + # ex.: + # @external + # def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]: + # return slice(b'\x00\x00', 0, length) + output_length = length if use_literal_length else input_bound + slice_output_too_large = output_length > length_bound + + end = start + length + + compile_time_oob = ( + (use_literal_length and (length > input_bound or length == 0)) + or (use_literal_start and start > input_bound) + or (use_literal_start and use_literal_length and start + length > input_bound) + ) + + if compile_time_oob or slice_output_too_large: assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif len(bytesdata) > data_length: + elif location == "code" and len(bytesdata) > length_bound: # deploy fail assert_tx_failed(lambda: _get_contract()) - elif start + length > len(bytesdata): + elif end > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: c = _get_contract() - assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code + assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code def test_slice_private(get_contract): From 5d10ea0d2a26ab0c58beab4b0b9a4a3d90c9c439 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 6 Nov 2023 15:33:31 -0500 Subject: [PATCH 120/201] refactor: internal handling of imports (#3655) this commit refactors how imports are handled internally. historically, vyper handled imports by using a preprocessing step (`extract_file_interface_imports`) which resolved imports to files and provided them to the compiler pipeline as pure inputs. however, this causes problems once the module system gets more complicated: - it mixes passes. resolving imports and loading the files should essentially be resolved during analysis, but instead they are being resolved before the compiler is even entered into(!) - it produces slightly different code paths into the main compiler entry point which introduces subtle bugs over time from scaffolding differences - relatedly, each entry point into the compiler has to maintain its own mechanism for resolving different kinds of inputs to the compiler (JSON interfaces vs .vy files at the moment). this commit replaces the external scaffolding with an "InputBundle" abstraction which essentially models how the compiler interacts with its inputs (depending on whether it is using the filesystem or JSON inputs). this makes the entry points to the compiler overall simpler, and have more consistent behavior. this commit also: - changes builtin interfaces so they are represented in the codebase as `.vy` files which are imported similarly to how regular (non-builtin) files are imported - simplifies the `compile_files` and `compile_json` pipelines - removes the `compile_codes` API, which was not actually more useful than the `compile_code` API (which is kept). - cleans up tests by introducing a `make_file` and `make_input_bundle` abstraction - simplifies and merges several files in the tests/cli/ directories - adds a test for multiple output selections in the standard json pipeline --- tests/base_conftest.py | 15 +- tests/cli/vyper_compile/test_compile_files.py | 205 +++++++++++- tests/cli/vyper_compile/test_import_paths.py | 260 --------------- tests/cli/vyper_compile/test_parse_args.py | 2 + .../test_compile_from_input_dict.py | 132 -------- tests/cli/vyper_json/test_compile_json.py | 190 +++++++++-- tests/cli/vyper_json/test_get_contracts.py | 71 ---- tests/cli/vyper_json/test_get_inputs.py | 142 ++++++++ tests/cli/vyper_json/test_get_settings.py | 2 - tests/cli/vyper_json/test_interfaces.py | 126 ------- tests/cli/vyper_json/test_output_dict.py | 38 --- tests/cli/vyper_json/test_output_selection.py | 38 +-- .../vyper_json/test_parse_args_vyperjson.py | 3 +- tests/compiler/test_bytecode_runtime.py | 14 +- tests/compiler/test_compile_code.py | 2 +- tests/compiler/test_input_bundle.py | 208 ++++++++++++ tests/compiler/test_opcodes.py | 2 +- tests/compiler/test_source_map.py | 6 +- tests/conftest.py | 29 ++ tests/parser/ast_utils/test_ast_dict.py | 8 +- tests/parser/features/test_init.py | 2 +- tests/parser/functions/test_bitwise.py | 2 +- tests/parser/functions/test_interfaces.py | 201 ++++++------ tests/parser/functions/test_raw_call.py | 8 +- tests/parser/functions/test_return_struct.py | 4 +- tests/parser/syntax/test_codehash.py | 2 +- tests/parser/syntax/test_interfaces.py | 11 +- tests/parser/syntax/test_self_balance.py | 2 +- tests/parser/test_selector_table_stability.py | 4 +- tests/parser/types/numbers/test_constants.py | 2 +- vyper/__init__.py | 2 +- .../interfaces/{ERC165.py => ERC165.vy} | 2 - .../interfaces/{ERC20.py => ERC20.vy} | 2 - vyper/builtins/interfaces/ERC20Detailed.py | 22 -- vyper/builtins/interfaces/ERC20Detailed.vy | 18 + .../interfaces/{ERC4626.py => ERC4626.vy} | 2 - .../interfaces/{ERC721.py => ERC721.vy} | 3 - vyper/builtins/interfaces/__init__.py | 0 vyper/cli/utils.py | 58 ---- vyper/cli/vyper_compile.py | 137 +++----- vyper/cli/vyper_json.py | 307 ++++++------------ vyper/cli/vyper_serve.py | 6 +- vyper/compiler/__init__.py | 177 +++------- vyper/compiler/input_bundle.py | 180 ++++++++++ vyper/compiler/output.py | 5 +- vyper/compiler/phases.py | 52 +-- vyper/semantics/analysis/__init__.py | 4 +- vyper/semantics/analysis/module.py | 155 +++++---- vyper/typing.py | 8 - 49 files changed, 1448 insertions(+), 1423 deletions(-) delete mode 100644 tests/cli/vyper_compile/test_import_paths.py delete mode 100644 tests/cli/vyper_json/test_compile_from_input_dict.py delete mode 100644 tests/cli/vyper_json/test_get_contracts.py create mode 100644 tests/cli/vyper_json/test_get_inputs.py delete mode 100644 tests/cli/vyper_json/test_interfaces.py delete mode 100644 tests/cli/vyper_json/test_output_dict.py create mode 100644 tests/compiler/test_input_bundle.py rename vyper/builtins/interfaces/{ERC165.py => ERC165.vy} (75%) rename vyper/builtins/interfaces/{ERC20.py => ERC20.vy} (96%) delete mode 100644 vyper/builtins/interfaces/ERC20Detailed.py create mode 100644 vyper/builtins/interfaces/ERC20Detailed.vy rename vyper/builtins/interfaces/{ERC4626.py => ERC4626.vy} (98%) rename vyper/builtins/interfaces/{ERC721.py => ERC721.vy} (97%) delete mode 100644 vyper/builtins/interfaces/__init__.py delete mode 100644 vyper/cli/utils.py create mode 100644 vyper/compiler/input_bundle.py diff --git a/tests/base_conftest.py b/tests/base_conftest.py index 1c7c6f3aed..f613ad0f47 100644 --- a/tests/base_conftest.py +++ b/tests/base_conftest.py @@ -112,16 +112,18 @@ def w3(tester): return w3 -def _get_contract(w3, source_code, optimize, *args, override_opt_level=None, **kwargs): +def _get_contract( + w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs +): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, # test that metadata and natspecs get generated - ["abi", "bytecode", "metadata", "userdoc", "devdoc"], + output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], settings=settings, - interface_codes=kwargs.pop("interface_codes", None), + input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing ) parse_vyper_source(source_code) # Test grammar. @@ -144,8 +146,7 @@ def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwar settings.optimize = optimize out = compiler.compile_code( source_code, - ["abi", "bytecode"], - interface_codes=kwargs.pop("interface_codes", None), + output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) @@ -187,10 +188,10 @@ def deploy_blueprint_for(source_code, *args, **kwargs): @pytest.fixture(scope="module") def get_contract(w3, optimize): - def get_contract(source_code, *args, **kwargs): + def fn(source_code, *args, **kwargs): return _get_contract(w3, source_code, optimize, *args, **kwargs) - return get_contract + return fn @pytest.fixture diff --git a/tests/cli/vyper_compile/test_compile_files.py b/tests/cli/vyper_compile/test_compile_files.py index 31cf622658..2a16efa777 100644 --- a/tests/cli/vyper_compile/test_compile_files.py +++ b/tests/cli/vyper_compile/test_compile_files.py @@ -1,12 +1,12 @@ +from pathlib import Path + import pytest from vyper.cli.vyper_compile import compile_files -def test_combined_json_keys(tmp_path): - bar_path = tmp_path.joinpath("bar.vy") - with bar_path.open("w") as fp: - fp.write("") +def test_combined_json_keys(tmp_path, make_file): + make_file("bar.vy", "") combined_keys = { "bytecode", @@ -19,12 +19,203 @@ def test_combined_json_keys(tmp_path): "userdoc", "devdoc", } - compile_data = compile_files([bar_path], ["combined_json"], root_folder=tmp_path) + compile_data = compile_files(["bar.vy"], ["combined_json"], root_folder=tmp_path) - assert set(compile_data.keys()) == {"bar.vy", "version"} - assert set(compile_data["bar.vy"].keys()) == combined_keys + assert set(compile_data.keys()) == {Path("bar.vy"), "version"} + assert set(compile_data[Path("bar.vy")].keys()) == combined_keys def test_invalid_root_path(): with pytest.raises(FileNotFoundError): compile_files([], [], root_folder="path/that/does/not/exist") + + +FOO_CODE = """ +{} + +struct FooStruct: + foo_: uint256 + +@external +def foo() -> FooStruct: + return FooStruct({{foo_: 13}}) + +@external +def bar(a: address) -> FooStruct: + return {}(a).bar() +""" + +BAR_CODE = """ +struct FooStruct: + foo_: uint256 +@external +def bar() -> FooStruct: + return FooStruct({foo_: 13}) +""" + + +SAME_FOLDER_IMPORT_STMT = [ + ("import Bar as Bar", "Bar"), + ("import contracts.Bar as Bar", "Bar"), + ("from . import Bar", "Bar"), + ("from contracts import Bar", "Bar"), + ("from ..contracts import Bar", "Bar"), + ("from . import Bar as FooBar", "FooBar"), + ("from contracts import Bar as FooBar", "FooBar"), + ("from ..contracts import Bar as FooBar", "FooBar"), +] + + +@pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) +def test_import_same_folder(import_stmt, alias, tmp_path, make_file): + foo = "contracts/foo.vy" + make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) + make_file("contracts/Bar.vy", BAR_CODE) + + assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + + +SUBFOLDER_IMPORT_STMT = [ + ("import other.Bar as Bar", "Bar"), + ("import contracts.other.Bar as Bar", "Bar"), + ("from other import Bar", "Bar"), + ("from contracts.other import Bar", "Bar"), + ("from .other import Bar", "Bar"), + ("from ..contracts.other import Bar", "Bar"), + ("from other import Bar as FooBar", "FooBar"), + ("from contracts.other import Bar as FooBar", "FooBar"), + ("from .other import Bar as FooBar", "FooBar"), + ("from ..contracts.other import Bar as FooBar", "FooBar"), +] + + +@pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) +def test_import_subfolder(import_stmt, alias, tmp_path, make_file): + foo = make_file("contracts/foo.vy", (FOO_CODE.format(import_stmt, alias))) + make_file("contracts/other/Bar.vy", BAR_CODE) + + assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + + +OTHER_FOLDER_IMPORT_STMT = [ + ("import interfaces.Bar as Bar", "Bar"), + ("from interfaces import Bar", "Bar"), + ("from ..interfaces import Bar", "Bar"), + ("from interfaces import Bar as FooBar", "FooBar"), + ("from ..interfaces import Bar as FooBar", "FooBar"), +] + + +@pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) +def test_import_other_folder(import_stmt, alias, tmp_path, make_file): + foo = make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) + make_file("interfaces/Bar.vy", BAR_CODE) + + assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + + +def test_import_parent_folder(tmp_path, make_file): + foo = make_file("contracts/baz/foo.vy", FOO_CODE.format("from ... import Bar", "Bar")) + make_file("Bar.vy", BAR_CODE) + + assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + + # perform relative import outside of base folder + compile_files([foo], ["combined_json"], root_folder=tmp_path / "contracts") + + +META_IMPORT_STMT = [ + "import Meta as Meta", + "import contracts.Meta as Meta", + "from . import Meta", + "from contracts import Meta", +] + + +@pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) +def test_import_self_interface(import_stmt, tmp_path, make_file): + # a contract can access its derived interface by importing itself + code = f""" +{import_stmt} + +struct FooStruct: + foo_: uint256 + +@external +def know_thyself(a: address) -> FooStruct: + return Meta(a).be_known() + +@external +def be_known() -> FooStruct: + return FooStruct({{foo_: 42}}) + """ + meta = make_file("contracts/Meta.vy", code) + + assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + + +DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] + +DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] + + +@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) +@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) +def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path, make_file): + # contracts-as-interfaces should be able to contain import statements + baz_code = f""" +{import_stmt_baz} + +struct FooStruct: + foo_: uint256 + +@external +def foo(a: address) -> FooStruct: + return Foo(a).foo() + +@external +def bar(_foo: address, _bar: address) -> FooStruct: + return Foo(_foo).bar(_bar) + """ + + make_file("Foo.vy", FOO_CODE.format(import_stmt_foo, "Bar")) + make_file("Bar.vy", BAR_CODE) + baz = make_file("Baz.vy", baz_code) + + assert compile_files([baz], ["combined_json"], root_folder=tmp_path) + + +def test_local_namespace(make_file, tmp_path): + # interface code namespaces should be isolated + # all of these contract should be able to compile together + codes = [ + "import foo as FooBar", + "import bar as FooBar", + "import foo as BarFoo", + "import bar as BarFoo", + ] + struct_def = """ +struct FooStruct: + foo_: uint256 + + """ + + paths = [] + for i, code in enumerate(codes): + code += struct_def + filename = f"code{i}.vy" + make_file(filename, code) + paths.append(filename) + + for file_name in ("foo.vy", "bar.vy"): + make_file(file_name, BAR_CODE) + + assert compile_files(paths, ["combined_json"], root_folder=tmp_path) + + +def test_compile_outside_root_path(tmp_path, make_file): + # absolute paths relative to "." + foo = make_file("foo.vy", FOO_CODE.format("import bar as Bar", "Bar")) + bar = make_file("bar.vy", BAR_CODE) + + assert compile_files([foo, bar], ["combined_json"], root_folder=".") diff --git a/tests/cli/vyper_compile/test_import_paths.py b/tests/cli/vyper_compile/test_import_paths.py deleted file mode 100644 index 81f209113f..0000000000 --- a/tests/cli/vyper_compile/test_import_paths.py +++ /dev/null @@ -1,260 +0,0 @@ -import pytest - -from vyper.cli.vyper_compile import compile_files, get_interface_file_path - -FOO_CODE = """ -{} - -struct FooStruct: - foo_: uint256 - -@external -def foo() -> FooStruct: - return FooStruct({{foo_: 13}}) - -@external -def bar(a: address) -> FooStruct: - return {}(a).bar() -""" - -BAR_CODE = """ -struct FooStruct: - foo_: uint256 -@external -def bar() -> FooStruct: - return FooStruct({foo_: 13}) -""" - - -SAME_FOLDER_IMPORT_STMT = [ - ("import Bar as Bar", "Bar"), - ("import contracts.Bar as Bar", "Bar"), - ("from . import Bar", "Bar"), - ("from contracts import Bar", "Bar"), - ("from ..contracts import Bar", "Bar"), - ("from . import Bar as FooBar", "FooBar"), - ("from contracts import Bar as FooBar", "FooBar"), - ("from ..contracts import Bar as FooBar", "FooBar"), -] - - -@pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) -def test_import_same_folder(import_stmt, alias, tmp_path): - tmp_path.joinpath("contracts").mkdir() - - foo_path = tmp_path.joinpath("contracts/foo.vy") - with foo_path.open("w") as fp: - fp.write(FOO_CODE.format(import_stmt, alias)) - - with tmp_path.joinpath("contracts/Bar.vy").open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files([foo_path], ["combined_json"], root_folder=tmp_path) - - -SUBFOLDER_IMPORT_STMT = [ - ("import other.Bar as Bar", "Bar"), - ("import contracts.other.Bar as Bar", "Bar"), - ("from other import Bar", "Bar"), - ("from contracts.other import Bar", "Bar"), - ("from .other import Bar", "Bar"), - ("from ..contracts.other import Bar", "Bar"), - ("from other import Bar as FooBar", "FooBar"), - ("from contracts.other import Bar as FooBar", "FooBar"), - ("from .other import Bar as FooBar", "FooBar"), - ("from ..contracts.other import Bar as FooBar", "FooBar"), -] - - -@pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) -def test_import_subfolder(import_stmt, alias, tmp_path): - tmp_path.joinpath("contracts").mkdir() - - foo_path = tmp_path.joinpath("contracts/foo.vy") - with foo_path.open("w") as fp: - fp.write(FOO_CODE.format(import_stmt, alias)) - - tmp_path.joinpath("contracts/other").mkdir() - with tmp_path.joinpath("contracts/other/Bar.vy").open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files([foo_path], ["combined_json"], root_folder=tmp_path) - - -OTHER_FOLDER_IMPORT_STMT = [ - ("import interfaces.Bar as Bar", "Bar"), - ("from interfaces import Bar", "Bar"), - ("from ..interfaces import Bar", "Bar"), - ("from interfaces import Bar as FooBar", "FooBar"), - ("from ..interfaces import Bar as FooBar", "FooBar"), -] - - -@pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) -def test_import_other_folder(import_stmt, alias, tmp_path): - tmp_path.joinpath("contracts").mkdir() - - foo_path = tmp_path.joinpath("contracts/foo.vy") - with foo_path.open("w") as fp: - fp.write(FOO_CODE.format(import_stmt, alias)) - - tmp_path.joinpath("interfaces").mkdir() - with tmp_path.joinpath("interfaces/Bar.vy").open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files([foo_path], ["combined_json"], root_folder=tmp_path) - - -def test_import_parent_folder(tmp_path, assert_compile_failed): - tmp_path.joinpath("contracts").mkdir() - tmp_path.joinpath("contracts/baz").mkdir() - - foo_path = tmp_path.joinpath("contracts/baz/foo.vy") - with foo_path.open("w") as fp: - fp.write(FOO_CODE.format("from ... import Bar", "Bar")) - - with tmp_path.joinpath("Bar.vy").open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files([foo_path], ["combined_json"], root_folder=tmp_path) - # Cannot perform relative import outside of base folder - with pytest.raises(FileNotFoundError): - compile_files([foo_path], ["combined_json"], root_folder=tmp_path.joinpath("contracts")) - - -META_IMPORT_STMT = [ - "import Meta as Meta", - "import contracts.Meta as Meta", - "from . import Meta", - "from contracts import Meta", -] - - -@pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) -def test_import_self_interface(import_stmt, tmp_path): - # a contract can access its derived interface by importing itself - code = f""" -{import_stmt} - -struct FooStruct: - foo_: uint256 - -@external -def know_thyself(a: address) -> FooStruct: - return Meta(a).be_known() - -@external -def be_known() -> FooStruct: - return FooStruct({{foo_: 42}}) - """ - - tmp_path.joinpath("contracts").mkdir() - - meta_path = tmp_path.joinpath("contracts/Meta.vy") - with meta_path.open("w") as fp: - fp.write(code) - - assert compile_files([meta_path], ["combined_json"], root_folder=tmp_path) - - -DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] - -DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] - - -@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) -@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) -def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path): - # contracts-as-interfaces should be able to contain import statements - baz_code = f""" -{import_stmt_baz} - -struct FooStruct: - foo_: uint256 - -@external -def foo(a: address) -> FooStruct: - return Foo(a).foo() - -@external -def bar(_foo: address, _bar: address) -> FooStruct: - return Foo(_foo).bar(_bar) - """ - - with tmp_path.joinpath("Foo.vy").open("w") as fp: - fp.write(FOO_CODE.format(import_stmt_foo, "Bar")) - - with tmp_path.joinpath("Bar.vy").open("w") as fp: - fp.write(BAR_CODE) - - baz_path = tmp_path.joinpath("Baz.vy") - with baz_path.open("w") as fp: - fp.write(baz_code) - - assert compile_files([baz_path], ["combined_json"], root_folder=tmp_path) - - -def test_local_namespace(tmp_path): - # interface code namespaces should be isolated - # all of these contract should be able to compile together - codes = [ - "import foo as FooBar", - "import bar as FooBar", - "import foo as BarFoo", - "import bar as BarFoo", - ] - struct_def = """ -struct FooStruct: - foo_: uint256 - - """ - - compile_paths = [] - for i, code in enumerate(codes): - code += struct_def - path = tmp_path.joinpath(f"code{i}.vy") - with path.open("w") as fp: - fp.write(code) - compile_paths.append(path) - - for file_name in ("foo.vy", "bar.vy"): - with tmp_path.joinpath(file_name).open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files(compile_paths, ["combined_json"], root_folder=tmp_path) - - -def test_get_interface_file_path(tmp_path): - for file_name in ("foo.vy", "foo.json", "bar.vy", "baz.json", "potato"): - with tmp_path.joinpath(file_name).open("w") as fp: - fp.write("") - - tmp_path.joinpath("interfaces").mkdir() - for file_name in ("interfaces/foo.json", "interfaces/bar"): - with tmp_path.joinpath(file_name).open("w") as fp: - fp.write("") - - base_paths = [tmp_path, tmp_path.joinpath("interfaces")] - assert get_interface_file_path(base_paths, "foo") == tmp_path.joinpath("foo.vy") - assert get_interface_file_path(base_paths, "bar") == tmp_path.joinpath("bar.vy") - assert get_interface_file_path(base_paths, "baz") == tmp_path.joinpath("baz.json") - - base_paths = [tmp_path.joinpath("interfaces"), tmp_path] - assert get_interface_file_path(base_paths, "foo") == tmp_path.joinpath("interfaces/foo.json") - assert get_interface_file_path(base_paths, "bar") == tmp_path.joinpath("bar.vy") - assert get_interface_file_path(base_paths, "baz") == tmp_path.joinpath("baz.json") - - with pytest.raises(Exception): - get_interface_file_path(base_paths, "potato") - - -def test_compile_outside_root_path(tmp_path): - foo_path = tmp_path.joinpath("foo.vy") - with foo_path.open("w") as fp: - fp.write(FOO_CODE.format("import bar as Bar", "Bar")) - - bar_path = tmp_path.joinpath("bar.vy") - with bar_path.open("w") as fp: - fp.write(BAR_CODE) - - assert compile_files([foo_path, bar_path], ["combined_json"], root_folder=".") diff --git a/tests/cli/vyper_compile/test_parse_args.py b/tests/cli/vyper_compile/test_parse_args.py index a676a7836b..0e8c4e9605 100644 --- a/tests/cli/vyper_compile/test_parse_args.py +++ b/tests/cli/vyper_compile/test_parse_args.py @@ -21,7 +21,9 @@ def foo() -> bool: bar_path = chdir_path.joinpath("bar.vy") with bar_path.open("w") as fp: fp.write(code) + _parse_args([str(bar_path)]) # absolute path os.chdir(chdir_path.parent) + _parse_args([str(bar_path)]) # absolute path, subfolder of cwd _parse_args([str(bar_path.relative_to(chdir_path.parent))]) # relative path diff --git a/tests/cli/vyper_json/test_compile_from_input_dict.py b/tests/cli/vyper_json/test_compile_from_input_dict.py deleted file mode 100644 index a6d0a23100..0000000000 --- a/tests/cli/vyper_json/test_compile_from_input_dict.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python3 - -from copy import deepcopy - -import pytest - -import vyper -from vyper.cli.vyper_json import ( - TRANSLATE_MAP, - compile_from_input_dict, - exc_handler_raises, - exc_handler_to_dict, -) -from vyper.exceptions import InvalidType, JSONError, SyntaxException - -FOO_CODE = """ -import contracts.bar as Bar - -@external -def foo(a: address) -> bool: - return Bar(a).bar(1) - -@external -def baz() -> uint256: - return self.balance -""" - -BAR_CODE = """ -@external -def bar(a: uint256) -> bool: - return True -""" - -BAD_SYNTAX_CODE = """ -def bar()>: -""" - -BAD_COMPILER_CODE = """ -@external -def oopsie(a: uint256) -> bool: - return 42 -""" - -BAR_ABI = [ - { - "name": "bar", - "outputs": [{"type": "bool", "name": "out"}], - "inputs": [{"type": "uint256", "name": "a"}], - "stateMutability": "nonpayable", - "type": "function", - "gas": 313, - } -] - -INPUT_JSON = { - "language": "Vyper", - "sources": { - "contracts/foo.vy": {"content": FOO_CODE}, - "contracts/bar.vy": {"content": BAR_CODE}, - }, - "interfaces": {"contracts/bar.json": {"abi": BAR_ABI}}, - "settings": {"outputSelection": {"*": ["*"]}}, -} - - -def test_root_folder_not_exists(): - with pytest.raises(FileNotFoundError): - compile_from_input_dict({}, root_folder="/path/that/does/not/exist") - - -def test_wrong_language(): - with pytest.raises(JSONError): - compile_from_input_dict({"language": "Solidity"}) - - -def test_exc_handler_raises_syntax(): - input_json = deepcopy(INPUT_JSON) - input_json["sources"]["badcode.vy"] = {"content": BAD_SYNTAX_CODE} - with pytest.raises(SyntaxException): - compile_from_input_dict(input_json, exc_handler_raises) - - -def test_exc_handler_to_dict_syntax(): - input_json = deepcopy(INPUT_JSON) - input_json["sources"]["badcode.vy"] = {"content": BAD_SYNTAX_CODE} - result, _ = compile_from_input_dict(input_json, exc_handler_to_dict) - assert "errors" in result - assert len(result["errors"]) == 1 - error = result["errors"][0] - assert error["component"] == "parser" - assert error["type"] == "SyntaxException" - - -def test_exc_handler_raises_compiler(): - input_json = deepcopy(INPUT_JSON) - input_json["sources"]["badcode.vy"] = {"content": BAD_COMPILER_CODE} - with pytest.raises(InvalidType): - compile_from_input_dict(input_json, exc_handler_raises) - - -def test_exc_handler_to_dict_compiler(): - input_json = deepcopy(INPUT_JSON) - input_json["sources"]["badcode.vy"] = {"content": BAD_COMPILER_CODE} - result, _ = compile_from_input_dict(input_json, exc_handler_to_dict) - assert sorted(result.keys()) == ["compiler", "errors"] - assert result["compiler"] == f"vyper-{vyper.__version__}" - assert len(result["errors"]) == 1 - error = result["errors"][0] - assert error["component"] == "compiler" - assert error["type"] == "InvalidType" - - -def test_source_ids_increment(): - input_json = deepcopy(INPUT_JSON) - input_json["settings"]["outputSelection"] = {"*": ["evm.deployedBytecode.sourceMap"]} - result, _ = compile_from_input_dict(input_json) - assert result["contracts/bar.vy"]["source_map"]["pc_pos_map_compressed"].startswith("-1:-1:0") - assert result["contracts/foo.vy"]["source_map"]["pc_pos_map_compressed"].startswith("-1:-1:1") - - -def test_outputs(): - result, _ = compile_from_input_dict(INPUT_JSON) - assert sorted(result.keys()) == ["contracts/bar.vy", "contracts/foo.vy"] - assert sorted(result["contracts/bar.vy"].keys()) == sorted(set(TRANSLATE_MAP.values())) - - -def test_relative_import_paths(): - input_json = deepcopy(INPUT_JSON) - input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} - input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} - compile_from_input_dict(input_json) diff --git a/tests/cli/vyper_json/test_compile_json.py b/tests/cli/vyper_json/test_compile_json.py index f03006c4ad..732762d72b 100644 --- a/tests/cli/vyper_json/test_compile_json.py +++ b/tests/cli/vyper_json/test_compile_json.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python3 - import json -from copy import deepcopy import pytest -from vyper.cli.vyper_json import compile_from_input_dict, compile_json -from vyper.exceptions import JSONError +import vyper +from vyper.cli.vyper_json import compile_from_input_dict, compile_json, exc_handler_to_dict +from vyper.compiler import OUTPUT_FORMATS, compile_code +from vyper.exceptions import InvalidType, JSONError, SyntaxException FOO_CODE = """ import contracts.bar as Bar @@ -14,6 +13,10 @@ @external def foo(a: address) -> bool: return Bar(a).bar(1) + +@external +def baz() -> uint256: + return self.balance """ BAR_CODE = """ @@ -22,6 +25,16 @@ def bar(a: uint256) -> bool: return True """ +BAD_SYNTAX_CODE = """ +def bar()>: +""" + +BAD_COMPILER_CODE = """ +@external +def oopsie(a: uint256) -> bool: + return 42 +""" + BAR_ABI = [ { "name": "bar", @@ -29,23 +42,26 @@ def bar(a: uint256) -> bool: "inputs": [{"type": "uint256", "name": "a"}], "stateMutability": "nonpayable", "type": "function", - "gas": 313, } ] -INPUT_JSON = { - "language": "Vyper", - "sources": { - "contracts/foo.vy": {"content": FOO_CODE}, - "contracts/bar.vy": {"content": BAR_CODE}, - }, - "interfaces": {"contracts/bar.json": {"abi": BAR_ABI}}, - "settings": {"outputSelection": {"*": ["*"]}}, -} + +@pytest.fixture(scope="function") +def input_json(): + return { + "language": "Vyper", + "sources": { + "contracts/foo.vy": {"content": FOO_CODE}, + "contracts/bar.vy": {"content": BAR_CODE}, + }, + "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, + "settings": {"outputSelection": {"*": ["*"]}}, + } -def test_input_formats(): - assert compile_json(INPUT_JSON) == compile_json(json.dumps(INPUT_JSON)) +# test string and dict inputs both work +def test_string_input(input_json): + assert compile_json(input_json) == compile_json(json.dumps(input_json)) def test_bad_json(): @@ -53,10 +69,146 @@ def test_bad_json(): compile_json("this probably isn't valid JSON, is it") -def test_keyerror_becomes_jsonerror(): - input_json = deepcopy(INPUT_JSON) +def test_keyerror_becomes_jsonerror(input_json): del input_json["sources"] with pytest.raises(KeyError): compile_from_input_dict(input_json) with pytest.raises(JSONError): compile_json(input_json) + + +def test_compile_json(input_json, make_input_bundle): + input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) + + foo = compile_code( + FOO_CODE, + source_id=0, + contract_name="contracts/foo.vy", + output_formats=OUTPUT_FORMATS, + input_bundle=input_bundle, + ) + bar = compile_code( + BAR_CODE, source_id=1, contract_name="contracts/bar.vy", output_formats=OUTPUT_FORMATS + ) + + compile_code_results = {"contracts/bar.vy": bar, "contracts/foo.vy": foo} + + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] + assert output_json["compiler"] == f"vyper-{vyper.__version__}" + + for source_id, contract_name in enumerate(["foo", "bar"]): + path = f"contracts/{contract_name}.vy" + data = compile_code_results[path] + assert output_json["sources"][path] == {"id": source_id, "ast": data["ast_dict"]["ast"]} + assert output_json["contracts"][path][contract_name] == { + "abi": data["abi"], + "devdoc": data["devdoc"], + "interface": data["interface"], + "ir": data["ir_dict"], + "userdoc": data["userdoc"], + "metadata": data["metadata"], + "evm": { + "bytecode": {"object": data["bytecode"], "opcodes": data["opcodes"]}, + "deployedBytecode": { + "object": data["bytecode_runtime"], + "opcodes": data["opcodes_runtime"], + "sourceMap": data["source_map"]["pc_pos_map_compressed"], + "sourceMapFull": data["source_map_full"], + }, + "methodIdentifiers": data["method_identifiers"], + }, + } + + +def test_different_outputs(make_input_bundle, input_json): + input_json["settings"]["outputSelection"] = { + "contracts/bar.vy": "*", + "contracts/foo.vy": ["evm.methodIdentifiers"], + } + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] + assert output_json["compiler"] == f"vyper-{vyper.__version__}" + + contracts = output_json["contracts"] + + foo = contracts["contracts/foo.vy"]["foo"] + bar = contracts["contracts/bar.vy"]["bar"] + assert sorted(bar.keys()) == ["abi", "devdoc", "evm", "interface", "ir", "metadata", "userdoc"] + + assert sorted(foo.keys()) == ["evm"] + + # check method_identifiers + input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) + method_identifiers = compile_code( + FOO_CODE, + contract_name="contracts/foo.vy", + output_formats=["method_identifiers"], + input_bundle=input_bundle, + )["method_identifiers"] + assert foo["evm"]["methodIdentifiers"] == method_identifiers + + +def test_root_folder_not_exists(input_json): + with pytest.raises(FileNotFoundError): + compile_json(input_json, root_folder="/path/that/does/not/exist") + + +def test_wrong_language(): + with pytest.raises(JSONError): + compile_json({"language": "Solidity"}) + + +def test_exc_handler_raises_syntax(input_json): + input_json["sources"]["badcode.vy"] = {"content": BAD_SYNTAX_CODE} + with pytest.raises(SyntaxException): + compile_json(input_json) + + +def test_exc_handler_to_dict_syntax(input_json): + input_json["sources"]["badcode.vy"] = {"content": BAD_SYNTAX_CODE} + result = compile_json(input_json, exc_handler_to_dict) + assert "errors" in result + assert len(result["errors"]) == 1 + error = result["errors"][0] + assert error["component"] == "compiler", error + assert error["type"] == "SyntaxException" + + +def test_exc_handler_raises_compiler(input_json): + input_json["sources"]["badcode.vy"] = {"content": BAD_COMPILER_CODE} + with pytest.raises(InvalidType): + compile_json(input_json) + + +def test_exc_handler_to_dict_compiler(input_json): + input_json["sources"]["badcode.vy"] = {"content": BAD_COMPILER_CODE} + result = compile_json(input_json, exc_handler_to_dict) + assert sorted(result.keys()) == ["compiler", "errors"] + assert result["compiler"] == f"vyper-{vyper.__version__}" + assert len(result["errors"]) == 1 + error = result["errors"][0] + assert error["component"] == "compiler" + assert error["type"] == "InvalidType" + + +def test_source_ids_increment(input_json): + input_json["settings"]["outputSelection"] = {"*": ["evm.deployedBytecode.sourceMap"]} + result = compile_json(input_json) + + def get(filename, contractname): + return result["contracts"][filename][contractname]["evm"]["deployedBytecode"]["sourceMap"] + + assert get("contracts/foo.vy", "foo").startswith("-1:-1:0") + assert get("contracts/bar.vy", "bar").startswith("-1:-1:1") + + +def test_relative_import_paths(input_json): + input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} + input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} + compile_from_input_dict(input_json) diff --git a/tests/cli/vyper_json/test_get_contracts.py b/tests/cli/vyper_json/test_get_contracts.py deleted file mode 100644 index 86a5052f72..0000000000 --- a/tests/cli/vyper_json/test_get_contracts.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 - -import pytest - -from vyper.cli.vyper_json import get_input_dict_contracts -from vyper.exceptions import JSONError -from vyper.utils import keccak256 - -FOO_CODE = """ -import contracts.bar as Bar - -@external -def foo(a: address) -> bool: - return Bar(a).bar(1) -""" - -BAR_CODE = """ -@external -def bar(a: uint256) -> bool: - return True -""" - - -def test_no_sources(): - with pytest.raises(KeyError): - get_input_dict_contracts({}) - - -def test_contracts_urls(): - with pytest.raises(JSONError): - get_input_dict_contracts({"sources": {"foo.vy": {"urls": ["https://foo.code.com/"]}}}) - - -def test_contracts_no_content_key(): - with pytest.raises(JSONError): - get_input_dict_contracts({"sources": {"foo.vy": FOO_CODE}}) - - -def test_contracts_keccak(): - hash_ = keccak256(FOO_CODE.encode()).hex() - - input_json = {"sources": {"foo.vy": {"content": FOO_CODE, "keccak256": hash_}}} - get_input_dict_contracts(input_json) - - input_json["sources"]["foo.vy"]["keccak256"] = "0x" + hash_ - get_input_dict_contracts(input_json) - - input_json["sources"]["foo.vy"]["keccak256"] = "0x1234567890" - with pytest.raises(JSONError): - get_input_dict_contracts(input_json) - - -def test_contracts_bad_path(): - input_json = {"sources": {"../foo.vy": {"content": FOO_CODE}}} - with pytest.raises(JSONError): - get_input_dict_contracts(input_json) - - -def test_contract_collision(): - # ./foo.vy and foo.vy will resolve to the same path - input_json = {"sources": {"./foo.vy": {"content": FOO_CODE}, "foo.vy": {"content": FOO_CODE}}} - with pytest.raises(JSONError): - get_input_dict_contracts(input_json) - - -def test_contracts_return_value(): - input_json = { - "sources": {"foo.vy": {"content": FOO_CODE}, "contracts/bar.vy": {"content": BAR_CODE}} - } - result = get_input_dict_contracts(input_json) - assert result == {"foo.vy": FOO_CODE, "contracts/bar.vy": BAR_CODE} diff --git a/tests/cli/vyper_json/test_get_inputs.py b/tests/cli/vyper_json/test_get_inputs.py new file mode 100644 index 0000000000..6e323a91bd --- /dev/null +++ b/tests/cli/vyper_json/test_get_inputs.py @@ -0,0 +1,142 @@ +from pathlib import PurePath + +import pytest + +from vyper.cli.vyper_json import get_compilation_targets, get_inputs +from vyper.exceptions import JSONError +from vyper.utils import keccak256 + +FOO_CODE = """ +import contracts.bar as Bar + +@external +def foo(a: address) -> bool: + return Bar(a).bar(1) +""" + +BAR_CODE = """ +@external +def bar(a: uint256) -> bool: + return True +""" + + +def test_no_sources(): + with pytest.raises(KeyError): + get_inputs({}) + + +def test_contracts_urls(): + with pytest.raises(JSONError): + get_inputs({"sources": {"foo.vy": {"urls": ["https://foo.code.com/"]}}}) + + +def test_contracts_no_content_key(): + with pytest.raises(JSONError): + get_inputs({"sources": {"foo.vy": FOO_CODE}}) + + +def test_contracts_keccak(): + hash_ = keccak256(FOO_CODE.encode()).hex() + + input_json = {"sources": {"foo.vy": {"content": FOO_CODE, "keccak256": hash_}}} + get_inputs(input_json) + + input_json["sources"]["foo.vy"]["keccak256"] = "0x" + hash_ + get_inputs(input_json) + + input_json["sources"]["foo.vy"]["keccak256"] = "0x1234567890" + with pytest.raises(JSONError): + get_inputs(input_json) + + +def test_contracts_outside_pwd(): + input_json = {"sources": {"../foo.vy": {"content": FOO_CODE}}} + get_inputs(input_json) + + +def test_contract_collision(): + # ./foo.vy and foo.vy will resolve to the same path + input_json = {"sources": {"./foo.vy": {"content": FOO_CODE}, "foo.vy": {"content": FOO_CODE}}} + with pytest.raises(JSONError): + get_inputs(input_json) + + +def test_contracts_return_value(): + input_json = { + "sources": {"foo.vy": {"content": FOO_CODE}, "contracts/bar.vy": {"content": BAR_CODE}} + } + result = get_inputs(input_json) + assert result == { + PurePath("foo.vy"): {"content": FOO_CODE}, + PurePath("contracts/bar.vy"): {"content": BAR_CODE}, + } + + +BAR_ABI = [ + { + "name": "bar", + "outputs": [{"type": "bool", "name": "out"}], + "inputs": [{"type": "uint256", "name": "a"}], + "stateMutability": "nonpayable", + "type": "function", + } +] + + +# tests to get interfaces from input dicts + + +def test_interface_collision(): + input_json = { + "sources": {"foo.vy": {"content": FOO_CODE}}, + "interfaces": {"bar.json": {"abi": BAR_ABI}, "bar.vy": {"content": BAR_CODE}}, + } + with pytest.raises(JSONError): + get_inputs(input_json) + + +def test_json_no_abi(): + input_json = { + "sources": {"foo.vy": {"content": FOO_CODE}}, + "interfaces": {"bar.json": {"content": BAR_ABI}}, + } + with pytest.raises(JSONError): + get_inputs(input_json) + + +def test_vy_no_content(): + input_json = { + "sources": {"foo.vy": {"content": FOO_CODE}}, + "interfaces": {"bar.vy": {"abi": BAR_CODE}}, + } + with pytest.raises(JSONError): + get_inputs(input_json) + + +def test_interfaces_output(): + input_json = { + "sources": {"foo.vy": {"content": FOO_CODE}}, + "interfaces": { + "bar.json": {"abi": BAR_ABI}, + "interface.folder/bar2.vy": {"content": BAR_CODE}, + }, + } + targets = get_compilation_targets(input_json) + assert targets == [PurePath("foo.vy")] + + result = get_inputs(input_json) + assert result == { + PurePath("foo.vy"): {"content": FOO_CODE}, + PurePath("bar.json"): {"abi": BAR_ABI}, + PurePath("interface.folder/bar2.vy"): {"content": BAR_CODE}, + } + + +# EIP-2678 -- not currently supported +@pytest.mark.xfail +def test_manifest_output(): + input_json = {"interfaces": {"bar.json": {"contractTypes": {"Bar": {"abi": BAR_ABI}}}}} + result = get_inputs(input_json) + assert isinstance(result, dict) + assert result == {"Bar": {"type": "json", "code": BAR_ABI}} diff --git a/tests/cli/vyper_json/test_get_settings.py b/tests/cli/vyper_json/test_get_settings.py index bbe5dab113..989d4565cd 100644 --- a/tests/cli/vyper_json/test_get_settings.py +++ b/tests/cli/vyper_json/test_get_settings.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - import pytest from vyper.cli.vyper_json import get_evm_version diff --git a/tests/cli/vyper_json/test_interfaces.py b/tests/cli/vyper_json/test_interfaces.py deleted file mode 100644 index 7804ae1c3d..0000000000 --- a/tests/cli/vyper_json/test_interfaces.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 - -import pytest - -from vyper.cli.vyper_json import get_input_dict_interfaces, get_interface_codes -from vyper.exceptions import JSONError - -FOO_CODE = """ -import contracts.bar as Bar - -@external -def foo(a: address) -> bool: - return Bar(a).bar(1) -""" - -BAR_CODE = """ -@external -def bar(a: uint256) -> bool: - return True -""" - -BAR_ABI = [ - { - "name": "bar", - "outputs": [{"type": "bool", "name": "out"}], - "inputs": [{"type": "uint256", "name": "a"}], - "stateMutability": "nonpayable", - "type": "function", - "gas": 313, - } -] - - -# get_input_dict_interfaces tests - - -def test_no_interfaces(): - result = get_input_dict_interfaces({}) - assert isinstance(result, dict) - assert not result - - -def test_interface_collision(): - input_json = {"interfaces": {"bar.json": {"abi": BAR_ABI}, "bar.vy": {"content": BAR_CODE}}} - with pytest.raises(JSONError): - get_input_dict_interfaces(input_json) - - -def test_interfaces_wrong_suffix(): - input_json = {"interfaces": {"foo.abi": {"content": FOO_CODE}}} - with pytest.raises(JSONError): - get_input_dict_interfaces(input_json) - - input_json = {"interfaces": {"interface.folder/foo": {"content": FOO_CODE}}} - with pytest.raises(JSONError): - get_input_dict_interfaces(input_json) - - -def test_json_no_abi(): - input_json = {"interfaces": {"bar.json": {"content": BAR_ABI}}} - with pytest.raises(JSONError): - get_input_dict_interfaces(input_json) - - -def test_vy_no_content(): - input_json = {"interfaces": {"bar.vy": {"abi": BAR_CODE}}} - with pytest.raises(JSONError): - get_input_dict_interfaces(input_json) - - -def test_interfaces_output(): - input_json = { - "interfaces": { - "bar.json": {"abi": BAR_ABI}, - "interface.folder/bar2.vy": {"content": BAR_CODE}, - } - } - result = get_input_dict_interfaces(input_json) - assert isinstance(result, dict) - assert result == { - "bar": {"type": "json", "code": BAR_ABI}, - "interface.folder/bar2": {"type": "vyper", "code": BAR_CODE}, - } - - -def test_manifest_output(): - input_json = {"interfaces": {"bar.json": {"contractTypes": {"Bar": {"abi": BAR_ABI}}}}} - result = get_input_dict_interfaces(input_json) - assert isinstance(result, dict) - assert result == {"Bar": {"type": "json", "code": BAR_ABI}} - - -# get_interface_codes tests - - -def test_interface_codes_from_contracts(): - # interface should be generated from contract - assert get_interface_codes( - None, "foo.vy", {"foo.vy": FOO_CODE, "contracts/bar.vy": BAR_CODE}, {} - ) - assert get_interface_codes( - None, "foo/foo.vy", {"foo/foo.vy": FOO_CODE, "contracts/bar.vy": BAR_CODE}, {} - ) - - -def test_interface_codes_from_interfaces(): - # existing interface should be given preference over contract-as-interface - contracts = {"foo.vy": FOO_CODE, "contacts/bar.vy": BAR_CODE} - result = get_interface_codes(None, "foo.vy", contracts, {"contracts/bar": "bar"}) - assert result["Bar"] == "bar" - - -def test_root_path(tmp_path): - tmp_path.joinpath("contracts").mkdir() - with tmp_path.joinpath("contracts/bar.vy").open("w") as fp: - fp.write("bar") - - with pytest.raises(FileNotFoundError): - get_interface_codes(None, "foo.vy", {"foo.vy": FOO_CODE}, {}) - - # interface from file system should take lowest priority - result = get_interface_codes(tmp_path, "foo.vy", {"foo.vy": FOO_CODE}, {}) - assert result["Bar"] == {"code": "bar", "type": "vyper"} - contracts = {"foo.vy": FOO_CODE, "contracts/bar.vy": BAR_CODE} - result = get_interface_codes(None, "foo.vy", contracts, {}) - assert result["Bar"] == {"code": BAR_CODE, "type": "vyper"} diff --git a/tests/cli/vyper_json/test_output_dict.py b/tests/cli/vyper_json/test_output_dict.py deleted file mode 100644 index e2a3466ccf..0000000000 --- a/tests/cli/vyper_json/test_output_dict.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 - -import vyper -from vyper.cli.vyper_json import format_to_output_dict -from vyper.compiler import OUTPUT_FORMATS, compile_codes - -FOO_CODE = """ -@external -def foo() -> bool: - return True -""" - - -def test_keys(): - compiler_data = compile_codes({"foo.vy": FOO_CODE}, output_formats=list(OUTPUT_FORMATS.keys())) - output_json = format_to_output_dict(compiler_data) - assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] - assert output_json["compiler"] == f"vyper-{vyper.__version__}" - data = compiler_data["foo.vy"] - assert output_json["sources"]["foo.vy"] == {"id": 0, "ast": data["ast_dict"]["ast"]} - assert output_json["contracts"]["foo.vy"]["foo"] == { - "abi": data["abi"], - "devdoc": data["devdoc"], - "interface": data["interface"], - "ir": data["ir_dict"], - "userdoc": data["userdoc"], - "metadata": data["metadata"], - "evm": { - "bytecode": {"object": data["bytecode"], "opcodes": data["opcodes"]}, - "deployedBytecode": { - "object": data["bytecode_runtime"], - "opcodes": data["opcodes_runtime"], - "sourceMap": data["source_map"]["pc_pos_map_compressed"], - "sourceMapFull": data["source_map_full"], - }, - "methodIdentifiers": data["method_identifiers"], - }, - } diff --git a/tests/cli/vyper_json/test_output_selection.py b/tests/cli/vyper_json/test_output_selection.py index 3b12e2b54a..78ad7404f2 100644 --- a/tests/cli/vyper_json/test_output_selection.py +++ b/tests/cli/vyper_json/test_output_selection.py @@ -1,60 +1,60 @@ -#!/usr/bin/env python3 +from pathlib import PurePath import pytest -from vyper.cli.vyper_json import TRANSLATE_MAP, get_input_dict_output_formats +from vyper.cli.vyper_json import TRANSLATE_MAP, get_output_formats from vyper.exceptions import JSONError def test_no_outputs(): with pytest.raises(KeyError): - get_input_dict_output_formats({}, {}) + get_output_formats({}, {}) def test_invalid_output(): input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}} - sources = {"foo.vy": ""} + targets = [PurePath("foo.vy")] with pytest.raises(JSONError): - get_input_dict_output_formats(input_json, sources) + get_output_formats(input_json, targets) def test_unknown_contract(): input_json = {"settings": {"outputSelection": {"bar.vy": ["abi"]}}} - sources = {"foo.vy": ""} + targets = [PurePath("foo.vy")] with pytest.raises(JSONError): - get_input_dict_output_formats(input_json, sources) + get_output_formats(input_json, targets) @pytest.mark.parametrize("output", TRANSLATE_MAP.items()) def test_translate_map(output): input_json = {"settings": {"outputSelection": {"foo.vy": [output[0]]}}} - sources = {"foo.vy": ""} - assert get_input_dict_output_formats(input_json, sources) == {"foo.vy": [output[1]]} + targets = [PurePath("foo.vy")] + assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): [output[1]]} def test_star(): input_json = {"settings": {"outputSelection": {"*": ["*"]}}} - sources = {"foo.vy": "", "bar.vy": ""} + targets = [PurePath("foo.vy"), PurePath("bar.vy")] expected = sorted(set(TRANSLATE_MAP.values())) - result = get_input_dict_output_formats(input_json, sources) - assert result == {"foo.vy": expected, "bar.vy": expected} + result = get_output_formats(input_json, targets) + assert result == {PurePath("foo.vy"): expected, PurePath("bar.vy"): expected} def test_evm(): input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}} - sources = {"foo.vy": ""} + targets = [PurePath("foo.vy")] expected = ["abi"] + sorted(v for k, v in TRANSLATE_MAP.items() if k.startswith("evm")) - result = get_input_dict_output_formats(input_json, sources) - assert result == {"foo.vy": expected} + result = get_output_formats(input_json, targets) + assert result == {PurePath("foo.vy"): expected} def test_solc_style(): input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} - sources = {"foo.vy": ""} - assert get_input_dict_output_formats(input_json, sources) == {"foo.vy": ["abi", "ir_dict"]} + targets = [PurePath("foo.vy")] + assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["abi", "ir_dict"]} def test_metadata(): input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} - sources = {"foo.vy": ""} - assert get_input_dict_output_formats(input_json, sources) == {"foo.vy": ["metadata"]} + targets = [PurePath("foo.vy")] + assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["metadata"]} diff --git a/tests/cli/vyper_json/test_parse_args_vyperjson.py b/tests/cli/vyper_json/test_parse_args_vyperjson.py index 11e527843a..3b0f700c7e 100644 --- a/tests/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/cli/vyper_json/test_parse_args_vyperjson.py @@ -29,7 +29,6 @@ def bar(a: uint256) -> bool: "inputs": [{"type": "uint256", "name": "a"}], "stateMutability": "nonpayable", "type": "function", - "gas": 313, } ] @@ -39,7 +38,7 @@ def bar(a: uint256) -> bool: "contracts/foo.vy": {"content": FOO_CODE}, "contracts/bar.vy": {"content": BAR_CODE}, }, - "interfaces": {"contracts/bar.json": {"abi": BAR_ABI}}, + "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, "settings": {"outputSelection": {"*": ["*"]}}, } diff --git a/tests/compiler/test_bytecode_runtime.py b/tests/compiler/test_bytecode_runtime.py index 9519b03772..613ee4d2b8 100644 --- a/tests/compiler/test_bytecode_runtime.py +++ b/tests/compiler/test_bytecode_runtime.py @@ -48,14 +48,14 @@ def _parse_cbor_metadata(initcode): def test_bytecode_runtime(): - out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"]) + out = vyper.compile_code(simple_contract_code, output_formats=["bytecode_runtime", "bytecode"]) assert len(out["bytecode"]) > len(out["bytecode_runtime"]) assert out["bytecode_runtime"].removeprefix("0x") in out["bytecode"].removeprefix("0x") def test_bytecode_signature(): - out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"]) + out = vyper.compile_code(simple_contract_code, output_formats=["bytecode_runtime", "bytecode"]) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) @@ -72,7 +72,9 @@ def test_bytecode_signature(): def test_bytecode_signature_dense_jumptable(): settings = Settings(optimize=OptimizationLevel.CODESIZE) - out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings) + out = vyper.compile_code( + many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings + ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) @@ -89,7 +91,9 @@ def test_bytecode_signature_dense_jumptable(): def test_bytecode_signature_sparse_jumptable(): settings = Settings(optimize=OptimizationLevel.GAS) - out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings) + out = vyper.compile_code( + many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings + ) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) @@ -104,7 +108,7 @@ def test_bytecode_signature_sparse_jumptable(): def test_bytecode_signature_immutables(): - out = vyper.compile_code(has_immutables, ["bytecode_runtime", "bytecode"]) + out = vyper.compile_code(has_immutables, output_formats=["bytecode_runtime", "bytecode"]) runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x")) initcode = bytes.fromhex(out["bytecode"].removeprefix("0x")) diff --git a/tests/compiler/test_compile_code.py b/tests/compiler/test_compile_code.py index cdbf9d1f52..7af133e362 100644 --- a/tests/compiler/test_compile_code.py +++ b/tests/compiler/test_compile_code.py @@ -11,4 +11,4 @@ def a() -> bool: return True """ with pytest.warns(vyper.warnings.ContractSizeLimitWarning): - vyper.compile_code(code, ["bytecode_runtime"]) + vyper.compile_code(code, output_formats=["bytecode_runtime"]) diff --git a/tests/compiler/test_input_bundle.py b/tests/compiler/test_input_bundle.py new file mode 100644 index 0000000000..c49c81219b --- /dev/null +++ b/tests/compiler/test_input_bundle.py @@ -0,0 +1,208 @@ +import json +from pathlib import Path, PurePath + +import pytest + +from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, JSONInputBundle + + +# FilesystemInputBundle which uses same search path as make_file +@pytest.fixture +def input_bundle(tmp_path): + return FilesystemInputBundle([tmp_path]) + + +def test_load_file(make_file, input_bundle, tmp_path): + make_file("foo.vy", "contents") + + file = input_bundle.load_file(Path("foo.vy")) + + assert isinstance(file, FileInput) + assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + + +def test_search_path_context_manager(make_file, tmp_path): + ib = FilesystemInputBundle([]) + + make_file("foo.vy", "contents") + + with pytest.raises(FileNotFoundError): + # no search path given + ib.load_file(Path("foo.vy")) + + with ib.search_path(tmp_path): + file = ib.load_file(Path("foo.vy")) + + assert isinstance(file, FileInput) + assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + + +def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bundle): + # test search path precedence. + # most recent search path is the highest precedence + tmpdir = tmp_path_factory.mktemp("some_directory") + tmpdir2 = tmp_path_factory.mktemp("some_other_directory") + + for i, directory in enumerate([tmp_path, tmpdir, tmpdir2]): + with (directory / "foo.vy").open("w") as f: + f.write(f"contents {i}") + + ib = FilesystemInputBundle([tmp_path, tmpdir, tmpdir2]) + + file = ib.load_file("foo.vy") + + assert isinstance(file, FileInput) + assert file == FileInput(0, tmpdir2 / "foo.vy", "contents 2") + + with ib.search_path(tmpdir): + file = ib.load_file("foo.vy") + + assert isinstance(file, FileInput) + assert file == FileInput(1, tmpdir / "foo.vy", "contents 1") + + +# special rules for handling json files +def test_load_abi(make_file, input_bundle, tmp_path): + contents = json.dumps("some string") + + make_file("foo.json", contents) + + file = input_bundle.load_file("foo.json") + assert isinstance(file, ABIInput) + assert file == ABIInput(0, tmp_path / "foo.json", "some string") + + # suffix doesn't matter + make_file("foo.txt", contents) + + file = input_bundle.load_file("foo.txt") + assert isinstance(file, ABIInput) + assert file == ABIInput(1, tmp_path / "foo.txt", "some string") + + +# check that unique paths give unique source ids +def test_source_id_file_input(make_file, input_bundle, tmp_path): + make_file("foo.vy", "contents") + make_file("bar.vy", "contents 2") + + file = input_bundle.load_file("foo.vy") + assert file.source_id == 0 + assert file == FileInput(0, tmp_path / "foo.vy", "contents") + + file2 = input_bundle.load_file("bar.vy") + # source id increments + assert file2.source_id == 1 + assert file2 == FileInput(1, tmp_path / "bar.vy", "contents 2") + + file3 = input_bundle.load_file("foo.vy") + assert file3.source_id == 0 + assert file3 == FileInput(0, tmp_path / "foo.vy", "contents") + + +# check that unique paths give unique source ids +def test_source_id_json_input(make_file, input_bundle, tmp_path): + contents = json.dumps("some string") + contents2 = json.dumps(["some list"]) + + make_file("foo.json", contents) + + make_file("bar.json", contents2) + + file = input_bundle.load_file("foo.json") + assert isinstance(file, ABIInput) + assert file == ABIInput(0, tmp_path / "foo.json", "some string") + + file2 = input_bundle.load_file("bar.json") + assert isinstance(file2, ABIInput) + assert file2 == ABIInput(1, tmp_path / "bar.json", ["some list"]) + + file3 = input_bundle.load_file("foo.json") + assert isinstance(file3, ABIInput) + assert file3 == ABIInput(0, tmp_path / "foo.json", "some string") + + +# test some pathological case where the file changes underneath +def test_mutating_file_source_id(make_file, input_bundle, tmp_path): + make_file("foo.vy", "contents") + + file = input_bundle.load_file("foo.vy") + assert file.source_id == 0 + assert file == FileInput(0, tmp_path / "foo.vy", "contents") + + make_file("foo.vy", "new contents") + + file = input_bundle.load_file("foo.vy") + # source id hasn't changed, even though contents have + assert file.source_id == 0 + assert file == FileInput(0, tmp_path / "foo.vy", "new contents") + + +# test the os.normpath behavior of symlink +# (slightly pathological, for illustration's sake) +def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): + dir1 = tmp_path / "first" + dir2 = tmp_path / "second" + symlink = tmp_path / "symlink" + + dir1.mkdir() + dir2.mkdir() + symlink.symlink_to(dir2, target_is_directory=True) + + with (tmp_path / "foo.vy").open("w") as f: + f.write("contents of the upper directory") + + with (dir1 / "foo.vy").open("w") as f: + f.write("contents of the inner directory") + + # symlink rules would be: + # base/symlink/../foo.vy => + # base/first/second/../foo.vy => + # base/first/foo.vy + # normpath would be base/symlink/../foo.vy => + # base/foo.vy + file = input_bundle.load_file(symlink / ".." / "foo.vy") + + assert file == FileInput(0, tmp_path / "foo.vy", "contents of the upper directory") + + +def test_json_input_bundle_basic(): + files = {PurePath("foo.vy"): {"content": "some text"}} + input_bundle = JSONInputBundle(files, [PurePath(".")]) + + file = input_bundle.load_file(PurePath("foo.vy")) + assert file == FileInput(0, PurePath("foo.vy"), "some text") + + +def test_json_input_bundle_normpath(): + files = {PurePath("foo/../bar.vy"): {"content": "some text"}} + input_bundle = JSONInputBundle(files, [PurePath(".")]) + + expected = FileInput(0, PurePath("bar.vy"), "some text") + + file = input_bundle.load_file(PurePath("bar.vy")) + assert file == expected + + file = input_bundle.load_file(PurePath("baz/../bar.vy")) + assert file == expected + + file = input_bundle.load_file(PurePath("./bar.vy")) + assert file == expected + + with input_bundle.search_path(PurePath("foo")): + file = input_bundle.load_file(PurePath("../bar.vy")) + assert file == expected + + +def test_json_input_abi(): + some_abi = ["some abi"] + some_abi_str = json.dumps(some_abi) + files = { + PurePath("foo.json"): {"abi": some_abi}, + PurePath("bar.txt"): {"content": some_abi_str}, + } + input_bundle = JSONInputBundle(files, [PurePath(".")]) + + file = input_bundle.load_file(PurePath("foo.json")) + assert file == ABIInput(0, PurePath("foo.json"), some_abi) + + file = input_bundle.load_file(PurePath("bar.txt")) + assert file == ABIInput(1, PurePath("bar.txt"), some_abi) diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index 20f45ced6b..15d2a617ba 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -22,7 +22,7 @@ def a() -> bool: return True """ - out = vyper.compile_code(code, ["opcodes_runtime", "opcodes"]) + out = vyper.compile_code(code, output_formats=["opcodes_runtime", "opcodes"]) assert len(out["opcodes"]) > len(out["opcodes_runtime"]) assert out["opcodes_runtime"] in out["opcodes"] diff --git a/tests/compiler/test_source_map.py b/tests/compiler/test_source_map.py index 886596bb80..c9a152b09c 100644 --- a/tests/compiler/test_source_map.py +++ b/tests/compiler/test_source_map.py @@ -28,7 +28,7 @@ def foo(a: uint256) -> int128: def test_jump_map(): - source_map = compile_code(TEST_CODE, ["source_map"])["source_map"] + source_map = compile_code(TEST_CODE, output_formats=["source_map"])["source_map"] pos_map = source_map["pc_pos_map"] jump_map = source_map["pc_jump_map"] @@ -46,7 +46,7 @@ def test_jump_map(): def test_pos_map_offsets(): - source_map = compile_code(TEST_CODE, ["source_map"])["source_map"] + source_map = compile_code(TEST_CODE, output_formats=["source_map"])["source_map"] expanded = expand_source_map(source_map["pc_pos_map_compressed"]) pc_iter = iter(source_map["pc_pos_map"][i] for i in sorted(source_map["pc_pos_map"])) @@ -76,7 +76,7 @@ def test_error_map(): def update_foo(): self.foo += 1 """ - error_map = compile_code(code, ["source_map"])["source_map"]["error_map"] + error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] assert "safeadd" in list(error_map.values()) assert "fallback function" in list(error_map.values()) diff --git a/tests/conftest.py b/tests/conftest.py index c9d3f794a0..9b10b7c51c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from vyper import compiler from vyper.codegen.ir_node import IRnode +from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.settings import OptimizationLevel, _set_debug_mode from vyper.ir import compile_ir, optimizer @@ -70,6 +71,34 @@ def keccak(): return Web3.keccak +@pytest.fixture +def make_file(tmp_path): + # writes file_contents to file_name, creating it in the + # tmp_path directory. returns final path. + def fn(file_name, file_contents): + path = tmp_path / file_name + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + f.write(file_contents) + + return path + + return fn + + +# this can either be used for its side effects (to prepare a call +# to get_contract), or the result can be provided directly to +# compile_code / CompilerData. +@pytest.fixture +def make_input_bundle(tmp_path, make_file): + def fn(sources_dict): + for file_name, file_contents in sources_dict.items(): + make_file(file_name, file_contents) + return FilesystemInputBundle([tmp_path]) + + return fn + + @pytest.fixture def bytes_helper(): def bytes_helper(str, length): diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/parser/ast_utils/test_ast_dict.py index f483d0cbe8..1f60c9ac8b 100644 --- a/tests/parser/ast_utils/test_ast_dict.py +++ b/tests/parser/ast_utils/test_ast_dict.py @@ -19,7 +19,7 @@ def get_node_ids(ast_struct, ids=None): elif v is None or isinstance(v, (str, int)): continue else: - raise Exception("Unknown ast_struct provided.") + raise Exception(f"Unknown ast_struct provided. {k}, {v}") return ids @@ -30,7 +30,7 @@ def test() -> int128: a: uint256 = 100 return 123 """ - dict_out = compiler.compile_code(code, ["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) node_ids = get_node_ids(dict_out) assert len(node_ids) == len(set(node_ids)) @@ -40,7 +40,7 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, ["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) assert dict_out["ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", @@ -89,7 +89,7 @@ def foo() -> uint256: view def foo() -> uint256: return 1 """ - dict_out = compiler.compile_code(code, ["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) assert dict_out["ast_dict"]["ast"]["body"][1] == { "col_offset": 0, "annotation": { diff --git a/tests/parser/features/test_init.py b/tests/parser/features/test_init.py index 83bcbc95ea..29a466e869 100644 --- a/tests/parser/features/test_init.py +++ b/tests/parser/features/test_init.py @@ -15,7 +15,7 @@ def __init__(a: uint256): assert c.val() == 123 # Make sure the init code does not access calldata - assembly = vyper.compile_code(code, ["asm"])["asm"].split(" ") + assembly = vyper.compile_code(code, output_formats=["asm"])["asm"].split(" ") ir_return_idx_start = assembly.index("{") ir_return_idx_end = assembly.index("}") diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py index 3ba74034ac..1d62a5be79 100644 --- a/tests/parser/functions/test_bitwise.py +++ b/tests/parser/functions/test_bitwise.py @@ -32,7 +32,7 @@ def _shr(x: uint256, y: uint256) -> uint256: def test_bitwise_opcodes(): - opcodes = compile_code(code, ["opcodes"])["opcodes"] + opcodes = compile_code(code, output_formats=["opcodes"])["opcodes"] assert "SHL" in opcodes assert "SHR" in opcodes diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py index c16e188cfd..8cb0124f29 100644 --- a/tests/parser/functions/test_interfaces.py +++ b/tests/parser/functions/test_interfaces.py @@ -1,10 +1,15 @@ +import json from decimal import Decimal import pytest -from vyper.cli.utils import extract_file_interface_imports -from vyper.compiler import compile_code, compile_codes -from vyper.exceptions import ArgumentException, InterfaceViolation, StructureException +from vyper.compiler import compile_code +from vyper.exceptions import ( + ArgumentException, + InterfaceViolation, + NamespaceCollision, + StructureException, +) def test_basic_extract_interface(): @@ -24,7 +29,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): return 1, 2 """ - out = compile_code(code, ["interface"]) + out = compile_code(code, output_formats=["interface"]) out = out["interface"] code_pass = "\n".join(code.split("\n")[:-2] + [" pass"]) # replace with a pass statement. @@ -55,8 +60,9 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): view def test(_owner: address): nonpayable """ - out = compile_codes({"one.vy": code}, ["external_interface"])["one.vy"] - out = out["external_interface"] + out = compile_code(code, contract_name="One.vy", output_formats=["external_interface"])[ + "external_interface" + ] assert interface.strip() == out.strip() @@ -75,7 +81,7 @@ def test() -> bool: assert_compile_failed(lambda: compile_code(code), InterfaceViolation) -def test_external_interface_parsing(assert_compile_failed): +def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external def foo() -> uint256: @@ -86,7 +92,7 @@ def bar() -> uint256: pass """ - interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + input_bundle = make_input_bundle({"a.vy": interface_code}) code = """ import a as FooBarInterface @@ -102,7 +108,7 @@ def bar() -> uint256: return 2 """ - assert compile_code(code, interface_codes=interface_codes) + assert compile_code(code, input_bundle=input_bundle) not_implemented_code = """ import a as FooBarInterface @@ -116,18 +122,17 @@ def foo() -> uint256: """ assert_compile_failed( - lambda: compile_code(not_implemented_code, interface_codes=interface_codes), - InterfaceViolation, + lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation ) -def test_missing_event(assert_compile_failed): +def test_missing_event(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: a: uint256 """ - interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + input_bundle = make_input_bundle({"a.vy": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -140,19 +145,18 @@ def bar() -> uint256: """ assert_compile_failed( - lambda: compile_code(not_implemented_code, interface_codes=interface_codes), - InterfaceViolation, + lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation ) # check that event types match -def test_malformed_event(assert_compile_failed): +def test_malformed_event(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: a: uint256 """ - interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + input_bundle = make_input_bundle({"a.vy": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -168,19 +172,18 @@ def bar() -> uint256: """ assert_compile_failed( - lambda: compile_code(not_implemented_code, interface_codes=interface_codes), - InterfaceViolation, + lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation ) # check that event non-indexed arg needs to match interface -def test_malformed_events_indexed(assert_compile_failed): +def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: a: uint256 """ - interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + input_bundle = make_input_bundle({"a.vy": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -197,19 +200,18 @@ def bar() -> uint256: """ assert_compile_failed( - lambda: compile_code(not_implemented_code, interface_codes=interface_codes), - InterfaceViolation, + lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation ) # check that event indexed arg needs to match interface -def test_malformed_events_indexed2(assert_compile_failed): +def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: a: indexed(uint256) """ - interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + input_bundle = make_input_bundle({"a.vy": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -226,43 +228,47 @@ def bar() -> uint256: """ assert_compile_failed( - lambda: compile_code(not_implemented_code, interface_codes=interface_codes), - InterfaceViolation, + lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation ) VALID_IMPORT_CODE = [ # import statement, import path without suffix - ("import a as Foo", "a"), - ("import b.a as Foo", "b/a"), - ("import Foo as Foo", "Foo"), - ("from a import Foo", "a/Foo"), - ("from b.a import Foo", "b/a/Foo"), - ("from .a import Foo", "./a/Foo"), - ("from ..a import Foo", "../a/Foo"), + ("import a as Foo", "a.vy"), + ("import b.a as Foo", "b/a.vy"), + ("import Foo as Foo", "Foo.vy"), + ("from a import Foo", "a/Foo.vy"), + ("from b.a import Foo", "b/a/Foo.vy"), + ("from .a import Foo", "./a/Foo.vy"), + ("from ..a import Foo", "../a/Foo.vy"), ] -@pytest.mark.parametrize("code", VALID_IMPORT_CODE) -def test_extract_file_interface_imports(code): - assert extract_file_interface_imports(code[0]) == {"Foo": code[1]} +@pytest.mark.parametrize("code,filename", VALID_IMPORT_CODE) +def test_extract_file_interface_imports(code, filename, make_input_bundle): + input_bundle = make_input_bundle({filename: ""}) + + assert compile_code(code, input_bundle=input_bundle) is not None BAD_IMPORT_CODE = [ - "import a", # must alias absolute imports - "import a as A\nimport a as A", # namespace collisions - "from b import a\nfrom a import a", - "from . import a\nimport a as a", - "import a as a\nfrom . import a", + ("import a", StructureException), # must alias absolute imports + ("import a as A\nimport a as A", NamespaceCollision), + ("from b import a\nfrom . import a", NamespaceCollision), + ("from . import a\nimport a as a", NamespaceCollision), + ("import a as a\nfrom . import a", NamespaceCollision), ] -@pytest.mark.parametrize("code", BAD_IMPORT_CODE) -def test_extract_file_interface_imports_raises(code, assert_compile_failed): - assert_compile_failed(lambda: extract_file_interface_imports(code), StructureException) +@pytest.mark.parametrize("code,exception_type", BAD_IMPORT_CODE) +def test_extract_file_interface_imports_raises( + code, exception_type, assert_compile_failed, make_input_bundle +): + input_bundle = make_input_bundle({"a.vy": "", "b/a.vy": ""}) # dummy + assert_compile_failed(lambda: compile_code(code, input_bundle=input_bundle), exception_type) -def test_external_call_to_interface(w3, get_contract): +def test_external_call_to_interface(w3, get_contract, make_input_bundle): token_code = """ balanceOf: public(HashMap[address, uint256]) @@ -271,6 +277,8 @@ def transfer(to: address, _value: uint256): self.balanceOf[to] += _value """ + input_bundle = make_input_bundle({"one.vy": token_code}) + code = """ import one as TokenCode @@ -292,9 +300,7 @@ def test(): """ erc20 = get_contract(token_code) - test_c = get_contract( - code, *[erc20.address], interface_codes={"TokenCode": {"type": "vyper", "code": token_code}} - ) + test_c = get_contract(code, *[erc20.address], input_bundle=input_bundle) sender = w3.eth.accounts[0] assert erc20.balanceOf(sender) == 0 @@ -313,7 +319,7 @@ def test(): ("epsilon(decimal)", "decimal", Decimal("1E-10")), ], ) -def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected): +def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected, make_input_bundle): code_a = f""" @external @view @@ -321,6 +327,8 @@ def foo(_max: {typ} = {kwarg}) -> {typ}: return _max """ + input_bundle = make_input_bundle({"one.vy": code_a}) + code_b = f""" import one as ContractA @@ -331,11 +339,7 @@ def bar(a_address: address) -> {typ}: """ contract_a = get_contract(code_a) - contract_b = get_contract( - code_b, - *[contract_a.address], - interface_codes={"ContractA": {"type": "vyper", "code": code_a}}, - ) + contract_b = get_contract(code_b, *[contract_a.address], input_bundle=input_bundle) assert contract_b.bar(contract_a.address) == expected @@ -368,9 +372,7 @@ def test(): """ erc20 = get_contract(token_code) - test_c = get_contract( - code, *[erc20.address], interface_codes={"TokenCode": {"type": "vyper", "code": token_code}} - ) + test_c = get_contract(code, *[erc20.address]) sender = w3.eth.accounts[0] assert erc20.balanceOf(sender) == 0 @@ -440,11 +442,7 @@ def test_fail3() -> int256: """ bad_c = get_contract(external_contract) - c = get_contract( - code, - bad_c.address, - interface_codes={"BadCode": {"type": "vyper", "code": external_contract}}, - ) + c = get_contract(code, bad_c.address) assert bad_c.ok() == 1 assert bad_c.should_fail() == -(2**255) @@ -502,7 +500,9 @@ def test_fail2() -> Bytes[3]: # test data returned from external interface gets clamped -def test_json_abi_bytes_clampers(get_contract, assert_tx_failed, assert_compile_failed): +def test_json_abi_bytes_clampers( + get_contract, assert_tx_failed, assert_compile_failed, make_input_bundle +): external_contract = """ @external def returns_Bytes3() -> Bytes[3]: @@ -546,18 +546,15 @@ def test_fail3() -> Bytes[3]: """ bad_c = get_contract(external_contract) - bad_c_interface = { - "BadJSONInterface": { - "type": "json", - "code": compile_code(external_contract, ["abi"])["abi"], - } - } + + bad_json_interface = json.dumps(compile_code(external_contract, output_formats=["abi"])["abi"]) + input_bundle = make_input_bundle({"BadJSONInterface.json": bad_json_interface}) assert_compile_failed( - lambda: get_contract(should_not_compile, interface_codes=bad_c_interface), ArgumentException + lambda: get_contract(should_not_compile, input_bundle=input_bundle), ArgumentException ) - c = get_contract(code, bad_c.address, interface_codes=bad_c_interface) + c = get_contract(code, bad_c.address, input_bundle=input_bundle) assert bad_c.returns_Bytes3() == b"123" assert_tx_failed(lambda: c.test_fail1()) @@ -565,7 +562,7 @@ def test_fail3() -> Bytes[3]: assert_tx_failed(lambda: c.test_fail3()) -def test_units_interface(w3, get_contract): +def test_units_interface(w3, get_contract, make_input_bundle): code = """ import balanceof as BalanceOf @@ -576,49 +573,41 @@ def test_units_interface(w3, get_contract): def balanceOf(owner: address) -> uint256: return as_wei_value(1, "ether") """ + interface_code = """ @external @view def balanceOf(owner: address) -> uint256: pass """ - interface_codes = {"BalanceOf": {"type": "vyper", "code": interface_code}} - c = get_contract(code, interface_codes=interface_codes) + + input_bundle = make_input_bundle({"balanceof.vy": interface_code}) + + c = get_contract(code, input_bundle=input_bundle) assert c.balanceOf(w3.eth.accounts[0]) == w3.to_wei(1, "ether") -def test_local_and_global_interface_namespaces(): +def test_simple_implements(make_input_bundle): interface_code = """ @external def foo() -> uint256: pass """ - global_interface_codes = { - "FooInterface": {"type": "vyper", "code": interface_code}, - "BarInterface": {"type": "vyper", "code": interface_code}, - } - local_interface_codes = { - "FooContract": {"FooInterface": {"type": "vyper", "code": interface_code}}, - "BarContract": {"BarInterface": {"type": "vyper", "code": interface_code}}, - } - code = """ -import a as {0} +import a as FooInterface -implements: {0} +implements: FooInterface @external def foo() -> uint256: return 1 """ - codes = {"FooContract": code.format("FooInterface"), "BarContract": code.format("BarInterface")} + input_bundle = make_input_bundle({"a.vy": interface_code}) - global_compiled = compile_codes(codes, interface_codes=global_interface_codes) - local_compiled = compile_codes(codes, interface_codes=local_interface_codes) - assert global_compiled == local_compiled + assert compile_code(code, input_bundle=input_bundle) is not None def test_self_interface_is_allowed(get_contract): @@ -724,20 +713,28 @@ def convert_v1_abi(abi): @pytest.mark.parametrize("type_str", [i[0] for i in type_str_params]) -def test_json_interface_implements(type_str): +def test_json_interface_implements(type_str, make_input_bundle, make_file): code = interface_test_code.format(type_str) - abi = compile_code(code, ["abi"])["abi"] + abi = compile_code(code, output_formats=["abi"])["abi"] + code = f"import jsonabi as jsonabi\nimplements: jsonabi\n{code}" - compile_code(code, interface_codes={"jsonabi": {"type": "json", "code": abi}}) - compile_code(code, interface_codes={"jsonabi": {"type": "json", "code": convert_v1_abi(abi)}}) + + input_bundle = make_input_bundle({"jsonabi.json": json.dumps(abi)}) + + compile_code(code, input_bundle=input_bundle) + + # !!! overwrite the file + make_file("jsonabi.json", json.dumps(convert_v1_abi(abi))) + + compile_code(code, input_bundle=input_bundle) @pytest.mark.parametrize("type_str,value", type_str_params) -def test_json_interface_calls(get_contract, type_str, value): +def test_json_interface_calls(get_contract, type_str, value, make_input_bundle, make_file): code = interface_test_code.format(type_str) - abi = compile_code(code, ["abi"])["abi"] + abi = compile_code(code, output_formats=["abi"])["abi"] c1 = get_contract(code) code = f""" @@ -748,9 +745,11 @@ def test_json_interface_calls(get_contract, type_str, value): def test_call(a: address, b: {type_str}) -> {type_str}: return jsonabi(a).test_json(b) """ - c2 = get_contract(code, interface_codes={"jsonabi": {"type": "json", "code": abi}}) + input_bundle = make_input_bundle({"jsonabi.json": json.dumps(abi)}) + + c2 = get_contract(code, input_bundle=input_bundle) assert c2.test_call(c1.address, value) == value - c3 = get_contract( - code, interface_codes={"jsonabi": {"type": "json", "code": convert_v1_abi(abi)}} - ) + + make_file("jsonabi.json", json.dumps(convert_v1_abi(abi))) + c3 = get_contract(code, input_bundle=input_bundle) assert c3.test_call(c1.address, value) == value diff --git a/tests/parser/functions/test_raw_call.py b/tests/parser/functions/test_raw_call.py index 81efe64a18..5bb23447e4 100644 --- a/tests/parser/functions/test_raw_call.py +++ b/tests/parser/functions/test_raw_call.py @@ -274,8 +274,8 @@ def test_raw_call(_target: address): def test_raw_call(_target: address): raw_call(_target, method_id("foo()"), max_outsize=0) """ - output1 = compile_code(code1, ["bytecode", "bytecode_runtime"]) - output2 = compile_code(code2, ["bytecode", "bytecode_runtime"]) + output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"]) + output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"]) assert output1 == output2 @@ -296,8 +296,8 @@ def test_raw_call(_target: address) -> bool: a: bool = raw_call(_target, method_id("foo()"), max_outsize=0, revert_on_failure=False) return a """ - output1 = compile_code(code1, ["bytecode", "bytecode_runtime"]) - output2 = compile_code(code2, ["bytecode", "bytecode_runtime"]) + output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"]) + output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"]) assert output1 == output2 diff --git a/tests/parser/functions/test_return_struct.py b/tests/parser/functions/test_return_struct.py index 425caedb75..cdd8342d8a 100644 --- a/tests/parser/functions/test_return_struct.py +++ b/tests/parser/functions/test_return_struct.py @@ -17,7 +17,7 @@ def test() -> Voter: return a """ - out = compile_code(code, ["abi"]) + out = compile_code(code, output_formats=["abi"]) abi = out["abi"][0] assert abi["name"] == "test" @@ -38,7 +38,7 @@ def test() -> Voter: return a """ - out = compile_code(code, ["abi"]) + out = compile_code(code, output_formats=["abi"]) abi = out["abi"][0] assert abi["name"] == "test" diff --git a/tests/parser/syntax/test_codehash.py b/tests/parser/syntax/test_codehash.py index 5074d14636..c2d9a2e274 100644 --- a/tests/parser/syntax/test_codehash.py +++ b/tests/parser/syntax/test_codehash.py @@ -33,7 +33,7 @@ def foo4() -> bytes32: return self.a.codehash """ settings = Settings(evm_version=evm_version, optimize=optimize) - compiled = compile_code(code, ["bytecode_runtime"], settings=settings) + compiled = compile_code(code, output_formats=["bytecode_runtime"], settings=settings) bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:]) hash_ = keccak256(bytecode) diff --git a/tests/parser/syntax/test_interfaces.py b/tests/parser/syntax/test_interfaces.py index 498f1363d8..9100389dbd 100644 --- a/tests/parser/syntax/test_interfaces.py +++ b/tests/parser/syntax/test_interfaces.py @@ -374,7 +374,7 @@ def test_interfaces_success(good_code): assert compiler.compile_code(good_code) is not None -def test_imports_and_implements_within_interface(): +def test_imports_and_implements_within_interface(make_input_bundle): interface_code = """ from vyper.interfaces import ERC20 import foo.bar as Baz @@ -386,6 +386,8 @@ def foobar(): pass """ + input_bundle = make_input_bundle({"foo.vy": interface_code}) + code = """ import foo as Foo @@ -396,9 +398,4 @@ def foobar(): pass """ - assert ( - compiler.compile_code( - code, interface_codes={"Foo": {"type": "vyper", "code": interface_code}} - ) - is not None - ) + assert compiler.compile_code(code, input_bundle=input_bundle) is not None diff --git a/tests/parser/syntax/test_self_balance.py b/tests/parser/syntax/test_self_balance.py index 63db58e347..d22d8a2750 100644 --- a/tests/parser/syntax/test_self_balance.py +++ b/tests/parser/syntax/test_self_balance.py @@ -20,7 +20,7 @@ def __default__(): pass """ settings = Settings(evm_version=evm_version) - opcodes = compiler.compile_code(code, ["opcodes"], settings=settings)["opcodes"] + opcodes = compiler.compile_code(code, output_formats=["opcodes"], settings=settings)["opcodes"] if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["istanbul"]: assert "SELFBALANCE" in opcodes else: diff --git a/tests/parser/test_selector_table_stability.py b/tests/parser/test_selector_table_stability.py index abc2c17b8f..3302ff5009 100644 --- a/tests/parser/test_selector_table_stability.py +++ b/tests/parser/test_selector_table_stability.py @@ -8,7 +8,9 @@ def test_dense_jumptable_stability(): code = "\n".join(f"@external\ndef {name}():\n pass" for name in function_names) - output = compile_code(code, ["asm"], settings=Settings(optimize=OptimizationLevel.CODESIZE)) + output = compile_code( + code, output_formats=["asm"], settings=Settings(optimize=OptimizationLevel.CODESIZE) + ) # test that the selector table data is stable across different runs # (tox should provide different PYTHONHASHSEEDs). diff --git a/tests/parser/types/numbers/test_constants.py b/tests/parser/types/numbers/test_constants.py index 652c8e8bd9..25617651ec 100644 --- a/tests/parser/types/numbers/test_constants.py +++ b/tests/parser/types/numbers/test_constants.py @@ -206,7 +206,7 @@ def test() -> uint256: return ret """ - ir = compile_code(code, ["ir"])["ir"] + ir = compile_code(code, output_formats=["ir"])["ir"] assert search_for_sublist( ir, ["mstore", [MemoryPositions.RESERVED_MEMORY], [2**12 * some_prime]] ) diff --git a/vyper/__init__.py b/vyper/__init__.py index 35237bd044..482d5c3a60 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -1,6 +1,6 @@ from pathlib import Path as _Path -from vyper.compiler import compile_code, compile_codes # noqa: F401 +from vyper.compiler import compile_code # noqa: F401 try: from importlib.metadata import PackageNotFoundError # type: ignore diff --git a/vyper/builtins/interfaces/ERC165.py b/vyper/builtins/interfaces/ERC165.vy similarity index 75% rename from vyper/builtins/interfaces/ERC165.py rename to vyper/builtins/interfaces/ERC165.vy index 0a75431f3c..a4ca451abd 100644 --- a/vyper/builtins/interfaces/ERC165.py +++ b/vyper/builtins/interfaces/ERC165.vy @@ -1,6 +1,4 @@ -interface_code = """ @view @external def supportsInterface(interface_id: bytes4) -> bool: pass -""" diff --git a/vyper/builtins/interfaces/ERC20.py b/vyper/builtins/interfaces/ERC20.vy similarity index 96% rename from vyper/builtins/interfaces/ERC20.py rename to vyper/builtins/interfaces/ERC20.vy index a63408672b..065ca97a9b 100644 --- a/vyper/builtins/interfaces/ERC20.py +++ b/vyper/builtins/interfaces/ERC20.vy @@ -1,4 +1,3 @@ -interface_code = """ # Events event Transfer: _from: indexed(address) @@ -37,4 +36,3 @@ def transferFrom(_from: address, _to: address, _value: uint256) -> bool: @external def approve(_spender: address, _value: uint256) -> bool: pass -""" diff --git a/vyper/builtins/interfaces/ERC20Detailed.py b/vyper/builtins/interfaces/ERC20Detailed.py deleted file mode 100644 index 03dd597e8a..0000000000 --- a/vyper/builtins/interfaces/ERC20Detailed.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -NOTE: interface uses `String[1]` where 1 is the lower bound of the string returned by the function. - For end-users this means they can't use `implements: ERC20Detailed` unless their implementation - uses a value n >= 1. Regardless this is fine as one can't do String[0] where n == 0. -""" - -interface_code = """ -@view -@external -def name() -> String[1]: - pass - -@view -@external -def symbol() -> String[1]: - pass - -@view -@external -def decimals() -> uint8: - pass -""" diff --git a/vyper/builtins/interfaces/ERC20Detailed.vy b/vyper/builtins/interfaces/ERC20Detailed.vy new file mode 100644 index 0000000000..7c4f546d45 --- /dev/null +++ b/vyper/builtins/interfaces/ERC20Detailed.vy @@ -0,0 +1,18 @@ +#NOTE: interface uses `String[1]` where 1 is the lower bound of the string returned by the function. +# For end-users this means they can't use `implements: ERC20Detailed` unless their implementation +# uses a value n >= 1. Regardless this is fine as one can't do String[0] where n == 0. + +@view +@external +def name() -> String[1]: + pass + +@view +@external +def symbol() -> String[1]: + pass + +@view +@external +def decimals() -> uint8: + pass diff --git a/vyper/builtins/interfaces/ERC4626.py b/vyper/builtins/interfaces/ERC4626.vy similarity index 98% rename from vyper/builtins/interfaces/ERC4626.py rename to vyper/builtins/interfaces/ERC4626.vy index 21a9ce723a..05865406cf 100644 --- a/vyper/builtins/interfaces/ERC4626.py +++ b/vyper/builtins/interfaces/ERC4626.vy @@ -1,4 +1,3 @@ -interface_code = """ # Events event Deposit: sender: indexed(address) @@ -89,4 +88,3 @@ def previewRedeem(shares: uint256) -> uint256: @external def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: pass -""" diff --git a/vyper/builtins/interfaces/ERC721.py b/vyper/builtins/interfaces/ERC721.vy similarity index 97% rename from vyper/builtins/interfaces/ERC721.py rename to vyper/builtins/interfaces/ERC721.vy index 8dea4e4976..464c0e255b 100644 --- a/vyper/builtins/interfaces/ERC721.py +++ b/vyper/builtins/interfaces/ERC721.vy @@ -1,4 +1,3 @@ -interface_code = """ # Events event Transfer: @@ -66,5 +65,3 @@ def approve(_approved: address, _tokenId: uint256): @external def setApprovalForAll(_operator: address, _approved: bool): pass - -""" diff --git a/vyper/builtins/interfaces/__init__.py b/vyper/builtins/interfaces/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/vyper/cli/utils.py b/vyper/cli/utils.py deleted file mode 100644 index 1110ecdfdd..0000000000 --- a/vyper/cli/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -from pathlib import Path -from typing import Sequence - -from vyper import ast as vy_ast -from vyper.exceptions import StructureException -from vyper.typing import InterfaceImports, SourceCode - - -def get_interface_file_path(base_paths: Sequence, import_path: str) -> Path: - relative_path = Path(import_path) - for path in base_paths: - # Find ABI JSON files - file_path = path.joinpath(relative_path) - suffix = next((i for i in (".vy", ".json") if file_path.with_suffix(i).exists()), None) - if suffix: - return file_path.with_suffix(suffix) - - # Find ethPM Manifest files (`from path.to.Manifest import InterfaceName`) - # NOTE: Use file parent because this assumes that `file_path` - # coincides with an ABI interface file - file_path = file_path.parent - suffix = next((i for i in (".vy", ".json") if file_path.with_suffix(i).exists()), None) - if suffix: - return file_path.with_suffix(suffix) - - raise FileNotFoundError(f" Cannot locate interface '{import_path}{{.vy,.json}}'") - - -def extract_file_interface_imports(code: SourceCode) -> InterfaceImports: - ast_tree = vy_ast.parse_to_ast(code) - - imports_dict: InterfaceImports = {} - for node in ast_tree.get_children((vy_ast.Import, vy_ast.ImportFrom)): - if isinstance(node, vy_ast.Import): # type: ignore - if not node.alias: - raise StructureException("Import requires an accompanying `as` statement", node) - if node.alias in imports_dict: - raise StructureException(f"Interface with alias {node.alias} already exists", node) - imports_dict[node.alias] = node.name.replace(".", "/") - elif isinstance(node, vy_ast.ImportFrom): # type: ignore - level = node.level # type: ignore - module = node.module or "" # type: ignore - if not level and module == "vyper.interfaces": - # uses a builtin interface, so skip adding to imports - continue - - base_path = "" - if level > 1: - base_path = "../" * (level - 1) - elif level == 1: - base_path = "./" - base_path = f"{base_path}{module.replace('.','/')}/" - - if node.name in imports_dict and imports_dict[node.name] != f"{base_path}{node.name}": - raise StructureException(f"Interface with name {node.name} already exists", node) - imports_dict[node.name] = f"{base_path}{node.name}" - - return imports_dict diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index bdd01eebbe..c4f60660cb 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -3,14 +3,13 @@ import json import sys import warnings -from collections import OrderedDict from pathlib import Path -from typing import Dict, Iterable, Iterator, Optional, Set, TypeVar +from typing import Any, Iterable, Iterator, Optional, Set, TypeVar import vyper import vyper.codegen.ir_node as ir_node from vyper.cli import vyper_json -from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path +from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle from vyper.compiler.settings import ( VYPER_TRACEBACK_LIMIT, OptimizationLevel, @@ -18,7 +17,7 @@ _set_debug_mode, ) from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS -from vyper.typing import ContractCodes, ContractPath, OutputFormats +from vyper.typing import ContractPath, OutputFormats T = TypeVar("T") @@ -219,94 +218,20 @@ def exc_handler(contract_path: ContractPath, exception: Exception) -> None: raise exception -def get_interface_codes(root_path: Path, contract_sources: ContractCodes) -> Dict: - interface_codes: Dict = {} - interfaces: Dict = {} - - for file_path, code in contract_sources.items(): - interfaces[file_path] = {} - parent_path = root_path.joinpath(file_path).parent - - interface_codes = extract_file_interface_imports(code) - for interface_name, interface_path in interface_codes.items(): - base_paths = [parent_path] - if not interface_path.startswith(".") and root_path.joinpath(file_path).exists(): - base_paths.append(root_path) - elif interface_path.startswith("../") and len(Path(file_path).parent.parts) < Path( - interface_path - ).parts.count(".."): - raise FileNotFoundError( - f"{file_path} - Cannot perform relative import outside of base folder" - ) - - valid_path = get_interface_file_path(base_paths, interface_path) - with valid_path.open() as fh: - code = fh.read() - if valid_path.suffix == ".json": - contents = json.loads(code.encode()) - - # EthPM Manifest (EIP-2678) - if "contractTypes" in contents: - if ( - interface_name not in contents["contractTypes"] - or "abi" not in contents["contractTypes"][interface_name] - ): - raise ValueError( - f"Could not find interface '{interface_name}'" - f" in manifest '{valid_path}'." - ) - - interfaces[file_path][interface_name] = { - "type": "json", - "code": contents["contractTypes"][interface_name]["abi"], - } - - # ABI JSON file (either `List[ABI]` or `{"abi": List[ABI]}`) - elif isinstance(contents, list) or ( - "abi" in contents and isinstance(contents["abi"], list) - ): - interfaces[file_path][interface_name] = {"type": "json", "code": contents} - - else: - raise ValueError(f"Corrupted file: '{valid_path}'") - - else: - interfaces[file_path][interface_name] = {"type": "vyper", "code": code} - - return interfaces - - def compile_files( - input_files: Iterable[str], + input_files: list[str], output_formats: OutputFormats, root_folder: str = ".", show_gas_estimates: bool = False, settings: Optional[Settings] = None, - storage_layout: Optional[Iterable[str]] = None, + storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, -) -> OrderedDict: +) -> dict: root_path = Path(root_folder).resolve() if not root_path.exists(): raise FileNotFoundError(f"Invalid root path - '{root_path.as_posix()}' does not exist") - contract_sources: ContractCodes = OrderedDict() - for file_name in input_files: - file_path = Path(file_name) - try: - file_str = file_path.resolve().relative_to(root_path).as_posix() - except ValueError: - file_str = file_path.as_posix() - with file_path.open() as fh: - # trailing newline fixes python parsing bug when source ends in a comment - # https://bugs.python.org/issue35107 - contract_sources[file_str] = fh.read() + "\n" - - storage_layouts = OrderedDict() - if storage_layout: - for storage_file_name, contract_name in zip(storage_layout, contract_sources.keys()): - storage_file_path = Path(storage_file_name) - with storage_file_path.open() as sfh: - storage_layouts[contract_name] = json.load(sfh) + input_bundle = FilesystemInputBundle([root_path]) show_version = False if "combined_json" in output_formats: @@ -318,20 +243,44 @@ def compile_files( translate_map = {"abi_python": "abi", "json": "abi", "ast": "ast_dict", "ir_json": "ir_dict"} final_formats = [translate_map.get(i, i) for i in output_formats] - compiler_data = vyper.compile_codes( - contract_sources, - final_formats, - exc_handler=exc_handler, - interface_codes=get_interface_codes(root_path, contract_sources), - settings=settings, - storage_layouts=storage_layouts, - show_gas_estimates=show_gas_estimates, - no_bytecode_metadata=no_bytecode_metadata, - ) + if storage_layout_paths: + if len(storage_layout_paths) != len(input_files): + raise ValueError( + "provided {len(storage_layout_paths)} storage " + "layouts, but {len(input_files)} source files" + ) + + ret: dict[Any, Any] = {} if show_version: - compiler_data["version"] = vyper.__version__ + ret["version"] = vyper.__version__ - return compiler_data + for file_name in input_files: + file_path = Path(file_name) + file = input_bundle.load_file(file_path) + assert isinstance(file, FileInput) # mypy hint + + storage_layout_override = None + if storage_layout_paths: + storage_file_path = storage_layout_paths.pop(0) + with open(storage_file_path) as sfh: + storage_layout_override = json.load(sfh) + + output = vyper.compile_code( + file.source_code, + contract_name=str(file.path), + source_id=file.source_id, + input_bundle=input_bundle, + output_formats=final_formats, + exc_handler=exc_handler, + settings=settings, + storage_layout_override=storage_layout_override, + show_gas_estimates=show_gas_estimates, + no_bytecode_metadata=no_bytecode_metadata, + ) + + ret[file_path] = output + + return ret if __name__ == "__main__": diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index f6d82c3fe0..2720f20d23 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -4,15 +4,14 @@ import json import sys import warnings -from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union +from pathlib import Path, PurePath +from typing import Any, Callable, Hashable, Optional import vyper -from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path +from vyper.compiler.input_bundle import FileInput, JSONInputBundle from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError -from vyper.typing import ContractCodes, ContractPath from vyper.utils import keccak256 TRANSLATE_MAP = { @@ -97,15 +96,15 @@ def _parse_args(argv): print(output_json) -def exc_handler_raises(file_path: Union[str, None], exception: Exception, component: str) -> None: +def exc_handler_raises(file_path: Optional[str], exception: Exception, component: str) -> None: if file_path: print(f"Unhandled exception in '{file_path}':") exception._exc_handler = True # type: ignore raise exception -def exc_handler_to_dict(file_path: Union[str, None], exception: Exception, component: str) -> Dict: - err_dict: Dict = { +def exc_handler_to_dict(file_path: Optional[str], exception: Exception, component: str) -> dict: + err_dict: dict = { "type": type(exception).__name__, "component": component, "severity": "error", @@ -129,23 +128,7 @@ def exc_handler_to_dict(file_path: Union[str, None], exception: Exception, compo return output_json -def _standardize_path(path_str: str) -> str: - try: - path = Path(path_str) - - if path.is_absolute(): - path = path.resolve() - else: - pwd = Path(".").resolve() - path = path.resolve().relative_to(pwd) - - except ValueError: - raise JSONError(f"{path_str} - path exists outside base folder") - - return path.as_posix() - - -def get_evm_version(input_dict: Dict) -> Optional[str]: +def get_evm_version(input_dict: dict) -> Optional[str]: if "settings" not in input_dict: return None @@ -168,76 +151,75 @@ def get_evm_version(input_dict: Dict) -> Optional[str]: return evm_version -def get_input_dict_contracts(input_dict: Dict) -> ContractCodes: - contract_sources: ContractCodes = {} +def get_compilation_targets(input_dict: dict) -> list[PurePath]: + # TODO: once we have modules, add optional "compilation_targets" key + # which specifies which sources we actually want to compile. + + return [PurePath(p) for p in input_dict["sources"].keys()] + + +def get_inputs(input_dict: dict) -> dict[PurePath, Any]: + ret = {} + seen = {} + for path, value in input_dict["sources"].items(): + path = PurePath(path) if "urls" in value: raise JSONError(f"{path} - 'urls' is not a supported field, use 'content' instead") if "content" not in value: raise JSONError(f"{path} missing required field - 'content'") if "keccak256" in value: - hash_ = value["keccak256"].lower() - if hash_.startswith("0x"): - hash_ = hash_[2:] + hash_ = value["keccak256"].lower().removeprefix("0x") if hash_ != keccak256(value["content"].encode("utf-8")).hex(): raise JSONError( f"Calculated keccak of '{path}' does not match keccak given in input JSON" ) - key = _standardize_path(path) - if key in contract_sources: - raise JSONError(f"Contract namespace collision: {key}") - contract_sources[key] = value["content"] - return contract_sources + if path.stem in seen: + raise JSONError(f"Contract namespace collision: {path}") - -def get_input_dict_interfaces(input_dict: Dict) -> Dict: - interface_sources: Dict = {} + # value looks like {"content": } + # this will be interpreted by JSONInputBundle later + ret[path] = value + seen[path.stem] = True for path, value in input_dict.get("interfaces", {}).items(): - key = _standardize_path(path) - - if key.endswith(".json"): - # EthPM Manifest v3 (EIP-2678) - if "contractTypes" in value: - for name, ct in value["contractTypes"].items(): - if name in interface_sources: - raise JSONError(f"Interface namespace collision: {name}") - - interface_sources[name] = {"type": "json", "code": ct["abi"]} - - continue # Skip to next interface - - # ABI JSON file (`{"abi": List[ABI]}`) - elif "abi" in value: - interface = {"type": "json", "code": value["abi"]} - - # ABI JSON file (`List[ABI]`) - elif isinstance(value, list): - interface = {"type": "json", "code": value} - - else: - raise JSONError(f"Interface '{path}' must have 'abi' field") - - elif key.endswith(".vy"): - if "content" not in value: - raise JSONError(f"Interface '{path}' must have 'content' field") - - interface = {"type": "vyper", "code": value["content"]} - + path = PurePath(path) + if path.stem in seen: + raise JSONError(f"Interface namespace collision: {path}") + + if isinstance(value, list): + # backwards compatibility - straight ABI with no "abi" key. + # (should probably just reject these) + value = {"abi": value} + + # some validation + if not isinstance(value, dict): + raise JSONError("invalid interface (must be a dictionary):\n{json.dumps(value)}") + if "content" in value: + if not isinstance(value["content"], str): + raise JSONError(f"invalid 'content' (expected string):\n{json.dumps(value)}") + elif "abi" in value: + if not isinstance(value["abi"], list): + raise JSONError(f"invalid 'abi' (expected list):\n{json.dumps(value)}") else: - raise JSONError(f"Interface '{path}' must have suffix '.vy' or '.json'") - - key = key.rsplit(".", maxsplit=1)[0] - if key in interface_sources: - raise JSONError(f"Interface namespace collision: {key}") + raise JSONError( + "invalid interface (must contain either 'content' or 'abi'):\n{json.dumps(value)}" + ) + if "content" in value and "abi" in value: + raise JSONError( + "invalid interface (found both 'content' and 'abi'):\n{json.dumps(value)}" + ) - interface_sources[key] = interface + ret[path] = value + seen[path.stem] = True - return interface_sources + return ret -def get_input_dict_output_formats(input_dict: Dict, contract_sources: ContractCodes) -> Dict: - output_formats = {} +# get unique output formats for each contract, given the input_dict +# NOTE: would maybe be nice to raise on duplicated output formats +def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePath, list[str]]: + output_formats: dict[PurePath, list[str]] = {} for path, outputs in input_dict["settings"]["outputSelection"].items(): if isinstance(outputs, dict): # if outputs are given in solc json format, collapse them into a single list @@ -248,6 +230,7 @@ def get_input_dict_output_formats(input_dict: Dict, contract_sources: ContractCo for key in [i for i in ("evm", "evm.bytecode", "evm.deployedBytecode") if i in outputs]: outputs.remove(key) outputs.update([i for i in TRANSLATE_MAP if i.startswith(key)]) + if "*" in outputs: outputs = TRANSLATE_MAP.values() else: @@ -259,107 +242,23 @@ def get_input_dict_output_formats(input_dict: Dict, contract_sources: ContractCo outputs = sorted(set(outputs)) if path == "*": - output_keys = list(contract_sources.keys()) + output_paths = targets else: - output_keys = [_standardize_path(path)] - if output_keys[0] not in contract_sources: - raise JSONError(f"outputSelection references unknown contract '{output_keys[0]}'") + output_paths = [PurePath(path)] + if output_paths[0] not in targets: + raise JSONError(f"outputSelection references unknown contract '{output_paths[0]}'") - for key in output_keys: - output_formats[key] = outputs + for output_path in output_paths: + output_formats[output_path] = outputs return output_formats -def get_interface_codes( - root_path: Union[Path, None], - contract_path: ContractPath, - contract_sources: ContractCodes, - interface_sources: Dict, -) -> Dict: - interface_codes: Dict = {} - interfaces: Dict = {} - - code = contract_sources[contract_path] - interface_codes = extract_file_interface_imports(code) - for interface_name, interface_path in interface_codes.items(): - # If we know the interfaces already (e.g. EthPM Manifest file) - if interface_name in interface_sources: - interfaces[interface_name] = interface_sources[interface_name] - continue - - path = Path(contract_path).parent.joinpath(interface_path).as_posix() - keys = [_standardize_path(path)] - if not interface_path.startswith("."): - keys.append(interface_path) - - key = next((i for i in keys if i in interface_sources), None) - if key: - interfaces[interface_name] = interface_sources[key] - continue - - key = next((i + ".vy" for i in keys if i + ".vy" in contract_sources), None) - if key: - interfaces[interface_name] = {"type": "vyper", "code": contract_sources[key]} - continue - - if root_path is None: - raise FileNotFoundError(f"Cannot locate interface '{interface_path}{{.vy,.json}}'") - - parent_path = root_path.joinpath(contract_path).parent - base_paths = [parent_path] - if not interface_path.startswith("."): - base_paths.append(root_path) - elif interface_path.startswith("../") and len(Path(contract_path).parent.parts) < Path( - interface_path - ).parts.count(".."): - raise FileNotFoundError( - f"{contract_path} - Cannot perform relative import outside of base folder" - ) - - valid_path = get_interface_file_path(base_paths, interface_path) - with valid_path.open() as fh: - code = fh.read() - if valid_path.suffix == ".json": - code_dict = json.loads(code.encode()) - # EthPM Manifest v3 (EIP-2678) - if "contractTypes" in code_dict: - if interface_name not in code_dict["contractTypes"]: - raise JSONError(f"'{interface_name}' not found in '{valid_path}'") - - if "abi" not in code_dict["contractTypes"][interface_name]: - raise JSONError(f"Missing abi for '{interface_name}' in '{valid_path}'") - - abi = code_dict["contractTypes"][interface_name]["abi"] - interfaces[interface_name] = {"type": "json", "code": abi} - - # ABI JSON (`{"abi": List[ABI]}`) - elif "abi" in code_dict: - interfaces[interface_name] = {"type": "json", "code": code_dict["abi"]} - - # ABI JSON (`List[ABI]`) - elif isinstance(code_dict, list): - interfaces[interface_name] = {"type": "json", "code": code_dict} - - else: - raise JSONError(f"Unexpected type in file: '{valid_path}'") - - else: - interfaces[interface_name] = {"type": "vyper", "code": code} - - return interfaces - - def compile_from_input_dict( - input_dict: Dict, - exc_handler: Callable = exc_handler_raises, - root_folder: Union[str, None] = None, -) -> Tuple[Dict, Dict]: - root_path = None - if root_folder is not None: - root_path = Path(root_folder).resolve() - if not root_path.exists(): - raise FileNotFoundError(f"Invalid root path - '{root_path.as_posix()}' does not exist") + input_dict: dict, exc_handler: Callable = exc_handler_raises, root_folder: Optional[str] = None +) -> tuple[dict, dict]: + if root_folder is None: + root_folder = "." if input_dict["language"] != "Vyper": raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.") @@ -382,46 +281,50 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) - contract_sources: ContractCodes = get_input_dict_contracts(input_dict) - interface_sources = get_input_dict_interfaces(input_dict) - output_formats = get_input_dict_output_formats(input_dict, contract_sources) + compilation_targets = get_compilation_targets(input_dict) + sources = get_inputs(input_dict) + output_formats = get_output_formats(input_dict, compilation_targets) - compiler_data, warning_data = {}, {} + input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) + + res, warnings_dict = {}, {} warnings.simplefilter("always") - for id_, contract_path in enumerate(sorted(contract_sources)): + for contract_path in compilation_targets: with warnings.catch_warnings(record=True) as caught_warnings: try: - interface_codes = get_interface_codes( - root_path, contract_path, contract_sources, interface_sources - ) - except Exception as exc: - return exc_handler(contract_path, exc, "parser"), {} - try: - data = vyper.compile_codes( - {contract_path: contract_sources[contract_path]}, - output_formats[contract_path], - interface_codes=interface_codes, - initial_id=id_, + # use load_file to get a unique source_id + file = input_bundle.load_file(contract_path) + assert isinstance(file, FileInput) # mypy hint + data = vyper.compile_code( + file.source_code, + contract_name=str(file.path), + input_bundle=input_bundle, + output_formats=output_formats[contract_path], + source_id=file.source_id, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) + assert isinstance(data, dict) + data["source_id"] = file.source_id except Exception as exc: return exc_handler(contract_path, exc, "compiler"), {} - compiler_data[contract_path] = data[contract_path] + res[contract_path] = data if caught_warnings: - warning_data[contract_path] = caught_warnings + warnings_dict[contract_path] = caught_warnings - return compiler_data, warning_data + return res, warnings_dict -def format_to_output_dict(compiler_data: Dict) -> Dict: - output_dict: Dict = {"compiler": f"vyper-{vyper.__version__}", "contracts": {}, "sources": {}} - for id_, (path, data) in enumerate(compiler_data.items()): - output_dict["sources"][path] = {"id": id_} +# convert output of compile_input_dict to final output format +def format_to_output_dict(compiler_data: dict) -> dict: + output_dict: dict = {"compiler": f"vyper-{vyper.__version__}", "contracts": {}, "sources": {}} + for path, data in compiler_data.items(): + path = str(path) # Path breaks json serializability + output_dict["sources"][path] = {"id": data["source_id"]} if "ast_dict" in data: output_dict["sources"][path]["ast"] = data["ast_dict"]["ast"] - name = Path(path).stem + name = PurePath(path).stem output_dict["contracts"][path] = {name: {}} output_contracts = output_dict["contracts"][path][name] @@ -459,7 +362,7 @@ def format_to_output_dict(compiler_data: Dict) -> Dict: # https://stackoverflow.com/a/49518779 -def _raise_on_duplicate_keys(ordered_pairs: List[Tuple[Hashable, Any]]) -> Dict: +def _raise_on_duplicate_keys(ordered_pairs: list[tuple[Hashable, Any]]) -> dict: """ Raise JSONError if a duplicate key exists in provided ordered list of pairs, otherwise return a dict. @@ -474,17 +377,15 @@ def _raise_on_duplicate_keys(ordered_pairs: List[Tuple[Hashable, Any]]) -> Dict: def compile_json( - input_json: Union[Dict, str], + input_json: dict | str, exc_handler: Callable = exc_handler_raises, - root_path: Union[str, None] = None, - json_path: Union[str, None] = None, -) -> Dict: + root_folder: Optional[str] = None, + json_path: Optional[str] = None, +) -> dict: try: if isinstance(input_json, str): try: - input_dict: Dict = json.loads( - input_json, object_pairs_hook=_raise_on_duplicate_keys - ) + input_dict = json.loads(input_json, object_pairs_hook=_raise_on_duplicate_keys) except json.decoder.JSONDecodeError as exc: new_exc = JSONError(str(exc), exc.lineno, exc.colno) return exc_handler(json_path, new_exc, "json") @@ -492,7 +393,7 @@ def compile_json( input_dict = input_json try: - compiler_data, warn_data = compile_from_input_dict(input_dict, exc_handler, root_path) + compiler_data, warn_data = compile_from_input_dict(input_dict, exc_handler, root_folder) if "errors" in compiler_data: return compiler_data except KeyError as exc: diff --git a/vyper/cli/vyper_serve.py b/vyper/cli/vyper_serve.py index 401e59e7ba..9771dc922d 100755 --- a/vyper/cli/vyper_serve.py +++ b/vyper/cli/vyper_serve.py @@ -91,11 +91,11 @@ def _compile(self, data): try: code = data["code"] - out_dict = vyper.compile_codes( - {"": code}, + out_dict = vyper.compile_code( + code, list(vyper.compiler.OUTPUT_FORMATS.keys()), evm_version=data.get("evm_version", DEFAULT_EVM_VERSION), - )[""] + ) out_dict["ir"] = str(out_dict["ir"]) out_dict["ir_runtime"] = str(out_dict["ir_runtime"]) except VyperException as e: diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index b1c4201361..62ea05b243 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -1,21 +1,15 @@ from collections import OrderedDict +from pathlib import Path from typing import Any, Callable, Dict, Optional, Sequence, Union import vyper.ast as vy_ast # break an import cycle import vyper.codegen.core as codegen import vyper.compiler.output as output +from vyper.compiler.input_bundle import InputBundle, PathLike from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version -from vyper.typing import ( - ContractCodes, - ContractPath, - InterfaceDict, - InterfaceImports, - OutputDict, - OutputFormats, - StorageLayout, -) +from vyper.typing import ContractPath, OutputFormats, StorageLayout OUTPUT_FORMATS = { # requires vyper_module @@ -47,119 +41,25 @@ } -def compile_codes( - contract_sources: ContractCodes, - output_formats: Union[OutputDict, OutputFormats, None] = None, - exc_handler: Union[Callable, None] = None, - interface_codes: Union[InterfaceDict, InterfaceImports, None] = None, - initial_id: int = 0, - settings: Settings = None, - storage_layouts: Optional[dict[ContractPath, Optional[StorageLayout]]] = None, - show_gas_estimates: bool = False, - no_bytecode_metadata: bool = False, -) -> OrderedDict: - """ - Generate compiler output(s) from one or more contract source codes. - - Arguments - --------- - contract_sources: Dict[str, str] - Vyper source codes to be compiled. Formatted as `{"contract name": "source code"}` - output_formats: List, optional - List of compiler outputs to generate. Possible options are all the keys - in `OUTPUT_FORMATS`. If not given, the deployment bytecode is generated. - exc_handler: Callable, optional - Callable used to handle exceptions if the compilation fails. Should accept - two arguments - the name of the contract, and the exception that was raised - initial_id: int, optional - The lowest source ID value to be used when generating the source map. - settings: Settings, optional - Compiler settings - show_gas_estimates: bool, optional - Show gas estimates for abi and ir output modes - interface_codes: Dict, optional - Interfaces that may be imported by the contracts during compilation. - - * May be a singular dictionary shared across all sources to be compiled, - i.e. `{'interface name': "definition"}` - * or may be organized according to contracts that are being compiled, i.e. - `{'contract name': {'interface name': "definition"}` - - * Interface definitions are formatted as: `{'type': "json/vyper", 'code': "interface code"}` - * JSON interfaces are given as lists, vyper interfaces as strings - no_bytecode_metadata: bool, optional - Do not add metadata to bytecode. Defaults to False - - Returns - ------- - Dict - Compiler output as `{'contract name': {'output key': "output data"}}` - """ - settings = settings or Settings() - - if output_formats is None: - output_formats = ("bytecode",) - if isinstance(output_formats, Sequence): - output_formats = dict((k, output_formats) for k in contract_sources.keys()) - - out: OrderedDict = OrderedDict() - for source_id, contract_name in enumerate(sorted(contract_sources), start=initial_id): - source_code = contract_sources[contract_name] - interfaces: Any = interface_codes - storage_layout_override = None - if storage_layouts and contract_name in storage_layouts: - storage_layout_override = storage_layouts[contract_name] - - if ( - isinstance(interfaces, dict) - and contract_name in interfaces - and isinstance(interfaces[contract_name], dict) - ): - interfaces = interfaces[contract_name] - - # make IR output the same between runs - codegen.reset_names() - - compiler_data = CompilerData( - source_code, - contract_name, - interfaces, - source_id, - settings, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, - ) - with anchor_evm_version(compiler_data.settings.evm_version): - for output_format in output_formats[contract_name]: - if output_format not in OUTPUT_FORMATS: - raise ValueError(f"Unsupported format type {repr(output_format)}") - try: - out.setdefault(contract_name, {}) - formatter = OUTPUT_FORMATS[output_format] - out[contract_name][output_format] = formatter(compiler_data) - except Exception as exc: - if exc_handler is not None: - exc_handler(contract_name, exc) - else: - raise exc - - return out - - UNKNOWN_CONTRACT_NAME = "" def compile_code( contract_source: str, - output_formats: Optional[OutputFormats] = None, - interface_codes: Optional[InterfaceImports] = None, + contract_name: str = UNKNOWN_CONTRACT_NAME, + source_id: int = 0, + input_bundle: InputBundle = None, settings: Settings = None, + output_formats: Optional[OutputFormats] = None, storage_layout_override: Optional[StorageLayout] = None, + no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, + exc_handler: Optional[Callable] = None, ) -> dict: """ - Generate compiler output(s) from a single contract source code. + Generate consumable compiler output(s) from a single contract source code. + Basically, a wrapper around CompilerData which munges the output + data into the requested output formats. Arguments --------- @@ -175,11 +75,11 @@ def compile_code( Compiler settings. show_gas_estimates: bool, optional Show gas estimates for abi and ir output modes - interface_codes: Dict, optional - Interfaces that may be imported by the contracts during compilation. - - * Formatted as as `{'interface name': {'type': "json/vyper", 'code': "interface code"}}` - * JSON interfaces are given as lists, vyper interfaces as strings + exc_handler: Callable, optional + Callable used to handle exceptions if the compilation fails. Should accept + two arguments - the name of the contract, and the exception that was raised + no_bytecode_metadata: bool, optional + Do not add metadata to bytecode. Defaults to False Returns ------- @@ -187,14 +87,37 @@ def compile_code( Compiler output as `{'output key': "output data"}` """ - contract_sources = {UNKNOWN_CONTRACT_NAME: contract_source} - storage_layouts = {UNKNOWN_CONTRACT_NAME: storage_layout_override} + settings = settings or Settings() + + if output_formats is None: + output_formats = ("bytecode",) - return compile_codes( - contract_sources, - output_formats, - interface_codes=interface_codes, - settings=settings, - storage_layouts=storage_layouts, - show_gas_estimates=show_gas_estimates, - )[UNKNOWN_CONTRACT_NAME] + # make IR output the same between runs + codegen.reset_names() + + compiler_data = CompilerData( + contract_source, + input_bundle, + Path(contract_name), + source_id, + settings, + storage_layout_override, + show_gas_estimates, + no_bytecode_metadata, + ) + + ret = {} + with anchor_evm_version(compiler_data.settings.evm_version): + for output_format in output_formats: + if output_format not in OUTPUT_FORMATS: + raise ValueError(f"Unsupported format type {repr(output_format)}") + try: + formatter = OUTPUT_FORMATS[output_format] + ret[output_format] = formatter(compiler_data) + except Exception as exc: + if exc_handler is not None: + exc_handler(contract_name, exc) + else: + raise exc + + return ret diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py new file mode 100644 index 0000000000..1e41c3f137 --- /dev/null +++ b/vyper/compiler/input_bundle.py @@ -0,0 +1,180 @@ +import contextlib +import json +import os +from dataclasses import dataclass +from pathlib import Path, PurePath +from typing import Any, Iterator, Optional + +from vyper.exceptions import JSONError + +# a type to make mypy happy +PathLike = Path | PurePath + + +@dataclass +class CompilerInput: + # an input to the compiler, basically an abstraction for file contents + source_id: int + path: PathLike + + @staticmethod + def from_string(source_id: int, path: PathLike, file_contents: str) -> "CompilerInput": + try: + s = json.loads(file_contents) + return ABIInput(source_id, path, s) + except (ValueError, TypeError): + return FileInput(source_id, path, file_contents) + + +@dataclass +class FileInput(CompilerInput): + source_code: str + + +@dataclass +class ABIInput(CompilerInput): + # some json input, which has already been parsed into a dict or list + # this is needed because json inputs present json interfaces as json + # objects, not as strings. this class helps us avoid round-tripping + # back to a string to pretend it's a file. + abi: Any # something that json.load() returns + + +class _NotFound(Exception): + pass + + +# wrap os.path.normpath, but return the same type as the input +def _normpath(path): + return path.__class__(os.path.normpath(path)) + + +# an "input bundle" to the compiler, representing the files which are +# available to the compiler. it is useful because it parametrizes I/O +# operations over different possible input types. you can think of it +# as a virtual filesystem which models the compiler's interactions +# with the outside world. it exposes a "load_file" operation which +# searches for a file from a set of search paths, and also provides +# id generation service to get a unique source id per file. +class InputBundle: + # a list of search paths + search_paths: list[PathLike] + + def __init__(self, search_paths): + self.search_paths = search_paths + self._source_id_counter = 0 + self._source_ids: dict[PathLike, int] = {} + + def _load_from_path(self, path): + raise NotImplementedError(f"not implemented! {self.__class__}._load_from_path()") + + def _generate_source_id(self, path: PathLike) -> int: + if path not in self._source_ids: + self._source_ids[path] = self._source_id_counter + self._source_id_counter += 1 + + return self._source_ids[path] + + def load_file(self, path: PathLike | str) -> CompilerInput: + # search path precedence + tried = [] + for sp in reversed(self.search_paths): + # note from pathlib docs: + # > If the argument is an absolute path, the previous path is ignored. + # Path("/a") / Path("/b") => Path("/b") + to_try = sp / path + + # normalize the path with os.path.normpath, to break down + # things like "foo/bar/../x.vy" => "foo/x.vy", with all + # the caveats around symlinks that os.path.normpath comes with. + to_try = _normpath(to_try) + try: + res = self._load_from_path(to_try) + break + except _NotFound: + tried.append(to_try) + + else: + formatted_search_paths = "\n".join([" " + str(p) for p in tried]) + raise FileNotFoundError( + f"could not find {path} in any of the following locations:\n" + f"{formatted_search_paths}" + ) + + # try to parse from json, so that return types are consistent + # across FilesystemInputBundle and JSONInputBundle. + if isinstance(res, FileInput): + return CompilerInput.from_string(res.source_id, res.path, res.source_code) + + return res + + def add_search_path(self, path: PathLike) -> None: + self.search_paths.append(path) + + # temporarily add something to the search path (within the + # scope of the context manager) with highest precedence. + # if `path` is None, do nothing + @contextlib.contextmanager + def search_path(self, path: Optional[PathLike]) -> Iterator[None]: + if path is None: + yield # convenience, so caller does not have to handle null path + + else: + self.search_paths.append(path) + try: + yield + finally: + self.search_paths.pop() + + +# regular input. takes a search path(s), and `load_file()` will search all +# search paths for the file and read it from the filesystem +class FilesystemInputBundle(InputBundle): + def _load_from_path(self, path: Path) -> CompilerInput: + try: + with path.open() as f: + code = f.read() + except FileNotFoundError: + raise _NotFound(path) + + source_id = super()._generate_source_id(path) + + return FileInput(source_id, path, code) + + +# fake filesystem for JSON inputs. takes a base path, and `load_file()` +# "reads" the file from the JSON input. Note that this input bundle type +# never actually interacts with the filesystem -- it is guaranteed to be pure! +class JSONInputBundle(InputBundle): + input_json: dict[PurePath, Any] + + def __init__(self, input_json, search_paths): + super().__init__(search_paths) + self.input_json = {} + for path, item in input_json.items(): + path = _normpath(path) + + # should be checked by caller + assert path not in self.input_json + self.input_json[_normpath(path)] = item + + def _load_from_path(self, path: PurePath) -> CompilerInput: + try: + value = self.input_json[path] + except KeyError: + raise _NotFound(path) + + source_id = super()._generate_source_id(path) + + if "content" in value: + return FileInput(source_id, path, value["content"]) + + if "abi" in value: + return ABIInput(source_id, path, value["abi"]) + + # TODO: ethPM support + # if isinstance(contents, dict) and "contractTypes" in contents: + + # unreachable, based on how JSONInputBundle is constructed in + # the codebase. + raise JSONError(f"Unexpected type in file: '{path}'") # pragma: nocover diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 1c38fcff9b..e47f300ba9 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -1,6 +1,5 @@ import warnings from collections import OrderedDict, deque -from pathlib import Path import asttokens @@ -17,7 +16,7 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: ast_dict = { - "contract_name": compiler_data.contract_name, + "contract_name": str(compiler_data.contract_path), "ast": ast_to_dict(compiler_data.vyper_module), } return ast_dict @@ -35,7 +34,7 @@ def build_userdoc(compiler_data: CompilerData) -> dict: def build_external_interface_output(compiler_data: CompilerData) -> str: interface = compiler_data.vyper_module_folded._metadata["type"] - stem = Path(compiler_data.contract_name).stem + stem = compiler_data.contract_path.stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface name = "".join([x.capitalize() for x in stem.split("_")]) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 72be4396e4..bfbb336d54 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -1,6 +1,7 @@ import copy import warnings from functools import cached_property +from pathlib import Path, PurePath from typing import Optional, Tuple from vyper import ast as vy_ast @@ -8,12 +9,15 @@ from vyper.codegen.core import anchor_opt_level from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode +from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT -from vyper.typing import InterfaceImports, StorageLayout +from vyper.typing import StorageLayout + +DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") class CompilerData: @@ -49,8 +53,8 @@ class CompilerData: def __init__( self, source_code: str, - contract_name: str = "VyperContract", - interface_codes: Optional[InterfaceImports] = None, + input_bundle: InputBundle = None, + contract_path: Path | PurePath = DEFAULT_CONTRACT_NAME, source_id: int = 0, settings: Settings = None, storage_layout: StorageLayout = None, @@ -62,15 +66,11 @@ def __init__( Arguments --------- - source_code : str + source_code: str Vyper source code. - contract_name : str, optional + contract_path: Path, optional The name of the contract being compiled. - interface_codes: Dict, optional - Interfaces that may be imported by the contracts during compilation. - * Formatted as as `{'interface name': {'type': "json/vyper", 'code': "interface code"}}` - * JSON interfaces are given as lists, vyper interfaces as strings - source_id : int, optional + source_id: int, optional ID number used to identify this contract in the source map. settings: Settings Set optimization mode. @@ -79,20 +79,22 @@ def __init__( no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False """ - self.contract_name = contract_name + self.contract_path = contract_path self.source_code = source_code - self.interface_codes = interface_codes self.source_id = source_id self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata + self.settings = settings or Settings() + self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) _ = self._generate_ast # force settings to be calculated @cached_property def _generate_ast(self): - settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + contract_name = str(self.contract_path) + settings, ast = generate_ast(self.source_code, self.source_id, contract_name) # validate the compiler settings # XXX: this is a bit ugly, clean up later @@ -133,12 +135,12 @@ def vyper_module_unfolded(self) -> vy_ast.Module: # This phase is intended to generate an AST for tooling use, and is not # used in the compilation process. - return generate_unfolded_ast(self.vyper_module, self.interface_codes) + return generate_unfolded_ast(self.contract_path, self.vyper_module, self.input_bundle) @cached_property def _folded_module(self): return generate_folded_ast( - self.vyper_module, self.interface_codes, self.storage_layout_override + self.contract_path, self.vyper_module, self.input_bundle, self.storage_layout_override ) @property @@ -220,7 +222,7 @@ def generate_ast( Vyper source code. source_id : int ID number used to identify this contract in the source map. - contract_name : str + contract_name: str Name of the contract. Returns @@ -231,20 +233,24 @@ def generate_ast( return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name) +# destructive -- mutates module in place! def generate_unfolded_ast( - vyper_module: vy_ast.Module, interface_codes: Optional[InterfaceImports] + contract_path: Path | PurePath, vyper_module: vy_ast.Module, input_bundle: InputBundle ) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) - # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, interface_codes) + + with input_bundle.search_path(contract_path.parent): + # note: validate_semantics does type inference on the AST + validate_semantics(vyper_module, input_bundle) return vyper_module def generate_folded_ast( + contract_path: Path, vyper_module: vy_ast.Module, - interface_codes: Optional[InterfaceImports], + input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, ) -> Tuple[vy_ast.Module, StorageLayout]: """ @@ -262,11 +268,15 @@ def generate_folded_ast( StorageLayout Layout of variables in storage """ + vy_ast.validation.validate_literal_nodes(vyper_module) vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) - validate_semantics(vyper_module_folded, interface_codes) + + with input_bundle.search_path(contract_path.parent): + validate_semantics(vyper_module_folded, input_bundle) + symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) return vyper_module_folded, symbol_tables diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 9e987d1cd0..7db230167e 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -7,11 +7,11 @@ from .utils import _ExprAnalyser -def validate_semantics(vyper_ast, interface_codes): +def validate_semantics(vyper_ast, input_bundle): # validate semantics and annotate AST with type/semantics information namespace = get_namespace() with namespace.enter_scope(): - add_module_namespace(vyper_ast, interface_codes) + add_module_namespace(vyper_ast, input_bundle) vy_ast.expansion.expand_annotated_ast(vyper_ast) validate_functions(vyper_ast) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e59422294c..239438f35b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,13 +1,13 @@ -import importlib -import pkgutil -from typing import Optional, Union +import os +from pathlib import Path, PurePath +from typing import Optional import vyper.builtins.interfaces from vyper import ast as vy_ast +from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, - CompilerPanic, ExceptionList, InvalidLiteral, InvalidType, @@ -15,30 +15,27 @@ StateAccessViolation, StructureException, SyntaxException, - UndeclaredDefinition, VariableDeclarationException, VyperException, ) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase -from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.analysis.utils import check_constant, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.utils import type_from_annotation -from vyper.typing import InterfaceDict -def add_module_namespace(vy_module: vy_ast.Module, interface_codes: InterfaceDict) -> None: +def add_module_namespace(vy_module: vy_ast.Module, input_bundle: InputBundle) -> None: """ Analyze a Vyper module AST node, add all module-level objects to the namespace and validate top-level correctness """ namespace = get_namespace() - ModuleAnalyzer(vy_module, interface_codes, namespace) + ModuleAnalyzer(vy_module, input_bundle, namespace) def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: @@ -58,10 +55,10 @@ class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, module_node: vy_ast.Module, interface_codes: InterfaceDict, namespace: Namespace + self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace ) -> None: self.ast = module_node - self.interface_codes = interface_codes or {} + self.input_bundle = input_bundle self.namespace = namespace # TODO: Move computation out of constructor @@ -287,17 +284,19 @@ def visit_FunctionDef(self, node): def visit_Import(self, node): if not node.alias: raise StructureException("Import requires an accompanying `as` statement", node) - _add_import(node, node.name, node.alias, node.alias, self.interface_codes, self.namespace) + # import x.y[name] as y[alias] + self._add_import(node, 0, node.name, node.alias) def visit_ImportFrom(self, node): - _add_import( - node, - node.module, - node.name, - node.alias or node.name, - self.interface_codes, - self.namespace, - ) + # from m.n[module] import x[name] as y[alias] + alias = node.alias or node.name + + module = node.module or "" + if module: + module += "." + + qualified_module_name = module + node.name + self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): obj = InterfaceT.from_ast(node) @@ -313,41 +312,87 @@ def visit_StructDef(self, node): except VyperException as exc: raise exc.with_annotation(node) from None + def _add_import( + self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str + ) -> None: + type_ = self._load_import(level, qualified_module_name) + + try: + self.namespace[alias] = type_ + except VyperException as exc: + raise exc.with_annotation(node) from None + + # load an InterfaceT from an import. + # raises FileNotFoundError + def _load_import(self, level: int, module_str: str) -> InterfaceT: + if _is_builtin(module_str): + return _load_builtin_import(level, module_str) + + path = _import_to_path(level, module_str) + + try: + file = self.input_bundle.load_file(path.with_suffix(".vy")) + assert isinstance(file, FileInput) # mypy hint + interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=str(file.path)) + return InterfaceT.from_ast(interface_ast) + except FileNotFoundError: + pass + + try: + file = self.input_bundle.load_file(path.with_suffix(".json")) + assert isinstance(file, ABIInput) # mypy hint + return InterfaceT.from_json_abi(str(file.path), file.abi) + except FileNotFoundError: + raise ModuleNotFoundError(module_str) + + +# convert an import to a path (without suffix) +def _import_to_path(level: int, module_str: str) -> PurePath: + base_path = "" + if level > 1: + base_path = "../" * (level - 1) + elif level == 1: + base_path = "./" + return PurePath(f"{base_path}{module_str.replace('.','/')}/") + + +# can add more, e.g. "vyper.builtins.interfaces", etc. +BUILTIN_PREFIXES = ["vyper.interfaces"] + + +def _is_builtin(module_str): + return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) + + +def _load_builtin_import(level: int, module_str: str) -> InterfaceT: + if not _is_builtin(module_str): + raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None + + builtins_path = vyper.builtins.interfaces.__path__[0] + # hygiene: convert to relpath to avoid leaking user directory info + # (note Path.relative_to cannot handle absolute to relative path + # conversion, so we must use the `os` module). + builtins_path = os.path.relpath(builtins_path) + + search_path = Path(builtins_path).parent.parent.parent + # generate an input bundle just because it knows how to build paths. + input_bundle = FilesystemInputBundle([search_path]) + + # remap builtins directory -- + # vyper/interfaces => vyper/builtins/interfaces + remapped_module = module_str + if remapped_module.startswith("vyper.interfaces"): + remapped_module = remapped_module.removeprefix("vyper.interfaces") + remapped_module = vyper.builtins.interfaces.__package__ + remapped_module -def _add_import( - node: Union[vy_ast.Import, vy_ast.ImportFrom], - module: str, - name: str, - alias: str, - interface_codes: InterfaceDict, - namespace: dict, -) -> None: - if module == "vyper.interfaces": - interface_codes = _get_builtin_interfaces() - if name not in interface_codes: - suggestions_str = get_levenshtein_error_suggestions(name, _get_builtin_interfaces(), 1.0) - raise UndeclaredDefinition(f"Unknown interface: {name}. {suggestions_str}", node) - - if interface_codes[name]["type"] == "vyper": - interface_ast = vy_ast.parse_to_ast(interface_codes[name]["code"], contract_name=name) - type_ = InterfaceT.from_ast(interface_ast) - elif interface_codes[name]["type"] == "json": - type_ = InterfaceT.from_json_abi(name, interface_codes[name]["code"]) # type: ignore - else: - raise CompilerPanic(f"Unknown interface format: {interface_codes[name]['type']}") + path = _import_to_path(level, remapped_module).with_suffix(".vy") try: - namespace[alias] = type_ - except VyperException as exc: - raise exc.with_annotation(node) from None - - -def _get_builtin_interfaces(): - interface_names = [i.name for i in pkgutil.iter_modules(vyper.builtins.interfaces.__path__)] - return { - name: { - "type": "vyper", - "code": importlib.import_module(f"vyper.builtins.interfaces.{name}").interface_code, - } - for name in interface_names - } + file = input_bundle.load_file(path) + assert isinstance(file, FileInput) # mypy hint + except FileNotFoundError: + raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None + + # TODO: it might be good to cache this computation + interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=module_str) + return InterfaceT.from_ast(interface_ast) diff --git a/vyper/typing.py b/vyper/typing.py index 18e201e814..ad3964dff9 100644 --- a/vyper/typing.py +++ b/vyper/typing.py @@ -7,17 +7,9 @@ # Compiler ContractPath = str SourceCode = str -ContractCodes = Dict[ContractPath, SourceCode] OutputFormats = Sequence[str] -OutputDict = Dict[ContractPath, OutputFormats] StorageLayout = Dict -# Interfaces -InterfaceAsName = str -InterfaceImportPath = str -InterfaceImports = Dict[InterfaceAsName, InterfaceImportPath] -InterfaceDict = Dict[ContractPath, InterfaceImports] - # Opcodes OpcodeGasCost = Union[int, Tuple] OpcodeValue = Tuple[Optional[int], int, int, OpcodeGasCost] From 4f7661478d01b4f5e67d8e964c702f8b1973af16 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Tue, 7 Nov 2023 23:22:33 +0100 Subject: [PATCH 121/201] docs: update resources section (#3656) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docs: update `resources.rst` file * correct 🐍 snekmate branding * We like longer hyphens :) * Add `Foundry-Vyper` to bottom * Add titanoboa reference --- docs/resources.rst | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/docs/resources.rst b/docs/resources.rst index 7f0d0600a9..a3dfa480ed 100644 --- a/docs/resources.rst +++ b/docs/resources.rst @@ -3,45 +3,47 @@ Other resources and learning material ##################################### -Vyper has an active community. You can find third party tutorials, -examples, courses and other learning material. +Vyper has an active community. You can find third-party tutorials, examples, courses, and other learning material. General ------- -- `Ape Academy - Learn how to build vyper projects `__ by ApeWorX -- `More Vyper by Example `__ by Smart Contract Engineer -- `Vyper cheat Sheet `__ -- `Vyper Hub for development `__ -- `Vyper greatest hits smart contract examples `__ +- `Ape Academy – Learn how to build Vyper projects `_ by ApeWorX +- `More Vyper by Example `_ by Smart Contract Engineer +- `Vyper cheat Sheet `_ +- `Vyper Hub for development `_ +- `Vyper greatest hits smart contract examples `_ +- `A curated list of Vyper resources, libraries, tools, and more `_ Frameworks and tooling ---------------------- -- `ApeWorX - The Ethereum development framework for Python Developers, Data Scientists, and Security Professionals `__ -- `Foundry x Vyper - Foundry template to compile Vyper contracts `__ -- `Snekmate - Vyper smart contract building blocks `__ -- `Serpentor - A set of smart contracts tools for governance `__ -- `Smart contract development frameworks and tools for Vyper on Ethreum.org `__ +- `Titanoboa – An experimental Vyper interpreter with pretty tracebacks, forking, debugging features and more `_ +- `ApeWorX – The Ethereum development framework for Python Developers, Data Scientists, and Security Professionals `_ +- `VyperDeployer – A helper smart contract to compile and test Vyper contracts in Foundry `_ +- `🐍 snekmate – Vyper smart contract building blocks `_ +- `Serpentor – A set of smart contracts tools for governance `_ +- `Smart contract development frameworks and tools for Vyper on Ethreum.org `_ Security -------- -- `VyperPunk - learn to secure and hack Vyper smart contracts `__ -- `VyperExamples - Vyper vulnerability examples `__ +- `VyperPunk – learn to secure and hack Vyper smart contracts `_ +- `VyperExamples – Vyper vulnerability examples `_ Conference presentations ------------------------ -- `Vyper Smart Contract Programming Language by Patrick Collins (2022, 30 mins) `__ -- `Python and DeFi by Curve Finance (2022, 15 mins) `__ -- `My experience with Vyper over the years by Benjamin Scherrey (2022, 15 mins) `__ -- `Short introduction to Vyper by Edison Que (3 mins) `__ +- `Vyper Smart Contract Programming Language by Patrick Collins (2022, 30 mins) `_ +- `Python and DeFi by Curve Finance (2022, 15 mins) `_ +- `My experience with Vyper over the years by Benjamin Scherrey (2022, 15 mins) `_ +- `Short introduction to Vyper by Edison Que (3 mins) `_ Unmaintained ------------ These resources have not been updated for a while, but may still offer interesting content. -- `Awesome Vyper curated resources `__ -- `Brownie - Python framework for developing smart contracts (deprecated) `__ +- `Awesome Vyper curated resources `_ +- `Brownie – Python framework for developing smart contracts (deprecated) `_ +- `Foundry x Vyper – Foundry template to compile Vyper contracts `_ From a87fb8c42c74f418b230003bcc66766a6c374f86 Mon Sep 17 00:00:00 2001 From: nfwsncked <75267395+nfwsncked@users.noreply.github.com> Date: Wed, 8 Nov 2023 00:41:08 +0100 Subject: [PATCH 122/201] chore: improve assert descriptions in crowdfund.vy (#3064) --------- Co-authored-by: Charles Cooper Co-authored-by: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com> --- examples/crowdfund.vy | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 3891ad0b74..56b34308f1 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -18,15 +18,15 @@ def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): @external @payable def participate(): - assert block.timestamp < self.deadline, "deadline not met (yet)" + assert block.timestamp < self.deadline, "deadline has expired" self.funders[msg.sender] += msg.value # Enough money was raised! Send funds to the beneficiary @external def finalize(): - assert block.timestamp >= self.deadline, "deadline has passed" - assert self.balance >= self.goal, "the goal has not been reached" + assert block.timestamp >= self.deadline, "deadline has not expired yet" + assert self.balance >= self.goal, "goal has not been reached" selfdestruct(self.beneficiary) From 806dd9075e83eaabbfbaa397c48c9703317b6154 Mon Sep 17 00:00:00 2001 From: Franfran <51274081+iFrostizz@users.noreply.github.com> Date: Wed, 8 Nov 2023 00:50:41 +0100 Subject: [PATCH 123/201] test: add unit tests for internal abi type construction (#3662) specifically test some internal sanity checks/validation --------- Co-authored-by: Charles Cooper --- tests/abi_types/test_invalid_abi_types.py | 26 +++++++++++++++++++++++ vyper/abi_types.py | 16 +++++++------- vyper/exceptions.py | 4 ++++ 3 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 tests/abi_types/test_invalid_abi_types.py diff --git a/tests/abi_types/test_invalid_abi_types.py b/tests/abi_types/test_invalid_abi_types.py new file mode 100644 index 0000000000..c8566e066f --- /dev/null +++ b/tests/abi_types/test_invalid_abi_types.py @@ -0,0 +1,26 @@ +import pytest + +from vyper.abi_types import ( + ABI_Bytes, + ABI_BytesM, + ABI_DynamicArray, + ABI_FixedMxN, + ABI_GIntM, + ABI_String, +) +from vyper.exceptions import InvalidABIType + +cases_invalid_types = [ + (ABI_GIntM, ((0, False), (7, False), (300, True), (300, False))), + (ABI_FixedMxN, ((0, 0, False), (8, 0, False), (256, 81, True), (300, 80, False))), + (ABI_BytesM, ((0,), (33,), (-10,))), + (ABI_Bytes, ((-1,), (-69,))), + (ABI_DynamicArray, ((ABI_GIntM(256, False), -1), (ABI_String(256), -10))), +] + + +@pytest.mark.parametrize("typ,params_variants", cases_invalid_types) +def test_invalid_abi_types(assert_compile_failed, typ, params_variants): + # double parametrization cannot work because the 2nd dimension is variable + for params in params_variants: + assert_compile_failed(lambda: typ(*params), InvalidABIType) diff --git a/vyper/abi_types.py b/vyper/abi_types.py index b272996aed..051f8db19f 100644 --- a/vyper/abi_types.py +++ b/vyper/abi_types.py @@ -1,4 +1,4 @@ -from vyper.exceptions import CompilerPanic +from vyper.exceptions import InvalidABIType from vyper.utils import ceil32 @@ -69,7 +69,7 @@ def __repr__(self): class ABI_GIntM(ABIType): def __init__(self, m_bits, signed): if not (0 < m_bits <= 256 and 0 == m_bits % 8): - raise CompilerPanic("Invalid M provided for GIntM") + raise InvalidABIType("Invalid M provided for GIntM") self.m_bits = m_bits self.signed = signed @@ -117,9 +117,9 @@ def selector_name(self): class ABI_FixedMxN(ABIType): def __init__(self, m_bits, n_places, signed): if not (0 < m_bits <= 256 and 0 == m_bits % 8): - raise CompilerPanic("Invalid M for FixedMxN") + raise InvalidABIType("Invalid M for FixedMxN") if not (0 < n_places and n_places <= 80): - raise CompilerPanic("Invalid N for FixedMxN") + raise InvalidABIType("Invalid N for FixedMxN") self.m_bits = m_bits self.n_places = n_places @@ -142,7 +142,7 @@ def is_complex_type(self): class ABI_BytesM(ABIType): def __init__(self, m_bytes): if not 0 < m_bytes <= 32: - raise CompilerPanic("Invalid M for BytesM") + raise InvalidABIType("Invalid M for BytesM") self.m_bytes = m_bytes @@ -173,7 +173,7 @@ def selector_name(self): class ABI_StaticArray(ABIType): def __init__(self, subtyp, m_elems): if not m_elems >= 0: - raise CompilerPanic("Invalid M") + raise InvalidABIType("Invalid M") self.subtyp = subtyp self.m_elems = m_elems @@ -200,7 +200,7 @@ def is_complex_type(self): class ABI_Bytes(ABIType): def __init__(self, bytes_bound): if not bytes_bound >= 0: - raise CompilerPanic("Negative bytes_bound provided to ABI_Bytes") + raise InvalidABIType("Negative bytes_bound provided to ABI_Bytes") self.bytes_bound = bytes_bound @@ -234,7 +234,7 @@ def selector_name(self): class ABI_DynamicArray(ABIType): def __init__(self, subtyp, elems_bound): if not elems_bound >= 0: - raise CompilerPanic("Negative bound provided to DynamicArray") + raise InvalidABIType("Negative bound provided to DynamicArray") self.subtyp = subtyp self.elems_bound = elems_bound diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 8b2020285a..3bde20356e 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -336,3 +336,7 @@ class UnfoldableNode(VyperInternalException): class TypeCheckFailure(VyperInternalException): """An issue was not caught during type checking that should have been.""" + + +class InvalidABIType(VyperInternalException): + """An internal routine constructed an invalid ABI type""" From c0d8a0dabf488adec0415e621c75c001584fb991 Mon Sep 17 00:00:00 2001 From: engn33r Date: Wed, 8 Nov 2023 12:57:56 +0000 Subject: [PATCH 124/201] chore: remove redundant help text (#3657) `--hex-ir` was redundantly reported in the format option help text (in addition to being listed as a regular option. the latter is correct) --- vyper/cli/vyper_compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index c4f60660cb..82eba63f32 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -42,7 +42,6 @@ ir_json - Intermediate representation in JSON format ir_runtime - Intermediate representation of runtime bytecode in list format asm - Output the EVM assembly of the deployable bytecode -hex-ir - Output IR and assembly constants in hex instead of decimal """ combined_json_outputs = [ From 4dd47e302fc538ca4fc6fe29d992f0b59456f8e2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 8 Nov 2023 14:38:19 -0500 Subject: [PATCH 125/201] refactor: test directory structure (#3664) refactor: test directory structure - consolidate a bunch of different directories - move files around - merge `conftest.py` and `base_conftest.py` - better organization - unit/ vs functional/ tests - rename old `parser/` and `parser_utils/` directories there is still more refactoring that could be done, and probably some files could be merged / specific tests in files could be moved around, but this commit tried to only touch the directory structure for the sake of reducing merge conflicts in PRs which are still open --- tests/base_conftest.py | 217 ------------- tests/conftest.py | 284 +++++++++++++++--- tests/{compiler/ir => functional}/__init__.py | 0 .../builtins/codegen}/__init__.py | 0 .../builtins/codegen}/test_abi.py | 0 .../builtins/codegen}/test_abi_decode.py | 0 .../builtins/codegen}/test_abi_encode.py | 0 .../builtins/codegen}/test_addmod.py | 0 .../builtins/codegen}/test_as_wei_value.py | 33 ++ .../builtins/codegen}/test_bitwise.py | 0 .../builtins/codegen}/test_ceil.py | 0 .../builtins/codegen}/test_concat.py | 0 .../builtins/codegen}/test_convert.py | 0 .../codegen}/test_create_functions.py | 0 .../builtins/codegen}/test_ec.py | 0 .../builtins/codegen}/test_ecrecover.py | 0 .../builtins/codegen}/test_empty.py | 0 .../builtins/codegen}/test_extract32.py | 0 .../builtins/codegen}/test_floor.py | 0 .../builtins/codegen}/test_interfaces.py | 0 .../builtins/codegen}/test_is_contract.py | 0 .../builtins/codegen}/test_keccak256.py | 0 .../builtins/codegen}/test_length.py | 0 .../builtins/codegen}/test_method_id.py | 0 .../builtins/codegen}/test_minmax.py | 0 .../builtins/codegen}/test_minmax_value.py | 0 .../builtins/codegen}/test_mulmod.py | 0 .../builtins/codegen}/test_raw_call.py | 0 .../builtins/codegen}/test_send.py | 0 .../builtins/codegen}/test_sha256.py | 0 .../builtins/codegen}/test_slice.py | 0 .../builtins/codegen/test_uint2str.py} | 0 .../builtins/codegen}/test_unary.py | 0 .../builtins/codegen}/test_unsafe_math.py | 0 .../builtins/folding/test_abs.py | 0 .../builtins/folding/test_addmod_mulmod.py | 0 .../builtins/folding/test_bitwise.py | 0 .../builtins/folding/test_epsilon.py | 0 .../builtins/folding/test_floor_ceil.py | 0 .../folding/test_fold_as_wei_value.py | 0 .../builtins/folding/test_keccak_sha.py | 0 .../builtins/folding/test_len.py | 0 .../builtins/folding/test_min_max.py | 0 .../builtins/folding/test_powmod.py | 0 .../codegen}/__init__.py | 0 .../test_default_function.py | 0 .../test_default_parameters.py | 0 .../calling_convention}/test_erc20_abi.py | 0 .../test_external_contract_calls.py | 0 ...test_modifiable_external_contract_calls.py | 0 .../calling_convention}/test_return.py | 0 .../calling_convention}/test_return_struct.py | 0 .../calling_convention}/test_return_tuple.py | 0 .../test_self_call_struct.py | 0 .../test_struct_return.py | 0 .../test_tuple_return.py | 0 .../test_block_number.py | 0 .../environment_variables/test_blockhash.py} | 0 .../codegen/environment_variables}/test_tx.py | 0 .../features/decorators/test_nonreentrant.py | 0 .../features/decorators/test_payable.py | 0 .../features/decorators/test_private.py | 0 .../features/decorators/test_public.py | 0 .../codegen}/features/decorators/test_pure.py | 0 .../codegen}/features/decorators/test_view.py | 0 .../codegen}/features/iteration/test_break.py | 0 .../features/iteration/test_continue.py | 0 .../features/iteration/test_for_in_list.py | 0 .../features/iteration/test_for_range.py | 0 .../features/iteration/test_range_in.py | 0 .../codegen}/features/test_address_balance.py | 0 .../codegen}/features/test_assert.py | 0 .../features/test_assert_unreachable.py | 0 .../codegen}/features/test_assignment.py | 0 .../codegen}/features/test_bytes_map_keys.py | 0 .../codegen}/features/test_clampers.py | 0 .../codegen}/features/test_comments.py | 0 .../codegen}/features/test_comparison.py | 0 .../codegen}/features/test_conditionals.py | 0 .../codegen}/features/test_constructor.py | 0 .../codegen}/features/test_gas.py | 0 .../codegen}/features/test_immutable.py | 0 .../codegen}/features/test_init.py | 0 .../codegen}/features/test_internal_call.py | 0 .../codegen}/features/test_logging.py | 0 .../features/test_logging_bytes_extended.py | 0 .../features/test_logging_from_call.py | 0 .../codegen}/features/test_memory_alloc.py | 0 .../codegen}/features/test_memory_dealloc.py | 0 .../codegen}/features/test_packing.py | 0 .../codegen}/features/test_reverting.py | 0 .../features/test_short_circuiting.py | 0 .../codegen}/features/test_string_map_keys.py | 0 .../codegen}/features/test_ternary.py | 0 .../codegen}/features/test_transient.py | 0 .../codegen}/integration/test_basics.py | 0 .../codegen}/integration/test_crowdfund.py | 1 + .../codegen}/integration/test_escrow.py | 0 .../storage_variables}/test_getters.py | 0 .../storage_variables}/test_setters.py | 0 .../test_storage_variable.py} | 0 .../codegen}/test_call_graph_stability.py | 0 .../codegen}/test_selector_table.py | 0 .../codegen}/test_selector_table_stability.py | 0 .../codegen}/types/numbers/test_constants.py | 0 .../codegen}/types/numbers/test_decimals.py | 0 .../codegen/types/numbers}/test_division.py | 0 .../codegen/types/numbers}/test_exponents.py | 0 .../codegen}/types/numbers/test_isqrt.py | 0 .../codegen/types/numbers}/test_modulo.py | 0 .../types/numbers/test_signed_ints.py | 0 .../codegen}/types/numbers/test_sqrt.py | 0 .../types/numbers/test_unsigned_ints.py | 0 .../codegen}/types/test_bytes.py | 0 .../codegen}/types/test_bytes_literal.py | 0 .../codegen}/types/test_bytes_zero_padding.py | 0 .../codegen}/types/test_dynamic_array.py | 0 .../codegen}/types/test_enum.py | 0 .../codegen}/types/test_identifier_naming.py | 0 .../codegen}/types/test_lists.py | 0 .../codegen}/types/test_node_types.py | 0 .../codegen}/types/test_string.py | 0 .../codegen}/types/test_string_literal.py | 0 .../examples/auctions/test_blind_auction.py | 0 .../auctions/test_simple_open_auction.py | 0 .../examples/company/test_company.py | 0 tests/{ => functional}/examples/conftest.py | 0 .../crowdfund/test_crowdfund_example.py | 0 .../examples/factory/test_factory.py | 0 .../test_on_chain_market_maker.py | 0 .../name_registry/test_name_registry.py | 0 .../test_safe_remote_purchase.py | 0 .../examples/storage/test_advanced_storage.py | 0 .../examples/storage/test_storage.py | 0 .../examples/tokens/test_erc1155.py | 0 .../examples/tokens/test_erc20.py | 0 .../examples/tokens/test_erc4626.py | 0 .../examples/tokens/test_erc721.py | 0 .../examples/voting/test_ballot.py | 0 .../examples/wallet/test_wallet.py | 0 .../{ => functional}/grammar/test_grammar.py | 0 .../{parser => functional}/syntax/__init__.py | 0 .../exceptions/test_argument_exception.py | 0 .../syntax}/exceptions/test_call_violation.py | 0 .../exceptions/test_constancy_exception.py | 0 .../test_function_declaration_exception.py | 0 .../test_instantiation_exception.py | 0 .../test_invalid_literal_exception.py | 0 .../exceptions/test_invalid_payable.py | 0 .../exceptions/test_invalid_reference.py | 0 .../exceptions/test_invalid_type_exception.py | 0 .../exceptions/test_namespace_collision.py | 0 .../exceptions/test_overflow_exception.py | 0 .../exceptions/test_structure_exception.py | 0 .../exceptions/test_syntax_exception.py | 0 .../test_type_mismatch_exception.py | 0 .../exceptions/test_undeclared_definition.py | 0 .../test_variable_declaration_exception.py | 0 .../exceptions/test_vyper_exception_pos.py | 0 .../syntax/names}/test_event_names.py | 0 .../syntax/names}/test_function_names.py | 0 .../syntax/names}/test_variable_names.py | 0 .../test_invalid_function_decorators.py | 0 .../signatures/test_method_id_conflicts.py | 0 .../syntax/test_abi_decode.py | 0 .../syntax/test_abi_encode.py | 0 .../syntax/test_addmulmod.py | 0 .../syntax/test_address_code.py | 0 .../syntax/test_ann_assign.py | 0 .../syntax/test_as_uint256.py | 0 .../syntax/test_as_wei_value.py | 0 .../syntax/test_block.py | 0 .../syntax/test_blockscope.py | 0 .../syntax/test_bool.py | 0 .../syntax/test_bool_ops.py | 0 .../syntax/test_bytes.py | 0 .../syntax/test_chainid.py | 0 .../syntax/test_code_size.py | 0 .../syntax/test_codehash.py | 0 .../syntax/test_concat.py | 0 .../syntax/test_conditionals.py | 0 .../syntax/test_constants.py | 0 .../syntax/test_create_with_code_of.py | 0 .../syntax/test_dynamic_array.py | 0 .../syntax/test_enum.py | 0 .../syntax/test_extract32.py | 0 .../syntax/test_for_range.py | 0 .../syntax/test_functions_call.py | 0 .../syntax/test_immutables.py | 0 .../syntax/test_interfaces.py | 0 .../syntax/test_invalids.py | 0 .../syntax/test_keccak256.py | 0 .../{parser => functional}/syntax/test_len.py | 0 .../syntax/test_list.py | 0 .../syntax/test_logging.py | 0 .../syntax/test_minmax.py | 0 .../syntax/test_minmax_value.py | 0 .../syntax/test_msg_data.py | 0 .../syntax/test_nested_list.py | 0 .../syntax/test_no_none.py | 0 .../syntax/test_print.py | 0 .../syntax/test_public.py | 0 .../syntax/test_raw_call.py | 0 .../syntax/test_return_tuple.py | 0 .../syntax/test_self_balance.py | 0 .../syntax/test_selfdestruct.py | 0 .../syntax/test_send.py | 0 .../syntax/test_slice.py | 0 .../syntax/test_string.py | 0 .../syntax/test_structs.py | 0 .../syntax/test_ternary.py | 0 .../syntax/test_tuple_assign.py | 0 .../syntax/test_unbalanced_return.py | 0 tests/parser/functions/test_as_wei_value.py | 31 -- tests/unit/__init__.py | 0 .../abi_types/test_invalid_abi_types.py | 0 tests/{ => unit}/ast/nodes/test_binary.py | 0 .../ast/nodes/test_compare_nodes.py | 0 .../ast/nodes/test_evaluate_binop_decimal.py | 0 .../ast/nodes/test_evaluate_binop_int.py | 0 .../ast/nodes/test_evaluate_boolop.py | 0 .../ast/nodes/test_evaluate_compare.py | 0 .../ast/nodes/test_evaluate_subscript.py | 0 .../ast/nodes/test_evaluate_unaryop.py | 0 tests/{ => unit}/ast/nodes/test_from_node.py | 0 .../{ => unit}/ast/nodes/test_get_children.py | 0 .../ast/nodes/test_get_descendants.py | 0 tests/{ => unit}/ast/nodes/test_hex.py | 0 .../ast/nodes/test_replace_in_tree.py | 0 .../ast}/test_annotate_and_optimize_ast.py | 0 .../ast_utils => unit/ast}/test_ast_dict.py | 0 tests/{ => unit}/ast/test_folding.py | 0 tests/{ => unit}/ast/test_metadata_journal.py | 0 tests/{ => unit}/ast/test_natspec.py | 0 .../test_ast.py => unit/ast/test_parser.py} | 0 tests/{ => unit}/ast/test_pre_parser.py | 0 .../ast/test_source_annotation.py} | 0 .../cli/outputs/test_storage_layout.py | 0 .../outputs/test_storage_layout_overrides.py | 0 .../cli/vyper_compile/test_compile_files.py | 0 .../cli/vyper_compile/test_parse_args.py | 0 .../cli/vyper_json/test_compile_json.py | 0 .../cli/vyper_json/test_get_inputs.py | 0 .../cli/vyper_json/test_get_settings.py | 0 .../cli/vyper_json/test_output_selection.py | 0 .../vyper_json/test_parse_args_vyperjson.py | 0 tests/{ => unit}/compiler/__init__.py | 0 .../compiler/asm/test_asm_optimizer.py | 0 tests/unit/compiler/ir/__init__.py | 0 .../compiler/ir}/test_calldatacopy.py | 0 .../{ => unit}/compiler/ir/test_compile_ir.py | 0 .../compiler/ir/test_optimize_ir.py | 0 tests/{ => unit}/compiler/ir/test_repeat.py | 0 tests/{ => unit}/compiler/ir/test_with.py | 0 .../compiler/test_bytecode_runtime.py | 0 .../{ => unit}/compiler/test_compile_code.py | 0 .../compiler/test_default_settings.py | 0 .../{ => unit}/compiler/test_input_bundle.py | 0 tests/{ => unit}/compiler/test_opcodes.py | 0 tests/{ => unit}/compiler/test_pre_parser.py | 0 tests/{ => unit}/compiler/test_sha3_32.py | 0 tests/{ => unit}/compiler/test_source_map.py | 0 .../semantics/analysis/test_array_index.py | 0 .../analysis/test_cyclic_function_calls.py | 0 .../semantics/analysis/test_for_loop.py | 0 .../analysis/test_potential_types.py | 0 .../semantics/conftest.py | 0 .../semantics/test_namespace.py | 0 .../semantics}/test_storage_slots.py | 0 .../semantics/types/test_event.py | 0 .../semantics/types/test_pure_types.py | 0 .../semantics/types/test_size_in_bytes.py | 0 .../semantics/types/test_type_from_abi.py | 0 .../types/test_type_from_annotation.py | 0 274 files changed, 282 insertions(+), 284 deletions(-) delete mode 100644 tests/base_conftest.py rename tests/{compiler/ir => functional}/__init__.py (100%) rename tests/{parser/exceptions => functional/builtins/codegen}/__init__.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_abi.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_abi_decode.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_abi_encode.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_addmod.py (100%) rename tests/{parser/types/value => functional/builtins/codegen}/test_as_wei_value.py (75%) rename tests/{parser/functions => functional/builtins/codegen}/test_bitwise.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_ceil.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_concat.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_convert.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_create_functions.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_ec.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_ecrecover.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_empty.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_extract32.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_floor.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_interfaces.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_is_contract.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_keccak256.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_length.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_method_id.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_minmax.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_minmax_value.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_mulmod.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_raw_call.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_send.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_sha256.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_slice.py (100%) rename tests/{parser/functions/test_mkstr.py => functional/builtins/codegen/test_uint2str.py} (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_unary.py (100%) rename tests/{parser/functions => functional/builtins/codegen}/test_unsafe_math.py (100%) rename tests/{ => functional}/builtins/folding/test_abs.py (100%) rename tests/{ => functional}/builtins/folding/test_addmod_mulmod.py (100%) rename tests/{ => functional}/builtins/folding/test_bitwise.py (100%) rename tests/{ => functional}/builtins/folding/test_epsilon.py (100%) rename tests/{ => functional}/builtins/folding/test_floor_ceil.py (100%) rename tests/{ => functional}/builtins/folding/test_fold_as_wei_value.py (100%) rename tests/{ => functional}/builtins/folding/test_keccak_sha.py (100%) rename tests/{ => functional}/builtins/folding/test_len.py (100%) rename tests/{ => functional}/builtins/folding/test_min_max.py (100%) rename tests/{ => functional}/builtins/folding/test_powmod.py (100%) rename tests/{parser/functions => functional/codegen}/__init__.py (100%) rename tests/{parser/functions => functional/codegen/calling_convention}/test_default_function.py (100%) rename tests/{parser/functions => functional/codegen/calling_convention}/test_default_parameters.py (100%) rename tests/{parser/features/external_contracts => functional/codegen/calling_convention}/test_erc20_abi.py (100%) rename tests/{parser/features/external_contracts => functional/codegen/calling_convention}/test_external_contract_calls.py (100%) rename tests/{parser/features/external_contracts => functional/codegen/calling_convention}/test_modifiable_external_contract_calls.py (100%) rename tests/{parser/functions => functional/codegen/calling_convention}/test_return.py (100%) rename tests/{parser/functions => functional/codegen/calling_convention}/test_return_struct.py (100%) rename tests/{parser/functions => functional/codegen/calling_convention}/test_return_tuple.py (100%) rename tests/{parser/features/external_contracts => functional/codegen/calling_convention}/test_self_call_struct.py (100%) rename tests/functional/codegen/{ => calling_convention}/test_struct_return.py (100%) rename tests/functional/codegen/{ => calling_convention}/test_tuple_return.py (100%) rename tests/{parser/functions => functional/codegen/environment_variables}/test_block_number.py (100%) rename tests/{parser/functions/test_block.py => functional/codegen/environment_variables/test_blockhash.py} (100%) rename tests/{parser/functions => functional/codegen/environment_variables}/test_tx.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_nonreentrant.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_payable.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_private.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_public.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_pure.py (100%) rename tests/{parser => functional/codegen}/features/decorators/test_view.py (100%) rename tests/{parser => functional/codegen}/features/iteration/test_break.py (100%) rename tests/{parser => functional/codegen}/features/iteration/test_continue.py (100%) rename tests/{parser => functional/codegen}/features/iteration/test_for_in_list.py (100%) rename tests/{parser => functional/codegen}/features/iteration/test_for_range.py (100%) rename tests/{parser => functional/codegen}/features/iteration/test_range_in.py (100%) rename tests/{parser => functional/codegen}/features/test_address_balance.py (100%) rename tests/{parser => functional/codegen}/features/test_assert.py (100%) rename tests/{parser => functional/codegen}/features/test_assert_unreachable.py (100%) rename tests/{parser => functional/codegen}/features/test_assignment.py (100%) rename tests/{parser => functional/codegen}/features/test_bytes_map_keys.py (100%) rename tests/{parser => functional/codegen}/features/test_clampers.py (100%) rename tests/{parser => functional/codegen}/features/test_comments.py (100%) rename tests/{parser => functional/codegen}/features/test_comparison.py (100%) rename tests/{parser => functional/codegen}/features/test_conditionals.py (100%) rename tests/{parser => functional/codegen}/features/test_constructor.py (100%) rename tests/{parser => functional/codegen}/features/test_gas.py (100%) rename tests/{parser => functional/codegen}/features/test_immutable.py (100%) rename tests/{parser => functional/codegen}/features/test_init.py (100%) rename tests/{parser => functional/codegen}/features/test_internal_call.py (100%) rename tests/{parser => functional/codegen}/features/test_logging.py (100%) rename tests/{parser => functional/codegen}/features/test_logging_bytes_extended.py (100%) rename tests/{parser => functional/codegen}/features/test_logging_from_call.py (100%) rename tests/{parser => functional/codegen}/features/test_memory_alloc.py (100%) rename tests/{parser => functional/codegen}/features/test_memory_dealloc.py (100%) rename tests/{parser => functional/codegen}/features/test_packing.py (100%) rename tests/{parser => functional/codegen}/features/test_reverting.py (100%) rename tests/{parser => functional/codegen}/features/test_short_circuiting.py (100%) rename tests/{parser => functional/codegen}/features/test_string_map_keys.py (100%) rename tests/{parser => functional/codegen}/features/test_ternary.py (100%) rename tests/{parser => functional/codegen}/features/test_transient.py (100%) rename tests/{parser => functional/codegen}/integration/test_basics.py (100%) rename tests/{parser => functional/codegen}/integration/test_crowdfund.py (98%) rename tests/{parser => functional/codegen}/integration/test_escrow.py (100%) rename tests/{parser/globals => functional/codegen/storage_variables}/test_getters.py (100%) rename tests/{parser/globals => functional/codegen/storage_variables}/test_setters.py (100%) rename tests/{parser/globals/test_globals.py => functional/codegen/storage_variables/test_storage_variable.py} (100%) rename tests/{parser => functional/codegen}/test_call_graph_stability.py (100%) rename tests/{parser => functional/codegen}/test_selector_table.py (100%) rename tests/{parser => functional/codegen}/test_selector_table_stability.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_constants.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_decimals.py (100%) rename tests/{parser/features/arithmetic => functional/codegen/types/numbers}/test_division.py (100%) rename tests/{fuzzing => functional/codegen/types/numbers}/test_exponents.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_isqrt.py (100%) rename tests/{parser/features/arithmetic => functional/codegen/types/numbers}/test_modulo.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_signed_ints.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_sqrt.py (100%) rename tests/{parser => functional/codegen}/types/numbers/test_unsigned_ints.py (100%) rename tests/{parser => functional/codegen}/types/test_bytes.py (100%) rename tests/{parser => functional/codegen}/types/test_bytes_literal.py (100%) rename tests/{parser => functional/codegen}/types/test_bytes_zero_padding.py (100%) rename tests/{parser => functional/codegen}/types/test_dynamic_array.py (100%) rename tests/{parser => functional/codegen}/types/test_enum.py (100%) rename tests/{parser => functional/codegen}/types/test_identifier_naming.py (100%) mode change 100755 => 100644 rename tests/{parser => functional/codegen}/types/test_lists.py (100%) rename tests/{parser => functional/codegen}/types/test_node_types.py (100%) rename tests/{parser => functional/codegen}/types/test_string.py (100%) rename tests/{parser => functional/codegen}/types/test_string_literal.py (100%) rename tests/{ => functional}/examples/auctions/test_blind_auction.py (100%) rename tests/{ => functional}/examples/auctions/test_simple_open_auction.py (100%) rename tests/{ => functional}/examples/company/test_company.py (100%) rename tests/{ => functional}/examples/conftest.py (100%) rename tests/{ => functional}/examples/crowdfund/test_crowdfund_example.py (100%) rename tests/{ => functional}/examples/factory/test_factory.py (100%) rename tests/{ => functional}/examples/market_maker/test_on_chain_market_maker.py (100%) rename tests/{ => functional}/examples/name_registry/test_name_registry.py (100%) rename tests/{ => functional}/examples/safe_remote_purchase/test_safe_remote_purchase.py (100%) rename tests/{ => functional}/examples/storage/test_advanced_storage.py (100%) rename tests/{ => functional}/examples/storage/test_storage.py (100%) rename tests/{ => functional}/examples/tokens/test_erc1155.py (100%) rename tests/{ => functional}/examples/tokens/test_erc20.py (100%) rename tests/{ => functional}/examples/tokens/test_erc4626.py (100%) rename tests/{ => functional}/examples/tokens/test_erc721.py (100%) rename tests/{ => functional}/examples/voting/test_ballot.py (100%) rename tests/{ => functional}/examples/wallet/test_wallet.py (100%) rename tests/{ => functional}/grammar/test_grammar.py (100%) rename tests/{parser => functional}/syntax/__init__.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_argument_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_call_violation.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_constancy_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_function_declaration_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_instantiation_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_invalid_literal_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_invalid_payable.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_invalid_reference.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_invalid_type_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_namespace_collision.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_overflow_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_structure_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_syntax_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_type_mismatch_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_undeclared_definition.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_variable_declaration_exception.py (100%) rename tests/{parser => functional/syntax}/exceptions/test_vyper_exception_pos.py (100%) rename tests/{parser/syntax/utils => functional/syntax/names}/test_event_names.py (100%) rename tests/{parser/syntax/utils => functional/syntax/names}/test_function_names.py (100%) rename tests/{parser/syntax/utils => functional/syntax/names}/test_variable_names.py (100%) rename tests/{ => functional/syntax}/signatures/test_invalid_function_decorators.py (100%) rename tests/{ => functional/syntax}/signatures/test_method_id_conflicts.py (100%) rename tests/{parser => functional}/syntax/test_abi_decode.py (100%) rename tests/{parser => functional}/syntax/test_abi_encode.py (100%) rename tests/{parser => functional}/syntax/test_addmulmod.py (100%) rename tests/{parser => functional}/syntax/test_address_code.py (100%) rename tests/{parser => functional}/syntax/test_ann_assign.py (100%) rename tests/{parser => functional}/syntax/test_as_uint256.py (100%) rename tests/{parser => functional}/syntax/test_as_wei_value.py (100%) rename tests/{parser => functional}/syntax/test_block.py (100%) rename tests/{parser => functional}/syntax/test_blockscope.py (100%) rename tests/{parser => functional}/syntax/test_bool.py (100%) rename tests/{parser => functional}/syntax/test_bool_ops.py (100%) rename tests/{parser => functional}/syntax/test_bytes.py (100%) rename tests/{parser => functional}/syntax/test_chainid.py (100%) rename tests/{parser => functional}/syntax/test_code_size.py (100%) rename tests/{parser => functional}/syntax/test_codehash.py (100%) rename tests/{parser => functional}/syntax/test_concat.py (100%) rename tests/{parser => functional}/syntax/test_conditionals.py (100%) rename tests/{parser => functional}/syntax/test_constants.py (100%) rename tests/{parser => functional}/syntax/test_create_with_code_of.py (100%) rename tests/{parser => functional}/syntax/test_dynamic_array.py (100%) rename tests/{parser => functional}/syntax/test_enum.py (100%) rename tests/{parser => functional}/syntax/test_extract32.py (100%) rename tests/{parser => functional}/syntax/test_for_range.py (100%) rename tests/{parser => functional}/syntax/test_functions_call.py (100%) rename tests/{parser => functional}/syntax/test_immutables.py (100%) rename tests/{parser => functional}/syntax/test_interfaces.py (100%) rename tests/{parser => functional}/syntax/test_invalids.py (100%) rename tests/{parser => functional}/syntax/test_keccak256.py (100%) rename tests/{parser => functional}/syntax/test_len.py (100%) rename tests/{parser => functional}/syntax/test_list.py (100%) rename tests/{parser => functional}/syntax/test_logging.py (100%) rename tests/{parser => functional}/syntax/test_minmax.py (100%) rename tests/{parser => functional}/syntax/test_minmax_value.py (100%) rename tests/{parser => functional}/syntax/test_msg_data.py (100%) rename tests/{parser => functional}/syntax/test_nested_list.py (100%) rename tests/{parser => functional}/syntax/test_no_none.py (100%) rename tests/{parser => functional}/syntax/test_print.py (100%) rename tests/{parser => functional}/syntax/test_public.py (100%) rename tests/{parser => functional}/syntax/test_raw_call.py (100%) rename tests/{parser => functional}/syntax/test_return_tuple.py (100%) rename tests/{parser => functional}/syntax/test_self_balance.py (100%) rename tests/{parser => functional}/syntax/test_selfdestruct.py (100%) rename tests/{parser => functional}/syntax/test_send.py (100%) rename tests/{parser => functional}/syntax/test_slice.py (100%) rename tests/{parser => functional}/syntax/test_string.py (100%) rename tests/{parser => functional}/syntax/test_structs.py (100%) rename tests/{parser => functional}/syntax/test_ternary.py (100%) rename tests/{parser => functional}/syntax/test_tuple_assign.py (100%) rename tests/{parser => functional}/syntax/test_unbalanced_return.py (100%) delete mode 100644 tests/parser/functions/test_as_wei_value.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/abi_types/test_invalid_abi_types.py (100%) rename tests/{ => unit}/ast/nodes/test_binary.py (100%) rename tests/{ => unit}/ast/nodes/test_compare_nodes.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_binop_decimal.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_binop_int.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_boolop.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_compare.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_subscript.py (100%) rename tests/{ => unit}/ast/nodes/test_evaluate_unaryop.py (100%) rename tests/{ => unit}/ast/nodes/test_from_node.py (100%) rename tests/{ => unit}/ast/nodes/test_get_children.py (100%) rename tests/{ => unit}/ast/nodes/test_get_descendants.py (100%) rename tests/{ => unit}/ast/nodes/test_hex.py (100%) rename tests/{ => unit}/ast/nodes/test_replace_in_tree.py (100%) rename tests/{parser/parser_utils => unit/ast}/test_annotate_and_optimize_ast.py (100%) rename tests/{parser/ast_utils => unit/ast}/test_ast_dict.py (100%) rename tests/{ => unit}/ast/test_folding.py (100%) rename tests/{ => unit}/ast/test_metadata_journal.py (100%) rename tests/{ => unit}/ast/test_natspec.py (100%) rename tests/{parser/ast_utils/test_ast.py => unit/ast/test_parser.py} (100%) rename tests/{ => unit}/ast/test_pre_parser.py (100%) rename tests/{test_utils.py => unit/ast/test_source_annotation.py} (100%) rename tests/{ => unit}/cli/outputs/test_storage_layout.py (100%) rename tests/{ => unit}/cli/outputs/test_storage_layout_overrides.py (100%) rename tests/{ => unit}/cli/vyper_compile/test_compile_files.py (100%) rename tests/{ => unit}/cli/vyper_compile/test_parse_args.py (100%) rename tests/{ => unit}/cli/vyper_json/test_compile_json.py (100%) rename tests/{ => unit}/cli/vyper_json/test_get_inputs.py (100%) rename tests/{ => unit}/cli/vyper_json/test_get_settings.py (100%) rename tests/{ => unit}/cli/vyper_json/test_output_selection.py (100%) rename tests/{ => unit}/cli/vyper_json/test_parse_args_vyperjson.py (100%) rename tests/{ => unit}/compiler/__init__.py (100%) rename tests/{ => unit}/compiler/asm/test_asm_optimizer.py (100%) create mode 100644 tests/unit/compiler/ir/__init__.py rename tests/{compiler => unit/compiler/ir}/test_calldatacopy.py (100%) rename tests/{ => unit}/compiler/ir/test_compile_ir.py (100%) rename tests/{ => unit}/compiler/ir/test_optimize_ir.py (100%) rename tests/{ => unit}/compiler/ir/test_repeat.py (100%) rename tests/{ => unit}/compiler/ir/test_with.py (100%) rename tests/{ => unit}/compiler/test_bytecode_runtime.py (100%) rename tests/{ => unit}/compiler/test_compile_code.py (100%) rename tests/{ => unit}/compiler/test_default_settings.py (100%) rename tests/{ => unit}/compiler/test_input_bundle.py (100%) rename tests/{ => unit}/compiler/test_opcodes.py (100%) rename tests/{ => unit}/compiler/test_pre_parser.py (100%) rename tests/{ => unit}/compiler/test_sha3_32.py (100%) rename tests/{ => unit}/compiler/test_source_map.py (100%) rename tests/{functional => unit}/semantics/analysis/test_array_index.py (100%) rename tests/{functional => unit}/semantics/analysis/test_cyclic_function_calls.py (100%) rename tests/{functional => unit}/semantics/analysis/test_for_loop.py (100%) rename tests/{functional => unit}/semantics/analysis/test_potential_types.py (100%) rename tests/{functional => unit}/semantics/conftest.py (100%) rename tests/{functional => unit}/semantics/test_namespace.py (100%) rename tests/{functional => unit/semantics}/test_storage_slots.py (100%) rename tests/{functional => unit}/semantics/types/test_event.py (100%) rename tests/{functional => unit}/semantics/types/test_pure_types.py (100%) rename tests/{functional => unit}/semantics/types/test_size_in_bytes.py (100%) rename tests/{functional => unit}/semantics/types/test_type_from_abi.py (100%) rename tests/{functional => unit}/semantics/types/test_type_from_annotation.py (100%) diff --git a/tests/base_conftest.py b/tests/base_conftest.py deleted file mode 100644 index f613ad0f47..0000000000 --- a/tests/base_conftest.py +++ /dev/null @@ -1,217 +0,0 @@ -import json - -import pytest -import web3.exceptions -from eth_tester import EthereumTester, PyEVMBackend -from eth_tester.exceptions import TransactionFailed -from eth_utils.toolz import compose -from hexbytes import HexBytes -from web3 import Web3 -from web3.contract import Contract -from web3.providers.eth_tester import EthereumTesterProvider - -from vyper import compiler -from vyper.ast.grammar import parse_vyper_source -from vyper.compiler.settings import Settings - - -class VyperMethod: - ALLOWED_MODIFIERS = {"call", "estimateGas", "transact", "buildTransaction"} - - def __init__(self, function, normalizers=None): - self._function = function - self._function._return_data_normalizers = normalizers - - def __call__(self, *args, **kwargs): - return self.__prepared_function(*args, **kwargs) - - def __prepared_function(self, *args, **kwargs): - if not kwargs: - modifier, modifier_dict = "call", {} - fn_abi = [ - x - for x in self._function.contract_abi - if x.get("name") == self._function.function_identifier - ].pop() - # To make tests faster just supply some high gas value. - modifier_dict.update({"gas": fn_abi.get("gas", 0) + 500000}) - elif len(kwargs) == 1: - modifier, modifier_dict = kwargs.popitem() - if modifier not in self.ALLOWED_MODIFIERS: - raise TypeError(f"The only allowed keyword arguments are: {self.ALLOWED_MODIFIERS}") - else: - raise TypeError(f"Use up to one keyword argument, one of: {self.ALLOWED_MODIFIERS}") - return getattr(self._function(*args), modifier)(modifier_dict) - - -class VyperContract: - """ - An alternative Contract Factory which invokes all methods as `call()`, - unless you add a keyword argument. The keyword argument assigns the prep method. - This call - > contract.withdraw(amount, transact={'from': eth.accounts[1], 'gas': 100000, ...}) - is equivalent to this call in the classic contract: - > contract.functions.withdraw(amount).transact({'from': eth.accounts[1], 'gas': 100000, ...}) - """ - - def __init__(self, classic_contract, method_class=VyperMethod): - classic_contract._return_data_normalizers += CONCISE_NORMALIZERS - self._classic_contract = classic_contract - self.address = self._classic_contract.address - protected_fn_names = [fn for fn in dir(self) if not fn.endswith("__")] - - try: - fn_names = [fn["name"] for fn in self._classic_contract.functions._functions] - except web3.exceptions.NoABIFunctionsFound: - fn_names = [] - - for fn_name in fn_names: - # Override namespace collisions - if fn_name in protected_fn_names: - raise AttributeError(f"{fn_name} is protected!") - else: - _classic_method = getattr(self._classic_contract.functions, fn_name) - _concise_method = method_class( - _classic_method, self._classic_contract._return_data_normalizers - ) - setattr(self, fn_name, _concise_method) - - @classmethod - def factory(cls, *args, **kwargs): - return compose(cls, Contract.factory(*args, **kwargs)) - - -def _none_addr(datatype, data): - if datatype == "address" and int(data, base=16) == 0: - return (datatype, None) - else: - return (datatype, data) - - -CONCISE_NORMALIZERS = (_none_addr,) - - -@pytest.fixture(scope="module") -def tester(): - # set absurdly high gas limit so that london basefee never adjusts - # (note: 2**63 - 1 is max that evm allows) - custom_genesis = PyEVMBackend._generate_genesis_params(overrides={"gas_limit": 10**10}) - custom_genesis["base_fee_per_gas"] = 0 - backend = PyEVMBackend(genesis_parameters=custom_genesis) - return EthereumTester(backend=backend) - - -def zero_gas_price_strategy(web3, transaction_params=None): - return 0 # zero gas price makes testing simpler. - - -@pytest.fixture(scope="module") -def w3(tester): - w3 = Web3(EthereumTesterProvider(tester)) - w3.eth.set_gas_price_strategy(zero_gas_price_strategy) - return w3 - - -def _get_contract( - w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs -): - settings = Settings() - settings.evm_version = kwargs.pop("evm_version", None) - settings.optimize = override_opt_level or optimize - out = compiler.compile_code( - source_code, - # test that metadata and natspecs get generated - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], - settings=settings, - input_bundle=input_bundle, - show_gas_estimates=True, # Enable gas estimates for testing - ) - parse_vyper_source(source_code) # Test grammar. - json.dumps(out["metadata"]) # test metadata is json serializable - abi = out["abi"] - bytecode = out["bytecode"] - value = kwargs.pop("value_in_eth", 0) * 10**18 # Handle deploying with an eth value. - c = w3.eth.contract(abi=abi, bytecode=bytecode) - deploy_transaction = c.constructor(*args) - tx_info = {"from": w3.eth.accounts[0], "value": value, "gasPrice": 0} - tx_info.update(kwargs) - tx_hash = deploy_transaction.transact(tx_info) - address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] - return w3.eth.contract(address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract) - - -def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): - settings = Settings() - settings.evm_version = kwargs.pop("evm_version", None) - settings.optimize = optimize - out = compiler.compile_code( - source_code, - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], - settings=settings, - show_gas_estimates=True, # Enable gas estimates for testing - ) - parse_vyper_source(source_code) # Test grammar. - abi = out["abi"] - bytecode = HexBytes(initcode_prefix) + HexBytes(out["bytecode"]) - bytecode_len = len(bytecode) - bytecode_len_hex = hex(bytecode_len)[2:].rjust(4, "0") - # prepend a quick deploy preamble - deploy_preamble = HexBytes("61" + bytecode_len_hex + "3d81600a3d39f3") - deploy_bytecode = HexBytes(deploy_preamble) + bytecode - - deployer_abi = [] # just a constructor - c = w3.eth.contract(abi=deployer_abi, bytecode=deploy_bytecode) - deploy_transaction = c.constructor() - tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} - - tx_hash = deploy_transaction.transact(tx_info) - address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] - - # sanity check - assert w3.eth.get_code(address) == bytecode, (w3.eth.get_code(address), bytecode) - - def factory(address): - return w3.eth.contract( - address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract - ) - - return w3.eth.contract(address, bytecode=deploy_bytecode), factory - - -@pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize): - def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) - - return deploy_blueprint_for - - -@pytest.fixture(scope="module") -def get_contract(w3, optimize): - def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) - - return fn - - -@pytest.fixture -def get_logs(w3): - def get_logs(tx_hash, c, event_name): - tx_receipt = w3.eth.get_transaction_receipt(tx_hash) - return c._classic_contract.events[event_name]().process_receipt(tx_receipt) - - return get_logs - - -@pytest.fixture(scope="module") -def assert_tx_failed(tester): - def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None): - snapshot_id = tester.take_snapshot() - with pytest.raises(exception) as excinfo: - function_to_test() - tester.revert_to_snapshot(snapshot_id) - if exc_text: - # TODO test equality - assert exc_text in str(excinfo.value), (exc_text, excinfo.value) - - return assert_tx_failed diff --git a/tests/conftest.py b/tests/conftest.py index 9b10b7c51c..216fb32b0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,24 +1,28 @@ +import json import logging from functools import wraps import hypothesis import pytest +import web3.exceptions from eth_tester import EthereumTester, PyEVMBackend +from eth_tester.exceptions import TransactionFailed from eth_utils import setup_DEBUG2_logging +from eth_utils.toolz import compose from hexbytes import HexBytes from web3 import Web3 +from web3.contract import Contract from web3.providers.eth_tester import EthereumTesterProvider from vyper import compiler +from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode from vyper.compiler.input_bundle import FilesystemInputBundle -from vyper.compiler.settings import OptimizationLevel, _set_debug_mode +from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer -from .base_conftest import VyperContract, _get_contract, zero_gas_price_strategy - -# Import the base_conftest fixtures -pytest_plugins = ["tests.base_conftest", "tests.fixtures.memorymock"] +# Import the base fixtures +pytest_plugins = ["tests.fixtures.memorymock"] ############ # PATCHING # @@ -99,6 +103,8 @@ def fn(sources_dict): return fn +# TODO: remove me, this is just string.encode("utf-8").ljust() +# only used in test_logging.py. @pytest.fixture def bytes_helper(): def bytes_helper(str, length): @@ -107,45 +113,35 @@ def bytes_helper(str, length): return bytes_helper -@pytest.fixture -def get_contract_from_ir(w3, optimize): - def ir_compiler(ir, *args, **kwargs): - ir = IRnode.from_list(ir) - if optimize != OptimizationLevel.NONE: - ir = optimizer.optimize(ir) - bytecode, _ = compile_ir.assembly_to_evm( - compile_ir.compile_to_assembly(ir, optimize=optimize) - ) - abi = kwargs.get("abi") or [] - c = w3.eth.contract(abi=abi, bytecode=bytecode) - deploy_transaction = c.constructor() - tx_hash = deploy_transaction.transact() - address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] - contract = w3.eth.contract( - address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract - ) - return contract +def _none_addr(datatype, data): + if datatype == "address" and int(data, base=16) == 0: + return (datatype, None) + else: + return (datatype, data) - return ir_compiler + +CONCISE_NORMALIZERS = (_none_addr,) @pytest.fixture(scope="module") -def get_contract_module(optimize): - """ - This fixture is used for Hypothesis tests to ensure that - the same contract is called over multiple runs of the test. - """ - custom_genesis = PyEVMBackend._generate_genesis_params(overrides={"gas_limit": 4500000}) +def tester(): + # set absurdly high gas limit so that london basefee never adjusts + # (note: 2**63 - 1 is max that evm allows) + custom_genesis = PyEVMBackend._generate_genesis_params(overrides={"gas_limit": 10**10}) custom_genesis["base_fee_per_gas"] = 0 backend = PyEVMBackend(genesis_parameters=custom_genesis) - tester = EthereumTester(backend=backend) - w3 = Web3(EthereumTesterProvider(tester)) - w3.eth.set_gas_price_strategy(zero_gas_price_strategy) + return EthereumTester(backend=backend) - def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) - return get_contract_module +def zero_gas_price_strategy(web3, transaction_params=None): + return 0 # zero gas price makes testing simpler. + + +@pytest.fixture(scope="module") +def w3(tester): + w3 = Web3(EthereumTesterProvider(tester)) + w3.eth.set_gas_price_strategy(zero_gas_price_strategy) + return w3 def get_compiler_gas_estimate(code, func): @@ -187,6 +183,130 @@ def set_decorator_to_contract_function(w3, tester, contract, source_code, func): setattr(contract, func, func_with_decorator) +class VyperMethod: + ALLOWED_MODIFIERS = {"call", "estimateGas", "transact", "buildTransaction"} + + def __init__(self, function, normalizers=None): + self._function = function + self._function._return_data_normalizers = normalizers + + def __call__(self, *args, **kwargs): + return self.__prepared_function(*args, **kwargs) + + def __prepared_function(self, *args, **kwargs): + if not kwargs: + modifier, modifier_dict = "call", {} + fn_abi = [ + x + for x in self._function.contract_abi + if x.get("name") == self._function.function_identifier + ].pop() + # To make tests faster just supply some high gas value. + modifier_dict.update({"gas": fn_abi.get("gas", 0) + 500000}) + elif len(kwargs) == 1: + modifier, modifier_dict = kwargs.popitem() + if modifier not in self.ALLOWED_MODIFIERS: + raise TypeError(f"The only allowed keyword arguments are: {self.ALLOWED_MODIFIERS}") + else: + raise TypeError(f"Use up to one keyword argument, one of: {self.ALLOWED_MODIFIERS}") + return getattr(self._function(*args), modifier)(modifier_dict) + + +class VyperContract: + """ + An alternative Contract Factory which invokes all methods as `call()`, + unless you add a keyword argument. The keyword argument assigns the prep method. + This call + > contract.withdraw(amount, transact={'from': eth.accounts[1], 'gas': 100000, ...}) + is equivalent to this call in the classic contract: + > contract.functions.withdraw(amount).transact({'from': eth.accounts[1], 'gas': 100000, ...}) + """ + + def __init__(self, classic_contract, method_class=VyperMethod): + classic_contract._return_data_normalizers += CONCISE_NORMALIZERS + self._classic_contract = classic_contract + self.address = self._classic_contract.address + protected_fn_names = [fn for fn in dir(self) if not fn.endswith("__")] + + try: + fn_names = [fn["name"] for fn in self._classic_contract.functions._functions] + except web3.exceptions.NoABIFunctionsFound: + fn_names = [] + + for fn_name in fn_names: + # Override namespace collisions + if fn_name in protected_fn_names: + raise AttributeError(f"{fn_name} is protected!") + else: + _classic_method = getattr(self._classic_contract.functions, fn_name) + _concise_method = method_class( + _classic_method, self._classic_contract._return_data_normalizers + ) + setattr(self, fn_name, _concise_method) + + @classmethod + def factory(cls, *args, **kwargs): + return compose(cls, Contract.factory(*args, **kwargs)) + + +@pytest.fixture +def get_contract_from_ir(w3, optimize): + def ir_compiler(ir, *args, **kwargs): + ir = IRnode.from_list(ir) + if optimize != OptimizationLevel.NONE: + ir = optimizer.optimize(ir) + bytecode, _ = compile_ir.assembly_to_evm( + compile_ir.compile_to_assembly(ir, optimize=optimize) + ) + abi = kwargs.get("abi") or [] + c = w3.eth.contract(abi=abi, bytecode=bytecode) + deploy_transaction = c.constructor() + tx_hash = deploy_transaction.transact() + address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + contract = w3.eth.contract( + address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract + ) + return contract + + return ir_compiler + + +def _get_contract( + w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs +): + settings = Settings() + settings.evm_version = kwargs.pop("evm_version", None) + settings.optimize = override_opt_level or optimize + out = compiler.compile_code( + source_code, + # test that metadata and natspecs get generated + output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + settings=settings, + input_bundle=input_bundle, + show_gas_estimates=True, # Enable gas estimates for testing + ) + parse_vyper_source(source_code) # Test grammar. + json.dumps(out["metadata"]) # test metadata is json serializable + abi = out["abi"] + bytecode = out["bytecode"] + value = kwargs.pop("value_in_eth", 0) * 10**18 # Handle deploying with an eth value. + c = w3.eth.contract(abi=abi, bytecode=bytecode) + deploy_transaction = c.constructor(*args) + tx_info = {"from": w3.eth.accounts[0], "value": value, "gasPrice": 0} + tx_info.update(kwargs) + tx_hash = deploy_transaction.transact(tx_info) + address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + return w3.eth.contract(address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract) + + +@pytest.fixture(scope="module") +def get_contract(w3, optimize): + def fn(source_code, *args, **kwargs): + return _get_contract(w3, source_code, optimize, *args, **kwargs) + + return fn + + @pytest.fixture def get_contract_with_gas_estimation(tester, w3, optimize): def get_contract_with_gas_estimation(source_code, *args, **kwargs): @@ -207,6 +327,73 @@ def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs) return get_contract_with_gas_estimation_for_constants +@pytest.fixture(scope="module") +def get_contract_module(optimize): + """ + This fixture is used for Hypothesis tests to ensure that + the same contract is called over multiple runs of the test. + """ + custom_genesis = PyEVMBackend._generate_genesis_params(overrides={"gas_limit": 4500000}) + custom_genesis["base_fee_per_gas"] = 0 + backend = PyEVMBackend(genesis_parameters=custom_genesis) + tester = EthereumTester(backend=backend) + w3 = Web3(EthereumTesterProvider(tester)) + w3.eth.set_gas_price_strategy(zero_gas_price_strategy) + + def get_contract_module(source_code, *args, **kwargs): + return _get_contract(w3, source_code, optimize, *args, **kwargs) + + return get_contract_module + + +def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): + settings = Settings() + settings.evm_version = kwargs.pop("evm_version", None) + settings.optimize = optimize + out = compiler.compile_code( + source_code, + output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + settings=settings, + show_gas_estimates=True, # Enable gas estimates for testing + ) + parse_vyper_source(source_code) # Test grammar. + abi = out["abi"] + bytecode = HexBytes(initcode_prefix) + HexBytes(out["bytecode"]) + bytecode_len = len(bytecode) + bytecode_len_hex = hex(bytecode_len)[2:].rjust(4, "0") + # prepend a quick deploy preamble + deploy_preamble = HexBytes("61" + bytecode_len_hex + "3d81600a3d39f3") + deploy_bytecode = HexBytes(deploy_preamble) + bytecode + + deployer_abi = [] # just a constructor + c = w3.eth.contract(abi=deployer_abi, bytecode=deploy_bytecode) + deploy_transaction = c.constructor() + tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} + + tx_hash = deploy_transaction.transact(tx_info) + address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] + + # sanity check + assert w3.eth.get_code(address) == bytecode, (w3.eth.get_code(address), bytecode) + + def factory(address): + return w3.eth.contract( + address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract + ) + + return w3.eth.contract(address, bytecode=deploy_bytecode), factory + + +@pytest.fixture(scope="module") +def deploy_blueprint_for(w3, optimize): + def deploy_blueprint_for(source_code, *args, **kwargs): + return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) + + return deploy_blueprint_for + + +# TODO: this should not be a fixture. +# remove me and replace all uses with `with pytest.raises`. @pytest.fixture def assert_compile_failed(): def assert_compile_failed(function_to_test, exception=Exception): @@ -216,6 +403,7 @@ def assert_compile_failed(function_to_test, exception=Exception): return assert_compile_failed +# TODO this should not be a fixture @pytest.fixture def search_for_sublist(): def search_for_sublist(ir, sublist): @@ -277,3 +465,27 @@ def assert_side_effects_invoked(side_effects_contract, side_effects_trigger, n=1 assert end_value == start_value + n return assert_side_effects_invoked + + +@pytest.fixture +def get_logs(w3): + def get_logs(tx_hash, c, event_name): + tx_receipt = w3.eth.get_transaction_receipt(tx_hash) + return c._classic_contract.events[event_name]().process_receipt(tx_receipt) + + return get_logs + + +# TODO replace me with function like `with anchor_state()` +@pytest.fixture(scope="module") +def assert_tx_failed(tester): + def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None): + snapshot_id = tester.take_snapshot() + with pytest.raises(exception) as excinfo: + function_to_test() + tester.revert_to_snapshot(snapshot_id) + if exc_text: + # TODO test equality + assert exc_text in str(excinfo.value), (exc_text, excinfo.value) + + return assert_tx_failed diff --git a/tests/compiler/ir/__init__.py b/tests/functional/__init__.py similarity index 100% rename from tests/compiler/ir/__init__.py rename to tests/functional/__init__.py diff --git a/tests/parser/exceptions/__init__.py b/tests/functional/builtins/codegen/__init__.py similarity index 100% rename from tests/parser/exceptions/__init__.py rename to tests/functional/builtins/codegen/__init__.py diff --git a/tests/parser/functions/test_abi.py b/tests/functional/builtins/codegen/test_abi.py similarity index 100% rename from tests/parser/functions/test_abi.py rename to tests/functional/builtins/codegen/test_abi.py diff --git a/tests/parser/functions/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py similarity index 100% rename from tests/parser/functions/test_abi_decode.py rename to tests/functional/builtins/codegen/test_abi_decode.py diff --git a/tests/parser/functions/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py similarity index 100% rename from tests/parser/functions/test_abi_encode.py rename to tests/functional/builtins/codegen/test_abi_encode.py diff --git a/tests/parser/functions/test_addmod.py b/tests/functional/builtins/codegen/test_addmod.py similarity index 100% rename from tests/parser/functions/test_addmod.py rename to tests/functional/builtins/codegen/test_addmod.py diff --git a/tests/parser/types/value/test_as_wei_value.py b/tests/functional/builtins/codegen/test_as_wei_value.py similarity index 75% rename from tests/parser/types/value/test_as_wei_value.py rename to tests/functional/builtins/codegen/test_as_wei_value.py index 249ac4b2ff..cc27507e7c 100644 --- a/tests/parser/types/value/test_as_wei_value.py +++ b/tests/functional/builtins/codegen/test_as_wei_value.py @@ -91,3 +91,36 @@ def foo(a: {data_type}) -> uint256: c = get_contract(code) assert c.foo(0) == 0 + + +def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> uint256: + return as_wei_value(a.foo(7), "ether") + +interface Foo: + def foo(x: uint8) -> uint8: nonpayable + """ + + c1 = side_effects_contract("uint8") + c2 = get_contract(code) + + assert c2.foo(c1.address) == w3.to_wei(7, "ether") + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_internal_call(w3, get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return as_wei_value(self.bar(), "ether") + +@internal +def bar() -> uint8: + return 7 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == w3.to_wei(7, "ether") diff --git a/tests/parser/functions/test_bitwise.py b/tests/functional/builtins/codegen/test_bitwise.py similarity index 100% rename from tests/parser/functions/test_bitwise.py rename to tests/functional/builtins/codegen/test_bitwise.py diff --git a/tests/parser/functions/test_ceil.py b/tests/functional/builtins/codegen/test_ceil.py similarity index 100% rename from tests/parser/functions/test_ceil.py rename to tests/functional/builtins/codegen/test_ceil.py diff --git a/tests/parser/functions/test_concat.py b/tests/functional/builtins/codegen/test_concat.py similarity index 100% rename from tests/parser/functions/test_concat.py rename to tests/functional/builtins/codegen/test_concat.py diff --git a/tests/parser/functions/test_convert.py b/tests/functional/builtins/codegen/test_convert.py similarity index 100% rename from tests/parser/functions/test_convert.py rename to tests/functional/builtins/codegen/test_convert.py diff --git a/tests/parser/functions/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py similarity index 100% rename from tests/parser/functions/test_create_functions.py rename to tests/functional/builtins/codegen/test_create_functions.py diff --git a/tests/parser/functions/test_ec.py b/tests/functional/builtins/codegen/test_ec.py similarity index 100% rename from tests/parser/functions/test_ec.py rename to tests/functional/builtins/codegen/test_ec.py diff --git a/tests/parser/functions/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py similarity index 100% rename from tests/parser/functions/test_ecrecover.py rename to tests/functional/builtins/codegen/test_ecrecover.py diff --git a/tests/parser/functions/test_empty.py b/tests/functional/builtins/codegen/test_empty.py similarity index 100% rename from tests/parser/functions/test_empty.py rename to tests/functional/builtins/codegen/test_empty.py diff --git a/tests/parser/functions/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py similarity index 100% rename from tests/parser/functions/test_extract32.py rename to tests/functional/builtins/codegen/test_extract32.py diff --git a/tests/parser/functions/test_floor.py b/tests/functional/builtins/codegen/test_floor.py similarity index 100% rename from tests/parser/functions/test_floor.py rename to tests/functional/builtins/codegen/test_floor.py diff --git a/tests/parser/functions/test_interfaces.py b/tests/functional/builtins/codegen/test_interfaces.py similarity index 100% rename from tests/parser/functions/test_interfaces.py rename to tests/functional/builtins/codegen/test_interfaces.py diff --git a/tests/parser/functions/test_is_contract.py b/tests/functional/builtins/codegen/test_is_contract.py similarity index 100% rename from tests/parser/functions/test_is_contract.py rename to tests/functional/builtins/codegen/test_is_contract.py diff --git a/tests/parser/functions/test_keccak256.py b/tests/functional/builtins/codegen/test_keccak256.py similarity index 100% rename from tests/parser/functions/test_keccak256.py rename to tests/functional/builtins/codegen/test_keccak256.py diff --git a/tests/parser/functions/test_length.py b/tests/functional/builtins/codegen/test_length.py similarity index 100% rename from tests/parser/functions/test_length.py rename to tests/functional/builtins/codegen/test_length.py diff --git a/tests/parser/functions/test_method_id.py b/tests/functional/builtins/codegen/test_method_id.py similarity index 100% rename from tests/parser/functions/test_method_id.py rename to tests/functional/builtins/codegen/test_method_id.py diff --git a/tests/parser/functions/test_minmax.py b/tests/functional/builtins/codegen/test_minmax.py similarity index 100% rename from tests/parser/functions/test_minmax.py rename to tests/functional/builtins/codegen/test_minmax.py diff --git a/tests/parser/functions/test_minmax_value.py b/tests/functional/builtins/codegen/test_minmax_value.py similarity index 100% rename from tests/parser/functions/test_minmax_value.py rename to tests/functional/builtins/codegen/test_minmax_value.py diff --git a/tests/parser/functions/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py similarity index 100% rename from tests/parser/functions/test_mulmod.py rename to tests/functional/builtins/codegen/test_mulmod.py diff --git a/tests/parser/functions/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py similarity index 100% rename from tests/parser/functions/test_raw_call.py rename to tests/functional/builtins/codegen/test_raw_call.py diff --git a/tests/parser/functions/test_send.py b/tests/functional/builtins/codegen/test_send.py similarity index 100% rename from tests/parser/functions/test_send.py rename to tests/functional/builtins/codegen/test_send.py diff --git a/tests/parser/functions/test_sha256.py b/tests/functional/builtins/codegen/test_sha256.py similarity index 100% rename from tests/parser/functions/test_sha256.py rename to tests/functional/builtins/codegen/test_sha256.py diff --git a/tests/parser/functions/test_slice.py b/tests/functional/builtins/codegen/test_slice.py similarity index 100% rename from tests/parser/functions/test_slice.py rename to tests/functional/builtins/codegen/test_slice.py diff --git a/tests/parser/functions/test_mkstr.py b/tests/functional/builtins/codegen/test_uint2str.py similarity index 100% rename from tests/parser/functions/test_mkstr.py rename to tests/functional/builtins/codegen/test_uint2str.py diff --git a/tests/parser/functions/test_unary.py b/tests/functional/builtins/codegen/test_unary.py similarity index 100% rename from tests/parser/functions/test_unary.py rename to tests/functional/builtins/codegen/test_unary.py diff --git a/tests/parser/functions/test_unsafe_math.py b/tests/functional/builtins/codegen/test_unsafe_math.py similarity index 100% rename from tests/parser/functions/test_unsafe_math.py rename to tests/functional/builtins/codegen/test_unsafe_math.py diff --git a/tests/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py similarity index 100% rename from tests/builtins/folding/test_abs.py rename to tests/functional/builtins/folding/test_abs.py diff --git a/tests/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py similarity index 100% rename from tests/builtins/folding/test_addmod_mulmod.py rename to tests/functional/builtins/folding/test_addmod_mulmod.py diff --git a/tests/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py similarity index 100% rename from tests/builtins/folding/test_bitwise.py rename to tests/functional/builtins/folding/test_bitwise.py diff --git a/tests/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py similarity index 100% rename from tests/builtins/folding/test_epsilon.py rename to tests/functional/builtins/folding/test_epsilon.py diff --git a/tests/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py similarity index 100% rename from tests/builtins/folding/test_floor_ceil.py rename to tests/functional/builtins/folding/test_floor_ceil.py diff --git a/tests/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py similarity index 100% rename from tests/builtins/folding/test_fold_as_wei_value.py rename to tests/functional/builtins/folding/test_fold_as_wei_value.py diff --git a/tests/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py similarity index 100% rename from tests/builtins/folding/test_keccak_sha.py rename to tests/functional/builtins/folding/test_keccak_sha.py diff --git a/tests/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py similarity index 100% rename from tests/builtins/folding/test_len.py rename to tests/functional/builtins/folding/test_len.py diff --git a/tests/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py similarity index 100% rename from tests/builtins/folding/test_min_max.py rename to tests/functional/builtins/folding/test_min_max.py diff --git a/tests/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py similarity index 100% rename from tests/builtins/folding/test_powmod.py rename to tests/functional/builtins/folding/test_powmod.py diff --git a/tests/parser/functions/__init__.py b/tests/functional/codegen/__init__.py similarity index 100% rename from tests/parser/functions/__init__.py rename to tests/functional/codegen/__init__.py diff --git a/tests/parser/functions/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py similarity index 100% rename from tests/parser/functions/test_default_function.py rename to tests/functional/codegen/calling_convention/test_default_function.py diff --git a/tests/parser/functions/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py similarity index 100% rename from tests/parser/functions/test_default_parameters.py rename to tests/functional/codegen/calling_convention/test_default_parameters.py diff --git a/tests/parser/features/external_contracts/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py similarity index 100% rename from tests/parser/features/external_contracts/test_erc20_abi.py rename to tests/functional/codegen/calling_convention/test_erc20_abi.py diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py similarity index 100% rename from tests/parser/features/external_contracts/test_external_contract_calls.py rename to tests/functional/codegen/calling_convention/test_external_contract_calls.py diff --git a/tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py similarity index 100% rename from tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py rename to tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py diff --git a/tests/parser/functions/test_return.py b/tests/functional/codegen/calling_convention/test_return.py similarity index 100% rename from tests/parser/functions/test_return.py rename to tests/functional/codegen/calling_convention/test_return.py diff --git a/tests/parser/functions/test_return_struct.py b/tests/functional/codegen/calling_convention/test_return_struct.py similarity index 100% rename from tests/parser/functions/test_return_struct.py rename to tests/functional/codegen/calling_convention/test_return_struct.py diff --git a/tests/parser/functions/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py similarity index 100% rename from tests/parser/functions/test_return_tuple.py rename to tests/functional/codegen/calling_convention/test_return_tuple.py diff --git a/tests/parser/features/external_contracts/test_self_call_struct.py b/tests/functional/codegen/calling_convention/test_self_call_struct.py similarity index 100% rename from tests/parser/features/external_contracts/test_self_call_struct.py rename to tests/functional/codegen/calling_convention/test_self_call_struct.py diff --git a/tests/functional/codegen/test_struct_return.py b/tests/functional/codegen/calling_convention/test_struct_return.py similarity index 100% rename from tests/functional/codegen/test_struct_return.py rename to tests/functional/codegen/calling_convention/test_struct_return.py diff --git a/tests/functional/codegen/test_tuple_return.py b/tests/functional/codegen/calling_convention/test_tuple_return.py similarity index 100% rename from tests/functional/codegen/test_tuple_return.py rename to tests/functional/codegen/calling_convention/test_tuple_return.py diff --git a/tests/parser/functions/test_block_number.py b/tests/functional/codegen/environment_variables/test_block_number.py similarity index 100% rename from tests/parser/functions/test_block_number.py rename to tests/functional/codegen/environment_variables/test_block_number.py diff --git a/tests/parser/functions/test_block.py b/tests/functional/codegen/environment_variables/test_blockhash.py similarity index 100% rename from tests/parser/functions/test_block.py rename to tests/functional/codegen/environment_variables/test_blockhash.py diff --git a/tests/parser/functions/test_tx.py b/tests/functional/codegen/environment_variables/test_tx.py similarity index 100% rename from tests/parser/functions/test_tx.py rename to tests/functional/codegen/environment_variables/test_tx.py diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py similarity index 100% rename from tests/parser/features/decorators/test_nonreentrant.py rename to tests/functional/codegen/features/decorators/test_nonreentrant.py diff --git a/tests/parser/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py similarity index 100% rename from tests/parser/features/decorators/test_payable.py rename to tests/functional/codegen/features/decorators/test_payable.py diff --git a/tests/parser/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py similarity index 100% rename from tests/parser/features/decorators/test_private.py rename to tests/functional/codegen/features/decorators/test_private.py diff --git a/tests/parser/features/decorators/test_public.py b/tests/functional/codegen/features/decorators/test_public.py similarity index 100% rename from tests/parser/features/decorators/test_public.py rename to tests/functional/codegen/features/decorators/test_public.py diff --git a/tests/parser/features/decorators/test_pure.py b/tests/functional/codegen/features/decorators/test_pure.py similarity index 100% rename from tests/parser/features/decorators/test_pure.py rename to tests/functional/codegen/features/decorators/test_pure.py diff --git a/tests/parser/features/decorators/test_view.py b/tests/functional/codegen/features/decorators/test_view.py similarity index 100% rename from tests/parser/features/decorators/test_view.py rename to tests/functional/codegen/features/decorators/test_view.py diff --git a/tests/parser/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py similarity index 100% rename from tests/parser/features/iteration/test_break.py rename to tests/functional/codegen/features/iteration/test_break.py diff --git a/tests/parser/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py similarity index 100% rename from tests/parser/features/iteration/test_continue.py rename to tests/functional/codegen/features/iteration/test_continue.py diff --git a/tests/parser/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py similarity index 100% rename from tests/parser/features/iteration/test_for_in_list.py rename to tests/functional/codegen/features/iteration/test_for_in_list.py diff --git a/tests/parser/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py similarity index 100% rename from tests/parser/features/iteration/test_for_range.py rename to tests/functional/codegen/features/iteration/test_for_range.py diff --git a/tests/parser/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py similarity index 100% rename from tests/parser/features/iteration/test_range_in.py rename to tests/functional/codegen/features/iteration/test_range_in.py diff --git a/tests/parser/features/test_address_balance.py b/tests/functional/codegen/features/test_address_balance.py similarity index 100% rename from tests/parser/features/test_address_balance.py rename to tests/functional/codegen/features/test_address_balance.py diff --git a/tests/parser/features/test_assert.py b/tests/functional/codegen/features/test_assert.py similarity index 100% rename from tests/parser/features/test_assert.py rename to tests/functional/codegen/features/test_assert.py diff --git a/tests/parser/features/test_assert_unreachable.py b/tests/functional/codegen/features/test_assert_unreachable.py similarity index 100% rename from tests/parser/features/test_assert_unreachable.py rename to tests/functional/codegen/features/test_assert_unreachable.py diff --git a/tests/parser/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py similarity index 100% rename from tests/parser/features/test_assignment.py rename to tests/functional/codegen/features/test_assignment.py diff --git a/tests/parser/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py similarity index 100% rename from tests/parser/features/test_bytes_map_keys.py rename to tests/functional/codegen/features/test_bytes_map_keys.py diff --git a/tests/parser/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py similarity index 100% rename from tests/parser/features/test_clampers.py rename to tests/functional/codegen/features/test_clampers.py diff --git a/tests/parser/features/test_comments.py b/tests/functional/codegen/features/test_comments.py similarity index 100% rename from tests/parser/features/test_comments.py rename to tests/functional/codegen/features/test_comments.py diff --git a/tests/parser/features/test_comparison.py b/tests/functional/codegen/features/test_comparison.py similarity index 100% rename from tests/parser/features/test_comparison.py rename to tests/functional/codegen/features/test_comparison.py diff --git a/tests/parser/features/test_conditionals.py b/tests/functional/codegen/features/test_conditionals.py similarity index 100% rename from tests/parser/features/test_conditionals.py rename to tests/functional/codegen/features/test_conditionals.py diff --git a/tests/parser/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py similarity index 100% rename from tests/parser/features/test_constructor.py rename to tests/functional/codegen/features/test_constructor.py diff --git a/tests/parser/features/test_gas.py b/tests/functional/codegen/features/test_gas.py similarity index 100% rename from tests/parser/features/test_gas.py rename to tests/functional/codegen/features/test_gas.py diff --git a/tests/parser/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py similarity index 100% rename from tests/parser/features/test_immutable.py rename to tests/functional/codegen/features/test_immutable.py diff --git a/tests/parser/features/test_init.py b/tests/functional/codegen/features/test_init.py similarity index 100% rename from tests/parser/features/test_init.py rename to tests/functional/codegen/features/test_init.py diff --git a/tests/parser/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py similarity index 100% rename from tests/parser/features/test_internal_call.py rename to tests/functional/codegen/features/test_internal_call.py diff --git a/tests/parser/features/test_logging.py b/tests/functional/codegen/features/test_logging.py similarity index 100% rename from tests/parser/features/test_logging.py rename to tests/functional/codegen/features/test_logging.py diff --git a/tests/parser/features/test_logging_bytes_extended.py b/tests/functional/codegen/features/test_logging_bytes_extended.py similarity index 100% rename from tests/parser/features/test_logging_bytes_extended.py rename to tests/functional/codegen/features/test_logging_bytes_extended.py diff --git a/tests/parser/features/test_logging_from_call.py b/tests/functional/codegen/features/test_logging_from_call.py similarity index 100% rename from tests/parser/features/test_logging_from_call.py rename to tests/functional/codegen/features/test_logging_from_call.py diff --git a/tests/parser/features/test_memory_alloc.py b/tests/functional/codegen/features/test_memory_alloc.py similarity index 100% rename from tests/parser/features/test_memory_alloc.py rename to tests/functional/codegen/features/test_memory_alloc.py diff --git a/tests/parser/features/test_memory_dealloc.py b/tests/functional/codegen/features/test_memory_dealloc.py similarity index 100% rename from tests/parser/features/test_memory_dealloc.py rename to tests/functional/codegen/features/test_memory_dealloc.py diff --git a/tests/parser/features/test_packing.py b/tests/functional/codegen/features/test_packing.py similarity index 100% rename from tests/parser/features/test_packing.py rename to tests/functional/codegen/features/test_packing.py diff --git a/tests/parser/features/test_reverting.py b/tests/functional/codegen/features/test_reverting.py similarity index 100% rename from tests/parser/features/test_reverting.py rename to tests/functional/codegen/features/test_reverting.py diff --git a/tests/parser/features/test_short_circuiting.py b/tests/functional/codegen/features/test_short_circuiting.py similarity index 100% rename from tests/parser/features/test_short_circuiting.py rename to tests/functional/codegen/features/test_short_circuiting.py diff --git a/tests/parser/features/test_string_map_keys.py b/tests/functional/codegen/features/test_string_map_keys.py similarity index 100% rename from tests/parser/features/test_string_map_keys.py rename to tests/functional/codegen/features/test_string_map_keys.py diff --git a/tests/parser/features/test_ternary.py b/tests/functional/codegen/features/test_ternary.py similarity index 100% rename from tests/parser/features/test_ternary.py rename to tests/functional/codegen/features/test_ternary.py diff --git a/tests/parser/features/test_transient.py b/tests/functional/codegen/features/test_transient.py similarity index 100% rename from tests/parser/features/test_transient.py rename to tests/functional/codegen/features/test_transient.py diff --git a/tests/parser/integration/test_basics.py b/tests/functional/codegen/integration/test_basics.py similarity index 100% rename from tests/parser/integration/test_basics.py rename to tests/functional/codegen/integration/test_basics.py diff --git a/tests/parser/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py similarity index 98% rename from tests/parser/integration/test_crowdfund.py rename to tests/functional/codegen/integration/test_crowdfund.py index c45a60d9c7..47c63dc015 100644 --- a/tests/parser/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -1,3 +1,4 @@ +# TODO: check, this is probably redundant with examples/test_crowdfund.py def test_crowdfund(w3, tester, get_contract_with_gas_estimation_for_constants): crowdfund = """ diff --git a/tests/parser/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py similarity index 100% rename from tests/parser/integration/test_escrow.py rename to tests/functional/codegen/integration/test_escrow.py diff --git a/tests/parser/globals/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py similarity index 100% rename from tests/parser/globals/test_getters.py rename to tests/functional/codegen/storage_variables/test_getters.py diff --git a/tests/parser/globals/test_setters.py b/tests/functional/codegen/storage_variables/test_setters.py similarity index 100% rename from tests/parser/globals/test_setters.py rename to tests/functional/codegen/storage_variables/test_setters.py diff --git a/tests/parser/globals/test_globals.py b/tests/functional/codegen/storage_variables/test_storage_variable.py similarity index 100% rename from tests/parser/globals/test_globals.py rename to tests/functional/codegen/storage_variables/test_storage_variable.py diff --git a/tests/parser/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py similarity index 100% rename from tests/parser/test_call_graph_stability.py rename to tests/functional/codegen/test_call_graph_stability.py diff --git a/tests/parser/test_selector_table.py b/tests/functional/codegen/test_selector_table.py similarity index 100% rename from tests/parser/test_selector_table.py rename to tests/functional/codegen/test_selector_table.py diff --git a/tests/parser/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py similarity index 100% rename from tests/parser/test_selector_table_stability.py rename to tests/functional/codegen/test_selector_table_stability.py diff --git a/tests/parser/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py similarity index 100% rename from tests/parser/types/numbers/test_constants.py rename to tests/functional/codegen/types/numbers/test_constants.py diff --git a/tests/parser/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py similarity index 100% rename from tests/parser/types/numbers/test_decimals.py rename to tests/functional/codegen/types/numbers/test_decimals.py diff --git a/tests/parser/features/arithmetic/test_division.py b/tests/functional/codegen/types/numbers/test_division.py similarity index 100% rename from tests/parser/features/arithmetic/test_division.py rename to tests/functional/codegen/types/numbers/test_division.py diff --git a/tests/fuzzing/test_exponents.py b/tests/functional/codegen/types/numbers/test_exponents.py similarity index 100% rename from tests/fuzzing/test_exponents.py rename to tests/functional/codegen/types/numbers/test_exponents.py diff --git a/tests/parser/types/numbers/test_isqrt.py b/tests/functional/codegen/types/numbers/test_isqrt.py similarity index 100% rename from tests/parser/types/numbers/test_isqrt.py rename to tests/functional/codegen/types/numbers/test_isqrt.py diff --git a/tests/parser/features/arithmetic/test_modulo.py b/tests/functional/codegen/types/numbers/test_modulo.py similarity index 100% rename from tests/parser/features/arithmetic/test_modulo.py rename to tests/functional/codegen/types/numbers/test_modulo.py diff --git a/tests/parser/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py similarity index 100% rename from tests/parser/types/numbers/test_signed_ints.py rename to tests/functional/codegen/types/numbers/test_signed_ints.py diff --git a/tests/parser/types/numbers/test_sqrt.py b/tests/functional/codegen/types/numbers/test_sqrt.py similarity index 100% rename from tests/parser/types/numbers/test_sqrt.py rename to tests/functional/codegen/types/numbers/test_sqrt.py diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py similarity index 100% rename from tests/parser/types/numbers/test_unsigned_ints.py rename to tests/functional/codegen/types/numbers/test_unsigned_ints.py diff --git a/tests/parser/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py similarity index 100% rename from tests/parser/types/test_bytes.py rename to tests/functional/codegen/types/test_bytes.py diff --git a/tests/parser/types/test_bytes_literal.py b/tests/functional/codegen/types/test_bytes_literal.py similarity index 100% rename from tests/parser/types/test_bytes_literal.py rename to tests/functional/codegen/types/test_bytes_literal.py diff --git a/tests/parser/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py similarity index 100% rename from tests/parser/types/test_bytes_zero_padding.py rename to tests/functional/codegen/types/test_bytes_zero_padding.py diff --git a/tests/parser/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py similarity index 100% rename from tests/parser/types/test_dynamic_array.py rename to tests/functional/codegen/types/test_dynamic_array.py diff --git a/tests/parser/types/test_enum.py b/tests/functional/codegen/types/test_enum.py similarity index 100% rename from tests/parser/types/test_enum.py rename to tests/functional/codegen/types/test_enum.py diff --git a/tests/parser/types/test_identifier_naming.py b/tests/functional/codegen/types/test_identifier_naming.py old mode 100755 new mode 100644 similarity index 100% rename from tests/parser/types/test_identifier_naming.py rename to tests/functional/codegen/types/test_identifier_naming.py diff --git a/tests/parser/types/test_lists.py b/tests/functional/codegen/types/test_lists.py similarity index 100% rename from tests/parser/types/test_lists.py rename to tests/functional/codegen/types/test_lists.py diff --git a/tests/parser/types/test_node_types.py b/tests/functional/codegen/types/test_node_types.py similarity index 100% rename from tests/parser/types/test_node_types.py rename to tests/functional/codegen/types/test_node_types.py diff --git a/tests/parser/types/test_string.py b/tests/functional/codegen/types/test_string.py similarity index 100% rename from tests/parser/types/test_string.py rename to tests/functional/codegen/types/test_string.py diff --git a/tests/parser/types/test_string_literal.py b/tests/functional/codegen/types/test_string_literal.py similarity index 100% rename from tests/parser/types/test_string_literal.py rename to tests/functional/codegen/types/test_string_literal.py diff --git a/tests/examples/auctions/test_blind_auction.py b/tests/functional/examples/auctions/test_blind_auction.py similarity index 100% rename from tests/examples/auctions/test_blind_auction.py rename to tests/functional/examples/auctions/test_blind_auction.py diff --git a/tests/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py similarity index 100% rename from tests/examples/auctions/test_simple_open_auction.py rename to tests/functional/examples/auctions/test_simple_open_auction.py diff --git a/tests/examples/company/test_company.py b/tests/functional/examples/company/test_company.py similarity index 100% rename from tests/examples/company/test_company.py rename to tests/functional/examples/company/test_company.py diff --git a/tests/examples/conftest.py b/tests/functional/examples/conftest.py similarity index 100% rename from tests/examples/conftest.py rename to tests/functional/examples/conftest.py diff --git a/tests/examples/crowdfund/test_crowdfund_example.py b/tests/functional/examples/crowdfund/test_crowdfund_example.py similarity index 100% rename from tests/examples/crowdfund/test_crowdfund_example.py rename to tests/functional/examples/crowdfund/test_crowdfund_example.py diff --git a/tests/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py similarity index 100% rename from tests/examples/factory/test_factory.py rename to tests/functional/examples/factory/test_factory.py diff --git a/tests/examples/market_maker/test_on_chain_market_maker.py b/tests/functional/examples/market_maker/test_on_chain_market_maker.py similarity index 100% rename from tests/examples/market_maker/test_on_chain_market_maker.py rename to tests/functional/examples/market_maker/test_on_chain_market_maker.py diff --git a/tests/examples/name_registry/test_name_registry.py b/tests/functional/examples/name_registry/test_name_registry.py similarity index 100% rename from tests/examples/name_registry/test_name_registry.py rename to tests/functional/examples/name_registry/test_name_registry.py diff --git a/tests/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py similarity index 100% rename from tests/examples/safe_remote_purchase/test_safe_remote_purchase.py rename to tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py diff --git a/tests/examples/storage/test_advanced_storage.py b/tests/functional/examples/storage/test_advanced_storage.py similarity index 100% rename from tests/examples/storage/test_advanced_storage.py rename to tests/functional/examples/storage/test_advanced_storage.py diff --git a/tests/examples/storage/test_storage.py b/tests/functional/examples/storage/test_storage.py similarity index 100% rename from tests/examples/storage/test_storage.py rename to tests/functional/examples/storage/test_storage.py diff --git a/tests/examples/tokens/test_erc1155.py b/tests/functional/examples/tokens/test_erc1155.py similarity index 100% rename from tests/examples/tokens/test_erc1155.py rename to tests/functional/examples/tokens/test_erc1155.py diff --git a/tests/examples/tokens/test_erc20.py b/tests/functional/examples/tokens/test_erc20.py similarity index 100% rename from tests/examples/tokens/test_erc20.py rename to tests/functional/examples/tokens/test_erc20.py diff --git a/tests/examples/tokens/test_erc4626.py b/tests/functional/examples/tokens/test_erc4626.py similarity index 100% rename from tests/examples/tokens/test_erc4626.py rename to tests/functional/examples/tokens/test_erc4626.py diff --git a/tests/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py similarity index 100% rename from tests/examples/tokens/test_erc721.py rename to tests/functional/examples/tokens/test_erc721.py diff --git a/tests/examples/voting/test_ballot.py b/tests/functional/examples/voting/test_ballot.py similarity index 100% rename from tests/examples/voting/test_ballot.py rename to tests/functional/examples/voting/test_ballot.py diff --git a/tests/examples/wallet/test_wallet.py b/tests/functional/examples/wallet/test_wallet.py similarity index 100% rename from tests/examples/wallet/test_wallet.py rename to tests/functional/examples/wallet/test_wallet.py diff --git a/tests/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py similarity index 100% rename from tests/grammar/test_grammar.py rename to tests/functional/grammar/test_grammar.py diff --git a/tests/parser/syntax/__init__.py b/tests/functional/syntax/__init__.py similarity index 100% rename from tests/parser/syntax/__init__.py rename to tests/functional/syntax/__init__.py diff --git a/tests/parser/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py similarity index 100% rename from tests/parser/exceptions/test_argument_exception.py rename to tests/functional/syntax/exceptions/test_argument_exception.py diff --git a/tests/parser/exceptions/test_call_violation.py b/tests/functional/syntax/exceptions/test_call_violation.py similarity index 100% rename from tests/parser/exceptions/test_call_violation.py rename to tests/functional/syntax/exceptions/test_call_violation.py diff --git a/tests/parser/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py similarity index 100% rename from tests/parser/exceptions/test_constancy_exception.py rename to tests/functional/syntax/exceptions/test_constancy_exception.py diff --git a/tests/parser/exceptions/test_function_declaration_exception.py b/tests/functional/syntax/exceptions/test_function_declaration_exception.py similarity index 100% rename from tests/parser/exceptions/test_function_declaration_exception.py rename to tests/functional/syntax/exceptions/test_function_declaration_exception.py diff --git a/tests/parser/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py similarity index 100% rename from tests/parser/exceptions/test_instantiation_exception.py rename to tests/functional/syntax/exceptions/test_instantiation_exception.py diff --git a/tests/parser/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py similarity index 100% rename from tests/parser/exceptions/test_invalid_literal_exception.py rename to tests/functional/syntax/exceptions/test_invalid_literal_exception.py diff --git a/tests/parser/exceptions/test_invalid_payable.py b/tests/functional/syntax/exceptions/test_invalid_payable.py similarity index 100% rename from tests/parser/exceptions/test_invalid_payable.py rename to tests/functional/syntax/exceptions/test_invalid_payable.py diff --git a/tests/parser/exceptions/test_invalid_reference.py b/tests/functional/syntax/exceptions/test_invalid_reference.py similarity index 100% rename from tests/parser/exceptions/test_invalid_reference.py rename to tests/functional/syntax/exceptions/test_invalid_reference.py diff --git a/tests/parser/exceptions/test_invalid_type_exception.py b/tests/functional/syntax/exceptions/test_invalid_type_exception.py similarity index 100% rename from tests/parser/exceptions/test_invalid_type_exception.py rename to tests/functional/syntax/exceptions/test_invalid_type_exception.py diff --git a/tests/parser/exceptions/test_namespace_collision.py b/tests/functional/syntax/exceptions/test_namespace_collision.py similarity index 100% rename from tests/parser/exceptions/test_namespace_collision.py rename to tests/functional/syntax/exceptions/test_namespace_collision.py diff --git a/tests/parser/exceptions/test_overflow_exception.py b/tests/functional/syntax/exceptions/test_overflow_exception.py similarity index 100% rename from tests/parser/exceptions/test_overflow_exception.py rename to tests/functional/syntax/exceptions/test_overflow_exception.py diff --git a/tests/parser/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py similarity index 100% rename from tests/parser/exceptions/test_structure_exception.py rename to tests/functional/syntax/exceptions/test_structure_exception.py diff --git a/tests/parser/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py similarity index 100% rename from tests/parser/exceptions/test_syntax_exception.py rename to tests/functional/syntax/exceptions/test_syntax_exception.py diff --git a/tests/parser/exceptions/test_type_mismatch_exception.py b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py similarity index 100% rename from tests/parser/exceptions/test_type_mismatch_exception.py rename to tests/functional/syntax/exceptions/test_type_mismatch_exception.py diff --git a/tests/parser/exceptions/test_undeclared_definition.py b/tests/functional/syntax/exceptions/test_undeclared_definition.py similarity index 100% rename from tests/parser/exceptions/test_undeclared_definition.py rename to tests/functional/syntax/exceptions/test_undeclared_definition.py diff --git a/tests/parser/exceptions/test_variable_declaration_exception.py b/tests/functional/syntax/exceptions/test_variable_declaration_exception.py similarity index 100% rename from tests/parser/exceptions/test_variable_declaration_exception.py rename to tests/functional/syntax/exceptions/test_variable_declaration_exception.py diff --git a/tests/parser/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py similarity index 100% rename from tests/parser/exceptions/test_vyper_exception_pos.py rename to tests/functional/syntax/exceptions/test_vyper_exception_pos.py diff --git a/tests/parser/syntax/utils/test_event_names.py b/tests/functional/syntax/names/test_event_names.py similarity index 100% rename from tests/parser/syntax/utils/test_event_names.py rename to tests/functional/syntax/names/test_event_names.py diff --git a/tests/parser/syntax/utils/test_function_names.py b/tests/functional/syntax/names/test_function_names.py similarity index 100% rename from tests/parser/syntax/utils/test_function_names.py rename to tests/functional/syntax/names/test_function_names.py diff --git a/tests/parser/syntax/utils/test_variable_names.py b/tests/functional/syntax/names/test_variable_names.py similarity index 100% rename from tests/parser/syntax/utils/test_variable_names.py rename to tests/functional/syntax/names/test_variable_names.py diff --git a/tests/signatures/test_invalid_function_decorators.py b/tests/functional/syntax/signatures/test_invalid_function_decorators.py similarity index 100% rename from tests/signatures/test_invalid_function_decorators.py rename to tests/functional/syntax/signatures/test_invalid_function_decorators.py diff --git a/tests/signatures/test_method_id_conflicts.py b/tests/functional/syntax/signatures/test_method_id_conflicts.py similarity index 100% rename from tests/signatures/test_method_id_conflicts.py rename to tests/functional/syntax/signatures/test_method_id_conflicts.py diff --git a/tests/parser/syntax/test_abi_decode.py b/tests/functional/syntax/test_abi_decode.py similarity index 100% rename from tests/parser/syntax/test_abi_decode.py rename to tests/functional/syntax/test_abi_decode.py diff --git a/tests/parser/syntax/test_abi_encode.py b/tests/functional/syntax/test_abi_encode.py similarity index 100% rename from tests/parser/syntax/test_abi_encode.py rename to tests/functional/syntax/test_abi_encode.py diff --git a/tests/parser/syntax/test_addmulmod.py b/tests/functional/syntax/test_addmulmod.py similarity index 100% rename from tests/parser/syntax/test_addmulmod.py rename to tests/functional/syntax/test_addmulmod.py diff --git a/tests/parser/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py similarity index 100% rename from tests/parser/syntax/test_address_code.py rename to tests/functional/syntax/test_address_code.py diff --git a/tests/parser/syntax/test_ann_assign.py b/tests/functional/syntax/test_ann_assign.py similarity index 100% rename from tests/parser/syntax/test_ann_assign.py rename to tests/functional/syntax/test_ann_assign.py diff --git a/tests/parser/syntax/test_as_uint256.py b/tests/functional/syntax/test_as_uint256.py similarity index 100% rename from tests/parser/syntax/test_as_uint256.py rename to tests/functional/syntax/test_as_uint256.py diff --git a/tests/parser/syntax/test_as_wei_value.py b/tests/functional/syntax/test_as_wei_value.py similarity index 100% rename from tests/parser/syntax/test_as_wei_value.py rename to tests/functional/syntax/test_as_wei_value.py diff --git a/tests/parser/syntax/test_block.py b/tests/functional/syntax/test_block.py similarity index 100% rename from tests/parser/syntax/test_block.py rename to tests/functional/syntax/test_block.py diff --git a/tests/parser/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py similarity index 100% rename from tests/parser/syntax/test_blockscope.py rename to tests/functional/syntax/test_blockscope.py diff --git a/tests/parser/syntax/test_bool.py b/tests/functional/syntax/test_bool.py similarity index 100% rename from tests/parser/syntax/test_bool.py rename to tests/functional/syntax/test_bool.py diff --git a/tests/parser/syntax/test_bool_ops.py b/tests/functional/syntax/test_bool_ops.py similarity index 100% rename from tests/parser/syntax/test_bool_ops.py rename to tests/functional/syntax/test_bool_ops.py diff --git a/tests/parser/syntax/test_bytes.py b/tests/functional/syntax/test_bytes.py similarity index 100% rename from tests/parser/syntax/test_bytes.py rename to tests/functional/syntax/test_bytes.py diff --git a/tests/parser/syntax/test_chainid.py b/tests/functional/syntax/test_chainid.py similarity index 100% rename from tests/parser/syntax/test_chainid.py rename to tests/functional/syntax/test_chainid.py diff --git a/tests/parser/syntax/test_code_size.py b/tests/functional/syntax/test_code_size.py similarity index 100% rename from tests/parser/syntax/test_code_size.py rename to tests/functional/syntax/test_code_size.py diff --git a/tests/parser/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py similarity index 100% rename from tests/parser/syntax/test_codehash.py rename to tests/functional/syntax/test_codehash.py diff --git a/tests/parser/syntax/test_concat.py b/tests/functional/syntax/test_concat.py similarity index 100% rename from tests/parser/syntax/test_concat.py rename to tests/functional/syntax/test_concat.py diff --git a/tests/parser/syntax/test_conditionals.py b/tests/functional/syntax/test_conditionals.py similarity index 100% rename from tests/parser/syntax/test_conditionals.py rename to tests/functional/syntax/test_conditionals.py diff --git a/tests/parser/syntax/test_constants.py b/tests/functional/syntax/test_constants.py similarity index 100% rename from tests/parser/syntax/test_constants.py rename to tests/functional/syntax/test_constants.py diff --git a/tests/parser/syntax/test_create_with_code_of.py b/tests/functional/syntax/test_create_with_code_of.py similarity index 100% rename from tests/parser/syntax/test_create_with_code_of.py rename to tests/functional/syntax/test_create_with_code_of.py diff --git a/tests/parser/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py similarity index 100% rename from tests/parser/syntax/test_dynamic_array.py rename to tests/functional/syntax/test_dynamic_array.py diff --git a/tests/parser/syntax/test_enum.py b/tests/functional/syntax/test_enum.py similarity index 100% rename from tests/parser/syntax/test_enum.py rename to tests/functional/syntax/test_enum.py diff --git a/tests/parser/syntax/test_extract32.py b/tests/functional/syntax/test_extract32.py similarity index 100% rename from tests/parser/syntax/test_extract32.py rename to tests/functional/syntax/test_extract32.py diff --git a/tests/parser/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py similarity index 100% rename from tests/parser/syntax/test_for_range.py rename to tests/functional/syntax/test_for_range.py diff --git a/tests/parser/syntax/test_functions_call.py b/tests/functional/syntax/test_functions_call.py similarity index 100% rename from tests/parser/syntax/test_functions_call.py rename to tests/functional/syntax/test_functions_call.py diff --git a/tests/parser/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py similarity index 100% rename from tests/parser/syntax/test_immutables.py rename to tests/functional/syntax/test_immutables.py diff --git a/tests/parser/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py similarity index 100% rename from tests/parser/syntax/test_interfaces.py rename to tests/functional/syntax/test_interfaces.py diff --git a/tests/parser/syntax/test_invalids.py b/tests/functional/syntax/test_invalids.py similarity index 100% rename from tests/parser/syntax/test_invalids.py rename to tests/functional/syntax/test_invalids.py diff --git a/tests/parser/syntax/test_keccak256.py b/tests/functional/syntax/test_keccak256.py similarity index 100% rename from tests/parser/syntax/test_keccak256.py rename to tests/functional/syntax/test_keccak256.py diff --git a/tests/parser/syntax/test_len.py b/tests/functional/syntax/test_len.py similarity index 100% rename from tests/parser/syntax/test_len.py rename to tests/functional/syntax/test_len.py diff --git a/tests/parser/syntax/test_list.py b/tests/functional/syntax/test_list.py similarity index 100% rename from tests/parser/syntax/test_list.py rename to tests/functional/syntax/test_list.py diff --git a/tests/parser/syntax/test_logging.py b/tests/functional/syntax/test_logging.py similarity index 100% rename from tests/parser/syntax/test_logging.py rename to tests/functional/syntax/test_logging.py diff --git a/tests/parser/syntax/test_minmax.py b/tests/functional/syntax/test_minmax.py similarity index 100% rename from tests/parser/syntax/test_minmax.py rename to tests/functional/syntax/test_minmax.py diff --git a/tests/parser/syntax/test_minmax_value.py b/tests/functional/syntax/test_minmax_value.py similarity index 100% rename from tests/parser/syntax/test_minmax_value.py rename to tests/functional/syntax/test_minmax_value.py diff --git a/tests/parser/syntax/test_msg_data.py b/tests/functional/syntax/test_msg_data.py similarity index 100% rename from tests/parser/syntax/test_msg_data.py rename to tests/functional/syntax/test_msg_data.py diff --git a/tests/parser/syntax/test_nested_list.py b/tests/functional/syntax/test_nested_list.py similarity index 100% rename from tests/parser/syntax/test_nested_list.py rename to tests/functional/syntax/test_nested_list.py diff --git a/tests/parser/syntax/test_no_none.py b/tests/functional/syntax/test_no_none.py similarity index 100% rename from tests/parser/syntax/test_no_none.py rename to tests/functional/syntax/test_no_none.py diff --git a/tests/parser/syntax/test_print.py b/tests/functional/syntax/test_print.py similarity index 100% rename from tests/parser/syntax/test_print.py rename to tests/functional/syntax/test_print.py diff --git a/tests/parser/syntax/test_public.py b/tests/functional/syntax/test_public.py similarity index 100% rename from tests/parser/syntax/test_public.py rename to tests/functional/syntax/test_public.py diff --git a/tests/parser/syntax/test_raw_call.py b/tests/functional/syntax/test_raw_call.py similarity index 100% rename from tests/parser/syntax/test_raw_call.py rename to tests/functional/syntax/test_raw_call.py diff --git a/tests/parser/syntax/test_return_tuple.py b/tests/functional/syntax/test_return_tuple.py similarity index 100% rename from tests/parser/syntax/test_return_tuple.py rename to tests/functional/syntax/test_return_tuple.py diff --git a/tests/parser/syntax/test_self_balance.py b/tests/functional/syntax/test_self_balance.py similarity index 100% rename from tests/parser/syntax/test_self_balance.py rename to tests/functional/syntax/test_self_balance.py diff --git a/tests/parser/syntax/test_selfdestruct.py b/tests/functional/syntax/test_selfdestruct.py similarity index 100% rename from tests/parser/syntax/test_selfdestruct.py rename to tests/functional/syntax/test_selfdestruct.py diff --git a/tests/parser/syntax/test_send.py b/tests/functional/syntax/test_send.py similarity index 100% rename from tests/parser/syntax/test_send.py rename to tests/functional/syntax/test_send.py diff --git a/tests/parser/syntax/test_slice.py b/tests/functional/syntax/test_slice.py similarity index 100% rename from tests/parser/syntax/test_slice.py rename to tests/functional/syntax/test_slice.py diff --git a/tests/parser/syntax/test_string.py b/tests/functional/syntax/test_string.py similarity index 100% rename from tests/parser/syntax/test_string.py rename to tests/functional/syntax/test_string.py diff --git a/tests/parser/syntax/test_structs.py b/tests/functional/syntax/test_structs.py similarity index 100% rename from tests/parser/syntax/test_structs.py rename to tests/functional/syntax/test_structs.py diff --git a/tests/parser/syntax/test_ternary.py b/tests/functional/syntax/test_ternary.py similarity index 100% rename from tests/parser/syntax/test_ternary.py rename to tests/functional/syntax/test_ternary.py diff --git a/tests/parser/syntax/test_tuple_assign.py b/tests/functional/syntax/test_tuple_assign.py similarity index 100% rename from tests/parser/syntax/test_tuple_assign.py rename to tests/functional/syntax/test_tuple_assign.py diff --git a/tests/parser/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py similarity index 100% rename from tests/parser/syntax/test_unbalanced_return.py rename to tests/functional/syntax/test_unbalanced_return.py diff --git a/tests/parser/functions/test_as_wei_value.py b/tests/parser/functions/test_as_wei_value.py deleted file mode 100644 index bab0aed616..0000000000 --- a/tests/parser/functions/test_as_wei_value.py +++ /dev/null @@ -1,31 +0,0 @@ -def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): - code = """ -@external -def foo(a: Foo) -> uint256: - return as_wei_value(a.foo(7), "ether") - -interface Foo: - def foo(x: uint8) -> uint8: nonpayable - """ - - c1 = side_effects_contract("uint8") - c2 = get_contract(code) - - assert c2.foo(c1.address) == w3.to_wei(7, "ether") - assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) - - -def test_internal_call(w3, get_contract_with_gas_estimation): - code = """ -@external -def foo() -> uint256: - return as_wei_value(self.bar(), "ether") - -@internal -def bar() -> uint8: - return 7 - """ - - c = get_contract_with_gas_estimation(code) - - assert c.foo() == w3.to_wei(7, "ether") diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/abi_types/test_invalid_abi_types.py b/tests/unit/abi_types/test_invalid_abi_types.py similarity index 100% rename from tests/abi_types/test_invalid_abi_types.py rename to tests/unit/abi_types/test_invalid_abi_types.py diff --git a/tests/ast/nodes/test_binary.py b/tests/unit/ast/nodes/test_binary.py similarity index 100% rename from tests/ast/nodes/test_binary.py rename to tests/unit/ast/nodes/test_binary.py diff --git a/tests/ast/nodes/test_compare_nodes.py b/tests/unit/ast/nodes/test_compare_nodes.py similarity index 100% rename from tests/ast/nodes/test_compare_nodes.py rename to tests/unit/ast/nodes/test_compare_nodes.py diff --git a/tests/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_evaluate_binop_decimal.py similarity index 100% rename from tests/ast/nodes/test_evaluate_binop_decimal.py rename to tests/unit/ast/nodes/test_evaluate_binop_decimal.py diff --git a/tests/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_evaluate_binop_int.py similarity index 100% rename from tests/ast/nodes/test_evaluate_binop_int.py rename to tests/unit/ast/nodes/test_evaluate_binop_int.py diff --git a/tests/ast/nodes/test_evaluate_boolop.py b/tests/unit/ast/nodes/test_evaluate_boolop.py similarity index 100% rename from tests/ast/nodes/test_evaluate_boolop.py rename to tests/unit/ast/nodes/test_evaluate_boolop.py diff --git a/tests/ast/nodes/test_evaluate_compare.py b/tests/unit/ast/nodes/test_evaluate_compare.py similarity index 100% rename from tests/ast/nodes/test_evaluate_compare.py rename to tests/unit/ast/nodes/test_evaluate_compare.py diff --git a/tests/ast/nodes/test_evaluate_subscript.py b/tests/unit/ast/nodes/test_evaluate_subscript.py similarity index 100% rename from tests/ast/nodes/test_evaluate_subscript.py rename to tests/unit/ast/nodes/test_evaluate_subscript.py diff --git a/tests/ast/nodes/test_evaluate_unaryop.py b/tests/unit/ast/nodes/test_evaluate_unaryop.py similarity index 100% rename from tests/ast/nodes/test_evaluate_unaryop.py rename to tests/unit/ast/nodes/test_evaluate_unaryop.py diff --git a/tests/ast/nodes/test_from_node.py b/tests/unit/ast/nodes/test_from_node.py similarity index 100% rename from tests/ast/nodes/test_from_node.py rename to tests/unit/ast/nodes/test_from_node.py diff --git a/tests/ast/nodes/test_get_children.py b/tests/unit/ast/nodes/test_get_children.py similarity index 100% rename from tests/ast/nodes/test_get_children.py rename to tests/unit/ast/nodes/test_get_children.py diff --git a/tests/ast/nodes/test_get_descendants.py b/tests/unit/ast/nodes/test_get_descendants.py similarity index 100% rename from tests/ast/nodes/test_get_descendants.py rename to tests/unit/ast/nodes/test_get_descendants.py diff --git a/tests/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py similarity index 100% rename from tests/ast/nodes/test_hex.py rename to tests/unit/ast/nodes/test_hex.py diff --git a/tests/ast/nodes/test_replace_in_tree.py b/tests/unit/ast/nodes/test_replace_in_tree.py similarity index 100% rename from tests/ast/nodes/test_replace_in_tree.py rename to tests/unit/ast/nodes/test_replace_in_tree.py diff --git a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py similarity index 100% rename from tests/parser/parser_utils/test_annotate_and_optimize_ast.py rename to tests/unit/ast/test_annotate_and_optimize_ast.py diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py similarity index 100% rename from tests/parser/ast_utils/test_ast_dict.py rename to tests/unit/ast/test_ast_dict.py diff --git a/tests/ast/test_folding.py b/tests/unit/ast/test_folding.py similarity index 100% rename from tests/ast/test_folding.py rename to tests/unit/ast/test_folding.py diff --git a/tests/ast/test_metadata_journal.py b/tests/unit/ast/test_metadata_journal.py similarity index 100% rename from tests/ast/test_metadata_journal.py rename to tests/unit/ast/test_metadata_journal.py diff --git a/tests/ast/test_natspec.py b/tests/unit/ast/test_natspec.py similarity index 100% rename from tests/ast/test_natspec.py rename to tests/unit/ast/test_natspec.py diff --git a/tests/parser/ast_utils/test_ast.py b/tests/unit/ast/test_parser.py similarity index 100% rename from tests/parser/ast_utils/test_ast.py rename to tests/unit/ast/test_parser.py diff --git a/tests/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py similarity index 100% rename from tests/ast/test_pre_parser.py rename to tests/unit/ast/test_pre_parser.py diff --git a/tests/test_utils.py b/tests/unit/ast/test_source_annotation.py similarity index 100% rename from tests/test_utils.py rename to tests/unit/ast/test_source_annotation.py diff --git a/tests/cli/outputs/test_storage_layout.py b/tests/unit/cli/outputs/test_storage_layout.py similarity index 100% rename from tests/cli/outputs/test_storage_layout.py rename to tests/unit/cli/outputs/test_storage_layout.py diff --git a/tests/cli/outputs/test_storage_layout_overrides.py b/tests/unit/cli/outputs/test_storage_layout_overrides.py similarity index 100% rename from tests/cli/outputs/test_storage_layout_overrides.py rename to tests/unit/cli/outputs/test_storage_layout_overrides.py diff --git a/tests/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py similarity index 100% rename from tests/cli/vyper_compile/test_compile_files.py rename to tests/unit/cli/vyper_compile/test_compile_files.py diff --git a/tests/cli/vyper_compile/test_parse_args.py b/tests/unit/cli/vyper_compile/test_parse_args.py similarity index 100% rename from tests/cli/vyper_compile/test_parse_args.py rename to tests/unit/cli/vyper_compile/test_parse_args.py diff --git a/tests/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py similarity index 100% rename from tests/cli/vyper_json/test_compile_json.py rename to tests/unit/cli/vyper_json/test_compile_json.py diff --git a/tests/cli/vyper_json/test_get_inputs.py b/tests/unit/cli/vyper_json/test_get_inputs.py similarity index 100% rename from tests/cli/vyper_json/test_get_inputs.py rename to tests/unit/cli/vyper_json/test_get_inputs.py diff --git a/tests/cli/vyper_json/test_get_settings.py b/tests/unit/cli/vyper_json/test_get_settings.py similarity index 100% rename from tests/cli/vyper_json/test_get_settings.py rename to tests/unit/cli/vyper_json/test_get_settings.py diff --git a/tests/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py similarity index 100% rename from tests/cli/vyper_json/test_output_selection.py rename to tests/unit/cli/vyper_json/test_output_selection.py diff --git a/tests/cli/vyper_json/test_parse_args_vyperjson.py b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py similarity index 100% rename from tests/cli/vyper_json/test_parse_args_vyperjson.py rename to tests/unit/cli/vyper_json/test_parse_args_vyperjson.py diff --git a/tests/compiler/__init__.py b/tests/unit/compiler/__init__.py similarity index 100% rename from tests/compiler/__init__.py rename to tests/unit/compiler/__init__.py diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py similarity index 100% rename from tests/compiler/asm/test_asm_optimizer.py rename to tests/unit/compiler/asm/test_asm_optimizer.py diff --git a/tests/unit/compiler/ir/__init__.py b/tests/unit/compiler/ir/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/compiler/test_calldatacopy.py b/tests/unit/compiler/ir/test_calldatacopy.py similarity index 100% rename from tests/compiler/test_calldatacopy.py rename to tests/unit/compiler/ir/test_calldatacopy.py diff --git a/tests/compiler/ir/test_compile_ir.py b/tests/unit/compiler/ir/test_compile_ir.py similarity index 100% rename from tests/compiler/ir/test_compile_ir.py rename to tests/unit/compiler/ir/test_compile_ir.py diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/unit/compiler/ir/test_optimize_ir.py similarity index 100% rename from tests/compiler/ir/test_optimize_ir.py rename to tests/unit/compiler/ir/test_optimize_ir.py diff --git a/tests/compiler/ir/test_repeat.py b/tests/unit/compiler/ir/test_repeat.py similarity index 100% rename from tests/compiler/ir/test_repeat.py rename to tests/unit/compiler/ir/test_repeat.py diff --git a/tests/compiler/ir/test_with.py b/tests/unit/compiler/ir/test_with.py similarity index 100% rename from tests/compiler/ir/test_with.py rename to tests/unit/compiler/ir/test_with.py diff --git a/tests/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py similarity index 100% rename from tests/compiler/test_bytecode_runtime.py rename to tests/unit/compiler/test_bytecode_runtime.py diff --git a/tests/compiler/test_compile_code.py b/tests/unit/compiler/test_compile_code.py similarity index 100% rename from tests/compiler/test_compile_code.py rename to tests/unit/compiler/test_compile_code.py diff --git a/tests/compiler/test_default_settings.py b/tests/unit/compiler/test_default_settings.py similarity index 100% rename from tests/compiler/test_default_settings.py rename to tests/unit/compiler/test_default_settings.py diff --git a/tests/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py similarity index 100% rename from tests/compiler/test_input_bundle.py rename to tests/unit/compiler/test_input_bundle.py diff --git a/tests/compiler/test_opcodes.py b/tests/unit/compiler/test_opcodes.py similarity index 100% rename from tests/compiler/test_opcodes.py rename to tests/unit/compiler/test_opcodes.py diff --git a/tests/compiler/test_pre_parser.py b/tests/unit/compiler/test_pre_parser.py similarity index 100% rename from tests/compiler/test_pre_parser.py rename to tests/unit/compiler/test_pre_parser.py diff --git a/tests/compiler/test_sha3_32.py b/tests/unit/compiler/test_sha3_32.py similarity index 100% rename from tests/compiler/test_sha3_32.py rename to tests/unit/compiler/test_sha3_32.py diff --git a/tests/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py similarity index 100% rename from tests/compiler/test_source_map.py rename to tests/unit/compiler/test_source_map.py diff --git a/tests/functional/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py similarity index 100% rename from tests/functional/semantics/analysis/test_array_index.py rename to tests/unit/semantics/analysis/test_array_index.py diff --git a/tests/functional/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py similarity index 100% rename from tests/functional/semantics/analysis/test_cyclic_function_calls.py rename to tests/unit/semantics/analysis/test_cyclic_function_calls.py diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py similarity index 100% rename from tests/functional/semantics/analysis/test_for_loop.py rename to tests/unit/semantics/analysis/test_for_loop.py diff --git a/tests/functional/semantics/analysis/test_potential_types.py b/tests/unit/semantics/analysis/test_potential_types.py similarity index 100% rename from tests/functional/semantics/analysis/test_potential_types.py rename to tests/unit/semantics/analysis/test_potential_types.py diff --git a/tests/functional/semantics/conftest.py b/tests/unit/semantics/conftest.py similarity index 100% rename from tests/functional/semantics/conftest.py rename to tests/unit/semantics/conftest.py diff --git a/tests/functional/semantics/test_namespace.py b/tests/unit/semantics/test_namespace.py similarity index 100% rename from tests/functional/semantics/test_namespace.py rename to tests/unit/semantics/test_namespace.py diff --git a/tests/functional/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py similarity index 100% rename from tests/functional/test_storage_slots.py rename to tests/unit/semantics/test_storage_slots.py diff --git a/tests/functional/semantics/types/test_event.py b/tests/unit/semantics/types/test_event.py similarity index 100% rename from tests/functional/semantics/types/test_event.py rename to tests/unit/semantics/types/test_event.py diff --git a/tests/functional/semantics/types/test_pure_types.py b/tests/unit/semantics/types/test_pure_types.py similarity index 100% rename from tests/functional/semantics/types/test_pure_types.py rename to tests/unit/semantics/types/test_pure_types.py diff --git a/tests/functional/semantics/types/test_size_in_bytes.py b/tests/unit/semantics/types/test_size_in_bytes.py similarity index 100% rename from tests/functional/semantics/types/test_size_in_bytes.py rename to tests/unit/semantics/types/test_size_in_bytes.py diff --git a/tests/functional/semantics/types/test_type_from_abi.py b/tests/unit/semantics/types/test_type_from_abi.py similarity index 100% rename from tests/functional/semantics/types/test_type_from_abi.py rename to tests/unit/semantics/types/test_type_from_abi.py diff --git a/tests/functional/semantics/types/test_type_from_annotation.py b/tests/unit/semantics/types/test_type_from_annotation.py similarity index 100% rename from tests/functional/semantics/types/test_type_from_annotation.py rename to tests/unit/semantics/types/test_type_from_annotation.py From 98f502baea6385fe25dbf94a70fb4eddc9f02f56 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 20 Nov 2023 23:59:23 +0800 Subject: [PATCH 126/201] feat: remove `vyper-serve` (#3666) moving it out into a separate project --- vyper/__main__.py | 8 +-- vyper/cli/vyper_serve.py | 127 --------------------------------------- 2 files changed, 3 insertions(+), 132 deletions(-) delete mode 100755 vyper/cli/vyper_serve.py diff --git a/vyper/__main__.py b/vyper/__main__.py index 371975c301..c5bda47bea 100644 --- a/vyper/__main__.py +++ b/vyper/__main__.py @@ -2,10 +2,10 @@ # -*- coding: UTF-8 -*- import sys -from vyper.cli import vyper_compile, vyper_ir, vyper_serve +from vyper.cli import vyper_compile, vyper_ir if __name__ == "__main__": - allowed_subcommands = ("--vyper-compile", "--vyper-ir", "--vyper-serve") + allowed_subcommands = ("--vyper-compile", "--vyper-ir") if len(sys.argv) <= 1 or sys.argv[1] not in allowed_subcommands: # default (no args, no switch in first arg): run vyper_compile @@ -13,9 +13,7 @@ else: # pop switch and forward args to subcommand subcommand = sys.argv.pop(1) - if subcommand == "--vyper-serve": - vyper_serve._parse_cli_args() - elif subcommand == "--vyper-ir": + if subcommand == "--vyper-ir": vyper_ir._parse_cli_args() else: vyper_compile._parse_cli_args() diff --git a/vyper/cli/vyper_serve.py b/vyper/cli/vyper_serve.py deleted file mode 100755 index 9771dc922d..0000000000 --- a/vyper/cli/vyper_serve.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import sys -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn - -import vyper -from vyper.codegen import ir_node -from vyper.evm.opcodes import DEFAULT_EVM_VERSION -from vyper.exceptions import VyperException - - -def _parse_cli_args(): - return _parse_args(sys.argv[1:]) - - -def _parse_args(argv): - parser = argparse.ArgumentParser(description="Serve Vyper compiler as an HTTP Service") - parser.add_argument( - "--version", action="version", version=f"{vyper.__version__}+commit{vyper.__commit__}" - ) - parser.add_argument( - "-b", - help="Address to bind JSON server on, default: localhost:8000", - default="localhost:8000", - dest="bind_address", - ) - - args = parser.parse_args(argv) - - if ":" in args.bind_address: - ir_node.VYPER_COLOR_OUTPUT = False - runserver(*args.bind_address.split(":")) - else: - print('Provide bind address in "{address}:{port}" format') - - -class VyperRequestHandler(BaseHTTPRequestHandler): - def send_404(self): - self.send_response(404) - self.end_headers() - return - - def send_cors_all(self): - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-type") - - def do_OPTIONS(self): - self.send_response(200) - self.send_cors_all() - self.end_headers() - - def do_GET(self): - if self.path == "/": - self.send_response(200) - self.send_cors_all() - self.end_headers() - self.wfile.write(f"Vyper Compiler. Version: {vyper.__version__}\n".encode()) - else: - self.send_404() - - return - - def do_POST(self): - if self.path == "/compile": - content_len = int(self.headers.get("content-length")) - post_body = self.rfile.read(content_len) - data = json.loads(post_body) - - response, status_code = self._compile(data) - - self.send_response(status_code) - self.send_header("Content-type", "application/json") - self.send_cors_all() - self.end_headers() - self.wfile.write(json.dumps(response).encode()) - - else: - self.send_404() - - return - - def _compile(self, data): - code = data.get("code") - if not code: - return {"status": "failed", "message": 'No "code" key supplied'}, 400 - if not isinstance(code, str): - return {"status": "failed", "message": '"code" must be a non-empty string'}, 400 - - try: - code = data["code"] - out_dict = vyper.compile_code( - code, - list(vyper.compiler.OUTPUT_FORMATS.keys()), - evm_version=data.get("evm_version", DEFAULT_EVM_VERSION), - ) - out_dict["ir"] = str(out_dict["ir"]) - out_dict["ir_runtime"] = str(out_dict["ir_runtime"]) - except VyperException as e: - return ( - {"status": "failed", "message": str(e), "column": e.col_offset, "line": e.lineno}, - 400, - ) - except SyntaxError as e: - return ( - {"status": "failed", "message": str(e), "column": e.offset, "line": e.lineno}, - 400, - ) - - out_dict.update({"status": "success"}) - - return out_dict, 200 - - -class VyperHTTPServer(ThreadingMixIn, HTTPServer): - """Handle requests in a separate thread.""" - - pass - - -def runserver(host="", port=8000): - server_address = (host, int(port)) - httpd = VyperHTTPServer(server_address, VyperRequestHandler) - print(f"Listening on http://{host}:{port}") - httpd.serve_forever() From 28b1121e6ca8042d10a68a3d91df016bc7b83c5f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 21 Nov 2023 08:36:13 -0500 Subject: [PATCH 127/201] perf: lazy eval of f-strings in IRnode ctor (#3602) 25% of IR generation is in IRnode.__repr__ due to the references to self in the f-strings for panic messages. this commit switches to using `assert`, which accomplishes the same thing, but lazily evaluating the error messages (and the code is slightly less pretty) --- vyper/codegen/ir_node.py | 95 ++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ad4aa76437..e17ef47c8f 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -202,27 +202,23 @@ def __init__( self.encoding = encoding self.as_hex = AS_HEX_DEFAULT - def _check(condition, err): - if not condition: - raise CompilerPanic(str(err)) - - _check(self.value is not None, "None is not allowed as IRnode value") + assert self.value is not None, "None is not allowed as IRnode value" # Determine this node's valency (1 if it pushes a value on the stack, # 0 otherwise) and checks to make sure the number and valencies of # children are correct. Also, find an upper bound on gas consumption # Numbers if isinstance(self.value, int): - _check(len(self.args) == 0, "int can't have arguments") + assert len(self.args) == 0, "int can't have arguments" # integers must be in the range (MIN_INT256, MAX_UINT256) - _check(-(2**255) <= self.value < 2**256, "out of range") + assert -(2**255) <= self.value < 2**256, "out of range" self.valency = 1 self._gas = 5 elif isinstance(self.value, bytes): # a literal bytes value, probably inside a "data" node. - _check(len(self.args) == 0, "bytes can't have arguments") + assert len(self.args) == 0, "bytes can't have arguments" self.valency = 0 self._gas = 0 @@ -232,10 +228,9 @@ def _check(condition, err): if self.value.upper() in get_ir_opcodes(): _, ins, outs, gas = get_ir_opcodes()[self.value.upper()] self.valency = outs - _check( - len(self.args) == ins, - f"Number of arguments mismatched: {self.value} {self.args}", - ) + assert ( + len(self.args) == ins + ), f"Number of arguments mismatched: {self.value} {self.args}" # We add 2 per stack height at push time and take it back # at pop time; this makes `break` easier to handle self._gas = gas + 2 * (outs - ins) @@ -244,10 +239,10 @@ def _check(condition, err): # consumed for internal functions, therefore we whitelist this as a zero valency # allowed argument. zero_valency_whitelist = {"pass", "pop"} - _check( - arg.valency == 1 or arg.value in zero_valency_whitelist, - f"invalid argument to `{self.value}`: {arg}", - ) + assert ( + arg.valency == 1 or arg.value in zero_valency_whitelist + ), f"invalid argument to `{self.value}`: {arg}" + self._gas += arg.gas # Dynamic gas cost: 8 gas for each byte of logging data if self.value.upper()[0:3] == "LOG" and isinstance(self.args[1].value, int): @@ -275,30 +270,27 @@ def _check(condition, err): self._gas = self.args[0].gas + max(self.args[1].gas, self.args[2].gas) + 3 if len(self.args) == 2: self._gas = self.args[0].gas + self.args[1].gas + 17 - _check( - self.args[0].valency > 0, - f"zerovalent argument as a test to an if statement: {self.args[0]}", - ) - _check(len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments") + assert ( + self.args[0].valency > 0 + ), f"zerovalent argument as a test to an if statement: {self.args[0]}" + assert len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments" self.valency = self.args[1].valency # With statements: with elif self.value == "with": - _check(len(self.args) == 3, self) - _check( - len(self.args[0].args) == 0 and isinstance(self.args[0].value, str), - f"first argument to with statement must be a variable name: {self.args[0]}", - ) - _check( - self.args[1].valency == 1 or self.args[1].value == "pass", - f"zerovalent argument to with statement: {self.args[1]}", - ) + assert len(self.args) == 3, self + assert len(self.args[0].args) == 0 and isinstance( + self.args[0].value, str + ), f"first argument to with statement must be a variable name: {self.args[0]}" + assert ( + self.args[1].valency == 1 or self.args[1].value == "pass" + ), f"zerovalent argument to with statement: {self.args[1]}" self.valency = self.args[2].valency self._gas = sum([arg.gas for arg in self.args]) + 5 # Repeat statements: repeat elif self.value == "repeat": - _check( - len(self.args) == 5, "repeat(index_name, startval, rounds, rounds_bound, body)" - ) + assert ( + len(self.args) == 5 + ), "repeat(index_name, startval, rounds, rounds_bound, body)" counter_ptr = self.args[0] start = self.args[1] @@ -306,13 +298,12 @@ def _check(condition, err): repeat_bound = self.args[3] body = self.args[4] - _check( - isinstance(repeat_bound.value, int) and repeat_bound.value > 0, - f"repeat bound must be a compile-time positive integer: {self.args[2]}", - ) - _check(repeat_count.valency == 1, repeat_count) - _check(counter_ptr.valency == 1, counter_ptr) - _check(start.valency == 1, start) + assert ( + isinstance(repeat_bound.value, int) and repeat_bound.value > 0 + ), f"repeat bound must be a compile-time positive integer: {self.args[2]}" + assert repeat_count.valency == 1, repeat_count + assert counter_ptr.valency == 1, counter_ptr + assert start.valency == 1, start self.valency = 0 @@ -335,19 +326,17 @@ def _check(condition, err): # then JUMP to my_label. elif self.value in ("goto", "exit_to"): for arg in self.args: - _check( - arg.valency == 1 or arg.value == "pass", - f"zerovalent argument to goto {arg}", - ) + assert ( + arg.valency == 1 or arg.value == "pass" + ), f"zerovalent argument to goto {arg}" self.valency = 0 self._gas = sum([arg.gas for arg in self.args]) elif self.value == "label": - _check( - self.args[1].value == "var_list", - f"2nd argument to label must be var_list, {self}", - ) - _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}") + assert ( + self.args[1].value == "var_list" + ), f"2nd argument to label must be var_list, {self}" + assert len(args) == 3, f"label should have 3 args but has {len(args)}, {self}" self.valency = 0 self._gas = 1 + sum(t.gas for t in self.args) elif self.value == "unique_symbol": @@ -371,14 +360,14 @@ def _check(condition, err): # Multi statements: multi ... elif self.value == "multi": for arg in self.args: - _check( - arg.valency > 0, f"Multi expects all children to not be zerovalent: {arg}" - ) + assert ( + arg.valency > 0 + ), f"Multi expects all children to not be zerovalent: {arg}" self.valency = sum([arg.valency for arg in self.args]) self._gas = sum([arg.gas for arg in self.args]) elif self.value == "deploy": self.valency = 0 - _check(len(self.args) == 3, f"`deploy` should have three args {self}") + assert len(self.args) == 3, f"`deploy` should have three args {self}" self._gas = NullAttractor() # unknown # Stack variables else: From b16ab914fc6126894e19172ba08df0193653edab Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 21 Nov 2023 09:16:07 -0500 Subject: [PATCH 128/201] docs: add script to help working on the compiler (#3674) --- README.md | 17 +++++++++++++++++ docs/contributing.rst | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bad929956d..33c4557cc8 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,23 @@ make dev-init python setup.py test ``` +## Developing (working on the compiler) + +A useful script to have in your PATH is something like the following: +```bash +$ cat ~/.local/bin/vyc +#!/usr/bin/env bash +PYTHONPATH=. python vyper/cli/vyper_compile.py "$@" +``` + +To run a python performance profile (to find compiler perf hotspots): +```bash +PYTHONPATH=. python -m cProfile -s tottime vyper/cli/vyper_compile.py "$@" +``` + +To get a call graph from a python profile, https://stackoverflow.com/a/23164271/ is helpful. + + # Contributing * See Issues tab, and feel free to submit your own issues * Add PRs if you discover a solution to an existing issue diff --git a/docs/contributing.rst b/docs/contributing.rst index 6dc57b26c3..221600f930 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -75,4 +75,4 @@ If you are making a larger change, please consult first with the `Vyper (Smart C Although we do CI testing, please make sure that the tests pass for supported Python version and ensure that it builds locally before submitting a pull request. -Thank you for your help! ​ +Thank you for your help! From aa1ea21a79e577227e13b9756a8c26107c5b3674 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 23 Nov 2023 03:02:47 +0800 Subject: [PATCH 129/201] refactor: builtin functions inherit from `VyperType` (#3559) for consistency, have builtin functions inherit from `VyperType`. --- vyper/builtins/_signatures.py | 25 +++++----- vyper/builtins/functions.py | 91 +++++++++++++++++------------------ 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 2802421129..a5949dfd85 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,5 +1,5 @@ import functools -from typing import Dict +from typing import Any, Optional from vyper.ast import nodes as vy_ast from vyper.ast.validation import validate_call_args @@ -74,12 +74,14 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction(VyperType): +class BuiltinFunctionT(VyperType): _has_varargs = False - _kwargs: Dict[str, KwargSettings] = {} + _inputs: list[tuple[str, Any]] = [] + _kwargs: dict[str, KwargSettings] = {} + _return_type: Optional[VyperType] = None # helper function to deal with TYPE_DEFINITIONs - def _validate_single(self, arg, expected_type): + def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: # TODO using "TYPE_DEFINITION" is a kludge in derived classes, # refactor me. if expected_type == "TYPE_DEFINITION": @@ -89,15 +91,15 @@ def _validate_single(self, arg, expected_type): else: validate_expected_type(arg, expected_type) - def _validate_arg_types(self, node): + def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates - expect_num_args = num_args + expect_num_args: Any = num_args if self._has_varargs: # note special meaning for -1 in validate_call_args API expect_num_args = (num_args, -1) - validate_call_args(node, expect_num_args, self._kwargs) + validate_call_args(node, expect_num_args, list(self._kwargs.keys())) for arg, (_, expected) in zip(node.args, self._inputs): self._validate_single(arg, expected) @@ -118,13 +120,12 @@ def _validate_arg_types(self, node): # ensures the type can be inferred exactly. get_exact_type_from_node(arg) - def fetch_call_return(self, node): + def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) - if self._return_type: - return self._return_type + return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] @@ -136,7 +137,7 @@ def infer_arg_types(self, node): ret.extend(get_exact_type_from_node(arg) for arg in varargs) return ret - def infer_kwarg_types(self, node): + def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} def __repr__(self): diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 001939638b..b2d817ec5c 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -98,14 +98,14 @@ ) from ._convert import convert -from ._signatures import BuiltinFunction, process_inputs +from ._signatures import BuiltinFunctionT, process_inputs SHA256_ADDRESS = 2 SHA256_BASE_GAS = 60 SHA256_PER_WORD_GAS = 12 -class FoldedFunction(BuiltinFunction): +class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded # Since foldable builtin functions are not folded before semantics validation, @@ -113,7 +113,7 @@ class FoldedFunction(BuiltinFunction): _kwargable = True -class TypenameFoldedFunction(FoldedFunction): +class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. @@ -132,7 +132,7 @@ def infer_arg_types(self, node): return [input_typedef] -class Floor(BuiltinFunction): +class Floor(BuiltinFunctionT): _id = "floor" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -162,7 +162,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Ceil(BuiltinFunction): +class Ceil(BuiltinFunctionT): _id = "ceil" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -192,7 +192,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Convert(BuiltinFunction): +class Convert(BuiltinFunctionT): _id = "convert" def fetch_call_return(self, node): @@ -285,14 +285,13 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: # note: this and a lot of other builtins could be refactored to accept any uint type -class Slice(BuiltinFunction): +class Slice(BuiltinFunctionT): _id = "slice" _inputs = [ ("b", (BYTES32_T, BytesT.any(), StringT.any())), ("start", UINT256_T), ("length", UINT256_T), ] - _return_type = None def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node) @@ -457,7 +456,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(b3.resolve(ret))) -class Len(BuiltinFunction): +class Len(BuiltinFunctionT): _id = "len" _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T @@ -488,7 +487,7 @@ def build_IR(self, node, context): return get_bytearray_length(arg) -class Concat(BuiltinFunction): +class Concat(BuiltinFunctionT): _id = "concat" def fetch_call_return(self, node): @@ -593,7 +592,7 @@ def build_IR(self, expr, context): ) -class Keccak256(BuiltinFunction): +class Keccak256(BuiltinFunctionT): _id = "keccak256" # TODO allow any BytesM_T _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] @@ -641,7 +640,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len): ] -class Sha256(BuiltinFunction): +class Sha256(BuiltinFunctionT): _id = "sha256" _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T @@ -713,7 +712,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class MethodID(FoldedFunction): +class MethodID(FoldedFunctionT): _id = "method_id" def evaluate(self, node): @@ -753,7 +752,7 @@ def infer_kwarg_types(self, node): return BytesT(4) -class ECRecover(BuiltinFunction): +class ECRecover(BuiltinFunctionT): _id = "ecrecover" _inputs = [ ("hash", BYTES32_T), @@ -788,7 +787,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class _ECArith(BuiltinFunction): +class _ECArith(BuiltinFunctionT): @process_inputs def build_IR(self, expr, _args, kwargs, context): args_tuple = ir_tuple_from_args(_args) @@ -847,14 +846,13 @@ def _storage_element_getter(index): return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T) -class Extract32(BuiltinFunction): +class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] # "TYPE_DEFINITION" is a placeholder value for a type definition string, and # will be replaced by a `TYPE_T` object in `infer_kwarg_types` # (note that it is ignored in _validate_arg_types) _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -959,7 +957,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(clamp_basetype(o), typ=ret_type) -class AsWeiValue(BuiltinFunction): +class AsWeiValue(BuiltinFunctionT): _id = "as_wei_value" _inputs = [("value", (IntegerT.any(), DecimalT())), ("unit", StringT.any())] _return_type = UINT256_T @@ -1058,7 +1056,7 @@ def build_IR(self, expr, args, kwargs, context): empty_value = IRnode.from_list(0, typ=BYTES32_T) -class RawCall(BuiltinFunction): +class RawCall(BuiltinFunctionT): _id = "raw_call" _inputs = [("to", AddressT()), ("data", BytesT.any())] _kwargs = { @@ -1069,7 +1067,6 @@ class RawCall(BuiltinFunction): "is_static_call": KwargSettings(BoolT(), False, require_literal=True), "revert_on_failure": KwargSettings(BoolT(), True, require_literal=True), } - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1215,12 +1212,11 @@ def build_IR(self, expr, args, kwargs, context): raise CompilerPanic("unreachable!") -class Send(BuiltinFunction): +class Send(BuiltinFunctionT): _id = "send" _inputs = [("to", AddressT()), ("value", UINT256_T)] # default gas stipend is 0 _kwargs = {"gas": KwargSettings(UINT256_T, 0)} - _return_type = None @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -1232,10 +1228,9 @@ def build_IR(self, expr, args, kwargs, context): ) -class SelfDestruct(BuiltinFunction): +class SelfDestruct(BuiltinFunctionT): _id = "selfdestruct" _inputs = [("to", AddressT())] - _return_type = None _is_terminus = True _warned = False @@ -1251,7 +1246,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BlockHash(BuiltinFunction): +class BlockHash(BuiltinFunctionT): _id = "blockhash" _inputs = [("block_num", UINT256_T)] _return_type = BYTES32_T @@ -1264,7 +1259,7 @@ def build_IR(self, expr, args, kwargs, contact): ) -class RawRevert(BuiltinFunction): +class RawRevert(BuiltinFunctionT): _id = "raw_revert" _inputs = [("data", BytesT.any())] _return_type = None @@ -1286,7 +1281,7 @@ def build_IR(self, expr, args, kwargs, context): return b.resolve(IRnode.from_list(["revert", data, len_])) -class RawLog(BuiltinFunction): +class RawLog(BuiltinFunctionT): _id = "raw_log" _inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))] @@ -1337,7 +1332,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BitwiseAnd(BuiltinFunction): +class BitwiseAnd(BuiltinFunctionT): _id = "bitwise_and" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1363,7 +1358,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["and", args[0], args[1]], typ=UINT256_T) -class BitwiseOr(BuiltinFunction): +class BitwiseOr(BuiltinFunctionT): _id = "bitwise_or" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1389,7 +1384,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["or", args[0], args[1]], typ=UINT256_T) -class BitwiseXor(BuiltinFunction): +class BitwiseXor(BuiltinFunctionT): _id = "bitwise_xor" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1415,7 +1410,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["xor", args[0], args[1]], typ=UINT256_T) -class BitwiseNot(BuiltinFunction): +class BitwiseNot(BuiltinFunctionT): _id = "bitwise_not" _inputs = [("x", UINT256_T)] _return_type = UINT256_T @@ -1442,7 +1437,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["not", args[0]], typ=UINT256_T) -class Shift(BuiltinFunction): +class Shift(BuiltinFunctionT): _id = "shift" _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())] _return_type = UINT256_T @@ -1496,7 +1491,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(IRnode.from_list(ret, typ=argty))) -class _AddMulMod(BuiltinFunction): +class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T @@ -1537,7 +1532,7 @@ class MulMod(_AddMulMod): _opcode = "mulmod" -class PowMod256(BuiltinFunction): +class PowMod256(BuiltinFunctionT): _id = "pow_mod256" _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T @@ -1560,7 +1555,7 @@ def build_IR(self, expr, context): return IRnode.from_list(["exp", left, right], typ=left.typ) -class Abs(BuiltinFunction): +class Abs(BuiltinFunctionT): _id = "abs" _inputs = [("value", INT256_T)] _return_type = INT256_T @@ -1711,7 +1706,7 @@ def _create_preamble(codesize): return ["or", bytes_to_int(evm), shl(shl_bits, codesize)], evm_len -class _CreateBase(BuiltinFunction): +class _CreateBase(BuiltinFunctionT): _kwargs = { "value": KwargSettings(UINT256_T, zero_value), "salt": KwargSettings(BYTES32_T, empty_value), @@ -1940,7 +1935,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar return b1.resolve(b2.resolve(ir)) -class _UnsafeMath(BuiltinFunction): +class _UnsafeMath(BuiltinFunctionT): # TODO add unsafe math for `decimal`s _inputs = [("a", IntegerT.any()), ("b", IntegerT.any())] @@ -2006,7 +2001,7 @@ class UnsafeDiv(_UnsafeMath): op = "div" -class _MinMax(BuiltinFunction): +class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] def evaluate(self, node): @@ -2080,7 +2075,7 @@ class Max(_MinMax): _opcode = "gt" -class Uint2Str(BuiltinFunction): +class Uint2Str(BuiltinFunctionT): _id = "uint2str" _inputs = [("x", IntegerT.unsigneds())] @@ -2152,7 +2147,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, location=MEMORY, typ=return_t)) -class Sqrt(BuiltinFunction): +class Sqrt(BuiltinFunctionT): _id = "sqrt" _inputs = [("d", DecimalT())] _return_type = DecimalT() @@ -2208,7 +2203,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class ISqrt(BuiltinFunction): +class ISqrt(BuiltinFunctionT): _id = "isqrt" _inputs = [("d", UINT256_T)] _return_type = UINT256_T @@ -2258,7 +2253,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, typ=UINT256_T)) -class Empty(TypenameFoldedFunction): +class Empty(TypenameFoldedFunctionT): _id = "empty" def fetch_call_return(self, node): @@ -2273,7 +2268,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode("~empty", typ=output_type) -class Breakpoint(BuiltinFunction): +class Breakpoint(BuiltinFunctionT): _id = "breakpoint" _inputs: list = [] @@ -2291,7 +2286,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list("breakpoint", annotation="breakpoint()") -class Print(BuiltinFunction): +class Print(BuiltinFunctionT): _id = "print" _inputs: list = [] _has_varargs = True @@ -2369,7 +2364,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, annotation="print:" + sig) -class ABIEncode(BuiltinFunction): +class ABIEncode(BuiltinFunctionT): _id = "_abi_encode" # TODO prettier to rename this to abi.encode # signature: *, ensure_tuple= -> Bytes[] # explanation of ensure_tuple: @@ -2486,7 +2481,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, location=MEMORY, typ=buf_t) -class ABIDecode(BuiltinFunction): +class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} @@ -2573,7 +2568,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class _MinMaxValue(TypenameFoldedFunction): +class _MinMaxValue(TypenameFoldedFunctionT): def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2607,7 +2602,7 @@ def _eval(self, type_): return type_.ast_bounds[1] -class Epsilon(TypenameFoldedFunction): +class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" def evaluate(self, node): From b334218f855ae94285afe271a770f1f29d20b7df Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 22 Nov 2023 22:57:30 -0500 Subject: [PATCH 130/201] docs: add warnings at the top of all example token contracts (#3676) discourage people from using them in production --- examples/crowdfund.vy | 6 +++++- examples/tokens/ERC1155ownable.vy | 7 ++++++- examples/tokens/ERC20.vy | 6 +++++- examples/tokens/ERC4626.vy | 7 +++++++ examples/tokens/ERC721.vy | 6 +++++- examples/wallet/wallet.vy | 7 +++++-- 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 56b34308f1..6d07e15bc4 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,4 +1,8 @@ -# Setup private variables (only callable from within the contract) +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# example of a crowd funding contract funders: HashMap[address, uint256] beneficiary: address diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index f1070b8f89..30057582e8 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,13 @@ +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + # @version >=0.3.4 """ -@dev Implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) +@dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) """ + ############### imports ############### from vyper.interfaces import ERC165 diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 4c1d334691..c3809dbb60 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-20 token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of an ERC20 token # @author Takayuki Jimba (@yudetamago) # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-20.md diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index a9cbcc86c8..0a0a698bf0 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,4 +1,11 @@ # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy + +# example implementation of an ERC4626 vault + +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + from vyper.interfaces import ERC20 from vyper.interfaces import ERC4626 diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 5125040399..152b94b046 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-721 non-fungible token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of ERC-721 non-fungible token standard. # @author Ryuya Nakamura (@nrryuya) # Modified from: https://github.com/vyperlang/vyper/blob/de74722bf2d8718cca46902be165f9fe0e3641dd/examples/tokens/ERC721.vy diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 5fd5229136..e2515d9e62 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,5 +1,8 @@ -# An example of how you can do a wallet in Vyper. -# Warning: NOT AUDITED. Do not use to store substantial quantities of funds. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# An example of how you can implement a wallet in Vyper. # A list of the owners addresses (there are a maximum of 5 owners) owners: public(address[5]) From 9a982bd37a8b5a48f9a30939ec57e37ed01a72e0 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 29 Nov 2023 00:54:19 +0900 Subject: [PATCH 131/201] docs: typo in on_chain_market_maker.vy (#3677) --- examples/market_maker/on_chain_market_maker.vy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index be9c62b945..d385d2e0c6 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -9,7 +9,7 @@ invariant: public(uint256) token_address: ERC20 owner: public(address) -# Sets the on chain market maker with its owner, intial token quantity, +# Sets the on chain market maker with its owner, initial token quantity, # and initial ether quantity @external @payable From cbac5aba53f87b388e08f169481d6b5c29002c27 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Fri, 1 Dec 2023 21:41:57 +0200 Subject: [PATCH 132/201] feat: implement new IR for vyper (venom IR) (#3659) this commit implements a new IR for the vyper compiler. most of the implementation is self-contained in the `./vyper/venom/` directory. Venom IR is LLVM-"inspired", although we do not use LLVM on account of: 1) not wanting to introduce a large external dependency 2) no EVM backend exists for LLVM, so we would have to write one ourselves. see prior work at https://github.com/etclabscore/evm_llvm. fundamentally, LLVM is architected to target register machines; an EVM backend could conceivably be implmented, but it would always feel "bolted" on. 3) integration with LLVM would invariably be very complex 4) one advantage of using LLVM is getting multiple backends "for free", but in our case, none of the backends we are interested in (particularly EVM) have LLVM implementations. that being said, Venom is close enough to LLVM that it would seem fairly straightforward to pass "in-and-out" of LLVM, converting to LLVM to take advantage of its optimization passes and/or analysis utilities, and then converting back to Venom for final EVM emission, if that becomes desirable down the line. it could even provided as an "extra" -- if LLVM is installed on the system and enabled for the build, pass to LLVM for extra optimization, but otherwise the compiler being self-contained. for more details about the design and architecture of Venom IR, see `./vyper/venom/README.md`. note that this commit specifically focuses on the architecture, design and implementation of Venom. that is, more focus was spent on architecting the Venom compiler itself. the Vyper frontend does not emit Venom natively yet, Venom emission is implemented as a translation step from the current s-expr based IR to Venom. the translation is not feature-complete, and may have bugs. that being said, vyper compilation via Venom is experimentally available by passing the `--experimental-codegen` flag to vyper on the CLI. incrementally refactoring the codegen to use Venom instead of the earlier s-expr IR will be the next area of focus of development. --------- Co-authored-by: Charles Cooper --- .../compiler/venom/test_duplicate_operands.py | 28 + .../compiler/venom/test_multi_entry_block.py | 96 ++ .../venom/test_stack_at_external_return.py | 5 + vyper/cli/vyper_compile.py | 8 + vyper/codegen/function_definitions/common.py | 4 + .../function_definitions/internal_function.py | 4 +- vyper/codegen/ir_node.py | 16 + vyper/codegen/return_.py | 4 +- vyper/codegen/self_call.py | 2 + vyper/compiler/__init__.py | 2 + vyper/compiler/phases.py | 28 +- vyper/ir/compile_ir.py | 80 +- vyper/ir/optimizer.py | 4 + vyper/semantics/types/function.py | 2 +- vyper/utils.py | 62 +- vyper/venom/README.md | 162 +++ vyper/venom/__init__.py | 56 ++ vyper/venom/analysis.py | 191 ++++ vyper/venom/basicblock.py | 345 +++++++ vyper/venom/bb_optimizer.py | 73 ++ vyper/venom/function.py | 170 ++++ vyper/venom/ir_node_to_venom.py | 943 ++++++++++++++++++ vyper/venom/passes/base_pass.py | 21 + vyper/venom/passes/constant_propagation.py | 13 + vyper/venom/passes/dft.py | 54 + vyper/venom/passes/normalization.py | 90 ++ vyper/venom/stack_model.py | 100 ++ vyper/venom/venom_to_assembly.py | 461 +++++++++ 28 files changed, 2994 insertions(+), 30 deletions(-) create mode 100644 tests/compiler/venom/test_duplicate_operands.py create mode 100644 tests/compiler/venom/test_multi_entry_block.py create mode 100644 tests/compiler/venom/test_stack_at_external_return.py create mode 100644 vyper/venom/README.md create mode 100644 vyper/venom/__init__.py create mode 100644 vyper/venom/analysis.py create mode 100644 vyper/venom/basicblock.py create mode 100644 vyper/venom/bb_optimizer.py create mode 100644 vyper/venom/function.py create mode 100644 vyper/venom/ir_node_to_venom.py create mode 100644 vyper/venom/passes/base_pass.py create mode 100644 vyper/venom/passes/constant_propagation.py create mode 100644 vyper/venom/passes/dft.py create mode 100644 vyper/venom/passes/normalization.py create mode 100644 vyper/venom/stack_model.py create mode 100644 vyper/venom/venom_to_assembly.py diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/compiler/venom/test_duplicate_operands.py new file mode 100644 index 0000000000..505f01e31b --- /dev/null +++ b/tests/compiler/venom/test_duplicate_operands.py @@ -0,0 +1,28 @@ +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRFunction + + +def test_duplicate_operands(): + """ + Test the duplicate operands code generation. + The venom code: + + %1 = 10 + %2 = add %1, %1 + %3 = mul %1, %2 + stop + + Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] + """ + ctx = IRFunction() + + op = ctx.append_instruction("store", [IRLiteral(10)]) + sum = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("mul", [sum, op]) + ctx.append_instruction("stop", [], False) + + asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/compiler/venom/test_multi_entry_block.py new file mode 100644 index 0000000000..bb57fa1065 --- /dev/null +++ b/tests/compiler/venom/test_multi_entry_block.py @@ -0,0 +1,96 @@ +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel +from vyper.venom.passes.normalization import NormalizationPass + + +def test_multi_entry_block_1(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +# more complicated one +def test_multi_entry_block_2(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + block_2_label = IRLabel("block_2", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + + block_2 = IRBasicBlock(block_2_label, ctx) + ctx.append_basic_block(block_2) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + # switch the order of the labels, for fun + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/compiler/venom/test_stack_at_external_return.py new file mode 100644 index 0000000000..be9fa66e9a --- /dev/null +++ b/tests/compiler/venom/test_stack_at_external_return.py @@ -0,0 +1,5 @@ +def test_stack_at_external_return(): + """ + TODO: USE BOA DO GENERATE THIS TEST + """ + pass diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 82eba63f32..ca1792384e 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -141,6 +141,11 @@ def _parse_args(argv): "-p", help="Set the root path for contract imports", default=".", dest="root_folder" ) parser.add_argument("-o", help="Set the output path", dest="output_path") + parser.add_argument( + "--experimental-codegen", + help="The compiler use the new IR codegen. This is an experimental feature.", + action="store_true", + ) args = parser.parse_args(argv) @@ -188,6 +193,7 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, + args.experimental_codegen, ) if args.output_path: @@ -225,6 +231,7 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> dict: root_path = Path(root_folder).resolve() if not root_path.exists(): @@ -275,6 +282,7 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, + experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 1d24b6c6dd..c48f1256c3 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -162,5 +162,9 @@ def generate_ir_for_function( # (note: internal functions do not need to adjust gas estimate since mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore + ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore + ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 228191e3ca..cf01dbdab4 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -68,4 +68,6 @@ def generate_ir_for_internal_function( ["seq"] + nonreentrant_post + [["exit_to", "return_pc"]], ] - return IRnode.from_list(["seq", body, cleanup_routine]) + ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + ir_node.passthrough_metadata["func_t"] = func_t + return ir_node diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index e17ef47c8f..ce26066968 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -171,6 +171,10 @@ class IRnode: valency: int args: List["IRnode"] value: Union[str, int] + is_self_call: bool + passthrough_metadata: dict[str, Any] + func_ir: Any + common_ir: Any def __init__( self, @@ -184,6 +188,8 @@ def __init__( mutable: bool = True, add_gas_estimate: int = 0, encoding: Encoding = Encoding.VYPER, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, ): if args is None: args = [] @@ -201,6 +207,10 @@ def __init__( self.add_gas_estimate = add_gas_estimate self.encoding = encoding self.as_hex = AS_HEX_DEFAULT + self.is_self_call = is_self_call + self.passthrough_metadata = passthrough_metadata or {} + self.func_ir = None + self.common_ir = None assert self.value is not None, "None is not allowed as IRnode value" @@ -585,6 +595,8 @@ def from_list( error_msg: Optional[str] = None, mutable: bool = True, add_gas_estimate: int = 0, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, encoding: Encoding = Encoding.VYPER, ) -> "IRnode": if isinstance(typ, str): @@ -617,6 +629,8 @@ def from_list( source_pos=source_pos, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) else: return cls( @@ -630,4 +644,6 @@ def from_list( add_gas_estimate=add_gas_estimate, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 56bea2b8da..41fa11ab56 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -40,7 +40,9 @@ def finalize(fill_return_buffer): cleanup_loops = "cleanup_repeat" if context.forvars else "seq" # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break - return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit]) + jump_to_exit_ir = IRnode.from_list(jump_to_exit) + jump_to_exit_ir.passthrough_metadata["func_t"] = func_t + return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: if context.is_internal: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index c320e6889c..f03f2eb9c8 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -121,4 +121,6 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True + o.passthrough_metadata["func_t"] = func_t + o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 62ea05b243..61d7a7c229 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -55,6 +55,7 @@ def compile_code( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, + experimental_codegen: bool = False, ) -> dict: """ Generate consumable compiler output(s) from a single contract source code. @@ -104,6 +105,7 @@ def compile_code( storage_layout_override, show_gas_estimates, no_bytecode_metadata, + experimental_codegen, ) ret = {} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index bfbb336d54..4e32812fee 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -16,6 +16,7 @@ from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.typing import StorageLayout +from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") @@ -60,6 +61,7 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -78,14 +80,18 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool, optional + Use experimental codegen. Defaults to False """ + # to force experimental codegen, uncomment: + # experimental_codegen = True self.contract_path = contract_path self.source_code = source_code self.source_id = source_id self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - + self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -160,7 +166,11 @@ def global_ctx(self) -> GlobalContext: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes(self.global_ctx, self.settings.optimize) + nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) + if self.experimental_codegen: + return [generate_ir(nodes[0]), generate_ir(nodes[1])] + else: + return nodes @property def ir_nodes(self) -> IRnode: @@ -183,11 +193,21 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - return generate_assembly(self.ir_nodes, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_nodes, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_nodes, self.settings.optimize) @cached_property def assembly_runtime(self) -> list: - return generate_assembly(self.ir_runtime, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_runtime, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property def bytecode(self) -> bytes: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1c4dc1ef7c..1d3df8becb 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -9,6 +9,7 @@ from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic +from vyper.ir.optimizer import COMMUTATIVE_OPS from vyper.utils import MemoryPositions from vyper.version import version_tuple @@ -164,7 +165,7 @@ def _add_postambles(asm_ops): # insert the postambles *before* runtime code # so the data section of the runtime code can't bork the postambles. runtime = None - if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], RuntimeHeader): runtime = asm_ops.pop() # for some reason there might not be a STOP at the end of asm_ops. @@ -229,7 +230,7 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): _relocate_segments(res) if optimize != OptimizationLevel.NONE: - _optimize_assembly(res) + optimize_assembly(res) return res @@ -531,7 +532,7 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode + subcode = [RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode # append the runtime code after the ctor code # `append(...)` call here is intentional. @@ -675,7 +676,7 @@ def _height_of(witharg): ) elif code.value == "data": - data_node = [_DataHeader("_sym_" + code.args[0].value)] + data_node = [DataHeader("_sym_" + code.args[0].value)] for c in code.args[1:]: if isinstance(c.value, int): @@ -837,6 +838,31 @@ def _prune_inefficient_jumps(assembly): return changed +def _optimize_inefficient_jumps(assembly): + # optimize sequences `_sym_common JUMPI _sym_x JUMP _sym_common JUMPDEST` + # to `ISZERO _sym_x JUMPI _sym_common JUMPDEST` + changed = False + i = 0 + while i < len(assembly) - 6: + if ( + is_symbol(assembly[i]) + and assembly[i + 1] == "JUMPI" + and is_symbol(assembly[i + 2]) + and assembly[i + 3] == "JUMP" + and assembly[i] == assembly[i + 4] + and assembly[i + 5] == "JUMPDEST" + ): + changed = True + assembly[i] = "ISZERO" + assembly[i + 1] = assembly[i + 2] + assembly[i + 2] = "JUMPI" + del assembly[i + 3 : i + 4] + else: + i += 1 + + return changed + + def _merge_jumpdests(assembly): # When we have multiple JUMPDESTs in a row, or when a JUMPDEST # is immediately followed by another JUMP, we can skip the @@ -938,7 +964,7 @@ def _prune_unused_jumpdests(assembly): used_jumpdests.add(assembly[i]) for item in assembly: - if isinstance(item, list) and isinstance(item[0], _DataHeader): + if isinstance(item, list) and isinstance(item[0], DataHeader): # add symbols used in data sections as they are likely # used for a jumptable. for t in item: @@ -961,6 +987,12 @@ def _stack_peephole_opts(assembly): changed = False i = 0 while i < len(assembly) - 2: + if assembly[i : i + 3] == ["DUP1", "SWAP2", "SWAP1"]: + changed = True + del assembly[i + 2] + assembly[i] = "SWAP1" + assembly[i + 1] = "DUP2" + continue # usually generated by with statements that return their input like # (with x (...x)) if assembly[i : i + 3] == ["DUP1", "SWAP1", "POP"]: @@ -975,16 +1007,22 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue + if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + changed = True + del assembly[i : i + 2] + if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: + changed = True + del assembly[i] i += 1 return changed # optimize assembly, in place -def _optimize_assembly(assembly): +def optimize_assembly(assembly): for x in assembly: - if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): - _optimize_assembly(x) + if isinstance(x, list) and isinstance(x[0], RuntimeHeader): + optimize_assembly(x) for _ in range(1024): changed = False @@ -993,6 +1031,7 @@ def _optimize_assembly(assembly): changed |= _merge_iszero(assembly) changed |= _merge_jumpdests(assembly) changed |= _prune_inefficient_jumps(assembly) + changed |= _optimize_inefficient_jumps(assembly) changed |= _prune_unused_jumpdests(assembly) changed |= _stack_peephole_opts(assembly) @@ -1021,7 +1060,7 @@ def adjust_pc_maps(pc_maps, ofst): def _data_to_evm(assembly, symbol_map): ret = bytearray() - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") @@ -1039,7 +1078,7 @@ def _data_to_evm(assembly, symbol_map): # predict what length of an assembly [data] node will be in bytecode def _length_of_data(assembly): ret = 0 - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): ret += SYMBOL_SIZE @@ -1055,7 +1094,7 @@ def _length_of_data(assembly): @dataclass -class _RuntimeHeader: +class RuntimeHeader: label: str ctor_mem_size: int immutables_len: int @@ -1065,7 +1104,7 @@ def __repr__(self): @dataclass -class _DataHeader: +class DataHeader: label: str def __repr__(self): @@ -1081,11 +1120,11 @@ def _relocate_segments(assembly): code_segments = [] for t in assembly: if isinstance(t, list): - if isinstance(t[0], _DataHeader): + if isinstance(t[0], DataHeader): data_segments.append(t) else: _relocate_segments(t) # recurse - assert isinstance(t[0], _RuntimeHeader) + assert isinstance(t[0], RuntimeHeader) code_segments.append(t) else: non_data_segments.append(t) @@ -1134,7 +1173,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + if isinstance(item, list) and isinstance(item[0], RuntimeHeader): assert runtime_code is None, "Multiple subcodes" assert ctor_mem_size is None @@ -1184,6 +1223,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: + print(assembly) raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc @@ -1198,7 +1238,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): # we are in initcode symbol_map[item[0].label] = pc # add source map for all items in the runtime map @@ -1209,10 +1249,10 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat pc += len(runtime_code) # grab lengths of data sections from the runtime for t in item: - if isinstance(t, list) and isinstance(t[0], _DataHeader): + if isinstance(t, list) and isinstance(t[0], DataHeader): data_section_lengths.append(_length_of_data(t)) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): symbol_map[item[0].label] = pc pc += _length_of_data(item) else: @@ -1285,9 +1325,9 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": ret.append(SWAP_OFFSET + int(item[4:])) - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): ret.extend(runtime_code) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): ret.extend(_data_to_evm(item, symbol_map)) else: # pragma: no cover # unreachable diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 8df4bbac2d..79e02f041d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -440,6 +440,8 @@ def _optimize(node: IRnode, parent: Optional[IRnode]) -> Tuple[bool, IRnode]: error_msg = node.error_msg annotation = node.annotation add_gas_estimate = node.add_gas_estimate + is_self_call = node.is_self_call + passthrough_metadata = node.passthrough_metadata changed = False @@ -462,6 +464,8 @@ def finalize(val, args): error_msg=error_msg, annotation=annotation, add_gas_estimate=add_gas_estimate, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) if should_check_symbols: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 77b9efb13d..140f73f095 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -93,7 +93,7 @@ def __init__( self.nonreentrant = nonreentrant # a list of internal functions this function calls - self.called_functions = OrderedSet() + self.called_functions = OrderedSet[ContractFunctionT]() # to be populated during codegen self._ir_info: Any = None diff --git a/vyper/utils.py b/vyper/utils.py index 3d9d9cb416..0a2e1f831f 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -6,12 +6,14 @@ import time import traceback import warnings -from typing import List, Union +from typing import Generic, List, TypeVar, Union from vyper.exceptions import DecimalOverrideException, InvalidLiteral +_T = TypeVar("_T") -class OrderedSet(dict): + +class OrderedSet(Generic[_T], dict[_T, None]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -20,9 +22,41 @@ class OrderedSet(dict): functionality as needed. """ - def add(self, item): + def __init__(self, iterable=None): + super().__init__() + if iterable is not None: + for item in iterable: + self.add(item) + + def __repr__(self): + keys = ", ".join(repr(k) for k in self.keys()) + return f"{{{keys}}}" + + def get(self, *args, **kwargs): + raise RuntimeError("can't call get() on OrderedSet!") + + def add(self, item: _T) -> None: self[item] = None + def remove(self, item: _T) -> None: + del self[item] + + def difference(self, other): + ret = self.copy() + for k in other.keys(): + if k in ret: + ret.remove(k) + return ret + + def union(self, other): + return self | other + + def __or__(self, other): + return self.__class__(super().__or__(other)) + + def copy(self): + return self.__class__(super().copy()) + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): @@ -436,3 +470,25 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) + + +def ir_pass(func): + """ + Decorator for IR passes. This decorator will run the pass repeatedly until + no more changes are made. + """ + + def wrapper(*args, **kwargs): + count = 0 + + while True: + changes = func(*args, **kwargs) or 0 + if isinstance(changes, list) or isinstance(changes, set): + changes = len(changes) + count += changes + if changes == 0: + break + + return count + + return wrapper diff --git a/vyper/venom/README.md b/vyper/venom/README.md new file mode 100644 index 0000000000..a81f6c0582 --- /dev/null +++ b/vyper/venom/README.md @@ -0,0 +1,162 @@ +## Venom - An Intermediate representation language for Vyper + +### Introduction + +Venom serves as the next-gen intermediate representation language specifically tailored for use with the Vyper smart contract compiler. Drawing inspiration from LLVM IR, Venom has been adapted to be simpler, and to be architected towards emitting code for stack-based virtual machines. Designed with a Single Static Assignment (SSA) form, Venom allows for sophisticated analysis and optimizations, while accommodating the idiosyncrasies of the EVM architecture. + +### Venom Form + +In Venom, values are denoted as strings commencing with the `'%'` character, referred to as variables. Variables can only be assigned to at declaration (they remain immutable post-assignment). Constants are represented as decimal numbers (hexadecimal may be added in the future). + +Reserved words include all the instruction opcodes and `'IRFunction'`, `'param'`, `'dbname'` and `'db'`. + +Any content following the `';'` character until the line end is treated as a comment. + +For instance, an example of incrementing a variable by one is as follows: + +```llvm +%sum = add %x, 1 ; Add one to x +``` + +Each instruction is identified by its opcode and a list of input operands. In cases where an instruction produces a result, it is stored in a new variable, as indicated on the left side of the assignment character. + +Code is organized into non-branching instruction blocks, known as _"Basic Blocks"_. Each basic block is defined by a label and contains its set of instructions. The final instruction of a basic block should either be a terminating instruction or a jump (conditional or unconditional) to other block(s). + +Basic blocks are grouped into _functions_ that are named and dictate the first block to execute. + +Venom employs two scopes: global and function level. + +### Example code + +```llvm +IRFunction: global + +global: + %1 = calldataload 0 + %2 = shr 224, %1 + jmp label %selector_bucket_0 + +selector_bucket_0: + %3 = xor %2, 1579456981 + %4 = iszero %3 + jnz label %1, label %2, %4 + +1: IN=[selector_bucket_0] OUT=[9] + jmp label %fallback + +2: + %5 = callvalue + %6 = calldatasize + %7 = lt %6, 164 + %8 = or %5, %7 + %9 = iszero %8 + assert %9 + stop + +fallback: + revert 0, 0 +``` + +### Grammar + +Below is a (not-so-complete) grammar to describe the text format of Venom IR: + +```llvm +program ::= function_declaration* + +function_declaration ::= "IRFunction:" identifier input_list? output_list? "=>" block + +input_list ::= "IN=" "[" (identifier ("," identifier)*)? "]" +output_list ::= "OUT=" "[" (identifier ("," identifier)*)? "]" + +block ::= label ":" input_list? output_list? "=>{" operation* "}" + +operation ::= "%" identifier "=" opcode operand ("," operand)* + | opcode operand ("," operand)* + +opcode ::= "calldataload" | "shr" | "shl" | "and" | "add" | "codecopy" | "mload" | "jmp" | "xor" | "iszero" | "jnz" | "label" | "lt" | "or" | "assert" | "callvalue" | "calldatasize" | "alloca" | "calldatacopy" | "invoke" | "gt" | ... + +operand ::= "%" identifier | label | integer | "label" "%" identifier +label ::= "%" identifier + +identifier ::= [a-zA-Z_][a-zA-Z0-9_]* +integer ::= [0-9]+ +``` + +## Implementation + +In the current implementation the compiler was extended to incorporate a new pass responsible for translating the original s-expr based IR into Venom. Subsequently, the generated Venom code undergoes processing by the actual Venom compiler, ultimately converting it to assembly code. That final assembly code is then passed to the original assembler of Vyper to produce the executable bytecode. + +Currently there is no implementation of the text format (that is, there is no front-end), although this is planned. At this time, Venom IR can only be constructed programmatically. + +## Architecture + +The Venom implementation is composed of several distinct passes that iteratively transform and optimize the Venom IR code until it reaches the assembly emitter, which produces the stack-based EVM assembly. The compiler is designed to be more-or-less pluggable, so passes can be written without too much knowledge of or dependency on other passes. + +These passes encompass generic transformations that streamline the code (such as dead code elimination and normalization), as well as those generating supplementary information about the code, like liveness analysis and control-flow graph (CFG) construction. Some passes may rely on the output of others, requiring a specific execution order. For instance, the code emitter expects the execution of a normalization pass preceding it, and this normalization pass, in turn, requires the augmentation of the Venom IR with code flow information. + +The primary categorization of pass types are: + +- Transformation passes +- Analysis/augmentation passes +- Optimization passes + +## Currently implemented passes + +The Venom compiler currently implements the following passes. + +### Control Flow Graph calculation + +The compiler generates a fundamental data structure known as the Control Flow Graph (CFG). This graph illustrates the interconnections between basic blocks, serving as a foundational data structure upon which many subsequent passes depend. + +### Data Flow Graph calculation + +To enable the compiler to analyze the movement of data through the code during execution, a specialized graph, the Dataflow Graph (DFG), is generated. The compiler inspects the code, determining where each variable is defined (in one location) and all the places where it is utilized. + +### Dataflow Transformation + +This pass depends on the DFG construction, and reorders variable declarations to try to reduce stack traffic during instruction selection. + +### Liveness analysis + +This pass conducts a dataflow analysis, utilizing information from previous passes to identify variables that are live at each instruction in the Venom IR code. A variable is deemed live at a particular instruction if it holds a value necessary for future operations. Variables only alive for their assignment instructions are identified here and then eliminated by the dead code elimination pass. + +### Dead code elimination + +This pass eliminates all basic blocks that are not reachable from any other basic block, leveraging the CFG. + +### Normalization + +A Venom program may feature basic blocks with multiple CFG inputs and outputs. This currently can occur when multiple blocks conditionally direct control to the same target basic block. We define a Venom IR as "normalized" when it contains no basic blocks that have multiple inputs and outputs. The normalization pass is responsible for converting any Venom IR program to its normalized form. EVM assembly emission operates solely on normalized Venom programs, because the stack layout is not well defined for non-normalized basic blocks. + +### Code emission + +This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calcluates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. + +## Future planned passes + +A number of passes that are planned to be implemented, or are implemented for immediately after the initial PR merge are below. + +### Constant folding + +### Instruction combination + +### Dead store elimination + +### Scalar evolution + +### Loop invariant code motion + +### Loop unrolling + +### Code sinking + +### Expression reassociation + +### Stack to mem + +### Mem to stack + +### Function inlining + +### Load-store elimination diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py new file mode 100644 index 0000000000..5a09f8378e --- /dev/null +++ b/vyper/venom/__init__.py @@ -0,0 +1,56 @@ +# maybe rename this `main.py` or `venom.py` +# (can have an `__init__.py` which exposes the API). + +from typing import Optional + +from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.bb_optimizer import ( + ir_pass_optimize_empty_blocks, + ir_pass_optimize_unused_variables, + ir_pass_remove_unreachable_blocks, +) +from vyper.venom.function import IRFunction +from vyper.venom.ir_node_to_venom import convert_ir_basicblock +from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation +from vyper.venom.passes.dft import DFTPass +from vyper.venom.venom_to_assembly import VenomCompiler + + +def generate_assembly_experimental( + ctx: IRFunction, optimize: Optional[OptimizationLevel] = None +) -> list[str]: + compiler = VenomCompiler(ctx) + return compiler.generate_evm(optimize is OptimizationLevel.NONE) + + +def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: + # Convert "old" IR to "new" IR + ctx = convert_ir_basicblock(ir) + + # Run passes on "new" IR + # TODO: Add support for optimization levels + while True: + changes = 0 + + changes += ir_pass_optimize_empty_blocks(ctx) + changes += ir_pass_remove_unreachable_blocks(ctx) + + calculate_liveness(ctx) + + changes += ir_pass_optimize_unused_variables(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + changes += ir_pass_constant_propagation(ctx) + changes += DFTPass.run_pass(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + if changes == 0: + break + + return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py new file mode 100644 index 0000000000..5980e21028 --- /dev/null +++ b/vyper/venom/analysis.py @@ -0,0 +1,191 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import ( + BB_TERMINATORS, + CFG_ALTERING_OPS, + IRBasicBlock, + IRInstruction, + IRVariable, +) +from vyper.venom.function import IRFunction + + +def calculate_cfg(ctx: IRFunction) -> None: + """ + Calculate (cfg) inputs for each basic block. + """ + for bb in ctx.basic_blocks: + bb.cfg_in = OrderedSet() + bb.cfg_out = OrderedSet() + bb.out_vars = OrderedSet() + + # TODO: This is a hack to support the old IR format where `deploy` is + # an instruction. in the future we should have two entry points, one + # for the initcode and one for the runtime code. + deploy_bb = None + after_deploy_bb = None + for i, bb in enumerate(ctx.basic_blocks): + if bb.instructions[0].opcode == "deploy": + deploy_bb = bb + after_deploy_bb = ctx.basic_blocks[i + 1] + break + + if deploy_bb is not None: + assert after_deploy_bb is not None, "No block after deploy block" + entry_block = after_deploy_bb + has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" + if has_constructor: + deploy_bb.add_cfg_in(ctx.basic_blocks[0]) + entry_block.add_cfg_in(deploy_bb) + else: + entry_block = ctx.basic_blocks[0] + + # TODO: Special case for the jump table of selector buckets and fallback. + # this will be cleaner when we introduce an "indirect jump" instruction + # for the selector table (which includes all possible targets). it will + # also clean up the code for normalization because it will not have to + # handle this case specially. + for bb in ctx.basic_blocks: + if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": + bb.add_cfg_in(entry_block) + + for bb in ctx.basic_blocks: + assert len(bb.instructions) > 0, "Basic block should not be empty" + last_inst = bb.instructions[-1] + assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" + + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_OPS: + ops = inst.get_label_operands() + for op in ops: + ctx.get_basic_block(op.value).add_cfg_in(bb) + + # Fill in the "out" set for each basic block + for bb in ctx.basic_blocks: + for in_bb in bb.cfg_in: + in_bb.add_cfg_out(bb) + + +def _reset_liveness(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + for inst in bb.instructions: + inst.liveness = OrderedSet() + + +def _calculate_liveness_bb(bb: IRBasicBlock) -> None: + """ + Compute liveness of each instruction in the basic block. + """ + liveness = bb.out_vars.copy() + for instruction in reversed(bb.instructions): + ops = instruction.get_inputs() + + for op in ops: + if op in liveness: + instruction.dup_requirements.add(op) + + liveness = liveness.union(OrderedSet.fromkeys(ops)) + out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None + if out in liveness: + liveness.remove(out) + instruction.liveness = liveness + + +def _calculate_liveness_r(bb: IRBasicBlock, visited: dict) -> None: + assert isinstance(visited, dict) + for out_bb in bb.cfg_out: + if visited.get(bb) == out_bb: + continue + visited[bb] = out_bb + + # recurse + _calculate_liveness_r(out_bb, visited) + + target_vars = input_vars_from(bb, out_bb) + + # the output stack layout for bb. it produces a stack layout + # which works for all possible cfg_outs from the bb. + bb.out_vars = bb.out_vars.union(target_vars) + + _calculate_liveness_bb(bb) + + +def calculate_liveness(ctx: IRFunction) -> None: + _reset_liveness(ctx) + _calculate_liveness_r(ctx.basic_blocks[0], dict()) + + +# calculate the input variables into self from source +def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: + liveness = target.instructions[0].liveness.copy() + assert isinstance(liveness, OrderedSet) + + for inst in target.instructions: + if inst.opcode == "phi": + # we arbitrarily choose one of the arguments to be in the + # live variables set (dependent on how we traversed into this + # basic block). the argument will be replaced by the destination + # operand during instruction selection. + # for instance, `%56 = phi %label1 %12 %label2 %14` + # will arbitrarily choose either %12 or %14 to be in the liveness + # set, and then during instruction selection, after this instruction, + # %12 will be replaced by %56 in the liveness set + source1, source2 = inst.operands[0], inst.operands[2] + phi1, phi2 = inst.operands[1], inst.operands[3] + if source.label == source1: + liveness.add(phi1) + if phi2 in liveness: + liveness.remove(phi2) + elif source.label == source2: + liveness.add(phi2) + if phi1 in liveness: + liveness.remove(phi1) + else: + # bad path into this phi node + raise CompilerPanic(f"unreachable: {inst}") + + return liveness + + +# DataFlow Graph +# this could be refactored into its own file, but it's only used here +# for now +class DFG: + _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_outputs: dict[IRVariable, IRInstruction] + + def __init__(self): + self._dfg_inputs = dict() + self._dfg_outputs = dict() + + # return uses of a given variable + def get_uses(self, op: IRVariable) -> list[IRInstruction]: + return self._dfg_inputs.get(op, []) + + # the instruction which produces this variable. + def get_producing_instruction(self, op: IRVariable) -> IRInstruction: + return self._dfg_outputs[op] + + @classmethod + def build_dfg(cls, ctx: IRFunction) -> "DFG": + dfg = cls() + + # Build DFG + + # %15 = add %13 %14 + # %16 = iszero %15 + # dfg_outputs of %15 is (%15 = add %13 %14) + # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] + for bb in ctx.basic_blocks: + for inst in bb.instructions: + operands = inst.get_inputs() + res = inst.get_outputs() + + for op in operands: + inputs = dfg._dfg_inputs.setdefault(op, []) + inputs.append(inst) + + for op in res: # type: ignore + dfg._dfg_outputs[op] = inst + + return dfg diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py new file mode 100644 index 0000000000..b95d7416ca --- /dev/null +++ b/vyper/venom/basicblock.py @@ -0,0 +1,345 @@ +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Iterator, Optional + +from vyper.utils import OrderedSet + +# instructions which can terminate a basic block +BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) + +VOLATILE_INSTRUCTIONS = frozenset( + [ + "param", + "alloca", + "call", + "staticcall", + "invoke", + "sload", + "sstore", + "iload", + "istore", + "assert", + "mstore", + "mload", + "calldatacopy", + "codecopy", + "dloadbytes", + "dload", + "return", + "ret", + "jmp", + "jnz", + ] +) + +CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) + + +if TYPE_CHECKING: + from vyper.venom.function import IRFunction + + +class IRDebugInfo: + """ + IRDebugInfo represents debug information in IR, used to annotate IR instructions + with source code information when printing IR. + """ + + line_no: int + src: str + + def __init__(self, line_no: int, src: str) -> None: + self.line_no = line_no + self.src = src + + def __repr__(self) -> str: + src = self.src if self.src else "" + return f"\t# line {self.line_no}: {src}".expandtabs(20) + + +class IROperand: + """ + IROperand represents an operand in IR. An operand is anything that can + be an argument to an IRInstruction + """ + + value: Any + + +class IRValue(IROperand): + """ + IRValue represents a value in IR. A value is anything that can be + operated by non-control flow instructions. That is, IRValues can be + IRVariables or IRLiterals. + """ + + pass + + +class IRLiteral(IRValue): + """ + IRLiteral represents a literal in IR + """ + + value: int + + def __init__(self, value: int) -> None: + assert isinstance(value, str) or isinstance(value, int), "value must be an int" + self.value = value + + def __repr__(self) -> str: + return str(self.value) + + +class MemType(Enum): + OPERAND_STACK = auto() + MEMORY = auto() + + +class IRVariable(IRValue): + """ + IRVariable represents a variable in IR. A variable is a string that starts with a %. + """ + + value: str + offset: int = 0 + + # some variables can be in memory for conversion from legacy IR to venom + mem_type: MemType = MemType.OPERAND_STACK + mem_addr: Optional[int] = None + + def __init__( + self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None + ) -> None: + assert isinstance(value, str) + self.value = value + self.offset = 0 + self.mem_type = mem_type + self.mem_addr = mem_addr + + def __repr__(self) -> str: + return self.value + + +class IRLabel(IROperand): + """ + IRLabel represents a label in IR. A label is a string that starts with a %. + """ + + # is_symbol is used to indicate if the label came from upstream + # (like a function name, try to preserve it in optimization passes) + is_symbol: bool = False + value: str + + def __init__(self, value: str, is_symbol: bool = False) -> None: + assert isinstance(value, str), "value must be an str" + self.value = value + self.is_symbol = is_symbol + + def __repr__(self) -> str: + return self.value + + +class IRInstruction: + """ + IRInstruction represents an instruction in IR. Each instruction has an opcode, + operands, and return value. For example, the following IR instruction: + %1 = add %0, 1 + has opcode "add", operands ["%0", "1"], and return value "%1". + + Convention: the rightmost value is the top of the stack. + """ + + opcode: str + volatile: bool + operands: list[IROperand] + output: Optional[IROperand] + # set of live variables at this instruction + liveness: OrderedSet[IRVariable] + dup_requirements: OrderedSet[IRVariable] + parent: Optional["IRBasicBlock"] + fence_id: int + annotation: Optional[str] + + def __init__( + self, + opcode: str, + operands: list[IROperand] | Iterator[IROperand], + output: Optional[IROperand] = None, + ): + assert isinstance(opcode, str), "opcode must be an str" + assert isinstance(operands, list | Iterator), "operands must be a list" + self.opcode = opcode + self.volatile = opcode in VOLATILE_INSTRUCTIONS + self.operands = [op for op in operands] # in case we get an iterator + self.output = output + self.liveness = OrderedSet() + self.dup_requirements = OrderedSet() + self.parent = None + self.fence_id = -1 + self.annotation = None + + def get_label_operands(self) -> list[IRLabel]: + """ + Get all labels in instruction. + """ + return [op for op in self.operands if isinstance(op, IRLabel)] + + def get_non_label_operands(self) -> list[IROperand]: + """ + Get input operands for instruction which are not labels + """ + return [op for op in self.operands if not isinstance(op, IRLabel)] + + def get_inputs(self) -> list[IRVariable]: + """ + Get all input operands for instruction. + """ + return [op for op in self.operands if isinstance(op, IRVariable)] + + def get_outputs(self) -> list[IROperand]: + """ + Get the output item for an instruction. + (Currently all instructions output at most one item, but write + it as a list to be generic for the future) + """ + return [self.output] if self.output else [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + for i, operand in enumerate(self.operands): + if operand in replacements: + self.operands[i] = replacements[operand] + + def __repr__(self) -> str: + s = "" + if self.output: + s += f"{self.output} = " + opcode = f"{self.opcode} " if self.opcode != "store" else "" + s += opcode + operands = ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + ) + s += operands + + if self.annotation: + s += f" <{self.annotation}>" + + # if self.liveness: + # return f"{s: <30} # {self.liveness}" + + return s + + +class IRBasicBlock: + """ + IRBasicBlock represents a basic block in IR. Each basic block has a label and + a list of instructions, while belonging to a function. + + The following IR code: + %1 = add %0, 1 + %2 = mul %1, 2 + is represented as: + bb = IRBasicBlock("bb", function) + bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) + bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + + The label of a basic block is used to refer to it from other basic blocks + in order to branch to it. + + The parent of a basic block is the function it belongs to. + + The instructions of a basic block are executed sequentially, and the last + instruction of a basic block is always a terminator instruction, which is + used to branch to other basic blocks. + """ + + label: IRLabel + parent: "IRFunction" + instructions: list[IRInstruction] + # basic blocks which can jump to this basic block + cfg_in: OrderedSet["IRBasicBlock"] + # basic blocks which this basic block can jump to + cfg_out: OrderedSet["IRBasicBlock"] + # stack items which this basic block produces + out_vars: OrderedSet[IRVariable] + + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: + assert isinstance(label, IRLabel), "label must be an IRLabel" + self.label = label + self.parent = parent + self.instructions = [] + self.cfg_in = OrderedSet() + self.cfg_out = OrderedSet() + self.out_vars = OrderedSet() + + def add_cfg_in(self, bb: "IRBasicBlock") -> None: + self.cfg_in.add(bb) + + def remove_cfg_in(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_in + self.cfg_in.remove(bb) + + def add_cfg_out(self, bb: "IRBasicBlock") -> None: + # malformed: jnz condition label1 label1 + # (we could handle but it makes a lot of code easier + # if we have this assumption) + self.cfg_out.add(bb) + + def remove_cfg_out(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_out + self.cfg_out.remove(bb) + + @property + def is_reachable(self) -> bool: + return len(self.cfg_in) > 0 + + def append_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.append(instruction) + + def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.insert(index, instruction) + + def clear_instructions(self) -> None: + self.instructions = [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + """ + for instruction in self.instructions: + instruction.replace_operands(replacements) + + @property + def is_terminated(self) -> bool: + """ + Check if the basic block is terminal, i.e. the last instruction is a terminator. + """ + # it's ok to return False here, since we use this to check + # if we can/need to append instructions to the basic block. + if len(self.instructions) == 0: + return False + return self.instructions[-1].opcode in BB_TERMINATORS + + def copy(self): + bb = IRBasicBlock(self.label, self.parent) + bb.instructions = self.instructions.copy() + bb.cfg_in = self.cfg_in.copy() + bb.cfg_out = self.cfg_out.copy() + bb.out_vars = self.out_vars.copy() + return bb + + def __repr__(self) -> str: + s = ( + f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" + f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars} \n" + ) + for instruction in self.instructions: + s += f" {instruction}\n" + return s diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py new file mode 100644 index 0000000000..620ee66d15 --- /dev/null +++ b/vyper/venom/bb_optimizer.py @@ -0,0 +1,73 @@ +from vyper.utils import ir_pass +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRInstruction, IRLabel +from vyper.venom.function import IRFunction + + +def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: + """ + Remove unused variables. + """ + removeList = set() + for bb in ctx.basic_blocks: + for i, inst in enumerate(bb.instructions[:-1]): + if inst.volatile: + continue + if inst.output and inst.output not in bb.instructions[i + 1].liveness: + removeList.add(inst) + + bb.instructions = [inst for inst in bb.instructions if inst not in removeList] + + return removeList + + +def _optimize_empty_basicblocks(ctx: IRFunction) -> int: + """ + Remove empty basic blocks. + """ + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if len(bb.instructions) > 0: + continue + + replaced_label = bb.label + replacement_label = ctx.basic_blocks[i].label if i < len(ctx.basic_blocks) else None + if replacement_label is None: + continue + + # Try to preserve symbol labels + if replaced_label.is_symbol: + replaced_label, replacement_label = replacement_label, replaced_label + ctx.basic_blocks[i].label = replacement_label + + for bb2 in ctx.basic_blocks: + for inst in bb2.instructions: + for op in inst.operands: + if isinstance(op, IRLabel) and op.value == replaced_label.value: + op.value = replacement_label.value + + ctx.basic_blocks.remove(bb) + i -= 1 + count += 1 + + return count + + +@ir_pass +def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: + changes = _optimize_empty_basicblocks(ctx) + calculate_cfg(ctx) + return changes + + +@ir_pass +def ir_pass_remove_unreachable_blocks(ctx: IRFunction) -> int: + return ctx.remove_unreachable_blocks() + + +@ir_pass +def ir_pass_optimize_unused_variables(ctx: IRFunction) -> int: + return len(_optimize_unused_variables(ctx)) diff --git a/vyper/venom/function.py b/vyper/venom/function.py new file mode 100644 index 0000000000..c14ad77345 --- /dev/null +++ b/vyper/venom/function.py @@ -0,0 +1,170 @@ +from typing import Optional + +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IROperand, + IRVariable, + MemType, +) + +GLOBAL_LABEL = IRLabel("global") + + +class IRFunction: + """ + Function that contains basic blocks. + """ + + name: IRLabel # symbol name + args: list + basic_blocks: list[IRBasicBlock] + data_segment: list[IRInstruction] + last_label: int + last_variable: int + + def __init__(self, name: IRLabel = None) -> None: + if name is None: + name = GLOBAL_LABEL + self.name = name + self.args = [] + self.basic_blocks = [] + self.data_segment = [] + self.last_label = 0 + self.last_variable = 0 + + self.append_basic_block(IRBasicBlock(name, self)) + + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + """ + Append basic block to function. + """ + assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" + self.basic_blocks.append(bb) + + # TODO add sanity check somewhere that basic blocks have unique labels + + return self.basic_blocks[-1] + + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: + """ + Get basic block by label. + If label is None, return the last basic block. + """ + if label is None: + return self.basic_blocks[-1] + for bb in self.basic_blocks: + if bb.label.value == label: + return bb + raise AssertionError(f"Basic block '{label}' not found") + + def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + """ + Get basic block after label. + """ + for i, bb in enumerate(self.basic_blocks[:-1]): + if bb.label.value == label.value: + return self.basic_blocks[i + 1] + raise AssertionError(f"Basic block after '{label}' not found") + + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: + """ + Get basic blocks that contain label. + """ + return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] + + def get_next_label(self) -> IRLabel: + self.last_label += 1 + return IRLabel(f"{self.last_label}") + + def get_next_variable( + self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None + ) -> IRVariable: + self.last_variable += 1 + return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + + def get_last_variable(self) -> str: + return f"%{self.last_variable}" + + def remove_unreachable_blocks(self) -> int: + removed = 0 + new_basic_blocks = [] + for bb in self.basic_blocks: + if not bb.is_reachable and bb.label.value != "global": + removed += 1 + else: + new_basic_blocks.append(bb) + self.basic_blocks = new_basic_blocks + return removed + + def append_instruction( + self, opcode: str, args: list[IROperand], do_ret: bool = True + ) -> Optional[IRVariable]: + """ + Append instruction to last basic block. + """ + ret = self.get_next_variable() if do_ret else None + inst = IRInstruction(opcode, args, ret) # type: ignore + self.get_basic_block().append_instruction(inst) + return ret + + def append_data(self, opcode: str, args: list[IROperand]) -> None: + """ + Append data + """ + self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + + @property + def normalized(self) -> bool: + """ + Check if function is normalized. A function is normalized if in the + CFG, no basic block simultaneously has multiple inputs and outputs. + That is, a basic block can be jumped to *from* multiple blocks, or it + can jump *to* multiple blocks, but it cannot simultaneously do both. + Having a normalized CFG makes calculation of stack layout easier when + emitting assembly. + """ + for bb in self.basic_blocks: + # Ignore if there are no multiple predecessors + if len(bb.cfg_in) <= 1: + continue + + # Check if there is a conditional jump at the end + # of one of the predecessors + # + # TODO: this check could be: + # `if len(in_bb.cfg_out) > 1: return False` + # but the cfg is currently not calculated "correctly" for + # certain special instructions (deploy instruction and + # selector table indirect jumps). + for in_bb in bb.cfg_in: + jump_inst = in_bb.instructions[-1] + if jump_inst.opcode != "jnz": + continue + if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): + continue + + # The function is not normalized + return False + + # The function is normalized + return True + + def copy(self): + new = IRFunction(self.name) + new.basic_blocks = self.basic_blocks.copy() + new.data_segment = self.data_segment.copy() + new.last_label = self.last_label + new.last_variable = self.last_variable + return new + + def __repr__(self) -> str: + str = f"IRFunction: {self.name}\n" + for bb in self.basic_blocks: + str += f"{bb}\n" + if len(self.data_segment) > 0: + str += "Data segment:\n" + for inst in self.data_segment: + str += f"{inst}\n" + return str diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py new file mode 100644 index 0000000000..19bd5c8b73 --- /dev/null +++ b/vyper/venom/ir_node_to_venom.py @@ -0,0 +1,943 @@ +from typing import Optional + +from vyper.codegen.context import VariableRecord +from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import get_opcodes +from vyper.exceptions import CompilerPanic +from vyper.ir.compile_ir import is_mem_sym, is_symbol +from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction + +_BINARY_IR_INSTRUCTIONS = frozenset( + [ + "eq", + "gt", + "lt", + "slt", + "sgt", + "shr", + "shl", + "or", + "xor", + "and", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "sha3", + "sha3_64", + "signextend", + ] +) + +# Instuctions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = [ + "chainid", + "basefee", + "timestamp", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "coinbase", + "number", + "iszero", + "ceil32", + "calldataload", + "extcodesize", + "extcodehash", + "balance", +] + +SymbolTable = dict[str, IROperand] + + +def _get_symbols_common(a: dict, b: dict) -> dict: + ret = {} + # preserves the ordering in `a` + for k in a.keys(): + if k not in b: + continue + if a[k] == b[k]: + continue + ret[k] = a[k], b[k] + return ret + + +def convert_ir_basicblock(ir: IRnode) -> IRFunction: + global_function = IRFunction() + _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) + + for i, bb in enumerate(global_function.basic_blocks): + if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: + bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + + revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) + revert_bb = global_function.append_basic_block(revert_bb) + revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + + return global_function + + +def _convert_binary_op( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], + swap: bool = False, +) -> IRVariable: + ir_args = ir.args[::-1] if swap else ir.args + arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) + args = [arg_1, arg_0] + + ret = ctx.get_next_variable() + + inst = IRInstruction(ir.value, args, ret) # type: ignore + ctx.get_basic_block().append_instruction(inst) + return ret + + +def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + + label = ctx.get_next_label() + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + + +def _new_block(ctx: IRFunction) -> IRBasicBlock: + bb = IRBasicBlock(ctx.get_next_label(), ctx) + bb = ctx.append_basic_block(bb) + return bb + + +def _handle_self_call( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + func_t = ir.passthrough_metadata.get("func_t", None) + args_ir = ir.passthrough_metadata["args_ir"] + goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] + target_label = goto_ir.args[0].value # goto + return_buf = goto_ir.args[1] # return buffer + ret_args = [IRLabel(target_label)] # type: ignore + + for arg in args_ir: + if arg.is_literal: + sym = symbols.get(f"&{arg.value}", None) + if sym is None: + ret = _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ret_args.append(ret) + else: + ret_args.append(sym) # type: ignore + else: + ret = _convert_ir_basicblock( + ctx, arg._optimized, symbols, variables, allocated_variables + ) + if arg.location and arg.location.load_op == "calldataload": + ret = ctx.append_instruction(arg.location.load_op, [ret]) + ret_args.append(ret) + + if return_buf.is_literal: + ret_args.append(IRLiteral(return_buf.value)) # type: ignore + + do_ret = func_t.return_type is not None + invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + + +def _handle_internal_func( + ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable +) -> IRnode: + bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore + bb = ctx.append_basic_block(bb) + + old_ir_mempos = 0 + old_ir_mempos += 64 + + for arg in func_t.arguments: + new_var = ctx.get_next_variable() + + alloca_inst = IRInstruction("param", [], new_var) + alloca_inst.annotation = arg.name + bb.append_instruction(alloca_inst) + symbols[f"&{old_ir_mempos}"] = new_var + old_ir_mempos += 32 # arg.typ.memory_bytes_required + + # return buffer + if func_t.return_type is not None: + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_buffer" + symbols["return_buffer"] = new_var + + # return address + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_pc" + symbols["return_pc"] = new_var + + return ir.args[0].args[2] + + +def _convert_ir_simple_node( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args + ] + return ctx.append_instruction(ir.value, args) # type: ignore + + +_break_target: Optional[IRBasicBlock] = None +_continue_target: Optional[IRBasicBlock] = None + + +def _get_variable_from_address( + variables: OrderedSet[VariableRecord], addr: int +) -> Optional[VariableRecord]: + assert isinstance(addr, int), "non-int address" + for var in variables.keys(): + if var.location.name != "memory": + continue + if addr >= var.pos and addr < var.pos + var.size: # type: ignore + return var + return None + + +def _get_return_for_stack_operand( + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> IRInstruction: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + else: + sym = symbols.get(ret_ir.value, None) + if sym is None: + # FIXME: needs real allocations + new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + else: + new_var = ret_ir + return IRInstruction("return", [last_ir, new_var]) # type: ignore + + +def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): + assert isinstance(variables, OrderedSet) + global _break_target, _continue_target + + frame_info = ir.passthrough_metadata.get("frame_info", None) + if frame_info is not None: + local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) + variables |= local_vars + + assert isinstance(variables, OrderedSet) + + if ir.value in _BINARY_IR_INSTRUCTIONS: + return _convert_binary_op( + ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] + ) + + elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + org_value = ir.value + ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] + new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + ir.value = org_value + return ctx.append_instruction("iszero", [new_var]) + + elif ir.value in PASS_THROUGH_INSTRUCTIONS: + return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) + + elif ir.value in ["pass", "stop", "return"]: + pass + elif ir.value == "deploy": + memsize = ir.args[0].value + ir_runtime = ir.args[1] + padding = ir.args[2].value + assert isinstance(memsize, int), "non-int memsize" + assert isinstance(padding, int), "non-int padding" + + runtimeLabel = ctx.get_next_label() + + inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) + ctx.get_basic_block().append_instruction(inst) + + bb = IRBasicBlock(runtimeLabel, ctx) + ctx.append_basic_block(bb) + + _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + elif ir.value == "seq": + func_t = ir.passthrough_metadata.get("func_t", None) + if ir.is_self_call: + return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) + elif func_t is not None: + symbols = {} + allocated_variables = {} + variables = OrderedSet( + {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} + ) + if func_t.is_internal: + ir = _handle_internal_func(ctx, ir, func_t, symbols) + # fallthrough + + ret = None + for ir_node in ir.args: # NOTE: skip the last one + ret = _convert_ir_basicblock(ctx, ir_node, symbols, variables, allocated_variables) + + return ret + elif ir.value in ["staticcall", "call"]: # external call + idx = 0 + gas = _convert_ir_basicblock(ctx, ir.args[idx], symbols, variables, allocated_variables) + address = _convert_ir_basicblock( + ctx, ir.args[idx + 1], symbols, variables, allocated_variables + ) + + value = None + if ir.value == "call": + value = _convert_ir_basicblock( + ctx, ir.args[idx + 2], symbols, variables, allocated_variables + ) + else: + idx -= 1 + + argsOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 3], symbols, variables, allocated_variables + ) + argsSize = _convert_ir_basicblock( + ctx, ir.args[idx + 4], symbols, variables, allocated_variables + ) + retOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 5], symbols, variables, allocated_variables + ) + retSize = _convert_ir_basicblock( + ctx, ir.args[idx + 6], symbols, variables, allocated_variables + ) + + if isinstance(argsOffset, IRLiteral): + offset = int(argsOffset.value) + addr = offset - 32 + 4 if offset > 0 else 0 + argsOffsetVar = symbols.get(f"&{addr}", None) + if argsOffsetVar is None: + argsOffsetVar = argsOffset + elif isinstance(argsOffsetVar, IRVariable): + argsOffsetVar.mem_type = MemType.MEMORY + argsOffsetVar.mem_addr = addr + argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 + else: # pragma: nocover + raise CompilerPanic("unreachable") + else: + argsOffsetVar = argsOffset + + retOffsetValue = int(retOffset.value) if retOffset else 0 + retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) + symbols[f"&{retOffsetValue}"] = retVar + + if ir.value == "call": + args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] + return ctx.append_instruction(ir.value, args) + else: + args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] + return ctx.append_instruction(ir.value, args) + elif ir.value == "if": + cond = ir.args[0] + current_bb = ctx.get_basic_block() + + # convert the condition + cont_ret = _convert_ir_basicblock(ctx, cond, symbols, variables, allocated_variables) + + else_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(else_block) + + # convert "else" + else_ret_val = None + else_syms = symbols.copy() + if len(ir.args) == 3: + else_ret_val = _convert_ir_basicblock( + ctx, ir.args[2], else_syms, variables, allocated_variables.copy() + ) + if isinstance(else_ret_val, IRLiteral): + assert isinstance(else_ret_val.value, int) # help mypy + else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + after_else_syms = else_syms.copy() + + # convert "then" + then_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(then_block) + + then_ret_val = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + if isinstance(then_ret_val, IRLiteral): + then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + + inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) + current_bb.append_instruction(inst) + + after_then_syms = symbols.copy() + + # exit bb + exit_label = ctx.get_next_label() + bb = IRBasicBlock(exit_label, ctx) + bb = ctx.append_basic_block(bb) + + if_ret = None + if then_ret_val is not None and else_ret_val is not None: + if_ret = ctx.get_next_variable() + bb.append_instruction( + IRInstruction( + "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret + ) + ) + + common_symbols = _get_symbols_common(after_then_syms, after_else_syms) + for sym, val in common_symbols.items(): + ret = ctx.get_next_variable() + old_var = symbols.get(sym, None) + symbols[sym] = ret + if old_var is not None: + for idx, var_rec in allocated_variables.items(): # type: ignore + if var_rec.value == old_var.value: + allocated_variables[idx] = ret # type: ignore + bb.append_instruction( + IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) + ) + + if not else_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + else_block.append_instruction(exit_inst) + + if not then_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + then_block.append_instruction(exit_inst) + + return if_ret + + elif ir.value == "with": + ret = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) # initialization + + # Handle with nesting with same symbol + with_symbols = symbols.copy() + + sym = ir.args[0] + if isinstance(ret, IRLiteral): + new_var = ctx.append_instruction("store", [ret]) # type: ignore + with_symbols[sym.value] = new_var + else: + with_symbols[sym.value] = ret # type: ignore + + return _convert_ir_basicblock( + ctx, ir.args[2], with_symbols, variables, allocated_variables + ) # body + elif ir.value == "goto": + _append_jmp(ctx, IRLabel(ir.args[0].value)) + elif ir.value == "jump": + arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + inst = IRInstruction("jmp", [arg_1]) + ctx.get_basic_block().append_instruction(inst) + _new_block(ctx) + elif ir.value == "set": + sym = ir.args[0] + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + symbols[sym.value] = new_var + + elif ir.value == "calldatacopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_v = arg_0 + var = ( + _get_variable_from_address(variables, int(arg_0.value)) + if isinstance(arg_0, IRLiteral) + else None + ) + if var is not None: + if allocated_variables.get(var.name, None) is None: + new_v = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore + ) + allocated_variables[var.name] = new_v # type: ignore + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + symbols[f"&{var.pos}"] = new_v # type: ignore + else: + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + + return new_v + elif ir.value == "codecopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + elif ir.value == "symbol": + return IRLabel(ir.args[0].value, True) + elif ir.value == "data": + label = IRLabel(ir.args[0].value) + ctx.append_data("dbname", [label]) + for c in ir.args[1:]: + if isinstance(c, int): + assert 0 <= c <= 255, "data with invalid size" + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, bytes): + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, IRnode): + data = _convert_ir_basicblock(ctx, c, symbols, variables, allocated_variables) + ctx.append_data("db", [data]) # type: ignore + elif ir.value == "assert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + current_bb = ctx.get_basic_block() + inst = IRInstruction("assert", [arg_0]) # type: ignore + current_bb.append_instruction(inst) + elif ir.value == "label": + label = IRLabel(ir.args[0].value, True) + if not ctx.get_basic_block().is_terminated: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + elif ir.value == "exit_to": + func_t = ir.passthrough_metadata.get("func_t", None) + assert func_t is not None, "exit_to without func_t" + + if func_t.is_external: + # Hardcoded contructor special case + if func_t.name == "__init__": + label = IRLabel(ir.args[0].value, True) + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + return None + if func_t.return_type is None: + inst = IRInstruction("stop", []) + ctx.get_basic_block().append_instruction(inst) + return None + else: + last_ir = None + ret_var = ir.args[1] + deleted = None + if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: + deleted = symbols[f"&{ret_var.value}"] + del symbols[f"&{ret_var.value}"] + for arg in ir.args[2:]: + last_ir = _convert_ir_basicblock( + ctx, arg, symbols, variables, allocated_variables + ) + if deleted is not None: + symbols[f"&{ret_var.value}"] = deleted + + ret_ir = _convert_ir_basicblock( + ctx, ret_var, symbols, variables, allocated_variables + ) + + var = ( + _get_variable_from_address(variables, int(ret_ir.value)) + if isinstance(ret_ir, IRLiteral) + else None + ) + if var is not None: + allocated_var = allocated_variables.get(var.name, None) + assert allocated_var is not None, "unallocated variable" + new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore + + if var.size and int(var.size) > 32: + offset = int(ret_ir.value) - var.pos # type: ignore + if offset > 0: + ptr_var = ctx.append_instruction( + "add", [IRLiteral(var.pos), IRLiteral(offset)] + ) + else: + ptr_var = allocated_var + inst = IRInstruction("return", [last_ir, ptr_var]) + else: + inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + else: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + if sym is None: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if func_t.return_type.memory_bytes_required > 32: + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + else: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if last_ir and int(last_ir.value) > 32: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + ret_buf = IRLiteral(128) # TODO: need allocator + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + + if func_t.is_internal: + assert ir.args[1].value == "return_pc", "return_pc not found" + if func_t.return_type is None: + inst = IRInstruction("ret", [symbols["return_pc"]]) + else: + if func_t.return_type.memory_bytes_required > 32: + inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + else: + ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) + inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) + + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "revert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction("revert", [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "dload": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + + ctx.append_instruction( + "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False + ) + return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) + elif ir.value == "dloadbytes": + dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src_offset = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) + + inst = IRInstruction("dloadbytes", [len_, src, dst]) + ctx.get_basic_block().append_instruction(inst) + return None + elif ir.value == "mload": + sym_ir = ir.args[0] + var = ( + _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None + ) + if var is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mload", [ptr_var]) + else: + if sym_ir.is_literal: + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + new_var = ctx.append_instruction("store", [sym_ir]) + symbols[f"&{sym_ir.value}"] = new_var + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + return sym + + sym = symbols.get(f"&{sym_ir.value}", None) + assert sym is not None, "unallocated variable" + return sym + else: + if sym_ir.is_literal: + new_var = symbols.get(f"&{sym_ir.value}", None) + if new_var is not None: + return ctx.append_instruction("mload", [new_var]) + else: + return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + else: + new_var = _convert_ir_basicblock( + ctx, sym_ir, symbols, variables, allocated_variables + ) + # + # Old IR gets it's return value as a reference in the stack + # New IR gets it's return value in stack in case of 32 bytes or less + # So here we detect ahead of time if this mload leads a self call and + # and we skip the mload + # + if sym_ir.is_self_call: + return new_var + return ctx.append_instruction("mload", [new_var]) + + elif ir.value == "mstore": + sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + + var = None + if isinstance(sym_ir, IRLiteral): + var = _get_variable_from_address(variables, int(sym_ir.value)) + + if var is not None and var.size is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + else: + if isinstance(sym_ir, IRLiteral): + new_var = ctx.append_instruction("store", [arg_1]) + symbols[f"&{sym_ir.value}"] = new_var + # if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + if not isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + return None + + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + if arg_1 and not isinstance(sym_ir, IRLiteral): + symbols[f"&{sym_ir.value}"] = arg_1 + return None + + if isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym]) + ctx.get_basic_block().append_instruction(inst) + return None + else: + symbols[sym_ir.value] = arg_1 + return arg_1 + + elif ir.value in ["sload", "iload"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + return ctx.append_instruction(ir.value, [arg_0]) + elif ir.value in ["sstore", "istore"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction(ir.value, [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + elif ir.value == "unique_symbol": + sym = ir.args[0] + new_var = ctx.get_next_variable() + symbols[f"&{sym.value}"] = new_var + return new_var + elif ir.value == "repeat": + # + # repeat(sym, start, end, bound, body) + # 1) entry block ] + # 2) init counter block ] -> same block + # 3) condition block (exit block, body block) + # 4) body block + # 5) increment block + # 6) exit block + # TODO: Add the extra bounds check after clarify + def emit_body_block(): + global _break_target, _continue_target + old_targets = _break_target, _continue_target + _break_target, _continue_target = exit_block, increment_block + _convert_ir_basicblock(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = old_targets + + sym = ir.args[0] + start = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + end = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + # "bound" is not used + _ = _convert_ir_basicblock(ctx, ir.args[3], symbols, variables, allocated_variables) + body = ir.args[4] + + entry_block = ctx.get_basic_block() + cond_block = IRBasicBlock(ctx.get_next_label(), ctx) + body_block = IRBasicBlock(ctx.get_next_label(), ctx) + jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) + increment_block = IRBasicBlock(ctx.get_next_label(), ctx) + exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + + counter_var = ctx.get_next_variable() + counter_inc_var = ctx.get_next_variable() + ret = ctx.get_next_variable() + + inst = IRInstruction("store", [start], counter_var) + ctx.get_basic_block().append_instruction(inst) + symbols[sym.value] = counter_var + inst = IRInstruction("jmp", [cond_block.label]) + ctx.get_basic_block().append_instruction(inst) + + symbols[sym.value] = ret + cond_block.append_instruction( + IRInstruction( + "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret + ) + ) + + xor_ret = ctx.get_next_variable() + cont_ret = ctx.get_next_variable() + inst = IRInstruction("xor", [ret, end], xor_ret) + cond_block.append_instruction(inst) + cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + ctx.append_basic_block(cond_block) + + # Do a dry run to get the symbols needing phi nodes + start_syms = symbols.copy() + ctx.append_basic_block(body_block) + emit_body_block() + end_syms = symbols.copy() + diff_syms = _get_symbols_common(start_syms, end_syms) + + replacements = {} + for sym, val in diff_syms.items(): + new_var = ctx.get_next_variable() + symbols[sym] = new_var + replacements[val[0]] = new_var + replacements[val[1]] = new_var + cond_block.insert_instruction( + IRInstruction( + "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var + ), + 1, + ) + + body_block.replace_operands(replacements) + + body_end = ctx.get_basic_block() + if not body_end.is_terminated: + body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + + jump_cond = IRInstruction("jmp", [increment_block.label]) + jump_up_block.append_instruction(jump_cond) + ctx.append_basic_block(jump_up_block) + + increment_block.append_instruction( + IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) + ) + increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + ctx.append_basic_block(increment_block) + + ctx.append_basic_block(exit_block) + + inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) + cond_block.append_instruction(inst) + elif ir.value == "break": + assert _break_target is not None, "Break with no break target" + inst = IRInstruction("jmp", [_break_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "continue": + assert _continue_target is not None, "Continue with no contrinue target" + inst = IRInstruction("jmp", [_continue_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "gas": + return ctx.append_instruction("gas", []) + elif ir.value == "returndatasize": + return ctx.append_instruction("returndatasize", []) + elif ir.value == "returndatacopy": + assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + + symbols[f"&{arg_0.value}"] = new_var + return new_var + elif ir.value == "selfdestruct": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + ctx.append_instruction("selfdestruct", [arg_0], False) + elif isinstance(ir.value, str) and ir.value.startswith("log"): + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + inst = IRInstruction(ir.value, reversed(args)) + ctx.get_basic_block().append_instruction(inst) + elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): + _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + elif isinstance(ir.value, str) and ir.value in symbols: + return symbols[ir.value] + elif ir.is_literal: + return IRLiteral(ir.value) + else: + raise Exception(f"Unknown IR node: {ir}") + + return None + + +def _convert_ir_opcode( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> None: + opcode = ir.value.upper() # type: ignore + inst_args = [] + for arg in ir.args: + if isinstance(arg, IRnode): + inst_args.append( + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ) + instruction = IRInstruction(opcode, inst_args) # type: ignore + ctx.get_basic_block().append_instruction(instruction) + + +def _data_ofst_of(sym, ofst, height_): + # e.g. _OFST _sym_foo 32 + assert is_symbol(sym) or is_mem_sym(sym) + if isinstance(ofst.value, int): + # resolve at compile time using magic _OFST op + return ["_OFST", sym, ofst.value] + else: + # if we can't resolve at compile time, resolve at runtime + # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) + return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py new file mode 100644 index 0000000000..11da80ac66 --- /dev/null +++ b/vyper/venom/passes/base_pass.py @@ -0,0 +1,21 @@ +class IRPass: + """ + Decorator for IR passes. This decorator will run the pass repeatedly + until no more changes are made. + """ + + @classmethod + def run_pass(cls, *args, **kwargs): + t = cls() + count = 0 + + while True: + changes_count = t._run_pass(*args, **kwargs) or 0 + count += changes_count + if changes_count == 0: + break + + return count + + def _run_pass(self, *args, **kwargs): + raise NotImplementedError(f"Not implemented! {self.__class__}.run_pass()") diff --git a/vyper/venom/passes/constant_propagation.py b/vyper/venom/passes/constant_propagation.py new file mode 100644 index 0000000000..94b556124e --- /dev/null +++ b/vyper/venom/passes/constant_propagation.py @@ -0,0 +1,13 @@ +from vyper.utils import ir_pass +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +def _process_basic_block(ctx: IRFunction, bb: IRBasicBlock): + pass + + +@ir_pass +def ir_pass_constant_propagation(ctx: IRFunction): + for bb in ctx.basic_blocks: + _process_basic_block(ctx, bb) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py new file mode 100644 index 0000000000..26994bd27f --- /dev/null +++ b/vyper/venom/passes/dft.py @@ -0,0 +1,54 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import DFG +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +# DataFlow Transformation +class DFTPass(IRPass): + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + if inst in self.visited_instructions: + return + self.visited_instructions.add(inst) + + if inst.opcode == "phi": + # phi instructions stay at the beginning of the basic block + # and no input processing is needed + bb.instructions.append(inst) + return + + for op in inst.get_inputs(): + target = self.dfg.get_producing_instruction(op) + if target.parent != inst.parent or target.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + self._process_instruction_r(bb, target) + + bb.instructions.append(inst) + + def _process_basic_block(self, bb: IRBasicBlock) -> None: + self.ctx.append_basic_block(bb) + + instructions = bb.instructions + bb.instructions = [] + + for inst in instructions: + inst.fence_id = self.fence_id + if inst.volatile: + self.fence_id += 1 + + for inst in instructions: + self._process_instruction_r(bb, inst) + + def _run_pass(self, ctx: IRFunction) -> None: + self.ctx = ctx + self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 + self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() + + basic_blocks = ctx.basic_blocks + ctx.basic_blocks = [] + + for bb in basic_blocks: + self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py new file mode 100644 index 0000000000..9ee1012f91 --- /dev/null +++ b/vyper/venom/passes/normalization.py @@ -0,0 +1,90 @@ +from vyper.exceptions import CompilerPanic +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class NormalizationPass(IRPass): + """ + This pass splits basic blocks when there are multiple conditional predecessors. + The code generator expect a normalized CFG, that has the property that + each basic block has at most one conditional predecessor. + """ + + changes = 0 + + def _split_basic_block(self, bb: IRBasicBlock) -> None: + # Iterate over the predecessors of the basic block + for in_bb in list(bb.cfg_in): + jump_inst = in_bb.instructions[-1] + assert bb in in_bb.cfg_out + + # Handle static and dynamic branching + if jump_inst.opcode == "jnz": + self._split_for_static_branch(bb, in_bb) + elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): + self._split_for_dynamic_branch(bb, in_bb) + else: + continue + + self.changes += 1 + + def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + jump_inst = in_bb.instructions[-1] + for i, op in enumerate(jump_inst.operands): + if op == bb.label: + edge = i + break + else: + # none of the edges points to this bb + raise CompilerPanic("bad CFG") + + assert edge in (1, 2) # the arguments which can be labels + + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Redirect the original conditional jump to the intermediary basic block + jump_inst.operands[edge] = split_bb.label + + def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Update any affected labels in the data segment + # TODO: this DESTROYS the cfg! refactor so the translation of the + # selector table produces indirect jumps properly. + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + + def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: + # Create an intermediary basic block and append it + source = in_bb.label.value + target = bb.label.value + split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + self.ctx.append_basic_block(split_bb) + + # Rewire the CFG + # TODO: this is cursed code, it is necessary instead of just running + # calculate_cfg() because split_for_dynamic_branch destroys the CFG! + # ideally, remove this rewiring and just re-run calculate_cfg(). + split_bb.add_cfg_in(in_bb) + split_bb.add_cfg_out(bb) + in_bb.remove_cfg_out(bb) + in_bb.add_cfg_out(split_bb) + bb.remove_cfg_in(in_bb) + bb.add_cfg_in(split_bb) + return split_bb + + def _run_pass(self, ctx: IRFunction) -> int: + self.ctx = ctx + self.changes = 0 + + for bb in ctx.basic_blocks: + if len(bb.cfg_in) > 1: + self._split_basic_block(bb) + + # Sanity check + assert ctx.normalized, "Normalization pass failed" + + return self.changes diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py new file mode 100644 index 0000000000..66c62b74d2 --- /dev/null +++ b/vyper/venom/stack_model.py @@ -0,0 +1,100 @@ +from vyper.venom.basicblock import IROperand, IRVariable + + +class StackModel: + NOT_IN_STACK = object() + _stack: list[IROperand] + + def __init__(self): + self._stack = [] + + def copy(self): + new = StackModel() + new._stack = self._stack.copy() + return new + + @property + def height(self) -> int: + """ + Returns the height of the stack map. + """ + return len(self._stack) + + def push(self, op: IROperand) -> None: + """ + Pushes an operand onto the stack map. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack.append(op) + + def pop(self, num: int = 1) -> None: + del self._stack[len(self._stack) - num :] + + def get_depth(self, op: IROperand) -> int: + """ + Returns the depth of the first matching operand in the stack map. + If the operand is not in the stack map, returns NOT_IN_STACK. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + + for i, stack_op in enumerate(reversed(self._stack)): + if stack_op.value == op.value: + return -i + + return StackModel.NOT_IN_STACK # type: ignore + + def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + """ + Returns the depth of the first matching phi variable in the stack map. + If the none of the phi operands are in the stack, returns NOT_IN_STACK. + Asserts that exactly one of phi1 and phi2 is found. + """ + assert isinstance(phi1, IRVariable) + assert isinstance(phi2, IRVariable) + + ret = StackModel.NOT_IN_STACK + for i, stack_item in enumerate(reversed(self._stack)): + if stack_item in (phi1, phi2): + assert ( + ret is StackModel.NOT_IN_STACK + ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ret = -i + + return ret # type: ignore + + def peek(self, depth: int) -> IROperand: + """ + Returns the top of the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot peek non-in-stack depth" + return self._stack[depth - 1] + + def poke(self, depth: int, op: IROperand) -> None: + """ + Pokes an operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot poke non-in-stack depth" + assert depth <= 0, "Bad depth" + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack[depth - 1] = op + + def dup(self, depth: int) -> None: + """ + Duplicates the operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot dup non-existent operand" + assert depth <= 0, "Cannot dup positive depth" + self._stack.append(self.peek(depth)) + + def swap(self, depth: int) -> None: + """ + Swaps the operand at the given depth in the stack map with the top of the stack. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot swap non-existent operand" + assert depth < 0, "Cannot swap positive depth" + top = self._stack[-1] + self._stack[-1] = self._stack[depth - 1] + self._stack[depth - 1] = top + + def __repr__(self) -> str: + return f"" diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py new file mode 100644 index 0000000000..f6ec45440a --- /dev/null +++ b/vyper/venom/venom_to_assembly.py @@ -0,0 +1,461 @@ +from typing import Any + +from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.stack_model import StackModel + +# instructions which map one-to-one from venom to EVM +_ONE_TO_ONE_INSTRUCTIONS = frozenset( + [ + "revert", + "coinbase", + "calldatasize", + "calldatacopy", + "calldataload", + "gas", + "gasprice", + "gaslimit", + "address", + "origin", + "number", + "extcodesize", + "extcodehash", + "returndatasize", + "returndatacopy", + "callvalue", + "selfbalance", + "sload", + "sstore", + "mload", + "mstore", + "timestamp", + "caller", + "selfdestruct", + "signextend", + "stop", + "shr", + "shl", + "and", + "xor", + "or", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "eq", + "iszero", + "lg", + "lt", + "slt", + "sgt", + "log0", + "log1", + "log2", + "log3", + "log4", + ] +) + + +# TODO: "assembly" gets into the recursion due to how the original +# IR was structured recursively in regards with the deploy instruction. +# There, recursing into the deploy instruction was by design, and +# made it easier to make the assembly generated "recursive" (i.e. +# instructions being lists of instructions). We don't have this restriction +# anymore, so we can probably refactor this to be iterative in coordination +# with the assembler. My suggestion is to let this be for now, and we can +# refactor it later when we are finished phasing out the old IR. +class VenomCompiler: + ctx: IRFunction + label_counter = 0 + visited_instructions: OrderedSet # {IRInstruction} + visited_basicblocks: OrderedSet # {IRBasicBlock} + + def __init__(self, ctx: IRFunction): + self.ctx = ctx + self.label_counter = 0 + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + + def generate_evm(self, no_optimize: bool = False) -> list[str]: + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + self.label_counter = 0 + + stack = StackModel() + asm: list[str] = [] + + # Before emitting the assembly, we need to make sure that the + # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) + # so it should not be called after calling NormalizationPass.run_pass(). + # Liveness is then computed for the normalized IR, and we can proceed to + # assembly generation. + # This is a side-effect of how dynamic jumps are temporarily being used + # to support the O(1) dispatcher. -> look into calculate_cfg() + calculate_cfg(self.ctx) + NormalizationPass.run_pass(self.ctx) + calculate_liveness(self.ctx) + + assert self.ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) + + # Append postambles + revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + runtime = None + if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): + runtime = asm.pop() + + asm.extend(revert_postamble) + if runtime: + runtime.extend(revert_postamble) + asm.append(runtime) + + # Append data segment + data_segments: dict[Any, list[Any]] = dict() + for inst in self.ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") + + extent_point = asm if not isinstance(asm[-1], list) else asm[-1] + extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + + if no_optimize is False: + optimize_assembly(asm) + + return asm + + def _stack_reorder( + self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + ) -> None: + # make a list so we can index it + stack_ops = [x for x in _stack_ops.keys()] + stack_ops_count = len(_stack_ops) + + for i in range(stack_ops_count): + op = stack_ops[i] + final_stack_depth = -(stack_ops_count - i - 1) + depth = stack.get_depth(op) # type: ignore + + if depth == final_stack_depth: + continue + + self.swap(assembly, stack, depth) + self.swap(assembly, stack, final_stack_depth) + + def _emit_input_operands( + self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + ) -> None: + # PRE: we already have all the items on the stack that have + # been scheduled to be killed. now it's just a matter of emitting + # SWAPs, DUPs and PUSHes until we match the `ops` argument + + # dumb heuristic: if the top of stack is not wanted here, swap + # it with something that is wanted + if ops and stack.height > 0 and stack.peek(0) not in ops: + for op in ops: + if isinstance(op, IRVariable) and op not in inst.dup_requirements: + self.swap_op(assembly, stack, op) + break + + emitted_ops = OrderedSet[IROperand]() + for op in ops: + if isinstance(op, IRLabel): + # invoke emits the actual instruction itself so we don't need to emit it here + # but we need to add it to the stack map + if inst.opcode != "invoke": + assembly.append(f"_sym_{op.value}") + stack.push(op) + continue + + if isinstance(op, IRLiteral): + assembly.extend([*PUSH(op.value)]) + stack.push(op) + continue + + if op in inst.dup_requirements: + self.dup_op(assembly, stack, op) + + if op in emitted_ops: + self.dup_op(assembly, stack, op) + + # REVIEW: this seems like it can be reordered across volatile + # boundaries (which includes memory fences). maybe just + # remove it entirely at this point + if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(op.mem_addr)]) + assembly.append("MLOAD") + + emitted_ops.add(op) + + def _generate_evm_for_basicblock_r( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if basicblock in self.visited_basicblocks: + return + self.visited_basicblocks.add(basicblock) + + # assembly entry point into the block + asm.append(f"_sym_{basicblock.label}") + asm.append("JUMPDEST") + + self.clean_stack_from_cfg_in(asm, basicblock, stack) + + for inst in basicblock.instructions: + asm = self._generate_evm_for_instruction(asm, inst, stack) + + for bb in basicblock.cfg_out: + self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + + # pop values from stack at entry to bb + # note this produces the same result(!) no matter which basic block + # we enter from in the CFG. + def clean_stack_from_cfg_in( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if len(basicblock.cfg_in) == 0: + return + + to_pop = OrderedSet[IRVariable]() + for in_bb in basicblock.cfg_in: + # inputs is the input variables we need from in_bb + inputs = input_vars_from(in_bb, basicblock) + + # layout is the output stack layout for in_bb (which works + # for all possible cfg_outs from the in_bb). + layout = in_bb.out_vars + + # pop all the stack items which in_bb produced which we don't need. + to_pop |= layout.difference(inputs) + + for var in to_pop: + depth = stack.get_depth(var) + # don't pop phantom phi inputs + if depth is StackModel.NOT_IN_STACK: + continue + + if depth != 0: + stack.swap(depth) + self.pop(asm, stack) + + def _generate_evm_for_instruction( + self, assembly: list, inst: IRInstruction, stack: StackModel + ) -> list[str]: + opcode = inst.opcode + + # + # generate EVM for op + # + + # Step 1: Apply instruction special stack manipulations + + if opcode in ["jmp", "jnz", "invoke"]: + operands = inst.get_non_label_operands() + elif opcode == "alloca": + operands = inst.operands[1:2] + elif opcode == "iload": + operands = [] + elif opcode == "istore": + operands = inst.operands[0:1] + else: + operands = inst.operands + + if opcode == "phi": + ret = inst.get_outputs()[0] + phi1, phi2 = inst.get_inputs() + depth = stack.get_phi_depth(phi1, phi2) + # collapse the arguments to the phi node in the stack. + # example, for `%56 = %label1 %13 %label2 %14`, we will + # find an instance of %13 *or* %14 in the stack and replace it with %56. + to_be_replaced = stack.peek(depth) + if to_be_replaced in inst.dup_requirements: + # %13/%14 is still live(!), so we make a copy of it + self.dup(assembly, stack, depth) + stack.poke(0, ret) + else: + stack.poke(depth, ret) + return assembly + + # Step 2: Emit instruction's input operands + self._emit_input_operands(assembly, inst, operands, stack) + + # Step 3: Reorder stack + if opcode in ["jnz", "jmp"]: + # prepare stack for jump into another basic block + assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) + b = next(iter(inst.parent.cfg_out)) + target_stack = input_vars_from(inst.parent, b) + # TODO optimize stack reordering at entry and exit from basic blocks + self._stack_reorder(assembly, stack, target_stack) + + # final step to get the inputs to this instruction ordered + # correctly on the stack + self._stack_reorder(assembly, stack, OrderedSet(operands)) + + # some instructions (i.e. invoke) need to do stack manipulations + # with the stack model containing the return value(s), so we fiddle + # with the stack model beforehand. + + # Step 4: Push instruction's return value to stack + stack.pop(len(operands)) + if inst.output is not None: + stack.push(inst.output) + + # Step 5: Emit the EVM instruction(s) + if opcode in _ONE_TO_ONE_INSTRUCTIONS: + assembly.append(opcode.upper()) + elif opcode == "alloca": + pass + elif opcode == "param": + pass + elif opcode == "store": + pass + elif opcode == "dbname": + pass + elif opcode in ["codecopy", "dloadbytes"]: + assembly.append("CODECOPY") + elif opcode == "jnz": + # jump if not zero + if_nonzero_label = inst.operands[1] + if_zero_label = inst.operands[2] + assembly.append(f"_sym_{if_nonzero_label.value}") + assembly.append("JUMPI") + + # make sure the if_zero_label will be optimized out + # assert if_zero_label == next(iter(inst.parent.cfg_out)).label + + assembly.append(f"_sym_{if_zero_label.value}") + assembly.append("JUMP") + + elif opcode == "jmp": + if isinstance(inst.operands[0], IRLabel): + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + else: + assembly.append("JUMP") + elif opcode == "gt": + assembly.append("GT") + elif opcode == "lt": + assembly.append("LT") + elif opcode == "invoke": + target = inst.operands[0] + assert isinstance(target, IRLabel), "invoke target must be a label" + assembly.extend( + [ + f"_sym_label_ret_{self.label_counter}", + f"_sym_{target.value}", + "JUMP", + f"_sym_label_ret_{self.label_counter}", + "JUMPDEST", + ] + ) + self.label_counter += 1 + if stack.height > 0 and stack.peek(0) in inst.dup_requirements: + self.pop(assembly, stack) + elif opcode == "call": + assembly.append("CALL") + elif opcode == "staticcall": + assembly.append("STATICCALL") + elif opcode == "ret": + assembly.append("JUMP") + elif opcode == "return": + assembly.append("RETURN") + elif opcode == "phi": + pass + elif opcode == "sha3": + assembly.append("SHA3") + elif opcode == "sha3_64": + assembly.extend( + [ + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "MSTORE", + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "SHA3", + ] + ) + elif opcode == "ceil32": + assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) + elif opcode == "assert": + assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "deploy": + memsize = inst.operands[0].value + padding = inst.operands[2].value + # TODO: fix this by removing deploy opcode altogether me move emition to ir translation + while assembly[-1] != "JUMPDEST": + assembly.pop() + assembly.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len + assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst + assembly.extend(["RETURN"]) + assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore + assembly = assembly[-1] + elif opcode == "iload": + loc = inst.operands[0].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + elif opcode == "istore": + loc = inst.operands[1].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + else: + raise Exception(f"Unknown opcode: {opcode}") + + # Step 6: Emit instructions output operands (if any) + if inst.output is not None: + assert isinstance(inst.output, IRVariable), "Return value must be a variable" + if inst.output.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(inst.output.mem_addr)]) + + return assembly + + def pop(self, assembly, stack, num=1): + stack.pop(num) + assembly.extend(["POP"] * num) + + def swap(self, assembly, stack, depth): + if depth == 0: + return + stack.swap(depth) + assembly.append(_evm_swap_for(depth)) + + def dup(self, assembly, stack, depth): + stack.dup(depth) + assembly.append(_evm_dup_for(depth)) + + def swap_op(self, assembly, stack, op): + self.swap(assembly, stack, stack.get_depth(op)) + + def dup_op(self, assembly, stack, op): + self.dup(assembly, stack, stack.get_depth(op)) + + +def _evm_swap_for(depth: int) -> str: + swap_idx = -depth + assert 1 <= swap_idx <= 16, "Unsupported swap depth" + return f"SWAP{swap_idx}" + + +def _evm_dup_for(depth: int) -> str: + dup_idx = 1 - depth + assert 1 <= dup_idx <= 16, "Unsupported dup depth" + return f"DUP{dup_idx}" From 21a47b614d1bd1e989195adedb1f5b709f5fbfee Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 7 Dec 2023 09:36:47 -0500 Subject: [PATCH 133/201] chore: move venom tests to `tests/unit/compiler` (#3684) the `tests/compiler` directory was moved to `tests/unit/` in 4dd47e302fc538c but this seems to have been missed in a merge during work on venom (cbac5aba53f87b) --- tests/functional/codegen/integration/test_crowdfund.py | 5 ++++- tests/{ => unit}/compiler/venom/test_duplicate_operands.py | 0 tests/{ => unit}/compiler/venom/test_multi_entry_block.py | 0 .../compiler/venom/test_stack_at_external_return.py | 0 4 files changed, 4 insertions(+), 1 deletion(-) rename tests/{ => unit}/compiler/venom/test_duplicate_operands.py (100%) rename tests/{ => unit}/compiler/venom/test_multi_entry_block.py (100%) rename tests/{ => unit}/compiler/venom/test_stack_at_external_return.py (100%) diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 47c63dc015..2083e62610 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -63,10 +63,13 @@ def refund(): """ a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] + c = get_contract_with_gas_estimation_for_constants(crowdfund, *[a1, 50, 60]) + start_timestamp = w3.eth.get_block(w3.eth.block_number).timestamp + c.participate(transact={"value": 5}) assert c.timelimit() == 60 - assert c.deadline() - c.block_timestamp() == 59 + assert c.deadline() - start_timestamp == 60 assert not c.expired() assert not c.reached() c.participate(transact={"value": 49}) diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py similarity index 100% rename from tests/compiler/venom/test_duplicate_operands.py rename to tests/unit/compiler/venom/test_duplicate_operands.py diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py similarity index 100% rename from tests/compiler/venom/test_multi_entry_block.py rename to tests/unit/compiler/venom/test_multi_entry_block.py diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/unit/compiler/venom/test_stack_at_external_return.py similarity index 100% rename from tests/compiler/venom/test_stack_at_external_return.py rename to tests/unit/compiler/venom/test_stack_at_external_return.py From 7c74aa2618c8051db88acfac3bd71a3017c524cb Mon Sep 17 00:00:00 2001 From: Franfran <51274081+iFrostizz@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:20:46 +0100 Subject: [PATCH 134/201] fix: add compile-time check for negative uint2str input (#3671) --- .../builtins/codegen/test_uint2str.py | 25 +++++++++++++++++++ vyper/builtins/functions.py | 5 +++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/functional/builtins/codegen/test_uint2str.py b/tests/functional/builtins/codegen/test_uint2str.py index 9d2b7fe3f5..d9edea154b 100644 --- a/tests/functional/builtins/codegen/test_uint2str.py +++ b/tests/functional/builtins/codegen/test_uint2str.py @@ -2,6 +2,9 @@ import pytest +from vyper.compiler import compile_code +from vyper.exceptions import InvalidType, OverflowException + VALID_BITS = list(range(8, 257, 8)) @@ -37,3 +40,25 @@ def foo(x: uint{bits}) -> uint256: """ c = get_contract(code) assert c.foo(2**bits - 1) == 0, bits + + +def test_bignum_throws(): + code = """ +@external +def test(): + a: String[78] = uint2str(2**256) + pass + """ + with pytest.raises(OverflowException): + compile_code(code) + + +def test_int_fails(): + code = """ +@external +def test(): + a: String[78] = uint2str(-1) + pass + """ + with pytest.raises(InvalidType): + compile_code(code) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index b2d817ec5c..22931508a6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2090,7 +2090,10 @@ def evaluate(self, node): if not isinstance(node.args[0], vy_ast.Int): raise UnfoldableNode - value = str(node.args[0].value) + value = node.args[0].value + if value < 0: + raise InvalidType("Only unsigned ints allowed", node) + value = str(value) return vy_ast.Str.from_node(node, value=value) def infer_arg_types(self, node): From 10564dcc37756f3d3684b7a91fd8f4325a38c4d8 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Tue, 12 Dec 2023 19:47:44 +0200 Subject: [PATCH 135/201] refactor: improve `IRBasicBlock` builder API clean up `append_instruction` api so it does a bit of magic on its arguments and figures out whether or not to allocate a stack variable. remove `append_instruction()` from IRFunction - automatically appending to the last basic block could be a bit error prone depending on which order basic blocks are added to the CFG. --------- Co-authored-by: Charles Cooper --- .../compiler/venom/test_duplicate_operands.py | 11 +- .../compiler/venom/test_multi_entry_block.py | 53 +-- vyper/venom/analysis.py | 4 +- vyper/venom/basicblock.py | 86 ++++- vyper/venom/function.py | 11 - vyper/venom/ir_node_to_venom.py | 325 ++++++++---------- vyper/venom/passes/normalization.py | 4 +- vyper/venom/venom_to_assembly.py | 11 +- 8 files changed, 260 insertions(+), 245 deletions(-) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 505f01e31b..a51992df67 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -1,6 +1,5 @@ from vyper.compiler.settings import OptimizationLevel from vyper.venom import generate_assembly_experimental -from vyper.venom.basicblock import IRLiteral from vyper.venom.function import IRFunction @@ -17,11 +16,11 @@ def test_duplicate_operands(): Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] """ ctx = IRFunction() - - op = ctx.append_instruction("store", [IRLiteral(10)]) - sum = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("mul", [sum, op]) - ctx.append_instruction("stop", [], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + sum = bb.append_instruction("add", op, op) + bb.append_instruction("mul", sum, op) + bb.append_instruction("stop") asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index bb57fa1065..6e7e6995d6 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -1,5 +1,4 @@ from vyper.venom.analysis import calculate_cfg -from vyper.venom.basicblock import IRLiteral from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel from vyper.venom.passes.normalization import NormalizationPass @@ -11,25 +10,26 @@ def test_multi_entry_block_1(): target_label = IRLabel("target") block_1_label = IRLabel("block_1", ctx) - op = ctx.append_instruction("store", [IRLiteral(10)]) - acc = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("jnz", acc, finish_label, block_1_label) block_1 = IRBasicBlock(block_1_label, ctx) ctx.append_basic_block(block_1) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, finish_label, target_label) target_bb = IRBasicBlock(target_label, ctx) ctx.append_basic_block(target_bb) - ctx.append_instruction("mul", [acc, acc]) - ctx.append_instruction("jmp", [finish_label], False) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) finish_bb = IRBasicBlock(finish_label, ctx) ctx.append_basic_block(finish_bb) - ctx.append_instruction("stop", [], False) + finish_bb.append_instruction("stop") calculate_cfg(ctx) assert not ctx.normalized, "CFG should not be normalized" @@ -54,33 +54,34 @@ def test_multi_entry_block_2(): block_1_label = IRLabel("block_1", ctx) block_2_label = IRLabel("block_2", ctx) - op = ctx.append_instruction("store", [IRLiteral(10)]) - acc = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("jnz", acc, finish_label, block_1_label) block_1 = IRBasicBlock(block_1_label, ctx) ctx.append_basic_block(block_1) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, target_label, finish_label) block_2 = IRBasicBlock(block_2_label, ctx) ctx.append_basic_block(block_2) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - # switch the order of the labels, for fun - ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + acc = block_2.append_instruction("add", acc, op) + op = block_2.append_instruction("store", 10) + block_2.append_instruction("mstore", acc, op) + # switch the order of the labels, for fun and profit + block_2.append_instruction("jnz", acc, finish_label, target_label) target_bb = IRBasicBlock(target_label, ctx) ctx.append_basic_block(target_bb) - ctx.append_instruction("mul", [acc, acc]) - ctx.append_instruction("jmp", [finish_label], False) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) finish_bb = IRBasicBlock(finish_label, ctx) ctx.append_basic_block(finish_bb) - ctx.append_instruction("stop", [], False) + finish_bb.append_instruction("stop") calculate_cfg(ctx) assert not ctx.normalized, "CFG should not be normalized" diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 5980e21028..1a82ca85d0 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -2,7 +2,7 @@ from vyper.utils import OrderedSet from vyper.venom.basicblock import ( BB_TERMINATORS, - CFG_ALTERING_OPS, + CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRInstruction, IRVariable, @@ -55,7 +55,7 @@ def calculate_cfg(ctx: IRFunction) -> None: assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_OPS: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS: ops = inst.get_label_operands() for op in ops: ctx.get_basic_block(op.value).add_cfg_in(bb) diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index b95d7416ca..6f1c1c8ab3 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union from vyper.utils import OrderedSet @@ -31,8 +31,31 @@ ] ) -CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) +NO_OUTPUT_INSTRUCTIONS = frozenset( + [ + "deploy", + "mstore", + "sstore", + "dstore", + "istore", + "dloadbytes", + "calldatacopy", + "codecopy", + "return", + "ret", + "revert", + "assert", + "selfdestruct", + "stop", + "invalid", + "invoke", + "jmp", + "jnz", + "log", + ] +) +CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -40,8 +63,8 @@ class IRDebugInfo: """ - IRDebugInfo represents debug information in IR, used to annotate IR instructions - with source code information when printing IR. + IRDebugInfo represents debug information in IR, used to annotate IR + instructions with source code information when printing IR. """ line_no: int @@ -83,7 +106,7 @@ class IRLiteral(IRValue): value: int def __init__(self, value: int) -> None: - assert isinstance(value, str) or isinstance(value, int), "value must be an int" + assert isinstance(value, int), "value must be an int" self.value = value def __repr__(self) -> str: @@ -170,7 +193,7 @@ def __init__( assert isinstance(operands, list | Iterator), "operands must be a list" self.opcode = opcode self.volatile = opcode in VOLATILE_INSTRUCTIONS - self.operands = [op for op in operands] # in case we get an iterator + self.operands = list(operands) # in case we get an iterator self.output = output self.liveness = OrderedSet() self.dup_requirements = OrderedSet() @@ -233,6 +256,14 @@ def __repr__(self) -> str: return s +def _ir_operand_from_value(val: Any) -> IROperand: + if isinstance(val, IROperand): + return val + + assert isinstance(val, int) + return IRLiteral(val) + + class IRBasicBlock: """ IRBasicBlock represents a basic block in IR. Each basic block has a label and @@ -243,8 +274,8 @@ class IRBasicBlock: %2 = mul %1, 2 is represented as: bb = IRBasicBlock("bb", function) - bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) - bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + r1 = bb.append_instruction("add", "%0", "1") + r2 = bb.append_instruction("mul", r1, "2") The label of a basic block is used to refer to it from other basic blocks in order to branch to it. @@ -296,10 +327,41 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: def is_reachable(self) -> bool: return len(self.cfg_in) > 0 - def append_instruction(self, instruction: IRInstruction) -> None: - assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" - instruction.parent = self - self.instructions.append(instruction) + def append_instruction(self, opcode: str, *args: Union[IROperand, int]) -> Optional[IRVariable]: + """ + Append an instruction to the basic block + + Returns the output variable if the instruction supports one + """ + ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + + # Wrap raw integers in IRLiterals + inst_args = [_ir_operand_from_value(arg) for arg in args] + + inst = IRInstruction(opcode, inst_args, ret) + inst.parent = self + self.instructions.append(inst) + return ret + + def append_invoke_instruction( + self, args: list[IROperand | int], returns: bool + ) -> Optional[IRVariable]: + """ + Append an instruction to the basic block + + Returns the output variable if the instruction supports one + """ + ret = None + if returns: + ret = self.parent.get_next_variable() + + # Wrap raw integers in IRLiterals + inst_args = [_ir_operand_from_value(arg) for arg in args] + + inst = IRInstruction("invoke", inst_args, ret) + inst.parent = self + self.instructions.append(inst) + return ret def insert_instruction(self, instruction: IRInstruction, index: int) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" diff --git a/vyper/venom/function.py b/vyper/venom/function.py index c14ad77345..e16b2ad6e6 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -98,17 +98,6 @@ def remove_unreachable_blocks(self) -> int: self.basic_blocks = new_basic_blocks return removed - def append_instruction( - self, opcode: str, args: list[IROperand], do_ret: bool = True - ) -> Optional[IRVariable]: - """ - Append instruction to last basic block. - """ - ret = self.get_next_variable() if do_ret else None - inst = IRInstruction(opcode, args, ret) # type: ignore - self.get_basic_block().append_instruction(inst) - return ret - def append_data(self, opcode: str, args: list[IROperand]) -> None: """ Append data diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 19bd5c8b73..e2ce28a8f9 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -72,7 +72,7 @@ "balance", ] -SymbolTable = dict[str, IROperand] +SymbolTable = dict[str, Optional[IROperand]] def _get_symbols_common(a: dict, b: dict) -> dict: @@ -93,11 +93,11 @@ def convert_ir_basicblock(ir: IRnode) -> IRFunction: for i, bb in enumerate(global_function.basic_blocks): if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: - bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + bb.append_instruction("jmp", global_function.basic_blocks[i + 1].label) revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) revert_bb = global_function.append_basic_block(revert_bb) - revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + revert_bb.append_instruction("revert", 0, 0) return global_function @@ -109,22 +109,16 @@ def _convert_binary_op( variables: OrderedSet, allocated_variables: dict[str, IRVariable], swap: bool = False, -) -> IRVariable: +) -> Optional[IRVariable]: ir_args = ir.args[::-1] if swap else ir.args arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) - args = [arg_1, arg_0] - - ret = ctx.get_next_variable() - inst = IRInstruction(ir.value, args, ret) # type: ignore - ctx.get_basic_block().append_instruction(inst) - return ret + return ctx.get_basic_block().append_instruction(str(ir.value), arg_1, arg_0) def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", label) label = ctx.get_next_label() bb = IRBasicBlock(label, ctx) @@ -149,7 +143,7 @@ def _handle_self_call( goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] target_label = goto_ir.args[0].value # goto return_buf = goto_ir.args[1] # return buffer - ret_args = [IRLabel(target_label)] # type: ignore + ret_args: list[IROperand] = [IRLabel(target_label)] # type: ignore for arg in args_ir: if arg.is_literal: @@ -164,16 +158,23 @@ def _handle_self_call( ctx, arg._optimized, symbols, variables, allocated_variables ) if arg.location and arg.location.load_op == "calldataload": - ret = ctx.append_instruction(arg.location.load_op, [ret]) + bb = ctx.get_basic_block() + ret = bb.append_instruction(arg.location.load_op, ret) ret_args.append(ret) if return_buf.is_literal: - ret_args.append(IRLiteral(return_buf.value)) # type: ignore + ret_args.append(return_buf.value) # type: ignore + + bb = ctx.get_basic_block() do_ret = func_t.return_type is not None - invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore - allocated_variables["return_buffer"] = invoke_ret # type: ignore - return invoke_ret + if do_ret: + invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + else: + bb.append_invoke_instruction(ret_args, returns=False) # type: ignore + return None def _handle_internal_func( @@ -186,28 +187,18 @@ def _handle_internal_func( old_ir_mempos += 64 for arg in func_t.arguments: - new_var = ctx.get_next_variable() - - alloca_inst = IRInstruction("param", [], new_var) - alloca_inst.annotation = arg.name - bb.append_instruction(alloca_inst) - symbols[f"&{old_ir_mempos}"] = new_var + symbols[f"&{old_ir_mempos}"] = bb.append_instruction("param") + bb.instructions[-1].annotation = arg.name old_ir_mempos += 32 # arg.typ.memory_bytes_required # return buffer if func_t.return_type is not None: - new_var = ctx.get_next_variable() - alloca_inst = IRInstruction("param", [], new_var) - bb.append_instruction(alloca_inst) - alloca_inst.annotation = "return_buffer" - symbols["return_buffer"] = new_var + symbols["return_buffer"] = bb.append_instruction("param") + bb.instructions[-1].annotation = "return_buffer" # return address - new_var = ctx.get_next_variable() - alloca_inst = IRInstruction("param", [], new_var) - bb.append_instruction(alloca_inst) - alloca_inst.annotation = "return_pc" - symbols["return_pc"] = new_var + symbols["return_pc"] = bb.append_instruction("param") + bb.instructions[-1].annotation = "return_pc" return ir.args[0].args[2] @@ -222,7 +213,7 @@ def _convert_ir_simple_node( args = [ _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args ] - return ctx.append_instruction(ir.value, args) # type: ignore + return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore _break_target: Optional[IRBasicBlock] = None @@ -241,22 +232,22 @@ def _get_variable_from_address( return None -def _get_return_for_stack_operand( - ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable -) -> IRInstruction: +def _append_return_for_stack_operand( + bb: IRBasicBlock, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> None: if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) - ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + new_var = bb.append_instruction("alloca", 32, ret_ir) + bb.append_instruction("mstore", sym, new_var) # type: ignore else: sym = symbols.get(ret_ir.value, None) if sym is None: # FIXME: needs real allocations - new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) - ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + new_var = bb.append_instruction("alloca", 32, 0) + bb.append_instruction("mstore", ret_ir, new_var) # type: ignore else: new_var = ret_ir - return IRInstruction("return", [last_ir, new_var]) # type: ignore + bb.append_instruction("return", last_ir, new_var) # type: ignore def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): @@ -280,7 +271,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) ir.value = org_value - return ctx.append_instruction("iszero", [new_var]) + return ctx.get_basic_block().append_instruction("iszero", new_var) elif ir.value in PASS_THROUGH_INSTRUCTIONS: return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) @@ -296,8 +287,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): runtimeLabel = ctx.get_next_label() - inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("deploy", memsize, runtimeLabel, padding) bb = IRBasicBlock(runtimeLabel, ctx) ctx.append_basic_block(bb) @@ -369,12 +359,14 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) symbols[f"&{retOffsetValue}"] = retVar + bb = ctx.get_basic_block() + if ir.value == "call": args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] - return ctx.append_instruction(ir.value, args) + return bb.append_instruction(ir.value, *args) else: args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] - return ctx.append_instruction(ir.value, args) + return bb.append_instruction(ir.value, *args) elif ir.value == "if": cond = ir.args[0] current_bb = ctx.get_basic_block() @@ -394,7 +386,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy - else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) after_else_syms = else_syms.copy() # convert "then" @@ -405,10 +397,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ctx, ir.args[1], symbols, variables, allocated_variables ) if isinstance(then_ret_val, IRLiteral): - then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) - inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) - current_bb.append_instruction(inst) + current_bb.append_instruction("jnz", cont_ret, then_block.label, else_block.label) after_then_syms = symbols.copy() @@ -419,33 +410,25 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if_ret = None if then_ret_val is not None and else_ret_val is not None: - if_ret = ctx.get_next_variable() - bb.append_instruction( - IRInstruction( - "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret - ) + if_ret = bb.append_instruction( + "phi", then_block.label, then_ret_val, else_block.label, else_ret_val ) common_symbols = _get_symbols_common(after_then_syms, after_else_syms) for sym, val in common_symbols.items(): - ret = ctx.get_next_variable() + ret = bb.append_instruction("phi", then_block.label, val[0], else_block.label, val[1]) old_var = symbols.get(sym, None) symbols[sym] = ret if old_var is not None: for idx, var_rec in allocated_variables.items(): # type: ignore if var_rec.value == old_var.value: allocated_variables[idx] = ret # type: ignore - bb.append_instruction( - IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) - ) if not else_block.is_terminated: - exit_inst = IRInstruction("jmp", [bb.label]) - else_block.append_instruction(exit_inst) + else_block.append_instruction("jmp", bb.label) if not then_block.is_terminated: - exit_inst = IRInstruction("jmp", [bb.label]) - then_block.append_instruction(exit_inst) + then_block.append_instruction("jmp", bb.label) return if_ret @@ -459,7 +442,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): sym = ir.args[0] if isinstance(ret, IRLiteral): - new_var = ctx.append_instruction("store", [ret]) # type: ignore + new_var = ctx.get_basic_block().append_instruction("store", ret) # type: ignore with_symbols[sym.value] = new_var else: with_symbols[sym.value] = ret # type: ignore @@ -471,13 +454,12 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "jump": arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - inst = IRInstruction("jmp", [arg_1]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", arg_1) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + new_var = ctx.get_basic_block().append_instruction("store", arg_1) # type: ignore symbols[sym.value] = new_var elif ir.value == "calldatacopy": @@ -491,16 +473,15 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if isinstance(arg_0, IRLiteral) else None ) + bb = ctx.get_basic_block() if var is not None: if allocated_variables.get(var.name, None) is None: - new_v = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore - ) + new_v = bb.append_instruction("alloca", var.size, var.pos) # type: ignore allocated_variables[var.name] = new_v # type: ignore - ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore symbols[f"&{var.pos}"] = new_v # type: ignore else: - ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore return new_v elif ir.value == "codecopy": @@ -508,7 +489,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + ctx.get_basic_block().append_instruction("codecopy", size, arg_1, arg_0) # type: ignore elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": @@ -526,13 +507,12 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value == "assert": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) current_bb = ctx.get_basic_block() - inst = IRInstruction("assert", [arg_0]) # type: ignore - current_bb.append_instruction(inst) + current_bb.append_instruction("assert", arg_0) elif ir.value == "label": label = IRLabel(ir.args[0].value, True) - if not ctx.get_basic_block().is_terminated: - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + bb = ctx.get_basic_block() + if not bb.is_terminated: + bb.append_instruction("jmp", label) bb = IRBasicBlock(label, ctx) ctx.append_basic_block(bb) _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) @@ -542,14 +522,13 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if func_t.is_external: # Hardcoded contructor special case + bb = ctx.get_basic_block() if func_t.name == "__init__": label = IRLabel(ir.args[0].value, True) - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("jmp", label) return None if func_t.return_type is None: - inst = IRInstruction("stop", []) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("stop") return None else: last_ir = None @@ -569,6 +548,8 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ctx, ret_var, symbols, variables, allocated_variables ) + bb = ctx.get_basic_block() + var = ( _get_variable_from_address(variables, int(ret_ir.value)) if isinstance(ret_ir, IRLiteral) @@ -582,101 +563,96 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if var.size and int(var.size) > 32: offset = int(ret_ir.value) - var.pos # type: ignore if offset > 0: - ptr_var = ctx.append_instruction( - "add", [IRLiteral(var.pos), IRLiteral(offset)] - ) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_var - inst = IRInstruction("return", [last_ir, ptr_var]) + bb.append_instruction("return", last_ir, ptr_var) else: - inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + _append_return_for_stack_operand(ctx, symbols, new_var, last_ir) else: if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) if sym is None: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: if func_t.return_type.memory_bytes_required > 32: - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) - ctx.append_instruction("mstore", [sym, new_var], False) - inst = IRInstruction("return", [last_ir, new_var]) + new_var = bb.append_instruction("alloca", 32, ret_ir) + bb.append_instruction("mstore", sym, new_var) + bb.append_instruction("return", last_ir, new_var) else: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: if last_ir and int(last_ir.value) > 32: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: - ret_buf = IRLiteral(128) # TODO: need allocator - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) - ctx.append_instruction("mstore", [ret_ir, new_var], False) - inst = IRInstruction("return", [last_ir, new_var]) + ret_buf = 128 # TODO: need allocator + new_var = bb.append_instruction("alloca", 32, ret_buf) + bb.append_instruction("mstore", ret_ir, new_var) + bb.append_instruction("return", last_ir, new_var) - ctx.get_basic_block().append_instruction(inst) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + bb = ctx.get_basic_block() if func_t.is_internal: assert ir.args[1].value == "return_pc", "return_pc not found" if func_t.return_type is None: - inst = IRInstruction("ret", [symbols["return_pc"]]) + bb.append_instruction("ret", symbols["return_pc"]) else: if func_t.return_type.memory_bytes_required > 32: - inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + bb.append_instruction("ret", symbols["return_buffer"], symbols["return_pc"]) else: - ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) - inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) - - ctx.get_basic_block().append_instruction(inst) + ret_by_value = bb.append_instruction("mload", symbols["return_buffer"]) + bb.append_instruction("ret", ret_by_value, symbols["return_pc"]) elif ir.value == "revert": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - inst = IRInstruction("revert", [arg_1, arg_0]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("revert", arg_1, arg_0) elif ir.value == "dload": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + bb = ctx.get_basic_block() + src = bb.append_instruction("add", arg_0, IRLabel("code_end")) + + bb.append_instruction("dloadbytes", 32, src, MemoryPositions.FREE_VAR_SPACE) + return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) - ctx.append_instruction( - "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False - ) - return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) elif ir.value == "dloadbytes": dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) src_offset = _convert_ir_basicblock( ctx, ir.args[1], symbols, variables, allocated_variables ) len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - - src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) - - inst = IRInstruction("dloadbytes", [len_, src, dst]) - ctx.get_basic_block().append_instruction(inst) + bb = ctx.get_basic_block() + src = bb.append_instruction("add", src_offset, IRLabel("code_end")) + bb.append_instruction("dloadbytes", len_, src, dst) return None + elif ir.value == "mload": sym_ir = ir.args[0] var = ( _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None ) + bb = ctx.get_basic_block() if var is not None: if var.size and var.size > 32: if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + allocated_variables[var.name] = bb.append_instruction( + "alloca", var.size, var.pos ) offset = int(sym_ir.value) - var.pos if offset > 0: - ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_variables[var.name] - return ctx.append_instruction("mload", [ptr_var]) + return bb.append_instruction("mload", ptr_var) else: if sym_ir.is_literal: sym = symbols.get(f"&{sym_ir.value}", None) if sym is None: - new_var = ctx.append_instruction("store", [sym_ir]) + new_var = bb.append_instruction("store", sym_ir) symbols[f"&{sym_ir.value}"] = new_var if allocated_variables.get(var.name, None) is None: allocated_variables[var.name] = new_var @@ -691,9 +667,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if sym_ir.is_literal: new_var = symbols.get(f"&{sym_ir.value}", None) if new_var is not None: - return ctx.append_instruction("mload", [new_var]) + return bb.append_instruction("mload", new_var) else: - return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + return bb.append_instruction("mload", sym_ir.value) else: new_var = _convert_ir_basicblock( ctx, sym_ir, symbols, variables, allocated_variables @@ -706,12 +682,14 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): # if sym_ir.is_self_call: return new_var - return ctx.append_instruction("mload", [new_var]) + return bb.append_instruction("mload", new_var) elif ir.value == "mstore": sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + bb = ctx.get_basic_block() + var = None if isinstance(sym_ir, IRLiteral): var = _get_variable_from_address(variables, int(sym_ir.value)) @@ -719,41 +697,38 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if var is not None and var.size is not None: if var.size and var.size > 32: if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + allocated_variables[var.name] = bb.append_instruction( + "alloca", var.size, var.pos ) offset = int(sym_ir.value) - var.pos if offset > 0: - ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_variables[var.name] - return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + bb.append_instruction("mstore", arg_1, ptr_var) else: if isinstance(sym_ir, IRLiteral): - new_var = ctx.append_instruction("store", [arg_1]) + new_var = bb.append_instruction("store", arg_1) symbols[f"&{sym_ir.value}"] = new_var # if allocated_variables.get(var.name, None) is None: allocated_variables[var.name] = new_var return new_var else: if not isinstance(sym_ir, IRLiteral): - inst = IRInstruction("mstore", [arg_1, sym_ir]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym_ir) return None sym = symbols.get(f"&{sym_ir.value}", None) if sym is None: - inst = IRInstruction("mstore", [arg_1, sym_ir]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym_ir) if arg_1 and not isinstance(sym_ir, IRLiteral): symbols[f"&{sym_ir.value}"] = arg_1 return None if isinstance(sym_ir, IRLiteral): - inst = IRInstruction("mstore", [arg_1, sym]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym) return None else: symbols[sym_ir.value] = arg_1 @@ -761,12 +736,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value in ["sload", "iload"]: arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - return ctx.append_instruction(ir.value, [arg_0]) + return ctx.get_basic_block().append_instruction(ir.value, arg_0) elif ir.value in ["sstore", "istore"]: arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - inst = IRInstruction(ir.value, [arg_1, arg_0]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) elif ir.value == "unique_symbol": sym = ir.args[0] new_var = ctx.get_next_variable() @@ -803,28 +777,19 @@ def emit_body_block(): increment_block = IRBasicBlock(ctx.get_next_label(), ctx) exit_block = IRBasicBlock(ctx.get_next_label(), ctx) - counter_var = ctx.get_next_variable() counter_inc_var = ctx.get_next_variable() - ret = ctx.get_next_variable() - inst = IRInstruction("store", [start], counter_var) - ctx.get_basic_block().append_instruction(inst) + counter_var = ctx.get_basic_block().append_instruction("store", start) symbols[sym.value] = counter_var - inst = IRInstruction("jmp", [cond_block.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", cond_block.label) - symbols[sym.value] = ret - cond_block.append_instruction( - IRInstruction( - "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret - ) + ret = cond_block.append_instruction( + "phi", entry_block.label, counter_var, increment_block.label, counter_inc_var ) + symbols[sym.value] = ret - xor_ret = ctx.get_next_variable() - cont_ret = ctx.get_next_variable() - inst = IRInstruction("xor", [ret, end], xor_ret) - cond_block.append_instruction(inst) - cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + xor_ret = cond_block.append_instruction("xor", ret, end) + cont_ret = cond_block.append_instruction("iszero", xor_ret) ctx.append_basic_block(cond_block) # Do a dry run to get the symbols needing phi nodes @@ -851,56 +816,55 @@ def emit_body_block(): body_end = ctx.get_basic_block() if not body_end.is_terminated: - body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + body_end.append_instruction("jmp", jump_up_block.label) - jump_cond = IRInstruction("jmp", [increment_block.label]) - jump_up_block.append_instruction(jump_cond) + jump_up_block.append_instruction("jmp", increment_block.label) ctx.append_basic_block(jump_up_block) - increment_block.append_instruction( - IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) - ) - increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + increment_block.append_instruction(IRInstruction("add", ret, 1)) + increment_block.insert_instruction[-1].output = counter_inc_var + + increment_block.append_instruction("jmp", cond_block.label) ctx.append_basic_block(increment_block) ctx.append_basic_block(exit_block) - inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) - cond_block.append_instruction(inst) + cond_block.append_instruction("jnz", cont_ret, exit_block.label, body_block.label) elif ir.value == "break": assert _break_target is not None, "Break with no break target" - inst = IRInstruction("jmp", [_break_target.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", _break_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) elif ir.value == "continue": assert _continue_target is not None, "Continue with no contrinue target" - inst = IRInstruction("jmp", [_continue_target.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", _continue_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) elif ir.value == "gas": - return ctx.append_instruction("gas", []) + return ctx.get_basic_block().append_instruction("gas") elif ir.value == "returndatasize": - return ctx.append_instruction("returndatasize", []) + return ctx.get_basic_block().append_instruction("returndatasize") elif ir.value == "returndatacopy": assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + new_var = ctx.get_basic_block().append_instruction("returndatacopy", arg_1, size) symbols[f"&{arg_0.value}"] = new_var return new_var elif ir.value == "selfdestruct": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.append_instruction("selfdestruct", [arg_0], False) + ctx.get_basic_block().append_instruction("selfdestruct", arg_0) elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = [ - _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) - for arg in ir.args - ] - inst = IRInstruction(ir.value, reversed(args)) - ctx.get_basic_block().append_instruction(inst) + args = reversed( + [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + ) + topic_count = int(ir.value[3:]) + assert topic_count >= 0 and topic_count <= 4, "invalid topic count" + ctx.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) elif isinstance(ir.value, str) and ir.value in symbols: @@ -927,8 +891,7 @@ def _convert_ir_opcode( inst_args.append( _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) ) - instruction = IRInstruction(opcode, inst_args) # type: ignore - ctx.get_basic_block().append_instruction(instruction) + ctx.get_basic_block().append_instruction(opcode, *inst_args) def _data_ofst_of(sym, ofst, height_): diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 9ee1012f91..90dd60e881 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,5 +1,5 @@ from vyper.exceptions import CompilerPanic -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -61,7 +61,7 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB source = in_bb.label.value target = bb.label.value split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) - split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) # Rewire the CFG diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index f6ec45440a..8760e9aa63 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -62,11 +62,6 @@ "lt", "slt", "sgt", - "log0", - "log1", - "log2", - "log3", - "log4", ] ) @@ -274,6 +269,10 @@ def _generate_evm_for_instruction( operands = [] elif opcode == "istore": operands = inst.operands[0:1] + elif opcode == "log": + log_topic_count = inst.operands[0].value + assert log_topic_count in [0, 1, 2, 3, 4], "Invalid topic count" + operands = inst.operands[1:] else: operands = inst.operands @@ -417,6 +416,8 @@ def _generate_evm_for_instruction( elif opcode == "istore": loc = inst.operands[1].value assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + elif opcode == "log": + assembly.extend([f"LOG{log_topic_count}"]) else: raise Exception(f"Unknown opcode: {opcode}") From 0b1f3e143c4f432c469b61fbe1566cb46cfcfca1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 13 Dec 2023 18:16:34 -0500 Subject: [PATCH 136/201] fix: remove .keyword from Call AST node (#3689) for some reason, there is a slot named "keyword" on the Call AST node, which is never used (and doesn't exist in the python AST!). this commit removes it for hygienic purposes. --- vyper/ast/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 2497928035..69bd1fed53 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1254,7 +1254,7 @@ def _op(self, left, right): class Call(ExprNode): - __slots__ = ("func", "args", "keywords", "keyword") + __slots__ = ("func", "args", "keywords") class keyword(VyperNode): From 919080e0b74c908d986f5cee121a2bf2379cb2dc Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 15 Dec 2023 21:31:31 -0500 Subject: [PATCH 137/201] chore: test all output formats (#3683) right now only certain output formats are tested in the main compiler test harness, namely bytecode, abi, metadata and some natspec outputs. in the past, there have been issues where output formats get broken but don't get detected until release testing or even after release. this commit adds hooks in `get_contract()` and `deploy_blueprint_for()` to generate all output formats, which will help detect broken output formats sooner. --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 216fb32b0d..22f8544beb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -279,8 +279,8 @@ def _get_contract( settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, - # test that metadata and natspecs get generated - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + # test that all output formats can get generated + output_formats=list(compiler.OUTPUT_FORMATS.keys()), settings=settings, input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing @@ -352,7 +352,7 @@ def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwar settings.optimize = optimize out = compiler.compile_code( source_code, - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + output_formats=list(compiler.OUTPUT_FORMATS.keys()), settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) From c6f457a73db40e4b113497883bd330e0dcec28d1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 11:16:14 -0500 Subject: [PATCH 138/201] feat: implement "stateless" modules (#3663) this commit implements support for "stateless" modules in vyper. this is the first major step in implementing vyper's module system. it redesigns the language's import system, allows calling internal functions from imported modules, allows for limited use of types from imported modules, and introduces support for `.vyi` interface files. note that the following features are left for future, follow-up work: - modules with variables (constants, immutables or storage variables) - full support for imported events (in that they do not get exported in the ABI) - a system for exporting imported functions in the external interface of a contract this commit first and foremost changes how imports are handled in vyper. previously, an imported file was assumed to be an interface file. some very limited validation was performed in `InterfaceT.from_ast`, but not fully typechecked, and no code was generated for it. now, when a file is imported, it is checked whether it is 1. a `.vy` file 2. a `.vyi` file 3. a `.json` file the `.json` pathway remains more or less unchanged. the `.vyi` pathway is new, but it is fairly straightforward and is basically a "simple" path through the `.vy` pathway which piggy-backs off the `.vy` analysis to produce an `InterfaceT` object. the `.vy` pathway now does full typechecking and analysis of the imported module. some changes were made to support this: - a new ImportGraph data structure tracks the position in the import graph and detects (and bands) import cycles - InputBundles now implement a `_normalize_path()` method. this method normalizes the path so that source IDs are stable no matter how a file is accessed in the filesystem (i.e., no matter what the search path was at the time `load_file()` was called). - CompilerInput now has a distinction between `resolved_path` and `path` (the original path that was asked for). this allows us to maintain UX considerations (showing unresolved paths etc) while still having a 1:1:1 correspondence between source id, filepath and filesystem. these changes were needed in order to stabilize notions like "which file are we looking at?" no matter the way the file was accessed or how it was imported. this is important so that types imported transitively can resolve as expected no matter how they are imported - for instance, `x.SomeType` and `a.x.SomeType` resolving to the same type. the other changes needed to support code generation and analysis for imported functions were fairly simple, and mostly involved generalizing the analysis/code generation to type-based dispatch instead of AST-based dispatch. other changes to the language and compiler API include: - import restrictions are more relaxed - `import x` is allowed now (previously, `import x as x` was required) - change function labels in IR function labels are changed to disambiguate functions of the same name (but whose parent module are different). this was done by computing a unique function_id for every function and using that function_id when constructing its IR identifier. - add compile_from_file_input which accepts a FileInput instead of a string. this is now the new preferred entry point into the compiler. its usage simplifies several internal APIs which expect to have `source_id` and `path` in addition to the raw source code. - change compile_code api from contract_name= to contract_path= additional changes to internal APIs and passes include: - remove `remove_unused_statements()` the "unused statements" are now important to keep around for imports! in general, it is planned to remove both the AST expansion and constant folding passes as copying around the AST results in both performance and correctness problems - abstract out a common exception rewriting pattern. instead of raising `exception.with_annotation(node)` -- just catch-all in the parent implementation and then don't have to worry about it at the exception site. - rename "type" metadata key on most top-level declarators to more specific names (e.g. "func_type", "getter_type", etc). - remove dead package pytest-rerunfailures use of `--reruns` was removed in c913b2db0881a6 - refactor: move `parse_*` functions, remove vyper.ast.annotation move `parse_*` functions into new module vyper.ast.parse and merge in vyper.ast.annotation - rename the old `GlobalContext` class to `ModuleT` - refactor: move InterfaceT into `vyper/semantics/types/module.py` it makes more sense here since it is closely coupled with `ModuleT`. --- setup.py | 1 - tests/conftest.py | 10 +- .../codegen/test_call_graph_stability.py | 2 +- .../{builtins => }/codegen/test_interfaces.py | 129 ++++-- .../codegen/test_selector_table_stability.py | 2 +- .../codegen/test_stateless_modules.py | 335 ++++++++++++++ tests/functional/grammar/test_grammar.py | 2 +- tests/functional/syntax/test_interfaces.py | 9 +- tests/unit/ast/nodes/test_hex.py | 4 +- .../ast/test_annotate_and_optimize_ast.py | 3 +- tests/unit/ast/test_ast_dict.py | 7 +- tests/unit/ast/test_parser.py | 2 +- .../test_storage_layout.py | 0 .../test_storage_layout_overrides.py | 2 +- .../cli/vyper_compile/test_compile_files.py | 182 ++++---- .../unit/cli/vyper_json/test_compile_json.py | 111 +++-- tests/unit/cli/vyper_json/test_get_inputs.py | 5 +- .../cli/vyper_json/test_output_selection.py | 52 ++- .../vyper_json/test_parse_args_vyperjson.py | 4 +- tests/unit/compiler/asm/test_asm_optimizer.py | 61 ++- tests/unit/compiler/test_input_bundle.py | 141 ++++-- .../semantics/analysis/test_array_index.py | 20 +- .../analysis/test_cyclic_function_calls.py | 28 +- .../unit/semantics/analysis/test_for_loop.py | 32 +- tests/unit/semantics/test_storage_slots.py | 2 +- tox.ini | 2 +- vyper/__init__.py | 2 +- vyper/ast/__init__.py | 3 +- vyper/ast/__init__.pyi | 2 +- vyper/ast/expansion.py | 51 +- vyper/ast/grammar.lark | 9 +- vyper/ast/natspec.py | 2 +- vyper/ast/nodes.py | 21 +- vyper/ast/nodes.pyi | 13 +- vyper/ast/{annotation.py => parse.py} | 128 +++++- vyper/ast/utils.py | 61 +-- vyper/builtins/_utils.py | 9 +- vyper/builtins/functions.py | 4 +- .../interfaces/{ERC165.vy => ERC165.vyi} | 2 +- .../interfaces/{ERC20.vy => ERC20.vyi} | 24 +- .../{ERC20Detailed.vy => ERC20Detailed.vyi} | 6 +- .../interfaces/{ERC4626.vy => ERC4626.vyi} | 32 +- .../interfaces/{ERC721.vy => ERC721.vyi} | 43 +- vyper/cli/vyper_compile.py | 6 +- vyper/cli/vyper_json.py | 31 +- vyper/codegen/context.py | 8 +- vyper/codegen/expr.py | 49 +- vyper/codegen/function_definitions/common.py | 19 +- vyper/codegen/global_context.py | 32 -- vyper/codegen/module.py | 116 +++-- vyper/codegen/self_call.py | 11 +- vyper/codegen/stmt.py | 57 +-- vyper/compiler/__init__.py | 45 +- vyper/compiler/input_bundle.py | 111 +++-- vyper/compiler/output.py | 22 +- vyper/compiler/phases.py | 102 ++-- vyper/exceptions.py | 20 +- vyper/semantics/analysis/__init__.py | 17 +- vyper/semantics/analysis/base.py | 39 +- vyper/semantics/analysis/common.py | 21 +- vyper/semantics/analysis/data_positions.py | 4 +- vyper/semantics/analysis/import_graph.py | 37 ++ vyper/semantics/analysis/local.py | 45 +- vyper/semantics/analysis/module.py | 411 +++++++++++------ vyper/semantics/analysis/utils.py | 21 +- vyper/semantics/namespace.py | 2 +- vyper/semantics/types/__init__.py | 3 +- vyper/semantics/types/base.py | 12 +- vyper/semantics/types/bytestrings.py | 10 +- vyper/semantics/types/function.py | 435 +++++++++++------- vyper/semantics/types/module.py | 332 +++++++++++++ vyper/semantics/types/subscriptable.py | 36 +- vyper/semantics/types/user.py | 268 ++--------- vyper/semantics/types/utils.py | 52 ++- vyper/utils.py | 9 +- 75 files changed, 2546 insertions(+), 1397 deletions(-) rename tests/functional/{builtins => }/codegen/test_interfaces.py (84%) create mode 100644 tests/functional/codegen/test_stateless_modules.py rename tests/unit/cli/{outputs => storage_layout}/test_storage_layout.py (100%) rename tests/unit/cli/{outputs => storage_layout}/test_storage_layout_overrides.py (98%) rename vyper/ast/{annotation.py => parse.py} (68%) rename vyper/builtins/interfaces/{ERC165.vy => ERC165.vyi} (88%) rename vyper/builtins/interfaces/{ERC20.vy => ERC20.vyi} (68%) rename vyper/builtins/interfaces/{ERC20Detailed.vy => ERC20Detailed.vyi} (93%) rename vyper/builtins/interfaces/{ERC4626.vy => ERC4626.vyi} (90%) rename vyper/builtins/interfaces/{ERC721.vy => ERC721.vyi} (61%) delete mode 100644 vyper/codegen/global_context.py create mode 100644 vyper/semantics/analysis/import_graph.py create mode 100644 vyper/semantics/types/module.py diff --git a/setup.py b/setup.py index 40efb436c5..431c50b74b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "pytest-instafail>=0.4,<1.0", "pytest-xdist>=2.5,<3.0", "pytest-split>=0.7.0,<1.0", - "pytest-rerunfailures>=10.2,<11", "eth-tester[py-evm]>=0.9.0b1,<0.10", "py-evm>=0.7.0a1,<0.8", "web3==6.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 22f8544beb..925a025a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer @@ -103,6 +103,12 @@ def fn(sources_dict): return fn +# for tests which just need an input bundle, doesn't matter what it is +@pytest.fixture +def dummy_input_bundle(): + return InputBundle([]) + + # TODO: remove me, this is just string.encode("utf-8").ljust() # only used in test_logging.py. @pytest.fixture @@ -255,9 +261,11 @@ def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) if optimize != OptimizationLevel.NONE: ir = optimizer.optimize(ir) + bytecode, _ = compile_ir.assembly_to_evm( compile_ir.compile_to_assembly(ir, optimize=optimize) ) + abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) deploy_transaction = c.constructor() diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 4c85c330f3..2d8ad59791 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -55,7 +55,7 @@ def foo(): # check the .called_functions data structure on foo() directly foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] - foo_t = foo._metadata["type"] + foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names # now for sanity, ensure the order that the function definitions appear diff --git a/tests/functional/builtins/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py similarity index 84% rename from tests/functional/builtins/codegen/test_interfaces.py rename to tests/functional/codegen/test_interfaces.py index 8cb0124f29..3544f4a965 100644 --- a/tests/functional/builtins/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -6,9 +6,9 @@ from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, + DuplicateImport, InterfaceViolation, NamespaceCollision, - StructureException, ) @@ -31,7 +31,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): out = compile_code(code, output_formats=["interface"]) out = out["interface"] - code_pass = "\n".join(code.split("\n")[:-2] + [" pass"]) # replace with a pass statement. + code_pass = "\n".join(code.split("\n")[:-2] + [" ..."]) # replace with a pass statement. assert code_pass.strip() == out.strip() @@ -60,7 +60,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): view def test(_owner: address): nonpayable """ - out = compile_code(code, contract_name="One.vy", output_formats=["external_interface"])[ + out = compile_code(code, contract_path="One.vy", output_formats=["external_interface"])[ "external_interface" ] @@ -85,14 +85,14 @@ def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external def foo() -> uint256: - pass + ... @external def bar() -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) code = """ import a as FooBarInterface @@ -121,9 +121,8 @@ def foo() -> uint256: """ - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) + with pytest.raises(InterfaceViolation): + compile_code(not_implemented_code, input_bundle=input_bundle) def test_missing_event(make_input_bundle, assert_compile_failed): @@ -132,7 +131,7 @@ def test_missing_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -156,7 +155,7 @@ def test_malformed_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -183,7 +182,7 @@ def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -211,7 +210,7 @@ def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): a: indexed(uint256) """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -234,13 +233,13 @@ def bar() -> uint256: VALID_IMPORT_CODE = [ # import statement, import path without suffix - ("import a as Foo", "a.vy"), - ("import b.a as Foo", "b/a.vy"), - ("import Foo as Foo", "Foo.vy"), - ("from a import Foo", "a/Foo.vy"), - ("from b.a import Foo", "b/a/Foo.vy"), - ("from .a import Foo", "./a/Foo.vy"), - ("from ..a import Foo", "../a/Foo.vy"), + ("import a as Foo", "a.vyi"), + ("import b.a as Foo", "b/a.vyi"), + ("import Foo as Foo", "Foo.vyi"), + ("from a import Foo", "a/Foo.vyi"), + ("from b.a import Foo", "b/a/Foo.vyi"), + ("from .a import Foo", "./a/Foo.vyi"), + ("from ..a import Foo", "../a/Foo.vyi"), ] @@ -252,11 +251,12 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): BAD_IMPORT_CODE = [ - ("import a", StructureException), # must alias absolute imports - ("import a as A\nimport a as A", NamespaceCollision), + ("import a as A\nimport a as A", DuplicateImport), + ("import a as A\nimport a as a", DuplicateImport), + ("from . import a\nimport a as a", DuplicateImport), + ("import a as a\nfrom . import a", DuplicateImport), ("from b import a\nfrom . import a", NamespaceCollision), - ("from . import a\nimport a as a", NamespaceCollision), - ("import a as a\nfrom . import a", NamespaceCollision), + ("import a\nimport c as a", NamespaceCollision), ] @@ -264,34 +264,50 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): def test_extract_file_interface_imports_raises( code, exception_type, assert_compile_failed, make_input_bundle ): - input_bundle = make_input_bundle({"a.vy": "", "b/a.vy": ""}) # dummy - assert_compile_failed(lambda: compile_code(code, input_bundle=input_bundle), exception_type) + input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": ""}) + with pytest.raises(exception_type): + compile_code(code, input_bundle=input_bundle) def test_external_call_to_interface(w3, get_contract, make_input_bundle): + token_interface = """ +@view +@external +def balanceOf(addr: address) -> uint256: + ... + +@external +def transfer(to: address, amount: uint256): + ... + """ + token_code = """ +import itoken as IToken + +implements: IToken + balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256): - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256): + self.balanceOf[to] += amount """ - input_bundle = make_input_bundle({"one.vy": token_code}) + input_bundle = make_input_bundle({"token.vy": token_code, "itoken.vyi": token_interface}) code = """ -import one as TokenCode +import itoken as IToken interface EPI: def test() -> uint256: view -token_address: TokenCode +token_address: IToken @external def __init__(_token_address: address): - self.token_address = TokenCode(_token_address) + self.token_address = IToken(_token_address) @external @@ -299,14 +315,15 @@ def test(): self.token_address.transfer(msg.sender, 1000) """ - erc20 = get_contract(token_code) - test_c = get_contract(code, *[erc20.address], input_bundle=input_bundle) + token = get_contract(token_code, input_bundle=input_bundle) + + test_c = get_contract(code, *[token.address], input_bundle=input_bundle) sender = w3.eth.accounts[0] - assert erc20.balanceOf(sender) == 0 + assert token.balanceOf(sender) == 0 test_c.test(transact={}) - assert erc20.balanceOf(sender) == 1000 + assert token.balanceOf(sender) == 1000 @pytest.mark.parametrize( @@ -320,26 +337,36 @@ def test(): ], ) def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected, make_input_bundle): - code_a = f""" + interface_code = f""" +@external +@view +def foo(_max: {typ} = {kwarg}) -> {typ}: + ... + """ + code1 = f""" +import one as IContract + +implements: IContract + @external @view def foo(_max: {typ} = {kwarg}) -> {typ}: return _max """ - input_bundle = make_input_bundle({"one.vy": code_a}) + input_bundle = make_input_bundle({"one.vyi": interface_code}) - code_b = f""" -import one as ContractA + code2 = f""" +import one as IContract @external @view def bar(a_address: address) -> {typ}: - return ContractA(a_address).foo() + return IContract(a_address).foo() """ - contract_a = get_contract(code_a) - contract_b = get_contract(code_b, *[contract_a.address], input_bundle=input_bundle) + contract_a = get_contract(code1, input_bundle=input_bundle) + contract_b = get_contract(code2, *[contract_a.address], input_bundle=input_bundle) assert contract_b.bar(contract_a.address) == expected @@ -349,8 +376,8 @@ def test_external_call_to_builtin_interface(w3, get_contract): balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256) -> bool: - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256) -> bool: + self.balanceOf[to] += amount return True """ @@ -510,14 +537,14 @@ def returns_Bytes3() -> Bytes[3]: """ should_not_compile = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface @external def foo(x: BadJSONInterface) -> Bytes[2]: return slice(x.returns_Bytes3(), 0, 2) """ code = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface foo: BadJSONInterface @@ -578,10 +605,10 @@ def balanceOf(owner: address) -> uint256: @external @view def balanceOf(owner: address) -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"balanceof.vy": interface_code}) + input_bundle = make_input_bundle({"balanceof.vyi": interface_code}) c = get_contract(code, input_bundle=input_bundle) @@ -592,7 +619,7 @@ def test_simple_implements(make_input_bundle): interface_code = """ @external def foo() -> uint256: - pass + ... """ code = """ @@ -605,7 +632,7 @@ def foo() -> uint256: return 1 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) assert compile_code(code, input_bundle=input_bundle) is not None diff --git a/tests/functional/codegen/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py index 3302ff5009..27f82416d6 100644 --- a/tests/functional/codegen/test_selector_table_stability.py +++ b/tests/functional/codegen/test_selector_table_stability.py @@ -14,7 +14,7 @@ def test_dense_jumptable_stability(): # test that the selector table data is stable across different runs # (tox should provide different PYTHONHASHSEEDs). - expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501 + expected_asm = """{ DATA _sym_BUCKET_HEADERS b\'\\x0bB\' _sym_bucket_0 b\'\\n\' b\'+\\x8d\' _sym_bucket_1 b\'\\x0c\' b\'\\x00\\x85\' _sym_bucket_2 b\'\\x08\' } { DATA _sym_bucket_1 b\'\\xd8\\xee\\xa1\\xe8\' _sym_external 6 foo6()3639517672 b\'\\x05\' b\'\\xd2\\x9e\\xe0\\xf9\' _sym_external 0 foo0()3533627641 b\'\\x05\' b\'\\x05\\xf1\\xe0_\' _sym_external 2 foo2()99737695 b\'\\x05\' b\'\\x91\\t\\xb4{\' _sym_external 23 foo23()2433332347 b\'\\x05\' b\'np3\\x7f\' _sym_external 11 foo11()1852846975 b\'\\x05\' b\'&\\xf5\\x96\\xf9\' _sym_external 13 foo13()653629177 b\'\\x05\' b\'\\x04ga\\xeb\' _sym_external 14 foo14()73884139 b\'\\x05\' b\'\\x89\\x06\\xad\\xc6\' _sym_external 17 foo17()2298916294 b\'\\x05\' b\'\\xe4%\\xac\\xd1\' _sym_external 4 foo4()3827674321 b\'\\x05\' b\'yj\\x01\\xac\' _sym_external 7 foo7()2036990380 b\'\\x05\' b\'\\xf1\\xe6K\\xe5\' _sym_external 29 foo29()4058401765 b\'\\x05\' b\'\\xd2\\x89X\\xb8\' _sym_external 3 foo3()3532216504 b\'\\x05\' } { DATA _sym_bucket_2 b\'\\x06p\\xffj\' _sym_external 25 foo25()108068714 b\'\\x05\' b\'\\x964\\x99I\' _sym_external 24 foo24()2520029513 b\'\\x05\' b\'s\\x81\\xe7\\xc1\' _sym_external 10 foo10()1937893313 b\'\\x05\' b\'\\x85\\xad\\xc11\' _sym_external 28 foo28()2242756913 b\'\\x05\' b\'\\xfa"\\xb1\\xed\' _sym_external 5 foo5()4196577773 b\'\\x05\' b\'A\\xe7[\\x05\' _sym_external 22 foo22()1105681157 b\'\\x05\' b\'\\xd3\\x89U\\xe8\' _sym_external 1 foo1()3548993000 b\'\\x05\' b\'hL\\xf8\\xf3\' _sym_external 20 foo20()1749874931 b\'\\x05\' } { DATA _sym_bucket_0 b\'\\xee\\xd9\\x1d\\xe3\' _sym_external 9 foo9()4007206371 b\'\\x05\' b\'a\\xbc\\x1ch\' _sym_external 16 foo16()1639717992 b\'\\x05\' b\'\\xd3*\\xa7\\x0c\' _sym_external 21 foo21()3542787852 b\'\\x05\' b\'\\x18iG\\xd9\' _sym_external 19 foo19()409552857 b\'\\x05\' b\'\\n\\xf1\\xf9\\x7f\' _sym_external 18 foo18()183630207 b\'\\x05\' b\')\\xda\\xd7`\' _sym_external 27 foo27()702207840 b\'\\x05\' b\'2\\xf6\\xaa\\xda\' _sym_external 12 foo12()855026394 b\'\\x05\' b\'\\xbe\\xb5\\x05\\xf5\' _sym_external 15 foo15()3199534581 b\'\\x05\' b\'\\xfc\\xa7_\\xe6\' _sym_external 8 foo8()4238827494 b\'\\x05\' b\'\\x1b\\x12C8\' _sym_external 26 foo26()454181688 b\'\\x05\' } }""" # noqa: E501 assert expected_asm in output["asm"] diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py new file mode 100644 index 0000000000..8e634e5868 --- /dev/null +++ b/tests/functional/codegen/test_stateless_modules.py @@ -0,0 +1,335 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +from vyper import compiler +from vyper.exceptions import ( + CallViolation, + DuplicateImport, + ImportCycle, + StructureException, + TypeMismatch, +) + +# test modules which have no variables - "libraries" + + +def test_simple_library(get_contract, make_input_bundle, w3): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + main = """ +import library + +@external +def bar() -> uint256: + return library.foo() - 1 + """ + input_bundle = make_input_bundle({"library.vy": library_source}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == w3.eth.block_number + + +# is this the best place for this? +def test_import_cycle(make_input_bundle): + code_a = "import b\n" + code_b = "import a\n" + + input_bundle = make_input_bundle({"a.vy": code_a, "b.vy": code_b}) + + with pytest.raises(ImportCycle): + compiler.compile_code(code_a, input_bundle=input_bundle) + + +# test we can have a function in the library with the same name as +# in the main contract +def test_library_function_same_name(get_contract, make_input_bundle): + library = """ +@internal +def foo() -> uint256: + return 10 + """ + + main = """ +import library + +@internal +def foo() -> uint256: + return 100 + +@external +def self_foo() -> uint256: + return self.foo() + +@external +def library_foo() -> uint256: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.self_foo() == 100 + assert c.library_foo() == 10 + + +def test_transitive_import(get_contract, make_input_bundle): + a = """ +@internal +def foo() -> uint256: + return 1 + """ + b = """ +import a + +@internal +def bar() -> uint256: + return a.foo() + 1 + """ + c = """ +import b + +@external +def baz() -> uint256: + return b.bar() + 1 + """ + # more complicated call graph, with `a` imported twice. + d = """ +import b +import a + +@external +def qux() -> uint256: + s: uint256 = a.foo() + return s + b.bar() + 1 + """ + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d}) + + contract = get_contract(c, input_bundle=input_bundle) + assert contract.baz() == 3 + contract = get_contract(d, input_bundle=input_bundle) + assert contract.qux() == 4 + + +def test_cannot_call_library_external_functions(make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(CallViolation): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_external_functions_not_in_abi(get_contract, make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + pass + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + assert not hasattr(c, "foo") + + +def test_library_structs(get_contract, make_input_bundle): + library_source = """ +struct SomeStruct: + x: uint256 + +@internal +def foo() -> SomeStruct: + return SomeStruct({x: 1}) + """ + contract_source = """ +import library + +@external +def bar(s: library.SomeStruct): + pass + +@external +def baz() -> library.SomeStruct: + return library.SomeStruct({x: 2}) + +@external +def qux() -> library.SomeStruct: + return library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + + assert c.bar((1,)) == [] + + assert c.baz() == (2,) + assert c.qux() == (1,) + + +# test calls to library functions in statement position +def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): + library_source = """ +from vyper.interfaces import ERC20 +@internal +def check_adds_to_ten(x: uint256, y: uint256): + assert x + y == 10 + """ + contract_source = """ +import library + +counter: public(uint256) + +@external +def foo(x: uint256): + library.check_adds_to_ten(3, x) + self.counter = x + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + + c = get_contract(contract_source, input_bundle=input_bundle) + + c.foo(7, transact={}) + + assert c.counter() == 7 + + assert_tx_failed(lambda: c.foo(8)) + + +def test_library_is_typechecked(make_input_bundle): + library_source = """ +@internal +def foo(): + asdlkfjasdflkajsdf + """ + contract_source = """ +import library + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(StructureException): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_is_typechecked2(make_input_bundle): + # check that we typecheck against imported function signatures + library_source = """ +@internal +def foo() -> uint256: + return 1 + """ + contract_source = """ +import library + +@external +def foo() -> bytes32: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(TypeMismatch): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_reject_duplicate_imports(make_input_bundle): + library_source = """ + """ + + contract_source = """ +import library +import library as library2 + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(DuplicateImport): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_nested_module_access(get_contract, make_input_bundle): + lib1 = """ +import lib2 + +@internal +def lib2_foo() -> uint256: + return lib2.foo() + """ + lib2 = """ +@internal +def foo() -> uint256: + return 1337 + """ + + main = """ +import lib1 +import lib2 + +@external +def lib1_foo() -> uint256: + return lib1.lib2_foo() + +@external +def lib2_foo() -> uint256: + return lib1.lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.lib1_foo() == c.lib2_foo() == 1337 + + +_int_127 = st.integers(min_value=0, max_value=127) +_bytes_128 = st.binary(min_size=0, max_size=128) + + +def test_slice_builtin(get_contract, make_input_bundle): + lib = """ +@internal +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + + main = """ +import lib +@external +def lib_slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return lib.slice_input(x, start, length) + +@external +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + # use an inner test so that we can cache the result of get_contract() + @given(start=_int_127, length=_int_127, bytesdata=_bytes_128) + @settings(max_examples=100) + def _test(bytesdata, start, length): + # surjectively map start into allowable range + if start > len(bytesdata): + start = start % (len(bytesdata) or 1) + # surjectively map length into allowable range + if length > (len(bytesdata) - start): + length = length % ((len(bytesdata) - start) or 1) + main_result = c.slice_input(bytesdata, start, length) + library_result = c.lib_slice_input(bytesdata, start, length) + assert main_result == library_result == bytesdata[start : start + length] + + _test() diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index aa0286cfa5..7dd8c35929 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -92,7 +92,7 @@ def from_grammar() -> st.SearchStrategy[str]: # Avoid examples with *only* single or double quote docstrings -# because they trigger a trivial compiler bug +# because they trigger a trivial parser bug SINGLE_QUOTE_DOCSTRING = re.compile(r"^'''.*'''$") DOUBLE_QUOTE_DOCSTRING = re.compile(r'^""".*"""$') diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 9100389dbd..a672ed7b88 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -376,17 +376,12 @@ def test_interfaces_success(good_code): def test_imports_and_implements_within_interface(make_input_bundle): interface_code = """ -from vyper.interfaces import ERC20 -import foo.bar as Baz - -implements: Baz - @external def foobar(): - pass + ... """ - input_bundle = make_input_bundle({"foo.vy": interface_code}) + input_bundle = make_input_bundle({"foo.vyi": interface_code}) code = """ import foo as Foo diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 47483c493c..d413340083 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -37,9 +37,9 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code): +def test_invalid_checksum(code, dummy_input_bundle): vyper_module = vy_ast.parse_to_ast(code) with pytest.raises(InvalidLiteral): vy_ast.validation.validate_literal_nodes(vyper_module) - semantics.validate_semantics(vyper_module, {}) + semantics.validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 68a07178bb..16ce6fe631 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,7 +1,6 @@ import ast as python_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse +from vyper.ast.parse import annotate_python_ast, pre_parse class AssertionVisitor(python_ast.NodeVisitor): diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 1f60c9ac8b..dc49f72561 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,7 +1,8 @@ import json from vyper import compiler -from vyper.ast.utils import ast_to_dict, dict_to_ast, parse_to_ast +from vyper.ast.parse import parse_to_ast +from vyper.ast.utils import ast_to_dict, dict_to_ast def get_node_ids(ast_struct, ids=None): @@ -40,7 +41,7 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", @@ -89,7 +90,7 @@ def foo() -> uint256: view def foo() -> uint256: return 1 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][1] == { "col_offset": 0, "annotation": { diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index c47bf40bfa..e0bfcbc2ef 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,4 +1,4 @@ -from vyper.ast.utils import parse_to_ast +from vyper.ast.parse import parse_to_ast def test_ast_equal(): diff --git a/tests/unit/cli/outputs/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py similarity index 100% rename from tests/unit/cli/outputs/test_storage_layout.py rename to tests/unit/cli/storage_layout/test_storage_layout.py diff --git a/tests/unit/cli/outputs/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py similarity index 98% rename from tests/unit/cli/outputs/test_storage_layout_overrides.py rename to tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 94e0faeb37..f4c11b7ae6 100644 --- a/tests/unit/cli/outputs/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -103,7 +103,7 @@ def test_overflow(): storage_layout_override = {"x": {"slot": 2**256 - 1, "type": "uint256[2]"}} with pytest.raises( - StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}\n" + StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}" ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a16efa777..f6e3a51a4b 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -30,93 +30,100 @@ def test_invalid_root_path(): compile_files([], [], root_folder="path/that/does/not/exist") -FOO_CODE = """ -{} - -struct FooStruct: - foo_: uint256 +CONTRACT_CODE = """ +{import_stmt} @external -def foo() -> FooStruct: - return FooStruct({{foo_: 13}}) +def foo() -> {alias}.FooStruct: + return {alias}.FooStruct({{foo_: 13}}) @external -def bar(a: address) -> FooStruct: - return {}(a).bar() +def bar(a: address) -> {alias}.FooStruct: + return {alias}(a).bar() """ -BAR_CODE = """ +INTERFACE_CODE = """ struct FooStruct: foo_: uint256 + +@external +def foo() -> FooStruct: + ... + @external def bar() -> FooStruct: - return FooStruct({foo_: 13}) + ... """ SAME_FOLDER_IMPORT_STMT = [ - ("import Bar as Bar", "Bar"), - ("import contracts.Bar as Bar", "Bar"), - ("from . import Bar", "Bar"), - ("from contracts import Bar", "Bar"), - ("from ..contracts import Bar", "Bar"), - ("from . import Bar as FooBar", "FooBar"), - ("from contracts import Bar as FooBar", "FooBar"), - ("from ..contracts import Bar as FooBar", "FooBar"), + ("import IFoo as IFoo", "IFoo"), + ("import contracts.IFoo as IFoo", "IFoo"), + ("from . import IFoo", "IFoo"), + ("from contracts import IFoo", "IFoo"), + ("from ..contracts import IFoo", "IFoo"), + ("from . import IFoo as FooBar", "FooBar"), + ("from contracts import IFoo as FooBar", "FooBar"), + ("from ..contracts import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) def test_import_same_folder(import_stmt, alias, tmp_path, make_file): foo = "contracts/foo.vy" - make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("contracts/Bar.vy", BAR_CODE) + make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) SUBFOLDER_IMPORT_STMT = [ - ("import other.Bar as Bar", "Bar"), - ("import contracts.other.Bar as Bar", "Bar"), - ("from other import Bar", "Bar"), - ("from contracts.other import Bar", "Bar"), - ("from .other import Bar", "Bar"), - ("from ..contracts.other import Bar", "Bar"), - ("from other import Bar as FooBar", "FooBar"), - ("from contracts.other import Bar as FooBar", "FooBar"), - ("from .other import Bar as FooBar", "FooBar"), - ("from ..contracts.other import Bar as FooBar", "FooBar"), + ("import other.IFoo as IFoo", "IFoo"), + ("import contracts.other.IFoo as IFoo", "IFoo"), + ("from other import IFoo", "IFoo"), + ("from contracts.other import IFoo", "IFoo"), + ("from .other import IFoo", "IFoo"), + ("from ..contracts.other import IFoo", "IFoo"), + ("from other import IFoo as FooBar", "FooBar"), + ("from contracts.other import IFoo as FooBar", "FooBar"), + ("from .other import IFoo as FooBar", "FooBar"), + ("from ..contracts.other import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) def test_import_subfolder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", (FOO_CODE.format(import_stmt, alias))) - make_file("contracts/other/Bar.vy", BAR_CODE) + foo = make_file( + "contracts/foo.vy", (CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + ) + make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) OTHER_FOLDER_IMPORT_STMT = [ - ("import interfaces.Bar as Bar", "Bar"), - ("from interfaces import Bar", "Bar"), - ("from ..interfaces import Bar", "Bar"), - ("from interfaces import Bar as FooBar", "FooBar"), - ("from ..interfaces import Bar as FooBar", "FooBar"), + ("import interfaces.IFoo as IFoo", "IFoo"), + ("from interfaces import IFoo", "IFoo"), + ("from ..interfaces import IFoo", "IFoo"), + ("from interfaces import IFoo as FooBar", "FooBar"), + ("from ..interfaces import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) def test_import_other_folder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("interfaces/Bar.vy", BAR_CODE) + foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("interfaces/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) def test_import_parent_folder(tmp_path, make_file): - foo = make_file("contracts/baz/foo.vy", FOO_CODE.format("from ... import Bar", "Bar")) - make_file("Bar.vy", BAR_CODE) + foo = make_file( + "contracts/baz/foo.vy", + CONTRACT_CODE.format(import_stmt="from ... import IFoo", alias="IFoo"), + ) + make_file("IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) @@ -125,62 +132,60 @@ def test_import_parent_folder(tmp_path, make_file): META_IMPORT_STMT = [ - "import Meta as Meta", - "import contracts.Meta as Meta", - "from . import Meta", - "from contracts import Meta", + "import ISelf as ISelf", + "import contracts.ISelf as ISelf", + "from . import ISelf", + "from contracts import ISelf", ] @pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) def test_import_self_interface(import_stmt, tmp_path, make_file): - # a contract can access its derived interface by importing itself - code = f""" -{import_stmt} - + interface_code = """ struct FooStruct: foo_: uint256 @external def know_thyself(a: address) -> FooStruct: - return Meta(a).be_known() + ... @external def be_known() -> FooStruct: - return FooStruct({{foo_: 42}}) + ... """ - meta = make_file("contracts/Meta.vy", code) - - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + code = f""" +{import_stmt} +@external +def know_thyself(a: address) -> ISelf.FooStruct: + return ISelf(a).be_known() -DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] +@external +def be_known() -> ISelf.FooStruct: + return ISelf.FooStruct({{foo_: 42}}) + """ + make_file("contracts/ISelf.vyi", interface_code) + meta = make_file("contracts/Self.vy", code) -DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] + assert compile_files([meta], ["combined_json"], root_folder=tmp_path) -@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) -@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) -def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path, make_file): - # contracts-as-interfaces should be able to contain import statements +# implement IFoo in another contract for fun +@pytest.mark.parametrize("import_stmt_foo,alias", SAME_FOLDER_IMPORT_STMT) +def test_another_interface_implementation(import_stmt_foo, alias, tmp_path, make_file): baz_code = f""" -{import_stmt_baz} - -struct FooStruct: - foo_: uint256 +{import_stmt_foo} @external -def foo(a: address) -> FooStruct: - return Foo(a).foo() +def foo(a: address) -> {alias}.FooStruct: + return {alias}(a).foo() @external -def bar(_foo: address, _bar: address) -> FooStruct: - return Foo(_foo).bar(_bar) +def bar(_foo: address) -> {alias}.FooStruct: + return {alias}(_foo).bar() """ - - make_file("Foo.vy", FOO_CODE.format(import_stmt_foo, "Bar")) - make_file("Bar.vy", BAR_CODE) - baz = make_file("Baz.vy", baz_code) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) + baz = make_file("contracts/Baz.vy", baz_code) assert compile_files([baz], ["combined_json"], root_folder=tmp_path) @@ -207,15 +212,36 @@ def test_local_namespace(make_file, tmp_path): make_file(filename, code) paths.append(filename) - for file_name in ("foo.vy", "bar.vy"): - make_file(file_name, BAR_CODE) + for file_name in ("foo.vyi", "bar.vyi"): + make_file(file_name, INTERFACE_CODE) assert compile_files(paths, ["combined_json"], root_folder=tmp_path) def test_compile_outside_root_path(tmp_path, make_file): # absolute paths relative to "." - foo = make_file("foo.vy", FOO_CODE.format("import bar as Bar", "Bar")) - bar = make_file("bar.vy", BAR_CODE) + make_file("ifoo.vyi", INTERFACE_CODE) + foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) + + assert compile_files([foo], ["combined_json"], root_folder=".") + + +def test_import_library(tmp_path, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + + make_file("lib.vy", library_source) + contract_file = make_file("contract.vy", contract_source) - assert compile_files([foo, bar], ["combined_json"], root_folder=".") + assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 732762d72b..a50946ba21 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -1,30 +1,55 @@ import json +from pathlib import PurePath import pytest import vyper -from vyper.cli.vyper_json import compile_from_input_dict, compile_json, exc_handler_to_dict -from vyper.compiler import OUTPUT_FORMATS, compile_code +from vyper.cli.vyper_json import ( + compile_from_input_dict, + compile_json, + exc_handler_to_dict, + get_inputs, +) +from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input +from vyper.compiler.input_bundle import JSONInputBundle from vyper.exceptions import InvalidType, JSONError, SyntaxException FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar + +import contracts.library as library @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) @external def baz() -> uint256: - return self.balance + return self.balance + library.foo() """ BAR_CODE = """ +import contracts.ibar as IBar + +implements: IBar + @external def bar(a: uint256) -> bool: return True """ +BAR_VYI = """ +@external +def bar(a: uint256) -> bool: + ... +""" + +LIBRARY_CODE = """ +@internal +def foo() -> uint256: + return block.number + 1 +""" + BAD_SYNTAX_CODE = """ def bar()>: """ @@ -52,6 +77,7 @@ def input_json(): "language": "Vyper", "sources": { "contracts/foo.vy": {"content": FOO_CODE}, + "contracts/library.vy": {"content": LIBRARY_CODE}, "contracts/bar.vy": {"content": BAR_CODE}, }, "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, @@ -59,6 +85,14 @@ def input_json(): } +@pytest.fixture(scope="function") +def input_bundle(input_json): + # CMC 2023-12-11 maybe input_json -> JSONInputBundle should be a helper + # function in `vyper_json.py`. + sources = get_inputs(input_json) + return JSONInputBundle(sources, search_paths=[PurePath(".")]) + + # test string and dict inputs both work def test_string_input(input_json): assert compile_json(input_json) == compile_json(json.dumps(input_json)) @@ -77,29 +111,39 @@ def test_keyerror_becomes_jsonerror(input_json): compile_json(input_json) -def test_compile_json(input_json, make_input_bundle): - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) +def test_compile_json(input_json, input_bundle): + foo_input = input_bundle.load_file("contracts/foo.vy") + foo = compile_from_file_input( + foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + ) - foo = compile_code( - FOO_CODE, - source_id=0, - contract_name="contracts/foo.vy", - output_formats=OUTPUT_FORMATS, - input_bundle=input_bundle, + library_input = input_bundle.load_file("contracts/library.vy") + library = compile_from_file_input( + library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - bar = compile_code( - BAR_CODE, source_id=1, contract_name="contracts/bar.vy", output_formats=OUTPUT_FORMATS + + bar_input = input_bundle.load_file("contracts/bar.vy") + bar = compile_from_file_input( + bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - compile_code_results = {"contracts/bar.vy": bar, "contracts/foo.vy": foo} + compile_code_results = { + "contracts/bar.vy": bar, + "contracts/library.vy": library, + "contracts/foo.vy": foo, + } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" - for source_id, contract_name in enumerate(["foo", "bar"]): + for source_id, contract_name in [(0, "foo"), (2, "library"), (3, "bar")]: path = f"contracts/{contract_name}.vy" data = compile_code_results[path] assert output_json["sources"][path] == {"id": source_id, "ast": data["ast_dict"]["ast"]} @@ -123,13 +167,28 @@ def test_compile_json(input_json, make_input_bundle): } -def test_different_outputs(make_input_bundle, input_json): +def test_compilation_targets(input_json): + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] + + # omit library.vy + input_json["settings"]["outputSelection"] = {"contracts/foo.vy": "*", "contracts/bar.vy": "*"} + output_json = compile_json(input_json) + + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + +def test_different_outputs(input_bundle, input_json): input_json["settings"]["outputSelection"] = { "contracts/bar.vy": "*", "contracts/foo.vy": ["evm.methodIdentifiers"], } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == ["contracts/bar.vy", "contracts/foo.vy"] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" @@ -143,10 +202,9 @@ def test_different_outputs(make_input_bundle, input_json): assert sorted(foo.keys()) == ["evm"] # check method_identifiers - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) method_identifiers = compile_code( FOO_CODE, - contract_name="contracts/foo.vy", + contract_path="contracts/foo.vy", output_formats=["method_identifiers"], input_bundle=input_bundle, )["method_identifiers"] @@ -204,11 +262,12 @@ def get(filename, contractname): return result["contracts"][filename][contractname]["evm"]["deployedBytecode"]["sourceMap"] assert get("contracts/foo.vy", "foo").startswith("-1:-1:0") - assert get("contracts/bar.vy", "bar").startswith("-1:-1:1") + assert get("contracts/library.vy", "library").startswith("-1:-1:2") + assert get("contracts/bar.vy", "bar").startswith("-1:-1:3") def test_relative_import_paths(input_json): - input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} - input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} + input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": "from ... import foo"} + input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": "from . import baz"} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from baz import baz"} compile_from_input_dict(input_json) diff --git a/tests/unit/cli/vyper_json/test_get_inputs.py b/tests/unit/cli/vyper_json/test_get_inputs.py index 6e323a91bd..c91cc750f2 100644 --- a/tests/unit/cli/vyper_json/test_get_inputs.py +++ b/tests/unit/cli/vyper_json/test_get_inputs.py @@ -2,7 +2,7 @@ import pytest -from vyper.cli.vyper_json import get_compilation_targets, get_inputs +from vyper.cli.vyper_json import get_inputs from vyper.exceptions import JSONError from vyper.utils import keccak256 @@ -122,9 +122,6 @@ def test_interfaces_output(): "interface.folder/bar2.vy": {"content": BAR_CODE}, }, } - targets = get_compilation_targets(input_json) - assert targets == [PurePath("foo.vy")] - result = get_inputs(input_json) assert result == { PurePath("foo.vy"): {"content": FOO_CODE}, diff --git a/tests/unit/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py index 78ad7404f2..5383190a66 100644 --- a/tests/unit/cli/vyper_json/test_output_selection.py +++ b/tests/unit/cli/vyper_json/test_output_selection.py @@ -8,53 +8,61 @@ def test_no_outputs(): with pytest.raises(KeyError): - get_output_formats({}, {}) + get_output_formats({}) def test_invalid_output(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}, + } with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) def test_unknown_contract(): - input_json = {"settings": {"outputSelection": {"bar.vy": ["abi"]}}} - targets = [PurePath("foo.vy")] + input_json = {"sources": {}, "settings": {"outputSelection": {"bar.vy": ["abi"]}}} with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) @pytest.mark.parametrize("output", TRANSLATE_MAP.items()) def test_translate_map(output): - input_json = {"settings": {"outputSelection": {"foo.vy": [output[0]]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): [output[1]]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": [output[0]]}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): [output[1]]} def test_star(): - input_json = {"settings": {"outputSelection": {"*": ["*"]}}} - targets = [PurePath("foo.vy"), PurePath("bar.vy")] + input_json = { + "sources": {"foo.vy": "", "bar.vy": ""}, + "settings": {"outputSelection": {"*": ["*"]}}, + } expected = sorted(set(TRANSLATE_MAP.values())) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected, PurePath("bar.vy"): expected} def test_evm(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}, + } expected = ["abi"] + sorted(v for k, v in TRANSLATE_MAP.items() if k.startswith("evm")) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected} def test_solc_style(): - input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["abi", "ir_dict"]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["abi", "ir_dict"]} def test_metadata(): - input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["metadata"]} + input_json = {"sources": {"foo.vy": ""}, "settings": {"outputSelection": {"*": ["metadata"]}}} + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["metadata"]} diff --git a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py index 3b0f700c7e..6b509dd3ef 100644 --- a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py @@ -9,11 +9,11 @@ from vyper.exceptions import JSONError FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) """ BAR_CODE = """ diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 47b70a8c70..44b823757c 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings @@ -71,33 +72,61 @@ def __init__(): ] +# check dead code eliminator works on unreachable functions @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime - ctor_only_label = "_sym_internal_ctor_only___" - runtime_only_label = "_sym_internal_runtime_only___" + # get the labels + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] + + ctor_only = "ctor_only()" + runtime_only = "runtime_only()" # qux reachable from unoptimized initcode, foo not reachable. - assert ctor_only_label + "_deploy" in initcode_asm - assert runtime_only_label + "_deploy" not in initcode_asm + assert any(ctor_only in instr for instr in initcode_asm) + assert all(runtime_only not in instr for instr in initcode_asm) # all labels should be in unoptimized runtime asm - for s in (ctor_only_label, runtime_only_label): - assert s + "_runtime" in runtime_asm + for s in (ctor_only, runtime_only): + assert any(s in instr for instr in runtime_asm) c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] # ctor only label should not be in runtime code - for instr in runtime_asm: - if isinstance(instr, str): - assert not instr.startswith(ctor_only_label), instr + assert all(ctor_only not in instr for instr in runtime_asm) # runtime only label should not be in initcode asm - for instr in initcode_asm: - if isinstance(instr, str): - assert not instr.startswith(runtime_only_label), instr + assert all(runtime_only not in instr for instr in initcode_asm) + + +def test_library_code_eliminator(make_input_bundle): + library = """ +@internal +def unused1(): + pass + +@internal +def unused2(): + self.unused1() + +@internal +def some_function(): + pass + """ + code = """ +import library + +@external +def foo(): + library.some_function() + """ + input_bundle = make_input_bundle({"library.vy": library}) + res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) + asm = res["asm"] + assert "some_function()" in asm + assert "unused1()" not in asm + assert "unused2()" not in asm diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index c49c81219b..e26555b169 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,4 +1,6 @@ +import contextlib import json +import os from pathlib import Path, PurePath import pytest @@ -12,19 +14,19 @@ def input_bundle(tmp_path): return FilesystemInputBundle([tmp_path]) -def test_load_file(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") +def test_load_file(make_file, input_bundle): + filepath = make_file("foo.vy", "contents") file = input_bundle.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_context_manager(make_file, tmp_path): ib = FilesystemInputBundle([]) - make_file("foo.vy", "contents") + filepath = make_file("foo.vy", "contents") with pytest.raises(FileNotFoundError): # no search path given @@ -34,7 +36,7 @@ def test_search_path_context_manager(make_file, tmp_path): file = ib.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bundle): @@ -43,59 +45,85 @@ def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bun tmpdir = tmp_path_factory.mktemp("some_directory") tmpdir2 = tmp_path_factory.mktemp("some_other_directory") + filepaths = [] for i, directory in enumerate([tmp_path, tmpdir, tmpdir2]): - with (directory / "foo.vy").open("w") as f: + path = directory / "foo.vy" + with path.open("w") as f: f.write(f"contents {i}") + filepaths.append(path) ib = FilesystemInputBundle([tmp_path, tmpdir, tmpdir2]) file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(0, tmpdir2 / "foo.vy", "contents 2") + assert file == FileInput(0, "foo.vy", filepaths[2], "contents 2") with ib.search_path(tmpdir): file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(1, tmpdir / "foo.vy", "contents 1") + assert file == FileInput(1, "foo.vy", filepaths[1], "contents 1") # special rules for handling json files def test_load_abi(make_file, input_bundle, tmp_path): contents = json.dumps("some string") - make_file("foo.json", contents) + path = make_file("foo.json", contents) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", path, "some string") # suffix doesn't matter - make_file("foo.txt", contents) - + path = make_file("foo.txt", contents) file = input_bundle.load_file("foo.txt") assert isinstance(file, ABIInput) - assert file == ABIInput(1, tmp_path / "foo.txt", "some string") + assert file == ABIInput(1, "foo.txt", path, "some string") + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") - make_file("bar.vy", "contents 2") + foopath = make_file("foo.vy", "contents") + barpath = make_file("bar.vy", "contents 2") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") file2 = input_bundle.load_file("bar.vy") # source id increments assert file2.source_id == 1 - assert file2 == FileInput(1, tmp_path / "bar.vy", "contents 2") + assert file2 == FileInput(1, "bar.vy", barpath, "contents 2") file3 = input_bundle.load_file("foo.vy") assert file3.source_id == 0 - assert file3 == FileInput(0, tmp_path / "foo.vy", "contents") + assert file3 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.vy") + assert file4.source_id == 0 + assert file4 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.vy") + assert file5.source_id == 0 + assert file5 == FileInput(0, Path(tmp_path.stem) / "foo.vy", foopath, "contents") # check that unique paths give unique source ids @@ -103,37 +131,51 @@ def test_source_id_json_input(make_file, input_bundle, tmp_path): contents = json.dumps("some string") contents2 = json.dumps(["some list"]) - make_file("foo.json", contents) + foopath = make_file("foo.json", contents) - make_file("bar.json", contents2) + barpath = make_file("bar.json", contents2) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", foopath, "some string") file2 = input_bundle.load_file("bar.json") assert isinstance(file2, ABIInput) - assert file2 == ABIInput(1, tmp_path / "bar.json", ["some list"]) + assert file2 == ABIInput(1, "bar.json", barpath, ["some list"]) file3 = input_bundle.load_file("foo.json") - assert isinstance(file3, ABIInput) - assert file3 == ABIInput(0, tmp_path / "foo.json", "some string") + assert file3.source_id == 0 + assert file3 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.json") + assert file4.source_id == 0 + assert file4 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.json") + assert file5.source_id == 0 + assert file5 == ABIInput(0, Path(tmp_path.stem) / "foo.json", foopath, "some string") # test some pathological case where the file changes underneath def test_mutating_file_source_id(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") + foopath = make_file("foo.vy", "contents") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") - make_file("foo.vy", "new contents") + foopath = make_file("foo.vy", "new contents") file = input_bundle.load_file("foo.vy") # source id hasn't changed, even though contents have assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "new contents") + assert file == FileInput(0, "foo.vy", foopath, "new contents") # test the os.normpath behavior of symlink @@ -147,10 +189,12 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): dir2.mkdir() symlink.symlink_to(dir2, target_is_directory=True) - with (tmp_path / "foo.vy").open("w") as f: - f.write("contents of the upper directory") + outer_path = tmp_path / "foo.vy" + with outer_path.open("w") as f: + f.write("contents of the outer directory") - with (dir1 / "foo.vy").open("w") as f: + inner_path = dir1 / "foo.vy" + with inner_path.open("w") as f: f.write("contents of the inner directory") # symlink rules would be: @@ -159,9 +203,10 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): # base/first/foo.vy # normpath would be base/symlink/../foo.vy => # base/foo.vy - file = input_bundle.load_file(symlink / ".." / "foo.vy") + to_load = symlink / ".." / "foo.vy" + file = input_bundle.load_file(to_load) - assert file == FileInput(0, tmp_path / "foo.vy", "contents of the upper directory") + assert file == FileInput(0, to_load, outer_path.resolve(), "contents of the outer directory") def test_json_input_bundle_basic(): @@ -169,40 +214,42 @@ def test_json_input_bundle_basic(): input_bundle = JSONInputBundle(files, [PurePath(".")]) file = input_bundle.load_file(PurePath("foo.vy")) - assert file == FileInput(0, PurePath("foo.vy"), "some text") + assert file == FileInput(0, PurePath("foo.vy"), PurePath("foo.vy"), "some text") def test_json_input_bundle_normpath(): - files = {PurePath("foo/../bar.vy"): {"content": "some text"}} + contents = "some text" + files = {PurePath("foo/../bar.vy"): {"content": contents}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - expected = FileInput(0, PurePath("bar.vy"), "some text") + barpath = PurePath("bar.vy") + + expected = FileInput(0, barpath, barpath, contents) file = input_bundle.load_file(PurePath("bar.vy")) assert file == expected file = input_bundle.load_file(PurePath("baz/../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("baz/../bar.vy"), barpath, contents) file = input_bundle.load_file(PurePath("./bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("./bar.vy"), barpath, contents) with input_bundle.search_path(PurePath("foo")): file = input_bundle.load_file(PurePath("../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("../bar.vy"), barpath, contents) def test_json_input_abi(): some_abi = ["some abi"] some_abi_str = json.dumps(some_abi) - files = { - PurePath("foo.json"): {"abi": some_abi}, - PurePath("bar.txt"): {"content": some_abi_str}, - } + foopath = PurePath("foo.json") + barpath = PurePath("bar.txt") + files = {foopath: {"abi": some_abi}, barpath: {"content": some_abi_str}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - file = input_bundle.load_file(PurePath("foo.json")) - assert file == ABIInput(0, PurePath("foo.json"), some_abi) + file = input_bundle.load_file(foopath) + assert file == ABIInput(0, foopath, foopath, some_abi) - file = input_bundle.load_file(PurePath("bar.txt")) - assert file == ABIInput(1, PurePath("bar.txt"), some_abi) + file = input_bundle.load_file(barpath) + assert file == ABIInput(1, barpath, barpath, some_abi) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 27c0634cf8..5ea373fc19 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value): +def test_type_mismatch(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -23,11 +23,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value): +def test_invalid_literal(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -38,11 +38,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidType): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value): +def test_out_of_bounds(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -53,11 +53,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value): +def test_undeclared_definition(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -68,11 +68,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value): +def test_invalid_reference(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -83,4 +83,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 2a09bd5ed5..c31146b16f 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -3,22 +3,20 @@ from vyper.ast import parse_to_ast from vyper.exceptions import CallViolation, StructureException from vyper.semantics.analysis import validate_semantics -from vyper.semantics.analysis.module import ModuleAnalyzer -def test_self_function_call(namespace): +def test_self_function_call(dummy_input_bundle): code = """ @internal def foo(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_cyclic_function_call(namespace): +def test_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -29,12 +27,11 @@ def bar(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_multi_cyclic_function_call(namespace): +def test_multi_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -53,12 +50,11 @@ def potato(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_global_ann_assign_callable_no_crash(): +def test_global_ann_assign_callable_no_crash(dummy_input_bundle): code = """ balanceOf: public(HashMap[address, uint256]) @@ -68,5 +64,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - validate_semantics(vyper_module, {}) - assert excinfo.value.message == "Value is not callable" + validate_semantics(vyper_module, dummy_input_bundle) + assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 0d61a8f8f8..e2c0f555af 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis import validate_semantics -def test_modify_iterator_function_outside_loop(namespace): +def test_modify_iterator_function_outside_loop(dummy_input_bundle): code = """ a: uint256[3] @@ -26,10 +26,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_pass_memory_var_to_other_function(namespace): +def test_pass_memory_var_to_other_function(dummy_input_bundle): code = """ @internal @@ -46,10 +46,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator(namespace): +def test_modify_iterator(dummy_input_bundle): code = """ a: uint256[3] @@ -61,10 +61,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_keywords(namespace): +def test_bad_keywords(dummy_input_bundle): code = """ @internal @@ -75,10 +75,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_bound(namespace): +def test_bad_bound(dummy_input_bundle): code = """ @internal @@ -89,10 +89,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StateAccessViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_function_call(namespace): +def test_modify_iterator_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -108,10 +108,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_recursive_function_call(namespace): +def test_modify_iterator_recursive_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -131,7 +131,7 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) iterator_inference_codes = [ @@ -169,7 +169,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(namespace, code): +def test_iterator_type_inference_checker(code, dummy_input_bundle): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index d390fe9a39..002ee38cd2 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}\n", + match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/tox.ini b/tox.ini index c949354dfe..f9d4c3b60b 100644 --- a/tox.ini +++ b/tox.ini @@ -53,4 +53,4 @@ commands = basepython = python3 extras = lint commands = - mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --disallow-incomplete-defs -p vyper + mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --implicit-optional -p vyper diff --git a/vyper/__init__.py b/vyper/__init__.py index 482d5c3a60..5bb6469757 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -1,6 +1,6 @@ from pathlib import Path as _Path -from vyper.compiler import compile_code # noqa: F401 +from vyper.compiler import compile_code, compile_from_file_input try: from importlib.metadata import PackageNotFoundError # type: ignore diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index e5b81f1e7f..4b46801153 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -6,7 +6,8 @@ from . import nodes, validation from .natspec import parse_natspec from .nodes import compare_nodes -from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings +from .utils import ast_to_dict +from .parse import parse_to_ast, parse_to_ast_with_settings # adds vyper.ast.nodes classes into the local namespace for name, obj in ( diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index d349e804d6..eac8ffdef5 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -4,5 +4,5 @@ from typing import Any, Optional, Union from . import expansion, folding, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * +from .parse import parse_to_ast as parse_to_ast from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 5471b971a4..1536f39165 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -5,22 +5,9 @@ from vyper.semantics.types.function import ContractFunctionT -def expand_annotated_ast(vyper_module: vy_ast.Module) -> None: - """ - Perform expansion / simplification operations on an annotated Vyper AST. - - This pass uses annotated type information to modify the AST, simplifying - logic and expanding subtrees to reduce the compexity during codegen. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node that has been type-checked and annotated. - """ - generate_public_variable_getters(vyper_module) - remove_unused_statements(vyper_module) - - +# TODO: remove this function. it causes correctness/performance problems +# because of copying and mutating the AST - getter generation should be handled +# during code generation. def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ Create getter functions for public variables. @@ -32,7 +19,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}): - func_type = node._metadata["func_type"] + func_type = node._metadata["getter_type"] input_types, return_type = node._metadata["type"].getter_signature input_nodes = [] @@ -86,31 +73,11 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: returns=return_node, ) - with vyper_module.namespace(): - func_type = ContractFunctionT.from_FunctionDef(expanded) - - expanded._metadata["type"] = func_type - return_node.set_parent(expanded) + # update pointers vyper_module.add_to_body(expanded) + return_node.set_parent(expanded) + with vyper_module.namespace(): + func_type = ContractFunctionT.from_FunctionDef(expanded) -def remove_unused_statements(vyper_module: vy_ast.Module) -> None: - """ - Remove statement nodes that are unused after type checking. - - Once type checking is complete, we can remove now-meaningless statements to - simplify the AST prior to IR generation. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - - # constant declarations - values were substituted within the AST during folding - for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}): - vyper_module.remove_from_body(node) - - # `implements: interface` statements - validated during type checking - for node in vyper_module.get_children(vy_ast.ImplementsDecl): - vyper_module.remove_from_body(node) + expanded._metadata["func_type"] = func_type diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index ca9979b2a3..15367ce94a 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -89,7 +89,8 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a // NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps) _MAP: "HashMap" map_def: _MAP "[" ( NAME | array_def ) "," type "]" -type: ( NAME | array_def | tuple_def | map_def | dyn_array_def ) +imported_type: NAME "." NAME +type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def ) // Structs can be composed of 1+ basic types or other custom_types _STRUCT_DECL: "struct" @@ -291,7 +292,7 @@ special_builtins: empty | abi_decode // Adapted from: https://docs.python.org/3/reference/grammar.html // Adapted by: Erez Shinan NAME: /[a-zA-Z_]\w*/ -COMMENT: /#[^\n]*/ +COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ @@ -312,8 +313,10 @@ _number: DEC_NUMBER BOOL.2: "True" | "False" +ELLIPSIS: "..." + // TODO: Remove Docstring from here, and add to first part of body -?literal: ( _number | STRING | DOCSTRING | BOOL ) +?literal: ( _number | STRING | DOCSTRING | BOOL | ELLIPSIS) %ignore /[\t \f]+/ // WS %ignore /\\[\t \f]*\r?\n/ // LINE_CONT diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index c25fc423f8..41905b178a 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -43,7 +43,7 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: docstring = node.doc_string.value - func_type = node._metadata["type"] + func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: continue diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 69bd1fed53..3bccc5f141 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -589,7 +589,8 @@ def __contains__(self, obj): class Module(TopLevel): - __slots__ = () + # metadata + __slots__ = ("path", "resolved_path", "source_id") def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: """ @@ -897,12 +898,16 @@ def validate(self): raise InvalidLiteral("Cannot have an empty tuple", self) -class Dict(ExprNode): - __slots__ = ("keys", "values") +class NameConstant(Constant): + __slots__ = () -class NameConstant(Constant): - __slots__ = ("value",) +class Ellipsis(Constant): + __slots__ = () + + +class Dict(ExprNode): + __slots__ = ("keys", "values") class Name(ExprNode): @@ -1407,7 +1412,7 @@ class Pass(Stmt): __slots__ = () -class _Import(Stmt): +class _ImportStmt(Stmt): __slots__ = ("name", "alias") def __init__(self, *args, **kwargs): @@ -1419,11 +1424,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class Import(_Import): +class Import(_ImportStmt): __slots__ = () -class ImportFrom(_Import): +class ImportFrom(_ImportStmt): __slots__ = ("level", "module") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47c9af8526..05784aed0f 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -2,9 +2,9 @@ import ast as python_ast from typing import Any, Optional, Sequence, Type, Union from .natspec import parse_natspec as parse_natspec +from .parse import parse_to_ast as parse_to_ast +from .parse import parse_to_ast_with_settings as parse_to_ast_with_settings from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast -from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings NODE_BASE_ATTRIBUTES: Any NODE_SRC_ATTRIBUTES: Any @@ -59,6 +59,8 @@ class TopLevel(VyperNode): def __contains__(self, obj: Any) -> bool: ... class Module(TopLevel): + path: str = ... + resolved_path: str = ... def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... @@ -121,6 +123,9 @@ class Bytes(Constant): @property def s(self): ... +class NameConstant(Constant): ... +class Ellipsis(Constant): ... + class List(VyperNode): elements: list = ... @@ -131,8 +136,6 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class NameConstant(Constant): ... - class Name(VyperNode): id: str = ... _type: str = ... @@ -188,7 +191,7 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: Name = ... + func: VyperNode = ... class keyword(VyperNode): ... diff --git a/vyper/ast/annotation.py b/vyper/ast/parse.py similarity index 68% rename from vyper/ast/annotation.py rename to vyper/ast/parse.py index 9c7b1e063f..a2f2542179 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/parse.py @@ -1,14 +1,114 @@ import ast as python_ast import tokenize from decimal import Decimal -from typing import Optional, cast +from typing import Any, Dict, List, Optional, Union, cast import asttokens -from vyper.exceptions import CompilerPanic, SyntaxException +from vyper.ast import nodes as vy_ast +from vyper.ast.pre_parser import pre_parse +from vyper.compiler.settings import Settings +from vyper.exceptions import CompilerPanic, ParserException, SyntaxException from vyper.typing import ModificationOffsets +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + _settings, ast = parse_to_ast_with_settings(*args, **kwargs) + return ast + + +def parse_to_ast_with_settings( + source_code: str, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, + add_fn_node: Optional[str] = None, +) -> tuple[Settings, vy_ast.Module]: + """ + Parses a Vyper source string and generates basic Vyper AST nodes. + + Parameters + ---------- + source_code : str + The Vyper source code to parse. + source_id : int, optional + Source id to use in the `src` member of each node. + contract_name: str, optional + Name of contract. + add_fn_node: str, optional + If not None, adds a dummy Python AST FunctionDef wrapper node. + source_id: int, optional + The source ID generated for this source code. + Corresponds to FileInput.source_id + module_path: str, optional + The path of the source code + Corresponds to FileInput.path + resolved_path: str, optional + The resolved path of the source code + Corresponds to FileInput.resolved_path + + Returns + ------- + list + Untyped, unoptimized Vyper AST nodes. + """ + if "\x00" in source_code: + raise ParserException("No null bytes (\\x00) allowed in the source code.") + settings, class_types, reformatted_code = pre_parse(source_code) + try: + py_ast = python_ast.parse(reformatted_code) + except SyntaxError as e: + # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors + raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e + + # Add dummy function node to ensure local variables are treated as `AnnAssign` + # instead of state variables (`VariableDecl`) + if add_fn_node: + fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) + fn_node.body = py_ast.body + fn_node.args = python_ast.arguments(defaults=[]) + py_ast.body = [fn_node] + + annotate_python_ast( + py_ast, + source_code, + class_types, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + + # Convert to Vyper AST. + module = vy_ast.get_node(py_ast) + assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module + + +def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: + """ + Converts a Vyper AST node, or list of nodes, into a dictionary suitable for + output to the user. + """ + if isinstance(ast_struct, vy_ast.VyperNode): + return ast_struct.to_dict() + + if isinstance(ast_struct, list): + return [i.to_dict() for i in ast_struct] + + raise CompilerPanic(f'Unknown Vyper AST node provided: "{type(ast_struct)}".') + + +def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: + """ + Converts an AST dict, or list of dicts, into Vyper AST node objects. + """ + if isinstance(ast_struct, dict): + return vy_ast.get_node(ast_struct) + if isinstance(ast_struct, list): + return [vy_ast.get_node(i) for i in ast_struct] + raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -19,11 +119,13 @@ def __init__( modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, source_id: int, - contract_name: Optional[str], + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ): self._tokens = tokens self._source_id = source_id - self._contract_name = contract_name + self._module_path = module_path + self._resolved_path = resolved_path self._source_code: str = source_code self.counter: int = 0 self._modification_offsets = {} @@ -83,7 +185,9 @@ def _visit_docstring(self, node): return node def visit_Module(self, node): - node.name = self._contract_name + node.path = self._module_path + node.resolved_path = self._resolved_path + node.source_id = self._source_id return self._visit_docstring(node) def visit_FunctionDef(self, node): @@ -163,6 +267,8 @@ def visit_Constant(self, node): node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" + elif isinstance(node.value, Ellipsis.__class__): + node.ast_type = "Ellipsis" else: raise SyntaxException( "Invalid syntax (unsupported Python Constant AST node).", @@ -250,7 +356,8 @@ def annotate_python_ast( source_code: str, modification_offsets: Optional[ModificationOffsets] = None, source_id: int = 0, - contract_name: Optional[str] = None, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ) -> python_ast.AST: """ Annotate and optimize a Python AST in preparation conversion to a Vyper AST. @@ -270,7 +377,14 @@ def annotate_python_ast( """ tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) visitor.visit(parsed_ast) return parsed_ast diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index 4e669385ab..4c2e5394c9 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -1,64 +1,7 @@ -import ast as python_ast -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union from vyper.ast import nodes as vy_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse -from vyper.compiler.settings import Settings -from vyper.exceptions import CompilerPanic, ParserException, SyntaxException - - -def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: - return parse_to_ast_with_settings(*args, **kwargs)[1] - - -def parse_to_ast_with_settings( - source_code: str, - source_id: int = 0, - contract_name: Optional[str] = None, - add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module]: - """ - Parses a Vyper source string and generates basic Vyper AST nodes. - - Parameters - ---------- - source_code : str - The Vyper source code to parse. - source_id : int, optional - Source id to use in the `src` member of each node. - contract_name: str, optional - Name of contract. - add_fn_node: str, optional - If not None, adds a dummy Python AST FunctionDef wrapper node. - - Returns - ------- - list - Untyped, unoptimized Vyper AST nodes. - """ - if "\x00" in source_code: - raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) - try: - py_ast = python_ast.parse(reformatted_code) - except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors - raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e - - # Add dummy function node to ensure local variables are treated as `AnnAssign` - # instead of state variables (`VariableDecl`) - if add_fn_node: - fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) - fn_node.body = py_ast.body - fn_node.args = python_ast.arguments(defaults=[]) - py_ast.body = [fn_node] - annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) - - # Convert to Vyper AST. - module = vy_ast.get_node(py_ast) - assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module +from vyper.exceptions import CompilerPanic def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index afc0987b6d..72b05f15e3 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,10 +1,10 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context -from vyper.codegen.global_context import GlobalContext from vyper.codegen.stmt import parse_body from vyper.semantics.analysis.local import FunctionNodeVisitor from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability +from vyper.semantics.types.module import ModuleT def _strip_source_pos(ir_node): @@ -22,15 +22,16 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): # Initialise a placeholder `FunctionDef` AST node and corresponding # `ContractFunctionT` type to rely on the annotation visitors in semantics # module. - ast_code.body[0]._metadata["type"] = ContractFunctionT( + ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) # The FunctionNodeVisitor's constructor performs semantic checks # annotate the AST as side effects. - FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer.analyze() new_context = Context( - vars_=variables, global_ctx=GlobalContext(), memory_allocator=memory_allocator + vars_=variables, module_ctx=ModuleT(ast_code), memory_allocator=memory_allocator ) generated_ir = parse_body(ast_code.body[0].body, new_context) # strip source position info from the generated_ir since diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 22931508a6..d50a31767d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2499,9 +2499,9 @@ def infer_arg_types(self, node): validate_call_args(node, 2, ["unwrap_tuple"]) data_type = get_exact_type_from_node(node.args[0]) - output_typedef = TYPE_T(type_from_annotation(node.args[1])) + output_type = type_from_annotation(node.args[1]) - return [data_type, output_typedef] + return [data_type, TYPE_T(output_type)] @process_inputs def build_IR(self, expr, args, kwargs, context): diff --git a/vyper/builtins/interfaces/ERC165.vy b/vyper/builtins/interfaces/ERC165.vyi similarity index 88% rename from vyper/builtins/interfaces/ERC165.vy rename to vyper/builtins/interfaces/ERC165.vyi index a4ca451abd..441130f77c 100644 --- a/vyper/builtins/interfaces/ERC165.vy +++ b/vyper/builtins/interfaces/ERC165.vyi @@ -1,4 +1,4 @@ @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20.vy b/vyper/builtins/interfaces/ERC20.vyi similarity index 68% rename from vyper/builtins/interfaces/ERC20.vy rename to vyper/builtins/interfaces/ERC20.vyi index 065ca97a9b..ee533ab326 100644 --- a/vyper/builtins/interfaces/ERC20.vy +++ b/vyper/builtins/interfaces/ERC20.vyi @@ -1,38 +1,38 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _value: uint256 + sender: indexed(address) + recipient: indexed(address) + value: uint256 event Approval: - _owner: indexed(address) - _spender: indexed(address) - _value: uint256 + owner: indexed(address) + spender: indexed(address) + value: uint256 # Functions @view @external def totalSupply() -> uint256: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def allowance(_owner: address, _spender: address) -> uint256: - pass + ... @external def transfer(_to: address, _value: uint256) -> bool: - pass + ... @external def transferFrom(_from: address, _to: address, _value: uint256) -> bool: - pass + ... @external def approve(_spender: address, _value: uint256) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20Detailed.vy b/vyper/builtins/interfaces/ERC20Detailed.vyi similarity index 93% rename from vyper/builtins/interfaces/ERC20Detailed.vy rename to vyper/builtins/interfaces/ERC20Detailed.vyi index 7c4f546d45..0be1c6f153 100644 --- a/vyper/builtins/interfaces/ERC20Detailed.vy +++ b/vyper/builtins/interfaces/ERC20Detailed.vyi @@ -5,14 +5,14 @@ @view @external def name() -> String[1]: - pass + ... @view @external def symbol() -> String[1]: - pass + ... @view @external def decimals() -> uint8: - pass + ... diff --git a/vyper/builtins/interfaces/ERC4626.vy b/vyper/builtins/interfaces/ERC4626.vyi similarity index 90% rename from vyper/builtins/interfaces/ERC4626.vy rename to vyper/builtins/interfaces/ERC4626.vyi index 05865406cf..6d9e4c6ef7 100644 --- a/vyper/builtins/interfaces/ERC4626.vy +++ b/vyper/builtins/interfaces/ERC4626.vyi @@ -16,75 +16,75 @@ event Withdraw: @view @external def asset() -> address: - pass + ... @view @external def totalAssets() -> uint256: - pass + ... @view @external def convertToShares(assetAmount: uint256) -> uint256: - pass + ... @view @external def convertToAssets(shareAmount: uint256) -> uint256: - pass + ... @view @external def maxDeposit(owner: address) -> uint256: - pass + ... @view @external def previewDeposit(assets: uint256) -> uint256: - pass + ... @external def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxMint(owner: address) -> uint256: - pass + ... @view @external def previewMint(shares: uint256) -> uint256: - pass + ... @external def mint(shares: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxWithdraw(owner: address) -> uint256: - pass + ... @view @external def previewWithdraw(assets: uint256) -> uint256: - pass + ... @external def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... @view @external def maxRedeem(owner: address) -> uint256: - pass + ... @view @external def previewRedeem(shares: uint256) -> uint256: - pass + ... @external def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... diff --git a/vyper/builtins/interfaces/ERC721.vy b/vyper/builtins/interfaces/ERC721.vyi similarity index 61% rename from vyper/builtins/interfaces/ERC721.vy rename to vyper/builtins/interfaces/ERC721.vyi index 464c0e255b..b8dcfd3c5f 100644 --- a/vyper/builtins/interfaces/ERC721.vy +++ b/vyper/builtins/interfaces/ERC721.vyi @@ -1,67 +1,62 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _tokenId: indexed(uint256) + sender: indexed(address) + recipient: indexed(address) + token_id: indexed(uint256) event Approval: - _owner: indexed(address) - _approved: indexed(address) - _tokenId: indexed(uint256) + owner: indexed(address) + approved: indexed(address) + token_id: indexed(uint256) event ApprovalForAll: - _owner: indexed(address) - _operator: indexed(address) - _approved: bool + owner: indexed(address) + operator: indexed(address) + approved: bool # Functions @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def ownerOf(_tokenId: uint256) -> address: - pass + ... @view @external def getApproved(_tokenId: uint256) -> address: - pass + ... @view @external def isApprovedForAll(_owner: address, _operator: address) -> bool: - pass + ... @external @payable def transferFrom(_from: address, _to: address, _tokenId: uint256): - pass + ... @external @payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256): - pass - -@external -@payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024]): - pass +def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024] = b""): + ... @external @payable def approve(_approved: address, _tokenId: uint256): - pass + ... @external def setApprovalForAll(_operator: address, _approved: bool): - pass + ... diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ca1792384e..4f88812fa0 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -271,10 +271,8 @@ def compile_files( with open(storage_file_path) as sfh: storage_layout_override = json.load(sfh) - output = vyper.compile_code( - file.source_code, - contract_name=str(file.path), - source_id=file.source_id, + output = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=final_formats, exc_handler=exc_handler, diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 2720f20d23..63da2e0643 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError -from vyper.utils import keccak256 +from vyper.utils import OrderedSet, keccak256 TRANSLATE_MAP = { "abi": "abi", @@ -151,13 +151,6 @@ def get_evm_version(input_dict: dict) -> Optional[str]: return evm_version -def get_compilation_targets(input_dict: dict) -> list[PurePath]: - # TODO: once we have modules, add optional "compilation_targets" key - # which specifies which sources we actually want to compile. - - return [PurePath(p) for p in input_dict["sources"].keys()] - - def get_inputs(input_dict: dict) -> dict[PurePath, Any]: ret = {} seen = {} @@ -218,14 +211,14 @@ def get_inputs(input_dict: dict) -> dict[PurePath, Any]: # get unique output formats for each contract, given the input_dict # NOTE: would maybe be nice to raise on duplicated output formats -def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePath, list[str]]: +def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: output_formats: dict[PurePath, list[str]] = {} for path, outputs in input_dict["settings"]["outputSelection"].items(): if isinstance(outputs, dict): # if outputs are given in solc json format, collapse them into a single list - outputs = set(x for i in outputs.values() for x in i) + outputs = OrderedSet(x for i in outputs.values() for x in i) else: - outputs = set(outputs) + outputs = OrderedSet(outputs) for key in [i for i in ("evm", "evm.bytecode", "evm.deployedBytecode") if i in outputs]: outputs.remove(key) @@ -239,13 +232,13 @@ def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePa except KeyError as e: raise JSONError(f"Invalid outputSelection - {e}") - outputs = sorted(set(outputs)) + outputs = sorted(list(outputs)) if path == "*": - output_paths = targets + output_paths = [PurePath(path) for path in input_dict["sources"].keys()] else: output_paths = [PurePath(path)] - if output_paths[0] not in targets: + if str(output_paths[0]) not in input_dict["sources"]: raise JSONError(f"outputSelection references unknown contract '{output_paths[0]}'") for output_path in output_paths: @@ -281,9 +274,9 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) - compilation_targets = get_compilation_targets(input_dict) sources = get_inputs(input_dict) - output_formats = get_output_formats(input_dict, compilation_targets) + output_formats = get_output_formats(input_dict) + compilation_targets = list(output_formats.keys()) input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) @@ -295,12 +288,10 @@ def compile_from_input_dict( # use load_file to get a unique source_id file = input_bundle.load_file(contract_path) assert isinstance(file, FileInput) # mypy hint - data = vyper.compile_code( - file.source_code, - contract_name=str(file.path), + data = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=output_formats[contract_path], - source_id=file.source_id, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5b79f293bd..dea30faabc 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -48,7 +48,7 @@ def __repr__(self): class Context: def __init__( self, - global_ctx, + module_ctx, memory_allocator, vars_=None, forvars=None, @@ -60,7 +60,7 @@ def __init__( self.vars = vars_ or {} # Global variables, in the form (name, storage location, type) - self.globals = global_ctx.variables + self.globals = module_ctx.variables # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,8 +75,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store global context - self.global_ctx = global_ctx + # store module context + self.module_ctx = module_ctx # full function type self.func_t = func_t diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index dc0e98786f..5870e64e98 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -47,8 +47,10 @@ StringT, StructT, TupleT, + is_type_t, ) from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T from vyper.utils import ( DECIMAL_DIVISOR, @@ -79,7 +81,7 @@ def __init__(self, node, context): self.ir_node = fn() if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.", node) + raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -662,39 +664,38 @@ def parse_Call(self): if function_name in DISPATCH_TABLE: return DISPATCH_TABLE[function_name].build_IR(self.expr, self.context) - # Struct constructors do not need `self` prefix. - elif isinstance(self.expr._metadata["type"], StructT): - args = self.expr.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) + func_type = self.expr.func._metadata["type"] - # Interface assignment. Bar(
). - elif isinstance(self.expr._metadata["type"], InterfaceT): - (arg0,) = self.expr.args - arg_ir = Expr(arg0, self.context).ir_node + # Struct constructor + if is_type_t(func_type, StructT): + args = self.expr.args + if len(args) == 1 and isinstance(args[0], vy_ast.Dict): + return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) - assert arg_ir.typ == AddressT() - arg_ir.typ = self.expr._metadata["type"] + # Interface constructor. Bar(
). + if is_type_t(func_type, InterfaceT): + (arg0,) = self.expr.args + arg_ir = Expr(arg0, self.context).ir_node - return arg_ir + assert arg_ir.typ == AddressT() + arg_ir.typ = self.expr._metadata["type"] - elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": + return arg_ir + + if isinstance(func_type, MemberFunctionT) and self.expr.func.attr == "pop": # TODO consider moving this to builtins darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 assert isinstance(darray.typ, DArrayT) return pop_dyn_array(darray, return_popped_item=True) - elif ( - # TODO use expr.func.type.is_internal once - # type annotations are consistently available - isinstance(self.expr.func, vy_ast.Attribute) - and isinstance(self.expr.func.value, vy_ast.Name) - and self.expr.func.value.id == "self" - ): - return self_call.ir_for_self_call(self.expr, self.context) - else: - return external_call.ir_for_external_call(self.expr, self.context) + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.expr, self.context) + else: + return external_call.ir_for_external_call(self.expr, self.context) + + raise CompilerPanic("unreachable", self.expr) def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index c48f1256c3..454ba9c8cd 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -7,13 +7,13 @@ from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator from vyper.exceptions import CompilerPanic from vyper.semantics.types import VyperType from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, calc_mem_gas, mkalphanum +from vyper.semantics.types.module import ModuleT +from vyper.utils import MemoryPositions, calc_mem_gas @dataclass @@ -44,7 +44,14 @@ def exit_sequence_label(self) -> str: @cached_property def ir_identifier(self) -> str: argz = ",".join([str(argtyp) for argtyp in self.func_t.argument_types]) - return mkalphanum(f"{self.visibility} {self.func_t.name} ({argz})") + + name = self.func_t.name + function_id = self.func_t._function_id + assert function_id is not None + + # include module id in the ir identifier to disambiguate functions + # with the same name but which come from different modules + return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: if self.frame_info is not None: @@ -94,7 +101,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False + code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -103,7 +110,7 @@ def generate_ir_for_function( - Clamping and copying of arguments - Function body """ - func_t = code._metadata["type"] + func_t = code._metadata["func_type"] # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -126,7 +133,7 @@ def generate_ir_for_function( context = Context( vars_=None, - global_ctx=global_ctx, + module_ctx=module_ctx, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py deleted file mode 100644 index 1f6783f6f8..0000000000 --- a/vyper/codegen/global_context.py +++ /dev/null @@ -1,32 +0,0 @@ -from functools import cached_property -from typing import Optional - -from vyper import ast as vy_ast - - -# Datatype to store all global context information. -# TODO: rename me to ModuleT -class GlobalContext: - def __init__(self, module: Optional[vy_ast.Module] = None): - self._module = module - - @cached_property - def functions(self): - return self._module.get_children(vy_ast.FunctionDef) - - @cached_property - def variables(self): - # variables that this module defines, ex. - # `x: uint256` is a private storage variable named x - if self._module is None: # TODO: make self._module never be None - return None - variable_decls = self._module.get_children(vy_ast.VariableDecl) - return {s.target.id: s.target._metadata["varinfo"] for s in variable_decls} - - @property - def immutables(self): - return [t for t in self.variables.values() if t.is_immutable] - - @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index bfdafa8ba9..ef861e3953 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -5,49 +5,67 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic -from vyper.utils import method_id_int +from vyper.semantics.types.module import ModuleT +from vyper.utils import OrderedSet, method_id_int -def _topsort_helper(functions, lookup): - # single pass to get a global topological sort of functions (so that each - # function comes after each of its callees). may have duplicates, which get - # filtered out in _topsort() +def _topsort(functions): + # single pass to get a global topological sort of functions (so that each + # function comes after each of its callees). + ret = OrderedSet() + for func_ast in functions: + fn_t = func_ast._metadata["func_type"] + + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t.ast_def) + + ret.add(func_ast) + + # create globally unique IDs for each function + for idx, func in enumerate(ret): + func._metadata["func_type"]._function_id = idx + + return list(ret) + - ret = [] +# calculate globally reachable functions to see which +# ones should make it into the final bytecode. +# TODO: in the future, this should get obsolesced by IR dead code eliminator. +def _globally_reachable_functions(functions): + ret = OrderedSet() for f in functions: - # called_functions is a list of ContractFunctions, need to map - # back to FunctionDefs. - callees = [lookup[t.name] for t in f._metadata["type"].called_functions] - ret.extend(_topsort_helper(callees, lookup)) - ret.append(f) + fn_t = f._metadata["func_type"] - return ret + if not fn_t.is_external: + continue + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t) -def _topsort(functions): - lookup = {f.name: f for f in functions} - # strip duplicates - return list(dict.fromkeys(_topsort_helper(functions, lookup))) + ret.add(fn_t) + + return ret def _is_constructor(func_ast): - return func_ast._metadata["type"].is_constructor + return func_ast._metadata["func_type"].is_constructor def _is_fallback(func_ast): - return func_ast._metadata["type"].is_fallback + return func_ast._metadata["func_type"].is_fallback def _is_internal(func_ast): - return func_ast._metadata["type"].is_internal + return func_ast._metadata["func_type"].is_internal def _is_payable(func_ast): - return func_ast._metadata["type"].is_payable + return func_ast._metadata["func_type"].is_payable def _annotated_method_id(abi_sig): @@ -63,7 +81,7 @@ def label_for_entry_point(abi_sig, entry_point): # adapt whatever generate_ir_for_function gives us into an IR node def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): - func_t = func_ast._metadata["type"] + func_t = func_ast._metadata["func_type"] assert func_t.is_fallback or func_t.is_constructor ret = ["seq"] @@ -86,12 +104,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, global_ctx): +def _generate_external_entry_points(external_functions, module_ctx): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx) + func_ir = generate_ir_for_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -113,13 +131,13 @@ def _generate_external_entry_points(external_functions, global_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, global_ctx): +def _selector_section_dense(external_functions, module_ctx): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -264,13 +282,13 @@ def _selector_section_dense(external_functions, global_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, global_ctx): +def _selector_section_sparse(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -367,14 +385,14 @@ def _selector_section_sparse(external_functions, global_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, global_ctx): +def _selector_section_linear(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) dispatcher = ["seq"] @@ -402,10 +420,11 @@ def _selector_section_linear(external_functions, global_ctx): return ret -# take a GlobalContext, and generate the runtime and deploy IR -def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: +# take a ModuleT, and generate the runtime and deploy IR +def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(global_ctx.functions) + function_defs = _topsort(module_ctx.function_defs) + reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) @@ -421,20 +440,26 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # compile internal functions first so we have the function info for func_ast in internal_functions: - func_ir = _ir_for_internal_function(func_ast, global_ctx, False) - internal_functions_ir.append(IRnode.from_list(func_ir)) + # compile it so that _ir_info is populated (whether or not it makes + # it into the final IR artifact) + func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + + # only include it in the IR if it is reachable from an external + # function. + if func_ast._metadata["func_type"] in reachable: + internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, global_ctx) + selector_section = _selector_section_linear(external_functions, module_ctx) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, global_ctx) + selector_section = _selector_section_dense(external_functions, module_ctx) else: - selector_section = _selector_section_sparse(external_functions, global_ctx) + selector_section = _selector_section_sparse(external_functions, module_ctx) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -447,29 +472,30 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = global_ctx.immutable_section_bytes + immutables_len = module_ctx.immutable_section_bytes if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` + init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: - init_func_t = init_function._metadata["type"] - if f.name not in init_func_t.recursive_calls: + func_t = f._metadata["func_type"] + if func_t not in init_func_t.reachable_internal_functions: # unreachable code, delete it continue - func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) + func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables # note: (deploy mem_ofst, code, extra_padding) - init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used + init_mem_used = init_function._metadata["func_type"]._ir_info.frame_info.mem_used # force msize to be initialized past the end of immutables section # so that builtins which use `msize` for "dynamic" memory diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f03f2eb9c8..f53e4a81b4 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -4,15 +4,6 @@ from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT -_label_counter = 0 - - -# TODO a more general way of doing this -def _generate_label(name: str) -> str: - global _label_counter - _label_counter += 1 - return f"label{_label_counter}" - def _align_kwargs(func_t, args_ir): """ @@ -63,7 +54,7 @@ def ir_for_self_call(stmt_expr, context): # note: internal_function_label asserts `func_t.is_internal`. _label = func_t._ir_info.internal_function_label(context.is_ctor_context) - return_label = _generate_label(f"{_label}_call") + return_label = _freshname(f"{_label}_call") # allocate space for the return buffer # TODO allocate in stmt and/or expr.py diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 254cad32e6..cc7a603b7c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import MEMORY, STORAGE from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure from vyper.semantics.types import DArrayT, MemberFunctionT +from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -117,44 +118,32 @@ def parse_Log(self): return events.ir_node_for_log(self.stmt, event, topic_ir, data_ir, self.context) def parse_Call(self): - # TODO use expr.func.type.is_internal once type annotations - # are consistently available. - is_self_function = ( - (isinstance(self.stmt.func, vy_ast.Attribute)) - and isinstance(self.stmt.func.value, vy_ast.Name) - and self.stmt.func.value.id == "self" - ) - if isinstance(self.stmt.func, vy_ast.Name): funcname = self.stmt.func.id return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context) - elif isinstance(self.stmt.func, vy_ast.Attribute) and self.stmt.func.attr in ( - "append", - "pop", - ): - func_type = self.stmt.func._metadata["type"] - if isinstance(func_type, MemberFunctionT): - darray = Expr(self.stmt.func.value, self.context).ir_node - args = [Expr(x, self.context).ir_node for x in self.stmt.args] - if self.stmt.func.attr == "append": - # sanity checks - assert len(args) == 1 - arg = args[0] - assert isinstance(darray.typ, DArrayT) - check_assign( - dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) - ) - - return append_dyn_array(darray, arg) - else: - assert len(args) == 0 - return pop_dyn_array(darray, return_popped_item=False) - - if is_self_function: - return self_call.ir_for_self_call(self.stmt, self.context) - else: - return external_call.ir_for_external_call(self.stmt, self.context) + func_type = self.stmt.func._metadata["type"] + + if isinstance(func_type, MemberFunctionT) and self.stmt.func.attr in ("append", "pop"): + darray = Expr(self.stmt.func.value, self.context).ir_node + args = [Expr(x, self.context).ir_node for x in self.stmt.args] + if self.stmt.func.attr == "append": + (arg,) = args + assert isinstance(darray.typ, DArrayT) + check_assign( + dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) + ) + + return append_dyn_array(darray, arg) + else: + assert len(args) == 0 + return pop_dyn_array(darray, return_popped_item=False) + + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.stmt, self.context) + else: + return external_call.ir_for_external_call(self.stmt, self.context) def _assert_reason(self, test_expr, msg): # from parse_Raise: None passed as the assert condition diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 61d7a7c229..026c8369c5 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -5,7 +5,7 @@ import vyper.ast as vy_ast # break an import cycle import vyper.codegen.core as codegen import vyper.compiler.output as output -from vyper.compiler.input_bundle import InputBundle, PathLike +from vyper.compiler.input_bundle import FileInput, InputBundle, PathLike from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version @@ -44,10 +44,8 @@ UNKNOWN_CONTRACT_NAME = "" -def compile_code( - contract_source: str, - contract_name: str = UNKNOWN_CONTRACT_NAME, - source_id: int = 0, +def compile_from_file_input( + file_input: FileInput, input_bundle: InputBundle = None, settings: Settings = None, output_formats: Optional[OutputFormats] = None, @@ -58,6 +56,8 @@ def compile_code( experimental_codegen: bool = False, ) -> dict: """ + Main entry point into the compiler. + Generate consumable compiler output(s) from a single contract source code. Basically, a wrapper around CompilerData which munges the output data into the requested output formats. @@ -72,6 +72,8 @@ def compile_code( evm_version: str, optional The target EVM ruleset to compile for. If not given, defaults to the latest implemented ruleset. + source_id: int, optional + source_id to tag AST nodes with. -1 if not provided. settings: Settings, optional Compiler settings. show_gas_estimates: bool, optional @@ -96,11 +98,11 @@ def compile_code( # make IR output the same between runs codegen.reset_names() + # TODO: maybe at this point we might as well just pass a `FileInput` + # directly to `CompilerData`. compiler_data = CompilerData( - contract_source, + file_input, input_bundle, - Path(contract_name), - source_id, settings, storage_layout_override, show_gas_estimates, @@ -118,8 +120,33 @@ def compile_code( ret[output_format] = formatter(compiler_data) except Exception as exc: if exc_handler is not None: - exc_handler(contract_name, exc) + exc_handler(str(file_input.path), exc) else: raise exc return ret + + +def compile_code( + source_code: str, + contract_path: str | PathLike = UNKNOWN_CONTRACT_NAME, + source_id: int = -1, + resolved_path: PathLike | None = None, + *args, + **kwargs, +): + # this function could be renamed to compile_from_string + """ + Do the same thing as compile_from_file_input but takes a string for source + code. This was previously the main entry point into the compiler + # (`compile_from_file_input()` is newer) + """ + if isinstance(contract_path, str): + contract_path = Path(contract_path) + file_input = FileInput( + source_id=source_id, + source_code=source_code, + path=contract_path, + resolved_path=resolved_path or contract_path, # type: ignore + ) + return compile_from_file_input(file_input, *args, **kwargs) diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 1e41c3f137..27170f0a56 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -15,15 +15,11 @@ class CompilerInput: # an input to the compiler, basically an abstraction for file contents source_id: int - path: PathLike + path: PathLike # the path that was asked for - @staticmethod - def from_string(source_id: int, path: PathLike, file_contents: str) -> "CompilerInput": - try: - s = json.loads(file_contents) - return ABIInput(source_id, path, s) - except (ValueError, TypeError): - return FileInput(source_id, path, file_contents) + # resolved_path is the real path that was resolved to. + # mainly handy for debugging at this point + resolved_path: PathLike @dataclass @@ -40,13 +36,16 @@ class ABIInput(CompilerInput): abi: Any # something that json.load() returns -class _NotFound(Exception): - pass +def try_parse_abi(file_input: FileInput) -> CompilerInput: + try: + s = json.loads(file_input.source_code) + return ABIInput(file_input.source_id, file_input.path, file_input.resolved_path, s) + except (ValueError, TypeError): + return file_input -# wrap os.path.normpath, but return the same type as the input -def _normpath(path): - return path.__class__(os.path.normpath(path)) +class _NotFound(Exception): + pass # an "input bundle" to the compiler, representing the files which are @@ -60,20 +59,31 @@ class InputBundle: # a list of search paths search_paths: list[PathLike] + _cache: Any + def __init__(self, search_paths): self.search_paths = search_paths self._source_id_counter = 0 self._source_ids: dict[PathLike, int] = {} - def _load_from_path(self, path): + # this is a little bit cursed, but it allows consumers to cache data that + # share the same lifetime as this input bundle. + self._cache = lambda: None + + def _normalize_path(self, path): + raise NotImplementedError(f"not implemented! {self.__class__}._normalize_path()") + + def _load_from_path(self, resolved_path, path): raise NotImplementedError(f"not implemented! {self.__class__}._load_from_path()") - def _generate_source_id(self, path: PathLike) -> int: - if path not in self._source_ids: - self._source_ids[path] = self._source_id_counter + def _generate_source_id(self, resolved_path: PathLike) -> int: + # Note: it is possible for a file to get in here more than once, + # e.g. by symlink + if resolved_path not in self._source_ids: + self._source_ids[resolved_path] = self._source_id_counter self._source_id_counter += 1 - return self._source_ids[path] + return self._source_ids[resolved_path] def load_file(self, path: PathLike | str) -> CompilerInput: # search path precedence @@ -84,12 +94,9 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # Path("/a") / Path("/b") => Path("/b") to_try = sp / path - # normalize the path with os.path.normpath, to break down - # things like "foo/bar/../x.vy" => "foo/x.vy", with all - # the caveats around symlinks that os.path.normpath comes with. - to_try = _normpath(to_try) try: - res = self._load_from_path(to_try) + to_try = self._normalize_path(to_try) + res = self._load_from_path(to_try, path) break except _NotFound: tried.append(to_try) @@ -104,7 +111,7 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # try to parse from json, so that return types are consistent # across FilesystemInputBundle and JSONInputBundle. if isinstance(res, FileInput): - return CompilerInput.from_string(res.source_id, res.path, res.source_code) + res = try_parse_abi(res) return res @@ -126,20 +133,45 @@ def search_path(self, path: Optional[PathLike]) -> Iterator[None]: finally: self.search_paths.pop() + # temporarily modify the top of the search path (within the + # scope of the context manager) with highest precedence to something else + @contextlib.contextmanager + def poke_search_path(self, path: PathLike) -> Iterator[None]: + tmp = self.search_paths[-1] + self.search_paths[-1] = path + try: + yield + finally: + self.search_paths[-1] = tmp + # regular input. takes a search path(s), and `load_file()` will search all # search paths for the file and read it from the filesystem class FilesystemInputBundle(InputBundle): - def _load_from_path(self, path: Path) -> CompilerInput: + def _normalize_path(self, path: Path) -> Path: + # normalize the path with os.path.normpath, to break down + # things like "foo/bar/../x.vy" => "foo/x.vy", with all + # the caveats around symlinks that os.path.normpath comes with. try: - with path.open() as f: - code = f.read() - except FileNotFoundError: + return path.resolve(strict=True) + except (FileNotFoundError, NotADirectoryError): raise _NotFound(path) - source_id = super()._generate_source_id(path) + def _load_from_path(self, resolved_path: Path, original_path: Path) -> CompilerInput: + try: + with resolved_path.open() as f: + code = f.read() + except (FileNotFoundError, NotADirectoryError): + raise _NotFound(resolved_path) + + source_id = super()._generate_source_id(resolved_path) + + return FileInput(source_id, original_path, resolved_path, code) - return FileInput(source_id, path, code) + +# wrap os.path.normpath, but return the same type as the input +def _normpath(path): + return path.__class__(os.path.normpath(path)) # fake filesystem for JSON inputs. takes a base path, and `load_file()` @@ -156,25 +188,28 @@ def __init__(self, input_json, search_paths): # should be checked by caller assert path not in self.input_json - self.input_json[_normpath(path)] = item + self.input_json[path] = item + + def _normalize_path(self, path: PurePath) -> PurePath: + return _normpath(path) - def _load_from_path(self, path: PurePath) -> CompilerInput: + def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> CompilerInput: try: - value = self.input_json[path] + value = self.input_json[resolved_path] except KeyError: - raise _NotFound(path) + raise _NotFound(resolved_path) - source_id = super()._generate_source_id(path) + source_id = super()._generate_source_id(resolved_path) if "content" in value: - return FileInput(source_id, path, value["content"]) + return FileInput(source_id, original_path, resolved_path, value["content"]) if "abi" in value: - return ABIInput(source_id, path, value["abi"]) + return ABIInput(source_id, original_path, resolved_path, value["abi"]) # TODO: ethPM support # if isinstance(contents, dict) and "contractTypes" in contents: # unreachable, based on how JSONInputBundle is constructed in # the codebase. - raise JSONError(f"Unexpected type in file: '{path}'") # pragma: nocover + raise JSONError(f"Unexpected type in file: '{resolved_path}'") # pragma: nocover diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index e47f300ba9..6d1e7ef70f 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict, deque +from pathlib import PurePath import asttokens @@ -33,8 +34,8 @@ def build_userdoc(compiler_data: CompilerData) -> dict: def build_external_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] - stem = compiler_data.contract_path.stem + interface = compiler_data.vyper_module_folded._metadata["type"].interface + stem = PurePath(compiler_data.contract_path).stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface name = "".join([x.capitalize() for x in stem.split("_")]) @@ -52,7 +53,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] + interface = compiler_data.vyper_module_folded._metadata["type"].interface out = "" if interface.events: @@ -70,7 +71,7 @@ def build_interface_output(compiler_data: CompilerData) -> str: out = f"{out}@{func.mutability.value}\n" args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments]) return_value = f" -> {func.return_type}" if func.return_type is not None else "" - out = f"{out}@external\ndef {func.name}({args}){return_value}:\n pass\n\n" + out = f"{out}@external\ndef {func.name}({args}){return_value}:\n ...\n\n" return out @@ -154,14 +155,19 @@ def _to_dict(func_t): def build_method_identifiers_output(compiler_data: CompilerData) -> dict: - interface = compiler_data.vyper_module_folded._metadata["type"] - functions = interface.functions.values() + module_t = compiler_data.vyper_module_folded._metadata["type"] + functions = module_t.function_defs - return {k: hex(v) for func in functions for k, v in func.method_ids.items()} + return { + k: hex(v) for func in functions for k, v in func._metadata["func_type"].method_ids.items() + } def build_abi_output(compiler_data: CompilerData) -> list: - abi = compiler_data.vyper_module_folded._metadata["type"].to_toplevel_abi_dict() + module_t = compiler_data.vyper_module_folded._metadata["type"] + _ = compiler_data.ir_runtime # ensure _ir_info is generated + + abi = module_t.interface.to_toplevel_abi_dict() if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e32812fee..edffa9a85e 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -7,18 +7,18 @@ from vyper import ast as vy_ast from vyper.codegen import module from vyper.codegen.core import anchor_opt_level -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout from vyper.venom import generate_assembly_experimental, generate_ir -DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") +DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") class CompilerData: @@ -35,7 +35,7 @@ class CompilerData: Top-level Vyper AST node vyper_module_folded : vy_ast.Module Folded Vyper AST - global_ctx : GlobalContext + global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode IR used to generate deployment bytecode @@ -53,10 +53,8 @@ class CompilerData: def __init__( self, - source_code: str, + file_input: FileInput | str, input_bundle: InputBundle = None, - contract_path: Path | PurePath = DEFAULT_CONTRACT_NAME, - source_id: int = 0, settings: Settings = None, storage_layout: StorageLayout = None, show_gas_estimates: bool = False, @@ -68,12 +66,10 @@ def __init__( Arguments --------- - source_code: str - Vyper source code. - contract_path: Path, optional - The name of the contract being compiled. - source_id: int, optional - ID number used to identify this contract in the source map. + file_input: FileInput | str + A FileInput or string representing the input to the compiler. + FileInput is preferred, but `str` is accepted as a convenience + method (and also for backwards compatibility reasons) settings: Settings Set optimization mode. show_gas_estimates: bool, optional @@ -85,9 +81,15 @@ def __init__( """ # to force experimental codegen, uncomment: # experimental_codegen = True - self.contract_path = contract_path - self.source_code = source_code - self.source_id = source_id + + if isinstance(file_input, str): + file_input = FileInput( + source_code=file_input, + source_id=-1, + path=DEFAULT_CONTRACT_PATH, + resolved_path=DEFAULT_CONTRACT_PATH, + ) + self.file_input = file_input self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata @@ -97,10 +99,26 @@ def __init__( _ = self._generate_ast # force settings to be calculated + @cached_property + def source_code(self): + return self.file_input.source_code + + @cached_property + def source_id(self): + return self.file_input.source_id + + @cached_property + def contract_path(self): + return self.file_input.path + @cached_property def _generate_ast(self): - contract_name = str(self.contract_path) - settings, ast = generate_ast(self.source_code, self.source_id, contract_name) + settings, ast = vy_ast.parse_to_ast_with_settings( + self.source_code, + self.source_id, + module_path=str(self.contract_path), + resolved_path=str(self.file_input.resolved_path), + ) # validate the compiler settings # XXX: this is a bit ugly, clean up later @@ -141,12 +159,12 @@ def vyper_module_unfolded(self) -> vy_ast.Module: # This phase is intended to generate an AST for tooling use, and is not # used in the compilation process. - return generate_unfolded_ast(self.contract_path, self.vyper_module, self.input_bundle) + return generate_unfolded_ast(self.vyper_module, self.input_bundle) @cached_property def _folded_module(self): return generate_folded_ast( - self.contract_path, self.vyper_module, self.input_bundle, self.storage_layout_override + self.vyper_module, self.input_bundle, self.storage_layout_override ) @property @@ -160,8 +178,8 @@ def storage_layout(self) -> StorageLayout: return storage_layout @property - def global_ctx(self) -> GlobalContext: - return GlobalContext(self.vyper_module_folded) + def global_ctx(self) -> ModuleT: + return self.vyper_module_folded._metadata["type"] @cached_property def _ir_output(self): @@ -189,7 +207,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: _ = self._ir_output fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef) - return {f.name: f._metadata["type"] for f in fs} + return {f.name: f._metadata["func_type"] for f in fs} @cached_property def assembly(self) -> list: @@ -230,37 +248,12 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_ast( - source_code: str, source_id: int, contract_name: str -) -> tuple[Settings, vy_ast.Module]: - """ - Generate a Vyper AST from source code. - - Arguments - --------- - source_code : str - Vyper source code. - source_id : int - ID number used to identify this contract in the source map. - contract_name: str - Name of the contract. - - Returns - ------- - vy_ast.Module - Top-level Vyper AST node - """ - return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name) - - # destructive -- mutates module in place! -def generate_unfolded_ast( - contract_path: Path | PurePath, vyper_module: vy_ast.Module, input_bundle: InputBundle -) -> vy_ast.Module: +def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) @@ -268,7 +261,6 @@ def generate_unfolded_ast( def generate_folded_ast( - contract_path: Path, vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, @@ -294,7 +286,7 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): validate_semantics(vyper_module_folded, input_bundle) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) @@ -302,9 +294,7 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes( - global_ctx: GlobalContext, optimize: OptimizationLevel -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -315,7 +305,7 @@ def generate_ir_nodes( Arguments --------- - global_ctx : GlobalContext + global_ctx: ModuleT Contextualized Vyper AST Returns diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3bde20356e..993c0a85eb 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -49,6 +49,7 @@ def __init__(self, message="Error Message not found.", *items): self.message = message self.lineno = None self.col_offset = None + self.annotations = None if len(items) == 1 and isinstance(items[0], tuple) and isinstance(items[0][0], int): # support older exceptions that don't annotate - remove this in the future! @@ -79,7 +80,7 @@ def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code - if not hasattr(self, "annotations"): + if not self.annotations: if self.lineno is not None and self.col_offset is not None: return f"line {self.lineno}:{self.col_offset} {self.message}" else: @@ -105,8 +106,9 @@ def __str__(self): if isinstance(node, vy_ast.VyperNode): module_node = node.get_ancestor(vy_ast.Module) - if module_node.get("name") not in (None, ""): - node_msg = f'{node_msg}contract "{module_node.name}:{node.lineno}", ' + + if module_node.get("path") not in (None, ""): + node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' fn_node = node.get_ancestor(vy_ast.FunctionDef) if fn_node: @@ -229,6 +231,18 @@ class CallViolation(VyperException): """Illegal function call.""" +class ImportCycle(VyperException): + """An import cycle""" + + +class DuplicateImport(VyperException): + """A module was imported twice from the same module""" + + +class ModuleNotFound(VyperException): + """Module was not found""" + + class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7db230167e..7b52a68e92 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,17 +1,4 @@ -import vyper.ast as vy_ast - from .. import types # break a dependency cycle. -from ..namespace import get_namespace -from .local import validate_functions -from .module import add_module_namespace -from .utils import _ExprAnalyser - - -def validate_semantics(vyper_ast, input_bundle): - # validate semantics and annotate AST with type/semantics information - namespace = get_namespace() +from .module import validate_semantics - with namespace.enter_scope(): - add_module_namespace(vyper_ast, input_bundle) - vy_ast.expansion.expand_annotated_ast(vyper_ast) - validate_functions(vyper_ast) +__all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 449e6ca338..4d1b1cdbab 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,8 +1,9 @@ import enum from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast +from vyper.compiler.input_bundle import InputBundle from vyper.exceptions import ( CompilerPanic, ImmutableViolation, @@ -12,6 +13,9 @@ from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +if TYPE_CHECKING: + from vyper.semantics.types.module import InterfaceT, ModuleT + class _StringEnum(enum.Enum): @staticmethod @@ -145,6 +149,35 @@ def __repr__(self): return f"" +# base class for things that are the "result" of analysis +class AnalysisResult: + pass + + +@dataclass +class ModuleInfo(AnalysisResult): + module_t: "ModuleT" + + @property + def module_node(self): + return self.module_t._module + + # duck type, conform to interface of VarInfo and ExprInfo + @property + def typ(self): + return self.module_t + + +@dataclass +class ImportInfo(AnalysisResult): + typ: Union[ModuleInfo, "InterfaceT"] + alias: str # the name in the namespace + qualified_module_name: str # for error messages + # source_id: int + input_bundle: InputBundle + node: vy_ast.VyperNode + + @dataclass class VarInfo: """ @@ -212,6 +245,10 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": is_immutable=var_info.is_immutable, ) + @classmethod + def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": + return cls(module_info.module_t) + def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 507eb0a570..9d35aef2bd 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -1,6 +1,17 @@ +import contextlib from typing import Tuple -from vyper.exceptions import StructureException +from vyper.exceptions import StructureException, VyperException + + +@contextlib.contextmanager +def tag_exceptions(node): + try: + yield + except VyperException as e: + if not e.annotations and not e.lineno: + raise e.with_annotation(node) from None + raise e from None class VyperNodeVisitorBase: @@ -16,9 +27,11 @@ def visit(self, node, *args): # node types with a shared parent for class_ in node.__class__.mro(): ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - return visitor_fn(node, *args) + + with tag_exceptions(node): + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) node_type = type(node).__name__ raise StructureException( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 87ec45c40d..88679a4b09 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -79,7 +79,7 @@ def set_storage_slots_with_overrides( # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] # Ignore functions without non-reentrant if type_.nonreentrant is None: @@ -165,7 +165,7 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: ret: Dict[str, Dict] = {} for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py new file mode 100644 index 0000000000..e406878194 --- /dev/null +++ b/vyper/semantics/analysis/import_graph.py @@ -0,0 +1,37 @@ +import contextlib +from dataclasses import dataclass, field +from typing import Iterator + +from vyper import ast as vy_ast +from vyper.exceptions import CompilerPanic, ImportCycle + +""" +data structure for collecting import statements and validating the +import graph +""" + + +@dataclass +class ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + if expected != popped: + raise CompilerPanic("unreachable") + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 647f01c299..974c14f261 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -55,14 +55,15 @@ def validate_functions(vy_module: vy_ast.Module) -> None: - """Analyzes a vyper ast and validates the function-level namespaces.""" + """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer.analyze() except VyperException as e: err_list.append(e) @@ -185,26 +186,31 @@ def __init__( self.vyper_module = vyper_module self.fn_node = fn_node self.namespace = namespace - self.func = fn_node._metadata["type"] + self.func = fn_node._metadata["func_type"] self.expr_visitor = _ExprVisitor(self.func) + def analyze(self): # allow internal function params to be mutable location, is_immutable = ( (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) + self.namespace[arg.name] = VarInfo( + arg.typ, location=location, is_immutable=is_immutable + ) - for node in fn_node.body: + for node in self.fn_node.body: self.visit(node) + if self.func.return_type: - if not check_for_terminus(fn_node.body): + if not check_for_terminus(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node + f"Missing or unmatched return statements in function '{self.fn_node.name}'", + self.fn_node, ) # visit default args - assert self.func.n_keyword_args == len(fn_node.args.defaults) + assert self.func.n_keyword_args == len(self.fn_node.args.defaults) for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @@ -224,10 +230,7 @@ def visit_AnnAssign(self, node): typ = type_from_annotation(node.annotation, DataLocation.MEMORY) validate_expected_type(node.value, typ) - try: - self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) self.expr_visitor.visit(node.target, typ) self.expr_visitor.visit(node.value, typ) @@ -290,6 +293,13 @@ def visit_Continue(self, node): raise StructureException("`continue` must be enclosed in a `for` loop", node) def visit_Expr(self, node): + if isinstance(node.value, vy_ast.Ellipsis): + raise StructureException( + "`...` is not allowed in `.vy` files! " + "Did you mean to import me as a `.vyi` file?", + node, + ) + if not isinstance(node.value, vy_ast.Call): raise StructureException("Expressions without assignment are disallowed", node) @@ -433,6 +443,7 @@ def visit_For(self, node): # Check if `iter` is a storage variable. get_descendants` is used to check for # nested `self` (e.g. structs) + # NOTE: this analysis will be borked once stateful modules are allowed! iter_is_storage_var = ( isinstance(node.iter, vy_ast.Attribute) and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 @@ -453,8 +464,11 @@ def visit_For(self, node): call_node, ) - for name in self.namespace["self"].typ.members[fn_name].recursive_calls: + for reachable_t in ( + self.namespace["self"].typ.members[fn_name].reachable_internal_functions + ): # check for indirect modification + name = reachable_t.name fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] if _check_iterator_modification(node.iter, fn_node): raise ImmutableViolation( @@ -472,10 +486,7 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - try: - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) try: with NodeMetadata.enter_typechecker_speculation(): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 239438f35b..7aa661aec3 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,6 +1,6 @@ import os from pathlib import Path, PurePath -from typing import Optional +from typing import Any, Optional import vyper.builtins.interfaces from vyper import ast as vy_ast @@ -8,9 +8,11 @@ from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, + DuplicateImport, ExceptionList, InvalidLiteral, InvalidType, + ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, @@ -18,128 +20,200 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase -from vyper.semantics.analysis.utils import check_constant, validate_expected_type +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.utils import ( + check_constant, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import Namespace, get_namespace +from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation -def add_module_namespace(vy_module: vy_ast.Module, input_bundle: InputBundle) -> None: +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + +def validate_semantics_r( + module_ast: vy_ast.Module, + input_bundle: InputBundle, + import_graph: ImportGraph, + is_interface: bool, +) -> ModuleT: """ Analyze a Vyper module AST node, add all module-level objects to the - namespace and validate top-level correctness + namespace, type-check/validate semantics and annotate with type and analysis info """ - + # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - ModuleAnalyzer(vy_module, input_bundle, namespace) + with namespace.enter_scope(), import_graph.enter_path(module_ast): + analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + ret = analyzer.analyze() + + vy_ast.expansion.generate_public_variable_getters(module_ast) + + # if this is an interface, the function is already validated + # in `ContractFunction.from_vyi()` + if not is_interface: + validate_functions(module_ast) + + return ret + + +# compute reachable set and validate the call graph (detect cycles) +def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT] = None) -> None: + path = path or [] + + path.append(fn_t) + root = path[0] -def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: - if fn_names[-1] not in self_members: - return None - internal_calls = self_members[fn_names[-1]].internal_calls - for name in internal_calls: - if name in fn_names: - return fn_names + [name] - sequence = _find_cyclic_call(fn_names + [name], self_members) - if sequence: - return sequence - return None + for g in fn_t.called_functions: + if g == root: + message = " -> ".join([f.name for f in path]) + raise CallViolation(f"Contract contains cyclic function call: {message}") + + _compute_reachable_set(g, path=path) + + for h in g.reachable_internal_functions: + assert h != fn_t # sanity check + + fn_t.reachable_internal_functions.add(h) + + fn_t.reachable_internal_functions.add(g) + + path.pop() class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace + self, + module_node: vy_ast.Module, + input_bundle: InputBundle, + namespace: Namespace, + import_graph: ImportGraph, + is_interface: bool = False, ) -> None: self.ast = module_node self.input_bundle = input_bundle self.namespace = namespace + self._import_graph = import_graph + self.is_interface = is_interface - # TODO: Move computation out of constructor - module_nodes = module_node.body.copy() - while module_nodes: - count = len(module_nodes) + # keep track of imported modules to prevent duplicate imports + self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} + + self.module_t: Optional[ModuleT] = None + + # ast cache, hitchhike onto the input_bundle object + if not hasattr(self.input_bundle._cache, "_ast_of"): + self.input_bundle._cache._ast_of: dict[int, vy_ast.Module] = {} # type: ignore + + def analyze(self) -> ModuleT: + # generate a `ModuleT` from the top-level node + # note: also validates unique method ids + if "type" in self.ast._metadata: + assert isinstance(self.ast._metadata["type"], ModuleT) + # we don't need to analyse again, skip out + self.module_t = self.ast._metadata["type"] + return self.module_t + + to_visit = self.ast.body.copy() + + # handle imports linearly + # (do this instead of handling in the next block so that + # `self._imported_modules` does not end up with garbage in it after + # exception swallowing). + import_stmts = self.ast.get_children((vy_ast.Import, vy_ast.ImportFrom)) + for node in import_stmts: + self.visit(node) + to_visit.remove(node) + + # keep trying to process all the nodes until we finish or can + # no longer progress. this makes it so we don't need to + # calculate a dependency tree between top-level items. + while len(to_visit) > 0: + count = len(to_visit) err_list = ExceptionList() - for node in list(module_nodes): + for node in to_visit.copy(): try: self.visit(node) - module_nodes.remove(node) - except (InvalidLiteral, InvalidType, VariableDeclarationException): + to_visit.remove(node) + except (InvalidLiteral, InvalidType, VariableDeclarationException) as e: # these exceptions cannot be caused by another statement not yet being # parsed, so we raise them immediately - raise + raise e from None except VyperException as e: err_list.append(e) # Only raise if no nodes were successfully processed. This allows module # level logic to parse regardless of the ordering of code elements. - if count == len(module_nodes): + if count == len(to_visit): err_list.raise_if_not_empty() - # generate an `InterfaceT` from the top-level node - used for building the ABI - # note: also validates unique method ids - interface = InterfaceT.from_ast(module_node) - module_node._metadata["type"] = interface - self.interface = interface # this is useful downstream + self.module_t = ModuleT(self.ast) + self.ast._metadata["type"] = self.module_t # attach namespace to the module for downstream use. _ns = Namespace() # note that we don't just copy the namespace because # there are constructor issues. - _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore - module_node._metadata["namespace"] = _ns + _ns.update({k: self.namespace[k] for k in self.namespace._scopes[-1]}) # type: ignore + self.ast._metadata["namespace"] = _ns + + self.analyze_call_graph() - self_members = namespace["self"].typ.members + return self.module_t + def analyze_call_graph(self): # get list of internal function calls made by each function - function_defs = self.ast.get_children(vy_ast.FunctionDef) - function_names = set(node.name for node in function_defs) - for node in function_defs: - calls_to_self = set( - i.func.attr for i in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}) - ) - # anything that is not a function call will get semantically checked later - calls_to_self = calls_to_self.intersection(function_names) - self_members[node.name].internal_calls = calls_to_self - - for fn_name in sorted(function_names): - if fn_name not in self_members: - # the referenced function does not exist - this is an issue, but we'll report - # it later when parsing the function so we can give more meaningful output - continue - - # check for circular function calls - sequence = _find_cyclic_call([fn_name], self_members) - if sequence is not None: - nodes = [] - for i in range(len(sequence) - 1): - fn_node = self.ast.get_children(vy_ast.FunctionDef, {"name": sequence[i]})[0] - call_node = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "self", "attr": sequence[i + 1]} - )[0] - nodes.append(call_node) - - raise CallViolation("Contract contains cyclic function call", *nodes) - - # get complete list of functions that are reachable from this function - function_set = set(i for i in self_members[fn_name].internal_calls if i in self_members) - while True: - expanded = set(x for i in function_set for x in self_members[i].internal_calls) - expanded |= function_set - if expanded == function_set: - break - function_set = expanded - - self_members[fn_name].recursive_calls = function_set + function_defs = self.module_t.function_defs + + for func in function_defs: + fn_t = func._metadata["func_type"] + + function_calls = func.get_descendants(vy_ast.Call) + + for call in function_calls: + try: + call_t = get_exact_type_from_node(call.func) + except VyperException: + # either there is a problem getting the call type. this is + # an issue, but it will be handled properly later. right now + # we just want to be able to construct the call graph. + continue + + if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + fn_t.called_functions.add(call_t) + + for func in function_defs: + fn_t = func._metadata["func_type"] + + # compute reachable set and validate the call graph + _compute_reachable_set(fn_t) + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self.input_bundle._cache._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_and_fold_ast(file) + + return ast_of[file.source_id] def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) + if not isinstance(type_, InterfaceT): raise StructureException("Invalid interface name", node.annotation) @@ -153,8 +227,9 @@ def visit_VariableDecl(self, node): if node.is_public: # generate function type and add to metadata # we need this when building the public getter - node._metadata["func_type"] = ContractFunctionT.getter_from_VariableDecl(node) + node._metadata["getter_type"] = ContractFunctionT.getter_from_VariableDecl(node) + # TODO: move this check to local analysis if node.is_immutable: # mutability is checked automatically preventing assignment # outside of the constructor, here we just check a value is assigned, @@ -213,22 +288,18 @@ def _finalize(): self.namespace["self"].typ.add_member(name, var_info) node.target._metadata["type"] = type_ except NamespaceCollision: + # rewrite the error message to be slightly more helpful raise NamespaceCollision( f"Value '{name}' has already been declared", node ) from None - except VyperException as exc: - raise exc.with_annotation(node) from None def _validate_self_namespace(): # block globals if storage variable already exists - try: - if name in self.namespace["self"].typ.members: - raise NamespaceCollision( - f"Value '{name}' has already been declared", node - ) from None - self.namespace[name] = var_info - except VyperException as exc: - raise exc.with_annotation(node) from None + if name in self.namespace["self"].typ.members: + raise NamespaceCollision( + f"Value '{name}' has already been declared", node + ) from None + self.namespace[name] = var_info if node.is_constant: if not node.value: @@ -251,41 +322,50 @@ def _validate_self_namespace(): _validate_self_namespace() return _finalize() - try: - self.namespace.validate_assignment(name) - except NamespaceCollision as exc: - raise exc.with_annotation(node) from None + self.namespace.validate_assignment(name) return _finalize() def visit_EnumDef(self, node): obj = EnumT.from_EnumDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[node.name] = obj def visit_EventDef(self, node): obj = EventT.from_EventDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + node._metadata["event_type"] = obj + self.namespace[node.name] = obj def visit_FunctionDef(self, node): - func = ContractFunctionT.from_FunctionDef(node) + if self.is_interface: + func_t = ContractFunctionT.from_vyi(node) + if not func_t.is_external: + # TODO test me! + raise StructureException( + "Internal functions in `.vyi` files are not allowed!", node + ) + else: + func_t = ContractFunctionT.from_FunctionDef(node) - try: - self.namespace["self"].typ.add_member(func.name, func) - node._metadata["type"] = func - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace["self"].typ.add_member(func_t.name, func_t) + node._metadata["func_type"] = func_t def visit_Import(self, node): - if not node.alias: - raise StructureException("Import requires an accompanying `as` statement", node) # import x.y[name] as y[alias] - self._add_import(node, 0, node.name, node.alias) + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + suggested_alias = node.name[node.name.rfind(".") :] + suggestion = f"hint: try `import {node.name} as {suggested_alias}`" + raise StructureException( + f"import requires an accompanying `as` statement ({suggestion})", node + ) + + self._add_import(node, 0, node.name, alias) def visit_ImportFrom(self, node): # from m.n[module] import x[name] as y[alias] @@ -299,42 +379,87 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_ast(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + obj = InterfaceT.from_InterfaceDef(node) + self.namespace[node.name] = obj def visit_StructDef(self, node): - struct_t = StructT.from_ast_def(node) - try: - self.namespace[node.name] = struct_t - except VyperException as exc: - raise exc.with_annotation(node) from None + struct_t = StructT.from_StructDef(node) + node._metadata["struct_type"] = struct_t + self.namespace[node.name] = struct_t def _add_import( self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str ) -> None: - type_ = self._load_import(level, qualified_module_name) - - try: - self.namespace[alias] = type_ - except VyperException as exc: - raise exc.with_annotation(node) from None + module_info = self._load_import(node, level, qualified_module_name, alias) + node._metadata["import_info"] = ImportInfo( + module_info, alias, qualified_module_name, self.input_bundle, node + ) + self.namespace[alias] = module_info - # load an InterfaceT from an import. + # load an InterfaceT or ModuleInfo from an import. # raises FileNotFoundError - def _load_import(self, level: int, module_str: str) -> InterfaceT: + def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: + # the directory this (currently being analyzed) module is in + self_search_path = Path(self.ast.resolved_path).parent + + with self.input_bundle.poke_search_path(self_search_path): + return self._load_import_helper(node, level, module_str, alias) + + def _load_import_helper( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> Any: if _is_builtin(module_str): return _load_builtin_import(level, module_str) path = _import_to_path(level, module_str) + # this could conceivably be in the ImportGraph but no need at this point + if path in self._imported_modules: + previous_import_stmt = self._imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self._imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self.input_bundle.load_file(path_vy) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=False, + ) + + return ModuleInfo(module_t) + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + try: - file = self.input_bundle.load_file(path.with_suffix(".vy")) + file = self.input_bundle.load_file(path.with_suffix(".vyi")) assert isinstance(file, FileInput) # mypy hint - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=str(file.path)) - return InterfaceT.from_ast(interface_ast) + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=True, + ) + module_t = module_ast._metadata["type"] + + return module_t.interface + except FileNotFoundError: pass @@ -343,7 +468,24 @@ def _load_import(self, level: int, module_str: str) -> InterfaceT: assert isinstance(file, ABIInput) # mypy hint return InterfaceT.from_json_abi(str(file.path), file.abi) except FileNotFoundError: - raise ModuleNotFoundError(module_str) + pass + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, node) from err + + +def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=str(file.path), + resolved_path=str(file.resolved_path), + ) + vy_ast.validation.validate_literal_nodes(ret) + vy_ast.folding.fold(ret) + + return ret # convert an import to a path (without suffix) @@ -385,7 +527,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: remapped_module = remapped_module.removeprefix("vyper.interfaces") remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - path = _import_to_path(level, remapped_module).with_suffix(".vy") + path = _import_to_path(level, remapped_module).with_suffix(".vyi") try: file = input_bundle.load_file(path) @@ -394,5 +536,8 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None # TODO: it might be good to cache this computation - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=module_str) - return InterfaceT.from_ast(interface_ast) + interface_ast = _parse_and_fold_ast(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index afa6b56838..1785afd92d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -66,8 +66,15 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): - varinfo = self.namespace[node.id] - return ExprInfo.from_varinfo(varinfo) + info = self.namespace[node.id] + + if isinstance(info, VarInfo): + return ExprInfo.from_varinfo(info) + + if isinstance(info, ModuleInfo): + return ExprInfo.from_moduleinfo(info) + + raise CompilerPanic("unreachable!", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -192,16 +199,17 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, VyperType): + + if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] if is_self_reference and (s.is_constant or s.is_immutable): _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] - except UnknownAttribute: + except UnknownAttribute as e: if not is_self_reference: - raise + raise e from None if name in self.namespace: _raise_invalid_reference(name, node) @@ -364,6 +372,7 @@ def types_from_Name(self, node): return [TYPE_T(t)] return [t.typ] + except VyperException as exc: raise exc.with_annotation(node) from None diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 613ac0c03b..4df2511a29 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -95,7 +95,7 @@ def validate_assignment(self, attr): def get_namespace(): """ - Get the active namespace object. + Get the global namespace object. """ global _namespace try: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index ad470718c8..1fef6a706e 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -2,9 +2,10 @@ from .base import TYPE_T, KwargSettings, VyperType, is_type_t from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT +from .module import InterfaceT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT -from .user import EnumT, EventT, InterfaceT, StructT +from .user import EnumT, EventT, StructT def _get_primitive_types(): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index c5af5c2a39..d22d9bfff9 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -44,6 +44,13 @@ class VyperType: A tuple of invalid `DataLocation`s for this type _is_prim_word: bool, optional This is a word type like uint256, int8, bytesM or address + _supports_external_calls: bool, optional + Whether or not this type supports external calls. Currently + limited to `InterfaceT`s + _attribute_in_annotation: bool, optional + Whether or not this type can be attributed in a type + annotation, like IFoo.SomeType. Currently limited to + `InterfaceT`s. """ _id: str @@ -58,6 +65,9 @@ class VyperType: _as_array: bool = False # rename to something like can_be_array_member _as_hashmap_key: bool = False + _supports_external_calls: bool = False + _attribute_in_annotation: bool = False + size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: @@ -261,7 +271,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]: VyperType, optional Type generated as a result of the call. """ - raise StructureException("Value is not callable", node) + raise StructureException(f"{self} is not callable", node) @classmethod def get_subscripted_type(self, node: vy_ast.Index) -> None: diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 09130626aa..e3c381ac69 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -132,7 +132,15 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": raise UnexpectedValue("Node id does not match type name") length = get_index_value(node.slice) # type: ignore - # return cls._type(length, location, is_constant, is_public, is_immutable) + + if length is None: + raise StructureException( + f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node + ) + + # TODO: pass None to constructor after we redo length inference on bytestrings + length = length or 0 + return cls(length) @classmethod diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 140f73f095..ec30ac85d6 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -17,7 +17,11 @@ StructureException, ) from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot -from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type +from vyper.semantics.analysis.utils import ( + check_kwargable, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -44,6 +48,7 @@ class KeywordArg(_FunctionArg): ast_source: Optional[vy_ast.VyperNode] = None +# TODO: refactor this into FunctionT (from an ast) and ABIFunctionT (from json) class ContractFunctionT(VyperType): """ Contract function type. @@ -81,6 +86,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, nonreentrant: Optional[str] = None, + ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -92,11 +98,18 @@ def __init__( self.mutability = state_mutability self.nonreentrant = nonreentrant - # a list of internal functions this function calls - self.called_functions = OrderedSet[ContractFunctionT]() + self.ast_def = ast_def + + # a list of internal functions this function calls. + # to be populated during analysis + self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() + + # recursively reachable from this function + self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # to be populated during codegen self._ir_info: Any = None + self._function_id: Optional[int] = None @cached_property def call_site_kwargs(self): @@ -126,7 +139,7 @@ def __hash__(self): return hash(id(self)) @classmethod - def from_abi(cls, abi: Dict) -> "ContractFunctionT": + def from_abi(cls, abi: dict) -> "ContractFunctionT": """ Generate a `ContractFunctionT` object from an ABI interface. @@ -157,190 +170,174 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT": ) @classmethod - def from_FunctionDef( - cls, node: vy_ast.FunctionDef, is_interface: Optional[bool] = False - ) -> "ContractFunctionT": + def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": """ - Generate a `ContractFunctionT` object from a `FunctionDef` node. + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an `InterfaceDef` Arguments --------- - node : FunctionDef + funcdef: FunctionDef Vyper ast node to generate the function definition from. - is_interface: bool, optional - Boolean indicating if the function definition is part of an interface. Returns ------- ContractFunctionT """ - kwargs: Dict[str, Any] = {} - if is_interface: - # FunctionDef with stateMutability in body (Interface defintions) - if ( - len(node.body) == 1 - and isinstance(node.body[0], vy_ast.Expr) - and isinstance(node.body[0].value, vy_ast.Name) - and StateMutability.is_valid_value(node.body[0].value.id) - ): - # Interfaces are always public - kwargs["function_visibility"] = FunctionVisibility.EXTERNAL - kwargs["state_mutability"] = StateMutability(node.body[0].value.id) - elif len(node.body) == 1 and node.body[0].get("value.id") in ("constant", "modifying"): - if node.body[0].value.id == "constant": - expected = "view or pure" - else: - expected = "payable or nonpayable" - raise StructureException( - f"State mutability should be set to {expected}", node.body[0] - ) + # FunctionDef with stateMutability in body (Interface defintions) + body = funcdef.body + if ( + len(body) == 1 + and isinstance(body[0], vy_ast.Expr) + and isinstance(body[0].value, vy_ast.Name) + and StateMutability.is_valid_value(body[0].value.id) + ): + # Interfaces are always public + function_visibility = FunctionVisibility.EXTERNAL + state_mutability = StateMutability(body[0].value.id) + # handle errors + elif len(body) == 1 and body[0].get("value.id") in ("constant", "modifying"): + if body[0].value.id == "constant": + expected = "view or pure" else: - raise StructureException( - "Body must only contain state mutability label", node.body[0] - ) - + expected = "payable or nonpayable" + raise StructureException(f"State mutability should be set to {expected}", body[0]) else: - # FunctionDef with decorators (normal functions) - for decorator in node.decorator_list: - if isinstance(decorator, vy_ast.Call): - if "nonreentrant" in kwargs: - raise StructureException( - "nonreentrant decorator is already set with key: " - f"{kwargs['nonreentrant']}", - node, - ) + raise StructureException("Body must only contain state mutability label", body[0]) - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - if node.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" - raise FunctionDeclarationException(msg, decorator) - - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) - - kwargs["nonreentrant"] = nonreentrant_key - - elif isinstance(decorator, vy_ast.Name): - if FunctionVisibility.is_valid_value(decorator.id): - if "function_visibility" in kwargs: - raise FunctionDeclarationException( - f"Visibility is already set to: {kwargs['function_visibility']}", - node, - ) - kwargs["function_visibility"] = FunctionVisibility(decorator.id) - - elif StateMutability.is_valid_value(decorator.id): - if "state_mutability" in kwargs: - raise FunctionDeclarationException( - f"Mutability is already set to: {kwargs['state_mutability']}", node - ) - kwargs["state_mutability"] = StateMutability(decorator.id) - - else: - if decorator.id == "constant": - warnings.warn( - "'@constant' decorator has been removed (see VIP2040). " - "Use `@view` instead.", - DeprecationWarning, - ) - raise FunctionDeclarationException( - f"Unknown decorator: {decorator.id}", decorator - ) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=None, + ast_def=funcdef, + ) + + @classmethod + def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an interface (`.vyi`) file + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + if nonreentrant_key is not None: + raise FunctionDeclarationException( + "nonreentrant key not allowed in interfaces", funcdef + ) + + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - else: - raise StructureException("Bad decorator syntax", decorator) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) - if "function_visibility" not in kwargs: + if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", node + "function body in an interface can only be ...!", funcdef ) - if node.name == "__default__": - if kwargs["function_visibility"] != FunctionVisibility.EXTERNAL: + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) + + @classmethod + def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` node. + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + # validate default and init functions + if funcdef.name == "__default__": + if function_visibility != FunctionVisibility.EXTERNAL: raise FunctionDeclarationException( - "Default function must be marked as `@external`", node + "Default function must be marked as `@external`", funcdef ) - if node.args.args: + if funcdef.args.args: raise FunctionDeclarationException( - "Default function may not receive any arguments", node.args.args[0] + "Default function may not receive any arguments", funcdef.args.args[0] ) - if "state_mutability" not in kwargs: - # Assume nonpayable if not set at all (cannot accept Ether, but can modify state) - kwargs["state_mutability"] = StateMutability.NONPAYABLE - - if kwargs["state_mutability"] == StateMutability.PURE and "nonreentrant" in kwargs: - raise StructureException("Cannot use reentrancy guard on pure functions", node) - - if node.name == "__init__": + if funcdef.name == "__init__": if ( - kwargs["state_mutability"] in (StateMutability.PURE, StateMutability.VIEW) - or kwargs["function_visibility"] == FunctionVisibility.INTERNAL + state_mutability in (StateMutability.PURE, StateMutability.VIEW) + or function_visibility == FunctionVisibility.INTERNAL ): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", node + "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef ) - - # call arguments - if node.args.defaults: + if return_type is not None: raise FunctionDeclarationException( - "Constructor may not use default arguments", node.args.defaults[0] + "Constructor may not have a return type", funcdef.returns ) - argnames = set() # for checking uniqueness - n_total_args = len(node.args.args) - n_positional_args = n_total_args - len(node.args.defaults) - - positional_args: list[PositionalArg] = [] - keyword_args: list[KeywordArg] = [] - - for i, arg in enumerate(node.args.args): - argname = arg.arg - if argname in ("gas", "value", "skip_contract_check", "default_return_value"): - raise ArgumentException( - f"Cannot use '{argname}' as a variable name in a function input", arg + # call arguments + if funcdef.args.defaults: + raise FunctionDeclarationException( + "Constructor may not use default arguments", funcdef.args.defaults[0] ) - if argname in argnames: - raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) - - if arg.annotation is None: - raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) - - type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) - - if i < n_positional_args: - positional_args.append(PositionalArg(argname, type_, ast_source=arg)) - else: - value = node.args.defaults[i - n_positional_args] - if not check_kwargable(value): - raise StateAccessViolation( - "Value must be literal or environment variable", value - ) - validate_expected_type(value, type_) - keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) - - argnames.add(argname) - # return types - if node.returns is None: - return_type = None - elif node.name == "__init__": - raise FunctionDeclarationException( - "Constructor may not have a return type", node.returns - ) - elif isinstance(node.returns, (vy_ast.Name, vy_ast.Subscript, vy_ast.Tuple)): - # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE - return_type = type_from_annotation(node.returns, DataLocation.MEMORY) - else: - raise InvalidType("Function return value must be a type name or tuple", node.returns) - - return cls(node.name, positional_args, keyword_args, return_type, **kwargs) + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) def set_reentrancy_key_position(self, position: StorageSlot) -> None: if hasattr(self, "reentrancy_key_position"): @@ -383,6 +380,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio return_type, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, + ast_def=node, ) @property @@ -489,8 +487,12 @@ def method_ids(self) -> Dict[str, int]: return method_ids def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: - if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL: - raise CallViolation("Cannot call external functions via 'self'", node) + # mypy hint - right now, the only way a ContractFunctionT can be + # called is via `Attribute`, e.x. self.foo() or library.bar() + assert isinstance(node.func, vy_ast.Attribute) + parent_t = get_exact_type_from_node(node.func.value) + if not parent_t._supports_external_calls and self.visibility == FunctionVisibility.EXTERNAL: + raise CallViolation("Cannot call external functions via 'self' or via library", node) kwarg_keys = [] # for external calls, include gas and value as optional kwargs @@ -584,6 +586,125 @@ def abi_signature_for_kwargs(self, kwargs: list[KeywordArg]) -> str: return self.name + "(" + ",".join([arg.typ.abi_type.selector_name() for arg in args]) + ")" +def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: + # return types + if funcdef.returns is None: + return None + # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE + return type_from_annotation(funcdef.returns, DataLocation.MEMORY) + + +def _parse_decorators( + funcdef: vy_ast.FunctionDef, +) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: + function_visibility = None + state_mutability = None + nonreentrant_key = None + + for decorator in funcdef.decorator_list: + if isinstance(decorator, vy_ast.Call): + if nonreentrant_key is not None: + raise StructureException( + "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", + funcdef, + ) + + if decorator.get("func.id") != "nonreentrant": + raise StructureException("Decorator is not callable", decorator) + if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): + raise StructureException( + "@nonreentrant name must be given as a single string literal", decorator + ) + + if funcdef.name == "__init__": + msg = "Nonreentrant decorator disallowed on `__init__`" + raise FunctionDeclarationException(msg, decorator) + + nonreentrant_key = decorator.args[0].value + validate_identifier(nonreentrant_key, decorator.args[0]) + + elif isinstance(decorator, vy_ast.Name): + if FunctionVisibility.is_valid_value(decorator.id): + if function_visibility is not None: + raise FunctionDeclarationException( + f"Visibility is already set to: {function_visibility}", funcdef + ) + function_visibility = FunctionVisibility(decorator.id) + + elif StateMutability.is_valid_value(decorator.id): + if state_mutability is not None: + raise FunctionDeclarationException( + f"Mutability is already set to: {state_mutability}", funcdef + ) + state_mutability = StateMutability(decorator.id) + + else: + if decorator.id == "constant": + warnings.warn( + "'@constant' decorator has been removed (see VIP2040). " + "Use `@view` instead.", + DeprecationWarning, + ) + raise FunctionDeclarationException(f"Unknown decorator: {decorator.id}", decorator) + + else: + raise StructureException("Bad decorator syntax", decorator) + + if function_visibility is None: + raise FunctionDeclarationException( + f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", funcdef + ) + + if state_mutability is None: + # default to nonpayable + state_mutability = StateMutability.NONPAYABLE + + if state_mutability == StateMutability.PURE and nonreentrant_key is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + + # assert function_visibility is not None # mypy + # assert state_mutability is not None # mypy + return function_visibility, state_mutability, nonreentrant_key + + +def _parse_args( + funcdef: vy_ast.FunctionDef, is_interface: bool = False +) -> tuple[list[PositionalArg], list[KeywordArg]]: + argnames = set() # for checking uniqueness + n_total_args = len(funcdef.args.args) + n_positional_args = n_total_args - len(funcdef.args.defaults) + + positional_args = [] + keyword_args = [] + + for i, arg in enumerate(funcdef.args.args): + argname = arg.arg + if argname in ("gas", "value", "skip_contract_check", "default_return_value"): + raise ArgumentException( + f"Cannot use '{argname}' as a variable name in a function input", arg + ) + if argname in argnames: + raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) + + if arg.annotation is None: + raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) + + type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) + + if i < n_positional_args: + positional_args.append(PositionalArg(argname, type_, ast_source=arg)) + else: + value = funcdef.args.defaults[i - n_positional_args] + if not check_kwargable(value): + raise StateAccessViolation("Value must be literal or environment variable", value) + validate_expected_type(value, type_) + keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) + + argnames.add(argname) + + return positional_args, keyword_args + + class MemberFunctionT(VyperType): """ Member function type definition. diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py new file mode 100644 index 0000000000..4622482951 --- /dev/null +++ b/vyper/semantics/types/module.py @@ -0,0 +1,332 @@ +from functools import cached_property +from typing import Optional + +from vyper import ast as vy_ast +from vyper.abi_types import ABI_Address, ABIType +from vyper.ast.validation import validate_call_args +from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.namespace import get_namespace +from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.primitives import AddressT +from vyper.semantics.types.user import EventT, StructT, _UserType + + +class InterfaceT(_UserType): + _type_members = {"address": AddressT()} + _is_prim_word = True + _as_array = True + _as_hashmap_key = True + _supports_external_calls = True + _attribute_in_annotation = True + + def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> None: + validate_unique_method_ids(list(functions.values())) + + members = functions | events | structs + + # sanity check: by construction, there should be no duplicates. + assert len(members) == len(functions) + len(events) + len(structs) + + super().__init__(functions) + + self._helper = VyperType(events | structs) + self._id = _id + self.functions = functions + self.events = events + self.structs = structs + + def get_type_member(self, attr, node): + # get an event or struct from this interface + return TYPE_T(self._helper.get_member(attr, node)) + + @property + def getter_signature(self): + return (), AddressT() + + @property + def abi_type(self) -> ABIType: + return ABI_Address() + + def __repr__(self): + return f"interface {self._id}" + + # when using the type itself (not an instance) in the call position + def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": + self._ctor_arg_types(node) + return self + + def _ctor_arg_types(self, node): + validate_call_args(node, 1) + validate_expected_type(node.args[0], AddressT()) + return [AddressT()] + + def _ctor_kwarg_types(self, node): + return {} + + # TODO x.validate_implements(other) + def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: + namespace = get_namespace() + unimplemented = [] + + def _is_function_implemented(fn_name, fn_type): + vyper_self = namespace["self"].typ + if fn_name not in vyper_self.members: + return False + s = vyper_self.members[fn_name] + if isinstance(s, ContractFunctionT): + to_compare = vyper_self.members[fn_name] + # this is kludgy, rework order of passes in ModuleNodeVisitor + elif isinstance(s, VarInfo) and s.is_public: + to_compare = s.decl_node._metadata["getter_type"] + else: + return False + + return to_compare.implements(fn_type) + + # check for missing functions + for name, type_ in self.functions.items(): + if not isinstance(type_, ContractFunctionT): + # ex. address + continue + + if not _is_function_implemented(name, type_): + unimplemented.append(name) + + # check for missing events + for name, event in self.events.items(): + if name not in namespace: + unimplemented.append(name) + continue + + if not isinstance(namespace[name], EventT): + unimplemented.append(f"{name} is not an event!") + if ( + namespace[name].event_id != event.event_id + or namespace[name].indexed != event.indexed + ): + unimplemented.append(f"{name} is not implemented! (should be {event})") + + if len(unimplemented) > 0: + # TODO: improve the error message for cases where the + # mismatch is small (like mutability, or just one argument + # is off, etc). + missing_str = ", ".join(sorted(unimplemented)) + raise InterfaceViolation( + f"Contract does not implement all interface functions or events: {missing_str}", + node, + ) + + def to_toplevel_abi_dict(self) -> list[dict]: + abi = [] + for event in self.events.values(): + abi += event.to_toplevel_abi_dict() + for func in self.functions.values(): + abi += func.to_toplevel_abi_dict() + return abi + + # helper function which performs namespace collision checking + @classmethod + def _from_lists( + cls, + name: str, + function_list: list[tuple[str, ContractFunctionT]], + event_list: list[tuple[str, EventT]], + struct_list: list[tuple[str, StructT]], + ) -> "InterfaceT": + functions = {} + events = {} + structs = {} + + seen_items: dict = {} + + for name, function in function_list: + if name in seen_items: + raise NamespaceCollision(f"multiple functions named '{name}'!", function.ast_def) + functions[name] = function + seen_items[name] = function + + for name, event in event_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + events[name] = event + seen_items[name] = event + + for name, struct in struct_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + structs[name] = struct + seen_items[name] = struct + + return cls(name, functions, events, structs) + + @classmethod + def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": + """ + Generate an `InterfaceT` object from an ABI. + + Arguments + --------- + name : str + The name of the interface + abi : dict + Contract ABI + + Returns + ------- + InterfaceT + primitive interface type + """ + functions: list = [] + events: list = [] + + for item in [i for i in abi if i.get("type") == "function"]: + functions.append((item["name"], ContractFunctionT.from_abi(item))) + for item in [i for i in abi if i.get("type") == "event"]: + events.append((item["name"], EventT.from_abi(item))) + + structs: list = [] # no structs in json ABI (as of yet) + return cls._from_lists(name, functions, events, structs) + + @classmethod + def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": + """ + Generate an `InterfaceT` object from a Vyper ast node. + + Arguments + --------- + module_t: ModuleT + Vyper module type + Returns + ------- + InterfaceT + primitive interface type + """ + funcs = [] + + for node in module_t.function_defs: + func_t = node._metadata["func_type"] + if not func_t.is_external: + continue + funcs.append((node.name, func_t)) + + # add getters for public variables since they aren't yet in the AST + for node in module_t._module.get_children(vy_ast.VariableDecl): + if not node.is_public: + continue + getter = node._metadata["getter_type"] + funcs.append((node.target.id, getter)) + + events = [(node.name, node._metadata["event_type"]) for node in module_t.event_defs] + + structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] + + return cls._from_lists(module_t._id, funcs, events, structs) + + @classmethod + def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": + functions = [] + for node in node.body: + if not isinstance(node, vy_ast.FunctionDef): + raise StructureException("Interfaces can only contain function definitions", node) + if len(node.decorator_list) > 0: + raise StructureException( + "Function definition in interface cannot be decorated", node.decorator_list[0] + ) + functions.append((node.name, ContractFunctionT.from_InterfaceDef(node))) + + # no structs or events in InterfaceDefs + events: list = [] + structs: list = [] + + return cls._from_lists(node.name, functions, events, structs) + + +# Datatype to store all module information. +class ModuleT(VyperType): + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): + super().__init__() + + self._module = module + + self._id = name or module.path + + # compute the interface, note this has the side effect of checking + # for function collisions + self._helper = self.interface + + for f in self.function_defs: + # note: this checks for collisions + self.add_member(f.name, f._metadata["func_type"]) + + for e in self.event_defs: + # add the type of the event so it can be used in call position + self.add_member(e.name, TYPE_T(e._metadata["event_type"])) # type: ignore + + for s in self.struct_defs: + # add the type of the struct so it can be used in call position + self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for v in self.variable_decls: + self.add_member(v.target.id, v.target._metadata["varinfo"]) + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + self.add_member(import_info.alias, import_info.typ) + + # __eq__ is very strict on ModuleT - object equality! this is because we + # don't want to reason about where a module came from (i.e. input bundle, + # search path, symlinked vs normalized path, etc.) + def __eq__(self, other): + return self is other + + def __hash__(self): + return hash(id(self)) + + def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": + return self._helper.get_member(key, node) + + # this is a property, because the function set changes after AST expansion + @property + def function_defs(self): + return self._module.get_children(vy_ast.FunctionDef) + + @property + def event_defs(self): + return self._module.get_children(vy_ast.EventDef) + + @property + def struct_defs(self): + return self._module.get_children(vy_ast.StructDef) + + @property + def import_stmts(self): + return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + + @property + def variable_decls(self): + return self._module.get_children(vy_ast.VariableDecl) + + @cached_property + def variables(self): + # variables that this module defines, ex. + # `x: uint256` is a private storage variable named x + return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + + @cached_property + def immutables(self): + return [t for t in self.variables.values() if t.is_immutable] + + @cached_property + def immutable_section_bytes(self): + return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + + @cached_property + def interface(self): + return InterfaceT.from_ModuleT(self) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 6a2d3aae73..46dffbdec4 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from vyper import ast as vy_ast from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType @@ -68,7 +68,7 @@ def get_subscripted_type(self, node): return self.value_type @classmethod - def from_annotation(cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript]) -> "HashMapT": + def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT": if ( not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index) @@ -274,24 +274,32 @@ def compare_type(self, other): @classmethod def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": + # common error message, different ast locations + err_msg = "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]" + + if not isinstance(node, vy_ast.Subscript): + raise StructureException(err_msg, node) + if ( - not isinstance(node, vy_ast.Subscript) - or not isinstance(node.slice, vy_ast.Index) + not isinstance(node.slice, vy_ast.Index) or not isinstance(node.slice.value, vy_ast.Tuple) - or not isinstance(node.slice.value.elements[1], vy_ast.Int) or len(node.slice.value.elements) != 2 ): - raise StructureException( - "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]", - node, - ) + raise StructureException(err_msg, node.slice) + + length_node = node.slice.value.elements[1] + + if not isinstance(length_node, vy_ast.Int): + raise StructureException(err_msg, length_node) - value_type = type_from_annotation(node.slice.value.elements[0]) + length = length_node.value + + value_node = node.slice.value.elements[0] + value_type = type_from_annotation(value_node) if not value_type._as_darray: - raise StructureException(f"Arrays of {value_type} are not allowed", node) + raise StructureException(f"Arrays of {value_type} are not allowed", value_node) - max_length = node.slice.value.elements[1].value - return cls(value_type, max_length) + return cls(value_type, length) class TupleT(VyperType): @@ -333,7 +341,7 @@ def tuple_items(self): return list(enumerate(self.member_types)) @classmethod - def from_annotation(cls, node: vy_ast.Tuple) -> VyperType: + def from_annotation(cls, node: vy_ast.Tuple) -> "TupleT": values = node.elements types = tuple(type_from_annotation(v) for v in values) return cls(types) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ce82731c34..ef7e1d0eb4 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -1,27 +1,22 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional from vyper import ast as vy_ast -from vyper.abi_types import ABI_Address, ABI_GIntM, ABI_Tuple, ABIType +from vyper.abi_types import ABI_GIntM, ABI_Tuple, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import ( EnumDeclarationException, EventDeclarationException, - InterfaceViolation, InvalidAttribute, NamespaceCollision, StructureException, UnknownAttribute, VariableDeclarationException, ) -from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import VyperType -from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.subscriptable import HashMapT from vyper.semantics.types.utils import type_from_abi, type_from_annotation from vyper.utils import keccak256 @@ -29,12 +24,19 @@ # user defined type class _UserType(VyperType): + def __init__(self, members=None): + super().__init__(members=members) + def __eq__(self, other): return self is other - # TODO: revisit this once user types can be imported via modules def compare_type(self, other): - return super().compare_type(other) and self._id == other._id + # object exact comparison is a bit tricky here since we have + # to be careful to construct any given user type exactly + # only one time. however, the alternative requires reasoning + # about both the name and source (module or json abi) of + # the type. + return self is other def __hash__(self): return hash(id(self)) @@ -52,7 +54,8 @@ def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: raise EnumDeclarationException("Enums are limited to 256 members!") - super().__init__() + super().__init__(members=None) + self._id = name self._enum_members = members @@ -112,7 +115,7 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": ------- Enum """ - members: Dict = {} + members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): raise EnumDeclarationException("Enum must have members", base_node) @@ -135,7 +138,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: # TODO return None - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: # TODO return [] @@ -160,13 +163,21 @@ class EventT(_UserType): _invalid_locations = tuple(iter(DataLocation)) # not instantiable in any location - def __init__(self, name: str, arguments: dict, indexed: list) -> None: + def __init__( + self, + name: str, + arguments: dict, + indexed: list, + decl_node: Optional[vy_ast.VyperNode] = None, + ) -> None: super().__init__(members=arguments) self.name = name self.indexed = indexed assert len(self.indexed) == len(self.arguments) self.event_id = int(keccak256(self.signature.encode()).hex(), 16) + self.decl_node = decl_node + # backward compatible @property def arguments(self): @@ -187,7 +198,7 @@ def signature(self): return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments.values())})" @classmethod - def from_abi(cls, abi: Dict) -> "EventT": + def from_abi(cls, abi: dict) -> "EventT": """ Generate an `Event` object from an ABI interface. @@ -201,7 +212,7 @@ def from_abi(cls, abi: Dict) -> "EventT": Event object. """ members: dict = {} - indexed: List = [i["indexed"] for i in abi["inputs"]] + indexed: list = [i["indexed"] for i in abi["inputs"]] for item in abi["inputs"]: members[item["name"]] = type_from_abi(item) return cls(abi["name"], members, indexed) @@ -219,11 +230,11 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": ------- Event """ - members: Dict = {} - indexed: List = [] + members: dict = {} + indexed: list = [] if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - return EventT(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): @@ -252,14 +263,14 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": members[member_name] = type_from_annotation(annotation) - return cls(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): validate_expected_type(arg, expected) - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: return [ { "name": self.name, @@ -273,215 +284,6 @@ def to_toplevel_abi_dict(self) -> List[Dict]: ] -class InterfaceT(_UserType): - _type_members = {"address": AddressT()} - _is_prim_word = True - _as_array = True - _as_hashmap_key = True - - def __init__(self, _id: str, members: dict, events: dict) -> None: - validate_unique_method_ids(list(members.values())) # explicit list cast for mypy - super().__init__(members) - - self._id = _id - self.events = events - - @property - def getter_signature(self): - return (), AddressT() - - @property - def abi_type(self) -> ABIType: - return ABI_Address() - - def __repr__(self): - return f"{self._id}" - - # when using the type itself (not an instance) in the call position - # maybe rename to _ctor_call_return - def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": - self._ctor_arg_types(node) - - return self - - def _ctor_arg_types(self, node): - validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] - - def _ctor_kwarg_types(self, node): - return {} - - # TODO x.validate_implements(other) - def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: - namespace = get_namespace() - unimplemented = [] - - def _is_function_implemented(fn_name, fn_type): - vyper_self = namespace["self"].typ - if fn_name not in vyper_self.members: - return False - s = vyper_self.members[fn_name] - if isinstance(s, ContractFunctionT): - to_compare = vyper_self.members[fn_name] - # this is kludgy, rework order of passes in ModuleNodeVisitor - elif isinstance(s, VarInfo) and s.is_public: - to_compare = s.decl_node._metadata["func_type"] - else: - return False - - return to_compare.implements(fn_type) - - # check for missing functions - for name, type_ in self.members.items(): - if not isinstance(type_, ContractFunctionT): - # ex. address - continue - - if not _is_function_implemented(name, type_): - unimplemented.append(name) - - # check for missing events - for name, event in self.events.items(): - if name not in namespace: - unimplemented.append(name) - continue - - if not isinstance(namespace[name], EventT): - unimplemented.append(f"{name} is not an event!") - if ( - namespace[name].event_id != event.event_id - or namespace[name].indexed != event.indexed - ): - unimplemented.append(f"{name} is not implemented! (should be {event})") - - if len(unimplemented) > 0: - # TODO: improve the error message for cases where the - # mismatch is small (like mutability, or just one argument - # is off, etc). - missing_str = ", ".join(sorted(unimplemented)) - raise InterfaceViolation( - f"Contract does not implement all interface functions or events: {missing_str}", - node, - ) - - def to_toplevel_abi_dict(self) -> List[Dict]: - abi = [] - for event in self.events.values(): - abi += event.to_toplevel_abi_dict() - for func in self.functions.values(): - abi += func.to_toplevel_abi_dict() - return abi - - @property - def functions(self): - return {k: v for (k, v) in self.members.items() if isinstance(v, ContractFunctionT)} - - @classmethod - def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": - """ - Generate an `InterfaceT` object from an ABI. - - Arguments - --------- - name : str - The name of the interface - abi : dict - Contract ABI - - Returns - ------- - InterfaceT - primitive interface type - """ - members: Dict = {} - events: Dict = {} - - names = [i["name"] for i in abi if i.get("type") in ("event", "function")] - collisions = set(i for i in names if names.count(i) > 1) - if collisions: - collision_list = ", ".join(sorted(collisions)) - raise NamespaceCollision( - f"ABI '{name}' has multiple functions or events " - f"with the same name: {collision_list}" - ) - - for item in [i for i in abi if i.get("type") == "function"]: - members[item["name"]] = ContractFunctionT.from_abi(item) - for item in [i for i in abi if i.get("type") == "event"]: - events[item["name"]] = EventT.from_abi(item) - - return cls(name, members, events) - - # TODO: split me into from_InterfaceDef and from_Module - @classmethod - def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "InterfaceT": - """ - Generate an `InterfaceT` object from a Vyper ast node. - - Arguments - --------- - node : InterfaceDef | Module - Vyper ast node defining the interface - Returns - ------- - InterfaceT - primitive interface type - """ - if isinstance(node, vy_ast.Module): - members, events = _get_module_definitions(node) - elif isinstance(node, vy_ast.InterfaceDef): - members = _get_class_functions(node) - events = {} - else: - raise StructureException("Invalid syntax for interface definition", node) - - return cls(node.name, members, events) - - -def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[Dict, Dict]: - functions: Dict = {} - events: Dict = {} - for node in base_node.get_children(vy_ast.FunctionDef): - if "external" in [i.id for i in node.decorator_list if isinstance(i, vy_ast.Name)]: - func = ContractFunctionT.from_FunctionDef(node) - functions[node.name] = func - for node in base_node.get_children(vy_ast.VariableDecl, {"is_public": True}): - name = node.target.id - if name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{name}'", base_node - ) - functions[name] = ContractFunctionT.getter_from_VariableDecl(node) - for node in base_node.get_children(vy_ast.EventDef): - name = node.name - if name in functions or name in events: - raise NamespaceCollision( - f"Interface contains multiple objects named '{name}'", base_node - ) - events[name] = EventT.from_EventDef(node) - - return functions, events - - -def _get_class_functions(base_node: vy_ast.InterfaceDef) -> Dict[str, ContractFunctionT]: - functions = {} - for node in base_node.body: - if not isinstance(node, vy_ast.FunctionDef): - raise StructureException("Interfaces can only contain function definitions", node) - if node.name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{node.name}'", node - ) - if len(node.decorator_list) > 0: - raise StructureException( - "Function definition in interface cannot be decorated", node.decorator_list[0] - ) - functions[node.name] = ContractFunctionT.from_FunctionDef(node, is_interface=True) - - return functions - - class StructT(_UserType): _as_array = True @@ -516,7 +318,7 @@ def member_types(self): return self.members @classmethod - def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": + def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": """ Generate a `StructT` object from a Vyper ast node. @@ -531,7 +333,7 @@ def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": """ struct_name = base_node.name - members: Dict[str, VyperType] = {} + members: dict[str, VyperType] = {} for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): raise StructureException( @@ -605,4 +407,4 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": f"Struct declaration does not define all fields: {', '.join(list(members))}", node ) - return StructT(self._id, self.member_types) + return self diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 1187080ca9..8d68a9fa01 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -6,12 +6,13 @@ InstantiationException, InvalidType, StructureException, + UndeclaredDefinition, UnknownType, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace -from vyper.semantics.types.base import VyperType +from vyper.semantics.types.base import TYPE_T, VyperType # TODO maybe this should be merged with .types/base.py @@ -75,7 +76,7 @@ def type_from_annotation( Arguments --------- - node : VyperNode + node: VyperNode Vyper ast node from the `annotation` member of a `VariableDecl` or `AnnAssign` node. Returns @@ -95,12 +96,6 @@ def type_from_annotation( def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: namespace = get_namespace() - def _failwith(type_name): - suggestions_str = get_levenshtein_error_suggestions(type_name, namespace, 0.3) - raise UnknownType( - f"No builtin or user-defined type named '{type_name}'. {suggestions_str}", node - ) from None - if isinstance(node, vy_ast.Tuple): tuple_t = namespace["$TupleT"] return tuple_t.from_annotation(node) @@ -116,11 +111,43 @@ def _failwith(type_name): return type_ctor.from_annotation(node) + # prepare a common error message + err_msg = f"'{node.node_source_code}' is not a type!" + + if isinstance(node, vy_ast.Attribute): + # ex. SomeModule.SomeStruct + + # sanity check - we only allow modules/interfaces to be + # imported as `Name`s currently. + if not isinstance(node.value, vy_ast.Name): + raise InvalidType(err_msg, node) + + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + + interface = module_or_interface + if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo + interface = module_or_interface.module_t.interface + + if not interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = interface.get_type_member(node.attr, node) + assert isinstance(type_t, TYPE_T) # sanity check + return type_t.typedef + if not isinstance(node, vy_ast.Name): # maybe handle this somewhere upstream in ast validation - raise InvalidType(f"'{node.node_source_code}' is not a type", node) - if node.id not in namespace: - _failwith(node.node_source_code) + raise InvalidType(err_msg, node) + + if node.id not in namespace: # type: ignore + suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) + raise UnknownType( + f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}", + node, + ) from None typ_ = namespace[node.id] if hasattr(typ_, "from_annotation"): @@ -138,7 +165,7 @@ def get_index_value(node: vy_ast.Index) -> int: Arguments --------- - node : vy_ast.Index + node: vy_ast.Index Vyper ast node from the `slice` member of a Subscript node. Must be an `Index` object (Vyper does not support `Slice` or `ExtSlice`). @@ -146,6 +173,7 @@ def get_index_value(node: vy_ast.Index) -> int: ------- int Literal integer value. + In the future, will return `None` if the subscript is an Ellipsis """ # this is imported to improve error messages # TODO: revisit this! diff --git a/vyper/utils.py b/vyper/utils.py index 0a2e1f831f..6816db9bae 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -51,6 +51,10 @@ def difference(self, other): def union(self, other): return self | other + def update(self, other): + for item in other: + self.add(item) + def __or__(self, other): return self.__class__(super().__or__(other)) @@ -162,11 +166,6 @@ def method_id(method_str: str) -> bytes: return keccak256(bytes(method_str, "utf-8"))[:4] -# map a string to only-alphanumeric chars -def mkalphanum(s): - return "".join([c if c.isalnum() else "_" for c in s]) - - def round_towards_zero(d: decimal.Decimal) -> int: # TODO double check if this can just be int(d) # (but either way keep this util function bc it's easier at a glance From 0cbc94d01d7be616329a9f70df15733818b4590c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 13:08:34 -0500 Subject: [PATCH 139/201] feat: add short options `-v` and `-O` to the CLI (#3695) this commit adds `-v` and `-O` as aliases for `--verbose` and `--optimize`, respectively. --- vyper/cli/vyper_compile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 4f88812fa0..ec4681a814 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -111,6 +111,7 @@ def _parse_args(argv): ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") parser.add_argument( + "-O", "--optimize", help="Optimization flag (defaults to 'gas')", choices=["gas", "codesize", "none"], @@ -125,6 +126,7 @@ def _parse_args(argv): type=int, ) parser.add_argument( + "-v", "--verbose", help="Turn on compiler verbose output. " "Currently an alias for --traceback-limit but " From b0ea5b6f1c8cd8d09db6f37e9857f9b3837fb386 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 13:42:31 -0500 Subject: [PATCH 140/201] feat: search path resolution for cli (#3694) the current behavior is that the current directory does *not* get into the search path when `-p` is specified, which is annoying. (one would expect `vyper some/directory/some/file.vy` to compile no matter what `-p` is specified as). this commit also handles the addition of multiple search paths specified on the CLI, and adds a long `--path` option as an alternative to `-p`. --- .../cli/vyper_compile/test_compile_files.py | 36 ++++++++++++------- tests/unit/compiler/test_input_bundle.py | 13 +------ tests/utils.py | 12 +++++++ vyper/cli/vyper_compile.py | 19 ++++++---- 4 files changed, 49 insertions(+), 31 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index f6e3a51a4b..2a65d66835 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import working_directory from vyper.cli.vyper_compile import compile_files @@ -19,7 +20,7 @@ def test_combined_json_keys(tmp_path, make_file): "userdoc", "devdoc", } - compile_data = compile_files(["bar.vy"], ["combined_json"], root_folder=tmp_path) + compile_data = compile_files(["bar.vy"], ["combined_json"], paths=[tmp_path]) assert set(compile_data.keys()) == {Path("bar.vy"), "version"} assert set(compile_data[Path("bar.vy")].keys()) == combined_keys @@ -27,7 +28,7 @@ def test_combined_json_keys(tmp_path, make_file): def test_invalid_root_path(): with pytest.raises(FileNotFoundError): - compile_files([], [], root_folder="path/that/does/not/exist") + compile_files([], [], paths=["path/that/does/not/exist"]) CONTRACT_CODE = """ @@ -74,7 +75,7 @@ def test_import_same_folder(import_stmt, alias, tmp_path, make_file): make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("contracts/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) SUBFOLDER_IMPORT_STMT = [ @@ -98,7 +99,7 @@ def test_import_subfolder(import_stmt, alias, tmp_path, make_file): ) make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) OTHER_FOLDER_IMPORT_STMT = [ @@ -115,7 +116,7 @@ def test_import_other_folder(import_stmt, alias, tmp_path, make_file): foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("interfaces/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) def test_import_parent_folder(tmp_path, make_file): @@ -125,10 +126,21 @@ def test_import_parent_folder(tmp_path, make_file): ) make_file("IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) # perform relative import outside of base folder - compile_files([foo], ["combined_json"], root_folder=tmp_path / "contracts") + compile_files([foo], ["combined_json"], paths=[tmp_path / "contracts"]) + + +def test_import_search_paths(tmp_path, make_file): + with working_directory(tmp_path): + contract_code = CONTRACT_CODE.format(import_stmt="from utils import IFoo", alias="IFoo") + contract_filename = "dir1/baz/foo.vy" + interface_filename = "dir2/utils/IFoo.vyi" + make_file(interface_filename, INTERFACE_CODE) + make_file(contract_filename, contract_code) + + assert compile_files([contract_filename], ["combined_json"], paths=["dir2"]) META_IMPORT_STMT = [ @@ -167,7 +179,7 @@ def be_known() -> ISelf.FooStruct: make_file("contracts/ISelf.vyi", interface_code) meta = make_file("contracts/Self.vy", code) - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + assert compile_files([meta], ["combined_json"], paths=[tmp_path]) # implement IFoo in another contract for fun @@ -187,7 +199,7 @@ def bar(_foo: address) -> {alias}.FooStruct: make_file("contracts/IFoo.vyi", INTERFACE_CODE) baz = make_file("contracts/Baz.vy", baz_code) - assert compile_files([baz], ["combined_json"], root_folder=tmp_path) + assert compile_files([baz], ["combined_json"], paths=[tmp_path]) def test_local_namespace(make_file, tmp_path): @@ -215,7 +227,7 @@ def test_local_namespace(make_file, tmp_path): for file_name in ("foo.vyi", "bar.vyi"): make_file(file_name, INTERFACE_CODE) - assert compile_files(paths, ["combined_json"], root_folder=tmp_path) + assert compile_files(paths, ["combined_json"], paths=[tmp_path]) def test_compile_outside_root_path(tmp_path, make_file): @@ -223,7 +235,7 @@ def test_compile_outside_root_path(tmp_path, make_file): make_file("ifoo.vyi", INTERFACE_CODE) foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) - assert compile_files([foo], ["combined_json"], root_folder=".") + assert compile_files([foo], ["combined_json"], paths=None) def test_import_library(tmp_path, make_file): @@ -244,4 +256,4 @@ def foo() -> uint256: make_file("lib.vy", library_source) contract_file = make_file("contract.vy", contract_source) - assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None + assert compile_files([contract_file], ["combined_json"], paths=[tmp_path]) is not None diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index e26555b169..621b529722 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,10 +1,9 @@ -import contextlib import json -import os from pathlib import Path, PurePath import pytest +from tests.utils import working_directory from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, JSONInputBundle @@ -83,16 +82,6 @@ def test_load_abi(make_file, input_bundle, tmp_path): assert file == ABIInput(1, "foo.txt", path, "some string") -@contextlib.contextmanager -def working_directory(directory): - tmp = os.getcwd() - try: - os.chdir(directory) - yield - finally: - os.chdir(tmp) - - # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): foopath = make_file("foo.vy", "contents") diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..0c89c39ff3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,12 @@ +import contextlib +import os + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ec4681a814..25f1180098 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -140,7 +140,7 @@ def _parse_args(argv): ) parser.add_argument("--hex-ir", action="store_true") parser.add_argument( - "-p", help="Set the root path for contract imports", default=".", dest="root_folder" + "--path", "-p", help="Set the root path for contract imports", action="append", dest="paths" ) parser.add_argument("-o", help="Set the output path", dest="output_path") parser.add_argument( @@ -190,7 +190,7 @@ def _parse_args(argv): compiled = compile_files( args.input_files, output_formats, - args.root_folder, + args.paths, args.show_gas_estimates, settings, args.storage_layout, @@ -228,18 +228,23 @@ def exc_handler(contract_path: ContractPath, exception: Exception) -> None: def compile_files( input_files: list[str], output_formats: OutputFormats, - root_folder: str = ".", + paths: list[str] = None, show_gas_estimates: bool = False, settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, experimental_codegen: bool = False, ) -> dict: - root_path = Path(root_folder).resolve() - if not root_path.exists(): - raise FileNotFoundError(f"Invalid root path - '{root_path.as_posix()}' does not exist") + paths = paths or [] - input_bundle = FilesystemInputBundle([root_path]) + # lowest precedence search path is always `.` + search_paths = [Path(".")] + + for p in paths: + path = Path(p).resolve(strict=True) + search_paths.append(path) + + input_bundle = FilesystemInputBundle(search_paths) show_version = False if "combined_json" in output_formats: From 5a67b68b4ba20d050e9a4af913823cbbf0007539 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Wed, 20 Dec 2023 16:12:56 +0200 Subject: [PATCH 141/201] fix: type annotation of helper function (#3702) Fixed the signature of _append_return_for_stack_operand() to take the context not the basic block --- vyper/venom/ir_node_to_venom.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index e2ce28a8f9..0aaf6aba03 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -233,8 +233,9 @@ def _get_variable_from_address( def _append_return_for_stack_operand( - bb: IRBasicBlock, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable ) -> None: + bb = ctx.get_basic_block() if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) new_var = bb.append_instruction("alloca", 32, ret_ir) From 91659266c55ac564d1ed7784a189f5b59b868ced Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 14:41:53 -0500 Subject: [PATCH 142/201] chore: improve exception handling in IR generation (#3705) QOL improvement - improve unannotated exceptions that happen during IR generation to include source info. --- vyper/codegen/expr.py | 6 +++++- vyper/codegen/stmt.py | 11 +++++++++-- vyper/exceptions.py | 20 ++++++++++++++++++-- vyper/semantics/analysis/common.py | 13 +------------ 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 5870e64e98..d5ca5aceee 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import ( + CodegenPanic, CompilerPanic, EvmVersionException, StructureException, @@ -33,6 +34,7 @@ TypeMismatch, UnimplementedException, VyperException, + tag_exceptions, ) from vyper.semantics.types import ( AddressT, @@ -79,7 +81,9 @@ def __init__(self, node, context): if fn is None: raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}", node) - self.ir_node = fn() + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + self.ir_node = fn() + if self.ir_node is None: raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index cc7a603b7c..601597771c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -24,7 +24,13 @@ from vyper.codegen.expr import Expr from vyper.codegen.return_ import make_return_stmt from vyper.evm.address_space import MEMORY, STORAGE -from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure +from vyper.exceptions import ( + CodegenPanic, + CompilerPanic, + StructureException, + TypeCheckFailure, + tag_exceptions, +) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -39,7 +45,8 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}") with context.internal_memory_scope(): - self.ir_node = fn() + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + self.ir_node = fn() if self.ir_node is None: raise TypeCheckFailure("Statement node did not produce IR") diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 993c0a85eb..4846b1c3b1 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -1,3 +1,4 @@ +import contextlib import copy import textwrap import types @@ -322,8 +323,9 @@ class VyperInternalException(_BaseVyperException): def __str__(self): return ( - f"{self.message}\n\nThis is an unhandled internal compiler error. " - "Please create an issue on Github to notify the developers.\n" + f"{super().__str__()}\n\n" + "This is an unhandled internal compiler error. " + "Please create an issue on Github to notify the developers!\n" "https://github.com/vyperlang/vyper/issues/new?template=bug.md" ) @@ -354,3 +356,17 @@ class TypeCheckFailure(VyperInternalException): class InvalidABIType(VyperInternalException): """An internal routine constructed an invalid ABI type""" + + +@contextlib.contextmanager +def tag_exceptions( + node, fallback_exception_type=CompilerPanic, fallback_message="unhandled exception" +): + try: + yield + except _BaseVyperException as e: + if not e.annotations and not e.lineno: + raise e.with_annotation(node) from None + raise e from None + except Exception as e: + raise fallback_exception_type(fallback_message, node) from e diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 9d35aef2bd..198cffca5d 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -1,17 +1,6 @@ -import contextlib from typing import Tuple -from vyper.exceptions import StructureException, VyperException - - -@contextlib.contextmanager -def tag_exceptions(node): - try: - yield - except VyperException as e: - if not e.annotations and not e.lineno: - raise e.with_annotation(node) from None - raise e from None +from vyper.exceptions import StructureException, tag_exceptions class VyperNodeVisitorBase: From 3116e88c886efaf0ea4157852c6c90485357cee7 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Thu, 21 Dec 2023 02:41:42 +0200 Subject: [PATCH 143/201] feat: add new target-constrained jump instruction (#3687) this commit adds a new "djmp" instruction which allows jumping to one of multiple jump targets. it has been added in both the s-expr IR and venom IR. this removes the workarounds that we had to implement in the normalization pass and the cfg calculations. --------- Co-authored-by: Charles Cooper --- tests/unit/ast/test_pre_parser.py | 3 + .../compiler/venom/test_multi_entry_block.py | 41 +++++++++++ vyper/cli/vyper_compile.py | 7 +- vyper/codegen/core.py | 2 + vyper/codegen/module.py | 18 ++--- vyper/compiler/__init__.py | 2 - vyper/compiler/output.py | 3 + vyper/compiler/phases.py | 67 ++++++++--------- vyper/compiler/settings.py | 1 + vyper/ir/compile_ir.py | 7 ++ vyper/utils.py | 1 + vyper/venom/analysis.py | 9 --- vyper/venom/basicblock.py | 17 ++++- vyper/venom/function.py | 12 +--- vyper/venom/ir_node_to_venom.py | 9 +-- vyper/venom/passes/normalization.py | 71 ++++++------------- vyper/venom/venom_to_assembly.py | 15 ++-- 17 files changed, 158 insertions(+), 127 deletions(-) diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 3d072674f6..682c13ca84 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -184,6 +184,9 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve # None is sentinel here meaning that nothing changed compiler_data_settings = pre_parse_settings + # cannot be set via pragma, don't check + compiler_data_settings.experimental_codegen = False + assert compiler_data.settings == compiler_data_settings diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 6e7e6995d6..104697432b 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -95,3 +95,44 @@ def test_multi_entry_block_2(): assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +def test_multi_entry_block_with_dynamic_jump(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("djmp", acc, finish_label, block_1_label) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, finish_label, target_label) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + finish_bb.append_instruction("stop") + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 25f1180098..3063a289ab 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -147,6 +147,7 @@ def _parse_args(argv): "--experimental-codegen", help="The compiler use the new IR codegen. This is an experimental feature.", action="store_true", + dest="experimental_codegen", ) args = parser.parse_args(argv) @@ -184,6 +185,9 @@ def _parse_args(argv): if args.evm_version: settings.evm_version = args.evm_version + if args.experimental_codegen: + settings.experimental_codegen = args.experimental_codegen + if args.verbose: print(f"cli specified: `{settings}`", file=sys.stderr) @@ -195,7 +199,6 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, - args.experimental_codegen, ) if args.output_path: @@ -233,7 +236,6 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, - experimental_codegen: bool = False, ) -> dict: paths = paths or [] @@ -287,7 +289,6 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, - experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index e1d3ea12b4..503e0e2f3b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -892,6 +892,8 @@ def make_setter(left, right): _opt_level = OptimizationLevel.GAS +# FIXME: this is to get around the fact that we don't have a +# proper context object in the IR generation phase. @contextlib.contextmanager def anchor_opt_level(new_level: OptimizationLevel) -> Generator: """ diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index ef861e3953..98395a6a0c 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -311,21 +311,23 @@ def _selector_section_sparse(external_functions, module_ctx): ret.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) - jumpdest = IRnode.from_list(["mload", 0]) - # don't particularly like using `jump` here since it can cause - # issues for other backends, consider changing `goto` to allow - # dynamic jumps, or adding some kind of jumptable instruction - ret.append(["jump", jumpdest]) + jump_targets = [] - jumptable_data = ["data", "selector_buckets"] for i in range(n_buckets): if i in buckets: bucket_label = f"selector_bucket_{i}" - jumptable_data.append(["symbol", bucket_label]) + jump_targets.append(bucket_label) else: # empty bucket - jumptable_data.append(["symbol", "fallback"]) + jump_targets.append("fallback") + + jumptable_data = ["data", "selector_buckets"] + jumptable_data.extend(["symbol", label] for label in jump_targets) + + jumpdest = IRnode.from_list(["mload", 0]) + jump_instr = IRnode.from_list(["djump", jumpdest, *jump_targets]) + ret.append(jump_instr) ret.append(jumptable_data) for bucket_id, bucket in buckets.items(): diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 026c8369c5..c87814ba15 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -53,7 +53,6 @@ def compile_from_file_input( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, - experimental_codegen: bool = False, ) -> dict: """ Main entry point into the compiler. @@ -107,7 +106,6 @@ def compile_from_file_input( storage_layout_override, show_gas_estimates, no_bytecode_metadata, - experimental_codegen, ) ret = {} diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 6d1e7ef70f..dc2a43720e 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -89,6 +89,9 @@ def build_ir_runtime_output(compiler_data: CompilerData) -> IRnode: def _ir_to_dict(ir_node): + # Currently only supported with IRnode and not VenomIR + if not isinstance(ir_node, IRnode): + return args = ir_node.args if len(args) > 0 or ir_node.value == "seq": return {ir_node.value: [_ir_to_dict(x) for x in args]} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index edffa9a85e..199bbbc3e5 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -21,6 +21,26 @@ DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") +def _merge_one(lhs, rhs, helpstr): + if lhs is not None and rhs is not None and lhs != rhs: + raise StructureException( + f"compiler settings indicate {helpstr} {lhs}, " f"but source pragma indicates {rhs}." + ) + return lhs if rhs is None else rhs + + +# TODO: does this belong as a method under Settings? +def _merge_settings(cli: Settings, pragma: Settings): + ret = Settings() + ret.evm_version = _merge_one(cli.evm_version, pragma.evm_version, "evm version") + ret.optimize = _merge_one(cli.optimize, pragma.optimize, "optimize") + ret.experimental_codegen = _merge_one( + cli.experimental_codegen, pragma.experimental_codegen, "experimental codegen" + ) + + return ret + + class CompilerData: """ Object for fetching and storing compiler data for a Vyper contract. @@ -59,7 +79,6 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, - experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -76,11 +95,9 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False - experimental_codegen: bool, optional - Use experimental codegen. Defaults to False """ # to force experimental codegen, uncomment: - # experimental_codegen = True + # settings.experimental_codegen = True if isinstance(file_input, str): file_input = FileInput( @@ -93,7 +110,6 @@ def __init__( self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -120,32 +136,13 @@ def _generate_ast(self): resolved_path=str(self.file_input.resolved_path), ) - # validate the compiler settings - # XXX: this is a bit ugly, clean up later - if settings.evm_version is not None: - if ( - self.settings.evm_version is not None - and self.settings.evm_version != settings.evm_version - ): - raise StructureException( - f"compiler settings indicate evm version {self.settings.evm_version}, " - f"but source pragma indicates {settings.evm_version}." - ) - - self.settings.evm_version = settings.evm_version - - if settings.optimize is not None: - if self.settings.optimize is not None and self.settings.optimize != settings.optimize: - raise StructureException( - f"compiler options indicate optimization mode {self.settings.optimize}, " - f"but source pragma indicates {settings.optimize}." - ) - self.settings.optimize = settings.optimize - - # ensure defaults + self.settings = _merge_settings(self.settings, settings) if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() + if self.settings.experimental_codegen is None: + self.settings.experimental_codegen = False + # note self.settings.compiler_version is erased here as it is # not used after pre-parsing return ast @@ -184,8 +181,10 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) - if self.experimental_codegen: + nodes = generate_ir_nodes( + self.global_ctx, self.settings.optimize, self.settings.experimental_codegen + ) + if self.settings.experimental_codegen: return [generate_ir(nodes[0]), generate_ir(nodes[1])] else: return nodes @@ -211,7 +210,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - if self.experimental_codegen: + if self.settings.experimental_codegen: return generate_assembly_experimental( self.ir_nodes, self.settings.optimize # type: ignore ) @@ -220,7 +219,7 @@ def assembly(self) -> list: @cached_property def assembly_runtime(self) -> list: - if self.experimental_codegen: + if self.settings.experimental_codegen: return generate_assembly_experimental( self.ir_runtime, self.settings.optimize # type: ignore ) @@ -294,7 +293,9 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: +def generate_ir_nodes( + global_ctx: ModuleT, optimize: OptimizationLevel, experimental_codegen: bool +) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index d2c88a8592..51c8d64e41 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -42,6 +42,7 @@ class Settings: compiler_version: Optional[str] = None optimize: Optional[OptimizationLevel] = None evm_version: Optional[str] = None + experimental_codegen: Optional[bool] = None _DEBUG = False diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1d3df8becb..8ce8c887f1 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -702,6 +702,13 @@ def _height_of(witharg): o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) o.extend(["_sym_" + code.args[0].value, "JUMP"]) return o + elif code.value == "djump": + o = [] + # "djump" compiles to a raw EVM jump instruction + jump_target = code.args[0] + o.extend(_compile_to_assembly(jump_target, withargs, existing_labels, break_dest, height)) + o.append("JUMP") + return o # push a literal symbol elif code.value == "symbol": return ["_sym_" + code.args[0].value] diff --git a/vyper/utils.py b/vyper/utils.py index 6816db9bae..a778a4e31b 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -331,6 +331,7 @@ class SizeLimits: "with", "label", "goto", + "djump", # "dynamic jump", i.e. constrained, multi-destination jump "~extcode", "~selfcode", "~calldata", diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 1a82ca85d0..6dfc3c3d7c 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -40,15 +40,6 @@ def calculate_cfg(ctx: IRFunction) -> None: else: entry_block = ctx.basic_blocks[0] - # TODO: Special case for the jump table of selector buckets and fallback. - # this will be cleaner when we introduce an "indirect jump" instruction - # for the selector table (which includes all possible targets). it will - # also clean up the code for normalization because it will not have to - # handle this case specially. - for bb in ctx.basic_blocks: - if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": - bb.add_cfg_in(entry_block) - for bb in ctx.basic_blocks: assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 6f1c1c8ab3..9afaa5e6fd 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -4,7 +4,7 @@ from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -50,12 +50,15 @@ "invalid", "invoke", "jmp", + "djmp", "jnz", "log", ] ) -CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) +CFG_ALTERING_INSTRUCTIONS = frozenset( + ["jmp", "djmp", "jnz", "call", "staticcall", "invoke", "deploy"] +) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -236,6 +239,16 @@ def replace_operands(self, replacements: dict) -> None: if operand in replacements: self.operands[i] = replacements[operand] + def replace_label_operands(self, replacements: dict) -> None: + """ + Update label operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + replacements = {k.value: v for k, v in replacements.items()} + for i, operand in enumerate(self.operands): + if isinstance(operand, IRLabel) and operand.value in replacements: + self.operands[i] = replacements[operand.value] + def __repr__(self) -> str: s = "" if self.output: diff --git a/vyper/venom/function.py b/vyper/venom/function.py index e16b2ad6e6..665fa0c6c2 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -125,17 +125,11 @@ def normalized(self) -> bool: # TODO: this check could be: # `if len(in_bb.cfg_out) > 1: return False` # but the cfg is currently not calculated "correctly" for - # certain special instructions (deploy instruction and - # selector table indirect jumps). + # the special deploy instruction. for in_bb in bb.cfg_in: jump_inst = in_bb.instructions[-1] - if jump_inst.opcode != "jnz": - continue - if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): - continue - - # The function is not normalized - return False + if jump_inst.opcode in ("jnz", "djmp"): + return False # The function is normalized return True diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 0aaf6aba03..9f5c23df0b 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -166,7 +166,6 @@ def _handle_self_call( ret_args.append(return_buf.value) # type: ignore bb = ctx.get_basic_block() - do_ret = func_t.return_type is not None if do_ret: invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore @@ -453,9 +452,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ) # body elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) - elif ir.value == "jump": - arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("jmp", arg_1) + elif ir.value == "djump": + args = [_convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables)] + for target in ir.args[1:]: + args.append(IRLabel(target.value)) + ctx.get_basic_block().append_instruction("djmp", *args) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 90dd60e881..43e8d47235 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,5 +1,5 @@ -from vyper.exceptions import CompilerPanic -from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -19,72 +19,43 @@ def _split_basic_block(self, bb: IRBasicBlock) -> None: jump_inst = in_bb.instructions[-1] assert bb in in_bb.cfg_out - # Handle static and dynamic branching - if jump_inst.opcode == "jnz": - self._split_for_static_branch(bb, in_bb) - elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): - self._split_for_dynamic_branch(bb, in_bb) - else: - continue - - self.changes += 1 - - def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: - jump_inst = in_bb.instructions[-1] - for i, op in enumerate(jump_inst.operands): - if op == bb.label: - edge = i + # Handle branching + if jump_inst.opcode in ("jnz", "djmp"): + self._insert_split_basicblock(bb, in_bb) + self.changes += 1 break - else: - # none of the edges points to this bb - raise CompilerPanic("bad CFG") - - assert edge in (1, 2) # the arguments which can be labels - - split_bb = self._insert_split_basicblock(bb, in_bb) - - # Redirect the original conditional jump to the intermediary basic block - jump_inst.operands[edge] = split_bb.label - - def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: - split_bb = self._insert_split_basicblock(bb, in_bb) - - # Update any affected labels in the data segment - # TODO: this DESTROYS the cfg! refactor so the translation of the - # selector table produces indirect jumps properly. - for inst in self.ctx.data_segment: - if inst.opcode == "db" and inst.operands[0] == bb.label: - inst.operands[0] = split_bb.label def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: # Create an intermediary basic block and append it source = in_bb.label.value target = bb.label.value - split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + + split_label = IRLabel(f"{target}_split_{source}") + in_terminal = in_bb.instructions[-1] + in_terminal.replace_label_operands({bb.label: split_label}) + + split_bb = IRBasicBlock(split_label, self.ctx) split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) - # Rewire the CFG - # TODO: this is cursed code, it is necessary instead of just running - # calculate_cfg() because split_for_dynamic_branch destroys the CFG! - # ideally, remove this rewiring and just re-run calculate_cfg(). - split_bb.add_cfg_in(in_bb) - split_bb.add_cfg_out(bb) - in_bb.remove_cfg_out(bb) - in_bb.add_cfg_out(split_bb) - bb.remove_cfg_in(in_bb) - bb.add_cfg_in(split_bb) + # Update the labels in the data segment + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + return split_bb def _run_pass(self, ctx: IRFunction) -> int: self.ctx = ctx self.changes = 0 + # Split blocks that need splitting for bb in ctx.basic_blocks: if len(bb.cfg_in) > 1: self._split_basic_block(bb) - # Sanity check - assert ctx.normalized, "Normalization pass failed" + # If we made changes, recalculate the cfg + if self.changes > 0: + calculate_cfg(ctx) return self.changes diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 8760e9aa63..0c32c3b816 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -261,7 +261,7 @@ def _generate_evm_for_instruction( # Step 1: Apply instruction special stack manipulations - if opcode in ["jmp", "jnz", "invoke"]: + if opcode in ["jmp", "djmp", "jnz", "invoke"]: operands = inst.get_non_label_operands() elif opcode == "alloca": operands = inst.operands[1:2] @@ -296,7 +296,7 @@ def _generate_evm_for_instruction( self._emit_input_operands(assembly, inst, operands, stack) # Step 3: Reorder stack - if opcode in ["jnz", "jmp"]: + if opcode in ["jnz", "djmp", "jmp"]: # prepare stack for jump into another basic block assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) b = next(iter(inst.parent.cfg_out)) @@ -344,11 +344,12 @@ def _generate_evm_for_instruction( assembly.append("JUMP") elif opcode == "jmp": - if isinstance(inst.operands[0], IRLabel): - assembly.append(f"_sym_{inst.operands[0].value}") - assembly.append("JUMP") - else: - assembly.append("JUMP") + assert isinstance(inst.operands[0], IRLabel) + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + elif opcode == "djmp": + assert isinstance(inst.operands[0], IRVariable) + assembly.append("JUMP") elif opcode == "gt": assembly.append("GT") elif opcode == "lt": From 8958bffc7755d8d99300803be9a07245fb345593 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Thu, 21 Dec 2023 18:18:05 +0100 Subject: [PATCH 144/201] chore: update lint dependencies (#3704) - Updated dependencies - Fixed new discovered issues - Most issues were caused by B023: https://docs.astral.sh/ruff/rules/function-uses-loop-variable/ - The issue with using loop variables is explained further here: https://docs.python-guide.org/writing/gotchas/#late-binding-closures --- .pre-commit-config.yaml | 14 +++++++++---- setup.py | 12 +++++------ .../builtins/codegen/test_extract32.py | 2 +- .../test_default_function.py | 2 +- .../features/decorators/test_payable.py | 6 ++++-- .../codegen/features/test_clampers.py | 11 ++++++---- .../functional/codegen/test_selector_table.py | 8 ++++---- .../codegen/types/numbers/test_signed_ints.py | 20 ++++++++++--------- .../types/numbers/test_unsigned_ints.py | 20 ++++++++++--------- .../codegen/types/test_node_types.py | 2 +- tests/functional/syntax/test_address_code.py | 1 - tests/functional/syntax/test_immutables.py | 2 +- tests/functional/syntax/test_no_none.py | 16 +++++++++++---- .../unit/abi_types/test_invalid_abi_types.py | 2 +- vyper/cli/vyper_json.py | 3 ++- vyper/compiler/phases.py | 3 ++- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/function.py | 1 + vyper/semantics/types/module.py | 15 ++++++++------ vyper/semantics/types/subscriptable.py | 2 +- vyper/utils.py | 4 +++- 21 files changed, 89 insertions(+), 59 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b416a4414..b943b5d31d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,27 +1,33 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.9.3 + rev: 5.13.2 hooks: - id: isort name: isort - repo: https://github.com/psf/black - rev: 21.9b0 + rev: 23.12.0 hooks: - id: black name: black - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 6.1.0 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v1.7.1 hooks: - id: mypy additional_dependencies: - "types-setuptools" + args: # settings from tox.ini + - --install-types + - --non-interactive + - --follow-imports=silent + - --ignore-missing-imports + - --implicit-optional default_language_version: python: python3.10 diff --git a/setup.py b/setup.py index 431c50b74b..f5d643ad88 100644 --- a/setup.py +++ b/setup.py @@ -22,12 +22,12 @@ "eth-stdlib==0.2.6", ], "lint": [ - "black==23.3.0", - "flake8==3.9.2", - "flake8-bugbear==20.1.4", - "flake8-use-fstring==1.1", - "isort==5.9.3", - "mypy==0.982", + "black==23.12.0", + "flake8==6.1.0", + "flake8-bugbear==23.12.2", + "flake8-use-fstring==1.4", + "isort==5.13.2", + "mypy==1.5", ], "docs": ["recommonmark", "sphinx>=6.0,<7.0", "sphinx_rtd_theme>=1.2,<1.3"], "dev": ["ipython", "pre-commit", "pyinstaller", "twine"], diff --git a/tests/functional/builtins/codegen/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py index c1a333ae32..6e4ee09abc 100644 --- a/tests/functional/builtins/codegen/test_extract32.py +++ b/tests/functional/builtins/codegen/test_extract32.py @@ -36,7 +36,7 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32: for S, i in test_cases: expected_result = S[i : i + 32] if 0 <= i <= len(S) - 32 else None if expected_result is None: - assert_tx_failed(lambda: c.extrakt32(S, i)) + assert_tx_failed(lambda p=(S, i): c.extrakt32(*p)) else: assert c.extrakt32(S, i) == expected_result assert c.extrakt32_mem(S, i) == expected_result diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index 4ad68697ac..f7eef21af7 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -143,7 +143,7 @@ def _call_with_bytes(hexstr): for i in range(4, 36): # match the full 4 selector bytes, but revert due to malformed (short) calldata - assert_tx_failed(lambda: _call_with_bytes("0x" + "00" * i)) + assert_tx_failed(lambda p="0x" + "00" * i: _call_with_bytes(p)) def test_another_zero_method_id(w3, get_logs, get_contract, assert_tx_failed): diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index 55c60236f4..4858a7df0d 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -352,7 +352,7 @@ def __default__(): """ c = get_contract(code) - w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}), + w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) def test_nonpayable_default_func_invalid_calldata(get_contract, w3, assert_tx_failed): @@ -391,5 +391,7 @@ def __default__(): for i in range(5): calldata = "0x" + data[:i].hex() assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": calldata}) + lambda data=calldata: w3.eth.send_transaction( + {"to": c.address, "value": 100, "data": data} + ) ) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index ad7ea32b1e..08ad349c09 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -118,8 +118,11 @@ def foo(s: bytes{n}) -> bytes{n}: c = get_contract(code, evm_version=evm_version) for v in values: # munge for `_make_tx` - v = int.from_bytes(v, byteorder="big") - assert_tx_failed(lambda: _make_tx(w3, c.address, f"foo(bytes{n})", [v])) + assert_tx_failed( + lambda val=int.from_bytes(v, byteorder="big"): _make_tx( + w3, c.address, f"foo(bytes{n})", [val] + ) + ) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -153,7 +156,7 @@ def foo(s: int{bits}) -> int{bits}: c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda: _make_tx(w3, c.address, f"foo(int{bits})", [v])) + assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(int{bits})", [val])) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -250,7 +253,7 @@ def foo(s: uint{bits}) -> uint{bits}: """ c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda: _make_tx(w3, c.address, f"foo(uint{bits})", [v])) + assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(uint{bits})", [val])) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) diff --git a/tests/functional/codegen/test_selector_table.py b/tests/functional/codegen/test_selector_table.py index 161cd480fd..abea81ced4 100644 --- a/tests/functional/codegen/test_selector_table.py +++ b/tests/functional/codegen/test_selector_table.py @@ -600,7 +600,7 @@ def __default__(): else: hexstr = (method_id + argsdata).hex() txdata = {"to": c.address, "data": hexstr, "value": 1} - assert_tx_failed(lambda: w3.eth.send_transaction(txdata)) + assert_tx_failed(lambda d=txdata: w3.eth.send_transaction(d)) # now do calldatasize check # strip some bytes @@ -610,7 +610,7 @@ def __default__(): if n_calldata_words == 0 and j == 0: # no args, hit default function if default_fn_mutability == "": - assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) elif default_fn_mutability == "@payable": # we should be able to send eth to it tx_params["value"] = 1 @@ -628,8 +628,8 @@ def __default__(): # check default function reverts tx_params["value"] = 1 - assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) else: - assert_tx_failed(lambda: w3.eth.send_transaction(tx_params)) + assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) _test() diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 281aab429c..3e44beb826 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -116,7 +116,7 @@ def foo(x: {typ}) -> {typ}: test_cases = [0, 1, 3, 4, 126, 127, -1, lo, hi] for x in test_cases: if x * 2 >= typ.bits or x < 0: # out of bounds - assert_tx_failed(lambda: c.foo(x)) + assert_tx_failed(lambda p=x: c.foo(p)) else: assert c.foo(x) == 4**x @@ -304,15 +304,17 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda: c.foo(x, y)) - assert_compile_failed(lambda: get_contract(code_2), ZeroDivisionException) - assert_tx_failed(lambda: get_contract(code_3).foo(y)) - assert_compile_failed(lambda: get_contract(code_4), ZeroDivisionException) + assert_tx_failed(lambda p=(x, y): c.foo(*p)) + assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) + assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) else: - assert_tx_failed(lambda: c.foo(x, y)) - assert_tx_failed(lambda: get_contract(code_2).foo(x)) - assert_tx_failed(lambda: get_contract(code_3).foo(y)) - assert_compile_failed(lambda: get_contract(code_4), (InvalidType, OverflowException)) + assert_tx_failed(lambda p=(x, y): c.foo(*p)) + assert_tx_failed(lambda p=x, code=code_2: get_contract(code).foo(p)) + assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + assert_compile_failed( + lambda code=code_4: get_contract(code), (InvalidType, OverflowException) + ) COMPARISON_OPS = { diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 683684e6be..6c8d114f29 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -148,15 +148,17 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda: c.foo(x, y)) - assert_compile_failed(lambda: get_contract(code_2), ZeroDivisionException) - assert_tx_failed(lambda: get_contract(code_3).foo(y)) - assert_compile_failed(lambda: get_contract(code_4), ZeroDivisionException) + assert_tx_failed(lambda p=(x, y): c.foo(*p)) + assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) + assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) else: - assert_tx_failed(lambda: c.foo(x, y)) - assert_tx_failed(lambda: get_contract(code_2).foo(x)) - assert_tx_failed(lambda: get_contract(code_3).foo(y)) - assert_compile_failed(lambda: get_contract(code_4), (InvalidType, OverflowException)) + assert_tx_failed(lambda p=(x, y): c.foo(*p)) + assert_tx_failed(lambda code=code_2, p=x: get_contract(code).foo(p)) + assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + assert_compile_failed( + lambda code=code_4: get_contract(code), (InvalidType, OverflowException) + ) COMPARISON_OPS = { @@ -213,7 +215,7 @@ def test() -> {typ}: assert c.test() == val for val in bad_cases: - assert_compile_failed(lambda: get_contract(code_template.format(typ=typ, val=val))) + assert_compile_failed(lambda v=val: get_contract(code_template.format(typ=typ, val=v))) @pytest.mark.parametrize("typ", types) diff --git a/tests/functional/codegen/types/test_node_types.py b/tests/functional/codegen/types/test_node_types.py index b6561ae8eb..8a2b1681d7 100644 --- a/tests/functional/codegen/types/test_node_types.py +++ b/tests/functional/codegen/types/test_node_types.py @@ -63,5 +63,5 @@ def test_type_storage_sizes(): assert struct_.storage_size_in_words == 2 # Don't allow unknown types. - with raises(Exception): + with raises(AttributeError): _ = int.storage_size_in_words diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index 70ba5cbbf7..fa6ed20117 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -125,7 +125,6 @@ def test_address_code_compile_error( ): with pytest.raises(error_type) as excinfo: compiler.compile_code(bad_code) - assert type(excinfo.value) == error_type assert excinfo.value.message == error_message diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index ab38f6b56d..1027d9fe66 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -63,7 +63,7 @@ def __init__(_value: uint256): @pytest.mark.parametrize("bad_code", fail_list) def test_compilation_fails_with_exception(bad_code): - with pytest.raises(Exception): + with pytest.raises(VyperException): compile_code(bad_code) diff --git a/tests/functional/syntax/test_no_none.py b/tests/functional/syntax/test_no_none.py index 24c32a46a4..085ce395ab 100644 --- a/tests/functional/syntax/test_no_none.py +++ b/tests/functional/syntax/test_no_none.py @@ -72,7 +72,9 @@ def foo(): ] for contract in contracts: - assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), InvalidLiteral) + assert_compile_failed( + lambda c=contract: get_contract_with_gas_estimation(c), InvalidLiteral + ) def test_no_is_none(assert_compile_failed, get_contract_with_gas_estimation): @@ -116,7 +118,9 @@ def foo(): ] for contract in contracts: - assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), SyntaxException) + assert_compile_failed( + lambda c=contract: get_contract_with_gas_estimation(c), SyntaxException + ) def test_no_eq_none(assert_compile_failed, get_contract_with_gas_estimation): @@ -160,7 +164,9 @@ def foo(): ] for contract in contracts: - assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), InvalidLiteral) + assert_compile_failed( + lambda c=contract: get_contract_with_gas_estimation(c), InvalidLiteral + ) def test_struct_none(assert_compile_failed, get_contract_with_gas_estimation): @@ -195,4 +201,6 @@ def foo(): ] for contract in contracts: - assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), InvalidLiteral) + assert_compile_failed( + lambda c=contract: get_contract_with_gas_estimation(c), InvalidLiteral + ) diff --git a/tests/unit/abi_types/test_invalid_abi_types.py b/tests/unit/abi_types/test_invalid_abi_types.py index c8566e066f..1a8a7db884 100644 --- a/tests/unit/abi_types/test_invalid_abi_types.py +++ b/tests/unit/abi_types/test_invalid_abi_types.py @@ -23,4 +23,4 @@ def test_invalid_abi_types(assert_compile_failed, typ, params_variants): # double parametrization cannot work because the 2nd dimension is variable for params in params_variants: - assert_compile_failed(lambda: typ(*params), InvalidABIType) + assert_compile_failed(lambda p=params: typ(*p), InvalidABIType) diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 63da2e0643..032d7ebe64 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -262,7 +262,8 @@ def compile_from_input_dict( if isinstance(optimize, bool): # bool optimization level for backwards compatibility warnings.warn( - "optimize: is deprecated! please use one of 'gas', 'codesize', 'none'." + "optimize: is deprecated! please use one of 'gas', 'codesize', 'none'.", + stacklevel=2, ) optimize = OptimizationLevel.default() if optimize else OptimizationLevel.NONE elif isinstance(optimize, str): diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 199bbbc3e5..b9b2df6ae8 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -343,7 +343,8 @@ def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = if _find_nested_opcode(assembly, "DEBUG"): warnings.warn( "This code contains DEBUG opcodes! The DEBUG opcode will only work in " - "a supported EVM! It will FAIL on all other nodes!" + "a supported EVM! It will FAIL on all other nodes!", + stacklevel=2, ) return assembly diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d22d9bfff9..6ecfe78be3 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -92,7 +92,7 @@ def __hash__(self): def __eq__(self, other): return ( - type(self) == type(other) and self._get_equality_attrs() == other._get_equality_attrs() + type(self) is type(other) and self._get_equality_attrs() == other._get_equality_attrs() ) def __lt__(self, other): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index ec30ac85d6..34206546fd 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -644,6 +644,7 @@ def _parse_decorators( "'@constant' decorator has been removed (see VIP2040). " "Use `@view` instead.", DeprecationWarning, + stacklevel=2, ) raise FunctionDeclarationException(f"Unknown decorator: {decorator.id}", decorator) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 4622482951..b0d7800011 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -232,14 +232,17 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": @classmethod def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": functions = [] - for node in node.body: - if not isinstance(node, vy_ast.FunctionDef): - raise StructureException("Interfaces can only contain function definitions", node) - if len(node.decorator_list) > 0: + for func_ast in node.body: + if not isinstance(func_ast, vy_ast.FunctionDef): raise StructureException( - "Function definition in interface cannot be decorated", node.decorator_list[0] + "Interfaces can only contain function definitions", func_ast ) - functions.append((node.name, ContractFunctionT.from_InterfaceDef(node))) + if len(func_ast.decorator_list) > 0: + raise StructureException( + "Function definition in interface cannot be decorated", + func_ast.decorator_list[0], + ) + functions.append((func_ast.name, ContractFunctionT.from_InterfaceDef(func_ast))) # no structs or events in InterfaceDefs events: list = [] diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 46dffbdec4..0c8e9fddd8 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -112,7 +112,7 @@ def __init__(self, value_type: VyperType, length: int): raise InvalidType("Array length is invalid") if length >= 2**64: - warnings.warn("Use of large arrays can be unsafe!") + warnings.warn("Use of large arrays can be unsafe!", stacklevel=2) super().__init__(UINT256_T, value_type) self.length = length diff --git a/vyper/utils.py b/vyper/utils.py index a778a4e31b..2349731b97 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -70,7 +70,9 @@ def __setattr__(self, name, value): raise DecimalOverrideException("Overriding decimal precision disabled") elif value > 78: # not sure it's incorrect, might not be end of the world - warnings.warn("Changing decimals precision could have unintended side effects!") + warnings.warn( + "Changing decimals precision could have unintended side effects!", stacklevel=2 + ) # else: no-op, is ok super().__setattr__(name, value) From 88c09a218d64e2f158499ac1f0c1c71e4b4e6b86 Mon Sep 17 00:00:00 2001 From: Alberto Date: Sat, 23 Dec 2023 17:09:50 +0100 Subject: [PATCH 145/201] feat: replace `enum` with `flag` keyword (#3697) per title, replace `enum` with `flag` as it more closely models https://docs.python.org/3/library/enum.html#enum.IntFlag than regular enums. allow `enum` for now (for backwards compatibility) but convert to `flag` internally and issue a warning --- docs/types.rst | 22 ++++---- .../builtins/codegen/test_convert.py | 8 +-- .../codegen/features/test_assignment.py | 4 +- .../codegen/features/test_clampers.py | 8 +-- .../codegen/types/test_dynamic_array.py | 10 ++-- .../types/{test_enum.py => test_flag.py} | 28 +++++----- tests/functional/syntax/test_dynamic_array.py | 4 +- .../syntax/{test_enum.py => test_flag.py} | 52 +++++++++---------- tests/functional/syntax/test_public.py | 4 +- vyper/ast/folding.py | 2 +- vyper/ast/grammar.lark | 10 +++- vyper/ast/identifiers.py | 1 + vyper/ast/nodes.py | 21 +++++++- vyper/ast/nodes.pyi | 2 +- vyper/ast/pre_parser.py | 5 +- vyper/builtins/_convert.py | 10 ++-- vyper/codegen/core.py | 10 ++-- vyper/codegen/expr.py | 12 ++--- vyper/exceptions.py | 4 +- vyper/semantics/analysis/local.py | 4 +- vyper/semantics/analysis/module.py | 6 +-- vyper/semantics/analysis/utils.py | 4 +- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/user.py | 14 ++--- 24 files changed, 137 insertions(+), 110 deletions(-) rename tests/functional/codegen/types/{test_enum.py => test_flag.py} (93%) rename tests/functional/syntax/{test_enum.py => test_flag.py} (74%) diff --git a/docs/types.rst b/docs/types.rst index d669e6946d..0ad13967e9 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -376,22 +376,22 @@ On the ABI level the Fixed-size bytes array is annotated as ``string``. example_str: String[100] = "Test String" -Enums +Flags ----- -**Keyword:** ``enum`` +**Keyword:** ``flag`` -Enums are custom defined types. An enum must have at least one member, and can hold up to a maximum of 256 members. +Flags are custom defined types. A flag must have at least one member, and can hold up to a maximum of 256 members. The members are represented by ``uint256`` values in the form of 2\ :sup:`n` where ``n`` is the index of the member in the range ``0 <= n <= 255``. .. code-block:: python - # Defining an enum with two members - enum Roles: + # Defining a flag with two members + flag Roles: ADMIN USER - # Declaring an enum variable + # Declaring a flag variable role: Roles = Roles.ADMIN # Returning a member @@ -426,13 +426,13 @@ Operator Description ``~x`` Bitwise not ============= ====================== -Enum members can be combined using the above bitwise operators. While enum members have values that are power of two, enum member combinations may not. +Flag members can be combined using the above bitwise operators. While flag members have values that are power of two, flag member combinations may not. -The ``in`` and ``not in`` operators can be used in conjunction with enum member combinations to check for membership. +The ``in`` and ``not in`` operators can be used in conjunction with flag member combinations to check for membership. .. code-block:: python - enum Roles: + flag Roles: MANAGER ADMIN USER @@ -447,7 +447,7 @@ The ``in`` and ``not in`` operators can be used in conjunction with enum member def bar(a: Roles) -> bool: return a not in (Roles.MANAGER | Roles.USER) -Note that ``in`` is not the same as strict equality (``==``). ``in`` checks that *any* of the flags on two enum objects are simultaneously set, while ``==`` checks that two enum objects are bit-for-bit equal. +Note that ``in`` is not the same as strict equality (``==``). ``in`` checks that *any* of the flags on two flag objects are simultaneously set, while ``==`` checks that two flag objects are bit-for-bit equal. The following code uses bitwise operations to add and revoke permissions from a given ``Roles`` object. @@ -488,7 +488,7 @@ Fixed-size Lists Fixed-size lists hold a finite number of elements which belong to a specified type. -Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and enums. +Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and flags. .. code-block:: python diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index b5ce613235..99dae4a932 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -486,10 +486,10 @@ def test_memory_variable_convert(x: {i_typ}) -> {o_typ}: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 2**128, 2**256 - 1, 2**256 - 2]) -def test_enum_conversion(get_contract_with_gas_estimation, assert_compile_failed, val, typ): +def test_flag_conversion(get_contract_with_gas_estimation, assert_compile_failed, val, typ): roles = "\n ".join([f"ROLE_{i}" for i in range(256)]) contract = f""" -enum Roles: +flag Roles: {roles} @external @@ -510,11 +510,11 @@ def bar(a: uint256) -> Roles: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2]) -def test_enum_conversion_2( +def test_flag_conversion_2( get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ ): contract = f""" -enum Status: +flag Status: STARTED PAUSED STOPPED diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index cd26659a5c..9af7058250 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -66,7 +66,7 @@ def bar(x: {typ}) -> {typ}: def test_internal_assign_struct(get_contract_with_gas_estimation): code = """ -enum Bar: +flag Bar: BAD BAK BAZ @@ -92,7 +92,7 @@ def bar(x: Foo) -> Foo: def test_internal_assign_struct_member(get_contract_with_gas_estimation): code = """ -enum Bar: +flag Bar: BAD BAK BAZ diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 08ad349c09..263f10a89c 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -187,9 +187,9 @@ def foo(s: bool) -> bool: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [0] + [2**i for i in range(5)]) -def test_enum_clamper_passing(w3, get_contract, value, evm_version): +def test_flag_clamper_passing(w3, get_contract, value, evm_version): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -207,9 +207,9 @@ def foo(s: Roles) -> Roles: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**i for i in range(5, 256)]) -def test_enum_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_flag_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 9231d1979f..d793a56d6e 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -102,7 +102,7 @@ def foo6() -> DynArray[DynArray[String[32], 2], 2]: def test_list_output_tester_code(get_contract_with_gas_estimation): list_output_tester_code = """ -enum Foobar: +flag Foobar: FOO BAR @@ -1247,13 +1247,13 @@ def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check """ code = struct_def + "\n" + code elif subtype == "DynArray[Foobar, 3]": - enum_def = """ -enum Foobar: + flag_def = """ +flag Foobar: FOO BAR BAZ """ - code = enum_def + "\n" + code + code = flag_def + "\n" + code test_data = [2 ** (i - 1) for i in test_data] c = get_contract(code) @@ -1292,7 +1292,7 @@ def foo() -> (uint256, DynArray[uint256, 3], DynArray[uint256, 2]): def test_list_of_structs_arg(get_contract): code = """ -enum Foobar: +flag Foobar: FOO BAR diff --git a/tests/functional/codegen/types/test_enum.py b/tests/functional/codegen/types/test_flag.py similarity index 93% rename from tests/functional/codegen/types/test_enum.py rename to tests/functional/codegen/types/test_flag.py index c66efff566..03c22134ed 100644 --- a/tests/functional/codegen/types/test_enum.py +++ b/tests/functional/codegen/types/test_flag.py @@ -1,6 +1,6 @@ def test_values_should_be_increasing_ints(get_contract): code = """ -enum Action: +flag Action: BUY SELL CANCEL @@ -26,9 +26,9 @@ def cancel() -> Action: assert c.cancel() == 4 -def test_enum_storage(get_contract): +def test_flag_storage(get_contract): code = """ -enum Actions: +flag Actions: BUY SELL CANCEL @@ -49,7 +49,7 @@ def set_and_get(a: Actions) -> Actions: def test_eq_neq(get_contract): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -76,7 +76,7 @@ def is_not_boss(a: Roles) -> bool: def test_bitwise(get_contract, assert_tx_failed): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -147,7 +147,7 @@ def binv_arg(a: Roles) -> Roles: def test_augassign_storage(get_contract, w3, assert_tx_failed): code = """ -enum Roles: +flag Roles: ADMIN MINTER @@ -214,9 +214,9 @@ def checkMinter(minter: address): assert_tx_failed(lambda: c.checkMinter(admin_address)) -def test_in_enum(get_contract_with_gas_estimation): +def test_in_flag(get_contract_with_gas_estimation): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -259,9 +259,9 @@ def baz(a: Roles) -> bool: assert c.baz(0b01000) is False # Roles.MANAGER should fail -def test_struct_with_enum(get_contract_with_gas_estimation): +def test_struct_with_flag(get_contract_with_gas_estimation): code = """ -enum Foobar: +flag Foobar: FOO BAR @@ -270,17 +270,17 @@ def test_struct_with_enum(get_contract_with_gas_estimation): b: Foobar @external -def get_enum_from_struct() -> Foobar: +def get_flag_from_struct() -> Foobar: f: Foo = Foo({a: 1, b: Foobar.BAR}) return f.b """ c = get_contract_with_gas_estimation(code) - assert c.get_enum_from_struct() == 2 + assert c.get_flag_from_struct() == 2 -def test_mapping_with_enum(get_contract_with_gas_estimation): +def test_mapping_with_flag(get_contract_with_gas_estimation): code = """ -enum Foobar: +flag Foobar: FOO BAR diff --git a/tests/functional/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py index 0c23bf67da..99a01a17c8 100644 --- a/tests/functional/syntax/test_dynamic_array.py +++ b/tests/functional/syntax/test_dynamic_array.py @@ -34,12 +34,12 @@ def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): valid_list = [ """ -enum Foo: +flag Foo: FE FI bar: DynArray[Foo, 10] - """, # dynamic arrays of enums are allowed, but not static arrays + """, # dynamic arrays of flags are allowed, but not static arrays """ bar: DynArray[Bytes[30], 10] """, # dynamic arrays of bytestrings are allowed, but not static arrays diff --git a/tests/functional/syntax/test_enum.py b/tests/functional/syntax/test_flag.py similarity index 74% rename from tests/functional/syntax/test_enum.py rename to tests/functional/syntax/test_flag.py index 9bb74fb675..22309502b7 100644 --- a/tests/functional/syntax/test_enum.py +++ b/tests/functional/syntax/test_flag.py @@ -2,7 +2,7 @@ from vyper import compiler from vyper.exceptions import ( - EnumDeclarationException, + FlagDeclarationException, InvalidOperation, NamespaceCollision, StructureException, @@ -16,7 +16,7 @@ event Action: pass -enum Action: +flag Action: BUY SELL """, @@ -24,23 +24,23 @@ ), ( """ -enum Action: +flag Action: pass """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Action: +flag Action: BUY BUY """, - EnumDeclarationException, + FlagDeclarationException, ), - ("enum Foo:\n" + "\n".join([f" member{i}" for i in range(257)]), EnumDeclarationException), + ("flag Foo:\n" + "\n".join([f" member{i}" for i in range(257)]), FlagDeclarationException), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -53,20 +53,20 @@ def foo(x: Roles) -> bool: ), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @external def foo(x: Roles) -> Roles: - return x.USER # can't dereference on enum instance + return x.USER # can't dereference on flag instance """, StructureException, ), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -79,28 +79,28 @@ def foo(x: Roles) -> bool: ), ( """ -enum Functions: +flag Functions: def foo():nonpayable """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Numbers: +flag Numbers: a:constant(uint256) = a """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Numbers: +flag Numbers: 12 """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Roles: +flag Roles: ADMIN USER @@ -112,9 +112,9 @@ def foo() -> Roles: ), ( """ -enum A: +flag A: a -enum B: +flag B: a b @@ -135,12 +135,12 @@ def test_fail_cases(bad_code): valid_list = [ """ -enum Action: +flag Action: BUY SELL """, """ -enum Action: +flag Action: BUY SELL @external @@ -148,7 +148,7 @@ def run() -> Action: return Action.BUY """, """ -enum Action: +flag Action: BUY SELL @@ -163,16 +163,16 @@ def run() -> Order: amount: 10**18 }) """, - "enum Foo:\n" + "\n".join([f" member{i}" for i in range(256)]), + "flag Foo:\n" + "\n".join([f" member{i}" for i in range(256)]), """ a: constant(uint256) = 1 -enum A: +flag A: a """, ] @pytest.mark.parametrize("good_code", valid_list) -def test_enum_success(good_code): +def test_flag_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 68575ebd41..71bff753f4 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -30,9 +30,9 @@ def foo() -> int128: x: public(HashMap[uint256, Foo]) """, - # expansion of public user-defined enum + # expansion of public user-defined flag """ -enum Foo: +flag Foo: BAR x: public(HashMap[uint256, Foo]) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 38d58f6fd0..087708a356 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -246,7 +246,7 @@ def replace_constant( continue # do not replace enum members - if node.get_ancestor(vy_ast.EnumDef): + if node.get_ancestor(vy_ast.FlagDef): continue try: diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 15367ce94a..7889473b19 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -10,7 +10,8 @@ module: ( DOCSTRING | interface_def | constant_def | variable_def - | enum_def + | enum_def // TODO deprecate at some point in favor of flag + | flag_def | event_def | function_def | immutable_def @@ -76,12 +77,19 @@ indexed_event_arg: NAME ":" "indexed" "(" type ")" event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT event_def: _EVENT_DECL NAME ":" ( event_body | _PASS ) +// TODO deprecate in favor of flag // Enums _ENUM_DECL: "enum" enum_member: NAME enum_body: _NEWLINE _INDENT (enum_member _NEWLINE)+ _DEDENT enum_def: _ENUM_DECL NAME ":" enum_body +// Flags +_FLAG_DECL: "flag" +flag_member: NAME +flag_body: _NEWLINE _INDENT (flag_member _NEWLINE)+ _DEDENT +flag_def: _FLAG_DECL NAME ":" flag_body + // Types array_def: (NAME | array_def | dyn_array_def) "[" _expr "]" dyn_array_def: "DynArray" "[" (NAME | array_def | dyn_array_def) "," _expr "]" diff --git a/vyper/ast/identifiers.py b/vyper/ast/identifiers.py index 985b04e5cd..7d42727066 100644 --- a/vyper/ast/identifiers.py +++ b/vyper/ast/identifiers.py @@ -69,6 +69,7 @@ def validate_identifier(attr, ast_node=None): "struct", "event", "enum", + "flag" # EVM operations "unreachable", # special functions (no name mangling) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 3bccc5f141..dba9f2a22d 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -4,6 +4,7 @@ import decimal import operator import sys +import warnings from typing import Any, Optional, Union from vyper.ast.metadata import NodeMetadata @@ -18,6 +19,7 @@ SyntaxException, TypeMismatch, UnfoldableNode, + VyperException, ZeroDivisionException, ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code @@ -78,6 +80,11 @@ def get_node( else: ast_struct["ast_type"] = "VariableDecl" + enum_warn = False + if ast_struct["ast_type"] == "EnumDef": + enum_warn = True + ast_struct["ast_type"] = "FlagDef" + vy_class = getattr(sys.modules[__name__], ast_struct["ast_type"], None) if not vy_class: if ast_struct["ast_type"] == "Delete": @@ -92,7 +99,17 @@ def get_node( ast_struct, ) - return vy_class(parent=parent, **ast_struct) + node = vy_class(parent=parent, **ast_struct) + + # TODO: Putting this after node creation to pretty print, remove after enum deprecation + if enum_warn: + # TODO: hack to pretty print, logic should be factored out of exception + pretty_printed_node = str(VyperException("", node)) + warnings.warn( + f"enum will be deprecated in a future release, use flag instead. {pretty_printed_node}", + stacklevel=2, + ) + return node def compare_nodes(left_node: "VyperNode", right_node: "VyperNode") -> bool: @@ -725,7 +742,7 @@ class Log(Stmt): __slots__ = ("value",) -class EnumDef(TopLevel): +class FlagDef(TopLevel): __slots__ = ("name", "body") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 05784aed0f..47856b6021 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -81,7 +81,7 @@ class Return(VyperNode): ... class Log(VyperNode): value: VyperNode = ... -class EnumDef(VyperNode): +class FlagDef(VyperNode): body: list = ... name: str = ... diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 9d96efea5e..b949a242bb 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -44,7 +44,8 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: # compound statements that are replaced with `class` -VYPER_CLASS_TYPES = {"enum", "event", "interface", "struct"} +# TODO remove enum in favor of flag +VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} # simple statements or expressions that are replaced with `yield` VYPER_EXPRESSION_TYPES = {"log"} @@ -55,7 +56,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Re-formats a vyper source string into a python source string and performs some validation. More specifically, - * Translates "interface", "struct", "enum, and "event" keywords into python "class" keyword + * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index e09f5f3174..998cbbc9f6 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -14,7 +14,7 @@ int_clamp, is_bytes_m_type, is_decimal_type, - is_enum_type, + is_flag_type, is_integer_type, sar, shl, @@ -35,7 +35,7 @@ BytesM_T, BytesT, DecimalT, - EnumT, + FlagT, IntegerT, StringT, ) @@ -277,7 +277,7 @@ def to_bool(expr, arg, out_typ): return IRnode.from_list(["iszero", ["iszero", arg]], typ=out_typ) -@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, EnumT, BytesT) +@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, FlagT, BytesT) def to_int(expr, arg, out_typ): return _to_int(expr, arg, out_typ) @@ -305,7 +305,7 @@ def _to_int(expr, arg, out_typ): elif is_decimal_type(arg.typ): arg = _fixed_to_int(arg, out_typ) - elif is_enum_type(arg.typ): + elif is_flag_type(arg.typ): if out_typ != UINT256_T: _FAIL(arg.typ, out_typ, expr) # pretend enum is uint256 @@ -468,7 +468,7 @@ def convert(expr, context): ret = to_bool(arg_ast, arg, out_typ) elif out_typ == AddressT(): ret = to_address(arg_ast, arg, out_typ) - elif is_enum_type(out_typ): + elif is_flag_type(out_typ): ret = to_enum(arg_ast, arg, out_typ) elif is_integer_type(out_typ): ret = to_int(arg_ast, arg, out_typ) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 503e0e2f3b..c16de3c55a 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -23,7 +23,7 @@ ) from vyper.semantics.types.shortcuts import BYTES32_T, INT256_T, UINT256_T from vyper.semantics.types.subscriptable import SArrayT -from vyper.semantics.types.user import EnumT +from vyper.semantics.types.user import FlagT from vyper.utils import GAS_COPY_WORD, GAS_IDENTITY, GAS_IDENTITYWORD, ceil32 DYNAMIC_ARRAY_OVERHEAD = 1 @@ -45,8 +45,8 @@ def is_decimal_type(typ): return isinstance(typ, DecimalT) -def is_enum_type(typ): - return isinstance(typ, EnumT) +def is_flag_type(typ): + return isinstance(typ, FlagT) def is_tuple_like(typ): @@ -829,7 +829,7 @@ def needs_clamp(t, encoding): raise CompilerPanic("unreachable") # pragma: notest if isinstance(t, (_BytestringT, DArrayT)): return True - if isinstance(t, EnumT): + if isinstance(t, FlagT): return len(t._enum_members) < 256 if isinstance(t, SArrayT): return needs_clamp(t.value_type, encoding) @@ -1138,7 +1138,7 @@ def clamp_basetype(ir_node): # copy of the input ir_node = unwrap_location(ir_node) - if isinstance(t, EnumT): + if isinstance(t, FlagT): bits = len(t._enum_members) # assert x >> bits == 0 ret = int_clamp(ir_node, bits, signed=False) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index d5ca5aceee..693d5c2aad 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -12,7 +12,7 @@ getpos, is_array_like, is_bytes_m_type, - is_enum_type, + is_flag_type, is_numeric_type, is_tuple_like, pop_dyn_array, @@ -42,7 +42,7 @@ BytesT, DArrayT, DecimalT, - EnumT, + FlagT, HashMapT, InterfaceT, SArrayT, @@ -209,7 +209,7 @@ def parse_Attribute(self): # MyEnum.foo if ( - isinstance(typ, EnumT) + isinstance(typ, FlagT) and isinstance(self.expr.value, vy_ast.Name) and typ.name == self.expr.value.id ): @@ -384,7 +384,7 @@ def parse_BinOp(self): # This should be unreachable due to the type check pass if left.typ != right.typ: raise TypeCheckFailure(f"unreachable, {left.typ} != {right.typ}", self.expr) - assert is_numeric_type(left.typ) or is_enum_type(left.typ) + assert is_numeric_type(left.typ) or is_flag_type(left.typ) out_typ = left.typ @@ -516,7 +516,7 @@ def parse_Compare(self): if is_array_like(right.typ): return self.build_in_comparator() else: - assert isinstance(right.typ, EnumT), right.typ + assert isinstance(right.typ, FlagT), right.typ intersection = ["and", left, right] if isinstance(self.expr.op, vy_ast.In): return IRnode.from_list(["iszero", ["iszero", intersection]], typ=BoolT()) @@ -633,7 +633,7 @@ def parse_UnaryOp(self): return IRnode.from_list(["iszero", operand], typ=BoolT()) if isinstance(self.expr.op, vy_ast.Invert): - if isinstance(operand.typ, EnumT): + if isinstance(operand.typ, FlagT): n_members = len(operand.typ._enum_members) # use (xor 0b11..1 operand) to flip all the bits in # `operand`. `mask` could be a very large constant and diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 4846b1c3b1..8f72d9afc9 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -176,8 +176,8 @@ class FunctionDeclarationException(VyperException): """Invalid function declaration.""" -class EnumDeclarationException(VyperException): - """Invalid enum declaration.""" +class FlagDeclarationException(VyperException): + """Invalid flag declaration.""" class EventDeclarationException(VyperException): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 974c14f261..2a84f69ad4 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -38,8 +38,8 @@ AddressT, BoolT, DArrayT, - EnumT, EventT, + FlagT, HashMapT, IntegerT, SArrayT, @@ -700,7 +700,7 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: validate_expected_type(node.right, rtyp) else: rtyp = get_exact_type_from_node(node.right) - if isinstance(rtyp, EnumT): + if isinstance(rtyp, FlagT): # enum membership - `some_enum in other_enum` ltyp = rtyp else: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 7aa661aec3..fb536b7ab7 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -31,7 +31,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace -from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT +from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation @@ -326,8 +326,8 @@ def _validate_self_namespace(): return _finalize() - def visit_EnumDef(self, node): - obj = EnumT.from_EnumDef(node) + def visit_FlagDef(self, node): + obj = FlagT.from_FlagDef(node) self.namespace[node.name] = obj def visit_EventDef(self, node): diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 1785afd92d..20ebb0f093 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -248,13 +248,13 @@ def types_from_Compare(self, node): # comparisons, e.g. `x < y` # TODO fixme circular import - from vyper.semantics.types.user import EnumT + from vyper.semantics.types.user import FlagT if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # x in y left = self.get_possible_types_from_node(node.left) right = self.get_possible_types_from_node(node.right) - if any(isinstance(t, EnumT) for t in left): + if any(isinstance(t, FlagT) for t in left): types_list = get_common_types(node.left, node.right) _validate_op(node, types_list, "validate_comparator") return [BoolT()] diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 1fef6a706e..880857ccb8 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -5,7 +5,7 @@ from .module import InterfaceT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT -from .user import EnumT, EventT, StructT +from .user import EventT, FlagT, StructT def _get_primitive_types(): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ef7e1d0eb4..a4e782349d 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -5,8 +5,8 @@ from vyper.abi_types import ABI_GIntM, ABI_Tuple, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import ( - EnumDeclarationException, EventDeclarationException, + FlagDeclarationException, InvalidAttribute, NamespaceCollision, StructureException, @@ -43,7 +43,7 @@ def __hash__(self): # note: enum behaves a lot like uint256, or uints in general. -class EnumT(_UserType): +class FlagT(_UserType): # this is a carveout because currently we allow dynamic arrays of # enums, but not static arrays of enums _as_darray = True @@ -52,7 +52,7 @@ class EnumT(_UserType): def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: - raise EnumDeclarationException("Enums are limited to 256 members!") + raise FlagDeclarationException("Enums are limited to 256 members!") super().__init__(members=None) @@ -103,7 +103,7 @@ def validate_comparator(self, node): # return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments)})" @classmethod - def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": + def from_FlagDef(cls, base_node: vy_ast.FlagDef) -> "FlagT": """ Generate an `Enum` object from a Vyper ast node. @@ -118,15 +118,15 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - raise EnumDeclarationException("Enum must have members", base_node) + raise FlagDeclarationException("Enum must have members", base_node) for i, node in enumerate(base_node.body): if not isinstance(node, vy_ast.Expr) or not isinstance(node.value, vy_ast.Name): - raise EnumDeclarationException("Invalid syntax for enum member", node) + raise FlagDeclarationException("Invalid syntax for enum member", node) member_name = node.value.id if member_name in members: - raise EnumDeclarationException( + raise FlagDeclarationException( f"Enum member '{member_name}' has already been declared", node.value ) From 2e4187377858c28df7efae757309f0593286cb70 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Sat, 23 Dec 2023 18:42:19 +0100 Subject: [PATCH 146/201] refactor: make `assert_tx_failed` a contextmanager (#3706) rename `assert_tx_failed` to `tx_failed` and change it into a context manager which has a similar API to `pytest.raises()`. --------- Co-authored-by: Charles Cooper --- docs/testing-contracts-ethtester.rst | 4 +- tests/conftest.py | 28 +-- .../builtins/codegen/test_abi_decode.py | 25 +- .../builtins/codegen/test_addmod.py | 5 +- .../builtins/codegen/test_as_wei_value.py | 16 +- .../builtins/codegen/test_convert.py | 13 +- .../builtins/codegen/test_create_functions.py | 59 ++--- .../builtins/codegen/test_extract32.py | 19 +- .../builtins/codegen/test_minmax.py | 2 +- .../builtins/codegen/test_mulmod.py | 5 +- .../builtins/codegen/test_raw_call.py | 17 +- .../functional/builtins/codegen/test_send.py | 8 +- .../functional/builtins/codegen/test_slice.py | 13 +- .../functional/builtins/codegen/test_unary.py | 5 +- tests/functional/builtins/folding/test_abs.py | 7 +- .../test_default_function.py | 19 +- .../test_default_parameters.py | 5 +- .../calling_convention/test_erc20_abi.py | 18 +- .../test_external_contract_calls.py | 150 +++++++----- ...test_modifiable_external_contract_calls.py | 21 +- .../calling_convention/test_return_tuple.py | 7 +- .../environment_variables/test_blockhash.py | 10 +- .../features/decorators/test_nonreentrant.py | 18 +- .../features/decorators/test_payable.py | 23 +- .../features/decorators/test_private.py | 2 +- .../features/iteration/test_for_range.py | 10 +- .../features/iteration/test_range_in.py | 10 +- .../codegen/features/test_assert.py | 38 +-- .../features/test_assert_unreachable.py | 24 +- .../codegen/features/test_clampers.py | 92 ++++---- .../functional/codegen/features/test_init.py | 5 +- .../codegen/features/test_logging.py | 81 ++++--- .../codegen/features/test_reverting.py | 27 +-- .../codegen/integration/test_escrow.py | 10 +- tests/functional/codegen/test_interfaces.py | 32 ++- .../functional/codegen/test_selector_table.py | 16 +- .../codegen/test_stateless_modules.py | 5 +- .../codegen/types/numbers/test_constants.py | 15 +- .../codegen/types/numbers/test_decimals.py | 33 ++- .../codegen/types/numbers/test_exponents.py | 31 ++- .../codegen/types/numbers/test_modulo.py | 5 +- .../codegen/types/numbers/test_signed_ints.py | 68 ++++-- .../types/numbers/test_unsigned_ints.py | 28 ++- tests/functional/codegen/types/test_bytes.py | 5 +- .../codegen/types/test_dynamic_array.py | 51 ++-- tests/functional/codegen/types/test_flag.py | 34 ++- tests/functional/codegen/types/test_lists.py | 33 ++- tests/functional/codegen/types/test_string.py | 14 +- .../examples/auctions/test_blind_auction.py | 38 ++- .../auctions/test_simple_open_auction.py | 19 +- .../examples/company/test_company.py | 49 ++-- .../crowdfund/test_crowdfund_example.py | 8 +- .../test_on_chain_market_maker.py | 25 +- .../name_registry/test_name_registry.py | 5 +- .../test_safe_remote_purchase.py | 34 ++- .../examples/storage/test_advanced_storage.py | 22 +- .../examples/tokens/test_erc1155.py | 223 +++++++++--------- .../functional/examples/tokens/test_erc20.py | 108 ++++++--- .../functional/examples/tokens/test_erc721.py | 103 ++++---- .../functional/examples/voting/test_ballot.py | 29 ++- .../functional/examples/wallet/test_wallet.py | 22 +- .../ast/nodes/test_evaluate_binop_decimal.py | 10 +- .../unit/ast/nodes/test_evaluate_binop_int.py | 15 +- 63 files changed, 1051 insertions(+), 825 deletions(-) diff --git a/docs/testing-contracts-ethtester.rst b/docs/testing-contracts-ethtester.rst index 992cdc312a..1b7e9e3263 100644 --- a/docs/testing-contracts-ethtester.rst +++ b/docs/testing-contracts-ethtester.rst @@ -55,9 +55,9 @@ To test events and failed transactions we expand our simple storage contract to Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions. -.. literalinclude:: ../tests/base_conftest.py +.. literalinclude:: ../tests/conftest.py :language: python - :pyobject: assert_tx_failed + :pyobject: tx_failed The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction. diff --git a/tests/conftest.py b/tests/conftest.py index 925a025a4a..51b4b4459a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import json import logging +from contextlib import contextmanager from functools import wraps import hypothesis @@ -411,23 +412,6 @@ def assert_compile_failed(function_to_test, exception=Exception): return assert_compile_failed -# TODO this should not be a fixture -@pytest.fixture -def search_for_sublist(): - def search_for_sublist(ir, sublist): - _list = ir.to_list() if hasattr(ir, "to_list") else ir - if _list == sublist: - return True - if isinstance(_list, list): - for i in _list: - ret = search_for_sublist(i, sublist) - if ret is True: - return ret - return False - - return search_for_sublist - - @pytest.fixture def create2_address_of(keccak): def _f(_addr, _salt, _initcode): @@ -484,16 +468,16 @@ def get_logs(tx_hash, c, event_name): return get_logs -# TODO replace me with function like `with anchor_state()` @pytest.fixture(scope="module") -def assert_tx_failed(tester): - def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None): +def tx_failed(tester): + @contextmanager + def fn(exception=TransactionFailed, exc_text=None): snapshot_id = tester.take_snapshot() with pytest.raises(exception) as excinfo: - function_to_test() + yield excinfo tester.revert_to_snapshot(snapshot_id) if exc_text: # TODO test equality assert exc_text in str(excinfo.value), (exc_text, excinfo.value) - return assert_tx_failed + return fn diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 242841e1cf..69bfef63ea 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -331,7 +331,7 @@ def abi_decode(x: Bytes[32]) -> uint256: b"\x01" * 96, # Length of byte array is beyond size bound of output type ], ) -def test_clamper(get_contract, assert_tx_failed, input_): +def test_clamper(get_contract, tx_failed, input_): contract = """ @external def abi_decode(x: Bytes[96]) -> (uint256, uint256): @@ -341,10 +341,11 @@ def abi_decode(x: Bytes[96]) -> (uint256, uint256): return a, b """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) -def test_clamper_nested_uint8(get_contract, assert_tx_failed): +def test_clamper_nested_uint8(get_contract, tx_failed): # check that _abi_decode clamps on word-types even when it is in a nested expression # decode -> validate uint8 -> revert if input >= 256 -> cast back to uint256 contract = """ @@ -355,10 +356,11 @@ def abi_decode(x: uint256) -> uint256: """ c = get_contract(contract) assert c.abi_decode(255) == 255 - assert_tx_failed(lambda: c.abi_decode(256)) + with tx_failed(): + c.abi_decode(256) -def test_clamper_nested_bytes(get_contract, assert_tx_failed): +def test_clamper_nested_bytes(get_contract, tx_failed): # check that _abi_decode clamps dynamic even when it is in a nested expression # decode -> validate Bytes[20] -> revert if len(input) > 20 -> convert back to -> add 1 contract = """ @@ -369,7 +371,8 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]: """ c = get_contract(contract) assert c.abi_decode(abi.encode("(bytes)", (b"bc",))) == b"abc" - assert_tx_failed(lambda: c.abi_decode(abi.encode("(bytes)", (b"a" * 22,)))) + with tx_failed(): + c.abi_decode(abi.encode("(bytes)", (b"a" * 22,))) @pytest.mark.parametrize( @@ -381,7 +384,7 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]: ("Bytes[5]", b"\x01" * 192), ], ) -def test_clamper_dynamic(get_contract, assert_tx_failed, output_typ, input_): +def test_clamper_dynamic(get_contract, tx_failed, output_typ, input_): contract = f""" @external def abi_decode(x: Bytes[192]) -> {output_typ}: @@ -390,7 +393,8 @@ def abi_decode(x: Bytes[192]) -> {output_typ}: return a """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) @pytest.mark.parametrize( @@ -422,7 +426,7 @@ def abi_decode(x: Bytes[160]) -> uint256: ("Bytes[5]", "address", b"\x01" * 128), ], ) -def test_clamper_dynamic_tuple(get_contract, assert_tx_failed, output_typ1, output_typ2, input_): +def test_clamper_dynamic_tuple(get_contract, tx_failed, output_typ1, output_typ2, input_): contract = f""" @external def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}): @@ -432,7 +436,8 @@ def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}): return a, b """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) FAIL_LIST = [ diff --git a/tests/functional/builtins/codegen/test_addmod.py b/tests/functional/builtins/codegen/test_addmod.py index b3135660bb..00745c0cdb 100644 --- a/tests/functional/builtins/codegen/test_addmod.py +++ b/tests/functional/builtins/codegen/test_addmod.py @@ -1,4 +1,4 @@ -def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation): +def test_uint256_addmod(tx_failed, get_contract_with_gas_estimation): uint256_code = """ @external def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: @@ -11,7 +11,8 @@ def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: assert c._uint256_addmod(32, 2, 32) == 2 assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 assert c._uint256_addmod(2**255, 2**255, 6) == 4 - assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) + with tx_failed(): + c._uint256_addmod(1, 2, 0) def test_uint256_addmod_ext_call( diff --git a/tests/functional/builtins/codegen/test_as_wei_value.py b/tests/functional/builtins/codegen/test_as_wei_value.py index cc27507e7c..522684fa05 100644 --- a/tests/functional/builtins/codegen/test_as_wei_value.py +++ b/tests/functional/builtins/codegen/test_as_wei_value.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_uint256(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_uint256(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: uint256) -> uint256: @@ -36,11 +36,12 @@ def foo(a: uint256) -> uint256: assert c.foo(value) == value * (10**multiplier) value = (2**256 - 1) // (10 ** (multiplier - 1)) - assert_tx_failed(lambda: c.foo(value)) + with tx_failed(): + c.foo(value) @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_int128(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_int128(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: int128) -> uint256: @@ -54,7 +55,7 @@ def foo(a: int128) -> uint256: @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_decimal(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_decimal(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: decimal) -> uint256: @@ -69,7 +70,7 @@ def foo(a: decimal) -> uint256: @pytest.mark.parametrize("value", (-1, -(2**127))) @pytest.mark.parametrize("data_type", ["decimal", "int128"]) -def test_negative_value_reverts(get_contract, assert_tx_failed, value, data_type): +def test_negative_value_reverts(get_contract, tx_failed, value, data_type): code = f""" @external def foo(a: {data_type}) -> uint256: @@ -77,12 +78,13 @@ def foo(a: {data_type}) -> uint256: """ c = get_contract(code) - assert_tx_failed(lambda: c.foo(value)) + with tx_failed(): + c.foo(value) @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) @pytest.mark.parametrize("data_type", ["decimal", "int128", "uint256"]) -def test_zero_value(get_contract, assert_tx_failed, denom, multiplier, data_type): +def test_zero_value(get_contract, tx_failed, denom, multiplier, data_type): code = f""" @external def foo(a: {data_type}) -> uint256: diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index 99dae4a932..559e1448ef 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -511,7 +511,7 @@ def bar(a: uint256) -> Roles: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2]) def test_flag_conversion_2( - get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ + get_contract_with_gas_estimation, assert_compile_failed, tx_failed, val, typ ): contract = f""" flag Status: @@ -529,7 +529,8 @@ def foo(a: {typ}) -> Status: if lo <= val <= hi: assert c.foo(val) == val else: - assert_tx_failed(lambda: c.foo(val)) + with tx_failed(): + c.foo(val) else: assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch) @@ -608,7 +609,7 @@ def foo() -> {t_bytes}: @pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases()) @pytest.mark.fuzzing def test_conversion_failures( - get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, i_typ, o_typ, val + get_contract_with_gas_estimation, assert_compile_failed, tx_failed, i_typ, o_typ, val ): """ Test multiple contracts and check for a specific exception. @@ -650,7 +651,8 @@ def foo(): """ c2 = get_contract_with_gas_estimation(contract_2) - assert_tx_failed(lambda: c2.foo()) + with tx_failed(): + c2.foo() contract_3 = f""" @external @@ -659,4 +661,5 @@ def foo(bar: {i_typ}) -> {o_typ}: """ c3 = get_contract_with_gas_estimation(contract_3) - assert_tx_failed(lambda: c3.foo(val)) + with tx_failed(): + c3.foo(val) diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index fa7729d98e..afa729ac8a 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -77,7 +77,7 @@ def test2() -> Bytes[100]: assert c.test2() == b"hello world!" -def test_minimal_proxy_exception(w3, get_contract, assert_tx_failed): +def test_minimal_proxy_exception(w3, get_contract, tx_failed): code = """ interface SubContract: @@ -111,7 +111,8 @@ def test2(a: uint256) -> Bytes[100]: c.test(transact={}) assert c.test2(1) == b"hello world!" - assert_tx_failed(lambda: c.test2(0)) + with tx_failed(): + c.test2(0) GAS_SENT = 30000 tx_hash = c.test2(0, transact={"gas": GAS_SENT}) @@ -122,9 +123,7 @@ def test2(a: uint256) -> Bytes[100]: assert receipt["gasUsed"] < GAS_SENT -def test_create_minimal_proxy_to_create2( - get_contract, create2_address_of, keccak, assert_tx_failed -): +def test_create_minimal_proxy_to_create2(get_contract, create2_address_of, keccak, tx_failed): code = """ main: address @@ -143,20 +142,15 @@ def test(_salt: bytes32) -> address: c.test(salt, transact={}) # revert on collision - assert_tx_failed(lambda: c.test(salt, transact={})) + with tx_failed(): + c.test(salt, transact={}) # test blueprints with various prefixes - 0xfe would block calls to the blueprint # contract, and 0xfe7100 is ERC5202 magic @pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) def test_create_from_blueprint( - get_contract, - deploy_blueprint_for, - w3, - keccak, - create2_address_of, - assert_tx_failed, - blueprint_prefix, + get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed, blueprint_prefix ): code = """ @external @@ -193,7 +187,8 @@ def test2(target: address, salt: bytes32): # extcodesize check zero_address = "0x" + "00" * 20 - assert_tx_failed(lambda: d.test(zero_address)) + with tx_failed(): + d.test(zero_address) # now same thing but with create2 salt = keccak(b"vyper") @@ -209,11 +204,12 @@ def test2(target: address, salt: bytes32): assert HexBytes(test.address) == create2_address_of(d.address, salt, initcode) # can't collide addresses - assert_tx_failed(lambda: d.test2(f.address, salt)) + with tx_failed(): + d.test2(f.address, salt) def test_create_from_blueprint_bad_code_offset( - get_contract, get_contract_from_ir, deploy_blueprint_for, w3, assert_tx_failed + get_contract, get_contract_from_ir, deploy_blueprint_for, w3, tx_failed ): deployer_code = """ BLUEPRINT: immutable(address) @@ -254,15 +250,17 @@ def test(code_ofst: uint256) -> address: d.test(initcode_len - 1) # code_offset=len(blueprint) NOT fine! would EXTCODECOPY empty initcode - assert_tx_failed(lambda: d.test(initcode_len)) + with tx_failed(): + d.test(initcode_len) # code_offset=EIP_170_LIMIT definitely not fine! - assert_tx_failed(lambda: d.test(EIP_170_LIMIT)) + with tx_failed(): + d.test(EIP_170_LIMIT) # test create_from_blueprint with args def test_create_from_blueprint_args( - get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, assert_tx_failed + get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed ): code = """ struct Bar: @@ -332,7 +330,8 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): assert test.bar() == BAR # extcodesize check - assert_tx_failed(lambda: d.test("0x" + "00" * 20, FOO, BAR)) + with tx_failed(): + d.test("0x" + "00" * 20, FOO, BAR) # now same thing but with create2 salt = keccak(b"vyper") @@ -359,9 +358,11 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): assert test.bar() == BAR # can't collide addresses - assert_tx_failed(lambda: d.test2(f.address, FOO, BAR, salt)) + with tx_failed(): + d.test2(f.address, FOO, BAR, salt) # ditto - with raw_args - assert_tx_failed(lambda: d.test4(f.address, encoded_args, salt)) + with tx_failed(): + d.test4(f.address, encoded_args, salt) # but creating a contract with different args is ok FOO = "bar" @@ -375,10 +376,11 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): BAR = ("",) sig = keccak("should_fail(address,string,(string))".encode()).hex()[:10] encoded = abi.encode("(address,string,(string))", (f.address, FOO, BAR)).hex() - assert_tx_failed(lambda: w3.eth.send_transaction({"to": d.address, "data": f"{sig}{encoded}"})) + with tx_failed(): + w3.eth.send_transaction({"to": d.address, "data": f"{sig}{encoded}"}) -def test_create_copy_of(get_contract, w3, keccak, create2_address_of, assert_tx_failed): +def test_create_copy_of(get_contract, w3, keccak, create2_address_of, tx_failed): code = """ created_address: public(address) @internal @@ -412,7 +414,8 @@ def test2(target: address, salt: bytes32) -> address: assert w3.eth.get_code(test1) == bytecode # extcodesize check - assert_tx_failed(lambda: c.test("0x" + "00" * 20)) + with tx_failed(): + c.test("0x" + "00" * 20) # test1 = c.test(b"\x01") # assert w3.eth.get_code(test1) == b"\x01" @@ -425,12 +428,14 @@ def test2(target: address, salt: bytes32) -> address: assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(bytecode)) # can't create2 where contract already exists - assert_tx_failed(lambda: c.test2(c.address, salt, transact={})) + with tx_failed(): + c.test2(c.address, salt, transact={}) # test single byte contract # test2 = c.test2(b"\x01", salt) # assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(b"\x01")) - # assert_tx_failed(lambda: c.test2(bytecode, salt)) + # with tx_failed(): + # c.test2(bytecode, salt) # XXX: these various tests to check the msize allocator for diff --git a/tests/functional/builtins/codegen/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py index 6e4ee09abc..a95b57b5ab 100644 --- a/tests/functional/builtins/codegen/test_extract32.py +++ b/tests/functional/builtins/codegen/test_extract32.py @@ -1,4 +1,4 @@ -def test_extract32_extraction(assert_tx_failed, get_contract_with_gas_estimation): +def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation): extract32_code = """ y: Bytes[100] @external @@ -34,18 +34,19 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32: ) for S, i in test_cases: - expected_result = S[i : i + 32] if 0 <= i <= len(S) - 32 else None - if expected_result is None: - assert_tx_failed(lambda p=(S, i): c.extrakt32(*p)) - else: + if 0 <= i <= len(S) - 32: + expected_result = S[i : i + 32] assert c.extrakt32(S, i) == expected_result assert c.extrakt32_mem(S, i) == expected_result assert c.extrakt32_storage(i, S) == expected_result + else: + with tx_failed(): + c.extrakt32(S, i) print("Passed bytes32 extraction test") -def test_extract32_code(assert_tx_failed, get_contract_with_gas_estimation): +def test_extract32_code(tx_failed, get_contract_with_gas_estimation): extract32_code = """ @external def foo(inp: Bytes[32]) -> int128: @@ -72,7 +73,8 @@ def foq(inp: Bytes[32]) -> address: assert c.foo(b"\x00" * 30 + b"\x01\x01") == 257 assert c.bar(b"\x00" * 30 + b"\x01\x01") == 257 - assert_tx_failed(lambda: c.foo(b"\x80" + b"\x00" * 30)) + with tx_failed(): + c.foo(b"\x80" + b"\x00" * 30) assert c.bar(b"\x80" + b"\x00" * 31) == 2**255 @@ -80,6 +82,7 @@ def foq(inp: Bytes[32]) -> address: assert c.fop(b"crow" * 8) == b"crow" * 8 assert c.foq(b"\x00" * 12 + b"3" * 20) == "0x" + "3" * 40 - assert_tx_failed(lambda: c.foq(b"crow" * 8)) + with tx_failed(): + c.foq(b"crow" * 8) print("Passed extract32 test") diff --git a/tests/functional/builtins/codegen/test_minmax.py b/tests/functional/builtins/codegen/test_minmax.py index da939d605a..f86504522f 100644 --- a/tests/functional/builtins/codegen/test_minmax.py +++ b/tests/functional/builtins/codegen/test_minmax.py @@ -198,7 +198,7 @@ def foo() -> uint256: def test_minmax_var_uint256_negative_int128( - get_contract_with_gas_estimation, assert_tx_failed, assert_compile_failed + get_contract_with_gas_estimation, tx_failed, assert_compile_failed ): from vyper.exceptions import TypeMismatch diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index 96477897b9..ba82ebd5b8 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -1,4 +1,4 @@ -def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation): +def test_uint256_mulmod(tx_failed, get_contract_with_gas_estimation): uint256_code = """ @external def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: @@ -11,7 +11,8 @@ def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: assert c._uint256_mulmod(200, 3, 601) == 600 assert c._uint256_mulmod(2**255, 1, 3) == 2 assert c._uint256_mulmod(2**255, 2, 6) == 4 - assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) + with tx_failed(): + c._uint256_mulmod(2, 2, 0) def test_uint256_mulmod_complex(get_contract_with_gas_estimation): diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index 5bb23447e4..4d37176cf8 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -91,7 +91,7 @@ def create_and_return_proxy(inp: address) -> address: # print(f'Gas consumed: {(chain.head_state.receipts[-1].gas_used - chain.head_state.receipts[-2].gas_used - chain.last_tx.intrinsic_gas_used)}') # noqa: E501 -def test_multiple_levels2(assert_tx_failed, get_contract_with_gas_estimation): +def test_multiple_levels2(tx_failed, get_contract_with_gas_estimation): inner_code = """ @external def returnten() -> int128: @@ -114,7 +114,8 @@ def create_and_return_proxy(inp: address) -> address: c2 = get_contract_with_gas_estimation(outer_code) - assert_tx_failed(lambda: c2.create_and_call_returnten(c.address)) + with tx_failed(): + c2.create_and_call_returnten(c.address) print("Passed minimal proxy exception test") @@ -171,7 +172,7 @@ def set(i: int128, owner: address): assert outer_contract.owners(1) == a1 -def test_gas(get_contract, assert_tx_failed): +def test_gas(get_contract, tx_failed): inner_code = """ bar: bytes32 @@ -202,7 +203,8 @@ def foo_call(_addr: address): # manually specifying an insufficient amount should fail outer_contract = get_contract(outer_code.format(", gas=15000")) - assert_tx_failed(lambda: outer_contract.foo_call(inner_contract.address)) + with tx_failed(): + outer_contract.foo_call(inner_contract.address) def test_static_call(get_contract): @@ -323,7 +325,7 @@ def foo(_addr: address) -> bool: assert caller.foo(target.address) is True -def test_static_call_fails_nonpayable(get_contract, assert_tx_failed): +def test_static_call_fails_nonpayable(get_contract, tx_failed): target_source = """ baz: int128 @@ -349,10 +351,11 @@ def foo(_addr: address) -> int128: target = get_contract(target_source) caller = get_contract(caller_source) - assert_tx_failed(lambda: caller.foo(target.address)) + with tx_failed(): + caller.foo(target.address) -def test_checkable_raw_call(get_contract, assert_tx_failed): +def test_checkable_raw_call(get_contract, tx_failed): target_source = """ baz: int128 @external diff --git a/tests/functional/builtins/codegen/test_send.py b/tests/functional/builtins/codegen/test_send.py index 199f708cb4..36f8979556 100644 --- a/tests/functional/builtins/codegen/test_send.py +++ b/tests/functional/builtins/codegen/test_send.py @@ -1,4 +1,4 @@ -def test_send(assert_tx_failed, get_contract): +def test_send(tx_failed, get_contract): send_test = """ @external def foo(): @@ -9,9 +9,11 @@ def fop(): send(msg.sender, 10) """ c = get_contract(send_test, value=10) - assert_tx_failed(lambda: c.foo(transact={})) + with tx_failed(): + c.foo(transact={}) c.fop(transact={}) - assert_tx_failed(lambda: c.fop(transact={})) + with tx_failed(): + c.fop(transact={}) def test_default_gas(get_contract, w3): diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 53e092019f..a15a3eeb35 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -41,7 +41,7 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: def test_slice_immutable( get_contract, assert_compile_failed, - assert_tx_failed, + tx_failed, opt_level, bytesdata, start, @@ -79,7 +79,8 @@ def _get_contract(): assert_compile_failed(lambda: _get_contract(), ArgumentException) elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): # deploy fail - assert_tx_failed(lambda: _get_contract()) + with tx_failed(): + _get_contract() else: c = _get_contract() assert c.do_splice() == bytesdata[start : start + length] @@ -95,7 +96,7 @@ def _get_contract(): def test_slice_bytes_fuzz( get_contract, assert_compile_failed, - assert_tx_failed, + tx_failed, opt_level, location, bytesdata, @@ -175,10 +176,12 @@ def _get_contract(): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) elif location == "code" and len(bytesdata) > length_bound: # deploy fail - assert_tx_failed(lambda: _get_contract()) + with tx_failed(): + _get_contract() elif end > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() - assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) + with tx_failed(): + c.do_slice(bytesdata, start, length) else: c = _get_contract() assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code diff --git a/tests/functional/builtins/codegen/test_unary.py b/tests/functional/builtins/codegen/test_unary.py index da3823edfe..33f79be233 100644 --- a/tests/functional/builtins/codegen/test_unary.py +++ b/tests/functional/builtins/codegen/test_unary.py @@ -13,14 +13,15 @@ def negate(a: uint256) -> uint256: assert_compile_failed(lambda: get_contract(code), exception=InvalidOperation) -def test_unary_sub_int128_fail(get_contract, assert_tx_failed): +def test_unary_sub_int128_fail(get_contract, tx_failed): code = """@external def negate(a: int128) -> int128: return -(a) """ c = get_contract(code) # This test should revert on overflow condition - assert_tx_failed(lambda: c.negate(-(2**127))) + with tx_failed(): + c.negate(-(2**127)) @pytest.mark.parametrize("val", [-(2**127) + 1, 0, 2**127 - 1]) diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 1c919d7826..a91a4f1ad3 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -39,7 +39,7 @@ def foo(a: int256) -> int256: get_contract(source) -def test_abs_lower_bound(get_contract, assert_tx_failed): +def test_abs_lower_bound(get_contract, tx_failed): source = """ @external def foo(a: int256) -> int256: @@ -47,10 +47,11 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - assert_tx_failed(lambda: contract.foo(-(2**255))) + with tx_failed(): + contract.foo(-(2**255)) -def test_abs_lower_bound_folded(get_contract, assert_tx_failed): +def test_abs_lower_bound_folded(get_contract, tx_failed): source = """ @external def foo() -> int256: diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index f7eef21af7..cf55607877 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -1,4 +1,4 @@ -def test_throw_on_sending(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) @@ -10,9 +10,8 @@ def __init__(): assert c.x() == 123 assert w3.eth.get_balance(c.address) == 0 - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "value": w3.to_wei(0.1, "ether")}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": w3.to_wei(0.1, "ether")}) assert w3.eth.get_balance(c.address) == 0 @@ -56,7 +55,7 @@ def __default__(): assert w3.eth.get_balance(c.address) == w3.to_wei(0.1, "ether") -def test_basic_default_not_payable(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_basic_default_not_payable(w3, tx_failed, get_contract_with_gas_estimation): code = """ event Sent: sender: indexed(address) @@ -67,7 +66,8 @@ def __default__(): """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c.address, "value": 10**17})) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 10**17}) def test_multi_arg_default(assert_compile_failed, get_contract_with_gas_estimation): @@ -100,7 +100,7 @@ def __default__(): assert_compile_failed(lambda: get_contract_with_gas_estimation(code)) -def test_zero_method_id(w3, get_logs, get_contract, assert_tx_failed): +def test_zero_method_id(w3, get_logs, get_contract, tx_failed): # test a method with 0x00000000 selector, # expects at least 36 bytes of calldata. code = """ @@ -143,10 +143,11 @@ def _call_with_bytes(hexstr): for i in range(4, 36): # match the full 4 selector bytes, but revert due to malformed (short) calldata - assert_tx_failed(lambda p="0x" + "00" * i: _call_with_bytes(p)) + with tx_failed(): + _call_with_bytes(f"0x{'00' * i}") -def test_another_zero_method_id(w3, get_logs, get_contract, assert_tx_failed): +def test_another_zero_method_id(w3, get_logs, get_contract, tx_failed): # test another zero method id but which only expects 4 bytes of calldata code = """ event Sent: diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index a90f5e6624..03f5d9fca2 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -150,7 +150,7 @@ def foo(a: int128[3] = [1, 2, 3]) -> int128[3]: assert c.foo() == [1, 2, 3] -def test_default_param_clamp(get_contract, monkeypatch, assert_tx_failed): +def test_default_param_clamp(get_contract, monkeypatch, tx_failed): code = """ @external def bar(a: int128, b: int128 = -1) -> (int128, int128): # noqa: E501 @@ -168,7 +168,8 @@ def validate_value(cls, value): monkeypatch.setattr("eth_abi.encoding.NumberEncoder.validate_value", validate_value) assert c.bar(200, 2**127 - 1) == [200, 2**127 - 1] - assert_tx_failed(lambda: c.bar(200, 2**127)) + with tx_failed(): + c.bar(200, 2**127) def test_default_param_private(get_contract): diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index 4a09ce68fa..b9dc5c663f 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -81,7 +81,7 @@ def test_initial_state(w3, erc20_caller): assert erc20_caller.decimals() == TOKEN_DECIMALS -def test_call_transfer(w3, erc20, erc20_caller, assert_tx_failed): +def test_call_transfer(w3, erc20, erc20_caller, tx_failed): # Basic transfer. erc20.transfer(erc20_caller.address, 10, transact={}) assert erc20.balanceOf(erc20_caller.address) == 10 @@ -90,13 +90,12 @@ def test_call_transfer(w3, erc20, erc20_caller, assert_tx_failed): assert erc20.balanceOf(w3.eth.accounts[1]) == 10 # more than allowed - assert_tx_failed(lambda: erc20_caller.transfer(w3.eth.accounts[1], TOKEN_TOTAL_SUPPLY)) + with tx_failed(): + erc20_caller.transfer(w3.eth.accounts[1], TOKEN_TOTAL_SUPPLY) # Negative transfer value. - assert_tx_failed( - function_to_test=lambda: erc20_caller.transfer(w3.eth.accounts[1], -1), - exception=ValidationError, - ) + with tx_failed(ValidationError): + erc20_caller.transfer(w3.eth.accounts[1], -1) def test_caller_approve_allowance(w3, erc20, erc20_caller): @@ -105,11 +104,10 @@ def test_caller_approve_allowance(w3, erc20, erc20_caller): assert erc20_caller.allowance(w3.eth.accounts[0], erc20_caller.address) == 10 -def test_caller_tranfer_from(w3, erc20, erc20_caller, assert_tx_failed): +def test_caller_tranfer_from(w3, erc20, erc20_caller, tx_failed): # Cannot transfer tokens that are unavailable - assert_tx_failed( - lambda: erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 10) - ) + with tx_failed(): + erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 10) assert erc20.balanceOf(erc20_caller.address) == 0 assert erc20.approve(erc20_caller.address, 10, transact={}) erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 5, transact={}) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 12fcde2f4f..0360396f03 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -3,6 +3,7 @@ import pytest from eth.codecs import abi +from vyper import compile_code from vyper.exceptions import ( ArgumentException, InvalidType, @@ -94,7 +95,7 @@ def get_array(arg1: address) -> Bytes[3]: assert c2.get_array(c.address) == b"dog" -def test_bytes_too_long(get_contract, assert_tx_failed): +def test_bytes_too_long(get_contract, tx_failed): contract_1 = """ @external def array() -> Bytes[4]: @@ -113,13 +114,14 @@ def get_array(arg1: address) -> Bytes[3]: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) @pytest.mark.parametrize( "revert_string", ["Mayday, mayday!", "A very long revert string" + "." * 512] ) -def test_revert_propagation(get_contract, assert_tx_failed, revert_string): +def test_revert_propagation(get_contract, tx_failed, revert_string): raiser = f""" @external def run(): @@ -135,7 +137,8 @@ def run(raiser: address): """ c1 = get_contract(raiser) c2 = get_contract(caller) - assert_tx_failed(lambda: c2.run(c1.address), exc_text=revert_string) + with tx_failed(exc_text=revert_string): + c2.run(c1.address) @pytest.mark.parametrize("a,b", [(3, 3), (4, 3), (3, 4), (32, 32), (33, 33), (64, 64)]) @@ -169,7 +172,7 @@ def get_array(arg1: address) -> (Bytes[{a}], int128, Bytes[{b}]): @pytest.mark.parametrize("a,b", [(18, 7), (18, 18), (19, 6), (64, 6), (7, 19)]) @pytest.mark.parametrize("c,d", [(19, 7), (64, 64)]) -def test_tuple_with_bytes_too_long(get_contract, assert_tx_failed, a, c, b, d): +def test_tuple_with_bytes_too_long(get_contract, tx_failed, a, c, b, d): contract_1 = f""" @external def array() -> (Bytes[{c}], int128, Bytes[{d}]): @@ -193,10 +196,11 @@ def get_array(arg1: address) -> (Bytes[{a}], int128, Bytes[{b}]): c2 = get_contract(contract_2) assert c.array() == [b"nineteen characters", 255, b"seven!!"] - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) -def test_tuple_with_bytes_too_long_two(get_contract, assert_tx_failed): +def test_tuple_with_bytes_too_long_two(get_contract, tx_failed): contract_1 = """ @external def array() -> (Bytes[30], int128, Bytes[30]): @@ -220,7 +224,8 @@ def get_array(arg1: address) -> (Bytes[30], int128, Bytes[3]): c2 = get_contract(contract_2) assert c.array() == [b"nineteen characters", 255, b"seven!!"] - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) @pytest.mark.parametrize("length", [8, 256]) @@ -246,7 +251,7 @@ def bar(arg1: address) -> uint8: assert c2.bar(c.address) == 255 -def test_uint8_too_long(get_contract, assert_tx_failed): +def test_uint8_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -265,7 +270,8 @@ def bar(arg1: address) -> uint8: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(8, 8), (8, 256), (256, 8), (256, 256)]) @@ -298,7 +304,7 @@ def bar(arg1: address) -> (uint{a}, Bytes[3], uint{b}): @pytest.mark.parametrize("a,b", [(8, 256), (256, 8), (256, 256)]) -def test_tuple_with_uint8_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_uint8_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{a}, Bytes[3], uint{b}): @@ -322,11 +328,12 @@ def bar(arg1: address) -> (uint8, Bytes[3], uint8): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**a)-1}"), b"dog", int(f"{(2**b)-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(8, 256), (256, 8)]) -def test_tuple_with_uint8_too_long_two(get_contract, assert_tx_failed, a, b): +def test_tuple_with_uint8_too_long_two(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{b}, Bytes[3], uint{a}): @@ -350,7 +357,8 @@ def bar(arg1: address) -> (uint{a}, Bytes[3], uint{b}): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**b)-1}"), b"dog", int(f"{(2**a)-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("length", [128, 256]) @@ -376,7 +384,7 @@ def bar(arg1: address) -> int128: assert c2.bar(c.address) == 1 -def test_int128_too_long(get_contract, assert_tx_failed): +def test_int128_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> int256: @@ -395,7 +403,8 @@ def bar(arg1: address) -> int128: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(128, 128), (128, 256), (256, 128), (256, 256)]) @@ -428,7 +437,7 @@ def bar(arg1: address) -> (int{a}, Bytes[3], int{b}): @pytest.mark.parametrize("a,b", [(128, 256), (256, 128), (256, 256)]) -def test_tuple_with_int128_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_int128_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (int{a}, Bytes[3], int{b}): @@ -452,11 +461,12 @@ def bar(arg1: address) -> (int128, Bytes[3], int128): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**(a-1))-1}"), b"dog", int(f"{(2**(b-1))-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(128, 256), (256, 128)]) -def test_tuple_with_int128_too_long_two(get_contract, assert_tx_failed, a, b): +def test_tuple_with_int128_too_long_two(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (int{b}, Bytes[3], int{a}): @@ -480,7 +490,8 @@ def bar(arg1: address) -> (int{a}, Bytes[3], int{b}): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**(b-1))-1}"), b"dog", int(f"{(2**(a-1))-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "uint256", "int128", "int256"]) @@ -506,7 +517,7 @@ def bar(arg1: address) -> decimal: assert c2.bar(c.address) == Decimal("1e-10") -def test_decimal_too_long(get_contract, assert_tx_failed): +def test_decimal_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -525,7 +536,8 @@ def bar(arg1: address) -> decimal: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @@ -559,7 +571,7 @@ def bar(arg1: address) -> (decimal, Bytes[3], decimal): @pytest.mark.parametrize("a,b", [(8, 256), (256, 8), (256, 256)]) -def test_tuple_with_decimal_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_decimal_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{a}, Bytes[3], uint{b}): @@ -583,7 +595,8 @@ def bar(arg1: address) -> (decimal, Bytes[3], decimal): c2 = get_contract(contract_2) assert c.foo() == [2 ** (a - 1), b"dog", 2 ** (b - 1)] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "uint256", "int128", "int256"]) @@ -609,7 +622,7 @@ def bar(arg1: address) -> bool: assert c2.bar(c.address) is True -def test_bool_too_long(get_contract, assert_tx_failed): +def test_bool_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -628,7 +641,8 @@ def bar(arg1: address) -> bool: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @@ -662,7 +676,7 @@ def bar(arg1: address) -> (bool, Bytes[3], bool): @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @pytest.mark.parametrize("b", ["uint8", "uint256", "int128", "int256"]) -def test_tuple_with_bool_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_bool_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> ({a}, Bytes[3], {b}): @@ -686,7 +700,8 @@ def bar(arg1: address) -> (bool, Bytes[3], bool): c2 = get_contract(contract_2) assert c.foo() == [1, b"dog", 2] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "int128", "uint256", "int256"]) @@ -736,7 +751,7 @@ def bar(arg1: address) -> address: @pytest.mark.parametrize("type", ["uint256", "int256"]) -def test_address_too_long(get_contract, assert_tx_failed, type): +def test_address_too_long(get_contract, tx_failed, type): contract_1 = f""" @external def foo() -> {type}: @@ -755,7 +770,8 @@ def bar(arg1: address) -> address: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "int128", "uint256", "int256"]) @@ -826,7 +842,7 @@ def bar(arg1: address) -> (address, Bytes[3], address): @pytest.mark.parametrize("a", ["uint256", "int256"]) @pytest.mark.parametrize("b", ["uint256", "int256"]) -def test_tuple_with_address_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_address_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> ({a}, Bytes[3], {b}): @@ -850,7 +866,8 @@ def bar(arg1: address) -> (address, Bytes[3], address): c2 = get_contract(contract_2) assert c.foo() == [(2**160) - 1, b"dog", 2**160] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) def test_external_contract_call_state_change(get_contract): @@ -1095,7 +1112,7 @@ def _expr(x: address) -> int128: assert c2._expr(c2.address) == 1 -def test_invalid_nonexistent_contract_call(w3, assert_tx_failed, get_contract): +def test_invalid_nonexistent_contract_call(w3, tx_failed, get_contract): contract_1 = """ @external def bar() -> int128: @@ -1115,11 +1132,13 @@ def foo(x: address) -> int128: c2 = get_contract(contract_2) assert c2.foo(c1.address) == 1 - assert_tx_failed(lambda: c2.foo(w3.eth.accounts[0])) - assert_tx_failed(lambda: c2.foo(w3.eth.accounts[3])) + with tx_failed(): + c2.foo(w3.eth.accounts[0]) + with tx_failed(): + c2.foo(w3.eth.accounts[3]) -def test_invalid_contract_reference_declaration(assert_tx_failed, get_contract): +def test_invalid_contract_reference_declaration(tx_failed, get_contract): contract = """ interface Bar: get_magic_number: 1 @@ -1130,19 +1149,21 @@ def test_invalid_contract_reference_declaration(assert_tx_failed, get_contract): def __init__(): pass """ - assert_tx_failed(lambda: get_contract(contract), exception=StructureException) + with tx_failed(exception=StructureException): + get_contract(contract) -def test_invalid_contract_reference_call(assert_tx_failed, get_contract): +def test_invalid_contract_reference_call(tx_failed, get_contract): contract = """ @external def bar(arg1: address, arg2: int128) -> int128: return Foo(arg1).foo(arg2) """ - assert_tx_failed(lambda: get_contract(contract), exception=UndeclaredDefinition) + with pytest.raises(UndeclaredDefinition): + compile_code(contract) -def test_invalid_contract_reference_return_type(assert_tx_failed, get_contract): +def test_invalid_contract_reference_return_type(tx_failed, get_contract): contract = """ interface Foo: def foo(arg2: int128) -> invalid: view @@ -1151,7 +1172,8 @@ def foo(arg2: int128) -> invalid: view def bar(arg1: address, arg2: int128) -> int128: return Foo(arg1).foo(arg2) """ - assert_tx_failed(lambda: get_contract(contract), exception=UnknownType) + with pytest.raises(UnknownType): + compile_code(contract) def test_external_contract_call_declaration_expr(get_contract): @@ -1378,7 +1400,7 @@ def get_lucky(amount_to_send: uint256) -> int128: assert w3.eth.get_balance(c2.address) == 250 -def test_external_call_with_gas(assert_tx_failed, get_contract_with_gas_estimation): +def test_external_call_with_gas(tx_failed, get_contract_with_gas_estimation): contract_1 = """ @external def get_lucky() -> int128: @@ -1406,7 +1428,8 @@ def get_lucky(gas_amount: uint256) -> int128: c2.set_contract(c1.address, transact={}) assert c2.get_lucky(1000) == 656598 - assert_tx_failed(lambda: c2.get_lucky(50)) # too little gas. + with tx_failed(): + c2.get_lucky(50) # too little gas. def test_skip_contract_check(get_contract_with_gas_estimation): @@ -2240,7 +2263,7 @@ def get_array(arg1: address) -> int128[3]: assert c2.get_array(c.address) == [0, 0, 0] -def test_returndatasize_too_short(get_contract, assert_tx_failed): +def test_returndatasize_too_short(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128) -> int128: @@ -2256,10 +2279,11 @@ def foo(_addr: address): """ c1 = get_contract(contract_1) c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.foo(c1.address)) + with tx_failed(): + c2.foo(c1.address) -def test_returndatasize_empty(get_contract, assert_tx_failed): +def test_returndatasize_empty(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128): @@ -2275,7 +2299,8 @@ def foo(_addr: address) -> int128: """ c1 = get_contract(contract_1) c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.foo(c1.address)) + with tx_failed(): + c2.foo(c1.address) def test_returndatasize_too_long(get_contract): @@ -2299,7 +2324,7 @@ def foo(_addr: address) -> int128: assert c2.foo(c1.address) == 456 -def test_no_returndata(get_contract, assert_tx_failed): +def test_no_returndata(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128) -> int128: @@ -2321,10 +2346,11 @@ def foo(_addr: address, _addr2: address) -> int128: c2 = get_contract(contract_2) assert c2.foo(c1.address, c1.address) == 123 - assert_tx_failed(lambda: c2.foo(c1.address, "0x1234567890123456789012345678901234567890")) + with tx_failed(): + c2.foo(c1.address, "0x1234567890123456789012345678901234567890") -def test_default_override(get_contract, assert_tx_failed): +def test_default_override(get_contract, tx_failed): bad_erc20_code = """ @external def transfer(receiver: address, amount: uint256): @@ -2358,17 +2384,20 @@ def transferBorked(erc20: ERC20, receiver: address, amount: uint256): c = get_contract(code) # demonstrate transfer failing - assert_tx_failed(lambda: c.transferBorked(bad_erc20.address, c.address, 0)) + with tx_failed(): + c.transferBorked(bad_erc20.address, c.address, 0) # would fail without default_return_value assert c.safeTransfer(bad_erc20.address, c.address, 0) == 7 # check that `default_return_value` does not stomp valid returndata. negative_contract = get_contract(negative_transfer_code) - assert_tx_failed(lambda: c.safeTransfer(negative_contract.address, c.address, 0)) + with tx_failed(): + c.safeTransfer(negative_contract.address, c.address, 0) # default_return_value should fail on EOAs (addresses with no code) random_address = "0x0000000000000000000000000000000000001234" - assert_tx_failed(lambda: c.safeTransfer(random_address, c.address, 1)) + with tx_failed(): + c.safeTransfer(random_address, c.address, 1) # in this case, the extcodesize check runs after the token contract # selfdestructs. however, extcodesize still returns nonzero until @@ -2378,7 +2407,7 @@ def transferBorked(erc20: ERC20, receiver: address, amount: uint256): assert c.safeTransfer(self_destructing_contract.address, c.address, 0) == 7 -def test_default_override2(get_contract, assert_tx_failed): +def test_default_override2(get_contract, tx_failed): bad_code_1 = """ @external def return_64_bytes() -> bool: @@ -2407,7 +2436,8 @@ def bar(foo: Foo): c = get_contract(code) # fails due to returndatasize being nonzero but also lt 64 - assert_tx_failed(lambda: c.bar(bad_1.address)) + with tx_failed(): + c.bar(bad_1.address) c.bar(bad_2.address) @@ -2456,7 +2486,7 @@ def do_stuff(f: Foo) -> uint256: @pytest.mark.parametrize("typ,val", [("address", TEST_ADDR)]) -def test_calldata_clamp(w3, get_contract, assert_tx_failed, keccak, typ, val): +def test_calldata_clamp(w3, get_contract, tx_failed, keccak, typ, val): code = f""" @external def foo(a: {typ}): @@ -2469,7 +2499,8 @@ def foo(a: {typ}): # Static size is short by 1 byte malformed = data[:-2] - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c1.address, "data": malformed})) + with tx_failed(): + w3.eth.send_transaction({"to": c1.address, "data": malformed}) # Static size is exact w3.eth.send_transaction({"to": c1.address, "data": data}) @@ -2479,7 +2510,7 @@ def foo(a: {typ}): @pytest.mark.parametrize("typ,val", [("address", ([TEST_ADDR] * 3, "vyper"))]) -def test_dynamic_calldata_clamp(w3, get_contract, assert_tx_failed, keccak, typ, val): +def test_dynamic_calldata_clamp(w3, get_contract, tx_failed, keccak, typ, val): code = f""" @external def foo(a: DynArray[{typ}, 3], b: String[5]): @@ -2493,7 +2524,8 @@ def foo(a: DynArray[{typ}, 3], b: String[5]): # Dynamic size is short by 1 byte malformed = data[:264] - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c1.address, "data": malformed})) + with tx_failed(): + w3.eth.send_transaction({"to": c1.address, "data": malformed}) # Dynamic size is at least minimum (132 bytes * 2 + 2 (for 0x) = 266) valid = data[:266] diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index 4c321442f4..e6b2402016 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -1,7 +1,7 @@ from vyper.exceptions import StructureException, SyntaxException, UnknownType -def test_external_contract_call_declaration_expr(get_contract, assert_tx_failed): +def test_external_contract_call_declaration_expr(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -39,11 +39,12 @@ def static_set_lucky(_lucky: int128): c2.modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 # Fails attempting a state change after a call to a static address - assert_tx_failed(lambda: c2.static_set_lucky(5, transact={})) + with tx_failed(): + c2.static_set_lucky(5, transact={}) assert c1.lucky() == 7 -def test_external_contract_call_declaration_stmt(get_contract, assert_tx_failed): +def test_external_contract_call_declaration_stmt(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -83,11 +84,12 @@ def static_set_lucky(_lucky: int128): c2.modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 # Fails attempting a state change after a call to a static address - assert_tx_failed(lambda: c2.static_set_lucky(5, transact={})) + with tx_failed(): + c2.static_set_lucky(5, transact={}) assert c1.lucky() == 7 -def test_multiple_contract_state_changes(get_contract, assert_tx_failed): +def test_multiple_contract_state_changes(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -161,9 +163,12 @@ def static_modifiable_set_lucky(_lucky: int128): assert c1.lucky() == 0 c3.modifiable_modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 - assert_tx_failed(lambda: c3.modifiable_static_set_lucky(6, transact={})) - assert_tx_failed(lambda: c3.static_modifiable_set_lucky(6, transact={})) - assert_tx_failed(lambda: c3.static_static_set_lucky(6, transact={})) + with tx_failed(): + c3.modifiable_static_set_lucky(6, transact={}) + with tx_failed(): + c3.static_modifiable_set_lucky(6, transact={}) + with tx_failed(): + c3.static_static_set_lucky(6, transact={}) assert c1.lucky() == 7 diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index b375839147..266555ead6 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -1,5 +1,6 @@ import pytest +from vyper import compile_code from vyper.exceptions import TypeMismatch pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -152,11 +153,11 @@ def test3() -> (address, int128): assert c.test3() == [c.out_literals()[2], 1] -def test_tuple_return_typecheck(assert_tx_failed, get_contract_with_gas_estimation): +def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): code = """ @external def getTimeAndBalance() -> (bool, address): return block.timestamp, self.balance """ - - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code) diff --git a/tests/functional/codegen/environment_variables/test_blockhash.py b/tests/functional/codegen/environment_variables/test_blockhash.py index b92c17a561..68db053b12 100644 --- a/tests/functional/codegen/environment_variables/test_blockhash.py +++ b/tests/functional/codegen/environment_variables/test_blockhash.py @@ -23,7 +23,7 @@ def foo() -> bytes32: assert_compile_failed(lambda: get_contract_with_gas_estimation(code)) -def test_too_old_blockhash(assert_tx_failed, get_contract_with_gas_estimation, w3): +def test_too_old_blockhash(tx_failed, get_contract_with_gas_estimation, w3): w3.testing.mine(257) code = """ @external @@ -31,14 +31,16 @@ def get_50_blockhash() -> bytes32: return blockhash(block.number - 257) """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.get_50_blockhash()) + with tx_failed(): + c.get_50_blockhash() -def test_non_existing_blockhash(assert_tx_failed, get_contract_with_gas_estimation): +def test_non_existing_blockhash(tx_failed, get_contract_with_gas_estimation): code = """ @external def get_future_blockhash() -> bytes32: return blockhash(block.number + 1) """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.get_future_blockhash()) + with tx_failed(): + c.get_future_blockhash() diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9e74019250..9329605678 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -5,7 +5,7 @@ # TODO test functions in this module across all evm versions # once we have cancun support. -def test_nonreentrant_decorator(get_contract, assert_tx_failed): +def test_nonreentrant_decorator(get_contract, tx_failed): calling_contract_code = """ interface SpecialContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable @@ -98,20 +98,23 @@ def unprotected_function(val: String[100], do_callback: bool): assert reentrant_contract.special_value() == "some value" assert reentrant_contract.protected_view_fn() == "some value" - assert_tx_failed(lambda: reentrant_contract.protected_function("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function("zzz value", True, transact={}) reentrant_contract.protected_function2("another value", False, transact={}) assert reentrant_contract.special_value() == "another value" - assert_tx_failed(lambda: reentrant_contract.protected_function2("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function2("zzz value", True, transact={}) reentrant_contract.protected_function3("another value", False, transact={}) assert reentrant_contract.special_value() == "another value" - assert_tx_failed(lambda: reentrant_contract.protected_function3("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function3("zzz value", True, transact={}) -def test_nonreentrant_decorator_for_default(w3, get_contract, assert_tx_failed): +def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): calling_contract_code = """ @external def send_funds(_amount: uint256): @@ -196,9 +199,8 @@ def __default__(): assert w3.eth.get_balance(calling_contract.address) == 2000 # Test protected function with callback to default. - assert_tx_failed( - lambda: reentrant_contract.protected_function("zzz value", True, transact={"value": 1000}) - ) + with tx_failed(): + reentrant_contract.protected_function("zzz value", True, transact={"value": 1000}) def test_disallow_on_init_function(get_contract): diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index 4858a7df0d..ced58e1af0 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -177,14 +177,13 @@ def baz() -> bool: @pytest.mark.parametrize("code", nonpayable_code) -def test_nonpayable_runtime_assertion(w3, keccak, assert_tx_failed, get_contract, code): +def test_nonpayable_runtime_assertion(w3, keccak, tx_failed, get_contract, code): c = get_contract(code) c.foo(transact={"value": 0}) sig = keccak("foo()".encode()).hex()[:10] - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "data": sig, "value": 10**18}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "data": sig, "value": 10**18}) payable_code = [ @@ -355,7 +354,7 @@ def __default__(): w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) -def test_nonpayable_default_func_invalid_calldata(get_contract, w3, assert_tx_failed): +def test_nonpayable_default_func_invalid_calldata(get_contract, w3, tx_failed): code = """ @external @payable @@ -369,12 +368,11 @@ def __default__(): c = get_contract(code) w3.eth.send_transaction({"to": c.address, "value": 0, "data": "0x12345678"}) - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) -def test_batch_nonpayable(get_contract, w3, assert_tx_failed): +def test_batch_nonpayable(get_contract, w3, tx_failed): code = """ @external def foo() -> bool: @@ -390,8 +388,5 @@ def __default__(): data = bytes([1, 2, 3, 4]) for i in range(5): calldata = "0x" + data[:i].hex() - assert_tx_failed( - lambda data=calldata: w3.eth.send_transaction( - {"to": c.address, "value": 100, "data": data} - ) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 100, "data": calldata}) diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 51e6d90ee1..39ea1bb9ae 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -449,7 +449,7 @@ def whoami() -> address: assert logged_addr == addr, "oh no" -def test_nested_static_params_only(get_contract, assert_tx_failed): +def test_nested_static_params_only(get_contract, tx_failed): code1 = """ @internal @view diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index ed6235d992..96b83ae691 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -14,7 +14,7 @@ def repeat(z: int128) -> int128: assert c.repeat(9) == 54 -def test_range_bound(get_contract, assert_tx_failed): +def test_range_bound(get_contract, tx_failed): code = """ @external def repeat(n: uint256) -> uint256: @@ -28,7 +28,8 @@ def repeat(n: uint256) -> uint256: assert c.repeat(n) == sum(i + 1 for i in range(n)) # check codegen inserts assertion for n greater than bound - assert_tx_failed(lambda: c.repeat(7)) + with tx_failed(): + c.repeat(7) def test_digit_reverser(get_contract_with_gas_estimation): @@ -172,7 +173,7 @@ def test(): @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) -def test_for_range_oob_check(get_contract, assert_tx_failed, typ): +def test_for_range_oob_check(get_contract, tx_failed, typ): code = f""" @external def test(): @@ -181,7 +182,8 @@ def test(): pass """ c = get_contract(code) - assert_tx_failed(lambda: c.test()) + with tx_failed(): + c.test() @pytest.mark.parametrize("typ", ["int128", "uint256"]) diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 062cd389a0..7540049778 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -110,7 +110,7 @@ def testin() -> bool: assert_compile_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) -def test_ownership(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): code = """ owners: address[2] @@ -135,7 +135,8 @@ def is_owner() -> bool: assert c.is_owner(call={"from": a1}) is False # no one else is. # only an owner may set another owner. - assert_tx_failed(lambda: c.set_owner(1, a1, call={"from": a1})) + with tx_failed(): + c.set_owner(1, a1, call={"from": a1}) c.set_owner(1, a1, transact={}) assert c.is_owner(call={"from": a1}) is True @@ -145,7 +146,7 @@ def is_owner() -> bool: assert c.is_owner() is False -def test_in_fails_when_types_dont_match(get_contract_with_gas_estimation, assert_tx_failed): +def test_in_fails_when_types_dont_match(get_contract_with_gas_estimation, tx_failed): code = """ @external def testin(x: address) -> bool: @@ -154,4 +155,5 @@ def testin(x: address) -> bool: return True return False """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(code) diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index 842b32d815..af189e6dca 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -3,12 +3,12 @@ # web3 returns f"execution reverted: {err_str}" -# TODO move exception string parsing logic into assert_tx_failed +# TODO move exception string parsing logic into tx_failed def _fixup_err_str(s): return s.replace("execution reverted: ", "") -def test_assert_refund(w3, get_contract_with_gas_estimation, assert_tx_failed): +def test_assert_refund(w3, get_contract_with_gas_estimation, tx_failed): code = """ @external def foo(): @@ -26,7 +26,7 @@ def foo(): assert tx_receipt["gasUsed"] < gas_sent -def test_assert_reason(w3, get_contract_with_gas_estimation, assert_tx_failed, memory_mocker): +def test_assert_reason(w3, get_contract_with_gas_estimation, tx_failed, memory_mocker): code = """ @external def test(a: int128) -> int128: @@ -132,7 +132,7 @@ def test_valid_assertions(get_contract, code): get_contract(code) -def test_assert_staticcall(get_contract, assert_tx_failed, memory_mocker): +def test_assert_staticcall(get_contract, tx_failed, memory_mocker): foreign_code = """ state: uint256 @external @@ -151,10 +151,11 @@ def test(): c1 = get_contract(foreign_code) c2 = get_contract(code, *[c1.address]) # static call prohibits state change - assert_tx_failed(lambda: c2.test()) + with tx_failed(): + c2.test() -def test_assert_in_for_loop(get_contract, assert_tx_failed, memory_mocker): +def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: @@ -166,12 +167,15 @@ def test(x: uint256[3]) -> bool: c = get_contract(code) c.test([1, 2, 3]) - assert_tx_failed(lambda: c.test([5, 1, 3])) - assert_tx_failed(lambda: c.test([1, 5, 3])) - assert_tx_failed(lambda: c.test([1, 3, 5])) + with tx_failed(): + c.test([5, 1, 3]) + with tx_failed(): + c.test([1, 5, 3]) + with tx_failed(): + c.test([1, 3, 5]) -def test_assert_with_reason_in_for_loop(get_contract, assert_tx_failed, memory_mocker): +def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: @@ -183,12 +187,15 @@ def test(x: uint256[3]) -> bool: c = get_contract(code) c.test([1, 2, 3]) - assert_tx_failed(lambda: c.test([5, 1, 3])) - assert_tx_failed(lambda: c.test([1, 5, 3])) - assert_tx_failed(lambda: c.test([1, 3, 5])) + with tx_failed(): + c.test([5, 1, 3]) + with tx_failed(): + c.test([1, 5, 3]) + with tx_failed(): + c.test([1, 3, 5]) -def test_assert_reason_revert_length(w3, get_contract, assert_tx_failed, memory_mocker): +def test_assert_reason_revert_length(w3, get_contract, tx_failed, memory_mocker): code = """ @external def test() -> int128: @@ -196,4 +203,5 @@ def test() -> int128: return 1 """ c = get_contract(code) - assert_tx_failed(lambda: c.test(), exc_text="oops") + with tx_failed(exc_text="oops"): + c.test() diff --git a/tests/functional/codegen/features/test_assert_unreachable.py b/tests/functional/codegen/features/test_assert_unreachable.py index 90ed31a22e..4db00bce7c 100644 --- a/tests/functional/codegen/features/test_assert_unreachable.py +++ b/tests/functional/codegen/features/test_assert_unreachable.py @@ -15,7 +15,7 @@ def foo(): assert tx_receipt["gasUsed"] == gas_sent # Drains all gains sent -def test_basic_unreachable(w3, get_contract, assert_tx_failed): +def test_basic_unreachable(w3, get_contract, tx_failed): code = """ @external def foo(val: int128) -> bool: @@ -28,12 +28,15 @@ def foo(val: int128) -> bool: assert c.foo(2) is True - assert_tx_failed(lambda: c.foo(1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-2), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-2) -def test_basic_call_unreachable(w3, get_contract, assert_tx_failed): +def test_basic_call_unreachable(w3, get_contract, tx_failed): code = """ @view @@ -51,11 +54,13 @@ def foo(val: int128) -> int128: assert c.foo(33) == -123 - assert_tx_failed(lambda: c.foo(1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-1), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-1) -def test_raise_unreachable(w3, get_contract, assert_tx_failed): +def test_raise_unreachable(w3, get_contract, tx_failed): code = """ @external def foo(): @@ -64,4 +69,5 @@ def foo(): c = get_contract(code) - assert_tx_failed(lambda: c.foo(), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo() diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 263f10a89c..6db8570fc7 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -33,7 +33,7 @@ def _make_invalid_dynarray_tx(w3, address, signature, data): w3.eth.send_transaction({"to": address, "data": f"0x{sig}{data}"}) -def test_bytes_clamper(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ @external def foo(s: Bytes[3]) -> Bytes[3]: @@ -43,10 +43,11 @@ def foo(s: Bytes[3]) -> Bytes[3]: c = get_contract_with_gas_estimation(clamper_test_code) assert c.foo(b"ca") == b"ca" assert c.foo(b"cat") == b"cat" - assert_tx_failed(lambda: c.foo(b"cate")) + with tx_failed(): + c.foo(b"cate") -def test_bytes_clamper_multiple_slots(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper_multiple_slots(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ @external def foo(s: Bytes[40]) -> Bytes[40]: @@ -58,10 +59,11 @@ def foo(s: Bytes[40]) -> Bytes[40]: assert c.foo(data[:30]) == data[:30] assert c.foo(data) == data - assert_tx_failed(lambda: c.foo(data + b"!")) + with tx_failed(): + c.foo(data + b"!") -def test_bytes_clamper_on_init(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] @@ -77,7 +79,8 @@ def get_foo() -> Bytes[3]: c = get_contract_with_gas_estimation(clamper_test_code, *[b"cat"]) assert c.get_foo() == b"cat" - assert_tx_failed(lambda: get_contract_with_gas_estimation(clamper_test_code, *[b"cats"])) + with tx_failed(): + get_contract_with_gas_estimation(clamper_test_code, *[b"cats"]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -99,7 +102,7 @@ def foo(s: bytes{n}) -> bytes{n}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(1, 32))) # bytes32 always passes -def test_bytes_m_clamper_failing(w3, get_contract, assert_tx_failed, n, evm_version): +def test_bytes_m_clamper_failing(w3, get_contract, tx_failed, n, evm_version): values = [] values.append(b"\x00" * n + b"\x80") # just one bit set values.append(b"\xff" * n + b"\x80") # n*8 + 1 bits set @@ -118,11 +121,9 @@ def foo(s: bytes{n}) -> bytes{n}: c = get_contract(code, evm_version=evm_version) for v in values: # munge for `_make_tx` - assert_tx_failed( - lambda val=int.from_bytes(v, byteorder="big"): _make_tx( - w3, c.address, f"foo(bytes{n})", [val] - ) - ) + with tx_failed(): + int_value = int.from_bytes(v, byteorder="big") + _make_tx(w3, c.address, f"foo(bytes{n})", [int_value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -144,7 +145,7 @@ def foo(s: int{bits}) -> int{bits}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(31))) # int256 does not clamp -def test_sint_clamper_failing(w3, assert_tx_failed, get_contract, n, evm_version): +def test_sint_clamper_failing(w3, tx_failed, get_contract, n, evm_version): bits = 8 * (n + 1) lo, hi = int_bounds(True, bits) values = [-(2**255), 2**255 - 1, lo - 1, hi + 1] @@ -156,7 +157,8 @@ def foo(s: int{bits}) -> int{bits}: c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(int{bits})", [val])) + with tx_failed(): + _make_tx(w3, c.address, f"foo(int{bits})", [v]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -174,7 +176,7 @@ def foo(s: bool) -> bool: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2, 3, 4, 8, 16, 2**256 - 1]) -def test_bool_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_bool_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: bool) -> bool: @@ -182,7 +184,8 @@ def foo(s: bool) -> bool: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(bool)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(bool)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -207,7 +210,7 @@ def foo(s: Roles) -> Roles: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**i for i in range(5, 256)]) -def test_flag_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_flag_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ flag Roles: USER @@ -222,7 +225,8 @@ def foo(s: Roles) -> Roles: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(uint256)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(uint256)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -243,7 +247,7 @@ def foo(s: uint{bits}) -> uint{bits}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(31))) # uint256 has no failing cases -def test_uint_clamper_failing(w3, assert_tx_failed, get_contract, evm_version, n): +def test_uint_clamper_failing(w3, tx_failed, get_contract, evm_version, n): bits = 8 * (n + 1) values = [-1, -(2**255), 2**bits] code = f""" @@ -253,7 +257,8 @@ def foo(s: uint{bits}) -> uint{bits}: """ c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(uint{bits})", [val])) + with tx_failed(): + _make_tx(w3, c.address, f"foo(uint{bits})", [v]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -284,7 +289,7 @@ def foo(s: address) -> address: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**160, 2**256 - 1]) -def test_address_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_address_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: address) -> address: @@ -292,7 +297,8 @@ def foo(s: address) -> address: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(address)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(address)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -337,7 +343,7 @@ def foo(s: decimal) -> decimal: -187072209578355573530071658587684226515959365500929, # - (2 ** 127 - 1e-10) ], ) -def test_decimal_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_decimal_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: decimal) -> decimal: @@ -346,7 +352,8 @@ def foo(s: decimal) -> decimal: c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(fixed168x10)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(fixed168x10)", [value]) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -366,7 +373,7 @@ def foo(a: uint256, b: int128[5], c: uint256) -> int128[5]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(5)) -def test_int128_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_array_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -378,7 +385,8 @@ def foo(b: int128[5]) -> int128[5]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[5])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[5])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -397,7 +405,7 @@ def foo(a: uint256, b: int128[10], c: uint256) -> int128[10]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_int128_array_looped_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_array_looped_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: int128[10]) -> int128[10]: @@ -408,7 +416,8 @@ def foo(b: int128[10]) -> int128[10]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[10])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[10])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -427,7 +436,7 @@ def foo(a: uint256, b: int128[6][3][1][8], c: uint256) -> int128[6][3][1][8]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(12)) -def test_multidimension_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_multidimension_array_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: int128[6][1][2]) -> int128[6][1][2]: @@ -438,7 +447,8 @@ def foo(b: int128[6][1][2]) -> int128[6][1][2]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[6][1][2]])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[6][1][2]])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -458,7 +468,7 @@ def foo(a: uint256, b: DynArray[int128, 5], c: uint256) -> DynArray[int128, 5]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(5)) -def test_int128_dynarray_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_dynarray_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -473,7 +483,8 @@ def foo(b: int128[5]) -> int128[5]: c = get_contract(code) data = _make_dynarray_data(32, 5, values) - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -491,7 +502,7 @@ def foo(a: uint256, b: DynArray[int128, 10], c: uint256) -> DynArray[int128, 10] @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_int128_dynarray_looped_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_dynarray_looped_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @@ -505,7 +516,8 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: data = _make_dynarray_data(32, 10, values) signature = "foo(int128[])" - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -527,9 +539,7 @@ def foo( @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(4)) -def test_multidimension_dynarray_clamper_failing( - w3, assert_tx_failed, get_contract, bad_value, idx -): +def test_multidimension_dynarray_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: DynArray[DynArray[int128, 2], 2]) -> DynArray[DynArray[int128, 2], 2]: @@ -549,7 +559,8 @@ def foo(b: DynArray[DynArray[int128, 2], 2]) -> DynArray[DynArray[int128, 2], 2] signature = "foo(int128[][])" c = get_contract(code) - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -570,7 +581,7 @@ def foo( @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_dynarray_list_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_dynarray_list_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -588,4 +599,5 @@ def foo(b: DynArray[int128[5], 2]) -> DynArray[int128[5], 2]: c = get_contract(code) signature = "foo(int128[5][])" - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index 29a466e869..fc765f8ab3 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -24,7 +24,7 @@ def __init__(a: uint256): assert "CALLDATALOAD" not in assembly[:ir_return_idx_start] + assembly[ir_return_idx_end:] -def test_init_calls_internal(get_contract, assert_compile_failed, assert_tx_failed): +def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) @internal @@ -46,7 +46,8 @@ def baz() -> uint8: n = 6 c = get_contract(code, n) assert c.foo() == n * 7 - assert_tx_failed(lambda: c.baz()) + with tx_failed(): + c.baz() n = 255 assert_compile_failed(lambda: get_contract(code, n)) diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 84311c41f5..ba09be1991 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -3,6 +3,7 @@ import pytest from eth.codecs import abi +from vyper import compile_code from vyper.exceptions import ( ArgumentException, EventDeclarationException, @@ -193,7 +194,7 @@ def bar(): def test_event_logging_cannot_have_more_than_three_topics( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -203,9 +204,8 @@ def test_event_logging_cannot_have_more_than_three_topics( arg4: indexed(int128) """ - assert_tx_failed( - lambda: get_contract_with_gas_estimation(loggy_code), EventDeclarationException - ) + with pytest.raises(EventDeclarationException): + compile_code(loggy_code) def test_event_logging_with_data(w3, tester, keccak, get_logs, get_contract_with_gas_estimation): @@ -555,7 +555,7 @@ def foo(): assert args.arg2 == {"x": 1, "y": b"abc", "z": {"t": "house", "w": Decimal("13.5")}} -def test_fails_when_input_is_the_wrong_type(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_input_is_the_wrong_type(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -565,10 +565,11 @@ def foo_(): log MyLog(b'yo') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_topic_is_the_wrong_size(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_topic_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(Bytes[3]) @@ -579,12 +580,11 @@ def foo(): log MyLog(b'bars') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_input_topic_is_the_wrong_size( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_fails_when_input_topic_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(Bytes[3]) @@ -594,10 +594,11 @@ def foo(arg1: Bytes[4]): log MyLog(arg1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_data_is_the_wrong_size(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_data_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -607,12 +608,11 @@ def foo(): log MyLog(b'bars') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_input_data_is_the_wrong_size( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_fails_when_input_data_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -622,7 +622,8 @@ def foo(arg1: Bytes[4]): log MyLog(arg1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) def test_topic_over_32_bytes(get_contract_with_gas_estimation): @@ -637,7 +638,7 @@ def foo(): get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_over_three_topics(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -650,12 +651,11 @@ def __init__(): log MyLog(1, 2, 3, 4) """ - assert_tx_failed( - lambda: get_contract_with_gas_estimation(loggy_code), EventDeclarationException - ) + with tx_failed(EventDeclarationException): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_duplicate_log_names(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_duplicate_log_names(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: pass event MyLog: pass @@ -665,12 +665,11 @@ def foo(): log MyLog() """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), NamespaceCollision) + with tx_failed(NamespaceCollision): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_when_log_is_undeclared( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_logging_fails_with_when_log_is_undeclared(tx_failed, get_contract_with_gas_estimation): loggy_code = """ @external @@ -678,10 +677,11 @@ def foo(): log MyLog() """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), UndeclaredDefinition) + with tx_failed(UndeclaredDefinition): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_topic_type_mismatch(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_topic_type_mismatch(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -691,10 +691,11 @@ def foo(): log MyLog(self) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_data_type_mismatch(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_data_type_mismatch(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -704,11 +705,12 @@ def foo(): log MyLog(self) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) def test_logging_fails_when_number_of_arguments_is_greater_than_declaration( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -718,11 +720,12 @@ def test_logging_fails_when_number_of_arguments_is_greater_than_declaration( def foo(): log MyLog(1, 2) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), ArgumentException) + with tx_failed(ArgumentException): + get_contract_with_gas_estimation(loggy_code) def test_logging_fails_when_number_of_arguments_is_less_than_declaration( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -733,7 +736,8 @@ def test_logging_fails_when_number_of_arguments_is_less_than_declaration( def foo(): log MyLog(1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), ArgumentException) + with tx_failed(ArgumentException): + get_contract_with_gas_estimation(loggy_code) def test_loggy_code(w3, tester, get_contract_with_gas_estimation): @@ -962,7 +966,7 @@ def set_list(): ] -def test_logging_fails_when_input_is_too_big(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_when_input_is_too_big(tx_failed, get_contract_with_gas_estimation): code = """ event Bar: _value: indexed(Bytes[32]) @@ -971,7 +975,8 @@ def test_logging_fails_when_input_is_too_big(assert_tx_failed, get_contract_with def foo(inp: Bytes[33]): log Bar(inp) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(code) def test_2nd_var_list_packing(get_logs, get_contract_with_gas_estimation): diff --git a/tests/functional/codegen/features/test_reverting.py b/tests/functional/codegen/features/test_reverting.py index 2cdc727015..f24886ce96 100644 --- a/tests/functional/codegen/features/test_reverting.py +++ b/tests/functional/codegen/features/test_reverting.py @@ -7,7 +7,7 @@ pytestmark = pytest.mark.usefixtures("memory_mocker") -def test_revert_reason(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -17,14 +17,11 @@ def foo(): revert_bytes = method_id("NoFives()") - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) -def test_revert_reason_typed(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason_typed(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -35,14 +32,11 @@ def foo(): revert_bytes = method_id("NoFives(uint256)") + abi.encode("(uint256)", (5,)) - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) -def test_revert_reason_typed_no_variable(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason_typed_no_variable(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -52,8 +46,5 @@ def foo(): revert_bytes = method_id("NoFives(uint256)") + abi.encode("(uint256)", (5,)) - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 1578f5a418..70e7cb4594 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -1,7 +1,7 @@ # from ethereum.tools import tester -def test_arbitration_code(w3, get_contract_with_gas_estimation, assert_tx_failed): +def test_arbitration_code(w3, get_contract_with_gas_estimation, tx_failed): arbitration_code = """ buyer: address seller: address @@ -28,13 +28,14 @@ def refund(): a0, a1, a2 = w3.eth.accounts[:3] c = get_contract_with_gas_estimation(arbitration_code, value=1) c.setup(a1, a2, transact={}) - assert_tx_failed(lambda: c.finalize(transact={"from": a1})) + with tx_failed(): + c.finalize(transact={"from": a1}) c.finalize(transact={}) print("Passed escrow test") -def test_arbitration_code_with_init(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimation): arbitration_code_with_init = """ buyer: address seller: address @@ -60,7 +61,8 @@ def refund(): """ a0, a1, a2 = w3.eth.accounts[:3] c = get_contract_with_gas_estimation(arbitration_code_with_init, *[a1, a2], value=1) - assert_tx_failed(lambda: c.finalize(transact={"from": a1})) + with tx_failed(): + c.finalize(transact={"from": a1}) c.finalize(transact={"from": a0}) print("Passed escrow test with initializer") diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3544f4a965..65d2df9038 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -427,7 +427,7 @@ def test(addr: address): # test data returned from external interface gets clamped @pytest.mark.parametrize("typ", ("int128", "uint8")) -def test_external_interface_int_clampers(get_contract, assert_tx_failed, typ): +def test_external_interface_int_clampers(get_contract, tx_failed, typ): external_contract = f""" @external def ok() -> {typ}: @@ -474,13 +474,16 @@ def test_fail3() -> int256: assert bad_c.should_fail() == -(2**255) assert c.test_ok() == 1 - assert_tx_failed(lambda: c.test_fail()) - assert_tx_failed(lambda: c.test_fail2()) - assert_tx_failed(lambda: c.test_fail3()) + with tx_failed(): + c.test_fail() + with tx_failed(): + c.test_fail2() + with tx_failed(): + c.test_fail3() # test data returned from external interface gets clamped -def test_external_interface_bytes_clampers(get_contract, assert_tx_failed): +def test_external_interface_bytes_clampers(get_contract, tx_failed): external_contract = """ @external def ok() -> Bytes[2]: @@ -522,14 +525,14 @@ def test_fail2() -> Bytes[3]: assert bad_c.should_fail() == b"123" assert c.test_ok() == b"12" - assert_tx_failed(lambda: c.test_fail1()) - assert_tx_failed(lambda: c.test_fail2()) + with tx_failed(): + c.test_fail1() + with tx_failed(): + c.test_fail2() # test data returned from external interface gets clamped -def test_json_abi_bytes_clampers( - get_contract, assert_tx_failed, assert_compile_failed, make_input_bundle -): +def test_json_abi_bytes_clampers(get_contract, tx_failed, assert_compile_failed, make_input_bundle): external_contract = """ @external def returns_Bytes3() -> Bytes[3]: @@ -584,9 +587,12 @@ def test_fail3() -> Bytes[3]: c = get_contract(code, bad_c.address, input_bundle=input_bundle) assert bad_c.returns_Bytes3() == b"123" - assert_tx_failed(lambda: c.test_fail1()) - assert_tx_failed(lambda: c.test_fail2()) - assert_tx_failed(lambda: c.test_fail3()) + with tx_failed(): + c.test_fail1() + with tx_failed(): + c.test_fail2() + with tx_failed(): + c.test_fail3() def test_units_interface(w3, get_contract, make_input_bundle): diff --git a/tests/functional/codegen/test_selector_table.py b/tests/functional/codegen/test_selector_table.py index abea81ced4..94233977c9 100644 --- a/tests/functional/codegen/test_selector_table.py +++ b/tests/functional/codegen/test_selector_table.py @@ -512,9 +512,7 @@ def generate_methods(draw, max_calldata_bytes): # dense selector table packing boundaries at 256 and 65336 @pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) @pytest.mark.fuzzing -def test_selector_table_fuzz( - max_calldata_bytes, opt_level, w3, get_contract, assert_tx_failed, get_logs -): +def test_selector_table_fuzz(max_calldata_bytes, opt_level, w3, get_contract, tx_failed, get_logs): def abi_sig(func_id, calldata_words, n_default_args): params = [] if not calldata_words else [f"uint256[{calldata_words}]"] params.extend(["uint256"] * n_default_args) @@ -600,7 +598,8 @@ def __default__(): else: hexstr = (method_id + argsdata).hex() txdata = {"to": c.address, "data": hexstr, "value": 1} - assert_tx_failed(lambda d=txdata: w3.eth.send_transaction(d)) + with tx_failed(): + w3.eth.send_transaction(txdata) # now do calldatasize check # strip some bytes @@ -610,7 +609,8 @@ def __default__(): if n_calldata_words == 0 and j == 0: # no args, hit default function if default_fn_mutability == "": - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) elif default_fn_mutability == "@payable": # we should be able to send eth to it tx_params["value"] = 1 @@ -628,8 +628,10 @@ def __default__(): # check default function reverts tx_params["value"] = 1 - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) else: - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) _test() diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py index 8e634e5868..2abc164689 100644 --- a/tests/functional/codegen/test_stateless_modules.py +++ b/tests/functional/codegen/test_stateless_modules.py @@ -186,7 +186,7 @@ def qux() -> library.SomeStruct: # test calls to library functions in statement position -def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): +def test_library_statement_calls(get_contract, make_input_bundle, tx_failed): library_source = """ from vyper.interfaces import ERC20 @internal @@ -211,7 +211,8 @@ def foo(x: uint256): assert c.counter() == 7 - assert_tx_failed(lambda: c.foo(8)) + with tx_failed(): + c.foo(8) def test_library_is_typechecked(make_input_bundle): diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index 25617651ec..8244bc5487 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -8,6 +8,13 @@ from vyper.utils import MemoryPositions +def search_for_sublist(ir, sublist): + _list = ir.to_list() if hasattr(ir, "to_list") else ir + if _list == sublist: + return True + return isinstance(_list, list) and any(search_for_sublist(i, sublist) for i in _list) + + def test_builtin_constants(get_contract_with_gas_estimation): code = """ @external @@ -192,7 +199,7 @@ def test() -> Bytes[100]: assert c.test() == test_str -def test_constant_folds(search_for_sublist): +def test_constant_folds(): some_prime = 10013677 code = f""" SOME_CONSTANT: constant(uint256) = 11 + 1 @@ -205,11 +212,9 @@ def test() -> uint256: ret: uint256 = 2**SOME_CONSTANT * SOME_PRIME return ret """ - ir = compile_code(code, output_formats=["ir"])["ir"] - assert search_for_sublist( - ir, ["mstore", [MemoryPositions.RESERVED_MEMORY], [2**12 * some_prime]] - ) + search = ["mstore", [MemoryPositions.RESERVED_MEMORY], [2**12 * some_prime]] + assert search_for_sublist(ir, search) def test_constant_lists(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 1418eab063..25dc1f1a1e 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -156,7 +156,7 @@ def iarg() -> uint256: print("Passed fractional multiplication test") -def test_mul_overflow(assert_tx_failed, get_contract_with_gas_estimation): +def test_mul_overflow(tx_failed, get_contract_with_gas_estimation): mul_code = """ @external @@ -170,12 +170,14 @@ def _num_mul(x: decimal, y: decimal) -> decimal: x = Decimal("85070591730234615865843651857942052864") y = Decimal("136112946768375385385349842973") - assert_tx_failed(lambda: c._num_mul(x, y)) + with tx_failed(): + c._num_mul(x, y) x = SizeLimits.MAX_AST_DECIMAL y = 1 + DECIMAL_EPSILON - assert_tx_failed(lambda: c._num_mul(x, y)) + with tx_failed(): + c._num_mul(x, y) assert c._num_mul(x, Decimal(1)) == x @@ -186,7 +188,7 @@ def _num_mul(x: decimal, y: decimal) -> decimal: # division failure modes(!) -def test_div_overflow(get_contract, assert_tx_failed): +def test_div_overflow(get_contract, tx_failed): code = """ @external def foo(x: decimal, y: decimal) -> decimal: @@ -198,32 +200,39 @@ def foo(x: decimal, y: decimal) -> decimal: x = SizeLimits.MIN_AST_DECIMAL y = -DECIMAL_EPSILON - assert_tx_failed(lambda: c.foo(x, y)) - assert_tx_failed(lambda: c.foo(x, Decimal(0))) - assert_tx_failed(lambda: c.foo(y, Decimal(0))) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + c.foo(x, Decimal(0)) + with tx_failed(): + c.foo(y, Decimal(0)) y = Decimal(1) - DECIMAL_EPSILON # 0.999999999 - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) y = Decimal(-1) - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) assert c.foo(x, Decimal(1)) == x assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) x = SizeLimits.MAX_AST_DECIMAL - assert_tx_failed(lambda: c.foo(x, DECIMAL_EPSILON)) + with tx_failed(): + c.foo(x, DECIMAL_EPSILON) y = Decimal(1) - DECIMAL_EPSILON - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) assert c.foo(x, Decimal(1)) == x assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) -def test_decimal_min_max_literals(assert_tx_failed, get_contract_with_gas_estimation): +def test_decimal_min_max_literals(tx_failed, get_contract_with_gas_estimation): code = """ @external def maximum(): diff --git a/tests/functional/codegen/types/numbers/test_exponents.py b/tests/functional/codegen/types/numbers/test_exponents.py index 5726e4c1ca..e958436efb 100644 --- a/tests/functional/codegen/types/numbers/test_exponents.py +++ b/tests/functional/codegen/types/numbers/test_exponents.py @@ -7,7 +7,7 @@ @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 255)) -def test_exp_uint256(get_contract, assert_tx_failed, power): +def test_exp_uint256(get_contract, tx_failed, power): code = f""" @external def foo(a: uint256) -> uint256: @@ -20,12 +20,13 @@ def foo(a: uint256) -> uint256: c = get_contract(code) c.foo(max_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) + with tx_failed(): + c.foo(max_base + 1) @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 127)) -def test_exp_int128(get_contract, assert_tx_failed, power): +def test_exp_int128(get_contract, tx_failed, power): code = f""" @external def foo(a: int128) -> int128: @@ -44,13 +45,15 @@ def foo(a: int128) -> int128: c.foo(max_base) c.foo(min_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) - assert_tx_failed(lambda: c.foo(min_base - 1)) + with tx_failed(): + c.foo(max_base + 1) + with tx_failed(): + c.foo(min_base - 1) @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 15)) -def test_exp_int16(get_contract, assert_tx_failed, power): +def test_exp_int16(get_contract, tx_failed, power): code = f""" @external def foo(a: int16) -> int16: @@ -69,8 +72,10 @@ def foo(a: int16) -> int16: c.foo(max_base) c.foo(min_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) - assert_tx_failed(lambda: c.foo(min_base - 1)) + with tx_failed(): + c.foo(max_base + 1) + with tx_failed(): + c.foo(min_base - 1) @pytest.mark.fuzzing @@ -93,7 +98,7 @@ def foo(a: int16) -> int16: # 256 bits @example(a=2**256 - 1) @settings(max_examples=200) -def test_max_exp(get_contract, assert_tx_failed, a): +def test_max_exp(get_contract, tx_failed, a): code = f""" @external def foo(b: uint256) -> uint256: @@ -108,7 +113,8 @@ def foo(b: uint256) -> uint256: assert a ** (max_power + 1) >= 2**256 c.foo(max_power) - assert_tx_failed(lambda: c.foo(max_power + 1)) + with tx_failed(): + c.foo(max_power + 1) @pytest.mark.fuzzing @@ -128,7 +134,7 @@ def foo(b: uint256) -> uint256: # 128 bits @example(a=2**127 - 1) @settings(max_examples=200) -def test_max_exp_int128(get_contract, assert_tx_failed, a): +def test_max_exp_int128(get_contract, tx_failed, a): code = f""" @external def foo(b: int128) -> int128: @@ -143,4 +149,5 @@ def foo(b: int128) -> int128: assert not -(2**127) <= a ** (max_power + 1) < 2**127 c.foo(max_power) - assert_tx_failed(lambda: c.foo(max_power + 1)) + with tx_failed(): + c.foo(max_power + 1) diff --git a/tests/functional/codegen/types/numbers/test_modulo.py b/tests/functional/codegen/types/numbers/test_modulo.py index 018a406baa..465426cd1d 100644 --- a/tests/functional/codegen/types/numbers/test_modulo.py +++ b/tests/functional/codegen/types/numbers/test_modulo.py @@ -31,14 +31,15 @@ def num_modulo_decimal() -> decimal: assert c.num_modulo_decimal() == Decimal(".5") -def test_modulo_with_input_of_zero(assert_tx_failed, get_contract_with_gas_estimation): +def test_modulo_with_input_of_zero(tx_failed, get_contract_with_gas_estimation): code = """ @external def foo(a: decimal, b: decimal) -> decimal: return a % b """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.foo(Decimal("1"), Decimal("0"))) + with tx_failed(): + c.foo(Decimal("1"), Decimal("0")) def test_literals_vs_evm(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 3e44beb826..52de5b649f 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("typ", types) -def test_exponent_base_zero(get_contract, assert_tx_failed, typ): +def test_exponent_base_zero(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -25,12 +25,14 @@ def foo(x: {typ}) -> {typ}: assert c.foo(1) == 0 assert c.foo(hi) == 0 - assert_tx_failed(lambda: c.foo(-1)) - assert_tx_failed(lambda: c.foo(lo)) # note: lo < 0 + with tx_failed(): + c.foo(-1) + with tx_failed(): + c.foo(lo) # note: lo < 0 @pytest.mark.parametrize("typ", types) -def test_exponent_base_one(get_contract, assert_tx_failed, typ): +def test_exponent_base_one(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -43,8 +45,10 @@ def foo(x: {typ}) -> {typ}: assert c.foo(1) == 1 assert c.foo(hi) == 1 - assert_tx_failed(lambda: c.foo(-1)) - assert_tx_failed(lambda: c.foo(lo)) + with tx_failed(): + c.foo(-1) + with tx_failed(): + c.foo(lo) def test_exponent_base_minus_one(get_contract): @@ -63,7 +67,7 @@ def foo(x: int256) -> int256: # TODO: make this test pass @pytest.mark.parametrize("base", (0, 1)) -def test_exponent_negative_power(get_contract, assert_tx_failed, base): +def test_exponent_negative_power(get_contract, tx_failed, base): # #2985 code = f""" @external @@ -73,7 +77,8 @@ def bar() -> int16: """ c = get_contract(code) # known bug: 2985 - assert_tx_failed(lambda: c.bar()) + with tx_failed(): + c.bar() def test_exponent_min_int16(get_contract): @@ -103,7 +108,7 @@ def foo() -> int256: @pytest.mark.parametrize("typ", types) -def test_exponent(get_contract, assert_tx_failed, typ): +def test_exponent(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -116,7 +121,8 @@ def foo(x: {typ}) -> {typ}: test_cases = [0, 1, 3, 4, 126, 127, -1, lo, hi] for x in test_cases: if x * 2 >= typ.bits or x < 0: # out of bounds - assert_tx_failed(lambda p=x: c.foo(p)) + with tx_failed(): + c.foo(x) else: assert c.foo(x) == 4**x @@ -145,7 +151,7 @@ def negative_four() -> {typ}: @pytest.mark.parametrize("typ", types) -def test_num_bound(assert_tx_failed, get_contract_with_gas_estimation, typ): +def test_num_bound(tx_failed, get_contract_with_gas_estimation, typ): lo, hi = typ.ast_bounds num_bound_code = f""" @@ -180,16 +186,22 @@ def _num_min() -> {typ}: assert c._num_sub(lo, 0) == lo assert c._num_add(hi - 1, 1) == hi assert c._num_sub(lo + 1, 1) == lo - assert_tx_failed(lambda: c._num_add(hi, 1)) - assert_tx_failed(lambda: c._num_sub(lo, 1)) - assert_tx_failed(lambda: c._num_add(hi - 1, 2)) - assert_tx_failed(lambda: c._num_sub(lo + 1, 2)) + with tx_failed(): + c._num_add(hi, 1) + with tx_failed(): + c._num_sub(lo, 1) + with tx_failed(): + c._num_add(hi - 1, 2) + with tx_failed(): + c._num_sub(lo + 1, 2) assert c._num_max() == hi assert c._num_min() == lo - assert_tx_failed(lambda: c._num_add3(hi, 1, -1)) + with tx_failed(): + c._num_add3(hi, 1, -1) assert c._num_add3(hi, -1, 1) == hi - 1 + 1 - assert_tx_failed(lambda: c._num_add3(lo, -1, 1)) + with tx_failed(): + c._num_add3(lo, -1, 1) assert c._num_add3(lo, 1, -1) == lo + 1 - 1 @@ -219,7 +231,7 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, assert_tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): # both variables code_1 = f""" @external @@ -304,14 +316,19 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) + with tx_failed(): + c.foo(x, y) assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + with tx_failed(): + get_contract(code_3).foo(y) assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) else: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_tx_failed(lambda p=x, code=code_2: get_contract(code).foo(p)) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + get_contract(code_2).foo(x) + with tx_failed(): + get_contract(code_3).foo(y) assert_compile_failed( lambda code=code_4: get_contract(code), (InvalidType, OverflowException) ) @@ -372,7 +389,7 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_negation(get_contract, assert_tx_failed, typ): +def test_negation(get_contract, tx_failed, typ): code = f""" @external def foo(a: {typ}) -> {typ}: @@ -390,7 +407,8 @@ def foo(a: {typ}) -> {typ}: assert c.foo(2) == -2 assert c.foo(-2) == 2 - assert_tx_failed(lambda: c.foo(lo)) + with tx_failed(): + c.foo(lo) @pytest.mark.parametrize("typ", types) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 6c8d114f29..8982065b5d 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -85,7 +85,7 @@ def foo(x: {typ}) -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, assert_tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): # both variables code_1 = f""" @external @@ -148,17 +148,23 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) - assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) + with tx_failed(): + c.foo(x, y) + with pytest.raises(ZeroDivisionException): + get_contract(code_2) + with tx_failed(): + get_contract(code_3).foo(y) + with pytest.raises(ZeroDivisionException): + get_contract(code_4) else: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_tx_failed(lambda code=code_2, p=x: get_contract(code).foo(p)) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) - assert_compile_failed( - lambda code=code_4: get_contract(code), (InvalidType, OverflowException) - ) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + get_contract(code_2).foo(x) + with tx_failed(): + get_contract(code_3).foo(y) + with pytest.raises((InvalidType, OverflowException)): + get_contract(code_4) COMPARISON_OPS = { diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 01ec75d5c1..1ee9b8d835 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -3,7 +3,7 @@ from vyper.exceptions import InvalidType, TypeMismatch -def test_test_bytes(get_contract_with_gas_estimation, assert_tx_failed): +def test_test_bytes(get_contract_with_gas_estimation, tx_failed): test_bytes = """ @external def foo(x: Bytes[100]) -> Bytes[100]: @@ -21,7 +21,8 @@ def foo(x: Bytes[100]) -> Bytes[100]: print("Passed max-length bytes test") # test for greater than 100 bytes, should raise exception - assert_tx_failed(lambda: c.foo(b"\x35" * 101)) + with tx_failed(): + c.foo(b"\x35" * 101) print("Passed input-too-long test") diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d793a56d6e..4ef6874ae9 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -759,27 +759,30 @@ def test_multi4_2() -> DynArray[DynArray[DynArray[DynArray[uint256, 2], 2], 2], assert c.test_multi4_2() == nest4 -def test_uint256_accessor(get_contract_with_gas_estimation, assert_tx_failed): +def test_uint256_accessor(get_contract_with_gas_estimation, tx_failed): code = """ @external def bounds_check_uint256(xs: DynArray[uint256, 3], ix: uint256) -> uint256: return xs[ix] """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.bounds_check_uint256([], 0)) + with tx_failed(): + c.bounds_check_uint256([], 0) assert c.bounds_check_uint256([1], 0) == 1 - assert_tx_failed(lambda: c.bounds_check_uint256([1], 1)) + with tx_failed(): + c.bounds_check_uint256([1], 1) assert c.bounds_check_uint256([1, 2, 3], 0) == 1 assert c.bounds_check_uint256([1, 2, 3], 2) == 3 - assert_tx_failed(lambda: c.bounds_check_uint256([1, 2, 3], 3)) + with tx_failed(): + c.bounds_check_uint256([1, 2, 3], 3) # TODO do bounds checks for nested darrays @pytest.mark.parametrize("list_", ([], [11], [11, 12], [11, 12, 13])) -def test_dynarray_len(get_contract_with_gas_estimation, assert_tx_failed, list_): +def test_dynarray_len(get_contract_with_gas_estimation, tx_failed, list_): code = """ @external def darray_len(xs: DynArray[uint256, 3]) -> uint256: @@ -790,7 +793,7 @@ def darray_len(xs: DynArray[uint256, 3]) -> uint256: assert c.darray_len(list_) == len(list_) -def test_dynarray_too_large(get_contract_with_gas_estimation, assert_tx_failed): +def test_dynarray_too_large(get_contract_with_gas_estimation, tx_failed): code = """ @external def darray_len(xs: DynArray[uint256, 3]) -> uint256: @@ -798,10 +801,11 @@ def darray_len(xs: DynArray[uint256, 3]) -> uint256: """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.darray_len([1, 2, 3, 4])) + with tx_failed(): + c.darray_len([1, 2, 3, 4]) -def test_int128_accessor(get_contract_with_gas_estimation, assert_tx_failed): +def test_int128_accessor(get_contract_with_gas_estimation, tx_failed): code = """ @external def bounds_check_int128(ix: int128) -> uint256: @@ -811,8 +815,10 @@ def bounds_check_int128(ix: int128) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check_int128(0) == 1 assert c.bounds_check_int128(2) == 3 - assert_tx_failed(lambda: c.bounds_check_int128(3)) - assert_tx_failed(lambda: c.bounds_check_int128(-1)) + with tx_failed(): + c.bounds_check_int128(3) + with tx_failed(): + c.bounds_check_int128(-1) def test_index_exception(get_contract_with_gas_estimation, assert_compile_failed): @@ -1164,12 +1170,13 @@ def test_invalid_append_pop(get_contract, assert_compile_failed, code, exception @pytest.mark.parametrize("code,check_result", append_pop_tests) # TODO change this to fuzz random data @pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) -def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_data): +def test_append_pop(get_contract, tx_failed, code, check_result, test_data): c = get_contract(code) expected_result = check_result(test_data) if expected_result is None: # None is sentinel to indicate txn should revert - assert_tx_failed(lambda: c.foo(test_data)) + with tx_failed(): + c.foo(test_data) else: assert c.foo(test_data) == expected_result @@ -1234,7 +1241,7 @@ def foo(x: {typ}) -> {typ}: ["uint256[3]", "DynArray[uint256,3]", "DynArray[uint8, 4]", "Foo", "DynArray[Foobar, 3]"], ) # TODO change this to fuzz random data -def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check_result, subtype): +def test_append_pop_complex(get_contract, tx_failed, code_template, check_result, subtype): code = code_template.format(typ=subtype) test_data = [1, 2, 3] if subtype == "Foo": @@ -1260,7 +1267,8 @@ def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check expected_result = check_result(test_data) if expected_result is None: # None is sentinel to indicate txn should revert - assert_tx_failed(lambda: c.foo(test_data)) + with tx_failed(): + c.foo(test_data) else: assert c.foo(test_data) == expected_result @@ -1330,7 +1338,7 @@ def bar(_baz: DynArray[Foo, 3]) -> String[96]: assert c.bar(c_input) == "Hello world!!!!" -def test_list_of_structs_lists_with_nested_lists(get_contract, assert_tx_failed): +def test_list_of_structs_lists_with_nested_lists(get_contract, tx_failed): code = """ struct Bar: a: DynArray[uint8[2], 2] @@ -1351,7 +1359,8 @@ def foo(x: uint8) -> uint8: """ c = get_contract(code) assert c.foo(17) == 98 - assert_tx_failed(lambda: c.foo(241)) + with tx_failed(): + c.foo(241) def test_list_of_nested_struct_arrays(get_contract): @@ -1622,7 +1631,7 @@ def bar() -> uint256: assert c.bar() == 58 -def test_constant_list(get_contract, assert_tx_failed): +def test_constant_list(get_contract, tx_failed): some_good_primes = [5.0, 11.0, 17.0, 29.0, 37.0, 41.0] code = f""" MY_LIST: constant(DynArray[decimal, 6]) = {some_good_primes} @@ -1634,7 +1643,8 @@ def ix(i: uint256) -> decimal: for i, p in enumerate(some_good_primes): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(some_good_primes) + 1)) + with tx_failed(): + c.ix(len(some_good_primes) + 1) def test_public_dynarray(get_contract): @@ -1831,7 +1841,8 @@ def should_revert() -> DynArray[String[65], 2]: @pytest.mark.parametrize("code", dynarray_length_no_clobber_cases) -def test_dynarray_length_no_clobber(get_contract, assert_tx_failed, code): +def test_dynarray_length_no_clobber(get_contract, tx_failed, code): # check that length is not clobbered before dynarray data copy happens c = get_contract(code) - assert_tx_failed(lambda: c.should_revert()) + with tx_failed(): + c.should_revert() diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 03c22134ed..5da6d57558 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -74,7 +74,7 @@ def is_not_boss(a: Roles) -> bool: assert c.is_not_boss(2**4) is False -def test_bitwise(get_contract, assert_tx_failed): +def test_bitwise(get_contract, tx_failed): code = """ flag Roles: USER @@ -134,18 +134,25 @@ def binv_arg(a: Roles) -> Roles: assert c.binv_arg(0b00000) == 0b11111 # LHS is out of bound - assert_tx_failed(lambda: c.bor_arg(32, 3)) - assert_tx_failed(lambda: c.band_arg(32, 3)) - assert_tx_failed(lambda: c.bxor_arg(32, 3)) - assert_tx_failed(lambda: c.binv_arg(32)) + with tx_failed(): + c.bor_arg(32, 3) + with tx_failed(): + c.band_arg(32, 3) + with tx_failed(): + c.bxor_arg(32, 3) + with tx_failed(): + c.binv_arg(32) # RHS - assert_tx_failed(lambda: c.bor_arg(3, 32)) - assert_tx_failed(lambda: c.band_arg(3, 32)) - assert_tx_failed(lambda: c.bxor_arg(3, 32)) + with tx_failed(): + c.bor_arg(3, 32) + with tx_failed(): + c.band_arg(3, 32) + with tx_failed(): + c.bxor_arg(3, 32) -def test_augassign_storage(get_contract, w3, assert_tx_failed): +def test_augassign_storage(get_contract, w3, tx_failed): code = """ flag Roles: ADMIN @@ -190,7 +197,8 @@ def checkMinter(minter: address): assert c.roles(minter_address) == 0b10 # admin is not a minter - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) c.addMinter(admin_address, transact={}) @@ -201,7 +209,8 @@ def checkMinter(minter: address): # revoke minter c.revokeMinter(admin_address, transact={}) assert c.roles(admin_address) == 0b01 - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) # flip minter c.flipMinter(admin_address, transact={}) @@ -211,7 +220,8 @@ def checkMinter(minter: address): # flip minter c.flipMinter(admin_address, transact={}) assert c.roles(admin_address) == 0b01 - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) def test_in_flag(get_contract_with_gas_estimation): diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 832b679e5e..657c4ba0b8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -353,7 +353,7 @@ def test_multi4() -> uint256[2][2][2][2]: @pytest.mark.parametrize("type_", ["uint8", "uint256"]) -def test_unsigned_accessors(get_contract_with_gas_estimation, assert_tx_failed, type_): +def test_unsigned_accessors(get_contract_with_gas_estimation, tx_failed, type_): code = f""" @external def bounds_check(ix: {type_}) -> uint256: @@ -363,11 +363,12 @@ def bounds_check(ix: {type_}) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check(0) == 1 assert c.bounds_check(2) == 3 - assert_tx_failed(lambda: c.bounds_check(3)) + with tx_failed(): + c.bounds_check(3) @pytest.mark.parametrize("type_", ["int128", "int256"]) -def test_signed_accessors(get_contract_with_gas_estimation, assert_tx_failed, type_): +def test_signed_accessors(get_contract_with_gas_estimation, tx_failed, type_): code = f""" @external def bounds_check(ix: {type_}) -> uint256: @@ -377,8 +378,10 @@ def bounds_check(ix: {type_}) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check(0) == 1 assert c.bounds_check(2) == 3 - assert_tx_failed(lambda: c.bounds_check(3)) - assert_tx_failed(lambda: c.bounds_check(-1)) + with tx_failed(): + c.bounds_check(3) + with tx_failed(): + c.bounds_check(-1) def test_list_check_heterogeneous_types(get_contract_with_gas_estimation, assert_compile_failed): @@ -662,7 +665,7 @@ def foo(x: Bar[2][2][2]) -> uint256: ("bool", [True, False, True, False, True, False]), ], ) -def test_constant_list(get_contract, assert_tx_failed, type, value): +def test_constant_list(get_contract, tx_failed, type, value): code = f""" MY_LIST: constant({type}[{len(value)}]) = {value} @external @@ -673,7 +676,8 @@ def ix(i: uint256) -> {type}: for i, p in enumerate(value): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(value) + 1)) + with tx_failed(): + c.ix(len(value) + 1) def test_nested_constant_list_accessor(get_contract): @@ -728,7 +732,7 @@ def foo(i: uint256) -> {return_type}: assert_compile_failed(lambda: get_contract(code), TypeMismatch) -def test_constant_list_address(get_contract, assert_tx_failed): +def test_constant_list_address(get_contract, tx_failed): some_good_address = [ "0x0000000000000000000000000000000000012345", "0x0000000000000000000000000000000000023456", @@ -754,10 +758,11 @@ def ix(i: uint256) -> address: for i, p in enumerate(some_good_address): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(some_good_address) + 1)) + with tx_failed(): + c.ix(len(some_good_address) + 1) -def test_list_index_complex_expr(get_contract, assert_tx_failed): +def test_list_index_complex_expr(get_contract, tx_failed): # test subscripts where the index is not a literal code = """ @external @@ -771,7 +776,8 @@ def foo(xs: uint256[257], i: uint8) -> uint256: assert c.foo(xs, ix) == xs[ix + 1] # safemath should fail for uint8: 255 + 1. - assert_tx_failed(lambda: c.foo(xs, 255)) + with tx_failed(): + c.foo(xs, 255) @pytest.mark.parametrize( @@ -793,7 +799,7 @@ def foo(xs: uint256[257], i: uint8) -> uint256: ("bool", [[True, False], [True, False], [True, False]]), ], ) -def test_constant_nested_list(get_contract, assert_tx_failed, type, value): +def test_constant_nested_list(get_contract, tx_failed, type, value): code = f""" MY_LIST: constant({type}[{len(value[0])}][{len(value)}]) = {value} @external @@ -805,7 +811,8 @@ def ix(i: uint256, j: uint256) -> {type}: for j, q in enumerate(p): assert c.ix(i, j) == q # assert oob - assert_tx_failed(lambda: c.ix(len(value) + 1, len(value[0]) + 1)) + with tx_failed(): + c.ix(len(value) + 1, len(value[0]) + 1) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 7f1fa71329..9d50f8df38 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -61,7 +61,7 @@ def get(k: String[34]) -> int128: assert c.get("a" * 34) == 6789 -def test_string_slice(get_contract_with_gas_estimation, assert_tx_failed): +def test_string_slice(get_contract_with_gas_estimation, tx_failed): test_slice4 = """ @external def foo(inp: String[10], start: uint256, _len: uint256) -> String[10]: @@ -76,10 +76,14 @@ def foo(inp: String[10], start: uint256, _len: uint256) -> String[10]: assert c.foo("badminton", 1, 0) == "" assert c.foo("badminton", 9, 0) == "" - assert_tx_failed(lambda: c.foo("badminton", 0, 10)) - assert_tx_failed(lambda: c.foo("badminton", 1, 9)) - assert_tx_failed(lambda: c.foo("badminton", 9, 1)) - assert_tx_failed(lambda: c.foo("badminton", 10, 0)) + with tx_failed(): + c.foo("badminton", 0, 10) + with tx_failed(): + c.foo("badminton", 1, 9) + with tx_failed(): + c.foo("badminton", 9, 1) + with tx_failed(): + c.foo("badminton", 10, 0) def test_private_string(get_contract_with_gas_estimation): diff --git a/tests/functional/examples/auctions/test_blind_auction.py b/tests/functional/examples/auctions/test_blind_auction.py index d814ab0cad..dcd4e0bf8b 100644 --- a/tests/functional/examples/auctions/test_blind_auction.py +++ b/tests/functional/examples/auctions/test_blind_auction.py @@ -33,15 +33,15 @@ def test_initial_state(w3, tester, auction_contract): assert auction_contract.highestBidder() is None -def test_late_bid(w3, auction_contract, assert_tx_failed): +def test_late_bid(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # Move time forward past bidding end w3.testing.mine(BIDDING_TIME + TEST_INCREMENT) # Try to bid after bidding has ended - assert_tx_failed( - lambda: auction_contract.bid( + with tx_failed(): + auction_contract.bid( w3.keccak( b"".join( [ @@ -53,10 +53,9 @@ def test_late_bid(w3, auction_contract, assert_tx_failed): ), transact={"value": 200, "from": k1}, ) - ) -def test_too_many_bids(w3, auction_contract, assert_tx_failed): +def test_too_many_bids(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # First 128 bids should be able to be placed successfully @@ -75,8 +74,8 @@ def test_too_many_bids(w3, auction_contract, assert_tx_failed): ) # 129th bid should fail - assert_tx_failed( - lambda: auction_contract.bid( + with tx_failed(): + auction_contract.bid( w3.keccak( b"".join( [ @@ -88,10 +87,9 @@ def test_too_many_bids(w3, auction_contract, assert_tx_failed): ), transact={"value": 128, "from": k1}, ) - ) -def test_early_reval(w3, auction_contract, assert_tx_failed): +def test_early_reval(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # k1 places 1 real bid @@ -119,11 +117,10 @@ def test_early_reval(w3, auction_contract, assert_tx_failed): _values[0] = 100 _fakes[0] = False _secrets[0] = (8675309).to_bytes(32, byteorder="big") - assert_tx_failed( - lambda: auction_contract.reveal( + with tx_failed(): + auction_contract.reveal( _numBids, _values, _fakes, _secrets, transact={"value": 0, "from": k1} ) - ) # Check highest bidder is still empty assert auction_contract.highestBidder() is None @@ -131,7 +128,7 @@ def test_early_reval(w3, auction_contract, assert_tx_failed): assert auction_contract.highestBid() == 0 -def test_late_reveal(w3, auction_contract, assert_tx_failed): +def test_late_reveal(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # k1 places 1 real bid @@ -159,11 +156,10 @@ def test_late_reveal(w3, auction_contract, assert_tx_failed): _values[0] = 100 _fakes[0] = False _secrets[0] = (8675309).to_bytes(32, byteorder="big") - assert_tx_failed( - lambda: auction_contract.reveal( + with tx_failed(): + auction_contract.reveal( _numBids, _values, _fakes, _secrets, transact={"value": 0, "from": k1} ) - ) # Check highest bidder is still empty assert auction_contract.highestBidder() is None @@ -171,14 +167,15 @@ def test_late_reveal(w3, auction_contract, assert_tx_failed): assert auction_contract.highestBid() == 0 -def test_early_end(w3, auction_contract, assert_tx_failed): +def test_early_end(w3, auction_contract, tx_failed): k0 = w3.eth.accounts[0] # Should not be able to end auction before reveal time has ended - assert_tx_failed(lambda: auction_contract.auctionEnd(transact={"value": 0, "from": k0})) + with tx_failed(): + auction_contract.auctionEnd(transact={"value": 0, "from": k0}) -def test_double_end(w3, auction_contract, assert_tx_failed): +def test_double_end(w3, auction_contract, tx_failed): k0 = w3.eth.accounts[0] # Move time forward past bidding and reveal end @@ -188,7 +185,8 @@ def test_double_end(w3, auction_contract, assert_tx_failed): auction_contract.auctionEnd(transact={"value": 0, "from": k0}) # Should not be able to end auction twice - assert_tx_failed(lambda: auction_contract.auctionEnd(transact={"value": 0, "from": k0})) + with tx_failed(): + auction_contract.auctionEnd(transact={"value": 0, "from": k0}) def test_blind_auction(w3, auction_contract): diff --git a/tests/functional/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py index cf0bb8cc20..c80b44d976 100644 --- a/tests/functional/examples/auctions/test_simple_open_auction.py +++ b/tests/functional/examples/auctions/test_simple_open_auction.py @@ -33,17 +33,19 @@ def test_initial_state(w3, tester, auction_contract, auction_start): assert auction_contract.auctionEnd() >= tester.get_block_by_number("latest")["timestamp"] -def test_bid(w3, tester, auction_contract, assert_tx_failed): +def test_bid(w3, tester, auction_contract, tx_failed): k1, k2, k3, k4, k5 = w3.eth.accounts[:5] # Bidder cannot bid 0 - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 0, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 0, "from": k1}) # Bidder can bid auction_contract.bid(transact={"value": 1, "from": k1}) # Check that highest bidder and highest bid have changed accordingly assert auction_contract.highestBidder() == k1 assert auction_contract.highestBid() == 1 # Bidder bid cannot equal current highest bid - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 1, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 1, "from": k1}) # Higher bid can replace current highest bid auction_contract.bid(transact={"value": 2, "from": k2}) # Check that highest bidder and highest bid have changed accordingly @@ -72,10 +74,11 @@ def test_bid(w3, tester, auction_contract, assert_tx_failed): assert auction_contract.pendingReturns(k1) == 0 -def test_end_auction(w3, tester, auction_contract, assert_tx_failed): +def test_end_auction(w3, tester, auction_contract, tx_failed): k1, k2, k3, k4, k5 = w3.eth.accounts[:5] # Fails if auction end time has not been reached - assert_tx_failed(lambda: auction_contract.endAuction()) + with tx_failed(): + auction_contract.endAuction() auction_contract.bid(transact={"value": 1 * 10**10, "from": k2}) # Move block timestamp foreward to reach auction end time # tester.time_travel(tester.get_block_by_number('latest')['timestamp'] + EXPIRY) @@ -86,6 +89,8 @@ def test_end_auction(w3, tester, auction_contract, assert_tx_failed): # Beneficiary receives the highest bid assert balance_after_end == balance_before_end + 1 * 10**10 # Bidder cannot bid after auction end time has been reached - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 10, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 10, "from": k1}) # Auction cannot be ended twice - assert_tx_failed(lambda: auction_contract.endAuction()) + with tx_failed(): + auction_contract.endAuction() diff --git a/tests/functional/examples/company/test_company.py b/tests/functional/examples/company/test_company.py index 71141b8bb5..5933a14e86 100644 --- a/tests/functional/examples/company/test_company.py +++ b/tests/functional/examples/company/test_company.py @@ -9,7 +9,7 @@ def c(w3, get_contract): return contract -def test_overbuy(w3, c, assert_tx_failed): +def test_overbuy(w3, c, tx_failed): # If all the stock has been bought, no one can buy more a1, a2 = w3.eth.accounts[1:3] test_shares = int(c.totalShares() / 2) @@ -19,15 +19,19 @@ def test_overbuy(w3, c, assert_tx_failed): assert c.stockAvailable() == 0 assert c.getHolding(a1) == (test_shares * 2) one_stock = c.price() - assert_tx_failed(lambda: c.buyStock(transact={"from": a1, "value": one_stock})) - assert_tx_failed(lambda: c.buyStock(transact={"from": a2, "value": one_stock})) + with tx_failed(): + c.buyStock(transact={"from": a1, "value": one_stock}) + with tx_failed(): + c.buyStock(transact={"from": a2, "value": one_stock}) -def test_sell_without_stock(w3, c, assert_tx_failed): +def test_sell_without_stock(w3, c, tx_failed): a1, a2 = w3.eth.accounts[1:3] # If you don't have any stock, you can't sell - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a2})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) + with tx_failed(): + c.sellStock(1, transact={"from": a2}) # But if you do, you can! test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) @@ -35,48 +39,57 @@ def test_sell_without_stock(w3, c, assert_tx_failed): assert c.getHolding(a1) == test_shares c.sellStock(test_shares, transact={"from": a1}) # But only until you run out - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) -def test_oversell(w3, c, assert_tx_failed): +def test_oversell(w3, c, tx_failed): a0, a1, a2 = w3.eth.accounts[:3] # You can't sell more than you own test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) - assert_tx_failed(lambda: c.sellStock(test_shares + 1, transact={"from": a1})) + with tx_failed(): + c.sellStock(test_shares + 1, transact={"from": a1}) -def test_transfer(w3, c, assert_tx_failed): +def test_transfer(w3, c, tx_failed): # If you don't have any stock, you can't transfer a1, a2 = w3.eth.accounts[1:3] - assert_tx_failed(lambda: c.transferStock(a2, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.transferStock(a1, 1, transact={"from": a2})) + with tx_failed(): + c.transferStock(a2, 1, transact={"from": a1}) + with tx_failed(): + c.transferStock(a1, 1, transact={"from": a2}) # If you transfer, you don't have the stock anymore test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) assert c.getHolding(a1) == test_shares c.transferStock(a2, test_shares, transact={"from": a1}) - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) # But the other person does c.sellStock(test_shares, transact={"from": a2}) -def test_paybill(w3, c, assert_tx_failed): +def test_paybill(w3, c, tx_failed): a0, a1, a2, a3 = w3.eth.accounts[:4] # Only the company can authorize payments - assert_tx_failed(lambda: c.payBill(a2, 1, transact={"from": a1})) + with tx_failed(): + c.payBill(a2, 1, transact={"from": a1}) # A company can only pay someone if it has the money - assert_tx_failed(lambda: c.payBill(a2, 1, transact={"from": a0})) + with tx_failed(): + c.payBill(a2, 1, transact={"from": a0}) # If it has the money, it can pay someone test_value = int(c.totalShares() * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) c.payBill(a2, test_value, transact={"from": a0}) # Until it runs out of money - assert_tx_failed(lambda: c.payBill(a3, 1, transact={"from": a0})) + with tx_failed(): + c.payBill(a3, 1, transact={"from": a0}) # Then no stockholders can sell their stock either - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) def test_valuation(w3, c): diff --git a/tests/functional/examples/crowdfund/test_crowdfund_example.py b/tests/functional/examples/crowdfund/test_crowdfund_example.py index 9a08d9241c..e75a88bf48 100644 --- a/tests/functional/examples/crowdfund/test_crowdfund_example.py +++ b/tests/functional/examples/crowdfund/test_crowdfund_example.py @@ -27,7 +27,7 @@ def test_crowdfund_example(c, w3): assert post_bal - pre_bal == 54 -def test_crowdfund_example2(c, w3, assert_tx_failed): +def test_crowdfund_example2(c, w3, tx_failed): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] c.participate(transact={"value": 1, "from": a3}) c.participate(transact={"value": 2, "from": a4}) @@ -39,9 +39,11 @@ def test_crowdfund_example2(c, w3, assert_tx_failed): # assert c.expired() # assert not c.reached() pre_bals = [w3.eth.get_balance(x) for x in [a3, a4, a5, a6]] - assert_tx_failed(lambda: c.refund(transact={"from": a0})) + with tx_failed(): + c.refund(transact={"from": a0}) c.refund(transact={"from": a3}) - assert_tx_failed(lambda: c.refund(transact={"from": a3})) + with tx_failed(): + c.refund(transact={"from": a3}) c.refund(transact={"from": a4}) c.refund(transact={"from": a5}) c.refund(transact={"from": a6}) diff --git a/tests/functional/examples/market_maker/test_on_chain_market_maker.py b/tests/functional/examples/market_maker/test_on_chain_market_maker.py index db9700da3b..235a0ea66f 100644 --- a/tests/functional/examples/market_maker/test_on_chain_market_maker.py +++ b/tests/functional/examples/market_maker/test_on_chain_market_maker.py @@ -31,25 +31,21 @@ def test_initial_state(market_maker): assert market_maker.owner() is None -def test_initiate(w3, market_maker, erc20, assert_tx_failed): +def test_initiate(w3, market_maker, erc20, tx_failed): a0 = w3.eth.accounts[0] - erc20.approve(market_maker.address, w3.to_wei(2, "ether"), transact={}) - market_maker.initiate( - erc20.address, w3.to_wei(1, "ether"), transact={"value": w3.to_wei(2, "ether")} - ) - assert market_maker.totalEthQty() == w3.to_wei(2, "ether") - assert market_maker.totalTokenQty() == w3.to_wei(1, "ether") + ether, ethers = w3.to_wei(1, "ether"), w3.to_wei(2, "ether") + erc20.approve(market_maker.address, ethers, transact={}) + market_maker.initiate(erc20.address, ether, transact={"value": ethers}) + assert market_maker.totalEthQty() == ethers + assert market_maker.totalTokenQty() == ether assert market_maker.invariant() == 2 * 10**36 assert market_maker.owner() == a0 assert erc20.name() == TOKEN_NAME assert erc20.decimals() == TOKEN_DECIMALS # Initiate cannot be called twice - assert_tx_failed( - lambda: market_maker.initiate( - erc20.address, w3.to_wei(1, "ether"), transact={"value": w3.to_wei(2, "ether")} - ) - ) # noqa: E501 + with tx_failed(): + market_maker.initiate(erc20.address, ether, transact={"value": ethers}) def test_eth_to_tokens(w3, market_maker, erc20): @@ -95,7 +91,7 @@ def test_tokens_to_eth(w3, market_maker, erc20): assert market_maker.totalEthQty() == w3.to_wei(1, "ether") -def test_owner_withdraw(w3, market_maker, erc20, assert_tx_failed): +def test_owner_withdraw(w3, market_maker, erc20, tx_failed): a0, a1 = w3.eth.accounts[:2] a0_balance_before = w3.eth.get_balance(a0) # Approve 2 eth transfers. @@ -110,7 +106,8 @@ def test_owner_withdraw(w3, market_maker, erc20, assert_tx_failed): assert erc20.balanceOf(a0) == TOKEN_TOTAL_SUPPLY - w3.to_wei(1, "ether") # Only owner can call ownerWithdraw - assert_tx_failed(lambda: market_maker.ownerWithdraw(transact={"from": a1})) + with tx_failed(): + market_maker.ownerWithdraw(transact={"from": a1}) market_maker.ownerWithdraw(transact={}) assert w3.eth.get_balance(a0) == a0_balance_before # Eth balance restored. assert erc20.balanceOf(a0) == TOKEN_TOTAL_SUPPLY # Tokens returned to a0. diff --git a/tests/functional/examples/name_registry/test_name_registry.py b/tests/functional/examples/name_registry/test_name_registry.py index 26f5844484..a2e92a7c52 100644 --- a/tests/functional/examples/name_registry/test_name_registry.py +++ b/tests/functional/examples/name_registry/test_name_registry.py @@ -1,8 +1,9 @@ -def test_name_registry(w3, get_contract, assert_tx_failed): +def test_name_registry(w3, get_contract, tx_failed): a0, a1 = w3.eth.accounts[:2] with open("examples/name_registry/name_registry.vy") as f: code = f.read() c = get_contract(code) c.register(b"jacques", a0, transact={}) assert c.lookup(b"jacques") == a0 - assert_tx_failed(lambda: c.register(b"jacques", a1)) + with tx_failed(): + c.register(b"jacques", a1) diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index 9a806ed885..2cc5dd8d4a 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -31,9 +31,10 @@ def get_balance(): return get_balance -def test_initial_state(w3, assert_tx_failed, get_contract, get_balance, contract_code): +def test_initial_state(w3, tx_failed, get_contract, get_balance, contract_code): # Inital deposit has to be divisible by two - assert_tx_failed(lambda: get_contract(contract_code, value=13)) + with tx_failed(): + get_contract(contract_code, value=13) # Seller puts item up for sale a0_pre_bal, a1_pre_bal = get_balance() c = get_contract(contract_code, value_in_eth=2) @@ -47,30 +48,34 @@ def test_initial_state(w3, assert_tx_failed, get_contract, get_balance, contract assert get_balance() == ((a0_pre_bal - w3.to_wei(2, "ether")), a1_pre_bal) -def test_abort(w3, assert_tx_failed, get_balance, get_contract, contract_code): +def test_abort(w3, tx_failed, get_balance, get_contract, contract_code): a0, a1, a2 = w3.eth.accounts[:3] a0_pre_bal, a1_pre_bal = get_balance() c = get_contract(contract_code, value=w3.to_wei(2, "ether")) assert c.value() == w3.to_wei(1, "ether") # Only sender can trigger refund - assert_tx_failed(lambda: c.abort(transact={"from": a2})) + with tx_failed(): + c.abort(transact={"from": a2}) # Refund works correctly c.abort(transact={"from": a0}) assert get_balance() == (a0_pre_bal, a1_pre_bal) # Purchase in process, no refund possible c = get_contract(contract_code, value=2) c.purchase(transact={"value": 2, "from": a1}) - assert_tx_failed(lambda: c.abort(transact={"from": a0})) + with tx_failed(): + c.abort(transact={"from": a0}) -def test_purchase(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_purchase(w3, get_contract, tx_failed, get_balance, contract_code): a0, a1, a2, a3 = w3.eth.accounts[:4] init_bal_a0, init_bal_a1 = get_balance() c = get_contract(contract_code, value=2) # Purchase for too low/high price - assert_tx_failed(lambda: c.purchase(transact={"value": 1, "from": a1})) - assert_tx_failed(lambda: c.purchase(transact={"value": 3, "from": a1})) + with tx_failed(): + c.purchase(transact={"value": 1, "from": a1}) + with tx_failed(): + c.purchase(transact={"value": 3, "from": a1}) # Purchase for the correct price c.purchase(transact={"value": 2, "from": a1}) # Check if buyer is set correctly @@ -80,26 +85,29 @@ def test_purchase(w3, get_contract, assert_tx_failed, get_balance, contract_code # Check balances, both deposits should have been deducted assert get_balance() == (init_bal_a0 - 2, init_bal_a1 - 2) # Allow nobody else to purchase - assert_tx_failed(lambda: c.purchase(transact={"value": 2, "from": a3})) + with tx_failed(): + c.purchase(transact={"value": 2, "from": a3}) -def test_received(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_received(w3, get_contract, tx_failed, get_balance, contract_code): a0, a1 = w3.eth.accounts[:2] init_bal_a0, init_bal_a1 = get_balance() c = get_contract(contract_code, value=2) # Can only be called after purchase - assert_tx_failed(lambda: c.received(transact={"from": a1})) + with tx_failed(): + c.received(transact={"from": a1}) # Purchase completed c.purchase(transact={"value": 2, "from": a1}) # Check that e.g. sender cannot trigger received - assert_tx_failed(lambda: c.received(transact={"from": a0})) + with tx_failed(): + c.received(transact={"from": a0}) # Check if buyer can call receive c.received(transact={"from": a1}) # Final check if everything worked. 1 value has been transferred assert get_balance() == (init_bal_a0 + 1, init_bal_a1 - 1) -def test_received_reentrancy(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_received_reentrancy(w3, get_contract, tx_failed, get_balance, contract_code): buyer_contract_code = """ interface PurchaseContract: diff --git a/tests/functional/examples/storage/test_advanced_storage.py b/tests/functional/examples/storage/test_advanced_storage.py index 13ffce4f82..313d1a7e5c 100644 --- a/tests/functional/examples/storage/test_advanced_storage.py +++ b/tests/functional/examples/storage/test_advanced_storage.py @@ -18,32 +18,30 @@ def test_initial_state(adv_storage_contract): assert adv_storage_contract.storedData() == INITIAL_VALUE -def test_failed_transactions(w3, adv_storage_contract, assert_tx_failed): +def test_failed_transactions(w3, adv_storage_contract, tx_failed): k1 = w3.eth.accounts[1] # Try to set the storage to a negative amount - assert_tx_failed(lambda: adv_storage_contract.set(-10, transact={"from": k1})) + with tx_failed(): + adv_storage_contract.set(-10, transact={"from": k1}) # Lock the contract by storing more than 100. Then try to change the value adv_storage_contract.set(150, transact={"from": k1}) - assert_tx_failed(lambda: adv_storage_contract.set(10, transact={"from": k1})) + with tx_failed(): + adv_storage_contract.set(10, transact={"from": k1}) # Reset the contract and try to change the value adv_storage_contract.reset(transact={"from": k1}) adv_storage_contract.set(10, transact={"from": k1}) assert adv_storage_contract.storedData() == 10 - # Assert a different exception (ValidationError for non matching argument type) - assert_tx_failed( - lambda: adv_storage_contract.set("foo", transact={"from": k1}), ValidationError - ) + # Assert a different exception (ValidationError for non-matching argument type) + with tx_failed(ValidationError): + adv_storage_contract.set("foo", transact={"from": k1}) # Assert a different exception that contains specific text - assert_tx_failed( - lambda: adv_storage_contract.set(1, 2, transact={"from": k1}), - ValidationError, - "invocation failed due to improper number of arguments", - ) + with tx_failed(ValidationError, "invocation failed due to improper number of arguments"): + adv_storage_contract.set(1, 2, transact={"from": k1}) def test_events(w3, adv_storage_contract, get_logs): diff --git a/tests/functional/examples/tokens/test_erc1155.py b/tests/functional/examples/tokens/test_erc1155.py index abebd024b6..5dc314c037 100644 --- a/tests/functional/examples/tokens/test_erc1155.py +++ b/tests/functional/examples/tokens/test_erc1155.py @@ -29,7 +29,7 @@ @pytest.fixture -def erc1155(get_contract, w3, assert_tx_failed): +def erc1155(get_contract, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] with open("examples/tokens/ERC1155ownable.vy") as f: code = f.read() @@ -41,18 +41,20 @@ def erc1155(get_contract, w3, assert_tx_failed): assert c.balanceOf(a1, 1) == 1 assert c.balanceOf(a1, 2) == 1 assert c.balanceOf(a1, 3) == 1 - assert_tx_failed( - lambda: c.mintBatch(ZERO_ADDRESS, mintBatch, minBatchSetOf10, transact={"from": owner}) - ) - assert_tx_failed(lambda: c.mintBatch(a1, [1, 2, 3], [1, 1], transact={"from": owner})) + with tx_failed(): + c.mintBatch(ZERO_ADDRESS, mintBatch, minBatchSetOf10, transact={"from": owner}) + with tx_failed(): + c.mintBatch(a1, [1, 2, 3], [1, 1], transact={"from": owner}) c.mint(a1, 21, 1, transact={"from": owner}) c.mint(a1, 22, 1, transact={"from": owner}) c.mint(a1, 23, 1, transact={"from": owner}) c.mint(a1, 24, 1, transact={"from": owner}) - assert_tx_failed(lambda: c.mint(a1, 24, 1, transact={"from": a3})) - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 24, 1, transact={"from": owner})) + with tx_failed(): + c.mint(a1, 24, 1, transact={"from": a3}) + with tx_failed(): + c.mint(ZERO_ADDRESS, 24, 1, transact={"from": owner}) assert c.balanceOf(a1, 21) == 1 assert c.balanceOf(a1, 22) == 1 @@ -80,69 +82,76 @@ def test_initial_state(erc1155): assert erc1155.supportsInterface(ERC1155_INTERFACE_ID_METADATA) -def test_pause(erc1155, w3, assert_tx_failed): +def test_pause(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check the pause status, pause, check, unpause, check, with owner and non-owner w3.eth.accounts # this test will check all the function that should not work when paused. assert not erc1155.paused() # try to pause the contract from a non owner account - assert_tx_failed(lambda: erc1155.pause(transact={"from": a1})) + with tx_failed(): + erc1155.pause(transact={"from": a1}) # now pause the contract and check status erc1155.pause(transact={"from": owner}) assert erc1155.paused() # try pausing a paused contract - assert_tx_failed(lambda: erc1155.pause()) + with tx_failed(): + erc1155.pause() # try functions that should not work when paused - assert_tx_failed(lambda: erc1155.setURI(NEW_CONTRACT_URI)) + with tx_failed(): + erc1155.setURI(NEW_CONTRACT_URI) # test burn and burnbatch - assert_tx_failed(lambda: erc1155.burn(21, 1)) - assert_tx_failed(lambda: erc1155.burnBatch([21, 22], [1, 1])) + with tx_failed(): + erc1155.burn(21, 1) + with tx_failed(): + erc1155.burnBatch([21, 22], [1, 1]) # check mint and mintbatch - assert_tx_failed(lambda: erc1155.mint(a1, 21, 1, transact={"from": owner})) - assert_tx_failed( - lambda: erc1155.mintBatch(a1, mintBatch, minBatchSetOf10, transact={"from": owner}) - ) + with tx_failed(): + erc1155.mint(a1, 21, 1, transact={"from": owner}) + with tx_failed(): + erc1155.mintBatch(a1, mintBatch, minBatchSetOf10, transact={"from": owner}) # check safetransferfrom and safebatchtransferfrom - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # check ownership functions - assert_tx_failed(lambda: erc1155.transferOwnership(a1)) - assert_tx_failed(lambda: erc1155.renounceOwnership()) + with tx_failed(): + erc1155.transferOwnership(a1) + with tx_failed(): + erc1155.renounceOwnership() # check approval functions - assert_tx_failed(lambda: erc1155.setApprovalForAll(owner, a5, True)) + with tx_failed(): + erc1155.setApprovalForAll(owner, a5, True) # try and unpause as non-owner - assert_tx_failed(lambda: erc1155.unpause(transact={"from": a1})) + with tx_failed(): + erc1155.unpause(transact={"from": a1}) erc1155.unpause(transact={"from": owner}) assert not erc1155.paused() # try un pausing an unpaused contract - assert_tx_failed(lambda: erc1155.unpause()) + with tx_failed(): + erc1155.unpause() -def test_contractURI(erc1155, w3, assert_tx_failed): +def test_contractURI(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # change contract URI and restore. assert erc1155.contractURI() == CONTRACT_METADATA_URI - assert_tx_failed( - lambda: erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": a1}) - ) + with tx_failed(): + erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": a1}) erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": owner}) assert erc1155.contractURI() == NEW_CONTRACT_METADATA_URI assert erc1155.contractURI() != CONTRACT_METADATA_URI @@ -150,10 +159,11 @@ def test_contractURI(erc1155, w3, assert_tx_failed): assert erc1155.contractURI() != NEW_CONTRACT_METADATA_URI assert erc1155.contractURI() == CONTRACT_METADATA_URI - assert_tx_failed(lambda: erc1155.setContractURI(CONTRACT_METADATA_URI)) + with tx_failed(): + erc1155.setContractURI(CONTRACT_METADATA_URI) -def test_URI(erc1155, w3, assert_tx_failed): +def test_URI(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # change contract URI and restore. assert erc1155.uri(0) == CONTRACT_URI @@ -164,7 +174,8 @@ def test_URI(erc1155, w3, assert_tx_failed): assert erc1155.uri(0) != NEW_CONTRACT_URI assert erc1155.uri(0) == CONTRACT_URI - assert_tx_failed(lambda: erc1155.setURI(CONTRACT_URI)) + with tx_failed(): + erc1155.setURI(CONTRACT_URI) # set contract to dynamic URI erc1155.toggleDynUri(True, transact={"from": owner}) @@ -172,49 +183,41 @@ def test_URI(erc1155, w3, assert_tx_failed): assert erc1155.uri(0) == CONTRACT_DYNURI + str(0) + ".json" -def test_safeTransferFrom_balanceOf_single(erc1155, w3, assert_tx_failed): +def test_safeTransferFrom_balanceOf_single(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] assert erc1155.balanceOf(a1, 24) == 1 # transfer by non-owner - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a2}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a2}) # transfer to zero address - assert_tx_failed( - lambda: erc1155.safeTransferFrom( - a1, ZERO_ADDRESS, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1} - ) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, ZERO_ADDRESS, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer to self - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a1, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a1, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer more than owned - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 24, 500, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 24, 500, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer item not owned / not existing - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 500, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 500, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) assert erc1155.balanceOf(a2, 21) == 1 # try to transfer item again - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) assert erc1155.balanceOf(a1, 21) == 0 # TODO: mint 20 NFTs [1:20] and check the balance for each -def test_mintBatch_balanceOf(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_mintBatch_balanceOf(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # Use the mint three fixture to mint the tokens. # this test checks the balances of this test @@ -222,7 +225,7 @@ def test_mintBatch_balanceOf(erc1155, w3, assert_tx_failed): # test_mint_batch assert erc1155.balanceOf(a1, i) == 1 -def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check a1 balances for NFTs 21-24 @@ -231,67 +234,58 @@ def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOf(a1, 23) == 1 assert erc1155.balanceOf(a1, 23) == 1 - # try to transfer item from non item owner account - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + # try to transfer item from non-item owner account + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a2} ) - ) # try to transfer item to zero address - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, ZERO_ADDRESS, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer item to self - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a1, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer more items than we own - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 125, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # mismatched item and amounts - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer nonexisting item - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 500], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) assert erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) # try to transfer again, our balances are zero now, should fail - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) - assert_tx_failed( - lambda: erc1155.balanceOfBatch([a2, a2, a2], [21, 22], transact={"from": owner}) - == [1, 1, 1] - ) + with tx_failed(): + erc1155.balanceOfBatch([a2, a2, a2], [21, 22], transact={"from": owner}) assert erc1155.balanceOfBatch([a2, a2, a2], [21, 22, 23]) == [1, 1, 1] assert erc1155.balanceOf(a1, 21) == 0 -def test_mint_one_burn_one(erc1155, w3, assert_tx_failed): +def test_mint_one_burn_one(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check the balance from an owner and non-owner account @@ -301,20 +295,23 @@ def test_mint_one_burn_one(erc1155, w3, assert_tx_failed): assert erc1155.balanceOf(owner, 25) == 1 # try and burn an item we don't control - assert_tx_failed(lambda: erc1155.burn(25, 1, transact={"from": a3})) + with tx_failed(): + erc1155.burn(25, 1, transact={"from": a3}) # burn an item that contains something we don't own - assert_tx_failed(lambda: erc1155.burn(595, 1, transact={"from": a1})) + with tx_failed(): + erc1155.burn(595, 1, transact={"from": a1}) # burn ah item passing a higher amount than we own - assert_tx_failed(lambda: erc1155.burn(25, 500, transact={"from": a1})) + with tx_failed(): + erc1155.burn(25, 500, transact={"from": a1}) erc1155.burn(25, 1, transact={"from": owner}) assert erc1155.balanceOf(owner, 25) == 0 -def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): +def test_mint_batch_burn_batch(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # mint NFTs 11-20 @@ -322,16 +319,20 @@ def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOfBatch([a3, a3, a3], [11, 12, 13]) == [1, 1, 1] # try and burn a batch we don't control - assert_tx_failed(lambda: erc1155.burnBatch([11, 12], [1, 1])) + with tx_failed(): + erc1155.burnBatch([11, 12], [1, 1]) # ids and amounts array length not matching - assert_tx_failed(lambda: erc1155.burnBatch([1, 2, 3], [1, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([1, 2, 3], [1, 1], transact={"from": a1}) # burn a batch that contains something we don't own - assert_tx_failed(lambda: erc1155.burnBatch([2, 3, 595], [1, 1, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([2, 3, 595], [1, 1, 1], transact={"from": a1}) # burn a batch passing a higher amount than we own - assert_tx_failed(lambda: erc1155.burnBatch([1, 2, 3], [1, 500, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([1, 2, 3], [1, 500, 1], transact={"from": a1}) # burn existing erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3}) @@ -339,18 +340,21 @@ def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOfBatch([a3, a3, a3], [11, 12, 13]) == [0, 0, 1] # burn again, should revert - assert_tx_failed(lambda: erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3})) + with tx_failed(): + erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3}) - assert lambda: erc1155.balanceOfBatch([a3, a3, a3], [1, 2, 3]) == [0, 0, 1] + assert erc1155.balanceOfBatch([a3, a3, a3], [1, 2, 3]) == [0, 0, 0] -def test_approval_functions(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_approval_functions(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # self-approval by the owner - assert_tx_failed(lambda: erc1155.setApprovalForAll(a5, a5, True, transact={"from": a5})) + with tx_failed(): + erc1155.setApprovalForAll(a5, a5, True, transact={"from": a5}) # let's approve and operator for somebody else's account - assert_tx_failed(lambda: erc1155.setApprovalForAll(owner, a5, True, transact={"from": a3})) + with tx_failed(): + erc1155.setApprovalForAll(owner, a5, True, transact={"from": a3}) # set approval correctly erc1155.setApprovalForAll(owner, a5, True) @@ -362,7 +366,7 @@ def test_approval_functions(erc1155, w3, assert_tx_failed): # test_mint_batch erc1155.setApprovalForAll(owner, a5, False) -def test_max_batch_size_violation(erc1155, w3, assert_tx_failed): +def test_max_batch_size_violation(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] TOTAL_BAD_BATCH = 200 ids = [] @@ -371,27 +375,29 @@ def test_max_batch_size_violation(erc1155, w3, assert_tx_failed): ids.append(i) amounts.append(1) - assert_tx_failed(lambda: erc1155.mintBatch(a1, ids, amounts, transact={"from": owner})) + with tx_failed(): + erc1155.mintBatch(a1, ids, amounts, transact={"from": owner}) # Transferring back and forth -def test_ownership_functions(erc1155, w3, assert_tx_failed, tester): +def test_ownership_functions(erc1155, w3, tx_failed, tester): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] print(owner, a1, a2) print("___owner___", erc1155.owner()) # change owner from account 0 to account 1 and back assert erc1155.owner() == owner - assert_tx_failed(lambda: erc1155.transferOwnership(a1, transact={"from": a2})) + with tx_failed(): + erc1155.transferOwnership(a1, transact={"from": a2}) # try to transfer ownership to current owner - assert_tx_failed(lambda: erc1155.transferOwnership(owner)) + with tx_failed(): + erc1155.transferOwnership(owner) # try to transfer ownership to ZERO ADDRESS - assert_tx_failed( - lambda: erc1155.transferOwnership("0x0000000000000000000000000000000000000000") - ) + with tx_failed(): + erc1155.transferOwnership("0x0000000000000000000000000000000000000000") # Transfer ownership to account 1 erc1155.transferOwnership(a1, transact={"from": owner}) @@ -399,11 +405,12 @@ def test_ownership_functions(erc1155, w3, assert_tx_failed, tester): assert erc1155.owner() == a1 -def test_renounce_ownership(erc1155, w3, assert_tx_failed): +def test_renounce_ownership(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] assert erc1155.owner() == owner # try to transfer ownership from non-owner account - assert_tx_failed(lambda: erc1155.renounceOwnership(transact={"from": a2})) + with tx_failed(): + erc1155.renounceOwnership(transact={"from": a2}) erc1155.renounceOwnership(transact={"from": owner}) diff --git a/tests/functional/examples/tokens/test_erc20.py b/tests/functional/examples/tokens/test_erc20.py index cba7769bae..ce507f75f8 100644 --- a/tests/functional/examples/tokens/test_erc20.py +++ b/tests/functional/examples/tokens/test_erc20.py @@ -61,7 +61,7 @@ def test_initial_state(c, w3): assert c.allowance(a2, a3) == 0 -def test_mint_and_burn(c, w3, assert_tx_failed): +def test_mint_and_burn(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] # Test scenario were mints 2 to a1, burns twice (check balance consistency) @@ -70,23 +70,30 @@ def test_mint_and_burn(c, w3, assert_tx_failed): assert c.balanceOf(a1) == 2 c.burn(2, transact={"from": a1}) assert c.balanceOf(a1) == 0 - assert_tx_failed(lambda: c.burn(2, transact={"from": a1})) + with tx_failed(): + c.burn(2, transact={"from": a1}) assert c.balanceOf(a1) == 0 # Test scenario were mintes 0 to a2, burns (check balance consistency, false burn) c.mint(a2, 0, transact={"from": minter}) assert c.balanceOf(a2) == 0 - assert_tx_failed(lambda: c.burn(2, transact={"from": a2})) + with tx_failed(): + c.burn(2, transact={"from": a2}) # Check that a1 cannot burn after depleting their balance - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) # Check that a1, a2 cannot mint - assert_tx_failed(lambda: c.mint(a1, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": a2})) + with tx_failed(): + c.mint(a1, 1, transact={"from": a1}) + with tx_failed(): + c.mint(a2, 1, transact={"from": a2}) # Check that mint to ZERO_ADDRESS failed - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 1, transact={"from": minter})) + with tx_failed(): + c.mint(ZERO_ADDRESS, 1, transact={"from": a1}) + with tx_failed(): + c.mint(ZERO_ADDRESS, 1, transact={"from": minter}) -def test_totalSupply(c, w3, assert_tx_failed): +def test_totalSupply(c, w3, tx_failed): # Test total supply initially, after mint, between two burns, and after failed burn minter, a1 = w3.eth.accounts[0:2] assert c.totalSupply() == 0 @@ -96,40 +103,49 @@ def test_totalSupply(c, w3, assert_tx_failed): assert c.totalSupply() == 1 c.burn(1, transact={"from": a1}) assert c.totalSupply() == 0 - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) assert c.totalSupply() == 0 # Test that 0-valued mint can't affect supply c.mint(a1, 0, transact={"from": minter}) assert c.totalSupply() == 0 -def test_transfer(c, w3, assert_tx_failed): +def test_transfer(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 2, transact={"from": minter}) c.burn(1, transact={"from": a1}) c.transfer(a2, 1, transact={"from": a1}) - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) c.burn(1, transact={"from": a2}) - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) # Ensure transfer fails with insufficient balance - assert_tx_failed(lambda: c.transfer(a1, 1, transact={"from": a2})) + with tx_failed(): + c.transfer(a1, 1, transact={"from": a2}) # Ensure 0-transfer always succeeds c.transfer(a1, 0, transact={"from": a2}) -def test_maxInts(c, w3, assert_tx_failed): +def test_maxInts(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] c.mint(a1, MAX_UINT256, transact={"from": minter}) assert c.balanceOf(a1) == MAX_UINT256 - assert_tx_failed(lambda: c.mint(a1, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(a1, MAX_UINT256, transact={"from": a1})) + with tx_failed(): + c.mint(a1, 1, transact={"from": a1}) + with tx_failed(): + c.mint(a1, MAX_UINT256, transact={"from": a1}) # Check that totalSupply cannot overflow, even when mint to other account - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": minter})) + with tx_failed(): + c.mint(a2, 1, transact={"from": minter}) # Check that corresponding mint is allowed after burn c.burn(1, transact={"from": a1}) c.mint(a2, 1, transact={"from": minter}) - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": minter})) + with tx_failed(): + c.mint(a2, 1, transact={"from": minter}) c.transfer(a1, 1, transact={"from": a2}) # Assert that after obtaining max number of tokens, a1 can transfer those but no more assert c.balanceOf(a1) == MAX_UINT256 @@ -150,21 +166,24 @@ def test_maxInts(c, w3, assert_tx_failed): assert c.balanceOf(a1) == 0 -def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): +def test_transferFrom_and_Allowance(c, w3, tx_failed): minter, a1, a2, a3 = w3.eth.accounts[0:4] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 1, transact={"from": minter}) c.mint(a2, 1, transact={"from": minter}) c.burn(1, transact={"from": a1}) # This should fail; no allowance or balance (0 always succeeds) - assert_tx_failed(lambda: c.transferFrom(a1, a3, 1, transact={"from": a2})) + with tx_failed(): + c.transferFrom(a1, a3, 1, transact={"from": a2}) c.transferFrom(a1, a3, 0, transact={"from": a2}) # Correct call to approval should update allowance (but not for reverse pair) c.approve(a2, 1, transact={"from": a1}) assert c.allowance(a1, a2) == 1 assert c.allowance(a2, a1) == 0 # transferFrom should succeed when allowed, fail with wrong sender - assert_tx_failed(lambda: c.transferFrom(a1, a3, 1, transact={"from": a3})) + with tx_failed(): + c.transferFrom(a1, a3, 1, transact={"from": a3}) assert c.balanceOf(a2) == 1 c.approve(a1, 1, transact={"from": a2}) c.transferFrom(a2, a3, 1, transact={"from": a1}) @@ -173,7 +192,8 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): # transferFrom with no funds should fail despite approval c.approve(a1, 1, transact={"from": a2}) assert c.allowance(a2, a1) == 1 - assert_tx_failed(lambda: c.transferFrom(a2, a3, 1, transact={"from": a1})) + with tx_failed(): + c.transferFrom(a2, a3, 1, transact={"from": a1}) # 0-approve should not change balance or allow transferFrom to change balance c.mint(a2, 1, transact={"from": minter}) assert c.allowance(a2, a1) == 1 @@ -181,7 +201,8 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 0 c.approve(a1, 0, transact={"from": a2}) assert c.allowance(a2, a1) == 0 - assert_tx_failed(lambda: c.transferFrom(a2, a3, 1, transact={"from": a1})) + with tx_failed(): + c.transferFrom(a2, a3, 1, transact={"from": a1}) # Test that if non-zero approval exists, 0-approval is NOT required to proceed # a non-conformant implementation is described in countermeasures at # https://docs.google.com/document/d/1YLPtQxZu1UAvO9cZ1O2RPXBbT0mooh4DYKjA_jp-RLM/edit#heading=h.m9fhqynw2xvt @@ -198,21 +219,24 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 5 -def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): +def test_burnFrom_and_Allowance(c, w3, tx_failed): minter, a1, a2, a3 = w3.eth.accounts[0:4] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 1, transact={"from": minter}) c.mint(a2, 1, transact={"from": minter}) c.burn(1, transact={"from": a1}) # This should fail; no allowance or balance (0 always succeeds) - assert_tx_failed(lambda: c.burnFrom(a1, 1, transact={"from": a2})) + with tx_failed(): + c.burnFrom(a1, 1, transact={"from": a2}) c.burnFrom(a1, 0, transact={"from": a2}) # Correct call to approval should update allowance (but not for reverse pair) c.approve(a2, 1, transact={"from": a1}) assert c.allowance(a1, a2) == 1 assert c.allowance(a2, a1) == 0 # transferFrom should succeed when allowed, fail with wrong sender - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a3})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a3}) assert c.balanceOf(a2) == 1 c.approve(a1, 1, transact={"from": a2}) c.burnFrom(a2, 1, transact={"from": a1}) @@ -221,7 +245,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): # transferFrom with no funds should fail despite approval c.approve(a1, 1, transact={"from": a2}) assert c.allowance(a2, a1) == 1 - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a1})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a1}) # 0-approve should not change balance or allow transferFrom to change balance c.mint(a2, 1, transact={"from": minter}) assert c.allowance(a2, a1) == 1 @@ -229,7 +254,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 0 c.approve(a1, 0, transact={"from": a2}) assert c.allowance(a2, a1) == 0 - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a1})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a1}) # Test that if non-zero approval exists, 0-approval is NOT required to proceed # a non-conformant implementation is described in countermeasures at # https://docs.google.com/document/d/1YLPtQxZu1UAvO9cZ1O2RPXBbT0mooh4DYKjA_jp-RLM/edit#heading=h.m9fhqynw2xvt @@ -245,7 +271,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): c.approve(a1, 5, transact={"from": a2}) assert c.allowance(a2, a1) == 5 # Check that burnFrom to ZERO_ADDRESS failed - assert_tx_failed(lambda: c.burnFrom(ZERO_ADDRESS, 0, transact={"from": a1})) + with tx_failed(): + c.burnFrom(ZERO_ADDRESS, 0, transact={"from": a1}) def test_raw_logs(c, w3, get_log_args): @@ -307,33 +334,36 @@ def test_raw_logs(c, w3, get_log_args): assert args.value == 0 -def test_bad_transfer(c_bad, w3, assert_tx_failed): +def test_bad_transfer(c_bad, w3, tx_failed): # Ensure transfer fails if it would otherwise overflow balance when totalSupply is corrupted minter, a1, a2 = w3.eth.accounts[0:3] c_bad.mint(a1, MAX_UINT256, transact={"from": minter}) c_bad.mint(a2, 1, transact={"from": minter}) - assert_tx_failed(lambda: c_bad.transfer(a1, 1, transact={"from": a2})) + with tx_failed(): + c_bad.transfer(a1, 1, transact={"from": a2}) c_bad.transfer(a2, MAX_UINT256 - 1, transact={"from": a1}) assert c_bad.balanceOf(a1) == 1 assert c_bad.balanceOf(a2) == MAX_UINT256 -def test_bad_burn(c_bad, w3, assert_tx_failed): +def test_bad_burn(c_bad, w3, tx_failed): # Ensure burn fails if it would otherwise underflow balance when totalSupply is corrupted minter, a1 = w3.eth.accounts[0:2] assert c_bad.balanceOf(a1) == 0 c_bad.mint(a1, 2, transact={"from": minter}) assert c_bad.balanceOf(a1) == 2 - assert_tx_failed(lambda: c_bad.burn(3, transact={"from": a1})) + with tx_failed(): + c_bad.burn(3, transact={"from": a1}) -def test_bad_transferFrom(c_bad, w3, assert_tx_failed): +def test_bad_transferFrom(c_bad, w3, tx_failed): # Ensure transferFrom fails if it would otherwise overflow balance when totalSupply is corrupted minter, a1, a2 = w3.eth.accounts[0:3] c_bad.mint(a1, MAX_UINT256, transact={"from": minter}) c_bad.mint(a2, 1, transact={"from": minter}) c_bad.approve(a1, 1, transact={"from": a2}) - assert_tx_failed(lambda: c_bad.transferFrom(a2, a1, 1, transact={"from": a1})) + with tx_failed(): + c_bad.transferFrom(a2, a1, 1, transact={"from": a1}) c_bad.approve(a2, MAX_UINT256 - 1, transact={"from": a1}) assert c_bad.allowance(a1, a2) == MAX_UINT256 - 1 c_bad.transferFrom(a1, a2, MAX_UINT256 - 1, transact={"from": a2}) diff --git a/tests/functional/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py index ab3c6368c5..c881149baa 100644 --- a/tests/functional/examples/tokens/test_erc721.py +++ b/tests/functional/examples/tokens/test_erc721.py @@ -40,16 +40,18 @@ def test_erc165(w3, c): assert c.supportsInterface(ERC721_SIG) -def test_balanceOf(c, w3, assert_tx_failed): +def test_balanceOf(c, w3, tx_failed): someone = w3.eth.accounts[1] assert c.balanceOf(someone) == 3 - assert_tx_failed(lambda: c.balanceOf(ZERO_ADDRESS)) + with tx_failed(): + c.balanceOf(ZERO_ADDRESS) -def test_ownerOf(c, w3, assert_tx_failed): +def test_ownerOf(c, w3, tx_failed): someone = w3.eth.accounts[1] assert c.ownerOf(SOMEONE_TOKEN_IDS[0]) == someone - assert_tx_failed(lambda: c.ownerOf(INVALID_TOKEN_ID)) + with tx_failed(): + c.ownerOf(INVALID_TOKEN_ID) def test_getApproved(c, w3): @@ -72,32 +74,24 @@ def test_isApprovedForAll(c, w3): assert c.isApprovedForAll(someone, operator) == 1 -def test_transferFrom_by_owner(c, w3, assert_tx_failed, get_logs): +def test_transferFrom_by_owner(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # transfer from zero address - assert_tx_failed( - lambda: c.transferFrom( - ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.transferFrom(ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer to zero address - assert_tx_failed( - lambda: c.transferFrom( - someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.transferFrom(someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer token without ownership - assert_tx_failed( - lambda: c.transferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.transferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # transfer invalid token - assert_tx_failed( - lambda: c.transferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.transferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) # transfer by owner tx_hash = c.transferFrom(someone, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) @@ -152,32 +146,24 @@ def test_transferFrom_by_operator(c, w3, get_logs): assert c.balanceOf(operator) == 2 -def test_safeTransferFrom_by_owner(c, w3, assert_tx_failed, get_logs): +def test_safeTransferFrom_by_owner(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # transfer from zero address - assert_tx_failed( - lambda: c.safeTransferFrom( - ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.safeTransferFrom(ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer to zero address - assert_tx_failed( - lambda: c.safeTransferFrom( - someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.safeTransferFrom(someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer token without ownership - assert_tx_failed( - lambda: c.safeTransferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.safeTransferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # transfer invalid token - assert_tx_failed( - lambda: c.safeTransferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.safeTransferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) # transfer by owner tx_hash = c.safeTransferFrom( @@ -238,15 +224,12 @@ def test_safeTransferFrom_by_operator(c, w3, get_logs): assert c.balanceOf(operator) == 2 -def test_safeTransferFrom_to_contract(c, w3, assert_tx_failed, get_logs, get_contract): +def test_safeTransferFrom_to_contract(c, w3, tx_failed, get_logs, get_contract): someone = w3.eth.accounts[1] # Can't transfer to a contract that doesn't implement the receiver code - assert_tx_failed( - lambda: c.safeTransferFrom( - someone, c.address, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) # noqa: E501 + with tx_failed(): + c.safeTransferFrom(someone, c.address, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # Only to an address that implements that function receiver = get_contract( @@ -277,17 +260,20 @@ def onERC721Received( assert c.balanceOf(receiver.address) == 1 -def test_approve(c, w3, assert_tx_failed, get_logs): +def test_approve(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # approve myself - assert_tx_failed(lambda: c.approve(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone})) + with tx_failed(): + c.approve(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # approve token without ownership - assert_tx_failed(lambda: c.approve(operator, OPERATOR_TOKEN_ID, transact={"from": someone})) + with tx_failed(): + c.approve(operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # approve invalid token - assert_tx_failed(lambda: c.approve(operator, INVALID_TOKEN_ID, transact={"from": someone})) + with tx_failed(): + c.approve(operator, INVALID_TOKEN_ID, transact={"from": someone}) tx_hash = c.approve(operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) logs = get_logs(tx_hash, c, "Approval") @@ -299,12 +285,13 @@ def test_approve(c, w3, assert_tx_failed, get_logs): assert args.tokenId == SOMEONE_TOKEN_IDS[0] -def test_setApprovalForAll(c, w3, assert_tx_failed, get_logs): +def test_setApprovalForAll(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] approved = True # setApprovalForAll myself - assert_tx_failed(lambda: c.setApprovalForAll(someone, approved, transact={"from": someone})) + with tx_failed(): + c.setApprovalForAll(someone, approved, transact={"from": someone}) tx_hash = c.setApprovalForAll(operator, approved, transact={"from": someone}) logs = get_logs(tx_hash, c, "ApprovalForAll") @@ -316,14 +303,16 @@ def test_setApprovalForAll(c, w3, assert_tx_failed, get_logs): assert args.approved == approved -def test_mint(c, w3, assert_tx_failed, get_logs): +def test_mint(c, w3, tx_failed, get_logs): minter, someone = w3.eth.accounts[:2] # mint by non-minter - assert_tx_failed(lambda: c.mint(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone})) + with tx_failed(): + c.mint(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # mint to zero address - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": minter})) + with tx_failed(): + c.mint(ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": minter}) # mint by minter tx_hash = c.mint(someone, NEW_TOKEN_ID, transact={"from": minter}) @@ -338,11 +327,12 @@ def test_mint(c, w3, assert_tx_failed, get_logs): assert c.balanceOf(someone) == 4 -def test_burn(c, w3, assert_tx_failed, get_logs): +def test_burn(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # burn token without ownership - assert_tx_failed(lambda: c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": operator})) + with tx_failed(): + c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": operator}) # burn token by owner tx_hash = c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": someone}) @@ -353,5 +343,6 @@ def test_burn(c, w3, assert_tx_failed, get_logs): assert args.sender == someone assert args.receiver == ZERO_ADDRESS assert args.tokenId == SOMEONE_TOKEN_IDS[0] - assert_tx_failed(lambda: c.ownerOf(SOMEONE_TOKEN_IDS[0])) + with tx_failed(): + c.ownerOf(SOMEONE_TOKEN_IDS[0]) assert c.balanceOf(someone) == 2 diff --git a/tests/functional/examples/voting/test_ballot.py b/tests/functional/examples/voting/test_ballot.py index 4207fe6e4e..9c3a09fc83 100644 --- a/tests/functional/examples/voting/test_ballot.py +++ b/tests/functional/examples/voting/test_ballot.py @@ -33,7 +33,7 @@ def test_initial_state(w3, c): assert c.voters(z0)[0] == 0 # Voter.weight -def test_give_the_right_to_vote(w3, c, assert_tx_failed): +def test_give_the_right_to_vote(w3, c, tx_failed): a0, a1, a2, a3, a4, a5 = w3.eth.accounts[:6] c.giveRightToVote(a1, transact={}) # Check voter given right has weight of 1 @@ -56,7 +56,8 @@ def test_give_the_right_to_vote(w3, c, assert_tx_failed): # Check voter_acount is now 6 assert c.voterCount() == 6 # Check chairperson cannot give the right to vote twice to the same voter - assert_tx_failed(lambda: c.giveRightToVote(a5, transact={})) + with tx_failed(): + c.giveRightToVote(a5, transact={}) # Check voters weight didn't change assert c.voters(a5)[0] == 1 # Voter.weight @@ -127,7 +128,7 @@ def test_forward_weight(w3, c): assert c.voters(a9)[0] == 10 # Voter.weight -def test_block_short_cycle(w3, c, assert_tx_failed): +def test_block_short_cycle(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = w3.eth.accounts[:10] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -141,7 +142,8 @@ def test_block_short_cycle(w3, c, assert_tx_failed): c.delegate(a3, transact={"from": a2}) c.delegate(a4, transact={"from": a3}) # would create a length 5 cycle: - assert_tx_failed(lambda: c.delegate(a0, transact={"from": a4})) + with tx_failed(): + c.delegate(a0, transact={"from": a4}) c.delegate(a5, transact={"from": a4}) # can't detect length 6 cycle, so this works: @@ -150,7 +152,7 @@ def test_block_short_cycle(w3, c, assert_tx_failed): # but this is something the frontend should prevent for user friendliness -def test_delegate(w3, c, assert_tx_failed): +def test_delegate(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -167,9 +169,11 @@ def test_delegate(w3, c, assert_tx_failed): # Delegate's weight is 2 assert c.voters(a0)[0] == 2 # Voter.weight # Voter cannot delegate twice - assert_tx_failed(lambda: c.delegate(a2, transact={"from": a1})) + with tx_failed(): + c.delegate(a2, transact={"from": a1}) # Voter cannot delegate to themselves - assert_tx_failed(lambda: c.delegate(a2, transact={"from": a2})) + with tx_failed(): + c.delegate(a2, transact={"from": a2}) # Voter CAN delegate to someone who hasn't been granted right to vote # Exercise: prevent that c.delegate(a6, transact={"from": a2}) @@ -180,7 +184,7 @@ def test_delegate(w3, c, assert_tx_failed): assert c.voters(a0)[0] == 3 # Voter.weight -def test_vote(w3, c, assert_tx_failed): +def test_vote(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = w3.eth.accounts[:10] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -197,9 +201,11 @@ def test_vote(w3, c, assert_tx_failed): # Vote count changes based on voters weight assert c.proposals(0)[1] == 3 # Proposal.voteCount # Voter cannot vote twice - assert_tx_failed(lambda: c.vote(0)) + with tx_failed(): + c.vote(0) # Voter cannot vote if they've delegated - assert_tx_failed(lambda: c.vote(0, transact={"from": a1})) + with tx_failed(): + c.vote(0, transact={"from": a1}) # Several voters can vote c.vote(1, transact={"from": a4}) c.vote(1, transact={"from": a2}) @@ -207,7 +213,8 @@ def test_vote(w3, c, assert_tx_failed): c.vote(1, transact={"from": a6}) assert c.proposals(1)[1] == 4 # Proposal.voteCount # Can't vote on a non-proposal - assert_tx_failed(lambda: c.vote(2, transact={"from": a7})) + with tx_failed(): + c.vote(2, transact={"from": a7}) def test_winning_proposal(w3, c): diff --git a/tests/functional/examples/wallet/test_wallet.py b/tests/functional/examples/wallet/test_wallet.py index 71f1e5f331..b9db5acee3 100644 --- a/tests/functional/examples/wallet/test_wallet.py +++ b/tests/functional/examples/wallet/test_wallet.py @@ -29,7 +29,7 @@ def _sign(seq, to, value, data, key): return _sign -def test_approve(w3, c, tester, assert_tx_failed, sign): +def test_approve(w3, c, tester, tx_failed, sign): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] k0, k1, k2, k3, k4, k5, k6, k7 = tester.backend.account_keys[:8] @@ -45,24 +45,20 @@ def pack_and_sign(seq, *args): c.approve(0, "0x" + to.hex(), value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if only 2 signatures are given sigs = pack_and_sign(1, k1, 0, k3, 0, 0) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if an invalid signature is given sigs = pack_and_sign(1, k1, 0, k7, 0, k5) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if transaction number is incorrect (the first argument should be 1) sigs = pack_and_sign(0, k1, 0, k3, 0, k5) - assert_tx_failed( - lambda: c.approve(0, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(0, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if not enough value is sent sigs = pack_and_sign(1, k1, 0, k3, 0, k5) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": 0, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": 0, "from": a1}) sigs = pack_and_sign(1, k1, 0, k3, 0, k5) # this call should succeed diff --git a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_evaluate_binop_decimal.py index 5c9956caba..44b82e321d 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/unit/ast/nodes/test_evaluate_binop_decimal.py @@ -20,7 +20,7 @@ @example(left=Decimal("0.9999999999"), right=Decimal("0.9999999999")) @example(left=Decimal("0.0000000001"), right=Decimal("0.0000000001")) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_decimal(get_contract, assert_tx_failed, op, left, right): +def test_binop_decimal(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: decimal, b: decimal) -> decimal: @@ -39,7 +39,8 @@ def foo(a: decimal, b: decimal) -> decimal: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) def test_binop_pow(): @@ -57,7 +58,7 @@ def test_binop_pow(): values=st.lists(st_decimals, min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), ) -def test_nested(get_contract, assert_tx_failed, values, ops): +def test_nested(get_contract, tx_failed, values, ops): variables = "abcdefghij" input_value = ",".join(f"{i}: decimal" for i in variables[: len(values)]) return_value = " ".join(f"{a} {b}" for a, b in zip(variables[: len(values)], ops)) @@ -83,4 +84,5 @@ def foo({input_value}) -> decimal: if is_valid: assert contract.foo(*values) == expected else: - assert_tx_failed(lambda: contract.foo(*values)) + with tx_failed(): + contract.foo(*values) diff --git a/tests/unit/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_evaluate_binop_int.py index 80c9381c0f..405d557f7d 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_int.py +++ b/tests/unit/ast/nodes/test_evaluate_binop_int.py @@ -16,7 +16,7 @@ @example(left=-1, right=1) @example(left=-1, right=-1) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_int128(get_contract, assert_tx_failed, op, left, right): +def test_binop_int128(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: int128, b: int128) -> int128: @@ -35,7 +35,8 @@ def foo(a: int128, b: int128) -> int128: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) st_uint64 = st.integers(min_value=0, max_value=2**64) @@ -45,7 +46,7 @@ def foo(a: int128, b: int128) -> int128: @settings(max_examples=50) @given(left=st_uint64, right=st_uint64) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_uint256(get_contract, assert_tx_failed, op, left, right): +def test_binop_uint256(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: uint256, b: uint256) -> uint256: @@ -64,7 +65,8 @@ def foo(a: uint256, b: uint256) -> uint256: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) @pytest.mark.xfail(reason="need to implement safe exponentiation logic") @@ -94,7 +96,7 @@ def foo(a: uint256, b: uint256) -> uint256: values=st.lists(st.integers(min_value=-256, max_value=256), min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), ) -def test_binop_nested(get_contract, assert_tx_failed, values, ops): +def test_binop_nested(get_contract, tx_failed, values, ops): variables = "abcdefghij" input_value = ",".join(f"{i}: int128" for i in variables[: len(values)]) return_value = " ".join(f"{a} {b}" for a, b in zip(variables[: len(values)], ops)) @@ -122,4 +124,5 @@ def foo({input_value}) -> int128: if is_valid: assert contract.foo(*values) == expected else: - assert_tx_failed(lambda: contract.foo(*values)) + with tx_failed(): + contract.foo(*values) From 5319cfbe14951e007ccdb323257e5ada869b35d5 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Sun, 24 Dec 2023 17:10:45 +0100 Subject: [PATCH 147/201] feat: allow `range(x, y, bound=N)` (#3679) - allow range where both start and end arguments are variables, so long as a bound is supplied - ban range expressions of the form `range(x, x + N)` since the new form is cleaner and supersedes it. - also do a bit of refactoring of the codegen for range --------- Co-authored-by: Charles Cooper --- docs/control-structures.rst | 8 +- .../features/iteration/test_for_in_list.py | 19 +- .../features/iteration/test_for_range.py | 116 ++++++++++- .../codegen/integration/test_crowdfund.py | 4 +- .../test_invalid_literal_exception.py | 7 - tests/functional/syntax/test_for_range.py | 197 +++++++++++++++++- vyper/codegen/ir_node.py | 8 +- vyper/codegen/stmt.py | 67 +++--- vyper/exceptions.py | 2 +- vyper/semantics/analysis/local.py | 109 ++++------ 10 files changed, 390 insertions(+), 147 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 873135709a..2f890bcb2f 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -287,9 +287,11 @@ Another use of range can be with ``START`` and ``STOP`` bounds. Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``. +Finally, it is possible to use ``range`` with runtime `start` and `stop` values as long as a constant `bound` value is provided. +In this case, Vyper checks at runtime that `end - start <= bound`. +``N`` must be a compile-time constant. + .. code-block:: python - for i in range(a, a + N): + for i in range(start, end, bound=N): ... - -``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert. diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index fb01cc98eb..bc1a12ae9e 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -1,3 +1,4 @@ +import re from decimal import Decimal import pytest @@ -700,13 +701,16 @@ def foo(): """, StateAccessViolation, ), - """ + ( + """ @external def foo(): a: int128 = 6 for i in range(a,a-3): pass """, + StateAccessViolation, + ), # invalid argument length ( """ @@ -789,10 +793,13 @@ def test_for() -> int128: ), ] +BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] +for_code_regex = re.compile(r"for .+ in (.*):") +bad_code_names = [ + f"{i} {for_code_regex.search(code).group(1)}" for i, (code, _) in enumerate(BAD_CODE) +] + -@pytest.mark.parametrize("code", BAD_CODE) -def test_bad_code(assert_compile_failed, get_contract, code): - err = StructureException - if not isinstance(code, str): - code, err = code +@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) +def test_bad_code(assert_compile_failed, get_contract, code, err): assert_compile_failed(lambda: get_contract(code), err) diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index 96b83ae691..e946447285 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -32,6 +32,102 @@ def repeat(n: uint256) -> uint256: c.repeat(7) +def test_range_bound_constant_end(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, 7, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 5): + assert c.repeat(n) == sum(i + 1 for i in range(n, 7)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(8) + # check assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0) + + +def test_range_bound_two_args(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(1, n, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 8): + assert c.repeat(n) == sum(i + 1 for i in range(1, n)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(0) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(8) + + +def test_range_bound_two_runtime_args(get_contract, tx_failed): + code = """ +@external +def repeat(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(0, 7): + assert c.repeat(0, n) == sum(range(0, n)) + assert c.repeat(n, n * 2) == sum(range(n, n * 2)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(1, 0) + with tx_failed(): + c.repeat(7, 0) + with tx_failed(): + c.repeat(8, 7) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0, 7) + with tx_failed(): + c.repeat(14, 21) + + +def test_range_overflow(get_contract, tx_failed): + code = """ +@external +def get_last(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x = i + return x + """ + c = get_contract(code) + UINT_MAX = 2**256 - 1 + assert c.get_last(UINT_MAX, UINT_MAX) == 0 # initial value of x + + for n in range(1, 6): + assert c.get_last(UINT_MAX - n, UINT_MAX) == UINT_MAX - 1 + + # check for `start + bound <= end`, overflow cases + for n in range(1, 7): + with tx_failed(): + c.get_last(UINT_MAX - n, 0) + with tx_failed(): + c.get_last(UINT_MAX, UINT_MAX - n) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external @@ -89,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101): + for i in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -146,26 +242,28 @@ def foo(a: {typ}) -> {typ}: assert c.foo(100) == 31337 -# test that we can get to the upper range of an integer @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) def test_for_range_edge(get_contract, typ): + """ + Check that we can get to the upper range of an integer. + Note that to avoid overflow in the bounds check for range(), + we need to calculate i+1 inside the loop. + """ code = f""" @external def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x, x + 1): - if i == max_value({typ}): + for i in range(x - 1, x, bound=1): + if i + 1 == max_value({typ}): found = True - assert found found = False x = max_value({typ}) - 1 - for i in range(x, x + 2): - if i == max_value({typ}): + for i in range(x - 1, x + 1, bound=2): + if i + 1 == max_value({typ}): found = True - assert found """ c = get_contract(code) @@ -178,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x+2): + for i in range(x, x + 2, bound=2): pass """ c = get_contract(code) diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 2083e62610..671d424d60 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index 1f4f112252..a0cf10ad02 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -18,13 +18,6 @@ def foo(): """, """ @external -def foo(x: int128): - y: int128 = 7 - for i in range(x, x + y): - pass - """, - """ -@external def foo(): x: String[100] = "these bytes are nо gооd because the o's are from the Russian alphabet" """, diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index e6f35c1d2d..7c7f9c476d 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -1,7 +1,9 @@ +import re + import pytest from vyper import compiler -from vyper.exceptions import StructureException +from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException fail_list = [ ( @@ -12,33 +14,191 @@ def foo(): pass """, StructureException, + "Invalid syntax for loop iterator", + "a[1]", + ), + ( + """ +@external +def foo(): + x: uint256 = 100 + for _ in range(10, bound=x): + pass + """, + StateAccessViolation, + "Bound must be a literal", + "x", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=5): + pass + """, + StructureException, + "Please remove the `bound=` kwarg when using range with constants", + "5", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x,x+1,bound=2,extra=3): + pass + """, + ArgumentException, + "Invalid keyword argument 'extra'", + "extra=3", ), ( """ @external def bar(): - for i in range(1,2,bound=2): + for i in range(0): pass """, StructureException, + "End must be greater than start", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(0, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(0, n * 10): + pass + return n + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n * 10", ), ( """ @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2): + for i in range(0, x + 1): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x + 1", + ), + ( + """ +@external +def bar(): + for i in range(2, 1): pass """, StructureException, + "End must be greater than start", + "1", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, x + 10): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(n, 6): + pass + return x + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n", + ), + ( + """ +@external +def foo(x: int128): + y: int128 = 7 + for i in range(x, x + y): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", ), ] +for_code_regex = re.compile(r"for .+ in (.*):") +fail_test_names = [ + ( + f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + f" raises {type(err).__name__}" + ) + for i, (code, err, msg, src) in enumerate(fail_list) +] -@pytest.mark.parametrize("bad_code", fail_list) -def test_range_fail(bad_code): - with pytest.raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + +@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names) +def test_range_fail(bad_code, error_type, message, source_code): + with pytest.raises(error_type) as exc_info: + compiler.compile_code(bad_code) + assert message == exc_info.value.message + assert source_code == exc_info.value.args[1].node_source_code valid_list = [ @@ -58,7 +218,21 @@ def foo(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i in range(1, x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(0, x, bound=4): pass """, """ @@ -72,7 +246,12 @@ def kick_foos(): """, ] +valid_test_names = [ + f"{i} {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + for i, code in enumerate(valid_list) +] + -@pytest.mark.parametrize("good_code", valid_list) +@pytest.mark.parametrize("good_code", valid_list, ids=valid_test_names) def test_range_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ce26066968..45d93f3067 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -444,11 +444,15 @@ def unique_symbols(self): return ret @property - def is_literal(self): + def is_literal(self) -> bool: return isinstance(self.value, int) or self.value == "multi" + def int_value(self) -> int: + assert isinstance(self.value, int) + return self.value + @property - def is_pointer(self): + def is_pointer(self) -> bool: # not used yet but should help refactor/clarify downstream code # eventually return self.location is not None diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 601597771c..18e5c3d494 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -225,15 +225,6 @@ def parse_Raise(self): else: return IRnode.from_list(["revert", 0, 0], error_msg="user raise") - def _check_valid_range_constant(self, arg_ast_node): - with self.context.range_scope(): - arg_expr = Expr.parse_value_expr(arg_ast_node, self.context) - return arg_expr - - def _get_range_const_value(self, arg_ast_node): - arg_expr = self._check_valid_range_constant(arg_ast_node) - return arg_expr.value - def parse_For(self): with self.context.block_scope(): if self.stmt.get("iter.func.id") == "range": @@ -249,41 +240,37 @@ def _parse_For_range(self): iter_typ = INT256_T # Get arg0 - arg0 = self.stmt.iter.args[0] - num_of_args = len(self.stmt.iter.args) - - kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) - for s in self.stmt.iter.keywords or [] - } - - # Type 1 for, e.g. for i in range(10): ... - if num_of_args == 1: - n = Expr.parse_value_expr(arg0, self.context) - start = IRnode.from_list(0, typ=iter_typ) - rounds = n - rounds_bound = kwargs.get("bound", rounds) - - # Type 2 for, e.g. for i in range(100, 110): ... - elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: - arg0_val = self._get_range_const_value(arg0) - arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) - start = IRnode.from_list(arg0_val, typ=iter_typ) - rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) - rounds_bound = rounds + for_iter: vy_ast.Call = self.stmt.iter + args_len = len(for_iter.args) + if args_len == 1: + arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + elif args_len == 2: + arg0, arg1 = for_iter.args + else: # pragma: nocover + raise TypeCheckFailure("unreachable: bad # of arguments to range()") - # Type 3 for, e.g. for i in range(x, x + 10): ... - else: - arg1 = self.stmt.iter.args[1] - rounds = self._get_range_const_value(arg1.right) + with self.context.range_scope(): start = Expr.parse_value_expr(arg0, self.context) - _, hi = start.typ.int_bounds - start = clamp("le", start, hi + 1 - rounds) + end = Expr.parse_value_expr(arg1, self.context) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + } + + if "bound" in kwargs: + with end.cache_when_complex("end") as (b1, end): + # note: the check for rounds<=rounds_bound happens in asm + # generation for `repeat`. + clamped_start = clamp("le", start, end) + rounds = b1.resolve(IRnode.from_list(["sub", end, clamped_start])) + rounds_bound = kwargs.pop("bound").int_value() + else: + rounds = end.int_value() - start.int_value() rounds_bound = rounds - bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value - if bound < 1: - return + assert len(kwargs) == 0 # sanity check stray keywords + + if rounds_bound < 1: # pragma: nocover + raise TypeCheckFailure("unreachable: unchecked 0 bound") varname = self.stmt.target.id i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 8f72d9afc9..8921814188 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -41,7 +41,7 @@ def __init__(self, message="Error Message not found.", *items): Error message to display with the exception. *items : VyperNode | Tuple[str, VyperNode], optional Vyper ast node(s), or tuple of (description, node) indicating where - the exception occured. Source annotations are generated in the order + the exception occurred. Source annotations are generated in the order the nodes are given. A single tuple of (lineno, col_offset) is also understood to support diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 2a84f69ad4..a3ebf85fa2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -7,7 +7,6 @@ ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidLiteral, InvalidOperation, InvalidType, IteratorException, @@ -355,71 +354,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - range_ = node.iter - validate_call_args(range_, (1, 2), kwargs=["bound"]) - - args = range_.args - kwargs = {s.arg: s.value for s in range_.keywords or []} - if len(args) == 1: - # range(CONSTANT) - n = args[0] - bound = kwargs.pop("bound", None) - validate_expected_type(n, IntegerT.any()) - - if bound is None: - if not isinstance(n, vy_ast.Num): - raise StateAccessViolation("Value must be a literal", n) - if n.value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - type_list = get_possible_types_from_node(n) - - else: - if not isinstance(bound, vy_ast.Num): - raise StateAccessViolation("bound must be a literal", bound) - if bound.value <= 0: - raise StructureException("bound must be at least 1", args[0]) - type_list = get_common_types(n, bound) - - else: - if range_.keywords: - raise StructureException( - "Keyword arguments are not supported for `range(N, M)` and" - "`range(x, x + N)` expressions", - range_.keywords[0], - ) - - validate_expected_type(args[0], IntegerT.any()) - type_list = get_common_types(*args) - if not isinstance(args[0], vy_ast.Constant): - # range(x, x + CONSTANT) - if not isinstance(args[1], vy_ast.BinOp) or not isinstance( - args[1].op, vy_ast.Add - ): - raise StructureException( - "Second element must be the first element plus a literal value", args[0] - ) - if not vy_ast.compare_nodes(args[0], args[1].left): - raise StructureException( - "First and second variable must be the same", args[1].left - ) - if not isinstance(args[1].right, vy_ast.Int): - raise InvalidLiteral("Literal must be an integer", args[1].right) - if args[1].right.value < 1: - raise StructureException( - f"For loop has invalid number of iterations ({args[1].right.value})," - " the value must be greater than zero", - args[1].right, - ) - else: - # range(CONSTANT, CONSTANT) - if not isinstance(args[1], vy_ast.Int): - raise InvalidType("Value must be a literal integer", args[1]) - validate_expected_type(args[1], IntegerT.any()) - if args[0].value >= args[1].value: - raise StructureException("Second value must be > first value", args[1]) - - if not type_list: - raise TypeMismatch("Iterator values are of different types", node.iter) + type_list = _analyse_range_call(node.iter) else: # iteration over a variable or literal list @@ -490,8 +425,8 @@ def visit_For(self, node): try: with NodeMetadata.enter_typechecker_speculation(): - for n in node.body: - self.visit(n) + for stmt in node.body: + self.visit(stmt) except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: @@ -801,3 +736,41 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.body, typ) validate_expected_type(node.orelse, typ) self.visit(node.orelse, typ) + + +def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: + """ + Check that the arguments to a range() call are valid. + :param node: call to range() + :return: None + """ + validate_call_args(node, (1, 2), kwargs=["bound"]) + kwargs = {s.arg: s.value for s in node.keywords or []} + start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + + all_args = (start, end, *kwargs.values()) + for arg1 in all_args: + validate_expected_type(arg1, IntegerT.any()) + + type_list = get_common_types(*all_args) + if not type_list: + raise TypeMismatch("Iterator values are of different types", node) + + if "bound" in kwargs: + bound = kwargs["bound"] + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("Bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("Bound must be at least 1", bound) + if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): + error = "Please remove the `bound=` kwarg when using range with constants" + raise StructureException(error, bound) + else: + for arg in (start, end): + if not isinstance(arg, vy_ast.Num): + error = "Value must be a literal integer, unless a bound is specified" + raise StateAccessViolation(error, arg) + if end.value <= start.value: + raise StructureException("End must be greater than start", end) + + return type_list From 2df916ca453ddda4cf08878c4fdaaa55b963686c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 25 Dec 2023 09:21:38 -0500 Subject: [PATCH 148/201] feat: improve panics in IR generation (#3708) * feat: improve panics in IR generation this QOL commit improves on 91659266c55a by passing through the `__traceback__` field when the exception is modified (instead of using `__cause__` - cf. PEP-3134 regarding the difference) and improves error messages when an IRnode is not returned properly. using `__traceback__` generally results in a better experience because the immediate cause of the exception is displayed when running `vyper -v` instead of needing to scroll up through the exception chain (if the exception chain is reproduced correctly at all in the first place). --------- Co-authored-by: Harry Kalogirou --- vyper/codegen/expr.py | 30 ++++++++++++++---------------- vyper/codegen/stmt.py | 14 ++++++-------- vyper/exceptions.py | 13 ++++++++----- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 693d5c2aad..4c7c3afaed 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -73,19 +73,17 @@ def __init__(self, node, context): self.context = context if isinstance(node, IRnode): - # TODO this seems bad + # this is a kludge for parse_AugAssign to pass in IRnodes + # directly. + # TODO fixme! self.ir_node = node return - fn = getattr(self, f"parse_{type(node).__name__}", None) - if fn is None: - raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}", node) - - with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" + with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + fn = getattr(self, fn_name) self.ir_node = fn() - - if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) + assert isinstance(self.ir_node, IRnode), self.ir_node self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -362,9 +360,9 @@ def parse_Subscript(self): index = self.expr.slice.value.n # note: this check should also happen in get_element_ptr if not 0 <= index < len(sub.typ.member_types): - return + raise TypeCheckFailure("unreachable") else: - return + raise TypeCheckFailure("unreachable") ir_node = get_element_ptr(sub, index) ir_node.mutable = sub.mutable @@ -399,13 +397,13 @@ def parse_BinOp(self): new_typ = left.typ if new_typ.bits != 256: # TODO implement me. ["and", 2**bits - 1, shl(right, left)] - return + raise TypeCheckFailure("unreachable") return IRnode.from_list(shl(right, left), typ=new_typ) if isinstance(self.expr.op, vy_ast.RShift): new_typ = left.typ if new_typ.bits != 256: # TODO implement me. promote_signed_int(op(right, left), bits) - return + raise TypeCheckFailure("unreachable") op = shr if not left.typ.is_signed else sar return IRnode.from_list(op(right, left), typ=new_typ) @@ -448,7 +446,7 @@ def build_in_comparator(self): elif isinstance(self.expr.op, vy_ast.NotIn): found, not_found = 0, 1 else: # pragma: no cover - return + raise TypeCheckFailure("unreachable") i = IRnode.from_list(self.context.fresh_varname("in_ix"), typ=UINT256_T) @@ -510,7 +508,7 @@ def parse_Compare(self): right = Expr.parse_value_expr(self.expr.right, self.context) if right.value is None: - return + raise TypeCheckFailure("unreachable") if isinstance(self.expr.op, (vy_ast.In, vy_ast.NotIn)): if is_array_like(right.typ): @@ -562,7 +560,7 @@ def parse_Compare(self): elif left.typ._is_prim_word and right.typ._is_prim_word: if op not in ("eq", "ne"): - return + raise TypeCheckFailure("unreachable") else: # kludge to block behavior in #2638 # TODO actually implement equality for complex types diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 18e5c3d494..bc29a79734 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -40,16 +40,14 @@ class Stmt: def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: self.stmt = node self.context = context - fn = getattr(self, f"parse_{type(node).__name__}", None) - if fn is None: - raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}") - with context.internal_memory_scope(): - with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" + with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + fn = getattr(self, fn_name) + with context.internal_memory_scope(): self.ir_node = fn() - if self.ir_node is None: - raise TypeCheckFailure("Statement node did not produce IR") + assert isinstance(self.ir_node, IRnode), self.ir_node self.ir_node.annotation = self.stmt.get("node_source_code") self.ir_node.source_pos = getpos(self.stmt) @@ -347,7 +345,7 @@ def parse_AugAssign(self): # because of this check, we do not need to check for # make_setter references lhs<->rhs as in parse_Assign - # single word load/stores are atomic. - return + raise TypeCheckFailure("unreachable") with target.cache_when_complex("_loc") as (b, target): rhs = Expr.parse_value_expr( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 8921814188..f216069eab 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -359,14 +359,17 @@ class InvalidABIType(VyperInternalException): @contextlib.contextmanager -def tag_exceptions( - node, fallback_exception_type=CompilerPanic, fallback_message="unhandled exception" -): +def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): try: yield except _BaseVyperException as e: if not e.annotations and not e.lineno: - raise e.with_annotation(node) from None + tb = e.__traceback__ + raise e.with_annotation(node).with_traceback(tb) raise e from None except Exception as e: - raise fallback_exception_type(fallback_message, node) from e + tb = e.__traceback__ + fallback_message = "unhandled exception" + if note: + fallback_message += f", {note}" + raise fallback_exception_type(fallback_message, node).with_traceback(tb) From bf26f83ec21788d9fb591879fdd6af2ddfacb050 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Wed, 27 Dec 2023 17:24:47 +0100 Subject: [PATCH 149/201] docs: `address.codehash` for empty account (#3711) --- docs/types.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/types.rst b/docs/types.rst index 0ad13967e9..a8be721b1a 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -299,7 +299,7 @@ Members Member Type Description =============== =========== ========================================================================== ``balance`` ``uint256`` Balance of an address -``codehash`` ``bytes32`` Keccak of code at an address, ``EMPTY_BYTES32`` if no contract is deployed +``codehash`` ``bytes32`` Keccak of code at an address, ``0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470`` if no contract is deployed (see `EIP-1052 `_) ``codesize`` ``uint256`` Size of code deployed at an address, in bytes ``is_contract`` ``bool`` Boolean indicating if a contract is deployed at an address ``code`` ``Bytes`` Contract bytecode From 87db3c1421acab63e385a132406722ee5b255685 Mon Sep 17 00:00:00 2001 From: Eike Caldeweyher <46899008+f3rmion@users.noreply.github.com> Date: Fri, 29 Dec 2023 19:04:26 +0100 Subject: [PATCH 150/201] docs: indexed arguments for events are limited (#3715) The number of indexed arguments is limited by the EVM to a maximum of four. Include definition of topics and data. --- docs/event-logging.rst | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/event-logging.rst b/docs/event-logging.rst index c6e20954f8..904b179e70 100644 --- a/docs/event-logging.rst +++ b/docs/event-logging.rst @@ -66,11 +66,19 @@ Let's look at an event declaration in more detail. receiver: indexed(address) value: uint256 +The EVM currently has five opcodes for emitting event logs: ``LOG0``, ``LOG1``, ``LOG2``, ``LOG3``, and ``LOG4``. +These opcodes can be used to create log records, where each log record consists of both **topics** and **data**. +Topics are 32-byte ''words'' that are used to describe what is happening in an event. +While topics are searchable, data is not. +Event data is however not limited, which means that you can include large or complicated data like arrays or strings. +Different opcodes (``LOG0`` through ``LOG4``) allow for different numbers of topics. +For instance, ``LOG1`` includes one topic, ``LOG2`` includes two topics, and so on. Event declarations look similar to struct declarations, containing one or more arguments that are passed to the event. Typical events will contain two kinds of arguments: - * **Indexed** arguments, which can be searched for by listeners. Each indexed argument is identified by the ``indexed`` keyword. Here, each indexed argument is an address. You can have any number of indexed arguments, but indexed arguments are not passed directly to listeners, although some of this information (such as the sender) may be available in the listener's `results` object. - * **Value** arguments, which are passed through to listeners. You can have any number of value arguments and they can have arbitrary names, but each is limited by the EVM to be no more than 32 bytes. + * **Indexed** arguments (topics), which can be searched for by listeners. Each indexed argument is identified by the ``indexed`` keyword. Here, each indexed argument is an address. You can have up to four indexed arguments (``LOG4``), but indexed arguments are not passed directly to listeners, although some of this information (such as the sender) may be available in the listener's `results` object. + * **Value** arguments (data), which are passed through to listeners. You can have any number of value arguments and they can have arbitrary names, but each is limited by the EVM to be no more than 32 bytes. +Note that the first topic of a log record consists of the signature of the name of the event that occurred, including the types of its parameters. It is also possible to create an event with no arguments. In this case, use the ``pass`` statement: .. code-block:: python From 56c4c9dbc09d6310bf132cfde3fdbe1431189a9b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 1 Jan 2024 06:19:47 +0800 Subject: [PATCH 151/201] refactor: reimplement AST folding (#3669) this commit reimplements AST folding. fundamentally, it changes AST folding from a mutating pass to be an annotation pass. this brings several benefits: - typechecking is easier, because folding does not have to reason at all about types. type checking happens on both the folded and unfolded nodes, so intermediate values are type-checked. - correctness in general is easier, because the AST is not mutated. there is also some incidental performance benefit, although that is not necessarily the focus here. - the vyper frontend is now nearly mutation-free. only the getter AST expansion pass remains. note that we cannot push folding past the typechecking stage entirely, because some type checking operations depend on having folded values (e.g., `range()` expressions, or type expressions with integer parameters). the approach taken in this commit is to change constant folding to be annotating, rather than mutating. this way, type-checking can operate on the original AST (and check for the folded values where needed). intermediate values are also type-checked, so expressions like `x: uint128 = 2**128 + 1 - 1` are caught by the typechecker. summary of changes: - `evaluate()` is renamed to `_try_fold()`. a new utility function called `get_folded_value()` caches folded values and is threaded through the codebase. - `pre_typecheck` is added, which extracts `constant` variables and runs `get_folded_value()` on all nodes. - a new `Modifiability` enum replaces the old (confusing) `is_constant` and `is_immutable` attributes on ExprInfo. - `ExprInfo.is_transient` is removed, and handled by adding `TRANSIENT` to the `DataLocation` enum. - the old `check_literal` and `check_kwargable` utility functions are replaced with a more general (and more correct) `check_modifiability` function - several utility functions (ex. `_validate_numeric_bounds()`) related to ad-hoc type-checking (which would happen during constant folding) are removed. - `CompilerData.vyper_module_folded` is renamed to `annotated_vyper_module` - the AST output options are now `ast` and `annotated_ast`. - `None` literals are now banned in AST validation instead of during analysis. --------- Co-authored-by: Charles Cooper --- .../builtins/codegen/test_keccak256.py | 31 ++ .../builtins/codegen/test_sha256.py | 30 ++ .../functional/builtins/codegen/test_unary.py | 7 +- tests/functional/builtins/folding/test_abs.py | 10 +- .../builtins/folding/test_addmod_mulmod.py | 2 +- .../builtins/folding/test_bitwise.py | 11 +- .../builtins/folding/test_epsilon.py | 2 +- .../builtins/folding/test_floor_ceil.py | 2 +- .../folding/test_fold_as_wei_value.py | 4 +- .../builtins/folding/test_keccak_sha.py | 6 +- tests/functional/builtins/folding/test_len.py | 6 +- .../builtins/folding/test_min_max.py | 6 +- .../builtins/folding/test_powmod.py | 2 +- .../test_default_parameters.py | 42 +++ .../test_external_contract_calls.py | 2 +- .../codegen/test_call_graph_stability.py | 2 +- tests/functional/codegen/test_interfaces.py | 2 +- .../codegen/types/numbers/test_constants.py | 4 +- .../codegen/types/numbers/test_decimals.py | 43 ++- .../codegen/types/numbers/test_signed_ints.py | 40 ++- .../types/numbers/test_unsigned_ints.py | 30 +- .../codegen/types/test_dynamic_array.py | 20 +- tests/functional/codegen/types/test_lists.py | 10 +- .../exceptions/test_argument_exception.py | 6 +- tests/functional/syntax/test_abi_decode.py | 4 +- tests/functional/syntax/test_abs.py | 40 +++ tests/functional/syntax/test_addmulmod.py | 22 ++ tests/functional/syntax/test_as_wei_value.py | 72 ++++- tests/functional/syntax/test_ceil.py | 19 ++ tests/functional/syntax/test_dynamic_array.py | 17 +- tests/functional/syntax/test_epsilon.py | 20 ++ tests/functional/syntax/test_floor.py | 19 ++ tests/functional/syntax/test_for_range.py | 56 +++- tests/functional/syntax/test_len.py | 22 +- tests/functional/syntax/test_method_id.py | 50 +++ tests/functional/syntax/test_minmax.py | 43 ++- tests/functional/syntax/test_minmax_value.py | 28 +- tests/functional/syntax/test_powmod.py | 39 +++ tests/functional/syntax/test_raw_call.py | 20 +- tests/functional/syntax/test_ternary.py | 4 +- tests/functional/syntax/test_uint2str.py | 19 ++ tests/functional/syntax/test_unary.py | 21 ++ ..._decimal.py => test_fold_binop_decimal.py} | 8 +- ...te_binop_int.py => test_fold_binop_int.py} | 10 +- ...evaluate_boolop.py => test_fold_boolop.py} | 6 +- ...aluate_compare.py => test_fold_compare.py} | 10 +- ...te_subscript.py => test_fold_subscript.py} | 2 +- ...aluate_unaryop.py => test_fold_unaryop.py} | 6 +- tests/unit/ast/nodes/test_replace_in_tree.py | 70 ----- tests/unit/ast/test_ast_dict.py | 6 +- tests/unit/ast/test_folding.py | 272 ---------------- tests/unit/ast/test_natspec.py | 2 +- vyper/ast/README.md | 21 -- vyper/ast/__init__.py | 2 +- vyper/ast/__init__.pyi | 2 +- vyper/ast/folding.py | 263 ---------------- vyper/ast/natspec.py | 10 +- vyper/ast/nodes.py | 265 ++++++++++------ vyper/ast/nodes.pyi | 13 +- vyper/ast/parse.py | 1 + vyper/ast/validation.py | 11 +- vyper/builtins/_signatures.py | 25 +- vyper/builtins/functions.py | 297 +++++++++--------- vyper/cli/vyper_compile.py | 11 +- vyper/codegen/expr.py | 17 +- vyper/compiler/README.md | 2 - vyper/compiler/__init__.py | 2 + vyper/compiler/output.py | 20 +- vyper/compiler/phases.py | 61 ++-- vyper/semantics/README.md | 29 +- vyper/semantics/analysis/base.py | 87 +++-- vyper/semantics/analysis/local.py | 81 +++-- vyper/semantics/analysis/module.py | 46 +-- vyper/semantics/analysis/pre_typecheck.py | 94 ++++++ vyper/semantics/analysis/utils.py | 73 ++--- vyper/semantics/data_locations.py | 3 +- vyper/semantics/environment.py | 4 +- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/function.py | 13 +- vyper/semantics/types/subscriptable.py | 2 + vyper/semantics/types/utils.py | 10 +- 81 files changed, 1464 insertions(+), 1230 deletions(-) create mode 100644 tests/functional/syntax/test_abs.py create mode 100644 tests/functional/syntax/test_ceil.py create mode 100644 tests/functional/syntax/test_epsilon.py create mode 100644 tests/functional/syntax/test_floor.py create mode 100644 tests/functional/syntax/test_method_id.py create mode 100644 tests/functional/syntax/test_powmod.py create mode 100644 tests/functional/syntax/test_uint2str.py create mode 100644 tests/functional/syntax/test_unary.py rename tests/unit/ast/nodes/{test_evaluate_binop_decimal.py => test_fold_binop_decimal.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_binop_int.py => test_fold_binop_int.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_boolop.py => test_fold_boolop.py} (92%) rename tests/unit/ast/nodes/{test_evaluate_compare.py => test_fold_compare.py} (94%) rename tests/unit/ast/nodes/{test_evaluate_subscript.py => test_fold_subscript.py} (93%) rename tests/unit/ast/nodes/{test_evaluate_unaryop.py => test_fold_unaryop.py} (86%) delete mode 100644 tests/unit/ast/nodes/test_replace_in_tree.py delete mode 100644 tests/unit/ast/test_folding.py delete mode 100644 vyper/ast/folding.py create mode 100644 vyper/semantics/analysis/pre_typecheck.py diff --git a/tests/functional/builtins/codegen/test_keccak256.py b/tests/functional/builtins/codegen/test_keccak256.py index 90fa8b9e09..3b0b9f2018 100644 --- a/tests/functional/builtins/codegen/test_keccak256.py +++ b/tests/functional/builtins/codegen/test_keccak256.py @@ -1,3 +1,6 @@ +from vyper.utils import hex_to_int + + def test_hash_code(get_contract_with_gas_estimation, keccak): hash_code = """ @external @@ -80,3 +83,31 @@ def try32(inp: bytes32) -> bool: assert c.tryy(b"\x35" * 33) is True print("Passed KECCAK256 hash test") + + +def test_hash_constant_bytes32(get_contract_with_gas_estimation, keccak): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(hex_to_int(hex_val).to_bytes(32, "big")).hex() + + +def test_hash_constant_string(get_contract_with_gas_estimation, keccak): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = keccak256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert "0x" + c.foo().hex() == keccak(str_val.encode()).hex() diff --git a/tests/functional/builtins/codegen/test_sha256.py b/tests/functional/builtins/codegen/test_sha256.py index 468e684645..8e1b89bd31 100644 --- a/tests/functional/builtins/codegen/test_sha256.py +++ b/tests/functional/builtins/codegen/test_sha256.py @@ -2,6 +2,8 @@ import pytest +from vyper.utils import hex_to_int + pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -77,3 +79,31 @@ def bar() -> bytes32: c.set(test_val, transact={}) assert c.a() == test_val assert c.bar() == hashlib.sha256(test_val).digest() + + +def test_sha256_constant_bytes32(get_contract_with_gas_estimation): + hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(bytes32) = {hex_val} +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(hex_to_int(hex_val).to_bytes(32, "big")).digest() + + +def test_sha256_constant_string(get_contract_with_gas_estimation): + str_val = "0x1234567890123456789012345678901234567890123456789012345678901234" + code = f""" +FOO: constant(String[66]) = "{str_val}" +BAR: constant(bytes32) = sha256(FOO) +@external +def foo() -> bytes32: + x: bytes32 = BAR + return x + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == hashlib.sha256(str_val.encode()).digest() diff --git a/tests/functional/builtins/codegen/test_unary.py b/tests/functional/builtins/codegen/test_unary.py index 33f79be233..2be5c0d33f 100644 --- a/tests/functional/builtins/codegen/test_unary.py +++ b/tests/functional/builtins/codegen/test_unary.py @@ -69,16 +69,11 @@ def bar() -> decimal: def test_negation_int128(get_contract): code = """ -a: constant(int128) = -2**127 - -@external -def foo() -> int128: - return -2**127 +a: constant(int128) = min_value(int128) @external def bar() -> int128: return -(a+1) """ c = get_contract(code) - assert c.foo() == -(2**127) assert c.bar() == 2**127 - 1 diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index a91a4f1ad3..68131678fa 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -4,7 +4,7 @@ from vyper import ast as vy_ast from vyper.builtins import functions as vy_fn -from vyper.exceptions import OverflowException +from vyper.exceptions import InvalidType @pytest.mark.fuzzing @@ -21,7 +21,7 @@ def foo(a: int256) -> int256: vyper_ast = vy_ast.parse_to_ast(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) assert contract.foo(a) == new_node.value == abs(a) @@ -35,7 +35,7 @@ def test_abs_upper_bound_folding(get_contract, a): def foo(a: int256) -> int256: return abs({a}) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) @@ -55,7 +55,7 @@ def test_abs_lower_bound_folded(get_contract, tx_failed): source = """ @external def foo() -> int256: - return abs(-2**255) + return abs(min_value(int256)) """ - with pytest.raises(OverflowException): + with pytest.raises(InvalidType): get_contract(source) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 33dcc62984..1d789f1655 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -24,6 +24,6 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 63e733644f..53a6d333a0 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -13,6 +13,9 @@ st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1) +# TODO: move this file to tests/unit/ast/nodes/test_fold_bitwise.py + + @pytest.mark.fuzzing @settings(max_examples=50) @pytest.mark.parametrize("op", ["&", "|", "^"]) @@ -28,7 +31,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value @@ -49,7 +52,7 @@ def foo(a: uint256, b: uint256) -> uint256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). @@ -79,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256: old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. @@ -104,6 +107,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"~{value}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 794648cfce..4f5e9434ec 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -15,6 +15,6 @@ def foo() -> {typ_name}: vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 87db23889a..04921e504e 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -30,6 +30,6 @@ def foo(a: decimal) -> int256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 210ab51f0d..4287615bab 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -32,7 +32,7 @@ def foo(a: decimal) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue()._try_fold(old_node) assert contract.foo(value) == new_node.value @@ -51,6 +51,6 @@ def foo(a: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue().evaluate(old_node) + new_node = vy_fn.AsWeiValue()._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index a2fe460dd1..8da420538f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -22,7 +22,7 @@ def foo(a: String[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -41,7 +41,7 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -62,6 +62,6 @@ def foo(a: Bytes[100]) -> bytes32: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index edf33120dd..967f906555 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -17,7 +17,7 @@ def foo(a: String[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value) == new_node.value @@ -35,7 +35,7 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value.encode()) == new_node.value @@ -53,6 +53,6 @@ def foo(a: Bytes[1024]) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len().evaluate(old_node) + new_node = vy_fn.Len()._try_fold(old_node) assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 309f7519c0..36a611fa1b 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value @@ -50,7 +50,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value @@ -69,6 +69,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node) + new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index 8667ec93fd..a3c2567f58 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -21,6 +21,6 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256().evaluate(old_node) + new_node = vy_fn.PowMod256()._try_fold(old_node) assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 03f5d9fca2..462748a9c7 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -305,6 +305,48 @@ def foo(a: address = empty(address)): def foo(a: int112 = min_value(int112)): self.A = a """, + """ +struct X: + x: int128 + y: address +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +@external +def out_literals(a: int128 = BAR.x + 1) -> X: + return BAR + """, + """ +struct X: + x: int128 + y: address +struct Y: + x: X + y: uint256 +BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +FOO: constant(Y) = Y({x: BAR, y: 256}) +@external +def out_literals(a: int128 = FOO.x.x + 1) -> Y: + return FOO + """, + """ +struct Bar: + a: bool + +BAR: constant(Bar) = Bar({a: True}) + +@external +def foo(x: bool = True and not BAR.a): + pass + """, + """ +struct Bar: + a: uint256 + +BAR: constant(Bar) = Bar({ a: 123 }) + +@external +def foo(x: bool = BAR.a + 1 > 456): + pass + """, ] diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 0360396f03..0af4f9f937 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -388,7 +388,7 @@ def test_int128_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> int256: - return (2**255)-1 + return max_value(int256) """ c = get_contract(contract_1) diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 2d8ad59791..ca0e6c8c9e 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -54,7 +54,7 @@ def foo(): t = CompilerData(code) # check the .called_functions data structure on foo() directly - foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] + foo = t.annotated_vyper_module.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 65d2df9038..7d363fadc0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -435,7 +435,7 @@ def ok() -> {typ}: @external def should_fail() -> int256: - return -2**255 # OOB for all int/uint types with less than 256 bits + return min_value(int256) """ code = f""" diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index 8244bc5487..af871983ab 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -4,7 +4,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch from vyper.utils import MemoryPositions @@ -158,7 +158,7 @@ def test_custom_constants_fail(get_contract, assert_compile_failed, storage_type def foo() -> {return_type}: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) def test_constant_address(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 25dc1f1a1e..fcf71f12f0 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -3,7 +3,13 @@ import pytest -from vyper.exceptions import DecimalOverrideException, InvalidOperation, TypeMismatch +from vyper import compile_code +from vyper.exceptions import ( + DecimalOverrideException, + InvalidOperation, + OverflowException, + TypeMismatch, +) from vyper.utils import DECIMAL_EPSILON, SizeLimits @@ -24,23 +30,25 @@ def test_decimal_override(): @pytest.mark.parametrize("op", ["**", "&", "|", "^"]) -def test_invalid_ops(get_contract, assert_compile_failed, op): +def test_invalid_ops(op): code = f""" @external def foo(x: decimal, y: decimal) -> decimal: return x {op} y """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, op): +def test_invalid_unary_ops(op): code = f""" @external def foo(x: decimal) -> decimal: return {op} x """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) def quantize(x: Decimal) -> Decimal: @@ -263,11 +271,32 @@ def bar(num: decimal) -> decimal: assert c.bar(Decimal("1e37")) == Decimal("-9e37") # Math lines up -def test_exponents(assert_compile_failed, get_contract): +def test_exponents(): code = """ @external def foo() -> decimal: return 2.2 ** 2.0 """ - assert_compile_failed(lambda: get_contract(code), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code) + + +def test_decimal_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: decimal = 18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) + + +def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): + code = """ +@external +def foo(): + a: decimal = -18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 52de5b649f..a10eaee408 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -4,6 +4,7 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT from vyper.utils import evm_div, evm_mod @@ -206,17 +207,16 @@ def _num_min() -> {typ}: @pytest.mark.parametrize("typ", types) -def test_overflow_out_of_range(get_contract, assert_compile_failed, typ): +def test_overflow_out_of_range(get_contract, typ): code = f""" @external def num_sub() -> {typ}: return 1-2**{typ.bits} """ - if typ.bits == 256: - assert_compile_failed(lambda: get_contract(code), OverflowException) - else: - assert_compile_failed(lambda: get_contract(code), InvalidType) + exc = OverflowException if typ.bits == 256 else InvalidType + with pytest.raises(exc): + compile_code(code) ARITHMETIC_OPS = { @@ -231,7 +231,7 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -318,10 +318,12 @@ def foo() -> {typ}: elif div_by_zero: with tx_failed(): c.foo(x, y) - assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_2) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) + with pytest.raises(ZeroDivisionException): + compile_code(code_4) else: with tx_failed(): c.foo(x, y) @@ -329,9 +331,8 @@ def foo() -> {typ}: get_contract(code_2).foo(x) with tx_failed(): get_contract(code_3).foo(y) - assert_compile_failed( - lambda code=code_4: get_contract(code), (InvalidType, OverflowException) - ) + with pytest.raises((InvalidType, OverflowException)): + compile_code(code_4) COMPARISON_OPS = { @@ -359,7 +360,7 @@ def foo(x: {typ}, y: {typ}) -> bool: fn = COMPARISON_OPS[op] c = get_contract(code_1) - # note: constant folding is tested in tests/ast/folding + # note: constant folding is tested in tests/unit/ast/nodes special_cases = [ lo, lo + 1, @@ -413,10 +414,21 @@ def foo(a: {typ}) -> {typ}: @pytest.mark.parametrize("typ", types) @pytest.mark.parametrize("op", ["not"]) -def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): +def test_invalid_unary_ops(typ, op): code = f""" @external def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_underflow(): + code = """ +@external +def foo(): + a: int256 = -2**255 * 2 - 10 + 100 + """ + with pytest.raises(InvalidType): + compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 8982065b5d..f10e861689 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -4,9 +4,10 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException from vyper.semantics.types import IntegerT -from vyper.utils import evm_div, evm_mod +from vyper.utils import SizeLimits, evm_div, evm_mod types = sorted(IntegerT.unsigneds()) @@ -85,7 +86,7 @@ def foo(x: {typ}) -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, op, typ): # both variables code_1 = f""" @external @@ -192,7 +193,7 @@ def foo(x: {typ}, y: {typ}) -> bool: lo, hi = typ.ast_bounds - # note: constant folding is tested in tests/ast/folding + # note: folding is tested in tests/unit/ast/nodes special_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 2, hi - 1, hi] xs = special_cases.copy() @@ -204,7 +205,7 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_uint_literal(get_contract, assert_compile_failed, typ): +def test_uint_literal(get_contract, typ): lo, hi = typ.ast_bounds good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] @@ -221,7 +222,13 @@ def test() -> {typ}: assert c.test() == val for val in bad_cases: - assert_compile_failed(lambda v=val: get_contract(code_template.format(typ=typ, val=v))) + exc = ( + InvalidType + if SizeLimits.MIN_INT256 <= val <= SizeLimits.MAX_UINT256 + else OverflowException + ) + with pytest.raises(exc): + compile_code(code_template.format(typ=typ, val=val)) @pytest.mark.parametrize("typ", types) @@ -232,4 +239,15 @@ def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): def foo(a: {typ}) -> {typ}: return {op} a """ - assert_compile_failed(lambda: get_contract(code), InvalidOperation) + with pytest.raises(InvalidOperation): + compile_code(code) + + +def test_binop_nested_intermediate_overflow(): + code = """ +@external +def foo(): + a: uint256 = 2**255 * 2 / 10 + """ + with pytest.raises(OverflowException): + compile_code(code) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 4ef6874ae9..70a68e3206 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -2,6 +2,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ArrayIndexException, @@ -315,6 +316,21 @@ def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: def test_array_negative_accessor(get_contract_with_gas_estimation, assert_compile_failed): + array_constant_negative_accessor = """ +FOO: constant(int128) = -1 +@external +def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: + a: int128[4] = [0, 0, 0, 0] + a[0] = x + a[1] = y + a[2] = z + a[3] = w + return a[-4] * 1000 + a[-3] * 100 + a[-2] * 10 + a[FOO] + """ + + with pytest.raises(ArrayIndexException): + compile_code(array_constant_negative_accessor) + array_negative_accessor = """ @external def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: @@ -1728,7 +1744,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> DynArray[{return_type}, 3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -1740,7 +1756,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 657c4ba0b8..b5b9538c20 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -2,7 +2,7 @@ import pytest -from vyper.exceptions import ArrayIndexException, InvalidType, OverflowException, TypeMismatch +from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch def test_list_tester_code(get_contract_with_gas_estimation): @@ -705,7 +705,7 @@ def test_constant_list_fail(get_contract, assert_compile_failed, storage_type, r def foo() -> {return_type}[3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -717,7 +717,7 @@ def test_constant_list_fail_2(get_contract, assert_compile_failed, storage_type, def foo() -> {return_type}: return MY_CONSTANT[0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -824,7 +824,7 @@ def test_constant_nested_list_fail(get_contract, assert_compile_failed, storage_ def foo() -> {return_type}[2][3]: return MY_CONSTANT """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) @@ -838,4 +838,4 @@ def test_constant_nested_list_fail_2( def foo() -> {return_type}: return MY_CONSTANT[0][0] """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index fc06395015..0b7ec21bdb 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -1,13 +1,13 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException fail_list = [ """ @external def foo(): - x = as_wei_value(5, "vader") + x: uint256 = as_wei_value(5, "vader") """, """ @external @@ -95,4 +95,4 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) def test_function_declaration_exception(bad_code): with pytest.raises(ArgumentException): - compiler.compile_code(bad_code) + compile_code(bad_code) diff --git a/tests/functional/syntax/test_abi_decode.py b/tests/functional/syntax/test_abi_decode.py index f05ff429cd..a6665bb84c 100644 --- a/tests/functional/syntax/test_abi_decode.py +++ b/tests/functional/syntax/test_abi_decode.py @@ -26,7 +26,7 @@ def bar(j: String[32]) -> bool: @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_abi_encode_fail(bad_code, exc): +def test_abi_decode_fail(bad_code, exc): with pytest.raises(exc): compiler.compile_code(bad_code) @@ -41,5 +41,5 @@ def foo(x: Bytes[32]) -> uint256: @pytest.mark.parametrize("good_code", valid_list) -def test_abi_encode_success(good_code): +def test_abi_decode_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_abs.py b/tests/functional/syntax/test_abs.py new file mode 100644 index 0000000000..0841ff05d6 --- /dev/null +++ b/tests/functional/syntax/test_abs.py @@ -0,0 +1,40 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + y: int256 = abs( + -57896044618658097711785492504343953926634992332820282019728792003956564819968 + ) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_abs_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(int256) = -3 +BAR: constant(int256) = abs(FOO) + +@external +def foo(): + a: int256 = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_abs_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_addmulmod.py b/tests/functional/syntax/test_addmulmod.py index ddff4d3e01..17c7b3ab8c 100644 --- a/tests/functional/syntax/test_addmulmod.py +++ b/tests/functional/syntax/test_addmulmod.py @@ -1,5 +1,6 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ @@ -25,3 +26,24 @@ def foo() -> uint256: @pytest.mark.parametrize("code,exc", fail_list) def test_add_mod_fail(assert_compile_failed, get_contract, code, exc): assert_compile_failed(lambda: get_contract(code), exc) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_addmod(FOO, BAR, BAZ) + """, + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = 19 +BAX: constant(uint256) = uint256_mulmod(FOO, BAR, BAZ) + """, +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_as_wei_value.py b/tests/functional/syntax/test_as_wei_value.py index a5232a5c9a..056d0348e9 100644 --- a/tests/functional/syntax/test_as_wei_value.py +++ b/tests/functional/syntax/test_as_wei_value.py @@ -1,13 +1,31 @@ import pytest -from vyper.exceptions import ArgumentException, InvalidType, StructureException +from vyper import compile_code +from vyper.exceptions import ( + ArgumentException, + InvalidLiteral, + InvalidType, + OverflowException, + StructureException, + UndeclaredDefinition, +) + +# CMC 2023-12-31 these tests could probably go in builtins/folding/ fail_list = [ ( """ @external def foo(): - x: int128 = as_wei_value(5, szabo) + x: uint256 = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(5, "szaboo") """, ArgumentException, ), @@ -28,12 +46,50 @@ def foo(): """, InvalidType, ), + ( + """ +@external +def foo() -> uint256: + return as_wei_value( + 115792089237316195423570985008687907853269984665640564039457584007913129639937, + 'milliether' + ) + """, + OverflowException, + ), + ( + """ +@external +def foo(): + x: uint256 = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, szabo) + """, + UndeclaredDefinition, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(5, "szaboo") + """, + ArgumentException, + ), + ( + """ +FOO: constant(uint256) = as_wei_value(-1, "szabo") + """, + InvalidLiteral, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_as_wei_fail(get_contract_with_gas_estimation, bad_code, exc, assert_compile_failed): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_as_wei_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -59,6 +115,14 @@ def foo() -> uint256: x: address = 0x1234567890123456789012345678901234567890 return x.balance """, + """ +y: constant(String[5]) = "szabo" +x: constant(uint256) = as_wei_value(5, y) + +@external +def foo(): + a: uint256 = x + """, ] diff --git a/tests/functional/syntax/test_ceil.py b/tests/functional/syntax/test_ceil.py new file mode 100644 index 0000000000..41f4175d01 --- /dev/null +++ b/tests/functional/syntax/test_ceil.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = ceil(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_ceil_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py index 99a01a17c8..f566a80625 100644 --- a/tests/functional/syntax/test_dynamic_array.py +++ b/tests/functional/syntax/test_dynamic_array.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import StructureException fail_list = [ @@ -24,12 +24,21 @@ def foo(): """, StructureException, ), + ( + """ +@external +def foo(): + a: DynArray[uint256, FOO] = [1, 2, 3] + """, + StructureException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): - assert_compile_failed(lambda: get_contract(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) valid_list = [ @@ -48,4 +57,4 @@ def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): @pytest.mark.parametrize("good_code", valid_list) def test_dynarray_pass(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_epsilon.py b/tests/functional/syntax/test_epsilon.py new file mode 100644 index 0000000000..0e80d2b4bf --- /dev/null +++ b/tests/functional/syntax/test_epsilon.py @@ -0,0 +1,20 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +# CMC 2023-12-31 this could probably go in builtins/folding/ +fail_list = [ + ( + """ +FOO: constant(address) = epsilon(address) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_floor.py b/tests/functional/syntax/test_floor.py new file mode 100644 index 0000000000..5c30aecbe1 --- /dev/null +++ b/tests/functional/syntax/test_floor.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +BAR: constant(decimal) = 2.5 +FOO: constant(int256) = floor(BAR) + +@external +def foo(): + a: int256 = FOO + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_floor_good(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 7c7f9c476d..a9c3ad5cab 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -3,7 +3,12 @@ import pytest from vyper import compiler -from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException +from vyper.exceptions import ( + ArgumentException, + StateAccessViolation, + StructureException, + TypeMismatch, +) fail_list = [ ( @@ -20,6 +25,17 @@ def foo(): ( """ @external +def bar(): + for i in range(1,2,bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external def foo(): x: uint256 = 100 for _ in range(10, bound=x): @@ -181,6 +197,44 @@ def foo(x: int128): "Value must be a literal integer, unless a bound is specified", "x", ), + ( + """ +@external +def bar(x: uint256): + for i in range(3, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +FOO: constant(int128) = 3 +BAR: constant(uint256) = 7 + +@external +def foo(): + for i in range(FOO, BAR): + pass + """, + TypeMismatch, + "Iterator values are of different types", + "range(FOO, BAR)", + ), + ( + """ +FOO: constant(int128) = -1 + +@external +def foo(): + for i in range(10, bound=FOO): + pass + """, + StructureException, + "Bound must be at least 1", + "-1", + ), ] for_code_regex = re.compile(r"for .+ in (.*):") diff --git a/tests/functional/syntax/test_len.py b/tests/functional/syntax/test_len.py index bbde7e4897..b8cc61df1d 100644 --- a/tests/functional/syntax/test_len.py +++ b/tests/functional/syntax/test_len.py @@ -1,7 +1,6 @@ import pytest -from pytest import raises -from vyper import compiler +from vyper import compile_code from vyper.exceptions import TypeMismatch fail_list = [ @@ -21,11 +20,11 @@ def foo(inp: int128) -> uint256: @pytest.mark.parametrize("bad_code", fail_list) def test_block_fail(bad_code): if isinstance(bad_code, tuple): - with raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + with pytest.raises(bad_code[1]): + compile_code(bad_code[0]) else: - with raises(TypeMismatch): - compiler.compile_code(bad_code) + with pytest.raises(TypeMismatch): + compile_code(bad_code) valid_list = [ @@ -39,9 +38,18 @@ def foo(inp: Bytes[10]) -> uint256: def foo(inp: String[10]) -> uint256: return len(inp) """, + """ +BAR: constant(String[5]) = "vyper" +FOO: constant(uint256) = len(BAR) + +@external +def foo() -> uint256: + a: uint256 = FOO + return a + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_list_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_method_id.py b/tests/functional/syntax/test_method_id.py new file mode 100644 index 0000000000..849c1b0d55 --- /dev/null +++ b/tests/functional/syntax/test_method_id.py @@ -0,0 +1,50 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidLiteral, InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: Bytes[4] = method_id("bar ()") + """, + InvalidLiteral, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id(1) + """, + InvalidType, + ), + ( + """ +FOO: constant(Bytes[4]) = method_id("bar ()") + """, + InvalidLiteral, + ), +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_method_id_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(String[5]) = "foo()" +BAR: constant(Bytes[4]) = method_id(FOO) + +@external +def foo(a: Bytes[4] = BAR): + pass + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_method_id_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_minmax.py b/tests/functional/syntax/test_minmax.py index 2ad3d363f1..78ee74635c 100644 --- a/tests/functional/syntax/test_minmax.py +++ b/tests/functional/syntax/test_minmax.py @@ -1,6 +1,7 @@ import pytest -from vyper.exceptions import InvalidType, TypeMismatch +from vyper import compile_code +from vyper.exceptions import InvalidType, OverflowException, TypeMismatch fail_list = [ ( @@ -19,9 +20,45 @@ def foo(): """, TypeMismatch, ), + ( + """ +@external +def foo(): + a: decimal = min(1.0, 18707220957835557353007165858768422651595.9365500928) + """, + OverflowException, + ), ] @pytest.mark.parametrize("bad_code,exc", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = min(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, + """ +FOO: constant(uint256) = 123 +BAR: constant(uint256) = 456 +BAZ: constant(uint256) = max(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """, +] + + +@pytest.mark.parametrize("good_code", valid_list) +def test_block_success(good_code): + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_minmax_value.py b/tests/functional/syntax/test_minmax_value.py index e154cad23f..8cc3370b42 100644 --- a/tests/functional/syntax/test_minmax_value.py +++ b/tests/functional/syntax/test_minmax_value.py @@ -1,21 +1,39 @@ import pytest +from vyper import compile_code from vyper.exceptions import InvalidType fail_list = [ - """ + ( + """ @external def foo(): a: address = min_value(address) """, - """ + InvalidType, + ), + ( + """ @external def foo(): a: address = max_value(address) """, + InvalidType, + ), + ( + """ +FOO: constant(address) = min_value(address) + +@external +def foo(): + a: address = FOO + """, + InvalidType, + ), ] -@pytest.mark.parametrize("bad_code", fail_list) -def test_block_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code): - assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), InvalidType) +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_block_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) diff --git a/tests/functional/syntax/test_powmod.py b/tests/functional/syntax/test_powmod.py new file mode 100644 index 0000000000..12ea23152c --- /dev/null +++ b/tests/functional/syntax/test_powmod.py @@ -0,0 +1,39 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo(): + a: uint256 = pow_mod256(-1, -1) + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("bad_code,exc", fail_list) +def test_powmod_fail(bad_code, exc): + with pytest.raises(exc): + compile_code(bad_code) + + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(uint256) = 5 +BAZ: constant(uint256) = pow_mod256(FOO, BAR) + +@external +def foo(): + a: uint256 = BAZ + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_powmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_raw_call.py b/tests/functional/syntax/test_raw_call.py index b1286e7a8e..c0b38d1d1e 100644 --- a/tests/functional/syntax/test_raw_call.py +++ b/tests/functional/syntax/test_raw_call.py @@ -1,6 +1,6 @@ import pytest -from vyper import compiler +from vyper import compile_code from vyper.exceptions import ArgumentException, InvalidType, SyntaxException, TypeMismatch fail_list = [ @@ -39,7 +39,7 @@ def foo(): @pytest.mark.parametrize("bad_code,exc", fail_list) def test_raw_call_fail(bad_code, exc): with pytest.raises(exc): - compiler.compile_code(bad_code) + compile_code(bad_code) valid_list = [ @@ -90,9 +90,23 @@ def foo(): value=self.balance - self.balances[0] ) """, + # test constants + """ +OUTSIZE: constant(uint256) = 4 +REVERT_ON_FAILURE: constant(bool) = True +@external +def foo(): + x: Bytes[9] = raw_call( + 0x1234567890123456789012345678901234567890, + b"cow", + max_outsize=OUTSIZE, + gas=595757, + revert_on_failure=REVERT_ON_FAILURE + ) + """, ] @pytest.mark.parametrize("good_code", valid_list) def test_raw_call_success(good_code): - assert compiler.compile_code(good_code) is not None + assert compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_ternary.py b/tests/functional/syntax/test_ternary.py index 325be3e43b..6a2bb9c072 100644 --- a/tests/functional/syntax/test_ternary.py +++ b/tests/functional/syntax/test_ternary.py @@ -1,6 +1,6 @@ import pytest -from vyper.compiler import compile_code +from vyper import compile_code from vyper.exceptions import InvalidType, TypeMismatch good_list = [ @@ -82,7 +82,7 @@ def foo() -> uint256: def foo() -> uint256: return 1 if TEST else 2 """, - InvalidType, + TypeMismatch, ), ( # bad test type: variable """ diff --git a/tests/functional/syntax/test_uint2str.py b/tests/functional/syntax/test_uint2str.py new file mode 100644 index 0000000000..9e6dde30cc --- /dev/null +++ b/tests/functional/syntax/test_uint2str.py @@ -0,0 +1,19 @@ +import pytest + +from vyper import compile_code + +valid_list = [ + """ +FOO: constant(uint256) = 3 +BAR: constant(String[78]) = uint2str(FOO) + +@external +def foo(): + a: String[78] = BAR + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_addmulmod_pass(code): + assert compile_code(code) is not None diff --git a/tests/functional/syntax/test_unary.py b/tests/functional/syntax/test_unary.py new file mode 100644 index 0000000000..5942ee15db --- /dev/null +++ b/tests/functional/syntax/test_unary.py @@ -0,0 +1,21 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import InvalidType + +fail_list = [ + ( + """ +@external +def foo() -> int128: + return -2**127 + """, + InvalidType, + ) +] + + +@pytest.mark.parametrize("code,exc", fail_list) +def test_unary_fail(code, exc): + with pytest.raises(exc): + compile_code(code) diff --git a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_binop_decimal.py rename to tests/unit/ast/nodes/test_fold_binop_decimal.py index 44b82e321d..e426a11de9 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: is_valid = False @@ -49,7 +49,7 @@ def test_binop_pow(): old_node = vyper_ast.body[0].value with pytest.raises(TypeMismatch): - old_node.evaluate() + old_node.get_folded_value() @pytest.mark.fuzzing @@ -74,8 +74,8 @@ def foo({input_value}) -> decimal: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value is_valid = -(2**127) <= expected < 2**127 except (OverflowException, ZeroDivisionException): # for overflow or division/modulus by 0, expect the contract call to revert diff --git a/tests/unit/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_binop_int.py rename to tests/unit/ast/nodes/test_fold_binop_int.py index 405d557f7d..904b36c167 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -27,7 +27,7 @@ def foo(a: int128, b: int128) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: is_valid = False @@ -57,7 +57,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value try: - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: is_valid = False @@ -85,7 +85,7 @@ def foo(a: uint256, b: uint256) -> uint256: vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -115,8 +115,8 @@ def foo({input_value}) -> int128: vyper_ast = vy_ast.parse_to_ast(literal_op) try: - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value is_valid = True except ZeroDivisionException: is_valid = False diff --git a/tests/unit/ast/nodes/test_evaluate_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py similarity index 92% rename from tests/unit/ast/nodes/test_evaluate_boolop.py rename to tests/unit/ast/nodes/test_fold_boolop.py index 8b70537c39..3c42da0d26 100644 --- a/tests/unit/ast/nodes/test_evaluate_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -26,7 +26,7 @@ def foo({input_value}) -> bool: vyper_ast = vy_ast.parse_to_ast(literal_op) old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(*values) == new_node.value @@ -53,7 +53,7 @@ def foo({input_value}) -> bool: literal_op = literal_op.rsplit(maxsplit=1)[0] vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value assert contract.foo(*values) == expected diff --git a/tests/unit/ast/nodes/test_evaluate_compare.py b/tests/unit/ast/nodes/test_fold_compare.py similarity index 94% rename from tests/unit/ast/nodes/test_evaluate_compare.py rename to tests/unit/ast/nodes/test_fold_compare.py index 07f8e70de6..2b7c0f09d7 100644 --- a/tests/unit/ast/nodes/test_evaluate_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -21,7 +21,7 @@ def foo(a: int128, b: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -41,7 +41,7 @@ def foo(a: uint128, b: uint128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -65,7 +65,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -94,7 +94,7 @@ def bar(a: int128) -> bool: vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() # check runtime == fully folded assert contract.foo(left, right) == new_node.value @@ -109,4 +109,4 @@ def test_compare_type_mismatch(op): vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): - old_node.evaluate() + old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_evaluate_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py similarity index 93% rename from tests/unit/ast/nodes/test_evaluate_subscript.py rename to tests/unit/ast/nodes/test_fold_subscript.py index ca50a076a5..1884abf73b 100644 --- a/tests/unit/ast/nodes/test_evaluate_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -21,6 +21,6 @@ def foo(array: int128[10], idx: uint256) -> int128: vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(array, idx) == new_node.value diff --git a/tests/unit/ast/nodes/test_evaluate_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py similarity index 86% rename from tests/unit/ast/nodes/test_evaluate_unaryop.py rename to tests/unit/ast/nodes/test_fold_unaryop.py index 63d7a0b7ff..ff48adfe71 100644 --- a/tests/unit/ast/nodes/test_evaluate_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -14,7 +14,7 @@ def foo(a: bool) -> bool: vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() + new_node = old_node.get_folded_value() assert contract.foo(bool_cond) == new_node.value @@ -31,7 +31,7 @@ def foo(a: bool) -> bool: literal_op = f"{'not ' * count}{bool_cond}" vyper_ast = vy_ast.parse_to_ast(literal_op) - vy_ast.folding.replace_literal_ops(vyper_ast) - expected = vyper_ast.body[0].value.value + new_node = vyper_ast.body[0].value.get_folded_value() + expected = new_node.value assert contract.foo(bool_cond) == expected diff --git a/tests/unit/ast/nodes/test_replace_in_tree.py b/tests/unit/ast/nodes/test_replace_in_tree.py deleted file mode 100644 index 682e7ce7de..0000000000 --- a/tests/unit/ast/nodes/test_replace_in_tree.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from vyper import ast as vy_ast -from vyper.exceptions import CompilerPanic - - -def test_assumptions(): - # ASTs generated separately from the same source should compare equal - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("foo = 42") - assert vy_ast.compare_nodes(test_tree, expected_tree) - - # ASTs generated separately with different source should compare not-equal - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("bar = 666") - assert not vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_simple_replacement(): - test_tree = vy_ast.parse_to_ast("foo = 42") - expected_tree = vy_ast.parse_to_ast("bar = 42") - - old_node = test_tree.body[0].target - new_node = vy_ast.parse_to_ast("bar").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - assert vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_list_replacement_similar_nodes(): - test_tree = vy_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]") - expected_tree = vy_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]") - - old_node = test_tree.body[0].value.elements[2] - new_node = vy_ast.parse_to_ast("31337").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - assert vy_ast.compare_nodes(test_tree, expected_tree) - - -def test_parents_children(): - test_tree = vy_ast.parse_to_ast("foo = 42") - - old_node = test_tree.body[0].target - parent = old_node.get_ancestor() - - new_node = vy_ast.parse_to_ast("bar").body[0].value - test_tree.replace_in_tree(old_node, new_node) - - assert old_node.get_ancestor() == new_node.get_ancestor() - - assert old_node not in parent.get_children() - assert new_node in parent.get_children() - - assert old_node not in test_tree.get_descendants() - assert new_node in test_tree.get_descendants() - - -def test_cannot_replace_twice(): - test_tree = vy_ast.parse_to_ast("foo = 42") - old_node = test_tree.body[0].target - - new_node = vy_ast.parse_to_ast("42").body[0].value - - test_tree.replace_in_tree(old_node, new_node) - - with pytest.raises(CompilerPanic): - test_tree.replace_in_tree(old_node, new_node) diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index dc49f72561..20390f3d5e 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -41,8 +41,8 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) - assert dict_out["ast_dict"]["ast"]["body"][0] == { + dict_out = compiler.compile_code(code, output_formats=["annotated_ast_dict"], source_id=0) + assert dict_out["annotated_ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", "col_offset": 3, @@ -69,12 +69,14 @@ def test_basic_ast(): "lineno": 2, "node_id": 2, "src": "1:1:0", + "type": "int128", }, "value": None, "is_constant": False, "is_immutable": False, "is_public": False, "is_transient": False, + "type": "int128", } diff --git a/tests/unit/ast/test_folding.py b/tests/unit/ast/test_folding.py deleted file mode 100644 index 62a7140e97..0000000000 --- a/tests/unit/ast/test_folding.py +++ /dev/null @@ -1,272 +0,0 @@ -import pytest - -from vyper import ast as vy_ast -from vyper.ast import folding -from vyper.exceptions import OverflowException - - -def test_integration(): - test_ast = vy_ast.parse_to_ast("[1+2, 6+7][8-8]") - expected_ast = vy_ast.parse_to_ast("3") - - folding.fold(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_simple(): - test_ast = vy_ast.parse_to_ast("1 + 2") - expected_ast = vy_ast.parse_to_ast("3") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested(): - test_ast = vy_ast.parse_to_ast("((6 + (2**4)) * 4) / 2") - expected_ast = vy_ast.parse_to_ast("44") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_binop_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast("2**255 * 2 / 10") - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_binop_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast("-2**255 * 2 - 10 + 100") - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_decimal_nested_intermediate_overflow(): - test_ast = vy_ast.parse_to_ast( - "18707220957835557353007165858768422651595.9365500927 + 1e-10 - 1e-10" - ) - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_decimal_nested_intermediate_underflow(): - test_ast = vy_ast.parse_to_ast( - "-18707220957835557353007165858768422651595.9365500928 - 1e-10 + 1e-10" - ) - with pytest.raises(OverflowException): - folding.fold(test_ast) - - -def test_replace_literal_ops(): - test_ast = vy_ast.parse_to_ast("[not True, True and False, True or False]") - expected_ast = vy_ast.parse_to_ast("[False, False, True]") - - folding.replace_literal_ops(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_simple(): - test_ast = vy_ast.parse_to_ast("[foo, bar, baz][1]") - expected_ast = vy_ast.parse_to_ast("bar") - - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -def test_replace_subscripts_nested(): - test_ast = vy_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]") - expected_ast = vy_ast.parse_to_ast("5") - - folding.replace_subscripts(test_ast) - - assert vy_ast.compare_nodes(test_ast, expected_ast) - - -constants_modified = [ - "bar = FOO", - "bar: int128[FOO]", - "[1, 2, FOO]", - "def bar(a: int128 = FOO): pass", - "log bar(FOO)", - "FOO + 1", - "a: int128[FOO / 2]", - "a[FOO - 1] = 44", -] - - -@pytest.mark.parametrize("source", constants_modified) -def test_replace_constant(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -constants_unmodified = [ - "FOO = 42", - "self.FOO = 42", - "bar = FOO()", - "FOO()", - "bar = FOO()", - "bar = self.FOO", - "log FOO(bar)", - "[1, 2, FOO()]", - "FOO[42] = 2", -] - - -@pytest.mark.parametrize("source", constants_unmodified) -def test_replace_constant_no(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -userdefined_modified = [ - "FOO", - "foo = FOO", - "foo: int128[FOO] = 42", - "foo = [FOO]", - "foo += FOO", - "def foo(bar: int128 = FOO): pass", - "def foo(): bar = FOO", - "def foo(): return FOO", -] - - -@pytest.mark.parametrize("source", userdefined_modified) -def test_replace_userdefined_constant(source): - source = f"FOO: constant(int128) = 42\n{source}" - - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_user_defined_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -userdefined_unmodified = [ - "FOO: constant(int128) = 42", - "FOO = 42", - "FOO += 42", - "FOO()", - "def foo(FOO: int128 = 42): pass", - "def foo(): FOO = 42", - "def FOO(): pass", -] - - -@pytest.mark.parametrize("source", userdefined_unmodified) -def test_replace_userdefined_constant_no(source): - source = f"FOO: constant(int128) = 42\n{source}" - - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_user_defined_constants(folded_ast) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -dummy_address = "0x000000000000000000000000000000000000dEaD" -userdefined_attributes = [("b: uint256 = ADDR.balance", f"b: uint256 = {dummy_address}.balance")] - - -@pytest.mark.parametrize("source", userdefined_attributes) -def test_replace_userdefined_attribute(source): - preamble = f"ADDR: constant(address) = {dummy_address}" - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_struct = [("b: Foo = FOO", "b: Foo = Foo({a: 123, b: 456})")] - - -@pytest.mark.parametrize("source", userdefined_struct) -def test_replace_userdefined_struct(source): - preamble = """ -struct Foo: - a: uint256 - b: uint256 - -FOO: constant(Foo) = Foo({a: 123, b: 456}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -userdefined_nested_struct = [ - ("b: Foo = FOO", "b: Foo = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789})") -] - - -@pytest.mark.parametrize("source", userdefined_nested_struct) -def test_replace_userdefined_nested_struct(source): - preamble = """ -struct Bar: - b1: uint256 - b2: uint256 - -struct Foo: - f1: Bar - f2: uint256 - -FOO: constant(Foo) = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) - """ - l_source = f"{preamble}\n{source[0]}" - r_source = f"{preamble}\n{source[1]}" - - l_ast = vy_ast.parse_to_ast(l_source) - folding.replace_user_defined_constants(l_ast) - - r_ast = vy_ast.parse_to_ast(r_source) - - assert vy_ast.compare_nodes(l_ast, r_ast) - - -builtin_folding_functions = [("ceil(4.2)", "5"), ("floor(4.2)", "4")] - -builtin_folding_sources = [ - "{}", - "foo = {}", - "foo = [{0}, {0}]", - "def foo(): {}", - "def foo(): return {}", - "def foo(bar: {}): pass", -] - - -@pytest.mark.parametrize("source", builtin_folding_sources) -@pytest.mark.parametrize("original,result", builtin_folding_functions) -def test_replace_builtins(source, original, result): - original_ast = vy_ast.parse_to_ast(source.format(original)) - target_ast = vy_ast.parse_to_ast(source.format(result)) - - folding.replace_builtin_functions(original_ast) - - assert vy_ast.compare_nodes(original_ast, target_ast) diff --git a/tests/unit/ast/test_natspec.py b/tests/unit/ast/test_natspec.py index c2133468aa..22167f8694 100644 --- a/tests/unit/ast/test_natspec.py +++ b/tests/unit/ast/test_natspec.py @@ -60,7 +60,7 @@ def doesEat(food: String[30], qty: uint256) -> bool: def parse_natspec(code): - vyper_ast = CompilerData(code).vyper_module_folded + vyper_ast = CompilerData(code).annotated_vyper_module return vy_ast.parse_natspec(vyper_ast) diff --git a/vyper/ast/README.md b/vyper/ast/README.md index 320c69da0c..7400091993 100644 --- a/vyper/ast/README.md +++ b/vyper/ast/README.md @@ -12,8 +12,6 @@ and parsing NatSpec docstrings. * [`annotation.py`](annotation.py): Contains the `AnnotatingVisitor` class, used to annotate and modify the Python AST prior to converting it to a Vyper AST. -* [`folding.py`](folding.py): Functions for evaluating and replacing literal -nodes within the Vyper AST. * [`natspec.py`](natspec.py): Functions for parsing NatSpec docstrings within the source. * [`nodes.py`](nodes.py): Contains the Vyper node classes, and the `get_node` @@ -70,25 +68,6 @@ or parents that match a desired pattern. To learn more about these methods, read their docstrings in the `VyperNode` class in [`nodes.py`](nodes.py). -### Modifying the AST - -[`folding.py`](folding.py) contains the `fold` function, a high-level method called -to evaluating and replacing literal nodes within the AST. Some examples of literal -folding include: - -* arithmetic operations (`3+2` becomes `5`) -* references to literal arrays (`["foo", "bar"][1]` becomes `"bar"`) -* builtin functions applied to literals (`min(1,2)` becomes `1`) - -The process of literal folding includes: - -1. Foldable node classes are evaluated via their `evaluate` method, which attempts -to create a new `Constant` from the content of the given node. -2. Replacement nodes are generated using the `from_node` class method within the new -node class. -3. The modification of the tree is handled by `Module.replace_in_tree`, which locates -the existing node and replaces it with a new one. - ## Design ### `__slots__` diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index 4b46801153..bc08626b59 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -17,4 +17,4 @@ # required to avoid circular dependency -from . import expansion, folding # noqa: E402 +from . import expansion # noqa: E402 diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index eac8ffdef5..5581e82fe2 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -1,7 +1,7 @@ import ast as python_ast from typing import Any, Optional, Union -from . import expansion, folding, nodes, validation +from . import expansion, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * from .parse import parse_to_ast as parse_to_ast diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py deleted file mode 100644 index 087708a356..0000000000 --- a/vyper/ast/folding.py +++ /dev/null @@ -1,263 +0,0 @@ -from typing import Optional, Union - -from vyper.ast import nodes as vy_ast -from vyper.builtins.functions import DISPATCH_TABLE -from vyper.exceptions import UnfoldableNode, UnknownType -from vyper.semantics.types.base import VyperType -from vyper.semantics.types.utils import type_from_annotation - - -def fold(vyper_module: vy_ast.Module) -> None: - """ - Perform literal folding operations on a Vyper AST. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - changed_nodes = 1 - while changed_nodes: - changed_nodes = 0 - changed_nodes += replace_user_defined_constants(vyper_module) - changed_nodes += replace_literal_ops(vyper_module) - changed_nodes += replace_subscripts(vyper_module) - changed_nodes += replace_builtin_functions(vyper_module) - - -def replace_literal_ops(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate operation and comparison nodes within the Vyper AST, - replacing them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - node_types = (vy_ast.BoolOp, vy_ast.BinOp, vy_ast.UnaryOp, vy_ast.Compare) - for node in vyper_module.get_descendants(node_types, reverse=True): - try: - new_node = node.evaluate() - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_subscripts(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate Subscript nodes within the Vyper AST, replacing them with - Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Subscript, reverse=True): - try: - new_node = node.evaluate() - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: - """ - Find and evaluate builtin function calls within the Vyper AST, replacing - them with Constant nodes where possible. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Call, reverse=True): - if not isinstance(node.func, vy_ast.Name): - continue - - name = node.func.id - func = DISPATCH_TABLE.get(name) - if func is None or not hasattr(func, "evaluate"): - continue - try: - new_node = func.evaluate(node) # type: ignore - except UnfoldableNode: - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes - - -def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: - """ - Find user-defined constant assignments, and replace references - to the constants with their literal values. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_children(vy_ast.VariableDecl): - if not isinstance(node.target, vy_ast.Name): - # left-hand-side of assignment is not a variable - continue - if not node.is_constant: - # annotation is not wrapped in `constant(...)` - continue - - # Extract type definition from propagated annotation - type_ = None - try: - type_ = type_from_annotation(node.annotation) - except UnknownType: - # handle user-defined types e.g. structs - it's OK to not - # propagate the type annotation here because user-defined - # types can be unambiguously inferred at typechecking time - pass - - changed_nodes += replace_constant( - vyper_module, node.target.id, node.value, False, type_=type_ - ) - - return changed_nodes - - -# TODO constant folding on log events - - -def _replace(old_node, new_node, type_=None): - if isinstance(new_node, vy_ast.Constant): - new_node = new_node.from_node(old_node, value=new_node.value) - if type_: - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.List): - base_type = type_.value_type if type_ else None - list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] - new_node = new_node.from_node(old_node, elements=list_values) - if type_: - new_node._metadata["type"] = type_ - return new_node - elif isinstance(new_node, vy_ast.Call): - # Replace `Name` node with `Call` node - keyword = keywords = None - if hasattr(new_node, "keyword"): - keyword = new_node.keyword - if hasattr(new_node, "keywords"): - keywords = new_node.keywords - new_node = new_node.from_node( - old_node, func=new_node.func, args=new_node.args, keyword=keyword, keywords=keywords - ) - return new_node - else: - raise UnfoldableNode - - -def replace_constant( - vyper_module: vy_ast.Module, - id_: str, - replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], - raise_on_error: bool, - type_: Optional[VyperType] = None, -) -> int: - """ - Replace references to a variable name with a literal value. - - Arguments - --------- - vyper_module : Module - Module-level ast node to perform replacement in. - id_ : str - String representing the `.id` attribute of the node(s) to be replaced. - replacement_node : Constant | List | Call - Vyper ast node representing the literal value to be substituted in. - `Call` nodes are for struct constants. - raise_on_error: bool - Boolean indicating if `UnfoldableNode` exception should be raised or ignored. - type_ : VyperType, optional - Type definition to be propagated to type checker. - - Returns - ------- - int - Number of nodes that were replaced. - """ - changed_nodes = 0 - - for node in vyper_module.get_descendants(vy_ast.Name, {"id": id_}, reverse=True): - parent = node.get_ancestor() - - if isinstance(parent, vy_ast.Call) and node == parent.func: - # do not replace calls because splicing a constant into a callable site is - # never valid and it worsens the error message - continue - - # do not replace dictionary keys - if isinstance(parent, vy_ast.Dict) and node in parent.keys: - continue - - if not node.get_ancestor(vy_ast.Index): - # do not replace left-hand side of assignments - assign = node.get_ancestor( - (vy_ast.Assign, vy_ast.AnnAssign, vy_ast.AugAssign, vy_ast.VariableDecl) - ) - - if assign and node in assign.target.get_descendants(include_self=True): - continue - - # do not replace enum members - if node.get_ancestor(vy_ast.FlagDef): - continue - - try: - # note: _replace creates a copy of the replacement_node - new_node = _replace(node, replacement_node, type_=type_) - except UnfoldableNode: - if raise_on_error: - raise - continue - - changed_nodes += 1 - vyper_module.replace_in_tree(node, new_node) - - return changed_nodes diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index 41905b178a..41a6703b6e 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -11,13 +11,13 @@ USERDOCS_FIELDS = ("notice",) -def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: +def parse_natspec(annotated_vyper_module: vy_ast.Module) -> Tuple[dict, dict]: """ Parses NatSpec documentation from a contract. Arguments --------- - vyper_module_folded : Module + annotated_vyper_module: Module Module-level vyper ast node. interface_codes: Dict, optional Dict containing relevant data for any import statements related to @@ -33,15 +33,15 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: from vyper.semantics.types.function import FunctionVisibility userdoc, devdoc = {}, {} - source: str = vyper_module_folded.full_source_code + source: str = annotated_vyper_module.full_source_code - docstring = vyper_module_folded.get("doc_string.value") + docstring = annotated_vyper_module.get("doc_string.value") if docstring: devdoc.update(_parse_docstring(source, docstring, ("param", "return"))) if "notice" in devdoc: userdoc["notice"] = devdoc.pop("notice") - for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: + for node in [i for i in annotated_vyper_module.body if i.get("doc_string.value")]: docstring = node.doc_string.value func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index dba9f2a22d..efab5117d4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -11,7 +11,6 @@ from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS from vyper.exceptions import ( ArgumentException, - CompilerPanic, InvalidLiteral, InvalidOperation, OverflowException, @@ -19,6 +18,7 @@ SyntaxException, TypeMismatch, UnfoldableNode, + VariableDeclarationException, VyperException, ZeroDivisionException, ) @@ -210,23 +210,6 @@ def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None: ) -def _validate_numeric_bounds( - node: Union["BinOp", "UnaryOp"], value: Union[decimal.Decimal, int] -) -> None: - if isinstance(value, decimal.Decimal): - # this will change if/when we add more decimal types - lower, upper = SizeLimits.MIN_AST_DECIMAL, SizeLimits.MAX_AST_DECIMAL - elif isinstance(value, int): - lower, upper = SizeLimits.MIN_INT256, SizeLimits.MAX_UINT256 - else: - raise CompilerPanic(f"Unexpected return type from {node._op}: {type(value)}") - if not lower <= value <= upper: - raise OverflowException( - f"Result of {node.op.description} ({value}) is outside bounds of all numeric types", - node, - ) - - class VyperNode: """ Base class for all vyper AST nodes. @@ -246,7 +229,7 @@ class VyperNode: Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _terminus : bool, optional + _is_terminus : bool, optional If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields @@ -390,22 +373,67 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def evaluate(self) -> "VyperNode": + @property + def is_literal_value(self): + """ + Check if the node is a literal value. + """ + return False + + @property + def has_folded_value(self): + """ + Property method to check if the node has a folded value. + """ + return "folded_value" in self._metadata + + def get_folded_value(self) -> "VyperNode": """ - Attempt to evaluate the content of a node and generate a new node from it. + Attempt to get the folded value, bubbling up UnfoldableNode if the node + is not foldable. + + + The returned value is cached on `_metadata["folded_value"]`. - If a node cannot be evaluated it should raise `UnfoldableNode`. This base - method acts as a catch-all to raise on any inherited classes that do not - implement the method. + For constant/literal nodes, the node should be directly returned + without caching to the metadata. """ - raise UnfoldableNode(f"{type(self)} cannot be evaluated") + if self.is_literal_value: + return self + + if "folded_value" not in self._metadata: + res = self._try_fold() # possibly throws UnfoldableNode + self._set_folded_value(res) + + return self._metadata["folded_value"] + + def _set_folded_value(self, node: "VyperNode") -> None: + # sanity check this is only called once + assert "folded_value" not in self._metadata + + # set the folded node's parent so that get_ancestor works + # this is mainly important for error messages. + node._parent = self._parent + + self._metadata["folded_value"] = node + + def _try_fold(self) -> "VyperNode": + """ + Attempt to constant-fold the content of a node, returning the result of + constant-folding if possible. + + If a node cannot be folded, it should raise `UnfoldableNode`. This + base implementation acts as a catch-all to raise on any inherited + classes that do not implement the method. + """ + raise UnfoldableNode(f"{type(self)} cannot be folded") def validate(self) -> None: """ Validate the content of a node. - Called by `ast.validation.validate_literal_nodes` to verify values within - literal nodes. + Called by `ast.validation.validate_literal_nodes` to verify values + within literal nodes. Returns `None` if the node is valid, raises `InvalidLiteral` or another more expressive exception if the value cannot be valid within a Vyper @@ -609,48 +637,6 @@ class Module(TopLevel): # metadata __slots__ = ("path", "resolved_path", "source_id") - def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: - """ - Perform an in-place substitution of a node within the tree. - - Parameters - ---------- - old_node : VyperNode - Node object to be replaced. - new_node : VyperNode - Node object to replace new_node. - - Returns - ------- - None - """ - parent = old_node._parent - - if old_node not in parent._children: - raise CompilerPanic("Node to be replaced does not exist within parent children") - - is_replaced = False - for key in parent.get_fields(): - obj = getattr(parent, key, None) - if obj == old_node: - if is_replaced: - raise CompilerPanic("Node to be replaced exists as multiple members in parent") - setattr(parent, key, new_node) - is_replaced = True - elif isinstance(obj, list) and obj.count(old_node): - if is_replaced or obj.count(old_node) > 1: - raise CompilerPanic("Node to be replaced exists as multiple members in parent") - obj[obj.index(old_node)] = new_node - is_replaced = True - if not is_replaced: - raise CompilerPanic("Node to be replaced does not exist within parent members") - - parent._children.remove(old_node) - - new_node._parent = parent - new_node._depth = old_node._depth - parent._children.add(new_node) - def add_to_body(self, node: VyperNode) -> None: """ Add a new node to the body of this node. @@ -769,6 +755,10 @@ class Constant(ExprNode): # inherited class for all simple constant node types __slots__ = ("value",) + @property + def is_literal_value(self): + return True + class Num(Constant): # inherited class for all numeric constant node types @@ -862,7 +852,14 @@ def n_bytes(self): """ The number of bytes this hex value represents """ - return self.n_nibbles // 2 + return len(self.bytes_value) + + @property + def bytes_value(self): + """ + This value as bytes + """ + return bytes.fromhex(self.value.removeprefix("0x")) class Str(Constant): @@ -905,19 +902,39 @@ class List(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} + @property + def is_literal_value(self): + return all(e.is_literal_value for e in self.elements) + + def _try_fold(self) -> ExprNode: + elements = [e.get_folded_value() for e in self.elements] + return type(self).from_node(self, elements=elements) + class Tuple(ExprNode): __slots__ = ("elements",) _translated_fields = {"elts": "elements"} + @property + def is_literal_value(self): + return all(e.is_literal_value for e in self.elements) + def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) + def _try_fold(self) -> ExprNode: + elements = [e.get_folded_value() for e in self.elements] + return type(self).from_node(self, elements=elements) + class NameConstant(Constant): __slots__ = () + def validate(self): + if self.value is None: + raise InvalidLiteral("`None` is not a valid vyper value!", self) + class Ellipsis(Constant): __slots__ = () @@ -926,6 +943,14 @@ class Ellipsis(Constant): class Dict(ExprNode): __slots__ = ("keys", "values") + @property + def is_literal_value(self): + return all(v.is_literal_value for v in self.values) + + def _try_fold(self) -> ExprNode: + values = [v.get_folded_value() for v in self.values] + return type(self).from_node(self, values=values) + class Name(ExprNode): __slots__ = ("id",) @@ -934,7 +959,7 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the unary operation. @@ -943,16 +968,17 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - if isinstance(self.op, Not) and not isinstance(self.operand, NameConstant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, USub) and not isinstance(self.operand, (Int, Decimal)): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, Invert) and not isinstance(self.operand, Int): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + operand = self.operand.get_folded_value() - value = self.op._op(self.operand.value) - _validate_numeric_bounds(self, value) - return type(self.operand).from_node(self, value=value) + if isinstance(self.op, Not) and not isinstance(operand, NameConstant): + raise UnfoldableNode("not a boolean!", self.operand) + if isinstance(self.op, USub) and not isinstance(operand, Num): + raise UnfoldableNode("not a number!", self.operand) + if isinstance(self.op, Invert) and not isinstance(operand, Int): + raise UnfoldableNode("not an int!", self.operand) + + value = self.op._op(operand.value) + return type(operand).from_node(self, value=value) class Operator(VyperNode): @@ -982,7 +1008,7 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the arithmetic operation. @@ -991,20 +1017,19 @@ def evaluate(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value() for i in (self.left, self.right)] if type(left) is not type(right): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if not isinstance(left, (Int, Decimal)): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + raise UnfoldableNode("invalid operation", self) + if not isinstance(left, Num): + raise UnfoldableNode("not a number!", self.left) # this validation is performed to prevent the compiler from hanging # on very large shifts and improve the error message for negative # values. if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", right) + raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) value = self.op._op(left.value, right.value) - _validate_numeric_bounds(self, value) return type(left).from_node(self, value=value) @@ -1132,7 +1157,7 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the boolean operation. @@ -1141,13 +1166,12 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - if next((i for i in self.values if not isinstance(i, NameConstant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [v.get_folded_value() for v in self.values] - values = [i.value for i in self.values] - if None in values: + if any(not isinstance(v, NameConstant) for v in values): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [v.value for v in values] value = self.op._op(values) return NameConstant.from_node(self, value=value) @@ -1188,7 +1212,7 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the comparison. @@ -1197,7 +1221,7 @@ def evaluate(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value() for i in (self.left, self.right)] if not isinstance(left, Constant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1278,6 +1302,21 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") + # try checking if this is a builtin, which is foldable + def _try_fold(self): + if not isinstance(self.func, Name): + raise UnfoldableNode("not a builtin", self) + + # cursed import cycle! + from vyper.builtins.functions import DISPATCH_TABLE + + func_name = self.func.id + if func_name not in DISPATCH_TABLE: + raise UnfoldableNode("not a builtin", self) + + builtin_t = DISPATCH_TABLE[func_name] + return builtin_t._try_fold(self) + class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1290,7 +1329,7 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def evaluate(self) -> ExprNode: + def _try_fold(self) -> ExprNode: """ Attempt to evaluate the subscript. @@ -1302,14 +1341,22 @@ def evaluate(self) -> ExprNode: ExprNode Node representing the result of the evaluation. """ - if not isinstance(self.value, List): + slice_ = self.slice.value.get_folded_value() + value = self.value.get_folded_value() + + if not isinstance(value, List): raise UnfoldableNode("Subscript object is not a literal list") - elements = self.value.elements + + elements = value.elements if len(set([type(i) for i in elements])) > 1: raise UnfoldableNode("List contains multiple node types") - idx = self.slice.get("value.value") - if not isinstance(idx, int) or idx < 0 or idx >= len(elements): - raise UnfoldableNode("Invalid index value") + + if not isinstance(slice_, Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") return elements[idx] @@ -1410,6 +1457,24 @@ def _check_args(annotation, call_name): if isinstance(self.annotation, Call): _raise_syntax_exc("Invalid scope for variable declaration", self.annotation) + def _pretty_location(self) -> str: + if self.is_constant: + return "Constant" + if self.is_transient: + return "Transient" + if self.is_immutable: + return "Immutable" + return "Storage" + + def validate(self): + if self.is_constant and self.value is None: + raise VariableDeclarationException("Constant must be declared with a value", self) + + if not self.is_constant and self.value is not None: + raise VariableDeclarationException( + f"{self._pretty_location} variables cannot have an initial value", self.value + ) + class AugAssign(Stmt): __slots__ = ("op", "target", "value") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47856b6021..8bc4a4eb57 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -24,9 +24,15 @@ class VyperNode: def __eq__(self, other: Any) -> Any: ... @property def description(self): ... + @property + def is_literal_value(self): ... + @property + def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def evaluate(self) -> VyperNode: ... + def get_folded_value(self) -> VyperNode: ... + def _try_fold(self) -> VyperNode: ... + def _set_folded_value(self, node: VyperNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... @@ -35,14 +41,14 @@ class VyperNode: node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ..., filters: Optional[dict] = ..., reverse: bool = ..., - ) -> Sequence: ... + ) -> list: ... def get_descendants( self, node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ..., filters: Optional[dict] = ..., include_self: bool = ..., reverse: bool = ..., - ) -> Sequence: ... + ) -> list: ... def get_ancestor( self, node_type: Union[Type[VyperNode], Sequence[Type[VyperNode]], None] = ... ) -> VyperNode: ... @@ -61,7 +67,6 @@ class TopLevel(VyperNode): class Module(TopLevel): path: str = ... resolved_path: str = ... - def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... def namespace(self) -> Any: ... # context manager diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index a2f2542179..38a9d31695 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -81,6 +81,7 @@ def parse_to_ast_with_settings( # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module diff --git a/vyper/ast/validation.py b/vyper/ast/validation.py index 36a6a0484c..387f7734b9 100644 --- a/vyper/ast/validation.py +++ b/vyper/ast/validation.py @@ -1,11 +1,11 @@ # validation utils for ast -# TODO this really belongs in vyper/semantics/validation/utils from typing import Optional, Union from vyper.ast import nodes as vy_ast from vyper.exceptions import ArgumentException, CompilerPanic, StructureException +# TODO this really belongs in vyper/semantics/validation/utils def validate_call_args( node: vy_ast.Call, arg_count: Union[int, tuple], kwargs: Optional[list] = None ) -> None: @@ -101,14 +101,13 @@ def validate_literal_nodes(vyper_module: vy_ast.Module) -> None: """ Individually validate Vyper AST nodes. - Calls the `validate` method of each node to verify that literal nodes - do not contain invalid values. + Recursively calls the `validate` method of each node to verify that + literal nodes do not contain invalid values. Arguments --------- vyper_module : vy_ast.Module Top level Vyper AST node. """ - for node in vyper_module.get_descendants(): - if hasattr(node, "validate"): - node.validate() + for node in vyper_module.get_descendants(include_self=True): + node.validate() diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index a5949dfd85..aac008ad1e 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,12 +1,17 @@ import functools from typing import Any, Optional -from vyper.ast import nodes as vy_ast +from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch -from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type +from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode +from vyper.semantics.analysis.base import Modifiability +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -29,7 +34,7 @@ def process_arg(arg, expected_arg_type, context): def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context): if kwarg_settings.require_literal: - return kwarg_node.value + return kwarg_node.get_folded_value().value return process_arg(kwarg_node, expected_kwarg_type, context) @@ -78,6 +83,7 @@ class BuiltinFunctionT(VyperType): _has_varargs = False _inputs: list[tuple[str, Any]] = [] _kwargs: dict[str, KwargSettings] = {} + _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None # helper function to deal with TYPE_DEFINITIONs @@ -106,8 +112,10 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): - raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) + if kwarg_settings.require_literal and not check_modifiability( + kwarg.value, Modifiability.CONSTANT + ): + raise TypeMismatch("Value must be literal", kwarg.value) self._validate_single(kwarg.value, kwarg_settings.typ) # typecheck varargs. we don't have type info from the signature, @@ -125,7 +133,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: return self._return_type - def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: + def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] @@ -142,3 +150,6 @@ def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: def __repr__(self): return f"(builtin) {self._id}" + + def _try_fold(self, node): + raise UnfoldableNode(f"not foldable: {self}", node) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index d50a31767d..c896fc7ef6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1,7 +1,6 @@ import hashlib import math import operator -from decimal import Decimal from vyper import ast as vy_ast from vyper.abi_types import ABI_Tuple @@ -44,14 +43,13 @@ CompilerPanic, InvalidLiteral, InvalidType, - OverflowException, StateAccessViolation, StructureException, TypeMismatch, UnfoldableNode, ZeroDivisionException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( get_common_types, get_exact_type_from_node, @@ -88,7 +86,6 @@ EIP_170_LIMIT, SHA3_PER_WORD, MemoryPositions, - SizeLimits, bytes_to_int, ceil32, fourbytes_to_int, @@ -108,9 +105,7 @@ class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded - # Since foldable builtin functions are not folded before semantics validation, - # this flag is used for `check_kwargable` in semantics validation. - _kwargable = True + _modifiability = Modifiability.CONSTANT class TypenameFoldedFunctionT(FoldedFunctionT): @@ -126,7 +121,7 @@ def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef return type_ - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 1) input_typedef = TYPE_T(type_from_annotation(node.args[0])) return [input_typedef] @@ -138,12 +133,13 @@ class Floor(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.floor(node.args[0].value) + value = math.floor(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -168,12 +164,13 @@ class Ceil(BuiltinFunctionT): # TODO: maybe use int136? _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Decimal): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode - value = math.ceil(node.args[0].value) + value = math.ceil(value.value) return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -202,7 +199,7 @@ def fetch_call_return(self, node): return target_typedef.typedef # TODO: push this down into convert.py for more consistency - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 2) target_type = type_from_annotation(node.args[1]) @@ -337,7 +334,7 @@ def fetch_call_return(self, node): return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `b` b_type = get_possible_types_from_node(node.args[0]).pop() @@ -461,20 +458,19 @@ class Len(BuiltinFunctionT): _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - arg = node.args[0] + arg = node.args[0].get_folded_value() if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): length = len(arg.value) elif isinstance(arg, vy_ast.Hex): - # 2 characters represent 1 byte and we subtract 1 to ignore the leading `0x` - length = len(arg.value) // 2 - 1 + length = len(arg.bytes_value) else: raise UnfoldableNode return vy_ast.Int.from_node(node, value=length) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type typ = get_possible_types_from_node(node.args[0]).pop() @@ -504,7 +500,7 @@ def fetch_call_return(self, node): return_type.set_length(length) return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: raise ArgumentException("Invalid argument count: expected at least 2", node) @@ -598,22 +594,22 @@ class Keccak256(BuiltinFunctionT): _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + value = value.bytes_value else: raise UnfoldableNode hash_ = f"0x{keccak256(value).hex()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -645,22 +641,22 @@ class Sha256(BuiltinFunctionT): _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if isinstance(node.args[0], vy_ast.Bytes): - value = node.args[0].value - elif isinstance(node.args[0], vy_ast.Str): - value = node.args[0].value.encode() - elif isinstance(node.args[0], vy_ast.Hex): - length = len(node.args[0].value) // 2 - 1 - value = int(node.args[0].value, 16).to_bytes(length, "big") + value = node.args[0].get_folded_value() + if isinstance(value, vy_ast.Bytes): + value = value.value + elif isinstance(value, vy_ast.Str): + value = value.value.encode() + elif isinstance(value, vy_ast.Hex): + value = value.bytes_value else: raise UnfoldableNode hash_ = f"0x{hashlib.sha256(value).hexdigest()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -714,18 +710,20 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" + _inputs = [("value", StringT.any())] + _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) - args = node.args - if not isinstance(args[0], vy_ast.Str): - raise InvalidType("method id must be given as a literal string", args[0]) - if " " in args[0].value: - raise InvalidLiteral("Invalid function signature - no spaces allowed.") + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Str): + raise InvalidType("method id must be given as a literal string", node.args[0]) + if " " in value.value: + raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) - return_type = self.infer_kwarg_types(node) - value = method_id_int(args[0].value) + return_type = self.infer_kwarg_types(node)["output_type"].typedef + value = method_id_int(value.value) if return_type.compare_type(BYTES4_T): return vy_ast.Hex.from_node(node, value=hex(value)) @@ -735,21 +733,22 @@ def evaluate(self, node): def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) - type_ = self.infer_kwarg_types(node) + type_ = self.infer_kwarg_types(node)["output_type"].typedef return type_ + def infer_arg_types(self, node, expected_return_typ=None): + return [self._inputs[0][1]] + def infer_kwarg_types(self, node): if node.keywords: - return_type = type_from_annotation(node.keywords[0].value) - if return_type.compare_type(BYTES4_T): - return BYTES4_T - elif isinstance(return_type, BytesT) and return_type.length == 4: - return BytesT(4) - else: + output_type = type_from_annotation(node.keywords[0].value) + if output_type not in (BytesT(4), BYTES4_T): raise ArgumentException("output_type must be Bytes[4] or bytes4", node.keywords[0]) + else: + # default to `Bytes[4]` + output_type = BytesT(4) - # If `output_type` is not given, default to `Bytes[4]` - return BytesT(4) + return {"output_type": TYPE_T(output_type)} class ECRecover(BuiltinFunctionT): @@ -762,7 +761,7 @@ class ECRecover(BuiltinFunctionT): ] _return_type = AddressT() - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) v_t, r_t, s_t = [get_possible_types_from_node(arg).pop() for arg in node.args[1:]] return [BYTES32_T, v_t, r_t, s_t] @@ -859,7 +858,7 @@ def fetch_call_return(self, node): return_type = self.infer_kwarg_types(node)["output_type"].typedef return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type, UINT256_T] @@ -974,42 +973,37 @@ class AsWeiValue(BuiltinFunctionT): } def get_denomination(self, node): - if not isinstance(node.args[1], vy_ast.Str): + value = node.args[1].get_folded_value() + if not isinstance(value, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] ) try: - denom = next(v for k, v in self.wei_denoms.items() if node.args[1].value in k) + denom = next(v for k, v in self.wei_denoms.items() if value.value in k) except StopIteration: - raise ArgumentException( - f"Unknown denomination: {node.args[1].value}", node.args[1] - ) from None + raise ArgumentException(f"Unknown denomination: {value.value}", node.args[1]) from None return denom - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + value = node.args[0].get_folded_value() + if not isinstance(value, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidLiteral("Negative wei value not allowed", node.args[0]) - if isinstance(value, int) and value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) - if isinstance(value, Decimal) and value > SizeLimits.MAX_AST_DECIMAL: - raise InvalidLiteral("Value out of range for decimal", node.args[0]) - return vy_ast.Int.from_node(node, value=int(value * denom)) def fetch_call_return(self, node): self.infer_arg_types(node) return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of abstract type value_type = get_possible_types_from_node(node.args[0]).pop() @@ -1074,8 +1068,14 @@ def fetch_call_return(self, node): kwargz = {i.arg: i.value for i in node.keywords} outsize = kwargz.get("max_outsize") + if outsize is not None: + outsize = outsize.get_folded_value() + revert_on_failure = kwargz.get("revert_on_failure") - revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True + if revert_on_failure is not None: + revert_on_failure = revert_on_failure.get_folded_value().value + else: + revert_on_failure = True if outsize is None or outsize.value == 0: if revert_on_failure: @@ -1093,7 +1093,7 @@ def fetch_call_return(self, node): return return_type return TupleT([BoolT(), return_type]) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `data` data_type = get_possible_types_from_node(node.args[1]).pop() @@ -1268,7 +1268,7 @@ class RawRevert(BuiltinFunctionT): def fetch_call_return(self, node): return None - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) data_type = get_possible_types_from_node(node.args[0]).pop() return [data_type] @@ -1288,7 +1288,7 @@ class RawLog(BuiltinFunctionT): def fetch_call_return(self, node): self.infer_arg_types(node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: @@ -1338,19 +1338,18 @@ class BitwiseAnd(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_and()` is deprecated! Please use the & operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value & node.args[1].value + value = values[0].value & values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1364,19 +1363,18 @@ class BitwiseOr(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_or()` is deprecated! Please use the | operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value | node.args[1].value + value = values[0].value | values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1390,19 +1388,18 @@ class BitwiseXor(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_xor()` is deprecated! Please use the ^ operator instead.") self.__class__._warned = True validate_call_args(node, 2) - for arg in node.args: - if not isinstance(arg, vy_ast.Int): + values = [i.get_folded_value() for i in node.args] + for val in values: + if not isinstance(val, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = node.args[0].value ^ node.args[1].value + value = values[0].value ^ values[1].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1416,18 +1413,17 @@ class BitwiseNot(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") self.__class__._warned = True validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value = value.value value = (2**256 - 1) - value return vy_ast.Int.from_node(node, value=value) @@ -1443,17 +1439,16 @@ class Shift(BuiltinFunctionT): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def _try_fold(self, node): if not self.__class__._warned: vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") self.__class__._warned = True validate_call_args(node, 2) - if [i for i in node.args if not isinstance(i, vy_ast.Int)]: + args = [i.get_folded_value() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode - value, shift = [i.value for i in node.args] - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) + value, shift = [i.value for i in args] if shift < -256 or shift > 256: # this validation is performed to prevent the compiler from hanging # rather than for correctness because the post-folded constant would @@ -1470,7 +1465,7 @@ def fetch_call_return(self, node): # return type is the type of the first argument return self.infer_arg_types(node)[0] - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of SignedIntegerAbstractType arg_ty = get_possible_types_from_node(node.args[0])[0] @@ -1495,17 +1490,16 @@ class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 3) - if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0: + args = [i.get_folded_value() for i in node.args] + if isinstance(args[2], vy_ast.Int) and args[2].value == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) - for arg in node.args: + for arg in args: if not isinstance(arg, vy_ast.Int): raise UnfoldableNode - if arg.value < 0 or arg.value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", arg) - value = self._eval_fn(node.args[0].value, node.args[1].value) % node.args[2].value + value = self._eval_fn(args[0].value, args[1].value) % args[2].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1537,15 +1531,13 @@ class PowMod256(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) - if next((i for i in node.args if not isinstance(i, vy_ast.Int)), None): - raise UnfoldableNode - - left, right = node.args - if left.value < 0 or right.value < 0: + values = [i.get_folded_value() for i in node.args] + if any(not isinstance(i, vy_ast.Int) for i in values): raise UnfoldableNode + left, right = values value = pow(left.value, right.value, 2**256) return vy_ast.Int.from_node(node, value=value) @@ -1560,18 +1552,13 @@ class Abs(BuiltinFunctionT): _inputs = [("value", INT256_T)] _return_type = INT256_T - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Literal is outside of allowable range for int256") - value = abs(value) - if not SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256: - raise OverflowException("Absolute literal value is outside allowable range for int256") - + value = abs(value.value) return vy_ast.Int.from_node(node, value=value) def build_IR(self, expr, context): @@ -1946,7 +1933,7 @@ def fetch_call_return(self, node): return_type = self.infer_arg_types(node).pop() return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) types_list = get_common_types(*node.args, filter_fn=lambda x: isinstance(x, IntegerT)) @@ -2004,34 +1991,26 @@ class UnsafeDiv(_UnsafeMath): class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 2) - if not isinstance(node.args[0], type(node.args[1])): + + left = node.args[0].get_folded_value() + right = node.args[1].get_folded_value() + if not isinstance(left, type(right)): raise UnfoldableNode - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + if not isinstance(left, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - left, right = (i.value for i in node.args) - if isinstance(left, Decimal) and ( - min(left, right) < SizeLimits.MIN_AST_DECIMAL - or max(left, right) > SizeLimits.MAX_AST_DECIMAL - ): - raise InvalidType("Decimal value is outside of allowable range", node) - types_list = get_common_types( - *node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) + *(left, right), filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) ) if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - value = self._eval_fn(left, right) - return type(node.args[0]).from_node(node, value=value) + value = self._eval_fn(left.value, right.value) + return type(left).from_node(node, value=value) def fetch_call_return(self, node): - return_type = self.infer_arg_types(node).pop() - return return_type - - def infer_arg_types(self, node): self._validate_arg_types(node) types_list = get_common_types( @@ -2040,8 +2019,13 @@ def infer_arg_types(self, node): if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - type_ = types_list.pop() - return [type_, type_] + return types_list + + def infer_arg_types(self, node, expected_return_typ=None): + types_list = self.fetch_call_return(node) + # type mismatch should have been caught in `fetch_call_return` + assert expected_return_typ in types_list + return [expected_return_typ, expected_return_typ] @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -2085,18 +2069,19 @@ def fetch_call_return(self, node): len_needed = math.ceil(bits * math.log(2) / math.log(10)) return StringT(len_needed) - def evaluate(self, node): + def _try_fold(self, node): validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Int): + value = node.args[0].get_folded_value() + if not isinstance(value, vy_ast.Int): raise UnfoldableNode - value = node.args[0].value + value = value.value if value < 0: raise InvalidType("Only unsigned ints allowed", node) value = str(value) return vy_ast.Str.from_node(node, value=value) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type] @@ -2493,7 +2478,7 @@ def fetch_call_return(self, node): _, output_type = self.infer_arg_types(node) return output_type.typedef - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) validate_call_args(node, 2, ["unwrap_tuple"]) @@ -2572,7 +2557,7 @@ def build_IR(self, expr, args, kwargs, context): class _MinMaxValue(TypenameFoldedFunctionT): - def evaluate(self, node): + def _try_fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2590,6 +2575,10 @@ def evaluate(self, node): ret._metadata["type"] = input_type return ret + def infer_arg_types(self, node, expected_return_typ=None): + input_typedef = TYPE_T(type_from_annotation(node.args[0])) + return [input_typedef] + class MinValue(_MinMaxValue): _id = "min_value" @@ -2608,7 +2597,7 @@ def _eval(self, type_): class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" - def evaluate(self, node): + def _try_fold(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 3063a289ab..d6ba9e180a 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -33,7 +33,8 @@ devdoc - Natspec developer documentation combined_json - All of the above format options combined as single JSON output layout - Storage layout of a Vyper contract -ast - AST in JSON format +ast - AST (not yet annotated) in JSON format +annotated_ast - Annotated AST in JSON format interface - Vyper interface of a contract external_interface - External interface of a contract, used for outside contract calls opcodes - List of opcodes as a string @@ -255,7 +256,13 @@ def compile_files( output_formats = combined_json_outputs show_version = True - translate_map = {"abi_python": "abi", "json": "abi", "ast": "ast_dict", "ir_json": "ir_dict"} + translate_map = { + "abi_python": "abi", + "json": "abi", + "ast": "ast_dict", + "annotated_ast": "annotated_ast_dict", + "ir_json": "ir_dict", + } final_formats = [translate_map.get(i, i) for i in output_formats] if storage_layout_paths: diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 4c7c3afaed..577660b883 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -69,9 +69,6 @@ class Expr: # TODO: Once other refactors are made reevaluate all inline imports def __init__(self, node, context): - self.expr = node - self.context = context - if isinstance(node, IRnode): # this is a kludge for parse_AugAssign to pass in IRnodes # directly. @@ -79,6 +76,13 @@ def __init__(self, node, context): self.ir_node = node return + assert isinstance(node, vy_ast.VyperNode) + if node.has_folded_value: + node = node.get_folded_value() + + self.expr = node + self.context = context + fn_name = f"parse_{type(node).__name__}" with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): fn = getattr(self, fn_name) @@ -184,6 +188,13 @@ def parse_Name(self): # TODO: use self.expr._expr_info elif self.expr.id in self.context.globals: varinfo = self.context.globals[self.expr.id] + + if varinfo.is_constant: + # non-struct constants should have already gotten propagated + # during constant folding + assert isinstance(varinfo.typ, StructT) + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + assert varinfo.is_immutable, "not an immutable!" ofst = varinfo.position.offset diff --git a/vyper/compiler/README.md b/vyper/compiler/README.md index eb70750a2b..abb8c6ee91 100644 --- a/vyper/compiler/README.md +++ b/vyper/compiler/README.md @@ -25,8 +25,6 @@ The compilation process includes the following broad phases: 1. In [`vyper.ast`](../ast), the source code is parsed and converted to an abstract syntax tree. -1. In [`vyper.ast.folding`](../ast/folding.py), literal Vyper AST nodes are -evaluated and replaced with the resulting values. 1. The [`GlobalContext`](../codegen/global_context.py) object is generated from the Vyper AST, analyzing and organizing the nodes prior to IR generation. 1. In [`vyper.codegen.module`](../codegen/module.py), the contextualized nodes are diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index c87814ba15..0f7d7a8014 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -14,6 +14,8 @@ OUTPUT_FORMATS = { # requires vyper_module "ast_dict": output.build_ast_dict, + # requires annotated_vyper_module + "annotated_ast_dict": output.build_annotated_ast_dict, "layout": output.build_layout_output, # requires global_ctx "devdoc": output.build_devdoc, diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index dc2a43720e..8ccf6abee1 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -23,18 +23,26 @@ def build_ast_dict(compiler_data: CompilerData) -> dict: return ast_dict +def build_annotated_ast_dict(compiler_data: CompilerData) -> dict: + annotated_ast_dict = { + "contract_name": str(compiler_data.contract_path), + "ast": ast_to_dict(compiler_data.annotated_vyper_module), + } + return annotated_ast_dict + + def build_devdoc(compiler_data: CompilerData) -> dict: - userdoc, devdoc = parse_natspec(compiler_data.vyper_module_folded) + userdoc, devdoc = parse_natspec(compiler_data.annotated_vyper_module) return devdoc def build_userdoc(compiler_data: CompilerData) -> dict: - userdoc, devdoc = parse_natspec(compiler_data.vyper_module_folded) + userdoc, devdoc = parse_natspec(compiler_data.annotated_vyper_module) return userdoc def build_external_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"].interface + interface = compiler_data.annotated_vyper_module._metadata["type"].interface stem = PurePath(compiler_data.contract_path).stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface @@ -53,7 +61,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"].interface + interface = compiler_data.annotated_vyper_module._metadata["type"].interface out = "" if interface.events: @@ -158,7 +166,7 @@ def _to_dict(func_t): def build_method_identifiers_output(compiler_data: CompilerData) -> dict: - module_t = compiler_data.vyper_module_folded._metadata["type"] + module_t = compiler_data.annotated_vyper_module._metadata["type"] functions = module_t.function_defs return { @@ -167,7 +175,7 @@ def build_method_identifiers_output(compiler_data: CompilerData) -> dict: def build_abi_output(compiler_data: CompilerData) -> list: - module_t = compiler_data.vyper_module_folded._metadata["type"] + module_t = compiler_data.annotated_vyper_module._metadata["type"] _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index b9b2df6ae8..8cbcfb1da9 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property from pathlib import Path, PurePath -from typing import Optional, Tuple +from typing import Optional from vyper import ast as vy_ast from vyper.codegen import module @@ -53,8 +53,8 @@ class CompilerData: ---------- vyper_module : vy_ast.Module Top-level Vyper AST node - vyper_module_folded : vy_ast.Module - Folded Vyper AST + annotated_vyper_module: vy_ast.Module + Annotated+analysed Vyper AST global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode @@ -152,31 +152,24 @@ def vyper_module(self): return self._generate_ast @cached_property - def vyper_module_unfolded(self) -> vy_ast.Module: - # This phase is intended to generate an AST for tooling use, and is not - # used in the compilation process. - - return generate_unfolded_ast(self.vyper_module, self.input_bundle) - - @cached_property - def _folded_module(self): - return generate_folded_ast( + def _annotated_module(self): + return generate_annotated_ast( self.vyper_module, self.input_bundle, self.storage_layout_override ) @property - def vyper_module_folded(self) -> vy_ast.Module: - module, storage_layout = self._folded_module + def annotated_vyper_module(self) -> vy_ast.Module: + module, storage_layout = self._annotated_module return module @property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._folded_module + module, storage_layout = self._annotated_module return storage_layout @property def global_ctx(self) -> ModuleT: - return self.vyper_module_folded._metadata["type"] + return self.annotated_vyper_module._metadata["type"] @cached_property def _ir_output(self): @@ -205,7 +198,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: # ensure codegen is run: _ = self._ir_output - fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef) + fs = self.annotated_vyper_module.get_children(vy_ast.FunctionDef) return {f.name: f._metadata["func_type"] for f in fs} @cached_property @@ -247,25 +240,13 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -# destructive -- mutates module in place! -def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: - vy_ast.validation.validate_literal_nodes(vyper_module) - vy_ast.folding.replace_builtin_functions(vyper_module) - - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, input_bundle) - - return vyper_module - - -def generate_folded_ast( +def generate_annotated_ast( vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, -) -> Tuple[vy_ast.Module, StorageLayout]: +) -> tuple[vy_ast.Module, StorageLayout]: """ - Perform constant folding operations on the Vyper AST. + Validates and annotates the Vyper AST. Arguments --------- @@ -275,22 +256,18 @@ def generate_folded_ast( Returns ------- vy_ast.Module - Folded Vyper AST + Annotated Vyper AST StorageLayout Layout of variables in storage """ - - vy_ast.validation.validate_literal_nodes(vyper_module) - - vyper_module_folded = copy.deepcopy(vyper_module) - vy_ast.folding.fold(vyper_module_folded) - + vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - validate_semantics(vyper_module_folded, input_bundle) + # note: validate_semantics does type inference on the AST + validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) + symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - return vyper_module_folded, symbol_tables + return vyper_module, symbol_tables def generate_ir_nodes( diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md index 1d81a0979b..36519bba29 100644 --- a/vyper/semantics/README.md +++ b/vyper/semantics/README.md @@ -25,6 +25,7 @@ Vyper abstract syntax tree (AST). * [`data_positions`](analysis/data_positions.py): Functions for tracking storage variables and allocating storage slots * [`levenhtein_utils.py`](analysis/levenshtein_utils.py): Helper for better error messages * [`local.py`](analysis/local.py): Validates the local namespace of each function within a contract + * [`pre_typecheck.py`](analysis/pre_typecheck.py): Evaluate foldable nodes and populate their metadata with the replacement nodes. * [`module.py`](analysis/module.py): Validates the module namespace of a contract. * [`utils.py`](analysis/utils.py): Functions for comparing and validating types * [`data_locations.py`](data_locations.py): `DataLocation` object for type location information @@ -35,13 +36,23 @@ Vyper abstract syntax tree (AST). The [`analysis`](analysis) subpackage contains the top-level `validate_semantics` function. This function is used to verify and type-check a contract. The process -consists of three steps: +consists of four steps: -1. Preparing the builtin namespace -2. Validating the module-level scope -3. Annotating and validating local scopes +1. Populating the metadata of foldable nodes with their replacement nodes +2. Preparing the builtin namespace +3. Validating the module-level scope +4. Annotating and validating local scopes -### 1. Preparing the builtin namespace +### 1. Populating the metadata of foldable nodes with their replacement nodes + +[`analysis/pre_typecheck.py`](analysis/pre_typecheck.py) populates the metadata of foldable nodes with their replacement nodes. + +This process includes: +1. Foldable node classes and builtin functions are evaluated via their `fold` method, which attempts to create a new `Constant` from the content of the given node. +2. Replacement nodes are generated using the `from_node` class method within the new +node class. + +### 2. Preparing the builtin namespace The [`Namespace`](namespace.py) object represents the namespace for a contract. Builtins are added upon initialization of the object. This includes: @@ -51,9 +62,9 @@ Builtins are added upon initialization of the object. This includes: * Adding builtin functions from the [`functions`](../builtins/functions.py) package * Adding / resetting `self` and `log` -### 2. Validating the Module Scope +### 3. Validating the Module Scope -[`validation/module.py`](validation/module.py) validates the module-level scope +[`analysis/module.py`](analysis/module.py) validates the module-level scope of a contract. This includes: * Generating user-defined types (e.g. structs and interfaces) @@ -61,9 +72,9 @@ of a contract. This includes: and functions * Validating import statements and function signatures -### 3. Annotating and validating the Local Scopes +### 4. Annotating and validating the Local Scopes -[`validation/local.py`](validation/local.py) validates the local scope within each +[`analysis/local.py`](analysis/local.py) validates the local scope within each function in a contract. `FunctionNodeVisitor` is used to iterate over the statement nodes in each function body, annotate them and apply appropriate checks. diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 4d1b1cdbab..bb6d9ad9f7 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -97,6 +97,27 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # specifying a state mutability modifier at all. Do the same here. +# classify the constancy of an expression +# CMC 2023-12-31 note that we now have three ways of classifying mutability in +# the codebase: StateMutability (for functions), Modifiability (for expressions +# and variables) and Constancy (in codegen). context.Constancy can/should +# probably be refactored away though as those kinds of checks should be done +# during analysis. +class Modifiability(enum.IntEnum): + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() + + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. + + # things that are constant within the current message call, including + # block.*, msg.*, tx.* and immutables + RUNTIME_CONSTANT = enum.auto() + + # compile-time / always constant + CONSTANT = enum.auto() + + class DataPosition: _location: DataLocation @@ -182,21 +203,18 @@ class ImportInfo(AnalysisResult): class VarInfo: """ VarInfo are objects that represent the type of a variable, - plus associated metadata like location and constancy attributes + plus associated metadata like location and modifiability attributes Object Attributes ----------------- - is_constant : bool, optional - If `True`, this is a variable defined with the `constant()` modifier + location: DataLocation of this variable + modifiability: Modifiability of this variable """ typ: VyperType location: DataLocation = DataLocation.UNSET - is_constant: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE is_public: bool = False - is_immutable: bool = False - is_transient: bool = False - is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None def __hash__(self): @@ -211,10 +229,28 @@ def set_position(self, position: DataPosition) -> None: if self.location != position._location: if self.location == DataLocation.UNSET: self.location = position._location + elif self.is_transient and position._location == DataLocation.STORAGE: + # CMC 2023-12-31 - use same allocator for storage and transient + # for now, this should be refactored soon. + pass else: raise CompilerPanic("Incompatible locations") self.position = position + @property + def is_transient(self): + return self.location == DataLocation.TRANSIENT + + @property + def is_immutable(self): + return self.location == DataLocation.CODE + + @property + def is_constant(self): + res = self.location == DataLocation.UNSET + assert res == (self.modifiability == Modifiability.CONSTANT) + return res + @dataclass class ExprInfo: @@ -225,11 +261,10 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET - is_constant: bool = False - is_immutable: bool = False + modifiability: Modifiability = Modifiability.MODIFIABLE def __post_init__(self): - should_match = ("typ", "location", "is_constant", "is_immutable") + should_match = ("typ", "location", "modifiability") if self.var_info is not None: for attr in should_match: if getattr(self.var_info, attr) != getattr(self, attr): @@ -241,8 +276,7 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": var_info.typ, var_info=var_info, location=var_info.location, - is_constant=var_info.is_constant, - is_immutable=var_info.is_immutable, + modifiability=var_info.modifiability, ) @classmethod @@ -253,7 +287,7 @@ def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ - to_copy = ("location", "is_constant", "is_immutable") + to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} return self.__class__(typ=typ, **fields) @@ -277,17 +311,24 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.location == DataLocation.CALLDATA: raise ImmutableViolation("Cannot write to calldata", node) - if self.is_constant: + + if self.modifiability == Modifiability.RUNTIME_CONSTANT: + if self.location == DataLocation.CODE: + if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": + raise ImmutableViolation("Immutable value cannot be written to", node) + + # special handling for immutable variables in the ctor + # TODO: we probably want to remove this restriction. + if self.var_info._modification_count: # type: ignore + raise ImmutableViolation( + "Immutable value cannot be modified after assignment", node + ) + self.var_info._modification_count += 1 # type: ignore + else: + raise ImmutableViolation("Environment variable cannot be written to", node) + + if self.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to", node) - if self.is_immutable: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore if isinstance(node, vy_ast.AugAssign): self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a3ebf85fa2..91fb2c21f0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -18,7 +18,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -186,16 +186,18 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = _ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self.func) def analyze(self): # allow internal function params to be mutable - location, is_immutable = ( - (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) - ) + if self.func.is_internal: + location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) + else: + location, modifiability = (DataLocation.CALLDATA, Modifiability.RUNTIME_CONSTANT) + for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, is_immutable=is_immutable + arg.typ, location=location, modifiability=modifiability ) for node in self.fn_node.body: @@ -358,7 +360,8 @@ def visit_For(self, node): else: # iteration over a variable or literal list - if isinstance(node.iter, vy_ast.List) and len(node.iter.elements) == 0: + iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter + if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) type_list = [ @@ -421,32 +424,35 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) + self.namespace[iter_name] = VarInfo( + possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) try: with NodeMetadata.enter_typechecker_speculation(): for stmt in node.body: self.visit(stmt) + + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - # success -- do not enter error handling section return @@ -523,10 +529,10 @@ def visit_Return(self, node): self.expr_visitor.visit(node.value, self.func.return_type) -class _ExprVisitor(VyperNodeVisitorBase): +class ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def __init__(self, fn_node: ContractFunctionT): + def __init__(self, fn_node: Optional[ContractFunctionT] = None): self.func = fn_node def visit(self, node, typ): @@ -543,6 +549,12 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + # validate and annotate folded value + if node.has_folded_value: + folded_node = node.get_folded_value() + validate_expected_type(folded_node, typ) + folded_node._metadata["type"] = typ + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) @@ -551,10 +563,10 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: # if self.func.mutability < expr_info.mutability: # raise ... - if self.func.mutability != StateMutability.PAYABLE: + if self.func and self.func.mutability != StateMutability.PAYABLE: _validate_msg_value_access(node) - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_pure_access(node, typ) value_type = get_exact_type_from_node(node.value) @@ -589,7 +601,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if isinstance(call_type, ContractFunctionT): # function calls - if call_type.is_internal: + if self.func and call_type.is_internal: self.func.called_functions.add(call_type) for arg, typ in zip(node.args, call_type.argument_types): self.visit(arg, typ) @@ -615,7 +627,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node) + arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) # `infer_arg_types` already calls `validate_expected_type` for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) @@ -680,7 +692,7 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func.mutability == StateMutability.PURE: + if self.func and self.func.mutability == StateMutability.PURE: _validate_self_reference(node) if not isinstance(typ, TYPE_T): @@ -691,7 +703,7 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): possible_base_types = get_possible_types_from_node(node.value) for possible_type in possible_base_types: @@ -747,6 +759,7 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] all_args = (start, end, *kwargs.values()) for arg1 in all_args: @@ -758,6 +771,8 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: if "bound" in kwargs: bound = kwargs["bound"] + if bound.has_folded_value: + bound = bound.get_folded_value() if not isinstance(bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) if bound.value <= 0: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index fb536b7ab7..8e435f870f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -4,6 +4,7 @@ import vyper.builtins.interfaces from vyper import ast as vy_ast +from vyper.ast.validation import validate_literal_nodes from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( @@ -20,12 +21,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.local import ExprVisitor, validate_functions +from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import ( - check_constant, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -51,6 +53,10 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + validate_literal_nodes(module_ast) + + pre_typecheck(module_ast) + # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -254,12 +260,19 @@ def visit_VariableDecl(self, node): if node.is_immutable else DataLocation.UNSET if node.is_constant - # XXX: needed if we want separate transient allocator - # else DataLocation.TRANSIENT - # if node.is_transient + else DataLocation.TRANSIENT + if node.is_transient else DataLocation.STORAGE ) + modifiability = ( + Modifiability.RUNTIME_CONSTANT + if node.is_immutable + else Modifiability.CONSTANT + if node.is_constant + else Modifiability.MODIFIABLE + ) + type_ = type_from_annotation(node.annotation, data_loc) if node.is_transient and not version_check(begin="cancun"): @@ -269,10 +282,8 @@ def visit_VariableDecl(self, node): type_, decl_node=node, location=data_loc, - is_constant=node.is_constant, + modifiability=modifiability, is_public=node.is_public, - is_immutable=node.is_immutable, - is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ @@ -302,9 +313,11 @@ def _validate_self_namespace(): self.namespace[name] = var_info if node.is_constant: - if not node.value: - raise VariableDeclarationException("Constant must be declared with a value", node) - if not check_constant(node.value): + assert node.value is not None # checked in VariableDecl.validate() + + ExprVisitor().visit(node.value, type_) + + if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) validate_expected_type(node.value, type_) @@ -312,11 +325,7 @@ def _validate_self_namespace(): return _finalize() - if node.value: - var_type = "Immutable" if node.is_immutable else "Storage" - raise VariableDeclarationException( - f"{var_type} variables cannot have an initial value", node.value - ) + assert node.value is None # checked in VariableDecl.validate() if node.is_immutable: _validate_self_namespace() @@ -482,9 +491,6 @@ def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: module_path=str(file.path), resolved_path=str(file.resolved_path), ) - vy_ast.validation.validate_literal_nodes(ret) - vy_ast.folding.fold(ret) - return ret diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py new file mode 100644 index 0000000000..a1302ce9c9 --- /dev/null +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -0,0 +1,94 @@ +from vyper import ast as vy_ast +from vyper.exceptions import UnfoldableNode + + +# try to fold a node, swallowing exceptions. this function is very similar to +# `VyperNode.get_folded_value()` but additionally checks in the constants +# table if the node is a `Name` node. +# +# CMC 2023-12-30 a potential refactor would be to move this function into +# `Name._try_fold` (which would require modifying the signature of _try_fold to +# take an optional constants table as parameter). this would remove the +# need to use this function in conjunction with `get_descendants` since +# `VyperNode._try_fold()` already recurses. it would also remove the need +# for `VyperNode._set_folded_value()`. +def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): + if node.has_folded_value: + return + + if isinstance(node, vy_ast.Name): + # check if it's in constants table + var_name = node.id + + if var_name not in constants: + return + + res = constants[var_name] + node._set_folded_value(res) + return + + try: + # call get_folded_value for its side effects + node.get_folded_value() + except UnfoldableNode: + pass + + +def _get_constants(node: vy_ast.Module) -> dict: + constants: dict[str, vy_ast.VyperNode] = {} + const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + assert c.value is not None # guaranteed by VariableDecl.validate() + + for n in c.get_descendants(reverse=True): + _fold_with_constants(n, constants) + + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + return constants + + +# perform constant folding on a module AST +def pre_typecheck(node: vy_ast.Module) -> None: + """ + Perform pre-typechecking steps on a Module AST node. + At this point, this is limited to performing constant folding. + """ + constants = _get_constants(node) + + # note: use reverse to get descendants in leaf-first order + for n in node.get_descendants(reverse=True): + # try folding every single node. note this should be done before + # type checking because the typechecker requires literals or + # foldable nodes in type signatures and some other places (e.g. + # certain builtin kwargs). + # + # note we could limit to only folding nodes which are required + # during type checking, but it's easier to just fold everything + # and be done with it! + _fold_with_constants(n, constants) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 20ebb0f093..ba1b02b8d6 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -98,12 +98,9 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # kludge! for validate_modification in local analysis of Assign types = [self.get_expr_info(n) for n in node.elements] location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - is_constant = any((getattr(i, "is_constant", False) for i in types)) - is_immutable = any((getattr(i, "is_immutable", False) for i in types)) + modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - return ExprInfo( - t, location=location, is_constant=is_constant, is_immutable=is_immutable - ) + return ExprInfo(t, location=location, modifiability=modifiability) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): @@ -137,8 +134,7 @@ def get_exact_type_from_node(self, node, include_type_exprs=False): def get_possible_types_from_node(self, node, include_type_exprs=False): """ Find all possible types for a given node. - If the node's metadata contains type information propagated from constant folding, - then that type is returned. + If the node's metadata contains type information, then that type is returned. Arguments --------- @@ -203,10 +199,12 @@ def _raise_invalid_reference(name, node): if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] + + # general case. s is a VarInfo, e.g. self.foo if is_self_reference and (s.is_constant or s.is_immutable): _raise_invalid_reference(name, node) - # general case. s is a VarInfo, e.g. self.foo return [s.typ] + except UnknownAttribute as e: if not is_self_reference: raise e from None @@ -282,6 +280,8 @@ def types_from_Call(self, node): var = self.get_exact_type_from_node(node.func, include_type_exprs=True) return_value = var.fetch_call_return(node) if return_value: + if isinstance(return_value, list): + return return_value return [return_value] raise InvalidType(f"{var} did not return a value", node) @@ -378,7 +378,7 @@ def types_from_Name(self, node): def types_from_Subscript(self, node): # index access, e.g. `foo[1]` - if isinstance(node.value, vy_ast.List): + if isinstance(node.value, (vy_ast.List, vy_ast.Subscript)): types_list = self.get_possible_types_from_node(node.value) ret = [] for t in types_list: @@ -625,54 +625,33 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_kwargable(node: vy_ast.VyperNode) -> bool: +def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: """ - Check if the given node can be used as a default arg + Check if the given node is not more modifiable than the given modifiability. """ - if _check_literal(node): + if node.is_literal_value or node.has_folded_value: return True - if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_kwargable(item) for item in node.elements) - if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_kwargable(v) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) - if getattr(call_type, "_kwargable", False): - return True + if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)): + return all(check_modifiability(i, modifiability) for i in (node.left, node.right)) - value_type = get_expr_info(node) - # is_constant here actually means not_assignable, and is to be renamed - return value_type.is_constant + if isinstance(node, vy_ast.BoolOp): + return all(check_modifiability(i, modifiability) for i in node.values) + if isinstance(node, vy_ast.UnaryOp): + return check_modifiability(node.operand, modifiability) -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - if isinstance(node, vy_ast.Constant): - return True - elif isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(_check_literal(item) for item in node.elements) - return False - - -def check_constant(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal or constant value. - """ - if _check_literal(node): - return True if isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(check_constant(item) for item in node.elements) + return all(check_modifiability(item, modifiability) for item in node.elements) + if isinstance(node, vy_ast.Call): args = node.args if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_constant(v) for v in args[0].values) + return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) - if getattr(call_type, "_kwargable", False): - return True + call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) + return call_type_modifiability >= modifiability - return False + value_type = get_expr_info(node) + return value_type.modifiability >= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 2f259b1766..cecea35a60 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -7,5 +7,4 @@ class DataLocation(enum.Enum): STORAGE = 2 CALLDATA = 3 CODE = 4 - # XXX: needed for separate transient storage allocator - # TRANSIENT = 5 + TRANSIENT = 5 diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index ad68f1103e..38bac0a63d 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,6 +1,6 @@ from typing import Dict -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.types import AddressT, BytesT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict: """ result = {} for k, v in CONSTANT_ENVIRONMENT_VARS.items(): - result[k] = VarInfo(v, is_constant=True) + result[k] = VarInfo(v, modifiability=Modifiability.RUNTIME_CONSTANT) return result diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 6ecfe78be3..429ba807e1 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -340,7 +340,7 @@ def fetch_call_return(self, node): return self.typedef._ctor_call_return(node) raise StructureException("Value is not callable", node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if hasattr(self.typedef, "_ctor_arg_types"): return self.typedef._ctor_arg_types(node) raise StructureException("Value is not callable", node) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 34206546fd..7c77560e49 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -16,9 +16,14 @@ StateAccessViolation, StructureException, ) -from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot +from vyper.semantics.analysis.base import ( + FunctionVisibility, + Modifiability, + StateMutability, + StorageSlot, +) from vyper.semantics.analysis.utils import ( - check_kwargable, + check_modifiability, get_exact_type_from_node, validate_expected_type, ) @@ -128,7 +133,7 @@ def __repr__(self): def __str__(self): ret_sig = "" if not self.return_type else f" -> {self.return_type}" args_sig = ",".join([str(t) for t in self.argument_types]) - return f"def {self.name} {args_sig}{ret_sig}:" + return f"def {self.name}({args_sig}){ret_sig}:" # override parent implementation. function type equality does not # make too much sense. @@ -696,7 +701,7 @@ def _parse_args( positional_args.append(PositionalArg(argname, type_, ast_source=arg)) else: value = funcdef.args.defaults[i - n_positional_args] - if not check_kwargable(value): + if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) validate_expected_type(value, type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 0c8e9fddd8..55ffc23b2f 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -288,6 +288,8 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": raise StructureException(err_msg, node.slice) length_node = node.slice.value.elements[1] + if length_node.has_folded_value: + length_node = length_node.get_folded_value() if not isinstance(length_node, vy_ast.Int): raise StructureException(err_msg, length_node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 8d68a9fa01..eb96375404 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -179,7 +179,11 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - if not isinstance(node.get("value"), vy_ast.Int): + value = node.get("value") + if value.has_folded_value: + value = value.get_folded_value() + + if not isinstance(value, vy_ast.Int): if hasattr(node, "value"): # even though the subscript is an invalid type, first check if it's a valid _something_ # this gives a more accurate error in case of e.g. a typo in a constant variable name @@ -191,7 +195,7 @@ def get_index_value(node: vy_ast.Index) -> int: raise InvalidType("Subscript must be a literal integer", node) - if node.value.value <= 0: + if value.value <= 0: raise ArrayIndexException("Subscript must be greater than 0", node) - return node.value.value + return value.value From 0c82d0bb0da9d696a4baeae18e021bae6b8287eb Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Tue, 2 Jan 2024 18:52:48 +0200 Subject: [PATCH 152/201] feat: remove `deploy` instruction from venom (#3703) this commit removes the `deploy` instruction from venom and replaces it with the possibility to support multiple entry points to a program. this lets us remove special opcode handling during CFG normalization and rely on cfg_ins/cfg_outs directly. --------- Co-authored-by: Charles Cooper --- .../compiler/venom/test_duplicate_operands.py | 6 +- .../compiler/venom/test_multi_entry_block.py | 6 +- vyper/compiler/phases.py | 20 ++-- vyper/venom/__init__.py | 33 +++++-- vyper/venom/analysis.py | 21 ----- vyper/venom/basicblock.py | 9 +- vyper/venom/function.py | 33 +++++-- vyper/venom/ir_node_to_venom.py | 53 ++++++----- vyper/venom/passes/normalization.py | 8 +- vyper/venom/venom_to_assembly.py | 94 +++++++++---------- 10 files changed, 144 insertions(+), 139 deletions(-) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index a51992df67..b96c7f3351 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -18,10 +18,10 @@ def test_duplicate_operands(): ctx = IRFunction() bb = ctx.get_basic_block() op = bb.append_instruction("store", 10) - sum = bb.append_instruction("add", op, op) - bb.append_instruction("mul", sum, op) + sum_ = bb.append_instruction("add", op, op) + bb.append_instruction("mul", sum_, op) bb.append_instruction("stop") - asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.CODESIZE) assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 104697432b..6d8b074994 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -41,7 +41,7 @@ def test_multi_entry_block_1(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" @@ -93,7 +93,7 @@ def test_multi_entry_block_2(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" @@ -134,5 +134,5 @@ def test_multi_entry_block_with_dynamic_jump(): finish_bb = ctx.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in.keys()) assert cfg_in[0].label.value == "target", "Should contain target" - assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[1].label.value == "finish_split___global", "Should contain finish_split___global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 8cbcfb1da9..850adcfea3 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -174,13 +174,9 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - nodes = generate_ir_nodes( + return generate_ir_nodes( self.global_ctx, self.settings.optimize, self.settings.experimental_codegen ) - if self.settings.experimental_codegen: - return [generate_ir(nodes[0]), generate_ir(nodes[1])] - else: - return nodes @property def ir_nodes(self) -> IRnode: @@ -201,11 +197,17 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: fs = self.annotated_vyper_module.get_children(vy_ast.FunctionDef) return {f.name: f._metadata["func_type"] for f in fs} + @cached_property + def venom_functions(self): + return generate_ir(self.ir_nodes, self.settings.optimize) + @cached_property def assembly(self) -> list: if self.settings.experimental_codegen: + deploy_code, runtime_code = self.venom_functions + assert self.settings.optimize is not None # mypy hint return generate_assembly_experimental( - self.ir_nodes, self.settings.optimize # type: ignore + runtime_code, deploy_code=deploy_code, optimize=self.settings.optimize ) else: return generate_assembly(self.ir_nodes, self.settings.optimize) @@ -213,9 +215,9 @@ def assembly(self) -> list: @cached_property def assembly_runtime(self) -> list: if self.settings.experimental_codegen: - return generate_assembly_experimental( - self.ir_runtime, self.settings.optimize # type: ignore - ) + _, runtime_code = self.venom_functions + assert self.settings.optimize is not None # mypy hint + return generate_assembly_experimental(runtime_code, optimize=self.settings.optimize) else: return generate_assembly(self.ir_runtime, self.settings.optimize) diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 5a09f8378e..570aba771a 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -1,7 +1,7 @@ # maybe rename this `main.py` or `venom.py` # (can have an `__init__.py` which exposes the API). -from typing import Optional +from typing import Any, Optional from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel @@ -17,19 +17,26 @@ from vyper.venom.passes.dft import DFTPass from vyper.venom.venom_to_assembly import VenomCompiler +DEFAULT_OPT_LEVEL = OptimizationLevel.default() + def generate_assembly_experimental( - ctx: IRFunction, optimize: Optional[OptimizationLevel] = None + runtime_code: IRFunction, + deploy_code: Optional[IRFunction] = None, + optimize: OptimizationLevel = DEFAULT_OPT_LEVEL, ) -> list[str]: - compiler = VenomCompiler(ctx) - return compiler.generate_evm(optimize is OptimizationLevel.NONE) + # note: VenomCompiler is sensitive to the order of these! + if deploy_code is not None: + functions = [deploy_code, runtime_code] + else: + functions = [runtime_code] + compiler = VenomCompiler(functions) + return compiler.generate_evm(optimize == OptimizationLevel.NONE) -def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: - # Convert "old" IR to "new" IR - ctx = convert_ir_basicblock(ir) - # Run passes on "new" IR +def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: + # Run passes on Venom IR # TODO: Add support for optimization levels while True: changes = 0 @@ -53,4 +60,12 @@ def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRF if changes == 0: break - return ctx + +def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> tuple[IRFunction, IRFunction]: + # Convert "old" IR to "new" IR + ctx, ctx_runtime = convert_ir_basicblock(ir) + + _run_passes(ctx, optimize) + _run_passes(ctx_runtime, optimize) + + return ctx, ctx_runtime diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 6dfc3c3d7c..eed579463e 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -19,27 +19,6 @@ def calculate_cfg(ctx: IRFunction) -> None: bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() - # TODO: This is a hack to support the old IR format where `deploy` is - # an instruction. in the future we should have two entry points, one - # for the initcode and one for the runtime code. - deploy_bb = None - after_deploy_bb = None - for i, bb in enumerate(ctx.basic_blocks): - if bb.instructions[0].opcode == "deploy": - deploy_bb = bb - after_deploy_bb = ctx.basic_blocks[i + 1] - break - - if deploy_bb is not None: - assert after_deploy_bb is not None, "No block after deploy block" - entry_block = after_deploy_bb - has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" - if has_constructor: - deploy_bb.add_cfg_in(ctx.basic_blocks[0]) - entry_block.add_cfg_in(deploy_bb) - else: - entry_block = ctx.basic_blocks[0] - for bb in ctx.basic_blocks: assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 9afaa5e6fd..598b8af7d5 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -4,7 +4,7 @@ from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -33,7 +33,6 @@ NO_OUTPUT_INSTRUCTIONS = frozenset( [ - "deploy", "mstore", "sstore", "dstore", @@ -56,9 +55,7 @@ ] ) -CFG_ALTERING_INSTRUCTIONS = frozenset( - ["jmp", "djmp", "jnz", "call", "staticcall", "invoke", "deploy"] -) +CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "djmp", "jnz", "call", "staticcall", "invoke"]) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -273,7 +270,7 @@ def _ir_operand_from_value(val: Any) -> IROperand: if isinstance(val, IROperand): return val - assert isinstance(val, int) + assert isinstance(val, int), val return IRLiteral(val) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 665fa0c6c2..9f26fa8ec0 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -9,7 +9,7 @@ MemType, ) -GLOBAL_LABEL = IRLabel("global") +GLOBAL_LABEL = IRLabel("__global") class IRFunction: @@ -18,7 +18,10 @@ class IRFunction: """ name: IRLabel # symbol name + entry_points: list[IRLabel] # entry points args: list + ctor_mem_size: Optional[int] + immutables_len: Optional[int] basic_blocks: list[IRBasicBlock] data_segment: list[IRInstruction] last_label: int @@ -28,14 +31,30 @@ def __init__(self, name: IRLabel = None) -> None: if name is None: name = GLOBAL_LABEL self.name = name + self.entry_points = [] self.args = [] + self.ctor_mem_size = None + self.immutables_len = None self.basic_blocks = [] self.data_segment = [] self.last_label = 0 self.last_variable = 0 + self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) + def add_entry_point(self, label: IRLabel) -> None: + """ + Add entry point. + """ + self.entry_points.append(label) + + def remove_entry_point(self, label: IRLabel) -> None: + """ + Remove entry point. + """ + self.entry_points.remove(label) + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: """ Append basic block to function. @@ -91,7 +110,7 @@ def remove_unreachable_blocks(self) -> int: removed = 0 new_basic_blocks = [] for bb in self.basic_blocks: - if not bb.is_reachable and bb.label.value != "global": + if not bb.is_reachable and bb.label not in self.entry_points: removed += 1 else: new_basic_blocks.append(bb) @@ -119,16 +138,10 @@ def normalized(self) -> bool: if len(bb.cfg_in) <= 1: continue - # Check if there is a conditional jump at the end + # Check if there is a branching jump at the end # of one of the predecessors - # - # TODO: this check could be: - # `if len(in_bb.cfg_out) > 1: return False` - # but the cfg is currently not calculated "correctly" for - # the special deploy instruction. for in_bb in bb.cfg_in: - jump_inst = in_bb.instructions[-1] - if jump_inst.opcode in ("jnz", "djmp"): + if len(in_bb.cfg_out) > 1: return False # The function is normalized diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 9f5c23df0b..c86d3a3d67 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -87,19 +87,35 @@ def _get_symbols_common(a: dict, b: dict) -> dict: return ret -def convert_ir_basicblock(ir: IRnode) -> IRFunction: - global_function = IRFunction() - _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) +def _findIRnode(ir: IRnode, value: str) -> Optional[IRnode]: + if ir.value == value: + return ir + for arg in ir.args: + if isinstance(arg, IRnode): + ret = _findIRnode(arg, value) + if ret is not None: + return ret + return None + + +def convert_ir_basicblock(ir: IRnode) -> tuple[IRFunction, IRFunction]: + deploy_ir = _findIRnode(ir, "deploy") + assert deploy_ir is not None + + deploy_venom = IRFunction() + _convert_ir_basicblock(deploy_venom, ir, {}, OrderedSet(), {}) + deploy_venom.get_basic_block().append_instruction("stop") - for i, bb in enumerate(global_function.basic_blocks): - if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: - bb.append_instruction("jmp", global_function.basic_blocks[i + 1].label) + runtime_ir = deploy_ir.args[1] + runtime_venom = IRFunction() + _convert_ir_basicblock(runtime_venom, runtime_ir, {}, OrderedSet(), {}) - revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) - revert_bb = global_function.append_basic_block(revert_bb) - revert_bb.append_instruction("revert", 0, 0) + # Connect unterminated blocks to the next with a jump + for i, bb in enumerate(runtime_venom.basic_blocks): + if not bb.is_terminated and i < len(runtime_venom.basic_blocks) - 1: + bb.append_instruction("jmp", runtime_venom.basic_blocks[i + 1].label) - return global_function + return deploy_venom, runtime_venom def _convert_binary_op( @@ -279,20 +295,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value in ["pass", "stop", "return"]: pass elif ir.value == "deploy": - memsize = ir.args[0].value - ir_runtime = ir.args[1] - padding = ir.args[2].value - assert isinstance(memsize, int), "non-int memsize" - assert isinstance(padding, int), "non-int padding" - - runtimeLabel = ctx.get_next_label() - - ctx.get_basic_block().append_instruction("deploy", memsize, runtimeLabel, padding) - - bb = IRBasicBlock(runtimeLabel, ctx) - ctx.append_basic_block(bb) - - _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + ctx.ctor_mem_size = ir.args[0].value + ctx.immutables_len = ir.args[2].value + return None elif ir.value == "seq": func_t = ir.passthrough_metadata.get("func_t", None) if ir.is_self_call: diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 43e8d47235..26699099b2 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -14,13 +14,11 @@ class NormalizationPass(IRPass): changes = 0 def _split_basic_block(self, bb: IRBasicBlock) -> None: - # Iterate over the predecessors of the basic block + # Iterate over the predecessors to this basic block for in_bb in list(bb.cfg_in): - jump_inst = in_bb.instructions[-1] assert bb in in_bb.cfg_out - - # Handle branching - if jump_inst.opcode in ("jnz", "djmp"): + # Handle branching in the predecessor bb + if len(in_bb.cfg_out) > 1: self._insert_split_basicblock(bb, in_bb) self.changes += 1 break diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 0c32c3b816..926f8df8a3 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -65,6 +65,8 @@ ] ) +_REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. @@ -75,13 +77,13 @@ # with the assembler. My suggestion is to let this be for now, and we can # refactor it later when we are finished phasing out the old IR. class VenomCompiler: - ctx: IRFunction + ctxs: list[IRFunction] label_counter = 0 visited_instructions: OrderedSet # {IRInstruction} visited_basicblocks: OrderedSet # {IRBasicBlock} - def __init__(self, ctx: IRFunction): - self.ctx = ctx + def __init__(self, ctxs: list[IRFunction]): + self.ctxs = ctxs self.label_counter = 0 self.visited_instructions = OrderedSet() self.visited_basicblocks = OrderedSet() @@ -91,8 +93,8 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: self.visited_basicblocks = OrderedSet() self.label_counter = 0 - stack = StackModel() - asm: list[str] = [] + asm: list[Any] = [] + top_asm = asm # Before emitting the assembly, we need to make sure that the # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) @@ -101,41 +103,49 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: # assembly generation. # This is a side-effect of how dynamic jumps are temporarily being used # to support the O(1) dispatcher. -> look into calculate_cfg() - calculate_cfg(self.ctx) - NormalizationPass.run_pass(self.ctx) - calculate_liveness(self.ctx) - - assert self.ctx.normalized, "Non-normalized CFG!" - - self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) - - # Append postambles - revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] - runtime = None - if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): - runtime = asm.pop() - - asm.extend(revert_postamble) - if runtime: - runtime.extend(revert_postamble) - asm.append(runtime) + for ctx in self.ctxs: + calculate_cfg(ctx) + NormalizationPass.run_pass(ctx) + calculate_liveness(ctx) + + assert ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, ctx.basic_blocks[0], StackModel()) + + # TODO make this property on IRFunction + if ctx.immutables_len is not None and ctx.ctor_mem_size is not None: + while asm[-1] != "JUMPDEST": + asm.pop() + asm.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + asm.extend(["_OFST", "_sym_subcode_size", ctx.immutables_len]) # stack: len + asm.extend(["_mem_deploy_start"]) # stack: len mem_ofst + asm.extend(["RETURN"]) + asm.extend(_REVERT_POSTAMBLE) + runtime_asm = [ + RuntimeHeader("_sym_runtime_begin", ctx.ctor_mem_size, ctx.immutables_len) + ] + asm.append(runtime_asm) + asm = runtime_asm + else: + asm.extend(_REVERT_POSTAMBLE) - # Append data segment - data_segments: dict[Any, list[Any]] = dict() - for inst in self.ctx.data_segment: - if inst.opcode == "dbname": - label = inst.operands[0].value - data_segments[label] = [DataHeader(f"_sym_{label}")] - elif inst.opcode == "db": - data_segments[label].append(f"_sym_{inst.operands[0].value}") + # Append data segment + data_segments: dict = dict() + for inst in ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") - extent_point = asm if not isinstance(asm[-1], list) else asm[-1] - extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + asm.extend(list(data_segments.values())) if no_optimize is False: - optimize_assembly(asm) + optimize_assembly(top_asm) - return asm + return top_asm def _stack_reorder( self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] @@ -397,20 +407,6 @@ def _generate_evm_for_instruction( assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) elif opcode == "assert": assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) - elif opcode == "deploy": - memsize = inst.operands[0].value - padding = inst.operands[2].value - # TODO: fix this by removing deploy opcode altogether me move emition to ir translation - while assembly[-1] != "JUMPDEST": - assembly.pop() - assembly.extend( - ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] - ) - assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len - assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst - assembly.extend(["RETURN"]) - assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore - assembly = assembly[-1] elif opcode == "iload": loc = inst.operands[0].value assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) From ddfce5273b39a199b194dd74f0f7f741efc03663 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 8 Jan 2024 02:37:01 +0800 Subject: [PATCH 153/201] feat: require type annotations for loop variables (#3596) this commit changes the vyper language to require type annotations for loop variables. that is, before, the following was allowed: ```vyper for i in [1, 2, 3]: pass ``` now, `i` is required to have a type annotation: ```vyper for i: uint256 in [1, 2, 3]: pass ``` this makes the annotation of loop variables consistent with the rest of vyper (it was previously a special case, that loop variables did not need to be annotated). the approach taken in this commit is to add a pre-parsing step which lifts out the type annotation into a separate data structure, and then splices it back in during the post-processing steps in `vyper/ast/parse.py`. this commit also simplifies a lot of analysis regarding for loops. notably, the possible types for the loop variable no longer needs to be iterated over, we can just propagate the type provided by the user. for this reason we also no longer need to use the typechecker speculation machinery for inferring the type of the loop variable. however, the NodeMetadata code is not removed because it might come in handy at a later date. --------- Co-authored-by: Charles Cooper --- examples/auctions/blind_auction.vy | 2 +- examples/tokens/ERC1155ownable.vy | 8 +- examples/voting/ballot.vy | 6 +- examples/wallet/wallet.vy | 4 +- .../functional/builtins/codegen/test_empty.py | 4 +- .../builtins/codegen/test_mulmod.py | 2 +- .../functional/builtins/codegen/test_slice.py | 2 +- .../codegen/features/iteration/test_break.py | 12 +- .../features/iteration/test_continue.py | 10 +- .../features/iteration/test_for_in_list.py | 150 ++++++++++-------- .../features/iteration/test_for_range.py | 56 +++---- .../codegen/features/test_assert.py | 4 +- .../codegen/features/test_internal_call.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/types/numbers/test_decimals.py | 2 +- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_bytes_zero_padding.py | 2 +- .../codegen/types/test_dynamic_array.py | 28 ++-- tests/functional/codegen/types/test_lists.py | 4 +- tests/functional/grammar/test_grammar.py | 2 +- .../exceptions/test_argument_exception.py | 4 +- .../exceptions/test_constancy_exception.py | 6 +- tests/functional/syntax/test_blockscope.py | 4 +- tests/functional/syntax/test_constants.py | 2 +- tests/functional/syntax/test_for_range.py | 58 +++---- tests/functional/syntax/test_list.py | 2 +- tests/unit/ast/nodes/test_hex.py | 2 +- .../ast/test_annotate_and_optimize_ast.py | 4 +- tests/unit/ast/test_pre_parser.py | 2 +- tests/unit/compiler/asm/test_asm_optimizer.py | 2 +- tests/unit/compiler/test_source_map.py | 2 +- .../unit/semantics/analysis/test_for_loop.py | 38 ++--- vyper/ast/grammar.lark | 3 +- vyper/ast/nodes.py | 23 ++- vyper/ast/parse.py | 66 +++++++- vyper/ast/pre_parser.py | 91 ++++++++++- vyper/builtins/functions.py | 2 +- vyper/codegen/stmt.py | 30 ++-- vyper/exceptions.py | 4 + vyper/semantics/analysis/local.py | 111 ++++--------- 40 files changed, 432 insertions(+), 330 deletions(-) diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 04f908f6d0..597aed57c7 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets: # Calculate refund for sender refund: uint256 = 0 - for i in range(MAX_BIDS): + for i: int128 in range(MAX_BIDS): # Note that loop may break sooner than 128 iterations if i >= _numBids if (i >= _numBids): break diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 30057582e8..e105a79133 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25 assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch" batchBalances: DynArray[uint256, BATCH_SIZE] = [] j: uint256 = 0 - for i in ids: + for i: uint256 in ids: batchBalances.append(self.balanceOf[accounts[j]][i]) j += 1 return batchBalances @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[receiver][ids[i]] += amounts[i] @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[msg.sender][ids[i]] -= amounts[i] @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID" assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break id: uint256 = ids[i] diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 0b568784a9..107716accf 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool: def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 - for i in range(2): + for i: int128 in range(2): self.proposals[i] = Proposal({ name: _proposalNames[i], voteCount: 0 @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address): assert self.voters[delegate_with_weight_to_forward].weight > 0 target: address = self.voters[delegate_with_weight_to_forward].delegate - for i in range(4): + for i: int128 in range(4): if self._delegated(target): target = self.voters[target].delegate # The following effectively detects cycles of length <= 5, @@ -157,7 +157,7 @@ def vote(proposal: int128): def _winningProposal() -> int128: winning_vote_count: int128 = 0 winning_proposal: int128 = 0 - for i in range(2): + for i: int128 in range(2): if self.proposals[i].voteCount > winning_vote_count: winning_vote_count = self.proposals[i].voteCount winning_proposal = i diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index e2515d9e62..231f538ecf 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -14,7 +14,7 @@ seq: public(int128) @external def __init__(_owners: address[5], _threshold: int128): - for i in range(5): + for i: uint256 in range(5): if _owners[i] != empty(address): self.owners[i] = _owners[i] self.threshold = _threshold @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda assert self.seq == _seq # # Iterates through all the owners and verifies that there signatures, # # given as the sigdata argument are correct - for i in range(5): + for i: uint256 in range(5): if sigdata[i][0] != 0: # If an invalid signature is given for an owner then the contract throws assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i] diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index c3627785dc..896c845da2 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @internal def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation): @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @external diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index ba82ebd5b8..31de1d9f22 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation): @external def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 - for i in range(256): + for i: uint256 in range(256): o = uint256_mulmod(o, o, modulus) if exponent & shift(1, 255 - i) != 0: o = uint256_mulmod(o, base, modulus) diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index a15a3eeb35..80936bbf82 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation): @external def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: inp: Bytes[50] = inp1 - for i in range(1, 11): + for i: uint256 in range(1, 11): inp = slice(inp, 1, 30 - i * 2) return inp """ diff --git a/tests/functional/codegen/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py index 8a08a11cc2..4abde9c617 100644 --- a/tests/functional/codegen/features/iteration/test_break.py +++ b/tests/functional/codegen/features/iteration/test_break.py @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(400): + for i: int128 in range(400): c = c / 1.2589 if c < 1.0: output = i @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c = c / 10.0 - for i in range(10): + for i: int128 in range(10): c = c / 1.2589 if c < 1.0: output = output + i @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation): def foo(n: int128) -> int128: c: decimal = convert(n, decimal) output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c /= 10.0 - for i in range(10): + for i: int128 in range(10): c /= 1.2589 if c < 1.0: output = output + i @@ -108,7 +108,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: break diff --git a/tests/functional/codegen/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py index 5f4f82a2de..1b2fcab460 100644 --- a/tests/functional/codegen/features/iteration/test_continue.py +++ b/tests/functional/codegen/features/iteration/test_continue.py @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation): code = """ @external def foo() -> bool: - for i in range(2): + for i: uint256 in range(2): continue return False return True @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += 1 continue x -= 1 @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += i continue return x @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): if i % 2 == 0: continue x += 1 @@ -83,7 +83,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: continue diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..5c7b5c6b1b 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -21,7 +21,7 @@ @external def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -33,7 +33,7 @@ def data() -> int128: @external def data() -> int128: s: DynArray[int128, 10] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -53,8 +53,8 @@ def data() -> int128: [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] ] ret: int128 = 0 - for ss in sss: - for s in ss: + for ss: DynArray[S, 10] in sss: + for s: S in ss: ret += s.x + s.y return ret""", sum(range(1, 11)), @@ -64,7 +64,7 @@ def data() -> int128: """ @external def data() -> int128: - for i in [3, 5, 7, 9]: + for i: int128 in [3, 5, 7, 9]: if i > 5: return i return -1""", @@ -76,7 +76,7 @@ def data() -> int128: @external def data() -> String[33]: xs: DynArray[String[33], 3] = ["hello", ",", "world"] - for x in xs: + for x: String[33] in xs: if x == ",": return x return "" @@ -88,7 +88,7 @@ def data() -> String[33]: """ @external def data() -> String[33]: - for x in ["hello", ",", "world"]: + for x: String[33] in ["hello", ",", "world"]: if x == ",": return x return "" @@ -100,7 +100,7 @@ def data() -> String[33]: """ @external def data() -> DynArray[String[33], 2]: - for x in [["hello", "world"], ["goodbye", "world!"]]: + for x: DynArray[String[33], 2] in [["hello", "world"], ["goodbye", "world!"]]: if x[1] == "world": return x return [] @@ -114,8 +114,8 @@ def data() -> DynArray[String[33], 2]: def data() -> int128: ret: int128 = 0 xss: int128[3][3] = [[1,2,3],[4,5,6],[7,8,9]] - for xs in xss: - for x in xs: + for xs: int128[3] in xss: + for x: int128 in xs: ret += x return ret""", sum(range(1, 10)), @@ -130,8 +130,8 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss in [[S({x:1, y:2})]]: - for s in ss: + for ss: S[1] in [[S({x:1, y:2})]]: + for s: S in ss: ret += s.x + s.y return ret""", 1 + 2, @@ -147,7 +147,7 @@ def data() -> address: 0xDCEceAF3fc5C0a63d195d69b1A90011B7B19650D ] count: int128 = 0 - for i in addresses: + for i: address in addresses: count += 1 if count == 2: return i @@ -174,7 +174,7 @@ def set(): @external def data() -> int128: - for i in self.x: + for i: int128 in self.x: if i > 5: return i return -1 @@ -198,7 +198,7 @@ def set(xs: DynArray[int128, 4]): @external def data() -> int128: t: int128 = 0 - for i in self.x: + for i: int128 in self.x: t += i return t """ @@ -227,7 +227,7 @@ def ret(i: int128) -> address: @external def iterate_return_second() -> address: count: int128 = 0 - for i in self.addresses: + for i: address in self.addresses: count += 1 if count == 2: return i @@ -258,7 +258,7 @@ def ret(i: int128) -> decimal: @external def i_return(break_count: int128) -> decimal: count: int128 = 0 - for i in self.readings: + for i: decimal in self.readings: if count == break_count: return i count += 1 @@ -284,7 +284,7 @@ def func(amounts: uint256[3]) -> uint256: total: uint256 = as_wei_value(0, "wei") # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -303,7 +303,7 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: total: uint256 = 0 # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -321,42 +321,42 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in range(4): + for i: int128 in range(4): p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass - for i in range(20): + for i: uint256 in range(20): pass """, # using index variable after loop """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass i: int128 = 100 # create new variable i i = 200 # look up the variable i and check whether it is in forvars @@ -372,25 +372,25 @@ def test_good_code(code, get_contract): RANGE_CONSTANT_CODE = [ ( """ -TREE_FIDDY: constant(int128) = 350 +TREE_FIDDY: constant(uint256) = 350 @external def a() -> uint256: x: uint256 = 0 - for i in range(TREE_FIDDY): + for i: uint256 in range(TREE_FIDDY): x += 1 return x""", 350, ), ( """ -ONE_HUNDRED: constant(int128) = 100 +ONE_HUNDRED: constant(uint256) = 100 @external def a() -> uint256: x: uint256 = 0 - for i in range(1, 1 + ONE_HUNDRED): + for i: uint256 in range(1, 1 + ONE_HUNDRED): x += 1 return x""", 100, @@ -401,9 +401,9 @@ def a() -> uint256: END: constant(int128) = 199 @external -def a() -> uint256: - x: uint256 = 0 - for i in range(START, END): +def a() -> int128: + x: int128 = 0 + for i: int128 in range(START, END): x += 1 return x""", 99, @@ -413,11 +413,23 @@ def a() -> uint256: @external def a() -> int128: x: int128 = 0 - for i in range(-5, -1): + for i: int128 in range(-5, -1): x += i return x""", -14, ), + ( + """ +@external +def a() -> uint256: + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + x: uint256 = 0 + for i: uint256 in a[2]: + x += i + return x + """, + 9, + ), ] @@ -436,7 +448,7 @@ def test_range_constant(get_contract, code, result): def data() -> int128: s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -451,7 +463,7 @@ def data() -> int128: def foo(): s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] += 1 """, ImmutableViolation, @@ -468,7 +480,7 @@ def set(): @external def data() -> int128: count: int128 = 0 - for i in self.s: + for i: int128 in self.s: self.s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -493,7 +505,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff(i) i += 1 """, @@ -519,7 +531,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.bar.foo: + for item: uint256 in self.my_array2.bar.foo: self.doStuff(i) i += 1 """, @@ -545,7 +557,7 @@ def doStuff(): @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff() i += 1 """, @@ -556,8 +568,8 @@ def _helper(): """ @external def foo(x: int128): - for i in range(4): - for i in range(5): + for i: int128 in range(4): + for i: int128 in range(5): pass """, NamespaceCollision, @@ -566,8 +578,8 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: - for i in [1,2]: + for i: int128 in [1,2]: + for i: int128 in [1,2]: pass """, NamespaceCollision, @@ -577,7 +589,7 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i = 2 """, ImmutableViolation, @@ -588,7 +600,7 @@ def foo(x: int128): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.pop() """, ImmutableViolation, @@ -599,7 +611,7 @@ def foo(): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.append(x) """, ImmutableViolation, @@ -610,7 +622,7 @@ def foo(): @external def foo(): xs: DynArray[DynArray[uint256, 5], 5] = [[1,2,3]] - for x in xs: + for x: DynArray[uint256, 5] in xs: x.pop() """, ImmutableViolation, @@ -629,7 +641,7 @@ def b(): @external def foo(): - for x in self.array: + for x: uint256 in self.array: self.a() """, ImmutableViolation, @@ -638,7 +650,7 @@ def foo(): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i += 2 """, ImmutableViolation, @@ -648,7 +660,7 @@ def foo(x: int128): """ @external def foo(): - for i in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -656,13 +668,13 @@ def foo(): """ @external def foo(): - for i in range(0): + for i: uint256 in range(0): pass """, """ @external def foo(): - for i in []: + for i: uint256 in []: pass """, """ @@ -670,14 +682,14 @@ def foo(): @external def foo(): - for i in FOO: + for i: uint256 in FOO: pass """, ( """ @external def foo(): - for i in range(5,3): + for i: uint256 in range(5,3): pass """, StructureException, @@ -686,7 +698,7 @@ def foo(): """ @external def foo(): - for i in range(5,3,-1): + for i: int128 in range(5,3,-1): pass """, ArgumentException, @@ -696,7 +708,7 @@ def foo(): @external def foo(): a: uint256 = 2 - for i in range(a): + for i: uint256 in range(a): pass """, StateAccessViolation, @@ -706,7 +718,7 @@ def foo(): @external def foo(): a: int128 = 6 - for i in range(a,a-3): + for i: int128 in range(a,a-3): pass """, StateAccessViolation, @@ -716,7 +728,7 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, ArgumentException, @@ -725,7 +737,7 @@ def foo(): """ @external def foo(): - for i in range(0,1,2): + for i: uint256 in range(0,1,2): pass """, ArgumentException, @@ -735,7 +747,7 @@ def foo(): """ @external def foo(): - for i in b"asdf": + for i: Bytes[1] in b"asdf": pass """, InvalidType, @@ -744,7 +756,7 @@ def foo(): """ @external def foo(): - for i in 31337: + for i: uint256 in 31337: pass """, InvalidType, @@ -753,7 +765,7 @@ def foo(): """ @external def foo(): - for i in bar(): + for i: uint256 in bar(): pass """, IteratorException, @@ -762,7 +774,7 @@ def foo(): """ @external def foo(): - for i in self.bar(): + for i: uint256 in self.bar(): pass """, IteratorException, @@ -772,11 +784,11 @@ def foo(): @external def test_for() -> int128: a: int128 = 0 - for i in range(max_value(int128), max_value(int128)+2): + for i: int128 in range(max_value(int128), max_value(int128)+2): a = i return a """, - TypeMismatch, + InvalidType, ), ( """ @@ -784,7 +796,7 @@ def test_for() -> int128: def test_for() -> int128: a: int128 = 0 b: uint256 = 0 - for i in range(5): + for i: int128 in range(5): a = i b = i return a diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e946447285..c661c46553 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -6,7 +6,7 @@ def test_basic_repeater(get_contract_with_gas_estimation): @external def repeat(z: int128) -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): x = x + z return(x) """ @@ -19,7 +19,7 @@ def test_range_bound(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, bound=6): + for i: uint256 in range(n, bound=6): x += i + 1 return x """ @@ -37,7 +37,7 @@ def test_range_bound_constant_end(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, 7, bound=6): + for i: uint256 in range(n, 7, bound=6): x += i + 1 return x """ @@ -58,7 +58,7 @@ def test_range_bound_two_args(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(1, n, bound=6): + for i: uint256 in range(1, n, bound=6): x += i + 1 return x """ @@ -80,7 +80,7 @@ def test_range_bound_two_runtime_args(get_contract, tx_failed): @external def repeat(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x += i return x """ @@ -109,7 +109,7 @@ def test_range_overflow(get_contract, tx_failed): @external def get_last(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x = i return x """ @@ -134,11 +134,11 @@ def test_digit_reverser(get_contract_with_gas_estimation): def reverse_digits(x: int128) -> int128: dig: int128[6] = [0, 0, 0, 0, 0, 0] z: int128 = x - for i in range(6): + for i: uint256 in range(6): dig[i] = z % 10 z = z / 10 o: int128 = 0 - for i in range(6): + for i: uint256 in range(6): o = o * 10 + dig[i] return o @@ -153,9 +153,9 @@ def test_more_complex_repeater(get_contract_with_gas_estimation): @external def repeat() -> int128: out: int128 = 0 - for i in range(6): + for i: uint256 in range(6): out = out * 10 - for j in range(4): + for j: int128 in range(4): out = out + j return(out) """ @@ -170,7 +170,7 @@ def test_offset_repeater(get_contract_with_gas_estimation, typ): @external def sum() -> {typ}: out: {typ} = 0 - for i in range(80, 121): + for i: {typ} in range(80, 121): out = out + i return out """ @@ -185,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101, bound=101): + for i: {typ} in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -205,7 +205,7 @@ def _bar() -> bool: @external def foo() -> bool: - for i in range(3): + for i: uint256 in range(3): self._bar() return True """ @@ -219,8 +219,8 @@ def test_return_inside_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for j in range(10): + for i: {typ} in range(10): + for j: {typ} in range(10): if j > 5: if i > a: return i @@ -254,14 +254,14 @@ def test_for_range_edge(get_contract, typ): def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x - 1, x, bound=1): + for i: {typ} in range(x - 1, x, bound=1): if i + 1 == max_value({typ}): found = True assert found found = False x = max_value({typ}) - 1 - for i in range(x - 1, x + 1, bound=2): + for i: {typ} in range(x - 1, x + 1, bound=2): if i + 1 == max_value({typ}): found = True assert found @@ -276,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x + 2, bound=2): + for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) @@ -289,8 +289,8 @@ def test_return_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -318,8 +318,8 @@ def test_return_void_nested_repeater(get_contract, typ, val): result: {typ} @internal def _final(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -347,8 +347,8 @@ def test_external_nested_repeater(get_contract, typ, val): code = f""" @external def foo(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -368,8 +368,8 @@ def test_external_void_nested_repeater(get_contract, typ, val): result: public({typ}) @external def foo(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -388,8 +388,8 @@ def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if a < 2: break return 6 diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index af189e6dca..df379d3f16 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -159,7 +159,7 @@ def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5 return True """ @@ -179,7 +179,7 @@ def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5, "because reasons" return True """ diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index f10d22ec99..422f53fdeb 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -152,7 +152,7 @@ def _increment(): @external def returnten() -> int128: - for i in range(10): + for i: uint256 in range(10): self._increment() return self.counter """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 671d424d60..891ed5aebe 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index fcf71f12f0..72171dd4b5 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -125,7 +125,7 @@ def test_harder_decimal_test(get_contract_with_gas_estimation): @external def phooey(inp: decimal) -> decimal: x: decimal = 10000.0 - for i in range(4): + for i: uint256 in range(4): x = x * inp return x diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 1ee9b8d835..882629de65 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -268,7 +268,7 @@ def test_zero_padding_with_private(get_contract): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py index f9fcf37b25..6597facd1b 100644 --- a/tests/functional/codegen/types/test_bytes_zero_padding.py +++ b/tests/functional/codegen/types/test_bytes_zero_padding.py @@ -10,7 +10,7 @@ def little_endian_contract(get_contract_module): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 70a68e3206..e47eda6042 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -969,7 +969,7 @@ def foo() -> (uint256, uint256, uint256, uint256, uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -981,7 +981,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: some_var: uint256 @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.some_var = x # test that typechecker for append args works self.my_array.append(self.some_var) @@ -994,9 +994,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) - for x in xs: + for x: uint256 in xs: self.my_array.pop() return self.my_array """, @@ -1008,7 +1008,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array, self.my_array.pop() """, @@ -1020,7 +1020,7 @@ def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array.pop(), self.my_array """, @@ -1033,7 +1033,7 @@ def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] i: uint256 = 0 - for x in xs: + for x: uint256 in xs: if i >= len(xs) - 1: break ys.append(x) @@ -1049,7 +1049,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -1061,9 +1061,9 @@ def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() return ys """, @@ -1075,9 +1075,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() ys.pop() # fail return ys @@ -1328,7 +1328,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: DynArray[Foo, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): e: Foobar = _baz[i].z f: uint256 = convert(e, uint256) sum += _baz[i].x * _baz[i].y + f @@ -1397,7 +1397,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: DynArray[Bar, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index b5b9538c20..ee287064e8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -566,7 +566,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: Foo[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _baz[i].x * _baz[i].y return sum """ @@ -608,7 +608,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: Bar[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 7dd8c35929..652102c376 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -106,6 +106,6 @@ def has_no_docstrings(c): @hypothesis.settings(max_examples=500) def test_grammar_bruteforce(code): if utf8_encodable(code): - _, _, reformatted_code = pre_parse(code + "\n") + _, _, _, reformatted_code = pre_parse(code + "\n") tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index 0b7ec21bdb..4240aec8d2 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -80,13 +80,13 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, """ @external def foo(): - for i in range(1, 2, 3, 4): + for i: uint256 in range(1, 2, 3, 4): pass """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4bd0b4fcb9..7adf9538c7 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -57,7 +57,7 @@ def foo() -> int128: return 5 @external def bar(): - for i in range(self.foo(), self.foo() + 1): + for i: int128 in range(self.foo(), self.foo() + 1): pass""", """ glob: int128 @@ -67,13 +67,13 @@ def foo() -> int128: return 5 @external def bar(): - for i in [1,2,3,4,self.foo()]: + for i: int128 in [1,2,3,4,self.foo()]: pass""", """ @external def foo(): x: int128 = 5 - for i in range(x): + for i: int128 in range(x): pass""", """ f:int128 diff --git a/tests/functional/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py index 942aa3fa68..466b5509ca 100644 --- a/tests/functional/syntax/test_blockscope.py +++ b/tests/functional/syntax/test_blockscope.py @@ -33,7 +33,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a = 1 """, @@ -41,7 +41,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a += 1 """, diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index ffd2f1faa0..7089dee3bb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i in range(CONST / 4): + for i: uint256 in range(CONST / 4): pass """, """ diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a9c3ad5cab..66981a90de 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -15,7 +15,7 @@ """ @external def foo(): - for a[1] in range(10): + for a[1]: uint256 in range(10): pass """, StructureException, @@ -26,7 +26,7 @@ def foo(): """ @external def bar(): - for i in range(1,2,bound=0): + for i: uint256 in range(1,2,bound=0): pass """, StructureException, @@ -38,7 +38,7 @@ def bar(): @external def foo(): x: uint256 = 100 - for _ in range(10, bound=x): + for _: uint256 in range(10, bound=x): pass """, StateAccessViolation, @@ -49,7 +49,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=5): + for _: uint256 in range(10, 20, bound=5): pass """, StructureException, @@ -60,7 +60,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=0): + for _: uint256 in range(10, 20, bound=0): pass """, StructureException, @@ -72,7 +72,7 @@ def foo(): @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2,extra=3): + for i: uint256 in range(x,x+1,bound=2,extra=3): pass """, ArgumentException, @@ -83,7 +83,7 @@ def bar(): """ @external def bar(): - for i in range(0): + for i: uint256 in range(0): pass """, StructureException, @@ -95,7 +95,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x): + for i: uint256 in range(x): pass """, StateAccessViolation, @@ -107,7 +107,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(0, x): + for i: uint256 in range(0, x): pass """, StateAccessViolation, @@ -118,7 +118,7 @@ def bar(): """ @external def repeat(n: uint256) -> uint256: - for i in range(0, n * 10): + for i: uint256 in range(0, n * 10): pass return n """, @@ -131,7 +131,7 @@ def repeat(n: uint256) -> uint256: @external def bar(): x:uint256 = 1 - for i in range(0, x + 1): + for i: uint256 in range(0, x + 1): pass """, StateAccessViolation, @@ -142,7 +142,7 @@ def bar(): """ @external def bar(): - for i in range(2, 1): + for i: uint256 in range(2, 1): pass """, StructureException, @@ -154,7 +154,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x, x): + for i: uint256 in range(x, x): pass """, StateAccessViolation, @@ -166,7 +166,7 @@ def bar(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i: int128 in range(x, x + 10): pass """, StateAccessViolation, @@ -177,7 +177,7 @@ def foo(): """ @external def repeat(n: uint256) -> uint256: - for i in range(n, 6): + for i: uint256 in range(n, 6): pass return x """, @@ -190,7 +190,7 @@ def repeat(n: uint256) -> uint256: @external def foo(x: int128): y: int128 = 7 - for i in range(x, x + y): + for i: int128 in range(x, x + y): pass """, StateAccessViolation, @@ -201,7 +201,7 @@ def foo(x: int128): """ @external def bar(x: uint256): - for i in range(3, x): + for i: uint256 in range(3, x): pass """, StateAccessViolation, @@ -215,12 +215,12 @@ def bar(x: uint256): @external def foo(): - for i in range(FOO, BAR): + for i: uint256 in range(FOO, BAR): pass """, TypeMismatch, - "Iterator values are of different types", - "range(FOO, BAR)", + "Given reference has type int128, expected uint256", + "FOO", ), ( """ @@ -228,12 +228,12 @@ def foo(): @external def foo(): - for i in range(10, bound=FOO): + for i: int128 in range(10, bound=FOO): pass """, StructureException, "Bound must be at least 1", - "-1", + "FOO", ), ] @@ -252,41 +252,41 @@ def test_range_fail(bad_code, error_type, message, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) assert message == exc_info.value.message - assert source_code == exc_info.value.args[1].node_source_code + assert source_code == exc_info.value.args[1].get_original_node().node_source_code valid_list = [ """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass """, """ @external def foo(): - for i in range(10, 20): + for i: uint256 in range(10, 20): pass """, """ @external def foo(): x: int128 = 5 - for i in range(1, x, bound=4): + for i: int128 in range(1, x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(x, bound=4): + for i: int128 in range(x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(0, x, bound=4): + for i: int128 in range(0, x, bound=4): pass """, """ @@ -295,7 +295,7 @@ def kick(): nonpayable foos: Foo[3] @external def kick_foos(): - for foo in self.foos: + for foo: Foo in self.foos: foo.kick() """, ] diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index db41de5526..3936f8c220 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -306,7 +306,7 @@ def foo(): @external def foo(): x: DynArray[uint256, 3] = [1, 2, 3] - for i in [[], []]: + for i: DynArray[uint256, 3] in [[], []]: x = i """, ] diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index d413340083..a6bc3147e6 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -24,7 +24,7 @@ def foo(): """ @external def foo(): - for i in [0x6b175474e89094c44da98b954eedeac495271d0F]: + for i: address in [0x6b175474e89094c44da98b954eedeac495271d0F]: pass """, """ diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 16ce6fe631..b202f6d8a3 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, reformatted_code = pre_parse(source_code) + _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types) + annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) return py_ast, reformatted_code diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 682c13ca84..020e83627c 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -173,7 +173,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _ = pre_parse(code) + settings, _, _, _ = pre_parse(code) assert settings == pre_parse_settings diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 44b823757c..b2851e908a 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -58,7 +58,7 @@ def ctor_only(): @internal def runtime_only(): - for i in range(10): + for i: uint256 in range(10): self.s += 1 @external diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index c9a152b09c..5b478dd2aa 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -6,7 +6,7 @@ @internal def _baz(a: int128) -> int128: b: int128 = a - for i in range(2, 5): + for i: int128 in range(2, 5): b *= i if b > 31337: break diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e2c0f555af..607587cc28 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -22,7 +22,7 @@ def foo(): @internal def bar(): self.foo() - for i in self.a: + for i: uint256 in self.a: pass """ vyper_module = parse_to_ast(code) @@ -42,7 +42,7 @@ def foo(a: uint256[3]) -> uint256[3]: @internal def bar(): a: uint256[3] = [1,2,3] - for i in a: + for i: uint256 in a: self.foo(a) """ vyper_module = parse_to_ast(code) @@ -56,7 +56,7 @@ def test_modify_iterator(dummy_input_bundle): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.a[0] = 1 """ vyper_module = parse_to_ast(code) @@ -70,7 +70,7 @@ def test_bad_keywords(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, boundddd=10): + for i: uint256 in range(n, boundddd=10): x += i """ vyper_module = parse_to_ast(code) @@ -84,7 +84,7 @@ def test_bad_bound(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, bound=n): + for i: uint256 in range(n, bound=n): x += i """ vyper_module = parse_to_ast(code) @@ -103,7 +103,7 @@ def foo(): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.foo() """ vyper_module = parse_to_ast(code) @@ -126,7 +126,7 @@ def bar(): @internal def baz(): - for i in self.a: + for i: uint256 in self.a: self.bar() """ vyper_module = parse_to_ast(code) @@ -138,32 +138,32 @@ def baz(): """ @external def main(): - for j in range(3): + for j: uint256 in range(3): x: uint256 = j y: uint16 = j """, # GH issue 3212 """ @external def foo(): - for i in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1]: - for j in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + for j: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1,2,3]: - for j in [1,2,3]: - b:uint256 = j + i - c:uint16 = i + for i: uint256 in [1,2,3]: + for j: uint256 in [1,2,3]: + b: uint256 = j + i + c: uint16 = i """, # GH issue 3374 ] diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 7889473b19..234e96e552 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -178,8 +178,7 @@ body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] -// TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` -loop_variable: NAME [":" NAME] +loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index efab5117d4..7a8c7443b7 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,15 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata") +NODE_BASE_ATTRIBUTES = ( + "_children", + "_depth", + "_parent", + "ast_type", + "node_id", + "_metadata", + "_original_node", +) NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", @@ -257,6 +265,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): self.set_parent(parent) self._children: set = set() self._metadata: NodeMetadata = NodeMetadata() + self._original_node = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset @@ -411,12 +420,16 @@ def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata - # set the folded node's parent so that get_ancestor works - # this is mainly important for error messages. - node._parent = self._parent + # set the "original node" so that exceptions can point to the original + # node and not the folded node + node = copy.copy(node) + node._original_node = self self._metadata["folded_value"] = node + def get_original_node(self) -> "VyperNode": + return self._original_node or self + def _try_fold(self) -> "VyperNode": """ Attempt to constant-fold the content of a node, returning the result of @@ -1546,7 +1559,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 38a9d31695..b657cf2245 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,7 +54,7 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) except SyntaxError as e: @@ -73,11 +73,15 @@ def parse_to_ast_with_settings( py_ast, source_code, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -113,11 +117,13 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets + _loop_var_annotations: dict[int, dict[str, Any]] def __init__( self, source_code: str, - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -127,11 +133,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -213,6 +219,47 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not raw_annotation: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + try: + annotation = python_ast.parse(raw_annotation, mode="eval") + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(raw_annotation) + tokens.mark_tokens(annotation) + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + assert isinstance(annotation, python_ast.Expression) + annotation = annotation.body + + old_target = node.target + new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) + node.target = new_target + + self.generic_visit(node) + + return node + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -355,7 +402,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, - modification_offsets: Optional[ModificationOffsets] = None, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -369,6 +417,9 @@ def annotate_python_ast( The AST to be annotated and optimized. source_code : str The originating source code of the AST. + loop_var_annotations: dict, optional + A mapping of line numbers of `For` nodes to the type annotation of the iterator + extracted during pre-parsing. modification_offsets : dict, optional A mapping of class names to their original class types. @@ -381,6 +432,7 @@ def annotate_python_ast( visitor = AnnotatingVisitor( source_code, modification_offsets, + for_loop_annotations, tokens, source_id, module_path=module_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b949a242bb..c7e6f3698f 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,3 +1,4 @@ +import enum import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize @@ -43,6 +44,64 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: ) +class ForParserState(enum.Enum): + NOT_RUNNING = enum.auto() + START_SOON = enum.auto() + RUNNING = enum.auto() + + +# a simple state machine which allows us to handle loop variable annotations +# (which are rejected by the python parser due to pep-526, so we scoop up the +# tokens between `:` and `in` and parse them and add them back in later). +class ForParser: + def __init__(self, code): + self._code = code + self.annotations = {} + self._current_annotation = None + + self._state = ForParserState.NOT_RUNNING + self._current_for_loop = None + + def consume(self, token): + # state machine: we can start slurping tokens soon + if token.type == NAME and token.string == "for": + # note: self._state should be NOT_RUNNING here, but we don't sanity + # check here as that should be an error the parser will handle. + self._state = ForParserState.START_SOON + self._current_for_loop = token.start + + if self._state == ForParserState.NOT_RUNNING: + return False + + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. + + # state machine: end slurping tokens + if token.type == NAME and token.string == "in": + self._state = ForParserState.NOT_RUNNING + self.annotations[self._current_for_loop] = self._current_annotation or [] + self._current_annotation = None + return False + + if self._state != ForParserState.RUNNING: + return False + + # slurp the token + self._current_annotation.append(token) + return True + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} @@ -51,7 +110,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -60,9 +119,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"). + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. Parameters ---------- @@ -71,21 +132,25 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Returns ------- - dict - Mapping of offsets where source was modified. + Settings + Compilation settings based on the directives in the source code + ModificationOffsets + A mapping of class names to their original class types. + dict[tuple[int, int], str] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) - for i in range(len(token_list)): - token = token_list[i] + for token in token_list: toks = [token] typ = token.type @@ -146,8 +211,18 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - result.extend(toks) + + if not for_parser.consume(token): + result.extend(toks) + except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return settings, modification_offsets, untokenize(result).decode("utf-8") + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + v_source = untokenize(v) + # untokenize adds backslashes and whitespace, strip them. + v_source = v_source.replace("\\", "").strip() + for_loop_annotations[k] = v_source + + return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index c896fc7ef6..39d97c4abe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2157,7 +2157,7 @@ def build_IR(self, expr, args, kwargs, context): z = x / 2.0 + 0.5 y: decimal = x - for i in range(256): + for i: uint256 in range(256): if z == y: break y = z diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..a47faefeb1 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -33,7 +33,7 @@ ) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.semantics.types.shortcuts import UINT256_T class Stmt: @@ -231,19 +231,17 @@ def parse_For(self): return self._parse_For_list() def _parse_For_range(self): - # TODO make sure type always gets annotated - if "type" in self.stmt.target._metadata: - iter_typ = self.stmt.target._metadata["type"] - else: - iter_typ = INT256_T + assert "type" in self.stmt.target.target._metadata + target_type = self.stmt.target.target._metadata["type"] # Get arg0 - for_iter: vy_ast.Call = self.stmt.iter - args_len = len(for_iter.args) + range_call: vy_ast.Call = self.stmt.iter + assert isinstance(range_call, vy_ast.Call) + args_len = len(range_call.args) if args_len == 1: - arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + arg0, arg1 = (IRnode.from_list(0, typ=target_type), range_call.args[0]) elif args_len == 2: - arg0, arg1 = for_iter.args + arg0, arg1 = range_call.args else: # pragma: nocover raise TypeCheckFailure("unreachable: bad # of arguments to range()") @@ -251,7 +249,7 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) end = Expr.parse_value_expr(arg1, self.context) kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + s.arg: Expr.parse_value_expr(s.value, self.context) for s in range_call.keywords } if "bound" in kwargs: @@ -270,9 +268,9 @@ def _parse_For_range(self): if rounds_bound < 1: # pragma: nocover raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) - iptr = self.context.new_variable(varname, iter_typ) + varname = self.stmt.target.target.id + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) self.context.forvars[varname] = True @@ -297,11 +295,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..51f3fea14c 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -92,6 +92,10 @@ def __str__(self): node = value[1] if isinstance(value, tuple) else value node_msg = "" + if isinstance(node, vy_ast.VyperNode): + # folded AST nodes contain pointers to the original source + node = node.get_original_node() + try: source_annotation = annotate_source_code( # add trailing space because EOF exceptions point one char beyond the length diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91fb2c21f0..169c71269d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,13 +1,11 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -40,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -347,8 +344,10 @@ def visit_Expr(self, node): self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): - if isinstance(node.iter, vy_ast.Subscript): - raise StructureException("Cannot iterate over a nested list", node.iter) + if not isinstance(node.target.target, vy_ast.Name): + raise StructureException("Invalid syntax for loop iterator", node.target.target) + + target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -356,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _validate_range_call(node.iter) else: # iteration over a variable or literal list @@ -364,14 +363,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -415,65 +410,28 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) + target_name = node.target.target.id + with self.namespace.enter_scope(): + self.namespace[target_name] = VarInfo( + target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) - for_loop_exceptions = [] - iter_name = node.target.id - for possible_target_type in type_list: - # type check the for loop body using each possible type for iterator value + for stmt in node.body: + self.visit(stmt) - with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo( - possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT - ) + self.expr_visitor.visit(node.target.target, target_type) - try: - with NodeMetadata.enter_typechecker_speculation(): - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) + elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + args = node.iter.args + kwargs = [s.value for s in node.iter.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + else: + iter_type = get_exact_type_from_node(node.iter) + self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -750,25 +708,18 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _validate_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range() :return: None """ + assert node.func.get("id") == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] - all_args = (start, end, *kwargs.values()) - for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) - if "bound" in kwargs: bound = kwargs["bound"] if bound.has_folded_value: @@ -787,5 +738,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list From a1fd228cb9936c3e4bbca6f3ee3fb4426ef45490 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Mon, 8 Jan 2024 18:12:59 +0200 Subject: [PATCH 154/201] feat: add `bb` and `bb_runtime` output options (#3700) add `bb` and `bb_runtime` output options for dumping venom output. disable this output format in tests for now since many vyper contracts still will not compile to venom. --- tests/conftest.py | 43 +++++++++++++------ .../unit/cli/vyper_json/test_compile_json.py | 11 +++-- vyper/compiler/__init__.py | 4 ++ vyper/compiler/output.py | 8 ++++ vyper/compiler/phases.py | 8 +--- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 51b4b4459a..e673f17b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,14 @@ def pytest_addoption(parser): parser.addoption("--enable-compiler-debug-mode", action="store_true") +@pytest.fixture(scope="module") +def output_formats(): + output_formats = compiler.OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] + return output_formats + + @pytest.fixture(scope="module") def optimize(pytestconfig): flag = pytestconfig.getoption("optimize") @@ -281,7 +289,14 @@ def ir_compiler(ir, *args, **kwargs): def _get_contract( - w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs + w3, + source_code, + optimize, + output_formats, + *args, + override_opt_level=None, + input_bundle=None, + **kwargs, ): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) @@ -289,7 +304,7 @@ def _get_contract( out = compiler.compile_code( source_code, # test that all output formats can get generated - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing @@ -309,17 +324,17 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize): +def get_contract(w3, optimize, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize): +def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, *args, **kwargs) + contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -329,15 +344,15 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize): +def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize): +def get_contract_module(optimize, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -350,18 +365,18 @@ def get_contract_module(optimize): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_module -def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): +def _deploy_blueprint_for(w3, source_code, optimize, output_formats, initcode_prefix=b"", **kwargs): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) settings.optimize = optimize out = compiler.compile_code( source_code, - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) @@ -394,9 +409,9 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize): +def deploy_blueprint_for(w3, optimize, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) + return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) return deploy_blueprint_for diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index a50946ba21..c805e2b5b1 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,18 +113,23 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") + # remove bb and bb_runtime from output formats + # because they require venom (experimental) + output_formats = OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] foo = compile_from_file_input( - foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + foo_input, output_formats=output_formats, input_bundle=input_bundle ) library_input = input_bundle.load_file("contracts/library.vy") library = compile_from_file_input( - library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + library_input, output_formats=output_formats, input_bundle=input_bundle ) bar_input = input_bundle.load_file("contracts/bar.vy") bar = compile_from_file_input( - bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + bar_input, output_formats=output_formats, input_bundle=input_bundle ) compile_code_results = { diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0f7d7a8014..9297f9e3c3 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -23,6 +23,8 @@ # requires ir_node "external_interface": output.build_external_interface_output, "interface": output.build_interface_output, + "bb": output.build_bb_output, + "bb_runtime": output.build_bb_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, @@ -84,6 +86,8 @@ def compile_from_file_input( two arguments - the name of the contract, and the exception that was raised no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool + Use experimental codegen. Defaults to False Returns ------- diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 8ccf6abee1..5e11a20139 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -84,6 +84,14 @@ def build_interface_output(compiler_data: CompilerData) -> str: return out +def build_bb_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[0] + + +def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[1] + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 850adcfea3..ba6ccbda20 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -174,9 +174,7 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes( - self.global_ctx, self.settings.optimize, self.settings.experimental_codegen - ) + return generate_ir_nodes(self.global_ctx, self.settings.optimize) @property def ir_nodes(self) -> IRnode: @@ -272,9 +270,7 @@ def generate_annotated_ast( return vyper_module, symbol_tables -def generate_ir_nodes( - global_ctx: ModuleT, optimize: OptimizationLevel, experimental_codegen: bool -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. From 06fa46a53ee2134951ee3cd9a8f46dcceb61f620 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 10 Jan 2024 18:23:45 -0500 Subject: [PATCH 155/201] refactor: constant folding (#3719) refactor constant folding into a visitor class, clean up a couple passes this moves responsibility for knowing how to fold a node off the individual AST node implementations and into the ConstantFolder visitor. by adding a dependency to get_namespace() it also makes constant folding more generic; soon we can rely on more things being in the global namespace at constant folding time. --- tests/functional/builtins/folding/test_abs.py | 7 +- .../builtins/folding/test_addmod_mulmod.py | 7 +- .../builtins/folding/test_bitwise.py | 16 +- .../builtins/folding/test_epsilon.py | 7 +- .../builtins/folding/test_floor_ceil.py | 7 +- .../folding/test_fold_as_wei_value.py | 10 +- .../builtins/folding/test_keccak_sha.py | 15 +- tests/functional/builtins/folding/test_len.py | 15 +- .../builtins/folding/test_min_max.py | 15 +- .../builtins/folding/test_powmod.py | 7 +- tests/functional/grammar/test_grammar.py | 4 +- tests/functional/syntax/test_bool.py | 2 +- .../unit/ast/nodes/test_fold_binop_decimal.py | 13 +- tests/unit/ast/nodes/test_fold_binop_int.py | 15 +- tests/unit/ast/nodes/test_fold_boolop.py | 6 +- tests/unit/ast/nodes/test_fold_compare.py | 12 +- tests/unit/ast/nodes/test_fold_subscript.py | 4 +- tests/unit/ast/nodes/test_fold_unaryop.py | 6 +- tests/utils.py | 9 + vyper/ast/nodes.py | 188 +---------- vyper/ast/nodes.pyi | 1 - vyper/builtins/functions.py | 7 +- vyper/exceptions.py | 2 +- vyper/semantics/analysis/local.py | 3 +- vyper/semantics/analysis/module.py | 9 +- vyper/semantics/analysis/pre_typecheck.py | 298 ++++++++++++------ vyper/semantics/analysis/utils.py | 2 + vyper/semantics/types/base.py | 2 +- vyper/semantics/types/module.py | 16 +- vyper/semantics/types/user.py | 11 + 30 files changed, 337 insertions(+), 379 deletions(-) diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 68131678fa..c954380def 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -2,8 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType @@ -19,9 +18,9 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"abs({a})") + vyper_ast = parse_and_fold(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a) == new_node.value == abs(a) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 1d789f1655..e6a9fc193f 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -2,8 +2,7 @@ from hypothesis import assume, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) @@ -22,8 +21,8 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") + vyper_ast = parse_and_fold(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 53a6d333a0..c1ff7674bb 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -29,7 +29,7 @@ def foo(a: uint256, b: uint256) -> uint256: contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") + vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -48,10 +48,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case @@ -78,10 +77,9 @@ def foo(a: int256, b: uint256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. @@ -105,7 +103,7 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"~{value}") + vyper_ast = parse_and_fold(f"~{value}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 4f5e9434ec..7bc2afe757 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("typ_name", ["decimal"]) @@ -13,8 +12,8 @@ def foo() -> {typ_name}: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") + vyper_ast = parse_and_fold(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 04921e504e..9e63c7b099 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -4,8 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_decimals = st.decimals( min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 @@ -28,8 +27,8 @@ def foo(a: decimal) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 4287615bab..01af646a16 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.builtins import functions as vy_fn from vyper.utils import SizeLimits @@ -30,9 +30,9 @@ def foo(a: decimal) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -49,8 +49,8 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index 8da420538f..3b5f99891f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~' # NOQA: E501 @@ -20,9 +19,9 @@ def foo(a: String[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") + vyper_ast = parse_and_fold(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -39,9 +38,9 @@ def foo(a: Bytes[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -60,8 +59,8 @@ def foo(a: Bytes[100]) -> bytes32: value = f"0x{value.hex()}" - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index 967f906555..6d59751748 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("length", [0, 1, 32, 33, 64, 65, 1024]) @@ -15,9 +14,9 @@ def foo(a: String[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") + vyper_ast = parse_and_fold(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -33,9 +32,9 @@ def foo(a: Bytes[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") + vyper_ast = parse_and_fold(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value.encode()) == new_node.value @@ -51,8 +50,8 @@ def foo(a: Bytes[1024]) -> uint256: value = f"0x{'00' * length}" - vyper_ast = vy_ast.parse_to_ast(f"len({value})") + vyper_ast = parse_and_fold(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 36a611fa1b..752b64eb04 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.utils import SizeLimits st_decimals = st.decimals( @@ -29,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -48,9 +47,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -67,8 +66,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index a3c2567f58..ad1197e8e3 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256) @@ -19,8 +18,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") + vyper_ast = parse_and_fold(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 652102c376..351793b28e 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -4,7 +4,7 @@ import hypothesis import hypothesis.strategies as st import pytest -from hypothesis import assume, given +from hypothesis import HealthCheck, assume, given from hypothesis.extra.lark import LarkStrategy from vyper.ast import Module, parse_to_ast @@ -103,7 +103,7 @@ def has_no_docstrings(c): @pytest.mark.fuzzing @given(code=from_grammar().filter(lambda c: utf8_encodable(c))) -@hypothesis.settings(max_examples=500) +@hypothesis.settings(max_examples=500, suppress_health_check=[HealthCheck.too_slow]) def test_grammar_bruteforce(code): if utf8_encodable(code): _, _, _, reformatted_code = pre_parse(code + "\n") diff --git a/tests/functional/syntax/test_bool.py b/tests/functional/syntax/test_bool.py index 48ed37321a..5388a92b95 100644 --- a/tests/functional/syntax/test_bool.py +++ b/tests/functional/syntax/test_bool.py @@ -37,7 +37,7 @@ def foo(): def foo() -> bool: return (1 == 2) <= (1 == 1) """, - TypeMismatch, + InvalidOperation, ), """ @external diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index e426a11de9..a75d114f88 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -4,7 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException st_decimals = st.decimals( @@ -28,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -45,11 +45,8 @@ def foo(a: decimal, b: decimal) -> decimal: def test_binop_pow(): # raises because Vyper does not support decimal exponentiation - vyper_ast = vy_ast.parse_to_ast("3.1337 ** 4.2") - old_node = vyper_ast.body[0].value - with pytest.raises(TypeMismatch): - old_node.get_folded_value() + _ = parse_and_fold("3.1337 ** 4.2") @pytest.mark.fuzzing @@ -72,8 +69,8 @@ def foo({input_value}) -> decimal: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = -(2**127) <= expected < 2**127 diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index 904b36c167..d9340927fe 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import ZeroDivisionException st_int32 = st.integers(min_value=-(2**32), max_value=2**32) @@ -24,9 +24,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -54,9 +54,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: @@ -83,7 +83,7 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") + vyper_ast = parse_and_fold(f"{left} ** {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -112,9 +112,8 @@ def foo({input_value}) -> int128: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) - try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = True diff --git a/tests/unit/ast/nodes/test_fold_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py index 3c42da0d26..082e6f35c3 100644 --- a/tests/unit/ast/nodes/test_fold_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold variables = "abcdefghij" @@ -24,7 +24,7 @@ def foo({input_value}) -> bool: literal_op = f" {comparator} ".join(str(i) for i in values) - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -52,7 +52,7 @@ def foo({input_value}) -> bool: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, comparators)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index 2b7c0f09d7..aab8ac0b2d 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import UnfoldableNode @@ -19,7 +19,7 @@ def foo(a: int128, b: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -39,7 +39,7 @@ def foo(a: uint128, b: uint128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -63,7 +63,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") + vyper_ast = parse_and_fold(f"{left} in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -92,7 +92,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") + vyper_ast = parse_and_fold(f"{left} not in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -106,7 +106,7 @@ def bar(a: int128) -> bool: @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_type_mismatch(op): - vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") + vyper_ast = parse_and_fold(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py index 1884abf73b..3ed26d07b7 100644 --- a/tests/unit/ast/nodes/test_fold_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.fuzzing @@ -19,7 +19,7 @@ def foo(array: int128[10], idx: uint256) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") + vyper_ast = parse_and_fold(f"{array}[{idx}]") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py index ff48adfe71..af72f5f8b0 100644 --- a/tests/unit/ast/nodes/test_fold_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -1,6 +1,6 @@ import pytest -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.parametrize("bool_cond", [True, False]) @@ -12,7 +12,7 @@ def foo(a: bool) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") + vyper_ast = parse_and_fold(f"not {bool_cond}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -30,7 +30,7 @@ def foo(a: bool) -> bool: contract = get_contract(source) literal_op = f"{'not ' * count}{bool_cond}" - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/utils.py b/tests/utils.py index 0c89c39ff3..b8a6b493d8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ import contextlib import os +from vyper import ast as vy_ast +from vyper.semantics.analysis.pre_typecheck import pre_typecheck + @contextlib.contextmanager def working_directory(directory): @@ -10,3 +13,9 @@ def working_directory(directory): yield finally: os.chdir(tmp) + + +def parse_and_fold(source_code): + ast = vy_ast.parse_to_ast(source_code) + pre_typecheck(ast) + return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 7a8c7443b7..90365c63d5 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -400,21 +400,11 @@ def get_folded_value(self) -> "VyperNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. - - - The returned value is cached on `_metadata["folded_value"]`. - - For constant/literal nodes, the node should be directly returned - without caching to the metadata. """ - if self.is_literal_value: - return self - - if "folded_value" not in self._metadata: - res = self._try_fold() # possibly throws UnfoldableNode - self._set_folded_value(res) - - return self._metadata["folded_value"] + try: + return self._metadata["folded_value"] + except KeyError: + raise UnfoldableNode("not foldable", self) def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once @@ -422,7 +412,9 @@ def _set_folded_value(self, node: "VyperNode") -> None: # set the "original node" so that exceptions can point to the original # node and not the folded node - node = copy.copy(node) + cls = node.__class__ + # make a fresh copy so that the node metadata is fresh. + node = cls(**{i: getattr(node, i) for i in node.get_fields() if hasattr(node, i)}) node._original_node = self self._metadata["folded_value"] = node @@ -430,17 +422,6 @@ def _set_folded_value(self, node: "VyperNode") -> None: def get_original_node(self) -> "VyperNode": return self._original_node or self - def _try_fold(self) -> "VyperNode": - """ - Attempt to constant-fold the content of a node, returning the result of - constant-folding if possible. - - If a node cannot be folded, it should raise `UnfoldableNode`. This - base implementation acts as a catch-all to raise on any inherited - classes that do not implement the method. - """ - raise UnfoldableNode(f"{type(self)} cannot be folded") - def validate(self) -> None: """ Validate the content of a node. @@ -919,10 +900,6 @@ class List(ExprNode): def is_literal_value(self): return all(e.is_literal_value for e in self.elements) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class Tuple(ExprNode): __slots__ = ("elements",) @@ -936,10 +913,6 @@ def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class NameConstant(Constant): __slots__ = () @@ -960,10 +933,6 @@ class Dict(ExprNode): def is_literal_value(self): return all(v.is_literal_value for v in self.values) - def _try_fold(self) -> ExprNode: - values = [v.get_folded_value() for v in self.values] - return type(self).from_node(self, values=values) - class Name(ExprNode): __slots__ = ("id",) @@ -972,27 +941,6 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the unary operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - operand = self.operand.get_folded_value() - - if isinstance(self.op, Not) and not isinstance(operand, NameConstant): - raise UnfoldableNode("not a boolean!", self.operand) - if isinstance(self.op, USub) and not isinstance(operand, Num): - raise UnfoldableNode("not a number!", self.operand) - if isinstance(self.op, Invert) and not isinstance(operand, Int): - raise UnfoldableNode("not an int!", self.operand) - - value = self.op._op(operand.value) - return type(operand).from_node(self, value=value) - class Operator(VyperNode): pass @@ -1021,30 +969,6 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the arithmetic operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if type(left) is not type(right): - raise UnfoldableNode("invalid operation", self) - if not isinstance(left, Num): - raise UnfoldableNode("not a number!", self.left) - - # this validation is performed to prevent the compiler from hanging - # on very large shifts and improve the error message for negative - # values. - if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) - - value = self.op._op(left.value, right.value) - return type(left).from_node(self, value=value) - class Add(Operator): __slots__ = () @@ -1170,24 +1094,6 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the boolean operation. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - values = [v.get_folded_value() for v in self.values] - - if any(not isinstance(v, NameConstant) for v in values): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - values = [v.value for v in values] - value = self.op._op(values) - return NameConstant.from_node(self, value=value) - class And(Operator): __slots__ = () @@ -1225,40 +1131,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the comparison. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if not isinstance(left, Constant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - # CMC 2022-08-04 we could probably remove these evaluation rules as they - # are taken care of in the IR optimizer now. - if isinstance(self.op, (In, NotIn)): - if not isinstance(right, List): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if next((i for i in right.elements if not isinstance(i, Constant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if len(set([type(i) for i in right.elements])) > 1: - raise UnfoldableNode("List contains multiple literal types") - value = self.op._op(left.value, [i.value for i in right.elements]) - return NameConstant.from_node(self, value=value) - - if not isinstance(left, type(right)): - raise UnfoldableNode("Cannot compare different literal types") - - if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)): - raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self) - - value = self.op._op(left.value, right.value) - return NameConstant.from_node(self, value=value) - class Eq(Operator): __slots__ = () @@ -1315,21 +1187,6 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") - # try checking if this is a builtin, which is foldable - def _try_fold(self): - if not isinstance(self.func, Name): - raise UnfoldableNode("not a builtin", self) - - # cursed import cycle! - from vyper.builtins.functions import DISPATCH_TABLE - - func_name = self.func.id - if func_name not in DISPATCH_TABLE: - raise UnfoldableNode("not a builtin", self) - - builtin_t = DISPATCH_TABLE[func_name] - return builtin_t._try_fold(self) - class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1342,37 +1199,6 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the subscript. - - This method reduces an indexed reference to a literal array into the value - within the array, e.g. `["foo", "bar"][1]` becomes `"bar"` - - Returns - ------- - ExprNode - Node representing the result of the evaluation. - """ - slice_ = self.slice.value.get_folded_value() - value = self.value.get_folded_value() - - if not isinstance(value, List): - raise UnfoldableNode("Subscript object is not a literal list") - - elements = value.elements - if len(set([type(i) for i in elements])) > 1: - raise UnfoldableNode("List contains multiple node types") - - if not isinstance(slice_, Int): - raise UnfoldableNode("invalid index type", slice_) - - idx = slice_.value - if idx < 0 or idx >= len(elements): - raise UnfoldableNode("invalid index value") - - return elements[idx] - class Index(VyperNode): __slots__ = ("value",) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 8bc4a4eb57..4a5bc0d001 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -31,7 +31,6 @@ class VyperNode: @classmethod def get_fields(cls: Any) -> set: ... def get_folded_value(self) -> VyperNode: ... - def _try_fold(self) -> VyperNode: ... def _set_folded_value(self, node: VyperNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 39d97c4abe..4f8101dfbe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -90,6 +90,7 @@ ceil32, fourbytes_to_int, keccak256, + method_id, method_id_int, vyper_warn, ) @@ -723,12 +724,12 @@ def _try_fold(self, node): raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) return_type = self.infer_kwarg_types(node)["output_type"].typedef - value = method_id_int(value.value) + value = method_id(value.value) if return_type.compare_type(BYTES4_T): - return vy_ast.Hex.from_node(node, value=hex(value)) + return vy_ast.Hex.from_node(node, value="0x" + value.hex()) else: - return vy_ast.Bytes.from_node(node, value=value.to_bytes(4, "big")) + return vy_ast.Bytes.from_node(node, value=value) def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 51f3fea14c..04667aaa59 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -373,7 +373,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 169c71269d..cc8ddaf98d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -510,8 +510,7 @@ def visit(self, node, typ): # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() - validate_expected_type(folded_node, typ) - folded_node._metadata["type"] = typ + self.visit(folded_node, typ) def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8e435f870f..4a7e33e848 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -26,11 +26,7 @@ from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions from vyper.semantics.analysis.pre_typecheck import pre_typecheck -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT @@ -315,12 +311,11 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) + ExprVisitor().visit(node.value, type_) # performs validate_expected_type if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) - validate_expected_type(node.value, type_) _validate_self_namespace() return _finalize() diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index a1302ce9c9..1c2a5392c3 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,94 +1,210 @@ from vyper import ast as vy_ast -from vyper.exceptions import UnfoldableNode - - -# try to fold a node, swallowing exceptions. this function is very similar to -# `VyperNode.get_folded_value()` but additionally checks in the constants -# table if the node is a `Name` node. -# -# CMC 2023-12-30 a potential refactor would be to move this function into -# `Name._try_fold` (which would require modifying the signature of _try_fold to -# take an optional constants table as parameter). this would remove the -# need to use this function in conjunction with `get_descendants` since -# `VyperNode._try_fold()` already recurses. it would also remove the need -# for `VyperNode._set_folded_value()`. -def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): - if node.has_folded_value: - return - - if isinstance(node, vy_ast.Name): - # check if it's in constants table - var_name = node.id - - if var_name not in constants: - return - - res = constants[var_name] - node._set_folded_value(res) - return - - try: - # call get_folded_value for its side effects - node.get_folded_value() - except UnfoldableNode: - pass - - -def _get_constants(node: vy_ast.Module) -> dict: - constants: dict[str, vy_ast.VyperNode] = {} - const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) - - while True: - n_processed = 0 - - for c in const_var_decls.copy(): - assert c.value is not None # guaranteed by VariableDecl.validate() - - for n in c.get_descendants(reverse=True): - _fold_with_constants(n, constants) - +from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.namespace import get_namespace + + +def pre_typecheck(module_ast: vy_ast.Module): + ConstantFolder(module_ast).run() + + +class ConstantFolder(VyperNodeVisitorBase): + def __init__(self, module_ast): + self._constants = {} + self._module_ast = module_ast + + def run(self): + self._get_constants() + self.visit(self._module_ast) + + def _get_constants(self): + module = self._module_ast + const_var_decls = module.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + # visit the entire constant node in case its type annotation + # has unfolded constants in it. + self.visit(c) + + assert c.value is not None # guaranteed by VariableDecl.validate() + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + self._constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + def visit(self, node): + if node.has_folded_value: + return node.get_folded_value() + + for c in node.get_children(): try: - val = c.value.get_folded_value() + self.visit(c) except UnfoldableNode: - # not foldable, maybe it depends on other constants - # so try again later - continue - - # note that if a constant is redefined, its value will be - # overwritten, but it is okay because the error is handled - # downstream - name = c.target.id - constants[name] = val - - n_processed += 1 - const_var_decls.remove(c) - - if n_processed == 0: - # this condition means that there are some constant vardecls - # whose values are not foldable. this can happen for struct - # and interface constants for instance. these are valid constant - # declarations, but we just can't fold them at this stage. - break - - return constants - - -# perform constant folding on a module AST -def pre_typecheck(node: vy_ast.Module) -> None: - """ - Perform pre-typechecking steps on a Module AST node. - At this point, this is limited to performing constant folding. - """ - constants = _get_constants(node) - - # note: use reverse to get descendants in leaf-first order - for n in node.get_descendants(reverse=True): - # try folding every single node. note this should be done before - # type checking because the typechecker requires literals or - # foldable nodes in type signatures and some other places (e.g. - # certain builtin kwargs). - # - # note we could limit to only folding nodes which are required - # during type checking, but it's easier to just fold everything - # and be done with it! - _fold_with_constants(n, constants) + # ignore bubbled up exceptions + pass + + try: + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + folded_value = visitor_fn(node) + node._set_folded_value(folded_value) + return folded_value + except UnfoldableNode: + # ignore bubbled up exceptions + pass + + return node + + def visit_Constant(self, node) -> vy_ast.ExprNode: + return node + + def visit_Name(self, node) -> vy_ast.ExprNode: + try: + return self._constants[node.id] + except KeyError: + raise UnfoldableNode("unknown name", node) + + def visit_UnaryOp(self, node): + operand = node.operand.get_folded_value() + + if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant): + raise UnfoldableNode("not a boolean!", node.operand) + if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num): + raise UnfoldableNode("not a number!", node.operand) + if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int): + raise UnfoldableNode("not an int!", node.operand) + + value = node.op._op(operand.value) + return type(operand).from_node(node, value=value) + + def visit_BinOp(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if type(left) is not type(right): + raise UnfoldableNode("invalid operation", node) + if not isinstance(left, vy_ast.Num): + raise UnfoldableNode("not a number!", node.left) + + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", node.right) + + value = node.op._op(left.value, right.value) + return type(left).from_node(node, value=value) + + def visit_BoolOp(self, node): + values = [v.get_folded_value() for v in node.values] + + if any(not isinstance(v, vy_ast.NameConstant) for v in values): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + values = [v.value for v in values] + value = node.op._op(values) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_Compare(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if not isinstance(left, vy_ast.Constant): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + # CMC 2022-08-04 we could probably remove these evaluation rules as they + # are taken care of in the IR optimizer now. + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + if not isinstance(right, vy_ast.List): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if len(set([type(i) for i in right.elements])) > 1: + raise UnfoldableNode("List contains multiple literal types") + value = node.op._op(left.value, [i.value for i in right.elements]) + return vy_ast.NameConstant.from_node(node, value=value) + + if not isinstance(left, type(right)): + raise UnfoldableNode("Cannot compare different literal types") + + # this is maybe just handled in the type checker. + if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num): + raise UnfoldableNode( + f"Invalid literal types for {node.op.description} comparison", node + ) + + value = node.op._op(left.value, right.value) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_List(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Tuple(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Dict(self, node) -> vy_ast.ExprNode: + values = [v.get_folded_value() for v in node.values] + return type(node).from_node(node, values=values) + + def visit_Call(self, node) -> vy_ast.ExprNode: + if not isinstance(node.func, vy_ast.Name): + raise UnfoldableNode("not a builtin", node) + + namespace = get_namespace() + + func_name = node.func.id + if func_name not in namespace: + raise UnfoldableNode("unknown", node) + + varinfo = namespace[func_name] + if not isinstance(varinfo, VarInfo): + raise UnfoldableNode("unfoldable", node) + + typ = varinfo.typ + # TODO: rename to vyper_type.try_fold_call_expr + if not hasattr(typ, "_try_fold"): + raise UnfoldableNode("unfoldable", node) + return typ._try_fold(node) # type: ignore + + def visit_Subscript(self, node) -> vy_ast.ExprNode: + slice_ = node.slice.value.get_folded_value() + value = node.value.get_folded_value() + + if not isinstance(value, vy_ast.List): + raise UnfoldableNode("Subscript object is not a literal list") + + elements = value.elements + if len(set([type(i) for i in elements])) > 1: + raise UnfoldableNode("List contains multiple node types") + + if not isinstance(slice_, vy_ast.Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") + + return elements[idx] diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ba1b02b8d6..359b51b71e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -650,6 +650,8 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) + + # builtins call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) return call_type_modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 429ba807e1..14949f693f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -19,7 +19,7 @@ # type of type `type_` class _GenericTypeAcceptor: def __repr__(self): - return repr(self.type_) + return f"GenericTypeAcceptor({self.type_})" def __init__(self, type_): self.type_ = type_ diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..8f1a5cc0dc 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -4,7 +4,12 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args -from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.exceptions import ( + InterfaceViolation, + NamespaceCollision, + StructureException, + UnfoldableNode, +) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.namespace import get_namespace @@ -53,6 +58,15 @@ def abi_type(self) -> ABIType: def __repr__(self): return f"interface {self._id}" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + arg = node.args[0].get_folded_value() + if not isinstance(arg, vy_ast.Hex): + raise UnfoldableNode("not an address", arg) + + return node + # when using the type itself (not an instance) in the call position def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": self._ctor_arg_types(node) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a4e782349d..8ef9aa8d4a 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -10,6 +10,7 @@ InvalidAttribute, NamespaceCollision, StructureException, + UnfoldableNode, UnknownAttribute, VariableDeclarationException, ) @@ -357,6 +358,16 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": def __repr__(self): return f"{self._id} declaration object" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + args = [arg.get_folded_value() for arg in node.args] + if not isinstance(args[0], vy_ast.Dict): + raise UnfoldableNode("not a dict") + + # it can't be reduced, but this lets upstream code know it's constant + return node + @property def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types.values()) From a6dc432db2df1dfae5919b737b4dd1f55ace859b Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:42:28 +0800 Subject: [PATCH 156/201] fix: improve diagnostics for invalid for loop annotation (#3721) improves diagnostic messages for invalid for loop annotations by fixing up the source location during `vyper/ast/parse.py`. propagates full AnnAssign node from pre_parse.py to get better location information --------- Co-authored-by: Charles Cooper --- .../features/iteration/test_for_in_list.py | 29 +++++++++ .../exceptions/test_syntax_exception.py | 12 ++++ tests/functional/syntax/test_for_range.py | 12 ++++ vyper/ast/parse.py | 62 ++++++++++++++----- vyper/ast/pre_parser.py | 41 ++++++------ 5 files changed, 120 insertions(+), 36 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 5c7b5c6b1b..7f5658e485 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -11,7 +11,9 @@ NamespaceCollision, StateAccessViolation, StructureException, + SyntaxException, TypeMismatch, + UnknownType, ) BASIC_FOR_LOOP_CODE = [ @@ -803,6 +805,33 @@ def test_for() -> int128: """, TypeMismatch, ), + ( + """ +@external +def foo(): + for i in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: $$$ in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: uint9 in [1, 2, 3]: + pass + """, + UnknownType, + ), ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 9ab9b6c677..53a9550a7d 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -86,6 +86,18 @@ def f(a:uint256,/): # test posonlyargs blocked def g(): self.f() """, + """ +@external +def foo(): + for i in range(0, 10): + pass + """, + """ +@external +def foo(): + for i: $$$ in range(0, 10): + pass + """, ] diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 66981a90de..e807e12d41 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -8,6 +8,7 @@ StateAccessViolation, StructureException, TypeMismatch, + UnknownType, ) fail_list = [ @@ -235,6 +236,17 @@ def foo(): "Bound must be at least 1", "FOO", ), + ( + """ +@external +def foo(): + for i: DynArra[uint256, 3] in [1, 2, 3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "DynArra", + ), ] for_code_regex = re.compile(r"for .+ in (.*):") diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b657cf2245..b1b9a8d917 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,4 +1,5 @@ import ast as python_ast +import string import tokenize from decimal import Decimal from typing import Any, Dict, List, Optional, Union, cast @@ -150,7 +151,9 @@ def generic_visit(self, node): self.counter += 1 # Decorate every node with source end offsets - start = node.first_token.start if hasattr(node, "first_token") else (None, None) + start = (None, None) + if hasattr(node, "first_token"): + start = node.first_token.start end = (None, None) if hasattr(node, "last_token"): end = node.last_token.end @@ -224,9 +227,9 @@ def visit_For(self, node): Visit a For node, splicing in the loop variable annotation provided by the pre-parser """ - raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) - if not raw_annotation: + if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more # specific error message than "invalid type annotation" raise SyntaxException( @@ -238,25 +241,50 @@ def visit_For(self, node): node.col_offset, ) + self.generic_visit(node) + try: - annotation = python_ast.parse(raw_annotation, mode="eval") - # annotate with token and source code information. `first_token` - # and `last_token` attributes are accessed in `generic_visit`. - tokens = asttokens.ASTTokens(raw_annotation) - tokens.mark_tokens(annotation) + annotation_str = tokenize.untokenize(annotation_tokens).strip(string.whitespace + "\\") + annotation = python_ast.parse(annotation_str) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - assert isinstance(annotation, python_ast.Expression) - annotation = annotation.body - - old_target = node.target - new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) - node.target = new_target + annotation = annotation.body[0] + og_target = node.target + + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(annotation_str) + tokens.mark_tokens(annotation) + + # decrease line offset by 1 because annotation is on the same line as `For` node + # but the spliced expression also starts at line 1 + adjustment = og_target.first_token.start[0] - 1, og_target.first_token.start[1] + + def _add_pair(x, y): + return x[0] + y[0], x[1] + y[1] + + for n in python_ast.walk(annotation): + # adjust all offsets + if hasattr(n, "first_token"): + n.first_token = n.first_token._replace( + start=_add_pair(n.first_token.start, adjustment), + end=_add_pair(n.first_token.end, adjustment), + startpos=n.first_token.startpos + og_target.first_token.startpos, + endpos=n.first_token.startpos + og_target.first_token.startpos, + ) + if hasattr(n, "last_token"): + n.last_token = n.last_token._replace( + start=_add_pair(n.last_token.start, adjustment), + end=_add_pair(n.last_token.end, adjustment), + startpos=n.last_token.startpos + og_target.first_token.startpos, + endpos=n.last_token.endpos + og_target.first_token.startpos, + ) - self.generic_visit(node) + node.target = annotation + node.target = self.generic_visit(node.target) return node @@ -418,8 +446,8 @@ def annotate_python_ast( source_code : str The originating source code of the AST. loop_var_annotations: dict, optional - A mapping of line numbers of `For` nodes to the type annotation of the iterator - extracted during pre-parsing. + A mapping of line numbers of `For` nodes to the tokens of the type annotation + of the iterator extracted during pre-parsing. modification_offsets : dict, optional A mapping of class names to their original class types. diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index c7e6f3698f..f7d2df208a 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -65,27 +65,32 @@ def __init__(self, code): def consume(self, token): # state machine: we can start slurping tokens soon if token.type == NAME and token.string == "for": - # note: self._state should be NOT_RUNNING here, but we don't sanity - # check here as that should be an error the parser will handle. + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + self._current_annotation = [] + + assert self._state == ForParserState.NOT_RUNNING self._state = ForParserState.START_SOON self._current_for_loop = token.start + return False if self._state == ForParserState.NOT_RUNNING: return False - # state machine: start slurping tokens - if token.type == OP and token.string == ":": - self._state = ForParserState.RUNNING + if self._state == ForParserState.START_SOON: + # state machine: start slurping tokens - # sanity check -- this should never really happen, but if it does, - # try to raise an exception which pinpoints the source. - if self._current_annotation is not None: - raise SyntaxException( - "for loop parse error", self._code, token.start[0], token.start[1] - ) + self._current_annotation.append(token) - self._current_annotation = [] - return True # do not add ":" to tokens. + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + return True # do not add ":" to global tokens. + + return False # add everything before ":" to tokens # state machine: end slurping tokens if token.type == NAME and token.string == "in": @@ -136,8 +141,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: Compilation settings based on the directives in the source code ModificationOffsets A mapping of class names to their original class types. - dict[tuple[int, int], str] - A mapping of line/column offsets of `For` nodes to the annotation of the for loop target + dict[tuple[int, int], list[TokenInfo]] + A mapping of line/column offsets of `For` nodes to the tokens of the annotation of the + for loop target str Reformatted python source string. """ @@ -220,9 +226,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: for_loop_annotations = {} for k, v in for_parser.annotations.items(): - v_source = untokenize(v) - # untokenize adds backslashes and whitespace, strip them. - v_source = v_source.replace("\\", "").strip() - for_loop_annotations[k] = v_source + for_loop_annotations[k] = v.copy() return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") From 07ab92f3e287a812a1197182fe6e8b1168016ce3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 12 Jan 2024 12:27:42 -0500 Subject: [PATCH 157/201] refactor: for loop target parsing (#3724) use a parser trick to get always-correct source locations into for-loop target annotations. this fixes an issue with parsing the for loop target (introduced in ddfce5273b3), where the target annotation is not decorated with the correct line/column offsets. a6dc432db2df partially fixed the issue, but does not decorate correctly when the annotation is split across multiple lines. this commit approaches the problem differently, by using `untokenize()` to produce a source code for the annotation during AST massaging which "happens" to have the annotation in exactly the same place as it appeared in the original source. small refactors: - move annotate_python_ast to be earlier in the file. it's somehow easier to navigate to it when it comes right before the definition of AnnotatingVisitor. - revert changes from a6dc432db2df to ast/pre_parser.py; go back to slurping just the annotation as opposed to the target+annotation together. --- tests/functional/syntax/test_for_range.py | 36 ++++- vyper/ast/parse.py | 154 ++++++++++------------ vyper/ast/pre_parser.py | 34 ++--- 3 files changed, 121 insertions(+), 103 deletions(-) diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index e807e12d41..a486d11738 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -247,9 +247,43 @@ def foo(): "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", "DynArra", ), + ( + # test for loop target broken into multiple lines + """ +@external +def foo(): + for i: \\ + \\ + \\ + \\ + \\ + \\ + uint9 in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", + ), + ( + # test an even more deranged example + """ +@external +def foo(): + for i: \\ + \\ + DynArray[\\ + uint9, 3\\ + ] in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", + ), ] -for_code_regex = re.compile(r"for .+ in (.*):") +for_code_regex = re.compile(r"for .+ in (.*):", re.DOTALL) fail_test_names = [ ( f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b1b9a8d917..cc0a47824c 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,5 +1,4 @@ import ast as python_ast -import string import tokenize from decimal import Decimal from typing import Any, Dict, List, Optional, Union, cast @@ -115,6 +114,50 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') +def annotate_python_ast( + parsed_ast: python_ast.AST, + source_code: str, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, +) -> python_ast.AST: + """ + Annotate and optimize a Python AST in preparation conversion to a Vyper AST. + + Parameters + ---------- + parsed_ast : AST + The AST to be annotated and optimized. + source_code : str + The originating source code of the AST. + loop_var_annotations: dict + A mapping of line numbers of `For` nodes to the tokens of the type + annotation of the iterator extracted during pre-parsing. + modification_offsets : dict + A mapping of class names to their original class types. + + Returns + ------- + The annotated and optimized AST. + """ + + tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + for_loop_annotations, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + visitor.visit(parsed_ast) + + return parsed_ast + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -170,6 +213,7 @@ def generic_visit(self, node): if hasattr(node, "last_token"): start_pos = node.first_token.startpos end_pos = node.last_token.endpos + if node.last_token.type == 4: # ignore trailing newline once more end_pos -= 1 @@ -241,52 +285,42 @@ def visit_For(self, node): node.col_offset, ) - self.generic_visit(node) + # some kind of black magic. untokenize preserves the line and column + # offsets, giving us something like `\ + # \ + # \ + # uint8` + # that's not a valid python Expr because it is indented. + # but it's good because the code is indented to exactly the same + # offset as it did in the original source! + # (to best understand this, print out annotation_str and + # self._source_code and compare them side-by-side). + # + # what we do here is add in a dummy target which we will remove + # in a bit, but for now lets us keep the line/col offset, and + # *also* gives us a valid AST. it doesn't matter what the dummy + # target name is, since it gets removed in a few lines. + annotation_str = tokenize.untokenize(annotation_tokens) + annotation_str = "dummy_target:" + annotation_str try: - annotation_str = tokenize.untokenize(annotation_tokens).strip(string.whitespace + "\\") - annotation = python_ast.parse(annotation_str) + fake_node = python_ast.parse(annotation_str).body[0] except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - annotation = annotation.body[0] - og_target = node.target - - # annotate with token and source code information. `first_token` - # and `last_token` attributes are accessed in `generic_visit`. - tokens = asttokens.ASTTokens(annotation_str) - tokens.mark_tokens(annotation) - - # decrease line offset by 1 because annotation is on the same line as `For` node - # but the spliced expression also starts at line 1 - adjustment = og_target.first_token.start[0] - 1, og_target.first_token.start[1] - - def _add_pair(x, y): - return x[0] + y[0], x[1] + y[1] - - for n in python_ast.walk(annotation): - # adjust all offsets - if hasattr(n, "first_token"): - n.first_token = n.first_token._replace( - start=_add_pair(n.first_token.start, adjustment), - end=_add_pair(n.first_token.end, adjustment), - startpos=n.first_token.startpos + og_target.first_token.startpos, - endpos=n.first_token.startpos + og_target.first_token.startpos, - ) - if hasattr(n, "last_token"): - n.last_token = n.last_token._replace( - start=_add_pair(n.last_token.start, adjustment), - end=_add_pair(n.last_token.end, adjustment), - startpos=n.last_token.startpos + og_target.first_token.startpos, - endpos=n.last_token.endpos + og_target.first_token.startpos, - ) + # fill in with asttokens info. note we can use `self._tokens` because + # it is indented to exactly the same position where it appeared + # in the original source! + self._tokens.mark_tokens(fake_node) - node.target = annotation - node.target = self.generic_visit(node.target) + # replace the dummy target name with the real target name. + fake_node.target = node.target + # replace the For node target with the new ann_assign + node.target = fake_node - return node + return self.generic_visit(node) def visit_Expr(self, node): """ @@ -425,47 +459,3 @@ def visit_UnaryOp(self, node): return node.operand else: return node - - -def annotate_python_ast( - parsed_ast: python_ast.AST, - source_code: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, - source_id: int = 0, - module_path: Optional[str] = None, - resolved_path: Optional[str] = None, -) -> python_ast.AST: - """ - Annotate and optimize a Python AST in preparation conversion to a Vyper AST. - - Parameters - ---------- - parsed_ast : AST - The AST to be annotated and optimized. - source_code : str - The originating source code of the AST. - loop_var_annotations: dict, optional - A mapping of line numbers of `For` nodes to the tokens of the type annotation - of the iterator extracted during pre-parsing. - modification_offsets : dict, optional - A mapping of class names to their original class types. - - Returns - ------- - The annotated and optimized AST. - """ - - tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor( - source_code, - modification_offsets, - for_loop_annotations, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, - ) - visitor.visit(parsed_ast) - - return parsed_ast diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index f7d2df208a..159dfc0ace 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -65,32 +65,27 @@ def __init__(self, code): def consume(self, token): # state machine: we can start slurping tokens soon if token.type == NAME and token.string == "for": - # sanity check -- this should never really happen, but if it does, - # try to raise an exception which pinpoints the source. - if self._current_annotation is not None: - raise SyntaxException( - "for loop parse error", self._code, token.start[0], token.start[1] - ) - self._current_annotation = [] - - assert self._state == ForParserState.NOT_RUNNING + # note: self._state should be NOT_RUNNING here, but we don't sanity + # check here as that should be an error the parser will handle. self._state = ForParserState.START_SOON self._current_for_loop = token.start - return False if self._state == ForParserState.NOT_RUNNING: return False - if self._state == ForParserState.START_SOON: - # state machine: start slurping tokens + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING - self._current_annotation.append(token) - - if token.type == OP and token.string == ":": - self._state = ForParserState.RUNNING - return True # do not add ":" to global tokens. + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) - return False # add everything before ":" to tokens + self._current_annotation = [] + return True # do not add ":" to tokens. # state machine: end slurping tokens if token.type == NAME and token.string == "in": @@ -142,8 +137,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: ModificationOffsets A mapping of class names to their original class types. dict[tuple[int, int], list[TokenInfo]] - A mapping of line/column offsets of `For` nodes to the tokens of the annotation of the - for loop target + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ From 5c2177b622bd07df26e53b4d950b6cecfdf9d677 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 12 Jan 2024 15:57:32 -0500 Subject: [PATCH 158/201] fix: allow using interface defs from imported modules (#3725) - add interface defs to ModuleT's exposed `get_type_members()` - slight refactor of ModuleT to have a special helper instead of dispatching to `self.interface` --- .../codegen/modules/test_interface_imports.py | 31 +++++++++++++++++++ .../test_stateless_functions.py} | 0 tests/functional/syntax/test_interfaces.py | 2 +- vyper/ast/nodes.py | 4 +-- vyper/semantics/analysis/module.py | 5 +-- vyper/semantics/types/module.py | 18 ++++++++++- vyper/semantics/types/utils.py | 10 +++--- 7 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 tests/functional/codegen/modules/test_interface_imports.py rename tests/functional/codegen/{test_stateless_modules.py => modules/test_stateless_functions.py} (100%) diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py new file mode 100644 index 0000000000..084ad26e6b --- /dev/null +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -0,0 +1,31 @@ +def test_import_interface_types(make_input_bundle, get_contract): + ifaces = """ +interface IFoo: + def foo() -> uint256: nonpayable + """ + + foo_impl = """ +import ifaces + +implements: ifaces.IFoo + +@external +def foo() -> uint256: + return block.number + """ + + contract = """ +import ifaces + +@external +def test_foo(s: ifaces.IFoo) -> bool: + assert s.foo() == block.number + return True + """ + + input_bundle = make_input_bundle({"ifaces.vy": ifaces}) + + foo = get_contract(foo_impl, input_bundle=input_bundle) + c = get_contract(contract, input_bundle=input_bundle) + + assert c.test_foo(foo.address) is True diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/modules/test_stateless_functions.py similarity index 100% rename from tests/functional/codegen/test_stateless_modules.py rename to tests/functional/codegen/modules/test_stateless_functions.py diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index a672ed7b88..ca96adca91 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -90,7 +90,7 @@ def foo(): nonpayable """ implements: self.x """, - StructureException, + InvalidType, ), ( """ diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 90365c63d5..fa1fb63673 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1372,8 +1372,8 @@ class ImplementsDecl(Stmt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not isinstance(self.annotation, Name): - raise StructureException("not an identifier", self.annotation) + if not isinstance(self.annotation, (Name, Attribute)): + raise StructureException("invalid implements", self.annotation) class If(Stmt): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 4a7e33e848..2972ed2917 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -383,8 +383,9 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_InterfaceDef(node) - self.namespace[node.name] = obj + interface_t = InterfaceT.from_InterfaceDef(node) + node._metadata["interface_type"] = interface_t + self.namespace[node.name] = interface_t def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 8f1a5cc0dc..f2c3d74525 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -39,6 +39,7 @@ def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> No self._helper = VyperType(events | structs) self._id = _id + self._helper._id = _id self.functions = functions self.events = events self.structs = structs @@ -267,6 +268,8 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -276,7 +279,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # compute the interface, note this has the side effect of checking # for function collisions - self._helper = self.interface + _ = self.interface + + self._helper = VyperType() + self._helper._id = self._id for f in self.function_defs: # note: this checks for collisions @@ -289,6 +295,12 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for s in self.struct_defs: # add the type of the struct so it can be used in call position self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + self._helper.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for i in self.interface_defs: + # add the type of the interface so it can be used in call position + self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore + self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -322,6 +334,10 @@ def event_defs(self): def struct_defs(self): return self._module.get_children(vy_ast.StructDef) + @property + def interface_defs(self): + return self._module.get_children(vy_ast.InterfaceDef) + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index eb96375404..c82eb73afc 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -127,14 +127,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface + module_or_interface = module_or_interface.module_t - if not interface._attribute_in_annotation: + if not isinstance(module_or_interface, VyperType): raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + if not module_or_interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = module_or_interface.get_type_member(node.attr, node) # type: ignore assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef From 785f09d3e7a7da4c162148ad5b8f1c07f3d8ab36 Mon Sep 17 00:00:00 2001 From: 0x77 <42128352+0x0077@users.noreply.github.com> Date: Sat, 13 Jan 2024 22:57:00 +0800 Subject: [PATCH 159/201] docs: add Vyper Online Compiler to resources section (#3680) --- docs/resources.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/resources.rst b/docs/resources.rst index a3dfa480ed..7bb3c99df4 100644 --- a/docs/resources.rst +++ b/docs/resources.rst @@ -24,6 +24,7 @@ Frameworks and tooling - `🐍 snekmate – Vyper smart contract building blocks `_ - `Serpentor – A set of smart contracts tools for governance `_ - `Smart contract development frameworks and tools for Vyper on Ethreum.org `_ +- `Vyper Online Compiler - an online platform for compiling and deploying Vyper smart contracts `_ Security -------- From 3f013ecef227dbac2383c1dfefc56de5e2ba8a4a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 13 Jan 2024 11:48:30 -0500 Subject: [PATCH 160/201] feat: add support for constants in imported modules (#3726) - add a case to ConstantFolder, and move the constant folder slightly down in the pipeline to after imports have been resolved. - rename `pre_typecheck` to `constant_fold` (since that is all it does, and it doesn't strictly happen before typechecking anymore). --- .../codegen/modules/test_module_constants.py | 78 +++++++++++++++++++ tests/utils.py | 4 +- vyper/ast/nodes.py | 2 +- vyper/ast/nodes.pyi | 4 +- .../{pre_typecheck.py => constant_folding.py} | 31 +++++++- vyper/semantics/analysis/module.py | 7 +- 6 files changed, 116 insertions(+), 10 deletions(-) create mode 100644 tests/functional/codegen/modules/test_module_constants.py rename vyper/semantics/analysis/{pre_typecheck.py => constant_folding.py} (89%) diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py new file mode 100644 index 0000000000..aafbb69252 --- /dev/null +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -0,0 +1,78 @@ +def test_module_constant(make_input_bundle, get_contract): + mod1 = """ +X: constant(uint256) = 12345 + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 12345 + + +def test_nested_module_constant(make_input_bundle, get_contract): + # test nested module constants + # test at least 3 modules deep to test the `path.reverse()` gizmo + # in ConstantFolder.visit_Attribute() + mod1 = """ +X: constant(uint256) = 12345 + """ + mod2 = """ +import mod1 +X: constant(uint256) = 54321 + """ + mod3 = """ +import mod2 +X: constant(uint256) = 98765 + """ + + contract = """ +import mod1 +import mod2 +import mod3 + +@external +def test_foo() -> bool: + assert mod1.X == 12345 + assert mod2.X == 54321 + assert mod3.X == 98765 + assert mod2.mod1.X == mod1.X + assert mod3.mod2.mod1.X == mod1.X + assert mod3.mod2.X == mod2.X + return True + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1, "mod2.vy": mod2, "mod3.vy": mod3}) + + c = get_contract(contract, input_bundle=input_bundle) + assert c.test_foo() is True + + +def test_import_constant_array(make_input_bundle, get_contract, tx_failed): + mod1 = """ +X: constant(uint256[3]) = [1,2,3] + """ + contract = """ +import mod1 + +@external +def foo(ix: uint256) -> uint256: + return mod1.X[ix] + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo(0) == 1 + assert c.foo(1) == 2 + assert c.foo(2) == 3 + with tx_failed(): + c.foo(3) diff --git a/tests/utils.py b/tests/utils.py index b8a6b493d8..25dad818ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ import os from vyper import ast as vy_ast -from vyper.semantics.analysis.pre_typecheck import pre_typecheck +from vyper.semantics.analysis.constant_folding import constant_fold @contextlib.contextmanager @@ -17,5 +17,5 @@ def working_directory(directory): def parse_and_fold(source_code): ast = vy_ast.parse_to_ast(source_code) - pre_typecheck(ast) + constant_fold(ast) return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index fa1fb63673..df419daa25 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -396,7 +396,7 @@ def has_folded_value(self): """ return "folded_value" in self._metadata - def get_folded_value(self) -> "VyperNode": + def get_folded_value(self) -> "ExprNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 4a5bc0d001..7f8c902d45 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -30,8 +30,8 @@ class VyperNode: def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def get_folded_value(self) -> VyperNode: ... - def _set_folded_value(self, node: VyperNode) -> None: ... + def get_folded_value(self) -> ExprNode: ... + def _set_folded_value(self, node: ExprNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/constant_folding.py similarity index 89% rename from vyper/semantics/analysis/pre_typecheck.py rename to vyper/semantics/analysis/constant_folding.py index 1c2a5392c3..b165a6dae9 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -1,11 +1,11 @@ from vyper import ast as vy_ast -from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.exceptions import InvalidLiteral, UnfoldableNode, VyperException from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.namespace import get_namespace -def pre_typecheck(module_ast: vy_ast.Module): +def constant_fold(module_ast: vy_ast.Module): ConstantFolder(module_ast).run() @@ -89,6 +89,33 @@ def visit_Name(self, node) -> vy_ast.ExprNode: except KeyError: raise UnfoldableNode("unknown name", node) + def visit_Attribute(self, node) -> vy_ast.ExprNode: + namespace = get_namespace() + path = [] + value = node.value + while isinstance(value, vy_ast.Attribute): + path.append(value.attr) + value = value.value + + path.reverse() + + if not isinstance(value, vy_ast.Name): + raise UnfoldableNode("not a module", value) + + # not super type-safe but we don't care. just catch AttributeErrors + # and move on + try: + module_t = namespace[value.id].module_t + + for module_name in path: + module_t = module_t.members[module_name].module_t + + varinfo = module_t.get_member(node.attr, node) + + return varinfo.decl_node.value.get_folded_value() + except (VyperException, AttributeError): + raise UnfoldableNode("not a module") + def visit_UnaryOp(self, node): operand = node.operand.get_folded_value() diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 2972ed2917..100819526b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -23,9 +23,9 @@ ) from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace @@ -51,8 +51,6 @@ def validate_semantics_r( """ validate_literal_nodes(module_ast) - pre_typecheck(module_ast) - # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -140,6 +138,9 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + # we can resolve constants after imports are handled. + constant_fold(self.ast) + # keep trying to process all the nodes until we finish or can # no longer progress. this makes it so we don't need to # calculate a dependency tree between top-level items. From 9cf66c9dd12c9a020c6945c100cc1266be262ebe Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 13 Jan 2024 12:14:33 -0500 Subject: [PATCH 161/201] chore: bump sphinx version (#3728) updated requirement by readthedocs --- requirements-docs.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-docs.txt b/requirements-docs.txt index d33eae62af..157d7bcab5 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,3 @@ -sphinx==4.5.0 +sphinx==5.0.0 recommonmark==0.6.0 sphinx_rtd_theme==0.5.2 From af5c49fad8d57d11c9f3cae863e313fdf2ae6db9 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 15 Jan 2024 02:52:39 +0800 Subject: [PATCH 162/201] refactor: remove duplicate terminus checking code (#3541) remove `check_single_exit` and `is_return_from_function` which duplicate functionality in `is_terminus_node`/`check_for_terminus`. additionally rewrite termination checking routine to be simpler, and also fix an outstanding analysis bug where the following program would not be rejected: ```vyper @external def foo(a: bool) -> uint256: if a: return 1 else: return 2 pass # unreachable ``` --------- Co-authored-by: Charles Cooper --- .../codegen/features/test_assert.py | 8 --- .../codegen/features/test_conditionals.py | 1 - .../syntax/test_unbalanced_return.py | 43 +++++++++++---- vyper/ast/nodes.py | 38 ++++++++++++-- vyper/builtins/_signatures.py | 1 + vyper/codegen/core.py | 40 +------------- vyper/codegen/function_definitions/common.py | 5 -- vyper/codegen/stmt.py | 3 +- vyper/semantics/analysis/local.py | 52 ++++++++++--------- 9 files changed, 98 insertions(+), 93 deletions(-) diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index df379d3f16..de9dd17ef6 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -107,14 +107,6 @@ def test(): assert self.ret1() == 1 """, """ -@internal -def valid_address(sender: address) -> bool: - selfdestruct(sender) -@external -def test(): - assert self.valid_address(msg.sender) - """, - """ @external def test(): assert raw_call(msg.sender, b'', max_outsize=1, gas=10, value=1000*1000) == b'' diff --git a/tests/functional/codegen/features/test_conditionals.py b/tests/functional/codegen/features/test_conditionals.py index 15ccc40bdf..3b0e57eeca 100644 --- a/tests/functional/codegen/features/test_conditionals.py +++ b/tests/functional/codegen/features/test_conditionals.py @@ -7,7 +7,6 @@ def foo(i: bool) -> int128: else: assert 2 != 0 return 7 - return 11 """ c = get_contract_with_gas_estimation(conditional_return_code) diff --git a/tests/functional/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py index d1d9732777..d5754f0053 100644 --- a/tests/functional/syntax/test_unbalanced_return.py +++ b/tests/functional/syntax/test_unbalanced_return.py @@ -8,7 +8,7 @@ """ @external def foo() -> int128: - pass + pass # missing return """, FunctionDeclarationException, ), @@ -18,6 +18,7 @@ def foo() -> int128: def foo() -> int128: if False: return 123 + # missing return """, FunctionDeclarationException, ), @@ -27,19 +28,19 @@ def foo() -> int128: def test() -> int128: if 1 == 1 : return 1 - if True: + if True: # unreachable return 0 else: assert msg.sender != msg.sender """, - FunctionDeclarationException, + StructureException, ), ( """ @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - return True + return True # unreachable """, StructureException, ), @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool: @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - a: address = sender + a: address = sender # unreachable """, StructureException, ), @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool: def valid_address(sender: address) -> bool: if sender == empty(address): selfdestruct(sender) - _sender: address = sender + _sender: address = sender # unreachable else: return False """, @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - return True + return True # unreachable """, StructureException, ), @@ -78,7 +79,7 @@ def foo() -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - x: uint256 = 3 + x: uint256 = 3 # unreachable """, StructureException, ), @@ -88,12 +89,35 @@ def foo() -> bool: def foo(x: uint256) -> bool: if x == 2: raw_revert(b"vyper") - a: uint256 = 3 + a: uint256 = 3 # unreachable else: return False """, StructureException, ), + ( + """ +@internal +def foo(): + return + return # unreachable + """, + StructureException, + ), + ( + """ +@internal +def foo() -> uint256: + if block.number % 2 == 0: + return 5 + elif block.number % 3 == 0: + return 6 + else: + return 10 + return 0 # unreachable + """, + StructureException, + ), ] @@ -154,7 +178,6 @@ def test() -> int128: else: x = keccak256(x) return 1 - return 1 """, """ @external diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index df419daa25..de15fb9075 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -237,8 +237,6 @@ class VyperNode: Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _is_terminus : bool, optional - If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields across different Python versions. @@ -389,6 +387,13 @@ def is_literal_value(self): """ return False + @property + def is_terminus(self): + """ + Check if execution halts upon reaching this node. + """ + return False + @property def has_folded_value(self): """ @@ -711,12 +716,19 @@ class Stmt(VyperNode): class Return(Stmt): __slots__ = ("value",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Expr(Stmt): __slots__ = ("value",) + @property + def is_terminus(self): + return self.value.is_terminus + class Log(Stmt): __slots__ = ("value",) @@ -1187,6 +1199,21 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") + @property + def is_terminus(self): + # cursed import cycle! + from vyper.builtins.functions import get_builtin_functions + + if not isinstance(self.func, Name): + return False + + funcname = self.func.id + builtin_t = get_builtin_functions().get(funcname) + if builtin_t is None: + return False + + return builtin_t._is_terminus + class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1322,7 +1349,10 @@ class AugAssign(Stmt): class Raise(Stmt): __slots__ = ("exc",) _only_empty_fields = ("cause",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Assert(Stmt): diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index aac008ad1e..1a488f39e0 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _is_terminus = False # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c16de3c55a..c3215f8c16 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,12 +1,11 @@ import contextlib from typing import Generator -from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check -from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch +from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( AddressT, BoolT, @@ -1035,43 +1034,6 @@ def eval_seq(ir_node): return None -def is_return_from_function(node): - if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in ( - "raw_revert", - "selfdestruct", - ): - return True - if isinstance(node, (vy_ast.Return, vy_ast.Raise)): - return True - return False - - -# TODO this is almost certainly duplicated with check_terminus_node -# in vyper/semantics/analysis/local.py -def check_single_exit(fn_node): - _check_return_body(fn_node, fn_node.body) - for node in fn_node.get_descendants(vy_ast.If): - _check_return_body(node, node.body) - if node.orelse: - _check_return_body(node, node.orelse) - - -def _check_return_body(node, node_list): - return_count = len([n for n in node_list if is_return_from_function(n)]) - if return_count > 1: - raise StructureException( - "Too too many exit statements (return, raise or selfdestruct).", node - ) - # Check for invalid code after returns. - last_node_pos = len(node_list) - 1 - for idx, n in enumerate(node_list): - if is_return_from_function(n) and idx < last_node_pos: - # is not last statement in body. - raise StructureException( - "Exit statement with succeeding code (that will not execute).", node_list[idx + 1] - ) - - def mzero(dst, nbytes): # calldatacopy from past-the-end gives zero bytes. # cf. YP H.2 (ops section) with CALLDATACOPY spec. diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 454ba9c8cd..5877ff3d13 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -4,7 +4,6 @@ import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode @@ -115,10 +114,6 @@ def generate_ir_for_function( # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) - # Validate return statements. - # XXX: This should really be in semantics pass. - check_single_exit(code) - callees = func_t.called_functions # we start our function frame from the largest callee frame diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index a47faefeb1..7d4938f287 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -15,7 +15,6 @@ get_dyn_array_count, get_element_ptr, getpos, - is_return_from_function, make_byte_array_copier, make_setter, pop_dyn_array, @@ -404,7 +403,7 @@ def parse_stmt(stmt, context): def _is_terminated(code): last_stmt = code[-1] - if is_return_from_function(last_stmt): + if last_stmt.is_terminus: return True if isinstance(last_stmt, vy_ast.If): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cc8ddaf98d..c4af5b1e3a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -66,26 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None: err_list.raise_if_not_empty() -def _is_terminus_node(node: vy_ast.VyperNode) -> bool: - if getattr(node, "_is_terminus", None): - return True - if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call): - func = get_exact_type_from_node(node.value.func) - if getattr(func, "_is_terminus", None): - return True - return False - - -def check_for_terminus(node_list: list) -> bool: - if next((i for i in node_list if _is_terminus_node(i)), None): - return True - for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]: - if not node.orelse or not check_for_terminus(node.orelse): - continue - if not check_for_terminus(node.body): - continue - return True - return False +# finds the terminus node for a list of nodes. +# raises an exception if any nodes are unreachable +def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: + ret = None + + for node in node_list: + if ret is not None: + raise StructureException("Unreachable code!", node) + if node.is_terminus: + ret = node + + if isinstance(node, vy_ast.If): + body_terminates = find_terminating_node(node.body) + + else_terminates = None + if node.orelse is not None: + else_terminates = find_terminating_node(node.orelse) + + if body_terminates is not None and else_terminates is not None: + ret = else_terminates + + return ret def _check_iterator_modification( @@ -201,11 +203,13 @@ def analyze(self): self.visit(node) if self.func.return_type: - if not check_for_terminus(self.fn_node.body): + if not find_terminating_node(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{self.fn_node.name}'", - self.fn_node, + f"Missing return statement in function '{self.fn_node.name}'", self.fn_node ) + else: + # call find_terminator for its unreachable code detection side effect + find_terminating_node(self.fn_node.body) # visit default args assert self.func.n_keyword_args == len(self.fn_node.args.defaults) @@ -468,7 +472,7 @@ def visit_Return(self, node): raise FunctionDeclarationException("Return statement is missing a value", node) return elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + raise FunctionDeclarationException("Function should not return any values", node) if isinstance(values, vy_ast.Tuple): values = values.elements From 88d9c220a7a2c12250aad44774458cc2bf9e418c Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 15 Jan 2024 23:32:00 +0800 Subject: [PATCH 163/201] feat: allow constant interfaces (#3718) This PR adds support for interfaces as constants. - also introduced a mild refactor of `check_modifiability()` --- .../test_default_parameters.py | 32 +++++++++++++++++++ .../codegen/storage_variables/test_getters.py | 5 +++ tests/functional/syntax/test_constants.py | 13 ++++++++ vyper/builtins/_signatures.py | 3 ++ vyper/codegen/expr.py | 6 ++-- vyper/semantics/analysis/utils.py | 10 ++---- vyper/semantics/types/base.py | 5 +++ vyper/semantics/types/module.py | 11 +++++-- vyper/semantics/types/user.py | 6 +++- 9 files changed, 78 insertions(+), 13 deletions(-) diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 462748a9c7..240ccb3bb1 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -111,6 +111,38 @@ def fooBar(a: Bytes[100], b: uint256[2], c: Bytes[6] = b"hello", d: int128[3] = assert c.fooBar(b"booo", [55, 66]) == [b"booo", 66, c_default, d_default] +def test_default_param_interface(get_contract): + code = """ +interface Foo: + def bar(): payable + +FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) + +@external +def bar(a: uint256, b: Foo = Foo(0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7)) -> Foo: + return b + +@external +def baz(a: uint256, b: Foo = Foo(empty(address))) -> Foo: + return b + +@external +def faz(a: uint256, b: Foo = FOO) -> Foo: + return b + """ + c = get_contract(code) + + addr1 = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" + addr2 = "0xF5D4020dCA6a62bB1efFcC9212AAF3c9819E30D7" + + assert c.bar(1) == addr2 + assert c.bar(1, addr1) == addr1 + assert c.baz(1) is None + assert c.baz(1, "0x0000000000000000000000000000000000000000") is None + assert c.faz(1) == addr1 + assert c.faz(1, addr1) == addr1 + + def test_default_param_internal_function(get_contract): code = """ @internal diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index 5eac074ef6..a2d9c6d0bb 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -19,6 +19,9 @@ def foo() -> int128: def test_getter_code(get_contract_with_gas_estimation_for_constants): getter_code = """ +interface V: + def foo(): nonpayable + struct W: a: uint256 b: int128[7] @@ -36,6 +39,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants): d: public(immutable(uint256)) e: public(immutable(uint256[2])) f: public(constant(uint256[2])) = [3, 7] +g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) @external def __init__(): @@ -70,6 +74,7 @@ def __init__(): assert c.d() == 1729 assert c.e(0) == 2 assert [c.f(i) for i in range(2)] == [3, 7] + assert c.g() == "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" def test_getter_mutability(get_contract): diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 7089dee3bb..04e778a00e 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -304,6 +304,19 @@ def deposit(deposit_input: Bytes[2048]): CONST_BAR: constant(Bar) = Bar({c: C, d: D}) """, + """ +interface Foo: + def foo(): nonpayable + +FOO: constant(Foo) = Foo(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) + """, + """ +interface Foo: + def foo(): nonpayable + +FOO: constant(Foo) = Foo(BAR) +BAR: constant(address) = 0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF + """, ] diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 1a488f39e0..d2aefb2fd4 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -129,6 +129,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: # ensures the type can be inferred exactly. get_exact_type_from_node(arg) + def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return self._modifiability >= modifiability + def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 577660b883..6a97e60ce2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -190,9 +190,9 @@ def parse_Name(self): varinfo = self.context.globals[self.expr.id] if varinfo.is_constant: - # non-struct constants should have already gotten propagated - # during constant folding - assert isinstance(varinfo.typ, StructT) + # constants other than structs and interfaces should have already gotten + # propagated during constant folding + assert isinstance(varinfo.typ, (InterfaceT, StructT)) return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 359b51b71e..3e818fa246 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -645,15 +645,11 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(item, modifiability) for item in node.elements) if isinstance(node, vy_ast.Call): - args = node.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return all(check_modifiability(v, modifiability) for v in args[0].values) - call_type = get_exact_type_from_node(node.func) - # builtins - call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) - return call_type_modifiability >= modifiability + # structs and interfaces + if hasattr(call_type, "check_modifiability_for_call"): + return call_type.check_modifiability_for_call(node, modifiability) value_type = get_expr_info(node) return value_type.modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 14949f693f..b15eca8ab2 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -334,6 +334,11 @@ def __init__(self, typedef): def __repr__(self): return f"type({self.typedef})" + def check_modifiability_for_call(self, node, modifiability): + if hasattr(self.typedef, "_ctor_modifiability_for_call"): + return self.typedef._ctor_modifiability_for_call(node, modifiability) + raise StructureException("Value is not callable", node) + # dispatch into ctor if it's called def fetch_call_return(self, node): if hasattr(self.typedef, "_ctor_call_return"): diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index f2c3d74525..ee1da22a87 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -10,8 +10,12 @@ StructureException, UnfoldableNode, ) -from vyper.semantics.analysis.base import VarInfo -from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.utils import ( + check_modifiability, + validate_expected_type, + validate_unique_method_ids, +) from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT @@ -81,6 +85,9 @@ def _ctor_arg_types(self, node): def _ctor_kwarg_types(self, node): return {} + def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return check_modifiability(node.args[0], modifiability) + # TODO x.validate_implements(other) def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: namespace = get_namespace() diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 8ef9aa8d4a..92a455e3d8 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -14,8 +14,9 @@ UnknownAttribute, VariableDeclarationException, ) +from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -419,3 +420,6 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": ) return self + + def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: + return all(check_modifiability(v, modifiability) for v in node.args[0].values) From 81c6d8ef8aea440932c51519ee2844a64da0cd90 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 15 Jan 2024 14:14:35 -0800 Subject: [PATCH 164/201] fix: unreachable code analysis inside for loops (#3731) unreachable code analysis did not analyze for loop bodies. fix: in `find_terminating_node()`, recurse into the bodies of for loops. --- .../syntax/test_unbalanced_return.py | 40 +++++++++++++++++++ vyper/ast/nodes.pyi | 6 ++- vyper/semantics/analysis/local.py | 5 +++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/functional/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py index d5754f0053..04835bb0f0 100644 --- a/tests/functional/syntax/test_unbalanced_return.py +++ b/tests/functional/syntax/test_unbalanced_return.py @@ -118,6 +118,39 @@ def foo() -> uint256: """, StructureException, ), + ( + """ +@internal +def foo() -> uint256: + for i: uint256 in range(10): + if i == 11: + return 1 + """, + FunctionDeclarationException, + ), + ( + """ +@internal +def foo() -> uint256: + for i: uint256 in range(9): + if i == 11: + return 1 + if block.number % 2 == 0: + return 1 + """, + FunctionDeclarationException, + ), + ( + """ +@internal +def foo() -> uint256: + for i: uint256 in range(10): + return 1 + pass # unreachable + return 5 + """, + StructureException, + ), ] @@ -187,6 +220,13 @@ def foo() -> int128: else: raw_revert(b"vyper") """, + """ +@external +def foo() -> int128: + for i: uint256 in range(1): + return 1 + return 0 + """, ] diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7f8c902d45..896329c702 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -257,6 +257,10 @@ class IfExp(ExprNode): body: ExprNode = ... orelse: ExprNode = ... -class For(VyperNode): ... +class For(VyperNode): + target: ExprNode + iter: ExprNode + body: list[VyperNode] + class Break(VyperNode): ... class Continue(VyperNode): ... diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c4af5b1e3a..29a93a9eaf 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -74,6 +74,7 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: for node in node_list: if ret is not None: raise StructureException("Unreachable code!", node) + if node.is_terminus: ret = node @@ -87,6 +88,10 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: if body_terminates is not None and else_terminates is not None: ret = else_terminates + if isinstance(node, vy_ast.For): + # call find_terminating_node for its side effects + find_terminating_node(node.body) + return ret From c42b077c1d355e5c12aad903681108cd7164e4de Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Thu, 18 Jan 2024 00:09:34 +0200 Subject: [PATCH 165/201] fix[venom]: liveness analysis in some loops (#3732) fixes a small issue with the calculation of liveness when variables are consumed inside of loops by refactoring the liveness calculation algorithm to an iterative version over the recursive one additional QOL refactoring: - simplify convert_ir_basicblock - rename it to ir_node_to_venom - fix bb well-formedness for deploy ir - don't do deploy IR detection; rely on caller in CompilerData to call for both deploy and runtime IR - remove findIRnode, it's no longer needed - rename _convert_ir_basicblock to _convert_ir_bb - add _convert_ir_bb_list helper to handle arglists --------- Co-authored-by: Charles Cooper --- .../venom/test_convert_basicblock_simple.py | 41 ++++ vyper/compiler/phases.py | 5 +- vyper/venom/__init__.py | 10 +- vyper/venom/analysis.py | 36 +-- vyper/venom/basicblock.py | 4 +- vyper/venom/function.py | 2 +- vyper/venom/ir_node_to_venom.py | 217 ++++++++---------- 7 files changed, 162 insertions(+), 153 deletions(-) create mode 100644 tests/unit/compiler/venom/test_convert_basicblock_simple.py diff --git a/tests/unit/compiler/venom/test_convert_basicblock_simple.py b/tests/unit/compiler/venom/test_convert_basicblock_simple.py new file mode 100644 index 0000000000..fdaa341a81 --- /dev/null +++ b/tests/unit/compiler/venom/test_convert_basicblock_simple.py @@ -0,0 +1,41 @@ +from vyper.codegen.ir_node import IRnode +from vyper.venom.ir_node_to_venom import ir_node_to_venom + + +def test_simple(): + ir = IRnode.from_list(["calldatacopy", 32, 0, ["calldatasize"]]) + ir_node = IRnode.from_list(ir) + venom = ir_node_to_venom(ir_node) + assert venom is not None + + bb = venom.basic_blocks[0] + assert bb.instructions[0].opcode == "calldatasize" + assert bb.instructions[1].opcode == "calldatacopy" + + +def test_simple_2(): + ir = [ + "seq", + [ + "seq", + [ + "mstore", + ["add", 64, 0], + [ + "with", + "x", + ["calldataload", ["add", 4, 0]], + [ + "with", + "ans", + ["add", "x", 1], + ["seq", ["assert", ["ge", "ans", "x"]], "ans"], + ], + ], + ], + ], + 32, + ] + ir_node = IRnode.from_list(ir) + venom = ir_node_to_venom(ir_node) + assert venom is not None diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index ba6ccbda20..5b7decec7b 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -197,7 +197,10 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def venom_functions(self): - return generate_ir(self.ir_nodes, self.settings.optimize) + deploy_ir, runtime_ir = self._ir_output + deploy_venom = generate_ir(deploy_ir, self.settings.optimize) + runtime_venom = generate_ir(runtime_ir, self.settings.optimize) + return deploy_venom, runtime_venom @cached_property def assembly(self) -> list: diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 570aba771a..d1c2d0c342 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -12,7 +12,7 @@ ir_pass_remove_unreachable_blocks, ) from vyper.venom.function import IRFunction -from vyper.venom.ir_node_to_venom import convert_ir_basicblock +from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation from vyper.venom.passes.dft import DFTPass from vyper.venom.venom_to_assembly import VenomCompiler @@ -61,11 +61,9 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: break -def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> tuple[IRFunction, IRFunction]: +def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> IRFunction: # Convert "old" IR to "new" IR - ctx, ctx_runtime = convert_ir_basicblock(ir) - + ctx = ir_node_to_venom(ir) _run_passes(ctx, optimize) - _run_passes(ctx_runtime, optimize) - return ctx, ctx_runtime + return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index eed579463e..daebd2560c 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -42,10 +42,12 @@ def _reset_liveness(ctx: IRFunction) -> None: inst.liveness = OrderedSet() -def _calculate_liveness_bb(bb: IRBasicBlock) -> None: +def _calculate_liveness(bb: IRBasicBlock) -> bool: """ Compute liveness of each instruction in the basic block. + Returns True if liveness changed """ + orig_liveness = bb.instructions[0].liveness.copy() liveness = bb.out_vars.copy() for instruction in reversed(bb.instructions): ops = instruction.get_inputs() @@ -60,29 +62,31 @@ def _calculate_liveness_bb(bb: IRBasicBlock) -> None: liveness.remove(out) instruction.liveness = liveness + return orig_liveness != bb.instructions[0].liveness -def _calculate_liveness_r(bb: IRBasicBlock, visited: dict) -> None: - assert isinstance(visited, dict) - for out_bb in bb.cfg_out: - if visited.get(bb) == out_bb: - continue - visited[bb] = out_bb - - # recurse - _calculate_liveness_r(out_bb, visited) +def _calculate_out_vars(bb: IRBasicBlock) -> bool: + """ + Compute out_vars of basic block. + Returns True if out_vars changed + """ + out_vars = bb.out_vars.copy() + for out_bb in bb.cfg_out: target_vars = input_vars_from(bb, out_bb) - - # the output stack layout for bb. it produces a stack layout - # which works for all possible cfg_outs from the bb. bb.out_vars = bb.out_vars.union(target_vars) - - _calculate_liveness_bb(bb) + return out_vars != bb.out_vars def calculate_liveness(ctx: IRFunction) -> None: _reset_liveness(ctx) - _calculate_liveness_r(ctx.basic_blocks[0], dict()) + while True: + changed = False + for bb in ctx.basic_blocks: + changed |= _calculate_out_vars(bb) + changed |= _calculate_liveness(bb) + + if not changed: + break # calculate the input variables into self from source diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 598b8af7d5..f86e9b330c 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -410,8 +410,8 @@ def copy(self): def __repr__(self) -> str: s = ( f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" - f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars} \n" + f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars}\n" ) for instruction in self.instructions: - s += f" {instruction}\n" + s += f" {str(instruction).strip()}\n" return s diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 9f26fa8ec0..771dcf73ce 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -163,4 +163,4 @@ def __repr__(self) -> str: str += "Data segment:\n" for inst in self.data_segment: str += f"{inst}\n" - return str + return str.strip() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index c86d3a3d67..6b47ac2415 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -87,35 +87,21 @@ def _get_symbols_common(a: dict, b: dict) -> dict: return ret -def _findIRnode(ir: IRnode, value: str) -> Optional[IRnode]: - if ir.value == value: - return ir - for arg in ir.args: - if isinstance(arg, IRnode): - ret = _findIRnode(arg, value) - if ret is not None: - return ret - return None - - -def convert_ir_basicblock(ir: IRnode) -> tuple[IRFunction, IRFunction]: - deploy_ir = _findIRnode(ir, "deploy") - assert deploy_ir is not None - - deploy_venom = IRFunction() - _convert_ir_basicblock(deploy_venom, ir, {}, OrderedSet(), {}) - deploy_venom.get_basic_block().append_instruction("stop") - - runtime_ir = deploy_ir.args[1] - runtime_venom = IRFunction() - _convert_ir_basicblock(runtime_venom, runtime_ir, {}, OrderedSet(), {}) - - # Connect unterminated blocks to the next with a jump - for i, bb in enumerate(runtime_venom.basic_blocks): - if not bb.is_terminated and i < len(runtime_venom.basic_blocks) - 1: - bb.append_instruction("jmp", runtime_venom.basic_blocks[i + 1].label) +# convert IRnode directly to venom +def ir_node_to_venom(ir: IRnode) -> IRFunction: + ctx = IRFunction() + _convert_ir_bb(ctx, ir, {}, OrderedSet(), {}) + + # Patch up basic blocks. Connect unterminated blocks to the next with + # a jump. terminate final basic block with STOP. + for i, bb in enumerate(ctx.basic_blocks): + if not bb.is_terminated: + if i < len(ctx.basic_blocks) - 1: + bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) + else: + bb.append_instruction("stop") - return deploy_venom, runtime_venom + return ctx def _convert_binary_op( @@ -127,10 +113,10 @@ def _convert_binary_op( swap: bool = False, ) -> Optional[IRVariable]: ir_args = ir.args[::-1] if swap else ir.args - arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) + arg_0, arg_1 = _convert_ir_bb_list(ctx, ir_args, symbols, variables, allocated_variables) - return ctx.get_basic_block().append_instruction(str(ir.value), arg_1, arg_0) + assert isinstance(ir.value, str) # mypy hint + return ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: @@ -165,14 +151,12 @@ def _handle_self_call( if arg.is_literal: sym = symbols.get(f"&{arg.value}", None) if sym is None: - ret = _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ret = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) ret_args.append(ret) else: ret_args.append(sym) # type: ignore else: - ret = _convert_ir_basicblock( - ctx, arg._optimized, symbols, variables, allocated_variables - ) + ret = _convert_ir_bb(ctx, arg._optimized, symbols, variables, allocated_variables) if arg.location and arg.location.load_op == "calldataload": bb = ctx.get_basic_block() ret = bb.append_instruction(arg.location.load_op, ret) @@ -225,9 +209,7 @@ def _convert_ir_simple_node( variables: OrderedSet, allocated_variables: dict[str, IRVariable], ) -> Optional[IRVariable]: - args = [ - _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args - ] + args = [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore @@ -266,7 +248,16 @@ def _append_return_for_stack_operand( bb.append_instruction("return", last_ir, new_var) # type: ignore -def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): +def _convert_ir_bb_list(ctx, ir, symbols, variables, allocated_variables): + ret = [] + for ir_node in ir: + venom = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) + ret.append(venom) + return ret + + +def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): + assert isinstance(ir, IRnode) assert isinstance(variables, OrderedSet) global _break_target, _continue_target @@ -314,35 +305,22 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ret = None for ir_node in ir.args: # NOTE: skip the last one - ret = _convert_ir_basicblock(ctx, ir_node, symbols, variables, allocated_variables) + ret = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) return ret elif ir.value in ["staticcall", "call"]: # external call idx = 0 - gas = _convert_ir_basicblock(ctx, ir.args[idx], symbols, variables, allocated_variables) - address = _convert_ir_basicblock( - ctx, ir.args[idx + 1], symbols, variables, allocated_variables - ) + gas = _convert_ir_bb(ctx, ir.args[idx], symbols, variables, allocated_variables) + address = _convert_ir_bb(ctx, ir.args[idx + 1], symbols, variables, allocated_variables) value = None if ir.value == "call": - value = _convert_ir_basicblock( - ctx, ir.args[idx + 2], symbols, variables, allocated_variables - ) + value = _convert_ir_bb(ctx, ir.args[idx + 2], symbols, variables, allocated_variables) else: idx -= 1 - argsOffset = _convert_ir_basicblock( - ctx, ir.args[idx + 3], symbols, variables, allocated_variables - ) - argsSize = _convert_ir_basicblock( - ctx, ir.args[idx + 4], symbols, variables, allocated_variables - ) - retOffset = _convert_ir_basicblock( - ctx, ir.args[idx + 5], symbols, variables, allocated_variables - ) - retSize = _convert_ir_basicblock( - ctx, ir.args[idx + 6], symbols, variables, allocated_variables + argsOffset, argsSize, retOffset, retSize = _convert_ir_bb_list( + ctx, ir.args[idx + 3 : idx + 7], symbols, variables, allocated_variables ) if isinstance(argsOffset, IRLiteral): @@ -374,10 +352,10 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): return bb.append_instruction(ir.value, *args) elif ir.value == "if": cond = ir.args[0] - current_bb = ctx.get_basic_block() # convert the condition - cont_ret = _convert_ir_basicblock(ctx, cond, symbols, variables, allocated_variables) + cont_ret = _convert_ir_bb(ctx, cond, symbols, variables, allocated_variables) + current_bb = ctx.get_basic_block() else_block = IRBasicBlock(ctx.get_next_label(), ctx) ctx.append_basic_block(else_block) @@ -386,42 +364,44 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): else_ret_val = None else_syms = symbols.copy() if len(ir.args) == 3: - else_ret_val = _convert_ir_basicblock( + else_ret_val = _convert_ir_bb( ctx, ir.args[2], else_syms, variables, allocated_variables.copy() ) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) after_else_syms = else_syms.copy() + else_block = ctx.get_basic_block() # convert "then" then_block = IRBasicBlock(ctx.get_next_label(), ctx) ctx.append_basic_block(then_block) - then_ret_val = _convert_ir_basicblock( - ctx, ir.args[1], symbols, variables, allocated_variables - ) + then_ret_val = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) if isinstance(then_ret_val, IRLiteral): then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) current_bb.append_instruction("jnz", cont_ret, then_block.label, else_block.label) after_then_syms = symbols.copy() + then_block = ctx.get_basic_block() # exit bb exit_label = ctx.get_next_label() - bb = IRBasicBlock(exit_label, ctx) - bb = ctx.append_basic_block(bb) + exit_bb = IRBasicBlock(exit_label, ctx) + exit_bb = ctx.append_basic_block(exit_bb) if_ret = None if then_ret_val is not None and else_ret_val is not None: - if_ret = bb.append_instruction( + if_ret = exit_bb.append_instruction( "phi", then_block.label, then_ret_val, else_block.label, else_ret_val ) common_symbols = _get_symbols_common(after_then_syms, after_else_syms) for sym, val in common_symbols.items(): - ret = bb.append_instruction("phi", then_block.label, val[0], else_block.label, val[1]) + ret = exit_bb.append_instruction( + "phi", then_block.label, val[0], else_block.label, val[1] + ) old_var = symbols.get(sym, None) symbols[sym] = ret if old_var is not None: @@ -430,15 +410,15 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): allocated_variables[idx] = ret # type: ignore if not else_block.is_terminated: - else_block.append_instruction("jmp", bb.label) + else_block.append_instruction("jmp", exit_bb.label) if not then_block.is_terminated: - then_block.append_instruction("jmp", bb.label) + then_block.append_instruction("jmp", exit_bb.label) return if_ret elif ir.value == "with": - ret = _convert_ir_basicblock( + ret = _convert_ir_bb( ctx, ir.args[1], symbols, variables, allocated_variables ) # initialization @@ -452,27 +432,25 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): else: with_symbols[sym.value] = ret # type: ignore - return _convert_ir_basicblock( - ctx, ir.args[2], with_symbols, variables, allocated_variables - ) # body + return _convert_ir_bb(ctx, ir.args[2], with_symbols, variables, allocated_variables) # body elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "djump": - args = [_convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables)] + args = [_convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables)] for target in ir.args[1:]: args.append(IRLabel(target.value)) ctx.get_basic_block().append_instruction("djmp", *args) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols, variables, allocated_variables) new_var = ctx.get_basic_block().append_instruction("store", arg_1) # type: ignore symbols[sym.value] = new_var elif ir.value == "calldatacopy": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + arg_0, arg_1, size = _convert_ir_bb_list( + ctx, ir.args, symbols, variables, allocated_variables + ) new_v = arg_0 var = ( @@ -492,9 +470,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): return new_v elif ir.value == "codecopy": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + arg_0, arg_1, size = _convert_ir_bb_list( + ctx, ir.args, symbols, variables, allocated_variables + ) ctx.get_basic_block().append_instruction("codecopy", size, arg_1, arg_0) # type: ignore elif ir.value == "symbol": @@ -509,10 +487,10 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif isinstance(c, bytes): ctx.append_data("db", [c]) # type: ignore elif isinstance(c, IRnode): - data = _convert_ir_basicblock(ctx, c, symbols, variables, allocated_variables) + data = _convert_ir_bb(ctx, c, symbols, variables, allocated_variables) ctx.append_data("db", [data]) # type: ignore elif ir.value == "assert": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) current_bb = ctx.get_basic_block() current_bb.append_instruction("assert", arg_0) elif ir.value == "label": @@ -522,7 +500,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): bb.append_instruction("jmp", label) bb = IRBasicBlock(label, ctx) ctx.append_basic_block(bb) - _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + _convert_ir_bb(ctx, ir.args[2], symbols, variables, allocated_variables) elif ir.value == "exit_to": func_t = ir.passthrough_metadata.get("func_t", None) assert func_t is not None, "exit_to without func_t" @@ -545,15 +523,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): deleted = symbols[f"&{ret_var.value}"] del symbols[f"&{ret_var.value}"] for arg in ir.args[2:]: - last_ir = _convert_ir_basicblock( - ctx, arg, symbols, variables, allocated_variables - ) + last_ir = _convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) if deleted is not None: symbols[f"&{ret_var.value}"] = deleted - ret_ir = _convert_ir_basicblock( - ctx, ret_var, symbols, variables, allocated_variables - ) + ret_ir = _convert_ir_bb(ctx, ret_var, symbols, variables, allocated_variables) bb = ctx.get_basic_block() @@ -612,12 +586,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): bb.append_instruction("ret", ret_by_value, symbols["return_pc"]) elif ir.value == "revert": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) ctx.get_basic_block().append_instruction("revert", arg_1, arg_0) elif ir.value == "dload": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) bb = ctx.get_basic_block() src = bb.append_instruction("add", arg_0, IRLabel("code_end")) @@ -625,11 +598,10 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) elif ir.value == "dloadbytes": - dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - src_offset = _convert_ir_basicblock( - ctx, ir.args[1], symbols, variables, allocated_variables + dst, src_offset, len_ = _convert_ir_bb_list( + ctx, ir.args, symbols, variables, allocated_variables ) - len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + bb = ctx.get_basic_block() src = bb.append_instruction("add", src_offset, IRLabel("code_end")) bb.append_instruction("dloadbytes", len_, src, dst) @@ -678,9 +650,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): else: return bb.append_instruction("mload", sym_ir.value) else: - new_var = _convert_ir_basicblock( - ctx, sym_ir, symbols, variables, allocated_variables - ) + new_var = _convert_ir_bb(ctx, sym_ir, symbols, variables, allocated_variables) # # Old IR gets it's return value as a reference in the stack # New IR gets it's return value in stack in case of 32 bytes or less @@ -692,8 +662,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): return bb.append_instruction("mload", new_var) elif ir.value == "mstore": - sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + sym_ir, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) bb = ctx.get_basic_block() @@ -742,11 +711,10 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): return arg_1 elif ir.value in ["sload", "iload"]: - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) return ctx.get_basic_block().append_instruction(ir.value, arg_0) elif ir.value in ["sstore", "istore"]: - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + arg_0, arg_1 = _convert_ir_bb_list(ctx, ir.args, symbols, variables, allocated_variables) ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) elif ir.value == "unique_symbol": sym = ir.args[0] @@ -763,18 +731,18 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): # 5) increment block # 6) exit block # TODO: Add the extra bounds check after clarify - def emit_body_block(): + def emit_body_blocks(): global _break_target, _continue_target old_targets = _break_target, _continue_target _break_target, _continue_target = exit_block, increment_block - _convert_ir_basicblock(ctx, body, symbols, variables, allocated_variables) + _convert_ir_bb(ctx, body, symbols, variables, allocated_variables) _break_target, _continue_target = old_targets sym = ir.args[0] - start = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - end = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - # "bound" is not used - _ = _convert_ir_basicblock(ctx, ir.args[3], symbols, variables, allocated_variables) + start, end, _ = _convert_ir_bb_list( + ctx, ir.args[1:4], symbols, variables, allocated_variables + ) + body = ir.args[4] entry_block = ctx.get_basic_block() @@ -799,10 +767,9 @@ def emit_body_block(): cont_ret = cond_block.append_instruction("iszero", xor_ret) ctx.append_basic_block(cond_block) - # Do a dry run to get the symbols needing phi nodes start_syms = symbols.copy() ctx.append_basic_block(body_block) - emit_body_block() + emit_body_blocks() end_syms = symbols.copy() diff_syms = _get_symbols_common(start_syms, end_syms) @@ -828,8 +795,9 @@ def emit_body_block(): jump_up_block.append_instruction("jmp", increment_block.label) ctx.append_basic_block(jump_up_block) - increment_block.append_instruction(IRInstruction("add", ret, 1)) - increment_block.insert_instruction[-1].output = counter_inc_var + increment_block.insert_instruction( + IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var), 0 + ) increment_block.append_instruction("jmp", cond_block.label) ctx.append_basic_block(increment_block) @@ -851,23 +819,20 @@ def emit_body_block(): return ctx.get_basic_block().append_instruction("returndatasize") elif ir.value == "returndatacopy": assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + arg_0, arg_1, size = _convert_ir_bb_list( + ctx, ir.args, symbols, variables, allocated_variables + ) new_var = ctx.get_basic_block().append_instruction("returndatacopy", arg_1, size) symbols[f"&{arg_0.value}"] = new_var return new_var elif ir.value == "selfdestruct": - arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) ctx.get_basic_block().append_instruction("selfdestruct", arg_0) elif isinstance(ir.value, str) and ir.value.startswith("log"): args = reversed( - [ - _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) - for arg in ir.args - ] + [_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args] ) topic_count = int(ir.value[3:]) assert topic_count >= 0 and topic_count <= 4, "invalid topic count" @@ -895,9 +860,7 @@ def _convert_ir_opcode( inst_args = [] for arg in ir.args: if isinstance(arg, IRnode): - inst_args.append( - _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) - ) + inst_args.append(_convert_ir_bb(ctx, arg, symbols, variables, allocated_variables)) ctx.get_basic_block().append_instruction(opcode, *inst_args) From 417774542f04a19c12d3787c6a3e550b64a6bf25 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 Jan 2024 05:28:12 -0800 Subject: [PATCH 166/201] feat: add more venom instructions (#3733) - add support for more IRnode instructions: * not * ceil32 * select * blockhash - improve assert reason for missed cases - fix CFG_ALTERING_INSTRUCTIONS - invoke/call/staticcall should not change the CFG! --- vyper/venom/basicblock.py | 2 +- vyper/venom/ir_node_to_venom.py | 19 ++++++++++++++++--- vyper/venom/venom_to_assembly.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index f86e9b330c..ed70a5eaa0 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -55,7 +55,7 @@ ] ) -CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "djmp", "jnz", "call", "staticcall", "invoke"]) +CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "djmp", "jnz"]) if TYPE_CHECKING: from vyper.venom.function import IRFunction diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 6b47ac2415..396abaf5f7 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -51,6 +51,7 @@ "chainid", "basefee", "timestamp", + "blockhash", "caller", "selfbalance", "calldatasize", @@ -65,7 +66,7 @@ "coinbase", "number", "iszero", - "ceil32", + "not", "calldataload", "extcodesize", "extcodehash", @@ -252,12 +253,13 @@ def _convert_ir_bb_list(ctx, ir, symbols, variables, allocated_variables): ret = [] for ir_node in ir: venom = _convert_ir_bb(ctx, ir_node, symbols, variables, allocated_variables) + assert venom is not None, ir_node ret.append(venom) return ret def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): - assert isinstance(ir, IRnode) + assert isinstance(ir, IRnode), ir assert isinstance(variables, OrderedSet) global _break_target, _continue_target @@ -631,7 +633,9 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): if sym_ir.is_literal: sym = symbols.get(f"&{sym_ir.value}", None) if sym is None: - new_var = bb.append_instruction("store", sym_ir) + new_var = _convert_ir_bb( + ctx, sym_ir, symbols, variables, allocated_variables + ) symbols[f"&{sym_ir.value}"] = new_var if allocated_variables.get(var.name, None) is None: allocated_variables[var.name] = new_var @@ -709,6 +713,15 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): else: symbols[sym_ir.value] = arg_1 return arg_1 + elif ir.value == "ceil32": + x = ir.args[0] + expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) + return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) + elif ir.value == "select": + # b ^ ((a ^ b) * cond) where cond is 1 or 0 + cond, a, b = ir.args + expanded = IRnode.from_list(["xor", b, ["mul", cond, ["xor", a, b]]]) + return _convert_ir_bb(ctx, expanded, symbols, variables, allocated_variables) elif ir.value in ["sload", "iload"]: arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols, variables, allocated_variables) diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 926f8df8a3..608e100cd1 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -58,6 +58,7 @@ "exp", "eq", "iszero", + "not", "lg", "lt", "slt", From 56b0d4fc3cb110ea9a53f3ca3affb4e758213d22 Mon Sep 17 00:00:00 2001 From: trocher Date: Thu, 18 Jan 2024 15:11:31 +0100 Subject: [PATCH 167/201] fix: `opcodes` and `opcodes_runtime` outputs (#3735) Fixed the opcodes and opcodes_runtime outputs as they would not respectively match the bytecode and bytecode_runtime outputs. The value following a PUSH instruction could be incorrect. For example, when compiling some Vyper code that results in `PUSH2 0x0100`: - The bytecode output would be `610100` - The opcodes output would be `PUSH2 0x10` instead of `PUSH2 0x0100` Note this commit fixes an ambiguity here, as prior to this commit, `610100` (`PUSH2 0x0100`) and `610010` (`PUSH2 0x0010`) would both get formatted as `PUSH2 0x10`. This issue is due to a lack of 0 padding. --- tests/unit/compiler/test_opcodes.py | 15 +++++++++++++++ vyper/compiler/output.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/unit/compiler/test_opcodes.py b/tests/unit/compiler/test_opcodes.py index 15d2a617ba..ed64f343c4 100644 --- a/tests/unit/compiler/test_opcodes.py +++ b/tests/unit/compiler/test_opcodes.py @@ -1,6 +1,7 @@ import pytest import vyper +from vyper.compiler.output import _build_opcodes from vyper.evm import opcodes from vyper.exceptions import CompilerPanic @@ -64,3 +65,17 @@ def test_get_opcodes(evm_version): else: for op in ("TLOAD", "TSTORE", "MCOPY"): assert op not in ops + + +def test_build_opcodes(): + assert _build_opcodes(bytes.fromhex("610250")) == "PUSH2 0x0250" + assert _build_opcodes(bytes.fromhex("612500")) == "PUSH2 0x2500" + assert _build_opcodes(bytes.fromhex("610100")) == "PUSH2 0x0100" + assert _build_opcodes(bytes.fromhex("611000")) == "PUSH2 0x1000" + assert _build_opcodes(bytes.fromhex("62010300")) == "PUSH3 0x010300" + assert ( + _build_opcodes( + bytes.fromhex("7f6100000000000000000000000000000000000000000000000000000000000000") + ) + == "PUSH32 0x6100000000000000000000000000000000000000000000000000000000000000" + ) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 5e11a20139..b9ce39ff08 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -330,7 +330,7 @@ def _build_opcodes(bytecode: bytes) -> str: # (instead of code) at end of contract # CMC 2023-07-13 maybe just strip known data segments? push_len = min(push_len, len(bytecode_sequence)) - push_values = [hex(bytecode_sequence.popleft())[2:] for i in range(push_len)] - opcode_output.append(f"0x{''.join(push_values).upper()}") + push_values = [f"{bytecode_sequence.popleft():0>2X}" for i in range(push_len)] + opcode_output.append(f"0x{''.join(push_values)}") return " ".join(opcode_output) From 55e18f6d128b2da8986adbbcccf1cd59a4b9ad6f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 18 Jan 2024 10:08:03 -0800 Subject: [PATCH 168/201] fix: concat buffer bug (#3738) the `concat()` builtin did not respect the `copy_bytes()` API, it allocated a buffer in some cases which did not have enough padding. patches GHSA-2q8v-3gqq-4f8p --------- Co-authored-by: cyberthirst --- .../builtins/codegen/test_concat.py | 64 +++++++++++++++++++ vyper/builtins/functions.py | 11 ++-- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/tests/functional/builtins/codegen/test_concat.py b/tests/functional/builtins/codegen/test_concat.py index 5558138551..7354515989 100644 --- a/tests/functional/builtins/codegen/test_concat.py +++ b/tests/functional/builtins/codegen/test_concat.py @@ -55,6 +55,70 @@ def krazykonkat(z: Bytes[10]) -> Bytes[25]: print("Passed third concat test") +def test_concat_buffer(get_contract): + # GHSA-2q8v-3gqq-4f8p + code = """ +@internal +def bar() -> uint256: + sss: String[2] = concat("a", "b") + return 1 + + +@external +def foo() -> int256: + a: int256 = -1 + b: uint256 = self.bar() + return a + """ + c = get_contract(code) + assert c.foo() == -1 + + +def test_concat_buffer2(get_contract): + # GHSA-2q8v-3gqq-4f8p + code = """ +i: immutable(int256) + +@external +def __init__(): + i = -1 + s: String[2] = concat("a", "b") + +@external +def foo() -> int256: + return i + """ + c = get_contract(code) + assert c.foo() == -1 + + +def test_concat_buffer3(get_contract): + # GHSA-2q8v-3gqq-4f8p + code = """ +s: String[1] +s2: String[33] +s3: String[34] + +@external +def __init__(): + self.s = "a" + self.s2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" # 33*'a' + +@internal +def bar() -> uint256: + self.s3 = concat(self.s, self.s2) + return 1 + +@external +def foo() -> int256: + i: int256 = -1 + b: uint256 = self.bar() + return i + """ + c = get_contract(code) + assert c.foo() == -1 + + def test_concat_bytes32(get_contract_with_gas_estimation): test_concat_bytes32 = """ @external diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 4f8101dfbe..8ee6f5fd76 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -543,13 +543,12 @@ def build_IR(self, expr, context): else: ret_typ = BytesT(dst_maxlen) + # respect API of copy_bytes + bufsize = dst_maxlen + 32 + buf = context.new_internal_variable(BytesT(bufsize)) + # Node representing the position of the output in memory - dst = IRnode.from_list( - context.new_internal_variable(ret_typ), - typ=ret_typ, - location=MEMORY, - annotation="concat destination", - ) + dst = IRnode.from_list(buf, typ=ret_typ, location=MEMORY, annotation="concat destination") ret = ["seq"] # stack item representing our current offset in the dst buffer From c150fc49ee9375a930d177044559b83cb95f7963 Mon Sep 17 00:00:00 2001 From: Thabokani <149070269+Thabokani@users.noreply.github.com> Date: Tue, 30 Jan 2024 02:53:17 +0800 Subject: [PATCH 169/201] chore(docs): fix typos (#3749) --- examples/tokens/ERC721.vy | 2 +- tests/functional/syntax/exceptions/test_invalid_payable.py | 2 +- vyper/semantics/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 152b94b046..9c86575ea3 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -28,7 +28,7 @@ interface ERC721Receiver: # transfer, the approved address for that NFT (if any) is reset to none. # @param _from Sender of NFT (if address is zero address it indicates token creation). # @param _to Receiver of NFT (if address is zero address it indicates token destruction). -# @param _tokenId The NFT that got transfered. +# @param _tokenId The NFT that got transferred. event Transfer: sender: indexed(address) receiver: indexed(address) diff --git a/tests/functional/syntax/exceptions/test_invalid_payable.py b/tests/functional/syntax/exceptions/test_invalid_payable.py index 4d8142fca2..9d0d942b0b 100644 --- a/tests/functional/syntax/exceptions/test_invalid_payable.py +++ b/tests/functional/syntax/exceptions/test_invalid_payable.py @@ -14,7 +14,7 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) -def test_variable_decleration_exception(bad_code): +def test_variable_declaration_exception(bad_code): with raises(NonPayableViolation): compiler.compile_code(bad_code) diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md index 36519bba29..3b7acf9469 100644 --- a/vyper/semantics/README.md +++ b/vyper/semantics/README.md @@ -206,7 +206,7 @@ function. 2. We call `fetch_call_return` on the function definition object, with the AST node representing the call. This method validates the input arguments, and returns a `BytesM_T` with `m=32`. -3. We validation of the delcaration of `bar` in the same manner as the first +3. We validation of the declaration of `bar` in the same manner as the first example, and compare the generated type to that returned by `sha256`. ### Exceptions From 9002ed2b99e7dd5b69bce245404536cc9a63da48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Schoentgen?= Date: Tue, 30 Jan 2024 15:50:20 +0100 Subject: [PATCH 170/201] docs: fix typos (#3747) * docs: Fix typos * fix tests --- SECURITY.md | 2 +- docs/index.rst | 2 +- docs/release-notes.rst | 8 ++++---- docs/style-guide.rst | 4 ++-- docs/versioning.rst | 2 +- examples/factory/Factory.vy | 2 +- tests/functional/builtins/codegen/test_send.py | 6 +++--- .../examples/auctions/test_simple_open_auction.py | 2 +- .../safe_remote_purchase/test_safe_remote_purchase.py | 2 +- .../syntax/exceptions/test_structure_exception.py | 2 +- tests/functional/syntax/test_constants.py | 2 +- tests/unit/ast/test_natspec.py | 2 +- tests/unit/semantics/test_storage_slots.py | 8 ++++---- vyper/ast/README.md | 2 +- vyper/ast/grammar.lark | 2 +- vyper/ast/identifiers.py | 2 +- vyper/ast/metadata.py | 2 +- vyper/codegen/context.py | 4 ++-- vyper/codegen/memory_allocator.py | 2 +- vyper/semantics/types/bytestrings.py | 4 ++-- vyper/semantics/types/function.py | 2 +- vyper/venom/README.md | 2 +- vyper/venom/ir_node_to_venom.py | 4 ++-- 23 files changed, 35 insertions(+), 35 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 0a054b2c93..24e275f614 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -9,7 +9,7 @@ It is un-audited software, use with caution. ## Audit reports Vyper is constantly changing and improving. -This means the lastest version available may not be audited. +This means the latest version available may not be audited. We try to ensure the highest security code possible, but occasionally things slip through. ### Compiler Audits diff --git a/docs/index.rst b/docs/index.rst index 76ad6fbd7d..69d818cd69 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ Following the principles and goals, Vyper **does not** provide the following fea * **Modifiers**: For example in Solidity you can define a ``function foo() mod1 { ... }``, where ``mod1`` can be defined elsewhere in the code to include a check that is done before execution, a check that is done after execution, some state changes, or possibly other things. Vyper does not have this, because it makes it too easy to write misleading code. ``mod1`` just looks too innocuous for something that could add arbitrary pre-conditions, post-conditions or state changes. Also, it encourages people to write code where the execution jumps around the file, harming auditability. The usual use case for a modifier is something that performs a single check before execution of a program; our recommendation is to simply inline these checks as asserts. * **Class inheritance**: Class inheritance requires people to jump between multiple files to understand what a program is doing, and requires people to understand the rules of precedence in case of conflicts ("Which class's function ``X`` is the one that's actually used?"). Hence, it makes code too complicated to understand which negatively impacts auditability. * **Inline assembly**: Adding inline assembly would make it no longer possible to search for a variable name in order to find all instances where that variable is read or modified. -* **Function overloading**: This can cause lots of confusion on which function is called at any given time. Thus it's easier to write missleading code (``foo("hello")`` logs "hello" but ``foo("hello", "world")`` steals your funds). Another problem with function overloading is that it makes the code much harder to search through as you have to keep track on which call refers to which function. +* **Function overloading**: This can cause lots of confusion on which function is called at any given time. Thus it's easier to write misleading code (``foo("hello")`` logs "hello" but ``foo("hello", "world")`` steals your funds). Another problem with function overloading is that it makes the code much harder to search through as you have to keep track on which call refers to which function. * **Operator overloading**: Operator overloading makes writing misleading code possible. For example ``+`` could be overloaded so that it executes commands that are not visible at a first glance, such as sending funds the user did not want to send. * **Recursive calling**: Recursive calling makes it impossible to set an upper bound on gas limits, opening the door for gas limit attacks. * **Infinite-length loops**: Similar to recursive calling, infinite-length loops make it impossible to set an upper bound on gas limits, opening the door for gas limit attacks. diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 3db11dc451..df0a02a76a 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -600,7 +600,7 @@ Fixes: - Memory corruption issue when performing function calls inside a tuple or another function call (`#2186 `_) - Incorrect function output when using multidimensional arrays (`#2184 `_) -- Reduced ambiguity bewteen ``address`` and ``Bytes[20]`` (`#2191 `_) +- Reduced ambiguity between ``address`` and ``Bytes[20]`` (`#2191 `_) v0.2.5 ****** @@ -684,7 +684,7 @@ Breaking changes: - ``@public`` and ``@private`` function decorators have been renamed to ``@external`` and ``@internal`` (VIP `#2065 `_) - The ``@constant`` decorator has been renamed to ``@view`` (VIP `#2040 `_) - Type units have been removed (VIP `#1881 `_) -- Event declaraion syntax now resembles that of struct declarations (VIP `#1864 `_) +- Event declaration syntax now resembles that of struct declarations (VIP `#1864 `_) - ``log`` is now a statement (VIP `#1864 `_) - Mapping declaration syntax changed to ``HashMap[key_type, value_type]`` (VIP `#1969 `_) - Interfaces are now declared via the ``interface`` keyword instead of ``contract`` (VIP `#1825 `_) @@ -823,7 +823,7 @@ Some of the bug and stability fixes: - Fixed stack valency issues in if and for statements (`#1665 `_) - Prevent overflow when using ``sqrt`` on certain datatypes (`#1679 `_) - Prevent shadowing of internal variables (`#1601 `_) -- Reject unary substraction on unsigned types (`#1638 `_) +- Reject unary subtraction on unsigned types (`#1638 `_) - Disallow ``orelse`` syntax in ``for`` loops (`#1633 `_) - Increased clarity and efficiency of zero-padding (`#1605 `_) @@ -928,7 +928,7 @@ Here is the old changelog: * **2019.03.04**: ``create_with_code_of`` has been renamed to ``create_forwarder_to``. (`#1177 `_) * **2019.02.14**: Assigning a persistent contract address can only be done using the ``bar_contact = ERC20(
)`` syntax. * **2019.02.12**: ERC20 interface has to be imported using ``from vyper.interfaces import ERC20`` to use. -* **2019.01.30**: Byte array literals need to be annoted using ``b""``, strings are represented as `""`. +* **2019.01.30**: Byte array literals need to be annotated using ``b""``, strings are represented as `""`. * **2018.12.12**: Disallow use of ``None``, disallow use of ``del``, implemented ``clear()`` built-in function. * **2018.11.19**: Change mapping syntax to use ``map()``. (`VIP564 `_) * **2018.10.02**: Change the convert style to use types instead of string. (`VIP1026 `_) diff --git a/docs/style-guide.rst b/docs/style-guide.rst index 1b98b770b9..10869076eb 100644 --- a/docs/style-guide.rst +++ b/docs/style-guide.rst @@ -94,7 +94,7 @@ Internal Imports Internal imports are those between two modules inside the same Vyper package. - * Internal imports **may** use either ``import`` or ``from ..`` syntax. The imported value **shoould** be a module, not an object. Importing modules instead of objects avoids circular dependency issues. + * Internal imports **may** use either ``import`` or ``from ..`` syntax. The imported value **should** be a module, not an object. Importing modules instead of objects avoids circular dependency issues. * Internal imports **may** be aliased where it aids readability. * Internal imports **must** use absolute paths. Relative imports cause issues when the module is moved. @@ -250,7 +250,7 @@ Maintainers **may** request a rebase, or choose to squash merge pull requests t Conventional Commits -------------------- -Commit messages **should** adhere to the `Conventional Commits `_ standard. A convetional commit message is structured as follows: +Commit messages **should** adhere to the `Conventional Commits `_ standard. A conventional commit message is structured as follows: :: diff --git a/docs/versioning.rst b/docs/versioning.rst index 174714a6ac..7cbd503a64 100644 --- a/docs/versioning.rst +++ b/docs/versioning.rst @@ -22,7 +22,7 @@ on the type of user, so that users can stay informed about the progress of Vyper +=============+==============================================+ | Developers | Write smart contracts in Vyper | +-------------+----------------------------------------------+ -| Integrators | Integerating Vyper package or CLI into tools | +| Integrators | Integrating Vyper package or CLI into tools | +-------------+----------------------------------------------+ | Auditors | Aware of Vyper features and security issues | +-------------+----------------------------------------------+ diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index d08a5eb7ee..bb60f12331 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -21,7 +21,7 @@ def __init__(_exchange_codehash: bytes32): # For example, allowing the deployer of this contract to change this # value allows them to use a new contract if the old one has an issue. # This would trigger a cascade effect across all exchanges that would -# need to be handled appropiately. +# need to be handled appropriately. @external diff --git a/tests/functional/builtins/codegen/test_send.py b/tests/functional/builtins/codegen/test_send.py index 36f8979556..779c67e7e5 100644 --- a/tests/functional/builtins/codegen/test_send.py +++ b/tests/functional/builtins/codegen/test_send.py @@ -47,14 +47,14 @@ def __default__(): sender.test_send(receiver.address, transact={"gas": 100000}) - # no value transfer hapenned, variable was not changed + # no value transfer happened, variable was not changed assert receiver.last_sender() is None assert w3.eth.get_balance(sender.address) == 1 assert w3.eth.get_balance(receiver.address) == 0 sender.test_call(receiver.address, transact={"gas": 100000}) - # value transfer hapenned, variable was changed + # value transfer happened, variable was changed assert receiver.last_sender() == sender.address assert w3.eth.get_balance(sender.address) == 0 assert w3.eth.get_balance(receiver.address) == 1 @@ -88,7 +88,7 @@ def __default__(): sender.test_send_stipend(receiver.address, transact={"gas": 100000}) - # value transfer hapenned, variable was changed + # value transfer happened, variable was changed assert receiver.last_sender() == sender.address assert w3.eth.get_balance(sender.address) == 0 assert w3.eth.get_balance(receiver.address) == 1 diff --git a/tests/functional/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py index c80b44d976..46f34f31cd 100644 --- a/tests/functional/examples/auctions/test_simple_open_auction.py +++ b/tests/functional/examples/auctions/test_simple_open_auction.py @@ -80,7 +80,7 @@ def test_end_auction(w3, tester, auction_contract, tx_failed): with tx_failed(): auction_contract.endAuction() auction_contract.bid(transact={"value": 1 * 10**10, "from": k2}) - # Move block timestamp foreward to reach auction end time + # Move block timestamp forward to reach auction end time # tester.time_travel(tester.get_block_by_number('latest')['timestamp'] + EXPIRY) w3.testing.mine(EXPIRY) balance_before_end = w3.eth.get_balance(k1) diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index 2cc5dd8d4a..e21a113f61 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -32,7 +32,7 @@ def get_balance(): def test_initial_state(w3, tx_failed, get_contract, get_balance, contract_code): - # Inital deposit has to be divisible by two + # Initial deposit has to be divisible by two with tx_failed(): get_contract(contract_code, value=13) # Seller puts item up for sale diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index 97ac2b139d..c6d733fc90 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -139,7 +139,7 @@ def set_lucky(arg1: int128): pass """, """ interface Bar: -# invalud interface declaration (assignment) +# invalid interface declaration (assignment) def set_lucky(arg1: int128): arg1 = 1 arg1 = 3 diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 04e778a00e..987e39809f 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -42,7 +42,7 @@ """, InvalidType, ), - # reserverd keyword + # reserved keyword ( """ wei: constant(uint256) = 1 diff --git a/tests/unit/ast/test_natspec.py b/tests/unit/ast/test_natspec.py index 22167f8694..0b860e562e 100644 --- a/tests/unit/ast/test_natspec.py +++ b/tests/unit/ast/test_natspec.py @@ -186,7 +186,7 @@ def test_partial_natspec(): @external def foo(): ''' - Regular comments preceeding natspec is not allowed + Regular comments preceding natspec is not allowed @notice this is natspec ''' pass diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 002ee38cd2..ea2b2fe559 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -19,7 +19,7 @@ d: public(int128[4]) foo: public(HashMap[uint256, uint256[3]]) dyn_array: DynArray[uint256, 3] -e: public(String[47]) +e: public(String[48]) f: public(int256[1]) g: public(StructTwo[2]) h: public(int256[1]) @@ -31,7 +31,7 @@ def __init__(): self.b = [7, 8] self.c = b"thisisthirtytwobytesokhowdoyoudo" self.d = [-1, -2, -3, -4] - self.e = "A realllllly long string but we wont use it all" + self.e = "A realllllly long string but we won't use it all" self.f = [33] self.g = [ StructTwo({a: b"hello", b: [-66, 420], c: "another string"}), @@ -65,7 +65,7 @@ def test_storage_slots(get_contract): assert [c.b(i) for i in range(2)] == [7, 8] assert c.c() == b"thisisthirtytwobytesokhowdoyoudo" assert [c.d(i) for i in range(4)] == [-1, -2, -3, -4] - assert c.e() == "A realllllly long string but we wont use it all" + assert c.e() == "A realllllly long string but we won't use it all" assert c.f(0) == 33 assert c.g(0) == (b"hello", [-66, 420], "another string") assert c.g(1) == ( @@ -90,7 +90,7 @@ def test_reentrancy_lock(get_contract): assert [c.b(i) for i in range(2)] == [7, 8] assert c.c() == b"thisisthirtytwobytesokhowdoyoudo" assert [c.d(i) for i in range(4)] == [-1, -2, -3, -4] - assert c.e() == "A realllllly long string but we wont use it all" + assert c.e() == "A realllllly long string but we won't use it all" assert c.f(0) == 33 assert c.g(0) == (b"hello", [-66, 420], "another string") assert c.g(1) == ( diff --git a/vyper/ast/README.md b/vyper/ast/README.md index 7400091993..8ec1273fa8 100644 --- a/vyper/ast/README.md +++ b/vyper/ast/README.md @@ -42,7 +42,7 @@ Conversion between a Python Node and a Vyper Node uses the following rules: * The type of Vyper node is determined from the `ast_type` field of the Python node. * Fields listed in `__slots__` may be included and may have a value. * Fields listed in `_translated_fields` have their key modified prior to being added. -This is used to handle discrepencies in how nodes are structured between different +This is used to handle discrepancies in how nodes are structured between different Python versions. * Fields listed in `_only_empty_fields`, if present within the Python AST, must be `None` or a `SyntaxException` is raised. diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 234e96e552..84429501e1 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -237,7 +237,7 @@ _BITAND: "&" _BITOR: "|" _BITXOR: "^" -// Comparisions +// Comparisons _EQ: "==" _NE: "!=" _LE: "<=" diff --git a/vyper/ast/identifiers.py b/vyper/ast/identifiers.py index 7d42727066..9e2e76f223 100644 --- a/vyper/ast/identifiers.py +++ b/vyper/ast/identifiers.py @@ -99,7 +99,7 @@ def validate_identifier(attr, ast_node=None): "mwei", "twei", "pwei", - # sentinal constant values + # sentinel constant values # TODO remove when these are removed from the language "zero_address", "empty_bytes32", diff --git a/vyper/ast/metadata.py b/vyper/ast/metadata.py index 0a419c3732..b4bf6a53bd 100644 --- a/vyper/ast/metadata.py +++ b/vyper/ast/metadata.py @@ -51,7 +51,7 @@ def _commit_inner(self): outer = self._node_updates[-1] - # register with previous frame in case inner gets commited + # register with previous frame in case inner gets committed # but outer needs to be rolled back for (_, k), (metadata, prev) in inner.items(): if (id(metadata), k) not in outer: diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index dea30faabc..4f644841f4 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -83,7 +83,7 @@ def __init__( # Active scopes self._scopes = set() - # Memory alloctor, keeps track of currently allocated memory. + # Memory allocator, keeps track of currently allocated memory. # Not intended to be accessed directly self.memory_allocator = memory_allocator @@ -101,7 +101,7 @@ def check_is_not_constant(self, err, expr): if self.is_constant(): raise StateAccessViolation(f"Cannot {err} from {self.pp_constancy()}", expr) - # convenience propreties + # convenience properties @property def is_payable(self): return self.func_t.is_payable diff --git a/vyper/codegen/memory_allocator.py b/vyper/codegen/memory_allocator.py index b5e1212917..f31148825c 100644 --- a/vyper/codegen/memory_allocator.py +++ b/vyper/codegen/memory_allocator.py @@ -38,7 +38,7 @@ def partially_allocate(self, size: int) -> int: class MemoryAllocator: """ - Low-level memory alloctor. Used to allocate and de-allocate memory slots. + Low-level memory allocator. Used to allocate and de-allocate memory slots. This object should not be accessed directly. Memory allocation happens via declaring variables within `Context`. diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index e3c381ac69..2f342c613e 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -63,7 +63,7 @@ def validate_literal(self, node: vy_ast.Constant) -> None: if len(node.value) != self.length: # should always be constructed with correct length - # at the point that validate_literal is calle.d + # at the point that validate_literal is called raise CompilerPanic("unreachable") @property @@ -71,7 +71,7 @@ def size_in_bytes(self): # the first slot (32 bytes) stores the actual length, and then we reserve # enough additional slots to store the data if it uses the max available length # because this data type is single-bytes, we make it so it takes the max 32 byte - # boundary as it's size, instead of giving it a size that is not cleanly divisble by 32 + # boundary as it's size, instead of giving it a size that is not cleanly divisible by 32 return 32 + ceil32(self.length) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7c77560e49..2d92370b9d 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -189,7 +189,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - # FunctionDef with stateMutability in body (Interface defintions) + # FunctionDef with stateMutability in body (Interface definitions) body = funcdef.body if ( len(body) == 1 diff --git a/vyper/venom/README.md b/vyper/venom/README.md index a81f6c0582..5d98b22dd6 100644 --- a/vyper/venom/README.md +++ b/vyper/venom/README.md @@ -131,7 +131,7 @@ A Venom program may feature basic blocks with multiple CFG inputs and outputs. T ### Code emission -This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calcluates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. +This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calculates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. ## Future planned passes diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 396abaf5f7..b3ac3c1ad7 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -42,7 +42,7 @@ ] ) -# Instuctions that are mapped to their inverse +# Instructions that are mapped to their inverse INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} # Instructions that have a direct EVM opcode equivalent and can @@ -508,7 +508,7 @@ def _convert_ir_bb(ctx, ir, symbols, variables, allocated_variables): assert func_t is not None, "exit_to without func_t" if func_t.is_external: - # Hardcoded contructor special case + # Hardcoded constructor special case bb = ctx.get_basic_block() if func_t.name == "__init__": label = IRLabel(ir.args[0].value, True) From 768b3e9baa04264980ea2fc5600e2f2356b69d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Schoentgen?= Date: Tue, 30 Jan 2024 15:50:53 +0100 Subject: [PATCH 171/201] docs: upgrade dependencies and fixes (#3745) * docs: Upgrade dependencies * docs: fixes --- docs/_static/css/dark.css | 8 ++++++++ docs/conf.py | 4 ++-- docs/control-structures.rst | 4 ++-- docs/installing-vyper.rst | 2 +- docs/release-notes.rst | 6 +++--- docs/testing-contracts-ethtester.rst | 8 ++++---- docs/types.rst | 5 +++-- requirements-docs.txt | 6 +++--- 8 files changed, 26 insertions(+), 17 deletions(-) diff --git a/docs/_static/css/dark.css b/docs/_static/css/dark.css index cb96b428c8..158f08e0fc 100644 --- a/docs/_static/css/dark.css +++ b/docs/_static/css/dark.css @@ -20,6 +20,10 @@ a:visited { background-color: #2d2d2d !important; } +.descname { + color: inherit !important; +} + .rst-content dl:not(.docutils) dt { color: #aaddff; border-top: solid 3px #525252; @@ -164,6 +168,10 @@ pre { /* table of contents */ +.wy-body-for-nav { + background-color: rgb(26, 28, 29); +} + .wy-nav-content-wrap { background-color: rgba(0, 0, 0, 0.6) !important; } diff --git a/docs/conf.py b/docs/conf.py index 12af82b6e4..5dc1eee8f5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,7 +56,7 @@ # General information about the project. project = "Vyper" -copyright = "2017-2020 CC-BY-4.0 Vyper Team" +copyright = "2017-2024 CC-BY-4.0 Vyper Team" author = "Vyper Team (originally created by Vitalik Buterin)" # The version info for the project you're documenting, acts as replacement for @@ -190,5 +190,5 @@ intersphinx_mapping = { "brownie": ("https://eth-brownie.readthedocs.io/en/stable", None), "pytest": ("https://docs.pytest.org/en/latest/", None), - "python": ("https://docs.python.org/3.8/", None), + "python": ("https://docs.python.org/3.10/", None), } diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 2f890bcb2f..14202cbae7 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -125,7 +125,7 @@ You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can p The ``__default__`` Function --------------------------- +---------------------------- A contract can also have a default function, which is executed on a call to the contract if no other functions match the given function identifier (or if none was supplied at all, such as through someone sending it Eth). It is the same construct as fallback functions `in Solidity `_. @@ -165,7 +165,7 @@ Lastly, although the default function receives no arguments, it can still access * the gas provided (``msg.gas``). The ``__init__`` Function ------------------------ +------------------------- ``__init__`` is a special initialization function that may only be called at the time of deploying a contract. It can be used to set initial values for storage variables. A common use case is to set an ``owner`` variable with the creator the contract: diff --git a/docs/installing-vyper.rst b/docs/installing-vyper.rst index fb2849708d..8eaa93590a 100644 --- a/docs/installing-vyper.rst +++ b/docs/installing-vyper.rst @@ -79,10 +79,10 @@ To install a specific version use: pip install vyper==0.3.7 You can check if Vyper is installed completely or not by typing the following in your terminal/cmd: - :: vyper --version + nix *** diff --git a/docs/release-notes.rst b/docs/release-notes.rst index df0a02a76a..2572c2690d 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -51,7 +51,7 @@ Notable fixes: - update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_) - fix: raw_call type when max_outsize=0 is set (`#3572 `_) - fix: implements check for indexed event arguments (`#3570 `_) -- fix: type-checking for ``_abi_decode()`` arguments (`#3626 `_) +- fix: type-checking for ``_abi_decode()`` arguments (`#3626 `__) Other docs updates, chores and fixes: ------------------------------------- @@ -164,7 +164,7 @@ Other docs updates, chores and fixes: - fix docs of ``blockhash`` to reflect revert behaviour (`#3168 `_) - improvements to compiler error messages (`#3121 `_, `#3134 `_, `#3312 `_, `#3304 `_, `#3240 `_, `#3264 `_, `#3343 `_, `#3307 `_, `#3313 `_ and `#3215 `_) -These are really just the highlights, as many other bugfixes, docs updates and refactoring (over 150 pull requests!) made it into this release! For the full list, please see the `changelog `_. Special thanks to contributions from @tserg, @trocher, @z80dev, @emc415 and @benber86 in this release! +These are really just the highlights, as many other bugfixes, docs updates and refactoring (over 150 pull requests!) made it into this release! For the full list, please see the `changelog `__. Special thanks to contributions from @tserg, @trocher, @z80dev, @emc415 and @benber86 in this release! New Contributors: @@ -346,7 +346,7 @@ Notable Fixes: * Referencing immutables in constructor (`#2627 `_) * Arrays of interfaces in for loops (`#2699 `_) -Lots of optimizations, refactoring and other fixes made it into this release! For the full list, please see the `changelog `_. +Lots of optimizations, refactoring and other fixes made it into this release! For the full list, please see the `changelog `__. Special thanks to @tserg for typechecker fixes and significant testing of new features! Additional contributors to this release include @abdullathedruid, @hi-ogawa, @skellet0r, @fubuloubu, @onlymaresia, @SwapOperator, @hitsuzen-eth, @Sud0u53r, @davidhq. diff --git a/docs/testing-contracts-ethtester.rst b/docs/testing-contracts-ethtester.rst index 1b7e9e3263..27e67831de 100644 --- a/docs/testing-contracts-ethtester.rst +++ b/docs/testing-contracts-ethtester.rst @@ -16,7 +16,7 @@ Prior to testing, the Vyper specific contract conversion and the blockchain rela Since the testing is done in the pytest framework, you can make use of `pytest.ini, tox.ini and setup.cfg `_ and you can use most IDEs' pytest plugins. -.. literalinclude:: ../tests/base_conftest.py +.. literalinclude:: ../tests/conftest.py :language: python :linenos: @@ -35,7 +35,7 @@ Assume the following simple contract ``storage.vy``. It has a single integer var We create a test file ``test_storage.py`` where we write our tests in pytest style. -.. literalinclude:: ../tests/examples/storage/test_storage.py +.. literalinclude:: ../tests/functional/examples/storage/test_storage.py :linenos: :language: python @@ -61,7 +61,7 @@ Next, we take a look at the two fixtures that will allow us to read the event lo The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction. -.. literalinclude:: ../tests/base_conftest.py +.. literalinclude:: ../tests/conftest.py :language: python :pyobject: get_logs @@ -69,6 +69,6 @@ This fixture will return a tuple with all the logs for a certain event and trans Finally, we create a new file ``test_advanced_storage.py`` where we use the new fixtures to test failed transactions and events. -.. literalinclude:: ../tests/examples/storage/test_advanced_storage.py +.. literalinclude:: ../tests/functional/examples/storage/test_advanced_storage.py :linenos: :language: python diff --git a/docs/types.rst b/docs/types.rst index a8be721b1a..0f5bfe7b04 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -316,7 +316,7 @@ Syntax as follows: ``_address.``, where ``_address`` is of the type ``ad ``_address.code`` requires the usage of :func:`slice ` to explicitly extract a section of contract bytecode. If the extracted section exceeds the bounds of bytecode, this will throw. You can check the size of ``_address.code`` using ``_address.codesize``. M-byte-wide Fixed Size Byte Array ----------------------- +--------------------------------- **Keyword:** ``bytesM`` This is an M-byte-wide byte array that is otherwise similar to dynamically sized byte arrays. On an ABI level, it is annotated as bytesM (e.g., bytes32). @@ -452,6 +452,7 @@ Note that ``in`` is not the same as strict equality (``==``). ``in`` checks that The following code uses bitwise operations to add and revoke permissions from a given ``Roles`` object. .. code-block:: python + @external def add_user(a: Roles) -> Roles: ret: Roles = a @@ -676,4 +677,4 @@ All type conversions in Vyper must be made explicitly using the built-in ``conve * Converting between bytes and int types which have different sizes follows the rule of going through the closest integer type, first. For instance, ``bytes1 -> int16`` is like ``bytes1 -> int8 -> int16`` (signextend, then widen). ``uint8 -> bytes20`` is like ``uint8 -> uint160 -> bytes20`` (rotate left 12 bytes). * Enums can be converted to and from ``uint256`` only. -A small Python reference implementation is maintained as part of Vyper's test suite, it can be found `here `_. The motivation and more detailed discussion of the rules can be found `here `_. +A small Python reference implementation is maintained as part of Vyper's test suite, it can be found `here `__. The motivation and more detailed discussion of the rules can be found `here `__. diff --git a/requirements-docs.txt b/requirements-docs.txt index 157d7bcab5..5906384fc7 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,3 @@ -sphinx==5.0.0 -recommonmark==0.6.0 -sphinx_rtd_theme==0.5.2 +sphinx==7.2.6 +recommonmark==0.7.1 +sphinx_rtd_theme==2.0.0 From a2df08888c318713742c57f71465f32a1c27ed72 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 31 Jan 2024 10:05:40 -0800 Subject: [PATCH 172/201] fix: disallow `value=` passing for delegate and static raw_calls (#3755) --- .../functional/builtins/codegen/test_raw_call.py | 16 ++++++++++++++++ vyper/builtins/functions.py | 9 ++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index 4d37176cf8..b75b5da89b 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -608,6 +608,22 @@ def foo(_addr: address): ( """ @external +def foo(_addr: address): + raw_call(_addr, method_id("foo()"), is_delegate_call=True, value=1) + """, + ArgumentException, + ), + ( + """ +@external +def foo(_addr: address): + raw_call(_addr, method_id("foo()"), is_static_call=True, value=1) + """, + ArgumentException, + ), + ( + """ +@external @view def foo(_addr: address): raw_call(_addr, 256) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 8ee6f5fd76..50ab4dacd8 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1114,13 +1114,16 @@ def build_IR(self, expr, args, kwargs, context): if delegate_call and static_call: raise ArgumentException( - "Call may use one of `is_delegate_call` or `is_static_call`, not both", expr + "Call may use one of `is_delegate_call` or `is_static_call`, not both" ) + + if (delegate_call or static_call) and value.value != 0: + raise ArgumentException("value= may not be passed for static or delegate calls!") + if not static_call and context.is_constant(): raise StateAccessViolation( f"Cannot make modifying calls from {context.pp_constancy()}," - " use `is_static_call=True` to perform this action", - expr, + " use `is_static_call=True` to perform this action" ) if data.value == "~calldata": From db8ac3c29ebae17a123ad526ec4ce69f3734be40 Mon Sep 17 00:00:00 2001 From: ControlCplusControlV <44706811+ControlCplusControlV@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:13:42 -0700 Subject: [PATCH 173/201] feat: rename `vyper.interfaces` to `ethereum.ercs` (#3741) rename the builtin interfaces `vyper.interfaces` to `ethereum.ercs`. --- docs/interfaces.rst | 4 +-- examples/factory/Exchange.vy | 2 +- examples/factory/Factory.vy | 2 +- .../market_maker/on_chain_market_maker.vy | 2 +- examples/tokens/ERC1155ownable.vy | 2 +- examples/tokens/ERC20.vy | 4 +-- examples/tokens/ERC4626.vy | 4 +-- examples/tokens/ERC721.vy | 4 +-- .../test_external_contract_calls.py | 2 +- .../modules/test_stateless_functions.py | 2 +- tests/functional/codegen/test_interfaces.py | 4 +-- .../functional/syntax/test_functions_call.py | 2 +- tests/functional/syntax/test_interfaces.py | 32 +++++++++---------- vyper/semantics/analysis/module.py | 8 ++--- 14 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/interfaces.rst b/docs/interfaces.rst index b4182cced7..ab220272d8 100644 --- a/docs/interfaces.rst +++ b/docs/interfaces.rst @@ -160,11 +160,11 @@ In the above example, the ``my_project`` folder is set as the root path. A contr Built-in Interfaces =================== -Vyper includes common built-in interfaces such as `ERC20 `_ and `ERC721 `_. These are imported from ``vyper.interfaces``: +Vyper includes common built-in interfaces such as `ERC20 `_ and `ERC721 `_. These are imported from ``ethereum.ercs``: .. code-block:: python - from vyper.interfaces import ERC20 + from ethereum.ercs import ERC20 implements: ERC20 diff --git a/examples/factory/Exchange.vy b/examples/factory/Exchange.vy index acdcfd6d4e..77f47984bc 100644 --- a/examples/factory/Exchange.vy +++ b/examples/factory/Exchange.vy @@ -1,4 +1,4 @@ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 interface Factory: diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index bb60f12331..50e7a81bf6 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -1,4 +1,4 @@ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 interface Exchange: def token() -> ERC20: view diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index d385d2e0c6..4f9859584c 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -1,4 +1,4 @@ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 totalEthQty: public(uint256) diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index e105a79133..d1e88dcd04 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -9,7 +9,7 @@ """ ############### imports ############### -from vyper.interfaces import ERC165 +from ethereum.ercs import ERC165 ############### variables ############### # maximum items in a batch call. Set to 128, to be determined what the practical limits are. diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index c3809dbb60..77550c3f5a 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -6,8 +6,8 @@ # @author Takayuki Jimba (@yudetamago) # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-20.md -from vyper.interfaces import ERC20 -from vyper.interfaces import ERC20Detailed +from ethereum.ercs import ERC20 +from ethereum.ercs import ERC20Detailed implements: ERC20 implements: ERC20Detailed diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 0a0a698bf0..73721fdb98 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -6,8 +6,8 @@ ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### -from vyper.interfaces import ERC20 -from vyper.interfaces import ERC4626 +from ethereum.ercs import ERC20 +from ethereum.ercs import ERC4626 implements: ERC20 implements: ERC4626 diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 9c86575ea3..d3a8d1f13d 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -6,8 +6,8 @@ # @author Ryuya Nakamura (@nrryuya) # Modified from: https://github.com/vyperlang/vyper/blob/de74722bf2d8718cca46902be165f9fe0e3641dd/examples/tokens/ERC721.vy -from vyper.interfaces import ERC165 -from vyper.interfaces import ERC721 +from ethereum.ercs import ERC165 +from ethereum.ercs import ERC721 implements: ERC721 implements: ERC165 diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 0af4f9f937..a7cf4d0ecf 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -2370,7 +2370,7 @@ def transfer(receiver: address, amount: uint256): """ code = """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 @external def safeTransfer(erc20: ERC20, receiver: address, amount: uint256) -> uint256: assert erc20.transfer(receiver, amount, default_return_value=True) diff --git a/tests/functional/codegen/modules/test_stateless_functions.py b/tests/functional/codegen/modules/test_stateless_functions.py index 2abc164689..26c3f338fb 100644 --- a/tests/functional/codegen/modules/test_stateless_functions.py +++ b/tests/functional/codegen/modules/test_stateless_functions.py @@ -188,7 +188,7 @@ def qux() -> library.SomeStruct: # test calls to library functions in statement position def test_library_statement_calls(get_contract, make_input_bundle, tx_failed): library_source = """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 @internal def check_adds_to_ten(x: uint256, y: uint256): assert x + y == 10 diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 7d363fadc0..3344ff113b 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -69,7 +69,7 @@ def test(_owner: address): nonpayable def test_basic_interface_implements(assert_compile_failed): code = """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 implements: ERC20 @@ -382,7 +382,7 @@ def transfer(to: address, amount: uint256) -> bool: """ code = """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 token_address: ERC20 diff --git a/tests/functional/syntax/test_functions_call.py b/tests/functional/syntax/test_functions_call.py index a1a23b6bc2..c585572c63 100644 --- a/tests/functional/syntax/test_functions_call.py +++ b/tests/functional/syntax/test_functions_call.py @@ -52,7 +52,7 @@ def foo(x: int128) -> uint256: return convert(x, uint256) """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 interface Factory: def getExchange(token_addr: address) -> address: view diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index ca96adca91..584e497534 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -15,7 +15,7 @@ fail_list = [ ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: public(ERC20) @external def test(): @@ -25,7 +25,7 @@ def test(): ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 aba: public(ERC20) @external def test(): @@ -35,7 +35,7 @@ def test(): ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: address(ERC20) # invalid syntax now. """, @@ -43,7 +43,7 @@ def test(): ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 @external def test(): @@ -63,7 +63,7 @@ def test(): # may not call normal address ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 @external def test(a: address): my_address: address = ERC20() @@ -72,7 +72,7 @@ def test(a: address): ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 implements: ERC20 = 1 """, @@ -109,7 +109,7 @@ def foo(): nonpayable ), ( """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 interface A: def f(): view @@ -137,7 +137,7 @@ def f(a: uint256): # visibility is nonpayable instead of view ( # `receiver` of `Transfer` event should be indexed """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 implements: ERC20 @@ -175,7 +175,7 @@ def approve(_spender : address, _value : uint256) -> bool: ( # `value` of `Transfer` event should not be indexed """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 implements: ERC20 @@ -221,14 +221,14 @@ def test_interfaces_fail(bad_code): valid_list = [ """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 b: ERC20 @external def test(input: address): assert self.b.totalSupply() == ERC20(input).totalSupply() """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 interface Factory: def getExchange(token_addr: address) -> address: view @@ -253,12 +253,12 @@ def test() -> (bool, Foo): return True, x """ """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: public(ERC20) """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: public(ERC20) @@ -267,7 +267,7 @@ def test() -> address: return self.a.address """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: public(ERC20) b: address @@ -277,7 +277,7 @@ def test(): self.b = self.a.address """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 struct aStruct: my_address: address @@ -291,7 +291,7 @@ def test() -> address: return self.b.my_address """, """ -from vyper.interfaces import ERC20 +from ethereum.ercs import ERC20 a: public(ERC20) @external def test(): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 100819526b..a83c2f3b7d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -502,7 +502,7 @@ def _import_to_path(level: int, module_str: str) -> PurePath: # can add more, e.g. "vyper.builtins.interfaces", etc. -BUILTIN_PREFIXES = ["vyper.interfaces"] +BUILTIN_PREFIXES = ["ethereum.ercs"] def _is_builtin(module_str): @@ -524,10 +524,10 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: input_bundle = FilesystemInputBundle([search_path]) # remap builtins directory -- - # vyper/interfaces => vyper/builtins/interfaces + # ethereum/ercs => vyper/builtins/interfaces remapped_module = module_str - if remapped_module.startswith("vyper.interfaces"): - remapped_module = remapped_module.removeprefix("vyper.interfaces") + if remapped_module.startswith("ethereum.ercs"): + remapped_module = remapped_module.removeprefix("ethereum.ercs") remapped_module = vyper.builtins.interfaces.__package__ + remapped_module path = _import_to_path(level, remapped_module).with_suffix(".vyi") From f7f67d06a155903582b08378b1a3c1e459908149 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 4 Feb 2024 16:30:05 -0800 Subject: [PATCH 174/201] fix: bad assertion in expr.py (#3758) in addition to structs and interfaces, some builtins can also be skipped by the constant folder. --- tests/functional/syntax/test_constants.py | 7 +++++++ vyper/codegen/expr.py | 3 --- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 987e39809f..5a0bbdb2b5 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -279,6 +279,13 @@ def deposit(deposit_input: Bytes[2048]): CONST_BAR: constant(Foo) = Foo({a: 1, b: 2}) """, """ +CONST_EMPTY: constant(bytes32) = empty(bytes32) + +@internal +def foo() -> bytes32: + return CONST_EMPTY + """, + """ struct Foo: a: uint256 b: uint256 diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 6a97e60ce2..13dae446ef 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -190,9 +190,6 @@ def parse_Name(self): varinfo = self.context.globals[self.expr.id] if varinfo.is_constant: - # constants other than structs and interfaces should have already gotten - # propagated during constant folding - assert isinstance(varinfo.typ, (InterfaceT, StructT)) return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" From 01ec9a1fd1cc99ebfff638c8f126c6ee5f2e5dbd Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 5 Feb 2024 23:08:33 +0800 Subject: [PATCH 175/201] feat: remove Index AST node (#3757) remove Index AST node type, it has been deprecated from the python AST since python3.9. --- vyper/ast/expansion.py | 8 ++--- vyper/ast/nodes.py | 4 --- vyper/ast/nodes.pyi | 5 +-- vyper/ast/parse.py | 18 ----------- vyper/codegen/expr.py | 6 ++-- vyper/semantics/analysis/constant_folding.py | 2 +- vyper/semantics/analysis/local.py | 6 +--- vyper/semantics/analysis/utils.py | 8 ++--- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/bytestrings.py | 2 +- vyper/semantics/types/subscriptable.py | 19 ++++------- vyper/semantics/types/utils.py | 34 +++++++++----------- 12 files changed, 37 insertions(+), 77 deletions(-) diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 1536f39165..6bc4ffb57f 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -43,10 +43,10 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: raise CompilerPanic("Mismatch between node and input type while building getter") if annotation.value.get("id") == "HashMap": # type: ignore # for a HashMap, split the key/value types and use the key type as the next arg - arg, annotation = annotation.slice.value.elements # type: ignore + arg, annotation = annotation.slice.elements # type: ignore elif annotation.value.get("id") == "DynArray": arg = vy_ast.Name(id=type_._id) - annotation = annotation.slice.value.elements[0] # type: ignore + annotation = annotation.slice.elements[0] # type: ignore else: # for other types, build an input arg node from the expected type # and remove the outer `Subscript` from the annotation @@ -55,9 +55,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: input_nodes.append(vy_ast.arg(arg=f"arg{i}", annotation=arg)) # wrap the return statement in a `Subscript` - return_stmt = vy_ast.Subscript( - value=return_stmt, slice=vy_ast.Index(value=vy_ast.Name(id=f"arg{i}")) - ) + return_stmt = vy_ast.Subscript(value=return_stmt, slice=vy_ast.Name(id=f"arg{i}")) # after iterating the input types, the remaining annotation node is our return type return_node = copy.copy(annotation) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index de15fb9075..054145d33b 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1227,10 +1227,6 @@ class Subscript(ExprNode): __slots__ = ("slice", "value") -class Index(VyperNode): - __slots__ = ("value",) - - class Assign(Stmt): """ An assignment. diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 896329c702..f71ed67821 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -204,12 +204,9 @@ class Attribute(VyperNode): value: VyperNode = ... class Subscript(VyperNode): - slice: Index = ... + slice: VyperNode = ... value: VyperNode = ... -class Index(VyperNode): - value: Constant = ... - class Assign(VyperNode): ... class AnnAssign(VyperNode): diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index cc0a47824c..fc99af901b 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -341,24 +341,6 @@ def visit_Expr(self, node): return node - def visit_Subscript(self, node): - """ - Maintain consistency of `Subscript.slice` across python versions. - - Starting from python 3.9, the `Index` node type has been deprecated, - and made impossible to instantiate via regular means. Here we do awful - hacky black magic to create an `Index` node. We need our own parser. - """ - self.generic_visit(node) - - if not isinstance(node.slice, python_ast.Index): - index = python_ast.Constant(value=node.slice, ast_type="Index") - index.__class__ = python_ast.Index - self.generic_visit(index) - node.slice = index - - return node - def visit_Constant(self, node): """ Handle `Constant` when using Python >=3.8 diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 13dae446ef..f4c7948382 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -356,16 +356,16 @@ def parse_Subscript(self): if isinstance(sub.typ, HashMapT): # TODO sanity check we are in a self.my_map[i] situation - index = Expr(self.expr.slice.value, self.context).ir_node + index = Expr(self.expr.slice, self.context).ir_node if isinstance(index.typ, _BytestringT): # we have to hash the key to get a storage location index = keccak256_helper(index, self.context) elif is_array_like(sub.typ): - index = Expr.parse_value_expr(self.expr.slice.value, self.context) + index = Expr.parse_value_expr(self.expr.slice, self.context) elif is_tuple_like(sub.typ): - index = self.expr.slice.value.n + index = self.expr.slice.n # note: this check should also happen in get_element_ptr if not 0 <= index < len(sub.typ.member_types): raise TypeCheckFailure("unreachable") diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index b165a6dae9..bfcc473d09 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -217,7 +217,7 @@ def visit_Call(self, node) -> vy_ast.ExprNode: return typ._try_fold(node) # type: ignore def visit_Subscript(self, node) -> vy_ast.ExprNode: - slice_ = node.slice.value.get_folded_value() + slice_ = node.slice.get_folded_value() value = node.value.get_folded_value() if not isinstance(value, vy_ast.List): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 29a93a9eaf..0520b2995f 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -647,10 +647,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None: validate_expected_type(node, typ) - def visit_Index(self, node: vy_ast.Index, typ: VyperType) -> None: - validate_expected_type(node.value, typ) - self.visit(node.value, typ) - def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: assert isinstance(typ, (SArrayT, DArrayT)) for element in node.elements: @@ -687,7 +683,7 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: # get the correct type for the index, it might # not be exactly base_type.key_type # note: index_type is validated in types_from_Subscript - index_types = get_possible_types_from_node(node.slice.value) + index_types = get_possible_types_from_node(node.slice) index_type = index_types.pop() self.visit(node.slice, index_type) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 3e818fa246..81b6843b2b 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -382,13 +382,13 @@ def types_from_Subscript(self, node): types_list = self.get_possible_types_from_node(node.value) ret = [] for t in types_list: - t.validate_index_type(node.slice.value) - ret.append(t.get_subscripted_type(node.slice.value)) + t.validate_index_type(node.slice) + ret.append(t.get_subscripted_type(node.slice)) return ret t = self.get_exact_type_from_node(node.value) - t.validate_index_type(node.slice.value) - return [t.get_subscripted_type(node.slice.value)] + t.validate_index_type(node.slice) + return [t.get_subscripted_type(node.slice)] def types_from_Tuple(self, node): types_list = [self.get_exact_type_from_node(i) for i in node.elements] diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index b15eca8ab2..b14351a20f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -274,7 +274,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]: raise StructureException(f"{self} is not callable", node) @classmethod - def get_subscripted_type(self, node: vy_ast.Index) -> None: + def get_subscripted_type(self, node: vy_ast.VyperNode) -> None: """ Return the type of a subscript expression, e.g. x[1] diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 2f342c613e..96bb1bbf74 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -123,7 +123,7 @@ def compare_type(self, other): @classmethod def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": - if not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index): + if not isinstance(node, vy_ast.Subscript): raise StructureException( f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node ) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 55ffc23b2f..635a1631a2 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -71,9 +71,8 @@ def get_subscripted_type(self, node): def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT": if ( not isinstance(node, vy_ast.Subscript) - or not isinstance(node.slice, vy_ast.Index) - or not isinstance(node.slice.value, vy_ast.Tuple) - or len(node.slice.value.elements) != 2 + or not isinstance(node.slice, vy_ast.Tuple) + or len(node.slice.elements) != 2 ): raise StructureException( ( @@ -83,7 +82,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT": node, ) - k_ast, v_ast = node.slice.value.elements + k_ast, v_ast = node.slice.elements key_type = type_from_annotation(k_ast, DataLocation.STORAGE) if not key_type._as_hashmap_key: raise InvalidType("can only use primitive types as HashMap key!", k_ast) @@ -198,7 +197,7 @@ def compare_type(self, other): @classmethod def from_annotation(cls, node: vy_ast.Subscript) -> "SArrayT": - if not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index): + if not isinstance(node, vy_ast.Subscript): raise StructureException( "Arrays must be defined with base type and length, e.g. bool[5]", node ) @@ -280,14 +279,10 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": if not isinstance(node, vy_ast.Subscript): raise StructureException(err_msg, node) - if ( - not isinstance(node.slice, vy_ast.Index) - or not isinstance(node.slice.value, vy_ast.Tuple) - or len(node.slice.value.elements) != 2 - ): + if not isinstance(node.slice, vy_ast.Tuple) or len(node.slice.elements) != 2: raise StructureException(err_msg, node.slice) - length_node = node.slice.value.elements[1] + length_node = node.slice.elements[1] if length_node.has_folded_value: length_node = length_node.get_folded_value() @@ -296,7 +291,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": length = length_node.value - value_node = node.slice.value.elements[0] + value_node = node.slice.elements[0] value_type = type_from_annotation(value_node) if not value_type._as_darray: raise StructureException(f"Arrays of {value_type} are not allowed", value_node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index c82eb73afc..5564570536 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -161,15 +161,14 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: return typ_ -def get_index_value(node: vy_ast.Index) -> int: +def get_index_value(node: vy_ast.VyperNode) -> int: """ Return the literal value for a `Subscript` index. Arguments --------- - node: vy_ast.Index - Vyper ast node from the `slice` member of a Subscript node. Must be an - `Index` object (Vyper does not support `Slice` or `ExtSlice`). + node: vy_ast.VyperNode + Vyper ast node from the `slice` member of a Subscript node. Returns ------- @@ -181,23 +180,20 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - value = node.get("value") - if value.has_folded_value: - value = value.get_folded_value() - - if not isinstance(value, vy_ast.Int): - if hasattr(node, "value"): - # even though the subscript is an invalid type, first check if it's a valid _something_ - # this gives a more accurate error in case of e.g. a typo in a constant variable name - try: - get_possible_types_from_node(node.value) - except StructureException: - # StructureException is a very broad error, better to raise InvalidType in this case - pass + if node.has_folded_value: + node = node.get_folded_value() + if not isinstance(node, vy_ast.Int): + # even though the subscript is an invalid type, first check if it's a valid _something_ + # this gives a more accurate error in case of e.g. a typo in a constant variable name + try: + get_possible_types_from_node(node) + except StructureException: + # StructureException is a very broad error, better to raise InvalidType in this case + pass raise InvalidType("Subscript must be a literal integer", node) - if value.value <= 0: + if node.value <= 0: raise ArrayIndexException("Subscript must be greater than 0", node) - return value.value + return node.value From e20885e6058e09826a0232b1823d5e597600d031 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 6 Feb 2024 01:56:54 +0800 Subject: [PATCH 176/201] refactor: `ExprVisitor` type validation (#3739) this commit simplifies the `ExprVisitor` implementation by moving calls to `validate_expected_type` into the generic `visit()` function, instead of having ad-hoc calls to validate_expected_type in the specialized visitor functions. in doing so, some inconsistencies in the generic implementation were found and fixed: - fix validate_expected_type for tuples - introduce a void type for dealing with function calls/statements which don't return anything. --------- Co-authored-by: Charles Cooper --- .../calling_convention/test_return_tuple.py | 4 +- vyper/semantics/analysis/local.py | 82 +++++++------------ vyper/semantics/analysis/utils.py | 41 +++++++--- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/base.py | 16 +++- 5 files changed, 74 insertions(+), 71 deletions(-) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index 266555ead6..afbcf4027b 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import TypeMismatch +from vyper.exceptions import InvalidType pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -159,5 +159,5 @@ def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): def getTimeAndBalance() -> (bool, address): return block.timestamp, self.balance """ - with pytest.raises(TypeMismatch): + with pytest.raises(InvalidType): compile_code(code) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0520b2995f..91cc0ebdf8 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -32,6 +32,7 @@ from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( TYPE_T, + VOID_TYPE, AddressT, BoolT, DArrayT, @@ -45,6 +46,7 @@ VyperType, _BytestringT, is_type_t, + map_void, ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability from vyper.semantics.types.utils import type_from_annotation @@ -235,12 +237,13 @@ def visit_AnnAssign(self, node): ) typ = type_from_annotation(node.annotation, DataLocation.MEMORY) - validate_expected_type(node.value, typ) + + # validate the value before adding it to the namespace + self.expr_visitor.visit(node.value, typ) self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) self.expr_visitor.visit(node.target, typ) - self.expr_visitor.visit(node.value, typ) def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: if isinstance(msg_node, vy_ast.Str): @@ -259,10 +262,6 @@ def visit_Assert(self, node): if node.msg: self._validate_revert_reason(node.msg) - try: - validate_expected_type(node.test, BoolT()) - except InvalidType: - raise InvalidType("Assertion test value must be a boolean", node.test) self.expr_visitor.visit(node.test, BoolT()) # repeated code for Assign and AugAssign @@ -276,7 +275,6 @@ def _assign_helper(self, node): "Left-hand side of assignment cannot be a HashMap without a key", node ) - validate_expected_type(node.value, target.typ) target.validate_modification(node, self.func.mutability) self.expr_visitor.visit(node.value, target.typ) @@ -341,16 +339,16 @@ def visit_Expr(self, node): expr_info.validate_modification(node, self.func.mutability) # NOTE: fetch_call_return validates call args. - return_value = fn_type.fetch_call_return(node.value) + return_value = map_void(fn_type.fetch_call_return(node.value)) if ( - return_value + return_value is not VOID_TYPE and not isinstance(fn_type, MemberFunctionT) and not isinstance(fn_type, ContractFunctionT) ): raise StructureException( f"Function '{fn_type._id}' cannot be called without assigning the result", node ) - self.expr_visitor.visit(node.value, fn_type) + self.expr_visitor.visit(node.value, return_value) def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): @@ -443,7 +441,6 @@ def visit_For(self, node): self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): - validate_expected_type(node.test, BoolT()) self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): for n in node.body: @@ -462,9 +459,11 @@ def visit_Log(self, node): raise StructureException( f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node ) - f.fetch_call_return(node.value) + t = map_void(f.fetch_call_return(node.value)) + # CMC 2024-02-05 annotate the event type for codegen usage + # TODO: refactor this node._metadata["type"] = f.typedef - self.expr_visitor.visit(node.value, f.typedef) + self.expr_visitor.visit(node.value, t) def visit_Raise(self, node): if node.exc: @@ -489,10 +488,7 @@ def visit_Return(self, node): f"expected {self.func.return_type.length}, got {len(values)}", node, ) - for given, expected in zip(values, self.func.return_type.member_types): - validate_expected_type(given, expected) - else: - validate_expected_type(values, self.func.return_type) + self.expr_visitor.visit(node.value, self.func.return_type) @@ -503,14 +499,11 @@ def __init__(self, fn_node: Optional[ContractFunctionT] = None): self.func = fn_node def visit(self, node, typ): + if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): + validate_expected_type(node, typ) + # recurse and typecheck in case we are being fed the wrong type for - # some reason. note that `validate_expected_type` is unnecessary - # for nodes that already call `get_exact_type_from_node` and - # `get_possible_types_from_node` because `validate_expected_type` - # would be calling the same function again. - # CMC 2023-06-27 would be cleanest to call validate_expected_type() - # before recursing but maybe needs some refactoring before that - # can happen. + # some reason. super().visit(node, typ) # annotate @@ -541,28 +534,21 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: self.visit(node.value, value_type) def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None: - validate_expected_type(node.left, typ) self.visit(node.left, typ) rtyp = typ if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): rtyp = get_possible_types_from_node(node.right).pop() - validate_expected_type(node.right, rtyp) - self.visit(node.right, rtyp) def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: assert typ == BoolT() # sanity check for value in node.values: - validate_expected_type(value, BoolT()) self.visit(value, BoolT()) def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: call_type = get_exact_type_from_node(node.func) - # except for builtin functions, `get_exact_type_from_node` - # already calls `validate_expected_type` on the call args - # and kwargs via `call_type.fetch_call_return` self.visit(node.func, call_type) if isinstance(call_type, ContractFunctionT): @@ -594,7 +580,6 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: else: # builtin functions arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) - # `infer_arg_types` already calls `validate_expected_type` for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) kwarg_types = call_type.infer_kwarg_types(node) @@ -610,7 +595,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: rlen = len(node.right.elements) rtyp = SArrayT(ltyp, rlen) - validate_expected_type(node.right, rtyp) else: rtyp = get_exact_type_from_node(node.right) if isinstance(rtyp, FlagT): @@ -621,8 +605,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: assert isinstance(rtyp, (SArrayT, DArrayT)) ltyp = rtyp.value_type - validate_expected_type(node.left, ltyp) - self.visit(node.left, ltyp) self.visit(node.right, rtyp) @@ -638,28 +620,27 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: rtyp = get_exact_type_from_node(node.right) else: ltyp = rtyp = cmp_typ - validate_expected_type(node.left, ltyp) - validate_expected_type(node.right, rtyp) self.visit(node.left, ltyp) self.visit(node.right, rtyp) def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None: - validate_expected_type(node, typ) + pass + + def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: + self.visit(node.test, BoolT()) + self.visit(node.body, typ) + self.visit(node.orelse, typ) def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: assert isinstance(typ, (SArrayT, DArrayT)) for element in node.elements: - validate_expected_type(element, typ.value_type) self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: if self.func and self.func.mutability == StateMutability.PURE: _validate_self_reference(node) - if not isinstance(typ, TYPE_T): - validate_expected_type(node, typ) - def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): # don't recurse; can't annotate AST children of type definition @@ -694,23 +675,16 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return + # these guarantees should be provided by validate_expected_type assert isinstance(typ, TupleT) - for element, subtype in zip(node.elements, typ.member_types): - validate_expected_type(element, subtype) - self.visit(element, subtype) + assert len(node.elements) == len(typ.member_types) + + for item_ast, item_type in zip(node.elements, typ.member_types): + self.visit(item_ast, item_type) def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: VyperType) -> None: - validate_expected_type(node.operand, typ) self.visit(node.operand, typ) - def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: - validate_expected_type(node.test, BoolT()) - self.visit(node.test, BoolT()) - validate_expected_type(node.body, typ) - self.visit(node.body, typ) - validate_expected_type(node.orelse, typ) - self.visit(node.orelse, typ) - def _validate_range_call(node: vy_ast.Call): """ diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 81b6843b2b..1c56b16020 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -317,6 +317,17 @@ def types_from_Constant(self, node): ) raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node) + def types_from_IfExp(self, node): + validate_expected_type(node.test, BoolT()) + types_list = get_common_types(node.body, node.orelse) + + if not types_list: + a = get_possible_types_from_node(node.body)[0] + b = get_possible_types_from_node(node.orelse)[0] + raise TypeMismatch(f"Dislike types: {a} and {b}", node) + + return types_list + def types_from_List(self, node): # literal array if _is_empty_list(node): @@ -399,17 +410,6 @@ def types_from_UnaryOp(self, node): types_list = self.get_possible_types_from_node(node.operand) return _validate_op(node, types_list, "validate_numeric_op") - def types_from_IfExp(self, node): - validate_expected_type(node.test, BoolT()) - types_list = get_common_types(node.body, node.orelse) - - if not types_list: - a = get_possible_types_from_node(node.body)[0] - b = get_possible_types_from_node(node.orelse)[0] - raise TypeMismatch(f"Dislike types: {a} and {b}", node) - - return types_list - def _is_empty_list(node): # Checks if a node is a `List` node with an empty list for `elements`, @@ -550,11 +550,26 @@ def validate_expected_type(node, expected_type): ------- None """ - given_types = _ExprAnalyser().get_possible_types_from_node(node) - if not isinstance(expected_type, tuple): expected_type = (expected_type,) + if isinstance(node, vy_ast.Tuple): + possible_tuple_types = [t for t in expected_type if isinstance(t, TupleT)] + for t in possible_tuple_types: + if len(t.member_types) != len(node.elements): + continue + for item_ast, item_type in zip(node.elements, t.member_types): + try: + validate_expected_type(item_ast, item_type) + return + except VyperException: + pass + else: + # fail block + pass + + given_types = _ExprAnalyser().get_possible_types_from_node(node) + if isinstance(node, vy_ast.List): # special case - for literal arrays we individually validate each item for expected in expected_type: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 880857ccb8..a04632b96f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,5 +1,5 @@ from . import primitives, subscriptable, user -from .base import TYPE_T, KwargSettings, VyperType, is_type_t +from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index b14351a20f..d659276ee0 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -25,7 +25,7 @@ def __init__(self, type_): self.type_ = type_ def compare_type(self, other): - return isinstance(other, self.type_) + return isinstance(other, self.type_) or self == other class VyperType: @@ -324,6 +324,20 @@ def __init__(self, typ, default, require_literal=False): self.require_literal = require_literal +class _VoidType(VyperType): + _id = "(void)" + + +# sentinel for function calls which return nothing +VOID_TYPE = _VoidType() + + +def map_void(typ: Optional[VyperType]) -> VyperType: + if typ is None: + return VOID_TYPE + return typ + + # A type type. Used internally for types which can live in expression # position, ex. constructors (events, interfaces and structs), and also # certain builtins which take types as parameters From 4ecd26b3651fd3069cfd65055d654afb8e9d554c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 7 Feb 2024 04:34:16 -0800 Subject: [PATCH 177/201] perf: reimplement `IRnode.__deepcopy__` (#3761) `deepcopy` is a hotspot for compile time. this commit results in a 16% improvement in compile time. --- vyper/codegen/ir_node.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index 45d93f3067..b1a71021c8 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -1,4 +1,5 @@ import contextlib +import copy import re from enum import Enum, auto from functools import cached_property @@ -392,6 +393,14 @@ def __init__( raise CompilerPanic(f"Invalid value for IR AST node: {self.value}") assert isinstance(self.args, list) + # deepcopy is a perf hotspot; it pays to optimize it a little + def __deepcopy__(self, memo): + cls = self.__class__ + ret = cls.__new__(cls) + ret.__dict__ = self.__dict__.copy() + ret.args = [copy.deepcopy(arg) for arg in ret.args] + return ret + # TODO would be nice to rename to `gas_estimate` or `gas_bound` @property def gas(self): From c6b29c7f06a493bbeb62e85baa2592a994b15a5d Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 9 Feb 2024 00:40:49 +0800 Subject: [PATCH 178/201] chore: improve exception for type validation (#3759) change `InvalidType` error to more accurate `TypeMismatch` --- .../builtins/codegen/test_minmax_value.py | 6 ++--- .../builtins/codegen/test_raw_call.py | 4 +-- tests/functional/builtins/folding/test_abs.py | 6 ++--- .../test_default_parameters.py | 12 ++++----- .../calling_convention/test_return_tuple.py | 4 +-- .../features/iteration/test_for_in_list.py | 2 +- .../codegen/features/test_assignment.py | 2 +- .../codegen/features/test_logging.py | 14 +++++----- .../codegen/types/numbers/test_signed_ints.py | 13 +++++++--- .../types/numbers/test_unsigned_ints.py | 11 +++++--- tests/functional/codegen/types/test_bytes.py | 8 +++--- .../codegen/types/test_dynamic_array.py | 3 +-- .../exceptions/test_invalid_type_exception.py | 22 ---------------- .../test_type_mismatch_exception.py | 22 ++++++++++++++++ tests/functional/syntax/test_abi_encode.py | 6 ++--- tests/functional/syntax/test_abs.py | 4 +-- tests/functional/syntax/test_addmulmod.py | 6 ++--- tests/functional/syntax/test_ann_assign.py | 14 +++++----- tests/functional/syntax/test_as_wei_value.py | 4 +-- tests/functional/syntax/test_block.py | 8 +++--- tests/functional/syntax/test_bool.py | 17 +++--------- tests/functional/syntax/test_bytes.py | 12 +++------ tests/functional/syntax/test_chainid.py | 6 ++--- tests/functional/syntax/test_concat.py | 4 +-- tests/functional/syntax/test_constants.py | 8 +++--- tests/functional/syntax/test_extract32.py | 2 +- tests/functional/syntax/test_invalids.py | 19 +++++++------- tests/functional/syntax/test_keccak256.py | 4 +-- tests/functional/syntax/test_list.py | 26 +++++++++---------- tests/functional/syntax/test_logging.py | 15 +++-------- tests/functional/syntax/test_minmax.py | 4 +-- tests/functional/syntax/test_nested_list.py | 8 +++--- tests/functional/syntax/test_powmod.py | 4 +-- tests/functional/syntax/test_selfdestruct.py | 4 +-- tests/functional/syntax/test_send.py | 12 ++++----- tests/functional/syntax/test_slice.py | 4 +-- tests/functional/syntax/test_structs.py | 3 +-- tests/functional/syntax/test_ternary.py | 4 +-- tests/functional/syntax/test_unary.py | 4 +-- .../unit/cli/vyper_json/test_compile_json.py | 6 ++--- .../semantics/analysis/test_array_index.py | 3 +-- vyper/semantics/analysis/utils.py | 3 +-- 42 files changed, 163 insertions(+), 180 deletions(-) diff --git a/tests/functional/builtins/codegen/test_minmax_value.py b/tests/functional/builtins/codegen/test_minmax_value.py index 033381f59f..c5ee5c3584 100644 --- a/tests/functional/builtins/codegen/test_minmax_value.py +++ b/tests/functional/builtins/codegen/test_minmax_value.py @@ -1,6 +1,6 @@ import pytest -from vyper.exceptions import InvalidType, OverflowException +from vyper.exceptions import OverflowException, TypeMismatch from vyper.semantics.types import DecimalT, IntegerT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -39,12 +39,12 @@ def foo(): if typ == UINT256_T: assert_compile_failed(lambda: get_contract(upper), OverflowException) else: - assert_compile_failed(lambda: get_contract(upper), InvalidType) + assert_compile_failed(lambda: get_contract(upper), TypeMismatch) if typ == INT256_T: assert_compile_failed(lambda: get_contract(lower), OverflowException) else: - assert_compile_failed(lambda: get_contract(lower), InvalidType) + assert_compile_failed(lambda: get_contract(lower), TypeMismatch) @pytest.mark.parametrize("typ", [DecimalT()]) diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index b75b5da89b..b30a94502d 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -3,7 +3,7 @@ from vyper import compile_code from vyper.builtins.functions import eip1167_bytecode -from vyper.exceptions import ArgumentException, InvalidType, StateAccessViolation +from vyper.exceptions import ArgumentException, StateAccessViolation, TypeMismatch pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -628,7 +628,7 @@ def foo(_addr: address): def foo(_addr: address): raw_call(_addr, 256) """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index c954380def..bc38b9fb1a 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -3,7 +3,7 @@ from hypothesis import strategies as st from tests.utils import parse_and_fold -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch @pytest.mark.fuzzing @@ -34,7 +34,7 @@ def test_abs_upper_bound_folding(get_contract, a): def foo(a: int256) -> int256: return abs({a}) """ - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): get_contract(source) @@ -56,5 +56,5 @@ def test_abs_lower_bound_folded(get_contract, tx_failed): def foo() -> int256: return abs(min_value(int256)) """ - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): get_contract(source) diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 240ccb3bb1..4153c7188e 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -3,9 +3,9 @@ from vyper.compiler import compile_code from vyper.exceptions import ( InvalidLiteral, - InvalidType, NonPayableViolation, StateAccessViolation, + TypeMismatch, UndeclaredDefinition, ) @@ -404,7 +404,7 @@ def foo(xx: int128, y: int128 = xx): pass @external def foo(a: uint256 = -1): pass """, - InvalidType, + TypeMismatch, ), ( """ @@ -412,7 +412,7 @@ def foo(a: uint256 = -1): pass @external def foo(a: int128 = 170141183460469231731687303715884105728): pass """, - InvalidType, + TypeMismatch, ), ( """ @@ -420,7 +420,7 @@ def foo(a: int128 = 170141183460469231731687303715884105728): pass @external def foo(a: uint256[2] = [13, -42]): pass """, - InvalidType, + TypeMismatch, ), ( """ @@ -428,7 +428,7 @@ def foo(a: uint256[2] = [13, -42]): pass @external def foo(a: int128[2] = [12, 170141183460469231731687303715884105728]): pass """, - InvalidType, + TypeMismatch, ), ( """ @@ -444,7 +444,7 @@ def foo(a: uint256[2] = [12, True]): pass @external def foo(a: uint256[2] = [1, 2, 3]): pass """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index afbcf4027b..266555ead6 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -159,5 +159,5 @@ def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): def getTimeAndBalance() -> (bool, address): return block.timestamp, self.balance """ - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): compile_code(code) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 7f5658e485..36252701c4 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -790,7 +790,7 @@ def test_for() -> int128: a = i return a """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index 9af7058250..aebb13eefa 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -207,7 +207,7 @@ def foo2() -> uint256: x += 1 return x """ - assert_compile_failed(lambda: get_contract_with_gas_estimation(code), InvalidType) + assert_compile_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) def test_invalid_uin256_assignment_calculate_literals(get_contract_with_gas_estimation): diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index ba09be1991..0cb8ad9abc 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -565,7 +565,7 @@ def foo_(): log MyLog(b'yo') """ - with tx_failed(InvalidType): + with tx_failed(TypeMismatch): get_contract_with_gas_estimation(loggy_code) @@ -580,7 +580,7 @@ def foo(): log MyLog(b'bars') """ - with tx_failed(InvalidType): + with tx_failed(TypeMismatch): get_contract_with_gas_estimation(loggy_code) @@ -608,7 +608,7 @@ def foo(): log MyLog(b'bars') """ - with tx_failed(InvalidType): + with tx_failed(TypeMismatch): get_contract_with_gas_estimation(loggy_code) @@ -1241,7 +1241,7 @@ def foo(): def foo(): raw_log([1, 2], b"moo") """, - InvalidType, + TypeMismatch, ), ( """ @@ -1249,7 +1249,7 @@ def foo(): def foo(): raw_log([1, 2], b"moo") """, - InvalidType, + TypeMismatch, ), ( """ @@ -1266,7 +1266,7 @@ def foo(): def foo(): raw_log([b"cow"], b"dog") """, - (StructureException, InvalidType), + (StructureException, TypeMismatch), ), ( """ @@ -1275,7 +1275,7 @@ def foo(): # bytes20 instead of bytes32 raw_log([], 0x1234567890123456789012345678901234567890) """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index a10eaee408..e646a25354 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -5,7 +5,12 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException +from vyper.exceptions import ( + InvalidOperation, + OverflowException, + TypeMismatch, + ZeroDivisionException, +) from vyper.semantics.types import IntegerT from vyper.utils import evm_div, evm_mod @@ -214,7 +219,7 @@ def num_sub() -> {typ}: return 1-2**{typ.bits} """ - exc = OverflowException if typ.bits == 256 else InvalidType + exc = OverflowException if typ.bits == 256 else TypeMismatch with pytest.raises(exc): compile_code(code) @@ -331,7 +336,7 @@ def foo() -> {typ}: get_contract(code_2).foo(x) with tx_failed(): get_contract(code_3).foo(y) - with pytest.raises((InvalidType, OverflowException)): + with pytest.raises((TypeMismatch, OverflowException)): compile_code(code_4) @@ -430,5 +435,5 @@ def test_binop_nested_intermediate_underflow(): def foo(): a: int256 = -2**255 * 2 - 10 + 100 """ - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): compile_code(code) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index f10e861689..3f3fa32aba 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -5,7 +5,12 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidOperation, InvalidType, OverflowException, ZeroDivisionException +from vyper.exceptions import ( + InvalidOperation, + OverflowException, + TypeMismatch, + ZeroDivisionException, +) from vyper.semantics.types import IntegerT from vyper.utils import SizeLimits, evm_div, evm_mod @@ -164,7 +169,7 @@ def foo() -> {typ}: get_contract(code_2).foo(x) with tx_failed(): get_contract(code_3).foo(y) - with pytest.raises((InvalidType, OverflowException)): + with pytest.raises((TypeMismatch, OverflowException)): get_contract(code_4) @@ -223,7 +228,7 @@ def test() -> {typ}: for val in bad_cases: exc = ( - InvalidType + TypeMismatch if SizeLimits.MIN_INT256 <= val <= SizeLimits.MAX_UINT256 else OverflowException ) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 882629de65..325f9d7923 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -1,6 +1,6 @@ import pytest -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch def test_test_bytes(get_contract_with_gas_estimation, tx_failed): @@ -310,7 +310,7 @@ def assign(): def assign(): xs: bytes6 = b"abcdef" """, - InvalidType, + TypeMismatch, ), ( """ @@ -318,7 +318,7 @@ def assign(): def assign(): xs: bytes4 = 0xabcdef # bytes3 literal """, - InvalidType, + TypeMismatch, ), ( """ @@ -326,7 +326,7 @@ def assign(): def assign(): xs: bytes4 = 0x1234abcdef # bytes5 literal """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index e47eda6042..d3d945740b 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -7,7 +7,6 @@ ArgumentException, ArrayIndexException, ImmutableViolation, - InvalidType, OverflowException, StateAccessViolation, TypeMismatch, @@ -1124,7 +1123,7 @@ def foo() -> DynArray[{subtyp}, 3]: x.append({lit}) return x """ - assert_compile_failed(lambda: get_contract(code), InvalidType) + assert_compile_failed(lambda: get_contract(code), TypeMismatch) invalid_appends_pops = [ diff --git a/tests/functional/syntax/exceptions/test_invalid_type_exception.py b/tests/functional/syntax/exceptions/test_invalid_type_exception.py index 3f441b8a93..8ce375c58e 100644 --- a/tests/functional/syntax/exceptions/test_invalid_type_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_type_exception.py @@ -22,30 +22,12 @@ def test_unknown_type_exception(bad_code, get_contract, assert_compile_failed): invalid_list = [ - """ -@external -def foo(): - raw_log(b"cow", b"dog") - """, - """ -@external -def foo(): - xs: uint256[1] = [] - """, # Must be a literal string. """ @external def mint(_to: address, _value: uint256): assert msg.sender == self,msg.sender """, - # literal longer than event member - """ -event Foo: - message: String[1] -@external -def foo(): - log Foo("abcd") - """, # Raise reason must be string """ @external @@ -58,10 +40,6 @@ def mint(_to: address, _value: uint256): # Key of mapping must be a base type """ b: HashMap[(int128, decimal), int128] - """, - # Address literal must be checksummed - """ -a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 """, """ x: String <= 33 diff --git a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py index 057eb94180..514f2df618 100644 --- a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py +++ b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py @@ -25,6 +25,28 @@ def foo(): b: Bytes[1] = b"\x05" x: uint256 = as_wei_value(b, "babbage") """, + """ +@external +def foo(): + raw_log(b"cow", b"dog") + """, + """ +@external +def foo(): + xs: uint256[1] = [] + """, + # literal longer than event member + """ +event Foo: + message: String[1] +@external +def foo(): + log Foo("abcd") + """, + # Address literal must be checksummed + """ +a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 + """, ] diff --git a/tests/functional/syntax/test_abi_encode.py b/tests/functional/syntax/test_abi_encode.py index 37d15e7e56..5e0175857d 100644 --- a/tests/functional/syntax/test_abi_encode.py +++ b/tests/functional/syntax/test_abi_encode.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -36,7 +36,7 @@ def foo(x: uint256) -> Bytes[36]: def foo(x: uint256) -> Bytes[36]: return _abi_encode(x, method_id=b"abcde") """, - InvalidType, # len(method_id) must be less than 4 + TypeMismatch, # len(method_id) must be less than 4 ), ( """ @@ -44,7 +44,7 @@ def foo(x: uint256) -> Bytes[36]: def foo(x: uint256) -> Bytes[36]: return _abi_encode(x, method_id=0x1234567890) """, - InvalidType, # len(method_id) must be less than 4 + TypeMismatch, # len(method_id) must be less than 4 ), ] diff --git a/tests/functional/syntax/test_abs.py b/tests/functional/syntax/test_abs.py index 0841ff05d6..6e61c4d7d2 100644 --- a/tests/functional/syntax/test_abs.py +++ b/tests/functional/syntax/test_abs.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -12,7 +12,7 @@ def foo(): -57896044618658097711785492504343953926634992332820282019728792003956564819968 ) """, - InvalidType, + TypeMismatch, ) ] diff --git a/tests/functional/syntax/test_addmulmod.py b/tests/functional/syntax/test_addmulmod.py index 17c7b3ab8c..69bd64aaa4 100644 --- a/tests/functional/syntax/test_addmulmod.py +++ b/tests/functional/syntax/test_addmulmod.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch fail_list = [ ( # bad AST nodes given as arguments @@ -10,7 +10,7 @@ def foo() -> uint256: return uint256_addmod(1.1, 1.2, 3.0) """, - InvalidType, + TypeMismatch, ), ( # bad AST nodes given as arguments """ @@ -18,7 +18,7 @@ def foo() -> uint256: def foo() -> uint256: return uint256_mulmod(1.1, 1.2, 3.0) """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/syntax/test_ann_assign.py b/tests/functional/syntax/test_ann_assign.py index b5c1f6a752..7fdb1328c2 100644 --- a/tests/functional/syntax/test_ann_assign.py +++ b/tests/functional/syntax/test_ann_assign.py @@ -4,7 +4,7 @@ from vyper import compiler from vyper.exceptions import ( InvalidAttribute, - InvalidType, + TypeMismatch, UndeclaredDefinition, UnknownAttribute, VariableDeclarationException, @@ -41,7 +41,7 @@ def test(): def test(): a: int128 = 33.33 """, - InvalidType, + TypeMismatch, ), ( """ @@ -50,7 +50,7 @@ def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5, 6] return 235357 """, - InvalidType, + TypeMismatch, ), ( """ @@ -62,7 +62,7 @@ def foo() -> int128: s: S = S({a: 1.2, b: 1}) return s.a """, - InvalidType, + TypeMismatch, ), ( """ @@ -105,7 +105,7 @@ def foo() -> bool: a: uint256 = -1 return True """, - InvalidType, + TypeMismatch, ), ( """ @@ -114,7 +114,7 @@ def foo() -> bool: a: uint256[2] = [13, -42] return True """, - InvalidType, + TypeMismatch, ), ( """ @@ -123,7 +123,7 @@ def foo() -> bool: a: int128 = 170141183460469231731687303715884105728 return True """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/syntax/test_as_wei_value.py b/tests/functional/syntax/test_as_wei_value.py index 056d0348e9..40562530d1 100644 --- a/tests/functional/syntax/test_as_wei_value.py +++ b/tests/functional/syntax/test_as_wei_value.py @@ -4,9 +4,9 @@ from vyper.exceptions import ( ArgumentException, InvalidLiteral, - InvalidType, OverflowException, StructureException, + TypeMismatch, UndeclaredDefinition, ) @@ -44,7 +44,7 @@ def foo() -> int128: def foo(): x: int128 = as_wei_value(0xf5, "szabo") """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_block.py b/tests/functional/syntax/test_block.py index 1e6bfcf0e2..8d8bffb697 100644 --- a/tests/functional/syntax/test_block.py +++ b/tests/functional/syntax/test_block.py @@ -2,7 +2,7 @@ from pytest import raises from vyper import compiler -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -11,7 +11,7 @@ def foo() -> int128[2]: return [3,block.timestamp] """, - InvalidType, + TypeMismatch, ), ( """ @@ -19,7 +19,7 @@ def foo() -> int128[2]: def foo() -> int128[2]: return [block.timestamp - block.timestamp, block.timestamp] """, - InvalidType, + TypeMismatch, ), """ @external @@ -34,7 +34,7 @@ def foo() -> decimal: def foo(): x: Bytes[10] = slice(b"cow", -1, block.timestamp) """, - InvalidType, + TypeMismatch, ), """ @external diff --git a/tests/functional/syntax/test_bool.py b/tests/functional/syntax/test_bool.py index 5388a92b95..fef40406b6 100644 --- a/tests/functional/syntax/test_bool.py +++ b/tests/functional/syntax/test_bool.py @@ -2,18 +2,15 @@ from pytest import raises from vyper import compiler -from vyper.exceptions import InvalidOperation, InvalidType, SyntaxException, TypeMismatch +from vyper.exceptions import InvalidOperation, SyntaxException, TypeMismatch fail_list = [ - ( - """ + """ @external def foo(): x: bool = True x = 5 """, - InvalidType, - ), ( """ @external @@ -22,15 +19,12 @@ def foo(): """, SyntaxException, ), - ( - """ + """ @external def foo(): x: bool = True x = 129 """, - InvalidType, - ), ( """ @external @@ -63,15 +57,12 @@ def foo(a: address) -> bool: """, InvalidOperation, ), - ( - """ + """ @external def test(a: address) -> bool: assert(a) return True """, - TypeMismatch, - ), ] diff --git a/tests/functional/syntax/test_bytes.py b/tests/functional/syntax/test_bytes.py index a7fb7e77ce..0ca3b27fee 100644 --- a/tests/functional/syntax/test_bytes.py +++ b/tests/functional/syntax/test_bytes.py @@ -1,13 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import ( - InvalidOperation, - InvalidType, - StructureException, - SyntaxException, - TypeMismatch, -) +from vyper.exceptions import InvalidOperation, StructureException, SyntaxException, TypeMismatch fail_list = [ ( @@ -64,7 +58,7 @@ def foo() -> Bytes[10]: x = 0x1234567890123456789012345678901234567890 return x """, - InvalidType, + TypeMismatch, ), ( """ @@ -72,7 +66,7 @@ def foo() -> Bytes[10]: def foo() -> Bytes[10]: return "badmintonzz" """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_chainid.py b/tests/functional/syntax/test_chainid.py index 2b6e08cbc4..ff8473f1a2 100644 --- a/tests/functional/syntax/test_chainid.py +++ b/tests/functional/syntax/test_chainid.py @@ -3,7 +3,7 @@ from vyper import compiler from vyper.compiler.settings import Settings from vyper.evm.opcodes import EVM_VERSIONS -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -25,7 +25,7 @@ def foo(): def foo() -> int128[2]: return [3,chain.id] """, - InvalidType, + TypeMismatch, ), """ @external @@ -60,7 +60,7 @@ def add_record(): def foo(inp: Bytes[10]) -> Bytes[3]: return slice(inp, chain.id, -3) """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/syntax/test_concat.py b/tests/functional/syntax/test_concat.py index e128e7c6ae..8431e5ecf2 100644 --- a/tests/functional/syntax/test_concat.py +++ b/tests/functional/syntax/test_concat.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import ArgumentException, InvalidType, TypeMismatch +from vyper.exceptions import ArgumentException, TypeMismatch fail_list = [ ( @@ -18,7 +18,7 @@ def cat(i1: Bytes[10], i2: Bytes[30]) -> Bytes[40]: def cat(i1: Bytes[10], i2: Bytes[30]) -> Bytes[40]: return concat(i1, 5) """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 5a0bbdb2b5..57922f28e2 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -5,11 +5,11 @@ from vyper.exceptions import ( ArgumentException, ImmutableViolation, - InvalidType, NamespaceCollision, StateAccessViolation, StructureException, SyntaxException, + TypeMismatch, VariableDeclarationException, ) @@ -33,14 +33,14 @@ """ VAL: constant(uint256) = "test" """, - InvalidType, + TypeMismatch, ), # invalid range ( """ VAL: constant(uint256) = -1 """, - InvalidType, + TypeMismatch, ), # reserved keyword ( @@ -62,7 +62,7 @@ """ VAL: constant(Bytes[4]) = b"testtest" """, - InvalidType, + TypeMismatch, ), # global with same name ( diff --git a/tests/functional/syntax/test_extract32.py b/tests/functional/syntax/test_extract32.py index b04c8b8742..caec38e5d1 100644 --- a/tests/functional/syntax/test_extract32.py +++ b/tests/functional/syntax/test_extract32.py @@ -34,7 +34,7 @@ def foo(inp: Bytes[32]) -> int128: def foo(inp: Bytes[32]) -> int128: return extract32(inp, -1, output_type=int128) """, - InvalidType, # `start` cannot be negative + TypeMismatch, # `start` cannot be negative ), ( """ diff --git a/tests/functional/syntax/test_invalids.py b/tests/functional/syntax/test_invalids.py index 33478fcff1..dfc74fc75b 100644 --- a/tests/functional/syntax/test_invalids.py +++ b/tests/functional/syntax/test_invalids.py @@ -4,7 +4,6 @@ from vyper.exceptions import ( FunctionDeclarationException, InvalidOperation, - InvalidType, StructureException, TypeMismatch, UndeclaredDefinition, @@ -67,7 +66,7 @@ def foo(): x: int128 = 5 x = 0x1234567890123456789012345678901234567890 """, - InvalidType, + TypeMismatch, ) must_fail( @@ -77,7 +76,7 @@ def foo(): x: int128 = 5 x = 3.5 """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -105,7 +104,7 @@ def foo(): def foo(): self.b = 7.5 """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -133,7 +132,7 @@ def foo(): def foo(): self.b = 7 """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -152,7 +151,7 @@ def foo(): def foo(): x: int128 = self.b[-5] """, - InvalidType, + TypeMismatch, ) must_fail( @@ -162,7 +161,7 @@ def foo(): def foo(): x: int128 = self.b[5.7] """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -181,7 +180,7 @@ def foo(): def foo(): self.b[3] = 5.6 """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -236,7 +235,7 @@ def foo(): def foo(): self.bar = 5 """, - InvalidType, + TypeMismatch, ) must_succeed( @@ -254,7 +253,7 @@ def foo(): def foo() -> address: return [1, 2, 3] """, - InvalidType, + TypeMismatch, ) must_fail( diff --git a/tests/functional/syntax/test_keccak256.py b/tests/functional/syntax/test_keccak256.py index 70d33edcf2..68253c8121 100644 --- a/tests/functional/syntax/test_keccak256.py +++ b/tests/functional/syntax/test_keccak256.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType, UndeclaredDefinition +from vyper.exceptions import TypeMismatch, UndeclaredDefinition type_fail_list = [ """ @@ -14,7 +14,7 @@ def foo(): @pytest.mark.parametrize("bad_code", type_fail_list) def test_block_type_fail(bad_code): - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): compiler.compile_code(bad_code) diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index 3936f8c220..e55b060542 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidLiteral, InvalidType, StructureException, TypeMismatch +from vyper.exceptions import InvalidLiteral, StructureException, TypeMismatch fail_list = [ ( @@ -11,7 +11,7 @@ def foo(): x: int128[3] = [1, 2, 3] x = 4 """, - InvalidType, + TypeMismatch, ), ( """ @@ -20,7 +20,7 @@ def foo(): x: int128[3] = [1, 2, 3] x = [4, 5, 6, 7] """, - InvalidType, + TypeMismatch, ), ( """ @@ -28,7 +28,7 @@ def foo(): def foo() -> int128[2]: return [3, 5, 7] """, - InvalidType, + TypeMismatch, ), ( """ @@ -36,7 +36,7 @@ def foo() -> int128[2]: def foo() -> int128[2]: return [3] """, - InvalidType, + TypeMismatch, ), ( """ @@ -94,7 +94,7 @@ def foo(): def foo(): self.bar = [] """, - InvalidType, + TypeMismatch, ), ( """ @@ -121,7 +121,7 @@ def foo(): def foo(): self.bar = 5 """, - InvalidType, + TypeMismatch, ), ( """ @@ -130,7 +130,7 @@ def foo(): def foo(): self.bar = [2, 5] """, - InvalidType, + TypeMismatch, ), ( """ @@ -139,7 +139,7 @@ def foo(): def foo(): self.bar = [1, 2, 3, 4] """, - InvalidType, + TypeMismatch, ), ( """ @@ -148,7 +148,7 @@ def foo(): def foo(): self.bar = [1, 2] """, - InvalidType, + TypeMismatch, ), ( """ @@ -157,7 +157,7 @@ def foo(): def foo(): self.b[0] = 7.5 """, - InvalidType, + TypeMismatch, ), ( """ @@ -176,7 +176,7 @@ def foo()->bool[2]: a[0] = 1 return a """, - InvalidType, + TypeMismatch, ), ( """ @@ -186,7 +186,7 @@ def foo()->bool[2]: a[0] = 1 return a """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_logging.py b/tests/functional/syntax/test_logging.py index 2dd21e7a92..edc728bd89 100644 --- a/tests/functional/syntax/test_logging.py +++ b/tests/functional/syntax/test_logging.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType, StructureException, TypeMismatch +from vyper.exceptions import StructureException, TypeMismatch fail_list = [ """ @@ -23,8 +23,7 @@ def foo(): x: decimal[4] = [0.0, 0.0, 0.0, 0.0] log Bar(x) """, - ( - """ + """ event Test: n: uint256 @@ -32,19 +31,13 @@ def foo(): def test(): log Test(-7) """, - InvalidType, - ), ] @pytest.mark.parametrize("bad_code", fail_list) def test_logging_fail(bad_code): - if isinstance(bad_code, tuple): - with pytest.raises(bad_code[1]): - compiler.compile_code(bad_code[0]) - else: - with pytest.raises(TypeMismatch): - compiler.compile_code(bad_code) + with pytest.raises(TypeMismatch): + compiler.compile_code(bad_code) @pytest.mark.parametrize("mutability", ["@pure", "@view"]) diff --git a/tests/functional/syntax/test_minmax.py b/tests/functional/syntax/test_minmax.py index 78ee74635c..0c4cc287b8 100644 --- a/tests/functional/syntax/test_minmax.py +++ b/tests/functional/syntax/test_minmax.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType, OverflowException, TypeMismatch +from vyper.exceptions import OverflowException, TypeMismatch fail_list = [ ( @@ -10,7 +10,7 @@ def foo(): y: int128 = min(7, 0x1234567890123456789012345678901234567890) """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_nested_list.py b/tests/functional/syntax/test_nested_list.py index a5f01274cf..dbc411b495 100644 --- a/tests/functional/syntax/test_nested_list.py +++ b/tests/functional/syntax/test_nested_list.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidLiteral, InvalidType, TypeMismatch +from vyper.exceptions import InvalidLiteral, TypeMismatch fail_list = [ ( @@ -11,7 +11,7 @@ def foo(): self.bar = [[1, 2], [3, 4, 5], [6, 7, 8]] """, - InvalidType, # casting darray to sarray + TypeMismatch, # casting darray to sarray ), ( """ @@ -28,7 +28,7 @@ def foo(): def foo() -> int128[2]: return [[1,2],[3,4]] """, - InvalidType, + TypeMismatch, ), ( """ @@ -36,7 +36,7 @@ def foo() -> int128[2]: def foo() -> int128[2][2]: return [1,2] """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_powmod.py b/tests/functional/syntax/test_powmod.py index 12ea23152c..da0b552d85 100644 --- a/tests/functional/syntax/test_powmod.py +++ b/tests/functional/syntax/test_powmod.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -10,7 +10,7 @@ def foo(): a: uint256 = pow_mod256(-1, -1) """, - InvalidType, + TypeMismatch, ) ] diff --git a/tests/functional/syntax/test_selfdestruct.py b/tests/functional/syntax/test_selfdestruct.py index 8f80a56ce1..9f55dca56b 100644 --- a/tests/functional/syntax/test_selfdestruct.py +++ b/tests/functional/syntax/test_selfdestruct.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch fail_list = [ """ @@ -14,7 +14,7 @@ def foo(): @pytest.mark.parametrize("bad_code", fail_list) def test_block_fail(bad_code): - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): compiler.compile_code(bad_code) diff --git a/tests/functional/syntax/test_send.py b/tests/functional/syntax/test_send.py index 15ec19f770..ffad1b3792 100644 --- a/tests/functional/syntax/test_send.py +++ b/tests/functional/syntax/test_send.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -10,7 +10,7 @@ def foo(): send(1, 2) """, - InvalidType, + TypeMismatch, ), ( """ @@ -18,7 +18,7 @@ def foo(): def foo(): send(0x1234567890123456789012345678901234567890, 2.5) """, - InvalidType, + TypeMismatch, ), ( """ @@ -26,7 +26,7 @@ def foo(): def foo(): send(0x1234567890123456789012345678901234567890, 0x1234567890123456789012345678901234567890) """, - InvalidType, + TypeMismatch, ), ( """ @@ -65,7 +65,7 @@ def foo(): def foo(): send(0x1234567890123456789012345678901234567890, 5, gas=1.5) """, - InvalidType, + TypeMismatch, ), ( """ @@ -73,7 +73,7 @@ def foo(): def foo(): send(0x1234567890123456789012345678901234567890, 5, gas=-2) """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_slice.py b/tests/functional/syntax/test_slice.py index 8fe162fc2b..6bb666527e 100644 --- a/tests/functional/syntax/test_slice.py +++ b/tests/functional/syntax/test_slice.py @@ -1,7 +1,7 @@ import pytest from vyper import compiler -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -26,7 +26,7 @@ def foo(inp: int128) -> Bytes[3]: def foo(inp: Bytes[10]) -> Bytes[3]: return slice(inp, 4.0, 3) """, - InvalidType, + TypeMismatch, ), ] diff --git a/tests/functional/syntax/test_structs.py b/tests/functional/syntax/test_structs.py index b30f7e6098..4fad35d1d4 100644 --- a/tests/functional/syntax/test_structs.py +++ b/tests/functional/syntax/test_structs.py @@ -3,7 +3,6 @@ from vyper import compiler from vyper.exceptions import ( InstantiationException, - InvalidType, StructureException, TypeMismatch, UnknownAttribute, @@ -254,7 +253,7 @@ def foo(): def foo(): self.mom = Mom({a: self.nom, b: 5.5}) """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/tests/functional/syntax/test_ternary.py b/tests/functional/syntax/test_ternary.py index 6a2bb9c072..c8b7d4e4b7 100644 --- a/tests/functional/syntax/test_ternary.py +++ b/tests/functional/syntax/test_ternary.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType, TypeMismatch +from vyper.exceptions import TypeMismatch good_list = [ # basic test @@ -73,7 +73,7 @@ def test_ternary_good(code): def foo() -> uint256: return 1 if 1 else 2 """, - InvalidType, + TypeMismatch, ), ( # bad test type: constant """ diff --git a/tests/functional/syntax/test_unary.py b/tests/functional/syntax/test_unary.py index 5942ee15db..2b2d5d8006 100644 --- a/tests/functional/syntax/test_unary.py +++ b/tests/functional/syntax/test_unary.py @@ -1,7 +1,7 @@ import pytest from vyper import compile_code -from vyper.exceptions import InvalidType +from vyper.exceptions import TypeMismatch fail_list = [ ( @@ -10,7 +10,7 @@ def foo() -> int128: return -2**127 """, - InvalidType, + TypeMismatch, ) ] diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index c805e2b5b1..0dc1b764e0 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -12,7 +12,7 @@ ) from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input from vyper.compiler.input_bundle import JSONInputBundle -from vyper.exceptions import InvalidType, JSONError, SyntaxException +from vyper.exceptions import JSONError, SyntaxException, TypeMismatch FOO_CODE = """ import contracts.ibar as IBar @@ -244,7 +244,7 @@ def test_exc_handler_to_dict_syntax(input_json): def test_exc_handler_raises_compiler(input_json): input_json["sources"]["badcode.vy"] = {"content": BAD_COMPILER_CODE} - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): compile_json(input_json) @@ -256,7 +256,7 @@ def test_exc_handler_to_dict_compiler(input_json): assert len(result["errors"]) == 1 error = result["errors"][0] assert error["component"] == "compiler" - assert error["type"] == "InvalidType" + assert error["type"] == "TypeMismatch" def test_source_ids_increment(input_json): diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 5ea373fc19..5487a47d97 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -4,7 +4,6 @@ from vyper.exceptions import ( ArrayIndexException, InvalidReference, - InvalidType, TypeMismatch, UndeclaredDefinition, ) @@ -37,7 +36,7 @@ def foo(): self.a[{value}] = 12 """ vyper_module = parse_to_ast(code) - with pytest.raises(InvalidType): + with pytest.raises(TypeMismatch): validate_semantics(vyper_module, dummy_input_bundle) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 1c56b16020..abbf6a68cc 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -609,8 +609,7 @@ def validate_expected_type(node, expected_type): if expected_type[0] == AddressT() and given_types[0] == BytesM_T(20): suggestion_str = f" Did you mean {checksum_encode(node.value)}?" - # CMC 2022-02-14 maybe TypeMismatch would make more sense here - raise InvalidType( + raise TypeMismatch( f"Expected {expected_str} but literal can only be cast as {given_str}.{suggestion_str}", node, ) From 8ccacb3f47f864ec2ff64d5f7ca65625e9df6b2f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 10 Feb 2024 08:39:51 -0800 Subject: [PATCH 179/201] feat[lang]: singleton modules with ownership hierarchy (#3729) this commit implements "singleton modules with ownership hierarchy" as described in https://github.com/vyperlang/vyper/issues/3722. to accomplish this, two new language constructs are added: `UsesDecl` and `InitializesDecl`. these are exposed to the user as `uses:` and `initializes:`. they are also accompanied by new `AnalysisResult` data structures: `UsesInfo` and `InitializesInfo`. `uses` and `initializes` can be thought of as a constraint system on the module system. a `uses: my-module` annotation is required if `my_module`'s state is accessed (read or written), and `initializes: my_module` is required to call `my_module.__init__()`. a module can be `use`d any number of times; it can only be `initialize`d once. a module which has been used (directly, or transitively) by the compilation target (main entry point module), must be `initialize`d exactly once. `initializes:` is also required to declare which modules it has been `initialize`d with. for example, if `mod1` declares it `uses: mod2`, then any `initializes: mod1` statement must declare *which* instance of `mod2` it has been initialized with. although there is only ever a single instance of `mod2`, this user-facing requirement improves readability by forcing the user to be aware of what the state access dependencies are for a given, `initialize`d module. the `NamedExpr` node ("walrus operator") has been added to the AST to support the initializer syntax. (note: the walrus operator is used, because the originally proposed syntax, `mod1[mod2 = mod2]` is rejected by the python parser). a new compiler pass, `vyper/semantics/analysis/global.py` has been added to implement the global initializer constraint, as it cannot be defined recursively (without a global context). since `__init__()` functions can now be called from other `__init__()` functions (which is not allowed for normal `@external` functions!), a new `@deploy` visibility has been added to vyper's visibility system. `@deploy` functions can be called from other `@deploy` functions, and never from `@external` or `@internal` functions. they also have special treatment in the ABI relative to other `@external` functions. `initializes:` is useful since it also serves the purpose of being a storage allocator directive. wherever `initializes:` is placed, is where the module will be placed in storage (and code, transient storage, or any other future storage locations). this commit refactors the storage allocator so that it recurses into child modules whenever it sees an `initializes:` statement. it refactors several data structures surrounding the storage allocator, including removing inheritance on the `DataPosition` data structure (which has also been renamed to `VarOffset`). some utility functions have been added for calculating the size of a given variable, which also get used in codegen (`get_element_ptr()`). additional work/refactoring in this commit: - new analysis machinery for detecting reads/writes for all `ExprInfo`s - dynamic programming on the `get_expr_info()` routine - refactoring of `visit_Expr`, which fixes call mutability analysis - move `StringEnum` back to vyper/utils.py - remove the "TYPE_DEFINITION" kludge in certain builtins, replace with usage of `TYPE_T` - improve `tag_exceptions()` formatting - remove `Context.globals`, as we rely on the results of the front-end analyser now. - remove dead variable: `Context.in_assertion` - refactor `generate_ir_for_function` into `generate_ir_for_external_function` and `generate_ir_for_internal_function` - move `get_nonreentrant_lock` to `function_definitions/common.py` - simplify layout allocation across locations into single function - add `VyperType.get_size_in()` and `VarInfo.get_size()` helper functions so we don't need to do as much switch/case in implementation functions - refactor `codegen/core.py` functions to use `VyperType.get_size()` - fix interfaces access from `.vyi` files --- examples/auctions/blind_auction.vy | 4 +- examples/auctions/simple_open_auction.vy | 4 +- examples/crowdfund.vy | 4 +- examples/factory/Exchange.vy | 4 +- examples/factory/Factory.vy | 4 +- .../market_maker/on_chain_market_maker.vy | 2 + examples/name_registry/name_registry.vy | 1 + .../safe_remote_purchase.vy | 4 +- examples/stock/company.vy | 4 +- examples/storage/advanced_storage.vy | 4 +- examples/storage/storage.vy | 6 +- examples/tokens/ERC1155ownable.vy | 5 +- examples/tokens/ERC20.vy | 4 +- examples/tokens/ERC4626.vy | 4 +- examples/tokens/ERC721.vy | 4 +- examples/voting/ballot.vy | 4 +- examples/wallet/wallet.vy | 4 +- tests/functional/builtins/codegen/test_abi.py | 4 +- .../builtins/codegen/test_abi_decode.py | 2 +- .../builtins/codegen/test_abi_encode.py | 2 +- .../functional/builtins/codegen/test_ceil.py | 4 +- .../builtins/codegen/test_concat.py | 4 +- .../builtins/codegen/test_create_functions.py | 10 +- .../builtins/codegen/test_ecrecover.py | 2 +- .../functional/builtins/codegen/test_floor.py | 4 +- .../builtins/codegen/test_raw_call.py | 2 +- .../functional/builtins/codegen/test_slice.py | 10 +- .../test_default_function.py | 2 +- .../calling_convention/test_erc20_abi.py | 2 +- .../test_external_contract_calls.py | 31 +- ...test_modifiable_external_contract_calls.py | 8 +- .../calling_convention/test_return_tuple.py | 2 +- .../features/decorators/test_payable.py | 4 +- .../features/decorators/test_private.py | 4 +- .../features/iteration/test_range_in.py | 2 +- .../codegen/features/test_bytes_map_keys.py | 12 +- .../codegen/features/test_clampers.py | 2 +- .../codegen/features/test_constructor.py | 22 +- .../codegen/features/test_immutable.py | 51 +- .../functional/codegen/features/test_init.py | 8 +- .../codegen/features/test_logging.py | 4 +- .../codegen/features/test_ternary.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/integration/test_escrow.py | 2 +- .../codegen/modules/test_module_constants.py | 20 + .../codegen/modules/test_module_variables.py | 318 +++++ .../codegen/storage_variables/test_getters.py | 4 +- .../test_storage_variable.py | 2 +- tests/functional/codegen/test_interfaces.py | 12 +- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_dynamic_array.py | 4 +- tests/functional/codegen/types/test_flag.py | 2 +- tests/functional/codegen/types/test_string.py | 2 +- .../test_safe_remote_purchase.py | 2 +- .../syntax/exceptions/test_call_violation.py | 9 + .../exceptions/test_constancy_exception.py | 59 +- .../test_function_declaration_exception.py | 10 +- .../test_instantiation_exception.py | 2 +- .../exceptions/test_invalid_reference.py | 2 +- .../exceptions/test_structure_exception.py | 6 +- .../exceptions/test_vyper_exception_pos.py | 2 +- .../syntax/modules/test_deploy_visibility.py | 27 + .../syntax/modules/test_implements.py | 51 + .../syntax/modules/test_initializers.py | 1139 +++++++++++++++++ tests/functional/syntax/test_address_code.py | 4 +- tests/functional/syntax/test_codehash.py | 2 +- tests/functional/syntax/test_constants.py | 4 +- tests/functional/syntax/test_immutables.py | 22 +- tests/functional/syntax/test_init.py | 64 + tests/functional/syntax/test_interfaces.py | 4 +- tests/functional/syntax/test_public.py | 2 +- tests/functional/syntax/test_tuple_assign.py | 2 +- tests/unit/ast/test_ast_dict.py | 10 - .../cli/storage_layout/test_storage_layout.py | 250 +++- tests/unit/compiler/asm/test_asm_optimizer.py | 22 +- tests/unit/compiler/test_bytecode_runtime.py | 2 +- tests/unit/semantics/test_storage_slots.py | 4 +- vyper/ast/__init__.py | 3 +- vyper/ast/grammar.lark | 14 +- vyper/ast/nodes.py | 105 +- vyper/ast/nodes.pyi | 35 +- vyper/ast/parse.py | 4 +- vyper/builtins/_signatures.py | 13 +- vyper/builtins/_utils.py | 6 +- vyper/builtins/functions.py | 18 +- vyper/codegen/context.py | 19 +- vyper/codegen/core.py | 61 +- vyper/codegen/expr.py | 37 +- .../codegen/function_definitions/__init__.py | 5 +- vyper/codegen/function_definitions/common.py | 120 +- .../function_definitions/external_function.py | 49 +- .../function_definitions/internal_function.py | 34 +- vyper/codegen/function_definitions/utils.py | 31 - vyper/codegen/module.py | 31 +- vyper/codegen/stmt.py | 2 +- vyper/compiler/phases.py | 27 +- vyper/evm/address_space.py | 8 - vyper/exceptions.py | 25 +- vyper/semantics/analysis/__init__.py | 2 +- vyper/semantics/analysis/base.py | 286 ++--- vyper/semantics/analysis/constant_folding.py | 2 +- vyper/semantics/analysis/data_positions.py | 221 ++-- vyper/semantics/analysis/global_.py | 80 ++ vyper/semantics/analysis/local.py | 228 +++- vyper/semantics/analysis/module.py | 265 +++- vyper/semantics/analysis/utils.py | 45 +- vyper/semantics/data_locations.py | 16 +- vyper/semantics/types/base.py | 23 +- vyper/semantics/types/function.py | 91 +- vyper/semantics/types/module.py | 94 +- vyper/semantics/types/utils.py | 16 +- vyper/utils.py | 56 +- 112 files changed, 3566 insertions(+), 845 deletions(-) create mode 100644 tests/functional/codegen/modules/test_module_variables.py create mode 100644 tests/functional/syntax/modules/test_deploy_visibility.py create mode 100644 tests/functional/syntax/modules/test_implements.py create mode 100644 tests/functional/syntax/modules/test_initializers.py create mode 100644 tests/functional/syntax/test_init.py delete mode 100644 vyper/codegen/function_definitions/utils.py create mode 100644 vyper/semantics/analysis/global_.py diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 597aed57c7..966565138f 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Blind Auction. Adapted to Vyper from [Solidity by Example](https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst#blind-auction-1) struct Bid: @@ -36,7 +38,7 @@ pendingReturns: HashMap[address, uint256] # Create a blinded auction with `_biddingTime` seconds bidding time and # `_revealTime` seconds reveal time on behalf of the beneficiary address # `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _biddingTime: uint256, _revealTime: uint256): self.beneficiary = _beneficiary self.biddingEnd = block.timestamp + _biddingTime diff --git a/examples/auctions/simple_open_auction.vy b/examples/auctions/simple_open_auction.vy index 6d5ce06f17..499e12af16 100644 --- a/examples/auctions/simple_open_auction.vy +++ b/examples/auctions/simple_open_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Open Auction # Auction params @@ -19,7 +21,7 @@ pendingReturns: public(HashMap[address, uint256]) # Create a simple auction with `_auction_start` and # `_bidding_time` seconds bidding time on behalf of the # beneficiary address `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _auction_start: uint256, _bidding_time: uint256): self.beneficiary = _beneficiary self.auctionStart = _auction_start # auction start time can be in the past, present or future diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 6d07e15bc4..50ec005924 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -11,7 +13,7 @@ goal: public(uint256) timelimit: public(uint256) # Setup global variables -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/examples/factory/Exchange.vy b/examples/factory/Exchange.vy index 77f47984bc..e66c60743a 100644 --- a/examples/factory/Exchange.vy +++ b/examples/factory/Exchange.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 @@ -9,7 +11,7 @@ token: public(ERC20) factory: Factory -@external +@deploy def __init__(_token: ERC20, _factory: Factory): self.token = _token self.factory = _factory diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index 50e7a81bf6..4fec723197 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 interface Exchange: @@ -11,7 +13,7 @@ exchange_codehash: public(bytes32) exchanges: public(HashMap[ERC20, Exchange]) -@external +@deploy def __init__(_exchange_codehash: bytes32): # Register the exchange code hash during deployment of the factory self.exchange_codehash = _exchange_codehash diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index 4f9859584c..74b1307dc1 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 diff --git a/examples/name_registry/name_registry.vy b/examples/name_registry/name_registry.vy index 7152851dac..937b41856b 100644 --- a/examples/name_registry/name_registry.vy +++ b/examples/name_registry/name_registry.vy @@ -1,3 +1,4 @@ +#pragma version >0.3.10 registry: HashMap[Bytes[100], address] diff --git a/examples/safe_remote_purchase/safe_remote_purchase.vy b/examples/safe_remote_purchase/safe_remote_purchase.vy index edc2163b85..91f0159a2d 100644 --- a/examples/safe_remote_purchase/safe_remote_purchase.vy +++ b/examples/safe_remote_purchase/safe_remote_purchase.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Safe Remote Purchase # Originally from # https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst @@ -19,7 +21,7 @@ buyer: public(address) unlocked: public(bool) ended: public(bool) -@external +@deploy @payable def __init__(): assert (msg.value % 2) == 0 diff --git a/examples/stock/company.vy b/examples/stock/company.vy index 6293e6eea4..355432830d 100644 --- a/examples/stock/company.vy +++ b/examples/stock/company.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Financial events the contract logs event Transfer: @@ -27,7 +29,7 @@ price: public(uint256) holdings: HashMap[address, uint256] # Set up the company. -@external +@deploy def __init__(_company: address, _total_shares: uint256, initial_price: uint256): assert _total_shares > 0 assert initial_price > 0 diff --git a/examples/storage/advanced_storage.vy b/examples/storage/advanced_storage.vy index 2ba50280d7..42a455cbf1 100644 --- a/examples/storage/advanced_storage.vy +++ b/examples/storage/advanced_storage.vy @@ -1,10 +1,12 @@ +#pragma version >0.3.10 + event DataChange: setter: indexed(address) value: int128 storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x diff --git a/examples/storage/storage.vy b/examples/storage/storage.vy index 7d05e4708c..30f570f212 100644 --- a/examples/storage/storage.vy +++ b/examples/storage/storage.vy @@ -1,9 +1,11 @@ +#pragma version >0.3.10 + storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x @external def set(_x: int128): - self.storedData = _x \ No newline at end of file + self.storedData = _x diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index d1e88dcd04..d88d459d64 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,9 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### -# @version >=0.3.4 """ @dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) @@ -122,7 +123,7 @@ interface IERC1155MetadataURI: ############### functions ############### -@external +@deploy def __init__(name: String[128], symbol: String[16], uri: String[MAX_URI_LENGTH], contractUri: String[MAX_URI_LENGTH]): """ @dev contract initialization on deployment diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 77550c3f5a..0e94b32b9d 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -38,7 +40,7 @@ totalSupply: public(uint256) minter: address -@external +@deploy def __init__(_name: String[32], _symbol: String[32], _decimals: uint8, _supply: uint256): init_supply: uint256 = _supply * 10 ** convert(_decimals, uint256) self.name = _name diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 73721fdb98..699b5edd42 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy # example implementation of an ERC4626 vault @@ -50,7 +52,7 @@ event Withdraw: shares: uint256 -@external +@deploy def __init__(asset: ERC20): self.asset = asset diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index d3a8d1f13d..70dff96051 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -82,7 +84,7 @@ SUPPORTED_INTERFACES: constant(bytes4[2]) = [ 0x80ac58cd, ] -@external +@deploy def __init__(): """ @dev Contract constructor. diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 107716accf..daaf712e0f 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Voting with delegation. # Information about voters @@ -50,7 +52,7 @@ def directlyVoted(addr: address) -> bool: # Setup global variables -@external +@deploy def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 231f538ecf..7e92c7e89c 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -12,7 +14,7 @@ threshold: int128 seq: public(int128) -@external +@deploy def __init__(_owners: address[5], _threshold: int128): for i: uint256 in range(5): if _owners[i] != empty(address): diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py index 4ddfcf50c1..335f728a37 100644 --- a/tests/functional/builtins/codegen/test_abi.py +++ b/tests/functional/builtins/codegen/test_abi.py @@ -8,14 +8,14 @@ """ x: int128 -@external +@deploy def __init__(): self.x = 1 """, """ x: int128 -@external +@deploy def __init__(): pass """, diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 69bfef63ea..96cbbe4c2d 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -224,7 +224,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f4b7d57a04..8709e31470 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -263,7 +263,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_ceil.py b/tests/functional/builtins/codegen/test_ceil.py index daa9cb7c1b..191e2adfef 100644 --- a/tests/functional/builtins/codegen/test_ceil.py +++ b/tests/functional/builtins/codegen/test_ceil.py @@ -6,7 +6,7 @@ def test_ceil(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -53,7 +53,7 @@ def test_ceil_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_concat.py b/tests/functional/builtins/codegen/test_concat.py index 7354515989..37bdaaaf7b 100644 --- a/tests/functional/builtins/codegen/test_concat.py +++ b/tests/functional/builtins/codegen/test_concat.py @@ -79,7 +79,7 @@ def test_concat_buffer2(get_contract): code = """ i: immutable(int256) -@external +@deploy def __init__(): i = -1 s: String[2] = concat("a", "b") @@ -99,7 +99,7 @@ def test_concat_buffer3(get_contract): s2: String[33] s3: String[34] -@external +@deploy def __init__(): self.s = "a" self.s2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" # 33*'a' diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index afa729ac8a..0aa718157c 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -214,7 +214,7 @@ def test_create_from_blueprint_bad_code_offset( deployer_code = """ BLUEPRINT: immutable(address) -@external +@deploy def __init__(blueprint_address: address): BLUEPRINT = blueprint_address @@ -269,7 +269,7 @@ def test_create_from_blueprint_args( FOO: immutable(String[128]) BAR: immutable(Bar) -@external +@deploy def __init__(foo: String[128], bar: Bar): FOO = foo BAR = bar @@ -450,7 +450,7 @@ def test_create_from_blueprint_complex_value( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -507,7 +507,7 @@ def test_create_from_blueprint_complex_salt_raw_args( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -565,7 +565,7 @@ def test_create_from_blueprint_complex_salt_no_constructor_args( code = """ var: uint256 -@external +@deploy @payable def __init__(): self.var = 12 diff --git a/tests/functional/builtins/codegen/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py index 8571948c3d..ce24868afe 100644 --- a/tests/functional/builtins/codegen/test_ecrecover.py +++ b/tests/functional/builtins/codegen/test_ecrecover.py @@ -68,7 +68,7 @@ def test_invalid_signature2(get_contract): owner: immutable(address) -@external +@deploy def __init__(): owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf diff --git a/tests/functional/builtins/codegen/test_floor.py b/tests/functional/builtins/codegen/test_floor.py index d2fd993785..5caffd5551 100644 --- a/tests/functional/builtins/codegen/test_floor.py +++ b/tests/functional/builtins/codegen/test_floor.py @@ -6,7 +6,7 @@ def test_floor(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -55,7 +55,7 @@ def test_floor_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index b30a94502d..e5201e9bb2 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -137,7 +137,7 @@ def set_owner(i: int128, o: address): owners: public(address[5]) -@external +@deploy def __init__(_owner_setter: address): self.owner_setter_contract = _owner_setter diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 80936bbf82..0c5a8fc485 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -57,7 +57,7 @@ def test_slice_immutable( IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) IMMUTABLE_SLICE: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256): IMMUTABLE_BYTES = inp IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length}) @@ -119,7 +119,7 @@ def test_slice_bytes_fuzz( elif location == "code": preamble = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(foo: Bytes[{length_bound}]): IMMUTABLE_BYTES = foo """ @@ -230,7 +230,7 @@ def test_slice_immutable_length_arg(get_contract_with_gas_estimation): code = """ LENGTH: immutable(uint256) -@external +@deploy def __init__(): LENGTH = 5 @@ -314,7 +314,7 @@ def f() -> bytes32: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f @@ -325,7 +325,7 @@ def bar() -> Bytes[{length}]: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index cf55607877..411f38eac9 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -2,7 +2,7 @@ def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) -@external +@deploy def __init__(): self.x = 123 """ diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index b9dc5c663f..59c4131fb2 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -33,7 +33,7 @@ def allowance(_owner: address, _spender: address) -> uint256: nonpayable token_address: ERC20Contract -@external +@deploy def __init__(token_addr: address): self.token_address = ERC20Contract(token_addr) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index a7cf4d0ecf..8b3f30b5a5 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -41,7 +41,7 @@ def test_complicated_external_contract_calls(get_contract, get_contract_with_gas contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky @@ -898,26 +898,31 @@ def set_lucky(arg1: address, arg2: int128): print("Successfully executed an external contract call state change") -def test_constant_external_contract_call_cannot_change_state( - assert_compile_failed, get_contract_with_gas_estimation -): +def test_constant_external_contract_call_cannot_change_state(): c = """ interface Foo: def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_expr(arg1: address, arg2: int128): +def set_lucky_stmt(arg1: address, arg2: int128): Foo(arg1).set_lucky(arg2) + """ + with pytest.raises(StateAccessViolation): + compile_code(c) + + c2 = """ +interface Foo: + def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_stmt(arg1: address, arg2: int128) -> int128: +def set_lucky_expr(arg1: address, arg2: int128) -> int128: return Foo(arg1).set_lucky(arg2) """ - assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) - print("Successfully blocked an external contract call from a constant function") + with pytest.raises(StateAccessViolation): + compile_code(c2) def test_external_contract_can_be_changed_based_on_address(get_contract): @@ -968,7 +973,7 @@ def test_external_contract_calls_with_public_globals(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -994,7 +999,7 @@ def test_external_contract_calls_with_multiple_contracts(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -1008,7 +1013,7 @@ def lucky() -> int128: view magic_number: public(int128) -@external +@deploy def __init__(arg1: address): self.magic_number = Foo(arg1).lucky() """ @@ -1020,7 +1025,7 @@ def magic_number() -> int128: view best_number: public(int128) -@external +@deploy def __init__(arg1: address): self.best_number = Bar(arg1).magic_number() """ @@ -1145,7 +1150,7 @@ def test_invalid_contract_reference_declaration(tx_failed, get_contract): best_number: public(int128) -@external +@deploy def __init__(): pass """ diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index e6b2402016..aa7130fd6a 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -20,7 +20,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -64,7 +64,7 @@ def set_lucky(_lucky: int128) -> int128: view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -108,7 +108,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -134,7 +134,7 @@ def static_set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index 266555ead6..74929c9496 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -16,7 +16,7 @@ def test_return_type(get_contract_with_gas_estimation): c: int128 chunk: Chunk -@external +@deploy def __init__(): self.chunk.a = b"hello" self.chunk.b = b"world" diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index ced58e1af0..955501a0e6 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -122,7 +122,7 @@ def bar() -> bool: """, """ # payable init function -@external +@deploy @payable def __init__(): a: int128 = 1 @@ -279,7 +279,7 @@ def baz() -> bool: """, """ # init function -@external +@deploy def __init__(): a: int128 = 1 diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 39ea1bb9ae..193112f02b 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -120,7 +120,7 @@ def test_private_bytes(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[100]) -@external +@deploy def __init__(): self.greeting = b"Hello " @@ -143,7 +143,7 @@ def test_private_statement(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[20]) -@external +@deploy def __init__(): self.greeting = b"Hello " diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 7540049778..f381f60b35 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -115,7 +115,7 @@ def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): owners: address[2] -@external +@deploy def __init__(): self.owners[0] = msg.sender diff --git a/tests/functional/codegen/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py index 4913182d52..22df767f02 100644 --- a/tests/functional/codegen/features/test_bytes_map_keys.py +++ b/tests/functional/codegen/features/test_bytes_map_keys.py @@ -80,7 +80,7 @@ def test_extended_bytes_key_from_storage(get_contract): code = """ a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"] = 1069 @@ -114,7 +114,7 @@ def test_struct_bytes_key_memory(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -145,7 +145,7 @@ def test_struct_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Foo -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -172,7 +172,7 @@ def test_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Bytes[5] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.b = b"hello" @@ -193,7 +193,7 @@ def test_bytes_key_calldata(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 @@ -215,7 +215,7 @@ def test_struct_bytes_hashmap_as_key_in_other_hashmap(get_contract): bar: public(HashMap[uint256, Thing]) foo: public(HashMap[Bytes[64], uint256]) -@external +@deploy def __init__(): self.foo[b"hello"] = 31337 self.bar[12] = Thing({name: b"hello"}) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 6db8570fc7..c028805c6a 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -67,7 +67,7 @@ def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] -@external +@deploy def __init__(x: Bytes[3]): self.foo = x diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index c9dfcfc5df..9146ace8a6 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -6,7 +6,7 @@ def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ moose: int128 -@external +@deploy def __init__(_moose: int128): self.moose = _moose @@ -26,7 +26,7 @@ def test_constructor_mapping(get_contract_with_gas_estimation): X: constant(bytes4) = 0x01ffc9a7 -@external +@deploy def __init__(): self.foo[X] = True @@ -44,7 +44,7 @@ def test_constructor_advanced_code(get_contract_with_gas_estimation): constructor_advanced_code = """ twox: int128 -@external +@deploy def __init__(x: int128): self.twox = x * 2 @@ -60,7 +60,7 @@ def test_constructor_advanced_code2(get_contract_with_gas_estimation): constructor_advanced_code2 = """ comb: uint256 -@external +@deploy def __init__(x: uint256[2], y: Bytes[3], z: uint256): self.comb = x[0] * 1000 + x[1] * 100 + len(y) * 10 + z @@ -90,7 +90,7 @@ def foo(x: int128) -> int128: def test_large_input_code_2(w3, get_contract_with_gas_estimation): large_input_code_2 = """ -@external +@deploy def __init__(x: int128): y: int128 = x @@ -113,7 +113,7 @@ def test_initialise_array_with_constant_key(get_contract_with_gas_estimation): foo: int16[X] -@external +@deploy def __init__(): self.foo[X-1] = -2 @@ -133,7 +133,7 @@ def test_initialise_dynarray_with_constant_key(get_contract_with_gas_estimation) foo: DynArray[int16, X] -@external +@deploy def __init__(): self.foo = [X - 3, X - 4, X - 5, X - 6] @@ -151,7 +151,7 @@ def test_nested_dynamic_array_constructor_arg(w3, get_contract_with_gas_estimati code = """ foo: uint256 -@external +@deploy def __init__(x: DynArray[DynArray[uint256, 3], 3]): self.foo = x[0][2] + x[1][1] + x[2][0] @@ -167,7 +167,7 @@ def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estima code = """ foo: int128 -@external +@deploy def __init__(x: DynArray[DynArray[DynArray[int128, 3], 3], 3]): self.foo = x[0][1][2] * x[1][1][1] * x[2][1][0] - x[0][0][0] - x[1][1][1] - x[2][2][2] @@ -192,7 +192,7 @@ def test_initialise_nested_dynamic_array(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[uint256, 3], 3] -@external +@deploy def __init__(x: uint256, y: uint256, z: uint256): self.foo = [ [x, y, z], @@ -212,7 +212,7 @@ def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation) code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] -@external +@deploy def __init__(x: int128, y: int128, z: int128): self.foo = [ [[x, y, z], [y, z, x], [z, y, x]], diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 47f7fc748e..d0bc47c238 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -20,7 +20,7 @@ def test_value_storage_retrieval(typ, value, get_contract): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -41,7 +41,7 @@ def test_usage_in_constructor(get_contract, val): a: public(uint256) -@external +@deploy def __init__(_a: uint256): A = _a self.a = A @@ -63,7 +63,7 @@ def test_multiple_immutable_values(get_contract): b: immutable(address) c: immutable(String[64]) -@external +@deploy def __init__(_a: uint256, _b: address, _c: String[64]): a = _a b = _b @@ -89,7 +89,7 @@ def test_struct_immutable(get_contract): my_struct: immutable(MyStruct) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: address, _d: int256): my_struct = MyStruct({ a: _a, @@ -108,11 +108,34 @@ def get_my_struct() -> MyStruct: assert c.get_my_struct() == values +def test_complex_immutable_modifiable(get_contract): + code = """ +struct MyStruct: + a: uint256 + +my_struct: immutable(MyStruct) + +@deploy +def __init__(a: uint256): + my_struct = MyStruct({a: a}) + + # struct members are modifiable after initialization + my_struct.a += 1 + +@view +@external +def get_my_struct() -> MyStruct: + return my_struct + """ + c = get_contract(code, 1) + assert c.get_my_struct() == (2,) + + def test_list_immutable(get_contract): code = """ my_list: immutable(uint256[3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -130,7 +153,7 @@ def test_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[uint256, 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -154,7 +177,7 @@ def test_nested_dynarray_immutable_2(get_contract): code = """ my_list: immutable(DynArray[DynArray[uint256, 3], 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]] @@ -179,7 +202,7 @@ def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) -@external +@deploy def __init__(x: int128, y: int128, z: int128): my_list = [ [[x, y, z], [y, z, x], [z, y, x]], @@ -227,7 +250,7 @@ def foo() -> uint256: counter: uint256 VALUE: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.counter = x self.foo() @@ -257,7 +280,7 @@ def foo() -> uint256: b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -281,7 +304,7 @@ def test_immutables_initialized2(get_contract, get_contract_from_ir): b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -299,7 +322,7 @@ def test_internal_functions_called_by_ctor_location(get_contract): d: uint256 x: immutable(uint256) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -323,7 +346,7 @@ def test_nested_internal_function_immutables(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -348,7 +371,7 @@ def test_immutable_read_ctor_and_runtime(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index fc765f8ab3..84d224f632 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -5,7 +5,7 @@ def test_basic_init_function(get_contract): code = """ val: public(uint256) -@external +@deploy def __init__(a: uint256): self.val = a """ @@ -27,10 +27,12 @@ def __init__(a: uint256): def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) + @internal def bar(x: uint256) -> uint8: return convert(x, uint8) * 7 -@external + +@deploy def __init__(a: uint256): self.foo = self.bar(a) @@ -61,7 +63,7 @@ def test_nested_internal_call_from_ctor(get_contract): code = """ x: uint256 -@external +@deploy def __init__(): self.a() diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 0cb8ad9abc..8b80811d02 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -646,7 +646,7 @@ def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_e arg3: indexed(int128) arg4: indexed(int128) -@external +@deploy def __init__(): log MyLog(1, 2, 3, 4) """ @@ -1033,7 +1033,7 @@ def test_mixed_var_list_packing(get_logs, get_contract_with_gas_estimation): x: int128[4] y: int128[2] -@external +@deploy def __init__(): self.y = [1024, 2048] diff --git a/tests/functional/codegen/features/test_ternary.py b/tests/functional/codegen/features/test_ternary.py index c5480286c8..661fdc86c9 100644 --- a/tests/functional/codegen/features/test_ternary.py +++ b/tests/functional/codegen/features/test_ternary.py @@ -195,7 +195,7 @@ def test_ternary_tuple(get_contract, code, test): def test_ternary_immutable(get_contract, test): code = """ IMM: public(immutable(uint256)) -@external +@deploy def __init__(test: bool): IMM = 1 if test else 2 """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 891ed5aebe..1a8b3f7e9f 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -13,7 +13,7 @@ def test_crowdfund(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit @@ -109,7 +109,7 @@ def test_crowdfund2(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 70e7cb4594..f86b4aa516 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -41,7 +41,7 @@ def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimat seller: address arbitrator: address -@external +@deploy @payable def __init__(_seller: address, _arbitrator: address): if self.buyer == empty(address): diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py index aafbb69252..ebfefb4546 100644 --- a/tests/functional/codegen/modules/test_module_constants.py +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -76,3 +76,23 @@ def foo(ix: uint256) -> uint256: assert c.foo(2) == 3 with tx_failed(): c.foo(3) + + +def test_module_constant_builtin(make_input_bundle, get_contract): + # test empty builtin, which is not (currently) foldable 2024-02-06 + mod1 = """ +X: constant(uint256) = empty(uint256) + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 0 diff --git a/tests/functional/codegen/modules/test_module_variables.py b/tests/functional/codegen/modules/test_module_variables.py new file mode 100644 index 0000000000..6bb1f9072c --- /dev/null +++ b/tests/functional/codegen/modules/test_module_variables.py @@ -0,0 +1,318 @@ +def test_simple_import(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import lib + +initializes: lib + +@external +def increment_counter(): + lib.increment_counter() + +@external +def get_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == 0 + c.increment_counter(transact={}) + assert c.get_counter() == 1 + + +def test_import_namespace(get_contract, make_input_bundle): + # test what happens when things in current and imported modules share names + lib = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: uint256 + +initializes: lib + +@external +def increment_counter(): + self.counter += 1 + +@external +def increment_lib_counter(): + lib.increment_counter() + +@external +def increment_lib_counter2(): + # modify lib.counter directly + lib.counter += 5 + +@external +def get_counter() -> uint256: + return self.counter + +@external +def get_lib_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == c.get_lib_counter() == 0 + + c.increment_counter(transact={}) + assert c.get_counter() == 1 + assert c.get_lib_counter() == 0 + + c.increment_lib_counter(transact={}) + assert c.get_lib_counter() == 1 + assert c.get_counter() == 1 + + c.increment_lib_counter2(transact={}) + assert c.get_lib_counter() == 6 + assert c.get_counter() == 1 + + +def test_init_function_side_effects(get_contract, make_input_bundle): + lib = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: public(uint256) + +MY_IMMUTABLE: public(immutable(uint256)) + +initializes: lib + +@deploy +def __init__(): + self.counter = 1 + MY_IMMUTABLE = 3 + lib.__init__(5) + +@external +def get_lib_counter() -> uint256: + return lib.counter + +@external +def get_lib_immutable() -> uint256: + return lib.MY_IMMUTABLE + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.counter() == 1 + assert c.MY_IMMUTABLE() == 3 + assert c.get_lib_counter() == 5 + assert c.get_lib_immutable() == 10 + + +def test_indirect_variable_uses(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +uses: lib1 + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2[lib1 := lib1] + +@deploy +def __init__(): + lib1.__init__(5) + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 5 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 10 + + +def test_uses_already_initialized(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value * 2 + MY_IMMUTABLE = initial_value * 3 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__(5) + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +uses: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib2.__init__() + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 10 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 15 + + +def test_import_complex_types(get_contract, make_input_bundle): + lib1 = """ +an_array: uint256[3] +a_hashmap: HashMap[address, HashMap[uint256, uint256]] + +@internal +def set_array_value(ix: uint256, new_value: uint256): + self.an_array[ix] = new_value + +@internal +def set_hashmap_value(ix0: address, ix1: uint256, new_value: uint256): + self.a_hashmap[ix0][ix1] = new_value + """ + + contract = """ +import lib + +initializes: lib + +@external +def do_things(): + lib.set_array_value(1, 5) + lib.set_hashmap_value(msg.sender, 6, 100) + +@external +def get_array_value(ix: uint256) -> uint256: + return lib.an_array[ix] + +@external +def get_hashmap_value(ix: uint256) -> uint256: + return lib.a_hashmap[msg.sender][ix] + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + c.do_things(transact={}) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + assert c.get_array_value(1) == 5 + assert c.get_hashmap_value(6) == 100 diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index a2d9c6d0bb..9e72bed075 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -41,7 +41,7 @@ def foo(): nonpayable f: public(constant(uint256[2])) = [3, 7] g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) -@external +@deploy def __init__(): self.x = as_wei_value(7, "wei") self.y[1] = 9 @@ -87,7 +87,7 @@ def test_getter_mutability(get_contract): nyoro: public(constant(uint256)) = 2 kune: public(immutable(uint256)) -@external +@deploy def __init__(): kune = 2 """ diff --git a/tests/functional/codegen/storage_variables/test_storage_variable.py b/tests/functional/codegen/storage_variables/test_storage_variable.py index 4636fa77e0..7a22d35e4b 100644 --- a/tests/functional/codegen/storage_variables/test_storage_variable.py +++ b/tests/functional/codegen/storage_variables/test_storage_variable.py @@ -10,7 +10,7 @@ def test_permanent_variables_test(get_contract_with_gas_estimation): b: int128 var: Var -@external +@deploy def __init__(a: int128, b: int128): self.var.a = a self.var.b = b diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3344ff113b..85efe904a0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -305,7 +305,7 @@ def test() -> uint256: view token_address: IToken -@external +@deploy def __init__(_token_address: address): self.token_address = IToken(_token_address) @@ -388,7 +388,7 @@ def transfer(to: address, amount: uint256) -> bool: token_address: ERC20 -@external +@deploy def __init__(_token_address: address): self.token_address = ERC20(_token_address) @@ -445,7 +445,7 @@ def should_fail() -> {typ}: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -501,7 +501,7 @@ def should_fail() -> Bytes[2]: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -551,7 +551,7 @@ def foo(x: BadJSONInterface) -> Bytes[2]: foo: BadJSONInterface -@external +@deploy def __init__(addr: BadJSONInterface): self.foo = addr @@ -667,7 +667,7 @@ def foo() -> uint256: view bar_contract: Bar -@external +@deploy def __init__(): self.bar_contract = Bar(self) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 325f9d7923..99e5835f6e 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -51,7 +51,7 @@ def test_test_bytes3(get_contract_with_gas_estimation): maa: Bytes[60] y: int128 -@external +@deploy def __init__(): self.x = 27 self.y = 37 diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d3d945740b..fc3223caaf 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1665,7 +1665,7 @@ def ix(i: uint256) -> decimal: def test_public_dynarray(get_contract): code = """ my_list: public(DynArray[uint256, 5]) -@external +@deploy def __init__(): self.my_list = [1,2,3] """ @@ -1678,7 +1678,7 @@ def __init__(): def test_nested_public_dynarray(get_contract): code = """ my_list: public(DynArray[DynArray[uint256, 5], 5]) -@external +@deploy def __init__(): self.my_list = [[1,2,3]] """ diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 5da6d57558..dd9c867a96 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -160,7 +160,7 @@ def test_augassign_storage(get_contract, w3, tx_failed): roles: public(HashMap[address, Roles]) -@external +@deploy def __init__(): self.roles[msg.sender] = Roles.ADMIN diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 9d50f8df38..9d596eda32 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -90,7 +90,7 @@ def test_private_string(get_contract_with_gas_estimation): private_test_code = """ greeting: public(String[100]) -@external +@deploy def __init__(): self.greeting = "Hello " diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index e21a113f61..f6eb3966d4 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -118,7 +118,7 @@ def unlocked() -> bool: view purchase_contract: PurchaseContract -@external +@deploy def __init__(_purchase_contract: address): self.purchase_contract = PurchaseContract(_purchase_contract) diff --git a/tests/functional/syntax/exceptions/test_call_violation.py b/tests/functional/syntax/exceptions/test_call_violation.py index d310a2b42a..d96df07e74 100644 --- a/tests/functional/syntax/exceptions/test_call_violation.py +++ b/tests/functional/syntax/exceptions/test_call_violation.py @@ -27,6 +27,15 @@ def goo(): def foo(): self.goo() """, + """ +@deploy +def __init__(): + pass + +@internal +def foo(): + self.__init__() + """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 7adf9538c7..6bfb8fee57 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -78,7 +78,7 @@ def foo(): """ f:int128 -@external +@internal def a (x:int128): self.f = 100 @@ -86,6 +86,63 @@ def a (x:int128): @external def b(): self.a(10)""", + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +a:DynArray[uint16,3] +@deploy +def __init__(): + self.a = [1,2,3] +@view +@external +def bar()->DynArray[uint16,3]: + x:uint16 = self.a.pop() + return self.a # return [1,2] + """, + """ +from ethereum.ercs import ERC20 + +token: ERC20 + +@external +@view +def topup(amount: uint256): + assert self.token.transferFrom(msg.sender, self, amount) + """, ], ) def test_statefulness_violations(bad_code): diff --git a/tests/functional/syntax/exceptions/test_function_declaration_exception.py b/tests/functional/syntax/exceptions/test_function_declaration_exception.py index 3fe23e0ec7..878c7f3e29 100644 --- a/tests/functional/syntax/exceptions/test_function_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_function_declaration_exception.py @@ -34,17 +34,17 @@ def test_func() -> int128: return (1, 2) """, """ -@external +@deploy def __init__(a: int128 = 12): pass """, """ -@external +@deploy def __init__() -> uint256: return 1 """, """ -@external +@deploy def __init__() -> bool: pass """, @@ -58,7 +58,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @pure def __init__(): a = 1 @@ -66,7 +66,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @view def __init__(): a = 1 diff --git a/tests/functional/syntax/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py index 0d641f154a..4dd0bf6e02 100644 --- a/tests/functional/syntax/exceptions/test_instantiation_exception.py +++ b/tests/functional/syntax/exceptions/test_instantiation_exception.py @@ -69,7 +69,7 @@ def foo(): """ b: immutable(HashMap[uint256, uint256]) -@external +@deploy def __init__(): b = empty(HashMap[uint256, uint256]) """, diff --git a/tests/functional/syntax/exceptions/test_invalid_reference.py b/tests/functional/syntax/exceptions/test_invalid_reference.py index fe315e5cbf..7519d1406e 100644 --- a/tests/functional/syntax/exceptions/test_invalid_reference.py +++ b/tests/functional/syntax/exceptions/test_invalid_reference.py @@ -47,7 +47,7 @@ def foo(): """ a: public(immutable(uint256)) -@external +@deploy def __init__(): a = 123 diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index c6d733fc90..afc7a35012 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -94,7 +94,7 @@ def foo(): a: immutable(uint256) n: public(HashMap[uint256, bool][a]) -@external +@deploy def __init__(): a = 3 """, @@ -105,14 +105,14 @@ def __init__(): m1: HashMap[uint8, uint8] m2: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = self.m2 """, """ m1: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = 234 """, diff --git a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py index a261cb0a11..9e0767cb83 100644 --- a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py +++ b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py @@ -22,7 +22,7 @@ def test_multiple_exceptions(get_contract, assert_compile_failed): foo: immutable(uint256) bar: immutable(uint256) -@external +@deploy def __init__(): self.foo = 1 # SyntaxException self.bar = 2 # SyntaxException diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py new file mode 100644 index 0000000000..f51bf9575b --- /dev/null +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -0,0 +1,27 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import CallViolation + + +def test_call_deploy_from_external(make_input_bundle): + lib1 = """ +@deploy +def __init__(): + pass + """ + + main = """ +import lib1 + +@external +def foo(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(CallViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value.message == "Cannot call an @deploy function from an @external function!" diff --git a/tests/functional/syntax/modules/test_implements.py b/tests/functional/syntax/modules/test_implements.py new file mode 100644 index 0000000000..c292e198d9 --- /dev/null +++ b/tests/functional/syntax/modules/test_implements.py @@ -0,0 +1,51 @@ +from vyper.compiler import compile_code + + +def test_implements_from_vyi(make_input_bundle): + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + main = """ +import lib1 + +implements: lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_implements_from_vyi2(make_input_bundle): + # test implements via nested imported vyi file + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + lib2 = """ +import lib1 + """ + main = """ +import lib2 + +implements: lib2.lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py new file mode 100644 index 0000000000..a12f5f57ea --- /dev/null +++ b/tests/functional/syntax/modules/test_initializers.py @@ -0,0 +1,1139 @@ +""" +tests for the uses/initializes checker +main properties to test: +- state usage -- if a module uses state, it must `used` or `initialized` +- conversely, if a module does not touch state, it should not be `used` +- global initializer check: each used module is `initialized` exactly once +""" + +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import ( + BorrowException, + ImmutableViolation, + InitializerException, + StructureException, + UndeclaredDefinition, +) + + +def test_initialize_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + lib2.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multiple_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +# multiple uses on one line +uses: ( + lib1, + lib2 +) + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multi_line_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +uses: lib1 +uses: lib2 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_uses_attribute(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib2.__init__() + # demonstrate we can call lib1.__init__ through lib2.lib1 + # (not sure this should be allowed, really. + lib2.lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializes_without_init_function(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_imported_as_different_names(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 as some_module +import lib2 + +initializes: lib2[m := some_module] +initializes: some_module + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializer_list_module_mismatch(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +something: uint256 + """ + lib3 = """ +import lib1 + +uses: lib1 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib3[lib1 := lib2] # typo -- should be [lib1 := lib1] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + with pytest.raises(StructureException) as e: + assert compile_code(main, input_bundle=input_bundle) is not None + + assert e.value._message == "lib1 is not lib2!" + + +def test_imported_as_different_names_error(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "unknown module `lib1`" + assert e.value._hint == "did you mean `m := lib1`?" + + +def test_global_initializer_constraint(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +# forgot to initialize lib1! + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "module `lib1.vy` is used but never initialized!" + assert e.value._hint == "add `initializes: lib1` to the top level of your main contract" + + +def test_initializer_no_references(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" + assert e.value._hint == "add `lib1` to its initializer list" + + +def test_missing_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.counter + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_immutable(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 7 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.MY_IMMUTABLE + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_inside_call(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 9 + +@internal +def get_counter() -> uint256: + return MY_IMMUTABLE + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.get_counter() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_hashmap(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + return lib1.counter[1][2] + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +something: uint256 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + lib1.counter[1][2], self.something = Foo(msg.sender).foo() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple_function_call(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + +something: uint256 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +@internal +def write_tuple(): + self.counter[1][2], self.something = Foo(msg.sender).foo() + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! +@internal +def foo(): + lib1.write_tuple() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_function_call(make_input_bundle): + # test missing uses through function call + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.update_counter(lib1.counter + 1) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute_function_call(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.update_counter(new_value) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_uses_skip_import(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 + +@external +def foo(new_value: uint256): + # can access lib1 state through lib2? + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_invalid_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_invalid_uses2(make_input_bundle): + # test a more complicated invalid uses + lib1 = """ +counter: uint256 + +@internal +def foo(addr: address): + # sends value -- modifies ethereum state + to_send_value: uint256 = 100 + raw_call(addr, b"someFunction()", value=to_send_value) + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + lib1.foo(msg.sender) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@external +def foo(): + lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_initializes_uses_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_uses_initializes_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_uses_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 + +random_variable: constant(uint256) = 3 + +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_initializes_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 + +random_variable: constant(uint256) = 3 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_no_initialize_unused_module(make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + main = """ +import lib1 + +# not needed: `initializes: lib1` + +@external +def do_add(x: uint256, y: uint256) -> uint256: + return lib1.add(x, y) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_no_initialize_unused_module2(make_input_bundle): + # slightly more complicated + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + lib2 = """ +import lib1 + +@internal +@pure +def addmul(x: uint256, y: uint256, z: uint256) -> uint256: + return lib1.add(x, y) * z + """ + main = """ +import lib1 +import lib2 + +@external +def do_addmul(x: uint256, y: uint256) -> uint256: + return lib2.addmul(x, y, 5) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_init_uninitialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_init_uninitialized_function2(make_input_bundle): + # test that we can't call module.__init__() even when we call `uses` + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +uses: lib1 +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_noinit_initialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + pass # missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_noinit_initialized_function2(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +# missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_ownership_decl_errors_not_swallowed(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 +# forgot to import lib2 + +uses: (lib1, lib2) # should get UndeclaredDefinition + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "'lib2' has not been declared. " diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index fa6ed20117..5873eb5af8 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -165,7 +165,7 @@ def test_address_code_self_success(get_contract, optimize): code = """ code_deployment: public(Bytes[32]) -@external +@deploy def __init__(): self.code_deployment = slice(self.code, 0, 32) @@ -186,7 +186,7 @@ def test_address_code_self_runtime_error_deployment(get_contract): code = """ dummy: public(Bytes[1000000]) -@external +@deploy def __init__(): self.dummy = slice(self.code, 0, 1000000) """ diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index c2d9a2e274..8aada22da7 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -11,7 +11,7 @@ def test_get_extcodehash(get_contract, evm_version, optimize): code = """ a: address -@external +@deploy def __init__(): self.a = self diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 57922f28e2..63abf24485 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -94,7 +94,7 @@ VAL: immutable(uint256) VAL: uint256 -@external +@deploy def __init__(): VAL = 1 """, @@ -106,7 +106,7 @@ def __init__(): VAL: uint256 VAL: immutable(uint256) -@external +@deploy def __init__(): VAL = 1 """, diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 1027d9fe66..59fb1a69d9 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -8,7 +8,7 @@ """ VALUE: immutable(uint256) -@external +@deploy def __init__(): pass """, @@ -25,7 +25,7 @@ def get_value() -> uint256: """ VALUE: immutable(uint256) = 3 -@external +@deploy def __init__(): pass """, @@ -33,7 +33,7 @@ def __init__(): """ VALUE: immutable(uint256) -@external +@deploy def __init__(): VALUE = 0 @@ -45,7 +45,7 @@ def set_value(_value: uint256): """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 VALUE = VALUE + 1 @@ -54,7 +54,7 @@ def __init__(_value: uint256): """ VALUE: immutable(public(uint256)) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 """, @@ -85,7 +85,7 @@ def test_compilation_simple_usage(typ): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -103,7 +103,7 @@ def get_value() -> {typ}: """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 x: uint256 = VALUE + 1 @@ -121,7 +121,7 @@ def test_compilation_success(good_code): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.imm = x """, @@ -131,7 +131,7 @@ def __init__(x: uint256): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): x = imm @@ -145,7 +145,7 @@ def report(): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): imm = x @@ -163,7 +163,7 @@ def report(): x: immutable(Foo) -@external +@deploy def __init__(): x = Foo({a:1}) diff --git a/tests/functional/syntax/test_init.py b/tests/functional/syntax/test_init.py new file mode 100644 index 0000000000..389b5ad681 --- /dev/null +++ b/tests/functional/syntax/test_init.py @@ -0,0 +1,64 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import FunctionDeclarationException + +good_list = [ + """ +@deploy +def __init__(): + pass + """, + """ +@deploy +@payable +def __init__(): + pass + """, + """ +counter: uint256 +SOME_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + SOME_IMMUTABLE = 5 + self.counter = 1 + """, +] + + +@pytest.mark.parametrize("code", good_list) +def test_good_init_funcs(code): + assert compile_code(code) is not None + + +fail_list = [ + """ +@internal +def __init__(): + pass + """, + """ +@deploy +@view +def __init__(): + pass + """, + """ +@deploy +@pure +def __init__(): + pass + """, + """ +@deploy +def some_function(): # for now, only __init__() functions can be marked @deploy + pass + """, +] + + +@pytest.mark.parametrize("code", fail_list) +def test_bad_init_funcs(code): + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 584e497534..a07ec4e3dc 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -304,7 +304,7 @@ def some_func(): nonpayable my_interface: MyInterface[3] idx: uint256 -@external +@deploy def __init__(): self.my_interface[self.idx] = MyInterface(empty(address)) """, @@ -348,7 +348,7 @@ def foo() -> uint256: view foo: public(immutable(uint256)) -@external +@deploy def __init__(x: uint256): foo = x """, diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 71bff753f4..217fcea998 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -10,7 +10,7 @@ x: public(constant(int128)) = 0 y: public(immutable(int128)) -@external +@deploy def __init__(): y = 0 """, diff --git a/tests/functional/syntax/test_tuple_assign.py b/tests/functional/syntax/test_tuple_assign.py index 49b63ee614..bb23804e30 100644 --- a/tests/functional/syntax/test_tuple_assign.py +++ b/tests/functional/syntax/test_tuple_assign.py @@ -92,7 +92,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): """ B: immutable(uint256) -@external +@deploy def __init__(b: uint256): B = b diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 20390f3d5e..9fec61cb90 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -109,16 +109,6 @@ def foo() -> uint256: "node_id": 9, "src": "48:15:0", "ast_type": "ImplementsDecl", - "target": { - "col_offset": 0, - "end_col_offset": 10, - "node_id": 10, - "src": "48:10:0", - "ast_type": "Name", - "end_lineno": 5, - "lineno": 5, - "id": "implements", - }, "end_lineno": 5, "lineno": 5, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index 1aa8901881..f0ee25f747 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -56,7 +56,7 @@ def test_storage_and_immutables_layout(): SYMBOL: immutable(String[32]) DECIMALS: immutable(uint8) -@external +@deploy def __init__(): SYMBOL = "VYPR" DECIMALS = 18 @@ -72,3 +72,251 @@ def __init__(): out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout + + +def test_storage_layout_module(make_input_bundle): + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + a_library.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "counter2": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module2(make_input_bundle): + # test module storage layout, but initializes is in a different order + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +initializes: a_library + +counter2: uint256 + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "a_library": {"supply": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_uses(make_input_bundle): + # test module storage layout, with initializes/uses + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +uses: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2[lib1 := a_library] + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + + lib2.__init__(17) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": {"immutable_variable": {"length": 32, "offset": 352, "type": "uint256"}}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 384, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 448, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_nested_initializes(make_input_bundle): + # test module storage layout, with initializes in an imported module + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": { + "lib1": {"supply": {"slot": 1, "type": "uint256"}}, + "storage_variable": {"slot": 2, "type": "uint256"}, + }, + "counter2": {"slot": 3, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index b2851e908a..ce32249202 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -20,7 +20,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -44,7 +44,7 @@ def ctor_only(): def bar(): self.foo() -@external +@deploy def __init__(): self.ctor_only() """, @@ -65,7 +65,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -73,6 +73,9 @@ def __init__(): # check dead code eliminator works on unreachable functions +# CMC 2024-02-05 this is not really the asm eliminator anymore, +# it happens during function code generation in module.py. so we don't +# need to test this using asm anymore. @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) @@ -88,20 +91,9 @@ def test_dead_code_eliminator(code): assert any(ctor_only in instr for instr in initcode_asm) assert all(runtime_only not in instr for instr in initcode_asm) - # all labels should be in unoptimized runtime asm - for s in (ctor_only, runtime_only): - assert any(s in instr for instr in runtime_asm) - - c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if isinstance(i, str)] - runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] - - # ctor only label should not be in runtime code + assert any(runtime_only in instr for instr in runtime_asm) assert all(ctor_only not in instr for instr in runtime_asm) - # runtime only label should not be in initcode asm - assert all(runtime_only not in instr for instr in initcode_asm) - def test_library_code_eliminator(make_input_bundle): library = """ diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 613ee4d2b8..64cee3a75c 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -35,7 +35,7 @@ def foo5(): has_immutables = """ A_GOOD_PRIME: public(immutable(uint256)) -@external +@deploy def __init__(): A_GOOD_PRIME = 967 """ diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index ea2b2fe559..3620ef64b9 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -25,7 +25,7 @@ h: public(int256[1]) -@external +@deploy def __init__(): self.a = StructOne({a: "ok", b: [4,5,6]}) self.b = [7, 8] @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", + match=f"Invalid storage slot, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index bc08626b59..0ae93e9710 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -5,7 +5,7 @@ from . import nodes, validation from .natspec import parse_natspec -from .nodes import compare_nodes +from .nodes import compare_nodes, as_tuple from .utils import ast_to_dict from .parse import parse_to_ast, parse_to_ast_with_settings @@ -15,6 +15,5 @@ ): setattr(sys.modules[__name__], name, obj) - # required to avoid circular dependency from . import expansion # noqa: E402 diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 84429501e1..5ad465a1f1 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -182,13 +182,9 @@ loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body -// ternary operator -ternary: _expr "if" _expr "else" _expr - // Expressions _expr: operation | dict - | ternary get_item: (variable_access | list) "[" _expr "]" get_attr: variable_access "." NAME @@ -214,7 +210,15 @@ dict: "{" "}" | "{" (NAME ":" _expr) ("," (NAME ":" _expr))* [","] "}" // See https://docs.python.org/3/reference/expressions.html#operator-precedence // NOTE: The recursive cycle here helps enforce operator precedence // Precedence goes up the lower down you go -?operation: bool_or +?operation: assignment_expr + +// "walrus" operator +?assignment_expr: ternary + | NAME ":=" assignment_expr + +// ternary operator +?ternary: bool_or + | ternary "if" ternary "else" ternary _AND: "and" _OR: "or" diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 054145d33b..c4bce814a4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -83,8 +83,20 @@ def get_node( if ast_struct["value"] is not None: _raise_syntax_exc("`implements` cannot have a value assigned", ast_struct) ast_struct["ast_type"] = "ImplementsDecl" + + # Replace "uses:" `AnnAssign` nodes with `UsesDecl` + elif getattr(ast_struct["target"], "id", None) == "uses": + if ast_struct["value"] is not None: + _raise_syntax_exc("`uses` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "UsesDecl" + + # Replace "initializes:" `AnnAssign` nodes with `InitializesDecl` + elif getattr(ast_struct["target"], "id", None) == "initializes": + if ast_struct["value"] is not None: + _raise_syntax_exc("`initializes` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "InitializesDecl" + # Replace state and local variable declarations `AnnAssign` with `VariableDecl` - # Parent node is required for context to determine whether replacement should happen. else: ast_struct["ast_type"] = "VariableDecl" @@ -730,6 +742,20 @@ def is_terminus(self): return self.value.is_terminus +class NamedExpr(Stmt): + __slots__ = ("target", "value") + + def validate(self): + # module[dep1 := dep2] + + # XXX: better error messages + if not isinstance(self.target, Name): + raise StructureException("not a Name") + + if not isinstance(self.value, Name): + raise StructureException("not a Name") + + class Log(Stmt): __slots__ = ("value",) @@ -756,6 +782,11 @@ class StructDef(TopLevel): class ExprNode(VyperNode): __slots__ = ("_expr_info",) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._expr_info = None + class Constant(ExprNode): # inherited class for all simple constant node types @@ -1383,17 +1414,13 @@ class ImplementsDecl(Stmt): """ An `implements` declaration. - Excludes `simple` and `value` attributes from Python `AnnAssign` node. - Attributes ---------- - target : Name - Name node for the `implements` keyword annotation : Name Name node for the interface to be implemented """ - __slots__ = ("target", "annotation") + __slots__ = ("annotation",) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1402,6 +1429,72 @@ def __init__(self, *args, **kwargs): raise StructureException("invalid implements", self.annotation) +def as_tuple(node: VyperNode): + """ + Convenience function for some AST nodes which allow either a Tuple + or single elements. Returns a python tuple of AST nodes. + """ + if isinstance(node, Tuple): + return node.elements + else: + return (node,) + + +class UsesDecl(Stmt): + """ + A `uses` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Tuple + The module(s) which this uses + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + items = as_tuple(self.annotation) + for item in items: + if not isinstance(item, (Name, Attribute)): + raise StructureException("invalid uses", item) + + +class InitializesDecl(Stmt): + """ + An `initializes` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Subscript + An imported module which this module initializes + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + module_ref = self.annotation + if isinstance(module_ref, Subscript): + dependencies = as_tuple(module_ref.slice) + module_ref = module_ref.value + + for item in dependencies: + if not isinstance(item, NamedExpr): + raise StructureException( + "invalid dependency (hint: should be [dependency := dependency]", item + ) + if not isinstance(item.target, (Name, Attribute)): + raise StructureException("invalid module", item.target) + if not isinstance(item.value, (Name, Attribute)): + raise StructureException("invalid module", item.target) + + if not isinstance(module_ref, (Name, Attribute)): + raise StructureException("invalid module", module_ref) + + class If(Stmt): __slots__ = ("test", "body", "orelse") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f71ed67821..7f863a8db9 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -101,7 +101,8 @@ class StructDef(VyperNode): body: list = ... name: str = ... -class ExprNode(VyperNode): ... +class ExprNode(VyperNode): + _expr_info: Any = ... class Constant(VyperNode): value: Any = ... @@ -145,19 +146,19 @@ class Name(VyperNode): _type: str = ... class Expr(VyperNode): - value: VyperNode = ... + value: ExprNode = ... class UnaryOp(ExprNode): op: VyperNode = ... - operand: VyperNode = ... + operand: ExprNode = ... class USub(VyperNode): ... class Not(VyperNode): ... class BinOp(ExprNode): - left: VyperNode = ... op: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Add(VyperNode): ... class Sub(VyperNode): ... @@ -173,15 +174,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... - values: list[VyperNode] = ... + values: list[ExprNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... - left: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -195,13 +196,13 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: VyperNode = ... + func: ExprNode = ... class keyword(VyperNode): ... class Attribute(VyperNode): attr: str = ... - value: VyperNode = ... + value: ExprNode = ... class Subscript(VyperNode): slice: VyperNode = ... @@ -224,8 +225,8 @@ class VariableDecl(VyperNode): class AugAssign(VyperNode): op: VyperNode = ... - target: VyperNode = ... - value: VyperNode = ... + target: ExprNode = ... + value: ExprNode = ... class Raise(VyperNode): ... class Assert(VyperNode): ... @@ -245,6 +246,12 @@ class ImplementsDecl(VyperNode): target: Name = ... annotation: Name = ... +class UsesDecl(VyperNode): + annotation: VyperNode = ... + +class InitializesDecl(VyperNode): + annotation: VyperNode = ... + class If(VyperNode): body: list = ... orelse: list = ... @@ -254,6 +261,10 @@ class IfExp(ExprNode): body: ExprNode = ... orelse: ExprNode = ... +class NamedExpr(ExprNode): + target: Name = ... + value: ExprNode = ... + class For(VyperNode): target: ExprNode iter: ExprNode diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index fc99af901b..a10a840da0 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -278,8 +278,8 @@ def visit_For(self, node): # specific error message than "invalid type annotation" raise SyntaxException( "missing type annotation\n\n" - "(hint: did you mean something like " - f"`for {node.target.id}: uint256 in ...`?)\n", + " (hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)", self._source_code, node.lineno, node.col_offset, diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d2aefb2fd4..6e6cf4c662 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,13 +85,16 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _equality_attrs = ("_id",) _is_terminus = False - # helper function to deal with TYPE_DEFINITIONs + @property + def modifiability(self): + return self._modifiability + + # helper function to deal with TYPE_Ts def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: - # TODO using "TYPE_DEFINITION" is a kludge in derived classes, - # refactor me. - if expected_type == "TYPE_DEFINITION": + if TYPE_T.any().compare_type(expected_type): # try to parse the type - call type_from_annotation # for its side effects (will throw if is not a type) type_from_annotation(arg) @@ -130,7 +133,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: get_exact_type_from_node(arg) def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: - return self._modifiability >= modifiability + return self._modifiability <= modifiability def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index 72b05f15e3..3fad225b48 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,7 +1,7 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context from vyper.codegen.stmt import parse_body -from vyper.semantics.analysis.local import FunctionNodeVisitor +from vyper.semantics.analysis.local import FunctionAnalyzer from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability from vyper.semantics.types.module import ModuleT @@ -25,9 +25,7 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) - # The FunctionNodeVisitor's constructor performs semantic checks - # annotate the AST as side effects. - analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionAnalyzer(ast_code, ast_code.body[0], namespace) analyzer.analyze() new_context = Context( diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 50ab4dacd8..7575f4d77e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -113,10 +113,7 @@ class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. - - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TypeTypeDefinition` object in `infer_arg_types`. - _inputs = [("typename", "TYPE_DEFINITION")] + _inputs = [("typename", TYPE_T.any())] def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef @@ -711,7 +708,7 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" _inputs = [("value", StringT.any())] - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BytesT(4))} def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) @@ -848,10 +845,7 @@ def _storage_element_getter(index): class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TYPE_T` object in `infer_kwarg_types` - # (note that it is ignored in _validate_arg_types) - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)} def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1976,18 +1970,22 @@ def build_IR(self, expr, args, kwargs, context): class UnsafeAdd(_UnsafeMath): + _id = "unsafe_add" op = "add" class UnsafeSub(_UnsafeMath): + _id = "unsafe_sub" op = "sub" class UnsafeMul(_UnsafeMath): + _id = "unsafe_mul" op = "mul" class UnsafeDiv(_UnsafeMath): + _id = "unsafe_div" op = "div" @@ -2474,7 +2472,7 @@ def build_IR(self, expr, args, kwargs, context): class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" - _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] + _inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} def fetch_call_return(self, node): diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 4f644841f4..af01c5b504 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -44,7 +44,7 @@ def __repr__(self): return f"VariableRecord({ret})" -# Contains arguments, variables, etc +# compilation context for a function class Context: def __init__( self, @@ -59,19 +59,12 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} # Is the function constant? self.constancy = constancy - # Whether body is currently in an assert statement - # XXX: dead, never set to True - self.in_assertion = False - # Whether we are currently parsing a range expression self.in_range_expr = False @@ -87,6 +80,10 @@ def __init__( # Not intended to be accessed directly self.memory_allocator = memory_allocator + # save the starting memory location so we can find out (later) + # how much memory this function uses. + self.starting_memory = memory_allocator.next_mem + # Incremented values, used for internal IDs self._internal_var_iter = 0 self._scope_id_iter = 0 @@ -95,7 +92,7 @@ def __init__( self.is_ctor_context = is_ctor_context def is_constant(self): - return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr + return self.constancy is Constancy.Constant or self.in_range_expr def check_is_not_constant(self, err, expr): if self.is_constant(): @@ -250,9 +247,7 @@ def lookup_var(self, varname): # Pretty print constancy for error messages def pp_constancy(self): - if self.in_assertion: - return "an assertion" - elif self.in_range_expr: + if self.in_range_expr: return "a range expression" elif self.constancy == Constancy.Constant: return "a constant function" diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c3215f8c16..1a090ac316 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -3,9 +3,18 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -100,6 +109,36 @@ def _codecopy_gas_bound(num_bytes): return GAS_COPY_WORD * ceil32(num_bytes) // 32 +def data_location_to_address_space(s: DataLocation, is_ctor_ctx: bool) -> AddrSpace: + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.TRANSIENT: + return TRANSIENT + if s == DataLocation.CODE: + if is_ctor_ctx: + return IMMUTABLES + return DATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + +def address_space_to_data_location(s: AddrSpace) -> DataLocation: + if s == MEMORY: + return DataLocation.MEMORY + if s == STORAGE: + return DataLocation.STORAGE + if s == TRANSIENT: + return DataLocation.TRANSIENT + if s in (IMMUTABLES, DATA): + return DataLocation.CODE + if s == CALLDATA: + return DataLocation.CALLDATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + # Copy byte array word-for-word (including layout) # TODO make this a private function def make_byte_array_copier(dst, src): @@ -482,14 +521,10 @@ def _get_element_ptr_tuplelike(parent, key): return _getelemptr_abi_helper(parent, member_t, ofst) - if parent.location.word_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words - elif parent.location.byte_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].memory_bytes_required - else: - raise CompilerPanic(f"bad location {parent.location}") # pragma: notest + data_location = address_space_to_data_location(parent.location) + for i in range(index): + t = typ.member_types[attrs[i]] + ofst += t.get_size_in(data_location) return IRnode.from_list( add_ofst(parent, ofst), @@ -550,12 +585,8 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) - if parent.location.word_addressable: - element_size = subtype.storage_size_in_words - elif parent.location.byte_addressable: - element_size = subtype.memory_bytes_required - else: - raise CompilerPanic("unreachable") # pragma: notest + data_location = address_space_to_data_location(parent.location) + element_size = subtype.get_size_in(data_location) ofst = _mul(ix, element_size) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index f4c7948382..335cfefb87 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -6,6 +6,7 @@ from vyper.codegen import external_call, self_call from vyper.codegen.core import ( clamp, + data_location_to_address_space, ensure_in_memory, get_dyn_array_count, get_element_ptr, @@ -23,7 +24,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import MEMORY from vyper.evm.opcodes import version_check from vyper.exceptions import ( CodegenPanic, @@ -185,26 +186,24 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: - varinfo = self.context.globals[self.expr.id] - + elif (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset + mutable = self.context.is_ctor_context - if self.context.is_ctor_context: - mutable = True - location = IMMUTABLES - else: - mutable = False - location = DATA + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=self.expr.id, + mutable=mutable, ) ret._referenced_variables = {varinfo} return ret @@ -265,9 +264,13 @@ def parse_Attribute(self): return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE + elif (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( varinfo.position.position, @@ -700,7 +703,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 94617bef35..254b4df72c 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1,4 @@ -from .common import FuncIR, generate_ir_for_function # noqa +from .external_function import generate_ir_for_external_function +from .internal_function import generate_ir_for_internal_function + +__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function] # type: ignore diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 5877ff3d13..d017ba7b81 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -2,17 +2,14 @@ from functools import cached_property from typing import Optional -import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function -from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator -from vyper.exceptions import CompilerPanic +from vyper.evm.opcodes import version_check from vyper.semantics.types import VyperType -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, StateMutability from vyper.semantics.types.module import ModuleT -from vyper.utils import MemoryPositions, calc_mem_gas +from vyper.utils import MemoryPositions @dataclass @@ -53,9 +50,11 @@ def ir_identifier(self) -> str: return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: + # XXX: when can this happen? if self.frame_info is not None: - raise CompilerPanic(f"frame_info already set for {self.func_t}!") - self.frame_info = frame_info + assert frame_info == self.frame_info + else: + self.frame_info = frame_info @property # common entry point for external function with kwargs @@ -64,13 +63,15 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" - suffix = "_deploy" if is_ctor_context else "_runtime" - return self.ir_identifier + suffix + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context -class FuncIR: - pass + suffix = "_deploy" if is_ctor_context else "_runtime" + return self.ir_identifier + suffix @dataclass @@ -80,7 +81,7 @@ class EntryPointInfo: ir_node: IRnode # the ir for this entry point def __post_init__(self): - # ABI v2 property guaranteed by the spec. + # sanity check ABI v2 properties guaranteed by the spec. # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 # > Note that for any X, len(enc(X)) is a multiple of 32. assert self.min_calldatasize >= 4 @@ -88,34 +89,28 @@ def __post_init__(self): @dataclass -class ExternalFuncIR(FuncIR): +class ExternalFuncIR: entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points common_ir: IRnode # the "common" code for the function @dataclass -class InternalFuncIR(FuncIR): +class InternalFuncIR: func_ir: IRnode # the code for the function -# TODO: should split this into external and internal ir generation? -def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False -) -> FuncIR: - """ - Parse a function and produce IR code for the function, includes: - - Signature method if statement - - Argument handling - - Clamping and copying of arguments - - Function body - """ - func_t = code._metadata["func_type"] - - # generate _FuncIRInfo +def init_ir_info(func_t: ContractFunctionT): + # initialize IRInfo on the function func_t._ir_info = _FuncIRInfo(func_t) - callees = func_t.called_functions +def initialize_context( + func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False +): + init_ir_info(func_t) + + # calculate starting frame + callees = func_t.called_functions # we start our function frame from the largest callee frame max_callee_frame_size = 0 for c_func_t in callees: @@ -126,7 +121,7 @@ def generate_ir_for_function( memory_allocator = MemoryAllocator(allocate_start) - context = Context( + return Context( vars_=None, module_ctx=module_ctx, memory_allocator=memory_allocator, @@ -135,38 +130,41 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal: - ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) - func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore - else: - kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) - entry_points = { - k: EntryPointInfo(func_t, mincalldatasize, ir_node) - for k, (mincalldatasize, ir_node) in kwarg_handlers.items() - } - ret = ExternalFuncIR(entry_points, common) - # note: this ignores the cost of traversing selector table - func_t._ir_info.gas_estimate = ret.common_ir.gas +def tag_frame_info(func_t, context): frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY + frame_start = context.starting_memory - frame_info = FrameInfo(allocate_start, frame_size, context.vars) + frame_info = FrameInfo(frame_start, frame_size, context.vars) + func_t._ir_info.set_frame_info(frame_info) - # XXX: when can this happen? - if func_t._ir_info.frame_info is None: - func_t._ir_info.set_frame_info(frame_info) - else: - assert frame_info == func_t._ir_info.frame_info - - if not func_t.is_internal: - # adjust gas estimate to include cost of mem expansion - # frame_size of external function includes all private functions called - # (note: internal functions do not need to adjust gas estimate since - mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore - ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + return frame_info + + +def get_nonreentrant_lock(func_t): + if not func_t.nonreentrant: + return ["pass"], ["pass"] + + nkey = func_t.reentrancy_key_position.position + + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + + if version_check(begin="berlin"): + # any nonzero values would work here (see pricing as of net gas + # metering); these values are chosen so that downgrading to the + # 0,1 scheme (if it is somehow necessary) is safe. + final_value, temp_value = 3, 2 else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + final_value, temp_value = 0, 1 + + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - return ret + if func_t.mutability == StateMutability.VIEW: + return [check_notset], [["seq"]] + + else: + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] + return [pre], [post] diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 65276469e7..b380eab2ce 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -2,12 +2,19 @@ from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp from vyper.codegen.expr import Expr -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + EntryPointInfo, + ExternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body from vyper.evm.address_space import CALLDATA, DATA, MEMORY from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import calc_mem_gas # register function args with the local calling context. @@ -51,7 +58,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list def _generate_kwarg_handlers( func_t: ContractFunctionT, context: Context -) -> dict[str, tuple[int, IRnode]]: +) -> dict[str, EntryPointInfo]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -126,34 +133,54 @@ def handler_for(calldata_kwargs, default_kwargs): default_kwargs = keyword_args[i:] sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) return ret -def generate_ir_for_external_function(code, func_t, context): +def _adjust_gas_estimate(func_t, common_ir): + # adjust gas estimate to include cost of mem expansion + # frame_size of external function includes all private functions called + # (note: internal functions do not need to adjust gas estimate since + frame_info = func_t._ir_info.frame_info + + mem_expansion_cost = calc_mem_gas(frame_info.mem_used) + common_ir.add_gas_estimate += mem_expansion_cost + func_t._ir_info.gas_estimate = common_ir.gas + + # pass metadata through for venom pipeline: + common_ir.passthrough_metadata["func_t"] = func_t + common_ir.passthrough_metadata["frame_info"] = frame_info + + +def generate_ir_for_external_function(code, compilation_target): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, - # func_t: ContractFunctionT, - # context: Context, + # compilation_target: ModuleT, # ) -> IRnode: """ Return the IR for an external function. Returns IR for the body of the function, handle kwargs and exit the function. Also returns metadata required for `module.py` to construct the selector table. """ + func_t = code._metadata["func_type"] + assert func_t.is_external or func_t.is_constructor # sanity check + + context = initialize_context(func_t, compilation_target, func_t.is_constructor) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) # generate handlers for base args and register the variable records handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context) + entry_points = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, @@ -185,4 +212,8 @@ def generate_ir_for_external_function(code, func_t, context): # besides any kwarg handling func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return kwarg_handlers, func_common_ir + tag_frame_info(func_t, context) + + _adjust_gas_estimate(func_t, func_common_ir) + + return ExternalFuncIR(entry_points, func_common_ir) diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..0cf9850b70 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -1,23 +1,25 @@ from vyper import ast as vy_ast -from vyper.codegen.context import Context -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + InternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import IRnode from vyper.codegen.stmt import parse_body -from vyper.semantics.types.function import ContractFunctionT def generate_ir_for_internal_function( - code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context -) -> IRnode: + code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool +) -> InternalFuncIR: """ Parse a internal function (FuncDef), and produce full function body. :param func_t: the ContractFunctionT :param code: ast of function - :param context: current calling context + :param compilation_target: current calling context :return: function body in IR """ - # The calling convention is: # Caller fills in argument buffer # Caller provides return address, return buffer on the stack @@ -37,13 +39,19 @@ def generate_ir_for_internal_function( # situation like the following is easy to bork: # x: T[2] = [self.generate_T(), self.generate_T()] - # Get nonreentrant lock + func_t = code._metadata["func_type"] + + # sanity check + assert func_t.is_internal or func_t.is_constructor + + context = initialize_context(func_t, module_ctx, is_ctor_context) for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable context.new_variable(arg.name, arg.typ, is_mutable=True) + # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context) @@ -69,5 +77,13 @@ def generate_ir_for_internal_function( ] ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + + # tag gas estimate and frame info + func_t._ir_info.gas_estimate = ir_node.gas + frame_info = tag_frame_info(func_t, context) + + # pass metadata through for venom pipeline: + ir_node.passthrough_metadata["frame_info"] = frame_info ir_node.passthrough_metadata["func_t"] = func_t - return ir_node + + return InternalFuncIR(ir_node) diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py deleted file mode 100644 index f524ec6e88..0000000000 --- a/vyper/codegen/function_definitions/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from vyper.evm.opcodes import version_check -from vyper.semantics.types.function import StateMutability - - -def get_nonreentrant_lock(func_type): - if not func_type.nonreentrant: - return ["pass"], ["pass"] - - nkey = func_type.reentrancy_key_position.position - - LOAD, STORE = "sload", "sstore" - if version_check(begin="cancun"): - LOAD, STORE = "tload", "tstore" - - if version_check(begin="berlin"): - # any nonzero values would work here (see pricing as of net gas - # metering); these values are chosen so that downgrading to the - # 0,1 scheme (if it is somehow necessary) is safe. - final_value, temp_value = 3, 2 - else: - final_value, temp_value = 0, 1 - - check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - - if func_type.mutability == StateMutability.VIEW: - return [check_notset], [["seq"]] - - else: - pre = ["seq", check_notset, [STORE, nkey, temp_value]] - post = [STORE, nkey, final_value] - return [pre], [post] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 98395a6a0c..fef4f23949 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,10 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr -from vyper.codegen.function_definitions import generate_ir_for_function +from vyper.codegen.function_definitions import ( + generate_ir_for_external_function, + generate_ir_for_internal_function, +) from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic @@ -89,7 +92,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): callvalue_check = ["assert", ["iszero", "callvalue"]] ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + func_ir = generate_ir_for_external_function(func_ast, *args, **kwargs) assert len(func_ir.entry_points) == 1 # add a goto to make the function entry look like other functions @@ -101,7 +104,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): def _ir_for_internal_function(func_ast, *args, **kwargs): - return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + return generate_ir_for_internal_function(func_ast, *args, **kwargs).func_ir def _generate_external_entry_points(external_functions, module_ctx): @@ -109,7 +112,7 @@ def _generate_external_entry_points(external_functions, module_ctx): sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_external_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -424,12 +427,13 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: + # XXX: rename `module_ctx` to `compilation_target` # order functions so that each function comes after all of its callees function_defs = _topsort(module_ctx.function_defs) reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + init_function = next((f for f in module_ctx.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -475,24 +479,21 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: deploy_code: List[Any] = ["seq"] immutables_len = module_ctx.immutable_section_bytes - if init_function: + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + func_ir = _ir_for_internal_function(fn_ast, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 7d4938f287..e6baea75f7 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -144,7 +144,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5b7decec7b..f7eccdf214 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -152,23 +152,18 @@ def vyper_module(self): return self._generate_ast @cached_property - def _annotated_module(self): - return generate_annotated_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override - ) - - @property def annotated_vyper_module(self) -> vy_ast.Module: - module, storage_layout = self._annotated_module - return module + return generate_annotated_ast(self.vyper_module, self.input_bundle) - @property + @cached_property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._annotated_module - return storage_layout + module_ast = self.annotated_vyper_module + return set_data_positions(module_ast, self.storage_layout_override) @property def global_ctx(self) -> ModuleT: + # ensure storage layout is computed + _ = self.storage_layout return self.annotated_vyper_module._metadata["type"] @cached_property @@ -243,11 +238,7 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast( - vyper_module: vy_ast.Module, - input_bundle: InputBundle, - storage_layout_overrides: StorageLayout = None, -) -> tuple[vy_ast.Module, StorageLayout]: +def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -268,9 +259,7 @@ def generate_annotated_ast( # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - - return vyper_module, symbol_tables + return vyper_module def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 85a75c3c23..fcbd4bcf63 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -28,14 +28,6 @@ class AddrSpace: # TODO maybe make positional instead of defaulting to None store_op: Optional[str] = None - @property - def word_addressable(self) -> bool: - return self.word_scale == 1 - - @property - def byte_addressable(self) -> bool: - return self.word_scale == 32 - # alternative: # class Memory(AddrSpace): diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 04667aaa59..53ad6f7bb8 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -31,7 +31,7 @@ class _BaseVyperException(Exception): order to display source annotations in the error string. """ - def __init__(self, message="Error Message not found.", *items): + def __init__(self, message="Error Message not found.", *items, hint=None): """ Exception initializer. @@ -47,7 +47,9 @@ def __init__(self, message="Error Message not found.", *items): A single tuple of (lineno, col_offset) is also understood to support the old API, but new exceptions should not use this approach. """ - self.message = message + self._message = message + self._hint = hint + self.lineno = None self.col_offset = None self.annotations = None @@ -77,6 +79,13 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def message(self): + msg = self._message + if self._hint: + msg += f"\n\n (hint: {self._hint})" + return msg + def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code @@ -131,7 +140,7 @@ def __str__(self): annotation_list.append(node_msg) annotation_msg = "\n".join(annotation_list) - return f"{self.message}\n{annotation_msg}" + return f"{self.message}\n\n{annotation_msg}" class VyperException(_BaseVyperException): @@ -252,6 +261,14 @@ class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" +class InitializerException(VyperException): + """An issue with initializing/constructing a module""" + + +class BorrowException(VyperException): + """An issue with borrowing/using a module""" + + class StateAccessViolation(VyperException): """Violating the mutability of a function definition.""" @@ -369,7 +386,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): except _BaseVyperException as e: if not e.annotations and not e.lineno: tb = e.__traceback__ - raise e.with_annotation(node).with_traceback(tb) + raise e.with_annotation(node).with_traceback(tb) from None raise e from None except Exception as e: tb = e.__traceback__ diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7b52a68e92..e23b2d2aa4 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,4 +1,4 @@ from .. import types # break a dependency cycle. -from .module import validate_semantics +from .global_ import validate_semantics __all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index bb6d9ad9f7..2086e5f9da 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,84 +1,29 @@ import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle -from vyper.exceptions import ( - CompilerPanic, - ImmutableViolation, - StateAccessViolation, - VyperInternalException, -) +from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: from vyper.semantics.types.module import InterfaceT, ModuleT -class _StringEnum(enum.Enum): - @staticmethod - def auto(): - return enum.auto() +class FunctionVisibility(StringEnum): + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + DEPLOY = enum.auto() - # Must be first, or else won't work, specifies what .value is - def _generate_next_value_(name, start, count, last_values): - return name.lower() - # Override ValueError with our own internal exception - @classmethod - def _missing_(cls, value): - raise VyperInternalException(f"{value} is not a valid {cls.__name__}") - - @classmethod - def is_valid_value(cls, value: str) -> bool: - return value in set(o.value for o in cls) - - @classmethod - def options(cls) -> List["_StringEnum"]: - return list(cls) - - @classmethod - def values(cls) -> List[str]: - return [v.value for v in cls.options()] - - # Comparison operations - def __eq__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - return self is other - - # Python normally does __ne__(other) ==> not self.__eq__(other) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - options = self.__class__.options() - return options.index(self) < options.index(other) # type: ignore - - def __le__(self, other: object) -> bool: - return self.__eq__(other) or self.__lt__(other) - - def __gt__(self, other: object) -> bool: - return not self.__le__(other) - - def __ge__(self, other: object) -> bool: - return self.__eq__(other) or self.__gt__(other) - - -class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() - - -class StateMutability(_StringEnum): - # TODO: these can just be enum.auto() right? - PURE = _StringEnum.auto() - VIEW = _StringEnum.auto() - NONPAYABLE = _StringEnum.auto() - PAYABLE = _StringEnum.auto() +class StateMutability(StringEnum): + PURE = enum.auto() + VIEW = enum.auto() + NONPAYABLE = enum.auto() + PAYABLE = enum.auto() @classmethod def from_abi(cls, abi_dict: Dict) -> "StateMutability": @@ -103,71 +48,40 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # and variables) and Constancy (in codegen). context.Constancy can/should # probably be refactored away though as those kinds of checks should be done # during analysis. -class Modifiability(enum.IntEnum): - # is writeable/can result in arbitrary state or memory changes - MODIFIABLE = enum.auto() - - # could potentially add more fine-grained here as needed, like - # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. +class Modifiability(StringEnum): + # compile-time / always constant + CONSTANT = enum.auto() # things that are constant within the current message call, including # block.*, msg.*, tx.* and immutables RUNTIME_CONSTANT = enum.auto() - # compile-time / always constant - CONSTANT = enum.auto() - - -class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset - - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" - - -class StorageSlot(DataPosition): - __slots__ = ("position",) - _location = DataLocation.STORAGE + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. - def __init__(self, position): - self.position = position + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() - def __repr__(self): - return f"" + @classmethod + def from_state_mutability(cls, mutability: StateMutability): + if mutability == StateMutability.PURE: + return cls.CONSTANT + if mutability == StateMutability.VIEW: + return cls.RUNTIME_CONSTANT + # sanity check in case more StateMutability levels are added in the future + assert mutability in (StateMutability.PAYABLE, StateMutability.NONPAYABLE) + return cls.MODIFIABLE -class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE +@dataclass +class VarOffset: + position: int - def __init__(self, offset): - self.offset = offset - def __repr__(self): - return f"" +class ModuleOwnership(StringEnum): + NO_OWNERSHIP = enum.auto() # readable + USES = enum.auto() # writeable + INITIALIZES = enum.auto() # initializes # base class for things that are the "result" of analysis @@ -178,6 +92,9 @@ class AnalysisResult: @dataclass class ModuleInfo(AnalysisResult): module_t: "ModuleT" + alias: str + ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP + ownership_decl: Optional[vy_ast.VyperNode] = None @property def module_node(self): @@ -188,6 +105,16 @@ def module_node(self): def typ(self): return self.module_t + def set_ownership(self, module_ownership: ModuleOwnership, node: Optional[vy_ast.VyperNode]): + if self.ownership != ModuleOwnership.NO_OWNERSHIP: + raise StructureException( + f"ownership already set to `{self.ownership}`", node, self.ownership_decl + ) + self.ownership = module_ownership + + def __hash__(self): + return hash(id(self.module_t)) + @dataclass class ImportInfo(AnalysisResult): @@ -199,6 +126,21 @@ class ImportInfo(AnalysisResult): node: vy_ast.VyperNode +# analysis result of InitializesDecl +@dataclass +class InitializesInfo(AnalysisResult): + module_info: ModuleInfo + dependencies: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + +# analysis result of UsesDecl +@dataclass +class UsesInfo(AnalysisResult): + used_modules: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + @dataclass class VarInfo: """ @@ -221,22 +163,21 @@ def __hash__(self): return hash(id(self)) def __post_init__(self): + self.position = None self._modification_count = 0 - def set_position(self, position: DataPosition) -> None: - if hasattr(self, "position"): + def set_position(self, position: VarOffset) -> None: + if self.position is not None: raise CompilerPanic("Position was already assigned") - if self.location != position._location: - if self.location == DataLocation.UNSET: - self.location = position._location - elif self.is_transient and position._location == DataLocation.STORAGE: - # CMC 2023-12-31 - use same allocator for storage and transient - # for now, this should be refactored soon. - pass - else: - raise CompilerPanic("Incompatible locations") + assert isinstance(position, VarOffset) # sanity check self.position = position + def is_module_variable(self): + return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + + def get_size(self) -> int: + return self.typ.get_size_in(self.location) + @property def is_transient(self): return self.location == DataLocation.TRANSIENT @@ -260,9 +201,13 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None + module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE + # the chain of attribute parents for this expr + attribute_chain: list["ExprInfo"] = field(default_factory=list) + def __post_init__(self): should_match = ("typ", "location", "modifiability") if self.var_info is not None: @@ -270,65 +215,48 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self._writes: OrderedSet[VarInfo] = OrderedSet() + self._reads: OrderedSet[VarInfo] = OrderedSet() + + # find exprinfo in the attribute chain which has a varinfo + # e.x. `x` will return varinfo for `x` + # `module.foo` will return varinfo for `module.foo` + # `self.my_struct.x.y` will return varinfo for `self.my_struct` + def get_root_varinfo(self) -> Optional[VarInfo]: + for expr_info in self.attribute_chain + [self]: + if expr_info.var_info is not None: + return expr_info.var_info + return None + @classmethod - def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, + attribute_chain=attribute_chain or [], ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) + def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + modifiability = Modifiability.RUNTIME_CONSTANT + if module_info.ownership >= ModuleOwnership.USES: + modifiability = Modifiability.MODIFIABLE - def copy_with_type(self, typ: VyperType) -> "ExprInfo": + return cls( + module_info.module_t, + module_info=module_info, + modifiability=modifiability, + attribute_chain=attribute_chain or [], + ) + + def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} + if attribute_chain is not None: + fields["attribute_chain"] = attribute_chain return self.__class__(typ=typ, **fields) - - def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutability) -> None: - """ - Validate an attempt to modify this value. - - Raises if the value is a constant or involves an invalid operation. - - Arguments - --------- - node : Assign | AugAssign | Call - Vyper ast node of the modifying action. - mutability: StateMutability - The mutability of the context (e.g., pure function) we are currently in - """ - if mutability <= StateMutability.VIEW and self.location == DataLocation.STORAGE: - raise StateAccessViolation( - f"Cannot modify storage in a {mutability.value} function", node - ) - - if self.location == DataLocation.CALLDATA: - raise ImmutableViolation("Cannot write to calldata", node) - - if self.modifiability == Modifiability.RUNTIME_CONSTANT: - if self.location == DataLocation.CODE: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - - # special handling for immutable variables in the ctor - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore - else: - raise ImmutableViolation("Environment variable cannot be written to", node) - - if self.modifiability == Modifiability.CONSTANT: - raise ImmutableViolation("Constant value cannot be written to", node) - - if isinstance(node, vy_ast.AugAssign): - self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index bfcc473d09..3522383167 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -113,7 +113,7 @@ def visit_Attribute(self, node) -> vy_ast.ExprNode: varinfo = module_t.get_member(node.attr, node) return varinfo.decl_node.value.get_folded_value() - except (VyperException, AttributeError): + except (VyperException, AttributeError, KeyError): raise UnfoldableNode("not a module") def visit_UnaryOp(self, node): diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..604bc6b594 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,11 +1,12 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List +from collections import defaultdict +from typing import Generic, TypeVar from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import VarOffset +from vyper.semantics.data_locations import DataLocation from vyper.typing import StorageLayout -from vyper.utils import ceil32 def set_data_positions( @@ -20,24 +21,76 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + if storage_layout_overrides is not None: + # extract code layout with no overrides + code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] + storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) + return {"storage_layout": storage_slots, "code_layout": code_offsets} - return {"storage_layout": storage_slots, "code_layout": code_offsets} + ret = _allocate_layout_r(vyper_module) + assert isinstance(ret, defaultdict) + return dict(ret) # convert back to dict -class StorageAllocator: +_T = TypeVar("_T") +_K = TypeVar("_K") + + +class InsertableOnceDict(Generic[_T, _K], dict[_T, _K]): + def __setitem__(self, k, v): + if k in self: + raise ValueError(f"{k} is already in dict!") + super().__setitem__(k, v) + + +class SimpleAllocator: + def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): + # Allocate storage slots from 0 + # note storage is word-addressable, not byte-addressable + self._slot = starting_slot + self._max_slot = max_slot + + def allocate_slot(self, n, var_name, node=None): + ret = self._slot + if self._slot + n >= self._max_slot: + raise StorageLayoutException( + f"Invalid storage slot, tried to allocate" + f" slots {self._slot} through {self._slot + n}", + node, + ) + self._slot += n + return ret + + +class Allocators: + storage_allocator: SimpleAllocator + transient_storage_allocator: SimpleAllocator + immutables_allocator: SimpleAllocator + + def __init__(self): + self.storage_allocator = SimpleAllocator(max_slot=2**256) + self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) + self.immutables_allocator = SimpleAllocator(max_slot=0x6000) + + def get_allocator(self, location: DataLocation): + if location == DataLocation.STORAGE: + return self.storage_allocator + if location == DataLocation.TRANSIENT: + return self.transient_storage_allocator + if location == DataLocation.CODE: + return self.immutables_allocator + + raise CompilerPanic("unreachable") # pragma: nocover + + +class OverridingStorageAllocator: """ Keep track of which storage slots have been used. If there is a collision of storage slots, this will raise an error and fail to compile """ def __init__(self): - self.occupied_slots: Dict[int, str] = {} + self.occupied_slots: dict[int, str] = {} def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: """ @@ -48,7 +101,7 @@ def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> No list_to_check = [x + first_slot for x in range(n_slots)] self._reserve_slots(list_to_check, var_name) - def _reserve_slots(self, slots: List[int], var_name: str) -> None: + def _reserve_slots(self, slots: list[int], var_name: str) -> None: for slot in slots: self._reserve_slot(slot, var_name) @@ -70,12 +123,13 @@ def set_storage_slots_with_overrides( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout ) -> StorageLayout: """ - Parse module-level Vyper AST to calculate the layout of storage variables. + Set storage layout given a layout override file. Returns the layout as a dict of variable name -> variable info + (Doesn't handle modules, or transient storage) """ - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() + ret: InsertableOnceDict[str, dict] = InsertableOnceDict() + reserved_slots = OverridingStorageAllocator() # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -90,7 +144,7 @@ def set_storage_slots_with_overrides( # re-entrant key was already identified if variable_name in ret: _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -100,7 +154,7 @@ def set_storage_slots_with_overrides( # from using the same slot reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) + type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} else: @@ -125,7 +179,7 @@ def set_storage_slots_with_overrides( # Ensure that all required storage slots are reserved, and prevents other variables # from using these slots reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) + varinfo.set_position(VarOffset(var_slot)) ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} else: @@ -138,105 +192,108 @@ def set_storage_slots_with_overrides( return ret -class SimpleStorageAllocator: - def __init__(self, starting_slot: int = 0): - self._slot = starting_slot +def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: + allocable = (vy_ast.InitializesDecl, vy_ast.VariableDecl) + return [node for node in vyper_module.body if isinstance(node, allocable)] - def allocate_slot(self, n, var_name): - ret = self._slot - if self._slot + n >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, tried to allocate" - f" slots {self._slot} through {self._slot + n}" - ) - self._slot += n - return ret +def get_reentrancy_key_location() -> DataLocation: + if version_check(begin="cancun"): + return DataLocation.TRANSIENT + return DataLocation.STORAGE -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: + +_LAYOUT_KEYS = { + DataLocation.CODE: "code_layout", + DataLocation.TRANSIENT: "transient_storage_layout", + DataLocation.STORAGE: "storage_layout", +} + + +def _allocate_layout_r( + vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False +) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - # Allocate storage slots from 0 - # note storage is word-addressable, not byte-addressable - allocator = SimpleStorageAllocator() + if allocators is None: + allocators = Allocators() - ret: Dict[str, Dict] = {} + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) for node in vyper_module.get_children(vy_ast.FunctionDef): + if immutables_only: + break + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue variable_name = f"nonreentrant.{type_.nonreentrant}" + reentrancy_key_location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[reentrancy_key_location] # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + if variable_name in ret[layout_key]: + _slot = ret[layout_key][variable_name]["slot"] + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + allocator = allocators.get_allocator(reentrancy_key_location) + slot = allocator.allocate_slot(1, variable_name, node) - type_.set_reentrancy_key_position(StorageSlot(slot)) + type_.set_reentrancy_key_position(VarOffset(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} - - for node in vyper_module.get_children(vy_ast.VariableDecl): - # skip non-storage variables - if node.is_constant or node.is_immutable: + ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _allocate_layout_r(module_info.module_node, allocators) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + ret[layout_key][module_alias] = module_layout[layout_key] continue + assert isinstance(node, vy_ast.VariableDecl) + # skip non-storage variables varinfo = node.target._metadata["varinfo"] - type_ = varinfo.typ - - # CMC 2021-07-23 note that HashMaps get assigned a slot here. - # I'm not sure if it's safe to avoid allocating that slot - # for HashMaps because downstream code might use the slot - # ID as a salt. - n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) - - varinfo.set_position(StorageSlot(slot)) - - # this could have better typing but leave it untyped until - # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} - - return ret - - -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + if not varinfo.is_module_variable(): + continue + location = varinfo.location + if immutables_only and location != DataLocation.CODE: + continue -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + allocator = allocators.get_allocator(location) + size = varinfo.get_size() + # CMC 2021-07-23 note that HashMaps get assigned a slot here + # using the same allocator (even though there is not really + # any risk of physical overlap) + offset = allocator.allocate_slot(size, node.target.id, node) -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: - ret = {} - offset = 0 + varinfo.set_position(VarOffset(offset)) - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] + layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ - varinfo.set_position(CodeOffset(offset)) - - len_ = ceil32(type_.size_in_bytes) - # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} - - offset += len_ + if location == DataLocation.CODE: + item = {"type": str(type_), "length": size, "offset": offset} + elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + item = {"type": str(type_), "slot": offset} + else: # pragma: nocover + raise CompilerPanic("unreachable") + ret[layout_key][node.target.id] = item return ret diff --git a/vyper/semantics/analysis/global_.py b/vyper/semantics/analysis/global_.py new file mode 100644 index 0000000000..92cdf35c5d --- /dev/null +++ b/vyper/semantics/analysis/global_.py @@ -0,0 +1,80 @@ +from collections import defaultdict + +from vyper.exceptions import ExceptionList, InitializerException +from vyper.semantics.analysis.base import InitializesInfo, UsesInfo +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.module import validate_module_semantics_r +from vyper.semantics.types.module import ModuleT + + +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + _validate_global_initializes_constraint(ret) + + return ret + + +def _collect_used_modules_r(module_t): + ret: defaultdict[ModuleT, list[UsesInfo]] = defaultdict(list) + + for uses_decl in module_t.uses_decls: + for used_module in uses_decl._metadata["uses_info"].used_modules: + ret[used_module.module_t].append(uses_decl) + + # recurse + used_modules = _collect_used_modules_r(used_module.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + # also recurse into modules used by initialized modules + for i in module_t.initialized_modules: + used_modules = _collect_used_modules_r(i.module_info.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + return ret + + +def _collect_initialized_modules_r(module_t, seen=None): + seen: dict[ModuleT, InitializesInfo] = seen or {} + + # list of InitializedInfo + initialized_infos = module_t.initialized_modules + + for i in initialized_infos: + initialized_module_t = i.module_info.module_t + if initialized_module_t in seen: + seen_nodes = (i.node, seen[initialized_module_t].node) + raise InitializerException(f"`{i.module_info.alias}` initialized twice!", *seen_nodes) + seen[initialized_module_t] = i + + _collect_initialized_modules_r(initialized_module_t, seen) + + return seen + + +# validate that each module which is `used` in the import graph is +# `initialized`. +def _validate_global_initializes_constraint(module_t: ModuleT): + all_used_modules = _collect_used_modules_r(module_t) + all_initialized_modules = _collect_initialized_modules_r(module_t) + + err_list = ExceptionList() + + for u, uses in all_used_modules.items(): + if u not in all_initialized_modules: + found_module = module_t.find_module_info(u) + if found_module is not None: + hint = f"add `initializes: {found_module.alias}` to the top level of " + hint += "your main contract" + else: + # CMC 2024-02-06 is this actually reachable? + hint = f"ensure `{module_t}` is imported in your main contract!" + err_list.append( + InitializerException( + f"module `{u}` is used but never initialized!", *uses, hint=hint + ) + ) + + err_list.raise_if_not_empty() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91cc0ebdf8..d96215ede0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,8 +1,11 @@ +# CMC 2024-02-03 TODO: split me into function.py and expr.py + from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.exceptions import ( + CallViolation, ExceptionList, FunctionDeclarationException, ImmutableViolation, @@ -16,7 +19,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -54,13 +57,12 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" - err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionAnalyzer(vy_module, node, namespace) analyzer.analyze() except VyperException as e: err_list.append(e) @@ -181,7 +183,7 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) -class FunctionNodeVisitor(VyperNodeVisitorBase): +class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -192,7 +194,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self) def analyze(self): # allow internal function params to be mutable @@ -270,21 +272,94 @@ def _assign_helper(self, node): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): - raise StructureException( - "Left-hand side of assignment cannot be a HashMap without a key", node - ) - target.validate_modification(node, self.func.mutability) + # check mutability of the function + self._handle_modification(node.target) self.expr_visitor.visit(node.value, target.typ) self.expr_visitor.visit(node.target, target.typ) + def _handle_modification(self, target: vy_ast.ExprNode): + if isinstance(target, vy_ast.Tuple): + for item in target.elements: + self._handle_modification(item) + return + + # check a modification of `target`. validate the modification is + # valid, and log the modification in relevant data structures. + func_t = self.func + info = get_expr_info(target) + + if isinstance(info.typ, HashMapT): + raise StructureException( + "Left-hand side of assignment cannot be a HashMap without a key" + ) + + if ( + info.location in (DataLocation.STORAGE, DataLocation.TRANSIENT) + and func_t.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot modify {info.location} variable in a {func_t.mutability} function" + ) + + if info.location == DataLocation.CALLDATA: + raise ImmutableViolation("Cannot write to calldata") + + if info.modifiability == Modifiability.RUNTIME_CONSTANT: + if info.location == DataLocation.CODE: + if not func_t.is_constructor: + raise ImmutableViolation("Immutable value cannot be written to") + + # handle immutables + if info.var_info is not None: # don't handle complex (struct,array) immutables + # special handling for immutable variables in the ctor + # TODO: maybe we want to remove this restriction. + if info.var_info._modification_count != 0: + raise ImmutableViolation( + "Immutable value cannot be modified after assignment" + ) + info.var_info._modification_count += 1 + else: + raise ImmutableViolation("Environment variable cannot be written to") + + if info.modifiability == Modifiability.CONSTANT: + raise ImmutableViolation("Constant value cannot be written to.") + + var_info = info.get_root_varinfo() + assert var_info is not None + + info._writes.add(var_info) + + def _check_module_use(self, target: vy_ast.ExprNode): + module_infos = [] + for t in get_expr_info(target).attribute_chain: + if t.module_info is not None: + module_infos.append(t.module_info) + + if len(module_infos) == 0: + return + + for module_info in module_infos: + if module_info.ownership < ModuleOwnership.USES: + msg = f"Cannot access `{module_info.alias}` state!" + hint = f"add `uses: {module_info.alias}` or " + hint += f"`initializes: {module_info.alias}` as " + hint += "a top-level statement to your contract" + raise ImmutableViolation(msg, hint=hint) + + # the leftmost- referenced module + root_module_info = module_infos[0] + + # log the access + self.func._used_modules.add(root_module_info) + def visit_Assign(self, node): self._assign_helper(node) def visit_AugAssign(self, node): self._assign_helper(node) + node.target._expr_info.typ.validate_numeric_op(node) def visit_Break(self, node): for_node = node.get_ancestor(vy_ast.For) @@ -309,35 +384,13 @@ def visit_Expr(self, node): raise StructureException("Expressions without assignment are disallowed", node) fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): raise StructureException("To call an event you must use the `log` statement", node) if is_type_t(fn_type, StructT): raise StructureException("Struct creation without assignment is disallowed", node) - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - # NOTE: fetch_call_return validates call args. return_value = map_void(fn_type.fetch_call_return(node.value)) if ( @@ -457,7 +510,7 @@ def visit_Log(self, node): raise StructureException("Value is not an event", node.value) if self.func.mutability <= StateMutability.VIEW: raise StructureException( - f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node + f"Cannot emit logs from {self.func.mutability} functions", node ) t = map_void(f.fetch_call_return(node.value)) # CMC 2024-02-05 annotate the event type for codegen usage @@ -493,10 +546,20 @@ def visit_Return(self, node): class ExprVisitor(VyperNodeVisitorBase): - scope_name = "function" + def __init__(self, function_analyzer: Optional[FunctionAnalyzer] = None): + self.function_analyzer = function_analyzer + + @property + def func(self): + if self.function_analyzer is None: + return None + return self.function_analyzer.func - def __init__(self, fn_node: Optional[ContractFunctionT] = None): - self.func = fn_node + @property + def scope_name(self): + if self.func is not None: + return "function" + return "module" def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): @@ -509,6 +572,24 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + if not isinstance(typ, TYPE_T): + info = get_expr_info(node) # get_expr_info fills in node._expr_info + + # log variable accesses. + # (note writes will get logged as both read+write) + varinfo = info.var_info + if varinfo is not None: + info._reads.add(varinfo) + + if self.func: + variable_accesses = info._writes | info._reads + for s in variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node) + + self.func._variable_writes.update(info._writes) + self.func._variable_reads.update(info._reads) + # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() @@ -547,42 +628,77 @@ def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: for value in node.values: self.visit(value, BoolT()) + def _check_call_mutability(self, call_mutability: StateMutability): + # note: payable can be called from nonpayable functions + ok = ( + call_mutability <= self.func.mutability + or self.func.mutability >= StateMutability.NONPAYABLE + ) + if not ok: + msg = f"Cannot call a {call_mutability} function from a {self.func.mutability} function" + raise StateAccessViolation(msg) + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: - call_type = get_exact_type_from_node(node.func) - self.visit(node.func, call_type) + func_info = get_expr_info(node.func, is_callable=True) + func_type = func_info.typ + self.visit(node.func, func_type) - if isinstance(call_type, ContractFunctionT): + if isinstance(func_type, ContractFunctionT): # function calls - if self.func and call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): + + func_info._writes.update(func_type._variable_writes) + func_info._reads.update(func_type._variable_reads) + + if self.function_analyzer: + if func_type.is_internal: + self.func.called_functions.add(func_type) + + self._check_call_mutability(func_type.mutability) + + # check that if the function accesses state, the defining + # module has been `used` or `initialized`. + for s in func_type._variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node.func) + + if func_type.is_deploy and not self.func.is_deploy: + raise CallViolation( + f"Cannot call an @{func_type.visibility} function from " + f"an @{self.func.visibility} function!", + node, + ) + + for arg, typ in zip(node.args, func_type.argument_types): self.visit(arg, typ) for kwarg in node.keywords: # We should only see special kwargs - typ = call_type.call_site_kwargs[kwarg.arg].typ + typ = func_type.call_site_kwargs[kwarg.arg].typ self.visit(kwarg.value, typ) - elif is_type_t(call_type, EventT): + elif is_type_t(func_type, EventT): # events have no kwargs - expected_types = call_type.typedef.arguments.values() + expected_types = func_type.typedef.arguments.values() # type: ignore for arg, typ in zip(node.args, expected_types): self.visit(arg, typ) - elif is_type_t(call_type, StructT): + elif is_type_t(func_type, StructT): # struct ctors # ctors have no kwargs - expected_types = call_type.typedef.members.values() + expected_types = func_type.typedef.members.values() # type: ignore for value, arg_type in zip(node.args[0].values, expected_types): self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): + elif isinstance(func_type, MemberFunctionT): + if func_type.is_modifying and self.function_analyzer is not None: + # TODO refactor this + self.function_analyzer._handle_modification(node.func) + assert len(node.args) == len(func_type.arg_types) + for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) + arg_types = func_type.infer_arg_types(node, expected_return_typ=typ) # type: ignore for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) + kwarg_types = func_type.infer_kwarg_types(node) # type: ignore for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) @@ -638,8 +754,10 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func and self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func: + # TODO: refactor to use expr_info mutability + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index a83c2f3b7d..e50c3e6d6f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -8,38 +8,50 @@ from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( + BorrowException, CallViolation, DuplicateImport, ExceptionList, + ImmutableViolation, + InitializerException, InvalidLiteral, InvalidType, ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, - SyntaxException, + UndeclaredDefinition, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ImportInfo, + InitializesInfo, + Modifiability, + ModuleInfo, + ModuleOwnership, + UsesInfo, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + get_expr_info, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation +from vyper.utils import OrderedSet -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) - - -def validate_semantics_r( +def validate_module_semantics_r( module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph, @@ -49,6 +61,11 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + if "type" in module_ast._metadata: + # we don't need to analyse again, skip out + assert isinstance(module_ast._metadata["type"], ModuleT) + return module_ast._metadata["type"] + validate_literal_nodes(module_ast) # validate semantics and annotate AST with type/semantics information @@ -64,6 +81,8 @@ def validate_semantics_r( # in `ContractFunction.from_vyi()` if not is_interface: validate_functions(module_ast) + analyzer.validate_initialized_modules() + analyzer.validate_used_modules() return ret @@ -121,11 +140,8 @@ def __init__( def analyze(self) -> ModuleT: # generate a `ModuleT` from the top-level node # note: also validates unique method ids - if "type" in self.ast._metadata: - assert isinstance(self.ast._metadata["type"], ModuleT) - # we don't need to analyse again, skip out - self.module_t = self.ast._metadata["type"] - return self.module_t + + assert "type" not in self.ast._metadata to_visit = self.ast.body.copy() @@ -138,6 +154,11 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + ownership_decls = self.ast.get_children((vy_ast.UsesDecl, vy_ast.InitializesDecl)) + for node in ownership_decls: + self.visit(node) + to_visit.remove(node) + # we can resolve constants after imports are handled. constant_fold(self.ast) @@ -179,6 +200,7 @@ def analyze(self) -> ModuleT: def analyze_call_graph(self): # get list of internal function calls made by each function + # CMC 2024-02-03 note: this could be cleaner in analysis/local.py function_defs = self.module_t.function_defs for func in function_defs: @@ -195,7 +217,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -204,6 +228,106 @@ def analyze_call_graph(self): # compute reachable set and validate the call graph _compute_reachable_set(fn_t) + def validate_used_modules(self): + # check all `uses:` modules are actually used + should_use = {} + + module_t = self.ast._metadata["type"] + uses_decls = module_t.uses_decls + for decl in uses_decls: + info = decl._metadata["uses_info"] + for m in info.used_modules: + should_use[m.module_t] = (m, info) + + initialized_modules = {t.module_info.module_t: t for t in module_t.initialized_modules} + + all_used_modules = OrderedSet() + + for f in module_t.functions.values(): + for u in f._used_modules: + all_used_modules.add(u.module_t) + + for used_module in all_used_modules: + if used_module in initialized_modules: + continue + + if used_module in should_use: + del should_use[used_module] + + if len(should_use) > 0: + err_list = ExceptionList() + for used_module_info, uses_info in should_use.values(): + msg = f"`{used_module_info.alias}` is declared as used, but " + msg += f"it is not actually used in {module_t}!" + hint = f"delete `uses: {used_module_info.alias}`" + err_list.append(BorrowException(msg, uses_info.node, hint=hint)) + + err_list.raise_if_not_empty() + + def validate_initialized_modules(self): + # check all `initializes:` modules have `__init__()` called exactly once + module_t = self.ast._metadata["type"] + should_initialize = {t.module_info.module_t: t for t in module_t.initialized_modules} + # don't call `__init__()` for modules which don't have + # `__init__()` function + for m in should_initialize.copy(): + for f in m.functions.values(): + if f.is_constructor: + break + else: + del should_initialize[m] + + init_calls = [] + for f in self.ast.get_children(vy_ast.FunctionDef): + if f._metadata["func_type"].is_constructor: + init_calls = f.get_descendants(vy_ast.Call) + break + + seen_initializers = {} + for call_node in init_calls: + expr_info = call_node.func._expr_info + if expr_info is None: + # this can happen for range() calls; CMC 2024-02-05 try to + # refactor so that range() is properly tagged. + continue + + call_t = call_node.func._expr_info.typ + + if not isinstance(call_t, ContractFunctionT): + continue + + if not call_t.is_constructor: + continue + + # XXX: check this works as expected for nested attributes + initialized_module = call_node.func.value._expr_info.module_info + + if initialized_module.module_t in seen_initializers: + seen_location = seen_initializers[initialized_module.module_t] + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but its __init__() function was already called!" + raise InitializerException(msg, call_node.func, seen_location) + + if initialized_module.module_t not in should_initialize: + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but it is not in initializer list!" + hint = f"add `initializes: {initialized_module.alias}` " + hint += "as a top-level statement to your contract" + raise InitializerException(msg, call_node.func, hint=hint) + + del should_initialize[initialized_module.module_t] + seen_initializers[initialized_module.module_t] = call_node.func + + if len(should_initialize) > 0: + err_list = ExceptionList() + for s in should_initialize.values(): + msg = "not initialized!" + hint = f"add `{s.module_info.alias}.__init__()` to " + hint += "your `__init__()` function" + err_list.append(InitializerException(msg, s.node, hint=hint)) + + err_list.raise_if_not_empty() + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: # cache ast if we have seen it before. # this gives us the additional property of object equality on @@ -218,10 +342,100 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("Invalid interface name", node.annotation) + raise StructureException("not an interface!", node.annotation) type_.validate_implements(node) + def visit_UsesDecl(self, node): + # TODO: check duplicate uses declarations, e.g. + # uses: x + # ... + # uses: x + items = vy_ast.as_tuple(node.annotation) + + used_modules = [] + + for item in items: + module_info = get_expr_info(item).module_info + if module_info is None: + raise StructureException("not a valid module!", item) + + # note: try to refactor - not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.USES, item) + + used_modules.append(module_info) + + node._metadata["uses_info"] = UsesInfo(used_modules, node) + + def visit_InitializesDecl(self, node): + module_ref = node.annotation + dependencies_ast = () + if isinstance(module_ref, vy_ast.Subscript): + dependencies_ast = vy_ast.as_tuple(module_ref.slice) + module_ref = module_ref.value + + # postcondition of InitializesDecl.validates() + assert isinstance(module_ref, (vy_ast.Name, vy_ast.Attribute)) + + module_info = get_expr_info(module_ref).module_info + if module_info is None: + raise StructureException("Not a module!", module_ref) + + used_modules = {i.module_t: i for i in module_info.module_t.used_modules} + + dependencies = [] + for named_expr in dependencies_ast: + assert isinstance(named_expr, vy_ast.NamedExpr) + + rhs_module = get_expr_info(named_expr.value).module_info + + with module_info.module_node.namespace(): + # lhs of the named_expr is evaluated in the namespace of the + # initialized module! + try: + lhs_module = get_expr_info(named_expr.target).module_info + except VyperException as e: + # try to report a common problem - user names the module in + # the current namespace instead of the initialized module + # namespace. + + # search for the module in the initialized module + found_module = module_info.module_t.find_module_info(rhs_module.module_t) + if found_module is not None: + msg = f"unknown module `{named_expr.target.id}`" + hint = f"did you mean `{found_module.alias} := {rhs_module.alias}`?" + raise UndeclaredDefinition(msg, named_expr.target, hint=hint) + + raise e from None + + if lhs_module.module_t != rhs_module.module_t: + raise StructureException( + f"{lhs_module.alias} is not {rhs_module.alias}!", named_expr + ) + dependencies.append(lhs_module) + + if lhs_module.module_t not in used_modules: + raise InitializerException( + f"`{module_info.alias}` is initialized with `{lhs_module.alias}`, " + f"but `{module_info.alias}` does not use `{lhs_module.alias}`!", + named_expr, + ) + + del used_modules[lhs_module.module_t] + + if len(used_modules) > 0: + item = next(iter(used_modules.values())) # just pick one + msg = f"`{module_info.alias}` uses `{item.alias}`, but it is not " + msg += f"initialized with `{item.alias}`" + hint = f"add `{item.alias}` to its initializer list" + raise InitializerException(msg, node, hint=hint) + + # note: try to refactor. not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.INITIALIZES, node) + node._metadata["initializes_info"] = InitializesInfo(module_info, dependencies, node) + def visit_VariableDecl(self, node): name = node.get("target.id") if name is None: @@ -250,7 +464,7 @@ def visit_VariableDecl(self, node): if len(wrong_self_attribute) > 0 else "Immutable definition requires an assignment in the constructor" ) - raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) + raise ImmutableViolation(message, node) data_loc = ( DataLocation.CODE @@ -364,11 +578,10 @@ def visit_Import(self, node): # don't handle things like `import x.y` if "." in alias: + msg = "import requires an accompanying `as` statement" suggested_alias = node.name[node.name.rfind(".") :] - suggestion = f"hint: try `import {node.name} as {suggested_alias}`" - raise StructureException( - f"import requires an accompanying `as` statement ({suggestion})", node - ) + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) self._add_import(node, 0, node.name, alias) @@ -436,14 +649,14 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - module_t = validate_semantics_r( + module_t = validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, is_interface=False, ) - return ModuleInfo(module_t) + return ModuleInfo(module_t, alias) except FileNotFoundError as e: # escape `e` from the block scope, it can make things @@ -456,7 +669,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - validate_semantics_r( + validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -481,7 +694,7 @@ def _load_import_helper( raise ModuleNotFound(module_str, node) from err -def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: +def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, @@ -542,5 +755,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + module_t = validate_module_semantics_r( + interface_ast, input_bundle, ImportGraph(), is_interface=True + ) return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index abbf6a68cc..f1f0f48a86 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -61,8 +61,8 @@ class _ExprAnalyser: def __init__(self): self.namespace = get_namespace() - def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: - t = self.get_exact_type_from_node(node) + def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> ExprInfo: + t = self.get_exact_type_from_node(node, include_type_exprs=is_callable) # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): @@ -74,7 +74,10 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, ModuleInfo): return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) + if isinstance(info, VyperType): + return ExprInfo(TYPE_T(info)) + + raise CompilerPanic(f"unreachable! {info}", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -82,30 +85,27 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # note: Attribute(expr value, identifier attr) name = node.attr - info = self.get_expr_info(node.value) + info = self.get_expr_info(node.value, is_callable=is_callable) + + attribute_chain = info.attribute_chain + [info] t = info.typ.get_member(name, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) - # it's something else, like my_struct.foo - return info.copy_with_type(t) + if isinstance(t, ModuleInfo): + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) - if isinstance(node, vy_ast.Tuple): - # always use the most restrictive location re: modification - # kludge! for validate_modification in local analysis of Assign - types = [self.get_expr_info(n) for n in node.elements] - location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - - return ExprInfo(t, location=location, modifiability=modifiability) + # it's something else, like my_struct.foo + return info.copy_with_type(t, attribute_chain=attribute_chain) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - return info.copy_with_type(t) + attribute_chain = info.attribute_chain + [info] + return info.copy_with_type(t, attribute_chain=attribute_chain) return ExprInfo(t) @@ -184,6 +184,7 @@ def _find_fn(self, node): def types_from_Attribute(self, node): is_self_reference = node.get("value.id") == "self" + # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr @@ -476,8 +477,10 @@ def get_exact_type_from_node(node): return _ExprAnalyser().get_exact_type_from_node(node, include_type_exprs=True) -def get_expr_info(node: vy_ast.VyperNode) -> ExprInfo: - return _ExprAnalyser().get_expr_info(node) +def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo: + if node._expr_info is None: + node._expr_info = _ExprAnalyser().get_expr_info(node, is_callable) + return node._expr_info def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List: @@ -639,7 +642,7 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: +def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ @@ -665,5 +668,5 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> if hasattr(call_type, "check_modifiability_for_call"): return call_type.check_modifiability_for_call(node, modifiability) - value_type = get_expr_info(node) - return value_type.modifiability >= modifiability + info = get_expr_info(node) + return info.modifiability <= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index cecea35a60..06245aa90d 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -1,10 +1,12 @@ import enum +from vyper.utils import StringEnum -class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - TRANSIENT = 5 + +class DataLocation(StringEnum): + UNSET = enum.auto() + MEMORY = enum.auto() + STORAGE = enum.auto() + CALLDATA = enum.auto() + CODE = enum.auto() + TRANSIENT = enum.auto() diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d659276ee0..c5e10b52be 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -25,7 +26,11 @@ def __init__(self, type_): self.type_ = type_ def compare_type(self, other): - return isinstance(other, self.type_) or self == other + if isinstance(other, self.type_): + return True + # compare two GenericTypeAcceptors -- they are the same if the base + # type is the same + return isinstance(other, self.__class__) and other.type_ == self.type_ class VyperType: @@ -91,6 +96,8 @@ def __hash__(self): return hash(self._get_equality_attrs()) def __eq__(self, other): + if self is other: + return True return ( type(self) is type(other) and self._get_equality_attrs() == other._get_equality_attrs() ) @@ -118,6 +125,16 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + def get_size_in(self, location: DataLocation): + if location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + return self.storage_size_in_words + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.CODE: + return self.memory_bytes_required + + raise CompilerPanic("unreachable: invalid location {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: # alias for API compatibility with codegen @@ -341,8 +358,10 @@ def map_void(typ: Optional[VyperType]) -> VyperType: # A type type. Used internally for types which can live in expression # position, ex. constructors (events, interfaces and structs), and also # certain builtins which take types as parameters -class TYPE_T: +class TYPE_T(VyperType): def __init__(self, typedef): + super().__init__() + self.typedef = typedef def __repr__(self): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 2d92370b9d..62f9c60585 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -19,8 +19,10 @@ from vyper.semantics.analysis.base import ( FunctionVisibility, Modifiability, + ModuleInfo, StateMutability, - StorageSlot, + VarInfo, + VarOffset, ) from vyper.semantics.analysis.utils import ( check_modifiability, @@ -112,10 +114,27 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # writes to variables from this function + self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + + # reads of variables from this function + self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + + # list of modules used (accessed state) by this function + self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + @property + def _variable_accesses(self): + return self._variable_reads | self._variable_writes + + @property + def modifiability(self): + return Modifiability.from_state_mutability(self.mutability) + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site @@ -269,9 +288,11 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - "function body in an interface can only be ...!", funcdef + "function body in an interface can only be `...`!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -314,13 +335,19 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Default function may not receive any arguments", funcdef.args.args[0] ) + if function_visibility == FunctionVisibility.DEPLOY and funcdef.name != "__init__": + raise FunctionDeclarationException( + "Only constructors can be marked as `@deploy`!", funcdef + ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility != FunctionVisibility.DEPLOY: + raise FunctionDeclarationException( + f"Constructor must be marked as `@deploy`, not `@{function_visibility}`", + funcdef, ) if return_type is not None: raise FunctionDeclarationException( @@ -333,6 +360,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -344,14 +374,11 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ast_def=funcdef, ) - def set_reentrancy_key_position(self, position: StorageSlot) -> None: + def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") - # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: - raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position @classmethod @@ -456,6 +483,14 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_deploy(self) -> bool: + return self.visibility == FunctionVisibility.DEPLOY + + @property + def is_constructor(self) -> bool: + return self.name == "__init__" + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -464,10 +499,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -535,20 +566,14 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: modified_line = re.sub( kwarg_pattern, kwarg.value.node_source_code, node.node_source_code ) - error_suggestion = ( - f"\n(hint: Try removing the kwarg: `{modified_line}`)" - if modified_line != node.node_source_code - else "" - ) - raise ArgumentException( - ( - "Usage of kwarg in Vyper is restricted to " - + ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) - + f". {error_suggestion}" - ), - kwarg, - ) + msg = "Usage of kwarg in Vyper is restricted to " + msg += ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) + + hint = None + if modified_line != node.node_source_code: + hint = f"Try removing the kwarg: `{modified_line}`" + raise ArgumentException(msg, kwarg, hint=hint) return self.return_type @@ -601,7 +626,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -632,7 +657,9 @@ def _parse_decorators( if FunctionVisibility.is_valid_value(decorator.id): if function_visibility is not None: raise FunctionDeclarationException( - f"Visibility is already set to: {function_visibility}", funcdef + f"Visibility is already set to: {function_visibility}", + decorator, + hint="only one visibility decorator is allowed per function", ) function_visibility = FunctionVisibility(decorator.id) @@ -748,6 +775,10 @@ def __init__( self.return_type = return_type self.is_modifying = is_modifying + @property + def modifiability(self): + return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT + def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ee1da22a87..86840f4f91 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType @@ -16,12 +16,16 @@ validate_expected_type, validate_unique_method_ids, ) +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType +if TYPE_CHECKING: + from vyper.semantics.analysis.base import ModuleInfo + class InterfaceT(_UserType): _type_members = {"address": AddressT()} @@ -234,7 +238,7 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for node in module_t.function_defs: func_t = node._metadata["func_type"] - if not func_t.is_external: + if not (func_t.is_external or func_t.is_constructor): continue funcs.append((node.name, func_t)) @@ -276,6 +280,12 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): _attribute_in_annotation = True + _invalid_locations = ( + DataLocation.CALLDATA, + DataLocation.CODE, + DataLocation.MEMORY, + DataLocation.TRANSIENT, + ) def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -307,7 +317,6 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.interface_defs: # add the type of the interface so it can be used in call position self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore - self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -316,6 +325,13 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): import_info = i._metadata["import_info"] self.add_member(import_info.alias, import_info.typ) + if hasattr(import_info.typ, "module_t"): + self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + + for name, interface_t in self.interfaces.items(): + # can access interfaces in type position + self._helper.add_member(name, TYPE_T(interface_t)) + # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, # search path, symlinked vs normalized path, etc.) @@ -345,27 +361,97 @@ def struct_defs(self): def interface_defs(self): return self._module.get_children(vy_ast.InterfaceDef) + @cached_property + def interfaces(self) -> dict[str, InterfaceT]: + ret = {} + for i in self.interface_defs: + assert i.name not in ret # precondition + ret[i.name] = i._metadata["interface_type"] + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + if isinstance(import_info.typ, InterfaceT): + assert import_info.alias not in ret # precondition + ret[import_info.alias] = import_info.typ + + return ret + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + @cached_property + def imported_modules(self) -> dict[str, "ModuleInfo"]: + ret = {} + for s in self.import_stmts: + info = s._metadata["import_info"] + module_info = info.typ + if isinstance(module_info, InterfaceT): + continue + ret[info.alias] = module_info + return ret + + def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: + for s in self.imported_modules.values(): + if s.module_t == needle: + return s + return None + @property def variable_decls(self): return self._module.get_children(vy_ast.VariableDecl) + @property + def uses_decls(self): + return self._module.get_children(vy_ast.UsesDecl) + + @property + def initializes_decls(self): + return self._module.get_children(vy_ast.InitializesDecl) + + @cached_property + def used_modules(self): + # modules which are written to + ret = [] + for node in self.uses_decls: + for used_module in node._metadata["uses_info"].used_modules: + ret.append(used_module) + return ret + + @property + def initialized_modules(self): + # modules which are initialized to + ret = [] + for node in self.initializes_decls: + info = node._metadata["initializes_info"] + ret.append(info) + return ret + @cached_property def variables(self): # variables that this module defines, ex. # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def functions(self): + return {f.name: f._metadata["func_type"] for f in self.function_defs} + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + ret = 0 + for s in self.immutables: + ret += s.typ.memory_bytes_required + + for initializes_info in self.initialized_modules: + module_t = initializes_info.module_info.module_t + ret += module_t.immutable_section_bytes + + return ret @cached_property def interface(self): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 5564570536..c6a4531df8 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -117,16 +117,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: if isinstance(node, vy_ast.Attribute): # ex. SomeModule.SomeStruct - # sanity check - we only allow modules/interfaces to be - # imported as `Name`s currently. - if not isinstance(node.value, vy_ast.Name): + if isinstance(node.value, vy_ast.Attribute): + module_or_interface = _type_from_annotation(node.value) + elif isinstance(node.value, vy_ast.Name): + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + else: raise InvalidType(err_msg, node) - try: - module_or_interface = namespace[node.value.id] # type: ignore - except UndeclaredDefinition: - raise InvalidType(err_msg, node) from None - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo module_or_interface = module_or_interface.module_t diff --git a/vyper/utils.py b/vyper/utils.py index 2349731b97..b2284eaba0 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -1,6 +1,7 @@ import binascii import contextlib import decimal +import enum import functools import sys import time @@ -8,7 +9,7 @@ import warnings from typing import Generic, List, TypeVar, Union -from vyper.exceptions import DecimalOverrideException, InvalidLiteral +from vyper.exceptions import CompilerPanic, DecimalOverrideException, InvalidLiteral _T = TypeVar("_T") @@ -62,6 +63,59 @@ def copy(self): return self.__class__(super().copy()) +class StringEnum(enum.Enum): + # Must be first, or else won't work, specifies what .value is + def _generate_next_value_(name, start, count, last_values): + return name.lower() + + # Override ValueError with our own internal exception + @classmethod + def _missing_(cls, value): + raise CompilerPanic(f"{value} is not a valid {cls.__name__}") + + @classmethod + def is_valid_value(cls, value: str) -> bool: + return value in set(o.value for o in cls) + + @classmethod + def options(cls) -> List["StringEnum"]: + return list(cls) + + @classmethod + def values(cls) -> List[str]: + return [v.value for v in cls.options()] + + # Comparison operations + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + return self is other + + # Python normally does __ne__(other) ==> not self.__eq__(other) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + options = self.__class__.options() + return options.index(self) < options.index(other) # type: ignore + + def __le__(self, other: object) -> bool: + return self.__eq__(other) or self.__lt__(other) + + def __gt__(self, other: object) -> bool: + return not self.__le__(other) + + def __ge__(self, other: object) -> bool: + return not self.__lt__(other) + + def __str__(self) -> str: + return self.value + + def __hash__(self) -> int: + # let `dataclass` know that this class is not mutable + return super().__hash__() + + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): if name == "prec": From cc7c19885b217539c0045b9bd26fed2e1fe76e5e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 07:08:25 -0800 Subject: [PATCH 180/201] fix: fuzz test not updated to use TypeMismatch (#3768) this is a regression introduced in c6b29c7f06a; the exception thrown by `validate_expected_type()` was updated to be `TypeMismatch`, but this test was not correspondingly updated. --- tests/functional/builtins/folding/test_bitwise.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..f63ef8484a 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -1,9 +1,9 @@ import pytest -from hypothesis import given, settings +from hypothesis import example, given, settings from hypothesis import strategies as st from tests.utils import parse_and_fold -from vyper.exceptions import InvalidType, OverflowException +from vyper.exceptions import OverflowException, TypeMismatch from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -66,9 +66,10 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing -@settings(max_examples=50) +@settings(max_examples=51) @pytest.mark.parametrize("op", ["<<", ">>"]) @given(a=st_sint256, b=st.integers(min_value=0, max_value=256)) +@example(a=128, b=248) # throws TypeMismatch def test_bitwise_shift_signed(get_contract, a, b, op): source = f""" @external @@ -84,7 +85,7 @@ def foo(a: int256, b: uint256) -> int256: validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. - except (InvalidType, OverflowException): + except (TypeMismatch, OverflowException): # check the wrapped value matches runtime assert op == "<<" assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256) From 37ef8f4b54375a458e8b708cf3c41877b5f1655e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 11 Feb 2024 07:08:53 -0800 Subject: [PATCH 181/201] chore: run mypy as part of lint rule in Makefile (#3771) and remove the separate mypy rule. this makes the development workflow a bit faster --- Makefile | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 645b800e79..649b381012 100644 --- a/Makefile +++ b/Makefile @@ -17,11 +17,8 @@ dev-init: test: pytest -mypy: - tox -e mypy - lint: - tox -e lint + tox -e lint,mypy docs: rm -f docs/vyper.rst From 261e3d9349cd8acc202b3c63f16c73ef45035c1b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 14:24:10 -0800 Subject: [PATCH 182/201] fix: `StringEnum._generate_next_value_ signature` (#3770) per the documentation, `_generate_next_value_` should be a staticmethod. reference: https://docs.python.org/3/library/enum.html#enum.Enum._generate_next_value_ --- vyper/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/utils.py b/vyper/utils.py index b2284eaba0..ab4d789aa4 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -65,6 +65,7 @@ def copy(self): class StringEnum(enum.Enum): # Must be first, or else won't work, specifies what .value is + @staticmethod def _generate_next_value_(name, start, count, last_values): return name.lower() From a2eb60c713ee538ace46dde5c8ffbe625c1daa86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Schoentgen?= Date: Mon, 12 Feb 2024 23:25:33 +0100 Subject: [PATCH 183/201] docs: adopt a new theme: `shibuya` (#3754) --- .readthedocs.yaml | 13 +- Dockerfile | 2 +- README.md | 12 +- docs/_static/css/dark.css | 215 --------------------------- docs/_static/css/toggle.css | 77 ---------- docs/_static/js/toggle.js | 26 ---- docs/built-in-functions.rst | 166 ++++++++++----------- docs/compiler-exceptions.rst | 14 +- docs/compiling-a-contract.rst | 28 ++-- docs/conf.py | 141 +++--------------- docs/constants-and-vars.rst | 6 +- docs/contributing.rst | 2 +- docs/control-structures.rst | 32 ++-- docs/event-logging.rst | 8 +- docs/index.rst | 2 +- docs/interfaces.rst | 24 +-- docs/logo.svg | 4 + docs/natspec.rst | 10 +- docs/scoping-and-declarations.rst | 32 ++-- docs/statements.rst | 16 +- docs/structure-of-a-contract.rst | 24 +-- docs/testing-contracts-brownie.rst | 9 +- docs/testing-contracts-ethtester.rst | 11 +- docs/types.rst | 20 +-- docs/vyper-by-example.rst | 78 +++++----- docs/vyper-logo-transparent.svg | 11 -- examples/tokens/ERC20.vy | 2 +- requirements-docs.txt | 4 +- tox.ini | 4 +- 29 files changed, 286 insertions(+), 707 deletions(-) delete mode 100644 docs/_static/css/dark.css delete mode 100644 docs/_static/css/toggle.css delete mode 100644 docs/_static/js/toggle.js create mode 100644 docs/logo.svg delete mode 100644 docs/vyper-logo-transparent.svg diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 1ad9000f53..e7f5fa079a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,23 +1,20 @@ -# File: .readthedocs.yaml - version: 2 -# Set the version of Python and other tools you might need build: # TODO: update to `latest` once supported # https://github.com/readthedocs/readthedocs.org/issues/8861 os: ubuntu-22.04 tools: - python: "3.10" + python: "3.11" -# Build from the docs/ directory with Sphinx sphinx: configuration: docs/conf.py -formats: all - +# We can't use "all" because "htmlzip" format is broken for now +formats: + - epub + - pdf -# Optionally declare the Python requirements required to build your docs python: install: - requirements: requirements-docs.txt diff --git a/Dockerfile b/Dockerfile index bc5bb607d6..b4bfa6d3a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG VCS_REF LABEL org.label-schema.build-date=$BUILD_DATE \ org.label-schema.name="Vyper" \ org.label-schema.description="Vyper is an experimental programming language" \ - org.label-schema.url="https://vyper.readthedocs.io/en/latest/" \ + org.label-schema.url="https://docs.vyperlang.org/en/latest/" \ org.label-schema.vcs-ref=$VCS_REF \ org.label-schema.vcs-url="https://github.com/vyperlang/vyper" \ org.label-schema.vendor="Vyper Team" \ diff --git a/README.md b/README.md index 33c4557cc8..b14b7eaaf0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Build Status](https://github.com/vyperlang/vyper/workflows/Test/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/test.yml) -[![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://vyper.readthedocs.io/en/latest/?badge=latest "ReadTheDocs") +[![Documentation Status](https://readthedocs.org/projects/vyper/badge/?version=latest)](http://docs.vyperlang.org/en/latest/?badge=latest "ReadTheDocs") [![Discord](https://img.shields.io/discord/969926564286459934.svg?label=%23vyper)](https://discord.gg/6tw7PTM7C2) [![PyPI](https://badge.fury.io/py/vyper.svg)](https://pypi.org/project/vyper "PyPI") @@ -13,9 +13,9 @@ [![Language grade: Python](https://github.com/vyperlang/vyper/workflows/CodeQL/badge.svg)](https://github.com/vyperlang/vyper/actions/workflows/codeql.yml) # Getting Started -See [Installing Vyper](http://vyper.readthedocs.io/en/latest/installing-vyper.html) to install vyper. +See [Installing Vyper](http://docs.vyperlang.org/en/latest/installing-vyper.html) to install vyper. See [Tools and Resources](https://github.com/vyperlang/vyper/wiki/Vyper-tools-and-resources) for an additional list of framework and tools with vyper support. -See [Documentation](http://vyper.readthedocs.io/en/latest/index.html) for the documentation and overall design goals of the Vyper language. +See [Documentation](http://docs.vyperlang.org/en/latest/index.html) for the documentation and overall design goals of the Vyper language. See [Learn.Vyperlang.org](https://learn.vyperlang.org/) for **learning Vyper by building a Pokémon game**. See [try.vyperlang.org](https://try.vyperlang.org/) to use Vyper in a hosted jupyter environment! @@ -23,7 +23,7 @@ See [try.vyperlang.org](https://try.vyperlang.org/) to use Vyper in a hosted jup **Note: Vyper is beta software, use with care** # Installation -See the [Vyper documentation](https://vyper.readthedocs.io/en/latest/installing-vyper.html) +See the [Vyper documentation](https://docs.vyperlang.org/en/latest/installing-vyper.html) for build instructions. # Compiling a contract @@ -47,7 +47,7 @@ be a bit behind the latest version found in the master branch of this repository ## Testing (using pytest) -(Complete [installation steps](https://vyper.readthedocs.io/en/latest/installing-vyper.html) first.) +(Complete [installation steps](https://docs.vyperlang.org/en/latest/installing-vyper.html) first.) ```bash make dev-init @@ -75,4 +75,4 @@ To get a call graph from a python profile, https://stackoverflow.com/a/23164271/ * See Issues tab, and feel free to submit your own issues * Add PRs if you discover a solution to an existing issue * For further discussions and questions, post in [Discussions](https://github.com/vyperlang/vyper/discussions) or talk to us on [Discord](https://discord.gg/6tw7PTM7C2) -* For more information, see [Contributing](http://vyper.readthedocs.io/en/latest/contributing.html) +* For more information, see [Contributing](http://docs.vyperlang.org/en/latest/contributing.html) diff --git a/docs/_static/css/dark.css b/docs/_static/css/dark.css deleted file mode 100644 index 158f08e0fc..0000000000 --- a/docs/_static/css/dark.css +++ /dev/null @@ -1,215 +0,0 @@ -/* links */ - -a, -a:visited { - color: #aaddff; -} - - -/* code directives */ - -.method dt, -.class dt, -.data dt, -.attribute dt, -.function dt, -.classmethod dt, -.exception dt, -.descclassname, -.descname { - background-color: #2d2d2d !important; -} - -.descname { - color: inherit !important; -} - -.rst-content dl:not(.docutils) dt { - color: #aaddff; - border-top: solid 3px #525252; - border-left: solid 3px #525252; -} - -em.property { - color: #888888; -} - - -/* tables */ - -.rst-content table.docutils thead { - color: #ddd; -} - -.rst-content table.docutils td { - border: 0px; -} - -.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td { - background-color: #5a5a5a; -} - - -/* inlined code highlights */ - -.xref, -.py-meth, -.rst-content a code { - color: #aaddff !important; - font-weight: normal !important; -} - -.rst-content code { - color: #eee !important; - font-weight: normal !important; -} - -code.literal { - background-color: #2d2d2d !important; - border: 1px solid #6d6d6d !important; -} - -code.docutils.literal.notranslate { - color: #ddd; -} - - -/* code examples */ - -pre { - background: #222; - color: #ddd; - font-size: 150%; - border-color: #333 !important; -} - -.copybutton { - color: #666 !important; - border-color: #333 !important; -} - -.highlight .go, -.highlight .nb, -.highlight .kn { - /* text */ - color: #ddd; - font-weight: normal; -} - -.highlight .o, -.highlight .p { - /* comparators, parentheses */ - color: #bbb; -} - -.highlight .c1 { - /* comments */ - color: #888; -} - -.highlight .bp { - /* self */ - color: #fc3; -} - -.highlight .mf, -.highlight .mi, -.highlight .kc { - /* numbers, booleans */ - color: #c90; -} - -.highlight .gt, -.highlight .nf, -.highlight .fm { - /* functions */ - color: #7cf; -} - -.highlight .nd { - /* decorators */ - color: #f66; -} - -.highlight .k, -.highlight .ow { - /* statements */ - color: #A7F; - font-weight: normal; -} - -.highlight .s2, -.highlight .s1, -.highlight .nt { - /* strings */ - color: #5d6; -} - - -/* notes, warnings, hints */ - -.hint .admonition-title { - background: #2aa87c !important; -} - -.warning .admonition-title { - background: #cc4444 !important; -} - -.admonition-title { - background: #3a7ca8 !important; -} - -.admonition, -.note { - background-color: #2d2d2d !important; -} - - -/* table of contents */ - -.wy-body-for-nav { - background-color: rgb(26, 28, 29); -} - -.wy-nav-content-wrap { - background-color: rgba(0, 0, 0, 0.6) !important; -} - -.sidebar { - background-color: #191919 !important; -} - -.sidebar-title { - background-color: #2b2b2b !important; -} - -.wy-menu-vertical a { - color: #ddd; -} - -.wy-menu-vertical code.docutils.literal.notranslate { - color: #404040; - background: none !important; - border: none !important; -} - -.wy-nav-content { - background: #3c3c3c; - color: #dddddd; -} - -.wy-menu-vertical li.on a, -.wy-menu-vertical li.current>a { - background: #a3a3a3; - border-bottom: 0px !important; - border-top: 0px !important; -} - -.wy-menu-vertical li.current { - background: #b3b3b3; -} - -.toc-backref { - color: grey !important; -} \ No newline at end of file diff --git a/docs/_static/css/toggle.css b/docs/_static/css/toggle.css deleted file mode 100644 index ebbd0658a1..0000000000 --- a/docs/_static/css/toggle.css +++ /dev/null @@ -1,77 +0,0 @@ -input[type=checkbox] { - visibility: hidden; - height: 0; - width: 0; - margin: 0; -} - -.rst-versions .rst-current-version { - padding: 10px; - display: flex; - justify-content: space-between; -} - -.rst-versions .rst-current-version .fa-book, -.rst-versions .rst-current-version .fa-v, -.rst-versions .rst-current-version .fa-caret-down { - height: 24px; - line-height: 24px; - vertical-align: middle; -} - -.rst-versions .rst-current-version .fa-element { - width: 80px; - text-align: center; -} - -.rst-versions .rst-current-version .fa-book { - text-align: left; -} - -.rst-versions .rst-current-version .fa-v { - color: #27AE60; - text-align: right; -} - -label { - margin: 0 auto; - display: inline-block; - justify-content: center; - align-items: right; - border-radius: 100px; - position: relative; - cursor: pointer; - text-indent: -9999px; - width: 50px; - height: 21px; - background: #000; -} - -label:after { - border-radius: 50%; - position: absolute; - content: ''; - background: #fff; - width: 15px; - height: 15px; - top: 3px; - left: 3px; - transition: ease-in-out 200ms; -} - -input:checked+label { - background: #3a7ca8; -} - -input:checked+label:after { - left: calc(100% - 5px); - transform: translateX(-100%); -} - -html.transition, -html.transition *, -html.transition *:before, -html.transition *:after { - transition: ease-in-out 200ms !important; - transition-delay: 0 !important; -} \ No newline at end of file diff --git a/docs/_static/js/toggle.js b/docs/_static/js/toggle.js deleted file mode 100644 index df131042b5..0000000000 --- a/docs/_static/js/toggle.js +++ /dev/null @@ -1,26 +0,0 @@ -document.addEventListener('DOMContentLoaded', function() { - - var checkbox = document.querySelector('input[name=mode]'); - - function toggleCssMode(isDay) { - var mode = (isDay ? "Day" : "Night"); - localStorage.setItem("css-mode", mode); - - var darksheet = $('link[href="_static/css/dark.css"]')[0].sheet; - darksheet.disabled = isDay; - } - - if (localStorage.getItem("css-mode") == "Day") { - toggleCssMode(true); - checkbox.setAttribute('checked', true); - } - - checkbox.addEventListener('change', function() { - document.documentElement.classList.add('transition'); - window.setTimeout(() => { - document.documentElement.classList.remove('transition'); - }, 1000) - toggleCssMode(this.checked); - }) - -}); \ No newline at end of file diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index 45cf9ec8c2..afb64e71ca 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -14,14 +14,14 @@ Bitwise Operations Perform a "bitwise and" operation. Each bit of the output is 1 if the corresponding bit of ``x`` AND of ``y`` is 1, otherwise it is 0. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_and(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 12353 @@ -34,14 +34,14 @@ Bitwise Operations Return the bitwise complement of ``x`` - the number you get by switching each 1 for a 0 and each 0 for a 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256) -> uint256: return bitwise_not(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(0) 115792089237316195423570985008687907853269984665640564039457584007913129639935 @@ -54,14 +54,14 @@ Bitwise Operations Perform a "bitwise or" operation. Each bit of the output is 0 if the corresponding bit of ``x`` AND of ``y`` is 0, otherwise it is 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_or(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 8027119 @@ -74,14 +74,14 @@ Bitwise Operations Perform a "bitwise exclusive or" operation. Each bit of the output is the same as the corresponding bit in ``x`` if that bit in ``y`` is 0, and it is the complement of the bit in ``x`` if that bit in ``y`` is 1. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: uint256) -> uint256: return bitwise_xor(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(31337, 8008135) 8014766 @@ -94,14 +94,14 @@ Bitwise Operations Return ``x`` with the bits shifted ``_shift`` places. A positive ``_shift`` value equals a left shift, a negative value is a right shift. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256, y: int128) -> uint256: return shift(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(2, 8) 512 @@ -144,7 +144,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the newly created proxy contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(target: address) -> address: @@ -173,7 +173,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the created contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. If there is no code at ``target``, execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(target: address) -> address: @@ -197,7 +197,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns the address of the created contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. If ``code_offset >= target.codesize`` (ex. if there is no code at ``target``), execution will revert. - .. code-block:: python + .. code-block:: vyper @external def foo(blueprint: address) -> address: @@ -213,7 +213,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui It is recommended to deploy blueprints with the ERC-5202 preamble ``0xFE7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided to enable this pattern: - .. code-block:: python + .. code-block:: vyper @external def foo(blueprint: address) -> address: @@ -241,7 +241,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui Returns ``success`` in a tuple with return value if ``revert_on_failure`` is set to ``False``. - .. code-block:: python + .. code-block:: vyper @external @payable @@ -276,7 +276,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``topics``: List of ``bytes32`` log topics. The length of this array determines which opcode is used. * ``data``: Unindexed event data to include in the log. May be given as ``Bytes`` or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external def foo(_topic: bytes32, _data: Bytes[100]): @@ -288,7 +288,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``data``: Data representing the error message causing the revert. - .. code-block:: python + .. code-block:: vyper @external def foo(_data: Bytes[100]): @@ -308,7 +308,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui This function has been deprecated from version 0.3.8 onwards. The underlying opcode will eventually undergo breaking changes, and its use is not recommended. - .. code-block:: python + .. code-block:: vyper @external def do_the_needful(): @@ -326,7 +326,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui The amount to send is always specified in ``wei``. - .. code-block:: python + .. code-block:: vyper @external def foo(_receiver: address, _amount: uint256, gas: uint256): @@ -339,14 +339,14 @@ Cryptography Take two points on the Alt-BN128 curve and add them together. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256[2], y: uint256[2]) -> uint256[2]: return ecadd(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo([1, 2], [1, 2]) [ @@ -361,14 +361,14 @@ Cryptography * ``point``: Point to be multiplied * ``scalar``: Scalar value - .. code-block:: python + .. code-block:: vyper @external @view def foo(point: uint256[2], scalar: uint256) -> uint256[2]: return ecmul(point, scalar) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo([1, 2], 3) [ @@ -390,7 +390,7 @@ Cryptography Prior to Vyper ``0.3.10``, the ``ecrecover`` function could return an undefined (possibly nonzero) value for invalid inputs to ``ecrecover``. For more information, please see `GHSA-f5x6-7qgp-jhf3 `_. - .. code-block:: python + .. code-block:: vyper @external @view @@ -402,7 +402,7 @@ Cryptography @view def foo(hash: bytes32, v: uint256, r:uint256, s:uint256) -> address: return ecrecover(hash, v, r, s) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo('0x6c9c5e133b8aafb2ea74f524a5263495e7ae5701c7248805f7b511d973dc7055', 28, @@ -417,14 +417,14 @@ Cryptography * ``_value``: Value to hash. Can be a ``String``, ``Bytes``, or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(_value: Bytes[100]) -> bytes32 return keccak256(_value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(b"potato") 0x9e159dfcfe557cc1ca6c716e87af98fdcb94cd8c832386d0429b2b7bec02754f @@ -435,14 +435,14 @@ Cryptography * ``_value``: Value to hash. Can be a ``String``, ``Bytes``, or ``bytes32``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(_value: Bytes[100]) -> bytes32 return sha256(_value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(b"potato") 0xe91c254ad58860a02c788dfb5c1a65d6a8846ab1dc649631c7db16fef4af2dec @@ -456,14 +456,14 @@ Data Manipulation If the input arguments are ``String`` the return type is ``String``. Otherwise the return type is ``Bytes``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: String[5], b: String[5], c: String[5]) -> String[100]: return concat(a, " ", b, " ", c, "!") - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("why","hello","there") "why hello there!" @@ -487,14 +487,14 @@ Data Manipulation Returns the string representation of ``value``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(b: uint256) -> String[78]: return uint2str(b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(420) "420" @@ -509,14 +509,14 @@ Data Manipulation Returns a value of the type specified by ``output_type``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(b: Bytes[32]) -> address: return extract32(b, 0, output_type=address) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("0x0000000000000000000000009f8F72aA9304c8B593d555F12eF6589cC3A579A2") "0x9f8F72aA9304c8B593d555F12eF6589cC3A579A2" @@ -531,14 +531,14 @@ Data Manipulation If the value being sliced is a ``Bytes`` or ``bytes32``, the return type is ``Bytes``. If it is a ``String``, the return type is ``String``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> String[5]: return slice(s, 4, 5) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("why hello! how are you?") "hello" @@ -552,14 +552,14 @@ Math * ``value``: Integer to return the absolute value of - .. code-block:: python + .. code-block:: vyper @external @view def foo(value: int256) -> int256: return abs(value) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(-31337) 31337 @@ -570,14 +570,14 @@ Math * ``value``: Decimal value to round up - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: decimal) -> int256: return ceil(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(3.1337) 4 @@ -588,14 +588,14 @@ Math * ``typename``: Name of the decimal type (currently only ``decimal``) - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> decimal: return epsilon(decimal) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() Decimal('1E-10') @@ -606,14 +606,14 @@ Math * ``value``: Decimal value to round down - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: decimal) -> int256: return floor(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(3.1337) 3 @@ -622,14 +622,14 @@ Math Return the greater value of ``a`` and ``b``. The input values may be any numeric type as long as they are both of the same type. The output value is of the same type as the input values. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return max(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(23, 42) 42 @@ -638,14 +638,14 @@ Math Returns the maximum value of the numeric type specified by ``type_`` (e.g., ``int128``, ``uint256``, ``decimal``). - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> int256: return max_value(int256) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 57896044618658097711785492504343953926634992332820282019728792003956564819967 @@ -654,14 +654,14 @@ Math Returns the lesser value of ``a`` and ``b``. The input values may be any numeric type as long as they are both of the same type. The output value is of the same type as the input values. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return min(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(23, 42) 23 @@ -670,14 +670,14 @@ Math Returns the minimum value of the numeric type specified by ``type_`` (e.g., ``int128``, ``uint256``, ``decimal``). - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> int256: return min_value(int256) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() -57896044618658097711785492504343953926634992332820282019728792003956564819968 @@ -688,14 +688,14 @@ Math This method is used to perform exponentiation without overflow checks. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256) -> uint256: return pow_mod256(a, b) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(2, 3) 8 @@ -706,14 +706,14 @@ Math Return the square root of the provided decimal number, using the Babylonian square root algorithm. - .. code-block:: python + .. code-block:: vyper @external @view def foo(d: decimal) -> decimal: return sqrt(d) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(9.0) 3.0 @@ -722,14 +722,14 @@ Math Return the (integer) square root of the provided integer number, using the Babylonian square root algorithm. The rounding mode is to round down to the nearest integer. For instance, ``isqrt(101) == 10``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(x: uint256) -> uint256: return isqrt(x) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(101) 10 @@ -738,14 +738,14 @@ Math Return the modulo of ``(a + b) % c``. Reverts if ``c == 0``. As this built-in function is intended to provides access to the underlying ``ADDMOD`` opcode, all intermediate calculations of this operation are not subject to the ``2 ** 256`` modulo according to the EVM specifications. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256, c: uint256) -> uint256: return uint256_addmod(a, b, c) - .. code-block:: python + .. code-block:: vyper >>> (6 + 13) % 8 3 @@ -756,14 +756,14 @@ Math Return the modulo from ``(a * b) % c``. Reverts if ``c == 0``. As this built-in function is intended to provides access to the underlying ``MULMOD`` opcode, all intermediate calculations of this operation are not subject to the ``2 ** 256`` modulo according to the EVM specifications. - .. code-block:: python + .. code-block:: vyper @external @view def foo(a: uint256, b: uint256, c: uint256) -> uint256: return uint256_mulmod(a, b, c) - .. code-block:: python + .. code-block:: vyper >>> (11 * 2) % 5 2 @@ -774,7 +774,7 @@ Math Add ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result exceeds the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -787,7 +787,7 @@ Math return unsafe_add(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 2 @@ -805,7 +805,7 @@ Math Subtract ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result underflows the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -818,7 +818,7 @@ Math return unsafe_sub(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(4, 3) 1 @@ -837,7 +837,7 @@ Math Multiply ``x`` and ``y``, without checking for overflow. ``x`` and ``y`` must both be integers of the same type. If the result exceeds the bounds of the input type, it will be wrapped. - .. code-block:: python + .. code-block:: vyper @external @view @@ -850,7 +850,7 @@ Math return unsafe_mul(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 1 @@ -872,7 +872,7 @@ Math Divide ``x`` and ``y``, without checking for division-by-zero. ``x`` and ``y`` must both be integers of the same type. If the denominator is zero, the result will (following EVM semantics) be zero. - .. code-block:: python + .. code-block:: vyper @external @view @@ -885,7 +885,7 @@ Math return unsafe_div(x, y) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1, 1) 1 @@ -910,14 +910,14 @@ Utilities * ``_value``: Value for the ether unit. Any numeric type may be used, however the value cannot be negative. * ``unit``: Ether unit name (e.g. ``"wei"``, ``"ether"``, ``"gwei"``, etc.) indicating the denomination of ``_value``. Must be given as a literal string. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> uint256: return as_wei_value(1.337, "ether") - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo(1) 1337000000000000000 @@ -930,14 +930,14 @@ Utilities The EVM only provides access to the most recent 256 blocks. This function reverts if the block number is greater than or equal to the current block number or more than 256 blocks behind the current block. - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> bytes32: return blockhash(block.number - 16) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 0xf3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 @@ -948,7 +948,7 @@ Utilities * ``typename``: Name of the type, except ``HashMap[_KeyType, _ValueType]`` - .. code-block:: python + .. code-block:: vyper @external @view @@ -959,14 +959,14 @@ Utilities Return the length of a given ``Bytes``, ``String`` or ``DynArray[_Type, _Integer]``. - .. code-block:: python + .. code-block:: vyper @external @view def foo(s: String[32]) -> uint256: return len(s) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo("hello") 5 @@ -980,14 +980,14 @@ Utilities Returns a value of the type specified by ``output_type``. - .. code-block:: python + .. code-block:: vyper @external @view def foo() -> Bytes[4]: return method_id('transfer(address,uint256)', output_type=Bytes[4]) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo() 0xa9059cbb @@ -1003,7 +1003,7 @@ Utilities Returns a bytestring whose max length is determined by the arguments. For example, encoding a ``Bytes[32]`` results in a ``Bytes[64]`` (first word is the length of the bytestring variable). - .. code-block:: python + .. code-block:: vyper @external @view @@ -1012,7 +1012,7 @@ Utilities y: Bytes[32] = b"234" return _abi_encode(x, y, method_id=method_id("foo()")) - .. code-block:: python + .. code-block:: vyper >>> ExampleContract.foo().hex() "c2985578" @@ -1033,7 +1033,7 @@ Utilities Returns the decoded value(s), with type as specified by `output_type`. - .. code-block:: python + .. code-block:: vyper @external @view diff --git a/docs/compiler-exceptions.rst b/docs/compiler-exceptions.rst index 395ce448ed..29b8b5c96e 100644 --- a/docs/compiler-exceptions.rst +++ b/docs/compiler-exceptions.rst @@ -58,7 +58,7 @@ of the error within the code: Raises when no valid type can be found for a literal value. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -70,7 +70,7 @@ of the error within the code: Raises when using an invalid operator for a given type. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -82,7 +82,7 @@ of the error within the code: Raises on an invalid reference to an existing definition. - .. code-block:: python + .. code-block:: vyper baz: int128 @@ -96,7 +96,7 @@ of the error within the code: Raises when using an invalid literal value for the given type. - .. code-block:: python + .. code-block:: vyper @external def foo(): @@ -132,7 +132,7 @@ of the error within the code: Raises when attempting to access ``msg.value`` from within a function that has not been marked as ``@payable``. - .. code-block:: python + .. code-block:: vyper @public def _foo(): @@ -174,7 +174,7 @@ of the error within the code: Raises when attempting to perform an action between two or more objects with known, dislike types. - .. code-block:: python + .. code-block:: vyper @external def foo(: @@ -215,7 +215,7 @@ CompilerPanic .. py:exception:: CompilerPanic - :: + .. code:: shell $ vyper v.vy Error compiling: v.vy diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index b529d1efb1..2b069c2add 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -20,20 +20,20 @@ vyper To compile a contract: -:: +.. code:: shell $ vyper yourFileName.vy Include the ``-f`` flag to specify which output formats to return. Use ``vyper --help`` for a full list of output options. -:: +.. code:: shell $ vyper -f abi,bytecode,bytecode_runtime,ir,asm,source_map,method_identifiers yourFileName.vy The ``-p`` flag allows you to set a root path that is used when searching for interface files to import. If none is given, it will default to the current working directory. See :ref:`searching_for_imports` for more information. -:: +.. code:: shell $ vyper -p yourProject yourProject/yourFileName.vy @@ -45,7 +45,7 @@ Storage Layout To display the default storage layout for a contract: -:: +.. code:: shell $ vyper -f layout yourFileName.vy @@ -53,7 +53,7 @@ This outputs a JSON object detailing the locations for all state variables as de To override the default storage layout for a contract: -:: +.. code:: shell $ vyper --storage-layout-file storageLayout.json yourFileName.vy @@ -69,19 +69,19 @@ vyper-json To compile from JSON supplied via ``stdin``: -:: +.. code:: shell $ vyper-json To compile from a JSON file: -:: +.. code:: shell $ vyper-json yourProject.json By default, the output is sent to ``stdout``. To redirect to a file, use the ``-o`` flag: -:: +.. code:: shell $ vyper-json -o compiled.json @@ -143,7 +143,7 @@ When you compile your contract code, you can specify the target Ethereum Virtual For instance, the adding the following pragma to a contract indicates that it should be compiled for the "shanghai" fork of the EVM. -.. code-block:: python +.. code-block:: vyper #pragma evm-version shanghai @@ -153,13 +153,13 @@ For instance, the adding the following pragma to a contract indicates that it sh When compiling via the ``vyper`` CLI, you can specify the EVM version option using the ``--evm-version`` flag: -:: +.. code:: shell $ vyper --evm-version [VERSION] When using the JSON interface, you can include the ``"evmVersion"`` key within the ``"settings"`` field: -.. code-block:: javascript +.. code-block:: json { "settings": { @@ -200,8 +200,6 @@ The following is a list of supported EVM versions, and changes in the compiler i - The ``MCOPY`` opcode will be generated automatically by the compiler for most memory operations. - - Compiler Input and Output JSON Description ========================================== @@ -216,7 +214,7 @@ Input JSON Description The following example describes the expected input format of ``vyper-json``. Comments are of course not permitted and used here *only for explanatory purposes*. -.. code-block:: javascript +.. code-block:: json { // Required: Source code language. Must be set to "Vyper". @@ -294,7 +292,7 @@ Output JSON Description The following example describes the output format of ``vyper-json``. Comments are of course not permitted and used here *only for explanatory purposes*. -.. code-block:: javascript +.. code-block:: json { // The compiler version used to generate the JSON diff --git a/docs/conf.py b/docs/conf.py index 5dc1eee8f5..99ffe35a63 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,57 +1,12 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# # Vyper documentation build configuration file, created by # sphinx-quickstart on Wed Jul 26 11:18:29 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) -from recommonmark.parser import CommonMarkParser - -# TO DO - Create and Implement Vyper Lexer -# def setup(sphinx): -# sys.path.insert(0, os.path.abspath('./utils')) -# from SolidityLexer import SolidityLexer -# sphinx.add_lexer('Python', SolidityLexer()) - - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. extensions = [ - "sphinx.ext.autodoc", + "sphinx_copybutton", "sphinx.ext.intersphinx", ] -# Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = [".rst", ".md"] - -# The master toctree document. master_doc = "toctree" # General information about the project. @@ -59,68 +14,31 @@ copyright = "2017-2024 CC-BY-4.0 Vyper Team" author = "Vyper Team (originally created by Vitalik Buterin)" -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = "" -# The full version, including alpha/beta/rc tags. -release = "" - # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = "python" - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - +language = "vyper" # -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -html_css_files = ["css/toggle.css", "css/dark.css"] - -html_js_files = ["js/toggle.js"] - -html_logo = "vyper-logo-transparent.svg" - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# The default sidebars (for documents that don't match any pattern) are -# defined by theme itself. Builtin themes are using these templates by -# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', -# 'searchbox.html']``. -# -# html_sidebars = {} - +html_theme = "shibuya" +html_theme_options = { + "accent_color": "purple", + "twitter_creator": "vyperlang", + "twitter_site": "vyperlang", + "twitter_url": "https://twitter.com/vyperlang", + "github_url": "https://github.com/vyperlang", +} +html_favicon = "logo.svg" +html_logo = "logo.svg" + +# For the "Edit this page ->" link +html_context = { + "source_type": "github", + "source_user": "vyperlang", + "source_repo": "vyper", +} # -- Options for HTMLHelp output ------------------------------------------ @@ -130,21 +48,6 @@ # -- Options for LaTeX output --------------------------------------------- -latex_elements: dict = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). @@ -153,7 +56,7 @@ master_doc, "Vyper.tex", "Vyper Documentation", - "Vyper Team (originally created by Vitalik Buterin)", + author, "manual", ), ] @@ -183,10 +86,6 @@ ), ] -source_parsers = { - ".md": CommonMarkParser, -} - intersphinx_mapping = { "brownie": ("https://eth-brownie.readthedocs.io/en/stable", None), "pytest": ("https://docs.pytest.org/en/latest/", None), diff --git a/docs/constants-and-vars.rst b/docs/constants-and-vars.rst index 7f9c1408c5..00ce7a8ccc 100644 --- a/docs/constants-and-vars.rst +++ b/docs/constants-and-vars.rst @@ -56,7 +56,7 @@ Accessing State Variables ``self`` is used to access a contract's :ref:`state variables`, as shown in the following example: -.. code-block:: python +.. code-block:: vyper state_var: uint256 @@ -76,7 +76,7 @@ Calling Internal Functions ``self`` is also used to call :ref:`internal functions` within a contract: -.. code-block:: python +.. code-block:: vyper @internal def _times_two(amount: uint256) -> uint256: @@ -93,7 +93,7 @@ Custom Constants Custom constants can be defined at a global level in Vyper. To define a constant, make use of the ``constant`` keyword. -.. code-block:: python +.. code-block:: vyper TOTAL_SUPPLY: constant(uint256) = 10000000 total_supply: public(uint256) diff --git a/docs/contributing.rst b/docs/contributing.rst index 221600f930..55b2694424 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -5,7 +5,7 @@ Contributing Help is always appreciated! -To get started, you can try `installing Vyper `_ in order to familiarize +To get started, you can try `installing Vyper `_ in order to familiarize yourself with the components of Vyper and the build process. Also, it may be useful to become well-versed at writing smart-contracts in Vyper. diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 14202cbae7..a0aa927261 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -10,7 +10,7 @@ Functions Functions are executable units of code within a contract. Functions may only be declared within a contract's :ref:`module scope `. -.. code-block:: python +.. code-block:: vyper @external def bid(): @@ -30,7 +30,7 @@ External Functions External functions (marked with the ``@external`` decorator) are a part of the contract interface and may only be called via transactions or from other contracts. -.. code-block:: python +.. code-block:: vyper @external def add_seven(a: int128) -> int128: @@ -52,7 +52,7 @@ Internal Functions Internal functions (marked with the ``@internal`` decorator) are only accessible from other functions within the same contract. They are called via the :ref:`self` object: -.. code-block:: python +.. code-block:: vyper @internal def _times_two(amount: uint256, two: uint256 = 2) -> uint256: @@ -77,7 +77,7 @@ You can optionally declare a function's mutability by using a :ref:`decorator )`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. -.. code-block:: python +.. code-block:: vyper @external @nonreentrant("lock") @@ -133,7 +133,7 @@ This function is always named ``__default__``. It must be annotated with ``@exte If the function is annotated as ``@payable``, this function is executed whenever the contract is sent Ether (without data). This is why the default function cannot accept arguments - it is a design decision of Ethereum to make no differentiation between sending ether to a contract or a user address. -.. code-block:: python +.. code-block:: vyper event Payment: amount: uint256 @@ -169,7 +169,7 @@ The ``__init__`` Function ``__init__`` is a special initialization function that may only be called at the time of deploying a contract. It can be used to set initial values for storage variables. A common use case is to set an ``owner`` variable with the creator the contract: -.. code-block:: python +.. code-block:: vyper owner: address @@ -202,7 +202,7 @@ Decorator Description The ``if`` statement is a control flow construct used for conditional execution: -.. code-block:: python +.. code-block:: vyper if CONDITION: ... @@ -213,7 +213,7 @@ Note that unlike Python, Vyper does not allow implicit conversion from non-boole You can also include ``elif`` and ``else`` statements, to add more conditional statements and a body that executes when the conditionals are false: -.. code-block:: python +.. code-block:: vyper if CONDITION: ... @@ -227,7 +227,7 @@ You can also include ``elif`` and ``else`` statements, to add more conditional s The ``for`` statement is a control flow construct used to iterate over a value: -.. code-block:: python +.. code-block:: vyper for i in : ... @@ -239,7 +239,7 @@ Array Iteration You can use ``for`` to iterate through the values of any array variable: -.. code-block:: python +.. code-block:: vyper foo: int128[3] = [4, 23, 42] for i in foo: @@ -249,7 +249,7 @@ In the above, example, the loop executes three times with ``i`` assigned the val You can also iterate over a literal array, as long as a common type can be determined for each item in the array: -.. code-block:: python +.. code-block:: vyper for i in [4, 23, 42]: ... @@ -264,14 +264,14 @@ Range Iteration Ranges are created using the ``range`` function. The following examples are valid uses of ``range``: -.. code-block:: python +.. code-block:: vyper for i in range(STOP): ... ``STOP`` is a literal integer greater than zero. ``i`` begins as zero and increments by one until it is equal to ``STOP``. -.. code-block:: python +.. code-block:: vyper for i in range(stop, bound=N): ... @@ -280,7 +280,7 @@ Here, ``stop`` can be a variable with integer type, greater than zero. ``N`` mus Another use of range can be with ``START`` and ``STOP`` bounds. -.. code-block:: python +.. code-block:: vyper for i in range(START, STOP): ... @@ -291,7 +291,7 @@ Finally, it is possible to use ``range`` with runtime `start` and `stop` values In this case, Vyper checks at runtime that `end - start <= bound`. ``N`` must be a compile-time constant. -.. code-block:: python +.. code-block:: vyper for i in range(start, end, bound=N): ... diff --git a/docs/event-logging.rst b/docs/event-logging.rst index 904b179e70..4f350d6459 100644 --- a/docs/event-logging.rst +++ b/docs/event-logging.rst @@ -10,7 +10,7 @@ Example of Logging This example is taken from the `sample ERC20 contract `_ and shows the basic flow of event logging: -.. code-block:: python +.. code-block:: vyper # Events of the token. event Transfer: @@ -59,7 +59,7 @@ Declaring Events Let's look at an event declaration in more detail. -.. code-block:: python +.. code-block:: vyper event Transfer: sender: indexed(address) @@ -81,7 +81,7 @@ Event declarations look similar to struct declarations, containing one or more a Note that the first topic of a log record consists of the signature of the name of the event that occurred, including the types of its parameters. It is also possible to create an event with no arguments. In this case, use the ``pass`` statement: -.. code-block:: python +.. code-block:: vyper event Foo: pass @@ -92,7 +92,7 @@ Once an event is declared, you can log (send) events. You can send events as man Logging events is done using the ``log`` statement: -.. code-block:: python +.. code-block:: vyper log Transfer(msg.sender, _to, _amount) diff --git a/docs/index.rst b/docs/index.rst index 69d818cd69..8ee48cdb83 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,4 +1,4 @@ -.. image:: vyper-logo-transparent.svg +.. image:: logo.svg :width: 140px :alt: Vyper logo :align: center diff --git a/docs/interfaces.rst b/docs/interfaces.rst index ab220272d8..803b9daf18 100644 --- a/docs/interfaces.rst +++ b/docs/interfaces.rst @@ -12,7 +12,7 @@ Interfaces can be added to contracts either through inline definition, or by imp The ``interface`` keyword is used to define an inline external interface: -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: view @@ -20,7 +20,7 @@ The ``interface`` keyword is used to define an inline external interface: The defined interface can then be used to make external calls, given a contract address: -.. code-block:: python +.. code-block:: vyper @external def test(foobar: FooBar): @@ -28,7 +28,7 @@ The defined interface can then be used to make external calls, given a contract The interface name can also be used as a type annotation for storage variables. You then assign an address value to the variable to access that interface. Note that casting an address to an interface is possible, e.g. ``FooBar()``: -.. code-block:: python +.. code-block:: vyper foobar_contract: FooBar @@ -42,7 +42,7 @@ The interface name can also be used as a type annotation for storage variables. Specifying ``payable`` or ``nonpayable`` annotation indicates that the call made to the external contract will be able to alter storage, whereas the ``view`` ``pure`` call will use a ``STATICCALL`` ensuring no storage can be altered during execution. Additionally, ``payable`` allows non-zero value to be sent along with the call. -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: pure @@ -70,7 +70,7 @@ Keyword Description The ``default_return_value`` parameter can be used to handle ERC20 tokens affected by the missing return value bug in a way similar to OpenZeppelin's ``safeTransfer`` for Solidity: -.. code-block:: python +.. code-block:: vyper ERC20(USDT).transfer(msg.sender, 1, default_return_value=True) # returns True ERC20(USDT).transfer(msg.sender, 1) # reverts because nothing returned @@ -86,7 +86,7 @@ Interfaces are imported with ``import`` or ``from ... import`` statements. Imported interfaces are written using standard Vyper syntax. The body of each function is ignored when the interface is imported. If you are defining a standalone interface, it is normally specified by using a ``pass`` statement: -.. code-block:: python +.. code-block:: vyper @external def test1(): @@ -98,7 +98,7 @@ Imported interfaces are written using standard Vyper syntax. The body of each fu You can also import a fully implemented contract and Vyper will automatically convert it to an interface. It is even possible for a contract to import itself to gain access to its own interface. -.. code-block:: python +.. code-block:: vyper import greeter as Greeter @@ -118,7 +118,7 @@ Imports via ``import`` With absolute ``import`` statements, you **must** include an alias as a name for the imported package. In the following example, failing to include ``as Foo`` will raise a compile error: -.. code-block:: python +.. code-block:: vyper import contract.foo as Foo @@ -127,7 +127,7 @@ Imports via ``from ... import`` Using ``from`` you can perform both absolute and relative imports. You may optionally include an alias - if you do not, the name of the interface will be the same as the file. -.. code-block:: python +.. code-block:: vyper # without an alias from contract import foo @@ -137,7 +137,7 @@ Using ``from`` you can perform both absolute and relative imports. You may optio Relative imports are possible by prepending dots to the contract name. A single leading dot indicates a relative import starting with the current package. Two leading dots indicate a relative import from the parent of the current package: -.. code-block:: python +.. code-block:: vyper from . import foo from ..interfaces import baz @@ -162,7 +162,7 @@ Built-in Interfaces Vyper includes common built-in interfaces such as `ERC20 `_ and `ERC721 `_. These are imported from ``ethereum.ercs``: -.. code-block:: python +.. code-block:: vyper from ethereum.ercs import ERC20 @@ -175,7 +175,7 @@ Implementing an Interface You can define an interface for your contract with the ``implements`` statement: -.. code-block:: python +.. code-block:: vyper import an_interface as FooBarInterface diff --git a/docs/logo.svg b/docs/logo.svg new file mode 100644 index 0000000000..d2c666074a --- /dev/null +++ b/docs/logo.svg @@ -0,0 +1,4 @@ + + + + diff --git a/docs/natspec.rst b/docs/natspec.rst index a6c2d932e4..90ad5d39b4 100644 --- a/docs/natspec.rst +++ b/docs/natspec.rst @@ -17,7 +17,7 @@ Vyper supports structured documentation for contracts and external functions usi The compiler does not parse docstrings of internal functions. You are welcome to NatSpec in comments for internal functions, however they are not processed or included in the compiler output. -.. code-block:: python +.. code-block:: vyper """ @title A simulator for Bug Bunny, the most famous Rabbit @@ -72,16 +72,16 @@ When parsed by the compiler, documentation such as the one from the above exampl If the above contract is saved as ``carrots.vy`` then you can generate the documentation using: -.. code:: +.. code:: shell - vyper -f userdoc,devdoc carrots.vy + $ vyper -f userdoc,devdoc carrots.vy User Documentation ------------------ The above documentation will produce the following user documentation JSON as output: -.. code-block:: javascript +.. code-block:: json { "methods": { @@ -102,7 +102,7 @@ Developer Documentation Apart from the user documentation file, a developer documentation JSON file should also be produced and should look like this: -.. code-block:: javascript +.. code-block:: json { "author": "Warned Bros", diff --git a/docs/scoping-and-declarations.rst b/docs/scoping-and-declarations.rst index 7165ec6e4d..838720c25b 100644 --- a/docs/scoping-and-declarations.rst +++ b/docs/scoping-and-declarations.rst @@ -8,7 +8,7 @@ Variable Declaration The first time a variable is referenced you must declare its :ref:`type `: -.. code-block:: python +.. code-block:: vyper data: int128 @@ -25,7 +25,7 @@ Declaring Public Variables Storage variables can be marked as ``public`` during declaration: -.. code-block:: python +.. code-block:: vyper data: public(int128) @@ -38,7 +38,7 @@ Declaring Immutable Variables Variables can be marked as ``immutable`` during declaration: -.. code-block:: python +.. code-block:: vyper DATA: immutable(uint256) @@ -55,7 +55,7 @@ Tuple Assignment You cannot directly declare tuple types. However, in certain cases you can use literal tuples during assignment. For example, when a function returns multiple values: -.. code-block:: python +.. code-block:: vyper @internal def foo() -> (int128, int128): @@ -84,13 +84,13 @@ This can be performed when compiling via ``vyper`` by including the ``--storage For example, consider upgrading the following contract: -.. code-block:: python +.. code-block:: vyper # old_contract.vy owner: public(address) balanceOf: public(HashMap[address, uint256]) -.. code-block:: python +.. code-block:: vyper # new_contract.vy owner: public(address) @@ -101,7 +101,7 @@ This would cause an issue when upgrading, as the ``balanceOf`` mapping would be This issue can be avoided by allocating ``balanceOf`` to ``slot1`` using the storage layout overrides. The contract can be compiled with ``vyper new_contract.vy --storage-layout-file new_contract_storage.json`` where ``new_contract_storage.json`` contains the following: -.. code-block:: javascript +.. code-block:: json { "owner": {"type": "address", "slot": 0}, @@ -130,7 +130,7 @@ Accessing Module Scope from Functions Values that are declared in the module scope of a contract, such as storage variables and functions, are accessed via the ``self`` object: -.. code-block:: python +.. code-block:: vyper a: int128 @@ -148,7 +148,7 @@ Name Shadowing It is not permitted for a memory or calldata variable to shadow the name of an immutable or constant value. The following examples will not compile: -.. code-block:: python +.. code-block:: vyper a: constant(bool) = True @@ -157,7 +157,7 @@ It is not permitted for a memory or calldata variable to shadow the name of an i # memory variable cannot have the same name as a constant or immutable variable a: bool = False return a -.. code-block:: python +.. code-block:: vyper a: immutable(bool) @@ -174,7 +174,7 @@ Function Scope Variables that are declared within a function, or given as function input arguments, are visible within the body of that function. For example, the following contract is valid because each declaration of ``a`` only exists within one function's body. -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): @@ -190,14 +190,14 @@ Variables that are declared within a function, or given as function input argume The following examples will not compile: -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): # `a` has already been declared as an input argument a: int128 = 21 -.. code-block:: python +.. code-block:: vyper @external def foo(a: int128): @@ -215,7 +215,7 @@ Block Scopes Logical blocks created by ``for`` and ``if`` statements have their own scope. For example, the following contract is valid because ``x`` only exists within the block scopes for each branch of the ``if`` statement: -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: @@ -226,7 +226,7 @@ Logical blocks created by ``for`` and ``if`` statements have their own scope. Fo In a ``for`` statement, the target variable exists within the scope of the loop. For example, the following contract is valid because ``i`` is no longer available upon exiting the loop: -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: @@ -236,7 +236,7 @@ In a ``for`` statement, the target variable exists within the scope of the loop. The following contract fails to compile because ``a`` has not been declared outside of the loop. -.. code-block:: python +.. code-block:: vyper @external def foo(a: bool) -> int128: diff --git a/docs/statements.rst b/docs/statements.rst index 02854adffd..34f15828a1 100644 --- a/docs/statements.rst +++ b/docs/statements.rst @@ -13,7 +13,7 @@ break The ``break`` statement terminates the nearest enclosing ``for`` loop. -.. code-block:: python +.. code-block:: vyper for i in [1, 2, 3, 4, 5]: if i == a: @@ -26,7 +26,7 @@ continue The ``continue`` statement begins the next cycle of the nearest enclosing ``for`` loop. -.. code-block:: python +.. code-block:: vyper for i in [1, 2, 3, 4, 5]: if i != a: @@ -40,7 +40,7 @@ pass ``pass`` is a null operation — when it is executed, nothing happens. It is useful as a placeholder when a statement is required syntactically, but no code needs to be executed: -.. code-block:: python +.. code-block:: vyper # this function does nothing (yet!) @@ -53,7 +53,7 @@ return ``return`` leaves the current function call with the expression list (or None) as a return value. -.. code-block:: python +.. code-block:: vyper return RETURN_VALUE @@ -69,7 +69,7 @@ log The ``log`` statement is used to log an event: -.. code-block:: python +.. code-block:: vyper log MyEvent(...) @@ -89,7 +89,7 @@ raise The ``raise`` statement triggers an exception and reverts the current call. -.. code-block:: python +.. code-block:: vyper raise "something went wrong" @@ -100,7 +100,7 @@ assert The ``assert`` statement makes an assertion about a given condition. If the condition evaluates falsely, the transaction is reverted. -.. code-block:: python +.. code-block:: vyper assert x > 5, "value too low" @@ -108,7 +108,7 @@ The error string is not required. If it is provided, it is limited to 1024 bytes This method's behavior is equivalent to: -.. code-block:: python +.. code-block:: vyper if not cond: raise "reason" diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index 3861bf4380..561f3000dd 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -10,7 +10,7 @@ This section provides a quick overview of the types of data present within a con .. _structure-versions: Pragmas -============== +======= Vyper supports several source code directives to control compiler modes and help with build reproducibility. @@ -21,7 +21,7 @@ The version pragma ensures that a contract is only compiled by the intended comp As of 0.3.10, the recommended way to specify the version pragma is as follows: -.. code-block:: python +.. code-block:: vyper #pragma version ^0.3.0 @@ -31,7 +31,7 @@ As of 0.3.10, the recommended way to specify the version pragma is as follows: The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: -.. code-block:: python +.. code-block:: vyper # @version ^0.3.0 @@ -43,7 +43,7 @@ Optimization Mode The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For example, adding the following line to a contract will cause it to try to optimize for codesize: -.. code-block:: python +.. code-block:: vyper #pragma optimize codesize @@ -62,13 +62,13 @@ State Variables State variables are values which are permanently stored in contract storage. They are declared outside of the body of any functions, and initially contain the :ref:`default value` for their type. -.. code-block:: python +.. code-block:: vyper storedData: int128 State variables are accessed via the :ref:`self` object. -.. code-block:: python +.. code-block:: vyper self.storedData = 123 @@ -81,7 +81,7 @@ Functions Functions are executable units of code within a contract. -.. code-block:: python +.. code-block:: vyper @external def bid(): @@ -96,7 +96,7 @@ Events Events provide an interface for the EVM's logging facilities. Events may be logged with specially indexed data structures that allow clients, including light clients, to efficiently search for them. -.. code-block:: python +.. code-block:: vyper event Payment: amount: int128 @@ -119,19 +119,19 @@ An interface is a set of function definitions used to enable calls between smart Interfaces can be added to contracts either through inline definition, or by importing them from a separate file. -.. code-block:: python +.. code-block:: vyper interface FooBar: def calculate() -> uint256: view def test1(): nonpayable -.. code-block:: python +.. code-block:: vyper from foo import FooBar Once defined, an interface can then be used to make external calls to a given address: -.. code-block:: python +.. code-block:: vyper @external def test(some_address: address): @@ -144,7 +144,7 @@ Structs A struct is a custom defined type that allows you to group several variables together: -.. code-block:: python +.. code-block:: vyper struct MyStruct: value1: int128 diff --git a/docs/testing-contracts-brownie.rst b/docs/testing-contracts-brownie.rst index bff871d38a..46d8df6ea6 100644 --- a/docs/testing-contracts-brownie.rst +++ b/docs/testing-contracts-brownie.rst @@ -12,7 +12,7 @@ Getting Started In order to use Brownie for testing you must first `initialize a new project `_. Create a new directory for the project, and from within that directory type: -:: +.. code:: shell $ brownie init @@ -24,12 +24,14 @@ Writing a Basic Test Assume the following simple contract ``Storage.vy``. It has a single integer variable and a function to set that value. .. literalinclude:: ../examples/storage/storage.vy - :language: python + :caption: storage.vy + :language: vyper :linenos: We create a test file ``tests/test_storage.py`` where we write our tests in pytest style. .. code-block:: python + :caption: test_storage.py :linenos: import pytest @@ -70,9 +72,10 @@ In this example we are using two fixtures which are provided by Brownie: Testing Events ============== -For the remaining examples, we expand our simple storage contract to include an event and two conditions for a failed transaction: ``AdvancedStorage.vy`` +For the remaining examples, we expand our simple storage contract to include an event and two conditions for a failed transaction: ``advanced_storage.vy`` .. literalinclude:: ../examples/storage/advanced_storage.vy + :caption: advanced_storage.vy :linenos: :language: python diff --git a/docs/testing-contracts-ethtester.rst b/docs/testing-contracts-ethtester.rst index 27e67831de..92522a1eca 100644 --- a/docs/testing-contracts-ethtester.rst +++ b/docs/testing-contracts-ethtester.rst @@ -17,6 +17,7 @@ Prior to testing, the Vyper specific contract conversion and the blockchain rela Since the testing is done in the pytest framework, you can make use of `pytest.ini, tox.ini and setup.cfg `_ and you can use most IDEs' pytest plugins. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :linenos: @@ -30,12 +31,14 @@ Writing a Basic Test Assume the following simple contract ``storage.vy``. It has a single integer variable and a function to set that value. .. literalinclude:: ../examples/storage/storage.vy + :caption: storage.vy :linenos: - :language: python + :language: vyper We create a test file ``test_storage.py`` where we write our tests in pytest style. .. literalinclude:: ../tests/functional/examples/storage/test_storage.py + :caption: test_storage.py :linenos: :language: python @@ -50,18 +53,21 @@ Events and Failed Transactions To test events and failed transactions we expand our simple storage contract to include an event and two conditions for a failed transaction: ``advanced_storage.vy`` .. literalinclude:: ../examples/storage/advanced_storage.vy + :caption: advanced_storage.vy :linenos: - :language: python + :language: vyper Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :pyobject: tx_failed The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction. .. literalinclude:: ../tests/conftest.py + :caption: conftest.py :language: python :pyobject: get_logs @@ -70,5 +76,6 @@ This fixture will return a tuple with all the logs for a certain event and trans Finally, we create a new file ``test_advanced_storage.py`` where we use the new fixtures to test failed transactions and events. .. literalinclude:: ../tests/functional/examples/storage/test_advanced_storage.py + :caption: test_advanced_storage.py :linenos: :language: python diff --git a/docs/types.rst b/docs/types.rst index 0f5bfe7b04..38779c2a4b 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -358,7 +358,7 @@ On the ABI level the Fixed-size bytes array is annotated as ``bytes``. Bytes literals may be given as bytes strings. -.. code-block:: python +.. code-block:: vyper bytes_string: Bytes[100] = b"\x01" @@ -372,7 +372,7 @@ Strings Fixed-size strings can hold strings with equal or fewer characters than the maximum length of the string. On the ABI level the Fixed-size bytes array is annotated as ``string``. -.. code-block:: python +.. code-block:: vyper example_str: String[100] = "Test String" @@ -384,7 +384,7 @@ Flags Flags are custom defined types. A flag must have at least one member, and can hold up to a maximum of 256 members. The members are represented by ``uint256`` values in the form of 2\ :sup:`n` where ``n`` is the index of the member in the range ``0 <= n <= 255``. -.. code-block:: python +.. code-block:: vyper # Defining a flag with two members flag Roles: @@ -430,7 +430,7 @@ Flag members can be combined using the above bitwise operators. While flag membe The ``in`` and ``not in`` operators can be used in conjunction with flag member combinations to check for membership. -.. code-block:: python +.. code-block:: vyper flag Roles: MANAGER @@ -491,7 +491,7 @@ Fixed-size lists hold a finite number of elements which belong to a specified ty Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and flags. -.. code-block:: python +.. code-block:: vyper # Defining a list exampleList: int128[3] @@ -507,7 +507,7 @@ Multidimensional lists are also possible. The notation for the declaration is re A two dimensional list can be declared with ``_name: _ValueType[inner_size][outer_size]``. Elements can be accessed with ``_name[outer_index][inner_index]``. -.. code-block:: python +.. code-block:: vyper # Defining a list with 2 rows and 5 columns and set all values to 0 exampleList2D: int128[5][2] = empty(int128[5][2]) @@ -531,7 +531,7 @@ Dynamic Arrays Dynamic arrays represent bounded arrays whose length can be modified at runtime, up to a bound specified in the type. They can be declared with ``_name: DynArray[_Type, _Integer]``, where ``_Type`` can be of value type or reference type (except mappings). -.. code-block:: python +.. code-block:: vyper # Defining a list exampleList: DynArray[int128, 3] @@ -558,7 +558,7 @@ Dynamic arrays represent bounded arrays whose length can be modified at runtime, .. note:: To keep code easy to reason about, modifying an array while using it as an iterator is disallowed by the language. For instance, the following usage is not allowed: - .. code-block:: python + .. code-block:: vyper for item in self.my_array: self.my_array[0] = item @@ -580,7 +580,7 @@ Struct types can be used inside mappings and arrays. Structs can contain arrays Struct members can be accessed via ``struct.argname``. -.. code-block:: python +.. code-block:: vyper # Defining a struct struct MyStruct: @@ -610,7 +610,7 @@ Mapping types are declared as ``HashMap[_KeyType, _ValueType]``. .. note:: Mappings are only allowed as state variables. -.. code-block:: python +.. code-block:: vyper # Defining a mapping exampleMapping: HashMap[int128, decimal] diff --git a/docs/vyper-by-example.rst b/docs/vyper-by-example.rst index b07842cd25..61b5e51c41 100644 --- a/docs/vyper-by-example.rst +++ b/docs/vyper-by-example.rst @@ -19,7 +19,7 @@ period ends, a predetermined beneficiary will receive the amount of the highest bid. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :linenos: As you can see, this example only has a constructor, two methods to call, and @@ -29,7 +29,7 @@ need for a basic implementation of an auction smart contract. Let's get started! .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-17 @@ -54,7 +54,7 @@ within the same contract. The ``public`` function additionally creates a Now, the constructor. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 22 :lines: 22-27 @@ -72,7 +72,7 @@ caller as we will soon see. With initial setup out of the way, lets look at how our users can make bids. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 33 :lines: 33-46 @@ -95,7 +95,7 @@ We will send back the previous ``highestBid`` to the previous ``highestBidder`` our new ``highestBid`` and ``highestBidder``. .. literalinclude:: ../examples/auctions/simple_open_auction.vy - :language: python + :language: vyper :lineno-start: 60 :lines: 60-85 @@ -141,13 +141,13 @@ Solidity, this blind auction allows for an auction where there is no time pressu .. _counterpart: https://solidity.readthedocs.io/en/v0.5.0/solidity-by-example.html#id2 .. literalinclude:: ../examples/auctions/blind_auction.vy - :language: python + :language: vyper :linenos: While this blind auction is almost functionally identical to the blind auction implemented in Solidity, the differences in their implementations help illustrate the differences between Solidity and Vyper. .. literalinclude:: ../examples/auctions/blind_auction.vy - :language: python + :language: vyper :lineno-start: 28 :lines: 28-30 @@ -184,14 +184,14 @@ we want to explore one way how an escrow system can be implemented trustlessly. Let's go! .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :linenos: This is also a moderately short contract, however a little more complex in logic. Let's break down this contract bit by bit. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 16 :lines: 16-19 @@ -200,7 +200,7 @@ their respective data types. Remember that the ``public`` function allows the variables to be *readable* by an external caller, but not *writeable*. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 22 :lines: 22-29 @@ -215,7 +215,7 @@ in the contract variable ``self.value`` and saves the contract creator into ``True``. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 31 :lines: 31-36 @@ -231,7 +231,7 @@ contract will call the ``selfdestruct()`` function and refunds the seller and subsequently destroys the contract. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 38 :lines: 38-45 @@ -244,7 +244,7 @@ contract has a balance equal to 4 times the item value and the seller must send the item to the buyer. .. literalinclude:: ../examples/safe_remote_purchase/safe_remote_purchase.vy - :language: python + :language: vyper :lineno-start: 47 :lines: 47-61 @@ -276,14 +276,14 @@ Participants will be refunded their respective contributions if the total funding does not reach its target goal. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :linenos: Most of this code should be relatively straightforward after going through our previous examples. Let's dive right in. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-13 @@ -304,7 +304,7 @@ once the crowdfunding period is over—as determined by the ``deadline`` and of all participants. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 9 :lines: 9-15 @@ -317,7 +317,7 @@ a definitive end time for the crowdfunding period. Now lets take a look at how a person can participate in the crowdfund. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 17 :lines: 17-23 @@ -331,7 +331,7 @@ mapping, ``self.nextFunderIndex`` increments appropriately to properly index each participant. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 25 :lines: 25-31 @@ -352,7 +352,7 @@ crowdfunding campaign isn't successful? We're going to need a way to refund all the participants. .. literalinclude:: ../examples/crowdfund.vy - :language: python + :language: vyper :lineno-start: 33 :lines: 33-42 @@ -374,14 +374,14 @@ determined upon calling the ``winningProposals()`` method, which iterates throug all the proposals and returns the one with the greatest number of votes. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :linenos: As we can see, this is the contract of moderate length which we will dissect section by section. Let’s begin! .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-25 @@ -402,7 +402,7 @@ their respective datatypes. Let’s move onto the constructor. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 53 :lines: 53-62 @@ -421,7 +421,7 @@ their respective index in the original array as its key. Now that the initial setup is done, lets take a look at the functionality. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 66 :lines: 66-75 @@ -437,7 +437,7 @@ voting power, we will set their ``weight`` to ``1`` and we will keep track of th total number of voters by incrementing ``voterCount``. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 120 :lines: 120-135 @@ -452,7 +452,7 @@ the delegate had already voted or increase the delegate’s vote ``weight`` if the delegate has not yet voted. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 139 :lines: 139-151 @@ -472,7 +472,7 @@ costs gas. By having the ``@view`` decorator, we let the EVM know that this is a read-only function and we benefit by saving gas fees. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 153 :lines: 153-170 @@ -484,7 +484,7 @@ respectively by looping through all the proposals. ``winningProposal()`` is an external function allowing access to ``_winningProposal()``. .. literalinclude:: ../examples/voting/ballot.vy - :language: python + :language: vyper :lineno-start: 175 :lines: 175-178 @@ -515,7 +515,7 @@ contract, holds all shares of the company at first but can sell them all. Let's get started. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :linenos: .. note:: Throughout this contract, we use a pattern where ``@external`` functions return data from ``@internal`` functions that have the same name prepended with an underscore. This is because Vyper does not allow calls between external functions within the same contract. The internal function handles the logic, while the external function acts as a getter to allow viewing. @@ -526,7 +526,7 @@ that the contract logs. We then declare our global variables, followed by function definitions. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 3 :lines: 3-27 @@ -537,7 +537,7 @@ represents the wei value of a share and ``holdings`` is a mapping that maps an address to the number of shares the address owns. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 29 :lines: 29-40 @@ -548,7 +548,7 @@ company's address is initialized to hold all shares of the company in the ``holdings`` mapping. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 42 :lines: 42-46 @@ -567,7 +567,7 @@ Now, lets take a look at a method that lets a person buy stock from the company's holding. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 51 :lines: 51-64 @@ -579,7 +579,7 @@ and transferred to the sender's in the ``holdings`` mapping. Now that people can buy shares, how do we check someone's holdings? .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 66 :lines: 66-71 @@ -588,7 +588,7 @@ and returns its corresponding stock holdings by keying into ``self.holdings``. Again, an external function ``getHolding()`` is included to allow access. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 72 :lines: 72-76 @@ -596,7 +596,7 @@ To check the ether balance of the company, we can simply call the getter method ``cash()``. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 78 :lines: 78-95 @@ -609,7 +609,7 @@ ether to complete the sale. If all conditions are met, the holdings are deducted from the seller and given to the company. The ethers are then sent to the seller. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 97 :lines: 97-110 @@ -620,7 +620,7 @@ than ``0`` and ``asserts`` whether the sender has enough stocks to send. If both conditions are satisfied, the transfer is made. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 112 :lines: 112-124 @@ -632,7 +632,7 @@ enough funds to pay the amount. If both conditions satisfy, the contract sends its ether to an address. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 126 :lines: 126-130 @@ -641,7 +641,7 @@ shares the company has sold and the price of each share. Internally, we get this value by calling the ``_debt()`` method. Externally it is accessed via ``debt()``. .. literalinclude:: ../examples/stock/company.vy - :language: python + :language: vyper :lineno-start: 132 :lines: 132-138 diff --git a/docs/vyper-logo-transparent.svg b/docs/vyper-logo-transparent.svg deleted file mode 100644 index 18bf3c25e2..0000000000 --- a/docs/vyper-logo-transparent.svg +++ /dev/null @@ -1,11 +0,0 @@ - diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 0e94b32b9d..a9d41cbf69 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -31,7 +31,7 @@ decimals: public(uint8) # NOTE: By declaring `balanceOf` as public, vyper automatically generates a 'balanceOf()' getter # method to allow access to account balances. # The _KeyType will become a required parameter for the getter and it will return _ValueType. -# See: https://vyper.readthedocs.io/en/v0.1.0-beta.8/types.html?highlight=getter#mappings +# See: https://docs.vyperlang.org/en/v0.1.0-beta.8/types.html?highlight=getter#mappings balanceOf: public(HashMap[address, uint256]) # By declaring `allowance` as public, vyper automatically generates the `allowance()` getter allowance: public(HashMap[address, HashMap[address, uint256]]) diff --git a/requirements-docs.txt b/requirements-docs.txt index 5906384fc7..5c19ca7cfd 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,3 @@ +shibuya==2024.1.17 sphinx==7.2.6 -recommonmark==0.7.1 -sphinx_rtd_theme==2.0.0 +sphinx-copybutton==0.5.2 diff --git a/tox.ini b/tox.ini index f9d4c3b60b..b42a13a0ab 100644 --- a/tox.ini +++ b/tox.ini @@ -19,9 +19,9 @@ whitelist_externals = make [testenv:docs] basepython=python3 deps = + shibuya sphinx - sphinx_rtd_theme - recommonmark + sphinx-copybutton commands = sphinx-build {posargs:-E} -b html docs dist/docs -n -q --color From a8c6ea284e85348d76be91e1ac53d92180fcf7b0 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 16:42:45 -0800 Subject: [PATCH 184/201] chore: improve some error messages (#3775) * fix error message for `implements: module` currently, the compiler will panic when it encounters this case. add a suggestion to rename the interface file to `.vyi`. also catch all invalid types with a compiler panic. * add a helpful hint for imports from `vyper.interfaces` hint to try `ethereum.ercs` --- vyper/semantics/analysis/module.py | 14 ++++++++++++-- vyper/semantics/types/utils.py | 7 +++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..9304eb3ded 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -342,7 +342,14 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("not an interface!", node.annotation) + msg = "Not an interface!" + hint = None + if isinstance(type_, ModuleT): + path = type_._module.path + msg += " (Since vyper v0.4.0, interface files are required" + msg += " to have a .vyi suffix.)" + hint = f"try renaming `{path}` to `{path}i`" + raise StructureException(msg, node.annotation, hint=hint) type_.validate_implements(node) @@ -627,6 +634,9 @@ def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alia def _load_import_helper( self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str ) -> Any: + if module_str.startswith("vyper.interfaces"): + hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" + raise ModuleNotFound(module_str, hint=hint) if _is_builtin(module_str): return _load_builtin_import(level, module_str) @@ -724,7 +734,7 @@ def _is_builtin(module_str): def _load_builtin_import(level: int, module_str: str) -> InterfaceT: if not _is_builtin(module_str): - raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None + raise ModuleNotFoundError(f"Not a builtin: {module_str}") builtins_path = vyper.builtins.interfaces.__path__[0] # hygiene: convert to relpath to avoid leaking user directory info diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index c6a4531df8..96c661021f 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -3,6 +3,7 @@ from vyper import ast as vy_ast from vyper.exceptions import ( ArrayIndexException, + CompilerPanic, InstantiationException, InvalidType, StructureException, @@ -158,6 +159,12 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: # call from_annotation to produce a better error message. typ_.from_annotation(node) + if hasattr(typ_, "module_t"): # it's a ModuleInfo + typ_ = typ_.module_t + + if not isinstance(typ_, VyperType): + raise CompilerPanic("Not a type: {typ_}", node) + return typ_ From a3bc3eb50ea10788a688ea79d74d294cd9a418d6 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 12 Feb 2024 17:07:21 -0800 Subject: [PATCH 185/201] feat: add python `sys.path` to vyper path (#3763) this makes it easier to install vyper packages from pip and import them using a regular python workflow. misc: - improve how paths appear in error messages; try hard to make them relative paths. - add `chdir_tmp_path` fixture which chdirs to the `tmp_path` fixture for the duration of the test. --- tests/conftest.py | 7 ++++ .../syntax/modules/test_initializers.py | 6 ++-- .../cli/vyper_compile/test_compile_files.py | 33 +++++++++++++++++++ vyper/cli/vyper_compile.py | 14 ++++++-- vyper/semantics/analysis/module.py | 14 +++++++- 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e673f17b35..6eb34a3e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from web3.contract import Contract from web3.providers.eth_tester import EthereumTesterProvider +from tests.utils import working_directory from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode @@ -79,6 +80,12 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture +def chdir_tmp_path(tmp_path): + with working_directory(tmp_path): + yield + + @pytest.fixture def keccak(): return Web3.keccak diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index a12f5f57ea..d0523153c8 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -326,7 +326,7 @@ def foo(): assert e.value._hint == "did you mean `m := lib1`?" -def test_global_initializer_constraint(make_input_bundle): +def test_global_initializer_constraint(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -818,7 +818,7 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint -def test_invalid_uses(make_input_bundle): +def test_invalid_uses(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -848,7 +848,7 @@ def foo(): assert e.value._hint == "delete `uses: lib1`" -def test_invalid_uses2(make_input_bundle): +def test_invalid_uses2(make_input_bundle, chdir_tmp_path): # test a more complicated invalid uses lib1 = """ counter: uint256 diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a65d66835..6adee24db6 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -1,3 +1,5 @@ +import contextlib +import sys from pathlib import Path import pytest @@ -257,3 +259,34 @@ def foo() -> uint256: contract_file = make_file("contract.vy", contract_source) assert compile_files([contract_file], ["combined_json"], paths=[tmp_path]) is not None + + +@contextlib.contextmanager +def mock_sys_path(path): + try: + sys.path.append(path) + yield + finally: + sys.path.pop() + + +def test_import_sys_path(tmp_path_factory, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + tmpdir = tmp_path_factory.mktemp("test-sys-path") + with open(tmpdir / "lib.vy", "w") as f: + f.write(library_source) + + contract_file = make_file("contract.vy", contract_source) + with mock_sys_path(tmpdir): + assert compile_files([contract_file], ["combined_json"]) is not None diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index d6ba9e180a..ac69cf3310 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -238,10 +238,18 @@ def compile_files( storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, ) -> dict: - paths = paths or [] + # lowest precedence search path is always sys path + search_paths = [Path(p) for p in sys.path] + + # python sys path uses opposite resolution order from us + # (first in list is highest precedence; we give highest precedence + # to the last in the list) + search_paths.reverse() - # lowest precedence search path is always `.` - search_paths = [Path(".")] + if Path(".") not in search_paths: + search_paths.append(Path(".")) + + paths = paths or [] for p in paths: path = Path(p).resolve(strict=True) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 9304eb3ded..43b11497ec 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -705,10 +705,22 @@ def _load_import_helper( def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, - module_path=str(file.path), + module_path=str(module_path), resolved_path=str(file.resolved_path), ) return ret From 7bdebbf12798ccda4285653f630c1a6b1d4af5b8 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 06:14:24 -0800 Subject: [PATCH 186/201] fix: iterator modification analysis (#3764) this commit fixes several bugs with analysis of iterator modification in loops. to do so, it refactors the analysis code to track reads/writes more accurately, and uses analysis machinery instead of AST queries to perform the check. it enriches ExprInfo with an `attr` attribute, so this can be used to detect if an ExprInfo is derived from an `Attribute`. ExprInfo could be further enriched with `Subscript` info so that the Attribute/Subscript chain can be reliably recovered just from ExprInfos, especially in the future if other functions rely on being able to recover the attribute chain. this commit also modifies `validate_functions` so that it validates the functions in dependency (call graph traversal) order rather than the order they appear in the AST. refactors: - add `enter_for_loop()` context manager for convenience+clarity - remove `ExprInfo.attribute_chain`, it was too confusing - hide `ContractFunctionT` member variables (`_variable_reads`, `_variable_writes`, `_used_modules`) behind public-facing API - remove `get_root_varinfo()` in favor of a helper `_get_variable_access()` function which detects access on variable sub-members (e.g., structs). --- .../features/iteration/test_for_in_list.py | 56 ++- .../syntax/modules/test_initializers.py | 42 +++ .../unit/semantics/analysis/test_for_loop.py | 105 ++++++ vyper/ast/nodes.pyi | 8 +- vyper/codegen/expr.py | 58 ++-- vyper/semantics/analysis/base.py | 50 ++- vyper/semantics/analysis/local.py | 326 ++++++++++-------- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/analysis/utils.py | 16 +- vyper/semantics/environment.py | 4 +- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/function.py | 42 ++- vyper/semantics/types/primitives.py | 10 + 13 files changed, 505 insertions(+), 216 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 36252701c4..e1bd8f313d 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -3,6 +3,7 @@ import pytest +from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, ImmutableViolation, @@ -841,6 +842,59 @@ def foo(): ] +# TODO: move these to tests/functional/syntax @pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) def test_bad_code(assert_compile_failed, get_contract, code, err): - assert_compile_failed(lambda: get_contract(code), err) + with pytest.raises(err): + compile_code(code) + + +def test_iterator_modification_module_attribute(make_input_bundle): + # test modifying iterator via attribute + lib1 = """ +queue: DynArray[uint256, 5] + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.queue.pop() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_module_function_call(make_input_bundle): + lib1 = """ +queue: DynArray[uint256, 5] + +@internal +def popqueue(): + self.queue.pop() + """ + main = """ +import lib1 + +initializes: lib1 + +@external +def foo(): + for i: uint256 in lib1.queue: + lib1.popqueue() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot modify loop variable `queue`" diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index d0523153c8..d0965ae61d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -741,6 +741,48 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint +def test_missing_uses_subscript(make_input_bundle): + # test missing uses through nested subscript/attribute access + lib1 = """ +struct Foo: + array: uint256[5] + +foos: Foo[5] + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.foos[0].array[1] = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + def test_missing_uses_nested_attribute_function_call(make_input_bundle): # test missing uses through nested attribute access lib1 = """ diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 607587cc28..c97c9c095e 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -134,6 +134,111 @@ def baz(): validate_semantics(vyper_module, dummy_input_bundle) +def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): + # test the analysis works no matter the order of functions + code = """ +a: uint256[3] + +@internal +def baz(): + for i: uint256 in self.a: + self.bar() + +@internal +def bar(): + self.foo() + +@internal +def foo(): + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_through_struct(dummy_input_bundle): + # GH issue 3429 + code = """ +struct A: + iter: DynArray[uint256, 5] + +a: A + +@external +def foo(): + self.a.iter = [1, 2, 3] + for i: uint256 in self.a.iter: + self.a = A({iter: [1, 2, 3, 4]}) + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `a`" + + +def test_modify_iterator_complex_expr(dummy_input_bundle): + # GH issue 3429 + # avoid false positive! + code = """ +a: DynArray[uint256, 5] +b: uint256[10] + +@external +def foo(): + self.a = [1, 2, 3] + for i: uint256 in self.a: + self.b[self.a[1]] = i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_iterator_siblings(dummy_input_bundle): + # test we can modify siblings in an access tree + code = """ +struct Foo: + a: uint256[2] + b: uint256 + +f: Foo + +@external +def foo(): + for i: uint256 in self.f.a: + self.f.b += i + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, dummy_input_bundle) + + +def test_modify_subscript_barrier(dummy_input_bundle): + # test that Subscript nodes are a barrier for analysis + code = """ +struct Foo: + x: uint256[2] + y: uint256 + +struct Bar: + f: Foo[2] + +b: Bar + +@external +def foo(): + for i: uint256 in self.b.f[1].x: + self.b.f[0].y += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ImmutableViolation) as e: + validate_semantics(vyper_module, dummy_input_bundle) + + assert e.value._message == "Cannot modify loop variable `b`" + + iterator_inference_codes = [ """ @external diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7f863a8db9..342c84876a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -200,13 +200,13 @@ class Call(ExprNode): class keyword(VyperNode): ... -class Attribute(VyperNode): +class Attribute(ExprNode): attr: str = ... value: ExprNode = ... -class Subscript(VyperNode): - slice: VyperNode = ... - value: VyperNode = ... +class Subscript(ExprNode): + slice: ExprNode = ... + value: ExprNode = ... class Assign(VyperNode): ... diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 335cfefb87..9c7f11dcb3 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -263,24 +263,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif (varinfo := self.expr._expr_info.var_info) is not None: - if varinfo.is_constant: - return Expr.parse_value_expr(varinfo.decl_node.value, self.context) - - location = data_location_to_address_space( - varinfo.location, self.context.is_ctor_context - ) - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -336,17 +318,37 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global attribute + if (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) + + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation="self." + self.expr.attr, + ) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + # contract type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 2086e5f9da..49b867aae5 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast @@ -193,6 +193,17 @@ def is_constant(self): return res +@dataclass(frozen=True) +class VarAccess: + variable: VarInfo + attrs: tuple[str, ...] + + def contains(self, other): + # VarAccess("v", ("a")) `contains` VarAccess("v", ("a", "b", "c")) + sub_attrs = other.attrs[: len(self.attrs)] + return self.variable == other.variable and sub_attrs == self.attrs + + @dataclass class ExprInfo: """ @@ -204,9 +215,7 @@ class ExprInfo: module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE - - # the chain of attribute parents for this expr - attribute_chain: list["ExprInfo"] = field(default_factory=list) + attr: Optional[str] = None def __post_init__(self): should_match = ("typ", "location", "modifiability") @@ -215,48 +224,35 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") - self._writes: OrderedSet[VarInfo] = OrderedSet() - self._reads: OrderedSet[VarInfo] = OrderedSet() - - # find exprinfo in the attribute chain which has a varinfo - # e.x. `x` will return varinfo for `x` - # `module.foo` will return varinfo for `module.foo` - # `self.my_struct.x.y` will return varinfo for `self.my_struct` - def get_root_varinfo(self) -> Optional[VarInfo]: - for expr_info in self.attribute_chain + [self]: - if expr_info.var_info is not None: - return expr_info.var_info - return None + self._writes: OrderedSet[VarAccess] = OrderedSet() + self._reads: OrderedSet[VarAccess] = OrderedSet() @classmethod - def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, **kwargs) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, - attribute_chain=attribute_chain or [], + **kwargs, ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + def from_moduleinfo(cls, module_info: ModuleInfo, **kwargs) -> "ExprInfo": modifiability = Modifiability.RUNTIME_CONSTANT if module_info.ownership >= ModuleOwnership.USES: modifiability = Modifiability.MODIFIABLE return cls( - module_info.module_t, - module_info=module_info, - modifiability=modifiability, - attribute_chain=attribute_chain or [], + module_info.module_t, module_info=module_info, modifiability=modifiability, **kwargs ) - def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": + def copy_with_type(self, typ: VyperType, **kwargs) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} - if attribute_chain is not None: - fields["attribute_chain"] = attribute_chain - return self.__class__(typ=typ, **fields) + for t in to_copy: + assert t not in kwargs + return self.__class__(typ=typ, **fields, **kwargs) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..39a1c59290 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,5 +1,6 @@ # CMC 2024-02-03 TODO: split me into function.py and expr.py +import contextlib from typing import Optional from vyper import ast as vy_ast @@ -19,7 +20,13 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo +from vyper.semantics.analysis.base import ( + Modifiability, + ModuleInfo, + ModuleOwnership, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -58,18 +65,33 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() - namespace = get_namespace() + for node in vy_module.get_children(vy_ast.FunctionDef): - with namespace.enter_scope(): - try: - analyzer = FunctionAnalyzer(vy_module, node, namespace) - analyzer.analyze() - except VyperException as e: - err_list.append(e) + _validate_function_r(vy_module, node, err_list) err_list.raise_if_not_empty() +def _validate_function_r( + vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList +): + func_t = node._metadata["func_type"] + + for call_t in func_t.called_functions: + if isinstance(call_t, ContractFunctionT): + assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy + _validate_function_r(vy_module, call_t.ast_def, err_list) + + namespace = get_namespace() + + try: + with namespace.enter_scope(): + analyzer = FunctionAnalyzer(vy_module, node, namespace) + analyzer.analyze() + except VyperException as e: + err_list.append(e) + + # finds the terminus node for a list of nodes. # raises an exception if any nodes are unreachable def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: @@ -99,36 +121,6 @@ def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: return ret -def _check_iterator_modification( - target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode -) -> Optional[vy_ast.VyperNode]: - similar_nodes = [ - n - for n in search_node.get_descendants(type(target_node)) - if vy_ast.compare_nodes(target_node, n) - ] - - for node in similar_nodes: - # raise if the node is the target of an assignment statement - assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - # note the use of get_descendants() blocks statements like - # self.my_array[i] = x - if assign_node and node in assign_node.target.get_descendants(include_self=True): - return node - - attr_node = node.get_ancestor(vy_ast.Attribute) - # note the use of get_descendants() blocks statements like - # self.my_array[i].append(x) - if ( - attr_node is not None - and node in attr_node.value.get_descendants(include_self=True) - and attr_node.attr in ("append", "pop", "extend") - ): - return node - - return None - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -183,6 +175,62 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) +# analyse the variable access for the attribute chain for a node +# e.x. `x` will return varinfo for `x` +# `module.foo` will return VarAccess for `module.foo` +# `self.my_struct.x.y` will return VarAccess for `self.my_struct.x.y` +def _get_variable_access(node: vy_ast.ExprNode) -> Optional[VarAccess]: + attrs: list[str] = [] + info = get_expr_info(node) + + while info.var_info is None: + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + # it's something like a literal + return None + + if isinstance(node, vy_ast.Subscript): + # Subscript is an analysis barrier + # we cannot analyse if `x.y[ix1].z` overlaps with `x.y[ix2].z`. + attrs.clear() + + if (attr := info.attr) is not None: + attrs.append(attr) + + assert isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)) # help mypy + node = node.value + info = get_expr_info(node) + + # ignore `self.` as it interferes with VarAccess comparison across modules + if len(attrs) > 0 and attrs[-1] == "self": + attrs.pop() + attrs.reverse() + + return VarAccess(info.var_info, tuple(attrs)) + + +# get the chain of modules, e.g. +# mod1.mod2.x.y -> [ModuleInfo(mod1), ModuleInfo(mod2)] +# CMC 2024-02-12 note that the Attribute/Subscript traversal in this and +# _get_variable_access() are a bit gross and could probably +# be refactored into data on ExprInfo. +def _get_module_chain(node: vy_ast.ExprNode) -> list[ModuleInfo]: + ret: list[ModuleInfo] = [] + info = get_expr_info(node) + + while True: + if info.module_info is not None: + ret.append(info.module_info) + + if not isinstance(node, (vy_ast.Subscript, vy_ast.Attribute)): + break + + node = node.value + info = get_expr_info(node) + + ret.reverse() + return ret + + class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -196,7 +244,16 @@ def __init__( self.func = fn_node._metadata["func_type"] self.expr_visitor = ExprVisitor(self) + self.loop_variables: list[Optional[VarAccess]] = [] + def analyze(self): + if self.func.analysed: + return + + # mark seen before analysing, if analysis throws an exception which + # gets caught, we don't want to analyse again. + self.func.mark_analysed() + # allow internal function params to be mutable if self.func.is_internal: location, modifiability = (DataLocation.MEMORY, Modifiability.MODIFIABLE) @@ -225,6 +282,14 @@ def analyze(self): for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + @contextlib.contextmanager + def enter_for_loop(self, varaccess: Optional[VarAccess]): + self.loop_variables.append(varaccess) + try: + yield + finally: + self.loop_variables.pop() + def visit(self, node): super().visit(node) @@ -326,16 +391,13 @@ def _handle_modification(self, target: vy_ast.ExprNode): if info.modifiability == Modifiability.CONSTANT: raise ImmutableViolation("Constant value cannot be written to.") - var_info = info.get_root_varinfo() - assert var_info is not None + var_access = _get_variable_access(target) + assert var_access is not None - info._writes.add(var_info) + info._writes.add(var_access) def _check_module_use(self, target: vy_ast.ExprNode): - module_infos = [] - for t in get_expr_info(target).attribute_chain: - if t.module_info is not None: - module_infos.append(t.module_info) + module_infos = _get_module_chain(target) if len(module_infos) == 0: return @@ -352,7 +414,7 @@ def _check_module_use(self, target: vy_ast.ExprNode): root_module_info = module_infos[0] # log the access - self.func._used_modules.add(root_module_info) + self.func.mark_used_module(root_module_info) def visit_Assign(self, node): self._assign_helper(node) @@ -403,96 +465,68 @@ def visit_Expr(self, node): ) self.expr_visitor.visit(node.value, return_value) + def _analyse_range_iter(self, iter_node, target_type): + # iteration via range() + if iter_node.get("func.id") != "range": + raise IteratorException("Cannot iterate over the result of a function call", iter_node) + _validate_range_call(iter_node) + + args = iter_node.args + kwargs = [s.value for s in iter_node.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + + def _analyse_list_iter(self, iter_node, target_type): + # iteration over a variable or literal list + iter_val = iter_node + if iter_val.has_folded_value: + iter_val = iter_val.get_folded_value() + + if isinstance(iter_val, vy_ast.List): + len_ = len(iter_val.elements) + if len_ == 0: + raise StructureException("For loop must have at least 1 iteration", iter_node) + iter_type = SArrayT(target_type, len_) + else: + try: + iter_type = get_exact_type_from_node(iter_node) + except (InvalidType, StructureException): + raise InvalidType("Not an iterable type", iter_node) + + # CMC 2024-02-09 TODO: use validate_expected_type once we have DArrays + # with generic length. + if not isinstance(iter_type, (DArrayT, SArrayT)): + raise InvalidType("Not an iterable type", iter_node) + + self.expr_visitor.visit(iter_node, iter_type) + + # get the root varinfo from iter_val in case we need to peer + # through folded constants + return _get_variable_access(iter_val) + def visit_For(self, node): if not isinstance(node.target.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target.target) target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) + iter_var = None if isinstance(node.iter, vy_ast.Call): - # iteration via range() - if node.iter.get("func.id") != "range": - raise IteratorException( - "Cannot iterate over the result of a function call", node.iter - ) - _validate_range_call(node.iter) - + self._analyse_range_iter(node.iter, target_type) else: - # iteration over a variable or literal list - iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter - if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: - raise StructureException("For loop must have at least 1 iteration", node.iter) - - if not any( - isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) - ): - raise InvalidType("Not an iterable type", node.iter) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # check for references to the iterated value within the body of the loop - assign = _check_iterator_modification(node.iter, node) - if assign: - raise ImmutableViolation("Cannot modify array during iteration", assign) - - # Check if `iter` is a storage variable. get_descendants` is used to check for - # nested `self` (e.g. structs) - # NOTE: this analysis will be borked once stateful modules are allowed! - iter_is_storage_var = ( - isinstance(node.iter, vy_ast.Attribute) - and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 - ) - - if iter_is_storage_var: - # check if iterated value may be modified by function calls inside the loop - iter_name = node.iter.attr - for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): - fn_name = call_node.func.attr - - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] - if _check_iterator_modification(node.iter, fn_node): - # check for direct modification - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it potentially " - f"modifies iterated storage variable '{iter_name}'", - call_node, - ) + iter_var = self._analyse_list_iter(node.iter, target_type) - for reachable_t in ( - self.namespace["self"].typ.members[fn_name].reachable_internal_functions - ): - # check for indirect modification - name = reachable_t.name - fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] - if _check_iterator_modification(node.iter, fn_node): - raise ImmutableViolation( - f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " - f"which potentially modifies iterated storage variable '{iter_name}'", - call_node, - ) - - target_name = node.target.target.id - with self.namespace.enter_scope(): + with self.namespace.enter_scope(), self.enter_for_loop(iter_var): + target_name = node.target.target.id + # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( target_type, modifiability=Modifiability.RUNTIME_CONSTANT ) + self.expr_visitor.visit(node.target.target, target_type) for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target.target, target_type) - - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) - elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - args = node.iter.args - kwargs = [s.value for s in node.iter.keywords] - for arg in (*args, *kwargs): - self.expr_visitor.visit(arg, target_type) - else: - iter_type = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, iter_type) - def visit_If(self, node): self.expr_visitor.visit(node.test, BoolT()) with self.namespace.enter_scope(): @@ -577,18 +611,32 @@ def visit(self, node, typ): # log variable accesses. # (note writes will get logged as both read+write) - varinfo = info.var_info - if varinfo is not None: - info._reads.add(varinfo) + var_access = _get_variable_access(node) + if var_access is not None: + info._reads.add(var_access) + + if self.function_analyzer: + for s in self.function_analyzer.loop_variables: + if s is None: + continue + + for v in info._writes: + if not v.contains(s): + continue + + msg = "Cannot modify loop variable" + var = s.variable + if var.decl_node is not None: + msg += f" `{var.decl_node.target.id}`" + raise ImmutableViolation(msg, var.decl_node, node) - if self.func: variable_accesses = info._writes | info._reads for s in variable_accesses: - if s.is_module_variable(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node) - self.func._variable_writes.update(info._writes) - self.func._variable_reads.update(info._reads) + self.func.mark_variable_writes(info._writes) + self.func.mark_variable_reads(info._reads) # validate and annotate folded value if node.has_folded_value: @@ -641,24 +689,23 @@ def _check_call_mutability(self, call_mutability: StateMutability): def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: func_info = get_expr_info(node.func, is_callable=True) func_type = func_info.typ - self.visit(node.func, func_type) if isinstance(func_type, ContractFunctionT): # function calls - func_info._writes.update(func_type._variable_writes) - func_info._reads.update(func_type._variable_reads) + if not func_type.from_interface: + for s in func_type.get_variable_writes(): + if s.variable.is_module_variable(): + func_info._writes.add(s) + for s in func_type.get_variable_reads(): + if s.variable.is_module_variable(): + func_info._reads.add(s) if self.function_analyzer: - if func_type.is_internal: - self.func.called_functions.add(func_type) - self._check_call_mutability(func_type.mutability) - # check that if the function accesses state, the defining - # module has been `used` or `initialized`. - for s in func_type._variable_accesses: - if s.is_module_variable(): + for s in func_type.get_variable_accesses(): + if s.variable.is_module_variable(): self.function_analyzer._check_module_use(node.func) if func_type.is_deploy and not self.func.is_deploy: @@ -689,7 +736,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this - self.function_analyzer._handle_modification(node.func) + assert isinstance(node.func, vy_ast.Attribute) # help mypy + self.function_analyzer._handle_modification(node.func.value) assert len(node.args) == len(func_type.arg_types) for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) @@ -702,6 +750,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) + self.visit(node.func, func_type) + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 43b11497ec..10acef59da 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -244,7 +244,7 @@ def validate_used_modules(self): all_used_modules = OrderedSet() for f in module_t.functions.values(): - for u in f._used_modules: + for u in f.get_used_modules(): all_used_modules.add(u.module_t) for used_module in all_used_modules: diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..034cd8c46e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -84,28 +84,24 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex # propagate the parent exprinfo members down into the new expr # note: Attribute(expr value, identifier attr) - name = node.attr info = self.get_expr_info(node.value, is_callable=is_callable) + attr = node.attr - attribute_chain = info.attribute_chain + [info] - - t = info.typ.get_member(name, node) + t = info.typ.get_member(attr, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_varinfo(t, attr=attr) if isinstance(t, ModuleInfo): - return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) + return ExprInfo.from_moduleinfo(t, attr=attr) - # it's something else, like my_struct.foo - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t, attr=attr) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - attribute_chain = info.attribute_chain + [info] - return info.copy_with_type(t, attribute_chain=attribute_chain) + return info.copy_with_type(t) return ExprInfo(t) diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index 38bac0a63d..94a26157af 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -1,7 +1,7 @@ from typing import Dict from vyper.semantics.analysis.base import Modifiability, VarInfo -from vyper.semantics.types import AddressT, BytesT, VyperType +from vyper.semantics.types import AddressT, BytesT, SelfT, VyperType from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T @@ -57,7 +57,7 @@ def get_constant_vars() -> Dict: return result -MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": AddressT} +MUTABLE_ENVIRONMENT_VARS: Dict[str, type] = {"self": SelfT} def get_mutable_vars() -> Dict: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index a04632b96f..59a20dd99f 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -3,7 +3,7 @@ from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT from .module import InterfaceT -from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT +from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..705470a798 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -21,7 +21,7 @@ Modifiability, ModuleInfo, StateMutability, - VarInfo, + VarAccess, VarOffset, ) from vyper.semantics.analysis.utils import ( @@ -92,6 +92,7 @@ def __init__( return_type: Optional[VyperType], function_visibility: FunctionVisibility, state_mutability: StateMutability, + from_interface: bool = False, nonreentrant: Optional[str] = None, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: @@ -104,9 +105,12 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.from_interface = from_interface self.ast_def = ast_def + self._analysed = False + # a list of internal functions this function calls. # to be populated during analysis self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() @@ -115,10 +119,10 @@ def __init__( self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function - self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + self._variable_writes: OrderedSet[VarAccess] = OrderedSet() # reads of variables from this function - self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + self._variable_reads: OrderedSet[VarAccess] = OrderedSet() # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() @@ -127,10 +131,35 @@ def __init__( self._ir_info: Any = None self._function_id: Optional[int] = None + def mark_analysed(self): + assert not self._analysed + self._analysed = True + @property - def _variable_accesses(self): + def analysed(self): + return self._analysed + + def get_variable_reads(self): + return self._variable_reads + + def get_variable_writes(self): + return self._variable_writes + + def get_variable_accesses(self): return self._variable_reads | self._variable_writes + def get_used_modules(self): + return self._used_modules + + def mark_used_module(self, module_info): + self._used_modules.add(module_info) + + def mark_variable_writes(self, var_infos): + self._variable_writes.update(var_infos) + + def mark_variable_reads(self, var_infos): + self._variable_reads.update(var_infos) + @property def modifiability(self): return Modifiability.from_state_mutability(self.mutability) @@ -189,6 +218,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": positional_args, [], return_type, + from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), ) @@ -248,6 +278,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=None, ast_def=funcdef, ) @@ -300,6 +331,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=True, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -370,6 +402,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": return_type, function_visibility, state_mutability, + from_interface=False, nonreentrant=nonreentrant_key, ast_def=funcdef, ) @@ -410,6 +443,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio args, [], return_type, + from_interface=False, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, ast_def=node, diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 07d1a21a94..d383f72ab2 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -340,3 +340,13 @@ def validate_literal(self, node: vy_ast.Constant) -> None: f"address, the correct checksummed form is: {checksum_encode(addr)}", node, ) + + +# type for "self" +# refactoring note: it might be best for this to be a ModuleT actually +class SelfT(AddressT): + _id = "self" + + def compare_type(self, other): + # compares true to AddressT + return isinstance(other, type(self)) or isinstance(self, type(other)) From 199f2b65e43e3d3f055756039ef4a9bce7f6f3cf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 06:56:41 -0800 Subject: [PATCH 187/201] feat[lang]: remove named reentrancy locks (#3769) this commit removes "fine-grained" nonreentrancy locks (i.e., reentrancy locks with names) from vyper. they aren't really used (all known production contracts just use a single global named lock) , and in any case such a use case should better be implemented manually by the user. this simplifies the language and allows moderate simplification to the storage allocator, although some complexity is added because the global restriction has to have special handling (it cannot be handled simply in the recursion into child modules). refactors: - the routine for allocating nonreentrant keys has been refactored into a helper function. --- docs/control-structures.rst | 12 +- .../features/decorators/test_nonreentrant.py | 139 ++++++++++++++---- .../exceptions/test_structure_exception.py | 31 ---- .../test_invalid_function_decorators.py | 15 +- .../cli/storage_layout/test_storage_layout.py | 75 ++++++---- .../test_storage_layout_overrides.py | 34 ++++- tests/unit/semantics/test_storage_slots.py | 11 +- vyper/semantics/analysis/data_positions.py | 83 ++++++----- vyper/semantics/types/function.py | 70 ++++----- 9 files changed, 291 insertions(+), 179 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index a0aa927261..4e18a21bd8 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -100,22 +100,24 @@ Functions marked with ``@pure`` cannot call non-``pure`` functions. Re-entrancy Locks ----------------- -The ``@nonreentrant()`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. +The ``@nonreentrant`` decorator places a global nonreentrancy lock on a function. An attempt by an external contract to call back into any other ``@nonreentrant`` function causes the transaction to revert. .. code-block:: vyper @external - @nonreentrant("lock") + @nonreentrant def make_a_call(_addr: address): # this function is protected from re-entrancy ... -You can put the ``@nonreentrant()`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. +You can put the ``@nonreentrant`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. Nonreentrancy locks work by setting a specially allocated storage slot to a ```` value on function entrance, and setting it to an ```` value on function exit. On function entrance, if the storage slot is detected to be the ```` value, execution reverts. You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can put it on a ``view`` function, but it only checks that the function is not in a callback (the storage slot is not in the ```` state), as ``view`` functions can only read the state, not change it. +You can view where the nonreentrant key is physically laid out in storage by using ``vyper`` with the ``-f layout`` option (e.g., ``vyper -f layout foo.vy``). Unless it is overriden, the compiler will allocate it at slot ``0``. + .. note:: A mutable function can protect a ``view`` function from being called back into (which is useful for instance, if a ``view`` function would return inconsistent state during a mutable function), but a ``view`` function cannot protect itself from being called back into. Note that mutable functions can never be called from a ``view`` function because all external calls out from a ``view`` function are protected by the use of the ``STATICCALL`` opcode. @@ -123,6 +125,8 @@ You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can p A nonreentrant lock has an ```` value of 3, and a ```` value of 2. Nonzero values are used to take advantage of net gas metering - as of the Berlin hard fork, the net cost for utilizing a nonreentrant lock is 2300 gas. Prior to v0.3.4, the ```` and ```` values were 0 and 1, respectively. +.. note:: + Prior to 0.4.0, nonreentrancy keys took a "key" argument for fine-grained nonreentrancy control. As of 0.4.0, only a global nonreentrancy lock is available. The ``__default__`` Function ---------------------------- @@ -194,7 +198,7 @@ Decorator Description ``@pure`` Function does not read contract state or environment variables ``@view`` Function does not alter contract state ``@payable`` Function is able to receive Ether -``@nonreentrant()`` Function cannot be called back into during an external call +``@nonreentrant`` Function cannot be called back into during an external call =============================== =========================================================== ``if`` statements diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9329605678..92a21cd302 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -2,30 +2,103 @@ from vyper.exceptions import FunctionDeclarationException - # TODO test functions in this module across all evm versions # once we have cancun support. + + def test_nonreentrant_decorator(get_contract, tx_failed): - calling_contract_code = """ -interface SpecialContract: + malicious_code = """ +interface ProtectedContract: + def protected_function(callback_address: address): nonpayable + +@external +def do_callback(): + ProtectedContract(msg.sender).protected_function(self) + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_nonreentrant_view_function(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: + def protected_function(): nonpayable + def protected_view_fn() -> uint256: view + +@external +def do_callback() -> uint256: + return ProtectedContract(msg.sender).protected_view_fn() + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +@external +@nonreentrant +@view +def protected_view_fn() -> uint256: + return 10 + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_multi_function_nonreentrant(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable def protected_function(val: String[100], do_callback: bool): nonpayable def special_value() -> String[100]: nonpayable @external def updated(): - SpecialContract(msg.sender).unprotected_function('surprise!', False) + ProtectedContract(msg.sender).unprotected_function('surprise!', False) @external def updated_protected(): # This should fail. - SpecialContract(msg.sender).protected_function('surprise protected!', False) + ProtectedContract(msg.sender).protected_function('surprise protected!', False) """ - reentrant_code = """ + protected_code = """ interface Callback: def updated(): nonpayable def updated_protected(): nonpayable + interface Self: def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable @@ -39,7 +112,7 @@ def set_callback(c: address): self.callback = Callback(c) @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val @@ -50,7 +123,7 @@ def protected_function(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function2(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -60,7 +133,7 @@ def protected_function2(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function3(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -71,7 +144,8 @@ def protected_function3(val: String[100], do_callback: bool) -> uint256: @external -@nonreentrant('protect_special_value') +@nonreentrant +@view def protected_view_fn() -> String[100]: return self.special_value @@ -81,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool): if do_callback: self.callback.updated() - """ - reentrant_contract = get_contract(reentrant_code) - calling_contract = get_contract(calling_contract_code) +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) - reentrant_contract.set_callback(calling_contract.address, transact={}) - assert reentrant_contract.callback() == calling_contract.address + contract.set_callback(malicious.address, transact={}) + assert contract.callback() == malicious.address # Test unprotected function. - reentrant_contract.unprotected_function("some value", True, transact={}) - assert reentrant_contract.special_value() == "surprise!" + contract.unprotected_function("some value", True, transact={}) + assert contract.special_value() == "surprise!" # Test protected function. - reentrant_contract.protected_function("some value", False, transact={}) - assert reentrant_contract.special_value() == "some value" - assert reentrant_contract.protected_view_fn() == "some value" + contract.protected_function("some value", False, transact={}) + assert contract.special_value() == "some value" + assert contract.protected_view_fn() == "some value" with tx_failed(): - reentrant_contract.protected_function("zzz value", True, transact={}) + contract.protected_function("zzz value", True, transact={}) - reentrant_contract.protected_function2("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function2("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function2("zzz value", True, transact={}) + contract.protected_function2("zzz value", True, transact={}) - reentrant_contract.protected_function3("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function3("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function3("zzz value", True, transact={}) + contract.protected_function3("zzz value", True, transact={}) def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): @@ -145,7 +224,7 @@ def set_callback(c: address): @external @payable -@nonreentrant("lock") +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val _amount: uint256 = msg.value @@ -169,7 +248,7 @@ def unprotected_function(val: String[100], do_callback: bool): @external @payable -@nonreentrant("lock") +@nonreentrant def __default__(): pass """ @@ -209,7 +288,7 @@ def test_disallow_on_init_function(get_contract): code = """ @external -@nonreentrant("lock") +@nonreentrant def __init__(): foo: uint256 = 0 """ diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index afc7a35012..e530487fea 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -44,42 +44,11 @@ def foo() -> int128: return x.codesize() """, """ -@external -@nonreentrant("B") -@nonreentrant("C") -def double_nonreentrant(): - pass - """, - """ struct X: int128[5]: int128[7] """, """ @external -@nonreentrant(" ") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("123") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("!123abcd") -def invalid_nonreentrant_key(): - pass - """, - """ -@external def foo(): true: int128 = 3 """, diff --git a/tests/functional/syntax/signatures/test_invalid_function_decorators.py b/tests/functional/syntax/signatures/test_invalid_function_decorators.py index b3d4219a2d..a7a500efc7 100644 --- a/tests/functional/syntax/signatures/test_invalid_function_decorators.py +++ b/tests/functional/syntax/signatures/test_invalid_function_decorators.py @@ -7,10 +7,23 @@ """ @external @pure -@nonreentrant('lock') +@nonreentrant def nonreentrant_foo() -> uint256: return 1 + """, """ +@external +@nonreentrant +@nonreentrant +def nonreentrant_foo() -> uint256: + return 1 + """, + """ +@external +@nonreentrant("foo") +def nonreentrant_foo() -> uint256: + return 1 + """, ] diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index f0ee25f747..9724dd723c 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -6,18 +6,18 @@ def test_storage_layout(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -28,12 +28,12 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ @@ -41,12 +41,11 @@ def public_foo3(): out = compile_code(code, output_formats=["layout"]) assert out["layout"]["storage_layout"] == { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 0}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 1}, - "foo": {"type": "HashMap[address, uint256]", "slot": 2}, - "arr": {"type": "DynArray[uint256, 3]", "slot": 3}, - "baz": {"type": "Bytes[65]", "slot": 7}, - "bar": {"type": "uint256", "slot": 11}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, + "baz": {"slot": 6, "type": "Bytes[65]"}, + "bar": {"slot": 10, "type": "uint256"}, } @@ -64,10 +63,13 @@ def __init__(): expected_layout = { "code_layout": { - "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, + }, + "storage_layout": { + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "name": {"slot": 1, "type": "String[32]"}, }, - "storage_layout": {"name": {"slot": 0, "type": "String[32]"}}, } out = compile_code(code, output_formats=["layout"]) @@ -107,14 +109,15 @@ def __init__(): "code_layout": { "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, "a_library": { - "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "counter2": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, }, } @@ -160,9 +163,10 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "a_library": {"supply": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, }, } @@ -171,7 +175,8 @@ def __init__(): def test_storage_layout_module_uses(make_input_bundle): - # test module storage layout, with initializes/uses + # test module storage layout, with initializes/uses and a nonreentrant + # lock lib1 = """ supply: uint256 SYMBOL: immutable(String[32]) @@ -197,6 +202,11 @@ def __init__(s: uint256): @internal def decimals() -> uint8: return lib1.DECIMALS + +@external +@nonreentrant +def foo(): + pass """ code = """ import lib1 as a_library @@ -218,6 +228,11 @@ def __init__(): some_immutable = [1, 2, 3] lib2.__init__(17) + +@external +@nonreentrant +def bar(): + pass """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) @@ -231,10 +246,11 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, + "a_library": {"supply": {"slot": 4, "type": "uint256"}}, }, } @@ -309,12 +325,13 @@ def foo() -> uint256: }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, "lib2": { - "lib1": {"supply": {"slot": 1, "type": "uint256"}}, - "storage_variable": {"slot": 2, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256"}}, + "storage_variable": {"slot": 3, "type": "uint256"}, }, - "counter2": {"slot": 3, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256"}, }, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..707c94c3fc 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -1,3 +1,5 @@ +import re + import pytest from vyper.compiler import compile_code @@ -28,18 +30,18 @@ def test_storage_layout_for_more_complex(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -48,19 +50,18 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ storage_layout_override = { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 8}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 7}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, "foo": {"type": "HashMap[address, uint256]", "slot": 1}, "baz": {"type": "Bytes[65]", "slot": 2}, "bar": {"type": "uint256", "slot": 6}, @@ -110,6 +111,25 @@ def test_overflow(): ) +def test_override_nonreentrant_slot(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} + + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + def test_incomplete_overrides(): code = """ name: public(String[64]) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 3620ef64b9..1dc70fd1ba 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -47,15 +47,9 @@ def __init__(): self.foo[1] = [123, 456, 789] @external -@nonreentrant('lock') +@nonreentrant def with_lock(): pass - - -@external -@nonreentrant('otherlock') -def with_other_lock(): - pass """ @@ -84,7 +78,6 @@ def test_reentrancy_lock(get_contract): # if re-entrancy locks are incorrectly placed within storage, these # calls will either revert or correupt the data that we read later c.with_lock() - c.with_other_lock() assert c.a() == ("ok", [4, 5, 6]) assert [c.b(i) for i in range(2)] == [7, 8] @@ -105,7 +98,7 @@ def test_reentrancy_lock(get_contract): def test_allocator_overflow(get_contract): code = """ -x: uint256 +# --> global nonreentrancy slot allocated here <-- y: uint256[max_value(uint256)] """ with pytest.raises( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 604bc6b594..bb4322c7b2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -43,10 +43,15 @@ def __setitem__(self, k, v): super().__setitem__(k, v) +# some name that the user cannot assign to a variable +GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" + + class SimpleAllocator: def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): # Allocate storage slots from 0 # note storage is word-addressable, not byte-addressable + self._starting_slot = starting_slot self._slot = starting_slot self._max_slot = max_slot @@ -61,12 +66,19 @@ def allocate_slot(self, n, var_name, node=None): self._slot += n return ret + def allocate_global_nonreentrancy_slot(self): + slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + assert slot == self._starting_slot + return slot + class Allocators: storage_allocator: SimpleAllocator transient_storage_allocator: SimpleAllocator immutables_allocator: SimpleAllocator + _global_nonreentrancy_key_slot: int + def __init__(self): self.storage_allocator = SimpleAllocator(max_slot=2**256) self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) @@ -82,6 +94,16 @@ def get_allocator(self, location: DataLocation): raise CompilerPanic("unreachable") # pragma: nocover + def allocate_global_nonreentrancy_slot(self): + location = get_reentrancy_key_location() + + allocator = self.get_allocator(location) + slot = allocator.allocate_global_nonreentrancy_slot() + self._global_nonreentrancy_key_slot = slot + + def get_global_nonreentrant_key_slot(self): + return self._global_nonreentrancy_key_slot + class OverridingStorageAllocator: """ @@ -127,7 +149,6 @@ def set_storage_slots_with_overrides( Returns the layout as a dict of variable name -> variable info (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() reserved_slots = OverridingStorageAllocator() @@ -136,15 +157,13 @@ def set_storage_slots_with_overrides( type_ = node._metadata["func_type"] # Ignore functions without non-reentrant - if type_.nonreentrant is None: + if not type_.nonreentrant: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + variable_name = GLOBAL_NONREENTRANT_KEY # re-entrant key was already identified if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -210,6 +229,20 @@ def get_reentrancy_key_location() -> DataLocation: } +def _allocate_nonreentrant_keys(vyper_module, allocators): + SLOT = allocators.get_global_nonreentrant_key_slot() + + for node in vyper_module.get_children(vy_ast.FunctionDef): + type_ = node._metadata["func_type"] + if not type_.nonreentrant: + continue + + # a nonreentrant key can appear many times in a module but it + # only takes one slot. after the first time we see it, do not + # increment the storage slot. + type_.set_reentrancy_key_position(VarOffset(SLOT)) + + def _allocate_layout_r( vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False ) -> StorageLayout: @@ -217,42 +250,26 @@ def _allocate_layout_r( Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ + global_ = False if allocators is None: + global_ = True allocators = Allocators() + # always allocate nonreentrancy slot, so that adding or removing + # reentrancy protection from a contract does not change its layout + allocators.allocate_global_nonreentrancy_slot() ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - for node in vyper_module.get_children(vy_ast.FunctionDef): - if immutables_only: - break - - type_ = node._metadata["func_type"] - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - reentrancy_key_location = get_reentrancy_key_location() - layout_key = _LAYOUT_KEYS[reentrancy_key_location] - - # a nonreentrant key can appear many times in a module but it - # only takes one slot. after the first time we see it, do not - # increment the storage slot. - if variable_name in ret[layout_key]: - _slot = ret[layout_key][variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) - continue - - # TODO use one byte - or bit - per reentrancy key - # requires either an extra SLOAD or caching the value of the - # location in memory at entrance - allocator = allocators.get_allocator(reentrancy_key_location) - slot = allocator.allocate_slot(1, variable_name, node) - - type_.set_reentrancy_key_position(VarOffset(slot)) + # tag functions with the global nonreentrant key + if not immutables_only: + _allocate_nonreentrant_keys(vyper_module, allocators) + layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + slot = allocators.get_global_nonreentrant_key_slot() + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 705470a798..43d553288e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast -from vyper.ast.identifiers import validate_identifier from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ArgumentException, @@ -78,8 +77,8 @@ class ContractFunctionT(VyperType): enum indicating the external visibility of a function. state_mutability : StateMutability enum indicating the authority a function has to mutate it's own state. - nonreentrant : Optional[str] - Re-entrancy lock name. + nonreentrant : bool + Whether this function is marked `@nonreentrant` or not """ _is_callable = True @@ -93,7 +92,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, from_interface: bool = False, - nonreentrant: Optional[str] = None, + nonreentrant: bool = False, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -107,6 +106,9 @@ def __init__( self.nonreentrant = nonreentrant self.from_interface = from_interface + # sanity check, nonreentrant used to be Optional[str] + assert isinstance(self.nonreentrant, bool) + self.ast_def = ast_def self._analysed = False @@ -279,7 +281,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=None, + nonreentrant=False, ast_def=funcdef, ) @@ -298,12 +300,10 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) - if nonreentrant_key is not None: - raise FunctionDeclarationException( - "nonreentrant key not allowed in interfaces", funcdef - ) + if nonreentrant: + raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -332,7 +332,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) @@ -350,7 +350,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) positional_args, keyword_args = _parse_args(funcdef) @@ -403,15 +403,16 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=False, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") - if self.nonreentrant is None: - raise CompilerPanic(f"No reentrant key {self}") + if not self.nonreentrant: + raise CompilerPanic(f"Not nonreentrant {self}", self.ast_def) + self.reentrancy_key_position = position @classmethod @@ -660,32 +661,30 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None - nonreentrant_key = None + nonreentrant_node = None for decorator in funcdef.decorator_list: if isinstance(decorator, vy_ast.Call): - if nonreentrant_key is not None: - raise StructureException( - "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", - funcdef, - ) - - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + msg = "Decorator is not callable" + hint = None + if decorator.get("func.id") == "nonreentrant": + hint = "use `@nonreentrant` with no arguments. the " + hint += "`@nonreentrant` decorator does not accept any " + hint += "arguments since vyper 0.4.0." + raise StructureException(msg, decorator, hint=hint) + + if decorator.get("id") == "nonreentrant": + if nonreentrant_node is not None: + raise StructureException("nonreentrant decorator is already set", nonreentrant_node) if funcdef.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" + msg = "`@nonreentrant` decorator disallowed on `__init__`" raise FunctionDeclarationException(msg, decorator) - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) + nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): if FunctionVisibility.is_valid_value(decorator.id): @@ -726,12 +725,13 @@ def _parse_decorators( # default to nonpayable state_mutability = StateMutability.NONPAYABLE - if state_mutability == StateMutability.PURE and nonreentrant_key is not None: - raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + if state_mutability == StateMutability.PURE and nonreentrant_node is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", nonreentrant_node) # assert function_visibility is not None # mypy # assert state_mutability is not None # mypy - return function_visibility, state_mutability, nonreentrant_key + nonreentrant = nonreentrant_node is not None + return function_visibility, state_mutability, nonreentrant def _parse_args( From d3723783caf3edfd8905fea0b221f99f18eeb27b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 11:21:45 -0800 Subject: [PATCH 188/201] refactor: `get_search_paths()` for vyper cli (#3778) this is for external consumers of the compiler library, which may want to be able to replicate how vyper computes the search path --- vyper/cli/vyper_compile.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ac69cf3310..3a63b88576 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -229,32 +229,37 @@ def exc_handler(contract_path: ContractPath, exception: Exception) -> None: raise exception -def compile_files( - input_files: list[str], - output_formats: OutputFormats, - paths: list[str] = None, - show_gas_estimates: bool = False, - settings: Optional[Settings] = None, - storage_layout_paths: list[str] = None, - no_bytecode_metadata: bool = False, -) -> dict: - # lowest precedence search path is always sys path - search_paths = [Path(p) for p in sys.path] +def get_search_paths(paths: list[str] = None) -> list[Path]: + # given `paths` input, get the full search path, including + # the system search path. + paths = paths or [] - # python sys path uses opposite resolution order from us + # lowest precedence search path is always sys path + # note python sys path uses opposite resolution order from us # (first in list is highest precedence; we give highest precedence # to the last in the list) - search_paths.reverse() + search_paths = [Path(p) for p in reversed(sys.path)] if Path(".") not in search_paths: search_paths.append(Path(".")) - paths = paths or [] - for p in paths: path = Path(p).resolve(strict=True) search_paths.append(path) + return search_paths + + +def compile_files( + input_files: list[str], + output_formats: OutputFormats, + paths: list[str] = None, + show_gas_estimates: bool = False, + settings: Optional[Settings] = None, + storage_layout_paths: list[str] = None, + no_bytecode_metadata: bool = False, +) -> dict: + search_paths = get_search_paths(paths) input_bundle = FilesystemInputBundle(search_paths) show_version = False From 8e5e1c24ababd36cbc83d0690f224a97801675a5 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Tue, 13 Feb 2024 23:01:18 +0100 Subject: [PATCH 189/201] docs: add missing cli flags (#3736) * Add missing compiler flags to docs * Add note on `opcodes` and `opcodes_runtime` wrong output * docs: add missing cli flag `metadata` --- docs/compiling-a-contract.rst | 5 ++++- vyper/cli/vyper_compile.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst index 2b069c2add..36b46497f9 100644 --- a/docs/compiling-a-contract.rst +++ b/docs/compiling-a-contract.rst @@ -29,7 +29,10 @@ Include the ``-f`` flag to specify which output formats to return. Use ``vyper - .. code:: shell - $ vyper -f abi,bytecode,bytecode_runtime,ir,asm,source_map,method_identifiers yourFileName.vy + $ vyper -f abi,abi_python,bytecode,bytecode_runtime,interface,external_interface,ast,ir,ir_json,ir_runtime,hex-ir,asm,opcodes,opcodes_runtime,source_map,method_identifiers,userdoc,devdoc,metadata,combined_json,layout yourFileName.vy + +.. note:: + The ``opcodes`` and ``opcodes_runtime`` output of the compiler has been returning incorrect opcodes since ``0.2.0`` due to a lack of 0 padding (patched via `PR 3735 `_). If you rely on these functions for debugging, please use the latest patched versions. The ``-p`` flag allows you to set a root path that is used when searching for interface files to import. If none is given, it will default to the current working directory. See :ref:`searching_for_imports` for more information. diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 3a63b88576..2ba8a5417c 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -31,6 +31,7 @@ method_identifiers - Dictionary of method signature to method identifier userdoc - Natspec user documentation devdoc - Natspec developer documentation +metadata - Contract metadata (intended for use by tooling developers) combined_json - All of the above format options combined as single JSON output layout - Storage layout of a Vyper contract ast - AST (not yet annotated) in JSON format From 4b4e188ba83d28b5dd6ff66479e7448e5b925030 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 19:26:46 -0800 Subject: [PATCH 190/201] perf: levenshtein optimization (#3780) optimize compile time. `levenshtein` is a hotspot since it is called a lot during type analysis to construct exceptions (which are then caught as part of the validation routines). this commit delays calling `levenshtein` until the last minute, and also adds a mechanism to `VyperException` so that hints can be constructed lazily in general. on a couple test contracts, compilation time comes down 7%. however, as a portion of the time spent in the frontend, compilation time comes down 20-30%. this will become important as projects become larger (that is, many imports but only some functions are actually present in codegen) and compilation time is dominated by the frontend. --- .../syntax/modules/test_initializers.py | 2 +- tests/functional/syntax/test_for_range.py | 39 +++++++++++++++---- vyper/exceptions.py | 13 ++++++- vyper/semantics/analysis/levenshtein_utils.py | 10 ++++- vyper/semantics/analysis/utils.py | 4 +- vyper/semantics/namespace.py | 4 +- vyper/semantics/types/base.py | 4 +- vyper/semantics/types/user.py | 4 +- vyper/semantics/types/utils.py | 5 +-- 9 files changed, 62 insertions(+), 23 deletions(-) diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index d0965ae61d..66a201a33d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -1178,4 +1178,4 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle): input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(UndeclaredDefinition) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "'lib2' has not been declared. " + assert e.value._message == "'lib2' has not been declared." diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a486d11738..94eed58dd4 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -21,6 +21,7 @@ def foo(): """, StructureException, "Invalid syntax for loop iterator", + None, "a[1]", ), ( @@ -32,6 +33,7 @@ def bar(): """, StructureException, "Bound must be at least 1", + None, "0", ), ( @@ -44,6 +46,7 @@ def foo(): """, StateAccessViolation, "Bound must be a literal", + None, "x", ), ( @@ -55,6 +58,7 @@ def foo(): """, StructureException, "Please remove the `bound=` kwarg when using range with constants", + None, "5", ), ( @@ -66,6 +70,7 @@ def foo(): """, StructureException, "Bound must be at least 1", + None, "0", ), ( @@ -78,6 +83,7 @@ def bar(): """, ArgumentException, "Invalid keyword argument 'extra'", + None, "extra=3", ), ( @@ -89,6 +95,7 @@ def bar(): """, StructureException, "End must be greater than start", + None, "0", ), ( @@ -101,6 +108,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -113,6 +121,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -125,6 +134,7 @@ def repeat(n: uint256) -> uint256: """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "n * 10", ), ( @@ -137,6 +147,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x + 1", ), ( @@ -148,6 +159,7 @@ def bar(): """, StructureException, "End must be greater than start", + None, "1", ), ( @@ -160,6 +172,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -172,6 +185,7 @@ def foo(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -184,6 +198,7 @@ def repeat(n: uint256) -> uint256: """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "n", ), ( @@ -196,6 +211,7 @@ def foo(x: int128): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -207,6 +223,7 @@ def bar(x: uint256): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -221,6 +238,7 @@ def foo(): """, TypeMismatch, "Given reference has type int128, expected uint256", + None, "FOO", ), ( @@ -234,6 +252,7 @@ def foo(): """, StructureException, "Bound must be at least 1", + None, "FOO", ), ( @@ -244,7 +263,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "No builtin or user-defined type named 'DynArra'.", + "Did you mean 'DynArray'?", "DynArra", ), ( @@ -262,7 +282,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "No builtin or user-defined type named 'uint9'.", + "Did you mean 'uint96', or maybe 'uint8'?", "uint9", ), ( @@ -278,7 +299,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "No builtin or user-defined type named 'uint9'.", + "Did you mean 'uint96', or maybe 'uint8'?", "uint9", ), ] @@ -289,15 +311,18 @@ def foo(): f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] f" raises {type(err).__name__}" ) - for i, (code, err, msg, src) in enumerate(fail_list) + for i, (code, err, msg, hint, src) in enumerate(fail_list) ] -@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names) -def test_range_fail(bad_code, error_type, message, source_code): +@pytest.mark.parametrize( + "bad_code,error_type,message,hint,source_code", fail_list, ids=fail_test_names +) +def test_range_fail(bad_code, error_type, message, hint, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) - assert message == exc_info.value.message + assert message == exc_info.value._message + assert hint == exc_info.value.hint assert source_code == exc_info.value.args[1].get_original_node().node_source_code diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 53ad6f7bb8..f57cdabe9d 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -79,11 +79,20 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def hint(self): + # some hints are expensive to compute, so we wait until the last + # minute when the formatted message is actually requested to compute + # them. + if callable(self._hint): + return self._hint() + return self._hint + @property def message(self): msg = self._message - if self._hint: - msg += f"\n\n (hint: {self._hint})" + if self.hint: + msg += f"\n\n (hint: {self.hint})" return msg def __str__(self): diff --git a/vyper/semantics/analysis/levenshtein_utils.py b/vyper/semantics/analysis/levenshtein_utils.py index 1d8f87dfbd..fc6e497d43 100644 --- a/vyper/semantics/analysis/levenshtein_utils.py +++ b/vyper/semantics/analysis/levenshtein_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Callable def levenshtein_norm(source: str, target: str) -> float: @@ -73,7 +73,13 @@ def levenshtein(source: str, target: str) -> int: return matrix[len(source)][len(target)] -def get_levenshtein_error_suggestions(key: str, namespace: Dict[str, Any], threshold: float) -> str: +def get_levenshtein_error_suggestions(*args, **kwargs) -> Callable: + return lambda: _get_levenshtein_error_suggestions(*args, **kwargs) + + +def _get_levenshtein_error_suggestions( + key: str, namespace: dict[str, Any], threshold: float +) -> str: """ Generate an error message snippet for the suggested closest values in the provided namespace with the shortest normalized Levenshtein distance from the given key if that distance diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 034cd8c46e..fa4dfcc1d1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -208,9 +208,9 @@ def _raise_invalid_reference(name, node): if name in self.namespace: _raise_invalid_reference(name, node) - suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4) + hint = get_levenshtein_error_suggestions(name, t.members, 0.4) raise UndeclaredDefinition( - f"Storage variable '{name}' has not been declared. {suggestions_str}", node + f"Storage variable '{name}' has not been declared.", node, hint=hint ) from None def types_from_BinOp(self, node): diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 4df2511a29..d59343edfb 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -45,8 +45,8 @@ def __setitem__(self, attr, obj): def __getitem__(self, key): if key not in self: - suggestions_str = get_levenshtein_error_suggestions(key, self, 0.2) - raise UndeclaredDefinition(f"'{key}' has not been declared. {suggestions_str}") + hint = get_levenshtein_error_suggestions(key, self, 0.2) + raise UndeclaredDefinition(f"'{key}' has not been declared.", hint=hint) return super().__getitem__(key) def __enter__(self): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index c5e10b52be..37de263319 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -321,8 +321,8 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": if not self.members: raise StructureException(f"{self} instance does not have members", node) - suggestions_str = get_levenshtein_error_suggestions(key, self.members, 0.3) - raise UnknownAttribute(f"{self} has no member '{key}'. {suggestions_str}", node) + hint = get_levenshtein_error_suggestions(key, self.members, 0.3) + raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint) def __repr__(self): return self._id diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92a455e3d8..0c9b5d70da 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -399,9 +399,9 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": keys = list(self.member_types.keys()) for i, (key, value) in enumerate(zip(node.args[0].keys, node.args[0].values)): if key is None or key.get("id") not in members: - suggestions_str = get_levenshtein_error_suggestions(key.get("id"), members, 1.0) + hint = get_levenshtein_error_suggestions(key.get("id"), members, 1.0) raise UnknownAttribute( - f"Unknown or duplicate struct member. {suggestions_str}", key or value + "Unknown or duplicate struct member.", key or value, hint=hint ) expected_key = keys[i] if key.id != expected_key: diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 96c661021f..0546668900 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -146,10 +146,9 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: raise InvalidType(err_msg, node) if node.id not in namespace: # type: ignore - suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) + hint = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) raise UnknownType( - f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}", - node, + f"No builtin or user-defined type named '{node.node_source_code}'.", node, hint=hint ) from None typ_ = namespace[node.id] From 29205baccd771ab584b1f0182963aca201352b64 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 14 Feb 2024 10:15:46 -0800 Subject: [PATCH 191/201] feat: frontend optimizations (#3781) a couple compilation-time optimizations. bring total time down by 5% (`vyper -f bytecode`), and time in frontend code down by 20% (`vyper -f annotated_ast`). - cache `VyperNode.get_fields()`, it's a hotspot - optimize `get_common_types(), this line is a hotspot; `rejected = [i for i in common_types if i not in common]` --- vyper/ast/nodes.py | 2 ++ vyper/semantics/analysis/utils.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index c4bce814a4..fb5fb73592 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -2,6 +2,7 @@ import contextlib import copy import decimal +import functools import operator import sys import warnings @@ -341,6 +342,7 @@ def from_node(cls, node: "VyperNode", **kwargs) -> "VyperNode": return cls(**ast_struct) @classmethod + @functools.lru_cache(maxsize=None) def get_fields(cls) -> set: """ Return a set of field names for this node. diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index fa4dfcc1d1..abea600d88 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -500,12 +500,14 @@ def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> Li for item in nodes[1:]: new_types = _ExprAnalyser().get_possible_types_from_node(item) - common = [i for i in common_types if _is_type_in_list(i, new_types)] - - rejected = [i for i in common_types if i not in common] - common += [i for i in new_types if _is_type_in_list(i, rejected)] - - common_types = common + tmp = [] + for c in common_types: + for t in new_types: + if t.compare_type(c) or c.compare_type(t): + tmp.append(c) + break + + common_types = tmp if filter_fn is not None: common_types = [i for i in common_types if filter_fn(i)] From 2d2a68224296342b3c888676848d7da9dc0b1cce Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 16 Feb 2024 14:43:48 -0800 Subject: [PATCH 192/201] feat: optimize `VyperNode.__deepcopy__` (#3784) `VyperNode.__deepcopy__` is a hotspot in the frontend. this commit improves time spent in the frontend by 10% --- vyper/ast/nodes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index fb5fb73592..0ebe18ab5d 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -4,6 +4,7 @@ import decimal import functools import operator +import pickle import sys import warnings from typing import Any, Optional, Union @@ -357,6 +358,9 @@ def __hash__(self): values = [getattr(self, i, None) for i in VyperNode.__slots__ if not i.startswith("_")] return hash(tuple(values)) + def __deepcopy__(self, memo): + return pickle.loads(pickle.dumps(self)) + def __eq__(self, other): if not isinstance(other, type(self)): return False @@ -786,7 +790,6 @@ class ExprNode(VyperNode): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._expr_info = None From 1fc819c317021b4acf3fcabb0b831cf946aef2bd Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 17 Feb 2024 07:13:45 +0800 Subject: [PATCH 193/201] chore: deduplicate test files (#3773) This PR merges the cross-contract tests in `test_return_struct.py` / `test_return_tuple.py` with `test_struct_return.py` / `test_tuple_return.py` in `tests/functional/codegen/calling_convention.py` The remaining struct unit tests were moved into a new file: `tests/functional/codegen/types/test_structs.py`. --- .../codegen/calling_convention/test_return.py | 769 ++++++++++++++++++ .../calling_convention/test_return_struct.py | 507 ------------ .../calling_convention/test_return_tuple.py | 163 ---- .../calling_convention/test_tuple_return.py | 45 - .../test_struct.py} | 61 -- 5 files changed, 769 insertions(+), 776 deletions(-) delete mode 100644 tests/functional/codegen/calling_convention/test_return_struct.py delete mode 100644 tests/functional/codegen/calling_convention/test_return_tuple.py delete mode 100644 tests/functional/codegen/calling_convention/test_tuple_return.py rename tests/functional/codegen/{calling_convention/test_struct_return.py => types/test_struct.py} (52%) diff --git a/tests/functional/codegen/calling_convention/test_return.py b/tests/functional/codegen/calling_convention/test_return.py index eaa4f9034a..2db3689da3 100644 --- a/tests/functional/codegen/calling_convention/test_return.py +++ b/tests/functional/codegen/calling_convention/test_return.py @@ -1,3 +1,11 @@ +import pytest + +from vyper import compile_code +from vyper.exceptions import TypeMismatch + +pytestmark = pytest.mark.usefixtures("memory_mocker") + + def test_correct_abi_right_padding(tester, w3, get_contract_with_gas_estimation): selfcall_code_6 = """ @external @@ -31,3 +39,764 @@ def hardtest(arg1: Bytes[64], arg2: Bytes[64]) -> Bytes[128]: assert dyn_section[32 : 32 + len_value] == b"hello" * 15 # second right pad assert assert dyn_section[32 + len_value :] == b"\x00" * (len(dyn_section) - 32 - len_value) + + +def test_return_type(get_contract_with_gas_estimation): + long_string = 35 * "test" + + code = """ +struct Chunk: + a: Bytes[8] + b: Bytes[8] + c: int128 +chunk: Chunk + +@deploy +def __init__(): + self.chunk.a = b"hello" + self.chunk.b = b"world" + self.chunk.c = 5678 + +@external +def out() -> (int128, address): + return 3333, 0x0000000000000000000000000000000000000001 + +@external +def out_literals() -> (int128, address, Bytes[6]): + return 1, 0x0000000000000000000000000000000000000000, b"random" + +@external +def out_bytes_first() -> (Bytes[4], int128): + return b"test", 1234 + +@external +def out_bytes_a(x: int128, y: Bytes[4]) -> (int128, Bytes[4]): + return x, y + +@external +def out_bytes_b(x: int128, y: Bytes[4]) -> (Bytes[4], int128, Bytes[4]): + return y, x, y + +@external +def four() -> (int128, Bytes[8], Bytes[8], int128): + return 1234, b"bytes", b"test", 4321 + +@external +def out_chunk() -> (Bytes[8], int128, Bytes[8]): + return self.chunk.a, self.chunk.c, self.chunk.b + +@external +def out_very_long_bytes() -> (int128, Bytes[1024], int128, address): + return 5555, b"testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest", 6666, 0x0000000000000000000000000000000000001234 # noqa + """ + + c = get_contract_with_gas_estimation(code) + + assert c.out() == [3333, "0x0000000000000000000000000000000000000001"] + assert c.out_literals() == [1, None, b"random"] + assert c.out_bytes_first() == [b"test", 1234] + assert c.out_bytes_a(5555555, b"test") == [5555555, b"test"] + assert c.out_bytes_b(5555555, b"test") == [b"test", 5555555, b"test"] + assert c.four() == [1234, b"bytes", b"test", 4321] + assert c.out_chunk() == [b"hello", 5678, b"world"] + assert c.out_very_long_bytes() == [ + 5555, + long_string.encode(), + 6666, + "0x0000000000000000000000000000000000001234", + ] + + +def test_return_type_signatures(get_contract_with_gas_estimation): + code = """ +@external +def out_literals() -> (int128, address, Bytes[6]): + return 1, 0x0000000000000000000000000000000000000000, b"random" + """ + + c = get_contract_with_gas_estimation(code) + assert c._classic_contract.abi[0]["outputs"] == [ + {"type": "int128", "name": ""}, + {"type": "address", "name": ""}, + {"type": "bytes", "name": ""}, + ] + + +def test_return_tuple_assign(get_contract_with_gas_estimation): + code = """ +@internal +def _out_literals() -> (int128, address, Bytes[10]): + return 1, 0x0000000000000000000000000000000000000000, b"random" + +@external +def out_literals() -> (int128, address, Bytes[10]): + return self._out_literals() + +@external +def test() -> (int128, address, Bytes[10]): + a: int128 = 0 + b: address = empty(address) + c: Bytes[10] = b"" + (a, b, c) = self._out_literals() + return a, b, c + """ + + c = get_contract_with_gas_estimation(code) + + assert c.out_literals() == c.test() == [1, None, b"random"] + + +def test_return_tuple_assign_storage(get_contract_with_gas_estimation): + code = """ +a: int128 +b: address +c: Bytes[20] +d: Bytes[20] + +@internal +def _out_literals() -> (int128, Bytes[20], address, Bytes[20]): + return 1, b"testtesttest", 0x0000000000000000000000000000000000000023, b"random" + +@external +def out_literals() -> (int128, Bytes[20], address, Bytes[20]): + return self._out_literals() + +@external +def test1() -> (int128, Bytes[20], address, Bytes[20]): + self.a, self.c, self.b, self.d = self._out_literals() + return self.a, self.c, self.b, self.d + +@external +def test2() -> (int128, address): + x: int128 = 0 + x, self.c, self.b, self.d = self._out_literals() + return x, self.b + +@external +def test3() -> (address, int128): + x: address = empty(address) + self.a, self.c, x, self.d = self._out_literals() + return x, self.a + """ + + c = get_contract_with_gas_estimation(code) + + addr = "0x" + "00" * 19 + "23" + assert c.out_literals() == [1, b"testtesttest", addr, b"random"] + assert c.out_literals() == c.test1() + assert c.test2() == [1, c.out_literals()[2]] + assert c.test3() == [c.out_literals()[2], 1] + + +@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) +def test_string_inside_tuple(get_contract, string): + code = f""" +@external +def test_return() -> (String[6], uint256): + return "{string}", 42 + """ + c1 = get_contract(code) + + code = """ +interface jsonabi: + def test_return() -> (String[6], uint256): view + +@external +def test_values(a: address) -> (String[6], uint256): + return jsonabi(a).test_return() + """ + + c2 = get_contract(code) + assert c2.test_values(c1.address) == [string, 42] + + +@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) +def test_bytes_inside_tuple(get_contract, string): + code = f""" +@external +def test_return() -> (Bytes[6], uint256): + return b"{string}", 42 + """ + c1 = get_contract(code) + + code = """ +interface jsonabi: + def test_return() -> (Bytes[6], uint256): view + +@external +def test_values(a: address) -> (Bytes[6], uint256): + return jsonabi(a).test_return() + """ + + c2 = get_contract(code) + assert c2.test_values(c1.address) == [bytes(string, "utf-8"), 42] + + +def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): + code = """ +@external +def getTimeAndBalance() -> (bool, address): + return block.timestamp, self.balance + """ + with pytest.raises(TypeMismatch): + compile_code(code) + + +def test_struct_return_abi(get_contract_with_gas_estimation): + code = """ +struct Voter: + weight: int128 + voted: bool + +@external +def test() -> Voter: + a: Voter = Voter({weight: 123, voted: True}) + return a + """ + + out = compile_code(code, output_formats=["abi"]) + abi = out["abi"][0] + + assert abi["name"] == "test" + + c = get_contract_with_gas_estimation(code) + + assert c.test() == (123, True) + + +def test_single_struct_return_abi(get_contract_with_gas_estimation): + code = """ +struct Voter: + voted: bool + +@external +def test() -> Voter: + a: Voter = Voter({voted: True}) + return a + """ + + out = compile_code(code, output_formats=["abi"]) + abi = out["abi"][0] + + assert abi["name"] == "test" + assert abi["outputs"][0]["type"] == "tuple" + + c = get_contract_with_gas_estimation(code) + + assert c.test() == (True,) + + +def test_struct_return(get_contract_with_gas_estimation): + code = """ +struct Foo: + x: int128 + y: uint256 + +_foo: Foo +_foos: HashMap[int128, Foo] + +@internal +def priv1() -> Foo: + return Foo({x: 1, y: 2}) +@external +def pub1() -> Foo: + return self.priv1() + +@internal +def priv2() -> Foo: + foo: Foo = Foo({x: 0, y: 0}) + foo.x = 3 + foo.y = 4 + return foo +@external +def pub2() -> Foo: + return self.priv2() + +@external +def pub3() -> Foo: + self._foo = Foo({x: 5, y: 6}) + return self._foo + +@external +def pub4() -> Foo: + self._foos[0] = Foo({x: 7, y: 8}) + return self._foos[0] + +@internal +def return_arg(foo: Foo) -> Foo: + return foo +@external +def pub5(foo: Foo) -> Foo: + return self.return_arg(foo) +@external +def pub6() -> Foo: + foo: Foo = Foo({x: 123, y: 456}) + return self.return_arg(foo) + """ + foo = (123, 456) + + c = get_contract_with_gas_estimation(code) + + assert c.pub1() == (1, 2) + assert c.pub2() == (3, 4) + assert c.pub3() == (5, 6) + assert c.pub4() == (7, 8) + assert c.pub5(foo) == foo + assert c.pub6() == foo + + +def test_single_struct_return(get_contract_with_gas_estimation): + code = """ +struct Foo: + x: int128 + +_foo: Foo +_foos: HashMap[int128, Foo] + +@internal +def priv1() -> Foo: + return Foo({x: 1}) +@external +def pub1() -> Foo: + return self.priv1() + +@internal +def priv2() -> Foo: + foo: Foo = Foo({x: 0}) + foo.x = 3 + return foo +@external +def pub2() -> Foo: + return self.priv2() + +@external +def pub3() -> Foo: + self._foo = Foo({x: 5}) + return self._foo + +@external +def pub4() -> Foo: + self._foos[0] = Foo({x: 7}) + return self._foos[0] + +@internal +def return_arg(foo: Foo) -> Foo: + return foo +@external +def pub5(foo: Foo) -> Foo: + return self.return_arg(foo) +@external +def pub6() -> Foo: + foo: Foo = Foo({x: 123}) + return self.return_arg(foo) + """ + foo = (123,) + + c = get_contract_with_gas_estimation(code) + + assert c.pub1() == (1,) + assert c.pub2() == (3,) + assert c.pub3() == (5,) + assert c.pub4() == (7,) + assert c.pub5(foo) == foo + assert c.pub6() == foo + + +def test_self_call_in_return_struct(get_contract): + code = """ +struct Foo: + a: uint256 + b: uint256 + c: uint256 + d: uint256 + e: uint256 + +@internal +def _foo() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 3 + +@external +def foo() -> Foo: + return Foo({a:1, b:2, c:self._foo(), d:4, e:5}) + """ + + c = get_contract(code) + + assert c.foo() == (1, 2, 3, 4, 5) + + +def test_self_call_in_return_single_struct(get_contract): + code = """ +struct Foo: + a: uint256 + +@internal +def _foo() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 3 + +@external +def foo() -> Foo: + return Foo({a:self._foo()}) + """ + + c = get_contract(code) + + assert c.foo() == (3,) + + +def test_call_in_call(get_contract): + code = """ +struct Foo: + a: uint256 + b: uint256 + c: uint256 + d: uint256 + e: uint256 + +@internal +def _foo(a: uint256, b: uint256, c: uint256) -> Foo: + return Foo({a:1, b:a, c:b, d:c, e:5}) + +@internal +def _foo2() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,15,16] + return 4 + +@external +def foo() -> Foo: + return self._foo(2, 3, self._foo2()) + """ + + c = get_contract(code) + + assert c.foo() == (1, 2, 3, 4, 5) + + +def test_call_in_call_single_struct(get_contract): + code = """ +struct Foo: + a: uint256 + +@internal +def _foo(a: uint256) -> Foo: + return Foo({a:a}) + +@internal +def _foo2() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,15,16] + return 4 + +@external +def foo() -> Foo: + return self._foo(self._foo2()) + """ + + c = get_contract(code) + + assert c.foo() == (4,) + + +def test_nested_calls_in_struct_return(get_contract): + code = """ +struct Foo: + a: uint256 + b: uint256 + c: uint256 + d: uint256 + e: uint256 +struct Bar: + a: uint256 + b: uint256 + +@internal +def _bar(a: uint256, b: uint256, c: uint256) -> Bar: + return Bar({a:415, b:3}) + +@internal +def _foo2(a: uint256) -> uint256: + b: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 99 + +@internal +def _foo3(a: uint256, b: uint256) -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 42 + +@internal +def _foo4() -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 4 + +@external +def foo() -> Foo: + return Foo({ + a:1, + b:2, + c:self._bar(6, 7, self._foo2(self._foo3(9, 11))).b, + d:self._foo4(), + e:5 + }) + """ + + c = get_contract(code) + + assert c.foo() == (1, 2, 3, 4, 5) + + +def test_nested_calls_in_single_struct_return(get_contract): + code = """ +struct Foo: + a: uint256 +struct Bar: + a: uint256 + b: uint256 + +@internal +def _bar(a: uint256, b: uint256, c: uint256) -> Bar: + return Bar({a:415, b:3}) + +@internal +def _foo2(a: uint256) -> uint256: + b: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 99 + +@internal +def _foo3(a: uint256, b: uint256) -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 42 + +@internal +def _foo4() -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 4 + +@external +def foo() -> Foo: + return Foo({ + a:self._bar(6, self._foo4(), self._foo2(self._foo3(9, 11))).b, + }) + """ + + c = get_contract(code) + + assert c.foo() == (3,) + + +def test_external_call_in_return_struct(get_contract): + code = """ +struct Bar: + a: uint256 + b: uint256 +@view +@external +def bar() -> Bar: + return Bar({a:3, b:4}) + """ + + code2 = """ +struct Foo: + a: uint256 + b: uint256 + c: uint256 + d: uint256 + e: uint256 +struct Bar: + a: uint256 + b: uint256 +interface IBar: + def bar() -> Bar: view + +@external +def foo(addr: address) -> Foo: + return Foo({ + a:1, + b:2, + c:IBar(addr).bar().a, + d:4, + e:5 + }) + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == (1, 2, 3, 4, 5) + + +def test_external_call_in_return_single_struct(get_contract): + code = """ +struct Bar: + a: uint256 +@view +@external +def bar() -> Bar: + return Bar({a:3}) + """ + + code2 = """ +struct Foo: + a: uint256 +struct Bar: + a: uint256 +interface IBar: + def bar() -> Bar: view + +@external +def foo(addr: address) -> Foo: + return Foo({ + a:IBar(addr).bar().a + }) + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == (3,) + + +def test_nested_external_call_in_return_struct(get_contract): + code = """ +struct Bar: + a: uint256 + b: uint256 + +@view +@external +def bar() -> Bar: + return Bar({a:3, b:4}) + +@view +@external +def baz(x: uint256) -> uint256: + return x+1 + """ + + code2 = """ +struct Foo: + a: uint256 + b: uint256 + c: uint256 + d: uint256 + e: uint256 +struct Bar: + a: uint256 + b: uint256 + +interface IBar: + def bar() -> Bar: view + def baz(a: uint256) -> uint256: view + +@external +def foo(addr: address) -> Foo: + return Foo({ + a:1, + b:2, + c:IBar(addr).bar().a, + d:4, + e:IBar(addr).baz(IBar(addr).bar().b) + }) + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == (1, 2, 3, 4, 5) + + +def test_nested_external_call_in_return_single_struct(get_contract): + code = """ +struct Bar: + a: uint256 + +@view +@external +def bar() -> Bar: + return Bar({a:3}) + +@view +@external +def baz(x: uint256) -> uint256: + return x+1 + """ + + code2 = """ +struct Foo: + a: uint256 +struct Bar: + a: uint256 + +interface IBar: + def bar() -> Bar: view + def baz(a: uint256) -> uint256: view + +@external +def foo(addr: address) -> Foo: + return Foo({ + a:IBar(addr).baz(IBar(addr).bar().a) + }) + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == (4,) + + +@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) +def test_string_inside_struct(get_contract, string): + code = f""" +struct Person: + name: String[6] + age: uint256 + +@external +def test_return() -> Person: + return Person({{ name:"{string}", age:42 }}) + """ + c1 = get_contract(code) + + code = """ +struct Person: + name: String[6] + age: uint256 + +interface jsonabi: + def test_return() -> Person: view + +@external +def test_values(a: address) -> Person: + return jsonabi(a).test_return() + """ + + c2 = get_contract(code) + assert c2.test_values(c1.address) == (string, 42) + + +@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) +def test_string_inside_single_struct(get_contract, string): + code = f""" +struct Person: + name: String[6] + +@external +def test_return() -> Person: + return Person({{ name:"{string}"}}) + """ + c1 = get_contract(code) + + code = """ +struct Person: + name: String[6] + +interface jsonabi: + def test_return() -> Person: view + +@external +def test_values(a: address) -> Person: + return jsonabi(a).test_return() + """ + + c2 = get_contract(code) + assert c2.test_values(c1.address) == (string,) diff --git a/tests/functional/codegen/calling_convention/test_return_struct.py b/tests/functional/codegen/calling_convention/test_return_struct.py deleted file mode 100644 index cdd8342d8a..0000000000 --- a/tests/functional/codegen/calling_convention/test_return_struct.py +++ /dev/null @@ -1,507 +0,0 @@ -import pytest - -from vyper.compiler import compile_code - -pytestmark = pytest.mark.usefixtures("memory_mocker") - - -def test_struct_return_abi(get_contract_with_gas_estimation): - code = """ -struct Voter: - weight: int128 - voted: bool - -@external -def test() -> Voter: - a: Voter = Voter({weight: 123, voted: True}) - return a - """ - - out = compile_code(code, output_formats=["abi"]) - abi = out["abi"][0] - - assert abi["name"] == "test" - - c = get_contract_with_gas_estimation(code) - - assert c.test() == (123, True) - - -def test_single_struct_return_abi(get_contract_with_gas_estimation): - code = """ -struct Voter: - voted: bool - -@external -def test() -> Voter: - a: Voter = Voter({voted: True}) - return a - """ - - out = compile_code(code, output_formats=["abi"]) - abi = out["abi"][0] - - assert abi["name"] == "test" - assert abi["outputs"][0]["type"] == "tuple" - - c = get_contract_with_gas_estimation(code) - - assert c.test() == (True,) - - -def test_struct_return(get_contract_with_gas_estimation): - code = """ -struct Foo: - x: int128 - y: uint256 - -_foo: Foo -_foos: HashMap[int128, Foo] - -@internal -def priv1() -> Foo: - return Foo({x: 1, y: 2}) -@external -def pub1() -> Foo: - return self.priv1() - -@internal -def priv2() -> Foo: - foo: Foo = Foo({x: 0, y: 0}) - foo.x = 3 - foo.y = 4 - return foo -@external -def pub2() -> Foo: - return self.priv2() - -@external -def pub3() -> Foo: - self._foo = Foo({x: 5, y: 6}) - return self._foo - -@external -def pub4() -> Foo: - self._foos[0] = Foo({x: 7, y: 8}) - return self._foos[0] - -@internal -def return_arg(foo: Foo) -> Foo: - return foo -@external -def pub5(foo: Foo) -> Foo: - return self.return_arg(foo) -@external -def pub6() -> Foo: - foo: Foo = Foo({x: 123, y: 456}) - return self.return_arg(foo) - """ - foo = (123, 456) - - c = get_contract_with_gas_estimation(code) - - assert c.pub1() == (1, 2) - assert c.pub2() == (3, 4) - assert c.pub3() == (5, 6) - assert c.pub4() == (7, 8) - assert c.pub5(foo) == foo - assert c.pub6() == foo - - -def test_single_struct_return(get_contract_with_gas_estimation): - code = """ -struct Foo: - x: int128 - -_foo: Foo -_foos: HashMap[int128, Foo] - -@internal -def priv1() -> Foo: - return Foo({x: 1}) -@external -def pub1() -> Foo: - return self.priv1() - -@internal -def priv2() -> Foo: - foo: Foo = Foo({x: 0}) - foo.x = 3 - return foo -@external -def pub2() -> Foo: - return self.priv2() - -@external -def pub3() -> Foo: - self._foo = Foo({x: 5}) - return self._foo - -@external -def pub4() -> Foo: - self._foos[0] = Foo({x: 7}) - return self._foos[0] - -@internal -def return_arg(foo: Foo) -> Foo: - return foo -@external -def pub5(foo: Foo) -> Foo: - return self.return_arg(foo) -@external -def pub6() -> Foo: - foo: Foo = Foo({x: 123}) - return self.return_arg(foo) - """ - foo = (123,) - - c = get_contract_with_gas_estimation(code) - - assert c.pub1() == (1,) - assert c.pub2() == (3,) - assert c.pub3() == (5,) - assert c.pub4() == (7,) - assert c.pub5(foo) == foo - assert c.pub6() == foo - - -def test_self_call_in_return_struct(get_contract): - code = """ -struct Foo: - a: uint256 - b: uint256 - c: uint256 - d: uint256 - e: uint256 - -@internal -def _foo() -> uint256: - a: uint256[10] = [6,7,8,9,10,11,12,13,14,15] - return 3 - -@external -def foo() -> Foo: - return Foo({a:1, b:2, c:self._foo(), d:4, e:5}) - """ - - c = get_contract(code) - - assert c.foo() == (1, 2, 3, 4, 5) - - -def test_self_call_in_return_single_struct(get_contract): - code = """ -struct Foo: - a: uint256 - -@internal -def _foo() -> uint256: - a: uint256[10] = [6,7,8,9,10,11,12,13,14,15] - return 3 - -@external -def foo() -> Foo: - return Foo({a:self._foo()}) - """ - - c = get_contract(code) - - assert c.foo() == (3,) - - -def test_call_in_call(get_contract): - code = """ -struct Foo: - a: uint256 - b: uint256 - c: uint256 - d: uint256 - e: uint256 - -@internal -def _foo(a: uint256, b: uint256, c: uint256) -> Foo: - return Foo({a:1, b:a, c:b, d:c, e:5}) - -@internal -def _foo2() -> uint256: - a: uint256[10] = [6,7,8,9,10,11,12,13,15,16] - return 4 - -@external -def foo() -> Foo: - return self._foo(2, 3, self._foo2()) - """ - - c = get_contract(code) - - assert c.foo() == (1, 2, 3, 4, 5) - - -def test_call_in_call_single_struct(get_contract): - code = """ -struct Foo: - a: uint256 - -@internal -def _foo(a: uint256) -> Foo: - return Foo({a:a}) - -@internal -def _foo2() -> uint256: - a: uint256[10] = [6,7,8,9,10,11,12,13,15,16] - return 4 - -@external -def foo() -> Foo: - return self._foo(self._foo2()) - """ - - c = get_contract(code) - - assert c.foo() == (4,) - - -def test_nested_calls_in_struct_return(get_contract): - code = """ -struct Foo: - a: uint256 - b: uint256 - c: uint256 - d: uint256 - e: uint256 -struct Bar: - a: uint256 - b: uint256 - -@internal -def _bar(a: uint256, b: uint256, c: uint256) -> Bar: - return Bar({a:415, b:3}) - -@internal -def _foo2(a: uint256) -> uint256: - b: uint256[10] = [6,7,8,9,10,11,12,13,14,15] - return 99 - -@internal -def _foo3(a: uint256, b: uint256) -> uint256: - c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] - return 42 - -@internal -def _foo4() -> uint256: - c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] - return 4 - -@external -def foo() -> Foo: - return Foo({ - a:1, - b:2, - c:self._bar(6, 7, self._foo2(self._foo3(9, 11))).b, - d:self._foo4(), - e:5 - }) - """ - - c = get_contract(code) - - assert c.foo() == (1, 2, 3, 4, 5) - - -def test_nested_calls_in_single_struct_return(get_contract): - code = """ -struct Foo: - a: uint256 -struct Bar: - a: uint256 - b: uint256 - -@internal -def _bar(a: uint256, b: uint256, c: uint256) -> Bar: - return Bar({a:415, b:3}) - -@internal -def _foo2(a: uint256) -> uint256: - b: uint256[10] = [6,7,8,9,10,11,12,13,14,15] - return 99 - -@internal -def _foo3(a: uint256, b: uint256) -> uint256: - c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] - return 42 - -@internal -def _foo4() -> uint256: - c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] - return 4 - -@external -def foo() -> Foo: - return Foo({ - a:self._bar(6, self._foo4(), self._foo2(self._foo3(9, 11))).b, - }) - """ - - c = get_contract(code) - - assert c.foo() == (3,) - - -def test_external_call_in_return_struct(get_contract): - code = """ -struct Bar: - a: uint256 - b: uint256 -@view -@external -def bar() -> Bar: - return Bar({a:3, b:4}) - """ - - code2 = """ -struct Foo: - a: uint256 - b: uint256 - c: uint256 - d: uint256 - e: uint256 -struct Bar: - a: uint256 - b: uint256 -interface IBar: - def bar() -> Bar: view - -@external -def foo(addr: address) -> Foo: - return Foo({ - a:1, - b:2, - c:IBar(addr).bar().a, - d:4, - e:5 - }) - """ - - c = get_contract(code) - c2 = get_contract(code2) - - assert c2.foo(c.address) == (1, 2, 3, 4, 5) - - -def test_external_call_in_return_single_struct(get_contract): - code = """ -struct Bar: - a: uint256 -@view -@external -def bar() -> Bar: - return Bar({a:3}) - """ - - code2 = """ -struct Foo: - a: uint256 -struct Bar: - a: uint256 -interface IBar: - def bar() -> Bar: view - -@external -def foo(addr: address) -> Foo: - return Foo({ - a:IBar(addr).bar().a - }) - """ - - c = get_contract(code) - c2 = get_contract(code2) - - assert c2.foo(c.address) == (3,) - - -def test_nested_external_call_in_return_struct(get_contract): - code = """ -struct Bar: - a: uint256 - b: uint256 - -@view -@external -def bar() -> Bar: - return Bar({a:3, b:4}) - -@view -@external -def baz(x: uint256) -> uint256: - return x+1 - """ - - code2 = """ -struct Foo: - a: uint256 - b: uint256 - c: uint256 - d: uint256 - e: uint256 -struct Bar: - a: uint256 - b: uint256 - -interface IBar: - def bar() -> Bar: view - def baz(a: uint256) -> uint256: view - -@external -def foo(addr: address) -> Foo: - return Foo({ - a:1, - b:2, - c:IBar(addr).bar().a, - d:4, - e:IBar(addr).baz(IBar(addr).bar().b) - }) - """ - - c = get_contract(code) - c2 = get_contract(code2) - - assert c2.foo(c.address) == (1, 2, 3, 4, 5) - - -def test_nested_external_call_in_return_single_struct(get_contract): - code = """ -struct Bar: - a: uint256 - -@view -@external -def bar() -> Bar: - return Bar({a:3}) - -@view -@external -def baz(x: uint256) -> uint256: - return x+1 - """ - - code2 = """ -struct Foo: - a: uint256 -struct Bar: - a: uint256 - -interface IBar: - def bar() -> Bar: view - def baz(a: uint256) -> uint256: view - -@external -def foo(addr: address) -> Foo: - return Foo({ - a:IBar(addr).baz(IBar(addr).bar().a) - }) - """ - - c = get_contract(code) - c2 = get_contract(code2) - - assert c2.foo(c.address) == (4,) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py deleted file mode 100644 index 74929c9496..0000000000 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ /dev/null @@ -1,163 +0,0 @@ -import pytest - -from vyper import compile_code -from vyper.exceptions import TypeMismatch - -pytestmark = pytest.mark.usefixtures("memory_mocker") - - -def test_return_type(get_contract_with_gas_estimation): - long_string = 35 * "test" - - code = """ -struct Chunk: - a: Bytes[8] - b: Bytes[8] - c: int128 -chunk: Chunk - -@deploy -def __init__(): - self.chunk.a = b"hello" - self.chunk.b = b"world" - self.chunk.c = 5678 - -@external -def out() -> (int128, address): - return 3333, 0x0000000000000000000000000000000000000001 - -@external -def out_literals() -> (int128, address, Bytes[6]): - return 1, 0x0000000000000000000000000000000000000000, b"random" - -@external -def out_bytes_first() -> (Bytes[4], int128): - return b"test", 1234 - -@external -def out_bytes_a(x: int128, y: Bytes[4]) -> (int128, Bytes[4]): - return x, y - -@external -def out_bytes_b(x: int128, y: Bytes[4]) -> (Bytes[4], int128, Bytes[4]): - return y, x, y - -@external -def four() -> (int128, Bytes[8], Bytes[8], int128): - return 1234, b"bytes", b"test", 4321 - -@external -def out_chunk() -> (Bytes[8], int128, Bytes[8]): - return self.chunk.a, self.chunk.c, self.chunk.b - -@external -def out_very_long_bytes() -> (int128, Bytes[1024], int128, address): - return 5555, b"testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest", 6666, 0x0000000000000000000000000000000000001234 # noqa - """ - - c = get_contract_with_gas_estimation(code) - - assert c.out() == [3333, "0x0000000000000000000000000000000000000001"] - assert c.out_literals() == [1, None, b"random"] - assert c.out_bytes_first() == [b"test", 1234] - assert c.out_bytes_a(5555555, b"test") == [5555555, b"test"] - assert c.out_bytes_b(5555555, b"test") == [b"test", 5555555, b"test"] - assert c.four() == [1234, b"bytes", b"test", 4321] - assert c.out_chunk() == [b"hello", 5678, b"world"] - assert c.out_very_long_bytes() == [ - 5555, - long_string.encode(), - 6666, - "0x0000000000000000000000000000000000001234", - ] - - -def test_return_type_signatures(get_contract_with_gas_estimation): - code = """ -@external -def out_literals() -> (int128, address, Bytes[6]): - return 1, 0x0000000000000000000000000000000000000000, b"random" - """ - - c = get_contract_with_gas_estimation(code) - assert c._classic_contract.abi[0]["outputs"] == [ - {"type": "int128", "name": ""}, - {"type": "address", "name": ""}, - {"type": "bytes", "name": ""}, - ] - - -def test_return_tuple_assign(get_contract_with_gas_estimation): - code = """ -@internal -def _out_literals() -> (int128, address, Bytes[10]): - return 1, 0x0000000000000000000000000000000000000000, b"random" - -@external -def out_literals() -> (int128, address, Bytes[10]): - return self._out_literals() - -@external -def test() -> (int128, address, Bytes[10]): - a: int128 = 0 - b: address = empty(address) - c: Bytes[10] = b"" - (a, b, c) = self._out_literals() - return a, b, c - """ - - c = get_contract_with_gas_estimation(code) - - assert c.out_literals() == c.test() == [1, None, b"random"] - - -def test_return_tuple_assign_storage(get_contract_with_gas_estimation): - code = """ -a: int128 -b: address -c: Bytes[20] -d: Bytes[20] - -@internal -def _out_literals() -> (int128, Bytes[20], address, Bytes[20]): - return 1, b"testtesttest", 0x0000000000000000000000000000000000000023, b"random" - -@external -def out_literals() -> (int128, Bytes[20], address, Bytes[20]): - return self._out_literals() - -@external -def test1() -> (int128, Bytes[20], address, Bytes[20]): - self.a, self.c, self.b, self.d = self._out_literals() - return self.a, self.c, self.b, self.d - -@external -def test2() -> (int128, address): - x: int128 = 0 - x, self.c, self.b, self.d = self._out_literals() - return x, self.b - -@external -def test3() -> (address, int128): - x: address = empty(address) - self.a, self.c, x, self.d = self._out_literals() - return x, self.a - """ - - c = get_contract_with_gas_estimation(code) - - addr = "0x" + "00" * 19 + "23" - assert c.out_literals() == [1, b"testtesttest", addr, b"random"] - assert c.out_literals() == c.test1() - assert c.test2() == [1, c.out_literals()[2]] - assert c.test3() == [c.out_literals()[2], 1] - - -def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): - code = """ -@external -def getTimeAndBalance() -> (bool, address): - return block.timestamp, self.balance - """ - with pytest.raises(TypeMismatch): - compile_code(code) diff --git a/tests/functional/codegen/calling_convention/test_tuple_return.py b/tests/functional/codegen/calling_convention/test_tuple_return.py deleted file mode 100644 index 670076cc24..0000000000 --- a/tests/functional/codegen/calling_convention/test_tuple_return.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest - - -@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) -def test_string_inside_tuple(get_contract, string): - code = f""" -@external -def test_return() -> (String[6], uint256): - return "{string}", 42 - """ - c1 = get_contract(code) - - code = """ -interface jsonabi: - def test_return() -> (String[6], uint256): view - -@external -def test_values(a: address) -> (String[6], uint256): - return jsonabi(a).test_return() - """ - - c2 = get_contract(code) - assert c2.test_values(c1.address) == [string, 42] - - -@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) -def test_bytes_inside_tuple(get_contract, string): - code = f""" -@external -def test_return() -> (Bytes[6], uint256): - return b"{string}", 42 - """ - c1 = get_contract(code) - - code = """ -interface jsonabi: - def test_return() -> (Bytes[6], uint256): view - -@external -def test_values(a: address) -> (Bytes[6], uint256): - return jsonabi(a).test_return() - """ - - c2 = get_contract(code) - assert c2.test_values(c1.address) == [bytes(string, "utf-8"), 42] diff --git a/tests/functional/codegen/calling_convention/test_struct_return.py b/tests/functional/codegen/types/test_struct.py similarity index 52% rename from tests/functional/codegen/calling_convention/test_struct_return.py rename to tests/functional/codegen/types/test_struct.py index d34d5128b6..0a6132200d 100644 --- a/tests/functional/codegen/calling_convention/test_struct_return.py +++ b/tests/functional/codegen/types/test_struct.py @@ -1,6 +1,3 @@ -import pytest - - def test_nested_struct(get_contract): code = """ struct Animal: @@ -51,61 +48,3 @@ def modify_nested_single_struct(_human: Human) -> Human: c = get_contract(code) assert c.modify_nested_single_struct({"animal": {"fur": "wool"}}) == (("wool is great",),) - - -@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) -def test_string_inside_struct(get_contract, string): - code = f""" -struct Person: - name: String[6] - age: uint256 - -@external -def test_return() -> Person: - return Person({{ name:"{string}", age:42 }}) - """ - c1 = get_contract(code) - - code = """ -struct Person: - name: String[6] - age: uint256 - -interface jsonabi: - def test_return() -> Person: view - -@external -def test_values(a: address) -> Person: - return jsonabi(a).test_return() - """ - - c2 = get_contract(code) - assert c2.test_values(c1.address) == (string, 42) - - -@pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) -def test_string_inside_single_struct(get_contract, string): - code = f""" -struct Person: - name: String[6] - -@external -def test_return() -> Person: - return Person({{ name:"{string}"}}) - """ - c1 = get_contract(code) - - code = """ -struct Person: - name: String[6] - -interface jsonabi: - def test_return() -> Person: view - -@external -def test_values(a: address) -> Person: - return jsonabi(a).test_return() - """ - - c2 = get_contract(code) - assert c2.test_values(c1.address) == (string,) From 4177314808e43d4f92ec8d44998f733d1261a903 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 19 Feb 2024 15:53:23 -0800 Subject: [PATCH 194/201] feat: more frontend optimizations (#3785) - optimize `VyperNode.get_descendants()` and `get_children()` - get rid of `sort_nodes()`, we can guarantee ordering the old fashioned way (topsort) - optimize `VyperNode.__hash__()` and `VyperNode.__init__()` - optimize `IntegerT.compare_type()` optimizes front-end compilation time by another 25% --- vyper/ast/nodes.py | 83 +++++++++++++++-------------- vyper/semantics/analysis/utils.py | 1 + vyper/semantics/types/primitives.py | 16 ++++-- 3 files changed, 56 insertions(+), 44 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 0ebe18ab5d..3e15a28512 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -34,6 +34,7 @@ "node_id", "_metadata", "_original_node", + "_cache_descendants", ) NODE_SRC_ATTRIBUTES = ( "col_offset", @@ -211,15 +212,17 @@ def _node_filter(node, filters): return True -def _sort_nodes(node_iterable): - # sorting function for VyperNode.get_children +def _apply_filters(node_iter, node_type, filters, reverse): + ret = node_iter + if node_type is not None: + ret = (i for i in ret if isinstance(i, node_type)) + if filters is not None: + ret = (i for i in ret if _node_filter(i, filters)) - def sortkey(key): - return float("inf") if key is None else key - - return sorted( - node_iterable, key=lambda k: (sortkey(k.lineno), sortkey(k.col_offset), k.node_id) - ) + ret = list(ret) + if reverse: + ret.reverse() + return ret def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None: @@ -257,10 +260,13 @@ class VyperNode: """ __slots__ = NODE_BASE_ATTRIBUTES + NODE_SRC_ATTRIBUTES + + _public_slots = [i for i in __slots__ if not i.startswith("_")] _only_empty_fields: tuple = () _translated_fields: dict = {} def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): + # this function is performance-sensitive """ AST node initializer method. @@ -275,21 +281,19 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): Dictionary of fields to be included within the node. """ self.set_parent(parent) - self._children: set = set() + self._children: list = [] self._metadata: NodeMetadata = NodeMetadata() self._original_node = None + self._cache_descendants = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset - value = kwargs.get(field_name) - if kwargs.get(field_name) is None: + value = kwargs.pop(field_name, None) + if value is None: value = getattr(parent, field_name, None) setattr(self, field_name, value) for field_name, value in kwargs.items(): - if field_name in NODE_SRC_ATTRIBUTES: - continue - if field_name in self._translated_fields: field_name = self._translated_fields[field_name] @@ -309,7 +313,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): # add to children of parent last to ensure an accurate hash is generated if parent is not None: - parent._children.add(self) + parent._children.append(self) # set parent, can be useful when inserting copied nodes into the AST def set_parent(self, parent: "VyperNode"): @@ -338,7 +342,7 @@ def from_node(cls, node: "VyperNode", **kwargs) -> "VyperNode": ------- Vyper node instance """ - ast_struct = {i: getattr(node, i) for i in VyperNode.__slots__ if not i.startswith("_")} + ast_struct = {i: getattr(node, i) for i in VyperNode._public_slots} ast_struct.update(ast_type=cls.__name__, **kwargs) return cls(**ast_struct) @@ -355,10 +359,11 @@ def get_fields(cls) -> set: return set(i for i in slot_fields if not i.startswith("_")) def __hash__(self): - values = [getattr(self, i, None) for i in VyperNode.__slots__ if not i.startswith("_")] + values = [getattr(self, i, None) for i in VyperNode._public_slots] return hash(tuple(values)) def __deepcopy__(self, memo): + # default implementation of deepcopy is a hotspot return pickle.loads(pickle.dumps(self)) def __eq__(self, other): @@ -537,14 +542,7 @@ def get_children( list Child nodes matching the filter conditions. """ - children = _sort_nodes(self._children) - if node_type is not None: - children = [i for i in children if isinstance(i, node_type)] - if reverse: - children.reverse() - if filters is None: - return children - return [i for i in children if _node_filter(i, filters)] + return _apply_filters(iter(self._children), node_type, filters, reverse) def get_descendants( self, @@ -553,6 +551,7 @@ def get_descendants( include_self: bool = False, reverse: bool = False, ) -> list: + # this function is performance-sensitive """ Return a list of descendant nodes of this node which match the given filter(s). @@ -589,19 +588,25 @@ def get_descendants( list Descendant nodes matching the filter conditions. """ - children = self.get_children(node_type, filters) - for node in self.get_children(): - children.extend(node.get_descendants(node_type, filters)) - if ( - include_self - and (not node_type or isinstance(self, node_type)) - and _node_filter(self, filters) - ): - children.append(self) - result = _sort_nodes(children) - if reverse: - result.reverse() - return result + ret = self._get_descendants(include_self) + return _apply_filters(ret, node_type, filters, reverse) + + def _get_descendants(self, include_self=True): + # get descendants in topsort order + if self._cache_descendants is None: + ret = [self] + for node in self._children: + ret.extend(node._get_descendants()) + + self._cache_descendants = ret + + ret = iter(self._cache_descendants) + + if not include_self: + s = next(ret) # pop + assert s is self + + return ret def get(self, field_str: str) -> Any: """ @@ -669,7 +674,7 @@ def add_to_body(self, node: VyperNode) -> None: self.body.append(node) node._depth = self._depth + 1 node._parent = self - self._children.add(node) + self._children.append(node) def remove_from_body(self, node: VyperNode) -> None: """ diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index abea600d88..21ca7a8d3f 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -480,6 +480,7 @@ def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo: def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List: + # this function is a performance hotspot """ Return a list of common possible types between one or more nodes. diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index d383f72ab2..d11a9595a3 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -251,11 +251,17 @@ def abi_type(self) -> ABIType: return ABI_GIntM(self.bits, self.is_signed) def compare_type(self, other: VyperType) -> bool: - if not super().compare_type(other): - return False - assert isinstance(other, IntegerT) # mypy - - return self.is_signed == other.is_signed and self.bits == other.bits + # this function is performance sensitive + # originally: + # if not super().compare_type(other): + # return False + # return self.is_signed == other.is_signed and self.bits == other.bits + + return ( # noqa: E721 + self.__class__ == other.__class__ + and self.is_signed == other.is_signed # type: ignore + and self.bits == other.bits # type: ignore + ) # helper function for readability. From 0752760807aa82aa4706e8f8df8c14fb08f8a678 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 20 Feb 2024 08:16:40 +0800 Subject: [PATCH 195/201] feat[lang]: use keyword arguments for struct instantiation (#3777) This commit changes struct instantiation from taking a single dict argument to taking keyword arguments. For backwards compatibility (and to ease users into v0.4.0), the old syntax is still supported, but a warning will be emitted. --------- Co-authored-by: Charles Cooper --- examples/auctions/blind_auction.vy | 8 +- examples/voting/ballot.vy | 8 +- tests/functional/builtins/codegen/test_abi.py | 4 +- .../builtins/codegen/test_abi_decode.py | 24 ++-- .../builtins/codegen/test_abi_encode.py | 26 ++--- .../functional/builtins/codegen/test_empty.py | 37 +++--- .../test_default_parameters.py | 14 +-- .../test_external_contract_calls.py | 40 +++---- .../codegen/calling_convention/test_return.py | 108 +++++++++--------- .../test_self_call_struct.py | 4 +- .../features/decorators/test_private.py | 6 +- .../features/iteration/test_for_in_list.py | 6 +- .../codegen/features/test_assignment.py | 4 +- .../codegen/features/test_bytes_map_keys.py | 8 +- .../codegen/features/test_immutable.py | 14 +-- .../codegen/features/test_internal_call.py | 10 +- .../codegen/features/test_logging.py | 2 +- .../codegen/features/test_packing.py | 2 +- .../codegen/integration/test_crowdfund.py | 2 +- .../modules/test_stateless_functions.py | 4 +- .../codegen/storage_variables/test_setters.py | 32 +++--- tests/functional/codegen/types/test_bytes.py | 8 +- .../codegen/types/test_dynamic_array.py | 47 ++++---- tests/functional/codegen/types/test_flag.py | 2 +- .../exceptions/test_syntax_exception.py | 5 + .../test_variable_declaration_exception.py | 5 - tests/functional/syntax/test_ann_assign.py | 8 +- tests/functional/syntax/test_block.py | 6 +- tests/functional/syntax/test_constants.py | 12 +- tests/functional/syntax/test_flag.py | 8 +- tests/functional/syntax/test_immutables.py | 2 +- tests/functional/syntax/test_invalids.py | 2 +- tests/functional/syntax/test_no_none.py | 6 +- tests/functional/syntax/test_structs.py | 105 +++++++++-------- .../cli/vyper_compile/test_compile_files.py | 4 +- .../unit/semantics/analysis/test_for_loop.py | 2 +- tests/unit/semantics/test_storage_slots.py | 14 +-- vyper/ast/parse.py | 24 ++++ vyper/codegen/expr.py | 22 ++-- vyper/semantics/analysis/local.py | 4 +- vyper/semantics/types/user.py | 36 +++--- 41 files changed, 357 insertions(+), 328 deletions(-) diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 966565138f..143206ccb4 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -69,10 +69,10 @@ def bid(_blindedBid: bytes32): assert numBids < MAX_BIDS # Add bid to mapping of all bids - self.bids[msg.sender][numBids] = Bid({ - blindedBid: _blindedBid, - deposit: msg.value - }) + self.bids[msg.sender][numBids] = Bid( + blindedBid=_blindedBid, + deposit=msg.value + ) self.bidCounts[msg.sender] += 1 diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index daaf712e0f..9016ae38c6 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -57,10 +57,10 @@ def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 for i: int128 in range(2): - self.proposals[i] = Proposal({ - name: _proposalNames[i], - voteCount: 0 - }) + self.proposals[i] = Proposal( + name=_proposalNames[i], + voteCount=0 + ) self.int128Proposals += 1 # Give a `voter` the right to vote on this ballot. diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py index 335f728a37..403ad6fc9a 100644 --- a/tests/functional/builtins/codegen/test_abi.py +++ b/tests/functional/builtins/codegen/test_abi.py @@ -112,8 +112,8 @@ def test_nested_struct(type, abi_type): @external def getStructList() -> {type}: return [ - NestedStruct({{t: MyStruct({{a: msg.sender, b: block.prevhash}}), foo: 1}}), - NestedStruct({{t: MyStruct({{a: msg.sender, b: block.prevhash}}), foo: 2}}) + NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=1), + NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=2) ] """ diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 96cbbe4c2d..dbbf195373 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -35,18 +35,18 @@ def abi_decode(x: Bytes[160]) -> (address, int128, bool, decimal, bytes32): @external def abi_decode_struct(x: Bytes[544]) -> Human: - human: Human = Human({ - name: "", - pet: Animal({ - name: "", - address_: empty(address), - id_: 0, - is_furry: False, - price: 0.0, - data: [0, 0, 0], - metadata: 0x0000000000000000000000000000000000000000000000000000000000000000 - }) - }) + human: Human = Human( + name = "", + pet = Animal( + name = "", + address_ = empty(address), + id_ = 0, + is_furry = False, + price = 0.0, + data = [0, 0, 0], + metadata = 0x0000000000000000000000000000000000000000000000000000000000000000 + ) + ) human = _abi_decode(x, Human) return human """ diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index 8709e31470..f818b04359 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -34,18 +34,18 @@ def abi_encode( ensure_tuple: bool, include_method_id: bool ) -> Bytes[548]: - human: Human = Human({ - name: name, - pet: Animal({ - name: pet_name, - address_: pet_address, - id_: pet_id, - is_furry: pet_is_furry, - price: pet_price, - data: pet_data, - metadata: pet_metadata - }), - }) + human: Human = Human( + name = name, + pet = Animal( + name = pet_name, + address_ = pet_address, + id_ = pet_id, + is_furry = pet_is_furry, + price = pet_price, + data = pet_data, + metadata = pet_metadata + ), + ) if ensure_tuple: if not include_method_id: return _abi_encode(human) # default ensure_tuple=True @@ -128,7 +128,7 @@ def test_abi_encode_length_failing(get_contract, assert_compile_failed, type, va @internal def foo(): - x: WrappedBytes = WrappedBytes({{bs: {value}}}) + x: WrappedBytes = WrappedBytes(bs={value}) y: {type}[96] = _abi_encode(x, ensure_tuple=True) # should be Bytes[128] """ diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index 896c845da2..a6c9c6441b 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -351,22 +351,22 @@ def test_empty_struct(get_contract_with_gas_estimation): @external def foo(): - self.foobar = FOOBAR({ - a: 1, - b: 2, - c: True, - d: 3.0, - e: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, - f: msg.sender - }) - bar: FOOBAR = FOOBAR({ - a: 1, - b: 2, - c: True, - d: 3.0, - e: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, - f: msg.sender - }) + self.foobar = FOOBAR( + a=1, + b=2, + c=True, + d=3.0, + e=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + f=msg.sender + ) + bar: FOOBAR = FOOBAR( + a=1, + b=2, + c=True, + d=3.0, + e=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + f=msg.sender + ) self.foobar = empty(FOOBAR) bar = empty(FOOBAR) @@ -575,10 +575,7 @@ def test_map_clear_struct(get_contract_with_gas_estimation): @external def set(): - self.structmap[123] = X({ - a: 333, - b: 444 - }) + self.structmap[123] = X(a=333, b=444) @external def get() -> (int128, int128): diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index 4153c7188e..e4db72ab8d 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -309,7 +309,7 @@ def foo(a: uint256 = 2**8): pass b: uint256 @external -def foo(bar: Bar = Bar({a: msg.sender, b: 1})): pass +def foo(bar: Bar = Bar(a=msg.sender, b=1)): pass """, """ struct Baz: @@ -321,7 +321,7 @@ def foo(bar: Bar = Bar({a: msg.sender, b: 1})): pass b: Baz @external -def foo(bar: Bar = Bar({a: msg.sender, b: Baz({c: block.coinbase, d: -10})})): pass +def foo(bar: Bar = Bar(a=msg.sender, b=Baz(c=block.coinbase, d=-10))): pass """, """ A: public(address) @@ -341,7 +341,7 @@ def foo(a: int112 = min_value(int112)): struct X: x: int128 y: address -BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +BAR: constant(X) = X(x=1, y=0x0000000000000000000000000000000000012345) @external def out_literals(a: int128 = BAR.x + 1) -> X: return BAR @@ -353,8 +353,8 @@ def out_literals(a: int128 = BAR.x + 1) -> X: struct Y: x: X y: uint256 -BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) -FOO: constant(Y) = Y({x: BAR, y: 256}) +BAR: constant(X) = X(x=1, y=0x0000000000000000000000000000000000012345) +FOO: constant(Y) = Y(x=BAR, y=256) @external def out_literals(a: int128 = FOO.x.x + 1) -> Y: return FOO @@ -363,7 +363,7 @@ def out_literals(a: int128 = FOO.x.x + 1) -> Y: struct Bar: a: bool -BAR: constant(Bar) = Bar({a: True}) +BAR: constant(Bar) = Bar(a=True) @external def foo(x: bool = True and not BAR.a): @@ -373,7 +373,7 @@ def foo(x: bool = True and not BAR.a): struct Bar: a: uint256 -BAR: constant(Bar) = Bar({ a: 123 }) +BAR: constant(Bar) = Bar(a=123) @external def foo(x: bool = BAR.a + 1 > 456): diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 8b3f30b5a5..bc9ac94bd5 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -1586,7 +1586,7 @@ def test_struct_return_external_contract_call_1(get_contract_with_gas_estimation y: address @external def out_literals() -> X: - return X({x: 1, y: 0x0000000000000000000000000000000000012345}) + return X(x=1, y=0x0000000000000000000000000000000000012345) """ contract_2 = """ @@ -1618,7 +1618,7 @@ def test_struct_return_external_contract_call_2(get_contract_with_gas_estimation z: Bytes[{ln}] @external def get_struct_x() -> X: - return X({{x: {i}, y: "{s}", z: b"{s}"}}) + return X(x={i}, y="{s}", z=b"{s}") """ contract_2 = f""" @@ -1648,7 +1648,7 @@ def test_struct_return_external_contract_call_3(get_contract_with_gas_estimation x: int128 @external def out_literals() -> X: - return X({x: 1}) + return X(x=1) """ contract_2 = """ @@ -1676,7 +1676,7 @@ def test_constant_struct_return_external_contract_call_1(get_contract_with_gas_e x: int128 y: address -BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +BAR: constant(X) = X(x=1, y=0x0000000000000000000000000000000000012345) @external def out_literals() -> X: @@ -1713,7 +1713,7 @@ def test_constant_struct_return_external_contract_call_2( y: String[{ln}] z: Bytes[{ln}] -BAR: constant(X) = X({{x: {i}, y: "{s}", z: b"{s}"}}) +BAR: constant(X) = X(x={i}, y="{s}", z=b"{s}") @external def get_struct_x() -> X: @@ -1746,7 +1746,7 @@ def test_constant_struct_return_external_contract_call_3(get_contract_with_gas_e struct X: x: int128 -BAR: constant(X) = X({x: 1}) +BAR: constant(X) = X(x=1) @external def out_literals() -> X: @@ -1778,7 +1778,7 @@ def test_constant_struct_member_return_external_contract_call_1(get_contract_wit x: int128 y: address -BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345}) +BAR: constant(X) = X(x=1, y=0x0000000000000000000000000000000000012345) @external def get_y() -> address: @@ -1811,7 +1811,7 @@ def test_constant_struct_member_return_external_contract_call_2( y: String[{ln}] z: Bytes[{ln}] -BAR: constant(X) = X({{x: {i}, y: "{s}", z: b"{s}"}}) +BAR: constant(X) = X(x={i}, y="{s}", z=b"{s}") @external def get_y() -> String[{ln}]: @@ -1840,7 +1840,7 @@ def test_constant_struct_member_return_external_contract_call_3(get_contract_wit struct X: x: int128 -BAR: constant(X) = X({x: 1}) +BAR: constant(X) = X(x=1) @external def get_x() -> int128: @@ -1874,7 +1874,7 @@ def test_constant_nested_struct_return_external_contract_call_1(get_contract_wit a: X b: uint256 -BAR: constant(A) = A({a: X({x: 1, y: 0x0000000000000000000000000000000000012345}), b: 777}) +BAR: constant(A) = A(a=X(x=1, y=0x0000000000000000000000000000000000012345), b=777) @external def out_literals() -> A: @@ -1919,7 +1919,7 @@ def test_constant_nested_struct_return_external_contract_call_2( a: X b: uint256 -BAR: constant(A) = A({{a: X({{x: {i}, y: "{s}", z: b"{s}"}}), b: 777}}) +BAR: constant(A) = A(a=X(x={i}, y="{s}", z=b"{s}"), b=777) @external def get_struct_a() -> A: @@ -1966,7 +1966,7 @@ def test_constant_nested_struct_return_external_contract_call_3(get_contract_wit c: A d: bool -BAR: constant(C) = C({c: A({a: X({x: 1, y: -1}), b: 777}), d: True}) +BAR: constant(C) = C(c=A(a=X(x=1, y=-1), b=777), d=True) @external def out_literals() -> C: @@ -2013,7 +2013,7 @@ def test_constant_nested_struct_member_return_external_contract_call_1( a: X b: uint256 -BAR: constant(A) = A({a: X({x: 1, y: 0x0000000000000000000000000000000000012345}), b: 777}) +BAR: constant(A) = A(a=X(x=1, y=0x0000000000000000000000000000000000012345), b=777) @external def get_y() -> address: @@ -2051,7 +2051,7 @@ def test_constant_nested_struct_member_return_external_contract_call_2( b: uint256 c: bool -BAR: constant(A) = A({{a: X({{x: {i}, y: "{s}", z: b"{s}"}}), b: 777, c: True}}) +BAR: constant(A) = A(a=X(x={i}, y="{s}", z=b"{s}"), b=777, c=True) @external def get_y() -> String[{ln}]: @@ -2091,7 +2091,7 @@ def test_constant_nested_struct_member_return_external_contract_call_3( c: A d: bool -BAR: constant(C) = C({c: A({a: X({x: 1, y: -1}), b: 777}), d: True}) +BAR: constant(C) = C(c=A(a=X(x=1, y=-1), b=777), d=True) @external def get_y() -> int128: @@ -2148,7 +2148,7 @@ def foo(x: X) -> Bytes[6]: nonpayable @external def bar(addr: address) -> Bytes[6]: - _X: X = X({x: 1, y: b"hello"}) + _X: X = X(x=1, y=b"hello") return Foo(addr).foo(_X) """ @@ -2180,7 +2180,7 @@ def foo(x: X) -> String[6]: nonpayable @external def bar(addr: address) -> String[6]: - _X: X = X({x: 1, y: "hello"}) + _X: X = X(x=1, y="hello") return Foo(addr).foo(_X) """ @@ -2208,7 +2208,7 @@ def foo(b: Bytes[6]) -> Bytes[6]: nonpayable @external def bar(addr: address) -> Bytes[6]: - _X: X = X({x: 1, y: b"hello"}) + _X: X = X(x=1, y=b"hello") return Foo(addr).foo(_X.y) """ @@ -2236,7 +2236,7 @@ def foo(b: String[6]) -> String[6]: nonpayable @external def bar(addr: address) -> String[6]: - _X: X = X({x: 1, y: "hello"}) + _X: X = X(x=1, y="hello") return Foo(addr).foo(_X.y) """ @@ -2433,7 +2433,7 @@ def return_64_bytes(): def return_64_bytes() -> BoolPair: nonpayable @external def bar(foo: Foo): - t: BoolPair = foo.return_64_bytes(default_return_value=BoolPair({x: True, y:True})) + t: BoolPair = foo.return_64_bytes(default_return_value=BoolPair(x=True, y=True)) assert t.x and t.y """ bad_1 = get_contract(bad_code_1) diff --git a/tests/functional/codegen/calling_convention/test_return.py b/tests/functional/codegen/calling_convention/test_return.py index 2db3689da3..ebc600956e 100644 --- a/tests/functional/codegen/calling_convention/test_return.py +++ b/tests/functional/codegen/calling_convention/test_return.py @@ -250,7 +250,7 @@ def test_struct_return_abi(get_contract_with_gas_estimation): @external def test() -> Voter: - a: Voter = Voter({weight: 123, voted: True}) + a: Voter = Voter(weight=123, voted=True) return a """ @@ -271,7 +271,7 @@ def test_single_struct_return_abi(get_contract_with_gas_estimation): @external def test() -> Voter: - a: Voter = Voter({voted: True}) + a: Voter = Voter(voted=True) return a """ @@ -297,14 +297,14 @@ def test_struct_return(get_contract_with_gas_estimation): @internal def priv1() -> Foo: - return Foo({x: 1, y: 2}) + return Foo(x= 1, y=2) @external def pub1() -> Foo: return self.priv1() @internal def priv2() -> Foo: - foo: Foo = Foo({x: 0, y: 0}) + foo: Foo = Foo(x= 0, y=0) foo.x = 3 foo.y = 4 return foo @@ -314,12 +314,12 @@ def pub2() -> Foo: @external def pub3() -> Foo: - self._foo = Foo({x: 5, y: 6}) + self._foo = Foo(x= 5, y=6) return self._foo @external def pub4() -> Foo: - self._foos[0] = Foo({x: 7, y: 8}) + self._foos[0] = Foo(x= 7, y=8) return self._foos[0] @internal @@ -330,7 +330,7 @@ def pub5(foo: Foo) -> Foo: return self.return_arg(foo) @external def pub6() -> Foo: - foo: Foo = Foo({x: 123, y: 456}) + foo: Foo = Foo(x= 123, y=456) return self.return_arg(foo) """ foo = (123, 456) @@ -355,14 +355,14 @@ def test_single_struct_return(get_contract_with_gas_estimation): @internal def priv1() -> Foo: - return Foo({x: 1}) + return Foo(x=1) @external def pub1() -> Foo: return self.priv1() @internal def priv2() -> Foo: - foo: Foo = Foo({x: 0}) + foo: Foo = Foo(x=0) foo.x = 3 return foo @external @@ -371,12 +371,12 @@ def pub2() -> Foo: @external def pub3() -> Foo: - self._foo = Foo({x: 5}) + self._foo = Foo(x=5) return self._foo @external def pub4() -> Foo: - self._foos[0] = Foo({x: 7}) + self._foos[0] = Foo(x=7) return self._foos[0] @internal @@ -387,7 +387,7 @@ def pub5(foo: Foo) -> Foo: return self.return_arg(foo) @external def pub6() -> Foo: - foo: Foo = Foo({x: 123}) + foo: Foo = Foo(x=123) return self.return_arg(foo) """ foo = (123,) @@ -418,7 +418,7 @@ def _foo() -> uint256: @external def foo() -> Foo: - return Foo({a:1, b:2, c:self._foo(), d:4, e:5}) + return Foo(a=1, b=2, c=self._foo(), d=4, e=5) """ c = get_contract(code) @@ -438,7 +438,7 @@ def _foo() -> uint256: @external def foo() -> Foo: - return Foo({a:self._foo()}) + return Foo(a=self._foo()) """ c = get_contract(code) @@ -457,7 +457,7 @@ def test_call_in_call(get_contract): @internal def _foo(a: uint256, b: uint256, c: uint256) -> Foo: - return Foo({a:1, b:a, c:b, d:c, e:5}) + return Foo(a=1, b=a, c=b, d=c, e=5) @internal def _foo2() -> uint256: @@ -481,7 +481,7 @@ def test_call_in_call_single_struct(get_contract): @internal def _foo(a: uint256) -> Foo: - return Foo({a:a}) + return Foo(a=a) @internal def _foo2() -> uint256: @@ -512,7 +512,7 @@ def test_nested_calls_in_struct_return(get_contract): @internal def _bar(a: uint256, b: uint256, c: uint256) -> Bar: - return Bar({a:415, b:3}) + return Bar(a=415, b=3) @internal def _foo2(a: uint256) -> uint256: @@ -531,13 +531,13 @@ def _foo4() -> uint256: @external def foo() -> Foo: - return Foo({ - a:1, - b:2, - c:self._bar(6, 7, self._foo2(self._foo3(9, 11))).b, - d:self._foo4(), - e:5 - }) + return Foo( + a=1, + b=2, + c=self._bar(6, 7, self._foo2(self._foo3(9, 11))).b, + d=self._foo4(), + e=5 + ) """ c = get_contract(code) @@ -555,7 +555,7 @@ def test_nested_calls_in_single_struct_return(get_contract): @internal def _bar(a: uint256, b: uint256, c: uint256) -> Bar: - return Bar({a:415, b:3}) + return Bar(a=415, b=3) @internal def _foo2(a: uint256) -> uint256: @@ -574,9 +574,9 @@ def _foo4() -> uint256: @external def foo() -> Foo: - return Foo({ - a:self._bar(6, self._foo4(), self._foo2(self._foo3(9, 11))).b, - }) + return Foo( + a=self._bar(6, self._foo4(), self._foo2(self._foo3(9, 11))).b, + ) """ c = get_contract(code) @@ -592,7 +592,7 @@ def test_external_call_in_return_struct(get_contract): @view @external def bar() -> Bar: - return Bar({a:3, b:4}) + return Bar(a=3, b=4) """ code2 = """ @@ -610,13 +610,13 @@ def bar() -> Bar: view @external def foo(addr: address) -> Foo: - return Foo({ - a:1, - b:2, - c:IBar(addr).bar().a, - d:4, - e:5 - }) + return Foo( + a=1, + b=2, + c=IBar(addr).bar().a, + d=4, + e=5 + ) """ c = get_contract(code) @@ -632,7 +632,7 @@ def test_external_call_in_return_single_struct(get_contract): @view @external def bar() -> Bar: - return Bar({a:3}) + return Bar(a=3) """ code2 = """ @@ -645,9 +645,9 @@ def bar() -> Bar: view @external def foo(addr: address) -> Foo: - return Foo({ - a:IBar(addr).bar().a - }) + return Foo( + a=IBar(addr).bar().a + ) """ c = get_contract(code) @@ -665,7 +665,7 @@ def test_nested_external_call_in_return_struct(get_contract): @view @external def bar() -> Bar: - return Bar({a:3, b:4}) + return Bar(a=3, b=4) @view @external @@ -690,13 +690,13 @@ def baz(a: uint256) -> uint256: view @external def foo(addr: address) -> Foo: - return Foo({ - a:1, - b:2, - c:IBar(addr).bar().a, - d:4, - e:IBar(addr).baz(IBar(addr).bar().b) - }) + return Foo( + a=1, + b=2, + c=IBar(addr).bar().a, + d=4, + e=IBar(addr).baz(IBar(addr).bar().b) + ) """ c = get_contract(code) @@ -713,7 +713,7 @@ def test_nested_external_call_in_return_single_struct(get_contract): @view @external def bar() -> Bar: - return Bar({a:3}) + return Bar(a=3) @view @external @@ -733,9 +733,9 @@ def baz(a: uint256) -> uint256: view @external def foo(addr: address) -> Foo: - return Foo({ - a:IBar(addr).baz(IBar(addr).bar().a) - }) + return Foo( + a=IBar(addr).baz(IBar(addr).bar().a) + ) """ c = get_contract(code) @@ -753,7 +753,7 @@ def test_string_inside_struct(get_contract, string): @external def test_return() -> Person: - return Person({{ name:"{string}", age:42 }}) + return Person(name="{string}", age=42) """ c1 = get_contract(code) @@ -782,7 +782,7 @@ def test_string_inside_single_struct(get_contract, string): @external def test_return() -> Person: - return Person({{ name:"{string}"}}) + return Person(name="{string}") """ c1 = get_contract(code) diff --git a/tests/functional/codegen/calling_convention/test_self_call_struct.py b/tests/functional/codegen/calling_convention/test_self_call_struct.py index 98aeba6915..f3ec96f1c0 100644 --- a/tests/functional/codegen/calling_convention/test_self_call_struct.py +++ b/tests/functional/codegen/calling_convention/test_self_call_struct.py @@ -10,7 +10,7 @@ def test_call_to_self_struct(w3, get_contract): @internal @view def get_my_struct(_e1: decimal, _e2: uint256) -> MyStruct: - return MyStruct({e1: _e1, e2: _e2}) + return MyStruct(e1=_e1, e2=_e2) @external @view @@ -42,7 +42,7 @@ def test_call_to_self_struct_2(get_contract): @internal @view def get_my_struct(_e1: decimal) -> MyStruct: - return MyStruct({e1: _e1}) + return MyStruct(e1=_e1) @external @view diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 193112f02b..15243ae3f3 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -594,7 +594,7 @@ def foo(a: int128) -> (int128, int128): @internal def _foo(_one: uint8) ->A: - return A({one: _one}) + return A(one=_one) @external def foo() -> A: @@ -611,7 +611,7 @@ def foo() -> A: @internal def _foo(_many: uint256[4], _one: uint256) -> A: - return A({many: _many, one: _one}) + return A(many=_many, one=_one) @external def foo() -> A: @@ -628,7 +628,7 @@ def foo() -> A: @internal def _foo(_many: uint256[4], _one: uint256) -> A: - return A({many: _many, one: _one}) + return A(many=_many, one=_one) @external def foo() -> (uint256[4], uint256): diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index e1bd8f313d..5d9462e7af 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -52,8 +52,8 @@ def data() -> int128: @external def data() -> int128: sss: DynArray[DynArray[S, 10], 10] = [ - [S({x:1, y:2})], - [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] + [S(x=1, y=2)], + [S(x=3, y=4), S(x=5, y=6), S(x=7, y=8), S(x=9, y=10)] ] ret: int128 = 0 for ss: DynArray[S, 10] in sss: @@ -133,7 +133,7 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss: S[1] in [[S({x:1, y:2})]]: + for ss: S[1] in [[S(x=1, y=2)]]: for s: S in ss: ret += s.x + s.y return ret""", diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index aebb13eefa..a276eb7ea9 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -78,7 +78,7 @@ def test_internal_assign_struct(get_contract_with_gas_estimation): @internal def foo(x: Foo) -> Foo: - x = Foo({a: 789, b: [Bar.BAZ, Bar.BAK, Bar.BAD], c: \"conda\"}) + x = Foo(a=789, b=[Bar.BAZ, Bar.BAK, Bar.BAD], c=\"conda\") return x @external @@ -437,7 +437,7 @@ def test_assign_rhs_lhs_overlap_struct(get_contract): @external def bug(p: Point) -> Point: t: Point = p - t = Point({x: t.y, y: t.x}) + t = Point(x=t.y, y=t.x) return t """ c = get_contract(code) diff --git a/tests/functional/codegen/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py index 22df767f02..c70ffb26ce 100644 --- a/tests/functional/codegen/features/test_bytes_map_keys.py +++ b/tests/functional/codegen/features/test_bytes_map_keys.py @@ -121,12 +121,12 @@ def __init__(): @external def get_one() -> int128: - b: Foo = Foo({one: b"hello", two: b"potato"}) + b: Foo = Foo(one=b"hello", two=b"potato") return self.a[b.one] @external def get_two() -> int128: - b: Foo = Foo({one: b"hello", two: b"potato"}) + b: Foo = Foo(one=b"hello", two=b"potato") return self.a[b.two] """ @@ -149,7 +149,7 @@ def test_struct_bytes_key_storage(get_contract): def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 - self.b = Foo({one: b"hello", two: b"potato"}) + self.b = Foo(one=b"hello", two=b"potato") @external def get_one() -> int128: @@ -218,7 +218,7 @@ def test_struct_bytes_hashmap_as_key_in_other_hashmap(get_contract): @deploy def __init__(): self.foo[b"hello"] = 31337 - self.bar[12] = Thing({name: b"hello"}) + self.bar[12] = Thing(name=b"hello") @external def do_the_thing(_index: uint256) -> uint256: diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index d0bc47c238..49ff54b353 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -91,12 +91,12 @@ def test_struct_immutable(get_contract): @deploy def __init__(_a: uint256, _b: uint256, _c: address, _d: int256): - my_struct = MyStruct({ - a: _a, - b: _b, - c: _c, - d: _d - }) + my_struct = MyStruct( + a=_a, + b=_b, + c=_c, + d=_d + ) @view @external @@ -117,7 +117,7 @@ def test_complex_immutable_modifiable(get_contract): @deploy def __init__(a: uint256): - my_struct = MyStruct({a: a}) + my_struct = MyStruct(a=a) # struct members are modifiable after initialization my_struct.a += 1 diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index 422f53fdeb..8320922e8e 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -550,7 +550,7 @@ def test_struct_return_1(get_contract_with_gas_estimation, i, ln, s): @internal def get_struct_x() -> X: - return X({{x: {i}, y: "{s}", z: b"{s}"}}) + return X(x={i}, y="{s}", z=b"{s}") @external def test() -> (int128, String[{ln}], Bytes[{ln}]): @@ -575,7 +575,7 @@ def _foo(x: X) -> Bytes[6]: @external def bar() -> Bytes[6]: - _X: X = X({x: 1, y: b"hello"}) + _X: X = X(x=1, y=b"hello") return self._foo(_X) """ @@ -596,7 +596,7 @@ def _foo(x: X) -> String[6]: @external def bar() -> String[6]: - _X: X = X({x: 1, y: "hello"}) + _X: X = X(x=1, y="hello") return self._foo(_X) """ @@ -617,7 +617,7 @@ def _foo(s: Bytes[6]) -> Bytes[6]: @external def bar() -> Bytes[6]: - _X: X = X({x: 1, y: b"hello"}) + _X: X = X(x=1, y=b"hello") return self._foo(_X.y) """ @@ -638,7 +638,7 @@ def _foo(s: String[6]) -> String[6]: @external def bar() -> String[6]: - _X: X = X({x: 1, y: "hello"}) + _X: X = X(x=1, y="hello") return self._foo(_X.y) """ diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 8b80811d02..cf64d271a9 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -493,7 +493,7 @@ def test_event_logging_with_multiple_logs_topics_and_data( @external def foo(): log MyLog(1, b'bar') - log YourLog(self, MyStruct({x: 1, y: b'abc', z: SmallStruct({t: 'house', w: 13.5})})) + log YourLog(self, MyStruct(x=1, y=b'abc', z=SmallStruct(t='house', w=13.5))) """ c = get_contract_with_gas_estimation(loggy_code) diff --git a/tests/functional/codegen/features/test_packing.py b/tests/functional/codegen/features/test_packing.py index bb3ccecbd8..3a18b5e88b 100644 --- a/tests/functional/codegen/features/test_packing.py +++ b/tests/functional/codegen/features/test_packing.py @@ -30,7 +30,7 @@ def foo() -> int128: def fop() -> int128: _x: int128 = 0 _y: int128[5] = [0, 0, 0, 0, 0] - _z: Z = Z({foo: [0, 0, 0], bar: [Bar({a: 0, b: 0}), Bar({a: 0, b: 0})]}) + _z: Z = Z(foo=[0, 0, 0], bar=[Bar(a=0, b=0), Bar(a=0, b=0)]) _a: int128 = 0 _x = 1 _y[0] = 2 diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 1a8b3f7e9f..b07ce5dc22 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -121,7 +121,7 @@ def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): def participate(): assert block.timestamp < self.deadline nfi: int128 = self.nextFunderIndex - self.funders[nfi] = Funder({sender: msg.sender, value: msg.value}) + self.funders[nfi] = Funder(sender=msg.sender, value=msg.value) self.nextFunderIndex = nfi + 1 @external diff --git a/tests/functional/codegen/modules/test_stateless_functions.py b/tests/functional/codegen/modules/test_stateless_functions.py index 26c3f338fb..722b287d98 100644 --- a/tests/functional/codegen/modules/test_stateless_functions.py +++ b/tests/functional/codegen/modules/test_stateless_functions.py @@ -159,7 +159,7 @@ def test_library_structs(get_contract, make_input_bundle): @internal def foo() -> SomeStruct: - return SomeStruct({x: 1}) + return SomeStruct(x=1) """ contract_source = """ import library @@ -170,7 +170,7 @@ def bar(s: library.SomeStruct): @external def baz() -> library.SomeStruct: - return library.SomeStruct({x: 2}) + return library.SomeStruct(x=2) @external def qux() -> library.SomeStruct: diff --git a/tests/functional/codegen/storage_variables/test_setters.py b/tests/functional/codegen/storage_variables/test_setters.py index 119157977a..cf4138b939 100644 --- a/tests/functional/codegen/storage_variables/test_setters.py +++ b/tests/functional/codegen/storage_variables/test_setters.py @@ -91,16 +91,16 @@ def test_multi_setter_struct_test(get_contract_with_gas_estimation): @external def foo() -> int128: foo0: int128 = 1 - self.dog[0] = Dog({foo: foo0, bar: 2}) - self.dog[1] = Dog({foo: 3, bar: 4}) - self.dog[2] = Dog({foo: 5, bar: 6}) + self.dog[0] = Dog(foo=foo0, bar=2) + self.dog[1] = Dog(foo=3, bar=4) + self.dog[2] = Dog(foo=5, bar=6) return self.dog[0].foo + self.dog[0].bar * 10 + self.dog[1].foo * 100 + \ self.dog[1].bar * 1000 + self.dog[2].foo * 10000 + self.dog[2].bar * 100000 @external def fop() -> int128: - self.z = [Z({foo: [1, 2, 3], bar: [Bar({a: 4, b: 5}), Bar({a: 2, b: 3})]}), - Z({foo: [6, 7, 8], bar: [Bar({a: 9, b: 1}), Bar({a: 7, b: 8})]})] + self.z = [Z(foo=[1, 2, 3], bar=[Bar(a=4, b=5), Bar(a=2, b=3)]), + Z(foo=[6, 7, 8], bar=[Bar(a=9, b=1), Bar(a=7, b=8)])] return self.z[0].foo[0] + self.z[0].foo[1] * 10 + self.z[0].foo[2] * 100 + \ self.z[0].bar[0].a * 1000 + \ self.z[0].bar[0].b * 10000 + \ @@ -116,15 +116,15 @@ def fop() -> int128: @external def goo() -> int128: - god: Goo[3] = [Goo({foo: 1, bar: 2}), Goo({foo: 3, bar: 4}), Goo({foo: 5, bar: 6})] + god: Goo[3] = [Goo(foo=1, bar=2), Goo(foo=3, bar=4), Goo(foo=5, bar=6)] return god[0].foo + god[0].bar * 10 + god[1].foo * 100 + \ god[1].bar * 1000 + god[2].foo * 10000 + god[2].bar * 100000 @external def gop() -> int128: zed: Zed[2] = [ - Zed({foo: [1, 2, 3], bar: [Bar({a: 4, b: 5}), Bar({a: 2, b: 3})]}), - Zed({foo: [6, 7, 8], bar: [Bar({a: 9, b: 1}), Bar({a: 7, b: 8})]}) + Zed(foo=[1, 2, 3], bar=[Bar(a=4, b=5), Bar(a=2, b=3)]), + Zed(foo=[6, 7, 8], bar=[Bar(a=9, b=1), Bar(a=7, b=8)]) ] return zed[0].foo[0] + zed[0].foo[1] * 10 + \ zed[0].foo[2] * 100 + \ @@ -157,7 +157,7 @@ def test_struct_assignment_order(get_contract, assert_compile_failed): @external @view def test2() -> uint256: - foo: Foo = Foo({b: 2, a: 297}) + foo: Foo = Foo(b=2, a=297) return foo.a """ assert_compile_failed(lambda: get_contract(code), InvalidAttribute) @@ -193,25 +193,25 @@ def test_composite_setter_test(get_contract_with_gas_estimation): @external def foo() -> int128: - self.mom = Mom({a: [C({c: 1}), C({c: 2}), C({c: 3})], b: 4}) - non: C = C({c: 5}) + self.mom = Mom(a=[C(c=1), C(c=2), C(c=3)], b=4) + non: C = C(c=5) self.mom.a[0] = non - non = C({c: 6}) + non = C(c=6) self.mom.a[2] = non return self.mom.a[0].c + self.mom.a[1].c * 10 + self.mom.a[2].c * 100 + self.mom.b * 1000 @external def fop() -> int128: - popp: Mom = Mom({a: [C({c: 1}), C({c: 2}), C({c: 3})], b: 4}) - self.qoq = C({c: 5}) + popp: Mom = Mom(a=[C(c=1), C(c=2), C(c=3)], b=4) + self.qoq = C(c=5) popp.a[0] = self.qoq - self.qoq = C({c: 6}) + self.qoq = C(c=6) popp.a[2] = self.qoq return popp.a[0].c + popp.a[1].c * 10 + popp.a[2].c * 100 + popp.b * 1000 @external def foq() -> int128: - popp: Mom = Mom({a: [C({c: 1}), C({c: 2}), C({c: 3})], b: 4}) + popp: Mom = Mom(a=[C(c=1), C(c=2), C(c=3)], b=4) popp.a[0] = empty(C) popp.a[2] = empty(C) return popp.a[0].c + popp.a[1].c * 10 + popp.a[2].c * 100 + popp.b * 1000 diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 99e5835f6e..f8ae65cc54 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -132,7 +132,7 @@ def test_test_bytes5(get_contract_with_gas_estimation): @external def foo(inp1: Bytes[40], inp2: Bytes[45]): - self.g = G({a: inp1, b: inp2}) + self.g = G(a=inp1, b=inp2) @external def check1() -> Bytes[50]: @@ -144,17 +144,17 @@ def check2() -> Bytes[50]: @external def bar(inp1: Bytes[40], inp2: Bytes[45]) -> Bytes[50]: - h: H = H({a: inp1, b: inp2}) + h: H = H(a=inp1, b=inp2) return h.a @external def bat(inp1: Bytes[40], inp2: Bytes[45]) -> Bytes[50]: - h: H = H({a: inp1, b: inp2}) + h: H = H(a=inp1, b=inp2) return h.b @external def quz(inp1: Bytes[40], inp2: Bytes[45]): - h: H = H({a: inp1, b: inp2}) + h: H = H(a=inp1, b=inp2) self.g.a = h.a self.g.b = h.b """ diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index fc3223caaf..b55f07639b 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1362,12 +1362,12 @@ def test_list_of_structs_lists_with_nested_lists(get_contract, tx_failed): def foo(x: uint8) -> uint8: b: DynArray[Bar[2], 2] = [ [ - Bar({a: [[x, x + 1], [x + 2, x + 3]]}), - Bar({a: [[x + 4, x +5], [x + 6, x + 7]]}) + Bar(a=[[x, x + 1], [x + 2, x + 3]]), + Bar(a=[[x + 4, x +5], [x + 6, x + 7]]) ], [ - Bar({a: [[x + 8, x + 9], [x + 10, x + 11]]}), - Bar({a: [[x + 12, x + 13], [x + 14, x + 15]]}) + Bar(a=[[x + 8, x + 9], [x + 10, x + 11]]), + Bar(a=[[x + 12, x + 13], [x + 14, x + 15]]) ], ] return b[0][0].a[0][0] + b[0][1].a[1][1] + b[1][0].a[0][1] + b[1][1].a[1][0] @@ -1503,11 +1503,11 @@ def _foo3() -> DynArray[DynArray[DynArray[uint256, 2], 2], 2]: @external def bar() -> DynArray[DynArray[DynArray[uint256, 2], 2], 2]: - foo: Foo = Foo({ - a1: self._foo(), - a2: self._foo2(), - a3: self._foo3(), - }) + foo: Foo = Foo( + a1=self._foo(), + a2=self._foo2(), + a3=self._foo3(), + ) return foo.a3 """ c = get_contract(code) @@ -1524,12 +1524,12 @@ def test_struct_of_lists_2(get_contract): @internal def _foo(x: int128) -> Foo: - f: Foo = Foo({ - b: b"hello", - da: [x, x * 2], - sa: [x + 1, x + 2, x + 3, x + 4, x + 5], - some_int: x - 1 - }) + f: Foo = Foo( + b=b"hello", + da=[x, x * 2], + sa=[x + 1, x + 2, x + 3, x + 4, x + 5], + some_int=x - 1 + ) return f @external @@ -1550,12 +1550,11 @@ def test_struct_of_lists_3(get_contract): @internal def _foo(x: int128) -> Foo: - f: Foo = Foo({ - a: [x, x * 2], - b: [0x0000000000000000000000000000000000000012], - c: [False, True, False] - - }) + f: Foo = Foo( + a=[x, x * 2], + b=[0x0000000000000000000000000000000000000012], + c=[False, True, False] + ) return f @external @@ -1577,15 +1576,15 @@ def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): @internal def _foo() -> nestedFoo: - return nestedFoo({a1: [ + return nestedFoo(a1=[ [[3, 7], [7, 3]], [[7, 3], [3, 7]], - ]}) + ]) @internal def _foo2() -> Foo: _nF1: nestedFoo = self._foo() - return Foo({b1: [[[_nF1, _nF1], [_nF1, _nF1]], [[_nF1, _nF1], [_nF1, _nF1]]]}) + return Foo(b1=[[[_nF1, _nF1], [_nF1, _nF1]], [[_nF1, _nF1], [_nF1, _nF1]]]) @internal def _foo3(f: Foo) -> Foo: diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index dd9c867a96..68aab1968f 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -281,7 +281,7 @@ def test_struct_with_flag(get_contract_with_gas_estimation): @external def get_flag_from_struct() -> Foobar: - f: Foo = Foo({a: 1, b: Foobar.BAR}) + f: Foo = Foo(a=1, b=Foobar.BAR) return f.b """ c = get_contract_with_gas_estimation(code) diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 53a9550a7d..80f499ac89 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -98,6 +98,11 @@ def foo(): for i: $$$ in range(0, 10): pass """, + """ +struct S: + x: int128 +s: S = S(x=int128, 1) + """, ] diff --git a/tests/functional/syntax/exceptions/test_variable_declaration_exception.py b/tests/functional/syntax/exceptions/test_variable_declaration_exception.py index f34c9a33c4..42c48dbe32 100644 --- a/tests/functional/syntax/exceptions/test_variable_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_variable_declaration_exception.py @@ -13,11 +13,6 @@ def foo() -> int128: """ struct S: x: int128 -s: S = S({x: int128}, 1) - """, - """ -struct S: - x: int128 s: S = S() """, """ diff --git a/tests/functional/syntax/test_ann_assign.py b/tests/functional/syntax/test_ann_assign.py index 7fdb1328c2..23ebeb9560 100644 --- a/tests/functional/syntax/test_ann_assign.py +++ b/tests/functional/syntax/test_ann_assign.py @@ -59,7 +59,7 @@ def data() -> int128: b: decimal @external def foo() -> int128: - s: S = S({a: 1.2, b: 1}) + s: S = S(a=1.2, b=1) return s.a """, TypeMismatch, @@ -71,7 +71,7 @@ def foo() -> int128: b: decimal @external def foo() -> int128: - s: S = S({a: 1}) + s: S = S(a=1) """, VariableDeclarationException, ), @@ -82,7 +82,7 @@ def foo() -> int128: b: decimal @external def foo() -> int128: - s: S = S({b: 1.2, a: 1}) + s: S = S(b=1.2, a=1) """, InvalidAttribute, ), @@ -93,7 +93,7 @@ def foo() -> int128: b: decimal @external def foo() -> int128: - s: S = S({a: 1, b: 1.2, c: 1, d: 33, e: 55}) + s: S = S(a=1, b=1.2, c=1, d=33, e=55) return s.a """, UnknownAttribute, diff --git a/tests/functional/syntax/test_block.py b/tests/functional/syntax/test_block.py index 8d8bffb697..aea39aa9c7 100644 --- a/tests/functional/syntax/test_block.py +++ b/tests/functional/syntax/test_block.py @@ -63,8 +63,8 @@ def add_record(): y: int128 @external def add_record(): - a: X = X({x: block.timestamp}) - b: Y = Y({y: 5}) + a: X = X(x=block.timestamp) + b: Y = Y(y=5) a.x = b.y """, """ @@ -123,7 +123,7 @@ def foo(): x: uint256 @external def add_record(): - a: X = X({x: block.timestamp}) + a: X = X(x=block.timestamp) a.x = block.gaslimit a.x = block.basefee a.x = 5 diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 63abf24485..5eb9eefe25 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -135,7 +135,7 @@ def test(VAL: uint256): a: uint256 b: uint256 -CONST_BAR: constant(Foo) = Foo({a: 1, b: block.number}) +CONST_BAR: constant(Foo) = Foo(a=1, b=block.number) """, StateAccessViolation, ), @@ -163,7 +163,7 @@ def foo() -> uint256: struct Foo: a : uint256 -x: constant(Foo) = Foo({a: 1}) +x: constant(Foo) = Foo(a=1) @external def hello() : @@ -276,7 +276,7 @@ def deposit(deposit_input: Bytes[2048]): a: uint256 b: uint256 -CONST_BAR: constant(Foo) = Foo({a: 1, b: 2}) +CONST_BAR: constant(Foo) = Foo(a=1, b=2) """, """ CONST_EMPTY: constant(bytes32) = empty(bytes32) @@ -293,7 +293,7 @@ def foo() -> bytes32: A: constant(uint256) = 1 B: constant(uint256) = 2 -CONST_BAR: constant(Foo) = Foo({a: A, b: B}) +CONST_BAR: constant(Foo) = Foo(a=A, b=B) """, """ struct Foo: @@ -306,10 +306,10 @@ def foo() -> bytes32: A: constant(uint256) = 1 B: constant(uint256) = 2 -C: constant(Foo) = Foo({a: A, b: B}) +C: constant(Foo) = Foo(a=A, b=B) D: constant(int128) = -1 -CONST_BAR: constant(Bar) = Bar({c: C, d: D}) +CONST_BAR: constant(Bar) = Bar(c=C, d=D) """, """ interface Foo: diff --git a/tests/functional/syntax/test_flag.py b/tests/functional/syntax/test_flag.py index 22309502b7..7732ccc39f 100644 --- a/tests/functional/syntax/test_flag.py +++ b/tests/functional/syntax/test_flag.py @@ -158,10 +158,10 @@ def run() -> Action: @external def run() -> Order: - return Order({ - action: Action.BUY, - amount: 10**18 - }) + return Order( + action=Action.BUY, + amount=10**18 + ) """, "flag Foo:\n" + "\n".join([f" member{i}" for i in range(256)]), """ diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 59fb1a69d9..7e5903a6a1 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -165,7 +165,7 @@ def report(): @deploy def __init__(): - x = Foo({a:1}) + x = Foo(a=1) @external def hello() : diff --git a/tests/functional/syntax/test_invalids.py b/tests/functional/syntax/test_invalids.py index dfc74fc75b..f4e60902ef 100644 --- a/tests/functional/syntax/test_invalids.py +++ b/tests/functional/syntax/test_invalids.py @@ -357,7 +357,7 @@ def a(): @external def a(): - x: int128 = StructX({y: 1}) + x: int128 = StructX(y=1) """, UnknownAttribute, ) diff --git a/tests/functional/syntax/test_no_none.py b/tests/functional/syntax/test_no_none.py index 085ce395ab..ebe32816bd 100644 --- a/tests/functional/syntax/test_no_none.py +++ b/tests/functional/syntax/test_no_none.py @@ -178,7 +178,7 @@ def test_struct_none(assert_compile_failed, get_contract_with_gas_estimation): @external def foo(): - mom: Mom = Mom({a: None, b: 0}) + mom: Mom = Mom(a=None, b=0) """, """ struct Mom: @@ -187,7 +187,7 @@ def foo(): @external def foo(): - mom: Mom = Mom({a: 0, b: None}) + mom: Mom = Mom(a=0, b=None) """, """ struct Mom: @@ -196,7 +196,7 @@ def foo(): @external def foo(): - mom: Mom = Mom({a: None, b: None}) + mom: Mom = Mom(a=None, b=None) """, ] diff --git a/tests/functional/syntax/test_structs.py b/tests/functional/syntax/test_structs.py index 4fad35d1d4..c2cc11324f 100644 --- a/tests/functional/syntax/test_structs.py +++ b/tests/functional/syntax/test_structs.py @@ -1,17 +1,19 @@ +import warnings + import pytest from vyper import compiler from vyper.exceptions import ( InstantiationException, StructureException, + SyntaxException, TypeMismatch, UnknownAttribute, VariableDeclarationException, ) fail_list = [ - ( - """ + """ struct A: x: int128 a: A @@ -19,8 +21,6 @@ def foo(): self.a = A(1) """, - VariableDeclarationException, - ), ( """ struct A: @@ -28,24 +28,20 @@ def foo(): a: A @external def foo(): - self.a = A({x: 1, y: 2}) + self.a = A(x=1, y=2) """, UnknownAttribute, ), - ( - """ + """ struct A: x: int128 y: int128 a: A @external def foo(): - self.a = A({x: 1}) + self.a = A(x=1) """, - VariableDeclarationException, - ), - ( - """ + """ struct A: x: int128 struct B: @@ -56,10 +52,7 @@ def foo(): def foo(): self.a = A(self.b) """, - VariableDeclarationException, - ), - ( - """ + """ struct A: x: int128 a: A @@ -68,10 +61,7 @@ def foo(): def foo(): self.a = A(self.b) """, - VariableDeclarationException, - ), - ( - """ + """ struct A: x: int128 y: int128 @@ -80,10 +70,7 @@ def foo(): def foo(): self.a = A({x: 1}) """, - VariableDeclarationException, - ), - ( - """ + """ struct C: c: int128 struct Mom: @@ -98,8 +85,6 @@ def foo(): def foo(): self.nom = Nom(self.mom) """, - VariableDeclarationException, - ), """ struct C1: c: int128 @@ -178,6 +163,15 @@ def foo(): def foo(): self.nom = Nom(self.mom) """, + """ +struct Foo: + a: uint256 + b: uint256 + +@external +def foo(i: uint256, j: uint256): + f: Foo = Foo(i, b=j) + """, ( """ struct Mom: @@ -251,7 +245,7 @@ def foo(): nom: C[3] @external def foo(): - self.mom = Mom({a: self.nom, b: 5.5}) + self.mom = Mom(a=self.nom, b=5.5) """, TypeMismatch, ), @@ -268,7 +262,7 @@ def foo(): nom: C2[3] @external def foo(): - self.mom = Mom({a: self.nom, b: 5}) + self.mom = Mom(a=self.nom, b=5) """, TypeMismatch, ), @@ -285,7 +279,7 @@ def foo(): nom: C[3] @external def foo(): - self.mom = Mom({a: self.nom, b: self.nom}) + self.mom = Mom(a=self.nom, b=self.nom) """, TypeMismatch, ), @@ -329,7 +323,7 @@ def foo(): nom: C2[3] @external def foo(): - self.mom = Mom({a: self.nom, b: 5}) + self.mom = Mom(a=self.nom, b=5) """, TypeMismatch, ), @@ -342,9 +336,9 @@ def foo(): bar: int128[3] @external def foo(): - self.bar = Bar({0: 5, 1: 7, 2: 9}) + self.bar = Bar(0=5, 1=7, 2=9) """, - UnknownAttribute, + SyntaxException, ), ( """ @@ -355,7 +349,7 @@ def foo(): bar: int128[3] @external def foo(): - self.bar = Bar({a: 5, b: 7, c: 9}) + self.bar = Bar(a=5, b=7, c=9) """, TypeMismatch, ), @@ -366,7 +360,7 @@ def foo(): dog: int128 @external def foo() -> int128: - f: Farm = Farm({cow: 5, dog: 7}) + f: Farm = Farm(cow=5, dog=7) return f """, TypeMismatch, @@ -390,7 +384,7 @@ def foo(): b: B @external def foo(): - self.b = B({foo: 1, foo: 2}) + self.b = B(foo=1, foo=2) """, UnknownAttribute, ), @@ -425,7 +419,7 @@ def foo(): @external def foo(): - Foo({a: 1}) + Foo(a=1) """, StructureException, ), @@ -459,7 +453,7 @@ def test_block_fail(bad_code): a: A @external def foo(): - self.a = A({x: 1}) + self.a = A(x=1) """, """ struct C: @@ -482,7 +476,7 @@ def foo(): nom: C[3] @external def foo(): - mom: Mom = Mom({a:[C({c:0}), C({c:0}), C({c:0})], b: 0}) + mom: Mom = Mom(a=[C(c=0), C(c=0), C(c=0)], b=0) mom.a = self.nom """, """ @@ -495,7 +489,7 @@ def foo(): nom: C[3] @external def foo(): - self.mom = Mom({a: self.nom, b: 5}) + self.mom = Mom(a=self.nom, b=5) """, """ struct C: @@ -507,7 +501,7 @@ def foo(): nom: C[3] @external def foo(): - self.mom = Mom({a: self.nom, b: 5}) + self.mom = Mom(a=self.nom, b=5) """, """ struct C: @@ -518,8 +512,8 @@ def foo(): mom: Mom @external def foo(): - nom: C[3] = [C({c:0}), C({c:0}), C({c:0})] - self.mom = Mom({a: nom, b: 5}) + nom: C[3] = [C(c=0), C(c=0), C(c=0)] + self.mom = Mom(a=nom, b=5) """, """ struct B: @@ -548,7 +542,7 @@ def foo(): d: bool @external def get_y() -> int128: - return C({c: A({a: X({x: 1, y: -1}), b: 777}), d: True}).c.a.y - 10 + return C(c=A(a=X(x=1, y=-1), b=777), d=True).c.a.y - 10 """, """ struct X: @@ -560,7 +554,7 @@ def get_y() -> int128: struct C: c: A d: bool -FOO: constant(C) = C({c: A({a: X({x: 1, y: -1}), b: 777}), d: True}) +FOO: constant(C) = C(c=A(a=X(x=1, y=-1), b=777), d=True) @external def get_y() -> int128: return FOO.c.a.y - 10 @@ -572,7 +566,7 @@ def get_y() -> int128: @external def foo(): - bar: C = C({a: 1, b: block.timestamp}) + bar: C = C(a=1, b=block.timestamp) """, ] @@ -580,3 +574,24 @@ def foo(): @pytest.mark.parametrize("good_code", valid_list) def test_block_success(good_code): assert compiler.compile_code(good_code) is not None + + +def test_old_constructor_syntax(): + # backwards compatibility for vyper <0.4.0 + code = """ +struct A: + x: int128 +a: A +@external +def foo(): + self.a = A({x: 1}) + """ + with warnings.catch_warnings(record=True) as w: + assert compiler.compile_code(code) is not None + + expected = "Instantiating a struct using a dictionary is deprecated " + expected += "as of v0.4.0 and will be disallowed in a future release. " + expected += "Use kwargs instead e.g. Foo(a=1, b=2)" + + assert len(w) == 1 + assert str(w[0].message) == expected diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 6adee24db6..bc99b07a8e 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -38,7 +38,7 @@ def test_invalid_root_path(): @external def foo() -> {alias}.FooStruct: - return {alias}.FooStruct({{foo_: 13}}) + return {alias}.FooStruct(foo_=13) @external def bar(a: address) -> {alias}.FooStruct: @@ -176,7 +176,7 @@ def know_thyself(a: address) -> ISelf.FooStruct: @external def be_known() -> ISelf.FooStruct: - return ISelf.FooStruct({{foo_: 42}}) + return ISelf.FooStruct(foo_=42) """ make_file("contracts/ISelf.vyi", interface_code) meta = make_file("contracts/Self.vy", code) diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index c97c9c095e..bc4ae1a2f7 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -171,7 +171,7 @@ def test_modify_iterator_through_struct(dummy_input_bundle): def foo(): self.a.iter = [1, 2, 3] for i: uint256 in self.a.iter: - self.a = A({iter: [1, 2, 3, 4]}) + self.a = A(iter=[1, 2, 3, 4]) """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 1dc70fd1ba..7cbe71cf29 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -27,19 +27,19 @@ @deploy def __init__(): - self.a = StructOne({a: "ok", b: [4,5,6]}) + self.a = StructOne(a="ok", b=[4,5,6]) self.b = [7, 8] self.c = b"thisisthirtytwobytesokhowdoyoudo" self.d = [-1, -2, -3, -4] self.e = "A realllllly long string but we won't use it all" self.f = [33] self.g = [ - StructTwo({a: b"hello", b: [-66, 420], c: "another string"}), - StructTwo({ - a: b"gbye", - b: [1337, 888], - c: "whatifthisstringtakesuptheentirelengthwouldthatbesobadidothinkso" - }) + StructTwo(a=b"hello", b=[-66, 420], c="another string"), + StructTwo( + a=b"gbye", + b=[1337, 888], + c="whatifthisstringtakesuptheentirelengthwouldthatbesobadidothinkso" + ) ] self.dyn_array = [1, 2, 3] self.h = [123456789] diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index a10a840da0..4f162ac18f 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,5 +1,6 @@ import ast as python_ast import tokenize +import warnings from decimal import Decimal from typing import Any, Dict, List, Optional, Union, cast @@ -341,6 +342,29 @@ def visit_Expr(self, node): return node + def visit_Call(self, node): + # Convert structs declared as `Dict` node for vyper < 0.4.0 to kwargs + if len(node.args) == 1 and isinstance(node.args[0], python_ast.Dict): + msg = "Instantiating a struct using a dictionary is deprecated " + msg += "as of v0.4.0 and will be disallowed in a future release. " + msg += "Use kwargs instead e.g. Foo(a=1, b=2)" + warnings.warn(msg, stacklevel=2) + + dict_ = node.args[0] + kw_list = [] + + assert len(dict_.keys) == len(dict_.values) + for key, value in zip(dict_.keys, dict_.values): + replacement_kw_node = python_ast.keyword(key.id, value) + kw_list.append(replacement_kw_node) + + node.args = [] + node.keywords = kw_list + + self.generic_visit(node) + + return node + def visit_Constant(self, node): """ Handle `Constant` when using Python >=3.8 diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 9c7f11dcb3..13e1309e6d 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -683,9 +683,7 @@ def parse_Call(self): # Struct constructor if is_type_t(func_type, StructT): - args = self.expr.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) + return self.handle_struct_literal() # Interface constructor. Bar(
). if is_type_t(func_type, InterfaceT): @@ -750,17 +748,15 @@ def parse_IfExp(self): location = body.location return IRnode.from_list(["if", test, body, orelse], typ=typ, location=location) - @staticmethod - def struct_literals(expr, context, typ): + def handle_struct_literal(self): + expr = self.expr + typ = expr._metadata["type"] member_subs = {} - member_typs = {} - for key, value in zip(expr.keys, expr.values): - assert isinstance(key, vy_ast.Name) - assert key.id not in member_subs - - sub = Expr(value, context).ir_node - member_subs[key.id] = sub - member_typs[key.id] = sub.typ + for kwarg in expr.keywords: + assert kwarg.arg not in member_subs + + sub = Expr(kwarg.value, self.context).ir_node + member_subs[kwarg.arg] = sub return IRnode.from_list( ["multi"] + [member_subs[key] for key in member_subs.keys()], typ=typ diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 39a1c59290..cc08b3b95a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -731,8 +731,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: # struct ctors # ctors have no kwargs expected_types = func_type.typedef.members.values() # type: ignore - for value, arg_type in zip(node.args[0].values, expected_types): - self.visit(value, arg_type) + for kwarg, arg_type in zip(node.keywords, expected_types): + self.visit(kwarg.value, arg_type) elif isinstance(func_type, MemberFunctionT): if func_type.is_modifying and self.function_analyzer is not None: # TODO refactor this diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 0c9b5d70da..f09a97d0b9 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -381,38 +381,36 @@ def to_abi_arg(self, name: str = "") -> dict: components = [t.to_abi_arg(name=k) for k, t in self.member_types.items()] return {"name": name, "type": "tuple", "components": components} - # TODO breaking change: use kwargs instead of dict - # when using the type itself (not an instance) in the call position - # maybe rename to _ctor_call_return def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": - validate_call_args(node, 1) - if not isinstance(node.args[0], vy_ast.Dict): + if len(node.args) > 0: raise VariableDeclarationException( - "Struct values must be declared via dictionary", node.args[0] + "Struct values must be declared as kwargs e.g. Foo(a=1, b=2)", node.args[0] ) if next((i for i in self.member_types.values() if isinstance(i, HashMapT)), False): raise VariableDeclarationException( "Struct contains a mapping and so cannot be declared as a literal", node ) + # manually validate kwargs for better error messages instead of + # relying on `validate_call_args` members = self.member_types.copy() keys = list(self.member_types.keys()) - for i, (key, value) in enumerate(zip(node.args[0].keys, node.args[0].values)): - if key is None or key.get("id") not in members: - hint = get_levenshtein_error_suggestions(key.get("id"), members, 1.0) - raise UnknownAttribute( - "Unknown or duplicate struct member.", key or value, hint=hint - ) - expected_key = keys[i] - if key.id != expected_key: + for i, kwarg in enumerate(node.keywords): + # x=5 => kwarg(arg="x", value=Int(5)) + argname = kwarg.arg + if argname not in members: + hint = get_levenshtein_error_suggestions(argname, members, 1.0) + raise UnknownAttribute("Unknown or duplicate struct member.", kwarg, hint=hint) + expected = keys[i] + if argname != expected: raise InvalidAttribute( "Struct keys are required to be in order, but got " - f"`{key.id}` instead of `{expected_key}`. (Reminder: the " + f"`{argname}` instead of `{expected}`. (Reminder: the " f"keys in this struct are {list(self.member_types.items())})", - key, + kwarg, ) - - validate_expected_type(value, members.pop(key.id)) + expected_type = members.pop(argname) + validate_expected_type(kwarg.value, expected_type) if members: raise VariableDeclarationException( @@ -422,4 +420,4 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": return self def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: - return all(check_modifiability(v, modifiability) for v in node.args[0].values) + return all(check_modifiability(k.value, modifiability) for k in node.keywords) From e50b67a66e30e174c0b710cb023db59688b2284e Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 20 Feb 2024 21:28:16 +0800 Subject: [PATCH 196/201] docs: new struct instantiation syntax (#3792) * fix formatting in some tests * update docs --- docs/types.rst | 2 +- .../builtins/codegen/test_abi_decode.py | 18 +++++++++--------- .../builtins/codegen/test_abi_encode.py | 18 +++++++++--------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/types.rst b/docs/types.rst index 38779c2a4b..8bc7b7d3e1 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -588,7 +588,7 @@ Struct members can be accessed via ``struct.argname``. value2: decimal # Declaring a struct variable - exampleStruct: MyStruct = MyStruct({value1: 1, value2: 2.0}) + exampleStruct: MyStruct = MyStruct(value1=1, value2=2.0) # Accessing a value exampleStruct.value1 = 1 diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index dbbf195373..9dd9691aa5 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -36,15 +36,15 @@ def abi_decode(x: Bytes[160]) -> (address, int128, bool, decimal, bytes32): @external def abi_decode_struct(x: Bytes[544]) -> Human: human: Human = Human( - name = "", - pet = Animal( - name = "", - address_ = empty(address), - id_ = 0, - is_furry = False, - price = 0.0, - data = [0, 0, 0], - metadata = 0x0000000000000000000000000000000000000000000000000000000000000000 + name="", + pet=Animal( + name="", + address_=empty(address), + id_=0, + is_furry=False, + price=0.0, + data=[0, 0, 0], + metadata=0x0000000000000000000000000000000000000000000000000000000000000000 ) ) human = _abi_decode(x, Human) diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f818b04359..2078cf65f3 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -35,15 +35,15 @@ def abi_encode( include_method_id: bool ) -> Bytes[548]: human: Human = Human( - name = name, - pet = Animal( - name = pet_name, - address_ = pet_address, - id_ = pet_id, - is_furry = pet_is_furry, - price = pet_price, - data = pet_data, - metadata = pet_metadata + name=name, + pet=Animal( + name=pet_name, + address_=pet_address, + id_=pet_id, + is_furry=pet_is_furry, + price=pet_price, + data=pet_data, + metadata=pet_metadata ), ) if ensure_tuple: From d8d98c27f12f140782e60a624557deac545bdf08 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 20 Feb 2024 05:34:16 -0800 Subject: [PATCH 197/201] fix: assembly dead code eliminator (#3791) it was not aggressive enough, and could leave some instructions which mangle the assembly so it can't be turned into bytecode * use a tighter loop * fix an off-by-one * use optimize=gas in a test --- .../compiler/venom/test_duplicate_operands.py | 4 +-- vyper/ir/compile_ir.py | 31 ++++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index b96c7f3351..437185cc72 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -22,6 +22,6 @@ def test_duplicate_operands(): bb.append_instruction("mul", sum_, op) bb.append_instruction("stop") - asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.CODESIZE) + asm = generate_assembly_experimental(ctx, optimize=OptimizationLevel.GAS) - assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP"] diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 8ce8c887f1..8b09ae454f 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -809,18 +809,26 @@ def _prune_unreachable_code(assembly): # unreachable changed = False i = 0 - while i < len(assembly) - 2: - instr = assembly[i] - if isinstance(instr, list): - instr = assembly[i][-1] + while i < len(assembly) - 1: + if assembly[i] in _TERMINAL_OPS: + # find the next jumpdest or sublist + for j in range(i + 1, len(assembly)): + next_is_jumpdest = ( + j < len(assembly) - 1 + and is_symbol(assembly[j]) + and assembly[j + 1] == "JUMPDEST" + ) + next_is_list = isinstance(assembly[j], list) + if next_is_jumpdest or next_is_list: + break + else: + # fixup an off-by-one if we made it to the end of the assembly + # without finding an jumpdest or sublist + j = len(assembly) + changed = j > i + 1 + del assembly[i + 1 : j] - if assembly[i] in _TERMINAL_OPS and not ( - is_symbol(assembly[i + 1]) or isinstance(assembly[i + 1], list) - ): - changed = True - del assembly[i + 1] - else: - i += 1 + i += 1 return changed @@ -1230,7 +1238,6 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: - print(assembly) raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc From d5e8bd8756d89a618e53da34077d8ff013c6003b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 20 Feb 2024 09:10:12 -0800 Subject: [PATCH 198/201] chore: add color to mypy output (#3793) also remove a stray `type: ignore` comment --- tox.ini | 1 + vyper/semantics/analysis/constant_folding.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index b42a13a0ab..2f71a85431 100644 --- a/tox.ini +++ b/tox.ini @@ -52,5 +52,6 @@ commands = [testenv:mypy] basepython = python3 extras = lint +passenv = TERM commands = mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --implicit-optional -p vyper diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index 3522383167..6e4166dc52 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -214,7 +214,7 @@ def visit_Call(self, node) -> vy_ast.ExprNode: # TODO: rename to vyper_type.try_fold_call_expr if not hasattr(typ, "_try_fold"): raise UnfoldableNode("unfoldable", node) - return typ._try_fold(node) # type: ignore + return typ._try_fold(node) def visit_Subscript(self, node) -> vy_ast.ExprNode: slice_ = node.slice.get_folded_value() From bc57775c325befdbe526e0c7b120317c8318e83b Mon Sep 17 00:00:00 2001 From: trocher Date: Tue, 20 Feb 2024 18:14:18 +0100 Subject: [PATCH 199/201] chore: replace occurences of 'enum' by 'flag' (#3794) --- docs/resources.rst | 2 +- docs/types.rst | 2 +- vyper/builtins/_convert.py | 10 +++++----- vyper/codegen/core.py | 4 ++-- vyper/codegen/expr.py | 12 ++++++------ vyper/semantics/README.md | 2 +- vyper/semantics/analysis/local.py | 2 +- vyper/semantics/analysis/utils.py | 2 +- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/user.py | 26 +++++++++++++------------- 10 files changed, 32 insertions(+), 32 deletions(-) diff --git a/docs/resources.rst b/docs/resources.rst index 7bb3c99df4..c2b0e3e427 100644 --- a/docs/resources.rst +++ b/docs/resources.rst @@ -23,7 +23,7 @@ Frameworks and tooling - `VyperDeployer – A helper smart contract to compile and test Vyper contracts in Foundry `_ - `🐍 snekmate – Vyper smart contract building blocks `_ - `Serpentor – A set of smart contracts tools for governance `_ -- `Smart contract development frameworks and tools for Vyper on Ethreum.org `_ +- `Smart contract development frameworks and tools for Vyper on Ethereum.org `_ - `Vyper Online Compiler - an online platform for compiling and deploying Vyper smart contracts `_ Security diff --git a/docs/types.rst b/docs/types.rst index 8bc7b7d3e1..5038e78240 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -675,6 +675,6 @@ All type conversions in Vyper must be made explicitly using the built-in ``conve * Narrowing conversions (e.g., ``int256 -> int128``) check that the input is in bounds for the output type. * Converting between bytes and int types results in sign-extension if the output type is signed. For instance, converting ``0xff`` (``bytes1``) to ``int8`` returns ``-1``. * Converting between bytes and int types which have different sizes follows the rule of going through the closest integer type, first. For instance, ``bytes1 -> int16`` is like ``bytes1 -> int8 -> int16`` (signextend, then widen). ``uint8 -> bytes20`` is like ``uint8 -> uint160 -> bytes20`` (rotate left 12 bytes). -* Enums can be converted to and from ``uint256`` only. +* Flags can be converted to and from ``uint256`` only. A small Python reference implementation is maintained as part of Vyper's test suite, it can be found `here `__. The motivation and more detailed discussion of the rules can be found `here `__. diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index 998cbbc9f6..adc2c233b8 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -308,7 +308,7 @@ def _to_int(expr, arg, out_typ): elif is_flag_type(arg.typ): if out_typ != UINT256_T: _FAIL(arg.typ, out_typ, expr) - # pretend enum is uint256 + # pretend flag is uint256 arg = IRnode.from_list(arg, typ=UINT256_T) # use int_to_int rules arg = _int_to_int(arg, out_typ) @@ -442,12 +442,12 @@ def to_bytes(expr, arg, out_typ): @_input_types(IntegerT) -def to_enum(expr, arg, out_typ): +def to_flag(expr, arg, out_typ): if arg.typ != UINT256_T: _FAIL(arg.typ, out_typ, expr) - if len(out_typ._enum_members) < 256: - arg = int_clamp(arg, bits=len(out_typ._enum_members), signed=False) + if len(out_typ._flag_members) < 256: + arg = int_clamp(arg, bits=len(out_typ._flag_members), signed=False) return IRnode.from_list(arg, typ=out_typ) @@ -469,7 +469,7 @@ def convert(expr, context): elif out_typ == AddressT(): ret = to_address(arg_ast, arg, out_typ) elif is_flag_type(out_typ): - ret = to_enum(arg_ast, arg, out_typ) + ret = to_flag(arg_ast, arg, out_typ) elif is_integer_type(out_typ): ret = to_int(arg_ast, arg, out_typ) elif is_bytes_m_type(out_typ): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 1a090ac316..8a186fe683 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -860,7 +860,7 @@ def needs_clamp(t, encoding): if isinstance(t, (_BytestringT, DArrayT)): return True if isinstance(t, FlagT): - return len(t._enum_members) < 256 + return len(t._flag_members) < 256 if isinstance(t, SArrayT): return needs_clamp(t.value_type, encoding) if is_tuple_like(t): @@ -1132,7 +1132,7 @@ def clamp_basetype(ir_node): ir_node = unwrap_location(ir_node) if isinstance(t, FlagT): - bits = len(t._enum_members) + bits = len(t._flag_members) # assert x >> bits == 0 ret = int_clamp(ir_node, bits, signed=False) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 13e1309e6d..6a444181c2 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -212,15 +212,15 @@ def parse_Name(self): def parse_Attribute(self): typ = self.expr._metadata["type"] - # MyEnum.foo + # MyFlag.foo if ( isinstance(typ, FlagT) and isinstance(self.expr.value, vy_ast.Name) and typ.name == self.expr.value.id ): # 0, 1, 2, .. 255 - enum_id = typ._enum_members[self.expr.attr] - value = 2**enum_id # 0 => 0001, 1 => 0010, 2 => 0100, etc. + flag_id = typ._flag_members[self.expr.attr] + value = 2**flag_id # 0 => 0001, 1 => 0010, 2 => 0100, etc. return IRnode.from_list(value, typ=typ) # x.balance: balance of address x @@ -420,7 +420,7 @@ def parse_BinOp(self): op = shr if not left.typ.is_signed else sar return IRnode.from_list(op(right, left), typ=new_typ) - # enums can only do bit ops, not arithmetic. + # flags can only do bit ops, not arithmetic. assert is_numeric_type(left.typ) with left.cache_when_complex("x") as (b1, x), right.cache_when_complex("y") as (b2, y): @@ -645,10 +645,10 @@ def parse_UnaryOp(self): if isinstance(self.expr.op, vy_ast.Invert): if isinstance(operand.typ, FlagT): - n_members = len(operand.typ._enum_members) + n_members = len(operand.typ._flag_members) # use (xor 0b11..1 operand) to flip all the bits in # `operand`. `mask` could be a very large constant and - # hurt codesize, but most user enums will likely have few + # hurt codesize, but most user flags will likely have few # enough members that the mask will not be large. mask = (2**n_members) - 1 return IRnode.from_list(["xor", mask, operand], typ=operand.typ) diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md index 3b7acf9469..7a8a384c6d 100644 --- a/vyper/semantics/README.md +++ b/vyper/semantics/README.md @@ -16,7 +16,7 @@ Vyper abstract syntax tree (AST). * [`primitives.py`](types/primitives.py): Address, boolean, fixed length byte, integer and decimal types * [`shortcuts.py`](types/shortcuts.py): Helper constants for commonly used types * [`subscriptable.py`](types/subscriptable.py): Mapping, array and tuple types - * [`user.py`](types/user.py): Enum, event, interface and struct types + * [`user.py`](types/user.py): Flag, event, interface and struct types * [`utils.py`](types/utils.py): Functions for generating and fetching type objects * [`analysis/`](analysis): Subpackage for type checking and syntax verification logic * [`annotation.py`](analysis/annotation.py): Annotates statements and expressions with the appropriate type information diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cc08b3b95a..0ce966d1c6 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -764,7 +764,7 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: else: rtyp = get_exact_type_from_node(node.right) if isinstance(rtyp, FlagT): - # enum membership - `some_enum in other_enum` + # flag membership - `some_flag in other_flag` ltyp = rtyp else: # array membership - `x in my_list_variable` diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 21ca7a8d3f..f102b1f13b 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -376,7 +376,7 @@ def types_from_Name(self, node): # when this is a type, we want to lower it if isinstance(t, VyperType): # TYPE_T is used to handle cases where a type can occur in call or - # attribute conditions, like Enum.foo or MyStruct({...}) + # attribute conditions, like Flag.foo or MyStruct({...}) return [TYPE_T(t)] return [t.typ] diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 37de263319..4213535af7 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -389,7 +389,7 @@ def infer_kwarg_types(self, node): raise StructureException("Value is not callable", node) # dispatch into get_type_member if it's dereferenced, ex. - # MyEnum.FOO + # MyFlag.FOO def get_member(self, key, node): if hasattr(self.typedef, "get_type_member"): return self.typedef.get_type_member(key, node) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index f09a97d0b9..68150f924b 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -44,23 +44,23 @@ def __hash__(self): return hash(id(self)) -# note: enum behaves a lot like uint256, or uints in general. +# note: flag behaves a lot like uint256, or uints in general. class FlagT(_UserType): # this is a carveout because currently we allow dynamic arrays of - # enums, but not static arrays of enums + # flags, but not static arrays of flags _as_darray = True _is_prim_word = True _as_hashmap_key = True def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: - raise FlagDeclarationException("Enums are limited to 256 members!") + raise FlagDeclarationException("Flags are limited to 256 members!") super().__init__(members=None) self._id = name - self._enum_members = members + self._flag_members = members # use a VyperType for convenient access to the `get_member` function # also conveniently checks well-formedness of the members namespace @@ -74,8 +74,8 @@ def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": return self def __repr__(self): - arg_types = ",".join(repr(a) for a in self._enum_members) - return f"enum {self.name}({arg_types})" + arg_types = ",".join(repr(a) for a in self._flag_members) + return f"flag {self.name}({arg_types})" @property def abi_type(self): @@ -107,29 +107,29 @@ def validate_comparator(self, node): @classmethod def from_FlagDef(cls, base_node: vy_ast.FlagDef) -> "FlagT": """ - Generate an `Enum` object from a Vyper ast node. + Generate an `Flag` object from a Vyper ast node. Arguments --------- - base_node : EnumDef - Vyper ast node defining the enum + base_node : FlagDef + Vyper ast node defining the flag Returns ------- - Enum + Flag """ members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - raise FlagDeclarationException("Enum must have members", base_node) + raise FlagDeclarationException("Flag must have members", base_node) for i, node in enumerate(base_node.body): if not isinstance(node, vy_ast.Expr) or not isinstance(node.value, vy_ast.Name): - raise FlagDeclarationException("Invalid syntax for enum member", node) + raise FlagDeclarationException("Invalid syntax for flag member", node) member_name = node.value.id if member_name in members: raise FlagDeclarationException( - f"Enum member '{member_name}' has already been declared", node.value + f"Flag member '{member_name}' has already been declared", node.value ) members[member_name] = i From 1ca243bb25666d3f4801a308af1d4db2872bbc21 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 20 Feb 2024 16:19:29 -0800 Subject: [PATCH 200/201] feat[lang]: introduce floordiv operator (#2937) introduce floordiv operator to increase type safety of numeric operations (and to look more like python). floordiv is banned for decimals; "regular" div is banned for integers. we could maybe loosen this restriction in the future, e.g. int1 / int2 -> decimal, but for now just segregate into decimal and integer division operations. --------- Co-authored-by: tserg <8017125+tserg@users.noreply.github.com> --- .../market_maker/on_chain_market_maker.vy | 6 +-- .../safe_remote_purchase.vy | 2 +- examples/stock/company.vy | 2 +- examples/tokens/ERC4626.vy | 4 +- .../features/iteration/test_for_range.py | 2 +- .../codegen/features/test_assignment.py | 6 +-- .../codegen/types/numbers/test_decimals.py | 14 +++++- .../codegen/types/numbers/test_signed_ints.py | 29 +++++++++-- .../types/numbers/test_unsigned_ints.py | 29 +++++++++-- tests/functional/syntax/test_block.py | 4 +- tests/functional/syntax/test_constants.py | 2 +- tests/functional/syntax/test_public.py | 2 +- tests/unit/ast/nodes/test_fold_binop_int.py | 6 +-- .../analysis/test_potential_types.py | 12 ++++- vyper/ast/grammar.lark | 2 + vyper/ast/nodes.py | 44 +++++++++++------ vyper/ast/nodes.pyi | 1 + vyper/codegen/expr.py | 2 +- vyper/semantics/analysis/utils.py | 2 +- vyper/semantics/types/base.py | 2 +- vyper/semantics/types/primitives.py | 49 +++++++++++++++++-- 21 files changed, 173 insertions(+), 49 deletions(-) diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index 74b1307dc1..e1865d00c0 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -29,10 +29,10 @@ def initiate(token_addr: address, token_quantity: uint256): @external @payable def ethToTokens(): - fee: uint256 = msg.value / 500 + fee: uint256 = msg.value // 500 eth_in_purchase: uint256 = msg.value - fee new_total_eth: uint256 = self.totalEthQty + eth_in_purchase - new_total_tokens: uint256 = self.invariant / new_total_eth + new_total_tokens: uint256 = self.invariant // new_total_eth self.token_address.transfer(msg.sender, self.totalTokenQty - new_total_tokens) self.totalEthQty = new_total_eth self.totalTokenQty = new_total_tokens @@ -42,7 +42,7 @@ def ethToTokens(): def tokensToEth(sell_quantity: uint256): self.token_address.transferFrom(msg.sender, self, sell_quantity) new_total_tokens: uint256 = self.totalTokenQty + sell_quantity - new_total_eth: uint256 = self.invariant / new_total_tokens + new_total_eth: uint256 = self.invariant // new_total_tokens eth_to_send: uint256 = self.totalEthQty - new_total_eth send(msg.sender, eth_to_send) self.totalEthQty = new_total_eth diff --git a/examples/safe_remote_purchase/safe_remote_purchase.vy b/examples/safe_remote_purchase/safe_remote_purchase.vy index 91f0159a2d..5d94430de3 100644 --- a/examples/safe_remote_purchase/safe_remote_purchase.vy +++ b/examples/safe_remote_purchase/safe_remote_purchase.vy @@ -25,7 +25,7 @@ ended: public(bool) @payable def __init__(): assert (msg.value % 2) == 0 - self.value = msg.value / 2 # The seller initializes the contract by + self.value = msg.value // 2 # The seller initializes the contract by # posting a safety deposit of 2*value of the item up for sale. self.seller = msg.sender self.unlocked = True diff --git a/examples/stock/company.vy b/examples/stock/company.vy index 355432830d..7739959e92 100644 --- a/examples/stock/company.vy +++ b/examples/stock/company.vy @@ -53,7 +53,7 @@ def stockAvailable() -> uint256: def buyStock(): # Note: full amount is given to company (no fractional shares), # so be sure to send exact amount to buy shares - buy_order: uint256 = msg.value / self.price # rounds down + buy_order: uint256 = msg.value // self.price # rounds down # Check that there are enough shares to buy. assert self._stockAvailable() >= buy_order diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 699b5edd42..acfaaab694 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -114,7 +114,7 @@ def _convertToAssets(shareAmount: uint256) -> uint256: # NOTE: `shareAmount = 0` is extremely rare case, not optimizing for it # NOTE: `totalAssets = 0` is extremely rare case, not optimizing for it - return shareAmount * self.asset.balanceOf(self) / totalSupply + return shareAmount * self.asset.balanceOf(self) // totalSupply @view @@ -132,7 +132,7 @@ def _convertToShares(assetAmount: uint256) -> uint256: return assetAmount # 1:1 price # NOTE: `assetAmount = 0` is extremely rare case, not optimizing for it - return assetAmount * totalSupply / totalAssets + return assetAmount * totalSupply // totalAssets @view diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index c661c46553..eedce46829 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -136,7 +136,7 @@ def reverse_digits(x: int128) -> int128: z: int128 = x for i: uint256 in range(6): dig[i] = z % 10 - z = z / 10 + z = z // 10 o: int128 = 0 for i: uint256 in range(6): o = o * 10 + dig[i] diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index a276eb7ea9..8f3270e4bc 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -197,7 +197,7 @@ def foo3(y: uint256) -> uint256: assert c.foo3(11) == 12 -def test_invalid_uin256_assignment(assert_compile_failed, get_contract_with_gas_estimation): +def test_invalid_uint256_assignment(assert_compile_failed, get_contract_with_gas_estimation): code = """ storx: uint256 @@ -210,14 +210,14 @@ def foo2() -> uint256: assert_compile_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) -def test_invalid_uin256_assignment_calculate_literals(get_contract_with_gas_estimation): +def test_invalid_uint256_assignment_calculate_literals(get_contract_with_gas_estimation): code = """ storx: uint256 @external def foo2() -> uint256: x: uint256 = 0 - x = 3 * 4 / 2 + 1 - 2 + x = 3 * 4 // 2 + 1 - 2 return x """ c = get_contract_with_gas_estimation(code) diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 72171dd4b5..425440fd4b 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -29,7 +29,7 @@ def test_decimal_override(): ) -@pytest.mark.parametrize("op", ["**", "&", "|", "^"]) +@pytest.mark.parametrize("op", ["//", "**", "&", "|", "^"]) def test_invalid_ops(op): code = f""" @external @@ -300,3 +300,15 @@ def foo(): """ with pytest.raises(OverflowException): compile_code(code) + + +def test_invalid_floordiv(): + code = """ +@external +def foo(): + a: decimal = 5.0 // 9.0 + """ + with pytest.raises(InvalidOperation) as e: + compile_code(code) + + assert e.value._hint == "did you mean `5.0 / 9.0`?" diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index e646a25354..e063f981ec 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -228,7 +228,7 @@ def num_sub() -> {typ}: "+": operator.add, "-": operator.sub, "*": operator.mul, - "/": evm_div, + "//": evm_div, "%": evm_mod, } @@ -263,7 +263,7 @@ def foo() -> {typ}: """ lo, hi = typ.ast_bounds - fns = {"+": operator.add, "-": operator.sub, "*": operator.mul, "/": evm_div, "%": evm_mod} + fns = {"+": operator.add, "-": operator.sub, "*": operator.mul, "//": evm_div, "%": evm_mod} fn = fns[op] c = get_contract(code_1) @@ -307,7 +307,7 @@ def foo() -> {typ}: in_bounds = lo <= expected <= hi # safediv and safemod disallow divisor == 0 - div_by_zero = y == 0 and op in ("/", "%") + div_by_zero = y == 0 and op in ("//", "%") ok = in_bounds and not div_by_zero @@ -417,6 +417,17 @@ def foo(a: {typ}) -> {typ}: c.foo(lo) +@pytest.mark.parametrize("typ", types) +@pytest.mark.parametrize("op", ["/"]) +def test_invalid_ops(get_contract, assert_compile_failed, typ, op): + code = f""" +@external +def foo(x: {typ}, y: {typ}) -> {typ}: + return x {op} y + """ + assert_compile_failed(lambda: get_contract(code), InvalidOperation) + + @pytest.mark.parametrize("typ", types) @pytest.mark.parametrize("op", ["not"]) def test_invalid_unary_ops(typ, op): @@ -437,3 +448,15 @@ def foo(): """ with pytest.raises(TypeMismatch): compile_code(code) + + +def test_invalid_div(): + code = """ +@external +def foo(): + a: int256 = -5 / 9 + """ + with pytest.raises(InvalidOperation) as e: + compile_code(code) + + assert e.value._hint == "did you mean `-5 // 9`?" diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 3f3fa32aba..42619a8bd5 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -83,7 +83,7 @@ def foo(x: {typ}) -> {typ}: "+": operator.add, "-": operator.sub, "*": operator.mul, - "/": evm_div, + "//": evm_div, "%": evm_mod, } @@ -140,7 +140,7 @@ def foo() -> {typ}: in_bounds = lo <= expected <= hi # safediv and safemod disallow divisor == 0 - div_by_zero = y == 0 and op in ("/", "%") + div_by_zero = y == 0 and op in ("//", "%") ok = in_bounds and not div_by_zero @@ -236,6 +236,17 @@ def test() -> {typ}: compile_code(code_template.format(typ=typ, val=val)) +@pytest.mark.parametrize("typ", types) +@pytest.mark.parametrize("op", ["/"]) +def test_invalid_ops(get_contract, assert_compile_failed, typ, op): + code = f""" +@external +def foo(x: {typ}, y: {typ}) -> {typ}: + return x {op} y + """ + assert_compile_failed(lambda: get_contract(code), InvalidOperation) + + @pytest.mark.parametrize("typ", types) @pytest.mark.parametrize("op", ["not", "-"]) def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op): @@ -252,7 +263,19 @@ def test_binop_nested_intermediate_overflow(): code = """ @external def foo(): - a: uint256 = 2**255 * 2 / 10 + a: uint256 = 2**255 * 2 // 10 """ with pytest.raises(OverflowException): compile_code(code) + + +def test_invalid_div(): + code = """ +@external +def foo(): + a: uint256 = 5 / 9 + """ + with pytest.raises(InvalidOperation) as e: + compile_code(code) + + assert e.value._hint == "did you mean `5 // 9`?" diff --git a/tests/functional/syntax/test_block.py b/tests/functional/syntax/test_block.py index aea39aa9c7..1cfdc87a5c 100644 --- a/tests/functional/syntax/test_block.py +++ b/tests/functional/syntax/test_block.py @@ -26,7 +26,7 @@ def foo() -> int128[2]: def foo() -> decimal: x: int128 = as_wei_value(5, "finney") y: int128 = block.timestamp + 50 - return x / y + return x // y """, ( """ @@ -106,7 +106,7 @@ def add_record(): def foo() -> uint256: x: uint256 = as_wei_value(5, "finney") y: uint256 = block.timestamp + 50 - block.timestamp - return x / y + return x // y """, """ @external diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 5eb9eefe25..db2accf359 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i: uint256 in range(CONST / 4): + for i: uint256 in range(CONST // 4): pass """, """ diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 217fcea998..636408875c 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -21,7 +21,7 @@ def __init__(): @external def foo() -> int128: - return self.x / self.y / self.z + return self.x // self.y // self.z """, # expansion of public user-defined struct """ diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index d9340927fe..9f7b2a71ff 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -15,7 +15,7 @@ @example(left=1, right=-1) @example(left=-1, right=1) @example(left=-1, right=-1) -@pytest.mark.parametrize("op", "+-*/%") +@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"]) def test_binop_int128(get_contract, tx_failed, op, left, right): source = f""" @external @@ -45,7 +45,7 @@ def foo(a: int128, b: int128) -> int128: @pytest.mark.fuzzing @settings(max_examples=50) @given(left=st_uint64, right=st_uint64) -@pytest.mark.parametrize("op", "+-*/%") +@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"]) def test_binop_uint256(get_contract, tx_failed, op, left, right): source = f""" @external @@ -94,7 +94,7 @@ def foo(a: uint256, b: uint256) -> uint256: @settings(max_examples=50) @given( values=st.lists(st.integers(min_value=-256, max_value=256), min_size=2, max_size=10), - ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), + ops=st.lists(st.sampled_from(["+", "-", "*", "//", "%"]), min_size=11, max_size=11), ) def test_binop_nested(get_contract, tx_failed, values, ops): variables = "abcdefghij" diff --git a/tests/unit/semantics/analysis/test_potential_types.py b/tests/unit/semantics/analysis/test_potential_types.py index 74cdc9ae0f..dabb242c96 100644 --- a/tests/unit/semantics/analysis/test_potential_types.py +++ b/tests/unit/semantics/analysis/test_potential_types.py @@ -58,9 +58,17 @@ def test_attribute_not_member_type(build_node, namespace): get_possible_types_from_node(node) +@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"]) +@pytest.mark.parametrize("left,right", INTEGER_LITERALS) +def test_binop_ints(build_node, namespace, op, left, right): + node = build_node(f"{left}{op}{right}") + with namespace.enter_scope(): + get_possible_types_from_node(node) + + @pytest.mark.parametrize("op", "+-*/%") -@pytest.mark.parametrize("left,right", INTEGER_LITERALS + DECIMAL_LITERALS) -def test_binop(build_node, namespace, op, left, right): +@pytest.mark.parametrize("left,right", DECIMAL_LITERALS) +def test_binop_decimal(build_node, namespace, op, left, right): node = build_node(f"{left}{op}{right}") with namespace.enter_scope(): get_possible_types_from_node(node) diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 5ad465a1f1..15dbbda280 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -140,6 +140,7 @@ assign: (variable_access | multiple_assign | "(" multiple_assign ")" ) "=" _expr | "-" -> sub | "*" -> mul | "/" -> div + | "//" -> floordiv | "%" -> mod | "**" -> pow | "<<" -> shl @@ -274,6 +275,7 @@ _IN: "in" ?product: unary | product "*" unary -> mul | product "/" unary -> div + | product "//" unary -> floordiv | product "%" unary -> mod ?unary: power | "+" power -> uadd diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 3e15a28512..c38d934dd3 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,7 @@ VyperException, ZeroDivisionException, ) -from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code +from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code, evm_div NODE_BASE_ATTRIBUTES = ( "_children", @@ -1056,7 +1056,7 @@ def _op(self, left, right): class Div(Operator): __slots__ = () - _description = "division" + _description = "decimal division" _pretty = "/" def _op(self, left, right): @@ -1065,20 +1065,32 @@ def _op(self, left, right): if not right: raise ZeroDivisionException("Division by zero") - if isinstance(left, decimal.Decimal): - value = left / right - if value < 0: - # the EVM always truncates toward zero - value = -(-left / right) - # ensure that the result is truncated to MAX_DECIMAL_PLACES - return value.quantize( - decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN - ) - else: - value = left // right - if value < 0: - return -(-left // right) - return value + if not isinstance(left, decimal.Decimal): + raise UnfoldableNode("Cannot use `/` on non-decimals (did you mean `//`?)") + + value = left / right + if value < 0: + # the EVM always truncates toward zero + value = -(-left / right) + # ensure that the result is truncated to MAX_DECIMAL_PLACES + return value.quantize(decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN) + + +class FloorDiv(VyperNode): + __slots__ = () + _description = "integer division" + _pretty = "//" + + def _op(self, left, right): + # evaluate the operation using true division or floor division + assert type(left) is type(right) + if not right: + raise ZeroDivisionException("Division by zero") + + if not isinstance(left, int): + raise UnfoldableNode("Cannot use `//` on non-integers (did you mean `/`?)") + + return evm_div(left, right) class Mod(Operator): diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 342c84876a..45be123a3e 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -164,6 +164,7 @@ class Add(VyperNode): ... class Sub(VyperNode): ... class Mult(VyperNode): ... class Div(VyperNode): ... +class FloorDiv(VyperNode): ... class Mod(VyperNode): ... class Pow(VyperNode): ... class LShift(VyperNode): ... diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 6a444181c2..fc8dc9393c 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -430,7 +430,7 @@ def parse_BinOp(self): ret = arithmetic.safe_sub(x, y) elif isinstance(self.expr.op, vy_ast.Mult): ret = arithmetic.safe_mul(x, y) - elif isinstance(self.expr.op, vy_ast.Div): + elif isinstance(self.expr.op, (vy_ast.Div, vy_ast.FloorDiv)): ret = arithmetic.safe_div(x, y) elif isinstance(self.expr.op, vy_ast.Mod): ret = arithmetic.safe_mod(x, y) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f102b1f13b..12e50da22b 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -225,7 +225,7 @@ def types_from_BinOp(self, node): types_list = get_common_types(node.left, node.right) if ( - isinstance(node.op, (vy_ast.Div, vy_ast.Mod)) + isinstance(node.op, (vy_ast.Div, vy_ast.FloorDiv, vy_ast.Mod)) and isinstance(node.right, vy_ast.Num) and not node.right.value ): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 4213535af7..38e082b298 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -183,7 +183,7 @@ def _raise_invalid_op( # TODO maybe make these AST classes inherit from "HasOperator" node: Union[vy_ast.UnaryOp, vy_ast.BinOp, vy_ast.AugAssign, vy_ast.Compare, vy_ast.BoolOp], ) -> None: - raise InvalidOperation(f"Cannot perform {node.op.description} on {self}", node) + raise InvalidOperation(f"Cannot perform {node.op.description} on {self}", node.op) def validate_comparator(self, node: vy_ast.Compare) -> None: """ diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index d11a9595a3..66efabd1db 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -6,7 +6,13 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABI_Bool, ABI_BytesM, ABI_FixedMxN, ABI_GIntM, ABIType -from vyper.exceptions import CompilerPanic, InvalidLiteral, InvalidOperation, OverflowException +from vyper.exceptions import ( + CompilerPanic, + InvalidLiteral, + InvalidOperation, + OverflowException, + VyperException, +) from vyper.utils import checksum_encode, int_bounds, is_checksum_encoded from .base import VyperType @@ -195,6 +201,24 @@ def validate_comparator(self, node: vy_ast.Compare) -> None: return +def _add_div_hint(node, e): + if isinstance(node.op, vy_ast.Div): + suggested = vy_ast.FloorDiv._pretty + elif isinstance(node.op, vy_ast.FloorDiv): + suggested = vy_ast.Div._pretty + else: + return e + + if isinstance(node, vy_ast.BinOp): + e._hint = f"did you mean `{node.left.node_source_code} " + e._hint += f"{suggested} {node.right.node_source_code}`?" + elif isinstance(node, vy_ast.AugAssign): + e._hint = f"did you mean `{node.target.node_source_code} " + e._hint += f"{suggested}= {node.value.node_source_code}`?" + + return e + + class IntegerT(NumericT): """ General integer type. All signed and unsigned ints from uint8 thru int256 @@ -228,11 +252,17 @@ def ast_bounds(self) -> Tuple[int, int]: @cached_property def _invalid_ops(self): - invalid_ops = (vy_ast.Not,) + invalid_ops = (vy_ast.Not, vy_ast.Div) if not self.is_signed: return invalid_ops + (vy_ast.USub,) return invalid_ops + def validate_numeric_op(self, node) -> None: + try: + super().validate_numeric_op(node) + except VyperException as e: + raise _add_div_hint(node, e) from None + @classmethod # TODO maybe cache these three classmethods def signeds(cls) -> Tuple["IntegerT", ...]: @@ -281,13 +311,26 @@ class DecimalT(NumericT): _decimal_places = 10 # TODO generalize _id = "decimal" _is_signed = True - _invalid_ops = (vy_ast.Pow, vy_ast.BitAnd, vy_ast.BitOr, vy_ast.BitXor, vy_ast.Not) + _invalid_ops = ( + vy_ast.Pow, + vy_ast.FloorDiv, + vy_ast.BitAnd, + vy_ast.BitOr, + vy_ast.BitXor, + vy_ast.Not, + ) _valid_literal = (vy_ast.Decimal,) _equality_attrs = ("_bits", "_decimal_places") ast_type = Decimal + def validate_numeric_op(self, node) -> None: + try: + super().validate_numeric_op(node) + except VyperException as e: + raise _add_div_hint(node, e) from None + @cached_property def abi_type(self) -> ABIType: return ABI_FixedMxN(self._bits, self._decimal_places, self._is_signed) From 015cf81408cbff22f8bc60fdc506f8e53bfdcca8 Mon Sep 17 00:00:00 2001 From: fiddyresearch <11488427+bout3fiddy@users.noreply.github.com> Date: Wed, 21 Feb 2024 01:45:07 +0100 Subject: [PATCH 201/201] feat: change default `code_offset` in `create_from_blueprint` (#3454) change default `code_offset` to 3 for `create_from_blueprint`, per ERC5202 standard --------- Co-authored-by: Charles Cooper --- docs/built-in-functions.rst | 10 ++- tests/conftest.py | 5 +- .../builtins/codegen/test_create_functions.py | 68 +++++++++++++++++-- vyper/builtins/functions.py | 2 +- vyper/compiler/phases.py | 4 +- vyper/utils.py | 1 + 6 files changed, 78 insertions(+), 12 deletions(-) diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index afb64e71ca..f2f6632906 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -184,7 +184,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui The implementation of ``create_copy_of`` assumes that the code at ``target`` is smaller than 16MB. While this is much larger than the EIP-170 constraint of 24KB, it is a conservative size limit intended to future-proof deployer contracts in case the EIP-170 constraint is lifted. If the code at ``target`` is larger than 16MB, the behavior of ``create_copy_of`` is undefined. -.. py:function:: create_from_blueprint(target: address, *args, value: uint256 = 0, raw_args: bool = False, code_offset: int = 0, [, salt: bytes32]) -> address +.. py:function:: create_from_blueprint(target: address, *args, value: uint256 = 0, raw_args: bool = False, code_offset: int = 3, [, salt: bytes32]) -> address Copy the code of ``target`` into memory and execute it as initcode. In other words, this operation interprets the code at ``target`` not as regular runtime code, but directly as initcode. The ``*args`` are interpreted as constructor arguments, and are ABI-encoded and included when executing the initcode. @@ -192,7 +192,7 @@ Vyper has three built-ins for contract creation; all three contract creation bui * ``*args``: Constructor arguments to forward to the initcode. * ``value``: The wei value to send to the new contract address (Optional, default 0) * ``raw_args``: If ``True``, ``*args`` must be a single ``Bytes[...]`` argument, which will be interpreted as a raw bytes buffer to forward to the create operation (which is useful for instance, if pre- ABI-encoded data is passed in from elsewhere). (Optional, default ``False``) - * ``code_offset``: The offset to start the ``EXTCODECOPY`` from (Optional, default 0) + * ``code_offset``: The offset to start the ``EXTCODECOPY`` from (Optional, default 3) * ``salt``: A ``bytes32`` value utilized by the deterministic ``CREATE2`` opcode (Optional, if not supplied, ``CREATE`` is used) Returns the address of the created contract. If the create operation fails (for instance, in the case of a ``CREATE2`` collision), execution will revert. If ``code_offset >= target.codesize`` (ex. if there is no code at ``target``), execution will revert. @@ -209,9 +209,13 @@ Vyper has three built-ins for contract creation; all three contract creation bui To properly deploy a blueprint contract, special deploy bytecode must be used. The output of ``vyper -f blueprint_bytecode`` will produce bytecode which deploys an ERC-5202 compatible blueprint. +.. note:: + + Prior to Vyper version ``0.4.0``, the ``code_offset`` parameter defaulted to ``0``. + .. warning:: - It is recommended to deploy blueprints with the ERC-5202 preamble ``0xFE7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided to enable this pattern: + It is recommended to deploy blueprints with an `ERC-5202 `_ preamble like ``0xFE7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided (and defaults to the ERC-5202 default of 3) to enable this pattern: .. code-block:: vyper diff --git a/tests/conftest.py b/tests/conftest.py index 6eb34a3e0a..201f723efa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer +from vyper.utils import ERC5202_PREFIX # Import the base fixtures pytest_plugins = ["tests.fixtures.memorymock"] @@ -377,7 +378,9 @@ def get_contract_module(source_code, *args, **kwargs): return get_contract_module -def _deploy_blueprint_for(w3, source_code, optimize, output_formats, initcode_prefix=b"", **kwargs): +def _deploy_blueprint_for( + w3, source_code, optimize, output_formats, initcode_prefix=ERC5202_PREFIX, **kwargs +): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) settings.optimize = optimize diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index 0aa718157c..75b10e47b6 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -6,7 +6,7 @@ import vyper.ir.compile_ir as compile_ir from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256 +from vyper.utils import EIP_170_LIMIT, ERC5202_PREFIX, checksum_encode, keccak256 # initcode used by create_minimal_proxy_to @@ -148,7 +148,7 @@ def test(_salt: bytes32) -> address: # test blueprints with various prefixes - 0xfe would block calls to the blueprint # contract, and 0xfe7100 is ERC5202 magic -@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) +@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", ERC5202_PREFIX]) def test_create_from_blueprint( get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed, blueprint_prefix ): @@ -208,6 +208,66 @@ def test2(target: address, salt: bytes32): d.test2(f.address, salt) +# test blueprints with 0xfe7100 prefix, which is the EIP 5202 standard. +# code offset by default should be 3 here. +def test_create_from_blueprint_default_offset( + get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed +): + code = """ +@external +def foo() -> uint256: + return 123 + """ + + deployer_code = """ +created_address: public(address) + +@external +def test(target: address): + self.created_address = create_from_blueprint(target) + +@external +def test2(target: address, salt: bytes32): + self.created_address = create_from_blueprint(target, salt=salt) + """ + + # deploy a foo so we can compare its bytecode with factory deployed version + foo_contract = get_contract(code) + expected_runtime_code = w3.eth.get_code(foo_contract.address) + + f, FooContract = deploy_blueprint_for(code) + + d = get_contract(deployer_code) + + d.test(f.address, transact={}) + + test = FooContract(d.created_address()) + assert w3.eth.get_code(test.address) == expected_runtime_code + assert test.foo() == 123 + + # extcodesize check + zero_address = "0x" + "00" * 20 + with tx_failed(): + d.test(zero_address) + + # now same thing but with create2 + salt = keccak(b"vyper") + d.test2(f.address, salt, transact={}) + + test = FooContract(d.created_address()) + assert w3.eth.get_code(test.address) == expected_runtime_code + assert test.foo() == 123 + + # check if the create2 address matches our offchain calculation + initcode = w3.eth.get_code(f.address) + initcode = initcode[len(ERC5202_PREFIX) :] # strip the prefix + assert HexBytes(test.address) == create2_address_of(d.address, salt, initcode) + + # can't collide addresses + with tx_failed(): + d.test2(f.address, salt) + + def test_create_from_blueprint_bad_code_offset( get_contract, get_contract_from_ir, deploy_blueprint_for, w3, tx_failed ): @@ -238,8 +298,6 @@ def test(code_ofst: uint256) -> address: tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0} tx_hash = deploy_transaction.transact(tx_info) blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"] - blueprint_code = w3.eth.get_code(blueprint_address) - print("BLUEPRINT CODE:", blueprint_code) d = get_contract(deployer_code, blueprint_address) @@ -320,7 +378,7 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): d = get_contract(deployer_code) - initcode = w3.eth.get_code(f.address) + initcode = w3.eth.get_code(f.address)[3:] d.test(f.address, FOO, BAR, transact={}) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7575f4d77e..de0158aba4 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1828,7 +1828,7 @@ class CreateFromBlueprint(_CreateBase): "value": KwargSettings(UINT256_T, zero_value), "salt": KwargSettings(BYTES32_T, empty_value), "raw_args": KwargSettings(BoolT(), False, require_literal=True), - "code_offset": KwargSettings(UINT256_T, zero_value), + "code_offset": KwargSettings(UINT256_T, IRnode.from_list(3, typ=UINT256_T)), } _has_varargs = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index f7eccdf214..af94011633 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -16,6 +16,7 @@ from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout +from vyper.utils import ERC5202_PREFIX from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") @@ -228,8 +229,7 @@ def bytecode_runtime(self) -> bytes: @cached_property def blueprint_bytecode(self) -> bytes: - blueprint_preamble = b"\xFE\x71\x00" # ERC5202 preamble - blueprint_bytecode = blueprint_preamble + self.bytecode + blueprint_bytecode = ERC5202_PREFIX + self.bytecode # the length of the deployed code in bytes len_bytes = len(blueprint_bytecode).to_bytes(2, "big") diff --git a/vyper/utils.py b/vyper/utils.py index ab4d789aa4..26869a6def 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -398,6 +398,7 @@ class SizeLimits: EIP_170_LIMIT = 0x6000 # 24kb +ERC5202_PREFIX = b"\xFE\x71\x00" # default prefix from ERC-5202 SHA3_BASE = 30 SHA3_PER_WORD = 6