diff --git a/tests/conftest.py b/tests/conftest.py index b3554f5493..507cde48bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.evm.opcodes import version_check -from vyper.exceptions import StructureException +from vyper.exceptions import TransientStorageException from vyper.ir import compile_ir, optimizer from vyper.utils import ERC5202_PREFIX @@ -327,14 +327,13 @@ def _get_contract( *args, override_opt_level=None, input_bundle=None, - has_transient_storage=False, **kwargs, ): settings = Settings() settings.optimize = override_opt_level or optimize - if has_transient_storage and not version_check(begin="cancun"): - with pytest.raises(StructureException): + if not version_check(begin="cancun"): + with pytest.raises(TransientStorageException): compiler.compile_code(source_code) return diff --git a/tests/functional/codegen/features/test_transient.py b/tests/functional/codegen/features/test_transient.py index 503a01d909..9817f136ce 100644 --- a/tests/functional/codegen/features/test_transient.py +++ b/tests/functional/codegen/features/test_transient.py @@ -2,7 +2,7 @@ from vyper.compiler import compile_code from vyper.evm.opcodes import version_check -from vyper.exceptions import StructureException, VyperException +from vyper.exceptions import TransientStorageException, VyperException def test_transient_blocked(get_contract): @@ -10,7 +10,7 @@ def test_transient_blocked(get_contract): code = """ my_map: transient(HashMap[address, uint256]) """ - get_contract(code, has_transient_storage=True) + get_contract(code) def test_transient_compiles(): @@ -76,7 +76,7 @@ def foo(a: {typ}) -> {typ}: return self.bar """ - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.foo(value) == value @@ -100,7 +100,7 @@ def a1() -> uint256: return self.A """ - c = get_contract(code, val, has_transient_storage=True) + c = get_contract(code, val) if version_check(begin="cancun"): assert c.a() == val assert c.a1() == 0 @@ -122,7 +122,7 @@ def foo(_a: uint256, _b: address, _c: String[64]) -> (uint256, address, String[6 values = (3, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", "Hello world") if version_check(begin="cancun"): - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) assert c.foo(*values) == list(values) else: # multiple errors @@ -152,7 +152,7 @@ def foo(_a: uint256, _b: uint256, _c: address, _d: int256) -> MyStruct: """ values = (100, 42, "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE", -(2**200)) - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.foo(*values) == values @@ -174,7 +174,7 @@ def foo(a: uint256) -> MyStruct: return self.my_struct """ - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.foo(1) == (2,) @@ -190,7 +190,7 @@ def foo(_a: uint256, _b: uint256, _c: uint256) -> uint256[3]: """ values = (100, 42, 23230) - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.foo(*values) == list(values) @@ -211,7 +211,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256: """ values = (100, 42, 23230) - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.get_my_list(*values) == list(values) assert c.get_idx_two(*values) == values[2] @@ -234,7 +234,7 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256: values = (100, 42, 23230) expected_values = [[100, 42, 23230], [42, 100, 23230], [23230, 42, 100]] - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.get_my_list(*values) == expected_values assert c.get_idx_two(*values) == expected_values[2][2] @@ -285,7 +285,7 @@ def get_idx_two(x: int128, y: int128, z: int128) -> int128: [[146, 123, 148], [-146, -123, -148], [-2993, -1517, -2701]], ] - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.get_my_list(*values) == expected_values assert c.get_idx_two(*values) == expected_values[2][2][2] @@ -310,7 +310,7 @@ def bar(x: uint256) -> uint256: return self.val """ - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.bar(n) == n + 2 @@ -335,7 +335,7 @@ def b(): self.d = self.x """ - c = get_contract(code, has_transient_storage=True) + c = get_contract(code) if version_check(begin="cancun"): assert c.d() == 2 diff --git a/vyper/exceptions.py b/vyper/exceptions.py index ced249f247..7a1d9ad3f2 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -327,6 +327,10 @@ class EvmVersionException(VyperException): """Invalid action for the active EVM ruleset.""" +class TransientStorageException(EvmVersionException): + """Transient storage is not supported for the active EVM ruleset.""" + + class StorageLayoutException(VyperException): """Invalid slot for the storage layout overrides""" diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index f4b7db129f..c86c779366 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -24,6 +24,7 @@ ModuleNotFound, StateAccessViolation, StructureException, + TransientStorageException, UndeclaredDefinition, VyperException, tag_exceptions, @@ -608,7 +609,9 @@ def visit_VariableDecl(self, node): type_ = type_from_annotation(node.annotation, data_loc) if node.is_transient and not version_check(begin="cancun"): - raise StructureException("`transient` is not available pre-cancun", node.annotation) + raise TransientStorageException( + "`transient` is not available pre-cancun", node.annotation + ) var_info = VarInfo( type_,