diff --git a/FUNDING.yml b/FUNDING.yml index 81e82160d0..efb9eb01b7 100644 --- a/FUNDING.yml +++ b/FUNDING.yml @@ -1 +1 @@ -custom: https://gitcoin.co/grants/200/vyper-smart-contract-language-2 +custom: https://etherscan.io/address/0x70CCBE10F980d80b7eBaab7D2E3A73e87D67B775 diff --git a/docs/release-notes.rst b/docs/release-notes.rst index da86c5c0ce..3db11dc451 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -14,17 +14,13 @@ Release Notes for advisory links: :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g -.. - v0.3.10 ("Black Adder") - *********************** - -v0.3.10rc1 -********** +v0.3.10 ("Black Adder") +*********************** -Date released: 2023-09-06 +Date released: 2023-10-04 ========================= -v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). +v0.3.10 is a performance focused release that additionally ships numerous bugfixes. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_). Breaking changes: ----------------- @@ -32,6 +28,7 @@ Breaking changes: - add runtime code layout to initcode (`#3584 `_) - drop evm versions through istanbul (`#3470 `_) - remove vyper signature from runtime (`#3471 `_) +- only allow valid identifiers to be nonreentrant keys (`#3605 `_) Non-breaking changes and improvements: -------------------------------------- @@ -46,12 +43,15 @@ Notable fixes: - fix ``ecrecover()`` behavior when signature is invalid (`GHSA-f5x6-7qgp-jhf3 `_, `#3586 `_) - fix: order of evaluation for some builtins (`#3583 `_, `#3587 `_) +- fix: memory allocation in certain builtins using ``msize`` (`#3610 `_) +- fix: ``_abi_decode()`` input validation in certain complex expressions (`#3626 `_) - fix: pycryptodome for arm builds (`#3485 `_) - let params of internal functions be mutable (`#3473 `_) - typechecking of folded builtins in (`#3490 `_) - update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_) - fix: raw_call type when max_outsize=0 is set (`#3572 `_) - fix: implements check for indexed event arguments (`#3570 `_) +- fix: type-checking for ``_abi_decode()`` arguments (`#3626 `_) Other docs updates, chores and fixes: ------------------------------------- diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst index d2c5d48d96..3861bf4380 100644 --- a/docs/structure-of-a-contract.rst +++ b/docs/structure-of-a-contract.rst @@ -17,7 +17,7 @@ Vyper supports several source code directives to control compiler modes and help Version Pragma -------------- -The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers _`. +The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers `_. As of 0.3.10, the recommended way to specify the version pragma is as follows: @@ -25,6 +25,10 @@ As of 0.3.10, the recommended way to specify the version pragma is as follows: #pragma version ^0.3.0 +.. note:: + + Both pragma directive versions ``#pragma`` and ``# pragma`` are supported. + The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version: .. code-block:: python diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py index 5427532c16..3d072674f6 100644 --- a/tests/ast/test_pre_parser.py +++ b/tests/ast/test_pre_parser.py @@ -1,8 +1,9 @@ import pytest from vyper.ast.pre_parser import pre_parse, validate_version_pragma +from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings -from vyper.exceptions import VersionException +from vyper.exceptions import StructureException, VersionException SRC_LINE = (1, 0) # Dummy source line COMPILER_VERSION = "0.1.1" @@ -96,43 +97,50 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): """ """, Settings(), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize """, Settings(optimize=OptimizationLevel.CODESIZE), + None, ), ( """ #pragma optimize none """, Settings(optimize=OptimizationLevel.NONE), + None, ), ( """ #pragma optimize gas """, Settings(optimize=OptimizationLevel.GAS), + None, ), ( """ #pragma version 0.3.10 """, Settings(compiler_version="0.3.10"), + Settings(optimize=OptimizationLevel.GAS), ), ( """ #pragma evm-version shanghai """, Settings(evm_version="shanghai"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ #pragma optimize codesize #pragma evm-version shanghai """, - Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), + Settings(evm_version="shanghai", optimize=OptimizationLevel.CODESIZE), + None, ), ( """ @@ -140,6 +148,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma evm-version shanghai """, Settings(evm_version="shanghai", compiler_version="0.3.10"), + Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS), ), ( """ @@ -147,6 +156,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS), + Settings(optimize=OptimizationLevel.GAS), ), ( """ @@ -155,11 +165,59 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): #pragma optimize gas """, Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS, evm_version="shanghai"), + Settings(optimize=OptimizationLevel.GAS, evm_version="shanghai"), ), ] -@pytest.mark.parametrize("code, expected_pragmas", pragma_examples) -def parse_pragmas(code, expected_pragmas): - pragmas, _, _ = pre_parse(code) - assert pragmas == expected_pragmas +@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) +def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): + mock_version("0.3.10") + settings, _, _ = pre_parse(code) + + assert settings == pre_parse_settings + + compiler_data = CompilerData(code) + + # check what happens after CompilerData constructor + if compiler_data_settings is None: + # None is sentinel here meaning that nothing changed + compiler_data_settings = pre_parse_settings + + assert compiler_data.settings == compiler_data_settings + + +invalid_pragmas = [ + # evm-versionnn + """ +# pragma evm-versionnn cancun + """, + # bad fork name + """ +# pragma evm-version cancunn + """, + # oppptimize + """ +# pragma oppptimize codesize + """, + # ggas + """ +# pragma optimize ggas + """, + # double specified + """ +# pragma optimize gas +# pragma optimize codesize + """, + # double specified + """ +# pragma evm-version cancun +# pragma evm-version shanghai + """, +] + + +@pytest.mark.parametrize("code", invalid_pragmas) +def test_invalid_pragma(code): + with pytest.raises(StructureException): + pre_parse(code) diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py index 1466166501..cb46ba238d 100644 --- a/tests/compiler/ir/test_optimize_ir.py +++ b/tests/compiler/ir/test_optimize_ir.py @@ -1,9 +1,13 @@ import pytest from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version from vyper.exceptions import StaticAssertionException from vyper.ir import optimizer +POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + optimize_list = [ (["eq", 1, 2], [0]), (["lt", 1, 2], [1]), @@ -272,3 +276,106 @@ def test_operator_set_values(): assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"} assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"} assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"} + + +mload_merge_list = [ + # copy "backward" with no overlap between src and dst buffers, + # OK to become mcopy + ( + ["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]], + ["mcopy", 32, 128, 64], + ), + # copy with overlap "backwards", OK to become mcopy + (["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]), + # "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy + (["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]), + # copy "forward" with no overlap, OK to become mcopy + (["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]), + # copy "forwards" with overlap by one word, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None), + # check "forward" overlap by one byte, must NOT become mcopy + (["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None), + # check "forward" overlap by one byte again, must NOT become mcopy + (["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None), + # copy 3 words with partial overlap "forwards", partially becomes mcopy + # (2 words are mcopied and 1 word is mload/mstored + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]], + ), + # copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each + ( + [ + "seq", + ["mstore", 96, ["mload", 32]], + ["mstore", 128, ["mload", 64]], + ["mstore", 160, ["mload", 96]], + ["mstore", 192, ["mload", 128]], + ], + ["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]], + ), + # copy 4 words with 1 byte of overlap, must NOT become mcopy + ( + [ + "seq", + ["mstore", 96, ["mload", 33]], + ["mstore", 128, ["mload", 65]], + ["mstore", 160, ["mload", 97]], + ["mstore", 192, ["mload", 129]], + ], + None, + ), + # Ensure only sequential mstore + mload sequences are optimized + ( + [ + "seq", + ["mstore", 0, ["mload", 32]], + ["sstore", 0, ["calldataload", 4]], + ["mstore", 32, ["mload", 64]], + ], + None, + ), + # not-word aligned optimizations (not overlap) + (["seq", ["mstore", 0, ["mload", 1]], ["mstore", 32, ["mload", 33]]], ["mcopy", 0, 1, 64]), + # not-word aligned optimizations (overlap) + (["seq", ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], None), + # not-word aligned optimizations (overlap and not-overlap) + ( + [ + "seq", + ["mstore", 0, ["mload", 1]], + ["mstore", 32, ["mload", 33]], + ["mstore", 1, ["mload", 0]], + ["mstore", 33, ["mload", 32]], + ], + ["seq", ["mcopy", 0, 1, 64], ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], + ), + # overflow test + ( + [ + "seq", + ["mstore", 2**256 - 1 - 31 - 32, ["mload", 0]], + ["mstore", 2**256 - 1 - 31, ["mload", 32]], + ], + ["mcopy", 2**256 - 1 - 31 - 32, 0, 64], + ), +] + + +@pytest.mark.parametrize("ir", mload_merge_list) +@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys())) +def test_mload_merge(ir, evm_version): + with anchor_evm_version(evm_version): + optimized = optimizer.optimize(IRnode.from_list(ir[0])) + if ir[1] is None: + # no-op, assert optimizer does nothing + expected = IRnode.from_list(ir[0]) + else: + expected = IRnode.from_list(ir[1]) + + assert optimized == expected diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py index 583a927b41..cd26659a5c 100644 --- a/tests/parser/features/test_assignment.py +++ b/tests/parser/features/test_assignment.py @@ -442,3 +442,63 @@ def bug(p: Point) -> Point: """ c = get_contract(code) assert c.bug((1, 2)) == (2, 1) + + +mload_merge_codes = [ + ( + """ +@external +def foo() -> uint256[4]: + # copy "backwards" + xs: uint256[4] = [1, 2, 3, 4] + +# dst < src + xs[0] = xs[1] + xs[1] = xs[2] + xs[2] = xs[3] + + return xs + """, + [2, 3, 4, 4], + ), + ( + """ +@external +def foo() -> uint256[4]: + # copy "forwards" + xs: uint256[4] = [1, 2, 3, 4] + +# src < dst + xs[1] = xs[0] + xs[2] = xs[1] + xs[3] = xs[2] + + return xs + """, + [1, 1, 1, 1], + ), + ( + """ +@external +def foo() -> uint256[5]: + # partial "forward" copy + xs: uint256[5] = [1, 2, 3, 4, 5] + +# src < dst + xs[2] = xs[0] + xs[3] = xs[1] + xs[4] = xs[2] + + return xs + """, + [1, 2, 1, 2, 1], + ), +] + + +# functional test that mload merging does not occur when source and dest +# buffers overlap. (note: mload merging only applies after cancun) +@pytest.mark.parametrize("code,expected_result", mload_merge_codes) +def test_mcopy_overlap(get_contract, code, expected_result): + c = get_contract(code) + assert c.foo() == expected_result diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 0ead889787..9d96efea5e 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -111,7 +111,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: validate_version_pragma(compiler_version, start) settings.compiler_version = compiler_version - if pragma.startswith("optimize "): + elif pragma.startswith("optimize "): if settings.optimize is not None: raise StructureException("pragma optimize specified twice!", start) try: @@ -119,7 +119,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: settings.optimize = OptimizationLevel.from_string(mode) except ValueError: raise StructureException(f"Invalid optimization mode `{mode}`", start) - if pragma.startswith("evm-version "): + elif pragma.startswith("evm-version "): if settings.evm_version is not None: raise StructureException("pragma evm-version specified twice!", start) evm_version = pragma.removeprefix("evm-version").strip() @@ -127,6 +127,9 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: raise StructureException("Invalid evm version: `{evm_version}`", start) settings.evm_version = evm_version + else: + raise StructureException(f"Unknown pragma `{pragma.split()[0]}`") + if typ == NAME and string in ("class", "yield"): raise SyntaxException( f"The `{string}` keyword is not allowed. ", code, start[0], start[1] diff --git a/vyper/codegen/memory_allocator.py b/vyper/codegen/memory_allocator.py index 582d4b9c54..b5e1212917 100644 --- a/vyper/codegen/memory_allocator.py +++ b/vyper/codegen/memory_allocator.py @@ -1,6 +1,6 @@ from typing import List -from vyper.exceptions import CompilerPanic +from vyper.exceptions import CompilerPanic, MemoryAllocationException from vyper.utils import MemoryPositions @@ -46,6 +46,8 @@ class MemoryAllocator: next_mem: int + _ALLOCATION_LIMIT: int = 2**64 + def __init__(self, start_position: int = MemoryPositions.RESERVED_MEMORY): """ Initializer. @@ -110,6 +112,14 @@ def _expand_memory(self, size: int) -> int: before_value = self.next_mem self.next_mem += size self.size_of_mem = max(self.size_of_mem, self.next_mem) + + if self.size_of_mem >= self._ALLOCATION_LIMIT: + # this should not be caught + raise MemoryAllocationException( + f"Tried to allocate {self.size_of_mem} bytes! " + f"(limit is {self._ALLOCATION_LIMIT} (2**64) bytes)" + ) + return before_value def deallocate_memory(self, pos: int, size: int) -> None: diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index c2951986c8..254cad32e6 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -91,17 +91,15 @@ def parse_Assign(self): return IRnode.from_list(ret) def parse_If(self): + with self.context.block_scope(): + test_expr = Expr.parse_value_expr(self.stmt.test, self.context) + body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + if self.stmt.orelse: with self.context.block_scope(): - add_on = [parse_body(self.stmt.orelse, self.context)] - else: - add_on = [] + body.extend([parse_body(self.stmt.orelse, self.context)]) - with self.context.block_scope(): - test_expr = Expr.parse_value_expr(self.stmt.test, self.context) - body = ["if", test_expr, parse_body(self.stmt.body, self.context)] + add_on - ir_node = IRnode.from_list(body) - return ir_node + return IRnode.from_list(body) def parse_Log(self): event = self.stmt._metadata["type"] diff --git a/vyper/compiler/README.md b/vyper/compiler/README.md index d6b55fdd82..eb70750a2b 100644 --- a/vyper/compiler/README.md +++ b/vyper/compiler/README.md @@ -51,11 +51,9 @@ for specific implementation details. [`vyper.compiler.compile_codes`](__init__.py) is the main user-facing function for generating compiler output from Vyper source. The process is as follows: -1. The `@evm_wrapper` decorator sets the target EVM version in -[`opcodes.py`](../evm/opcodes.py). -2. A [`CompilerData`](phases.py) object is created for each contract to be compiled. +1. A [`CompilerData`](phases.py) object is created for each contract to be compiled. This object uses `@property` methods to trigger phases of the compiler as required. -3. Functions in [`output.py`](output.py) generate the requested outputs from the +2. Functions in [`output.py`](output.py) generate the requested outputs from the compiler data. ## Design diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0b3c0d8191..b1c4201361 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -120,17 +120,17 @@ def compile_codes( # make IR output the same between runs codegen.reset_names() - with anchor_evm_version(settings.evm_version): - compiler_data = CompilerData( - source_code, - contract_name, - interfaces, - source_id, - settings, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, - ) + compiler_data = CompilerData( + source_code, + contract_name, + interfaces, + source_id, + settings, + storage_layout_override, + show_gas_estimates, + no_bytecode_metadata, + ) + with anchor_evm_version(compiler_data.settings.evm_version): for output_format in output_formats[contract_name]: if output_format not in OUTPUT_FORMATS: raise ValueError(f"Unsupported format type {repr(output_format)}") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 935bb12cae..eb2d269714 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -89,9 +89,12 @@ def __init__( self.no_bytecode_metadata = no_bytecode_metadata self.settings = settings or Settings() + _ = self._generate_ast # force settings to be calculated + @cached_property def _generate_ast(self): settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name) + # validate the compiler settings # XXX: this is a bit ugly, clean up later if settings.evm_version is not None: @@ -118,6 +121,8 @@ def _generate_ast(self): if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() + # note self.settings.compiler_version is erased here as it is + # not used after pre-parsing return ast @cached_property diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3346501de4..97abd6e4fc 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -269,6 +269,10 @@ class StorageLayoutException(VyperException): """Invalid slot for the storage layout overrides""" +class MemoryAllocationException(VyperException): + """Tried to allocate too much memory""" + + class JSONError(Exception): """Invalid compiler input JSON.""" diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 08c2168381..8df4bbac2d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz): def _merge_mload(argz): if not version_check(begin="cancun"): return False - return _merge_load(argz, "mload", "mcopy") + return _merge_load(argz, "mload", "mcopy", allow_overlap=False) -def _merge_load(argz, _LOAD, _COPY): +def _merge_load(argz, _LOAD, _COPY, allow_overlap=True): # look for sequential operations copying from X to Y # and merge them into a single copy operation changed = False @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY): initial_dst_offset = dst_offset initial_src_offset = src_offset idx = i + + # dst and src overlap, discontinue the optimization + has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32 + if ( initial_dst_offset + total_length == dst_offset and initial_src_offset + total_length == src_offset + and (allow_overlap or not has_overlap) ): mstore_nodes.append(ir_node) total_length += 32 diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 9b3f24d456..785fb5f399 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -121,18 +121,6 @@ def _check_iterator_modification( return None -def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None: - if msg_node: - if isinstance(msg_node, vy_ast.Str): - if not msg_node.value.strip(): - raise StructureException("Reason string cannot be empty", msg_node) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - validate_expected_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - - # helpers def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": @@ -182,6 +170,7 @@ def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: def _validate_self_reference(node: vy_ast.Name) -> None: + # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): raise StateAccessViolation("not allowed to query self in pure functions", node) @@ -214,6 +203,7 @@ def __init__( f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node ) + # visit default args assert self.func.n_keyword_args == len(fn_node.args.defaults) for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @@ -242,9 +232,22 @@ def visit_AnnAssign(self, node): self.expr_visitor.visit(node.target, typ) self.expr_visitor.visit(node.value, typ) + def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Str): + if not msg_node.value.strip(): + raise StructureException("Reason string cannot be empty", msg_node) + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): + try: + validate_expected_type(msg_node, StringT(1024)) + except TypeMismatch as e: + raise InvalidType("revert reason must fit within String[1024]") from e + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + def visit_Assert(self, node): if node.msg: - _validate_revert_reason(node.msg) + self._validate_revert_reason(node.msg) try: validate_expected_type(node.test, BoolT()) @@ -252,7 +255,8 @@ def visit_Assert(self, node): raise InvalidType("Assertion test value must be a boolean", node.test) self.expr_visitor.visit(node.test, BoolT()) - def visit_Assign(self, node): + # repeated code for Assign and AugAssign + def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) @@ -268,17 +272,11 @@ def visit_Assign(self, node): self.expr_visitor.visit(node.value, target.typ) self.expr_visitor.visit(node.target, target.typ) - def visit_AugAssign(self, node): - if isinstance(node.value, vy_ast.Tuple): - raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - - lhs_info = get_expr_info(node.target) - - validate_expected_type(node.value, lhs_info.typ) - lhs_info.validate_modification(node, self.func.mutability) + def visit_Assign(self, node): + self._assign_helper(node) - self.expr_visitor.visit(node.value, lhs_info.typ) - self.expr_visitor.visit(node.target, lhs_info.typ) + def visit_AugAssign(self, node): + self._assign_helper(node) def visit_Break(self, node): for_node = node.get_ancestor(vy_ast.For) @@ -286,6 +284,7 @@ def visit_Break(self, node): raise StructureException("`break` must be enclosed in a `for` loop", node) def visit_Continue(self, node): + # TODO: use context/state instead of ast search for_node = node.get_ancestor(vy_ast.For) if for_node is None: raise StructureException("`continue` must be enclosed in a `for` loop", node) @@ -476,12 +475,12 @@ def visit_For(self, node): for_loop_exceptions = [] iter_name = node.target.id - for typ in type_list: + for possible_target_type in type_list: # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): try: - self.namespace[iter_name] = VarInfo(typ, is_constant=True) + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) except VyperException as exc: raise exc.with_annotation(node) from None @@ -492,20 +491,22 @@ def visit_For(self, node): except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - self.expr_visitor.visit(node.target, typ) + self.expr_visitor.visit(node.target, possible_target_type) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - typ = get_exact_type_from_node(node.iter) - self.expr_visitor.visit(node.iter, typ) + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) if isinstance(node.iter, vy_ast.List): len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(typ, len_)) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": for a in node.iter.args: - self.expr_visitor.visit(a, typ) + self.expr_visitor.visit(a, possible_target_type) for a in node.iter.keywords: if a.arg == "bound": - self.expr_visitor.visit(a.value, typ) + self.expr_visitor.visit(a.value, possible_target_type) # success -- do not enter error handling section return @@ -555,7 +556,7 @@ def visit_Log(self, node): def visit_Raise(self, node): if node.exc: - _validate_revert_reason(node.exc) + self._validate_revert_reason(node.exc) def visit_Return(self, node): values = node.value @@ -606,10 +607,15 @@ def visit(self, node, typ): def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) - if self.func and self.func.mutability is not StateMutability.PAYABLE: + # CMC 2023-10-19 TODO generalize this to mutability check on every node. + # something like, + # if self.func.mutability < expr_info.mutability: + # raise ... + + if self.func.mutability != StateMutability.PAYABLE: _validate_msg_value_access(node) - if self.func and self.func.mutability == StateMutability.PURE: + if self.func.mutability == StateMutability.PURE: _validate_pure_access(node, typ) value_type = get_exact_type_from_node(node.value) @@ -624,8 +630,8 @@ def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None: rtyp = typ if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): rtyp = get_possible_types_from_node(node.right).pop() - else: - validate_expected_type(node.right, rtyp) + + validate_expected_type(node.right, rtyp) self.visit(node.right, rtyp) @@ -645,7 +651,6 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if isinstance(call_type, ContractFunctionT): # function calls if call_type.is_internal: - assert self.func is not None self.func.called_functions.add(call_type) for arg, typ in zip(node.args, call_type.argument_types): self.visit(arg, typ) @@ -687,27 +692,27 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # membership in list literal - `x in [a, b, c]` + # needle: ltyp, haystack: rtyp if isinstance(node.right, vy_ast.List): - cmp_typ = get_common_types(node.left, *node.right.elements).pop() - ltyp = cmp_typ + ltyp = get_common_types(node.left, *node.right.elements).pop() rlen = len(node.right.elements) - rtyp = SArrayT(cmp_typ, rlen) + rtyp = SArrayT(ltyp, rlen) validate_expected_type(node.right, rtyp) - self.visit(node.right, rtyp) else: - cmp_typ = get_exact_type_from_node(node.right) - self.visit(node.right, cmp_typ) - if isinstance(cmp_typ, EnumT): + rtyp = get_exact_type_from_node(node.right) + if isinstance(rtyp, EnumT): # enum membership - `some_enum in other_enum` - ltyp = cmp_typ + ltyp = rtyp else: # array membership - `x in my_list_variable` - assert isinstance(cmp_typ, (SArrayT, DArrayT)) - ltyp = cmp_typ.value_type + assert isinstance(rtyp, (SArrayT, DArrayT)) + ltyp = rtyp.value_type validate_expected_type(node.left, ltyp) + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) else: # ex. a < b @@ -768,7 +773,8 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: base_type = get_exact_type_from_node(node.value) # get the correct type for the index, it might - # not be base_type.key_type + # not be exactly base_type.key_type + # note: index_type is validated in types_from_Subscript index_types = get_possible_types_from_node(node.slice.value) index_type = index_types.pop()