Skip to content

Commit

Permalink
Merge pull request #2 from cyberthirst/tests/transient
Browse files Browse the repository at this point in the history
add complex transient storage tests
  • Loading branch information
tserg authored Mar 29, 2024
2 parents 9b4c209 + 8ee1061 commit 2948300
Showing 1 changed file with 175 additions and 19 deletions.
194 changes: 175 additions & 19 deletions tests/functional/codegen/features/test_transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,35 @@ def foo(_a: uint256, _b: uint256, _c: address, _d: int256) -> MyStruct:
assert c.foo(*values) == values


def test_complex_struct_transient(get_contract):
code = """
struct MyStruct:
a: address
b: MyStruct2
c: DynArray[DynArray[uint256, 3], 3]
struct MyStruct2:
a: DynArray[uint256, 3]
my_struct: public(transient(MyStruct))
@external
def foo(_a: address, _b: MyStruct2, _c: DynArray[DynArray[uint256, 3], 3]) -> MyStruct:
self.my_struct = MyStruct(
a=_a,
b=_b,
c=_c,
)
return self.my_struct
"""
values = ("0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", ([1],), [[3,4], [5,6]])

c = get_contract(code)
assert c.foo(*values) == values
assert c.my_struct() == (None, ([],), [])
assert c.foo(*values) == values


def test_complex_transient_modifiable(get_contract):
code = """
struct MyStruct:
Expand Down Expand Up @@ -200,6 +229,53 @@ def foo(_a: uint256, _b: uint256, _c: uint256) -> uint256[3]:
assert c.foo(*values) == list(values)


def test_hashmap_transient(get_contract):
code = """
my_map: public(transient(HashMap[uint256, uint256]))
@external
def foo(k: uint256, v: uint256) -> uint256:
self.my_map[k] = v
return self.my_map[k]
"""
c = get_contract(code)
for v in range(5):
for k in range(5):
assert c.foo(k, v) == v
assert c.my_map(k) == 0


def test_complex_hashmap_transient(get_contract):
code = """
struct MyStruct:
a: uint256
b: DynArray[uint256, 3]
my_map: public(transient(HashMap[uint256, MyStruct]))
my_res: public(HashMap[uint256, MyStruct])
@external
def do_side_effects():
a: DynArray[uint256, 3] = [1, 2, 3]
s: MyStruct = MyStruct(a=100, b=a)
for i: uint256 in range(2):
for j: uint256 in range(3):
s.b[j] = i + j
s.a = i
self.my_map[i] = s
self.my_res[i] = self.my_map[i]
"""
c = get_contract(code)
c.do_side_effects(transact={})
for i in range(2):
assert c.my_res(i)[0] == i
assert c.my_map(i)[0] == 0
for j in range(3):
print(c.my_res(i)[1])
assert c.my_res(i)[1][j] == i + j
assert c.my_map(i)[1] == []


def test_dynarray_transient(get_contract):
code = """
my_list: public(transient(DynArray[uint256, 3]))
Expand Down Expand Up @@ -248,11 +324,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256:


def test_nested_dynarray_transient(get_contract):
code = """
my_list: public(transient(DynArray[DynArray[DynArray[int128, 3], 3], 3]))
@external
def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[int128, 3], 3], 3]:
set_list = """
self.my_list = [
[[x, y, z], [y, z, x], [z, y, x]],
[
Expand All @@ -265,25 +337,29 @@ def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[i
[z * (-2), y * (-3), x * (-4)],
[z * (-y), y * (-x), x * (-z)],
],
]
]
"""
code = f"""
interface Iface:
def my_list(x: uint256, y: uint256, z: uint256) -> int128: view
my_list: public(transient(DynArray[DynArray[DynArray[int128, 3], 3], 3]))
@external
def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[int128, 3], 3], 3]:
{set_list}
return self.my_list
@external
def get_idx_two(x: int128, y: int128, z: int128) -> int128:
self.my_list = [
[[x, y, z], [y, z, x], [z, y, x]],
[
[x * 1000 + y, y * 1000 + z, z * 1000 + x],
[- (x * 1000 + y), - (y * 1000 + z), - (z * 1000 + x)],
[- (x * 1000) + y, - (y * 1000) + z, - (z * 1000) + x],
],
[
[z * 2, y * 3, x * 4],
[z * (-2), y * (-3), x * (-4)],
[z * (-y), y * (-x), x * (-z)],
],
]
{set_list}
return self.my_list[2][2][2]
@external
def get_idx_two_using_getter(x: int128, y: int128, z: int128) -> int128:
{set_list}
#return self.my_list[2][2][2]
return staticcall Iface(self).my_list(2, 2, 2)
"""
values = (37, 41, 73)
expected_values = [
Expand All @@ -299,6 +375,9 @@ def get_idx_two(x: int128, y: int128, z: int128) -> int128:
assert c.get_idx_two(*values) == expected_values[2][2][2]
with pytest.raises(TransactionFailed):
c.my_list(0, 0, 0)
assert c.get_idx_two_using_getter(*values) == expected_values[2][2][2]
with pytest.raises(TransactionFailed):
c.my_list(0, 0, 0)


@pytest.mark.parametrize("n", range(5))
Expand Down Expand Up @@ -383,3 +462,80 @@ def bar(i: uint256, a: address) -> uint256:
value = 333
assert c2.bar(value, c1.address) == value
assert c1.get_x() == 0


def test_modules_transient(get_contract, make_input_bundle):
lib1 = """
counter: transient(uint256)
"""
lib2 = """
import lib1
uses: lib1
counter: transient(uint256)
counter2: public(uint256)
@internal
def foo():
lib1.counter += 1
"""
main = """
import lib2
import lib1
initializes: lib2[lib1 := lib1]
initializes: lib1
@external
def foo() -> (uint256, uint256):
lib1.counter = 2
lib2.foo()
lib2.counter = 10
return lib1.counter, lib2.counter
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == [3, 10]


def test_complex_modules_transient(get_contract, make_input_bundle):
lib1 = """
l: transient(uint256[3])
"""
lib2 = """
import lib1
uses: lib1
struct MyStruct:
a: uint256
b: uint256
s: transient(MyStruct)
@internal
def foo():
self.s = MyStruct(a=lib1.l[0], b=lib1.l[1])
"""
main = """
import lib2
import lib1
initializes: lib2[lib1 := lib1]
initializes: lib1
my_map: HashMap[uint256, uint256]
@external
def foo() -> (uint256[3], uint256, uint256, uint256):
lib1.l = [1, 2, 3]
lib2.foo()
self.my_map[0] = 42
return lib1.l, lib2.s.a, lib2.s.b, self.my_map[0]
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == [[1, 2, 3], 1, 2, 42]

0 comments on commit 2948300

Please sign in to comment.