Skip to content

Commit

Permalink
folding fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 2, 2023
1 parent 4e99c48 commit a8dd405
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_int128_too_long(get_contract, assert_tx_failed):
contract_1 = """
@external
def foo() -> int256:
return (2**255)-1
return max_value(int256)
"""

c = get_contract(contract_1)
Expand Down
3 changes: 1 addition & 2 deletions tests/parser/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
InvalidType,
IteratorException,
NamespaceCollision,
OverflowException,
StateAccessViolation,
StructureException,
TypeMismatch,
Expand Down Expand Up @@ -773,7 +772,7 @@ def test_for() -> int128:
a = i
return a
""",
OverflowException,
InvalidType,
),
(
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/functions/test_bitwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import InvalidLiteral, InvalidOperation, InvalidType, TypeMismatch
from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch
from vyper.utils import unsigned_to_signed

code = """
Expand Down Expand Up @@ -153,7 +153,7 @@ def foo() -> uint256:
def foo() -> uint256:
return 2 << -1
""",
InvalidType,
InvalidLiteral,
),
]

Expand Down
2 changes: 1 addition & 1 deletion tests/parser/functions/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def ok() -> {typ}:
@external
def should_fail() -> int256:
return -2**255 # OOB for all int/uint types with less than 256 bits
return min_value(int256) # OOB for all int/uint types with less than 256 bits
"""

code = f"""
Expand Down
6 changes: 3 additions & 3 deletions tests/parser/functions/test_minmax_value.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vyper.exceptions import OverflowException
from vyper.exceptions import InvalidType, OverflowException
from vyper.semantics.types import DecimalT, IntegerT


Expand Down Expand Up @@ -35,8 +35,8 @@ def foo():
a: {typ} = min_value({typ}) - 1
"""

assert_compile_failed(lambda: get_contract(upper), OverflowException)
assert_compile_failed(lambda: get_contract(lower), OverflowException)
assert_compile_failed(lambda: get_contract(upper), (InvalidType, OverflowException))
assert_compile_failed(lambda: get_contract(lower), (InvalidType, OverflowException))


@pytest.mark.parametrize("typ", [DecimalT()])
Expand Down
2 changes: 1 addition & 1 deletion tests/parser/types/numbers/test_signed_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def num_sub() -> {typ}:
return 1-2**{typ.bits}
"""

assert_compile_failed(lambda: get_contract(code), OverflowException)
assert_compile_failed(lambda: get_contract(code), (InvalidType, OverflowException))


ARITHMETIC_OPS = {
Expand Down
14 changes: 4 additions & 10 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def visit_For(self, node):
self.visit(node.target)

def visit_If(self, node):
self.visit(node.test)
for n in node.body:
self.visit(n)
for n in node.orelse:
Expand All @@ -155,9 +156,10 @@ def visit_keyword(self, node):

def visit_Attribute(self, node):
self.visit(node.value)

value_node = get_folded_value(node.value)
if isinstance(value_node, vy_ast.Dict):
for k, v in zip(value_node.keys, value_node.values):
if isinstance(value_node, vy_ast.Call) and isinstance(value_node.args[0], vy_ast.Dict):
for k, v in zip(value_node.args[0].keys, value_node.args[0].values):
if k.id == node.attr:
node._metadata["folded_value"] = v
return
Expand All @@ -169,10 +171,6 @@ def visit_BinOp(self, node):
left = get_folded_value(node.left)
right = get_folded_value(node.right)
if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)):
if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (
0 <= right.value <= 256
):
return
node._metadata["folded_value"] = node.evaluate(left, right)

def visit_BoolOp(self, node):
Expand Down Expand Up @@ -251,10 +249,6 @@ def _subscriptable_helper(self, node):
for e in node.elements:
self.visit(e)

values = [get_folded_value(e) for e in node.elements]
if None not in values:
node._metadata["folded_value"] = type(node).from_node(node, elts=values)

def visit_List(self, node):
self._subscriptable_helper(node)

Expand Down
4 changes: 4 additions & 0 deletions vyper/semantics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]:
if isinstance(node, vy_ast.Constant):
return node
elif isinstance(node, vy_ast.List):
values = [get_folded_value(e) for e in node.elements]
if None not in values:
return type(node).from_node(node, elts=values)
elif isinstance(node, vy_ast.Index):
return get_folded_value(node.value)

Expand Down

0 comments on commit a8dd405

Please sign in to comment.