From ddfce5273b39a199b194dd74f0f7f741efc03663 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 8 Jan 2024 02:37:01 +0800 Subject: [PATCH 1/3] feat: require type annotations for loop variables (#3596) this commit changes the vyper language to require type annotations for loop variables. that is, before, the following was allowed: ```vyper for i in [1, 2, 3]: pass ``` now, `i` is required to have a type annotation: ```vyper for i: uint256 in [1, 2, 3]: pass ``` this makes the annotation of loop variables consistent with the rest of vyper (it was previously a special case, that loop variables did not need to be annotated). the approach taken in this commit is to add a pre-parsing step which lifts out the type annotation into a separate data structure, and then splices it back in during the post-processing steps in `vyper/ast/parse.py`. this commit also simplifies a lot of analysis regarding for loops. notably, the possible types for the loop variable no longer needs to be iterated over, we can just propagate the type provided by the user. for this reason we also no longer need to use the typechecker speculation machinery for inferring the type of the loop variable. however, the NodeMetadata code is not removed because it might come in handy at a later date. --------- Co-authored-by: Charles Cooper --- examples/auctions/blind_auction.vy | 2 +- examples/tokens/ERC1155ownable.vy | 8 +- examples/voting/ballot.vy | 6 +- examples/wallet/wallet.vy | 4 +- .../functional/builtins/codegen/test_empty.py | 4 +- .../builtins/codegen/test_mulmod.py | 2 +- .../functional/builtins/codegen/test_slice.py | 2 +- .../codegen/features/iteration/test_break.py | 12 +- .../features/iteration/test_continue.py | 10 +- .../features/iteration/test_for_in_list.py | 150 ++++++++++-------- .../features/iteration/test_for_range.py | 56 +++---- .../codegen/features/test_assert.py | 4 +- .../codegen/features/test_internal_call.py | 2 +- .../codegen/integration/test_crowdfund.py | 4 +- .../codegen/types/numbers/test_decimals.py | 2 +- tests/functional/codegen/types/test_bytes.py | 2 +- .../codegen/types/test_bytes_zero_padding.py | 2 +- .../codegen/types/test_dynamic_array.py | 28 ++-- tests/functional/codegen/types/test_lists.py | 4 +- tests/functional/grammar/test_grammar.py | 2 +- .../exceptions/test_argument_exception.py | 4 +- .../exceptions/test_constancy_exception.py | 6 +- tests/functional/syntax/test_blockscope.py | 4 +- tests/functional/syntax/test_constants.py | 2 +- tests/functional/syntax/test_for_range.py | 58 +++---- tests/functional/syntax/test_list.py | 2 +- tests/unit/ast/nodes/test_hex.py | 2 +- .../ast/test_annotate_and_optimize_ast.py | 4 +- tests/unit/ast/test_pre_parser.py | 2 +- tests/unit/compiler/asm/test_asm_optimizer.py | 2 +- tests/unit/compiler/test_source_map.py | 2 +- .../unit/semantics/analysis/test_for_loop.py | 38 ++--- vyper/ast/grammar.lark | 3 +- vyper/ast/nodes.py | 23 ++- vyper/ast/parse.py | 66 +++++++- vyper/ast/pre_parser.py | 91 ++++++++++- vyper/builtins/functions.py | 2 +- vyper/codegen/stmt.py | 30 ++-- vyper/exceptions.py | 4 + vyper/semantics/analysis/local.py | 111 ++++--------- 40 files changed, 432 insertions(+), 330 deletions(-) diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 04f908f6d0..597aed57c7 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets: # Calculate refund for sender refund: uint256 = 0 - for i in range(MAX_BIDS): + for i: int128 in range(MAX_BIDS): # Note that loop may break sooner than 128 iterations if i >= _numBids if (i >= _numBids): break diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 30057582e8..e105a79133 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25 assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch" batchBalances: DynArray[uint256, BATCH_SIZE] = [] j: uint256 = 0 - for i in ids: + for i: uint256 in ids: batchBalances.append(self.balanceOf[accounts[j]][i]) j += 1 return batchBalances @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[receiver][ids[i]] += amounts[i] @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[msg.sender][ids[i]] -= amounts[i] @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID" assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break id: uint256 = ids[i] diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 0b568784a9..107716accf 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool: def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 - for i in range(2): + for i: int128 in range(2): self.proposals[i] = Proposal({ name: _proposalNames[i], voteCount: 0 @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address): assert self.voters[delegate_with_weight_to_forward].weight > 0 target: address = self.voters[delegate_with_weight_to_forward].delegate - for i in range(4): + for i: int128 in range(4): if self._delegated(target): target = self.voters[target].delegate # The following effectively detects cycles of length <= 5, @@ -157,7 +157,7 @@ def vote(proposal: int128): def _winningProposal() -> int128: winning_vote_count: int128 = 0 winning_proposal: int128 = 0 - for i in range(2): + for i: int128 in range(2): if self.proposals[i].voteCount > winning_vote_count: winning_vote_count = self.proposals[i].voteCount winning_proposal = i diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index e2515d9e62..231f538ecf 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -14,7 +14,7 @@ seq: public(int128) @external def __init__(_owners: address[5], _threshold: int128): - for i in range(5): + for i: uint256 in range(5): if _owners[i] != empty(address): self.owners[i] = _owners[i] self.threshold = _threshold @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda assert self.seq == _seq # # Iterates through all the owners and verifies that there signatures, # # given as the sigdata argument are correct - for i in range(5): + for i: uint256 in range(5): if sigdata[i][0] != 0: # If an invalid signature is given for an owner then the contract throws assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i] diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index c3627785dc..896c845da2 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @internal def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation): @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @external diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index ba82ebd5b8..31de1d9f22 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation): @external def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 - for i in range(256): + for i: uint256 in range(256): o = uint256_mulmod(o, o, modulus) if exponent & shift(1, 255 - i) != 0: o = uint256_mulmod(o, base, modulus) diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index a15a3eeb35..80936bbf82 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation): @external def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: inp: Bytes[50] = inp1 - for i in range(1, 11): + for i: uint256 in range(1, 11): inp = slice(inp, 1, 30 - i * 2) return inp """ diff --git a/tests/functional/codegen/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py index 8a08a11cc2..4abde9c617 100644 --- a/tests/functional/codegen/features/iteration/test_break.py +++ b/tests/functional/codegen/features/iteration/test_break.py @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(400): + for i: int128 in range(400): c = c / 1.2589 if c < 1.0: output = i @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c = c / 10.0 - for i in range(10): + for i: int128 in range(10): c = c / 1.2589 if c < 1.0: output = output + i @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation): def foo(n: int128) -> int128: c: decimal = convert(n, decimal) output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c /= 10.0 - for i in range(10): + for i: int128 in range(10): c /= 1.2589 if c < 1.0: output = output + i @@ -108,7 +108,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: break diff --git a/tests/functional/codegen/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py index 5f4f82a2de..1b2fcab460 100644 --- a/tests/functional/codegen/features/iteration/test_continue.py +++ b/tests/functional/codegen/features/iteration/test_continue.py @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation): code = """ @external def foo() -> bool: - for i in range(2): + for i: uint256 in range(2): continue return False return True @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += 1 continue x -= 1 @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += i continue return x @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): if i % 2 == 0: continue x += 1 @@ -83,7 +83,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: continue diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..5c7b5c6b1b 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -21,7 +21,7 @@ @external def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -33,7 +33,7 @@ def data() -> int128: @external def data() -> int128: s: DynArray[int128, 10] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -53,8 +53,8 @@ def data() -> int128: [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] ] ret: int128 = 0 - for ss in sss: - for s in ss: + for ss: DynArray[S, 10] in sss: + for s: S in ss: ret += s.x + s.y return ret""", sum(range(1, 11)), @@ -64,7 +64,7 @@ def data() -> int128: """ @external def data() -> int128: - for i in [3, 5, 7, 9]: + for i: int128 in [3, 5, 7, 9]: if i > 5: return i return -1""", @@ -76,7 +76,7 @@ def data() -> int128: @external def data() -> String[33]: xs: DynArray[String[33], 3] = ["hello", ",", "world"] - for x in xs: + for x: String[33] in xs: if x == ",": return x return "" @@ -88,7 +88,7 @@ def data() -> String[33]: """ @external def data() -> String[33]: - for x in ["hello", ",", "world"]: + for x: String[33] in ["hello", ",", "world"]: if x == ",": return x return "" @@ -100,7 +100,7 @@ def data() -> String[33]: """ @external def data() -> DynArray[String[33], 2]: - for x in [["hello", "world"], ["goodbye", "world!"]]: + for x: DynArray[String[33], 2] in [["hello", "world"], ["goodbye", "world!"]]: if x[1] == "world": return x return [] @@ -114,8 +114,8 @@ def data() -> DynArray[String[33], 2]: def data() -> int128: ret: int128 = 0 xss: int128[3][3] = [[1,2,3],[4,5,6],[7,8,9]] - for xs in xss: - for x in xs: + for xs: int128[3] in xss: + for x: int128 in xs: ret += x return ret""", sum(range(1, 10)), @@ -130,8 +130,8 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss in [[S({x:1, y:2})]]: - for s in ss: + for ss: S[1] in [[S({x:1, y:2})]]: + for s: S in ss: ret += s.x + s.y return ret""", 1 + 2, @@ -147,7 +147,7 @@ def data() -> address: 0xDCEceAF3fc5C0a63d195d69b1A90011B7B19650D ] count: int128 = 0 - for i in addresses: + for i: address in addresses: count += 1 if count == 2: return i @@ -174,7 +174,7 @@ def set(): @external def data() -> int128: - for i in self.x: + for i: int128 in self.x: if i > 5: return i return -1 @@ -198,7 +198,7 @@ def set(xs: DynArray[int128, 4]): @external def data() -> int128: t: int128 = 0 - for i in self.x: + for i: int128 in self.x: t += i return t """ @@ -227,7 +227,7 @@ def ret(i: int128) -> address: @external def iterate_return_second() -> address: count: int128 = 0 - for i in self.addresses: + for i: address in self.addresses: count += 1 if count == 2: return i @@ -258,7 +258,7 @@ def ret(i: int128) -> decimal: @external def i_return(break_count: int128) -> decimal: count: int128 = 0 - for i in self.readings: + for i: decimal in self.readings: if count == break_count: return i count += 1 @@ -284,7 +284,7 @@ def func(amounts: uint256[3]) -> uint256: total: uint256 = as_wei_value(0, "wei") # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -303,7 +303,7 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: total: uint256 = 0 # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -321,42 +321,42 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in range(4): + for i: int128 in range(4): p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass - for i in range(20): + for i: uint256 in range(20): pass """, # using index variable after loop """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass i: int128 = 100 # create new variable i i = 200 # look up the variable i and check whether it is in forvars @@ -372,25 +372,25 @@ def test_good_code(code, get_contract): RANGE_CONSTANT_CODE = [ ( """ -TREE_FIDDY: constant(int128) = 350 +TREE_FIDDY: constant(uint256) = 350 @external def a() -> uint256: x: uint256 = 0 - for i in range(TREE_FIDDY): + for i: uint256 in range(TREE_FIDDY): x += 1 return x""", 350, ), ( """ -ONE_HUNDRED: constant(int128) = 100 +ONE_HUNDRED: constant(uint256) = 100 @external def a() -> uint256: x: uint256 = 0 - for i in range(1, 1 + ONE_HUNDRED): + for i: uint256 in range(1, 1 + ONE_HUNDRED): x += 1 return x""", 100, @@ -401,9 +401,9 @@ def a() -> uint256: END: constant(int128) = 199 @external -def a() -> uint256: - x: uint256 = 0 - for i in range(START, END): +def a() -> int128: + x: int128 = 0 + for i: int128 in range(START, END): x += 1 return x""", 99, @@ -413,11 +413,23 @@ def a() -> uint256: @external def a() -> int128: x: int128 = 0 - for i in range(-5, -1): + for i: int128 in range(-5, -1): x += i return x""", -14, ), + ( + """ +@external +def a() -> uint256: + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + x: uint256 = 0 + for i: uint256 in a[2]: + x += i + return x + """, + 9, + ), ] @@ -436,7 +448,7 @@ def test_range_constant(get_contract, code, result): def data() -> int128: s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -451,7 +463,7 @@ def data() -> int128: def foo(): s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] += 1 """, ImmutableViolation, @@ -468,7 +480,7 @@ def set(): @external def data() -> int128: count: int128 = 0 - for i in self.s: + for i: int128 in self.s: self.s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -493,7 +505,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff(i) i += 1 """, @@ -519,7 +531,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.bar.foo: + for item: uint256 in self.my_array2.bar.foo: self.doStuff(i) i += 1 """, @@ -545,7 +557,7 @@ def doStuff(): @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff() i += 1 """, @@ -556,8 +568,8 @@ def _helper(): """ @external def foo(x: int128): - for i in range(4): - for i in range(5): + for i: int128 in range(4): + for i: int128 in range(5): pass """, NamespaceCollision, @@ -566,8 +578,8 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: - for i in [1,2]: + for i: int128 in [1,2]: + for i: int128 in [1,2]: pass """, NamespaceCollision, @@ -577,7 +589,7 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i = 2 """, ImmutableViolation, @@ -588,7 +600,7 @@ def foo(x: int128): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.pop() """, ImmutableViolation, @@ -599,7 +611,7 @@ def foo(): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.append(x) """, ImmutableViolation, @@ -610,7 +622,7 @@ def foo(): @external def foo(): xs: DynArray[DynArray[uint256, 5], 5] = [[1,2,3]] - for x in xs: + for x: DynArray[uint256, 5] in xs: x.pop() """, ImmutableViolation, @@ -629,7 +641,7 @@ def b(): @external def foo(): - for x in self.array: + for x: uint256 in self.array: self.a() """, ImmutableViolation, @@ -638,7 +650,7 @@ def foo(): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i += 2 """, ImmutableViolation, @@ -648,7 +660,7 @@ def foo(x: int128): """ @external def foo(): - for i in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -656,13 +668,13 @@ def foo(): """ @external def foo(): - for i in range(0): + for i: uint256 in range(0): pass """, """ @external def foo(): - for i in []: + for i: uint256 in []: pass """, """ @@ -670,14 +682,14 @@ def foo(): @external def foo(): - for i in FOO: + for i: uint256 in FOO: pass """, ( """ @external def foo(): - for i in range(5,3): + for i: uint256 in range(5,3): pass """, StructureException, @@ -686,7 +698,7 @@ def foo(): """ @external def foo(): - for i in range(5,3,-1): + for i: int128 in range(5,3,-1): pass """, ArgumentException, @@ -696,7 +708,7 @@ def foo(): @external def foo(): a: uint256 = 2 - for i in range(a): + for i: uint256 in range(a): pass """, StateAccessViolation, @@ -706,7 +718,7 @@ def foo(): @external def foo(): a: int128 = 6 - for i in range(a,a-3): + for i: int128 in range(a,a-3): pass """, StateAccessViolation, @@ -716,7 +728,7 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, ArgumentException, @@ -725,7 +737,7 @@ def foo(): """ @external def foo(): - for i in range(0,1,2): + for i: uint256 in range(0,1,2): pass """, ArgumentException, @@ -735,7 +747,7 @@ def foo(): """ @external def foo(): - for i in b"asdf": + for i: Bytes[1] in b"asdf": pass """, InvalidType, @@ -744,7 +756,7 @@ def foo(): """ @external def foo(): - for i in 31337: + for i: uint256 in 31337: pass """, InvalidType, @@ -753,7 +765,7 @@ def foo(): """ @external def foo(): - for i in bar(): + for i: uint256 in bar(): pass """, IteratorException, @@ -762,7 +774,7 @@ def foo(): """ @external def foo(): - for i in self.bar(): + for i: uint256 in self.bar(): pass """, IteratorException, @@ -772,11 +784,11 @@ def foo(): @external def test_for() -> int128: a: int128 = 0 - for i in range(max_value(int128), max_value(int128)+2): + for i: int128 in range(max_value(int128), max_value(int128)+2): a = i return a """, - TypeMismatch, + InvalidType, ), ( """ @@ -784,7 +796,7 @@ def test_for() -> int128: def test_for() -> int128: a: int128 = 0 b: uint256 = 0 - for i in range(5): + for i: int128 in range(5): a = i b = i return a diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e946447285..c661c46553 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -6,7 +6,7 @@ def test_basic_repeater(get_contract_with_gas_estimation): @external def repeat(z: int128) -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): x = x + z return(x) """ @@ -19,7 +19,7 @@ def test_range_bound(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, bound=6): + for i: uint256 in range(n, bound=6): x += i + 1 return x """ @@ -37,7 +37,7 @@ def test_range_bound_constant_end(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, 7, bound=6): + for i: uint256 in range(n, 7, bound=6): x += i + 1 return x """ @@ -58,7 +58,7 @@ def test_range_bound_two_args(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(1, n, bound=6): + for i: uint256 in range(1, n, bound=6): x += i + 1 return x """ @@ -80,7 +80,7 @@ def test_range_bound_two_runtime_args(get_contract, tx_failed): @external def repeat(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x += i return x """ @@ -109,7 +109,7 @@ def test_range_overflow(get_contract, tx_failed): @external def get_last(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x = i return x """ @@ -134,11 +134,11 @@ def test_digit_reverser(get_contract_with_gas_estimation): def reverse_digits(x: int128) -> int128: dig: int128[6] = [0, 0, 0, 0, 0, 0] z: int128 = x - for i in range(6): + for i: uint256 in range(6): dig[i] = z % 10 z = z / 10 o: int128 = 0 - for i in range(6): + for i: uint256 in range(6): o = o * 10 + dig[i] return o @@ -153,9 +153,9 @@ def test_more_complex_repeater(get_contract_with_gas_estimation): @external def repeat() -> int128: out: int128 = 0 - for i in range(6): + for i: uint256 in range(6): out = out * 10 - for j in range(4): + for j: int128 in range(4): out = out + j return(out) """ @@ -170,7 +170,7 @@ def test_offset_repeater(get_contract_with_gas_estimation, typ): @external def sum() -> {typ}: out: {typ} = 0 - for i in range(80, 121): + for i: {typ} in range(80, 121): out = out + i return out """ @@ -185,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101, bound=101): + for i: {typ} in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -205,7 +205,7 @@ def _bar() -> bool: @external def foo() -> bool: - for i in range(3): + for i: uint256 in range(3): self._bar() return True """ @@ -219,8 +219,8 @@ def test_return_inside_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for j in range(10): + for i: {typ} in range(10): + for j: {typ} in range(10): if j > 5: if i > a: return i @@ -254,14 +254,14 @@ def test_for_range_edge(get_contract, typ): def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x - 1, x, bound=1): + for i: {typ} in range(x - 1, x, bound=1): if i + 1 == max_value({typ}): found = True assert found found = False x = max_value({typ}) - 1 - for i in range(x - 1, x + 1, bound=2): + for i: {typ} in range(x - 1, x + 1, bound=2): if i + 1 == max_value({typ}): found = True assert found @@ -276,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x + 2, bound=2): + for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) @@ -289,8 +289,8 @@ def test_return_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -318,8 +318,8 @@ def test_return_void_nested_repeater(get_contract, typ, val): result: {typ} @internal def _final(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -347,8 +347,8 @@ def test_external_nested_repeater(get_contract, typ, val): code = f""" @external def foo(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -368,8 +368,8 @@ def test_external_void_nested_repeater(get_contract, typ, val): result: public({typ}) @external def foo(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -388,8 +388,8 @@ def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if a < 2: break return 6 diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index af189e6dca..df379d3f16 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -159,7 +159,7 @@ def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5 return True """ @@ -179,7 +179,7 @@ def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5, "because reasons" return True """ diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index f10d22ec99..422f53fdeb 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -152,7 +152,7 @@ def _increment(): @external def returnten() -> int128: - for i in range(10): + for i: uint256 in range(10): self._increment() return self.counter """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 671d424d60..891ed5aebe 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index fcf71f12f0..72171dd4b5 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -125,7 +125,7 @@ def test_harder_decimal_test(get_contract_with_gas_estimation): @external def phooey(inp: decimal) -> decimal: x: decimal = 10000.0 - for i in range(4): + for i: uint256 in range(4): x = x * inp return x diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 1ee9b8d835..882629de65 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -268,7 +268,7 @@ def test_zero_padding_with_private(get_contract): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py index f9fcf37b25..6597facd1b 100644 --- a/tests/functional/codegen/types/test_bytes_zero_padding.py +++ b/tests/functional/codegen/types/test_bytes_zero_padding.py @@ -10,7 +10,7 @@ def little_endian_contract(get_contract_module): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 70a68e3206..e47eda6042 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -969,7 +969,7 @@ def foo() -> (uint256, uint256, uint256, uint256, uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -981,7 +981,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: some_var: uint256 @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.some_var = x # test that typechecker for append args works self.my_array.append(self.some_var) @@ -994,9 +994,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) - for x in xs: + for x: uint256 in xs: self.my_array.pop() return self.my_array """, @@ -1008,7 +1008,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array, self.my_array.pop() """, @@ -1020,7 +1020,7 @@ def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array.pop(), self.my_array """, @@ -1033,7 +1033,7 @@ def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] i: uint256 = 0 - for x in xs: + for x: uint256 in xs: if i >= len(xs) - 1: break ys.append(x) @@ -1049,7 +1049,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -1061,9 +1061,9 @@ def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() return ys """, @@ -1075,9 +1075,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() ys.pop() # fail return ys @@ -1328,7 +1328,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: DynArray[Foo, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): e: Foobar = _baz[i].z f: uint256 = convert(e, uint256) sum += _baz[i].x * _baz[i].y + f @@ -1397,7 +1397,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: DynArray[Bar, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index b5b9538c20..ee287064e8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -566,7 +566,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: Foo[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _baz[i].x * _baz[i].y return sum """ @@ -608,7 +608,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: Bar[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 7dd8c35929..652102c376 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -106,6 +106,6 @@ def has_no_docstrings(c): @hypothesis.settings(max_examples=500) def test_grammar_bruteforce(code): if utf8_encodable(code): - _, _, reformatted_code = pre_parse(code + "\n") + _, _, _, reformatted_code = pre_parse(code + "\n") tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index 0b7ec21bdb..4240aec8d2 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -80,13 +80,13 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, """ @external def foo(): - for i in range(1, 2, 3, 4): + for i: uint256 in range(1, 2, 3, 4): pass """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4bd0b4fcb9..7adf9538c7 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -57,7 +57,7 @@ def foo() -> int128: return 5 @external def bar(): - for i in range(self.foo(), self.foo() + 1): + for i: int128 in range(self.foo(), self.foo() + 1): pass""", """ glob: int128 @@ -67,13 +67,13 @@ def foo() -> int128: return 5 @external def bar(): - for i in [1,2,3,4,self.foo()]: + for i: int128 in [1,2,3,4,self.foo()]: pass""", """ @external def foo(): x: int128 = 5 - for i in range(x): + for i: int128 in range(x): pass""", """ f:int128 diff --git a/tests/functional/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py index 942aa3fa68..466b5509ca 100644 --- a/tests/functional/syntax/test_blockscope.py +++ b/tests/functional/syntax/test_blockscope.py @@ -33,7 +33,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a = 1 """, @@ -41,7 +41,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a += 1 """, diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index ffd2f1faa0..7089dee3bb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i in range(CONST / 4): + for i: uint256 in range(CONST / 4): pass """, """ diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a9c3ad5cab..66981a90de 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -15,7 +15,7 @@ """ @external def foo(): - for a[1] in range(10): + for a[1]: uint256 in range(10): pass """, StructureException, @@ -26,7 +26,7 @@ def foo(): """ @external def bar(): - for i in range(1,2,bound=0): + for i: uint256 in range(1,2,bound=0): pass """, StructureException, @@ -38,7 +38,7 @@ def bar(): @external def foo(): x: uint256 = 100 - for _ in range(10, bound=x): + for _: uint256 in range(10, bound=x): pass """, StateAccessViolation, @@ -49,7 +49,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=5): + for _: uint256 in range(10, 20, bound=5): pass """, StructureException, @@ -60,7 +60,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=0): + for _: uint256 in range(10, 20, bound=0): pass """, StructureException, @@ -72,7 +72,7 @@ def foo(): @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2,extra=3): + for i: uint256 in range(x,x+1,bound=2,extra=3): pass """, ArgumentException, @@ -83,7 +83,7 @@ def bar(): """ @external def bar(): - for i in range(0): + for i: uint256 in range(0): pass """, StructureException, @@ -95,7 +95,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x): + for i: uint256 in range(x): pass """, StateAccessViolation, @@ -107,7 +107,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(0, x): + for i: uint256 in range(0, x): pass """, StateAccessViolation, @@ -118,7 +118,7 @@ def bar(): """ @external def repeat(n: uint256) -> uint256: - for i in range(0, n * 10): + for i: uint256 in range(0, n * 10): pass return n """, @@ -131,7 +131,7 @@ def repeat(n: uint256) -> uint256: @external def bar(): x:uint256 = 1 - for i in range(0, x + 1): + for i: uint256 in range(0, x + 1): pass """, StateAccessViolation, @@ -142,7 +142,7 @@ def bar(): """ @external def bar(): - for i in range(2, 1): + for i: uint256 in range(2, 1): pass """, StructureException, @@ -154,7 +154,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x, x): + for i: uint256 in range(x, x): pass """, StateAccessViolation, @@ -166,7 +166,7 @@ def bar(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i: int128 in range(x, x + 10): pass """, StateAccessViolation, @@ -177,7 +177,7 @@ def foo(): """ @external def repeat(n: uint256) -> uint256: - for i in range(n, 6): + for i: uint256 in range(n, 6): pass return x """, @@ -190,7 +190,7 @@ def repeat(n: uint256) -> uint256: @external def foo(x: int128): y: int128 = 7 - for i in range(x, x + y): + for i: int128 in range(x, x + y): pass """, StateAccessViolation, @@ -201,7 +201,7 @@ def foo(x: int128): """ @external def bar(x: uint256): - for i in range(3, x): + for i: uint256 in range(3, x): pass """, StateAccessViolation, @@ -215,12 +215,12 @@ def bar(x: uint256): @external def foo(): - for i in range(FOO, BAR): + for i: uint256 in range(FOO, BAR): pass """, TypeMismatch, - "Iterator values are of different types", - "range(FOO, BAR)", + "Given reference has type int128, expected uint256", + "FOO", ), ( """ @@ -228,12 +228,12 @@ def foo(): @external def foo(): - for i in range(10, bound=FOO): + for i: int128 in range(10, bound=FOO): pass """, StructureException, "Bound must be at least 1", - "-1", + "FOO", ), ] @@ -252,41 +252,41 @@ def test_range_fail(bad_code, error_type, message, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) assert message == exc_info.value.message - assert source_code == exc_info.value.args[1].node_source_code + assert source_code == exc_info.value.args[1].get_original_node().node_source_code valid_list = [ """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass """, """ @external def foo(): - for i in range(10, 20): + for i: uint256 in range(10, 20): pass """, """ @external def foo(): x: int128 = 5 - for i in range(1, x, bound=4): + for i: int128 in range(1, x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(x, bound=4): + for i: int128 in range(x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(0, x, bound=4): + for i: int128 in range(0, x, bound=4): pass """, """ @@ -295,7 +295,7 @@ def kick(): nonpayable foos: Foo[3] @external def kick_foos(): - for foo in self.foos: + for foo: Foo in self.foos: foo.kick() """, ] diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index db41de5526..3936f8c220 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -306,7 +306,7 @@ def foo(): @external def foo(): x: DynArray[uint256, 3] = [1, 2, 3] - for i in [[], []]: + for i: DynArray[uint256, 3] in [[], []]: x = i """, ] diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index d413340083..a6bc3147e6 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -24,7 +24,7 @@ def foo(): """ @external def foo(): - for i in [0x6b175474e89094c44da98b954eedeac495271d0F]: + for i: address in [0x6b175474e89094c44da98b954eedeac495271d0F]: pass """, """ diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 16ce6fe631..b202f6d8a3 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, reformatted_code = pre_parse(source_code) + _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types) + annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) return py_ast, reformatted_code diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 682c13ca84..020e83627c 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -173,7 +173,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _ = pre_parse(code) + settings, _, _, _ = pre_parse(code) assert settings == pre_parse_settings diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 44b823757c..b2851e908a 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -58,7 +58,7 @@ def ctor_only(): @internal def runtime_only(): - for i in range(10): + for i: uint256 in range(10): self.s += 1 @external diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index c9a152b09c..5b478dd2aa 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -6,7 +6,7 @@ @internal def _baz(a: int128) -> int128: b: int128 = a - for i in range(2, 5): + for i: int128 in range(2, 5): b *= i if b > 31337: break diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e2c0f555af..607587cc28 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -22,7 +22,7 @@ def foo(): @internal def bar(): self.foo() - for i in self.a: + for i: uint256 in self.a: pass """ vyper_module = parse_to_ast(code) @@ -42,7 +42,7 @@ def foo(a: uint256[3]) -> uint256[3]: @internal def bar(): a: uint256[3] = [1,2,3] - for i in a: + for i: uint256 in a: self.foo(a) """ vyper_module = parse_to_ast(code) @@ -56,7 +56,7 @@ def test_modify_iterator(dummy_input_bundle): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.a[0] = 1 """ vyper_module = parse_to_ast(code) @@ -70,7 +70,7 @@ def test_bad_keywords(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, boundddd=10): + for i: uint256 in range(n, boundddd=10): x += i """ vyper_module = parse_to_ast(code) @@ -84,7 +84,7 @@ def test_bad_bound(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, bound=n): + for i: uint256 in range(n, bound=n): x += i """ vyper_module = parse_to_ast(code) @@ -103,7 +103,7 @@ def foo(): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.foo() """ vyper_module = parse_to_ast(code) @@ -126,7 +126,7 @@ def bar(): @internal def baz(): - for i in self.a: + for i: uint256 in self.a: self.bar() """ vyper_module = parse_to_ast(code) @@ -138,32 +138,32 @@ def baz(): """ @external def main(): - for j in range(3): + for j: uint256 in range(3): x: uint256 = j y: uint16 = j """, # GH issue 3212 """ @external def foo(): - for i in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1]: - for j in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + for j: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1,2,3]: - for j in [1,2,3]: - b:uint256 = j + i - c:uint16 = i + for i: uint256 in [1,2,3]: + for j: uint256 in [1,2,3]: + b: uint256 = j + i + c: uint16 = i """, # GH issue 3374 ] diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 7889473b19..234e96e552 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -178,8 +178,7 @@ body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] -// TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` -loop_variable: NAME [":" NAME] +loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index efab5117d4..7a8c7443b7 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,15 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata") +NODE_BASE_ATTRIBUTES = ( + "_children", + "_depth", + "_parent", + "ast_type", + "node_id", + "_metadata", + "_original_node", +) NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", @@ -257,6 +265,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): self.set_parent(parent) self._children: set = set() self._metadata: NodeMetadata = NodeMetadata() + self._original_node = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset @@ -411,12 +420,16 @@ def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata - # set the folded node's parent so that get_ancestor works - # this is mainly important for error messages. - node._parent = self._parent + # set the "original node" so that exceptions can point to the original + # node and not the folded node + node = copy.copy(node) + node._original_node = self self._metadata["folded_value"] = node + def get_original_node(self) -> "VyperNode": + return self._original_node or self + def _try_fold(self) -> "VyperNode": """ Attempt to constant-fold the content of a node, returning the result of @@ -1546,7 +1559,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 38a9d31695..b657cf2245 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,7 +54,7 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) except SyntaxError as e: @@ -73,11 +73,15 @@ def parse_to_ast_with_settings( py_ast, source_code, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -113,11 +117,13 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets + _loop_var_annotations: dict[int, dict[str, Any]] def __init__( self, source_code: str, - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -127,11 +133,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -213,6 +219,47 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not raw_annotation: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + try: + annotation = python_ast.parse(raw_annotation, mode="eval") + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(raw_annotation) + tokens.mark_tokens(annotation) + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + assert isinstance(annotation, python_ast.Expression) + annotation = annotation.body + + old_target = node.target + new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) + node.target = new_target + + self.generic_visit(node) + + return node + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -355,7 +402,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, - modification_offsets: Optional[ModificationOffsets] = None, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -369,6 +417,9 @@ def annotate_python_ast( The AST to be annotated and optimized. source_code : str The originating source code of the AST. + loop_var_annotations: dict, optional + A mapping of line numbers of `For` nodes to the type annotation of the iterator + extracted during pre-parsing. modification_offsets : dict, optional A mapping of class names to their original class types. @@ -381,6 +432,7 @@ def annotate_python_ast( visitor = AnnotatingVisitor( source_code, modification_offsets, + for_loop_annotations, tokens, source_id, module_path=module_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b949a242bb..c7e6f3698f 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,3 +1,4 @@ +import enum import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize @@ -43,6 +44,64 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: ) +class ForParserState(enum.Enum): + NOT_RUNNING = enum.auto() + START_SOON = enum.auto() + RUNNING = enum.auto() + + +# a simple state machine which allows us to handle loop variable annotations +# (which are rejected by the python parser due to pep-526, so we scoop up the +# tokens between `:` and `in` and parse them and add them back in later). +class ForParser: + def __init__(self, code): + self._code = code + self.annotations = {} + self._current_annotation = None + + self._state = ForParserState.NOT_RUNNING + self._current_for_loop = None + + def consume(self, token): + # state machine: we can start slurping tokens soon + if token.type == NAME and token.string == "for": + # note: self._state should be NOT_RUNNING here, but we don't sanity + # check here as that should be an error the parser will handle. + self._state = ForParserState.START_SOON + self._current_for_loop = token.start + + if self._state == ForParserState.NOT_RUNNING: + return False + + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. + + # state machine: end slurping tokens + if token.type == NAME and token.string == "in": + self._state = ForParserState.NOT_RUNNING + self.annotations[self._current_for_loop] = self._current_annotation or [] + self._current_annotation = None + return False + + if self._state != ForParserState.RUNNING: + return False + + # slurp the token + self._current_annotation.append(token) + return True + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} @@ -51,7 +110,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -60,9 +119,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"). + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. Parameters ---------- @@ -71,21 +132,25 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Returns ------- - dict - Mapping of offsets where source was modified. + Settings + Compilation settings based on the directives in the source code + ModificationOffsets + A mapping of class names to their original class types. + dict[tuple[int, int], str] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) - for i in range(len(token_list)): - token = token_list[i] + for token in token_list: toks = [token] typ = token.type @@ -146,8 +211,18 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - result.extend(toks) + + if not for_parser.consume(token): + result.extend(toks) + except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return settings, modification_offsets, untokenize(result).decode("utf-8") + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + v_source = untokenize(v) + # untokenize adds backslashes and whitespace, strip them. + v_source = v_source.replace("\\", "").strip() + for_loop_annotations[k] = v_source + + return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index c896fc7ef6..39d97c4abe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2157,7 +2157,7 @@ def build_IR(self, expr, args, kwargs, context): z = x / 2.0 + 0.5 y: decimal = x - for i in range(256): + for i: uint256 in range(256): if z == y: break y = z diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..a47faefeb1 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -33,7 +33,7 @@ ) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.semantics.types.shortcuts import UINT256_T class Stmt: @@ -231,19 +231,17 @@ def parse_For(self): return self._parse_For_list() def _parse_For_range(self): - # TODO make sure type always gets annotated - if "type" in self.stmt.target._metadata: - iter_typ = self.stmt.target._metadata["type"] - else: - iter_typ = INT256_T + assert "type" in self.stmt.target.target._metadata + target_type = self.stmt.target.target._metadata["type"] # Get arg0 - for_iter: vy_ast.Call = self.stmt.iter - args_len = len(for_iter.args) + range_call: vy_ast.Call = self.stmt.iter + assert isinstance(range_call, vy_ast.Call) + args_len = len(range_call.args) if args_len == 1: - arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + arg0, arg1 = (IRnode.from_list(0, typ=target_type), range_call.args[0]) elif args_len == 2: - arg0, arg1 = for_iter.args + arg0, arg1 = range_call.args else: # pragma: nocover raise TypeCheckFailure("unreachable: bad # of arguments to range()") @@ -251,7 +249,7 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) end = Expr.parse_value_expr(arg1, self.context) kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + s.arg: Expr.parse_value_expr(s.value, self.context) for s in range_call.keywords } if "bound" in kwargs: @@ -270,9 +268,9 @@ def _parse_For_range(self): if rounds_bound < 1: # pragma: nocover raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) - iptr = self.context.new_variable(varname, iter_typ) + varname = self.stmt.target.target.id + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) self.context.forvars[varname] = True @@ -297,11 +295,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..51f3fea14c 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -92,6 +92,10 @@ def __str__(self): node = value[1] if isinstance(value, tuple) else value node_msg = "" + if isinstance(node, vy_ast.VyperNode): + # folded AST nodes contain pointers to the original source + node = node.get_original_node() + try: source_annotation = annotate_source_code( # add trailing space because EOF exceptions point one char beyond the length diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91fb2c21f0..169c71269d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,13 +1,11 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -40,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -347,8 +344,10 @@ def visit_Expr(self, node): self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): - if isinstance(node.iter, vy_ast.Subscript): - raise StructureException("Cannot iterate over a nested list", node.iter) + if not isinstance(node.target.target, vy_ast.Name): + raise StructureException("Invalid syntax for loop iterator", node.target.target) + + target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -356,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _validate_range_call(node.iter) else: # iteration over a variable or literal list @@ -364,14 +363,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -415,65 +410,28 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) + target_name = node.target.target.id + with self.namespace.enter_scope(): + self.namespace[target_name] = VarInfo( + target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) - for_loop_exceptions = [] - iter_name = node.target.id - for possible_target_type in type_list: - # type check the for loop body using each possible type for iterator value + for stmt in node.body: + self.visit(stmt) - with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo( - possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT - ) + self.expr_visitor.visit(node.target.target, target_type) - try: - with NodeMetadata.enter_typechecker_speculation(): - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) + elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + args = node.iter.args + kwargs = [s.value for s in node.iter.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + else: + iter_type = get_exact_type_from_node(node.iter) + self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -750,25 +708,18 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _validate_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range() :return: None """ + assert node.func.get("id") == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] - all_args = (start, end, *kwargs.values()) - for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) - if "bound" in kwargs: bound = kwargs["bound"] if bound.has_folded_value: @@ -787,5 +738,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list From a1fd228cb9936c3e4bbca6f3ee3fb4426ef45490 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Mon, 8 Jan 2024 18:12:59 +0200 Subject: [PATCH 2/3] feat: add `bb` and `bb_runtime` output options (#3700) add `bb` and `bb_runtime` output options for dumping venom output. disable this output format in tests for now since many vyper contracts still will not compile to venom. --- tests/conftest.py | 43 +++++++++++++------ .../unit/cli/vyper_json/test_compile_json.py | 11 +++-- vyper/compiler/__init__.py | 4 ++ vyper/compiler/output.py | 8 ++++ vyper/compiler/phases.py | 8 +--- 5 files changed, 51 insertions(+), 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 51b4b4459a..e673f17b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,14 @@ def pytest_addoption(parser): parser.addoption("--enable-compiler-debug-mode", action="store_true") +@pytest.fixture(scope="module") +def output_formats(): + output_formats = compiler.OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] + return output_formats + + @pytest.fixture(scope="module") def optimize(pytestconfig): flag = pytestconfig.getoption("optimize") @@ -281,7 +289,14 @@ def ir_compiler(ir, *args, **kwargs): def _get_contract( - w3, source_code, optimize, *args, override_opt_level=None, input_bundle=None, **kwargs + w3, + source_code, + optimize, + output_formats, + *args, + override_opt_level=None, + input_bundle=None, + **kwargs, ): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) @@ -289,7 +304,7 @@ def _get_contract( out = compiler.compile_code( source_code, # test that all output formats can get generated - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing @@ -309,17 +324,17 @@ def _get_contract( @pytest.fixture(scope="module") -def get_contract(w3, optimize): +def get_contract(w3, optimize, output_formats): def fn(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return fn @pytest.fixture -def get_contract_with_gas_estimation(tester, w3, optimize): +def get_contract_with_gas_estimation(tester, w3, optimize, output_formats): def get_contract_with_gas_estimation(source_code, *args, **kwargs): - contract = _get_contract(w3, source_code, optimize, *args, **kwargs) + contract = _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) for abi_ in contract._classic_contract.functions.abi: if abi_["type"] == "function": set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"]) @@ -329,15 +344,15 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs): @pytest.fixture -def get_contract_with_gas_estimation_for_constants(w3, optimize): +def get_contract_with_gas_estimation_for_constants(w3, optimize, output_formats): def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_with_gas_estimation_for_constants @pytest.fixture(scope="module") -def get_contract_module(optimize): +def get_contract_module(optimize, output_formats): """ This fixture is used for Hypothesis tests to ensure that the same contract is called over multiple runs of the test. @@ -350,18 +365,18 @@ def get_contract_module(optimize): w3.eth.set_gas_price_strategy(zero_gas_price_strategy) def get_contract_module(source_code, *args, **kwargs): - return _get_contract(w3, source_code, optimize, *args, **kwargs) + return _get_contract(w3, source_code, optimize, output_formats, *args, **kwargs) return get_contract_module -def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs): +def _deploy_blueprint_for(w3, source_code, optimize, output_formats, initcode_prefix=b"", **kwargs): settings = Settings() settings.evm_version = kwargs.pop("evm_version", None) settings.optimize = optimize out = compiler.compile_code( source_code, - output_formats=list(compiler.OUTPUT_FORMATS.keys()), + output_formats=output_formats, settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) @@ -394,9 +409,9 @@ def factory(address): @pytest.fixture(scope="module") -def deploy_blueprint_for(w3, optimize): +def deploy_blueprint_for(w3, optimize, output_formats): def deploy_blueprint_for(source_code, *args, **kwargs): - return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs) + return _deploy_blueprint_for(w3, source_code, optimize, output_formats, *args, **kwargs) return deploy_blueprint_for diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index a50946ba21..c805e2b5b1 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -113,18 +113,23 @@ def test_keyerror_becomes_jsonerror(input_json): def test_compile_json(input_json, input_bundle): foo_input = input_bundle.load_file("contracts/foo.vy") + # remove bb and bb_runtime from output formats + # because they require venom (experimental) + output_formats = OUTPUT_FORMATS.copy() + del output_formats["bb"] + del output_formats["bb_runtime"] foo = compile_from_file_input( - foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + foo_input, output_formats=output_formats, input_bundle=input_bundle ) library_input = input_bundle.load_file("contracts/library.vy") library = compile_from_file_input( - library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + library_input, output_formats=output_formats, input_bundle=input_bundle ) bar_input = input_bundle.load_file("contracts/bar.vy") bar = compile_from_file_input( - bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + bar_input, output_formats=output_formats, input_bundle=input_bundle ) compile_code_results = { diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 0f7d7a8014..9297f9e3c3 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -23,6 +23,8 @@ # requires ir_node "external_interface": output.build_external_interface_output, "interface": output.build_interface_output, + "bb": output.build_bb_output, + "bb_runtime": output.build_bb_runtime_output, "ir": output.build_ir_output, "ir_runtime": output.build_ir_runtime_output, "ir_dict": output.build_ir_dict_output, @@ -84,6 +86,8 @@ def compile_from_file_input( two arguments - the name of the contract, and the exception that was raised no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool + Use experimental codegen. Defaults to False Returns ------- diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 8ccf6abee1..5e11a20139 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -84,6 +84,14 @@ def build_interface_output(compiler_data: CompilerData) -> str: return out +def build_bb_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[0] + + +def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: + return compiler_data.venom_functions[1] + + def build_ir_output(compiler_data: CompilerData) -> IRnode: if compiler_data.show_gas_estimates: IRnode.repr_show_gas = True diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 850adcfea3..ba6ccbda20 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -174,9 +174,7 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes( - self.global_ctx, self.settings.optimize, self.settings.experimental_codegen - ) + return generate_ir_nodes(self.global_ctx, self.settings.optimize) @property def ir_nodes(self) -> IRnode: @@ -272,9 +270,7 @@ def generate_annotated_ast( return vyper_module, symbol_tables -def generate_ir_nodes( - global_ctx: ModuleT, optimize: OptimizationLevel, experimental_codegen: bool -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. From 06fa46a53ee2134951ee3cd9a8f46dcceb61f620 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 10 Jan 2024 18:23:45 -0500 Subject: [PATCH 3/3] refactor: constant folding (#3719) refactor constant folding into a visitor class, clean up a couple passes this moves responsibility for knowing how to fold a node off the individual AST node implementations and into the ConstantFolder visitor. by adding a dependency to get_namespace() it also makes constant folding more generic; soon we can rely on more things being in the global namespace at constant folding time. --- tests/functional/builtins/folding/test_abs.py | 7 +- .../builtins/folding/test_addmod_mulmod.py | 7 +- .../builtins/folding/test_bitwise.py | 16 +- .../builtins/folding/test_epsilon.py | 7 +- .../builtins/folding/test_floor_ceil.py | 7 +- .../folding/test_fold_as_wei_value.py | 10 +- .../builtins/folding/test_keccak_sha.py | 15 +- tests/functional/builtins/folding/test_len.py | 15 +- .../builtins/folding/test_min_max.py | 15 +- .../builtins/folding/test_powmod.py | 7 +- tests/functional/grammar/test_grammar.py | 4 +- tests/functional/syntax/test_bool.py | 2 +- .../unit/ast/nodes/test_fold_binop_decimal.py | 13 +- tests/unit/ast/nodes/test_fold_binop_int.py | 15 +- tests/unit/ast/nodes/test_fold_boolop.py | 6 +- tests/unit/ast/nodes/test_fold_compare.py | 12 +- tests/unit/ast/nodes/test_fold_subscript.py | 4 +- tests/unit/ast/nodes/test_fold_unaryop.py | 6 +- tests/utils.py | 9 + vyper/ast/nodes.py | 188 +---------- vyper/ast/nodes.pyi | 1 - vyper/builtins/functions.py | 7 +- vyper/exceptions.py | 2 +- vyper/semantics/analysis/local.py | 3 +- vyper/semantics/analysis/module.py | 9 +- vyper/semantics/analysis/pre_typecheck.py | 298 ++++++++++++------ vyper/semantics/analysis/utils.py | 2 + vyper/semantics/types/base.py | 2 +- vyper/semantics/types/module.py | 16 +- vyper/semantics/types/user.py | 11 + 30 files changed, 337 insertions(+), 379 deletions(-) diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 68131678fa..c954380def 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -2,8 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType @@ -19,9 +18,9 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"abs({a})") + vyper_ast = parse_and_fold(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a) == new_node.value == abs(a) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 1d789f1655..e6a9fc193f 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -2,8 +2,7 @@ from hypothesis import assume, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) @@ -22,8 +21,8 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") + vyper_ast = parse_and_fold(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 53a6d333a0..c1ff7674bb 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -29,7 +29,7 @@ def foo(a: uint256, b: uint256) -> uint256: contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") + vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -48,10 +48,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case @@ -78,10 +77,9 @@ def foo(a: int256, b: uint256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. @@ -105,7 +103,7 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"~{value}") + vyper_ast = parse_and_fold(f"~{value}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 4f5e9434ec..7bc2afe757 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("typ_name", ["decimal"]) @@ -13,8 +12,8 @@ def foo() -> {typ_name}: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") + vyper_ast = parse_and_fold(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 04921e504e..9e63c7b099 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -4,8 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_decimals = st.decimals( min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 @@ -28,8 +27,8 @@ def foo(a: decimal) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 4287615bab..01af646a16 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.builtins import functions as vy_fn from vyper.utils import SizeLimits @@ -30,9 +30,9 @@ def foo(a: decimal) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -49,8 +49,8 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index 8da420538f..3b5f99891f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~' # NOQA: E501 @@ -20,9 +19,9 @@ def foo(a: String[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") + vyper_ast = parse_and_fold(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -39,9 +38,9 @@ def foo(a: Bytes[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -60,8 +59,8 @@ def foo(a: Bytes[100]) -> bytes32: value = f"0x{value.hex()}" - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index 967f906555..6d59751748 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("length", [0, 1, 32, 33, 64, 65, 1024]) @@ -15,9 +14,9 @@ def foo(a: String[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") + vyper_ast = parse_and_fold(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -33,9 +32,9 @@ def foo(a: Bytes[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") + vyper_ast = parse_and_fold(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value.encode()) == new_node.value @@ -51,8 +50,8 @@ def foo(a: Bytes[1024]) -> uint256: value = f"0x{'00' * length}" - vyper_ast = vy_ast.parse_to_ast(f"len({value})") + vyper_ast = parse_and_fold(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 36a611fa1b..752b64eb04 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.utils import SizeLimits st_decimals = st.decimals( @@ -29,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -48,9 +47,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -67,8 +66,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index a3c2567f58..ad1197e8e3 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256) @@ -19,8 +18,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") + vyper_ast = parse_and_fold(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 652102c376..351793b28e 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -4,7 +4,7 @@ import hypothesis import hypothesis.strategies as st import pytest -from hypothesis import assume, given +from hypothesis import HealthCheck, assume, given from hypothesis.extra.lark import LarkStrategy from vyper.ast import Module, parse_to_ast @@ -103,7 +103,7 @@ def has_no_docstrings(c): @pytest.mark.fuzzing @given(code=from_grammar().filter(lambda c: utf8_encodable(c))) -@hypothesis.settings(max_examples=500) +@hypothesis.settings(max_examples=500, suppress_health_check=[HealthCheck.too_slow]) def test_grammar_bruteforce(code): if utf8_encodable(code): _, _, _, reformatted_code = pre_parse(code + "\n") diff --git a/tests/functional/syntax/test_bool.py b/tests/functional/syntax/test_bool.py index 48ed37321a..5388a92b95 100644 --- a/tests/functional/syntax/test_bool.py +++ b/tests/functional/syntax/test_bool.py @@ -37,7 +37,7 @@ def foo(): def foo() -> bool: return (1 == 2) <= (1 == 1) """, - TypeMismatch, + InvalidOperation, ), """ @external diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index e426a11de9..a75d114f88 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -4,7 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException st_decimals = st.decimals( @@ -28,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -45,11 +45,8 @@ def foo(a: decimal, b: decimal) -> decimal: def test_binop_pow(): # raises because Vyper does not support decimal exponentiation - vyper_ast = vy_ast.parse_to_ast("3.1337 ** 4.2") - old_node = vyper_ast.body[0].value - with pytest.raises(TypeMismatch): - old_node.get_folded_value() + _ = parse_and_fold("3.1337 ** 4.2") @pytest.mark.fuzzing @@ -72,8 +69,8 @@ def foo({input_value}) -> decimal: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = -(2**127) <= expected < 2**127 diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index 904b36c167..d9340927fe 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import ZeroDivisionException st_int32 = st.integers(min_value=-(2**32), max_value=2**32) @@ -24,9 +24,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -54,9 +54,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: @@ -83,7 +83,7 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") + vyper_ast = parse_and_fold(f"{left} ** {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -112,9 +112,8 @@ def foo({input_value}) -> int128: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) - try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = True diff --git a/tests/unit/ast/nodes/test_fold_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py index 3c42da0d26..082e6f35c3 100644 --- a/tests/unit/ast/nodes/test_fold_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold variables = "abcdefghij" @@ -24,7 +24,7 @@ def foo({input_value}) -> bool: literal_op = f" {comparator} ".join(str(i) for i in values) - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -52,7 +52,7 @@ def foo({input_value}) -> bool: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, comparators)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index 2b7c0f09d7..aab8ac0b2d 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import UnfoldableNode @@ -19,7 +19,7 @@ def foo(a: int128, b: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -39,7 +39,7 @@ def foo(a: uint128, b: uint128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -63,7 +63,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") + vyper_ast = parse_and_fold(f"{left} in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -92,7 +92,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") + vyper_ast = parse_and_fold(f"{left} not in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -106,7 +106,7 @@ def bar(a: int128) -> bool: @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_type_mismatch(op): - vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") + vyper_ast = parse_and_fold(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py index 1884abf73b..3ed26d07b7 100644 --- a/tests/unit/ast/nodes/test_fold_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.fuzzing @@ -19,7 +19,7 @@ def foo(array: int128[10], idx: uint256) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") + vyper_ast = parse_and_fold(f"{array}[{idx}]") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py index ff48adfe71..af72f5f8b0 100644 --- a/tests/unit/ast/nodes/test_fold_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -1,6 +1,6 @@ import pytest -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.parametrize("bool_cond", [True, False]) @@ -12,7 +12,7 @@ def foo(a: bool) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") + vyper_ast = parse_and_fold(f"not {bool_cond}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -30,7 +30,7 @@ def foo(a: bool) -> bool: contract = get_contract(source) literal_op = f"{'not ' * count}{bool_cond}" - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/utils.py b/tests/utils.py index 0c89c39ff3..b8a6b493d8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ import contextlib import os +from vyper import ast as vy_ast +from vyper.semantics.analysis.pre_typecheck import pre_typecheck + @contextlib.contextmanager def working_directory(directory): @@ -10,3 +13,9 @@ def working_directory(directory): yield finally: os.chdir(tmp) + + +def parse_and_fold(source_code): + ast = vy_ast.parse_to_ast(source_code) + pre_typecheck(ast) + return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 7a8c7443b7..90365c63d5 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -400,21 +400,11 @@ def get_folded_value(self) -> "VyperNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. - - - The returned value is cached on `_metadata["folded_value"]`. - - For constant/literal nodes, the node should be directly returned - without caching to the metadata. """ - if self.is_literal_value: - return self - - if "folded_value" not in self._metadata: - res = self._try_fold() # possibly throws UnfoldableNode - self._set_folded_value(res) - - return self._metadata["folded_value"] + try: + return self._metadata["folded_value"] + except KeyError: + raise UnfoldableNode("not foldable", self) def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once @@ -422,7 +412,9 @@ def _set_folded_value(self, node: "VyperNode") -> None: # set the "original node" so that exceptions can point to the original # node and not the folded node - node = copy.copy(node) + cls = node.__class__ + # make a fresh copy so that the node metadata is fresh. + node = cls(**{i: getattr(node, i) for i in node.get_fields() if hasattr(node, i)}) node._original_node = self self._metadata["folded_value"] = node @@ -430,17 +422,6 @@ def _set_folded_value(self, node: "VyperNode") -> None: def get_original_node(self) -> "VyperNode": return self._original_node or self - def _try_fold(self) -> "VyperNode": - """ - Attempt to constant-fold the content of a node, returning the result of - constant-folding if possible. - - If a node cannot be folded, it should raise `UnfoldableNode`. This - base implementation acts as a catch-all to raise on any inherited - classes that do not implement the method. - """ - raise UnfoldableNode(f"{type(self)} cannot be folded") - def validate(self) -> None: """ Validate the content of a node. @@ -919,10 +900,6 @@ class List(ExprNode): def is_literal_value(self): return all(e.is_literal_value for e in self.elements) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class Tuple(ExprNode): __slots__ = ("elements",) @@ -936,10 +913,6 @@ def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class NameConstant(Constant): __slots__ = () @@ -960,10 +933,6 @@ class Dict(ExprNode): def is_literal_value(self): return all(v.is_literal_value for v in self.values) - def _try_fold(self) -> ExprNode: - values = [v.get_folded_value() for v in self.values] - return type(self).from_node(self, values=values) - class Name(ExprNode): __slots__ = ("id",) @@ -972,27 +941,6 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the unary operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - operand = self.operand.get_folded_value() - - if isinstance(self.op, Not) and not isinstance(operand, NameConstant): - raise UnfoldableNode("not a boolean!", self.operand) - if isinstance(self.op, USub) and not isinstance(operand, Num): - raise UnfoldableNode("not a number!", self.operand) - if isinstance(self.op, Invert) and not isinstance(operand, Int): - raise UnfoldableNode("not an int!", self.operand) - - value = self.op._op(operand.value) - return type(operand).from_node(self, value=value) - class Operator(VyperNode): pass @@ -1021,30 +969,6 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the arithmetic operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if type(left) is not type(right): - raise UnfoldableNode("invalid operation", self) - if not isinstance(left, Num): - raise UnfoldableNode("not a number!", self.left) - - # this validation is performed to prevent the compiler from hanging - # on very large shifts and improve the error message for negative - # values. - if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) - - value = self.op._op(left.value, right.value) - return type(left).from_node(self, value=value) - class Add(Operator): __slots__ = () @@ -1170,24 +1094,6 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the boolean operation. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - values = [v.get_folded_value() for v in self.values] - - if any(not isinstance(v, NameConstant) for v in values): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - values = [v.value for v in values] - value = self.op._op(values) - return NameConstant.from_node(self, value=value) - class And(Operator): __slots__ = () @@ -1225,40 +1131,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the comparison. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if not isinstance(left, Constant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - # CMC 2022-08-04 we could probably remove these evaluation rules as they - # are taken care of in the IR optimizer now. - if isinstance(self.op, (In, NotIn)): - if not isinstance(right, List): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if next((i for i in right.elements if not isinstance(i, Constant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if len(set([type(i) for i in right.elements])) > 1: - raise UnfoldableNode("List contains multiple literal types") - value = self.op._op(left.value, [i.value for i in right.elements]) - return NameConstant.from_node(self, value=value) - - if not isinstance(left, type(right)): - raise UnfoldableNode("Cannot compare different literal types") - - if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)): - raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self) - - value = self.op._op(left.value, right.value) - return NameConstant.from_node(self, value=value) - class Eq(Operator): __slots__ = () @@ -1315,21 +1187,6 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") - # try checking if this is a builtin, which is foldable - def _try_fold(self): - if not isinstance(self.func, Name): - raise UnfoldableNode("not a builtin", self) - - # cursed import cycle! - from vyper.builtins.functions import DISPATCH_TABLE - - func_name = self.func.id - if func_name not in DISPATCH_TABLE: - raise UnfoldableNode("not a builtin", self) - - builtin_t = DISPATCH_TABLE[func_name] - return builtin_t._try_fold(self) - class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1342,37 +1199,6 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the subscript. - - This method reduces an indexed reference to a literal array into the value - within the array, e.g. `["foo", "bar"][1]` becomes `"bar"` - - Returns - ------- - ExprNode - Node representing the result of the evaluation. - """ - slice_ = self.slice.value.get_folded_value() - value = self.value.get_folded_value() - - if not isinstance(value, List): - raise UnfoldableNode("Subscript object is not a literal list") - - elements = value.elements - if len(set([type(i) for i in elements])) > 1: - raise UnfoldableNode("List contains multiple node types") - - if not isinstance(slice_, Int): - raise UnfoldableNode("invalid index type", slice_) - - idx = slice_.value - if idx < 0 or idx >= len(elements): - raise UnfoldableNode("invalid index value") - - return elements[idx] - class Index(VyperNode): __slots__ = ("value",) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 8bc4a4eb57..4a5bc0d001 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -31,7 +31,6 @@ class VyperNode: @classmethod def get_fields(cls: Any) -> set: ... def get_folded_value(self) -> VyperNode: ... - def _try_fold(self) -> VyperNode: ... def _set_folded_value(self, node: VyperNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 39d97c4abe..4f8101dfbe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -90,6 +90,7 @@ ceil32, fourbytes_to_int, keccak256, + method_id, method_id_int, vyper_warn, ) @@ -723,12 +724,12 @@ def _try_fold(self, node): raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) return_type = self.infer_kwarg_types(node)["output_type"].typedef - value = method_id_int(value.value) + value = method_id(value.value) if return_type.compare_type(BYTES4_T): - return vy_ast.Hex.from_node(node, value=hex(value)) + return vy_ast.Hex.from_node(node, value="0x" + value.hex()) else: - return vy_ast.Bytes.from_node(node, value=value.to_bytes(4, "big")) + return vy_ast.Bytes.from_node(node, value=value) def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 51f3fea14c..04667aaa59 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -373,7 +373,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 169c71269d..cc8ddaf98d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -510,8 +510,7 @@ def visit(self, node, typ): # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() - validate_expected_type(folded_node, typ) - folded_node._metadata["type"] = typ + self.visit(folded_node, typ) def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8e435f870f..4a7e33e848 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -26,11 +26,7 @@ from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions from vyper.semantics.analysis.pre_typecheck import pre_typecheck -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT @@ -315,12 +311,11 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) + ExprVisitor().visit(node.value, type_) # performs validate_expected_type if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) - validate_expected_type(node.value, type_) _validate_self_namespace() return _finalize() diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index a1302ce9c9..1c2a5392c3 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,94 +1,210 @@ from vyper import ast as vy_ast -from vyper.exceptions import UnfoldableNode - - -# try to fold a node, swallowing exceptions. this function is very similar to -# `VyperNode.get_folded_value()` but additionally checks in the constants -# table if the node is a `Name` node. -# -# CMC 2023-12-30 a potential refactor would be to move this function into -# `Name._try_fold` (which would require modifying the signature of _try_fold to -# take an optional constants table as parameter). this would remove the -# need to use this function in conjunction with `get_descendants` since -# `VyperNode._try_fold()` already recurses. it would also remove the need -# for `VyperNode._set_folded_value()`. -def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): - if node.has_folded_value: - return - - if isinstance(node, vy_ast.Name): - # check if it's in constants table - var_name = node.id - - if var_name not in constants: - return - - res = constants[var_name] - node._set_folded_value(res) - return - - try: - # call get_folded_value for its side effects - node.get_folded_value() - except UnfoldableNode: - pass - - -def _get_constants(node: vy_ast.Module) -> dict: - constants: dict[str, vy_ast.VyperNode] = {} - const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) - - while True: - n_processed = 0 - - for c in const_var_decls.copy(): - assert c.value is not None # guaranteed by VariableDecl.validate() - - for n in c.get_descendants(reverse=True): - _fold_with_constants(n, constants) - +from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.namespace import get_namespace + + +def pre_typecheck(module_ast: vy_ast.Module): + ConstantFolder(module_ast).run() + + +class ConstantFolder(VyperNodeVisitorBase): + def __init__(self, module_ast): + self._constants = {} + self._module_ast = module_ast + + def run(self): + self._get_constants() + self.visit(self._module_ast) + + def _get_constants(self): + module = self._module_ast + const_var_decls = module.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + # visit the entire constant node in case its type annotation + # has unfolded constants in it. + self.visit(c) + + assert c.value is not None # guaranteed by VariableDecl.validate() + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + self._constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + def visit(self, node): + if node.has_folded_value: + return node.get_folded_value() + + for c in node.get_children(): try: - val = c.value.get_folded_value() + self.visit(c) except UnfoldableNode: - # not foldable, maybe it depends on other constants - # so try again later - continue - - # note that if a constant is redefined, its value will be - # overwritten, but it is okay because the error is handled - # downstream - name = c.target.id - constants[name] = val - - n_processed += 1 - const_var_decls.remove(c) - - if n_processed == 0: - # this condition means that there are some constant vardecls - # whose values are not foldable. this can happen for struct - # and interface constants for instance. these are valid constant - # declarations, but we just can't fold them at this stage. - break - - return constants - - -# perform constant folding on a module AST -def pre_typecheck(node: vy_ast.Module) -> None: - """ - Perform pre-typechecking steps on a Module AST node. - At this point, this is limited to performing constant folding. - """ - constants = _get_constants(node) - - # note: use reverse to get descendants in leaf-first order - for n in node.get_descendants(reverse=True): - # try folding every single node. note this should be done before - # type checking because the typechecker requires literals or - # foldable nodes in type signatures and some other places (e.g. - # certain builtin kwargs). - # - # note we could limit to only folding nodes which are required - # during type checking, but it's easier to just fold everything - # and be done with it! - _fold_with_constants(n, constants) + # ignore bubbled up exceptions + pass + + try: + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + folded_value = visitor_fn(node) + node._set_folded_value(folded_value) + return folded_value + except UnfoldableNode: + # ignore bubbled up exceptions + pass + + return node + + def visit_Constant(self, node) -> vy_ast.ExprNode: + return node + + def visit_Name(self, node) -> vy_ast.ExprNode: + try: + return self._constants[node.id] + except KeyError: + raise UnfoldableNode("unknown name", node) + + def visit_UnaryOp(self, node): + operand = node.operand.get_folded_value() + + if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant): + raise UnfoldableNode("not a boolean!", node.operand) + if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num): + raise UnfoldableNode("not a number!", node.operand) + if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int): + raise UnfoldableNode("not an int!", node.operand) + + value = node.op._op(operand.value) + return type(operand).from_node(node, value=value) + + def visit_BinOp(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if type(left) is not type(right): + raise UnfoldableNode("invalid operation", node) + if not isinstance(left, vy_ast.Num): + raise UnfoldableNode("not a number!", node.left) + + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", node.right) + + value = node.op._op(left.value, right.value) + return type(left).from_node(node, value=value) + + def visit_BoolOp(self, node): + values = [v.get_folded_value() for v in node.values] + + if any(not isinstance(v, vy_ast.NameConstant) for v in values): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + values = [v.value for v in values] + value = node.op._op(values) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_Compare(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if not isinstance(left, vy_ast.Constant): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + # CMC 2022-08-04 we could probably remove these evaluation rules as they + # are taken care of in the IR optimizer now. + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + if not isinstance(right, vy_ast.List): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if len(set([type(i) for i in right.elements])) > 1: + raise UnfoldableNode("List contains multiple literal types") + value = node.op._op(left.value, [i.value for i in right.elements]) + return vy_ast.NameConstant.from_node(node, value=value) + + if not isinstance(left, type(right)): + raise UnfoldableNode("Cannot compare different literal types") + + # this is maybe just handled in the type checker. + if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num): + raise UnfoldableNode( + f"Invalid literal types for {node.op.description} comparison", node + ) + + value = node.op._op(left.value, right.value) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_List(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Tuple(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Dict(self, node) -> vy_ast.ExprNode: + values = [v.get_folded_value() for v in node.values] + return type(node).from_node(node, values=values) + + def visit_Call(self, node) -> vy_ast.ExprNode: + if not isinstance(node.func, vy_ast.Name): + raise UnfoldableNode("not a builtin", node) + + namespace = get_namespace() + + func_name = node.func.id + if func_name not in namespace: + raise UnfoldableNode("unknown", node) + + varinfo = namespace[func_name] + if not isinstance(varinfo, VarInfo): + raise UnfoldableNode("unfoldable", node) + + typ = varinfo.typ + # TODO: rename to vyper_type.try_fold_call_expr + if not hasattr(typ, "_try_fold"): + raise UnfoldableNode("unfoldable", node) + return typ._try_fold(node) # type: ignore + + def visit_Subscript(self, node) -> vy_ast.ExprNode: + slice_ = node.slice.value.get_folded_value() + value = node.value.get_folded_value() + + if not isinstance(value, vy_ast.List): + raise UnfoldableNode("Subscript object is not a literal list") + + elements = value.elements + if len(set([type(i) for i in elements])) > 1: + raise UnfoldableNode("List contains multiple node types") + + if not isinstance(slice_, vy_ast.Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") + + return elements[idx] diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ba1b02b8d6..359b51b71e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -650,6 +650,8 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) + + # builtins call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) return call_type_modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 429ba807e1..14949f693f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -19,7 +19,7 @@ # type of type `type_` class _GenericTypeAcceptor: def __repr__(self): - return repr(self.type_) + return f"GenericTypeAcceptor({self.type_})" def __init__(self, type_): self.type_ = type_ diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..8f1a5cc0dc 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -4,7 +4,12 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args -from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.exceptions import ( + InterfaceViolation, + NamespaceCollision, + StructureException, + UnfoldableNode, +) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.namespace import get_namespace @@ -53,6 +58,15 @@ def abi_type(self) -> ABIType: def __repr__(self): return f"interface {self._id}" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + arg = node.args[0].get_folded_value() + if not isinstance(arg, vy_ast.Hex): + raise UnfoldableNode("not an address", arg) + + return node + # when using the type itself (not an instance) in the call position def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": self._ctor_arg_types(node) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a4e782349d..8ef9aa8d4a 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -10,6 +10,7 @@ InvalidAttribute, NamespaceCollision, StructureException, + UnfoldableNode, UnknownAttribute, VariableDeclarationException, ) @@ -357,6 +358,16 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": def __repr__(self): return f"{self._id} declaration object" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + args = [arg.get_folded_value() for arg in node.args] + if not isinstance(args[0], vy_ast.Dict): + raise UnfoldableNode("not a dict") + + # it can't be reduced, but this lets upstream code know it's constant + return node + @property def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types.values())