diff --git a/README.md b/README.md index bad929956d..33c4557cc8 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,23 @@ make dev-init python setup.py test ``` +## Developing (working on the compiler) + +A useful script to have in your PATH is something like the following: +```bash +$ cat ~/.local/bin/vyc +#!/usr/bin/env bash +PYTHONPATH=. python vyper/cli/vyper_compile.py "$@" +``` + +To run a python performance profile (to find compiler perf hotspots): +```bash +PYTHONPATH=. python -m cProfile -s tottime vyper/cli/vyper_compile.py "$@" +``` + +To get a call graph from a python profile, https://stackoverflow.com/a/23164271/ is helpful. + + # Contributing * See Issues tab, and feel free to submit your own issues * Add PRs if you discover a solution to an existing issue diff --git a/docs/contributing.rst b/docs/contributing.rst index 6dc57b26c3..221600f930 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -75,4 +75,4 @@ If you are making a larger change, please consult first with the `Vyper (Smart C Although we do CI testing, please make sure that the tests pass for supported Python version and ensure that it builds locally before submitting a pull request. -Thank you for your help! ​ +Thank you for your help! diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 56b34308f1..6d07e15bc4 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,4 +1,8 @@ -# Setup private variables (only callable from within the contract) +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# example of a crowd funding contract funders: HashMap[address, uint256] beneficiary: address diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index be9c62b945..d385d2e0c6 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -9,7 +9,7 @@ invariant: public(uint256) token_address: ERC20 owner: public(address) -# Sets the on chain market maker with its owner, intial token quantity, +# Sets the on chain market maker with its owner, initial token quantity, # and initial ether quantity @external @payable diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index f1070b8f89..30057582e8 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,13 @@ +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + # @version >=0.3.4 """ -@dev Implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) +@dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) """ + ############### imports ############### from vyper.interfaces import ERC165 diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 4c1d334691..c3809dbb60 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-20 token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of an ERC20 token # @author Takayuki Jimba (@yudetamago) # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-20.md diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index a9cbcc86c8..0a0a698bf0 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,4 +1,11 @@ # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy + +# example implementation of an ERC4626 vault + +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + from vyper.interfaces import ERC20 from vyper.interfaces import ERC4626 diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 5125040399..152b94b046 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-721 non-fungible token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of ERC-721 non-fungible token standard. # @author Ryuya Nakamura (@nrryuya) # Modified from: https://github.com/vyperlang/vyper/blob/de74722bf2d8718cca46902be165f9fe0e3641dd/examples/tokens/ERC721.vy diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 5fd5229136..e2515d9e62 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,5 +1,8 @@ -# An example of how you can do a wallet in Vyper. -# Warning: NOT AUDITED. Do not use to store substantial quantities of funds. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# An example of how you can implement a wallet in Vyper. # A list of the owners addresses (there are a maximum of 5 owners) owners: public(address[5]) diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/compiler/venom/test_duplicate_operands.py new file mode 100644 index 0000000000..505f01e31b --- /dev/null +++ b/tests/compiler/venom/test_duplicate_operands.py @@ -0,0 +1,28 @@ +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRFunction + + +def test_duplicate_operands(): + """ + Test the duplicate operands code generation. + The venom code: + + %1 = 10 + %2 = add %1, %1 + %3 = mul %1, %2 + stop + + Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] + """ + ctx = IRFunction() + + op = ctx.append_instruction("store", [IRLiteral(10)]) + sum = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("mul", [sum, op]) + ctx.append_instruction("stop", [], False) + + asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/compiler/venom/test_multi_entry_block.py new file mode 100644 index 0000000000..bb57fa1065 --- /dev/null +++ b/tests/compiler/venom/test_multi_entry_block.py @@ -0,0 +1,96 @@ +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel +from vyper.venom.passes.normalization import NormalizationPass + + +def test_multi_entry_block_1(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +# more complicated one +def test_multi_entry_block_2(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + block_2_label = IRLabel("block_2", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + + block_2 = IRBasicBlock(block_2_label, ctx) + ctx.append_basic_block(block_2) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + # switch the order of the labels, for fun + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/compiler/venom/test_stack_at_external_return.py new file mode 100644 index 0000000000..be9fa66e9a --- /dev/null +++ b/tests/compiler/venom/test_stack_at_external_return.py @@ -0,0 +1,5 @@ +def test_stack_at_external_return(): + """ + TODO: USE BOA DO GENERATE THIS TEST + """ + pass diff --git a/vyper/__main__.py b/vyper/__main__.py index 371975c301..c5bda47bea 100644 --- a/vyper/__main__.py +++ b/vyper/__main__.py @@ -2,10 +2,10 @@ # -*- coding: UTF-8 -*- import sys -from vyper.cli import vyper_compile, vyper_ir, vyper_serve +from vyper.cli import vyper_compile, vyper_ir if __name__ == "__main__": - allowed_subcommands = ("--vyper-compile", "--vyper-ir", "--vyper-serve") + allowed_subcommands = ("--vyper-compile", "--vyper-ir") if len(sys.argv) <= 1 or sys.argv[1] not in allowed_subcommands: # default (no args, no switch in first arg): run vyper_compile @@ -13,9 +13,7 @@ else: # pop switch and forward args to subcommand subcommand = sys.argv.pop(1) - if subcommand == "--vyper-serve": - vyper_serve._parse_cli_args() - elif subcommand == "--vyper-ir": + if subcommand == "--vyper-ir": vyper_ir._parse_cli_args() else: vyper_compile._parse_cli_args() diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 2802421129..a5949dfd85 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,5 +1,5 @@ import functools -from typing import Dict +from typing import Any, Optional from vyper.ast import nodes as vy_ast from vyper.ast.validation import validate_call_args @@ -74,12 +74,14 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction(VyperType): +class BuiltinFunctionT(VyperType): _has_varargs = False - _kwargs: Dict[str, KwargSettings] = {} + _inputs: list[tuple[str, Any]] = [] + _kwargs: dict[str, KwargSettings] = {} + _return_type: Optional[VyperType] = None # helper function to deal with TYPE_DEFINITIONs - def _validate_single(self, arg, expected_type): + def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: # TODO using "TYPE_DEFINITION" is a kludge in derived classes, # refactor me. if expected_type == "TYPE_DEFINITION": @@ -89,15 +91,15 @@ def _validate_single(self, arg, expected_type): else: validate_expected_type(arg, expected_type) - def _validate_arg_types(self, node): + def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates - expect_num_args = num_args + expect_num_args: Any = num_args if self._has_varargs: # note special meaning for -1 in validate_call_args API expect_num_args = (num_args, -1) - validate_call_args(node, expect_num_args, self._kwargs) + validate_call_args(node, expect_num_args, list(self._kwargs.keys())) for arg, (_, expected) in zip(node.args, self._inputs): self._validate_single(arg, expected) @@ -118,13 +120,12 @@ def _validate_arg_types(self, node): # ensures the type can be inferred exactly. get_exact_type_from_node(arg) - def fetch_call_return(self, node): + def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) - if self._return_type: - return self._return_type + return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] @@ -136,7 +137,7 @@ def infer_arg_types(self, node): ret.extend(get_exact_type_from_node(arg) for arg in varargs) return ret - def infer_kwarg_types(self, node): + def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} def __repr__(self): diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7f9fa55bc7..c6237b681f 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -98,14 +98,14 @@ ) from ._convert import convert -from ._signatures import BuiltinFunction, process_inputs +from ._signatures import BuiltinFunctionT, process_inputs SHA256_ADDRESS = 2 SHA256_BASE_GAS = 60 SHA256_PER_WORD_GAS = 12 -class FoldedFunction(BuiltinFunction): +class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded # Since foldable builtin functions are not folded before semantics validation, @@ -113,7 +113,7 @@ class FoldedFunction(BuiltinFunction): _kwargable = True -class TypenameFoldedFunction(FoldedFunction): +class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. @@ -132,7 +132,7 @@ def infer_arg_types(self, node): return [input_typedef] -class Floor(BuiltinFunction): +class Floor(BuiltinFunctionT): _id = "floor" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -162,7 +162,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Ceil(BuiltinFunction): +class Ceil(BuiltinFunctionT): _id = "ceil" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -192,7 +192,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Convert(BuiltinFunction): +class Convert(BuiltinFunctionT): _id = "convert" def fetch_call_return(self, node): @@ -285,14 +285,13 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: # note: this and a lot of other builtins could be refactored to accept any uint type -class Slice(BuiltinFunction): +class Slice(BuiltinFunctionT): _id = "slice" _inputs = [ ("b", (BYTES32_T, BytesT.any(), StringT.any())), ("start", UINT256_T), ("length", UINT256_T), ] - _return_type = None def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node) @@ -453,7 +452,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(b3.resolve(ret))) -class Len(BuiltinFunction): +class Len(BuiltinFunctionT): _id = "len" _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T @@ -484,7 +483,7 @@ def build_IR(self, node, context): return get_bytearray_length(arg) -class Concat(BuiltinFunction): +class Concat(BuiltinFunctionT): _id = "concat" def fetch_call_return(self, node): @@ -588,7 +587,7 @@ def build_IR(self, expr, context): ) -class Keccak256(BuiltinFunction): +class Keccak256(BuiltinFunctionT): _id = "keccak256" # TODO allow any BytesM_T _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] @@ -636,7 +635,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len): ] -class Sha256(BuiltinFunction): +class Sha256(BuiltinFunctionT): _id = "sha256" _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T @@ -708,7 +707,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class MethodID(FoldedFunction): +class MethodID(FoldedFunctionT): _id = "method_id" def evaluate(self, node): @@ -748,7 +747,7 @@ def infer_kwarg_types(self, node): return BytesT(4) -class ECRecover(BuiltinFunction): +class ECRecover(BuiltinFunctionT): _id = "ecrecover" _inputs = [ ("hash", BYTES32_T), @@ -783,7 +782,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class _ECArith(BuiltinFunction): +class _ECArith(BuiltinFunctionT): @process_inputs def build_IR(self, expr, _args, kwargs, context): args_tuple = ir_tuple_from_args(_args) @@ -842,14 +841,13 @@ def _storage_element_getter(index): return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T) -class Extract32(BuiltinFunction): +class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] # "TYPE_DEFINITION" is a placeholder value for a type definition string, and # will be replaced by a `TYPE_T` object in `infer_kwarg_types` # (note that it is ignored in _validate_arg_types) _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -954,7 +952,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(clamp_basetype(o), typ=ret_type) -class AsWeiValue(BuiltinFunction): +class AsWeiValue(BuiltinFunctionT): _id = "as_wei_value" _inputs = [("value", (IntegerT.any(), DecimalT())), ("unit", StringT.any())] _return_type = UINT256_T @@ -1053,7 +1051,7 @@ def build_IR(self, expr, args, kwargs, context): empty_value = IRnode.from_list(0, typ=BYTES32_T) -class RawCall(BuiltinFunction): +class RawCall(BuiltinFunctionT): _id = "raw_call" _inputs = [("to", AddressT()), ("data", BytesT.any())] _kwargs = { @@ -1064,7 +1062,6 @@ class RawCall(BuiltinFunction): "is_static_call": KwargSettings(BoolT(), False, require_literal=True), "revert_on_failure": KwargSettings(BoolT(), True, require_literal=True), } - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1209,12 +1206,11 @@ def build_IR(self, expr, args, kwargs, context): raise CompilerPanic("unreachable!") -class Send(BuiltinFunction): +class Send(BuiltinFunctionT): _id = "send" _inputs = [("to", AddressT()), ("value", UINT256_T)] # default gas stipend is 0 _kwargs = {"gas": KwargSettings(UINT256_T, 0)} - _return_type = None @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -1226,10 +1222,9 @@ def build_IR(self, expr, args, kwargs, context): ) -class SelfDestruct(BuiltinFunction): +class SelfDestruct(BuiltinFunctionT): _id = "selfdestruct" _inputs = [("to", AddressT())] - _return_type = None _is_terminus = True _warned = False @@ -1245,7 +1240,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BlockHash(BuiltinFunction): +class BlockHash(BuiltinFunctionT): _id = "blockhash" _inputs = [("block_num", UINT256_T)] _return_type = BYTES32_T @@ -1258,7 +1253,7 @@ def build_IR(self, expr, args, kwargs, contact): ) -class RawRevert(BuiltinFunction): +class RawRevert(BuiltinFunctionT): _id = "raw_revert" _inputs = [("data", BytesT.any())] _return_type = None @@ -1280,7 +1275,7 @@ def build_IR(self, expr, args, kwargs, context): return b.resolve(IRnode.from_list(["revert", data, len_])) -class RawLog(BuiltinFunction): +class RawLog(BuiltinFunctionT): _id = "raw_log" _inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))] @@ -1331,7 +1326,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BitwiseAnd(BuiltinFunction): +class BitwiseAnd(BuiltinFunctionT): _id = "bitwise_and" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1357,7 +1352,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["and", args[0], args[1]], typ=UINT256_T) -class BitwiseOr(BuiltinFunction): +class BitwiseOr(BuiltinFunctionT): _id = "bitwise_or" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1383,7 +1378,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["or", args[0], args[1]], typ=UINT256_T) -class BitwiseXor(BuiltinFunction): +class BitwiseXor(BuiltinFunctionT): _id = "bitwise_xor" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1409,7 +1404,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["xor", args[0], args[1]], typ=UINT256_T) -class BitwiseNot(BuiltinFunction): +class BitwiseNot(BuiltinFunctionT): _id = "bitwise_not" _inputs = [("x", UINT256_T)] _return_type = UINT256_T @@ -1436,7 +1431,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["not", args[0]], typ=UINT256_T) -class Shift(BuiltinFunction): +class Shift(BuiltinFunctionT): _id = "shift" _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())] _return_type = UINT256_T @@ -1490,7 +1485,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(IRnode.from_list(ret, typ=argty))) -class _AddMulMod(BuiltinFunction): +class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T @@ -1531,7 +1526,7 @@ class MulMod(_AddMulMod): _opcode = "mulmod" -class PowMod256(BuiltinFunction): +class PowMod256(BuiltinFunctionT): _id = "pow_mod256" _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T @@ -1554,7 +1549,7 @@ def build_IR(self, expr, context): return IRnode.from_list(["exp", left, right], typ=left.typ) -class Abs(BuiltinFunction): +class Abs(BuiltinFunctionT): _id = "abs" _inputs = [("value", INT256_T)] _return_type = INT256_T @@ -1705,7 +1700,7 @@ def _create_preamble(codesize): return ["or", bytes_to_int(evm), shl(shl_bits, codesize)], evm_len -class _CreateBase(BuiltinFunction): +class _CreateBase(BuiltinFunctionT): _kwargs = { "value": KwargSettings(UINT256_T, zero_value), "salt": KwargSettings(BYTES32_T, empty_value), @@ -1934,7 +1929,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar return b1.resolve(b2.resolve(ir)) -class _UnsafeMath(BuiltinFunction): +class _UnsafeMath(BuiltinFunctionT): # TODO add unsafe math for `decimal`s _inputs = [("a", IntegerT.any()), ("b", IntegerT.any())] @@ -2000,7 +1995,7 @@ class UnsafeDiv(_UnsafeMath): op = "div" -class _MinMax(BuiltinFunction): +class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] def evaluate(self, node): @@ -2074,7 +2069,7 @@ class Max(_MinMax): _opcode = "gt" -class Uint2Str(BuiltinFunction): +class Uint2Str(BuiltinFunctionT): _id = "uint2str" _inputs = [("x", IntegerT.unsigneds())] @@ -2146,7 +2141,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, location=MEMORY, typ=return_t)) -class Sqrt(BuiltinFunction): +class Sqrt(BuiltinFunctionT): _id = "sqrt" _inputs = [("d", DecimalT())] _return_type = DecimalT() @@ -2202,7 +2197,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class ISqrt(BuiltinFunction): +class ISqrt(BuiltinFunctionT): _id = "isqrt" _inputs = [("d", UINT256_T)] _return_type = UINT256_T @@ -2252,7 +2247,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, typ=UINT256_T)) -class Empty(TypenameFoldedFunction): +class Empty(TypenameFoldedFunctionT): _id = "empty" def fetch_call_return(self, node): @@ -2267,7 +2262,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode("~empty", typ=output_type) -class Breakpoint(BuiltinFunction): +class Breakpoint(BuiltinFunctionT): _id = "breakpoint" _inputs: list = [] @@ -2285,7 +2280,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list("breakpoint", annotation="breakpoint()") -class Print(BuiltinFunction): +class Print(BuiltinFunctionT): _id = "print" _inputs: list = [] _has_varargs = True @@ -2363,7 +2358,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, annotation="print:" + sig) -class ABIEncode(BuiltinFunction): +class ABIEncode(BuiltinFunctionT): _id = "_abi_encode" # TODO prettier to rename this to abi.encode # signature: *, ensure_tuple= -> Bytes[] # explanation of ensure_tuple: @@ -2478,7 +2473,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, location=MEMORY, typ=buf_t) -class ABIDecode(BuiltinFunction): +class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} @@ -2565,7 +2560,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class _MinMaxValue(TypenameFoldedFunction): +class _MinMaxValue(TypenameFoldedFunctionT): def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2599,7 +2594,7 @@ def _eval(self, type_): return type_.ast_bounds[1] -class Epsilon(TypenameFoldedFunction): +class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" def evaluate(self, node): diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 82eba63f32..ca1792384e 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -141,6 +141,11 @@ def _parse_args(argv): "-p", help="Set the root path for contract imports", default=".", dest="root_folder" ) parser.add_argument("-o", help="Set the output path", dest="output_path") + parser.add_argument( + "--experimental-codegen", + help="The compiler use the new IR codegen. This is an experimental feature.", + action="store_true", + ) args = parser.parse_args(argv) @@ -188,6 +193,7 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, + args.experimental_codegen, ) if args.output_path: @@ -225,6 +231,7 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> dict: root_path = Path(root_folder).resolve() if not root_path.exists(): @@ -275,6 +282,7 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, + experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/cli/vyper_serve.py b/vyper/cli/vyper_serve.py deleted file mode 100755 index 9771dc922d..0000000000 --- a/vyper/cli/vyper_serve.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import sys -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn - -import vyper -from vyper.codegen import ir_node -from vyper.evm.opcodes import DEFAULT_EVM_VERSION -from vyper.exceptions import VyperException - - -def _parse_cli_args(): - return _parse_args(sys.argv[1:]) - - -def _parse_args(argv): - parser = argparse.ArgumentParser(description="Serve Vyper compiler as an HTTP Service") - parser.add_argument( - "--version", action="version", version=f"{vyper.__version__}+commit{vyper.__commit__}" - ) - parser.add_argument( - "-b", - help="Address to bind JSON server on, default: localhost:8000", - default="localhost:8000", - dest="bind_address", - ) - - args = parser.parse_args(argv) - - if ":" in args.bind_address: - ir_node.VYPER_COLOR_OUTPUT = False - runserver(*args.bind_address.split(":")) - else: - print('Provide bind address in "{address}:{port}" format') - - -class VyperRequestHandler(BaseHTTPRequestHandler): - def send_404(self): - self.send_response(404) - self.end_headers() - return - - def send_cors_all(self): - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-type") - - def do_OPTIONS(self): - self.send_response(200) - self.send_cors_all() - self.end_headers() - - def do_GET(self): - if self.path == "/": - self.send_response(200) - self.send_cors_all() - self.end_headers() - self.wfile.write(f"Vyper Compiler. Version: {vyper.__version__}\n".encode()) - else: - self.send_404() - - return - - def do_POST(self): - if self.path == "/compile": - content_len = int(self.headers.get("content-length")) - post_body = self.rfile.read(content_len) - data = json.loads(post_body) - - response, status_code = self._compile(data) - - self.send_response(status_code) - self.send_header("Content-type", "application/json") - self.send_cors_all() - self.end_headers() - self.wfile.write(json.dumps(response).encode()) - - else: - self.send_404() - - return - - def _compile(self, data): - code = data.get("code") - if not code: - return {"status": "failed", "message": 'No "code" key supplied'}, 400 - if not isinstance(code, str): - return {"status": "failed", "message": '"code" must be a non-empty string'}, 400 - - try: - code = data["code"] - out_dict = vyper.compile_code( - code, - list(vyper.compiler.OUTPUT_FORMATS.keys()), - evm_version=data.get("evm_version", DEFAULT_EVM_VERSION), - ) - out_dict["ir"] = str(out_dict["ir"]) - out_dict["ir_runtime"] = str(out_dict["ir_runtime"]) - except VyperException as e: - return ( - {"status": "failed", "message": str(e), "column": e.col_offset, "line": e.lineno}, - 400, - ) - except SyntaxError as e: - return ( - {"status": "failed", "message": str(e), "column": e.offset, "line": e.lineno}, - 400, - ) - - out_dict.update({"status": "success"}) - - return out_dict, 200 - - -class VyperHTTPServer(ThreadingMixIn, HTTPServer): - """Handle requests in a separate thread.""" - - pass - - -def runserver(host="", port=8000): - server_address = (host, int(port)) - httpd = VyperHTTPServer(server_address, VyperRequestHandler) - print(f"Listening on http://{host}:{port}") - httpd.serve_forever() diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 1d24b6c6dd..c48f1256c3 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -162,5 +162,9 @@ def generate_ir_for_function( # (note: internal functions do not need to adjust gas estimate since mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore + ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore + ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 228191e3ca..cf01dbdab4 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -68,4 +68,6 @@ def generate_ir_for_internal_function( ["seq"] + nonreentrant_post + [["exit_to", "return_pc"]], ] - return IRnode.from_list(["seq", body, cleanup_routine]) + ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + ir_node.passthrough_metadata["func_t"] = func_t + return ir_node diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ad4aa76437..ce26066968 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -171,6 +171,10 @@ class IRnode: valency: int args: List["IRnode"] value: Union[str, int] + is_self_call: bool + passthrough_metadata: dict[str, Any] + func_ir: Any + common_ir: Any def __init__( self, @@ -184,6 +188,8 @@ def __init__( mutable: bool = True, add_gas_estimate: int = 0, encoding: Encoding = Encoding.VYPER, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, ): if args is None: args = [] @@ -201,28 +207,28 @@ def __init__( self.add_gas_estimate = add_gas_estimate self.encoding = encoding self.as_hex = AS_HEX_DEFAULT + self.is_self_call = is_self_call + self.passthrough_metadata = passthrough_metadata or {} + self.func_ir = None + self.common_ir = None - def _check(condition, err): - if not condition: - raise CompilerPanic(str(err)) - - _check(self.value is not None, "None is not allowed as IRnode value") + assert self.value is not None, "None is not allowed as IRnode value" # Determine this node's valency (1 if it pushes a value on the stack, # 0 otherwise) and checks to make sure the number and valencies of # children are correct. Also, find an upper bound on gas consumption # Numbers if isinstance(self.value, int): - _check(len(self.args) == 0, "int can't have arguments") + assert len(self.args) == 0, "int can't have arguments" # integers must be in the range (MIN_INT256, MAX_UINT256) - _check(-(2**255) <= self.value < 2**256, "out of range") + assert -(2**255) <= self.value < 2**256, "out of range" self.valency = 1 self._gas = 5 elif isinstance(self.value, bytes): # a literal bytes value, probably inside a "data" node. - _check(len(self.args) == 0, "bytes can't have arguments") + assert len(self.args) == 0, "bytes can't have arguments" self.valency = 0 self._gas = 0 @@ -232,10 +238,9 @@ def _check(condition, err): if self.value.upper() in get_ir_opcodes(): _, ins, outs, gas = get_ir_opcodes()[self.value.upper()] self.valency = outs - _check( - len(self.args) == ins, - f"Number of arguments mismatched: {self.value} {self.args}", - ) + assert ( + len(self.args) == ins + ), f"Number of arguments mismatched: {self.value} {self.args}" # We add 2 per stack height at push time and take it back # at pop time; this makes `break` easier to handle self._gas = gas + 2 * (outs - ins) @@ -244,10 +249,10 @@ def _check(condition, err): # consumed for internal functions, therefore we whitelist this as a zero valency # allowed argument. zero_valency_whitelist = {"pass", "pop"} - _check( - arg.valency == 1 or arg.value in zero_valency_whitelist, - f"invalid argument to `{self.value}`: {arg}", - ) + assert ( + arg.valency == 1 or arg.value in zero_valency_whitelist + ), f"invalid argument to `{self.value}`: {arg}" + self._gas += arg.gas # Dynamic gas cost: 8 gas for each byte of logging data if self.value.upper()[0:3] == "LOG" and isinstance(self.args[1].value, int): @@ -275,30 +280,27 @@ def _check(condition, err): self._gas = self.args[0].gas + max(self.args[1].gas, self.args[2].gas) + 3 if len(self.args) == 2: self._gas = self.args[0].gas + self.args[1].gas + 17 - _check( - self.args[0].valency > 0, - f"zerovalent argument as a test to an if statement: {self.args[0]}", - ) - _check(len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments") + assert ( + self.args[0].valency > 0 + ), f"zerovalent argument as a test to an if statement: {self.args[0]}" + assert len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments" self.valency = self.args[1].valency # With statements: with elif self.value == "with": - _check(len(self.args) == 3, self) - _check( - len(self.args[0].args) == 0 and isinstance(self.args[0].value, str), - f"first argument to with statement must be a variable name: {self.args[0]}", - ) - _check( - self.args[1].valency == 1 or self.args[1].value == "pass", - f"zerovalent argument to with statement: {self.args[1]}", - ) + assert len(self.args) == 3, self + assert len(self.args[0].args) == 0 and isinstance( + self.args[0].value, str + ), f"first argument to with statement must be a variable name: {self.args[0]}" + assert ( + self.args[1].valency == 1 or self.args[1].value == "pass" + ), f"zerovalent argument to with statement: {self.args[1]}" self.valency = self.args[2].valency self._gas = sum([arg.gas for arg in self.args]) + 5 # Repeat statements: repeat elif self.value == "repeat": - _check( - len(self.args) == 5, "repeat(index_name, startval, rounds, rounds_bound, body)" - ) + assert ( + len(self.args) == 5 + ), "repeat(index_name, startval, rounds, rounds_bound, body)" counter_ptr = self.args[0] start = self.args[1] @@ -306,13 +308,12 @@ def _check(condition, err): repeat_bound = self.args[3] body = self.args[4] - _check( - isinstance(repeat_bound.value, int) and repeat_bound.value > 0, - f"repeat bound must be a compile-time positive integer: {self.args[2]}", - ) - _check(repeat_count.valency == 1, repeat_count) - _check(counter_ptr.valency == 1, counter_ptr) - _check(start.valency == 1, start) + assert ( + isinstance(repeat_bound.value, int) and repeat_bound.value > 0 + ), f"repeat bound must be a compile-time positive integer: {self.args[2]}" + assert repeat_count.valency == 1, repeat_count + assert counter_ptr.valency == 1, counter_ptr + assert start.valency == 1, start self.valency = 0 @@ -335,19 +336,17 @@ def _check(condition, err): # then JUMP to my_label. elif self.value in ("goto", "exit_to"): for arg in self.args: - _check( - arg.valency == 1 or arg.value == "pass", - f"zerovalent argument to goto {arg}", - ) + assert ( + arg.valency == 1 or arg.value == "pass" + ), f"zerovalent argument to goto {arg}" self.valency = 0 self._gas = sum([arg.gas for arg in self.args]) elif self.value == "label": - _check( - self.args[1].value == "var_list", - f"2nd argument to label must be var_list, {self}", - ) - _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}") + assert ( + self.args[1].value == "var_list" + ), f"2nd argument to label must be var_list, {self}" + assert len(args) == 3, f"label should have 3 args but has {len(args)}, {self}" self.valency = 0 self._gas = 1 + sum(t.gas for t in self.args) elif self.value == "unique_symbol": @@ -371,14 +370,14 @@ def _check(condition, err): # Multi statements: multi ... elif self.value == "multi": for arg in self.args: - _check( - arg.valency > 0, f"Multi expects all children to not be zerovalent: {arg}" - ) + assert ( + arg.valency > 0 + ), f"Multi expects all children to not be zerovalent: {arg}" self.valency = sum([arg.valency for arg in self.args]) self._gas = sum([arg.gas for arg in self.args]) elif self.value == "deploy": self.valency = 0 - _check(len(self.args) == 3, f"`deploy` should have three args {self}") + assert len(self.args) == 3, f"`deploy` should have three args {self}" self._gas = NullAttractor() # unknown # Stack variables else: @@ -596,6 +595,8 @@ def from_list( error_msg: Optional[str] = None, mutable: bool = True, add_gas_estimate: int = 0, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, encoding: Encoding = Encoding.VYPER, ) -> "IRnode": if isinstance(typ, str): @@ -628,6 +629,8 @@ def from_list( source_pos=source_pos, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) else: return cls( @@ -641,4 +644,6 @@ def from_list( add_gas_estimate=add_gas_estimate, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 56bea2b8da..41fa11ab56 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -40,7 +40,9 @@ def finalize(fill_return_buffer): cleanup_loops = "cleanup_repeat" if context.forvars else "seq" # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break - return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit]) + jump_to_exit_ir = IRnode.from_list(jump_to_exit) + jump_to_exit_ir.passthrough_metadata["func_t"] = func_t + return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: if context.is_internal: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index c320e6889c..f03f2eb9c8 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -121,4 +121,6 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True + o.passthrough_metadata["func_t"] = func_t + o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 62ea05b243..61d7a7c229 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -55,6 +55,7 @@ def compile_code( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, + experimental_codegen: bool = False, ) -> dict: """ Generate consumable compiler output(s) from a single contract source code. @@ -104,6 +105,7 @@ def compile_code( storage_layout_override, show_gas_estimates, no_bytecode_metadata, + experimental_codegen, ) ret = {} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index bfbb336d54..4e32812fee 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -16,6 +16,7 @@ from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.typing import StorageLayout +from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") @@ -60,6 +61,7 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -78,14 +80,18 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool, optional + Use experimental codegen. Defaults to False """ + # to force experimental codegen, uncomment: + # experimental_codegen = True self.contract_path = contract_path self.source_code = source_code self.source_id = source_id self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - + self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -160,7 +166,11 @@ def global_ctx(self) -> GlobalContext: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes(self.global_ctx, self.settings.optimize) + nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) + if self.experimental_codegen: + return [generate_ir(nodes[0]), generate_ir(nodes[1])] + else: + return nodes @property def ir_nodes(self) -> IRnode: @@ -183,11 +193,21 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - return generate_assembly(self.ir_nodes, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_nodes, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_nodes, self.settings.optimize) @cached_property def assembly_runtime(self) -> list: - return generate_assembly(self.ir_runtime, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_runtime, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property def bytecode(self) -> bytes: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1c4dc1ef7c..1d3df8becb 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -9,6 +9,7 @@ from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic +from vyper.ir.optimizer import COMMUTATIVE_OPS from vyper.utils import MemoryPositions from vyper.version import version_tuple @@ -164,7 +165,7 @@ def _add_postambles(asm_ops): # insert the postambles *before* runtime code # so the data section of the runtime code can't bork the postambles. runtime = None - if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], RuntimeHeader): runtime = asm_ops.pop() # for some reason there might not be a STOP at the end of asm_ops. @@ -229,7 +230,7 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): _relocate_segments(res) if optimize != OptimizationLevel.NONE: - _optimize_assembly(res) + optimize_assembly(res) return res @@ -531,7 +532,7 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode + subcode = [RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode # append the runtime code after the ctor code # `append(...)` call here is intentional. @@ -675,7 +676,7 @@ def _height_of(witharg): ) elif code.value == "data": - data_node = [_DataHeader("_sym_" + code.args[0].value)] + data_node = [DataHeader("_sym_" + code.args[0].value)] for c in code.args[1:]: if isinstance(c.value, int): @@ -837,6 +838,31 @@ def _prune_inefficient_jumps(assembly): return changed +def _optimize_inefficient_jumps(assembly): + # optimize sequences `_sym_common JUMPI _sym_x JUMP _sym_common JUMPDEST` + # to `ISZERO _sym_x JUMPI _sym_common JUMPDEST` + changed = False + i = 0 + while i < len(assembly) - 6: + if ( + is_symbol(assembly[i]) + and assembly[i + 1] == "JUMPI" + and is_symbol(assembly[i + 2]) + and assembly[i + 3] == "JUMP" + and assembly[i] == assembly[i + 4] + and assembly[i + 5] == "JUMPDEST" + ): + changed = True + assembly[i] = "ISZERO" + assembly[i + 1] = assembly[i + 2] + assembly[i + 2] = "JUMPI" + del assembly[i + 3 : i + 4] + else: + i += 1 + + return changed + + def _merge_jumpdests(assembly): # When we have multiple JUMPDESTs in a row, or when a JUMPDEST # is immediately followed by another JUMP, we can skip the @@ -938,7 +964,7 @@ def _prune_unused_jumpdests(assembly): used_jumpdests.add(assembly[i]) for item in assembly: - if isinstance(item, list) and isinstance(item[0], _DataHeader): + if isinstance(item, list) and isinstance(item[0], DataHeader): # add symbols used in data sections as they are likely # used for a jumptable. for t in item: @@ -961,6 +987,12 @@ def _stack_peephole_opts(assembly): changed = False i = 0 while i < len(assembly) - 2: + if assembly[i : i + 3] == ["DUP1", "SWAP2", "SWAP1"]: + changed = True + del assembly[i + 2] + assembly[i] = "SWAP1" + assembly[i + 1] = "DUP2" + continue # usually generated by with statements that return their input like # (with x (...x)) if assembly[i : i + 3] == ["DUP1", "SWAP1", "POP"]: @@ -975,16 +1007,22 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue + if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + changed = True + del assembly[i : i + 2] + if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: + changed = True + del assembly[i] i += 1 return changed # optimize assembly, in place -def _optimize_assembly(assembly): +def optimize_assembly(assembly): for x in assembly: - if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): - _optimize_assembly(x) + if isinstance(x, list) and isinstance(x[0], RuntimeHeader): + optimize_assembly(x) for _ in range(1024): changed = False @@ -993,6 +1031,7 @@ def _optimize_assembly(assembly): changed |= _merge_iszero(assembly) changed |= _merge_jumpdests(assembly) changed |= _prune_inefficient_jumps(assembly) + changed |= _optimize_inefficient_jumps(assembly) changed |= _prune_unused_jumpdests(assembly) changed |= _stack_peephole_opts(assembly) @@ -1021,7 +1060,7 @@ def adjust_pc_maps(pc_maps, ofst): def _data_to_evm(assembly, symbol_map): ret = bytearray() - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") @@ -1039,7 +1078,7 @@ def _data_to_evm(assembly, symbol_map): # predict what length of an assembly [data] node will be in bytecode def _length_of_data(assembly): ret = 0 - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): ret += SYMBOL_SIZE @@ -1055,7 +1094,7 @@ def _length_of_data(assembly): @dataclass -class _RuntimeHeader: +class RuntimeHeader: label: str ctor_mem_size: int immutables_len: int @@ -1065,7 +1104,7 @@ def __repr__(self): @dataclass -class _DataHeader: +class DataHeader: label: str def __repr__(self): @@ -1081,11 +1120,11 @@ def _relocate_segments(assembly): code_segments = [] for t in assembly: if isinstance(t, list): - if isinstance(t[0], _DataHeader): + if isinstance(t[0], DataHeader): data_segments.append(t) else: _relocate_segments(t) # recurse - assert isinstance(t[0], _RuntimeHeader) + assert isinstance(t[0], RuntimeHeader) code_segments.append(t) else: non_data_segments.append(t) @@ -1134,7 +1173,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + if isinstance(item, list) and isinstance(item[0], RuntimeHeader): assert runtime_code is None, "Multiple subcodes" assert ctor_mem_size is None @@ -1184,6 +1223,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: + print(assembly) raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc @@ -1198,7 +1238,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): # we are in initcode symbol_map[item[0].label] = pc # add source map for all items in the runtime map @@ -1209,10 +1249,10 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat pc += len(runtime_code) # grab lengths of data sections from the runtime for t in item: - if isinstance(t, list) and isinstance(t[0], _DataHeader): + if isinstance(t, list) and isinstance(t[0], DataHeader): data_section_lengths.append(_length_of_data(t)) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): symbol_map[item[0].label] = pc pc += _length_of_data(item) else: @@ -1285,9 +1325,9 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": ret.append(SWAP_OFFSET + int(item[4:])) - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): ret.extend(runtime_code) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): ret.extend(_data_to_evm(item, symbol_map)) else: # pragma: no cover # unreachable diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 8df4bbac2d..79e02f041d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -440,6 +440,8 @@ def _optimize(node: IRnode, parent: Optional[IRnode]) -> Tuple[bool, IRnode]: error_msg = node.error_msg annotation = node.annotation add_gas_estimate = node.add_gas_estimate + is_self_call = node.is_self_call + passthrough_metadata = node.passthrough_metadata changed = False @@ -462,6 +464,8 @@ def finalize(val, args): error_msg=error_msg, annotation=annotation, add_gas_estimate=add_gas_estimate, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) if should_check_symbols: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 207e437dde..a856b1184f 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -95,7 +95,7 @@ def __init__( self.is_from_abi = is_from_abi # a list of internal functions this function calls - self.called_functions = OrderedSet() + self.called_functions = OrderedSet[ContractFunctionT]() # to be populated during codegen self._ir_info: Any = None diff --git a/vyper/utils.py b/vyper/utils.py index 3d9d9cb416..0a2e1f831f 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -6,12 +6,14 @@ import time import traceback import warnings -from typing import List, Union +from typing import Generic, List, TypeVar, Union from vyper.exceptions import DecimalOverrideException, InvalidLiteral +_T = TypeVar("_T") -class OrderedSet(dict): + +class OrderedSet(Generic[_T], dict[_T, None]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -20,9 +22,41 @@ class OrderedSet(dict): functionality as needed. """ - def add(self, item): + def __init__(self, iterable=None): + super().__init__() + if iterable is not None: + for item in iterable: + self.add(item) + + def __repr__(self): + keys = ", ".join(repr(k) for k in self.keys()) + return f"{{{keys}}}" + + def get(self, *args, **kwargs): + raise RuntimeError("can't call get() on OrderedSet!") + + def add(self, item: _T) -> None: self[item] = None + def remove(self, item: _T) -> None: + del self[item] + + def difference(self, other): + ret = self.copy() + for k in other.keys(): + if k in ret: + ret.remove(k) + return ret + + def union(self, other): + return self | other + + def __or__(self, other): + return self.__class__(super().__or__(other)) + + def copy(self): + return self.__class__(super().copy()) + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): @@ -436,3 +470,25 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) + + +def ir_pass(func): + """ + Decorator for IR passes. This decorator will run the pass repeatedly until + no more changes are made. + """ + + def wrapper(*args, **kwargs): + count = 0 + + while True: + changes = func(*args, **kwargs) or 0 + if isinstance(changes, list) or isinstance(changes, set): + changes = len(changes) + count += changes + if changes == 0: + break + + return count + + return wrapper diff --git a/vyper/venom/README.md b/vyper/venom/README.md new file mode 100644 index 0000000000..a81f6c0582 --- /dev/null +++ b/vyper/venom/README.md @@ -0,0 +1,162 @@ +## Venom - An Intermediate representation language for Vyper + +### Introduction + +Venom serves as the next-gen intermediate representation language specifically tailored for use with the Vyper smart contract compiler. Drawing inspiration from LLVM IR, Venom has been adapted to be simpler, and to be architected towards emitting code for stack-based virtual machines. Designed with a Single Static Assignment (SSA) form, Venom allows for sophisticated analysis and optimizations, while accommodating the idiosyncrasies of the EVM architecture. + +### Venom Form + +In Venom, values are denoted as strings commencing with the `'%'` character, referred to as variables. Variables can only be assigned to at declaration (they remain immutable post-assignment). Constants are represented as decimal numbers (hexadecimal may be added in the future). + +Reserved words include all the instruction opcodes and `'IRFunction'`, `'param'`, `'dbname'` and `'db'`. + +Any content following the `';'` character until the line end is treated as a comment. + +For instance, an example of incrementing a variable by one is as follows: + +```llvm +%sum = add %x, 1 ; Add one to x +``` + +Each instruction is identified by its opcode and a list of input operands. In cases where an instruction produces a result, it is stored in a new variable, as indicated on the left side of the assignment character. + +Code is organized into non-branching instruction blocks, known as _"Basic Blocks"_. Each basic block is defined by a label and contains its set of instructions. The final instruction of a basic block should either be a terminating instruction or a jump (conditional or unconditional) to other block(s). + +Basic blocks are grouped into _functions_ that are named and dictate the first block to execute. + +Venom employs two scopes: global and function level. + +### Example code + +```llvm +IRFunction: global + +global: + %1 = calldataload 0 + %2 = shr 224, %1 + jmp label %selector_bucket_0 + +selector_bucket_0: + %3 = xor %2, 1579456981 + %4 = iszero %3 + jnz label %1, label %2, %4 + +1: IN=[selector_bucket_0] OUT=[9] + jmp label %fallback + +2: + %5 = callvalue + %6 = calldatasize + %7 = lt %6, 164 + %8 = or %5, %7 + %9 = iszero %8 + assert %9 + stop + +fallback: + revert 0, 0 +``` + +### Grammar + +Below is a (not-so-complete) grammar to describe the text format of Venom IR: + +```llvm +program ::= function_declaration* + +function_declaration ::= "IRFunction:" identifier input_list? output_list? "=>" block + +input_list ::= "IN=" "[" (identifier ("," identifier)*)? "]" +output_list ::= "OUT=" "[" (identifier ("," identifier)*)? "]" + +block ::= label ":" input_list? output_list? "=>{" operation* "}" + +operation ::= "%" identifier "=" opcode operand ("," operand)* + | opcode operand ("," operand)* + +opcode ::= "calldataload" | "shr" | "shl" | "and" | "add" | "codecopy" | "mload" | "jmp" | "xor" | "iszero" | "jnz" | "label" | "lt" | "or" | "assert" | "callvalue" | "calldatasize" | "alloca" | "calldatacopy" | "invoke" | "gt" | ... + +operand ::= "%" identifier | label | integer | "label" "%" identifier +label ::= "%" identifier + +identifier ::= [a-zA-Z_][a-zA-Z0-9_]* +integer ::= [0-9]+ +``` + +## Implementation + +In the current implementation the compiler was extended to incorporate a new pass responsible for translating the original s-expr based IR into Venom. Subsequently, the generated Venom code undergoes processing by the actual Venom compiler, ultimately converting it to assembly code. That final assembly code is then passed to the original assembler of Vyper to produce the executable bytecode. + +Currently there is no implementation of the text format (that is, there is no front-end), although this is planned. At this time, Venom IR can only be constructed programmatically. + +## Architecture + +The Venom implementation is composed of several distinct passes that iteratively transform and optimize the Venom IR code until it reaches the assembly emitter, which produces the stack-based EVM assembly. The compiler is designed to be more-or-less pluggable, so passes can be written without too much knowledge of or dependency on other passes. + +These passes encompass generic transformations that streamline the code (such as dead code elimination and normalization), as well as those generating supplementary information about the code, like liveness analysis and control-flow graph (CFG) construction. Some passes may rely on the output of others, requiring a specific execution order. For instance, the code emitter expects the execution of a normalization pass preceding it, and this normalization pass, in turn, requires the augmentation of the Venom IR with code flow information. + +The primary categorization of pass types are: + +- Transformation passes +- Analysis/augmentation passes +- Optimization passes + +## Currently implemented passes + +The Venom compiler currently implements the following passes. + +### Control Flow Graph calculation + +The compiler generates a fundamental data structure known as the Control Flow Graph (CFG). This graph illustrates the interconnections between basic blocks, serving as a foundational data structure upon which many subsequent passes depend. + +### Data Flow Graph calculation + +To enable the compiler to analyze the movement of data through the code during execution, a specialized graph, the Dataflow Graph (DFG), is generated. The compiler inspects the code, determining where each variable is defined (in one location) and all the places where it is utilized. + +### Dataflow Transformation + +This pass depends on the DFG construction, and reorders variable declarations to try to reduce stack traffic during instruction selection. + +### Liveness analysis + +This pass conducts a dataflow analysis, utilizing information from previous passes to identify variables that are live at each instruction in the Venom IR code. A variable is deemed live at a particular instruction if it holds a value necessary for future operations. Variables only alive for their assignment instructions are identified here and then eliminated by the dead code elimination pass. + +### Dead code elimination + +This pass eliminates all basic blocks that are not reachable from any other basic block, leveraging the CFG. + +### Normalization + +A Venom program may feature basic blocks with multiple CFG inputs and outputs. This currently can occur when multiple blocks conditionally direct control to the same target basic block. We define a Venom IR as "normalized" when it contains no basic blocks that have multiple inputs and outputs. The normalization pass is responsible for converting any Venom IR program to its normalized form. EVM assembly emission operates solely on normalized Venom programs, because the stack layout is not well defined for non-normalized basic blocks. + +### Code emission + +This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calcluates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. + +## Future planned passes + +A number of passes that are planned to be implemented, or are implemented for immediately after the initial PR merge are below. + +### Constant folding + +### Instruction combination + +### Dead store elimination + +### Scalar evolution + +### Loop invariant code motion + +### Loop unrolling + +### Code sinking + +### Expression reassociation + +### Stack to mem + +### Mem to stack + +### Function inlining + +### Load-store elimination diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py new file mode 100644 index 0000000000..5a09f8378e --- /dev/null +++ b/vyper/venom/__init__.py @@ -0,0 +1,56 @@ +# maybe rename this `main.py` or `venom.py` +# (can have an `__init__.py` which exposes the API). + +from typing import Optional + +from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.bb_optimizer import ( + ir_pass_optimize_empty_blocks, + ir_pass_optimize_unused_variables, + ir_pass_remove_unreachable_blocks, +) +from vyper.venom.function import IRFunction +from vyper.venom.ir_node_to_venom import convert_ir_basicblock +from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation +from vyper.venom.passes.dft import DFTPass +from vyper.venom.venom_to_assembly import VenomCompiler + + +def generate_assembly_experimental( + ctx: IRFunction, optimize: Optional[OptimizationLevel] = None +) -> list[str]: + compiler = VenomCompiler(ctx) + return compiler.generate_evm(optimize is OptimizationLevel.NONE) + + +def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: + # Convert "old" IR to "new" IR + ctx = convert_ir_basicblock(ir) + + # Run passes on "new" IR + # TODO: Add support for optimization levels + while True: + changes = 0 + + changes += ir_pass_optimize_empty_blocks(ctx) + changes += ir_pass_remove_unreachable_blocks(ctx) + + calculate_liveness(ctx) + + changes += ir_pass_optimize_unused_variables(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + changes += ir_pass_constant_propagation(ctx) + changes += DFTPass.run_pass(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + if changes == 0: + break + + return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py new file mode 100644 index 0000000000..5980e21028 --- /dev/null +++ b/vyper/venom/analysis.py @@ -0,0 +1,191 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import ( + BB_TERMINATORS, + CFG_ALTERING_OPS, + IRBasicBlock, + IRInstruction, + IRVariable, +) +from vyper.venom.function import IRFunction + + +def calculate_cfg(ctx: IRFunction) -> None: + """ + Calculate (cfg) inputs for each basic block. + """ + for bb in ctx.basic_blocks: + bb.cfg_in = OrderedSet() + bb.cfg_out = OrderedSet() + bb.out_vars = OrderedSet() + + # TODO: This is a hack to support the old IR format where `deploy` is + # an instruction. in the future we should have two entry points, one + # for the initcode and one for the runtime code. + deploy_bb = None + after_deploy_bb = None + for i, bb in enumerate(ctx.basic_blocks): + if bb.instructions[0].opcode == "deploy": + deploy_bb = bb + after_deploy_bb = ctx.basic_blocks[i + 1] + break + + if deploy_bb is not None: + assert after_deploy_bb is not None, "No block after deploy block" + entry_block = after_deploy_bb + has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" + if has_constructor: + deploy_bb.add_cfg_in(ctx.basic_blocks[0]) + entry_block.add_cfg_in(deploy_bb) + else: + entry_block = ctx.basic_blocks[0] + + # TODO: Special case for the jump table of selector buckets and fallback. + # this will be cleaner when we introduce an "indirect jump" instruction + # for the selector table (which includes all possible targets). it will + # also clean up the code for normalization because it will not have to + # handle this case specially. + for bb in ctx.basic_blocks: + if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": + bb.add_cfg_in(entry_block) + + for bb in ctx.basic_blocks: + assert len(bb.instructions) > 0, "Basic block should not be empty" + last_inst = bb.instructions[-1] + assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" + + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_OPS: + ops = inst.get_label_operands() + for op in ops: + ctx.get_basic_block(op.value).add_cfg_in(bb) + + # Fill in the "out" set for each basic block + for bb in ctx.basic_blocks: + for in_bb in bb.cfg_in: + in_bb.add_cfg_out(bb) + + +def _reset_liveness(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + for inst in bb.instructions: + inst.liveness = OrderedSet() + + +def _calculate_liveness_bb(bb: IRBasicBlock) -> None: + """ + Compute liveness of each instruction in the basic block. + """ + liveness = bb.out_vars.copy() + for instruction in reversed(bb.instructions): + ops = instruction.get_inputs() + + for op in ops: + if op in liveness: + instruction.dup_requirements.add(op) + + liveness = liveness.union(OrderedSet.fromkeys(ops)) + out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None + if out in liveness: + liveness.remove(out) + instruction.liveness = liveness + + +def _calculate_liveness_r(bb: IRBasicBlock, visited: dict) -> None: + assert isinstance(visited, dict) + for out_bb in bb.cfg_out: + if visited.get(bb) == out_bb: + continue + visited[bb] = out_bb + + # recurse + _calculate_liveness_r(out_bb, visited) + + target_vars = input_vars_from(bb, out_bb) + + # the output stack layout for bb. it produces a stack layout + # which works for all possible cfg_outs from the bb. + bb.out_vars = bb.out_vars.union(target_vars) + + _calculate_liveness_bb(bb) + + +def calculate_liveness(ctx: IRFunction) -> None: + _reset_liveness(ctx) + _calculate_liveness_r(ctx.basic_blocks[0], dict()) + + +# calculate the input variables into self from source +def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: + liveness = target.instructions[0].liveness.copy() + assert isinstance(liveness, OrderedSet) + + for inst in target.instructions: + if inst.opcode == "phi": + # we arbitrarily choose one of the arguments to be in the + # live variables set (dependent on how we traversed into this + # basic block). the argument will be replaced by the destination + # operand during instruction selection. + # for instance, `%56 = phi %label1 %12 %label2 %14` + # will arbitrarily choose either %12 or %14 to be in the liveness + # set, and then during instruction selection, after this instruction, + # %12 will be replaced by %56 in the liveness set + source1, source2 = inst.operands[0], inst.operands[2] + phi1, phi2 = inst.operands[1], inst.operands[3] + if source.label == source1: + liveness.add(phi1) + if phi2 in liveness: + liveness.remove(phi2) + elif source.label == source2: + liveness.add(phi2) + if phi1 in liveness: + liveness.remove(phi1) + else: + # bad path into this phi node + raise CompilerPanic(f"unreachable: {inst}") + + return liveness + + +# DataFlow Graph +# this could be refactored into its own file, but it's only used here +# for now +class DFG: + _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_outputs: dict[IRVariable, IRInstruction] + + def __init__(self): + self._dfg_inputs = dict() + self._dfg_outputs = dict() + + # return uses of a given variable + def get_uses(self, op: IRVariable) -> list[IRInstruction]: + return self._dfg_inputs.get(op, []) + + # the instruction which produces this variable. + def get_producing_instruction(self, op: IRVariable) -> IRInstruction: + return self._dfg_outputs[op] + + @classmethod + def build_dfg(cls, ctx: IRFunction) -> "DFG": + dfg = cls() + + # Build DFG + + # %15 = add %13 %14 + # %16 = iszero %15 + # dfg_outputs of %15 is (%15 = add %13 %14) + # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] + for bb in ctx.basic_blocks: + for inst in bb.instructions: + operands = inst.get_inputs() + res = inst.get_outputs() + + for op in operands: + inputs = dfg._dfg_inputs.setdefault(op, []) + inputs.append(inst) + + for op in res: # type: ignore + dfg._dfg_outputs[op] = inst + + return dfg diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py new file mode 100644 index 0000000000..b95d7416ca --- /dev/null +++ b/vyper/venom/basicblock.py @@ -0,0 +1,345 @@ +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Iterator, Optional + +from vyper.utils import OrderedSet + +# instructions which can terminate a basic block +BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) + +VOLATILE_INSTRUCTIONS = frozenset( + [ + "param", + "alloca", + "call", + "staticcall", + "invoke", + "sload", + "sstore", + "iload", + "istore", + "assert", + "mstore", + "mload", + "calldatacopy", + "codecopy", + "dloadbytes", + "dload", + "return", + "ret", + "jmp", + "jnz", + ] +) + +CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) + + +if TYPE_CHECKING: + from vyper.venom.function import IRFunction + + +class IRDebugInfo: + """ + IRDebugInfo represents debug information in IR, used to annotate IR instructions + with source code information when printing IR. + """ + + line_no: int + src: str + + def __init__(self, line_no: int, src: str) -> None: + self.line_no = line_no + self.src = src + + def __repr__(self) -> str: + src = self.src if self.src else "" + return f"\t# line {self.line_no}: {src}".expandtabs(20) + + +class IROperand: + """ + IROperand represents an operand in IR. An operand is anything that can + be an argument to an IRInstruction + """ + + value: Any + + +class IRValue(IROperand): + """ + IRValue represents a value in IR. A value is anything that can be + operated by non-control flow instructions. That is, IRValues can be + IRVariables or IRLiterals. + """ + + pass + + +class IRLiteral(IRValue): + """ + IRLiteral represents a literal in IR + """ + + value: int + + def __init__(self, value: int) -> None: + assert isinstance(value, str) or isinstance(value, int), "value must be an int" + self.value = value + + def __repr__(self) -> str: + return str(self.value) + + +class MemType(Enum): + OPERAND_STACK = auto() + MEMORY = auto() + + +class IRVariable(IRValue): + """ + IRVariable represents a variable in IR. A variable is a string that starts with a %. + """ + + value: str + offset: int = 0 + + # some variables can be in memory for conversion from legacy IR to venom + mem_type: MemType = MemType.OPERAND_STACK + mem_addr: Optional[int] = None + + def __init__( + self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None + ) -> None: + assert isinstance(value, str) + self.value = value + self.offset = 0 + self.mem_type = mem_type + self.mem_addr = mem_addr + + def __repr__(self) -> str: + return self.value + + +class IRLabel(IROperand): + """ + IRLabel represents a label in IR. A label is a string that starts with a %. + """ + + # is_symbol is used to indicate if the label came from upstream + # (like a function name, try to preserve it in optimization passes) + is_symbol: bool = False + value: str + + def __init__(self, value: str, is_symbol: bool = False) -> None: + assert isinstance(value, str), "value must be an str" + self.value = value + self.is_symbol = is_symbol + + def __repr__(self) -> str: + return self.value + + +class IRInstruction: + """ + IRInstruction represents an instruction in IR. Each instruction has an opcode, + operands, and return value. For example, the following IR instruction: + %1 = add %0, 1 + has opcode "add", operands ["%0", "1"], and return value "%1". + + Convention: the rightmost value is the top of the stack. + """ + + opcode: str + volatile: bool + operands: list[IROperand] + output: Optional[IROperand] + # set of live variables at this instruction + liveness: OrderedSet[IRVariable] + dup_requirements: OrderedSet[IRVariable] + parent: Optional["IRBasicBlock"] + fence_id: int + annotation: Optional[str] + + def __init__( + self, + opcode: str, + operands: list[IROperand] | Iterator[IROperand], + output: Optional[IROperand] = None, + ): + assert isinstance(opcode, str), "opcode must be an str" + assert isinstance(operands, list | Iterator), "operands must be a list" + self.opcode = opcode + self.volatile = opcode in VOLATILE_INSTRUCTIONS + self.operands = [op for op in operands] # in case we get an iterator + self.output = output + self.liveness = OrderedSet() + self.dup_requirements = OrderedSet() + self.parent = None + self.fence_id = -1 + self.annotation = None + + def get_label_operands(self) -> list[IRLabel]: + """ + Get all labels in instruction. + """ + return [op for op in self.operands if isinstance(op, IRLabel)] + + def get_non_label_operands(self) -> list[IROperand]: + """ + Get input operands for instruction which are not labels + """ + return [op for op in self.operands if not isinstance(op, IRLabel)] + + def get_inputs(self) -> list[IRVariable]: + """ + Get all input operands for instruction. + """ + return [op for op in self.operands if isinstance(op, IRVariable)] + + def get_outputs(self) -> list[IROperand]: + """ + Get the output item for an instruction. + (Currently all instructions output at most one item, but write + it as a list to be generic for the future) + """ + return [self.output] if self.output else [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + for i, operand in enumerate(self.operands): + if operand in replacements: + self.operands[i] = replacements[operand] + + def __repr__(self) -> str: + s = "" + if self.output: + s += f"{self.output} = " + opcode = f"{self.opcode} " if self.opcode != "store" else "" + s += opcode + operands = ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + ) + s += operands + + if self.annotation: + s += f" <{self.annotation}>" + + # if self.liveness: + # return f"{s: <30} # {self.liveness}" + + return s + + +class IRBasicBlock: + """ + IRBasicBlock represents a basic block in IR. Each basic block has a label and + a list of instructions, while belonging to a function. + + The following IR code: + %1 = add %0, 1 + %2 = mul %1, 2 + is represented as: + bb = IRBasicBlock("bb", function) + bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) + bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + + The label of a basic block is used to refer to it from other basic blocks + in order to branch to it. + + The parent of a basic block is the function it belongs to. + + The instructions of a basic block are executed sequentially, and the last + instruction of a basic block is always a terminator instruction, which is + used to branch to other basic blocks. + """ + + label: IRLabel + parent: "IRFunction" + instructions: list[IRInstruction] + # basic blocks which can jump to this basic block + cfg_in: OrderedSet["IRBasicBlock"] + # basic blocks which this basic block can jump to + cfg_out: OrderedSet["IRBasicBlock"] + # stack items which this basic block produces + out_vars: OrderedSet[IRVariable] + + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: + assert isinstance(label, IRLabel), "label must be an IRLabel" + self.label = label + self.parent = parent + self.instructions = [] + self.cfg_in = OrderedSet() + self.cfg_out = OrderedSet() + self.out_vars = OrderedSet() + + def add_cfg_in(self, bb: "IRBasicBlock") -> None: + self.cfg_in.add(bb) + + def remove_cfg_in(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_in + self.cfg_in.remove(bb) + + def add_cfg_out(self, bb: "IRBasicBlock") -> None: + # malformed: jnz condition label1 label1 + # (we could handle but it makes a lot of code easier + # if we have this assumption) + self.cfg_out.add(bb) + + def remove_cfg_out(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_out + self.cfg_out.remove(bb) + + @property + def is_reachable(self) -> bool: + return len(self.cfg_in) > 0 + + def append_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.append(instruction) + + def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.insert(index, instruction) + + def clear_instructions(self) -> None: + self.instructions = [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + """ + for instruction in self.instructions: + instruction.replace_operands(replacements) + + @property + def is_terminated(self) -> bool: + """ + Check if the basic block is terminal, i.e. the last instruction is a terminator. + """ + # it's ok to return False here, since we use this to check + # if we can/need to append instructions to the basic block. + if len(self.instructions) == 0: + return False + return self.instructions[-1].opcode in BB_TERMINATORS + + def copy(self): + bb = IRBasicBlock(self.label, self.parent) + bb.instructions = self.instructions.copy() + bb.cfg_in = self.cfg_in.copy() + bb.cfg_out = self.cfg_out.copy() + bb.out_vars = self.out_vars.copy() + return bb + + def __repr__(self) -> str: + s = ( + f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" + f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars} \n" + ) + for instruction in self.instructions: + s += f" {instruction}\n" + return s diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py new file mode 100644 index 0000000000..620ee66d15 --- /dev/null +++ b/vyper/venom/bb_optimizer.py @@ -0,0 +1,73 @@ +from vyper.utils import ir_pass +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRInstruction, IRLabel +from vyper.venom.function import IRFunction + + +def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: + """ + Remove unused variables. + """ + removeList = set() + for bb in ctx.basic_blocks: + for i, inst in enumerate(bb.instructions[:-1]): + if inst.volatile: + continue + if inst.output and inst.output not in bb.instructions[i + 1].liveness: + removeList.add(inst) + + bb.instructions = [inst for inst in bb.instructions if inst not in removeList] + + return removeList + + +def _optimize_empty_basicblocks(ctx: IRFunction) -> int: + """ + Remove empty basic blocks. + """ + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if len(bb.instructions) > 0: + continue + + replaced_label = bb.label + replacement_label = ctx.basic_blocks[i].label if i < len(ctx.basic_blocks) else None + if replacement_label is None: + continue + + # Try to preserve symbol labels + if replaced_label.is_symbol: + replaced_label, replacement_label = replacement_label, replaced_label + ctx.basic_blocks[i].label = replacement_label + + for bb2 in ctx.basic_blocks: + for inst in bb2.instructions: + for op in inst.operands: + if isinstance(op, IRLabel) and op.value == replaced_label.value: + op.value = replacement_label.value + + ctx.basic_blocks.remove(bb) + i -= 1 + count += 1 + + return count + + +@ir_pass +def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: + changes = _optimize_empty_basicblocks(ctx) + calculate_cfg(ctx) + return changes + + +@ir_pass +def ir_pass_remove_unreachable_blocks(ctx: IRFunction) -> int: + return ctx.remove_unreachable_blocks() + + +@ir_pass +def ir_pass_optimize_unused_variables(ctx: IRFunction) -> int: + return len(_optimize_unused_variables(ctx)) diff --git a/vyper/venom/function.py b/vyper/venom/function.py new file mode 100644 index 0000000000..c14ad77345 --- /dev/null +++ b/vyper/venom/function.py @@ -0,0 +1,170 @@ +from typing import Optional + +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IROperand, + IRVariable, + MemType, +) + +GLOBAL_LABEL = IRLabel("global") + + +class IRFunction: + """ + Function that contains basic blocks. + """ + + name: IRLabel # symbol name + args: list + basic_blocks: list[IRBasicBlock] + data_segment: list[IRInstruction] + last_label: int + last_variable: int + + def __init__(self, name: IRLabel = None) -> None: + if name is None: + name = GLOBAL_LABEL + self.name = name + self.args = [] + self.basic_blocks = [] + self.data_segment = [] + self.last_label = 0 + self.last_variable = 0 + + self.append_basic_block(IRBasicBlock(name, self)) + + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + """ + Append basic block to function. + """ + assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" + self.basic_blocks.append(bb) + + # TODO add sanity check somewhere that basic blocks have unique labels + + return self.basic_blocks[-1] + + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: + """ + Get basic block by label. + If label is None, return the last basic block. + """ + if label is None: + return self.basic_blocks[-1] + for bb in self.basic_blocks: + if bb.label.value == label: + return bb + raise AssertionError(f"Basic block '{label}' not found") + + def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + """ + Get basic block after label. + """ + for i, bb in enumerate(self.basic_blocks[:-1]): + if bb.label.value == label.value: + return self.basic_blocks[i + 1] + raise AssertionError(f"Basic block after '{label}' not found") + + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: + """ + Get basic blocks that contain label. + """ + return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] + + def get_next_label(self) -> IRLabel: + self.last_label += 1 + return IRLabel(f"{self.last_label}") + + def get_next_variable( + self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None + ) -> IRVariable: + self.last_variable += 1 + return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + + def get_last_variable(self) -> str: + return f"%{self.last_variable}" + + def remove_unreachable_blocks(self) -> int: + removed = 0 + new_basic_blocks = [] + for bb in self.basic_blocks: + if not bb.is_reachable and bb.label.value != "global": + removed += 1 + else: + new_basic_blocks.append(bb) + self.basic_blocks = new_basic_blocks + return removed + + def append_instruction( + self, opcode: str, args: list[IROperand], do_ret: bool = True + ) -> Optional[IRVariable]: + """ + Append instruction to last basic block. + """ + ret = self.get_next_variable() if do_ret else None + inst = IRInstruction(opcode, args, ret) # type: ignore + self.get_basic_block().append_instruction(inst) + return ret + + def append_data(self, opcode: str, args: list[IROperand]) -> None: + """ + Append data + """ + self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + + @property + def normalized(self) -> bool: + """ + Check if function is normalized. A function is normalized if in the + CFG, no basic block simultaneously has multiple inputs and outputs. + That is, a basic block can be jumped to *from* multiple blocks, or it + can jump *to* multiple blocks, but it cannot simultaneously do both. + Having a normalized CFG makes calculation of stack layout easier when + emitting assembly. + """ + for bb in self.basic_blocks: + # Ignore if there are no multiple predecessors + if len(bb.cfg_in) <= 1: + continue + + # Check if there is a conditional jump at the end + # of one of the predecessors + # + # TODO: this check could be: + # `if len(in_bb.cfg_out) > 1: return False` + # but the cfg is currently not calculated "correctly" for + # certain special instructions (deploy instruction and + # selector table indirect jumps). + for in_bb in bb.cfg_in: + jump_inst = in_bb.instructions[-1] + if jump_inst.opcode != "jnz": + continue + if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): + continue + + # The function is not normalized + return False + + # The function is normalized + return True + + def copy(self): + new = IRFunction(self.name) + new.basic_blocks = self.basic_blocks.copy() + new.data_segment = self.data_segment.copy() + new.last_label = self.last_label + new.last_variable = self.last_variable + return new + + def __repr__(self) -> str: + str = f"IRFunction: {self.name}\n" + for bb in self.basic_blocks: + str += f"{bb}\n" + if len(self.data_segment) > 0: + str += "Data segment:\n" + for inst in self.data_segment: + str += f"{inst}\n" + return str diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py new file mode 100644 index 0000000000..19bd5c8b73 --- /dev/null +++ b/vyper/venom/ir_node_to_venom.py @@ -0,0 +1,943 @@ +from typing import Optional + +from vyper.codegen.context import VariableRecord +from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import get_opcodes +from vyper.exceptions import CompilerPanic +from vyper.ir.compile_ir import is_mem_sym, is_symbol +from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction + +_BINARY_IR_INSTRUCTIONS = frozenset( + [ + "eq", + "gt", + "lt", + "slt", + "sgt", + "shr", + "shl", + "or", + "xor", + "and", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "sha3", + "sha3_64", + "signextend", + ] +) + +# Instuctions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = [ + "chainid", + "basefee", + "timestamp", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "coinbase", + "number", + "iszero", + "ceil32", + "calldataload", + "extcodesize", + "extcodehash", + "balance", +] + +SymbolTable = dict[str, IROperand] + + +def _get_symbols_common(a: dict, b: dict) -> dict: + ret = {} + # preserves the ordering in `a` + for k in a.keys(): + if k not in b: + continue + if a[k] == b[k]: + continue + ret[k] = a[k], b[k] + return ret + + +def convert_ir_basicblock(ir: IRnode) -> IRFunction: + global_function = IRFunction() + _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) + + for i, bb in enumerate(global_function.basic_blocks): + if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: + bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + + revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) + revert_bb = global_function.append_basic_block(revert_bb) + revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + + return global_function + + +def _convert_binary_op( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], + swap: bool = False, +) -> IRVariable: + ir_args = ir.args[::-1] if swap else ir.args + arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) + args = [arg_1, arg_0] + + ret = ctx.get_next_variable() + + inst = IRInstruction(ir.value, args, ret) # type: ignore + ctx.get_basic_block().append_instruction(inst) + return ret + + +def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + + label = ctx.get_next_label() + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + + +def _new_block(ctx: IRFunction) -> IRBasicBlock: + bb = IRBasicBlock(ctx.get_next_label(), ctx) + bb = ctx.append_basic_block(bb) + return bb + + +def _handle_self_call( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + func_t = ir.passthrough_metadata.get("func_t", None) + args_ir = ir.passthrough_metadata["args_ir"] + goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] + target_label = goto_ir.args[0].value # goto + return_buf = goto_ir.args[1] # return buffer + ret_args = [IRLabel(target_label)] # type: ignore + + for arg in args_ir: + if arg.is_literal: + sym = symbols.get(f"&{arg.value}", None) + if sym is None: + ret = _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ret_args.append(ret) + else: + ret_args.append(sym) # type: ignore + else: + ret = _convert_ir_basicblock( + ctx, arg._optimized, symbols, variables, allocated_variables + ) + if arg.location and arg.location.load_op == "calldataload": + ret = ctx.append_instruction(arg.location.load_op, [ret]) + ret_args.append(ret) + + if return_buf.is_literal: + ret_args.append(IRLiteral(return_buf.value)) # type: ignore + + do_ret = func_t.return_type is not None + invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + + +def _handle_internal_func( + ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable +) -> IRnode: + bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore + bb = ctx.append_basic_block(bb) + + old_ir_mempos = 0 + old_ir_mempos += 64 + + for arg in func_t.arguments: + new_var = ctx.get_next_variable() + + alloca_inst = IRInstruction("param", [], new_var) + alloca_inst.annotation = arg.name + bb.append_instruction(alloca_inst) + symbols[f"&{old_ir_mempos}"] = new_var + old_ir_mempos += 32 # arg.typ.memory_bytes_required + + # return buffer + if func_t.return_type is not None: + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_buffer" + symbols["return_buffer"] = new_var + + # return address + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_pc" + symbols["return_pc"] = new_var + + return ir.args[0].args[2] + + +def _convert_ir_simple_node( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args + ] + return ctx.append_instruction(ir.value, args) # type: ignore + + +_break_target: Optional[IRBasicBlock] = None +_continue_target: Optional[IRBasicBlock] = None + + +def _get_variable_from_address( + variables: OrderedSet[VariableRecord], addr: int +) -> Optional[VariableRecord]: + assert isinstance(addr, int), "non-int address" + for var in variables.keys(): + if var.location.name != "memory": + continue + if addr >= var.pos and addr < var.pos + var.size: # type: ignore + return var + return None + + +def _get_return_for_stack_operand( + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> IRInstruction: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + else: + sym = symbols.get(ret_ir.value, None) + if sym is None: + # FIXME: needs real allocations + new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + else: + new_var = ret_ir + return IRInstruction("return", [last_ir, new_var]) # type: ignore + + +def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): + assert isinstance(variables, OrderedSet) + global _break_target, _continue_target + + frame_info = ir.passthrough_metadata.get("frame_info", None) + if frame_info is not None: + local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) + variables |= local_vars + + assert isinstance(variables, OrderedSet) + + if ir.value in _BINARY_IR_INSTRUCTIONS: + return _convert_binary_op( + ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] + ) + + elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + org_value = ir.value + ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] + new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + ir.value = org_value + return ctx.append_instruction("iszero", [new_var]) + + elif ir.value in PASS_THROUGH_INSTRUCTIONS: + return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) + + elif ir.value in ["pass", "stop", "return"]: + pass + elif ir.value == "deploy": + memsize = ir.args[0].value + ir_runtime = ir.args[1] + padding = ir.args[2].value + assert isinstance(memsize, int), "non-int memsize" + assert isinstance(padding, int), "non-int padding" + + runtimeLabel = ctx.get_next_label() + + inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) + ctx.get_basic_block().append_instruction(inst) + + bb = IRBasicBlock(runtimeLabel, ctx) + ctx.append_basic_block(bb) + + _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + elif ir.value == "seq": + func_t = ir.passthrough_metadata.get("func_t", None) + if ir.is_self_call: + return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) + elif func_t is not None: + symbols = {} + allocated_variables = {} + variables = OrderedSet( + {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} + ) + if func_t.is_internal: + ir = _handle_internal_func(ctx, ir, func_t, symbols) + # fallthrough + + ret = None + for ir_node in ir.args: # NOTE: skip the last one + ret = _convert_ir_basicblock(ctx, ir_node, symbols, variables, allocated_variables) + + return ret + elif ir.value in ["staticcall", "call"]: # external call + idx = 0 + gas = _convert_ir_basicblock(ctx, ir.args[idx], symbols, variables, allocated_variables) + address = _convert_ir_basicblock( + ctx, ir.args[idx + 1], symbols, variables, allocated_variables + ) + + value = None + if ir.value == "call": + value = _convert_ir_basicblock( + ctx, ir.args[idx + 2], symbols, variables, allocated_variables + ) + else: + idx -= 1 + + argsOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 3], symbols, variables, allocated_variables + ) + argsSize = _convert_ir_basicblock( + ctx, ir.args[idx + 4], symbols, variables, allocated_variables + ) + retOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 5], symbols, variables, allocated_variables + ) + retSize = _convert_ir_basicblock( + ctx, ir.args[idx + 6], symbols, variables, allocated_variables + ) + + if isinstance(argsOffset, IRLiteral): + offset = int(argsOffset.value) + addr = offset - 32 + 4 if offset > 0 else 0 + argsOffsetVar = symbols.get(f"&{addr}", None) + if argsOffsetVar is None: + argsOffsetVar = argsOffset + elif isinstance(argsOffsetVar, IRVariable): + argsOffsetVar.mem_type = MemType.MEMORY + argsOffsetVar.mem_addr = addr + argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 + else: # pragma: nocover + raise CompilerPanic("unreachable") + else: + argsOffsetVar = argsOffset + + retOffsetValue = int(retOffset.value) if retOffset else 0 + retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) + symbols[f"&{retOffsetValue}"] = retVar + + if ir.value == "call": + args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] + return ctx.append_instruction(ir.value, args) + else: + args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] + return ctx.append_instruction(ir.value, args) + elif ir.value == "if": + cond = ir.args[0] + current_bb = ctx.get_basic_block() + + # convert the condition + cont_ret = _convert_ir_basicblock(ctx, cond, symbols, variables, allocated_variables) + + else_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(else_block) + + # convert "else" + else_ret_val = None + else_syms = symbols.copy() + if len(ir.args) == 3: + else_ret_val = _convert_ir_basicblock( + ctx, ir.args[2], else_syms, variables, allocated_variables.copy() + ) + if isinstance(else_ret_val, IRLiteral): + assert isinstance(else_ret_val.value, int) # help mypy + else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + after_else_syms = else_syms.copy() + + # convert "then" + then_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(then_block) + + then_ret_val = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + if isinstance(then_ret_val, IRLiteral): + then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + + inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) + current_bb.append_instruction(inst) + + after_then_syms = symbols.copy() + + # exit bb + exit_label = ctx.get_next_label() + bb = IRBasicBlock(exit_label, ctx) + bb = ctx.append_basic_block(bb) + + if_ret = None + if then_ret_val is not None and else_ret_val is not None: + if_ret = ctx.get_next_variable() + bb.append_instruction( + IRInstruction( + "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret + ) + ) + + common_symbols = _get_symbols_common(after_then_syms, after_else_syms) + for sym, val in common_symbols.items(): + ret = ctx.get_next_variable() + old_var = symbols.get(sym, None) + symbols[sym] = ret + if old_var is not None: + for idx, var_rec in allocated_variables.items(): # type: ignore + if var_rec.value == old_var.value: + allocated_variables[idx] = ret # type: ignore + bb.append_instruction( + IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) + ) + + if not else_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + else_block.append_instruction(exit_inst) + + if not then_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + then_block.append_instruction(exit_inst) + + return if_ret + + elif ir.value == "with": + ret = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) # initialization + + # Handle with nesting with same symbol + with_symbols = symbols.copy() + + sym = ir.args[0] + if isinstance(ret, IRLiteral): + new_var = ctx.append_instruction("store", [ret]) # type: ignore + with_symbols[sym.value] = new_var + else: + with_symbols[sym.value] = ret # type: ignore + + return _convert_ir_basicblock( + ctx, ir.args[2], with_symbols, variables, allocated_variables + ) # body + elif ir.value == "goto": + _append_jmp(ctx, IRLabel(ir.args[0].value)) + elif ir.value == "jump": + arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + inst = IRInstruction("jmp", [arg_1]) + ctx.get_basic_block().append_instruction(inst) + _new_block(ctx) + elif ir.value == "set": + sym = ir.args[0] + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + symbols[sym.value] = new_var + + elif ir.value == "calldatacopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_v = arg_0 + var = ( + _get_variable_from_address(variables, int(arg_0.value)) + if isinstance(arg_0, IRLiteral) + else None + ) + if var is not None: + if allocated_variables.get(var.name, None) is None: + new_v = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore + ) + allocated_variables[var.name] = new_v # type: ignore + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + symbols[f"&{var.pos}"] = new_v # type: ignore + else: + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + + return new_v + elif ir.value == "codecopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + elif ir.value == "symbol": + return IRLabel(ir.args[0].value, True) + elif ir.value == "data": + label = IRLabel(ir.args[0].value) + ctx.append_data("dbname", [label]) + for c in ir.args[1:]: + if isinstance(c, int): + assert 0 <= c <= 255, "data with invalid size" + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, bytes): + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, IRnode): + data = _convert_ir_basicblock(ctx, c, symbols, variables, allocated_variables) + ctx.append_data("db", [data]) # type: ignore + elif ir.value == "assert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + current_bb = ctx.get_basic_block() + inst = IRInstruction("assert", [arg_0]) # type: ignore + current_bb.append_instruction(inst) + elif ir.value == "label": + label = IRLabel(ir.args[0].value, True) + if not ctx.get_basic_block().is_terminated: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + elif ir.value == "exit_to": + func_t = ir.passthrough_metadata.get("func_t", None) + assert func_t is not None, "exit_to without func_t" + + if func_t.is_external: + # Hardcoded contructor special case + if func_t.name == "__init__": + label = IRLabel(ir.args[0].value, True) + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + return None + if func_t.return_type is None: + inst = IRInstruction("stop", []) + ctx.get_basic_block().append_instruction(inst) + return None + else: + last_ir = None + ret_var = ir.args[1] + deleted = None + if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: + deleted = symbols[f"&{ret_var.value}"] + del symbols[f"&{ret_var.value}"] + for arg in ir.args[2:]: + last_ir = _convert_ir_basicblock( + ctx, arg, symbols, variables, allocated_variables + ) + if deleted is not None: + symbols[f"&{ret_var.value}"] = deleted + + ret_ir = _convert_ir_basicblock( + ctx, ret_var, symbols, variables, allocated_variables + ) + + var = ( + _get_variable_from_address(variables, int(ret_ir.value)) + if isinstance(ret_ir, IRLiteral) + else None + ) + if var is not None: + allocated_var = allocated_variables.get(var.name, None) + assert allocated_var is not None, "unallocated variable" + new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore + + if var.size and int(var.size) > 32: + offset = int(ret_ir.value) - var.pos # type: ignore + if offset > 0: + ptr_var = ctx.append_instruction( + "add", [IRLiteral(var.pos), IRLiteral(offset)] + ) + else: + ptr_var = allocated_var + inst = IRInstruction("return", [last_ir, ptr_var]) + else: + inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + else: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + if sym is None: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if func_t.return_type.memory_bytes_required > 32: + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + else: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if last_ir and int(last_ir.value) > 32: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + ret_buf = IRLiteral(128) # TODO: need allocator + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + + if func_t.is_internal: + assert ir.args[1].value == "return_pc", "return_pc not found" + if func_t.return_type is None: + inst = IRInstruction("ret", [symbols["return_pc"]]) + else: + if func_t.return_type.memory_bytes_required > 32: + inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + else: + ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) + inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) + + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "revert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction("revert", [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "dload": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + + ctx.append_instruction( + "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False + ) + return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) + elif ir.value == "dloadbytes": + dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src_offset = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) + + inst = IRInstruction("dloadbytes", [len_, src, dst]) + ctx.get_basic_block().append_instruction(inst) + return None + elif ir.value == "mload": + sym_ir = ir.args[0] + var = ( + _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None + ) + if var is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mload", [ptr_var]) + else: + if sym_ir.is_literal: + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + new_var = ctx.append_instruction("store", [sym_ir]) + symbols[f"&{sym_ir.value}"] = new_var + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + return sym + + sym = symbols.get(f"&{sym_ir.value}", None) + assert sym is not None, "unallocated variable" + return sym + else: + if sym_ir.is_literal: + new_var = symbols.get(f"&{sym_ir.value}", None) + if new_var is not None: + return ctx.append_instruction("mload", [new_var]) + else: + return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + else: + new_var = _convert_ir_basicblock( + ctx, sym_ir, symbols, variables, allocated_variables + ) + # + # Old IR gets it's return value as a reference in the stack + # New IR gets it's return value in stack in case of 32 bytes or less + # So here we detect ahead of time if this mload leads a self call and + # and we skip the mload + # + if sym_ir.is_self_call: + return new_var + return ctx.append_instruction("mload", [new_var]) + + elif ir.value == "mstore": + sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + + var = None + if isinstance(sym_ir, IRLiteral): + var = _get_variable_from_address(variables, int(sym_ir.value)) + + if var is not None and var.size is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + else: + if isinstance(sym_ir, IRLiteral): + new_var = ctx.append_instruction("store", [arg_1]) + symbols[f"&{sym_ir.value}"] = new_var + # if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + if not isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + return None + + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + if arg_1 and not isinstance(sym_ir, IRLiteral): + symbols[f"&{sym_ir.value}"] = arg_1 + return None + + if isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym]) + ctx.get_basic_block().append_instruction(inst) + return None + else: + symbols[sym_ir.value] = arg_1 + return arg_1 + + elif ir.value in ["sload", "iload"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + return ctx.append_instruction(ir.value, [arg_0]) + elif ir.value in ["sstore", "istore"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction(ir.value, [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + elif ir.value == "unique_symbol": + sym = ir.args[0] + new_var = ctx.get_next_variable() + symbols[f"&{sym.value}"] = new_var + return new_var + elif ir.value == "repeat": + # + # repeat(sym, start, end, bound, body) + # 1) entry block ] + # 2) init counter block ] -> same block + # 3) condition block (exit block, body block) + # 4) body block + # 5) increment block + # 6) exit block + # TODO: Add the extra bounds check after clarify + def emit_body_block(): + global _break_target, _continue_target + old_targets = _break_target, _continue_target + _break_target, _continue_target = exit_block, increment_block + _convert_ir_basicblock(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = old_targets + + sym = ir.args[0] + start = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + end = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + # "bound" is not used + _ = _convert_ir_basicblock(ctx, ir.args[3], symbols, variables, allocated_variables) + body = ir.args[4] + + entry_block = ctx.get_basic_block() + cond_block = IRBasicBlock(ctx.get_next_label(), ctx) + body_block = IRBasicBlock(ctx.get_next_label(), ctx) + jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) + increment_block = IRBasicBlock(ctx.get_next_label(), ctx) + exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + + counter_var = ctx.get_next_variable() + counter_inc_var = ctx.get_next_variable() + ret = ctx.get_next_variable() + + inst = IRInstruction("store", [start], counter_var) + ctx.get_basic_block().append_instruction(inst) + symbols[sym.value] = counter_var + inst = IRInstruction("jmp", [cond_block.label]) + ctx.get_basic_block().append_instruction(inst) + + symbols[sym.value] = ret + cond_block.append_instruction( + IRInstruction( + "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret + ) + ) + + xor_ret = ctx.get_next_variable() + cont_ret = ctx.get_next_variable() + inst = IRInstruction("xor", [ret, end], xor_ret) + cond_block.append_instruction(inst) + cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + ctx.append_basic_block(cond_block) + + # Do a dry run to get the symbols needing phi nodes + start_syms = symbols.copy() + ctx.append_basic_block(body_block) + emit_body_block() + end_syms = symbols.copy() + diff_syms = _get_symbols_common(start_syms, end_syms) + + replacements = {} + for sym, val in diff_syms.items(): + new_var = ctx.get_next_variable() + symbols[sym] = new_var + replacements[val[0]] = new_var + replacements[val[1]] = new_var + cond_block.insert_instruction( + IRInstruction( + "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var + ), + 1, + ) + + body_block.replace_operands(replacements) + + body_end = ctx.get_basic_block() + if not body_end.is_terminated: + body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + + jump_cond = IRInstruction("jmp", [increment_block.label]) + jump_up_block.append_instruction(jump_cond) + ctx.append_basic_block(jump_up_block) + + increment_block.append_instruction( + IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) + ) + increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + ctx.append_basic_block(increment_block) + + ctx.append_basic_block(exit_block) + + inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) + cond_block.append_instruction(inst) + elif ir.value == "break": + assert _break_target is not None, "Break with no break target" + inst = IRInstruction("jmp", [_break_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "continue": + assert _continue_target is not None, "Continue with no contrinue target" + inst = IRInstruction("jmp", [_continue_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "gas": + return ctx.append_instruction("gas", []) + elif ir.value == "returndatasize": + return ctx.append_instruction("returndatasize", []) + elif ir.value == "returndatacopy": + assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + + symbols[f"&{arg_0.value}"] = new_var + return new_var + elif ir.value == "selfdestruct": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + ctx.append_instruction("selfdestruct", [arg_0], False) + elif isinstance(ir.value, str) and ir.value.startswith("log"): + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + inst = IRInstruction(ir.value, reversed(args)) + ctx.get_basic_block().append_instruction(inst) + elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): + _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + elif isinstance(ir.value, str) and ir.value in symbols: + return symbols[ir.value] + elif ir.is_literal: + return IRLiteral(ir.value) + else: + raise Exception(f"Unknown IR node: {ir}") + + return None + + +def _convert_ir_opcode( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> None: + opcode = ir.value.upper() # type: ignore + inst_args = [] + for arg in ir.args: + if isinstance(arg, IRnode): + inst_args.append( + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ) + instruction = IRInstruction(opcode, inst_args) # type: ignore + ctx.get_basic_block().append_instruction(instruction) + + +def _data_ofst_of(sym, ofst, height_): + # e.g. _OFST _sym_foo 32 + assert is_symbol(sym) or is_mem_sym(sym) + if isinstance(ofst.value, int): + # resolve at compile time using magic _OFST op + return ["_OFST", sym, ofst.value] + else: + # if we can't resolve at compile time, resolve at runtime + # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) + return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py new file mode 100644 index 0000000000..11da80ac66 --- /dev/null +++ b/vyper/venom/passes/base_pass.py @@ -0,0 +1,21 @@ +class IRPass: + """ + Decorator for IR passes. This decorator will run the pass repeatedly + until no more changes are made. + """ + + @classmethod + def run_pass(cls, *args, **kwargs): + t = cls() + count = 0 + + while True: + changes_count = t._run_pass(*args, **kwargs) or 0 + count += changes_count + if changes_count == 0: + break + + return count + + def _run_pass(self, *args, **kwargs): + raise NotImplementedError(f"Not implemented! {self.__class__}.run_pass()") diff --git a/vyper/venom/passes/constant_propagation.py b/vyper/venom/passes/constant_propagation.py new file mode 100644 index 0000000000..94b556124e --- /dev/null +++ b/vyper/venom/passes/constant_propagation.py @@ -0,0 +1,13 @@ +from vyper.utils import ir_pass +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +def _process_basic_block(ctx: IRFunction, bb: IRBasicBlock): + pass + + +@ir_pass +def ir_pass_constant_propagation(ctx: IRFunction): + for bb in ctx.basic_blocks: + _process_basic_block(ctx, bb) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py new file mode 100644 index 0000000000..26994bd27f --- /dev/null +++ b/vyper/venom/passes/dft.py @@ -0,0 +1,54 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import DFG +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +# DataFlow Transformation +class DFTPass(IRPass): + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + if inst in self.visited_instructions: + return + self.visited_instructions.add(inst) + + if inst.opcode == "phi": + # phi instructions stay at the beginning of the basic block + # and no input processing is needed + bb.instructions.append(inst) + return + + for op in inst.get_inputs(): + target = self.dfg.get_producing_instruction(op) + if target.parent != inst.parent or target.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + self._process_instruction_r(bb, target) + + bb.instructions.append(inst) + + def _process_basic_block(self, bb: IRBasicBlock) -> None: + self.ctx.append_basic_block(bb) + + instructions = bb.instructions + bb.instructions = [] + + for inst in instructions: + inst.fence_id = self.fence_id + if inst.volatile: + self.fence_id += 1 + + for inst in instructions: + self._process_instruction_r(bb, inst) + + def _run_pass(self, ctx: IRFunction) -> None: + self.ctx = ctx + self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 + self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() + + basic_blocks = ctx.basic_blocks + ctx.basic_blocks = [] + + for bb in basic_blocks: + self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py new file mode 100644 index 0000000000..9ee1012f91 --- /dev/null +++ b/vyper/venom/passes/normalization.py @@ -0,0 +1,90 @@ +from vyper.exceptions import CompilerPanic +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class NormalizationPass(IRPass): + """ + This pass splits basic blocks when there are multiple conditional predecessors. + The code generator expect a normalized CFG, that has the property that + each basic block has at most one conditional predecessor. + """ + + changes = 0 + + def _split_basic_block(self, bb: IRBasicBlock) -> None: + # Iterate over the predecessors of the basic block + for in_bb in list(bb.cfg_in): + jump_inst = in_bb.instructions[-1] + assert bb in in_bb.cfg_out + + # Handle static and dynamic branching + if jump_inst.opcode == "jnz": + self._split_for_static_branch(bb, in_bb) + elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): + self._split_for_dynamic_branch(bb, in_bb) + else: + continue + + self.changes += 1 + + def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + jump_inst = in_bb.instructions[-1] + for i, op in enumerate(jump_inst.operands): + if op == bb.label: + edge = i + break + else: + # none of the edges points to this bb + raise CompilerPanic("bad CFG") + + assert edge in (1, 2) # the arguments which can be labels + + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Redirect the original conditional jump to the intermediary basic block + jump_inst.operands[edge] = split_bb.label + + def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Update any affected labels in the data segment + # TODO: this DESTROYS the cfg! refactor so the translation of the + # selector table produces indirect jumps properly. + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + + def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: + # Create an intermediary basic block and append it + source = in_bb.label.value + target = bb.label.value + split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + self.ctx.append_basic_block(split_bb) + + # Rewire the CFG + # TODO: this is cursed code, it is necessary instead of just running + # calculate_cfg() because split_for_dynamic_branch destroys the CFG! + # ideally, remove this rewiring and just re-run calculate_cfg(). + split_bb.add_cfg_in(in_bb) + split_bb.add_cfg_out(bb) + in_bb.remove_cfg_out(bb) + in_bb.add_cfg_out(split_bb) + bb.remove_cfg_in(in_bb) + bb.add_cfg_in(split_bb) + return split_bb + + def _run_pass(self, ctx: IRFunction) -> int: + self.ctx = ctx + self.changes = 0 + + for bb in ctx.basic_blocks: + if len(bb.cfg_in) > 1: + self._split_basic_block(bb) + + # Sanity check + assert ctx.normalized, "Normalization pass failed" + + return self.changes diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py new file mode 100644 index 0000000000..66c62b74d2 --- /dev/null +++ b/vyper/venom/stack_model.py @@ -0,0 +1,100 @@ +from vyper.venom.basicblock import IROperand, IRVariable + + +class StackModel: + NOT_IN_STACK = object() + _stack: list[IROperand] + + def __init__(self): + self._stack = [] + + def copy(self): + new = StackModel() + new._stack = self._stack.copy() + return new + + @property + def height(self) -> int: + """ + Returns the height of the stack map. + """ + return len(self._stack) + + def push(self, op: IROperand) -> None: + """ + Pushes an operand onto the stack map. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack.append(op) + + def pop(self, num: int = 1) -> None: + del self._stack[len(self._stack) - num :] + + def get_depth(self, op: IROperand) -> int: + """ + Returns the depth of the first matching operand in the stack map. + If the operand is not in the stack map, returns NOT_IN_STACK. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + + for i, stack_op in enumerate(reversed(self._stack)): + if stack_op.value == op.value: + return -i + + return StackModel.NOT_IN_STACK # type: ignore + + def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + """ + Returns the depth of the first matching phi variable in the stack map. + If the none of the phi operands are in the stack, returns NOT_IN_STACK. + Asserts that exactly one of phi1 and phi2 is found. + """ + assert isinstance(phi1, IRVariable) + assert isinstance(phi2, IRVariable) + + ret = StackModel.NOT_IN_STACK + for i, stack_item in enumerate(reversed(self._stack)): + if stack_item in (phi1, phi2): + assert ( + ret is StackModel.NOT_IN_STACK + ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ret = -i + + return ret # type: ignore + + def peek(self, depth: int) -> IROperand: + """ + Returns the top of the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot peek non-in-stack depth" + return self._stack[depth - 1] + + def poke(self, depth: int, op: IROperand) -> None: + """ + Pokes an operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot poke non-in-stack depth" + assert depth <= 0, "Bad depth" + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack[depth - 1] = op + + def dup(self, depth: int) -> None: + """ + Duplicates the operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot dup non-existent operand" + assert depth <= 0, "Cannot dup positive depth" + self._stack.append(self.peek(depth)) + + def swap(self, depth: int) -> None: + """ + Swaps the operand at the given depth in the stack map with the top of the stack. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot swap non-existent operand" + assert depth < 0, "Cannot swap positive depth" + top = self._stack[-1] + self._stack[-1] = self._stack[depth - 1] + self._stack[depth - 1] = top + + def __repr__(self) -> str: + return f"" diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py new file mode 100644 index 0000000000..f6ec45440a --- /dev/null +++ b/vyper/venom/venom_to_assembly.py @@ -0,0 +1,461 @@ +from typing import Any + +from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.stack_model import StackModel + +# instructions which map one-to-one from venom to EVM +_ONE_TO_ONE_INSTRUCTIONS = frozenset( + [ + "revert", + "coinbase", + "calldatasize", + "calldatacopy", + "calldataload", + "gas", + "gasprice", + "gaslimit", + "address", + "origin", + "number", + "extcodesize", + "extcodehash", + "returndatasize", + "returndatacopy", + "callvalue", + "selfbalance", + "sload", + "sstore", + "mload", + "mstore", + "timestamp", + "caller", + "selfdestruct", + "signextend", + "stop", + "shr", + "shl", + "and", + "xor", + "or", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "eq", + "iszero", + "lg", + "lt", + "slt", + "sgt", + "log0", + "log1", + "log2", + "log3", + "log4", + ] +) + + +# TODO: "assembly" gets into the recursion due to how the original +# IR was structured recursively in regards with the deploy instruction. +# There, recursing into the deploy instruction was by design, and +# made it easier to make the assembly generated "recursive" (i.e. +# instructions being lists of instructions). We don't have this restriction +# anymore, so we can probably refactor this to be iterative in coordination +# with the assembler. My suggestion is to let this be for now, and we can +# refactor it later when we are finished phasing out the old IR. +class VenomCompiler: + ctx: IRFunction + label_counter = 0 + visited_instructions: OrderedSet # {IRInstruction} + visited_basicblocks: OrderedSet # {IRBasicBlock} + + def __init__(self, ctx: IRFunction): + self.ctx = ctx + self.label_counter = 0 + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + + def generate_evm(self, no_optimize: bool = False) -> list[str]: + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + self.label_counter = 0 + + stack = StackModel() + asm: list[str] = [] + + # Before emitting the assembly, we need to make sure that the + # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) + # so it should not be called after calling NormalizationPass.run_pass(). + # Liveness is then computed for the normalized IR, and we can proceed to + # assembly generation. + # This is a side-effect of how dynamic jumps are temporarily being used + # to support the O(1) dispatcher. -> look into calculate_cfg() + calculate_cfg(self.ctx) + NormalizationPass.run_pass(self.ctx) + calculate_liveness(self.ctx) + + assert self.ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) + + # Append postambles + revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + runtime = None + if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): + runtime = asm.pop() + + asm.extend(revert_postamble) + if runtime: + runtime.extend(revert_postamble) + asm.append(runtime) + + # Append data segment + data_segments: dict[Any, list[Any]] = dict() + for inst in self.ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") + + extent_point = asm if not isinstance(asm[-1], list) else asm[-1] + extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + + if no_optimize is False: + optimize_assembly(asm) + + return asm + + def _stack_reorder( + self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + ) -> None: + # make a list so we can index it + stack_ops = [x for x in _stack_ops.keys()] + stack_ops_count = len(_stack_ops) + + for i in range(stack_ops_count): + op = stack_ops[i] + final_stack_depth = -(stack_ops_count - i - 1) + depth = stack.get_depth(op) # type: ignore + + if depth == final_stack_depth: + continue + + self.swap(assembly, stack, depth) + self.swap(assembly, stack, final_stack_depth) + + def _emit_input_operands( + self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + ) -> None: + # PRE: we already have all the items on the stack that have + # been scheduled to be killed. now it's just a matter of emitting + # SWAPs, DUPs and PUSHes until we match the `ops` argument + + # dumb heuristic: if the top of stack is not wanted here, swap + # it with something that is wanted + if ops and stack.height > 0 and stack.peek(0) not in ops: + for op in ops: + if isinstance(op, IRVariable) and op not in inst.dup_requirements: + self.swap_op(assembly, stack, op) + break + + emitted_ops = OrderedSet[IROperand]() + for op in ops: + if isinstance(op, IRLabel): + # invoke emits the actual instruction itself so we don't need to emit it here + # but we need to add it to the stack map + if inst.opcode != "invoke": + assembly.append(f"_sym_{op.value}") + stack.push(op) + continue + + if isinstance(op, IRLiteral): + assembly.extend([*PUSH(op.value)]) + stack.push(op) + continue + + if op in inst.dup_requirements: + self.dup_op(assembly, stack, op) + + if op in emitted_ops: + self.dup_op(assembly, stack, op) + + # REVIEW: this seems like it can be reordered across volatile + # boundaries (which includes memory fences). maybe just + # remove it entirely at this point + if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(op.mem_addr)]) + assembly.append("MLOAD") + + emitted_ops.add(op) + + def _generate_evm_for_basicblock_r( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if basicblock in self.visited_basicblocks: + return + self.visited_basicblocks.add(basicblock) + + # assembly entry point into the block + asm.append(f"_sym_{basicblock.label}") + asm.append("JUMPDEST") + + self.clean_stack_from_cfg_in(asm, basicblock, stack) + + for inst in basicblock.instructions: + asm = self._generate_evm_for_instruction(asm, inst, stack) + + for bb in basicblock.cfg_out: + self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + + # pop values from stack at entry to bb + # note this produces the same result(!) no matter which basic block + # we enter from in the CFG. + def clean_stack_from_cfg_in( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if len(basicblock.cfg_in) == 0: + return + + to_pop = OrderedSet[IRVariable]() + for in_bb in basicblock.cfg_in: + # inputs is the input variables we need from in_bb + inputs = input_vars_from(in_bb, basicblock) + + # layout is the output stack layout for in_bb (which works + # for all possible cfg_outs from the in_bb). + layout = in_bb.out_vars + + # pop all the stack items which in_bb produced which we don't need. + to_pop |= layout.difference(inputs) + + for var in to_pop: + depth = stack.get_depth(var) + # don't pop phantom phi inputs + if depth is StackModel.NOT_IN_STACK: + continue + + if depth != 0: + stack.swap(depth) + self.pop(asm, stack) + + def _generate_evm_for_instruction( + self, assembly: list, inst: IRInstruction, stack: StackModel + ) -> list[str]: + opcode = inst.opcode + + # + # generate EVM for op + # + + # Step 1: Apply instruction special stack manipulations + + if opcode in ["jmp", "jnz", "invoke"]: + operands = inst.get_non_label_operands() + elif opcode == "alloca": + operands = inst.operands[1:2] + elif opcode == "iload": + operands = [] + elif opcode == "istore": + operands = inst.operands[0:1] + else: + operands = inst.operands + + if opcode == "phi": + ret = inst.get_outputs()[0] + phi1, phi2 = inst.get_inputs() + depth = stack.get_phi_depth(phi1, phi2) + # collapse the arguments to the phi node in the stack. + # example, for `%56 = %label1 %13 %label2 %14`, we will + # find an instance of %13 *or* %14 in the stack and replace it with %56. + to_be_replaced = stack.peek(depth) + if to_be_replaced in inst.dup_requirements: + # %13/%14 is still live(!), so we make a copy of it + self.dup(assembly, stack, depth) + stack.poke(0, ret) + else: + stack.poke(depth, ret) + return assembly + + # Step 2: Emit instruction's input operands + self._emit_input_operands(assembly, inst, operands, stack) + + # Step 3: Reorder stack + if opcode in ["jnz", "jmp"]: + # prepare stack for jump into another basic block + assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) + b = next(iter(inst.parent.cfg_out)) + target_stack = input_vars_from(inst.parent, b) + # TODO optimize stack reordering at entry and exit from basic blocks + self._stack_reorder(assembly, stack, target_stack) + + # final step to get the inputs to this instruction ordered + # correctly on the stack + self._stack_reorder(assembly, stack, OrderedSet(operands)) + + # some instructions (i.e. invoke) need to do stack manipulations + # with the stack model containing the return value(s), so we fiddle + # with the stack model beforehand. + + # Step 4: Push instruction's return value to stack + stack.pop(len(operands)) + if inst.output is not None: + stack.push(inst.output) + + # Step 5: Emit the EVM instruction(s) + if opcode in _ONE_TO_ONE_INSTRUCTIONS: + assembly.append(opcode.upper()) + elif opcode == "alloca": + pass + elif opcode == "param": + pass + elif opcode == "store": + pass + elif opcode == "dbname": + pass + elif opcode in ["codecopy", "dloadbytes"]: + assembly.append("CODECOPY") + elif opcode == "jnz": + # jump if not zero + if_nonzero_label = inst.operands[1] + if_zero_label = inst.operands[2] + assembly.append(f"_sym_{if_nonzero_label.value}") + assembly.append("JUMPI") + + # make sure the if_zero_label will be optimized out + # assert if_zero_label == next(iter(inst.parent.cfg_out)).label + + assembly.append(f"_sym_{if_zero_label.value}") + assembly.append("JUMP") + + elif opcode == "jmp": + if isinstance(inst.operands[0], IRLabel): + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + else: + assembly.append("JUMP") + elif opcode == "gt": + assembly.append("GT") + elif opcode == "lt": + assembly.append("LT") + elif opcode == "invoke": + target = inst.operands[0] + assert isinstance(target, IRLabel), "invoke target must be a label" + assembly.extend( + [ + f"_sym_label_ret_{self.label_counter}", + f"_sym_{target.value}", + "JUMP", + f"_sym_label_ret_{self.label_counter}", + "JUMPDEST", + ] + ) + self.label_counter += 1 + if stack.height > 0 and stack.peek(0) in inst.dup_requirements: + self.pop(assembly, stack) + elif opcode == "call": + assembly.append("CALL") + elif opcode == "staticcall": + assembly.append("STATICCALL") + elif opcode == "ret": + assembly.append("JUMP") + elif opcode == "return": + assembly.append("RETURN") + elif opcode == "phi": + pass + elif opcode == "sha3": + assembly.append("SHA3") + elif opcode == "sha3_64": + assembly.extend( + [ + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "MSTORE", + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "SHA3", + ] + ) + elif opcode == "ceil32": + assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) + elif opcode == "assert": + assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "deploy": + memsize = inst.operands[0].value + padding = inst.operands[2].value + # TODO: fix this by removing deploy opcode altogether me move emition to ir translation + while assembly[-1] != "JUMPDEST": + assembly.pop() + assembly.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len + assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst + assembly.extend(["RETURN"]) + assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore + assembly = assembly[-1] + elif opcode == "iload": + loc = inst.operands[0].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + elif opcode == "istore": + loc = inst.operands[1].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + else: + raise Exception(f"Unknown opcode: {opcode}") + + # Step 6: Emit instructions output operands (if any) + if inst.output is not None: + assert isinstance(inst.output, IRVariable), "Return value must be a variable" + if inst.output.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(inst.output.mem_addr)]) + + return assembly + + def pop(self, assembly, stack, num=1): + stack.pop(num) + assembly.extend(["POP"] * num) + + def swap(self, assembly, stack, depth): + if depth == 0: + return + stack.swap(depth) + assembly.append(_evm_swap_for(depth)) + + def dup(self, assembly, stack, depth): + stack.dup(depth) + assembly.append(_evm_dup_for(depth)) + + def swap_op(self, assembly, stack, op): + self.swap(assembly, stack, stack.get_depth(op)) + + def dup_op(self, assembly, stack, op): + self.dup(assembly, stack, stack.get_depth(op)) + + +def _evm_swap_for(depth: int) -> str: + swap_idx = -depth + assert 1 <= swap_idx <= 16, "Unsupported swap depth" + return f"SWAP{swap_idx}" + + +def _evm_dup_for(depth: int) -> str: + dup_idx = 1 - depth + assert 1 <= dup_idx <= 16, "Unsupported dup depth" + return f"DUP{dup_idx}"