Skip to content

Commit

Permalink
fix wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 9, 2023
1 parent c1a9bbc commit 38e94d2
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 23 deletions.
35 changes: 21 additions & 14 deletions tests/parser/functions/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def test_fail3() -> Bytes[3]:
# should revert - returns_Bytes3 is inferred to have return type Bytes[2]
# (because test_fail3 comes after test_fail1)
return self.foo.returns_Bytes3()
"""

bad_c = get_contract(external_contract)
Expand All @@ -559,7 +558,9 @@ def test_fail3() -> Bytes[3]:

assert_tx_failed(lambda: c.test_fail1())
assert_tx_failed(lambda: c.test_fail2())
assert_tx_failed(lambda: c.test_fail3())


# assert_tx_failed(lambda: c.test_fail3())


def test_units_interface(w3, get_contract, make_input_bundle):
Expand Down Expand Up @@ -764,10 +765,12 @@ def test_json(a: {0}) -> (uint256, {0}):


@pytest.mark.parametrize("type_str,value", type_str_params)
def test_json_interface_calls_tuple_return(get_contract, type_str, value):
def test_json_interface_calls_tuple_return(
get_contract, type_str, value, make_input_bundle, make_file
):
code = interface_tuple_return_test_code.format(type_str)

abi = compile_code(code, ["abi"])["abi"]
abi = compile_code(code, output_formats=["abi"])["abi"]
c1 = get_contract(code)

code = f"""
Expand All @@ -778,20 +781,23 @@ def test_json_interface_calls_tuple_return(get_contract, type_str, value):
def test_call(a: address, b: {type_str}) -> (uint256, {type_str}):
return jsonabi(a).test_json(b)
"""
c2 = get_contract(code, interface_codes={"jsonabi": {"type": "json", "code": abi}})
input_bundle = make_input_bundle({"jsonabi.json": json.dumps(abi)})
c2 = get_contract(code, input_bundle=input_bundle)
assert c2.test_call(c1.address, value) == [1, value]
c3 = get_contract(
code, interface_codes={"jsonabi": {"type": "json", "code": convert_v1_abi(abi)}}
)

make_file("jsonabi.json", json.dumps(convert_v1_abi(abi)))
c3 = get_contract(code, input_bundle=input_bundle)
assert c3.test_call(c1.address, value) == [1, value]


@pytest.mark.parametrize("typ,length,value", [("Bytes", 4, b"newp"), ("String", 6, "potato")])
def test_json_interface_calls_bytestring_widening(get_contract, typ, length, value):
def test_json_interface_calls_bytestring_widening(
get_contract, typ, length, value, make_input_bundle, make_file
):
type_str = f"{typ}[{length}]"
code = interface_test_code.format(type_str)

abi = compile_code(code, ["abi"])["abi"]
abi = compile_code(code, output_formats=["abi"])["abi"]
c1 = get_contract(code)

widened_typ1_str = f"{typ}[{length + 1}]"
Expand All @@ -806,9 +812,10 @@ def test_call(a: address, b: {type_str}) -> ({widened_typ1_str}, {widened_typ2_s
y: {widened_typ2_str} = jsonabi(a).test_json(b)
return x, y
"""
c2 = get_contract(code, interface_codes={"jsonabi": {"type": "json", "code": abi}})
input_bundle = make_input_bundle({"jsonabi.json": json.dumps(abi)})
c2 = get_contract(code, input_bundle=input_bundle)
assert c2.test_call(c1.address, value) == [value, value]
c3 = get_contract(
code, interface_codes={"jsonabi": {"type": "json", "code": convert_v1_abi(abi)}}
)

make_file("jsonabi.json", json.dumps(convert_v1_abi(abi)))
c3 = get_contract(code, input_bundle=input_bundle)
assert c3.test_call(c1.address, value) == [value, value]
10 changes: 10 additions & 0 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Optional

from vyper import ast as vy_ast
Expand Down Expand Up @@ -642,6 +643,15 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
self.visit(node.func, call_type)

if isinstance(call_type, ContractFunctionT):
# if the function call is an expression (if it is a statement,
# the function type is passed down), and its return type consists
# of at least one bytestring (which is initialized as zero-length
# in `type_from_abi`), overwrite the return type with a concrete type.
if not isinstance(typ, ContractFunctionT) and call_type.returns_abi_bytestring:
call_type_copy = copy.copy(call_type)
call_type_copy.return_type = typ
self.visit(node.func, call_type_copy)

# function calls
if call_type.is_internal:
self.func.called_functions.add(call_type)
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _is_empty_list(node):

def _is_type_in_list(obj, types_list):
# check if a type object is in a list of types
return any(i.compare_type(obj) for i in types_list)
return any(obj.compare_type(i) for i in types_list)


# NOTE: dead fn
Expand Down
19 changes: 11 additions & 8 deletions vyper/semantics/types/bytestrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,23 @@ def compare_type(self, other):
if not super().compare_type(other):
return False

# when comparing two literals, invert the comparison so that the
# larger type is derived during annotation of the smaller type for widening
if self._is_literal and other._is_literal:
return self._length <= other._length

# if both are non-literals, ensure the current length fits within the other
if self._length and other._length:
# when comparing two literals, or two bytestrings of non-zero lengths,
# ensure the current length fits within the other
if (self._is_literal and other._is_literal) or (self._length and other._length):
return self._length >= other._length

# relax typechecking if length has not been set for other type
# (e.g. JSON ABI import) so that it can be updated in annotation phase
# (e.g. JSON ABI import, `address.code`) so that it can be updated in
# annotation phase
if self._length:
return True

# if both are non-literals and zero length, then the bytestring length
# cannot be derived and it is likely to be a syntax error, so we defer
# the syntax error to be handled downstream for better error messages
if self._length == other._length == 0:
return True

return other.compare_type(self)

@classmethod
Expand Down
9 changes: 9 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.bytestrings import _BytestringT
from vyper.semantics.types.primitives import BoolT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.subscriptable import TupleT
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
function_visibility: FunctionVisibility,
state_mutability: StateMutability,
nonreentrant: Optional[str] = None,
returns_abi_bytestring: Optional[bool] = False,
) -> None:
super().__init__()

Expand All @@ -91,6 +93,7 @@ def __init__(
self.visibility = function_visibility
self.mutability = state_mutability
self.nonreentrant = nonreentrant
self.returns_abi_bytestring = returns_abi_bytestring

# a list of internal functions this function calls
self.called_functions = OrderedSet()
Expand Down Expand Up @@ -139,21 +142,27 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT":
-------
ContractFunctionT object.
"""
returns_abi_bytestring = False
positional_args = []
for item in abi["inputs"]:
positional_args.append(PositionalArg(item["name"], type_from_abi(item)))
return_type = None
if len(abi["outputs"]) == 1:
return_type = type_from_abi(abi["outputs"][0])
if isinstance(return_type, _BytestringT):
returns_abi_bytestring = True
elif len(abi["outputs"]) > 1:
return_type = TupleT(tuple(type_from_abi(i) for i in abi["outputs"]))
if any([i for i in return_type.member_types if isinstance(i, _BytestringT)]):
returns_abi_bytestring = True
return cls(
abi["name"],
positional_args,
[],
return_type,
function_visibility=FunctionVisibility.EXTERNAL,
state_mutability=StateMutability.from_abi(abi),
returns_abi_bytestring=returns_abi_bytestring,
)

@classmethod
Expand Down

0 comments on commit 38e94d2

Please sign in to comment.