Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Mar 23, 2024
1 parent 4820822 commit d8ec500
Showing 1 changed file with 314 additions and 1 deletion.
315 changes: 314 additions & 1 deletion tests/functional/codegen/features/test_transient.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest

from vyper.compiler import compile_code
from vyper.evm.opcodes import version_check
from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version, version_check
from vyper.exceptions import StructureException

POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]}


def test_transient_blocked(evm_version):
# test transient is blocked on pre-cancun and compiles post-cancun
Expand Down Expand Up @@ -55,3 +57,314 @@ def setter(k: address, v: uint256):

assert "TLOAD" in t
assert "TSTORE" in t


@pytest.mark.parametrize(
"typ,value",
[
("uint256", 42),
("int256", -(2**200)),
("int128", -(2**126)),
("address", "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE"),
("bytes32", b"deadbeef" * 4),
("bool", True),
("String[10]", "Vyper hiss"),
("Bytes[10]", b"Vyper hiss"),
],
)
def test_value_storage_retrieval(typ, value, get_contract):
code = f"""
bar: transient({typ})
@external
def foo(a: {typ}) -> {typ}:
self.bar = a
return self.bar
"""

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(value) == value
else:
with pytest.raises(StructureException):
compile_code(code)


@pytest.mark.parametrize("val", [0, 1, 2**256 - 1])
def test_usage_in_constructor(get_contract, val):
code = """
A: transient(uint256)
a: public(uint256)
@deploy
def __init__(_a: uint256):
self.A = _a
self.a = self.A
@external
@view
def a1() -> uint256:
return self.A
"""

if version_check(begin="cancun"):
c = get_contract(code, val)
assert c.a() == val
assert c.a1() == 0
else:
with pytest.raises(StructureException):
compile_code(code)


