Skip to content

Commit

Permalink
apply bts suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Mar 26, 2024
1 parent 82ed420 commit 0d05989
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,10 @@ def fn(exception=TransactionFailed, exc_text=None):


@pytest.fixture(autouse=True)
def check_transient_marker(request):
if request.node.get_closest_marker("transient") and not version_check(begin="cancun"):
def check_transient_storage_marker(request):
if request.node.get_closest_marker("uses_transient_storage") and not version_check(
begin="cancun"
):
request.node.add_marker(
pytest.mark.xfail(
reason="transient storage", raises=TransientStorageException, strict=True
Expand Down
51 changes: 28 additions & 23 deletions tests/functional/codegen/features/test_transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from vyper.compiler import compile_code
from vyper.evm.opcodes import version_check
from vyper.exceptions import VyperException
from vyper.exceptions import TransientStorageException, VyperException


# with eth-tester, each call happens in an isolated transaction and so we need to
# test get/set within a single contract call. (we should remove this restriction
# in the future by migrating away from eth-tester).
def test_transient_compiles():
if not version_check(begin="cancun"):
pytest.skip("transient storage will not compile, pre-cancun")
Expand Down Expand Up @@ -46,7 +49,7 @@ def setter(k: address, v: uint256):
assert "TSTORE" in t


@pytest.mark.transient
@pytest.mark.uses_transient_storage
@pytest.mark.parametrize(
"typ,value,zero",
[
Expand Down Expand Up @@ -78,7 +81,7 @@ def foo(a: {typ}) -> {typ}:
assert c.bar() == zero


@pytest.mark.transient
@pytest.mark.uses_transient_storage
@pytest.mark.parametrize("val", [0, 1, 2**256 - 1])
def test_usage_in_constructor(get_contract, val):
code = """
Expand All @@ -103,6 +106,7 @@ def a1() -> uint256:
assert c.a1() == 0


@pytest.mark.uses_transient_storage
def test_multiple_transient_values(get_contract):
code = """
a: public(transient(uint256))
Expand All @@ -116,22 +120,23 @@ def foo(_a: uint256, _b: address, _c: String[64]) -> (uint256, address, String[6
self.c = _c
return self.a, self.b, self.c
"""
values = (3, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", "Hello world")

if version_check(begin="cancun"):
c = get_contract(code)
assert c.foo(*values) == list(values)
assert c.a() == 0
assert c.b() is None
assert c.c() == ""
assert c.foo(*values) == list(values)
else:
# multiple errors
with pytest.raises(VyperException):
compile_code(code)
try:
compile_code(code)
except VyperException as e:
assert e.message.count("TransientStorageException") == 3
raise TransientStorageException()

values = (3, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", "Hello world")
c = get_contract(code)
assert c.foo(*values) == list(values)
assert c.a() == 0
assert c.b() is None
assert c.c() == ""
assert c.foo(*values) == list(values)


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_struct_transient(get_contract):
code = """
struct MyStruct:
Expand Down Expand Up @@ -160,7 +165,7 @@ def foo(_a: uint256, _b: uint256, _c: address, _d: int256) -> MyStruct:
assert c.foo(*values) == values


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_complex_transient_modifiable(get_contract):
code = """
struct MyStruct:
Expand All @@ -184,7 +189,7 @@ def foo(a: uint256) -> MyStruct:
assert c.foo(1) == (2,)


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_list_transient(get_contract):
code = """
my_list: public(transient(uint256[3]))
Expand All @@ -203,7 +208,7 @@ def foo(_a: uint256, _b: uint256, _c: uint256) -> uint256[3]:
assert c.foo(*values) == list(values)


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_dynarray_transient(get_contract):
code = """
my_list: public(transient(DynArray[uint256, 3]))
Expand All @@ -229,7 +234,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256:
c.my_list(0)


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_nested_dynarray_transient_2(get_contract):
code = """
my_list: public(transient(DynArray[DynArray[uint256, 3], 3]))
Expand All @@ -252,7 +257,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256:
assert c.get_idx_two(*values) == expected_values[2][2]


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_nested_dynarray_transient(get_contract):
code = """
my_list: public(transient(DynArray[DynArray[DynArray[int128, 3], 3], 3]))
Expand Down Expand Up @@ -307,7 +312,7 @@ def get_idx_two(x: int128, y: int128, z: int128) -> int128:
c.my_list(0, 0, 0)


@pytest.mark.transient
@pytest.mark.uses_transient_storage
@pytest.mark.parametrize("n", range(5))
def test_internal_function_with_transient(get_contract, n):
code = """
Expand All @@ -333,7 +338,7 @@ def bar(x: uint256) -> uint256:
assert c.bar(n) == n + 2


@pytest.mark.transient
@pytest.mark.uses_transient_storage
def test_nested_internal_function_transient(get_contract):
code = """
d: public(uint256)
Expand Down

0 comments on commit 0d05989

Please sign in to comment.