From 8aae7cd6b86c15978bdfa16d5a6e3ca273121107 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 2 Oct 2023 19:20:19 -0700 Subject: [PATCH 01/10] feat: disallow invalid pragmas (#3634) this commit prevents typos (ex. `# pragma evm-versionn ...`) from getting unexpectedly ignored. it also fixes an issue with the `evm_version` pragma getting properly propagated into `CompilerData.settings`. it also fixes a misnamed test function, and adds some unit tests to check that `evm_version` gets properly propagated into `CompilerData.settings`. commit summary: * fix `evm_version` passing into compiler data * fix `compile_code()` * add tests for `CompilerData.settings` * add tests for invalid pragmas --- tests/ast/test_pre_parser.py | 70 ++++++++++++++++++++++++++++++++---- vyper/ast/pre_parser.py | 7 ++-- vyper/compiler/README.md | 6 ++-- vyper/compiler/__init__.py | 22 ++++++------ vyper/compiler/phases.py | 5 +++ 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py index 5427532c16..3d072674f6 100644 --- a/tests/ast/test_pre_parser.py +++ b/tests/ast/test_pre_parser.py @@ -1,8 +1,9 @@ import pytest from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings -from vyper.exceptions import VersionException +from vyper.exceptions import StructureException, VersionException SRC_LINE = (1, 0) # Dummy source line COMPILER_VERSION = "0.1.1" @@ -96,43 +97,50 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): """ """, Settings(), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize """, Settings(optimize=OptimizationLevel.CODESIZE), + None, ), ( """ #pragma optimize none """, Settings(optimize=OptimizationLevel.NONE), + None, ), ( """ #pragma optimize gas """, Settings(optimize=OptimizationLevel.GAS), + None, ), ( """ #pragma version 0.3.10 """, Settings(compiler_version="0.3.10"), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma evm-version shanghai """, Settings(evm_version="shanghai"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize #pragma evm-version shanghai """, - Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), + Settings(evm_version="shanghai", optimize=OptimizationLevel.CODESIZE), + None, ), ( """ @@ -140,6 +148,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma evm-version shanghai """, Settings(evm_version="shanghai", compiler_version="0.3.10"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ @@ -147,6 +156,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS), + Settings(optimize=OptimizationLevel.GAS), ), ( """ @@ -155,11 +165,59 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS, evm_version="shanghai"), + Settings(optimize=OptimizationLevel.GAS, evm_version="shanghai"), ), ] -@pytest.mark.parametrize("code, expected_pragmas", pragma_examples) -def parse_pragmas(code, expected_pragmas): - pragmas, _, _ = pre_parse(code) - assert pragmas == expected_pragmas +@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) +def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): + mock_version("0.3.10") + settings, _, _ = pre_parse(code) + + assert settings == pre_parse_settings + + compiler_data = CompilerData(code) + + # check what happens after CompilerData constructor + if compiler_data_settings is None: + # None is sentinel here meaning that nothing changed + compiler_data_settings = pre_parse_settings + + assert compiler_data.settings == compiler_data_settings + + +invalid_pragmas = [ + # evm-versionnn + """ +# pragma evm-versionnn cancun + """, + # bad fork name + """ +# pragma evm-version cancunn + """, + # oppptimize + """ +# pragma oppptimize codesize + """, + # ggas + """ +# pragma optimize ggas + """, + # double specified + """ +# pragma optimize gas +# pragma optimize codesize + """, + # double specified + """ +# pragma evm-version cancun +# pragma evm-version shanghai + """, +] + + +@pytest.mark.parametrize("code", invalid_pragmas) +def test_invalid_pragma(code): + with pytest.raises(StructureException): + pre_parse(code) diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 0ead889787..9d96efea5e 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -111,7 +111,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version - if pragma.startswith("optimize "): + elif pragma.startswith("optimize "): if settings.optimize is not None: raise StructureException("pragma optimize specified twice!", start) try: @@ -119,7 +119,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: settings.optimize = OptimizationLevel.from_string(mode) except ValueError: raise StructureException(f"Invalid optimization mode `{mode}`", start) - if pragma.startswith("evm-version "): + elif pragma.startswith("evm-version "): if settings.evm_version is not None: raise StructureException("pragma evm-version specified twice!", start) evm_version = pragma.removeprefix("evm-version").strip() @@ -127,6 +127,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + else: + raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") + if typ == NAME and string in ("class", "yield"): raise SyntaxException( f"The `{string}` keyword is not allowed. ", code, start[0], start[1] diff --git a/vyper/compiler/README.md b/vyper/compiler/README.md index d6b55fdd82..eb70750a2b 100644 --- a/vyper/compiler/README.md +++ b/vyper/compiler/README.md @@ -51,11 +51,9 @@ for specific implementation details. [`vyper.compiler.compile_codes`](__init__.py) is the main user-facing function for generating compiler output from Vyper source. The process is as follows: -1. The `@evm_wrapper` decorator sets the target EVM version in -[`opcodes.py`](../evm/opcodes.py). -2. A [`CompilerData`](phases.py) object is created for each contract to be compiled. +1. A [`CompilerData`](phases.py) object is created for each contract to be compiled. This object uses `@property` methods to trigger phases of the compiler as required. -3. Functions in [`output.py`](output.py) generate the requested outputs from the +2. Functions in [`output.py`](output.py) generate the requested outputs from the compiler data. ## Design diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0b3c0d8191..b1c4201361 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -120,17 +120,17 @@ def compile_codes( # make IR output the same between runs codegen.reset_names() - with anchor_evm_version(settings.evm_version): - compiler_data = CompilerData( - source_code, - contract_name, - interfaces, - source_id, - settings, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, - ) + compiler_data = CompilerData( + source_code, + contract_name, + interfaces, + source_id, + settings, + storage_layout_override, + show_gas_estimates, + no_bytecode_metadata, + ) + with anchor_evm_version(compiler_data.settings.evm_version): for output_format in output_formats[contract_name]: if output_format not in OUTPUT_FORMATS: raise ValueError(f"Unsupported format type {repr(output_format)}") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index a1c7342320..5ddf071caf 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -88,9 +88,12 @@ def __init__( self.no_bytecode_metadata = no_bytecode_metadata self.settings = settings or Settings() + _ = self._generate_ast # force settings to be calculated + @cached_property def _generate_ast(self): settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + # validate the compiler settings # XXX: this is a bit ugly, clean up later if settings.evm_version is not None: @@ -117,6 +120,8 @@ def _generate_ast(self): if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() + # note self.settings.compiler_version is erased here as it is + # not used after pre-parsing return ast @cached_property From e9c16e40dd11ba21ba817ff0da78eaeab744fd39 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 3 Oct 2023 16:36:58 -0700 Subject: [PATCH 02/10] fix: block `mload` merging when src and dst overlap (#3635) this commit fixes an optimization bug when the target architecture has the `mcopy` instruction (i.e. `cancun` or later). the bug was introduced in 5dc3ac7. specifically, the `merge_mload` step can incorrectly merge `mload`/`mstore` sequences (into `mcopy`) when the source and destination buffers overlap, and the destination buffer is "ahead of" (i.e. greater than) the source buffer. this commit fixes the issue by blocking the optimization in these cases, and adds unit and functional tests demonstrating the correct behavior. --------- Co-authored-by: Robert Chen --- tests/compiler/ir/test_optimize_ir.py | 107 +++++++++++++++++++++++ tests/parser/features/test_assignment.py | 60 +++++++++++++ vyper/ir/optimizer.py | 9 +- 3 files changed, 174 insertions(+), 2 deletions(-) diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index 1466166501..cb46ba238d 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -1,9 +1,13 @@ import pytest from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version from vyper.exceptions import StaticAssertionException from vyper.ir import optimizer +POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + optimize_list = [ (["eq", 1, 2], [0]), (["lt", 1, 2], [1]), @@ -272,3 +276,106 @@ def test_operator_set_values(): assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"} assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"} assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"} + + +mload_merge_list = [ + # copy "backward" with no overlap between src and dst buffers, + # OK to become mcopy + ( + ["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]], + ["mcopy", 32, 128, 64], + ), + # copy with overlap "backwards", OK to become mcopy + (["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]), + # "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy + (["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]), + # copy "forward" with no overlap, OK to become mcopy + (["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]), + # copy "forwards" with overlap by one word, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None), + # check "forward" overlap by one byte, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None), + # check "forward" overlap by one byte again, must NOT become mcopy + (["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None), + # copy 3 words with partial overlap "forwards", partially becomes mcopy + # (2 words are mcopied and 1 word is mload/mstored + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]], + ), + # copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ["mstore", 192, ["mload", 128]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]], + ), + # copy 4 words with 1 byte of overlap, must NOT become mcopy + ( + [ + "seq", + ["mstore", 96, ["mload", 33]], + ["mstore", 128, ["mload", 65]], + ["mstore", 160, ["mload", 97]], + ["mstore", 192, ["mload", 129]], + ], + None, + ), + # Ensure only sequential mstore + mload sequences are optimized + ( + [ + "seq", + ["mstore", 0, ["mload", 32]], + ["sstore", 0, ["calldataload", 4]], + ["mstore", 32, ["mload", 64]], + ], + None, + ), + # not-word aligned optimizations (not overlap) + (["seq", ["mstore", 0, ["mload", 1]], ["mstore", 32, ["mload", 33]]], ["mcopy", 0, 1, 64]), + # not-word aligned optimizations (overlap) + (["seq", ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], None), + # not-word aligned optimizations (overlap and not-overlap) + ( + [ + "seq", + ["mstore", 0, ["mload", 1]], + ["mstore", 32, ["mload", 33]], + ["mstore", 1, ["mload", 0]], + ["mstore", 33, ["mload", 32]], + ], + ["seq", ["mcopy", 0, 1, 64], ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], + ), + # overflow test + ( + [ + "seq", + ["mstore", 2**256 - 1 - 31 - 32, ["mload", 0]], + ["mstore", 2**256 - 1 - 31, ["mload", 32]], + ], + ["mcopy", 2**256 - 1 - 31 - 32, 0, 64], + ), +] + + +@pytest.mark.parametrize("ir", mload_merge_list) +@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys())) +def test_mload_merge(ir, evm_version): + with anchor_evm_version(evm_version): + optimized = optimizer.optimize(IRnode.from_list(ir[0])) + if ir[1] is None: + # no-op, assert optimizer does nothing + expected = IRnode.from_list(ir[0]) + else: + expected = IRnode.from_list(ir[1]) + + assert optimized == expected diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index e550f60541..35b008a8ba 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -442,3 +442,63 @@ def bug(p: Point) -> Point: """ c = get_contract(code) assert c.bug((1, 2)) == (2, 1) + + +mload_merge_codes = [ + ( + """ +@external +def foo() -> uint256[4]: + # copy "backwards" + xs: uint256[4] = [1, 2, 3, 4] + +# dst < src + xs[0] = xs[1] + xs[1] = xs[2] + xs[2] = xs[3] + + return xs + """, + [2, 3, 4, 4], + ), + ( + """ +@external +def foo() -> uint256[4]: + # copy "forwards" + xs: uint256[4] = [1, 2, 3, 4] + +# src < dst + xs[1] = xs[0] + xs[2] = xs[1] + xs[3] = xs[2] + + return xs + """, + [1, 1, 1, 1], + ), + ( + """ +@external +def foo() -> uint256[5]: + # partial "forward" copy + xs: uint256[5] = [1, 2, 3, 4, 5] + +# src < dst + xs[2] = xs[0] + xs[3] = xs[1] + xs[4] = xs[2] + + return xs + """, + [1, 2, 1, 2, 1], + ), +] + + +# functional test that mload merging does not occur when source and dest +# buffers overlap. (note: mload merging only applies after cancun) +@pytest.mark.parametrize("code,expected_result", mload_merge_codes) +def test_mcopy_overlap(get_contract, code, expected_result): + c = get_contract(code) + assert c.foo() == expected_result diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 08c2168381..8df4bbac2d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz): def _merge_mload(argz): if not version_check(begin="cancun"): return False - return _merge_load(argz, "mload", "mcopy") + return _merge_load(argz, "mload", "mcopy", allow_overlap=False) -def _merge_load(argz, _LOAD, _COPY): +def _merge_load(argz, _LOAD, _COPY, allow_overlap=True): # look for sequential operations copying from X to Y # and merge them into a single copy operation changed = False @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY): initial_dst_offset = dst_offset initial_src_offset = src_offset idx = i + + # dst and src overlap, discontinue the optimization + has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32 + if ( initial_dst_offset + total_length == dst_offset and initial_src_offset + total_length == src_offset + and (allow_overlap or not has_overlap) ): mstore_nodes.append(ir_node) total_length += 32 From 9136169468f317a53b4e7448389aa315f90b95ba Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 4 Oct 2023 06:16:44 -0700 Subject: [PATCH 03/10] docs: v0.3.10 release (#3629) --- docs/release-notes.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index da86c5c0ce..64199bc860 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,14 +14,10 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g -.. - v0.3.10 ("Black Adder") - *********************** - -v0.3.10rc1 -********** +v0.3.10 ("Black Adder") +*********************** -Date released: 2023-09-06 +Date released: 2023-10-04 ========================= v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). @@ -32,6 +28,7 @@ Breaking changes: - add runtime code layout to initcode (`#3584 `_) - drop evm versions through istanbul (`#3470 `_) - remove vyper signature from runtime (`#3471 `_) +- only allow valid identifiers to be nonreentrant keys (`#3605 `_) Non-breaking changes and improvements: -------------------------------------- @@ -46,12 +43,15 @@ Notable fixes: - fix ``ecrecover()`` behavior when signature is invalid (`GHSA-f5x6-7qgp-jhf3 `_, `#3586 `_) - fix: order of evaluation for some builtins (`#3583 `_, `#3587 `_) +- fix: memory allocation in certain builtins using ``msize`` (`#3610 `_) +- fix: ``_abi_decode()`` input validation in certain complex expressions (`#3626 `_) - fix: pycryptodome for arm builds (`#3485 `_) - let params of internal functions be mutable (`#3473 `_) - typechecking of folded builtins in (`#3490 `_) - update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_) - fix: raw_call type when max_outsize=0 is set (`#3572 `_) - fix: implements check for indexed event arguments (`#3570 `_) +- fix: type-checking for ``_abi_decode()`` arguments (`#3626 `_) Other docs updates, chores and fixes: ------------------------------------- From b8b4610a46379558367ba60f94ac2813eec056d4 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Wed, 4 Oct 2023 20:25:09 +0200 Subject: [PATCH 04/10] chore: update `FUNDING.yml` to point directly at wallet (#3636) Add link to the Vyper multisig address. --- FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FUNDING.yml b/FUNDING.yml index 81e82160d0..efb9eb01b7 100644 --- a/FUNDING.yml +++ b/FUNDING.yml @@ -1 +1 @@ -custom: https://gitcoin.co/grants/200/vyper-smart-contract-language-2 +custom: https://etherscan.io/address/0x70CCBE10F980d80b7eBaab7D2E3A73e87D67B775 From 5ca0cbfa4064a5577347bca8377bc9bea1cebca5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 4 Oct 2023 12:17:36 -0700 Subject: [PATCH 05/10] docs: fix nit in v0.3.10 release notes (#3638) --- docs/release-notes.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 64199bc860..3db11dc451 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -20,7 +20,7 @@ v0.3.10 ("Black Adder") Date released: 2023-10-04 ========================= -v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). +v0.3.10 is a performance focused release that additionally ships numerous bugfixes. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). Breaking changes: ----------------- From 435754dd1db3e1b0b4608569e003f060e6e9eb40 Mon Sep 17 00:00:00 2001 From: sudo rm -rf --no-preserve-root / Date: Thu, 5 Oct 2023 16:54:29 +0200 Subject: [PATCH 06/10] docs: add note on `pragma` parsing (#3640) * Fix broken link in `structure-of-a-contract.rst` * Add note on `pragma` version parsing --- docs/structure-of-a-contract.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index d2c5d48d96..3861bf4380 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -17,7 +17,7 @@ Vyper supports several source code directives to control compiler modes and help Version Pragma -------------- -The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers _`. +The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers `_. As of 0.3.10, the recommended way to specify the version pragma is as follows: @@ -25,6 +25,10 @@ As of 0.3.10, the recommended way to specify the version pragma is as follows: #pragma version ^0.3.0 +.. note:: + + Both pragma directive versions ``#pragma`` and ``# pragma`` are supported. + The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: .. code-block:: python From 68da04b2e9e010c2e4da288a80eeeb9c8e076025 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 5 Oct 2023 09:04:18 -0700 Subject: [PATCH 07/10] fix: block memory allocation overflow (#3639) this fixes potential overflow bugs in pointer calculation by blocking memory allocation above a certain size. the size limit is set at `2**64`, which is the size of addressable memory on physical machines. practically, for EVM use cases, we could limit at a much smaller number (like `2**24`), but we want to allow for "exotic" targets which may allow much more addressable memory. --- vyper/codegen/memory_allocator.py | 12 +++++++++++- vyper/exceptions.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/vyper/codegen/memory_allocator.py b/vyper/codegen/memory_allocator.py index 582d4b9c54..b5e1212917 100644 --- a/vyper/codegen/memory_allocator.py +++ b/vyper/codegen/memory_allocator.py @@ -1,6 +1,6 @@ from typing import List -from vyper.exceptions import CompilerPanic +from vyper.exceptions import CompilerPanic, MemoryAllocationException from vyper.utils import MemoryPositions @@ -46,6 +46,8 @@ class MemoryAllocator: next_mem: int + _ALLOCATION_LIMIT: int = 2**64 + def __init__(self, start_position: int = MemoryPositions.RESERVED_MEMORY): """ Initializer. @@ -110,6 +112,14 @@ def _expand_memory(self, size: int) -> int: before_value = self.next_mem self.next_mem += size self.size_of_mem = max(self.size_of_mem, self.next_mem) + + if self.size_of_mem >= self._ALLOCATION_LIMIT: + # this should not be caught + raise MemoryAllocationException( + f"Tried to allocate {self.size_of_mem} bytes! " + f"(limit is {self._ALLOCATION_LIMIT} (2**64) bytes)" + ) + return before_value def deallocate_memory(self, pos: int, size: int) -> None: diff --git a/vyper/exceptions.py b/vyper/exceptions.py index defca7cc53..8b2020285a 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -269,6 +269,10 @@ class StorageLayoutException(VyperException): """Invalid slot for the storage layout overrides""" +class MemoryAllocationException(VyperException): + """Tried to allocate too much memory""" + + class JSONError(Exception): """Invalid compiler input JSON.""" From 74a8e0254461119af9a5d504f326877d7aed4134 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 17 Oct 2023 09:23:40 -0700 Subject: [PATCH 08/10] chore: reorder compilation of branches in stmt.py (#3603) the if and else branches were being allocated out-of-source-order. this commit switches the order of compilation of the if and else branches to be in source order. this is a hygienic fix, right now the only thing that should be affected is the memory allocator (but if more side effects are ever introduced in codegen, the existing code might compile the side effects out of order). --- vyper/codegen/stmt.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index c2951986c8..254cad32e6 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -91,17 +91,15 @@ def parse_Assign(self): return IRnode.from_list(ret) def parse_If(self): + with self.context.block_scope(): + test_expr = Expr.parse_value_expr(self.stmt.test, self.context) + body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + if self.stmt.orelse: with self.context.block_scope(): - add_on = [parse_body(self.stmt.orelse, self.context)] - else: - add_on = [] + body.extend([parse_body(self.stmt.orelse, self.context)]) - with self.context.block_scope(): - test_expr = Expr.parse_value_expr(self.stmt.test, self.context) - body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + add_on - ir_node = IRnode.from_list(body) - return ir_node + return IRnode.from_list(body) def parse_Log(self): event = self.stmt._metadata["type"] From 5482bbcbed22b856bec6e57c06aeb7e0bee9a1ab Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 18 Oct 2023 00:26:56 +0800 Subject: [PATCH 09/10] feat: remove builtin constants (#3350) Builtin constants like `MAX_INT128`, `MIN_INT128`, `MIN_DECIMAL`, `MAX_DECIMAL` and `MAX_UINT256` have been deprecated since v0.3.4 with the introduction of the `min_value` and `max_value` builtins, and further back for `ZERO_ADDRESS` and `EMPTY_BYTES32` with the `empty` builtin. This PR removes them from the language entirely, and will be a breaking change. --- tests/ast/test_folding.py | 43 ------------------- .../features/decorators/test_private.py | 2 +- .../test_external_contract_calls.py | 14 +++--- .../features/iteration/test_for_in_list.py | 2 +- tests/parser/features/test_assignment.py | 4 +- tests/parser/features/test_memory_dealloc.py | 2 +- tests/parser/functions/test_abi_decode.py | 4 +- tests/parser/functions/test_convert.py | 18 -------- .../parser/functions/test_default_function.py | 2 +- tests/parser/functions/test_empty.py | 20 ++++----- tests/parser/functions/test_return_tuple.py | 4 +- tests/parser/integration/test_escrow.py | 4 +- tests/parser/syntax/test_bool.py | 4 +- tests/parser/syntax/test_interfaces.py | 4 +- tests/parser/syntax/test_no_none.py | 12 +++--- tests/parser/syntax/test_tuple_assign.py | 2 +- tests/parser/syntax/test_unbalanced_return.py | 4 +- tests/parser/types/numbers/test_constants.py | 8 ++-- tests/parser/types/test_identifier_naming.py | 8 +--- tests/parser/types/test_string.py | 2 +- vyper/ast/folding.py | 33 -------------- vyper/compiler/phases.py | 1 - 22 files changed, 49 insertions(+), 148 deletions(-) diff --git a/tests/ast/test_folding.py b/tests/ast/test_folding.py index 22d5f58222..62a7140e97 100644 --- a/tests/ast/test_folding.py +++ b/tests/ast/test_folding.py @@ -132,49 +132,6 @@ def test_replace_constant_no(source): assert vy_ast.compare_nodes(unmodified_ast, folded_ast) -builtins_modified = [ - "ZERO_ADDRESS", - "foo = ZERO_ADDRESS", - "foo: int128[ZERO_ADDRESS] = 42", - "foo = [ZERO_ADDRESS]", - "def foo(bar: address = ZERO_ADDRESS): pass", - "def foo(): bar = ZERO_ADDRESS", - "def foo(): return ZERO_ADDRESS", - "log foo(ZERO_ADDRESS)", - "log foo(42, ZERO_ADDRESS)", -] - - -@pytest.mark.parametrize("source", builtins_modified) -def test_replace_builtin_constant(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_builtin_constants(folded_ast) - - assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) - - -builtins_unmodified = [ - "ZERO_ADDRESS = 2", - "ZERO_ADDRESS()", - "def foo(ZERO_ADDRESS: int128 = 42): pass", - "def foo(): ZERO_ADDRESS = 42", - "def ZERO_ADDRESS(): pass", - "log ZERO_ADDRESS(42)", -] - - -@pytest.mark.parametrize("source", builtins_unmodified) -def test_replace_builtin_constant_no(source): - unmodified_ast = vy_ast.parse_to_ast(source) - folded_ast = vy_ast.parse_to_ast(source) - - folding.replace_builtin_constants(folded_ast) - - assert vy_ast.compare_nodes(unmodified_ast, folded_ast) - - userdefined_modified = [ "FOO", "foo = FOO", diff --git a/tests/parser/features/decorators/test_private.py b/tests/parser/features/decorators/test_private.py index 7c92f72af9..51e6d90ee1 100644 --- a/tests/parser/features/decorators/test_private.py +++ b/tests/parser/features/decorators/test_private.py @@ -304,7 +304,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): b: uint256 = 1 c: int128 = 1 d: int128 = 123 - f: bytes32 = EMPTY_BYTES32 + f: bytes32 = empty(bytes32) f, b, c = self._test(a) assert d == 123 return f, b, c diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index b3cc6f5576..12fcde2f4f 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -775,9 +775,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -808,9 +808,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -841,9 +841,9 @@ def foo() -> (address, Bytes[3], address): view @external def bar(arg1: address) -> (address, Bytes[3], address): - a: address = ZERO_ADDRESS + a: address = empty(address) b: Bytes[3] = b"" - c: address = ZERO_ADDRESS + c: address = empty(address) a, b, c = Foo(arg1).foo() return a, b, c """ @@ -1538,7 +1538,7 @@ def out_literals() -> (int128, address, Bytes[10]) : view @external def test(addr: address) -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: Bytes[10] = b"" (a, b, c) = Test(addr).out_literals() return a, b,c diff --git a/tests/parser/features/iteration/test_for_in_list.py b/tests/parser/features/iteration/test_for_in_list.py index bfd960a787..fb01cc98eb 100644 --- a/tests/parser/features/iteration/test_for_in_list.py +++ b/tests/parser/features/iteration/test_for_in_list.py @@ -230,7 +230,7 @@ def iterate_return_second() -> address: count += 1 if count == 2: return i - return ZERO_ADDRESS + return empty(address) """ c = get_contract_with_gas_estimation(code) diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 35b008a8ba..cd26659a5c 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -331,7 +331,7 @@ def foo(): @external def foo(): y: int128 = 1 - z: bytes32 = EMPTY_BYTES32 + z: bytes32 = empty(bytes32) z = y """, """ @@ -344,7 +344,7 @@ def foo(): @external def foo(): y: uint256 = 1 - z: bytes32 = EMPTY_BYTES32 + z: bytes32 = empty(bytes32) z = y """, ], diff --git a/tests/parser/features/test_memory_dealloc.py b/tests/parser/features/test_memory_dealloc.py index de82f03296..814bf0d3bb 100644 --- a/tests/parser/features/test_memory_dealloc.py +++ b/tests/parser/features/test_memory_dealloc.py @@ -9,7 +9,7 @@ def sendit(): nonpayable @external def foo(target: address) -> uint256[2]: - log Shimmy(ZERO_ADDRESS, 3) + log Shimmy(empty(address), 3) amount: uint256 = 1 flargen: uint256 = 42 Other(target).sendit() diff --git a/tests/parser/functions/test_abi_decode.py b/tests/parser/functions/test_abi_decode.py index 2216a5bd76..242841e1cf 100644 --- a/tests/parser/functions/test_abi_decode.py +++ b/tests/parser/functions/test_abi_decode.py @@ -25,7 +25,7 @@ def test_abi_decode_complex(get_contract): @external def abi_decode(x: Bytes[160]) -> (address, int128, bool, decimal, bytes32): - a: address = ZERO_ADDRESS + a: address = empty(address) b: int128 = 0 c: bool = False d: decimal = 0.0 @@ -39,7 +39,7 @@ def abi_decode_struct(x: Bytes[544]) -> Human: name: "", pet: Animal({ name: "", - address_: ZERO_ADDRESS, + address_: empty(address), id_: 0, is_furry: False, price: 0.0, diff --git a/tests/parser/functions/test_convert.py b/tests/parser/functions/test_convert.py index eb8449447c..b5ce613235 100644 --- a/tests/parser/functions/test_convert.py +++ b/tests/parser/functions/test_convert.py @@ -534,24 +534,6 @@ def foo(a: {typ}) -> Status: assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch) -# TODO CMC 2022-04-06 I think this test is somewhat unnecessary. -@pytest.mark.parametrize( - "builtin_constant,out_type,out_value", - [("ZERO_ADDRESS", "bool", False), ("msg.sender", "bool", True)], -) -def test_convert_builtin_constant( - get_contract_with_gas_estimation, builtin_constant, out_type, out_value -): - contract = f""" -@external -def convert_builtin_constant() -> {out_type}: - return convert({builtin_constant}, {out_type}) - """ - - c = get_contract_with_gas_estimation(contract) - assert c.convert_builtin_constant() == out_value - - # uint256 conversion is currently valid due to type inference on literals # not quite working yet same_type_conversion_blocked = sorted(TEST_TYPES - {UINT256_T}) diff --git a/tests/parser/functions/test_default_function.py b/tests/parser/functions/test_default_function.py index 4aa0b04a77..4ad68697ac 100644 --- a/tests/parser/functions/test_default_function.py +++ b/tests/parser/functions/test_default_function.py @@ -41,7 +41,7 @@ def test_basic_default_default_param_function(w3, get_logs, get_contract_with_ga @external @payable def fooBar(a: int128 = 12345) -> int128: - log Sent(ZERO_ADDRESS) + log Sent(empty(address)) return a @external diff --git a/tests/parser/functions/test_empty.py b/tests/parser/functions/test_empty.py index c10d03550a..c3627785dc 100644 --- a/tests/parser/functions/test_empty.py +++ b/tests/parser/functions/test_empty.py @@ -87,8 +87,8 @@ def foo(): self.foobar = empty(address) bar = empty(address) - assert self.foobar == ZERO_ADDRESS - assert bar == ZERO_ADDRESS + assert self.foobar == empty(address) + assert bar == empty(address) """, """ @external @@ -214,12 +214,12 @@ def foo(): self.foobar = empty(address[3]) bar = empty(address[3]) - assert self.foobar[0] == ZERO_ADDRESS - assert self.foobar[1] == ZERO_ADDRESS - assert self.foobar[2] == ZERO_ADDRESS - assert bar[0] == ZERO_ADDRESS - assert bar[1] == ZERO_ADDRESS - assert bar[2] == ZERO_ADDRESS + assert self.foobar[0] == empty(address) + assert self.foobar[1] == empty(address) + assert self.foobar[2] == empty(address) + assert bar[0] == empty(address) + assert bar[1] == empty(address) + assert bar[2] == empty(address) """, ], ) @@ -376,14 +376,14 @@ def foo(): assert self.foobar.c == False assert self.foobar.d == 0.0 assert self.foobar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 - assert self.foobar.f == ZERO_ADDRESS + assert self.foobar.f == empty(address) assert bar.a == 0 assert bar.b == 0 assert bar.c == False assert bar.d == 0.0 assert bar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 - assert bar.f == ZERO_ADDRESS + assert bar.f == empty(address) """ c = get_contract_with_gas_estimation(code) diff --git a/tests/parser/functions/test_return_tuple.py b/tests/parser/functions/test_return_tuple.py index 87b7cdcde3..b375839147 100644 --- a/tests/parser/functions/test_return_tuple.py +++ b/tests/parser/functions/test_return_tuple.py @@ -99,7 +99,7 @@ def out_literals() -> (int128, address, Bytes[10]): @external def test() -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: Bytes[10] = b"" (a, b, c) = self._out_literals() return a, b, c @@ -138,7 +138,7 @@ def test2() -> (int128, address): @external def test3() -> (address, int128): - x: address = ZERO_ADDRESS + x: address = empty(address) self.a, self.c, x, self.d = self._out_literals() return x, self.a """ diff --git a/tests/parser/integration/test_escrow.py b/tests/parser/integration/test_escrow.py index 2982ff9eae..1578f5a418 100644 --- a/tests/parser/integration/test_escrow.py +++ b/tests/parser/integration/test_escrow.py @@ -9,7 +9,7 @@ def test_arbitration_code(w3, get_contract_with_gas_estimation, assert_tx_failed @external def setup(_seller: address, _arbitrator: address): - if self.buyer == ZERO_ADDRESS: + if self.buyer == empty(address): self.buyer = msg.sender self.seller = _seller self.arbitrator = _arbitrator @@ -43,7 +43,7 @@ def test_arbitration_code_with_init(w3, assert_tx_failed, get_contract_with_gas_ @external @payable def __init__(_seller: address, _arbitrator: address): - if self.buyer == ZERO_ADDRESS: + if self.buyer == empty(address): self.buyer = msg.sender self.seller = _seller self.arbitrator = _arbitrator diff --git a/tests/parser/syntax/test_bool.py b/tests/parser/syntax/test_bool.py index 09f799d91c..48ed37321a 100644 --- a/tests/parser/syntax/test_bool.py +++ b/tests/parser/syntax/test_bool.py @@ -52,7 +52,7 @@ def foo() -> bool: """ @external def foo() -> bool: - a: address = ZERO_ADDRESS + a: address = empty(address) return a == 1 """, ( @@ -137,7 +137,7 @@ def foo() -> bool: """ @external def foo2(a: address) -> bool: - return a != ZERO_ADDRESS + return a != empty(address) """, ] diff --git a/tests/parser/syntax/test_interfaces.py b/tests/parser/syntax/test_interfaces.py index 5afb34e6bd..498f1363d8 100644 --- a/tests/parser/syntax/test_interfaces.py +++ b/tests/parser/syntax/test_interfaces.py @@ -47,7 +47,7 @@ def test(): @external def test(): - a: address(ERC20) = ZERO_ADDRESS + a: address(ERC20) = empty(address) """, InvalidType, ), @@ -306,7 +306,7 @@ def some_func(): nonpayable @external def __init__(): - self.my_interface[self.idx] = MyInterface(ZERO_ADDRESS) + self.my_interface[self.idx] = MyInterface(empty(address)) """, """ interface MyInterface: diff --git a/tests/parser/syntax/test_no_none.py b/tests/parser/syntax/test_no_none.py index 7030a56b18..24c32a46a4 100644 --- a/tests/parser/syntax/test_no_none.py +++ b/tests/parser/syntax/test_no_none.py @@ -30,13 +30,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) bar = None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) bar = None """, """ @@ -104,13 +104,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) assert bar is None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) assert bar is None """, ] @@ -148,13 +148,13 @@ def foo(): """ @external def foo(): - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) assert bar == None """, """ @external def foo(): - bar: address = ZERO_ADDRESS + bar: address = empty(address) assert bar == None """, ] diff --git a/tests/parser/syntax/test_tuple_assign.py b/tests/parser/syntax/test_tuple_assign.py index 115499ce8b..49b63ee614 100644 --- a/tests/parser/syntax/test_tuple_assign.py +++ b/tests/parser/syntax/test_tuple_assign.py @@ -41,7 +41,7 @@ def out_literals() -> (int128, int128, Bytes[10]): @external def test() -> (int128, address, Bytes[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) a, b = self.out_literals() # tuple count mismatch return """, diff --git a/tests/parser/syntax/test_unbalanced_return.py b/tests/parser/syntax/test_unbalanced_return.py index 5337b4b677..d1d9732777 100644 --- a/tests/parser/syntax/test_unbalanced_return.py +++ b/tests/parser/syntax/test_unbalanced_return.py @@ -56,7 +56,7 @@ def valid_address(sender: address) -> bool: """ @internal def valid_address(sender: address) -> bool: - if sender == ZERO_ADDRESS: + if sender == empty(address): selfdestruct(sender) _sender: address = sender else: @@ -144,7 +144,7 @@ def test() -> int128: """ @external def test() -> int128: - x: bytes32 = EMPTY_BYTES32 + x: bytes32 = empty(bytes32) if False: if False: return 0 diff --git a/tests/parser/types/numbers/test_constants.py b/tests/parser/types/numbers/test_constants.py index 0d5e386dad..652c8e8bd9 100644 --- a/tests/parser/types/numbers/test_constants.py +++ b/tests/parser/types/numbers/test_constants.py @@ -12,12 +12,12 @@ def test_builtin_constants(get_contract_with_gas_estimation): code = """ @external def test_zaddress(a: address) -> bool: - return a == ZERO_ADDRESS + return a == empty(address) @external def test_empty_bytes32(a: bytes32) -> bool: - return a == EMPTY_BYTES32 + return a == empty(bytes32) @external @@ -81,12 +81,12 @@ def goo() -> int128: @external def hoo() -> bytes32: - bar: bytes32 = EMPTY_BYTES32 + bar: bytes32 = empty(bytes32) return bar @external def joo() -> address: - bar: address = ZERO_ADDRESS + bar: address = empty(address) return bar @external diff --git a/tests/parser/types/test_identifier_naming.py b/tests/parser/types/test_identifier_naming.py index 5cfc7e8ed7..0a93329848 100755 --- a/tests/parser/types/test_identifier_naming.py +++ b/tests/parser/types/test_identifier_naming.py @@ -1,16 +1,12 @@ import pytest -from vyper.ast.folding import BUILTIN_CONSTANTS from vyper.ast.identifiers import RESERVED_KEYWORDS from vyper.builtins.functions import BUILTIN_FUNCTIONS from vyper.codegen.expr import ENVIRONMENT_VARIABLES from vyper.exceptions import NamespaceCollision, StructureException, SyntaxException from vyper.semantics.types.primitives import AddressT -BUILTIN_CONSTANTS = set(BUILTIN_CONSTANTS.keys()) -ALL_RESERVED_KEYWORDS = ( - BUILTIN_CONSTANTS | BUILTIN_FUNCTIONS | RESERVED_KEYWORDS | ENVIRONMENT_VARIABLES -) +ALL_RESERVED_KEYWORDS = BUILTIN_FUNCTIONS | RESERVED_KEYWORDS | ENVIRONMENT_VARIABLES @pytest.mark.parametrize("constant", sorted(ALL_RESERVED_KEYWORDS)) @@ -46,7 +42,7 @@ def test({constant}: int128): SELF_NAMESPACE_MEMBERS = set(AddressT._type_members.keys()) -DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | RESERVED_KEYWORDS | BUILTIN_CONSTANTS +DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | RESERVED_KEYWORDS ALLOWED_FN_NAMES = ALL_RESERVED_KEYWORDS - DISALLOWED_FN_NAMES diff --git a/tests/parser/types/test_string.py b/tests/parser/types/test_string.py index a5eef66dae..7f1fa71329 100644 --- a/tests/parser/types/test_string.py +++ b/tests/parser/types/test_string.py @@ -139,7 +139,7 @@ def out_literals() -> (int128, address, String[10]) : view @external def test(addr: address) -> (int128, address, String[10]): a: int128 = 0 - b: address = ZERO_ADDRESS + b: address = empty(address) c: String[10] = "" (a, b, c) = Test(addr).out_literals() return a, b,c diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index fbd1dfc2f4..38d58f6fd0 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,4 +1,3 @@ -import warnings from typing import Optional, Union from vyper.ast import nodes as vy_ast @@ -6,21 +5,6 @@ from vyper.exceptions import UnfoldableNode, UnknownType from vyper.semantics.types.base import VyperType from vyper.semantics.types.utils import type_from_annotation -from vyper.utils import SizeLimits - -BUILTIN_CONSTANTS = { - "EMPTY_BYTES32": ( - vy_ast.Hex, - "0x0000000000000000000000000000000000000000000000000000000000000000", - "empty(bytes32)", - ), # NOQA: E501 - "ZERO_ADDRESS": (vy_ast.Hex, "0x0000000000000000000000000000000000000000", "empty(address)"), - "MAX_INT128": (vy_ast.Int, 2**127 - 1, "max_value(int128)"), - "MIN_INT128": (vy_ast.Int, -(2**127), "min_value(int128)"), - "MAX_DECIMAL": (vy_ast.Decimal, SizeLimits.MAX_AST_DECIMAL, "max_value(decimal)"), - "MIN_DECIMAL": (vy_ast.Decimal, SizeLimits.MIN_AST_DECIMAL, "min_value(decimal)"), - "MAX_UINT256": (vy_ast.Int, 2**256 - 1, "max_value(uint256)"), -} def fold(vyper_module: vy_ast.Module) -> None: @@ -32,8 +16,6 @@ def fold(vyper_module: vy_ast.Module) -> None: vyper_module : Module Top-level Vyper AST node. """ - replace_builtin_constants(vyper_module) - changed_nodes = 1 while changed_nodes: changed_nodes = 0 @@ -138,21 +120,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: return changed_nodes -def replace_builtin_constants(vyper_module: vy_ast.Module) -> None: - """ - Replace references to builtin constants with their literal values. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - for name, (node, value, replacement) in BUILTIN_CONSTANTS.items(): - found = replace_constant(vyper_module, name, node(value=value), True) - if found > 0: - warnings.warn(f"{name} is deprecated. Please use `{replacement}` instead.") - - def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: """ Find user-defined constant assignments, and replace references diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5ddf071caf..72be4396e4 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -235,7 +235,6 @@ def generate_unfolded_ast( vyper_module: vy_ast.Module, interface_codes: Optional[InterfaceImports] ) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) - vy_ast.folding.replace_builtin_constants(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, interface_codes) From 3ba14124602b673d45b86bae7ff90a01d782acb5 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:26:40 +0800 Subject: [PATCH 10/10] refactor: merge `annotation.py` and `local.py` (#3456) This commit merges two phases of analysis: typechecking (`vyper/semantics/analysis/local.py`) and annotation of ast nodes with types (`vyper/semantics/analysis/annotation.py`). This is both for consistency with how it is done for `analysis/module.py`, and also because it increases internal consistency, as some typechecking was being done in `annotation.py`. It is also easier to maintain this way, since bugfixes or modifications to typechecking need to only be done in one place, rather than two passes. Lastly, it also probably improves performance, because it collapses two passes into one (and calls `get_*_type_from_node` less often). This commit also fixes a bug with accessing the iterator when the iterator is an empty list literal. --- tests/parser/syntax/test_list.py | 3 +- vyper/ast/nodes.pyi | 5 + vyper/builtins/_signatures.py | 2 +- vyper/builtins/functions.py | 6 + vyper/semantics/analysis/annotation.py | 283 ------------ vyper/semantics/analysis/common.py | 19 +- vyper/semantics/analysis/local.py | 607 ++++++++++++++++--------- vyper/semantics/analysis/utils.py | 13 +- 8 files changed, 420 insertions(+), 518 deletions(-) delete mode 100644 vyper/semantics/analysis/annotation.py diff --git a/tests/parser/syntax/test_list.py b/tests/parser/syntax/test_list.py index 3f81b911c8..db41de5526 100644 --- a/tests/parser/syntax/test_list.py +++ b/tests/parser/syntax/test_list.py @@ -305,8 +305,9 @@ def foo(): """ @external def foo(): + x: DynArray[uint256, 3] = [1, 2, 3] for i in [[], []]: - pass + x = i """, ] diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 0d59a2fa63..47c9af8526 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -142,6 +142,7 @@ class Expr(VyperNode): class UnaryOp(ExprNode): op: VyperNode = ... + operand: VyperNode = ... class USub(VyperNode): ... class Not(VyperNode): ... @@ -165,12 +166,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... + values: list[VyperNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... + left: VyperNode = ... + right: VyperNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -179,6 +183,7 @@ class LtE(VyperNode): ... class Gt(VyperNode): ... class GtE(VyperNode): ... class In(VyperNode): ... +class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d39a4a085f..2802421129 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -74,7 +74,7 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction: +class BuiltinFunction(VyperType): _has_varargs = False _kwargs: Dict[str, KwargSettings] = {} diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f07202831d..001939638b 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -475,6 +475,12 @@ def evaluate(self, node): return vy_ast.Int.from_node(node, value=length) + def infer_arg_types(self, node): + self._validate_arg_types(node) + # return a concrete type + typ = get_possible_types_from_node(node.args[0]).pop() + return [typ] + def build_IR(self, node, context): arg = Expr(node.args[0], context).ir_node if arg.value == "~calldata": diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py deleted file mode 100644 index 01ca51d7f4..0000000000 --- a/vyper/semantics/analysis/annotation.py +++ /dev/null @@ -1,283 +0,0 @@ -from vyper import ast as vy_ast -from vyper.exceptions import StructureException, TypeCheckFailure -from vyper.semantics.analysis.utils import ( - get_common_types, - get_exact_type_from_node, - get_possible_types_from_node, -) -from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t -from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT - - -class _AnnotationVisitorBase: - - """ - Annotation visitor base class. - - Annotation visitors apply metadata (such as type information) to vyper AST objects. - Immediately after type checking a statement-level node, that node is passed to - `StatementAnnotationVisitor`. Some expression nodes are then passed onward to - `ExpressionAnnotationVisitor` for additional annotation. - """ - - def visit(self, node, *args): - if isinstance(node, self.ignored_types): - return - # iterate over the MRO until we find a matching visitor function - # this lets us use a single function to broadly target several - # node types with a shared parent - for class_ in node.__class__.mro(): - ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - visitor_fn(node, *args) - return - raise StructureException(f"Cannot annotate: {node.ast_type}", node) - - -class StatementAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = (vy_ast.Break, vy_ast.Continue, vy_ast.Pass, vy_ast.Raise) - - def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None: - self.func = fn_node._metadata["type"] - self.namespace = namespace - self.expr_visitor = ExpressionAnnotationVisitor(self.func) - - assert self.func.n_keyword_args == len(fn_node.args.defaults) - for kwarg in self.func.keyword_args: - self.expr_visitor.visit(kwarg.default_value, kwarg.typ) - - def visit(self, node): - super().visit(node) - - def visit_AnnAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Assert(self, node): - self.expr_visitor.visit(node.test) - - def visit_Assign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_AugAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Expr(self, node): - self.expr_visitor.visit(node.value) - - def visit_If(self, node): - self.expr_visitor.visit(node.test) - - def visit_Log(self, node): - node._metadata["type"] = self.namespace[node.value.func.id] - self.expr_visitor.visit(node.value) - - def visit_Return(self, node): - if node.value is not None: - self.expr_visitor.visit(node.value, self.func.return_type) - - def visit_For(self, node): - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - self.expr_visitor.visit(node.iter) - - iter_type = node.target._metadata["type"] - if isinstance(node.iter, vy_ast.List): - # typecheck list literal as static array - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) - - -class ExpressionAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = () - - def __init__(self, fn_node: ContractFunctionT): - self.func = fn_node - - def visit(self, node, type_=None): - # the statement visitor sometimes passes type information about expressions - super().visit(node, type_) - - def visit_Attribute(self, node, type_): - type_ = get_exact_type_from_node(node) - node._metadata["type"] = type_ - self.visit(node.value, None) - - def visit_BinOp(self, node, type_): - if type_ is None: - type_ = get_common_types(node.left, node.right) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_BoolOp(self, node, type_): - for value in node.values: - self.visit(value) - - def visit_Call(self, node, type_): - call_type = get_exact_type_from_node(node.func) - node_type = type_ or call_type.fetch_call_return(node) - node._metadata["type"] = node_type - self.visit(node.func) - - if isinstance(call_type, ContractFunctionT): - # function calls - if call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): - self.visit(arg, typ) - for kwarg in node.keywords: - # We should only see special kwargs - self.visit(kwarg.value, call_type.call_site_kwargs[kwarg.arg].typ) - - elif is_type_t(call_type, EventT): - # events have no kwargs - for arg, typ in zip(node.args, list(call_type.typedef.arguments.values())): - self.visit(arg, typ) - elif is_type_t(call_type, StructT): - # struct ctors - # ctors have no kwargs - for value, arg_type in zip( - node.args[0].values, list(call_type.typedef.members.values()) - ): - self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): - self.visit(arg, arg_type) - else: - # builtin functions - arg_types = call_type.infer_arg_types(node) - for arg, arg_type in zip(node.args, arg_types): - self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) - for kwarg in node.keywords: - self.visit(kwarg.value, kwarg_types[kwarg.arg]) - - def visit_Compare(self, node, type_): - if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): - if isinstance(node.right, vy_ast.List): - type_ = get_common_types(node.left, *node.right.elements).pop() - self.visit(node.left, type_) - rlen = len(node.right.elements) - self.visit(node.right, SArrayT(type_, rlen)) - else: - type_ = get_exact_type_from_node(node.right) - self.visit(node.right, type_) - if isinstance(type_, EnumT): - self.visit(node.left, type_) - else: - # array membership - self.visit(node.left, type_.value_type) - else: - type_ = get_common_types(node.left, node.right).pop() - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_Constant(self, node, type_): - if type_ is None: - possible_types = get_possible_types_from_node(node) - if len(possible_types) == 1: - type_ = possible_types.pop() - node._metadata["type"] = type_ - - def visit_Dict(self, node, type_): - node._metadata["type"] = type_ - - def visit_Index(self, node, type_): - self.visit(node.value, type_) - - def visit_List(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node) - # CMC 2022-04-14 this seems sus. try to only annotate - # if get_possible_types only returns 1 type - if len(type_) >= 1: - type_ = type_.pop() - node._metadata["type"] = type_ - for element in node.elements: - self.visit(element, type_.value_type) - - def visit_Name(self, node, type_): - if isinstance(type_, TYPE_T): - node._metadata["type"] = type_ - else: - node._metadata["type"] = get_exact_type_from_node(node) - - def visit_Subscript(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - if isinstance(node.value, vy_ast.List): - possible_base_types = get_possible_types_from_node(node.value) - - if len(possible_base_types) == 1: - base_type = possible_base_types.pop() - - elif type_ is not None and len(possible_base_types) > 1: - for possible_type in possible_base_types: - if type_.compare_type(possible_type.value_type): - base_type = possible_type - break - else: - # this should have been caught in - # `get_possible_types_from_node` but wasn't. - raise TypeCheckFailure(f"Expected {type_} but it is not a possible type", node) - - else: - base_type = get_exact_type_from_node(node.value) - - # get the correct type for the index, it might - # not be base_type.key_type - index_types = get_possible_types_from_node(node.slice.value) - index_type = index_types.pop() - - self.visit(node.slice, index_type) - self.visit(node.value, base_type) - - def visit_Tuple(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - for element, subtype in zip(node.elements, type_.member_types): - self.visit(element, subtype) - - def visit_UnaryOp(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node.operand) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - self.visit(node.operand, type_) - - def visit_IfExp(self, node, type_): - if type_ is None: - ts = get_common_types(node.body, node.orelse) - if len(ts) == 1: - type_ = ts.pop() - - node._metadata["type"] = type_ - self.visit(node.test, BoolT()) - self.visit(node.body, type_) - self.visit(node.orelse, type_) diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 193d1892e1..507eb0a570 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -10,10 +10,17 @@ class VyperNodeVisitorBase: def visit(self, node, *args): if isinstance(node, self.ignored_types): return + + # iterate over the MRO until we find a matching visitor function + # this lets us use a single function to broadly target several + # node types with a shared parent + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) + node_type = type(node).__name__ - visitor_fn = getattr(self, f"visit_{node_type}", None) - if visitor_fn is None: - raise StructureException( - f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node - ) - visitor_fn(node, *args) + raise StructureException( + f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node + ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b391b33953..647f01c299 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -14,11 +14,11 @@ NonPayableViolation, StateAccessViolation, StructureException, + TypeCheckFailure, TypeMismatch, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.annotation import StatementAnnotationVisitor from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( @@ -34,9 +34,11 @@ from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( + TYPE_T, AddressT, BoolT, DArrayT, + EnumT, EventT, HashMapT, IntegerT, @@ -44,6 +46,8 @@ StringT, StructT, TupleT, + VyperType, + _BytestringT, is_type_t, ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability @@ -117,20 +121,8 @@ def _check_iterator_modification( return None -def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None: - if msg_node: - if isinstance(msg_node, vy_ast.Str): - if not msg_node.value.strip(): - raise StructureException("Reason string cannot be empty", msg_node) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - validate_expected_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - - -def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: - value_type = get_exact_type_from_node(node.value) +# helpers +def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": # Validate `slice(
.code, start, length)` where `length` is constant parent = node.get_ancestor() @@ -139,6 +131,7 @@ def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) if ok_func and ok_args: return + raise StructureException( "(address).code is only allowed inside of a slice function with a constant length", node ) @@ -160,8 +153,30 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: ) +def _validate_msg_value_access(node: vy_ast.Attribute) -> None: + if isinstance(node.value, vy_ast.Name) and node.attr == "value" and node.value.id == "msg": + raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) + + +def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: + env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) + if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: + if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: + return + + raise StateAccessViolation( + "not allowed to query contract or environment variables in pure functions", node + ) + + +def _validate_self_reference(node: vy_ast.Name) -> None: + # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through + if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): + raise StateAccessViolation("not allowed to query self in pure functions", node) + + class FunctionNodeVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Pass) + ignored_types = (vy_ast.Pass,) scope_name = "function" def __init__( @@ -171,8 +186,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["type"] - self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace) - self.expr_visitor = _LocalExpressionVisitor() + self.expr_visitor = _ExprVisitor(self.func) # allow internal function params to be mutable location, is_immutable = ( @@ -189,44 +203,13 @@ def __init__( f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node ) - if self.func.mutability == StateMutability.PURE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, - { - "value.id": set(CONSTANT_ENVIRONMENT_VARS.keys()).union( - set(MUTABLE_ENVIRONMENT_VARS.keys()) - ) - }, - ) - - # Add references to `self` as standalone address - self_references = fn_node.get_descendants(vy_ast.Name, {"id": "self"}) - standalone_self = [ - n for n in self_references if not isinstance(n.get_ancestor(), vy_ast.Attribute) - ] - node_list.extend(standalone_self) # type: ignore - - for node in node_list: - t = node._metadata.get("type") - if isinstance(t, ContractFunctionT) and t.mutability == StateMutability.PURE: - # allowed - continue - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", - node_list[0], - ) - if self.func.mutability is not StateMutability.PAYABLE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "msg", "attr": "value"} - ) - if node_list: - raise NonPayableViolation( - "msg.value is not allowed in non-payable functions", node_list[0] - ) + # visit default args + assert self.func.n_keyword_args == len(fn_node.args.defaults) + for kwarg in self.func.keyword_args: + self.expr_visitor.visit(kwarg.default_value, kwarg.typ) def visit(self, node): super().visit(node) - self.annotation_visitor.visit(node) def visit_AnnAssign(self, node): name = node.get("target.id") @@ -238,16 +221,42 @@ def visit_AnnAssign(self, node): "Memory variables must be declared with an initial value", node ) - type_ = type_from_annotation(node.annotation, DataLocation.MEMORY) - validate_expected_type(node.value, type_) + typ = type_from_annotation(node.annotation, DataLocation.MEMORY) + validate_expected_type(node.value, typ) try: - self.namespace[name] = VarInfo(type_, location=DataLocation.MEMORY) + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) except VyperException as exc: raise exc.with_annotation(node) from None - self.expr_visitor.visit(node.value) - def visit_Assign(self, node): + self.expr_visitor.visit(node.target, typ) + self.expr_visitor.visit(node.value, typ) + + def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Str): + if not msg_node.value.strip(): + raise StructureException("Reason string cannot be empty", msg_node) + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): + try: + validate_expected_type(msg_node, StringT(1024)) + except TypeMismatch as e: + raise InvalidType("revert reason must fit within String[1024]") from e + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + + def visit_Assert(self, node): + if node.msg: + self._validate_revert_reason(node.msg) + + try: + validate_expected_type(node.test, BoolT()) + except InvalidType: + raise InvalidType("Assertion test value must be a boolean", node.test) + self.expr_visitor.visit(node.test, BoolT()) + + # repeated code for Assign and AugAssign + def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) @@ -260,81 +269,71 @@ def visit_Assign(self, node): validate_expected_type(node.value, target.typ) target.validate_modification(node, self.func.mutability) - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) + self.expr_visitor.visit(node.value, target.typ) + self.expr_visitor.visit(node.target, target.typ) - def visit_AugAssign(self, node): - if isinstance(node.value, vy_ast.Tuple): - raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - - lhs_info = get_expr_info(node.target) - - validate_expected_type(node.value, lhs_info.typ) - lhs_info.validate_modification(node, self.func.mutability) - - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) - - def visit_Raise(self, node): - if node.exc: - _validate_revert_reason(node.exc) - self.expr_visitor.visit(node.exc) + def visit_Assign(self, node): + self._assign_helper(node) - def visit_Assert(self, node): - if node.msg: - _validate_revert_reason(node.msg) - self.expr_visitor.visit(node.msg) + def visit_AugAssign(self, node): + self._assign_helper(node) - try: - validate_expected_type(node.test, BoolT()) - except InvalidType: - raise InvalidType("Assertion test value must be a boolean", node.test) - self.expr_visitor.visit(node.test) + def visit_Break(self, node): + for_node = node.get_ancestor(vy_ast.For) + if for_node is None: + raise StructureException("`break` must be enclosed in a `for` loop", node) def visit_Continue(self, node): + # TODO: use context/state instead of ast search for_node = node.get_ancestor(vy_ast.For) if for_node is None: raise StructureException("`continue` must be enclosed in a `for` loop", node) - def visit_Break(self, node): - for_node = node.get_ancestor(vy_ast.For) - if for_node is None: - raise StructureException("`break` must be enclosed in a `for` loop", node) + def visit_Expr(self, node): + if not isinstance(node.value, vy_ast.Call): + raise StructureException("Expressions without assignment are disallowed", node) - def visit_Return(self, node): - values = node.value - if values is None: - if self.func.return_type: - raise FunctionDeclarationException("Return statement is missing a value", node) - return - elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): + raise StructureException("To call an event you must use the `log` statement", node) - if isinstance(values, vy_ast.Tuple): - values = values.elements - if not isinstance(self.func.return_type, TupleT): - raise FunctionDeclarationException("Function only returns a single value", node) - if self.func.return_type.length != len(values): - raise FunctionDeclarationException( - f"Incorrect number of return values: " - f"expected {self.func.return_type.length}, got {len(values)}", + if is_type_t(fn_type, StructT): + raise StructureException("Struct creation without assignment is disallowed", node) + + if isinstance(fn_type, ContractFunctionT): + if ( + fn_type.mutability > StateMutability.VIEW + and self.func.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot call a mutating function from a {self.func.mutability.value} function", node, ) - for given, expected in zip(values, self.func.return_type.member_types): - validate_expected_type(given, expected) - else: - validate_expected_type(values, self.func.return_type) - self.expr_visitor.visit(node.value) - def visit_If(self, node): - validate_expected_type(node.test, BoolT()) - self.expr_visitor.visit(node.test) - with self.namespace.enter_scope(): - for n in node.body: - self.visit(n) - with self.namespace.enter_scope(): - for n in node.orelse: - self.visit(n) + if ( + self.func.mutability == StateMutability.PURE + and fn_type.mutability != StateMutability.PURE + ): + raise StateAccessViolation( + "Cannot call non-pure function from a pure function", node + ) + + if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: + # it's a dotted function call like dynarray.pop() + expr_info = get_expr_info(node.value.func.value) + expr_info.validate_modification(node, self.func.mutability) + + # NOTE: fetch_call_return validates call args. + return_value = fn_type.fetch_call_return(node.value) + if ( + return_value + and not isinstance(fn_type, MemberFunctionT) + and not isinstance(fn_type, ContractFunctionT) + ): + raise StructureException( + f"Function '{fn_type._id}' cannot be called without assigning the result", node + ) + self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): @@ -463,19 +462,18 @@ def visit_For(self, node): f"which potentially modifies iterated storage variable '{iter_name}'", call_node, ) - self.expr_visitor.visit(node.iter) if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) for_loop_exceptions = [] iter_name = node.target.id - for type_ in type_list: + for possible_target_type in type_list: # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): try: - self.namespace[iter_name] = VarInfo(type_, is_constant=True) + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) except VyperException as exc: raise exc.with_annotation(node) from None @@ -486,17 +484,27 @@ def visit_For(self, node): except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - # type information is applied directly here because the - # scope is closed prior to the call to - # `StatementAnnotationVisitor` - node.target._metadata["type"] = type_ - - # success -- bail out instead of error handling. + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + + # success -- do not enter error handling section return - # if we have gotten here, there was an error for - # every type tried for the iterator - + # failed to find a good type. bail out if len(set(str(i) for i in for_loop_exceptions)) == 1: # if every attempt at type checking raised the same exception raise for_loop_exceptions[0] @@ -510,56 +518,20 @@ def visit_For(self, node): "but type checking fails with all possible types:", node, *( - (f"Casting '{iter_name}' as {type_}: {exc.message}", exc.annotations[0]) - for type_, exc in zip(type_list, for_loop_exceptions) + (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) + for typ, exc in zip(type_list, for_loop_exceptions) ), ) - def visit_Expr(self, node): - if not isinstance(node.value, vy_ast.Call): - raise StructureException("Expressions without assignment are disallowed", node) - - fn_type = get_exact_type_from_node(node.value.func) - if is_type_t(fn_type, EventT): - raise StructureException("To call an event you must use the `log` statement", node) - - if is_type_t(fn_type, StructT): - raise StructureException("Struct creation without assignment is disallowed", node) - - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - - # NOTE: fetch_call_return validates call args. - return_value = fn_type.fetch_call_return(node.value) - if ( - return_value - and not isinstance(fn_type, MemberFunctionT) - and not isinstance(fn_type, ContractFunctionT) - ): - raise StructureException( - f"Function '{fn_type._id}' cannot be called without assigning the result", node - ) - self.expr_visitor.visit(node.value) + def visit_If(self, node): + validate_expected_type(node.test, BoolT()) + self.expr_visitor.visit(node.test, BoolT()) + with self.namespace.enter_scope(): + for n in node.body: + self.visit(n) + with self.namespace.enter_scope(): + for n in node.orelse: + self.visit(n) def visit_Log(self, node): if not isinstance(node.value, vy_ast.Call): @@ -572,62 +544,249 @@ def visit_Log(self, node): f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node ) f.fetch_call_return(node.value) - self.expr_visitor.visit(node.value) + node._metadata["type"] = f.typedef + self.expr_visitor.visit(node.value, f.typedef) + + def visit_Raise(self, node): + if node.exc: + self._validate_revert_reason(node.exc) + def visit_Return(self, node): + values = node.value + if values is None: + if self.func.return_type: + raise FunctionDeclarationException("Return statement is missing a value", node) + return + elif self.func.return_type is None: + raise FunctionDeclarationException("Function does not return any values", node) -class _LocalExpressionVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Name) + if isinstance(values, vy_ast.Tuple): + values = values.elements + if not isinstance(self.func.return_type, TupleT): + raise FunctionDeclarationException("Function only returns a single value", node) + if self.func.return_type.length != len(values): + raise FunctionDeclarationException( + f"Incorrect number of return values: " + f"expected {self.func.return_type.length}, got {len(values)}", + node, + ) + for given, expected in zip(values, self.func.return_type.member_types): + validate_expected_type(given, expected) + else: + validate_expected_type(values, self.func.return_type) + self.expr_visitor.visit(node.value, self.func.return_type) + + +class _ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def visit_Attribute(self, node: vy_ast.Attribute) -> None: - self.visit(node.value) + def __init__(self, fn_node: ContractFunctionT): + self.func = fn_node + + def visit(self, node, typ): + # recurse and typecheck in case we are being fed the wrong type for + # some reason. note that `validate_expected_type` is unnecessary + # for nodes that already call `get_exact_type_from_node` and + # `get_possible_types_from_node` because `validate_expected_type` + # would be calling the same function again. + # CMC 2023-06-27 would be cleanest to call validate_expected_type() + # before recursing but maybe needs some refactoring before that + # can happen. + super().visit(node, typ) + + # annotate + node._metadata["type"] = typ + + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) - _validate_address_code_attribute(node) - - def visit_BinOp(self, node: vy_ast.BinOp) -> None: - self.visit(node.left) - self.visit(node.right) - - def visit_BoolOp(self, node: vy_ast.BoolOp) -> None: - for value in node.values: # type: ignore[attr-defined] - self.visit(value) - - def visit_Call(self, node: vy_ast.Call) -> None: - self.visit(node.func) - for arg in node.args: - self.visit(arg) - for kwarg in node.keywords: - self.visit(kwarg.value) - - def visit_Compare(self, node: vy_ast.Compare) -> None: - self.visit(node.left) # type: ignore[attr-defined] - self.visit(node.right) # type: ignore[attr-defined] - - def visit_Dict(self, node: vy_ast.Dict) -> None: - for key in node.keys: - self.visit(key) + + # CMC 2023-10-19 TODO generalize this to mutability check on every node. + # something like, + # if self.func.mutability < expr_info.mutability: + # raise ... + + if self.func.mutability != StateMutability.PAYABLE: + _validate_msg_value_access(node) + + if self.func.mutability == StateMutability.PURE: + _validate_pure_access(node, typ) + + value_type = get_exact_type_from_node(node.value) + _validate_address_code(node, value_type) + + self.visit(node.value, value_type) + + def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None: + validate_expected_type(node.left, typ) + self.visit(node.left, typ) + + rtyp = typ + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): + rtyp = get_possible_types_from_node(node.right).pop() + + validate_expected_type(node.right, rtyp) + + self.visit(node.right, rtyp) + + def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: + assert typ == BoolT() # sanity check for value in node.values: - self.visit(value) + validate_expected_type(value, BoolT()) + self.visit(value, BoolT()) + + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: + call_type = get_exact_type_from_node(node.func) + # except for builtin functions, `get_exact_type_from_node` + # already calls `validate_expected_type` on the call args + # and kwargs via `call_type.fetch_call_return` + self.visit(node.func, call_type) + + if isinstance(call_type, ContractFunctionT): + # function calls + if call_type.is_internal: + self.func.called_functions.add(call_type) + for arg, typ in zip(node.args, call_type.argument_types): + self.visit(arg, typ) + for kwarg in node.keywords: + # We should only see special kwargs + typ = call_type.call_site_kwargs[kwarg.arg].typ + self.visit(kwarg.value, typ) + + elif is_type_t(call_type, EventT): + # events have no kwargs + expected_types = call_type.typedef.arguments.values() + for arg, typ in zip(node.args, expected_types): + self.visit(arg, typ) + elif is_type_t(call_type, StructT): + # struct ctors + # ctors have no kwargs + expected_types = call_type.typedef.members.values() + for value, arg_type in zip(node.args[0].values, expected_types): + self.visit(value, arg_type) + elif isinstance(call_type, MemberFunctionT): + assert len(node.args) == len(call_type.arg_types) + for arg, arg_type in zip(node.args, call_type.arg_types): + self.visit(arg, arg_type) + else: + # builtin functions + arg_types = call_type.infer_arg_types(node) + # `infer_arg_types` already calls `validate_expected_type` + for arg, arg_type in zip(node.args, arg_types): + self.visit(arg, arg_type) + kwarg_types = call_type.infer_kwarg_types(node) + for kwarg in node.keywords: + self.visit(kwarg.value, kwarg_types[kwarg.arg]) + + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + # membership in list literal - `x in [a, b, c]` + # needle: ltyp, haystack: rtyp + if isinstance(node.right, vy_ast.List): + ltyp = get_common_types(node.left, *node.right.elements).pop() + + rlen = len(node.right.elements) + rtyp = SArrayT(ltyp, rlen) + validate_expected_type(node.right, rtyp) + else: + rtyp = get_exact_type_from_node(node.right) + if isinstance(rtyp, EnumT): + # enum membership - `some_enum in other_enum` + ltyp = rtyp + else: + # array membership - `x in my_list_variable` + assert isinstance(rtyp, (SArrayT, DArrayT)) + ltyp = rtyp.value_type - def visit_Index(self, node: vy_ast.Index) -> None: - self.visit(node.value) + validate_expected_type(node.left, ltyp) - def visit_List(self, node: vy_ast.List) -> None: - for element in node.elements: - self.visit(element) + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + else: + # ex. a < b + cmp_typ = get_common_types(node.left, node.right).pop() + if isinstance(cmp_typ, _BytestringT): + # for bytestrings, get_common_types automatically downcasts + # to the smaller common type - that will annotate with the + # wrong type, instead use get_exact_type_from_node (which + # resolves to the right type for bytestrings anyways). + ltyp = get_exact_type_from_node(node.left) + rtyp = get_exact_type_from_node(node.right) + else: + ltyp = rtyp = cmp_typ + validate_expected_type(node.left, ltyp) + validate_expected_type(node.right, rtyp) + + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None: + validate_expected_type(node, typ) - def visit_Subscript(self, node: vy_ast.Subscript) -> None: - self.visit(node.value) - self.visit(node.slice) + def visit_Index(self, node: vy_ast.Index, typ: VyperType) -> None: + validate_expected_type(node.value, typ) + self.visit(node.value, typ) - def visit_Tuple(self, node: vy_ast.Tuple) -> None: + def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: + assert isinstance(typ, (SArrayT, DArrayT)) for element in node.elements: - self.visit(element) + validate_expected_type(element, typ.value_type) + self.visit(element, typ.value_type) - def visit_UnaryOp(self, node: vy_ast.UnaryOp) -> None: - self.visit(node.operand) # type: ignore[attr-defined] + def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) + + if not isinstance(typ, TYPE_T): + validate_expected_type(node, typ) + + def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + if isinstance(node.value, vy_ast.List): + possible_base_types = get_possible_types_from_node(node.value) + + for possible_type in possible_base_types: + if typ.compare_type(possible_type.value_type): + base_type = possible_type + break + else: + # this should have been caught in + # `get_possible_types_from_node` but wasn't. + raise TypeCheckFailure(f"Expected {typ} but it is not a possible type", node) + + else: + base_type = get_exact_type_from_node(node.value) + + # get the correct type for the index, it might + # not be exactly base_type.key_type + # note: index_type is validated in types_from_Subscript + index_types = get_possible_types_from_node(node.slice.value) + index_type = index_types.pop() + + self.visit(node.slice, index_type) + self.visit(node.value, base_type) + + def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + assert isinstance(typ, TupleT) + for element, subtype in zip(node.elements, typ.member_types): + validate_expected_type(element, subtype) + self.visit(element, subtype) - def visit_IfExp(self, node: vy_ast.IfExp) -> None: - self.visit(node.test) - self.visit(node.body) - self.visit(node.orelse) + def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: VyperType) -> None: + validate_expected_type(node.operand, typ) + self.visit(node.operand, typ) + + def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: + validate_expected_type(node.test, BoolT()) + self.visit(node.test, BoolT()) + validate_expected_type(node.body, typ) + self.visit(node.body, typ) + validate_expected_type(node.orelse, typ) + self.visit(node.orelse, typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4f911764e0..afa6b56838 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -312,10 +312,17 @@ def types_from_Constant(self, node): def types_from_List(self, node): # literal array if _is_empty_list(node): - # empty list literal `[]` ret = [] - # subtype can be anything - for t in types.PRIMITIVE_TYPES.values(): + + if len(node.elements) > 0: + # empty nested list literals `[[], []]` + subtypes = self.get_possible_types_from_node(node.elements[0]) + else: + # empty list literal `[]` + # subtype can be anything + subtypes = types.PRIMITIVE_TYPES.values() + + for t in subtypes: # 1 is minimum possible length for dynarray, # can be assigned to anything if isinstance(t, VyperType):