From 8ccacb3f47f864ec2ff64d5f7ca65625e9df6b2f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 10 Feb 2024 08:39:51 -0800 Subject: [PATCH] feat[lang]: singleton modules with ownership hierarchy (#3729) this commit implements "singleton modules with ownership hierarchy" as described in https://github.com/vyperlang/vyper/issues/3722. to accomplish this, two new language constructs are added: `UsesDecl` and `InitializesDecl`. these are exposed to the user as `uses:` and `initializes:`. they are also accompanied by new `AnalysisResult` data structures: `UsesInfo` and `InitializesInfo`. `uses` and `initializes` can be thought of as a constraint system on the module system. a `uses: my-module` annotation is required if `my_module`'s state is accessed (read or written), and `initializes: my_module` is required to call `my_module.__init__()`. a module can be `use`d any number of times; it can only be `initialize`d once. a module which has been used (directly, or transitively) by the compilation target (main entry point module), must be `initialize`d exactly once. `initializes:` is also required to declare which modules it has been `initialize`d with. for example, if `mod1` declares it `uses: mod2`, then any `initializes: mod1` statement must declare *which* instance of `mod2` it has been initialized with. although there is only ever a single instance of `mod2`, this user-facing requirement improves readability by forcing the user to be aware of what the state access dependencies are for a given, `initialize`d module. the `NamedExpr` node ("walrus operator") has been added to the AST to support the initializer syntax. (note: the walrus operator is used, because the originally proposed syntax, `mod1[mod2 = mod2]` is rejected by the python parser). a new compiler pass, `vyper/semantics/analysis/global.py` has been added to implement the global initializer constraint, as it cannot be defined recursively (without a global context). since `__init__()` functions can now be called from other `__init__()` functions (which is not allowed for normal `@external` functions!), a new `@deploy` visibility has been added to vyper's visibility system. `@deploy` functions can be called from other `@deploy` functions, and never from `@external` or `@internal` functions. they also have special treatment in the ABI relative to other `@external` functions. `initializes:` is useful since it also serves the purpose of being a storage allocator directive. wherever `initializes:` is placed, is where the module will be placed in storage (and code, transient storage, or any other future storage locations). this commit refactors the storage allocator so that it recurses into child modules whenever it sees an `initializes:` statement. it refactors several data structures surrounding the storage allocator, including removing inheritance on the `DataPosition` data structure (which has also been renamed to `VarOffset`). some utility functions have been added for calculating the size of a given variable, which also get used in codegen (`get_element_ptr()`). additional work/refactoring in this commit: - new analysis machinery for detecting reads/writes for all `ExprInfo`s - dynamic programming on the `get_expr_info()` routine - refactoring of `visit_Expr`, which fixes call mutability analysis - move `StringEnum` back to vyper/utils.py - remove the "TYPE_DEFINITION" kludge in certain builtins, replace with usage of `TYPE_T` - improve `tag_exceptions()` formatting - remove `Context.globals`, as we rely on the results of the front-end analyser now. - remove dead variable: `Context.in_assertion` - refactor `generate_ir_for_function` into `generate_ir_for_external_function` and `generate_ir_for_internal_function` - move `get_nonreentrant_lock` to `function_definitions/common.py` - simplify layout allocation across locations into single function - add `VyperType.get_size_in()` and `VarInfo.get_size()` helper functions so we don't need to do as much switch/case in implementation functions - refactor `codegen/core.py` functions to use `VyperType.get_size()` - fix interfaces access from `.vyi` files --- examples/auctions/blind_auction.vy | 4 +- examples/auctions/simple_open_auction.vy | 4 +- examples/crowdfund.vy | 4 +- examples/factory/Exchange.vy | 4 +- examples/factory/Factory.vy | 4 +- .../market_maker/on_chain_market_maker.vy | 2 + examples/name_registry/name_registry.vy | 1 + .../safe_remote_purchase.vy | 4 +- examples/stock/company.vy | 4 +- examples/storage/advanced_storage.vy | 4 +- examples/storage/storage.vy | 6 +- examples/tokens/ERC1155ownable.vy | 5 +- examples/tokens/ERC20.vy | 4 +- examples/tokens/ERC4626.vy | 4 +- examples/tokens/ERC721.vy | 4 +- examples/voting/ballot.vy | 4 +- examples/wallet/wallet.vy | 4 +- tests/functional/builtins/codegen/test_abi.py | 4 +- .../builtins/codegen/test_abi_decode.py | 2 +- .../builtins/codegen/test_abi_encode.py | 2 +- .../functional/builtins/codegen/test_ceil.py | 4 +- .../builtins/codegen/test_concat.py | 4 +- .../builtins/codegen/test_create_functions.py | 10 +- .../builtins/codegen/test_ecrecover.py | 2 +- .../functional/builtins/codegen/test_floor.py | 4 +- .../builtins/codegen/test_raw_call.py | 2 +- .../functional/builtins/codegen/test_slice.py | 10 +- .../test_default_function.py | 2 +- .../calling_convention/test_erc20_abi.py | 2 +- .../test_external_contract_calls.py | 31 +- ...test_modifiable_external_contract_calls.py | 8 +- .../calling_convention/test_return_tuple.py | 2 +- .../features/decorators/test_payable.py | 4 +- .../features/decorators/test_private.py | 4 +- .../features/iteration/test_range_in.py | 2 +- .../codegen/features/test_bytes_map_keys.py | 12 +- .../codegen/features/test_clampers.py | 2 +- .../codegen/features/test_constructor.py | 22 +- .../codegen/features/test_immutable.py | 51 +- .../functional/codegen/features/test_init.py | 8 +- .../codegen/features/test_logging.py | 4 +- .../codegen/features/test_ternary.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/integration/test_escrow.py | 2 +- .../codegen/modules/test_module_constants.py | 20 + .../codegen/modules/test_module_variables.py | 318 +++++ .../codegen/storage_variables/test_getters.py | 4 +- .../test_storage_variable.py | 2 +- tests/functional/codegen/test_interfaces.py | 12 +- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_dynamic_array.py | 4 +- tests/functional/codegen/types/test_flag.py | 2 +- tests/functional/codegen/types/test_string.py | 2 +- .../test_safe_remote_purchase.py | 2 +- .../syntax/exceptions/test_call_violation.py | 9 + .../exceptions/test_constancy_exception.py | 59 +- .../test_function_declaration_exception.py | 10 +- .../test_instantiation_exception.py | 2 +- .../exceptions/test_invalid_reference.py | 2 +- .../exceptions/test_structure_exception.py | 6 +- .../exceptions/test_vyper_exception_pos.py | 2 +- .../syntax/modules/test_deploy_visibility.py | 27 + .../syntax/modules/test_implements.py | 51 + .../syntax/modules/test_initializers.py | 1139 +++++++++++++++++ tests/functional/syntax/test_address_code.py | 4 +- tests/functional/syntax/test_codehash.py | 2 +- tests/functional/syntax/test_constants.py | 4 +- tests/functional/syntax/test_immutables.py | 22 +- tests/functional/syntax/test_init.py | 64 + tests/functional/syntax/test_interfaces.py | 4 +- tests/functional/syntax/test_public.py | 2 +- tests/functional/syntax/test_tuple_assign.py | 2 +- tests/unit/ast/test_ast_dict.py | 10 - .../cli/storage_layout/test_storage_layout.py | 250 +++- tests/unit/compiler/asm/test_asm_optimizer.py | 22 +- tests/unit/compiler/test_bytecode_runtime.py | 2 +- tests/unit/semantics/test_storage_slots.py | 4 +- vyper/ast/__init__.py | 3 +- vyper/ast/grammar.lark | 14 +- vyper/ast/nodes.py | 105 +- vyper/ast/nodes.pyi | 35 +- vyper/ast/parse.py | 4 +- vyper/builtins/_signatures.py | 13 +- vyper/builtins/_utils.py | 6 +- vyper/builtins/functions.py | 18 +- vyper/codegen/context.py | 19 +- vyper/codegen/core.py | 61 +- vyper/codegen/expr.py | 37 +- .../codegen/function_definitions/__init__.py | 5 +- vyper/codegen/function_definitions/common.py | 120 +- .../function_definitions/external_function.py | 49 +- .../function_definitions/internal_function.py | 34 +- vyper/codegen/function_definitions/utils.py | 31 - vyper/codegen/module.py | 31 +- vyper/codegen/stmt.py | 2 +- vyper/compiler/phases.py | 27 +- vyper/evm/address_space.py | 8 - vyper/exceptions.py | 25 +- vyper/semantics/analysis/__init__.py | 2 +- vyper/semantics/analysis/base.py | 286 ++--- vyper/semantics/analysis/constant_folding.py | 2 +- vyper/semantics/analysis/data_positions.py | 221 ++-- vyper/semantics/analysis/global_.py | 80 ++ vyper/semantics/analysis/local.py | 228 +++- vyper/semantics/analysis/module.py | 265 +++- vyper/semantics/analysis/utils.py | 45 +- vyper/semantics/data_locations.py | 16 +- vyper/semantics/types/base.py | 23 +- vyper/semantics/types/function.py | 91 +- vyper/semantics/types/module.py | 94 +- vyper/semantics/types/utils.py | 16 +- vyper/utils.py | 56 +- 112 files changed, 3566 insertions(+), 845 deletions(-) create mode 100644 tests/functional/codegen/modules/test_module_variables.py create mode 100644 tests/functional/syntax/modules/test_deploy_visibility.py create mode 100644 tests/functional/syntax/modules/test_implements.py create mode 100644 tests/functional/syntax/modules/test_initializers.py create mode 100644 tests/functional/syntax/test_init.py delete mode 100644 vyper/codegen/function_definitions/utils.py create mode 100644 vyper/semantics/analysis/global_.py diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 597aed57c7..966565138f 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Blind Auction. Adapted to Vyper from [Solidity by Example](https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst#blind-auction-1) struct Bid: @@ -36,7 +38,7 @@ pendingReturns: HashMap[address, uint256] # Create a blinded auction with `_biddingTime` seconds bidding time and # `_revealTime` seconds reveal time on behalf of the beneficiary address # `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _biddingTime: uint256, _revealTime: uint256): self.beneficiary = _beneficiary self.biddingEnd = block.timestamp + _biddingTime diff --git a/examples/auctions/simple_open_auction.vy b/examples/auctions/simple_open_auction.vy index 6d5ce06f17..499e12af16 100644 --- a/examples/auctions/simple_open_auction.vy +++ b/examples/auctions/simple_open_auction.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Open Auction # Auction params @@ -19,7 +21,7 @@ pendingReturns: public(HashMap[address, uint256]) # Create a simple auction with `_auction_start` and # `_bidding_time` seconds bidding time on behalf of the # beneficiary address `_beneficiary`. -@external +@deploy def __init__(_beneficiary: address, _auction_start: uint256, _bidding_time: uint256): self.beneficiary = _beneficiary self.auctionStart = _auction_start # auction start time can be in the past, present or future diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 6d07e15bc4..50ec005924 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -11,7 +13,7 @@ goal: public(uint256) timelimit: public(uint256) # Setup global variables -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/examples/factory/Exchange.vy b/examples/factory/Exchange.vy index 77f47984bc..e66c60743a 100644 --- a/examples/factory/Exchange.vy +++ b/examples/factory/Exchange.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 @@ -9,7 +11,7 @@ token: public(ERC20) factory: Factory -@external +@deploy def __init__(_token: ERC20, _factory: Factory): self.token = _token self.factory = _factory diff --git a/examples/factory/Factory.vy b/examples/factory/Factory.vy index 50e7a81bf6..4fec723197 100644 --- a/examples/factory/Factory.vy +++ b/examples/factory/Factory.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 interface Exchange: @@ -11,7 +13,7 @@ exchange_codehash: public(bytes32) exchanges: public(HashMap[ERC20, Exchange]) -@external +@deploy def __init__(_exchange_codehash: bytes32): # Register the exchange code hash during deployment of the factory self.exchange_codehash = _exchange_codehash diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index 4f9859584c..74b1307dc1 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + from ethereum.ercs import ERC20 diff --git a/examples/name_registry/name_registry.vy b/examples/name_registry/name_registry.vy index 7152851dac..937b41856b 100644 --- a/examples/name_registry/name_registry.vy +++ b/examples/name_registry/name_registry.vy @@ -1,3 +1,4 @@ +#pragma version >0.3.10 registry: HashMap[Bytes[100], address] diff --git a/examples/safe_remote_purchase/safe_remote_purchase.vy b/examples/safe_remote_purchase/safe_remote_purchase.vy index edc2163b85..91f0159a2d 100644 --- a/examples/safe_remote_purchase/safe_remote_purchase.vy +++ b/examples/safe_remote_purchase/safe_remote_purchase.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Safe Remote Purchase # Originally from # https://github.com/ethereum/solidity/blob/develop/docs/solidity-by-example.rst @@ -19,7 +21,7 @@ buyer: public(address) unlocked: public(bool) ended: public(bool) -@external +@deploy @payable def __init__(): assert (msg.value % 2) == 0 diff --git a/examples/stock/company.vy b/examples/stock/company.vy index 6293e6eea4..355432830d 100644 --- a/examples/stock/company.vy +++ b/examples/stock/company.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Financial events the contract logs event Transfer: @@ -27,7 +29,7 @@ price: public(uint256) holdings: HashMap[address, uint256] # Set up the company. -@external +@deploy def __init__(_company: address, _total_shares: uint256, initial_price: uint256): assert _total_shares > 0 assert initial_price > 0 diff --git a/examples/storage/advanced_storage.vy b/examples/storage/advanced_storage.vy index 2ba50280d7..42a455cbf1 100644 --- a/examples/storage/advanced_storage.vy +++ b/examples/storage/advanced_storage.vy @@ -1,10 +1,12 @@ +#pragma version >0.3.10 + event DataChange: setter: indexed(address) value: int128 storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x diff --git a/examples/storage/storage.vy b/examples/storage/storage.vy index 7d05e4708c..30f570f212 100644 --- a/examples/storage/storage.vy +++ b/examples/storage/storage.vy @@ -1,9 +1,11 @@ +#pragma version >0.3.10 + storedData: public(int128) -@external +@deploy def __init__(_x: int128): self.storedData = _x @external def set(_x: int128): - self.storedData = _x \ No newline at end of file + self.storedData = _x diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index d1e88dcd04..d88d459d64 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,9 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### -# @version >=0.3.4 """ @dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) @@ -122,7 +123,7 @@ interface IERC1155MetadataURI: ############### functions ############### -@external +@deploy def __init__(name: String[128], symbol: String[16], uri: String[MAX_URI_LENGTH], contractUri: String[MAX_URI_LENGTH]): """ @dev contract initialization on deployment diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 77550c3f5a..0e94b32b9d 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -38,7 +40,7 @@ totalSupply: public(uint256) minter: address -@external +@deploy def __init__(_name: String[32], _symbol: String[32], _decimals: uint8, _supply: uint256): init_supply: uint256 = _supply * 10 ** convert(_decimals, uint256) self.name = _name diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index 73721fdb98..699b5edd42 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy # example implementation of an ERC4626 vault @@ -50,7 +52,7 @@ event Withdraw: shares: uint256 -@external +@deploy def __init__(asset: ERC20): self.asset = asset diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index d3a8d1f13d..70dff96051 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -82,7 +84,7 @@ SUPPORTED_INTERFACES: constant(bytes4[2]) = [ 0x80ac58cd, ] -@external +@deploy def __init__(): """ @dev Contract constructor. diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 107716accf..daaf712e0f 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + # Voting with delegation. # Information about voters @@ -50,7 +52,7 @@ def directlyVoted(addr: address) -> bool: # Setup global variables -@external +@deploy def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 231f538ecf..7e92c7e89c 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,3 +1,5 @@ +#pragma version >0.3.10 + ########################################################################### ## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! ########################################################################### @@ -12,7 +14,7 @@ threshold: int128 seq: public(int128) -@external +@deploy def __init__(_owners: address[5], _threshold: int128): for i: uint256 in range(5): if _owners[i] != empty(address): diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py index 4ddfcf50c1..335f728a37 100644 --- a/tests/functional/builtins/codegen/test_abi.py +++ b/tests/functional/builtins/codegen/test_abi.py @@ -8,14 +8,14 @@ """ x: int128 -@external +@deploy def __init__(): self.x = 1 """, """ x: int128 -@external +@deploy def __init__(): pass """, diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 69bfef63ea..96cbbe4c2d 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -224,7 +224,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index f4b7d57a04..8709e31470 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -263,7 +263,7 @@ def test_side_effects_evaluation(get_contract): contract_1 = """ counter: uint256 -@external +@deploy def __init__(): self.counter = 0 diff --git a/tests/functional/builtins/codegen/test_ceil.py b/tests/functional/builtins/codegen/test_ceil.py index daa9cb7c1b..191e2adfef 100644 --- a/tests/functional/builtins/codegen/test_ceil.py +++ b/tests/functional/builtins/codegen/test_ceil.py @@ -6,7 +6,7 @@ def test_ceil(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -53,7 +53,7 @@ def test_ceil_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_concat.py b/tests/functional/builtins/codegen/test_concat.py index 7354515989..37bdaaaf7b 100644 --- a/tests/functional/builtins/codegen/test_concat.py +++ b/tests/functional/builtins/codegen/test_concat.py @@ -79,7 +79,7 @@ def test_concat_buffer2(get_contract): code = """ i: immutable(int256) -@external +@deploy def __init__(): i = -1 s: String[2] = concat("a", "b") @@ -99,7 +99,7 @@ def test_concat_buffer3(get_contract): s2: String[33] s3: String[34] -@external +@deploy def __init__(): self.s = "a" self.s2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" # 33*'a' diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index afa729ac8a..0aa718157c 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -214,7 +214,7 @@ def test_create_from_blueprint_bad_code_offset( deployer_code = """ BLUEPRINT: immutable(address) -@external +@deploy def __init__(blueprint_address: address): BLUEPRINT = blueprint_address @@ -269,7 +269,7 @@ def test_create_from_blueprint_args( FOO: immutable(String[128]) BAR: immutable(Bar) -@external +@deploy def __init__(foo: String[128], bar: Bar): FOO = foo BAR = bar @@ -450,7 +450,7 @@ def test_create_from_blueprint_complex_value( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -507,7 +507,7 @@ def test_create_from_blueprint_complex_salt_raw_args( code = """ var: uint256 -@external +@deploy @payable def __init__(x: uint256): self.var = x @@ -565,7 +565,7 @@ def test_create_from_blueprint_complex_salt_no_constructor_args( code = """ var: uint256 -@external +@deploy @payable def __init__(): self.var = 12 diff --git a/tests/functional/builtins/codegen/test_ecrecover.py b/tests/functional/builtins/codegen/test_ecrecover.py index 8571948c3d..ce24868afe 100644 --- a/tests/functional/builtins/codegen/test_ecrecover.py +++ b/tests/functional/builtins/codegen/test_ecrecover.py @@ -68,7 +68,7 @@ def test_invalid_signature2(get_contract): owner: immutable(address) -@external +@deploy def __init__(): owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf diff --git a/tests/functional/builtins/codegen/test_floor.py b/tests/functional/builtins/codegen/test_floor.py index d2fd993785..5caffd5551 100644 --- a/tests/functional/builtins/codegen/test_floor.py +++ b/tests/functional/builtins/codegen/test_floor.py @@ -6,7 +6,7 @@ def test_floor(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = 504.0000000001 @@ -55,7 +55,7 @@ def test_floor_negative(get_contract_with_gas_estimation): code = """ x: decimal -@external +@deploy def __init__(): self.x = -504.0000000001 diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index b30a94502d..e5201e9bb2 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -137,7 +137,7 @@ def set_owner(i: int128, o: address): owners: public(address[5]) -@external +@deploy def __init__(_owner_setter: address): self.owner_setter_contract = _owner_setter diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 80936bbf82..0c5a8fc485 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -57,7 +57,7 @@ def test_slice_immutable( IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) IMMUTABLE_SLICE: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256): IMMUTABLE_BYTES = inp IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length}) @@ -119,7 +119,7 @@ def test_slice_bytes_fuzz( elif location == "code": preamble = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external +@deploy def __init__(foo: Bytes[{length_bound}]): IMMUTABLE_BYTES = foo """ @@ -230,7 +230,7 @@ def test_slice_immutable_length_arg(get_contract_with_gas_estimation): code = """ LENGTH: immutable(uint256) -@external +@deploy def __init__(): LENGTH = 5 @@ -314,7 +314,7 @@ def f() -> bytes32: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f @@ -325,7 +325,7 @@ def bar() -> Bytes[{length}]: """ foo: bytes32 -@external +@deploy def __init__(): self.foo = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index cf55607877..411f38eac9 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -2,7 +2,7 @@ def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) -@external +@deploy def __init__(): self.x = 123 """ diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index b9dc5c663f..59c4131fb2 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -33,7 +33,7 @@ def allowance(_owner: address, _spender: address) -> uint256: nonpayable token_address: ERC20Contract -@external +@deploy def __init__(token_addr: address): self.token_address = ERC20Contract(token_addr) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index a7cf4d0ecf..8b3f30b5a5 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -41,7 +41,7 @@ def test_complicated_external_contract_calls(get_contract, get_contract_with_gas contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky @@ -898,26 +898,31 @@ def set_lucky(arg1: address, arg2: int128): print("Successfully executed an external contract call state change") -def test_constant_external_contract_call_cannot_change_state( - assert_compile_failed, get_contract_with_gas_estimation -): +def test_constant_external_contract_call_cannot_change_state(): c = """ interface Foo: def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_expr(arg1: address, arg2: int128): +def set_lucky_stmt(arg1: address, arg2: int128): Foo(arg1).set_lucky(arg2) + """ + with pytest.raises(StateAccessViolation): + compile_code(c) + + c2 = """ +interface Foo: + def set_lucky(_lucky: int128) -> int128: nonpayable @external @view -def set_lucky_stmt(arg1: address, arg2: int128) -> int128: +def set_lucky_expr(arg1: address, arg2: int128) -> int128: return Foo(arg1).set_lucky(arg2) """ - assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation) - print("Successfully blocked an external contract call from a constant function") + with pytest.raises(StateAccessViolation): + compile_code(c2) def test_external_contract_can_be_changed_based_on_address(get_contract): @@ -968,7 +973,7 @@ def test_external_contract_calls_with_public_globals(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -994,7 +999,7 @@ def test_external_contract_calls_with_multiple_contracts(get_contract): contract_1 = """ lucky: public(int128) -@external +@deploy def __init__(_lucky: int128): self.lucky = _lucky """ @@ -1008,7 +1013,7 @@ def lucky() -> int128: view magic_number: public(int128) -@external +@deploy def __init__(arg1: address): self.magic_number = Foo(arg1).lucky() """ @@ -1020,7 +1025,7 @@ def magic_number() -> int128: view best_number: public(int128) -@external +@deploy def __init__(arg1: address): self.best_number = Bar(arg1).magic_number() """ @@ -1145,7 +1150,7 @@ def test_invalid_contract_reference_declaration(tx_failed, get_contract): best_number: public(int128) -@external +@deploy def __init__(): pass """ diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index e6b2402016..aa7130fd6a 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -20,7 +20,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -64,7 +64,7 @@ def set_lucky(_lucky: int128) -> int128: view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -108,7 +108,7 @@ def set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) @@ -134,7 +134,7 @@ def static_set_lucky(_lucky: int128): view modifiable_bar_contract: ModBar static_bar_contract: ConstBar -@external +@deploy def __init__(contract_address: address): self.modifiable_bar_contract = ModBar(contract_address) self.static_bar_contract = ConstBar(contract_address) diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index 266555ead6..74929c9496 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -16,7 +16,7 @@ def test_return_type(get_contract_with_gas_estimation): c: int128 chunk: Chunk -@external +@deploy def __init__(): self.chunk.a = b"hello" self.chunk.b = b"world" diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index ced58e1af0..955501a0e6 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -122,7 +122,7 @@ def bar() -> bool: """, """ # payable init function -@external +@deploy @payable def __init__(): a: int128 = 1 @@ -279,7 +279,7 @@ def baz() -> bool: """, """ # init function -@external +@deploy def __init__(): a: int128 = 1 diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 39ea1bb9ae..193112f02b 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -120,7 +120,7 @@ def test_private_bytes(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[100]) -@external +@deploy def __init__(): self.greeting = b"Hello " @@ -143,7 +143,7 @@ def test_private_statement(get_contract_with_gas_estimation): private_test_code = """ greeting: public(Bytes[20]) -@external +@deploy def __init__(): self.greeting = b"Hello " diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 7540049778..f381f60b35 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -115,7 +115,7 @@ def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): owners: address[2] -@external +@deploy def __init__(): self.owners[0] = msg.sender diff --git a/tests/functional/codegen/features/test_bytes_map_keys.py b/tests/functional/codegen/features/test_bytes_map_keys.py index 4913182d52..22df767f02 100644 --- a/tests/functional/codegen/features/test_bytes_map_keys.py +++ b/tests/functional/codegen/features/test_bytes_map_keys.py @@ -80,7 +80,7 @@ def test_extended_bytes_key_from_storage(get_contract): code = """ a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"] = 1069 @@ -114,7 +114,7 @@ def test_struct_bytes_key_memory(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -145,7 +145,7 @@ def test_struct_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Foo -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.a[b"potato"] = 31337 @@ -172,7 +172,7 @@ def test_bytes_key_storage(get_contract): a: HashMap[Bytes[100000], int128] b: Bytes[5] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 self.b = b"hello" @@ -193,7 +193,7 @@ def test_bytes_key_calldata(get_contract): a: HashMap[Bytes[100000], int128] -@external +@deploy def __init__(): self.a[b"hello"] = 1069 @@ -215,7 +215,7 @@ def test_struct_bytes_hashmap_as_key_in_other_hashmap(get_contract): bar: public(HashMap[uint256, Thing]) foo: public(HashMap[Bytes[64], uint256]) -@external +@deploy def __init__(): self.foo[b"hello"] = 31337 self.bar[12] = Thing({name: b"hello"}) diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 6db8570fc7..c028805c6a 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -67,7 +67,7 @@ def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] -@external +@deploy def __init__(x: Bytes[3]): self.foo = x diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index c9dfcfc5df..9146ace8a6 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -6,7 +6,7 @@ def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ moose: int128 -@external +@deploy def __init__(_moose: int128): self.moose = _moose @@ -26,7 +26,7 @@ def test_constructor_mapping(get_contract_with_gas_estimation): X: constant(bytes4) = 0x01ffc9a7 -@external +@deploy def __init__(): self.foo[X] = True @@ -44,7 +44,7 @@ def test_constructor_advanced_code(get_contract_with_gas_estimation): constructor_advanced_code = """ twox: int128 -@external +@deploy def __init__(x: int128): self.twox = x * 2 @@ -60,7 +60,7 @@ def test_constructor_advanced_code2(get_contract_with_gas_estimation): constructor_advanced_code2 = """ comb: uint256 -@external +@deploy def __init__(x: uint256[2], y: Bytes[3], z: uint256): self.comb = x[0] * 1000 + x[1] * 100 + len(y) * 10 + z @@ -90,7 +90,7 @@ def foo(x: int128) -> int128: def test_large_input_code_2(w3, get_contract_with_gas_estimation): large_input_code_2 = """ -@external +@deploy def __init__(x: int128): y: int128 = x @@ -113,7 +113,7 @@ def test_initialise_array_with_constant_key(get_contract_with_gas_estimation): foo: int16[X] -@external +@deploy def __init__(): self.foo[X-1] = -2 @@ -133,7 +133,7 @@ def test_initialise_dynarray_with_constant_key(get_contract_with_gas_estimation) foo: DynArray[int16, X] -@external +@deploy def __init__(): self.foo = [X - 3, X - 4, X - 5, X - 6] @@ -151,7 +151,7 @@ def test_nested_dynamic_array_constructor_arg(w3, get_contract_with_gas_estimati code = """ foo: uint256 -@external +@deploy def __init__(x: DynArray[DynArray[uint256, 3], 3]): self.foo = x[0][2] + x[1][1] + x[2][0] @@ -167,7 +167,7 @@ def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estima code = """ foo: int128 -@external +@deploy def __init__(x: DynArray[DynArray[DynArray[int128, 3], 3], 3]): self.foo = x[0][1][2] * x[1][1][1] * x[2][1][0] - x[0][0][0] - x[1][1][1] - x[2][2][2] @@ -192,7 +192,7 @@ def test_initialise_nested_dynamic_array(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[uint256, 3], 3] -@external +@deploy def __init__(x: uint256, y: uint256, z: uint256): self.foo = [ [x, y, z], @@ -212,7 +212,7 @@ def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation) code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] -@external +@deploy def __init__(x: int128, y: int128, z: int128): self.foo = [ [[x, y, z], [y, z, x], [z, y, x]], diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 47f7fc748e..d0bc47c238 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -20,7 +20,7 @@ def test_value_storage_retrieval(typ, value, get_contract): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -41,7 +41,7 @@ def test_usage_in_constructor(get_contract, val): a: public(uint256) -@external +@deploy def __init__(_a: uint256): A = _a self.a = A @@ -63,7 +63,7 @@ def test_multiple_immutable_values(get_contract): b: immutable(address) c: immutable(String[64]) -@external +@deploy def __init__(_a: uint256, _b: address, _c: String[64]): a = _a b = _b @@ -89,7 +89,7 @@ def test_struct_immutable(get_contract): my_struct: immutable(MyStruct) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: address, _d: int256): my_struct = MyStruct({ a: _a, @@ -108,11 +108,34 @@ def get_my_struct() -> MyStruct: assert c.get_my_struct() == values +def test_complex_immutable_modifiable(get_contract): + code = """ +struct MyStruct: + a: uint256 + +my_struct: immutable(MyStruct) + +@deploy +def __init__(a: uint256): + my_struct = MyStruct({a: a}) + + # struct members are modifiable after initialization + my_struct.a += 1 + +@view +@external +def get_my_struct() -> MyStruct: + return my_struct + """ + c = get_contract(code, 1) + assert c.get_my_struct() == (2,) + + def test_list_immutable(get_contract): code = """ my_list: immutable(uint256[3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -130,7 +153,7 @@ def test_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[uint256, 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [_a, _b, _c] @@ -154,7 +177,7 @@ def test_nested_dynarray_immutable_2(get_contract): code = """ my_list: immutable(DynArray[DynArray[uint256, 3], 3]) -@external +@deploy def __init__(_a: uint256, _b: uint256, _c: uint256): my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]] @@ -179,7 +202,7 @@ def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) -@external +@deploy def __init__(x: int128, y: int128, z: int128): my_list = [ [[x, y, z], [y, z, x], [z, y, x]], @@ -227,7 +250,7 @@ def foo() -> uint256: counter: uint256 VALUE: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.counter = x self.foo() @@ -257,7 +280,7 @@ def foo() -> uint256: b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -281,7 +304,7 @@ def test_immutables_initialized2(get_contract, get_contract_from_ir): b: public(uint256) @payable -@external +@deploy def __init__(to_copy: address): c: address = create_copy_of(to_copy) self.b = a @@ -299,7 +322,7 @@ def test_internal_functions_called_by_ctor_location(get_contract): d: uint256 x: immutable(uint256) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -323,7 +346,7 @@ def test_nested_internal_function_immutables(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 @@ -348,7 +371,7 @@ def test_immutable_read_ctor_and_runtime(get_contract): d: public(uint256) x: public(immutable(uint256)) -@external +@deploy def __init__(): self.d = 1 x = 2 diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index fc765f8ab3..84d224f632 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -5,7 +5,7 @@ def test_basic_init_function(get_contract): code = """ val: public(uint256) -@external +@deploy def __init__(a: uint256): self.val = a """ @@ -27,10 +27,12 @@ def __init__(a: uint256): def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) + @internal def bar(x: uint256) -> uint8: return convert(x, uint8) * 7 -@external + +@deploy def __init__(a: uint256): self.foo = self.bar(a) @@ -61,7 +63,7 @@ def test_nested_internal_call_from_ctor(get_contract): code = """ x: uint256 -@external +@deploy def __init__(): self.a() diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 0cb8ad9abc..8b80811d02 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -646,7 +646,7 @@ def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_e arg3: indexed(int128) arg4: indexed(int128) -@external +@deploy def __init__(): log MyLog(1, 2, 3, 4) """ @@ -1033,7 +1033,7 @@ def test_mixed_var_list_packing(get_logs, get_contract_with_gas_estimation): x: int128[4] y: int128[2] -@external +@deploy def __init__(): self.y = [1024, 2048] diff --git a/tests/functional/codegen/features/test_ternary.py b/tests/functional/codegen/features/test_ternary.py index c5480286c8..661fdc86c9 100644 --- a/tests/functional/codegen/features/test_ternary.py +++ b/tests/functional/codegen/features/test_ternary.py @@ -195,7 +195,7 @@ def test_ternary_tuple(get_contract, code, test): def test_ternary_immutable(get_contract, test): code = """ IMM: public(immutable(uint256)) -@external +@deploy def __init__(test: bool): IMM = 1 if test else 2 """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 891ed5aebe..1a8b3f7e9f 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -13,7 +13,7 @@ def test_crowdfund(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit @@ -109,7 +109,7 @@ def test_crowdfund2(w3, tester, get_contract_with_gas_estimation_for_constants): refundIndex: int128 timelimit: public(uint256) -@external +@deploy def __init__(_beneficiary: address, _goal: uint256, _timelimit: uint256): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 70e7cb4594..f86b4aa516 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -41,7 +41,7 @@ def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimat seller: address arbitrator: address -@external +@deploy @payable def __init__(_seller: address, _arbitrator: address): if self.buyer == empty(address): diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py index aafbb69252..ebfefb4546 100644 --- a/tests/functional/codegen/modules/test_module_constants.py +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -76,3 +76,23 @@ def foo(ix: uint256) -> uint256: assert c.foo(2) == 3 with tx_failed(): c.foo(3) + + +def test_module_constant_builtin(make_input_bundle, get_contract): + # test empty builtin, which is not (currently) foldable 2024-02-06 + mod1 = """ +X: constant(uint256) = empty(uint256) + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 0 diff --git a/tests/functional/codegen/modules/test_module_variables.py b/tests/functional/codegen/modules/test_module_variables.py new file mode 100644 index 0000000000..6bb1f9072c --- /dev/null +++ b/tests/functional/codegen/modules/test_module_variables.py @@ -0,0 +1,318 @@ +def test_simple_import(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import lib + +initializes: lib + +@external +def increment_counter(): + lib.increment_counter() + +@external +def get_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == 0 + c.increment_counter(transact={}) + assert c.get_counter() == 1 + + +def test_import_namespace(get_contract, make_input_bundle): + # test what happens when things in current and imported modules share names + lib = """ +counter: uint256 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: uint256 + +initializes: lib + +@external +def increment_counter(): + self.counter += 1 + +@external +def increment_lib_counter(): + lib.increment_counter() + +@external +def increment_lib_counter2(): + # modify lib.counter directly + lib.counter += 5 + +@external +def get_counter() -> uint256: + return self.counter + +@external +def get_lib_counter() -> uint256: + return lib.counter + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_counter() == c.get_lib_counter() == 0 + + c.increment_counter(transact={}) + assert c.get_counter() == 1 + assert c.get_lib_counter() == 0 + + c.increment_lib_counter(transact={}) + assert c.get_lib_counter() == 1 + assert c.get_counter() == 1 + + c.increment_lib_counter2(transact={}) + assert c.get_lib_counter() == 6 + assert c.get_counter() == 1 + + +def test_init_function_side_effects(get_contract, make_input_bundle): + lib = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + + contract = """ +import library as lib + +counter: public(uint256) + +MY_IMMUTABLE: public(immutable(uint256)) + +initializes: lib + +@deploy +def __init__(): + self.counter = 1 + MY_IMMUTABLE = 3 + lib.__init__(5) + +@external +def get_lib_counter() -> uint256: + return lib.counter + +@external +def get_lib_immutable() -> uint256: + return lib.MY_IMMUTABLE + """ + + input_bundle = make_input_bundle({"library.vy": lib}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.counter() == 1 + assert c.MY_IMMUTABLE() == 3 + assert c.get_lib_counter() == 5 + assert c.get_lib_immutable() == 10 + + +def test_indirect_variable_uses(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 + +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value + MY_IMMUTABLE = initial_value * 2 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +uses: lib1 + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2[lib1 := lib1] + +@deploy +def __init__(): + lib1.__init__(5) + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 5 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 10 + + +def test_uses_already_initialized(get_contract, make_input_bundle): + lib1 = """ +counter: uint256 +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(initial_value: uint256): + self.counter = initial_value * 2 + MY_IMMUTABLE = initial_value * 3 + +@internal +def increment_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__(5) + +@internal +def get_lib1_counter() -> uint256: + return lib1.counter + +@internal +def get_lib1_my_immutable() -> uint256: + return lib1.MY_IMMUTABLE + """ + + contract = """ +import lib1 +import lib2 + +uses: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib2.__init__() + +@external +def get_storage_via_lib1() -> uint256: + return lib1.counter + +@external +def get_immutable_via_lib1() -> uint256: + return lib1.MY_IMMUTABLE + +@external +def get_storage_via_lib2() -> uint256: + return lib2.get_lib1_counter() + +@external +def get_immutable_via_lib2() -> uint256: + return lib2.get_lib1_my_immutable() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_storage_via_lib1() == c.get_storage_via_lib2() == 10 + assert c.get_immutable_via_lib1() == c.get_immutable_via_lib2() == 15 + + +def test_import_complex_types(get_contract, make_input_bundle): + lib1 = """ +an_array: uint256[3] +a_hashmap: HashMap[address, HashMap[uint256, uint256]] + +@internal +def set_array_value(ix: uint256, new_value: uint256): + self.an_array[ix] = new_value + +@internal +def set_hashmap_value(ix0: address, ix1: uint256, new_value: uint256): + self.a_hashmap[ix0][ix1] = new_value + """ + + contract = """ +import lib + +initializes: lib + +@external +def do_things(): + lib.set_array_value(1, 5) + lib.set_hashmap_value(msg.sender, 6, 100) + +@external +def get_array_value(ix: uint256) -> uint256: + return lib.an_array[ix] + +@external +def get_hashmap_value(ix: uint256) -> uint256: + return lib.a_hashmap[msg.sender][ix] + """ + + input_bundle = make_input_bundle({"lib.vy": lib1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + c.do_things(transact={}) + + assert c.get_array_value(0) == 0 + assert c.get_hashmap_value(0) == 0 + assert c.get_array_value(1) == 5 + assert c.get_hashmap_value(6) == 100 diff --git a/tests/functional/codegen/storage_variables/test_getters.py b/tests/functional/codegen/storage_variables/test_getters.py index a2d9c6d0bb..9e72bed075 100644 --- a/tests/functional/codegen/storage_variables/test_getters.py +++ b/tests/functional/codegen/storage_variables/test_getters.py @@ -41,7 +41,7 @@ def foo(): nonpayable f: public(constant(uint256[2])) = [3, 7] g: public(constant(V)) = V(0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF) -@external +@deploy def __init__(): self.x = as_wei_value(7, "wei") self.y[1] = 9 @@ -87,7 +87,7 @@ def test_getter_mutability(get_contract): nyoro: public(constant(uint256)) = 2 kune: public(immutable(uint256)) -@external +@deploy def __init__(): kune = 2 """ diff --git a/tests/functional/codegen/storage_variables/test_storage_variable.py b/tests/functional/codegen/storage_variables/test_storage_variable.py index 4636fa77e0..7a22d35e4b 100644 --- a/tests/functional/codegen/storage_variables/test_storage_variable.py +++ b/tests/functional/codegen/storage_variables/test_storage_variable.py @@ -10,7 +10,7 @@ def test_permanent_variables_test(get_contract_with_gas_estimation): b: int128 var: Var -@external +@deploy def __init__(a: int128, b: int128): self.var.a = a self.var.b = b diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3344ff113b..85efe904a0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -305,7 +305,7 @@ def test() -> uint256: view token_address: IToken -@external +@deploy def __init__(_token_address: address): self.token_address = IToken(_token_address) @@ -388,7 +388,7 @@ def transfer(to: address, amount: uint256) -> bool: token_address: ERC20 -@external +@deploy def __init__(_token_address: address): self.token_address = ERC20(_token_address) @@ -445,7 +445,7 @@ def should_fail() -> {typ}: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -501,7 +501,7 @@ def should_fail() -> Bytes[2]: view foo: BadContract -@external +@deploy def __init__(addr: BadContract): self.foo = addr @@ -551,7 +551,7 @@ def foo(x: BadJSONInterface) -> Bytes[2]: foo: BadJSONInterface -@external +@deploy def __init__(addr: BadJSONInterface): self.foo = addr @@ -667,7 +667,7 @@ def foo() -> uint256: view bar_contract: Bar -@external +@deploy def __init__(): self.bar_contract = Bar(self) diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 325f9d7923..99e5835f6e 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -51,7 +51,7 @@ def test_test_bytes3(get_contract_with_gas_estimation): maa: Bytes[60] y: int128 -@external +@deploy def __init__(): self.x = 27 self.y = 37 diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d3d945740b..fc3223caaf 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -1665,7 +1665,7 @@ def ix(i: uint256) -> decimal: def test_public_dynarray(get_contract): code = """ my_list: public(DynArray[uint256, 5]) -@external +@deploy def __init__(): self.my_list = [1,2,3] """ @@ -1678,7 +1678,7 @@ def __init__(): def test_nested_public_dynarray(get_contract): code = """ my_list: public(DynArray[DynArray[uint256, 5], 5]) -@external +@deploy def __init__(): self.my_list = [[1,2,3]] """ diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 5da6d57558..dd9c867a96 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -160,7 +160,7 @@ def test_augassign_storage(get_contract, w3, tx_failed): roles: public(HashMap[address, Roles]) -@external +@deploy def __init__(): self.roles[msg.sender] = Roles.ADMIN diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 9d50f8df38..9d596eda32 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -90,7 +90,7 @@ def test_private_string(get_contract_with_gas_estimation): private_test_code = """ greeting: public(String[100]) -@external +@deploy def __init__(): self.greeting = "Hello " diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index e21a113f61..f6eb3966d4 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -118,7 +118,7 @@ def unlocked() -> bool: view purchase_contract: PurchaseContract -@external +@deploy def __init__(_purchase_contract: address): self.purchase_contract = PurchaseContract(_purchase_contract) diff --git a/tests/functional/syntax/exceptions/test_call_violation.py b/tests/functional/syntax/exceptions/test_call_violation.py index d310a2b42a..d96df07e74 100644 --- a/tests/functional/syntax/exceptions/test_call_violation.py +++ b/tests/functional/syntax/exceptions/test_call_violation.py @@ -27,6 +27,15 @@ def goo(): def foo(): self.goo() """, + """ +@deploy +def __init__(): + pass + +@internal +def foo(): + self.__init__() + """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 7adf9538c7..6bfb8fee57 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -78,7 +78,7 @@ def foo(): """ f:int128 -@external +@internal def a (x:int128): self.f = 100 @@ -86,6 +86,63 @@ def a (x:int128): @external def b(): self.a(10)""", + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: view +@external +@pure +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + x:uint16 = a.bar() + """, + """ +interface A: + def bar() -> uint16: nonpayable +@external +@view +def test(to:address): + a:A = A(to) + a.bar() + """, + """ +a:DynArray[uint16,3] +@deploy +def __init__(): + self.a = [1,2,3] +@view +@external +def bar()->DynArray[uint16,3]: + x:uint16 = self.a.pop() + return self.a # return [1,2] + """, + """ +from ethereum.ercs import ERC20 + +token: ERC20 + +@external +@view +def topup(amount: uint256): + assert self.token.transferFrom(msg.sender, self, amount) + """, ], ) def test_statefulness_violations(bad_code): diff --git a/tests/functional/syntax/exceptions/test_function_declaration_exception.py b/tests/functional/syntax/exceptions/test_function_declaration_exception.py index 3fe23e0ec7..878c7f3e29 100644 --- a/tests/functional/syntax/exceptions/test_function_declaration_exception.py +++ b/tests/functional/syntax/exceptions/test_function_declaration_exception.py @@ -34,17 +34,17 @@ def test_func() -> int128: return (1, 2) """, """ -@external +@deploy def __init__(a: int128 = 12): pass """, """ -@external +@deploy def __init__() -> uint256: return 1 """, """ -@external +@deploy def __init__() -> bool: pass """, @@ -58,7 +58,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @pure def __init__(): a = 1 @@ -66,7 +66,7 @@ def __init__(): """ a: immutable(uint256) -@external +@deploy @view def __init__(): a = 1 diff --git a/tests/functional/syntax/exceptions/test_instantiation_exception.py b/tests/functional/syntax/exceptions/test_instantiation_exception.py index 0d641f154a..4dd0bf6e02 100644 --- a/tests/functional/syntax/exceptions/test_instantiation_exception.py +++ b/tests/functional/syntax/exceptions/test_instantiation_exception.py @@ -69,7 +69,7 @@ def foo(): """ b: immutable(HashMap[uint256, uint256]) -@external +@deploy def __init__(): b = empty(HashMap[uint256, uint256]) """, diff --git a/tests/functional/syntax/exceptions/test_invalid_reference.py b/tests/functional/syntax/exceptions/test_invalid_reference.py index fe315e5cbf..7519d1406e 100644 --- a/tests/functional/syntax/exceptions/test_invalid_reference.py +++ b/tests/functional/syntax/exceptions/test_invalid_reference.py @@ -47,7 +47,7 @@ def foo(): """ a: public(immutable(uint256)) -@external +@deploy def __init__(): a = 123 diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index c6d733fc90..afc7a35012 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -94,7 +94,7 @@ def foo(): a: immutable(uint256) n: public(HashMap[uint256, bool][a]) -@external +@deploy def __init__(): a = 3 """, @@ -105,14 +105,14 @@ def __init__(): m1: HashMap[uint8, uint8] m2: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = self.m2 """, """ m1: HashMap[uint8, uint8] -@external +@deploy def __init__(): self.m1 = 234 """, diff --git a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py index a261cb0a11..9e0767cb83 100644 --- a/tests/functional/syntax/exceptions/test_vyper_exception_pos.py +++ b/tests/functional/syntax/exceptions/test_vyper_exception_pos.py @@ -22,7 +22,7 @@ def test_multiple_exceptions(get_contract, assert_compile_failed): foo: immutable(uint256) bar: immutable(uint256) -@external +@deploy def __init__(): self.foo = 1 # SyntaxException self.bar = 2 # SyntaxException diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py new file mode 100644 index 0000000000..f51bf9575b --- /dev/null +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -0,0 +1,27 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import CallViolation + + +def test_call_deploy_from_external(make_input_bundle): + lib1 = """ +@deploy +def __init__(): + pass + """ + + main = """ +import lib1 + +@external +def foo(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(CallViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value.message == "Cannot call an @deploy function from an @external function!" diff --git a/tests/functional/syntax/modules/test_implements.py b/tests/functional/syntax/modules/test_implements.py new file mode 100644 index 0000000000..c292e198d9 --- /dev/null +++ b/tests/functional/syntax/modules/test_implements.py @@ -0,0 +1,51 @@ +from vyper.compiler import compile_code + + +def test_implements_from_vyi(make_input_bundle): + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + main = """ +import lib1 + +implements: lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_implements_from_vyi2(make_input_bundle): + # test implements via nested imported vyi file + vyi = """ +@external +def foo(): + ... + """ + lib1 = """ +import some_interface + """ + lib2 = """ +import lib1 + """ + main = """ +import lib2 + +implements: lib2.lib1.some_interface + +@external +def foo(): # implementation + pass + """ + input_bundle = make_input_bundle({"some_interface.vyi": vyi, "lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py new file mode 100644 index 0000000000..a12f5f57ea --- /dev/null +++ b/tests/functional/syntax/modules/test_initializers.py @@ -0,0 +1,1139 @@ +""" +tests for the uses/initializes checker +main properties to test: +- state usage -- if a module uses state, it must `used` or `initialized` +- conversely, if a module does not touch state, it should not be `used` +- global initializer check: each used module is `initialized` exactly once +""" + +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import ( + BorrowException, + ImmutableViolation, + InitializerException, + StructureException, + UndeclaredDefinition, +) + + +def test_initialize_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib2 +import lib1 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + lib2.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multiple_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +# multiple uses on one line +uses: ( + lib1, + lib2 +) + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_multi_line_uses(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +totalSupply: uint256 + """ + lib3 = """ +import lib1 +import lib2 + +uses: lib1 +uses: lib2 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + x: uint256 = lib2.totalSupply + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib2 +initializes: lib3[ + lib1 := lib1, + lib2 := lib2 +] + +@deploy +def __init__(): + lib1.__init__() + lib3.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initialize_uses_attribute(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@deploy +def __init__(): + pass + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + lib2.__init__() + # demonstrate we can call lib1.__init__ through lib2.lib1 + # (not sure this should be allowed, really. + lib2.lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializes_without_init_function(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@deploy +def __init__(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_imported_as_different_names(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 as some_module +import lib2 + +initializes: lib2[m := some_module] +initializes: some_module + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_initializer_list_module_mismatch(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +something: uint256 + """ + lib3 = """ +import lib1 + +uses: lib1 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 +import lib3 + +initializes: lib1 +initializes: lib3[lib1 := lib2] # typo -- should be [lib1 := lib1] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3}) + + with pytest.raises(StructureException) as e: + assert compile_code(main, input_bundle=input_bundle) is not None + + assert e.value._message == "lib1 is not lib2!" + + +def test_imported_as_different_names_error(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 as m + +uses: m + +counter: uint256 + +@internal +def foo(): + m.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "unknown module `lib1`" + assert e.value._hint == "did you mean `m := lib1`?" + + +def test_global_initializer_constraint(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +# forgot to initialize lib1! + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "module `lib1.vy` is used but never initialized!" + assert e.value._hint == "add `initializes: lib1` to the top level of your main contract" + + +def test_initializer_no_references(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" + assert e.value._hint == "add `lib1` to its initializer list" + + +def test_missing_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.counter += 1 + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.counter + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_immutable(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 7 + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.MY_IMMUTABLE + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_read_inside_call(make_input_bundle): + lib1 = """ +MY_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + MY_IMMUTABLE = 9 + +@internal +def get_counter() -> uint256: + return MY_IMMUTABLE + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo() -> uint256: + return lib1.get_counter() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_hashmap(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + return lib1.counter[1][2] + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + """ + lib2 = """ +import lib1 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +something: uint256 + +# forgot `uses: lib1`! + +@internal +def foo() -> uint256: + lib1.counter[1][2], self.something = Foo(msg.sender).foo() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_for_tuple_function_call(make_input_bundle): + lib1 = """ +counter: HashMap[uint256, HashMap[uint256, uint256]] + +something: uint256 + +interface Foo: + def foo() -> (uint256, uint256): nonpayable + +@internal +def write_tuple(): + self.counter[1][2], self.something = Foo(msg.sender).foo() + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! +@internal +def foo(): + lib1.write_tuple() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2 + +@deploy +def __init__(): + lib1.counter = 100 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_function_call(make_input_bundle): + # test missing uses through function call + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +# forgot `uses: lib1`! + +counter: uint256 + +@internal +def foo(): + lib1.update_counter(lib1.counter + 1) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_missing_uses_nested_attribute_function_call(make_input_bundle): + # test missing uses through nested attribute access + lib1 = """ +counter: uint256 + +@internal +def update_counter(new_value: uint256): + self.counter = new_value + """ + lib2 = """ +import lib1 + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 + +# did not `use` or `initialize` lib2! + +@external +def foo(new_value: uint256): + # cannot access lib1 state through lib2 + lib2.lib1.update_counter(new_value) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib2` state!" + + expected_hint = "add `uses: lib2` or `initializes: lib2` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_uses_skip_import(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2 + +@external +def foo(new_value: uint256): + # can access lib1 state through lib2? + lib2.lib1.counter = new_value + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "Cannot access `lib1` state!" + + expected_hint = "add `uses: lib1` or `initializes: lib1` as a " + expected_hint += "top-level statement to your contract" + assert e.value._hint == expected_hint + + +def test_invalid_uses(make_input_bundle): + lib1 = """ +counter: uint256 + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_invalid_uses2(make_input_bundle): + # test a more complicated invalid uses + lib1 = """ +counter: uint256 + +@internal +def foo(addr: address): + # sends value -- modifies ethereum state + to_send_value: uint256 = 100 + raw_call(addr, b"someFunction()", value=to_send_value) + """ + lib2 = """ +import lib1 + +uses: lib1 # not necessary! + +counter: uint256 + +@internal +def foo(): + lib1.foo(msg.sender) + """ + main = """ +import lib1 +import lib2 + +initializes: lib2[lib1 := lib1] +initializes: lib1 + +@external +def foo(): + lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(BorrowException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "`lib1` is declared as used, but it is not actually used in lib2.vy!" + assert e.value._hint == "delete `uses: lib1`" + + +def test_initializes_uses_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_uses_initializes_conflict(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_uses_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +uses: lib1 + +random_variable: constant(uint256) = 3 + +uses: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `uses`" + + +def test_initializes_twice(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 + +initializes: lib1 + +random_variable: constant(uint256) = 3 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "ownership already set to `initializes`" + + +def test_no_initialize_unused_module(make_input_bundle): + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + main = """ +import lib1 + +# not needed: `initializes: lib1` + +@external +def do_add(x: uint256, y: uint256) -> uint256: + return lib1.add(x, y) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_no_initialize_unused_module2(make_input_bundle): + # slightly more complicated + lib1 = """ +counter: uint256 + +@internal +def set_counter(new_value: uint256): + self.counter = new_value + +@internal +@pure +def add(x: uint256, y: uint256) -> uint256: + return x + y + """ + lib2 = """ +import lib1 + +@internal +@pure +def addmul(x: uint256, y: uint256, z: uint256) -> uint256: + return lib1.add(x, y) * z + """ + main = """ +import lib1 +import lib2 + +@external +def do_addmul(x: uint256, y: uint256) -> uint256: + return lib2.addmul(x, y, 5) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + assert compile_code(main, input_bundle=input_bundle) is not None + + +def test_init_uninitialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_init_uninitialized_function2(make_input_bundle): + # test that we can't call module.__init__() even when we call `uses` + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + pass + """ + main = """ +import lib1 + +uses: lib1 +# missing `initializes: lib1`! + +@deploy +def __init__(): + lib1.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "tried to initialize `lib1`, but it is not in initializer list!" + assert e.value._hint == "add `initializes: lib1` as a top-level statement to your contract" + + +def test_noinit_initialized_function(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + pass # missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_noinit_initialized_function2(make_input_bundle): + lib1 = """ +counter: uint256 + +@deploy +def __init__(): + self.counter = 5 + """ + main = """ +import lib1 + +initializes: lib1 + +# missing `lib1.__init__()`! + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InitializerException) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "not initialized!" + assert e.value._hint == "add `lib1.__init__()` to your `__init__()` function" + + +def test_ownership_decl_errors_not_swallowed(make_input_bundle): + lib1 = """ +counter: uint256 + """ + main = """ +import lib1 +# forgot to import lib2 + +uses: (lib1, lib2) # should get UndeclaredDefinition + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(UndeclaredDefinition) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "'lib2' has not been declared. " diff --git a/tests/functional/syntax/test_address_code.py b/tests/functional/syntax/test_address_code.py index fa6ed20117..5873eb5af8 100644 --- a/tests/functional/syntax/test_address_code.py +++ b/tests/functional/syntax/test_address_code.py @@ -165,7 +165,7 @@ def test_address_code_self_success(get_contract, optimize): code = """ code_deployment: public(Bytes[32]) -@external +@deploy def __init__(): self.code_deployment = slice(self.code, 0, 32) @@ -186,7 +186,7 @@ def test_address_code_self_runtime_error_deployment(get_contract): code = """ dummy: public(Bytes[1000000]) -@external +@deploy def __init__(): self.dummy = slice(self.code, 0, 1000000) """ diff --git a/tests/functional/syntax/test_codehash.py b/tests/functional/syntax/test_codehash.py index c2d9a2e274..8aada22da7 100644 --- a/tests/functional/syntax/test_codehash.py +++ b/tests/functional/syntax/test_codehash.py @@ -11,7 +11,7 @@ def test_get_extcodehash(get_contract, evm_version, optimize): code = """ a: address -@external +@deploy def __init__(): self.a = self diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index 57922f28e2..63abf24485 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -94,7 +94,7 @@ VAL: immutable(uint256) VAL: uint256 -@external +@deploy def __init__(): VAL = 1 """, @@ -106,7 +106,7 @@ def __init__(): VAL: uint256 VAL: immutable(uint256) -@external +@deploy def __init__(): VAL = 1 """, diff --git a/tests/functional/syntax/test_immutables.py b/tests/functional/syntax/test_immutables.py index 1027d9fe66..59fb1a69d9 100644 --- a/tests/functional/syntax/test_immutables.py +++ b/tests/functional/syntax/test_immutables.py @@ -8,7 +8,7 @@ """ VALUE: immutable(uint256) -@external +@deploy def __init__(): pass """, @@ -25,7 +25,7 @@ def get_value() -> uint256: """ VALUE: immutable(uint256) = 3 -@external +@deploy def __init__(): pass """, @@ -33,7 +33,7 @@ def __init__(): """ VALUE: immutable(uint256) -@external +@deploy def __init__(): VALUE = 0 @@ -45,7 +45,7 @@ def set_value(_value: uint256): """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 VALUE = VALUE + 1 @@ -54,7 +54,7 @@ def __init__(_value: uint256): """ VALUE: immutable(public(uint256)) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 """, @@ -85,7 +85,7 @@ def test_compilation_simple_usage(typ): code = f""" VALUE: immutable({typ}) -@external +@deploy def __init__(_value: {typ}): VALUE = _value @@ -103,7 +103,7 @@ def get_value() -> {typ}: """ VALUE: immutable(uint256) -@external +@deploy def __init__(_value: uint256): VALUE = _value * 3 x: uint256 = VALUE + 1 @@ -121,7 +121,7 @@ def test_compilation_success(good_code): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): self.imm = x """, @@ -131,7 +131,7 @@ def __init__(x: uint256): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): x = imm @@ -145,7 +145,7 @@ def report(): """ imm: immutable(uint256) -@external +@deploy def __init__(x: uint256): imm = x @@ -163,7 +163,7 @@ def report(): x: immutable(Foo) -@external +@deploy def __init__(): x = Foo({a:1}) diff --git a/tests/functional/syntax/test_init.py b/tests/functional/syntax/test_init.py new file mode 100644 index 0000000000..389b5ad681 --- /dev/null +++ b/tests/functional/syntax/test_init.py @@ -0,0 +1,64 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.exceptions import FunctionDeclarationException + +good_list = [ + """ +@deploy +def __init__(): + pass + """, + """ +@deploy +@payable +def __init__(): + pass + """, + """ +counter: uint256 +SOME_IMMUTABLE: immutable(uint256) + +@deploy +def __init__(): + SOME_IMMUTABLE = 5 + self.counter = 1 + """, +] + + +@pytest.mark.parametrize("code", good_list) +def test_good_init_funcs(code): + assert compile_code(code) is not None + + +fail_list = [ + """ +@internal +def __init__(): + pass + """, + """ +@deploy +@view +def __init__(): + pass + """, + """ +@deploy +@pure +def __init__(): + pass + """, + """ +@deploy +def some_function(): # for now, only __init__() functions can be marked @deploy + pass + """, +] + + +@pytest.mark.parametrize("code", fail_list) +def test_bad_init_funcs(code): + with pytest.raises(FunctionDeclarationException): + compile_code(code) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 584e497534..a07ec4e3dc 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -304,7 +304,7 @@ def some_func(): nonpayable my_interface: MyInterface[3] idx: uint256 -@external +@deploy def __init__(): self.my_interface[self.idx] = MyInterface(empty(address)) """, @@ -348,7 +348,7 @@ def foo() -> uint256: view foo: public(immutable(uint256)) -@external +@deploy def __init__(x: uint256): foo = x """, diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 71bff753f4..217fcea998 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -10,7 +10,7 @@ x: public(constant(int128)) = 0 y: public(immutable(int128)) -@external +@deploy def __init__(): y = 0 """, diff --git a/tests/functional/syntax/test_tuple_assign.py b/tests/functional/syntax/test_tuple_assign.py index 49b63ee614..bb23804e30 100644 --- a/tests/functional/syntax/test_tuple_assign.py +++ b/tests/functional/syntax/test_tuple_assign.py @@ -92,7 +92,7 @@ def test(a: bytes32) -> (bytes32, uint256, int128): """ B: immutable(uint256) -@external +@deploy def __init__(b: uint256): B = b diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 20390f3d5e..9fec61cb90 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -109,16 +109,6 @@ def foo() -> uint256: "node_id": 9, "src": "48:15:0", "ast_type": "ImplementsDecl", - "target": { - "col_offset": 0, - "end_col_offset": 10, - "node_id": 10, - "src": "48:10:0", - "ast_type": "Name", - "end_lineno": 5, - "lineno": 5, - "id": "implements", - }, "end_lineno": 5, "lineno": 5, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index 1aa8901881..f0ee25f747 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -56,7 +56,7 @@ def test_storage_and_immutables_layout(): SYMBOL: immutable(String[32]) DECIMALS: immutable(uint8) -@external +@deploy def __init__(): SYMBOL = "VYPR" DECIMALS = 18 @@ -72,3 +72,251 @@ def __init__(): out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout + + +def test_storage_layout_module(make_input_bundle): + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + a_library.__init__() + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "counter2": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module2(make_input_bundle): + # test module storage layout, but initializes is in a different order + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + code = """ +import lib1 as a_library + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +initializes: a_library + +counter2: uint256 + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "a_library": {"supply": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_uses(make_input_bundle): + # test module storage layout, with initializes/uses + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +uses: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2[lib1 := a_library] + +counter2: uint256 + +initializes: a_library + +@deploy +def __init__(): + a_library.__init__() + some_immutable = [1, 2, 3] + + lib2.__init__(17) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": {"immutable_variable": {"length": 32, "offset": 352, "type": "uint256"}}, + "a_library": { + "SYMBOL": {"length": 64, "offset": 384, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 448, "type": "uint8"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout + + +def test_storage_layout_module_nested_initializes(make_input_bundle): + # test module storage layout, with initializes in an imported module + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + expected_layout = { + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + "storage_layout": { + "counter": {"slot": 0, "type": "uint256"}, + "lib2": { + "lib1": {"supply": {"slot": 1, "type": "uint256"}}, + "storage_variable": {"slot": 2, "type": "uint256"}, + }, + "counter2": {"slot": 3, "type": "uint256"}, + }, + } + + out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) + assert out["layout"] == expected_layout diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index b2851e908a..ce32249202 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -20,7 +20,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -44,7 +44,7 @@ def ctor_only(): def bar(): self.foo() -@external +@deploy def __init__(): self.ctor_only() """, @@ -65,7 +65,7 @@ def runtime_only(): def bar(): self.runtime_only() -@external +@deploy def __init__(): self.ctor_only() """, @@ -73,6 +73,9 @@ def __init__(): # check dead code eliminator works on unreachable functions +# CMC 2024-02-05 this is not really the asm eliminator anymore, +# it happens during function code generation in module.py. so we don't +# need to test this using asm anymore. @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) @@ -88,20 +91,9 @@ def test_dead_code_eliminator(code): assert any(ctor_only in instr for instr in initcode_asm) assert all(runtime_only not in instr for instr in initcode_asm) - # all labels should be in unoptimized runtime asm - for s in (ctor_only, runtime_only): - assert any(s in instr for instr in runtime_asm) - - c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if isinstance(i, str)] - runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] - - # ctor only label should not be in runtime code + assert any(runtime_only in instr for instr in runtime_asm) assert all(ctor_only not in instr for instr in runtime_asm) - # runtime only label should not be in initcode asm - assert all(runtime_only not in instr for instr in initcode_asm) - def test_library_code_eliminator(make_input_bundle): library = """ diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 613ee4d2b8..64cee3a75c 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -35,7 +35,7 @@ def foo5(): has_immutables = """ A_GOOD_PRIME: public(immutable(uint256)) -@external +@deploy def __init__(): A_GOOD_PRIME = 967 """ diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index ea2b2fe559..3620ef64b9 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -25,7 +25,7 @@ h: public(int256[1]) -@external +@deploy def __init__(): self.a = StructOne({a: "ok", b: [4,5,6]}) self.b = [7, 8] @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", + match=f"Invalid storage slot, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index bc08626b59..0ae93e9710 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -5,7 +5,7 @@ from . import nodes, validation from .natspec import parse_natspec -from .nodes import compare_nodes +from .nodes import compare_nodes, as_tuple from .utils import ast_to_dict from .parse import parse_to_ast, parse_to_ast_with_settings @@ -15,6 +15,5 @@ ): setattr(sys.modules[__name__], name, obj) - # required to avoid circular dependency from . import expansion # noqa: E402 diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 84429501e1..5ad465a1f1 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -182,13 +182,9 @@ loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body -// ternary operator -ternary: _expr "if" _expr "else" _expr - // Expressions _expr: operation | dict - | ternary get_item: (variable_access | list) "[" _expr "]" get_attr: variable_access "." NAME @@ -214,7 +210,15 @@ dict: "{" "}" | "{" (NAME ":" _expr) ("," (NAME ":" _expr))* [","] "}" // See https://docs.python.org/3/reference/expressions.html#operator-precedence // NOTE: The recursive cycle here helps enforce operator precedence // Precedence goes up the lower down you go -?operation: bool_or +?operation: assignment_expr + +// "walrus" operator +?assignment_expr: ternary + | NAME ":=" assignment_expr + +// ternary operator +?ternary: bool_or + | ternary "if" ternary "else" ternary _AND: "and" _OR: "or" diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 054145d33b..c4bce814a4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -83,8 +83,20 @@ def get_node( if ast_struct["value"] is not None: _raise_syntax_exc("`implements` cannot have a value assigned", ast_struct) ast_struct["ast_type"] = "ImplementsDecl" + + # Replace "uses:" `AnnAssign` nodes with `UsesDecl` + elif getattr(ast_struct["target"], "id", None) == "uses": + if ast_struct["value"] is not None: + _raise_syntax_exc("`uses` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "UsesDecl" + + # Replace "initializes:" `AnnAssign` nodes with `InitializesDecl` + elif getattr(ast_struct["target"], "id", None) == "initializes": + if ast_struct["value"] is not None: + _raise_syntax_exc("`initializes` cannot have a value assigned", ast_struct) + ast_struct["ast_type"] = "InitializesDecl" + # Replace state and local variable declarations `AnnAssign` with `VariableDecl` - # Parent node is required for context to determine whether replacement should happen. else: ast_struct["ast_type"] = "VariableDecl" @@ -730,6 +742,20 @@ def is_terminus(self): return self.value.is_terminus +class NamedExpr(Stmt): + __slots__ = ("target", "value") + + def validate(self): + # module[dep1 := dep2] + + # XXX: better error messages + if not isinstance(self.target, Name): + raise StructureException("not a Name") + + if not isinstance(self.value, Name): + raise StructureException("not a Name") + + class Log(Stmt): __slots__ = ("value",) @@ -756,6 +782,11 @@ class StructDef(TopLevel): class ExprNode(VyperNode): __slots__ = ("_expr_info",) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._expr_info = None + class Constant(ExprNode): # inherited class for all simple constant node types @@ -1383,17 +1414,13 @@ class ImplementsDecl(Stmt): """ An `implements` declaration. - Excludes `simple` and `value` attributes from Python `AnnAssign` node. - Attributes ---------- - target : Name - Name node for the `implements` keyword annotation : Name Name node for the interface to be implemented """ - __slots__ = ("target", "annotation") + __slots__ = ("annotation",) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1402,6 +1429,72 @@ def __init__(self, *args, **kwargs): raise StructureException("invalid implements", self.annotation) +def as_tuple(node: VyperNode): + """ + Convenience function for some AST nodes which allow either a Tuple + or single elements. Returns a python tuple of AST nodes. + """ + if isinstance(node, Tuple): + return node.elements + else: + return (node,) + + +class UsesDecl(Stmt): + """ + A `uses` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Tuple + The module(s) which this uses + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + items = as_tuple(self.annotation) + for item in items: + if not isinstance(item, (Name, Attribute)): + raise StructureException("invalid uses", item) + + +class InitializesDecl(Stmt): + """ + An `initializes` declaration. + + Attributes + ---------- + annotation : Name | Attribute | Subscript + An imported module which this module initializes + """ + + __slots__ = ("annotation",) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + module_ref = self.annotation + if isinstance(module_ref, Subscript): + dependencies = as_tuple(module_ref.slice) + module_ref = module_ref.value + + for item in dependencies: + if not isinstance(item, NamedExpr): + raise StructureException( + "invalid dependency (hint: should be [dependency := dependency]", item + ) + if not isinstance(item.target, (Name, Attribute)): + raise StructureException("invalid module", item.target) + if not isinstance(item.value, (Name, Attribute)): + raise StructureException("invalid module", item.target) + + if not isinstance(module_ref, (Name, Attribute)): + raise StructureException("invalid module", module_ref) + + class If(Stmt): __slots__ = ("test", "body", "orelse") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index f71ed67821..7f863a8db9 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -101,7 +101,8 @@ class StructDef(VyperNode): body: list = ... name: str = ... -class ExprNode(VyperNode): ... +class ExprNode(VyperNode): + _expr_info: Any = ... class Constant(VyperNode): value: Any = ... @@ -145,19 +146,19 @@ class Name(VyperNode): _type: str = ... class Expr(VyperNode): - value: VyperNode = ... + value: ExprNode = ... class UnaryOp(ExprNode): op: VyperNode = ... - operand: VyperNode = ... + operand: ExprNode = ... class USub(VyperNode): ... class Not(VyperNode): ... class BinOp(ExprNode): - left: VyperNode = ... op: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Add(VyperNode): ... class Sub(VyperNode): ... @@ -173,15 +174,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... - values: list[VyperNode] = ... + values: list[ExprNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... - left: VyperNode = ... - right: VyperNode = ... + left: ExprNode = ... + right: ExprNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -195,13 +196,13 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: VyperNode = ... + func: ExprNode = ... class keyword(VyperNode): ... class Attribute(VyperNode): attr: str = ... - value: VyperNode = ... + value: ExprNode = ... class Subscript(VyperNode): slice: VyperNode = ... @@ -224,8 +225,8 @@ class VariableDecl(VyperNode): class AugAssign(VyperNode): op: VyperNode = ... - target: VyperNode = ... - value: VyperNode = ... + target: ExprNode = ... + value: ExprNode = ... class Raise(VyperNode): ... class Assert(VyperNode): ... @@ -245,6 +246,12 @@ class ImplementsDecl(VyperNode): target: Name = ... annotation: Name = ... +class UsesDecl(VyperNode): + annotation: VyperNode = ... + +class InitializesDecl(VyperNode): + annotation: VyperNode = ... + class If(VyperNode): body: list = ... orelse: list = ... @@ -254,6 +261,10 @@ class IfExp(ExprNode): body: ExprNode = ... orelse: ExprNode = ... +class NamedExpr(ExprNode): + target: Name = ... + value: ExprNode = ... + class For(VyperNode): target: ExprNode iter: ExprNode diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index fc99af901b..a10a840da0 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -278,8 +278,8 @@ def visit_For(self, node): # specific error message than "invalid type annotation" raise SyntaxException( "missing type annotation\n\n" - "(hint: did you mean something like " - f"`for {node.target.id}: uint256 in ...`?)\n", + " (hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)", self._source_code, node.lineno, node.col_offset, diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d2aefb2fd4..6e6cf4c662 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,13 +85,16 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _equality_attrs = ("_id",) _is_terminus = False - # helper function to deal with TYPE_DEFINITIONs + @property + def modifiability(self): + return self._modifiability + + # helper function to deal with TYPE_Ts def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: - # TODO using "TYPE_DEFINITION" is a kludge in derived classes, - # refactor me. - if expected_type == "TYPE_DEFINITION": + if TYPE_T.any().compare_type(expected_type): # try to parse the type - call type_from_annotation # for its side effects (will throw if is not a type) type_from_annotation(arg) @@ -130,7 +133,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: get_exact_type_from_node(arg) def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: - return self._modifiability >= modifiability + return self._modifiability <= modifiability def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index 72b05f15e3..3fad225b48 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,7 +1,7 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context from vyper.codegen.stmt import parse_body -from vyper.semantics.analysis.local import FunctionNodeVisitor +from vyper.semantics.analysis.local import FunctionAnalyzer from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability from vyper.semantics.types.module import ModuleT @@ -25,9 +25,7 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) - # The FunctionNodeVisitor's constructor performs semantic checks - # annotate the AST as side effects. - analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionAnalyzer(ast_code, ast_code.body[0], namespace) analyzer.analyze() new_context = Context( diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 50ab4dacd8..7575f4d77e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -113,10 +113,7 @@ class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. - - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TypeTypeDefinition` object in `infer_arg_types`. - _inputs = [("typename", "TYPE_DEFINITION")] + _inputs = [("typename", TYPE_T.any())] def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef @@ -711,7 +708,7 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunctionT): _id = "method_id" _inputs = [("value", StringT.any())] - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BytesT(4))} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BytesT(4))} def _try_fold(self, node): validate_call_args(node, 1, ["output_type"]) @@ -848,10 +845,7 @@ def _storage_element_getter(index): class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] - # "TYPE_DEFINITION" is a placeholder value for a type definition string, and - # will be replaced by a `TYPE_T` object in `infer_kwarg_types` - # (note that it is ignored in _validate_arg_types) - _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} + _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)} def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1976,18 +1970,22 @@ def build_IR(self, expr, args, kwargs, context): class UnsafeAdd(_UnsafeMath): + _id = "unsafe_add" op = "add" class UnsafeSub(_UnsafeMath): + _id = "unsafe_sub" op = "sub" class UnsafeMul(_UnsafeMath): + _id = "unsafe_mul" op = "mul" class UnsafeDiv(_UnsafeMath): + _id = "unsafe_div" op = "div" @@ -2474,7 +2472,7 @@ def build_IR(self, expr, args, kwargs, context): class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" - _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] + _inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} def fetch_call_return(self, node): diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 4f644841f4..af01c5b504 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -44,7 +44,7 @@ def __repr__(self): return f"VariableRecord({ret})" -# Contains arguments, variables, etc +# compilation context for a function class Context: def __init__( self, @@ -59,19 +59,12 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} # Is the function constant? self.constancy = constancy - # Whether body is currently in an assert statement - # XXX: dead, never set to True - self.in_assertion = False - # Whether we are currently parsing a range expression self.in_range_expr = False @@ -87,6 +80,10 @@ def __init__( # Not intended to be accessed directly self.memory_allocator = memory_allocator + # save the starting memory location so we can find out (later) + # how much memory this function uses. + self.starting_memory = memory_allocator.next_mem + # Incremented values, used for internal IDs self._internal_var_iter = 0 self._scope_id_iter = 0 @@ -95,7 +92,7 @@ def __init__( self.is_ctor_context = is_ctor_context def is_constant(self): - return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr + return self.constancy is Constancy.Constant or self.in_range_expr def check_is_not_constant(self, err, expr): if self.is_constant(): @@ -250,9 +247,7 @@ def lookup_var(self, varname): # Pretty print constancy for error messages def pp_constancy(self): - if self.in_assertion: - return "an assertion" - elif self.in_range_expr: + if self.in_range_expr: return "a range expression" elif self.constancy == Constancy.Constant: return "a constant function" diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c3215f8c16..1a090ac316 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -3,9 +3,18 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -100,6 +109,36 @@ def _codecopy_gas_bound(num_bytes): return GAS_COPY_WORD * ceil32(num_bytes) // 32 +def data_location_to_address_space(s: DataLocation, is_ctor_ctx: bool) -> AddrSpace: + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.TRANSIENT: + return TRANSIENT + if s == DataLocation.CODE: + if is_ctor_ctx: + return IMMUTABLES + return DATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + +def address_space_to_data_location(s: AddrSpace) -> DataLocation: + if s == MEMORY: + return DataLocation.MEMORY + if s == STORAGE: + return DataLocation.STORAGE + if s == TRANSIENT: + return DataLocation.TRANSIENT + if s in (IMMUTABLES, DATA): + return DataLocation.CODE + if s == CALLDATA: + return DataLocation.CALLDATA + + raise CompilerPanic("unreachable!") # pragma: nocover + + # Copy byte array word-for-word (including layout) # TODO make this a private function def make_byte_array_copier(dst, src): @@ -482,14 +521,10 @@ def _get_element_ptr_tuplelike(parent, key): return _getelemptr_abi_helper(parent, member_t, ofst) - if parent.location.word_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words - elif parent.location.byte_addressable: - for i in range(index): - ofst += typ.member_types[attrs[i]].memory_bytes_required - else: - raise CompilerPanic(f"bad location {parent.location}") # pragma: notest + data_location = address_space_to_data_location(parent.location) + for i in range(index): + t = typ.member_types[attrs[i]] + ofst += t.get_size_in(data_location) return IRnode.from_list( add_ofst(parent, ofst), @@ -550,12 +585,8 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) - if parent.location.word_addressable: - element_size = subtype.storage_size_in_words - elif parent.location.byte_addressable: - element_size = subtype.memory_bytes_required - else: - raise CompilerPanic("unreachable") # pragma: notest + data_location = address_space_to_data_location(parent.location) + element_size = subtype.get_size_in(data_location) ofst = _mul(ix, element_size) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index f4c7948382..335cfefb87 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -6,6 +6,7 @@ from vyper.codegen import external_call, self_call from vyper.codegen.core import ( clamp, + data_location_to_address_space, ensure_in_memory, get_dyn_array_count, get_element_ptr, @@ -23,7 +24,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import MEMORY from vyper.evm.opcodes import version_check from vyper.exceptions import ( CodegenPanic, @@ -185,26 +186,24 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: - varinfo = self.context.globals[self.expr.id] - + elif (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: return Expr.parse_value_expr(varinfo.decl_node.value, self.context) assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset + mutable = self.context.is_ctor_context - if self.context.is_ctor_context: - mutable = True - location = IMMUTABLES - else: - mutable = False - location = DATA + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=self.expr.id, + mutable=mutable, ) ret._referenced_variables = {varinfo} return ret @@ -265,9 +264,13 @@ def parse_Attribute(self): return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE + elif (varinfo := self.expr._expr_info.var_info) is not None: + if varinfo.is_constant: + return Expr.parse_value_expr(varinfo.decl_node.value, self.context) + + location = data_location_to_address_space( + varinfo.location, self.context.is_ctor_context + ) ret = IRnode.from_list( varinfo.position.position, @@ -700,7 +703,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 94617bef35..254b4df72c 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1,4 @@ -from .common import FuncIR, generate_ir_for_function # noqa +from .external_function import generate_ir_for_external_function +from .internal_function import generate_ir_for_internal_function + +__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function] # type: ignore diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 5877ff3d13..d017ba7b81 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -2,17 +2,14 @@ from functools import cached_property from typing import Optional -import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function -from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator -from vyper.exceptions import CompilerPanic +from vyper.evm.opcodes import version_check from vyper.semantics.types import VyperType -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, StateMutability from vyper.semantics.types.module import ModuleT -from vyper.utils import MemoryPositions, calc_mem_gas +from vyper.utils import MemoryPositions @dataclass @@ -53,9 +50,11 @@ def ir_identifier(self) -> str: return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: + # XXX: when can this happen? if self.frame_info is not None: - raise CompilerPanic(f"frame_info already set for {self.func_t}!") - self.frame_info = frame_info + assert frame_info == self.frame_info + else: + self.frame_info = frame_info @property # common entry point for external function with kwargs @@ -64,13 +63,15 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" - suffix = "_deploy" if is_ctor_context else "_runtime" - return self.ir_identifier + suffix + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context -class FuncIR: - pass + suffix = "_deploy" if is_ctor_context else "_runtime" + return self.ir_identifier + suffix @dataclass @@ -80,7 +81,7 @@ class EntryPointInfo: ir_node: IRnode # the ir for this entry point def __post_init__(self): - # ABI v2 property guaranteed by the spec. + # sanity check ABI v2 properties guaranteed by the spec. # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 # > Note that for any X, len(enc(X)) is a multiple of 32. assert self.min_calldatasize >= 4 @@ -88,34 +89,28 @@ def __post_init__(self): @dataclass -class ExternalFuncIR(FuncIR): +class ExternalFuncIR: entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points common_ir: IRnode # the "common" code for the function @dataclass -class InternalFuncIR(FuncIR): +class InternalFuncIR: func_ir: IRnode # the code for the function -# TODO: should split this into external and internal ir generation? -def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False -) -> FuncIR: - """ - Parse a function and produce IR code for the function, includes: - - Signature method if statement - - Argument handling - - Clamping and copying of arguments - - Function body - """ - func_t = code._metadata["func_type"] - - # generate _FuncIRInfo +def init_ir_info(func_t: ContractFunctionT): + # initialize IRInfo on the function func_t._ir_info = _FuncIRInfo(func_t) - callees = func_t.called_functions +def initialize_context( + func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False +): + init_ir_info(func_t) + + # calculate starting frame + callees = func_t.called_functions # we start our function frame from the largest callee frame max_callee_frame_size = 0 for c_func_t in callees: @@ -126,7 +121,7 @@ def generate_ir_for_function( memory_allocator = MemoryAllocator(allocate_start) - context = Context( + return Context( vars_=None, module_ctx=module_ctx, memory_allocator=memory_allocator, @@ -135,38 +130,41 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal: - ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) - func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore - else: - kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) - entry_points = { - k: EntryPointInfo(func_t, mincalldatasize, ir_node) - for k, (mincalldatasize, ir_node) in kwarg_handlers.items() - } - ret = ExternalFuncIR(entry_points, common) - # note: this ignores the cost of traversing selector table - func_t._ir_info.gas_estimate = ret.common_ir.gas +def tag_frame_info(func_t, context): frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY + frame_start = context.starting_memory - frame_info = FrameInfo(allocate_start, frame_size, context.vars) + frame_info = FrameInfo(frame_start, frame_size, context.vars) + func_t._ir_info.set_frame_info(frame_info) - # XXX: when can this happen? - if func_t._ir_info.frame_info is None: - func_t._ir_info.set_frame_info(frame_info) - else: - assert frame_info == func_t._ir_info.frame_info - - if not func_t.is_internal: - # adjust gas estimate to include cost of mem expansion - # frame_size of external function includes all private functions called - # (note: internal functions do not need to adjust gas estimate since - mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore - ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + return frame_info + + +def get_nonreentrant_lock(func_t): + if not func_t.nonreentrant: + return ["pass"], ["pass"] + + nkey = func_t.reentrancy_key_position.position + + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + + if version_check(begin="berlin"): + # any nonzero values would work here (see pricing as of net gas + # metering); these values are chosen so that downgrading to the + # 0,1 scheme (if it is somehow necessary) is safe. + final_value, temp_value = 3, 2 else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + final_value, temp_value = 0, 1 + + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - return ret + if func_t.mutability == StateMutability.VIEW: + return [check_notset], [["seq"]] + + else: + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] + return [pre], [post] diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 65276469e7..b380eab2ce 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -2,12 +2,19 @@ from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp from vyper.codegen.expr import Expr -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + EntryPointInfo, + ExternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body from vyper.evm.address_space import CALLDATA, DATA, MEMORY from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import calc_mem_gas # register function args with the local calling context. @@ -51,7 +58,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list def _generate_kwarg_handlers( func_t: ContractFunctionT, context: Context -) -> dict[str, tuple[int, IRnode]]: +) -> dict[str, EntryPointInfo]: # generate kwarg handlers. # since they might come in thru calldata or be default, # allocate them in memory and then fill it in based on calldata or default, @@ -126,34 +133,54 @@ def handler_for(calldata_kwargs, default_kwargs): default_kwargs = keyword_args[i:] sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) return ret -def generate_ir_for_external_function(code, func_t, context): +def _adjust_gas_estimate(func_t, common_ir): + # adjust gas estimate to include cost of mem expansion + # frame_size of external function includes all private functions called + # (note: internal functions do not need to adjust gas estimate since + frame_info = func_t._ir_info.frame_info + + mem_expansion_cost = calc_mem_gas(frame_info.mem_used) + common_ir.add_gas_estimate += mem_expansion_cost + func_t._ir_info.gas_estimate = common_ir.gas + + # pass metadata through for venom pipeline: + common_ir.passthrough_metadata["func_t"] = func_t + common_ir.passthrough_metadata["frame_info"] = frame_info + + +def generate_ir_for_external_function(code, compilation_target): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, - # func_t: ContractFunctionT, - # context: Context, + # compilation_target: ModuleT, # ) -> IRnode: """ Return the IR for an external function. Returns IR for the body of the function, handle kwargs and exit the function. Also returns metadata required for `module.py` to construct the selector table. """ + func_t = code._metadata["func_type"] + assert func_t.is_external or func_t.is_constructor # sanity check + + context = initialize_context(func_t, compilation_target, func_t.is_constructor) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) # generate handlers for base args and register the variable records handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context) + entry_points = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, @@ -185,4 +212,8 @@ def generate_ir_for_external_function(code, func_t, context): # besides any kwarg handling func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return kwarg_handlers, func_common_ir + tag_frame_info(func_t, context) + + _adjust_gas_estimate(func_t, func_common_ir) + + return ExternalFuncIR(entry_points, func_common_ir) diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..0cf9850b70 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -1,23 +1,25 @@ from vyper import ast as vy_ast -from vyper.codegen.context import Context -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + InternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import IRnode from vyper.codegen.stmt import parse_body -from vyper.semantics.types.function import ContractFunctionT def generate_ir_for_internal_function( - code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context -) -> IRnode: + code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool +) -> InternalFuncIR: """ Parse a internal function (FuncDef), and produce full function body. :param func_t: the ContractFunctionT :param code: ast of function - :param context: current calling context + :param compilation_target: current calling context :return: function body in IR """ - # The calling convention is: # Caller fills in argument buffer # Caller provides return address, return buffer on the stack @@ -37,13 +39,19 @@ def generate_ir_for_internal_function( # situation like the following is easy to bork: # x: T[2] = [self.generate_T(), self.generate_T()] - # Get nonreentrant lock + func_t = code._metadata["func_type"] + + # sanity check + assert func_t.is_internal or func_t.is_constructor + + context = initialize_context(func_t, module_ctx, is_ctor_context) for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable context.new_variable(arg.name, arg.typ, is_mutable=True) + # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context) @@ -69,5 +77,13 @@ def generate_ir_for_internal_function( ] ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + + # tag gas estimate and frame info + func_t._ir_info.gas_estimate = ir_node.gas + frame_info = tag_frame_info(func_t, context) + + # pass metadata through for venom pipeline: + ir_node.passthrough_metadata["frame_info"] = frame_info ir_node.passthrough_metadata["func_t"] = func_t - return ir_node + + return InternalFuncIR(ir_node) diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py deleted file mode 100644 index f524ec6e88..0000000000 --- a/vyper/codegen/function_definitions/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from vyper.evm.opcodes import version_check -from vyper.semantics.types.function import StateMutability - - -def get_nonreentrant_lock(func_type): - if not func_type.nonreentrant: - return ["pass"], ["pass"] - - nkey = func_type.reentrancy_key_position.position - - LOAD, STORE = "sload", "sstore" - if version_check(begin="cancun"): - LOAD, STORE = "tload", "tstore" - - if version_check(begin="berlin"): - # any nonzero values would work here (see pricing as of net gas - # metering); these values are chosen so that downgrading to the - # 0,1 scheme (if it is somehow necessary) is safe. - final_value, temp_value = 3, 2 - else: - final_value, temp_value = 0, 1 - - check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - - if func_type.mutability == StateMutability.VIEW: - return [check_notset], [["seq"]] - - else: - pre = ["seq", check_notset, [STORE, nkey, temp_value]] - post = [STORE, nkey, final_value] - return [pre], [post] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 98395a6a0c..fef4f23949 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,10 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr -from vyper.codegen.function_definitions import generate_ir_for_function +from vyper.codegen.function_definitions import ( + generate_ir_for_external_function, + generate_ir_for_internal_function, +) from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic @@ -89,7 +92,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): callvalue_check = ["assert", ["iszero", "callvalue"]] ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + func_ir = generate_ir_for_external_function(func_ast, *args, **kwargs) assert len(func_ir.entry_points) == 1 # add a goto to make the function entry look like other functions @@ -101,7 +104,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): def _ir_for_internal_function(func_ast, *args, **kwargs): - return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + return generate_ir_for_internal_function(func_ast, *args, **kwargs).func_ir def _generate_external_entry_points(external_functions, module_ctx): @@ -109,7 +112,7 @@ def _generate_external_entry_points(external_functions, module_ctx): sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_external_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -424,12 +427,13 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: + # XXX: rename `module_ctx` to `compilation_target` # order functions so that each function comes after all of its callees function_defs = _topsort(module_ctx.function_defs) reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + init_function = next((f for f in module_ctx.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -475,24 +479,21 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: deploy_code: List[Any] = ["seq"] immutables_len = module_ctx.immutable_section_bytes - if init_function: + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + func_ir = _ir_for_internal_function(fn_ast, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 7d4938f287..e6baea75f7 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -144,7 +144,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 5b7decec7b..f7eccdf214 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -152,23 +152,18 @@ def vyper_module(self): return self._generate_ast @cached_property - def _annotated_module(self): - return generate_annotated_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override - ) - - @property def annotated_vyper_module(self) -> vy_ast.Module: - module, storage_layout = self._annotated_module - return module + return generate_annotated_ast(self.vyper_module, self.input_bundle) - @property + @cached_property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._annotated_module - return storage_layout + module_ast = self.annotated_vyper_module + return set_data_positions(module_ast, self.storage_layout_override) @property def global_ctx(self) -> ModuleT: + # ensure storage layout is computed + _ = self.storage_layout return self.annotated_vyper_module._metadata["type"] @cached_property @@ -243,11 +238,7 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast( - vyper_module: vy_ast.Module, - input_bundle: InputBundle, - storage_layout_overrides: StorageLayout = None, -) -> tuple[vy_ast.Module, StorageLayout]: +def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -268,9 +259,7 @@ def generate_annotated_ast( # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) - symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - - return vyper_module, symbol_tables + return vyper_module def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 85a75c3c23..fcbd4bcf63 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -28,14 +28,6 @@ class AddrSpace: # TODO maybe make positional instead of defaulting to None store_op: Optional[str] = None - @property - def word_addressable(self) -> bool: - return self.word_scale == 1 - - @property - def byte_addressable(self) -> bool: - return self.word_scale == 32 - # alternative: # class Memory(AddrSpace): diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 04667aaa59..53ad6f7bb8 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -31,7 +31,7 @@ class _BaseVyperException(Exception): order to display source annotations in the error string. """ - def __init__(self, message="Error Message not found.", *items): + def __init__(self, message="Error Message not found.", *items, hint=None): """ Exception initializer. @@ -47,7 +47,9 @@ def __init__(self, message="Error Message not found.", *items): A single tuple of (lineno, col_offset) is also understood to support the old API, but new exceptions should not use this approach. """ - self.message = message + self._message = message + self._hint = hint + self.lineno = None self.col_offset = None self.annotations = None @@ -77,6 +79,13 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def message(self): + msg = self._message + if self._hint: + msg += f"\n\n (hint: {self._hint})" + return msg + def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code @@ -131,7 +140,7 @@ def __str__(self): annotation_list.append(node_msg) annotation_msg = "\n".join(annotation_list) - return f"{self.message}\n{annotation_msg}" + return f"{self.message}\n\n{annotation_msg}" class VyperException(_BaseVyperException): @@ -252,6 +261,14 @@ class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" +class InitializerException(VyperException): + """An issue with initializing/constructing a module""" + + +class BorrowException(VyperException): + """An issue with borrowing/using a module""" + + class StateAccessViolation(VyperException): """Violating the mutability of a function definition.""" @@ -369,7 +386,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): except _BaseVyperException as e: if not e.annotations and not e.lineno: tb = e.__traceback__ - raise e.with_annotation(node).with_traceback(tb) + raise e.with_annotation(node).with_traceback(tb) from None raise e from None except Exception as e: tb = e.__traceback__ diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7b52a68e92..e23b2d2aa4 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,4 +1,4 @@ from .. import types # break a dependency cycle. -from .module import validate_semantics +from .global_ import validate_semantics __all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index bb6d9ad9f7..2086e5f9da 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,84 +1,29 @@ import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Optional, Union from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle -from vyper.exceptions import ( - CompilerPanic, - ImmutableViolation, - StateAccessViolation, - VyperInternalException, -) +from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.utils import OrderedSet, StringEnum if TYPE_CHECKING: from vyper.semantics.types.module import InterfaceT, ModuleT -class _StringEnum(enum.Enum): - @staticmethod - def auto(): - return enum.auto() +class FunctionVisibility(StringEnum): + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + DEPLOY = enum.auto() - # Must be first, or else won't work, specifies what .value is - def _generate_next_value_(name, start, count, last_values): - return name.lower() - # Override ValueError with our own internal exception - @classmethod - def _missing_(cls, value): - raise VyperInternalException(f"{value} is not a valid {cls.__name__}") - - @classmethod - def is_valid_value(cls, value: str) -> bool: - return value in set(o.value for o in cls) - - @classmethod - def options(cls) -> List["_StringEnum"]: - return list(cls) - - @classmethod - def values(cls) -> List[str]: - return [v.value for v in cls.options()] - - # Comparison operations - def __eq__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - return self is other - - # Python normally does __ne__(other) ==> not self.__eq__(other) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") - options = self.__class__.options() - return options.index(self) < options.index(other) # type: ignore - - def __le__(self, other: object) -> bool: - return self.__eq__(other) or self.__lt__(other) - - def __gt__(self, other: object) -> bool: - return not self.__le__(other) - - def __ge__(self, other: object) -> bool: - return self.__eq__(other) or self.__gt__(other) - - -class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() - - -class StateMutability(_StringEnum): - # TODO: these can just be enum.auto() right? - PURE = _StringEnum.auto() - VIEW = _StringEnum.auto() - NONPAYABLE = _StringEnum.auto() - PAYABLE = _StringEnum.auto() +class StateMutability(StringEnum): + PURE = enum.auto() + VIEW = enum.auto() + NONPAYABLE = enum.auto() + PAYABLE = enum.auto() @classmethod def from_abi(cls, abi_dict: Dict) -> "StateMutability": @@ -103,71 +48,40 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # and variables) and Constancy (in codegen). context.Constancy can/should # probably be refactored away though as those kinds of checks should be done # during analysis. -class Modifiability(enum.IntEnum): - # is writeable/can result in arbitrary state or memory changes - MODIFIABLE = enum.auto() - - # could potentially add more fine-grained here as needed, like - # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. +class Modifiability(StringEnum): + # compile-time / always constant + CONSTANT = enum.auto() # things that are constant within the current message call, including # block.*, msg.*, tx.* and immutables RUNTIME_CONSTANT = enum.auto() - # compile-time / always constant - CONSTANT = enum.auto() - - -class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset - - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" - - -class StorageSlot(DataPosition): - __slots__ = ("position",) - _location = DataLocation.STORAGE + # could potentially add more fine-grained here as needed, like + # CONSTANT_AFTER_DEPLOY, TX_CONSTANT, BLOCK_CONSTANT, etc. - def __init__(self, position): - self.position = position + # is writeable/can result in arbitrary state or memory changes + MODIFIABLE = enum.auto() - def __repr__(self): - return f"" + @classmethod + def from_state_mutability(cls, mutability: StateMutability): + if mutability == StateMutability.PURE: + return cls.CONSTANT + if mutability == StateMutability.VIEW: + return cls.RUNTIME_CONSTANT + # sanity check in case more StateMutability levels are added in the future + assert mutability in (StateMutability.PAYABLE, StateMutability.NONPAYABLE) + return cls.MODIFIABLE -class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE +@dataclass +class VarOffset: + position: int - def __init__(self, offset): - self.offset = offset - def __repr__(self): - return f"" +class ModuleOwnership(StringEnum): + NO_OWNERSHIP = enum.auto() # readable + USES = enum.auto() # writeable + INITIALIZES = enum.auto() # initializes # base class for things that are the "result" of analysis @@ -178,6 +92,9 @@ class AnalysisResult: @dataclass class ModuleInfo(AnalysisResult): module_t: "ModuleT" + alias: str + ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP + ownership_decl: Optional[vy_ast.VyperNode] = None @property def module_node(self): @@ -188,6 +105,16 @@ def module_node(self): def typ(self): return self.module_t + def set_ownership(self, module_ownership: ModuleOwnership, node: Optional[vy_ast.VyperNode]): + if self.ownership != ModuleOwnership.NO_OWNERSHIP: + raise StructureException( + f"ownership already set to `{self.ownership}`", node, self.ownership_decl + ) + self.ownership = module_ownership + + def __hash__(self): + return hash(id(self.module_t)) + @dataclass class ImportInfo(AnalysisResult): @@ -199,6 +126,21 @@ class ImportInfo(AnalysisResult): node: vy_ast.VyperNode +# analysis result of InitializesDecl +@dataclass +class InitializesInfo(AnalysisResult): + module_info: ModuleInfo + dependencies: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + +# analysis result of UsesDecl +@dataclass +class UsesInfo(AnalysisResult): + used_modules: list[ModuleInfo] + node: Optional[vy_ast.VyperNode] = None + + @dataclass class VarInfo: """ @@ -221,22 +163,21 @@ def __hash__(self): return hash(id(self)) def __post_init__(self): + self.position = None self._modification_count = 0 - def set_position(self, position: DataPosition) -> None: - if hasattr(self, "position"): + def set_position(self, position: VarOffset) -> None: + if self.position is not None: raise CompilerPanic("Position was already assigned") - if self.location != position._location: - if self.location == DataLocation.UNSET: - self.location = position._location - elif self.is_transient and position._location == DataLocation.STORAGE: - # CMC 2023-12-31 - use same allocator for storage and transient - # for now, this should be refactored soon. - pass - else: - raise CompilerPanic("Incompatible locations") + assert isinstance(position, VarOffset) # sanity check self.position = position + def is_module_variable(self): + return self.location not in (DataLocation.UNSET, DataLocation.MEMORY) + + def get_size(self) -> int: + return self.typ.get_size_in(self.location) + @property def is_transient(self): return self.location == DataLocation.TRANSIENT @@ -260,9 +201,13 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None + module_info: Optional[ModuleInfo] = None location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE + # the chain of attribute parents for this expr + attribute_chain: list["ExprInfo"] = field(default_factory=list) + def __post_init__(self): should_match = ("typ", "location", "modifiability") if self.var_info is not None: @@ -270,65 +215,48 @@ def __post_init__(self): if getattr(self.var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") + self._writes: OrderedSet[VarInfo] = OrderedSet() + self._reads: OrderedSet[VarInfo] = OrderedSet() + + # find exprinfo in the attribute chain which has a varinfo + # e.x. `x` will return varinfo for `x` + # `module.foo` will return varinfo for `module.foo` + # `self.my_struct.x.y` will return varinfo for `self.my_struct` + def get_root_varinfo(self) -> Optional[VarInfo]: + for expr_info in self.attribute_chain + [self]: + if expr_info.var_info is not None: + return expr_info.var_info + return None + @classmethod - def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": + def from_varinfo(cls, var_info: VarInfo, attribute_chain=None) -> "ExprInfo": return cls( var_info.typ, var_info=var_info, location=var_info.location, modifiability=var_info.modifiability, + attribute_chain=attribute_chain or [], ) @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) + def from_moduleinfo(cls, module_info: ModuleInfo, attribute_chain=None) -> "ExprInfo": + modifiability = Modifiability.RUNTIME_CONSTANT + if module_info.ownership >= ModuleOwnership.USES: + modifiability = Modifiability.MODIFIABLE - def copy_with_type(self, typ: VyperType) -> "ExprInfo": + return cls( + module_info.module_t, + module_info=module_info, + modifiability=modifiability, + attribute_chain=attribute_chain or [], + ) + + def copy_with_type(self, typ: VyperType, attribute_chain=None) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ to_copy = ("location", "modifiability") fields = {k: getattr(self, k) for k in to_copy} + if attribute_chain is not None: + fields["attribute_chain"] = attribute_chain return self.__class__(typ=typ, **fields) - - def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutability) -> None: - """ - Validate an attempt to modify this value. - - Raises if the value is a constant or involves an invalid operation. - - Arguments - --------- - node : Assign | AugAssign | Call - Vyper ast node of the modifying action. - mutability: StateMutability - The mutability of the context (e.g., pure function) we are currently in - """ - if mutability <= StateMutability.VIEW and self.location == DataLocation.STORAGE: - raise StateAccessViolation( - f"Cannot modify storage in a {mutability.value} function", node - ) - - if self.location == DataLocation.CALLDATA: - raise ImmutableViolation("Cannot write to calldata", node) - - if self.modifiability == Modifiability.RUNTIME_CONSTANT: - if self.location == DataLocation.CODE: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": - raise ImmutableViolation("Immutable value cannot be written to", node) - - # special handling for immutable variables in the ctor - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore - raise ImmutableViolation( - "Immutable value cannot be modified after assignment", node - ) - self.var_info._modification_count += 1 # type: ignore - else: - raise ImmutableViolation("Environment variable cannot be written to", node) - - if self.modifiability == Modifiability.CONSTANT: - raise ImmutableViolation("Constant value cannot be written to", node) - - if isinstance(node, vy_ast.AugAssign): - self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/constant_folding.py b/vyper/semantics/analysis/constant_folding.py index bfcc473d09..3522383167 100644 --- a/vyper/semantics/analysis/constant_folding.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -113,7 +113,7 @@ def visit_Attribute(self, node) -> vy_ast.ExprNode: varinfo = module_t.get_member(node.attr, node) return varinfo.decl_node.value.get_folded_value() - except (VyperException, AttributeError): + except (VyperException, AttributeError, KeyError): raise UnfoldableNode("not a module") def visit_UnaryOp(self, node): diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..604bc6b594 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,11 +1,12 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List +from collections import defaultdict +from typing import Generic, TypeVar from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.evm.opcodes import version_check +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import VarOffset +from vyper.semantics.data_locations import DataLocation from vyper.typing import StorageLayout -from vyper.utils import ceil32 def set_data_positions( @@ -20,24 +21,76 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + if storage_layout_overrides is not None: + # extract code layout with no overrides + code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] + storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) + return {"storage_layout": storage_slots, "code_layout": code_offsets} - return {"storage_layout": storage_slots, "code_layout": code_offsets} + ret = _allocate_layout_r(vyper_module) + assert isinstance(ret, defaultdict) + return dict(ret) # convert back to dict -class StorageAllocator: +_T = TypeVar("_T") +_K = TypeVar("_K") + + +class InsertableOnceDict(Generic[_T, _K], dict[_T, _K]): + def __setitem__(self, k, v): + if k in self: + raise ValueError(f"{k} is already in dict!") + super().__setitem__(k, v) + + +class SimpleAllocator: + def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): + # Allocate storage slots from 0 + # note storage is word-addressable, not byte-addressable + self._slot = starting_slot + self._max_slot = max_slot + + def allocate_slot(self, n, var_name, node=None): + ret = self._slot + if self._slot + n >= self._max_slot: + raise StorageLayoutException( + f"Invalid storage slot, tried to allocate" + f" slots {self._slot} through {self._slot + n}", + node, + ) + self._slot += n + return ret + + +class Allocators: + storage_allocator: SimpleAllocator + transient_storage_allocator: SimpleAllocator + immutables_allocator: SimpleAllocator + + def __init__(self): + self.storage_allocator = SimpleAllocator(max_slot=2**256) + self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) + self.immutables_allocator = SimpleAllocator(max_slot=0x6000) + + def get_allocator(self, location: DataLocation): + if location == DataLocation.STORAGE: + return self.storage_allocator + if location == DataLocation.TRANSIENT: + return self.transient_storage_allocator + if location == DataLocation.CODE: + return self.immutables_allocator + + raise CompilerPanic("unreachable") # pragma: nocover + + +class OverridingStorageAllocator: """ Keep track of which storage slots have been used. If there is a collision of storage slots, this will raise an error and fail to compile """ def __init__(self): - self.occupied_slots: Dict[int, str] = {} + self.occupied_slots: dict[int, str] = {} def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: """ @@ -48,7 +101,7 @@ def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> No list_to_check = [x + first_slot for x in range(n_slots)] self._reserve_slots(list_to_check, var_name) - def _reserve_slots(self, slots: List[int], var_name: str) -> None: + def _reserve_slots(self, slots: list[int], var_name: str) -> None: for slot in slots: self._reserve_slot(slot, var_name) @@ -70,12 +123,13 @@ def set_storage_slots_with_overrides( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout ) -> StorageLayout: """ - Parse module-level Vyper AST to calculate the layout of storage variables. + Set storage layout given a layout override file. Returns the layout as a dict of variable name -> variable info + (Doesn't handle modules, or transient storage) """ - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() + ret: InsertableOnceDict[str, dict] = InsertableOnceDict() + reserved_slots = OverridingStorageAllocator() # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -90,7 +144,7 @@ def set_storage_slots_with_overrides( # re-entrant key was already identified if variable_name in ret: _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -100,7 +154,7 @@ def set_storage_slots_with_overrides( # from using the same slot reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) + type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} else: @@ -125,7 +179,7 @@ def set_storage_slots_with_overrides( # Ensure that all required storage slots are reserved, and prevents other variables # from using these slots reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) + varinfo.set_position(VarOffset(var_slot)) ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} else: @@ -138,105 +192,108 @@ def set_storage_slots_with_overrides( return ret -class SimpleStorageAllocator: - def __init__(self, starting_slot: int = 0): - self._slot = starting_slot +def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: + allocable = (vy_ast.InitializesDecl, vy_ast.VariableDecl) + return [node for node in vyper_module.body if isinstance(node, allocable)] - def allocate_slot(self, n, var_name): - ret = self._slot - if self._slot + n >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, tried to allocate" - f" slots {self._slot} through {self._slot + n}" - ) - self._slot += n - return ret +def get_reentrancy_key_location() -> DataLocation: + if version_check(begin="cancun"): + return DataLocation.TRANSIENT + return DataLocation.STORAGE -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: + +_LAYOUT_KEYS = { + DataLocation.CODE: "code_layout", + DataLocation.TRANSIENT: "transient_storage_layout", + DataLocation.STORAGE: "storage_layout", +} + + +def _allocate_layout_r( + vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False +) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - # Allocate storage slots from 0 - # note storage is word-addressable, not byte-addressable - allocator = SimpleStorageAllocator() + if allocators is None: + allocators = Allocators() - ret: Dict[str, Dict] = {} + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) for node in vyper_module.get_children(vy_ast.FunctionDef): + if immutables_only: + break + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue variable_name = f"nonreentrant.{type_.nonreentrant}" + reentrancy_key_location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[reentrancy_key_location] # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) + if variable_name in ret[layout_key]: + _slot = ret[layout_key][variable_name]["slot"] + type_.set_reentrancy_key_position(VarOffset(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + allocator = allocators.get_allocator(reentrancy_key_location) + slot = allocator.allocate_slot(1, variable_name, node) - type_.set_reentrancy_key_position(StorageSlot(slot)) + type_.set_reentrancy_key_position(VarOffset(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} - - for node in vyper_module.get_children(vy_ast.VariableDecl): - # skip non-storage variables - if node.is_constant or node.is_immutable: + ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _allocate_layout_r(module_info.module_node, allocators) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + ret[layout_key][module_alias] = module_layout[layout_key] continue + assert isinstance(node, vy_ast.VariableDecl) + # skip non-storage variables varinfo = node.target._metadata["varinfo"] - type_ = varinfo.typ - - # CMC 2021-07-23 note that HashMaps get assigned a slot here. - # I'm not sure if it's safe to avoid allocating that slot - # for HashMaps because downstream code might use the slot - # ID as a salt. - n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) - - varinfo.set_position(StorageSlot(slot)) - - # this could have better typing but leave it untyped until - # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} - - return ret - - -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + if not varinfo.is_module_variable(): + continue + location = varinfo.location + if immutables_only and location != DataLocation.CODE: + continue -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + allocator = allocators.get_allocator(location) + size = varinfo.get_size() + # CMC 2021-07-23 note that HashMaps get assigned a slot here + # using the same allocator (even though there is not really + # any risk of physical overlap) + offset = allocator.allocate_slot(size, node.target.id, node) -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: - ret = {} - offset = 0 + varinfo.set_position(VarOffset(offset)) - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] + layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ - varinfo.set_position(CodeOffset(offset)) - - len_ = ceil32(type_.size_in_bytes) - # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} - - offset += len_ + if location == DataLocation.CODE: + item = {"type": str(type_), "length": size, "offset": offset} + elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + item = {"type": str(type_), "slot": offset} + else: # pragma: nocover + raise CompilerPanic("unreachable") + ret[layout_key][node.target.id] = item return ret diff --git a/vyper/semantics/analysis/global_.py b/vyper/semantics/analysis/global_.py new file mode 100644 index 0000000000..92cdf35c5d --- /dev/null +++ b/vyper/semantics/analysis/global_.py @@ -0,0 +1,80 @@ +from collections import defaultdict + +from vyper.exceptions import ExceptionList, InitializerException +from vyper.semantics.analysis.base import InitializesInfo, UsesInfo +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.module import validate_module_semantics_r +from vyper.semantics.types.module import ModuleT + + +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + _validate_global_initializes_constraint(ret) + + return ret + + +def _collect_used_modules_r(module_t): + ret: defaultdict[ModuleT, list[UsesInfo]] = defaultdict(list) + + for uses_decl in module_t.uses_decls: + for used_module in uses_decl._metadata["uses_info"].used_modules: + ret[used_module.module_t].append(uses_decl) + + # recurse + used_modules = _collect_used_modules_r(used_module.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + # also recurse into modules used by initialized modules + for i in module_t.initialized_modules: + used_modules = _collect_used_modules_r(i.module_info.module_t) + for k, v in used_modules.items(): + ret[k].extend(v) + + return ret + + +def _collect_initialized_modules_r(module_t, seen=None): + seen: dict[ModuleT, InitializesInfo] = seen or {} + + # list of InitializedInfo + initialized_infos = module_t.initialized_modules + + for i in initialized_infos: + initialized_module_t = i.module_info.module_t + if initialized_module_t in seen: + seen_nodes = (i.node, seen[initialized_module_t].node) + raise InitializerException(f"`{i.module_info.alias}` initialized twice!", *seen_nodes) + seen[initialized_module_t] = i + + _collect_initialized_modules_r(initialized_module_t, seen) + + return seen + + +# validate that each module which is `used` in the import graph is +# `initialized`. +def _validate_global_initializes_constraint(module_t: ModuleT): + all_used_modules = _collect_used_modules_r(module_t) + all_initialized_modules = _collect_initialized_modules_r(module_t) + + err_list = ExceptionList() + + for u, uses in all_used_modules.items(): + if u not in all_initialized_modules: + found_module = module_t.find_module_info(u) + if found_module is not None: + hint = f"add `initializes: {found_module.alias}` to the top level of " + hint += "your main contract" + else: + # CMC 2024-02-06 is this actually reachable? + hint = f"ensure `{module_t}` is imported in your main contract!" + err_list.append( + InitializerException( + f"module `{u}` is used but never initialized!", *uses, hint=hint + ) + ) + + err_list.raise_if_not_empty() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91cc0ebdf8..d96215ede0 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,8 +1,11 @@ +# CMC 2024-02-03 TODO: split me into function.py and expr.py + from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.exceptions import ( + CallViolation, ExceptionList, FunctionDeclarationException, ImmutableViolation, @@ -16,7 +19,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import Modifiability, VarInfo +from vyper.semantics.analysis.base import Modifiability, ModuleOwnership, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( get_common_types, @@ -54,13 +57,12 @@ def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" - err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionAnalyzer(vy_module, node, namespace) analyzer.analyze() except VyperException as e: err_list.append(e) @@ -181,7 +183,7 @@ def _validate_self_reference(node: vy_ast.Name) -> None: raise StateAccessViolation("not allowed to query self in pure functions", node) -class FunctionNodeVisitor(VyperNodeVisitorBase): +class FunctionAnalyzer(VyperNodeVisitorBase): ignored_types = (vy_ast.Pass,) scope_name = "function" @@ -192,7 +194,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["func_type"] - self.expr_visitor = ExprVisitor(self.func) + self.expr_visitor = ExprVisitor(self) def analyze(self): # allow internal function params to be mutable @@ -270,21 +272,94 @@ def _assign_helper(self, node): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): - raise StructureException( - "Left-hand side of assignment cannot be a HashMap without a key", node - ) - target.validate_modification(node, self.func.mutability) + # check mutability of the function + self._handle_modification(node.target) self.expr_visitor.visit(node.value, target.typ) self.expr_visitor.visit(node.target, target.typ) + def _handle_modification(self, target: vy_ast.ExprNode): + if isinstance(target, vy_ast.Tuple): + for item in target.elements: + self._handle_modification(item) + return + + # check a modification of `target`. validate the modification is + # valid, and log the modification in relevant data structures. + func_t = self.func + info = get_expr_info(target) + + if isinstance(info.typ, HashMapT): + raise StructureException( + "Left-hand side of assignment cannot be a HashMap without a key" + ) + + if ( + info.location in (DataLocation.STORAGE, DataLocation.TRANSIENT) + and func_t.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot modify {info.location} variable in a {func_t.mutability} function" + ) + + if info.location == DataLocation.CALLDATA: + raise ImmutableViolation("Cannot write to calldata") + + if info.modifiability == Modifiability.RUNTIME_CONSTANT: + if info.location == DataLocation.CODE: + if not func_t.is_constructor: + raise ImmutableViolation("Immutable value cannot be written to") + + # handle immutables + if info.var_info is not None: # don't handle complex (struct,array) immutables + # special handling for immutable variables in the ctor + # TODO: maybe we want to remove this restriction. + if info.var_info._modification_count != 0: + raise ImmutableViolation( + "Immutable value cannot be modified after assignment" + ) + info.var_info._modification_count += 1 + else: + raise ImmutableViolation("Environment variable cannot be written to") + + if info.modifiability == Modifiability.CONSTANT: + raise ImmutableViolation("Constant value cannot be written to.") + + var_info = info.get_root_varinfo() + assert var_info is not None + + info._writes.add(var_info) + + def _check_module_use(self, target: vy_ast.ExprNode): + module_infos = [] + for t in get_expr_info(target).attribute_chain: + if t.module_info is not None: + module_infos.append(t.module_info) + + if len(module_infos) == 0: + return + + for module_info in module_infos: + if module_info.ownership < ModuleOwnership.USES: + msg = f"Cannot access `{module_info.alias}` state!" + hint = f"add `uses: {module_info.alias}` or " + hint += f"`initializes: {module_info.alias}` as " + hint += "a top-level statement to your contract" + raise ImmutableViolation(msg, hint=hint) + + # the leftmost- referenced module + root_module_info = module_infos[0] + + # log the access + self.func._used_modules.add(root_module_info) + def visit_Assign(self, node): self._assign_helper(node) def visit_AugAssign(self, node): self._assign_helper(node) + node.target._expr_info.typ.validate_numeric_op(node) def visit_Break(self, node): for_node = node.get_ancestor(vy_ast.For) @@ -309,35 +384,13 @@ def visit_Expr(self, node): raise StructureException("Expressions without assignment are disallowed", node) fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): raise StructureException("To call an event you must use the `log` statement", node) if is_type_t(fn_type, StructT): raise StructureException("Struct creation without assignment is disallowed", node) - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - # NOTE: fetch_call_return validates call args. return_value = map_void(fn_type.fetch_call_return(node.value)) if ( @@ -457,7 +510,7 @@ def visit_Log(self, node): raise StructureException("Value is not an event", node.value) if self.func.mutability <= StateMutability.VIEW: raise StructureException( - f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node + f"Cannot emit logs from {self.func.mutability} functions", node ) t = map_void(f.fetch_call_return(node.value)) # CMC 2024-02-05 annotate the event type for codegen usage @@ -493,10 +546,20 @@ def visit_Return(self, node): class ExprVisitor(VyperNodeVisitorBase): - scope_name = "function" + def __init__(self, function_analyzer: Optional[FunctionAnalyzer] = None): + self.function_analyzer = function_analyzer + + @property + def func(self): + if self.function_analyzer is None: + return None + return self.function_analyzer.func - def __init__(self, fn_node: Optional[ContractFunctionT] = None): - self.func = fn_node + @property + def scope_name(self): + if self.func is not None: + return "function" + return "module" def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): @@ -509,6 +572,24 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + if not isinstance(typ, TYPE_T): + info = get_expr_info(node) # get_expr_info fills in node._expr_info + + # log variable accesses. + # (note writes will get logged as both read+write) + varinfo = info.var_info + if varinfo is not None: + info._reads.add(varinfo) + + if self.func: + variable_accesses = info._writes | info._reads + for s in variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node) + + self.func._variable_writes.update(info._writes) + self.func._variable_reads.update(info._reads) + # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() @@ -547,42 +628,77 @@ def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: for value in node.values: self.visit(value, BoolT()) + def _check_call_mutability(self, call_mutability: StateMutability): + # note: payable can be called from nonpayable functions + ok = ( + call_mutability <= self.func.mutability + or self.func.mutability >= StateMutability.NONPAYABLE + ) + if not ok: + msg = f"Cannot call a {call_mutability} function from a {self.func.mutability} function" + raise StateAccessViolation(msg) + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: - call_type = get_exact_type_from_node(node.func) - self.visit(node.func, call_type) + func_info = get_expr_info(node.func, is_callable=True) + func_type = func_info.typ + self.visit(node.func, func_type) - if isinstance(call_type, ContractFunctionT): + if isinstance(func_type, ContractFunctionT): # function calls - if self.func and call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): + + func_info._writes.update(func_type._variable_writes) + func_info._reads.update(func_type._variable_reads) + + if self.function_analyzer: + if func_type.is_internal: + self.func.called_functions.add(func_type) + + self._check_call_mutability(func_type.mutability) + + # check that if the function accesses state, the defining + # module has been `used` or `initialized`. + for s in func_type._variable_accesses: + if s.is_module_variable(): + self.function_analyzer._check_module_use(node.func) + + if func_type.is_deploy and not self.func.is_deploy: + raise CallViolation( + f"Cannot call an @{func_type.visibility} function from " + f"an @{self.func.visibility} function!", + node, + ) + + for arg, typ in zip(node.args, func_type.argument_types): self.visit(arg, typ) for kwarg in node.keywords: # We should only see special kwargs - typ = call_type.call_site_kwargs[kwarg.arg].typ + typ = func_type.call_site_kwargs[kwarg.arg].typ self.visit(kwarg.value, typ) - elif is_type_t(call_type, EventT): + elif is_type_t(func_type, EventT): # events have no kwargs - expected_types = call_type.typedef.arguments.values() + expected_types = func_type.typedef.arguments.values() # type: ignore for arg, typ in zip(node.args, expected_types): self.visit(arg, typ) - elif is_type_t(call_type, StructT): + elif is_type_t(func_type, StructT): # struct ctors # ctors have no kwargs - expected_types = call_type.typedef.members.values() + expected_types = func_type.typedef.members.values() # type: ignore for value, arg_type in zip(node.args[0].values, expected_types): self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): + elif isinstance(func_type, MemberFunctionT): + if func_type.is_modifying and self.function_analyzer is not None: + # TODO refactor this + self.function_analyzer._handle_modification(node.func) + assert len(node.args) == len(func_type.arg_types) + for arg, arg_type in zip(node.args, func_type.arg_types): self.visit(arg, arg_type) else: # builtin functions - arg_types = call_type.infer_arg_types(node, expected_return_typ=typ) + arg_types = func_type.infer_arg_types(node, expected_return_typ=typ) # type: ignore for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) + kwarg_types = func_type.infer_kwarg_types(node) # type: ignore for kwarg in node.keywords: self.visit(kwarg.value, kwarg_types[kwarg.arg]) @@ -638,8 +754,10 @@ def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: self.visit(element, typ.value_type) def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: - if self.func and self.func.mutability == StateMutability.PURE: - _validate_self_reference(node) + if self.func: + # TODO: refactor to use expr_info mutability + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: if isinstance(typ, TYPE_T): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index a83c2f3b7d..e50c3e6d6f 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -8,38 +8,50 @@ from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle from vyper.evm.opcodes import version_check from vyper.exceptions import ( + BorrowException, CallViolation, DuplicateImport, ExceptionList, + ImmutableViolation, + InitializerException, InvalidLiteral, InvalidType, ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, - SyntaxException, + UndeclaredDefinition, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ( + ImportInfo, + InitializesInfo, + Modifiability, + ModuleInfo, + ModuleOwnership, + UsesInfo, + VarInfo, +) from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node +from vyper.semantics.analysis.utils import ( + check_modifiability, + get_exact_type_from_node, + get_expr_info, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation +from vyper.utils import OrderedSet -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) - - -def validate_semantics_r( +def validate_module_semantics_r( module_ast: vy_ast.Module, input_bundle: InputBundle, import_graph: ImportGraph, @@ -49,6 +61,11 @@ def validate_semantics_r( Analyze a Vyper module AST node, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ + if "type" in module_ast._metadata: + # we don't need to analyse again, skip out + assert isinstance(module_ast._metadata["type"], ModuleT) + return module_ast._metadata["type"] + validate_literal_nodes(module_ast) # validate semantics and annotate AST with type/semantics information @@ -64,6 +81,8 @@ def validate_semantics_r( # in `ContractFunction.from_vyi()` if not is_interface: validate_functions(module_ast) + analyzer.validate_initialized_modules() + analyzer.validate_used_modules() return ret @@ -121,11 +140,8 @@ def __init__( def analyze(self) -> ModuleT: # generate a `ModuleT` from the top-level node # note: also validates unique method ids - if "type" in self.ast._metadata: - assert isinstance(self.ast._metadata["type"], ModuleT) - # we don't need to analyse again, skip out - self.module_t = self.ast._metadata["type"] - return self.module_t + + assert "type" not in self.ast._metadata to_visit = self.ast.body.copy() @@ -138,6 +154,11 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + ownership_decls = self.ast.get_children((vy_ast.UsesDecl, vy_ast.InitializesDecl)) + for node in ownership_decls: + self.visit(node) + to_visit.remove(node) + # we can resolve constants after imports are handled. constant_fold(self.ast) @@ -179,6 +200,7 @@ def analyze(self) -> ModuleT: def analyze_call_graph(self): # get list of internal function calls made by each function + # CMC 2024-02-03 note: this could be cleaner in analysis/local.py function_defs = self.module_t.function_defs for func in function_defs: @@ -195,7 +217,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -204,6 +228,106 @@ def analyze_call_graph(self): # compute reachable set and validate the call graph _compute_reachable_set(fn_t) + def validate_used_modules(self): + # check all `uses:` modules are actually used + should_use = {} + + module_t = self.ast._metadata["type"] + uses_decls = module_t.uses_decls + for decl in uses_decls: + info = decl._metadata["uses_info"] + for m in info.used_modules: + should_use[m.module_t] = (m, info) + + initialized_modules = {t.module_info.module_t: t for t in module_t.initialized_modules} + + all_used_modules = OrderedSet() + + for f in module_t.functions.values(): + for u in f._used_modules: + all_used_modules.add(u.module_t) + + for used_module in all_used_modules: + if used_module in initialized_modules: + continue + + if used_module in should_use: + del should_use[used_module] + + if len(should_use) > 0: + err_list = ExceptionList() + for used_module_info, uses_info in should_use.values(): + msg = f"`{used_module_info.alias}` is declared as used, but " + msg += f"it is not actually used in {module_t}!" + hint = f"delete `uses: {used_module_info.alias}`" + err_list.append(BorrowException(msg, uses_info.node, hint=hint)) + + err_list.raise_if_not_empty() + + def validate_initialized_modules(self): + # check all `initializes:` modules have `__init__()` called exactly once + module_t = self.ast._metadata["type"] + should_initialize = {t.module_info.module_t: t for t in module_t.initialized_modules} + # don't call `__init__()` for modules which don't have + # `__init__()` function + for m in should_initialize.copy(): + for f in m.functions.values(): + if f.is_constructor: + break + else: + del should_initialize[m] + + init_calls = [] + for f in self.ast.get_children(vy_ast.FunctionDef): + if f._metadata["func_type"].is_constructor: + init_calls = f.get_descendants(vy_ast.Call) + break + + seen_initializers = {} + for call_node in init_calls: + expr_info = call_node.func._expr_info + if expr_info is None: + # this can happen for range() calls; CMC 2024-02-05 try to + # refactor so that range() is properly tagged. + continue + + call_t = call_node.func._expr_info.typ + + if not isinstance(call_t, ContractFunctionT): + continue + + if not call_t.is_constructor: + continue + + # XXX: check this works as expected for nested attributes + initialized_module = call_node.func.value._expr_info.module_info + + if initialized_module.module_t in seen_initializers: + seen_location = seen_initializers[initialized_module.module_t] + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but its __init__() function was already called!" + raise InitializerException(msg, call_node.func, seen_location) + + if initialized_module.module_t not in should_initialize: + msg = f"tried to initialize `{initialized_module.alias}`, " + msg += "but it is not in initializer list!" + hint = f"add `initializes: {initialized_module.alias}` " + hint += "as a top-level statement to your contract" + raise InitializerException(msg, call_node.func, hint=hint) + + del should_initialize[initialized_module.module_t] + seen_initializers[initialized_module.module_t] = call_node.func + + if len(should_initialize) > 0: + err_list = ExceptionList() + for s in should_initialize.values(): + msg = "not initialized!" + hint = f"add `{s.module_info.alias}.__init__()` to " + hint += "your `__init__()` function" + err_list.append(InitializerException(msg, s.node, hint=hint)) + + err_list.raise_if_not_empty() + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: # cache ast if we have seen it before. # this gives us the additional property of object equality on @@ -218,10 +342,100 @@ def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) if not isinstance(type_, InterfaceT): - raise StructureException("Invalid interface name", node.annotation) + raise StructureException("not an interface!", node.annotation) type_.validate_implements(node) + def visit_UsesDecl(self, node): + # TODO: check duplicate uses declarations, e.g. + # uses: x + # ... + # uses: x + items = vy_ast.as_tuple(node.annotation) + + used_modules = [] + + for item in items: + module_info = get_expr_info(item).module_info + if module_info is None: + raise StructureException("not a valid module!", item) + + # note: try to refactor - not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.USES, item) + + used_modules.append(module_info) + + node._metadata["uses_info"] = UsesInfo(used_modules, node) + + def visit_InitializesDecl(self, node): + module_ref = node.annotation + dependencies_ast = () + if isinstance(module_ref, vy_ast.Subscript): + dependencies_ast = vy_ast.as_tuple(module_ref.slice) + module_ref = module_ref.value + + # postcondition of InitializesDecl.validates() + assert isinstance(module_ref, (vy_ast.Name, vy_ast.Attribute)) + + module_info = get_expr_info(module_ref).module_info + if module_info is None: + raise StructureException("Not a module!", module_ref) + + used_modules = {i.module_t: i for i in module_info.module_t.used_modules} + + dependencies = [] + for named_expr in dependencies_ast: + assert isinstance(named_expr, vy_ast.NamedExpr) + + rhs_module = get_expr_info(named_expr.value).module_info + + with module_info.module_node.namespace(): + # lhs of the named_expr is evaluated in the namespace of the + # initialized module! + try: + lhs_module = get_expr_info(named_expr.target).module_info + except VyperException as e: + # try to report a common problem - user names the module in + # the current namespace instead of the initialized module + # namespace. + + # search for the module in the initialized module + found_module = module_info.module_t.find_module_info(rhs_module.module_t) + if found_module is not None: + msg = f"unknown module `{named_expr.target.id}`" + hint = f"did you mean `{found_module.alias} := {rhs_module.alias}`?" + raise UndeclaredDefinition(msg, named_expr.target, hint=hint) + + raise e from None + + if lhs_module.module_t != rhs_module.module_t: + raise StructureException( + f"{lhs_module.alias} is not {rhs_module.alias}!", named_expr + ) + dependencies.append(lhs_module) + + if lhs_module.module_t not in used_modules: + raise InitializerException( + f"`{module_info.alias}` is initialized with `{lhs_module.alias}`, " + f"but `{module_info.alias}` does not use `{lhs_module.alias}`!", + named_expr, + ) + + del used_modules[lhs_module.module_t] + + if len(used_modules) > 0: + item = next(iter(used_modules.values())) # just pick one + msg = f"`{module_info.alias}` uses `{item.alias}`, but it is not " + msg += f"initialized with `{item.alias}`" + hint = f"add `{item.alias}` to its initializer list" + raise InitializerException(msg, node, hint=hint) + + # note: try to refactor. not a huge fan of mutating the + # ModuleInfo after it's constructed + module_info.set_ownership(ModuleOwnership.INITIALIZES, node) + node._metadata["initializes_info"] = InitializesInfo(module_info, dependencies, node) + def visit_VariableDecl(self, node): name = node.get("target.id") if name is None: @@ -250,7 +464,7 @@ def visit_VariableDecl(self, node): if len(wrong_self_attribute) > 0 else "Immutable definition requires an assignment in the constructor" ) - raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) + raise ImmutableViolation(message, node) data_loc = ( DataLocation.CODE @@ -364,11 +578,10 @@ def visit_Import(self, node): # don't handle things like `import x.y` if "." in alias: + msg = "import requires an accompanying `as` statement" suggested_alias = node.name[node.name.rfind(".") :] - suggestion = f"hint: try `import {node.name} as {suggested_alias}`" - raise StructureException( - f"import requires an accompanying `as` statement ({suggestion})", node - ) + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) self._add_import(node, 0, node.name, alias) @@ -436,14 +649,14 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - module_t = validate_semantics_r( + module_t = validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, is_interface=False, ) - return ModuleInfo(module_t) + return ModuleInfo(module_t, alias) except FileNotFoundError as e: # escape `e` from the block scope, it can make things @@ -456,7 +669,7 @@ def _load_import_helper( module_ast = self._ast_from_file(file) with override_global_namespace(Namespace()): - validate_semantics_r( + validate_module_semantics_r( module_ast, self.input_bundle, import_graph=self._import_graph, @@ -481,7 +694,7 @@ def _load_import_helper( raise ModuleNotFound(module_str, node) from err -def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: +def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, @@ -542,5 +755,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: interface_ast = _parse_and_fold_ast(file) with override_global_namespace(Namespace()): - module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + module_t = validate_module_semantics_r( + interface_ast, input_bundle, ImportGraph(), is_interface=True + ) return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index abbf6a68cc..f1f0f48a86 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -61,8 +61,8 @@ class _ExprAnalyser: def __init__(self): self.namespace = get_namespace() - def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: - t = self.get_exact_type_from_node(node) + def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> ExprInfo: + t = self.get_exact_type_from_node(node, include_type_exprs=is_callable) # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): @@ -74,7 +74,10 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, ModuleInfo): return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) + if isinstance(info, VyperType): + return ExprInfo(TYPE_T(info)) + + raise CompilerPanic(f"unreachable! {info}", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -82,30 +85,27 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # note: Attribute(expr value, identifier attr) name = node.attr - info = self.get_expr_info(node.value) + info = self.get_expr_info(node.value, is_callable=is_callable) + + attribute_chain = info.attribute_chain + [info] t = info.typ.get_member(name, node) # it's a top-level variable if isinstance(t, VarInfo): - return ExprInfo.from_varinfo(t) + return ExprInfo.from_varinfo(t, attribute_chain=attribute_chain) - # it's something else, like my_struct.foo - return info.copy_with_type(t) + if isinstance(t, ModuleInfo): + return ExprInfo.from_moduleinfo(t, attribute_chain=attribute_chain) - if isinstance(node, vy_ast.Tuple): - # always use the most restrictive location re: modification - # kludge! for validate_modification in local analysis of Assign - types = [self.get_expr_info(n) for n in node.elements] - location = sorted((i.location for i in types), key=lambda k: k.value)[-1] - modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1] - - return ExprInfo(t, location=location, modifiability=modifiability) + # it's something else, like my_struct.foo + return info.copy_with_type(t, attribute_chain=attribute_chain) # If it's a Subscript, propagate the subscriptable varinfo if isinstance(node, vy_ast.Subscript): info = self.get_expr_info(node.value) - return info.copy_with_type(t) + attribute_chain = info.attribute_chain + [info] + return info.copy_with_type(t, attribute_chain=attribute_chain) return ExprInfo(t) @@ -184,6 +184,7 @@ def _find_fn(self, node): def types_from_Attribute(self, node): is_self_reference = node.get("value.id") == "self" + # variable attribute, e.g. `foo.bar` t = self.get_exact_type_from_node(node.value, include_type_exprs=True) name = node.attr @@ -476,8 +477,10 @@ def get_exact_type_from_node(node): return _ExprAnalyser().get_exact_type_from_node(node, include_type_exprs=True) -def get_expr_info(node: vy_ast.VyperNode) -> ExprInfo: - return _ExprAnalyser().get_expr_info(node) +def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo: + if node._expr_info is None: + node._expr_info = _ExprAnalyser().get_expr_info(node, is_callable) + return node._expr_info def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List: @@ -639,7 +642,7 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: +def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ @@ -665,5 +668,5 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> if hasattr(call_type, "check_modifiability_for_call"): return call_type.check_modifiability_for_call(node, modifiability) - value_type = get_expr_info(node) - return value_type.modifiability >= modifiability + info = get_expr_info(node) + return info.modifiability <= modifiability diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index cecea35a60..06245aa90d 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -1,10 +1,12 @@ import enum +from vyper.utils import StringEnum -class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - TRANSIENT = 5 + +class DataLocation(StringEnum): + UNSET = enum.auto() + MEMORY = enum.auto() + STORAGE = enum.auto() + CALLDATA = enum.auto() + CODE = enum.auto() + TRANSIENT = enum.auto() diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d659276ee0..c5e10b52be 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -25,7 +26,11 @@ def __init__(self, type_): self.type_ = type_ def compare_type(self, other): - return isinstance(other, self.type_) or self == other + if isinstance(other, self.type_): + return True + # compare two GenericTypeAcceptors -- they are the same if the base + # type is the same + return isinstance(other, self.__class__) and other.type_ == self.type_ class VyperType: @@ -91,6 +96,8 @@ def __hash__(self): return hash(self._get_equality_attrs()) def __eq__(self, other): + if self is other: + return True return ( type(self) is type(other) and self._get_equality_attrs() == other._get_equality_attrs() ) @@ -118,6 +125,16 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + def get_size_in(self, location: DataLocation): + if location in (DataLocation.STORAGE, DataLocation.TRANSIENT): + return self.storage_size_in_words + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.CODE: + return self.memory_bytes_required + + raise CompilerPanic("unreachable: invalid location {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: # alias for API compatibility with codegen @@ -341,8 +358,10 @@ def map_void(typ: Optional[VyperType]) -> VyperType: # A type type. Used internally for types which can live in expression # position, ex. constructors (events, interfaces and structs), and also # certain builtins which take types as parameters -class TYPE_T: +class TYPE_T(VyperType): def __init__(self, typedef): + super().__init__() + self.typedef = typedef def __repr__(self): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 2d92370b9d..62f9c60585 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -19,8 +19,10 @@ from vyper.semantics.analysis.base import ( FunctionVisibility, Modifiability, + ModuleInfo, StateMutability, - StorageSlot, + VarInfo, + VarOffset, ) from vyper.semantics.analysis.utils import ( check_modifiability, @@ -112,10 +114,27 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # writes to variables from this function + self._variable_writes: OrderedSet[VarInfo] = OrderedSet() + + # reads of variables from this function + self._variable_reads: OrderedSet[VarInfo] = OrderedSet() + + # list of modules used (accessed state) by this function + self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + @property + def _variable_accesses(self): + return self._variable_reads | self._variable_writes + + @property + def modifiability(self): + return Modifiability.from_state_mutability(self.mutability) + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site @@ -269,9 +288,11 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - "function body in an interface can only be ...!", funcdef + "function body in an interface can only be `...`!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -314,13 +335,19 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Default function may not receive any arguments", funcdef.args.args[0] ) + if function_visibility == FunctionVisibility.DEPLOY and funcdef.name != "__init__": + raise FunctionDeclarationException( + "Only constructors can be marked as `@deploy`!", funcdef + ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility != FunctionVisibility.DEPLOY: + raise FunctionDeclarationException( + f"Constructor must be marked as `@deploy`, not `@{function_visibility}`", + funcdef, ) if return_type is not None: raise FunctionDeclarationException( @@ -333,6 +360,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -344,14 +374,11 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ast_def=funcdef, ) - def set_reentrancy_key_position(self, position: StorageSlot) -> None: + def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") - # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: - raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position @classmethod @@ -456,6 +483,14 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_deploy(self) -> bool: + return self.visibility == FunctionVisibility.DEPLOY + + @property + def is_constructor(self) -> bool: + return self.name == "__init__" + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -464,10 +499,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -535,20 +566,14 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: modified_line = re.sub( kwarg_pattern, kwarg.value.node_source_code, node.node_source_code ) - error_suggestion = ( - f"\n(hint: Try removing the kwarg: `{modified_line}`)" - if modified_line != node.node_source_code - else "" - ) - raise ArgumentException( - ( - "Usage of kwarg in Vyper is restricted to " - + ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) - + f". {error_suggestion}" - ), - kwarg, - ) + msg = "Usage of kwarg in Vyper is restricted to " + msg += ", ".join([f"{k}=" for k in self.call_site_kwargs.keys()]) + + hint = None + if modified_line != node.node_source_code: + hint = f"Try removing the kwarg: `{modified_line}`" + raise ArgumentException(msg, kwarg, hint=hint) return self.return_type @@ -601,7 +626,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -632,7 +657,9 @@ def _parse_decorators( if FunctionVisibility.is_valid_value(decorator.id): if function_visibility is not None: raise FunctionDeclarationException( - f"Visibility is already set to: {function_visibility}", funcdef + f"Visibility is already set to: {function_visibility}", + decorator, + hint="only one visibility decorator is allowed per function", ) function_visibility = FunctionVisibility(decorator.id) @@ -748,6 +775,10 @@ def __init__( self.return_type = return_type self.is_modifying = is_modifying + @property + def modifiability(self): + return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT + def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ee1da22a87..86840f4f91 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType @@ -16,12 +16,16 @@ validate_expected_type, validate_unique_method_ids, ) +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType +if TYPE_CHECKING: + from vyper.semantics.analysis.base import ModuleInfo + class InterfaceT(_UserType): _type_members = {"address": AddressT()} @@ -234,7 +238,7 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for node in module_t.function_defs: func_t = node._metadata["func_type"] - if not func_t.is_external: + if not (func_t.is_external or func_t.is_constructor): continue funcs.append((node.name, func_t)) @@ -276,6 +280,12 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): _attribute_in_annotation = True + _invalid_locations = ( + DataLocation.CALLDATA, + DataLocation.CODE, + DataLocation.MEMORY, + DataLocation.TRANSIENT, + ) def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -307,7 +317,6 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.interface_defs: # add the type of the interface so it can be used in call position self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore - self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -316,6 +325,13 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): import_info = i._metadata["import_info"] self.add_member(import_info.alias, import_info.typ) + if hasattr(import_info.typ, "module_t"): + self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + + for name, interface_t in self.interfaces.items(): + # can access interfaces in type position + self._helper.add_member(name, TYPE_T(interface_t)) + # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle, # search path, symlinked vs normalized path, etc.) @@ -345,27 +361,97 @@ def struct_defs(self): def interface_defs(self): return self._module.get_children(vy_ast.InterfaceDef) + @cached_property + def interfaces(self) -> dict[str, InterfaceT]: + ret = {} + for i in self.interface_defs: + assert i.name not in ret # precondition + ret[i.name] = i._metadata["interface_type"] + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + if isinstance(import_info.typ, InterfaceT): + assert import_info.alias not in ret # precondition + ret[import_info.alias] = import_info.typ + + return ret + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + @cached_property + def imported_modules(self) -> dict[str, "ModuleInfo"]: + ret = {} + for s in self.import_stmts: + info = s._metadata["import_info"] + module_info = info.typ + if isinstance(module_info, InterfaceT): + continue + ret[info.alias] = module_info + return ret + + def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: + for s in self.imported_modules.values(): + if s.module_t == needle: + return s + return None + @property def variable_decls(self): return self._module.get_children(vy_ast.VariableDecl) + @property + def uses_decls(self): + return self._module.get_children(vy_ast.UsesDecl) + + @property + def initializes_decls(self): + return self._module.get_children(vy_ast.InitializesDecl) + + @cached_property + def used_modules(self): + # modules which are written to + ret = [] + for node in self.uses_decls: + for used_module in node._metadata["uses_info"].used_modules: + ret.append(used_module) + return ret + + @property + def initialized_modules(self): + # modules which are initialized to + ret = [] + for node in self.initializes_decls: + info = node._metadata["initializes_info"] + ret.append(info) + return ret + @cached_property def variables(self): # variables that this module defines, ex. # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def functions(self): + return {f.name: f._metadata["func_type"] for f in self.function_defs} + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + ret = 0 + for s in self.immutables: + ret += s.typ.memory_bytes_required + + for initializes_info in self.initialized_modules: + module_t = initializes_info.module_info.module_t + ret += module_t.immutable_section_bytes + + return ret @cached_property def interface(self): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 5564570536..c6a4531df8 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -117,16 +117,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: if isinstance(node, vy_ast.Attribute): # ex. SomeModule.SomeStruct - # sanity check - we only allow modules/interfaces to be - # imported as `Name`s currently. - if not isinstance(node.value, vy_ast.Name): + if isinstance(node.value, vy_ast.Attribute): + module_or_interface = _type_from_annotation(node.value) + elif isinstance(node.value, vy_ast.Name): + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + else: raise InvalidType(err_msg, node) - try: - module_or_interface = namespace[node.value.id] # type: ignore - except UndeclaredDefinition: - raise InvalidType(err_msg, node) from None - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo module_or_interface = module_or_interface.module_t diff --git a/vyper/utils.py b/vyper/utils.py index 2349731b97..b2284eaba0 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -1,6 +1,7 @@ import binascii import contextlib import decimal +import enum import functools import sys import time @@ -8,7 +9,7 @@ import warnings from typing import Generic, List, TypeVar, Union -from vyper.exceptions import DecimalOverrideException, InvalidLiteral +from vyper.exceptions import CompilerPanic, DecimalOverrideException, InvalidLiteral _T = TypeVar("_T") @@ -62,6 +63,59 @@ def copy(self): return self.__class__(super().copy()) +class StringEnum(enum.Enum): + # Must be first, or else won't work, specifies what .value is + def _generate_next_value_(name, start, count, last_values): + return name.lower() + + # Override ValueError with our own internal exception + @classmethod + def _missing_(cls, value): + raise CompilerPanic(f"{value} is not a valid {cls.__name__}") + + @classmethod + def is_valid_value(cls, value: str) -> bool: + return value in set(o.value for o in cls) + + @classmethod + def options(cls) -> List["StringEnum"]: + return list(cls) + + @classmethod + def values(cls) -> List[str]: + return [v.value for v in cls.options()] + + # Comparison operations + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + return self is other + + # Python normally does __ne__(other) ==> not self.__eq__(other) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + raise CompilerPanic(f"bad comparison: ({type(other)}, {type(self)})") + options = self.__class__.options() + return options.index(self) < options.index(other) # type: ignore + + def __le__(self, other: object) -> bool: + return self.__eq__(other) or self.__lt__(other) + + def __gt__(self, other: object) -> bool: + return not self.__le__(other) + + def __ge__(self, other: object) -> bool: + return not self.__lt__(other) + + def __str__(self) -> str: + return self.value + + def __hash__(self) -> int: + # let `dataclass` know that this class is not mutable + return super().__hash__() + + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): if name == "prec":