def test_multiple_transient_values(get_contract):
code = """
a: transient(uint256)
b: transient(address)
c: transient(String[64])
@external
def foo(_a: uint256, _b: address, _c: String[64]) -> (uint256, address, String[64]):
self.a = _a
self.b = _b
self.c = _c
return self.a, self.b, self.c
"""
values = (3, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", "Hello world")

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(*values) == list(values)
else:
with pytest.raises(StructureException):
compile_code(code)


def test_struct_transient(get_contract):
code = """
struct MyStruct:
a: uint256
b: uint256
c: address
d: int256
my_struct: transient(MyStruct)
@external
def foo(_a: uint256, _b: uint256, _c: address, _d: int256) -> MyStruct:
self.my_struct = MyStruct(
a=_a,
b=_b,
c=_c,
d=_d
)
return self.my_struct
"""
values = (100, 42, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", -(2**200))

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(*values) == values
else:
with pytest.raises(StructureException):
compile_code(code)


def test_complex_transient_modifiable(get_contract):
code = """
struct MyStruct:
a: uint256
my_struct: transient(MyStruct)
@external
def foo(a: uint256) -> MyStruct:
self.my_struct = MyStruct(a=a)
# struct members are modifiable after initialization
self.my_struct.a += 1
return self.my_struct
"""

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(1) == (2,)
else:
with pytest.raises(StructureException):
compile_code(code)


def test_list_transient(get_contract):
code = """
my_list: transient(uint256[3])
@external
def foo(_a: uint256, _b: uint256, _c: uint256) -> uint256[3]:
self.my_list = [_a, _b, _c]
return self.my_list
"""
values = (100, 42, 23230)

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(*values) == list(values)
else:
with pytest.raises(StructureException):
compile_code(code)


def test_dynarray_transient(get_contract):
code = """
my_list: transient(DynArray[uint256, 3])
@external
def get_my_list(_a: uint256, _b: uint256, _c: uint256) -> DynArray[uint256, 3]:
self.my_list = [_a, _b, _c]
return self.my_list
@external
def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256:
self.my_list = [_a, _b, _c]
return self.my_list[2]
"""
values = (100, 42, 23230)

if version_check(begin="cancun"):
c = get_contract(code)
assert c.get_my_list(*values) == list(values)
assert c.get_idx_two(*values) == values[2]
else:
with pytest.raises(StructureException):
compile_code(code)


def test_nested_dynarray_transient_2(get_contract):
code = """
my_list: transient(DynArray[DynArray[uint256, 3], 3])
@external
def get_my_list(_a: uint256, _b: uint256, _c: uint256) -> DynArray[DynArray[uint256, 3], 3]:
self.my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]]
return self.my_list
@external
def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256:
self.my_list = [[_a, _b, _c], [_b, _a, _c], [_c, _b, _a]]
return self.my_list[2][2]
"""
values = (100, 42, 23230)
expected_values = [[100, 42, 23230], [42, 100, 23230], [23230, 42, 100]]

if version_check(begin="cancun"):
c = get_contract(code)
assert c.get_my_list(*values) == expected_values
assert c.get_idx_two(*values) == expected_values[2][2]
else:
with pytest.raises(StructureException):
compile_code(code)


def test_nested_dynarray_transient(get_contract):
code = """
my_list: transient(DynArray[DynArray[DynArray[int128, 3], 3], 3])
@external
def get_my_list(x: int128, y: int128, z: int128) -> DynArray[DynArray[DynArray[int128, 3], 3], 3]:
self.my_list = [
[[x, y, z], [y, z, x], [z, y, x]],
[
[x * 1000 + y, y * 1000 + z, z * 1000 + x],
[- (x * 1000 + y), - (y * 1000 + z), - (z * 1000 + x)],
[- (x * 1000) + y, - (y * 1000) + z, - (z * 1000) + x],
],
[
[z * 2, y * 3, x * 4],
[z * (-2), y * (-3), x * (-4)],
[z * (-y), y * (-x), x * (-z)],
],
]
return self.my_list
@external
def get_idx_two(x: int128, y: int128, z: int128) -> int128:
self.my_list = [
[[x, y, z], [y, z, x], [z, y, x]],
[
[x * 1000 + y, y * 1000 + z, z * 1000 + x],
[- (x * 1000 + y), - (y * 1000 + z), - (z * 1000 + x)],
[- (x * 1000) + y, - (y * 1000) + z, - (z * 1000) + x],
],
[
[z * 2, y * 3, x * 4],
[z * (-2), y * (-3), x * (-4)],
[z * (-y), y * (-x), x * (-z)],
],
]
return self.my_list[2][2][2]
"""
values = (37, 41, 73)
expected_values = [
[[37, 41, 73], [41, 73, 37], [73, 41, 37]],
[[37041, 41073, 73037], [-37041, -41073, -73037], [-36959, -40927, -72963]],
[[146, 123, 148], [-146, -123, -148], [-2993, -1517, -2701]],
]

if version_check(begin="cancun"):
c = get_contract(code)
assert c.get_my_list(*values) == expected_values
assert c.get_idx_two(*values) == expected_values[2][2][2]
else:
with pytest.raises(StructureException):
compile_code(code)


@pytest.mark.parametrize("n", range(5))
def test_internal_function_with_transient(get_contract, n):
code = """
@internal
def foo() -> uint256:
self.counter += 1
return self.counter
counter: uint256
val: transient(uint256)
@external
def bar(x: uint256) -> uint256:
self.counter = x
self.foo()
self.val = self.foo()
return self.val
"""

if version_check(begin="cancun"):
c = get_contract(code)
assert c.bar(n) == n + 2
else:
with pytest.raises(StructureException):
compile_code(code)


def test_nested_internal_function_transient(get_contract):
code = """
d: public(uint256)
x: public(transient(uint256))
@deploy
def __init__():
self.d = 1
self.x = 2
self.a()
@internal
def a():
self.b()
@internal
def b():
self.d = self.x
"""
c = get_contract(code)
assert c.d() == 2

0 comments on commit d8ec500

Please sign in to comment.