From a8dd405e0f0c0f70aebe3107fc133959341d5d8f Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 2 Nov 2023 18:33:31 +0800 Subject: [PATCH] folding fixes --- .../test_external_contract_calls.py | 2 +- .../parser/features/iteration/test_for_in_list.py | 3 +-- tests/parser/functions/test_bitwise.py | 4 ++-- tests/parser/functions/test_interfaces.py | 2 +- tests/parser/functions/test_minmax_value.py | 6 +++--- tests/parser/types/numbers/test_signed_ints.py | 2 +- vyper/semantics/analysis/pre_typecheck.py | 14 ++++---------- vyper/semantics/utils.py | 4 ++++ 8 files changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/parser/features/external_contracts/test_external_contract_calls.py b/tests/parser/features/external_contracts/test_external_contract_calls.py index 12fcde2f4f..935c7b74fc 100644 --- a/tests/parser/features/external_contracts/test_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_external_contract_calls.py @@ -380,7 +380,7 @@ def test_int128_too_long(get_contract, assert_tx_failed): contract_1 = """ @external def foo() -> int256: - return (2**255)-1 + return max_value(int256) """ c = get_contract(contract_1) diff --git a/tests/parser/features/iteration/test_for_in_list.py b/tests/parser/features/iteration/test_for_in_list.py index ebde55d228..b634735d31 100644 --- a/tests/parser/features/iteration/test_for_in_list.py +++ b/tests/parser/features/iteration/test_for_in_list.py @@ -8,7 +8,6 @@ InvalidType, IteratorException, NamespaceCollision, - OverflowException, StateAccessViolation, StructureException, TypeMismatch, @@ -773,7 +772,7 @@ def test_for() -> int128: a = i return a """, - OverflowException, + InvalidType, ), ( """ diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py index 3c916a064e..718610ad86 100644 --- a/tests/parser/functions/test_bitwise.py +++ b/tests/parser/functions/test_bitwise.py @@ -1,7 +1,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import InvalidLiteral, InvalidOperation, InvalidType, TypeMismatch +from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch from vyper.utils import unsigned_to_signed code = """ @@ -153,7 +153,7 @@ def foo() -> uint256: def foo() -> uint256: return 2 << -1 """, - InvalidType, + InvalidLiteral, ), ] diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py index c16e188cfd..3da9b16346 100644 --- a/tests/parser/functions/test_interfaces.py +++ b/tests/parser/functions/test_interfaces.py @@ -406,7 +406,7 @@ def ok() -> {typ}: @external def should_fail() -> int256: - return -2**255 # OOB for all int/uint types with less than 256 bits + return min_value(int256) # OOB for all int/uint types with less than 256 bits """ code = f""" diff --git a/tests/parser/functions/test_minmax_value.py b/tests/parser/functions/test_minmax_value.py index a82db870f5..520f12db51 100644 --- a/tests/parser/functions/test_minmax_value.py +++ b/tests/parser/functions/test_minmax_value.py @@ -1,6 +1,6 @@ import pytest -from vyper.exceptions import OverflowException +from vyper.exceptions import InvalidType, OverflowException from vyper.semantics.types import DecimalT, IntegerT @@ -35,8 +35,8 @@ def foo(): a: {typ} = min_value({typ}) - 1 """ - assert_compile_failed(lambda: get_contract(upper), OverflowException) - assert_compile_failed(lambda: get_contract(lower), OverflowException) + assert_compile_failed(lambda: get_contract(upper), (InvalidType, OverflowException)) + assert_compile_failed(lambda: get_contract(lower), (InvalidType, OverflowException)) @pytest.mark.parametrize("typ", [DecimalT()]) diff --git a/tests/parser/types/numbers/test_signed_ints.py b/tests/parser/types/numbers/test_signed_ints.py index e7f3427f05..dfe0eaa20f 100644 --- a/tests/parser/types/numbers/test_signed_ints.py +++ b/tests/parser/types/numbers/test_signed_ints.py @@ -201,7 +201,7 @@ def num_sub() -> {typ}: return 1-2**{typ.bits} """ - assert_compile_failed(lambda: get_contract(code), OverflowException) + assert_compile_failed(lambda: get_contract(code), (InvalidType, OverflowException)) ARITHMETIC_OPS = { diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index 92843ac5d2..c10a70cf7c 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -131,6 +131,7 @@ def visit_For(self, node): self.visit(node.target) def visit_If(self, node): + self.visit(node.test) for n in node.body: self.visit(n) for n in node.orelse: @@ -155,9 +156,10 @@ def visit_keyword(self, node): def visit_Attribute(self, node): self.visit(node.value) + value_node = get_folded_value(node.value) - if isinstance(value_node, vy_ast.Dict): - for k, v in zip(value_node.keys, value_node.values): + if isinstance(value_node, vy_ast.Call) and isinstance(value_node.args[0], vy_ast.Dict): + for k, v in zip(value_node.args[0].keys, value_node.args[0].values): if k.id == node.attr: node._metadata["folded_value"] = v return @@ -169,10 +171,6 @@ def visit_BinOp(self, node): left = get_folded_value(node.left) right = get_folded_value(node.right) if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)): - if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not ( - 0 <= right.value <= 256 - ): - return node._metadata["folded_value"] = node.evaluate(left, right) def visit_BoolOp(self, node): @@ -251,10 +249,6 @@ def _subscriptable_helper(self, node): for e in node.elements: self.visit(e) - values = [get_folded_value(e) for e in node.elements] - if None not in values: - node._metadata["folded_value"] = type(node).from_node(node, elts=values) - def visit_List(self, node): self._subscriptable_helper(node) diff --git a/vyper/semantics/utils.py b/vyper/semantics/utils.py index 5caa49af03..3106f756b8 100644 --- a/vyper/semantics/utils.py +++ b/vyper/semantics/utils.py @@ -6,6 +6,10 @@ def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]: if isinstance(node, vy_ast.Constant): return node + elif isinstance(node, vy_ast.List): + values = [get_folded_value(e) for e in node.elements] + if None not in values: + return type(node).from_node(node, elts=values) elif isinstance(node, vy_ast.Index): return get_folded_value(node.value)