diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py new file mode 100644 index 0000000000..524b8df064 --- /dev/null +++ b/tests/compiler/asm/test_asm_optimizer.py @@ -0,0 +1,49 @@ +from vyper.compiler.phases import CompilerData + + +def test_dead_code_eliminator(): + code = """ +s: uint256 + +@internal +def foo(): + self.s = 1 + +@internal +def qux(): + self.s = 2 + +@external +def bar(): + self.foo() + +@external +def __init__(): + self.qux() + """ + + c = CompilerData(code, no_optimize=True) + initcode_asm = [i for i in c.assembly if not isinstance(i, list)] + runtime_asm = c.assembly_runtime + + foo_label = "_sym_internal_foo___" + qux_label = "_sym_internal_qux___" + + # all the labels should be in all the unoptimized asms + for s in (foo_label, qux_label): + assert s in initcode_asm + assert s in runtime_asm + + c = CompilerData(code, no_optimize=False) + initcode_asm = [i for i in c.assembly if not isinstance(i, list)] + runtime_asm = c.assembly_runtime + + # qux should not be in runtime code + for instr in runtime_asm: + if isinstance(instr, str): + assert not instr.startswith(qux_label), instr + + # foo should not be in initcode asm + for instr in initcode_asm: + if isinstance(instr, str): + assert not instr.startswith(foo_label), instr diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py index 71e38d253c..8707b4c326 100644 --- a/tests/functional/semantics/analysis/test_for_loop.py +++ b/tests/functional/semantics/analysis/test_for_loop.py @@ -108,14 +108,14 @@ def main(): for j in range(3): x: uint256 = j y: uint16 = j - """, # issue 3212 + """, # GH issue 3212 """ @external def foo(): for i in [1]: a:uint256 = i b:uint16 = i - """, # issue 3374 + """, # GH issue 3374 """ @external def foo(): @@ -123,7 +123,7 @@ def foo(): for j in [1]: a:uint256 = i b:uint16 = i - """, # issue 3374 + """, # GH issue 3374 """ @external def foo(): @@ -131,7 +131,7 @@ def foo(): for j in [1,2,3]: b:uint256 = j + i c:uint16 = i - """, # issue 3374 + """, # GH issue 3374 ] diff --git a/tests/parser/features/test_comparison.py b/tests/parser/features/test_comparison.py index 1c2f287c10..5a86ffb4b8 100644 --- a/tests/parser/features/test_comparison.py +++ b/tests/parser/features/test_comparison.py @@ -4,7 +4,7 @@ def test_3034_verbatim(get_contract): - # test issue #3034 exactly + # test GH issue 3034 exactly code = """ @view @external diff --git a/tests/parser/features/test_immutable.py b/tests/parser/features/test_immutable.py index bb01b3fc07..488943f784 100644 --- a/tests/parser/features/test_immutable.py +++ b/tests/parser/features/test_immutable.py @@ -239,3 +239,90 @@ def get_immutable() -> uint256: c = get_contract(code, n) assert c.get_immutable() == n + 2 + + +# GH issue 3292 +def test_internal_functions_called_by_ctor_location(get_contract): + code = """ +d: uint256 +x: immutable(uint256) + +@external +def __init__(): + self.d = 1 + x = 2 + self.a() + +@external +def test() -> uint256: + return self.d + +@internal +def a(): + self.d = x + """ + c = get_contract(code) + assert c.test() == 2 + + +# GH issue 3292, extended to nested internal functions +def test_nested_internal_function_immutables(get_contract): + code = """ +d: public(uint256) +x: public(immutable(uint256)) + +@external +def __init__(): + self.d = 1 + x = 2 + self.a() + +@internal +def a(): + self.b() + +@internal +def b(): + self.d = x + """ + c = get_contract(code) + assert c.x() == 2 + assert c.d() == 2 + + +# GH issue 3292, test immutable read from both ctor and runtime +def test_immutable_read_ctor_and_runtime(get_contract): + code = """ +d: public(uint256) +x: public(immutable(uint256)) + +@external +def __init__(): + self.d = 1 + x = 2 + self.a() + +@internal +def a(): + self.d = x + +@external +def thrash(): + self.d += 5 + +@external +def fix(): + self.a() + """ + c = get_contract(code) + assert c.x() == 2 + assert c.d() == 2 + + c.thrash(transact={}) + + assert c.x() == 2 + assert c.d() == 2 + 5 + + c.fix(transact={}) + assert c.x() == 2 + assert c.d() == 2 diff --git a/tests/parser/features/test_init.py b/tests/parser/features/test_init.py index feeabe311a..83bcbc95ea 100644 --- a/tests/parser/features/test_init.py +++ b/tests/parser/features/test_init.py @@ -53,3 +53,29 @@ def baz() -> uint8: n = 256 assert_compile_failed(lambda: get_contract(code, n)) + + +# GH issue 3206 +def test_nested_internal_call_from_ctor(get_contract): + code = """ +x: uint256 + +@external +def __init__(): + self.a() + +@internal +def a(): + self.x += 1 + self.b() + +@internal +def b(): + self.x += 2 + +@external +def test() -> uint256: + return self.x + """ + c = get_contract(code) + assert c.test() == 3 diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 696b81d124..34c409e16c 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -54,6 +54,7 @@ def __init__( forvars=None, constancy=Constancy.Mutable, sig=None, + is_ctor_context=False, ): # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} @@ -92,6 +93,9 @@ def __init__( self._internal_var_iter = 0 self._scope_id_iter = 0 + # either the constructor, or called from the constructor + self.is_ctor_context = is_ctor_context + def is_constant(self): return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 6da3d9501b..9ed80b86f9 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -184,7 +184,7 @@ def parse_Name(self): ofst = varinfo.position.offset - if self.context.sig.is_init_func: + if self.context.is_ctor_context: mutable = True location = IMMUTABLES else: diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 6dece865fa..cd467a152e 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -18,6 +18,7 @@ def generate_ir_for_function( sigs: Dict[str, Dict[str, FunctionSignature]], # all signatures in all namespaces global_ctx: GlobalContext, skip_nonpayable_check: bool, + is_ctor_context: bool = False, ) -> IRnode: """ Parse a function and produce IR code for the function, includes: @@ -51,6 +52,7 @@ def generate_ir_for_function( memory_allocator=memory_allocator, constancy=Constancy.Constant if sig.mutability in ("view", "pure") else Constancy.Mutable, sig=sig, + is_ctor_context=is_ctor_context, ) if sig.internal: @@ -65,13 +67,19 @@ def generate_ir_for_function( frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY - sig.set_frame_info(FrameInfo(allocate_start, frame_size, context.vars)) + frame_info = FrameInfo(allocate_start, frame_size, context.vars) + + if sig.frame_info is None: + sig.set_frame_info(frame_info) + else: + assert frame_info == sig.frame_info if not sig.internal: # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since # it is already accounted for by the caller.) + assert sig.frame_info is not None # mypy hint o.add_gas_estimate += calc_mem_gas(sig.frame_info.mem_used) sig.gas_estimate = o.gas diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 71f9ed552d..bdf8c067f7 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -67,19 +67,21 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx): # create a map of the IR functions since they might live in both # runtime and deploy code (if init function calls them) - internal_functions_map: Dict[str, IRnode] = {} + internal_functions_ir: list[IRnode] = [] for func_ast in internal_functions: func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False) - internal_functions_map[func_ast.name] = func_ir + internal_functions_ir.append(func_ir) # for some reason, somebody may want to deploy a contract with no # external functions, or more likely, a "pure data" contract which # contains immutables if len(external_functions) == 0: - # TODO: prune internal functions in this case? - runtime = ["seq"] + list(internal_functions_map.values()) - return runtime, internal_functions_map + # TODO: prune internal functions in this case? dead code eliminator + # might not eliminate them, since internal function jumpdest is at the + # first instruction in the contract. + runtime = ["seq"] + internal_functions_ir + return runtime # note: if the user does not provide one, the default fallback function # reverts anyway. so it does not hurt to batch the payable check. @@ -125,10 +127,10 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx): ["label", "fallback", ["var_list"], fallback_ir], ] - # TODO: prune unreachable functions? - runtime.extend(internal_functions_map.values()) + # note: dead code eliminator will clean dead functions + runtime.extend(internal_functions_ir) - return runtime, internal_functions_map + return runtime # take a GlobalContext, which is basically @@ -159,12 +161,15 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F runtime_functions = [f for f in function_defs if not _is_init_func(f)] init_function = next((f for f in function_defs if _is_init_func(f)), None) - runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs, global_ctx) + runtime = _runtime_ir(runtime_functions, all_sigs, global_ctx) deploy_code: List[Any] = ["seq"] immutables_len = global_ctx.immutable_section_bytes if init_function: - init_func_ir = generate_ir_for_function(init_function, all_sigs, global_ctx, False) + # TODO might be cleaner to separate this into an _init_ir helper func + init_func_ir = generate_ir_for_function( + init_function, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True + ) deploy_code.append(init_func_ir) # pass the amount of memory allocated for the init function @@ -174,8 +179,13 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F deploy_code.append(["deploy", init_mem_used, runtime, immutables_len]) # internal functions come after everything else - for f in init_function._metadata["type"].called_functions: - deploy_code.append(internal_functions[f.name]) + internal_functions = [f for f in runtime_functions if _is_internal(f)] + for f in internal_functions: + func_ir = generate_ir_for_function( + f, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True + ) + # note: we depend on dead code eliminator to clean dead function defs + deploy_code.append(func_ir) else: if immutables_len != 0: