From 98f502baea6385fe25dbf94a70fb4eddc9f02f56 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Mon, 20 Nov 2023 23:59:23 +0800 Subject: [PATCH 01/18] feat: remove `vyper-serve` (#3666) moving it out into a separate project --- vyper/__main__.py | 8 +-- vyper/cli/vyper_serve.py | 127 --------------------------------------- 2 files changed, 3 insertions(+), 132 deletions(-) delete mode 100755 vyper/cli/vyper_serve.py diff --git a/vyper/__main__.py b/vyper/__main__.py index 371975c301..c5bda47bea 100644 --- a/vyper/__main__.py +++ b/vyper/__main__.py @@ -2,10 +2,10 @@ # -*- coding: UTF-8 -*- import sys -from vyper.cli import vyper_compile, vyper_ir, vyper_serve +from vyper.cli import vyper_compile, vyper_ir if __name__ == "__main__": - allowed_subcommands = ("--vyper-compile", "--vyper-ir", "--vyper-serve") + allowed_subcommands = ("--vyper-compile", "--vyper-ir") if len(sys.argv) <= 1 or sys.argv[1] not in allowed_subcommands: # default (no args, no switch in first arg): run vyper_compile @@ -13,9 +13,7 @@ else: # pop switch and forward args to subcommand subcommand = sys.argv.pop(1) - if subcommand == "--vyper-serve": - vyper_serve._parse_cli_args() - elif subcommand == "--vyper-ir": + if subcommand == "--vyper-ir": vyper_ir._parse_cli_args() else: vyper_compile._parse_cli_args() diff --git a/vyper/cli/vyper_serve.py b/vyper/cli/vyper_serve.py deleted file mode 100755 index 9771dc922d..0000000000 --- a/vyper/cli/vyper_serve.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import sys -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn - -import vyper -from vyper.codegen import ir_node -from vyper.evm.opcodes import DEFAULT_EVM_VERSION -from vyper.exceptions import VyperException - - -def _parse_cli_args(): - return _parse_args(sys.argv[1:]) - - -def _parse_args(argv): - parser = argparse.ArgumentParser(description="Serve Vyper compiler as an HTTP Service") - parser.add_argument( - "--version", action="version", version=f"{vyper.__version__}+commit{vyper.__commit__}" - ) - parser.add_argument( - "-b", - help="Address to bind JSON server on, default: localhost:8000", - default="localhost:8000", - dest="bind_address", - ) - - args = parser.parse_args(argv) - - if ":" in args.bind_address: - ir_node.VYPER_COLOR_OUTPUT = False - runserver(*args.bind_address.split(":")) - else: - print('Provide bind address in "{address}:{port}" format') - - -class VyperRequestHandler(BaseHTTPRequestHandler): - def send_404(self): - self.send_response(404) - self.end_headers() - return - - def send_cors_all(self): - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-type") - - def do_OPTIONS(self): - self.send_response(200) - self.send_cors_all() - self.end_headers() - - def do_GET(self): - if self.path == "/": - self.send_response(200) - self.send_cors_all() - self.end_headers() - self.wfile.write(f"Vyper Compiler. Version: {vyper.__version__}\n".encode()) - else: - self.send_404() - - return - - def do_POST(self): - if self.path == "/compile": - content_len = int(self.headers.get("content-length")) - post_body = self.rfile.read(content_len) - data = json.loads(post_body) - - response, status_code = self._compile(data) - - self.send_response(status_code) - self.send_header("Content-type", "application/json") - self.send_cors_all() - self.end_headers() - self.wfile.write(json.dumps(response).encode()) - - else: - self.send_404() - - return - - def _compile(self, data): - code = data.get("code") - if not code: - return {"status": "failed", "message": 'No "code" key supplied'}, 400 - if not isinstance(code, str): - return {"status": "failed", "message": '"code" must be a non-empty string'}, 400 - - try: - code = data["code"] - out_dict = vyper.compile_code( - code, - list(vyper.compiler.OUTPUT_FORMATS.keys()), - evm_version=data.get("evm_version", DEFAULT_EVM_VERSION), - ) - out_dict["ir"] = str(out_dict["ir"]) - out_dict["ir_runtime"] = str(out_dict["ir_runtime"]) - except VyperException as e: - return ( - {"status": "failed", "message": str(e), "column": e.col_offset, "line": e.lineno}, - 400, - ) - except SyntaxError as e: - return ( - {"status": "failed", "message": str(e), "column": e.offset, "line": e.lineno}, - 400, - ) - - out_dict.update({"status": "success"}) - - return out_dict, 200 - - -class VyperHTTPServer(ThreadingMixIn, HTTPServer): - """Handle requests in a separate thread.""" - - pass - - -def runserver(host="", port=8000): - server_address = (host, int(port)) - httpd = VyperHTTPServer(server_address, VyperRequestHandler) - print(f"Listening on http://{host}:{port}") - httpd.serve_forever() From 28b1121e6ca8042d10a68a3d91df016bc7b83c5f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 21 Nov 2023 08:36:13 -0500 Subject: [PATCH 02/18] perf: lazy eval of f-strings in IRnode ctor (#3602) 25% of IR generation is in IRnode.__repr__ due to the references to self in the f-strings for panic messages. this commit switches to using `assert`, which accomplishes the same thing, but lazily evaluating the error messages (and the code is slightly less pretty) --- vyper/codegen/ir_node.py | 95 ++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ad4aa76437..e17ef47c8f 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -202,27 +202,23 @@ def __init__( self.encoding = encoding self.as_hex = AS_HEX_DEFAULT - def _check(condition, err): - if not condition: - raise CompilerPanic(str(err)) - - _check(self.value is not None, "None is not allowed as IRnode value") + assert self.value is not None, "None is not allowed as IRnode value" # Determine this node's valency (1 if it pushes a value on the stack, # 0 otherwise) and checks to make sure the number and valencies of # children are correct. Also, find an upper bound on gas consumption # Numbers if isinstance(self.value, int): - _check(len(self.args) == 0, "int can't have arguments") + assert len(self.args) == 0, "int can't have arguments" # integers must be in the range (MIN_INT256, MAX_UINT256) - _check(-(2**255) <= self.value < 2**256, "out of range") + assert -(2**255) <= self.value < 2**256, "out of range" self.valency = 1 self._gas = 5 elif isinstance(self.value, bytes): # a literal bytes value, probably inside a "data" node. - _check(len(self.args) == 0, "bytes can't have arguments") + assert len(self.args) == 0, "bytes can't have arguments" self.valency = 0 self._gas = 0 @@ -232,10 +228,9 @@ def _check(condition, err): if self.value.upper() in get_ir_opcodes(): _, ins, outs, gas = get_ir_opcodes()[self.value.upper()] self.valency = outs - _check( - len(self.args) == ins, - f"Number of arguments mismatched: {self.value} {self.args}", - ) + assert ( + len(self.args) == ins + ), f"Number of arguments mismatched: {self.value} {self.args}" # We add 2 per stack height at push time and take it back # at pop time; this makes `break` easier to handle self._gas = gas + 2 * (outs - ins) @@ -244,10 +239,10 @@ def _check(condition, err): # consumed for internal functions, therefore we whitelist this as a zero valency # allowed argument. zero_valency_whitelist = {"pass", "pop"} - _check( - arg.valency == 1 or arg.value in zero_valency_whitelist, - f"invalid argument to `{self.value}`: {arg}", - ) + assert ( + arg.valency == 1 or arg.value in zero_valency_whitelist + ), f"invalid argument to `{self.value}`: {arg}" + self._gas += arg.gas # Dynamic gas cost: 8 gas for each byte of logging data if self.value.upper()[0:3] == "LOG" and isinstance(self.args[1].value, int): @@ -275,30 +270,27 @@ def _check(condition, err): self._gas = self.args[0].gas + max(self.args[1].gas, self.args[2].gas) + 3 if len(self.args) == 2: self._gas = self.args[0].gas + self.args[1].gas + 17 - _check( - self.args[0].valency > 0, - f"zerovalent argument as a test to an if statement: {self.args[0]}", - ) - _check(len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments") + assert ( + self.args[0].valency > 0 + ), f"zerovalent argument as a test to an if statement: {self.args[0]}" + assert len(self.args) in (2, 3), "if statement can only have 2 or 3 arguments" self.valency = self.args[1].valency # With statements: with elif self.value == "with": - _check(len(self.args) == 3, self) - _check( - len(self.args[0].args) == 0 and isinstance(self.args[0].value, str), - f"first argument to with statement must be a variable name: {self.args[0]}", - ) - _check( - self.args[1].valency == 1 or self.args[1].value == "pass", - f"zerovalent argument to with statement: {self.args[1]}", - ) + assert len(self.args) == 3, self + assert len(self.args[0].args) == 0 and isinstance( + self.args[0].value, str + ), f"first argument to with statement must be a variable name: {self.args[0]}" + assert ( + self.args[1].valency == 1 or self.args[1].value == "pass" + ), f"zerovalent argument to with statement: {self.args[1]}" self.valency = self.args[2].valency self._gas = sum([arg.gas for arg in self.args]) + 5 # Repeat statements: repeat elif self.value == "repeat": - _check( - len(self.args) == 5, "repeat(index_name, startval, rounds, rounds_bound, body)" - ) + assert ( + len(self.args) == 5 + ), "repeat(index_name, startval, rounds, rounds_bound, body)" counter_ptr = self.args[0] start = self.args[1] @@ -306,13 +298,12 @@ def _check(condition, err): repeat_bound = self.args[3] body = self.args[4] - _check( - isinstance(repeat_bound.value, int) and repeat_bound.value > 0, - f"repeat bound must be a compile-time positive integer: {self.args[2]}", - ) - _check(repeat_count.valency == 1, repeat_count) - _check(counter_ptr.valency == 1, counter_ptr) - _check(start.valency == 1, start) + assert ( + isinstance(repeat_bound.value, int) and repeat_bound.value > 0 + ), f"repeat bound must be a compile-time positive integer: {self.args[2]}" + assert repeat_count.valency == 1, repeat_count + assert counter_ptr.valency == 1, counter_ptr + assert start.valency == 1, start self.valency = 0 @@ -335,19 +326,17 @@ def _check(condition, err): # then JUMP to my_label. elif self.value in ("goto", "exit_to"): for arg in self.args: - _check( - arg.valency == 1 or arg.value == "pass", - f"zerovalent argument to goto {arg}", - ) + assert ( + arg.valency == 1 or arg.value == "pass" + ), f"zerovalent argument to goto {arg}" self.valency = 0 self._gas = sum([arg.gas for arg in self.args]) elif self.value == "label": - _check( - self.args[1].value == "var_list", - f"2nd argument to label must be var_list, {self}", - ) - _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}") + assert ( + self.args[1].value == "var_list" + ), f"2nd argument to label must be var_list, {self}" + assert len(args) == 3, f"label should have 3 args but has {len(args)}, {self}" self.valency = 0 self._gas = 1 + sum(t.gas for t in self.args) elif self.value == "unique_symbol": @@ -371,14 +360,14 @@ def _check(condition, err): # Multi statements: multi ... elif self.value == "multi": for arg in self.args: - _check( - arg.valency > 0, f"Multi expects all children to not be zerovalent: {arg}" - ) + assert ( + arg.valency > 0 + ), f"Multi expects all children to not be zerovalent: {arg}" self.valency = sum([arg.valency for arg in self.args]) self._gas = sum([arg.gas for arg in self.args]) elif self.value == "deploy": self.valency = 0 - _check(len(self.args) == 3, f"`deploy` should have three args {self}") + assert len(self.args) == 3, f"`deploy` should have three args {self}" self._gas = NullAttractor() # unknown # Stack variables else: From b16ab914fc6126894e19172ba08df0193653edab Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 21 Nov 2023 09:16:07 -0500 Subject: [PATCH 03/18] docs: add script to help working on the compiler (#3674) --- README.md | 17 +++++++++++++++++ docs/contributing.rst | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bad929956d..33c4557cc8 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,23 @@ make dev-init python setup.py test ``` +## Developing (working on the compiler) + +A useful script to have in your PATH is something like the following: +```bash +$ cat ~/.local/bin/vyc +#!/usr/bin/env bash +PYTHONPATH=. python vyper/cli/vyper_compile.py "$@" +``` + +To run a python performance profile (to find compiler perf hotspots): +```bash +PYTHONPATH=. python -m cProfile -s tottime vyper/cli/vyper_compile.py "$@" +``` + +To get a call graph from a python profile, https://stackoverflow.com/a/23164271/ is helpful. + + # Contributing * See Issues tab, and feel free to submit your own issues * Add PRs if you discover a solution to an existing issue diff --git a/docs/contributing.rst b/docs/contributing.rst index 6dc57b26c3..221600f930 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -75,4 +75,4 @@ If you are making a larger change, please consult first with the `Vyper (Smart C Although we do CI testing, please make sure that the tests pass for supported Python version and ensure that it builds locally before submitting a pull request. -Thank you for your help! ​ +Thank you for your help! From aa1ea21a79e577227e13b9756a8c26107c5b3674 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 23 Nov 2023 03:02:47 +0800 Subject: [PATCH 04/18] refactor: builtin functions inherit from `VyperType` (#3559) for consistency, have builtin functions inherit from `VyperType`. --- vyper/builtins/_signatures.py | 25 +++++----- vyper/builtins/functions.py | 91 +++++++++++++++++------------------ 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 2802421129..a5949dfd85 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,5 +1,5 @@ import functools -from typing import Dict +from typing import Any, Optional from vyper.ast import nodes as vy_ast from vyper.ast.validation import validate_call_args @@ -74,12 +74,14 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction(VyperType): +class BuiltinFunctionT(VyperType): _has_varargs = False - _kwargs: Dict[str, KwargSettings] = {} + _inputs: list[tuple[str, Any]] = [] + _kwargs: dict[str, KwargSettings] = {} + _return_type: Optional[VyperType] = None # helper function to deal with TYPE_DEFINITIONs - def _validate_single(self, arg, expected_type): + def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: # TODO using "TYPE_DEFINITION" is a kludge in derived classes, # refactor me. if expected_type == "TYPE_DEFINITION": @@ -89,15 +91,15 @@ def _validate_single(self, arg, expected_type): else: validate_expected_type(arg, expected_type) - def _validate_arg_types(self, node): + def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates - expect_num_args = num_args + expect_num_args: Any = num_args if self._has_varargs: # note special meaning for -1 in validate_call_args API expect_num_args = (num_args, -1) - validate_call_args(node, expect_num_args, self._kwargs) + validate_call_args(node, expect_num_args, list(self._kwargs.keys())) for arg, (_, expected) in zip(node.args, self._inputs): self._validate_single(arg, expected) @@ -118,13 +120,12 @@ def _validate_arg_types(self, node): # ensures the type can be inferred exactly. get_exact_type_from_node(arg) - def fetch_call_return(self, node): + def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) - if self._return_type: - return self._return_type + return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node: vy_ast.Call) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] @@ -136,7 +137,7 @@ def infer_arg_types(self, node): ret.extend(get_exact_type_from_node(arg) for arg in varargs) return ret - def infer_kwarg_types(self, node): + def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} def __repr__(self): diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 001939638b..b2d817ec5c 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -98,14 +98,14 @@ ) from ._convert import convert -from ._signatures import BuiltinFunction, process_inputs +from ._signatures import BuiltinFunctionT, process_inputs SHA256_ADDRESS = 2 SHA256_BASE_GAS = 60 SHA256_PER_WORD_GAS = 12 -class FoldedFunction(BuiltinFunction): +class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded # Since foldable builtin functions are not folded before semantics validation, @@ -113,7 +113,7 @@ class FoldedFunction(BuiltinFunction): _kwargable = True -class TypenameFoldedFunction(FoldedFunction): +class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. @@ -132,7 +132,7 @@ def infer_arg_types(self, node): return [input_typedef] -class Floor(BuiltinFunction): +class Floor(BuiltinFunctionT): _id = "floor" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -162,7 +162,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Ceil(BuiltinFunction): +class Ceil(BuiltinFunctionT): _id = "ceil" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -192,7 +192,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Convert(BuiltinFunction): +class Convert(BuiltinFunctionT): _id = "convert" def fetch_call_return(self, node): @@ -285,14 +285,13 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: # note: this and a lot of other builtins could be refactored to accept any uint type -class Slice(BuiltinFunction): +class Slice(BuiltinFunctionT): _id = "slice" _inputs = [ ("b", (BYTES32_T, BytesT.any(), StringT.any())), ("start", UINT256_T), ("length", UINT256_T), ] - _return_type = None def fetch_call_return(self, node): arg_type, _, _ = self.infer_arg_types(node) @@ -457,7 +456,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(b3.resolve(ret))) -class Len(BuiltinFunction): +class Len(BuiltinFunctionT): _id = "len" _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T @@ -488,7 +487,7 @@ def build_IR(self, node, context): return get_bytearray_length(arg) -class Concat(BuiltinFunction): +class Concat(BuiltinFunctionT): _id = "concat" def fetch_call_return(self, node): @@ -593,7 +592,7 @@ def build_IR(self, expr, context): ) -class Keccak256(BuiltinFunction): +class Keccak256(BuiltinFunctionT): _id = "keccak256" # TODO allow any BytesM_T _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] @@ -641,7 +640,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len): ] -class Sha256(BuiltinFunction): +class Sha256(BuiltinFunctionT): _id = "sha256" _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T @@ -713,7 +712,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class MethodID(FoldedFunction): +class MethodID(FoldedFunctionT): _id = "method_id" def evaluate(self, node): @@ -753,7 +752,7 @@ def infer_kwarg_types(self, node): return BytesT(4) -class ECRecover(BuiltinFunction): +class ECRecover(BuiltinFunctionT): _id = "ecrecover" _inputs = [ ("hash", BYTES32_T), @@ -788,7 +787,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class _ECArith(BuiltinFunction): +class _ECArith(BuiltinFunctionT): @process_inputs def build_IR(self, expr, _args, kwargs, context): args_tuple = ir_tuple_from_args(_args) @@ -847,14 +846,13 @@ def _storage_element_getter(index): return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T) -class Extract32(BuiltinFunction): +class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] # "TYPE_DEFINITION" is a placeholder value for a type definition string, and # will be replaced by a `TYPE_T` object in `infer_kwarg_types` # (note that it is ignored in _validate_arg_types) _kwargs = {"output_type": KwargSettings("TYPE_DEFINITION", BYTES32_T)} - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -959,7 +957,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(clamp_basetype(o), typ=ret_type) -class AsWeiValue(BuiltinFunction): +class AsWeiValue(BuiltinFunctionT): _id = "as_wei_value" _inputs = [("value", (IntegerT.any(), DecimalT())), ("unit", StringT.any())] _return_type = UINT256_T @@ -1058,7 +1056,7 @@ def build_IR(self, expr, args, kwargs, context): empty_value = IRnode.from_list(0, typ=BYTES32_T) -class RawCall(BuiltinFunction): +class RawCall(BuiltinFunctionT): _id = "raw_call" _inputs = [("to", AddressT()), ("data", BytesT.any())] _kwargs = { @@ -1069,7 +1067,6 @@ class RawCall(BuiltinFunction): "is_static_call": KwargSettings(BoolT(), False, require_literal=True), "revert_on_failure": KwargSettings(BoolT(), True, require_literal=True), } - _return_type = None def fetch_call_return(self, node): self._validate_arg_types(node) @@ -1215,12 +1212,11 @@ def build_IR(self, expr, args, kwargs, context): raise CompilerPanic("unreachable!") -class Send(BuiltinFunction): +class Send(BuiltinFunctionT): _id = "send" _inputs = [("to", AddressT()), ("value", UINT256_T)] # default gas stipend is 0 _kwargs = {"gas": KwargSettings(UINT256_T, 0)} - _return_type = None @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -1232,10 +1228,9 @@ def build_IR(self, expr, args, kwargs, context): ) -class SelfDestruct(BuiltinFunction): +class SelfDestruct(BuiltinFunctionT): _id = "selfdestruct" _inputs = [("to", AddressT())] - _return_type = None _is_terminus = True _warned = False @@ -1251,7 +1246,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BlockHash(BuiltinFunction): +class BlockHash(BuiltinFunctionT): _id = "blockhash" _inputs = [("block_num", UINT256_T)] _return_type = BYTES32_T @@ -1264,7 +1259,7 @@ def build_IR(self, expr, args, kwargs, contact): ) -class RawRevert(BuiltinFunction): +class RawRevert(BuiltinFunctionT): _id = "raw_revert" _inputs = [("data", BytesT.any())] _return_type = None @@ -1286,7 +1281,7 @@ def build_IR(self, expr, args, kwargs, context): return b.resolve(IRnode.from_list(["revert", data, len_])) -class RawLog(BuiltinFunction): +class RawLog(BuiltinFunctionT): _id = "raw_log" _inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))] @@ -1337,7 +1332,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BitwiseAnd(BuiltinFunction): +class BitwiseAnd(BuiltinFunctionT): _id = "bitwise_and" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1363,7 +1358,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["and", args[0], args[1]], typ=UINT256_T) -class BitwiseOr(BuiltinFunction): +class BitwiseOr(BuiltinFunctionT): _id = "bitwise_or" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1389,7 +1384,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["or", args[0], args[1]], typ=UINT256_T) -class BitwiseXor(BuiltinFunction): +class BitwiseXor(BuiltinFunctionT): _id = "bitwise_xor" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1415,7 +1410,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["xor", args[0], args[1]], typ=UINT256_T) -class BitwiseNot(BuiltinFunction): +class BitwiseNot(BuiltinFunctionT): _id = "bitwise_not" _inputs = [("x", UINT256_T)] _return_type = UINT256_T @@ -1442,7 +1437,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["not", args[0]], typ=UINT256_T) -class Shift(BuiltinFunction): +class Shift(BuiltinFunctionT): _id = "shift" _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())] _return_type = UINT256_T @@ -1496,7 +1491,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(IRnode.from_list(ret, typ=argty))) -class _AddMulMod(BuiltinFunction): +class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T @@ -1537,7 +1532,7 @@ class MulMod(_AddMulMod): _opcode = "mulmod" -class PowMod256(BuiltinFunction): +class PowMod256(BuiltinFunctionT): _id = "pow_mod256" _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T @@ -1560,7 +1555,7 @@ def build_IR(self, expr, context): return IRnode.from_list(["exp", left, right], typ=left.typ) -class Abs(BuiltinFunction): +class Abs(BuiltinFunctionT): _id = "abs" _inputs = [("value", INT256_T)] _return_type = INT256_T @@ -1711,7 +1706,7 @@ def _create_preamble(codesize): return ["or", bytes_to_int(evm), shl(shl_bits, codesize)], evm_len -class _CreateBase(BuiltinFunction): +class _CreateBase(BuiltinFunctionT): _kwargs = { "value": KwargSettings(UINT256_T, zero_value), "salt": KwargSettings(BYTES32_T, empty_value), @@ -1940,7 +1935,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar return b1.resolve(b2.resolve(ir)) -class _UnsafeMath(BuiltinFunction): +class _UnsafeMath(BuiltinFunctionT): # TODO add unsafe math for `decimal`s _inputs = [("a", IntegerT.any()), ("b", IntegerT.any())] @@ -2006,7 +2001,7 @@ class UnsafeDiv(_UnsafeMath): op = "div" -class _MinMax(BuiltinFunction): +class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] def evaluate(self, node): @@ -2080,7 +2075,7 @@ class Max(_MinMax): _opcode = "gt" -class Uint2Str(BuiltinFunction): +class Uint2Str(BuiltinFunctionT): _id = "uint2str" _inputs = [("x", IntegerT.unsigneds())] @@ -2152,7 +2147,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, location=MEMORY, typ=return_t)) -class Sqrt(BuiltinFunction): +class Sqrt(BuiltinFunctionT): _id = "sqrt" _inputs = [("d", DecimalT())] _return_type = DecimalT() @@ -2208,7 +2203,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class ISqrt(BuiltinFunction): +class ISqrt(BuiltinFunctionT): _id = "isqrt" _inputs = [("d", UINT256_T)] _return_type = UINT256_T @@ -2258,7 +2253,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, typ=UINT256_T)) -class Empty(TypenameFoldedFunction): +class Empty(TypenameFoldedFunctionT): _id = "empty" def fetch_call_return(self, node): @@ -2273,7 +2268,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode("~empty", typ=output_type) -class Breakpoint(BuiltinFunction): +class Breakpoint(BuiltinFunctionT): _id = "breakpoint" _inputs: list = [] @@ -2291,7 +2286,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list("breakpoint", annotation="breakpoint()") -class Print(BuiltinFunction): +class Print(BuiltinFunctionT): _id = "print" _inputs: list = [] _has_varargs = True @@ -2369,7 +2364,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, annotation="print:" + sig) -class ABIEncode(BuiltinFunction): +class ABIEncode(BuiltinFunctionT): _id = "_abi_encode" # TODO prettier to rename this to abi.encode # signature: *, ensure_tuple= -> Bytes[] # explanation of ensure_tuple: @@ -2486,7 +2481,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, location=MEMORY, typ=buf_t) -class ABIDecode(BuiltinFunction): +class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} @@ -2573,7 +2568,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class _MinMaxValue(TypenameFoldedFunction): +class _MinMaxValue(TypenameFoldedFunctionT): def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2607,7 +2602,7 @@ def _eval(self, type_): return type_.ast_bounds[1] -class Epsilon(TypenameFoldedFunction): +class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" def evaluate(self, node): From b334218f855ae94285afe271a770f1f29d20b7df Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 22 Nov 2023 22:57:30 -0500 Subject: [PATCH 05/18] docs: add warnings at the top of all example token contracts (#3676) discourage people from using them in production --- examples/crowdfund.vy | 6 +++++- examples/tokens/ERC1155ownable.vy | 7 ++++++- examples/tokens/ERC20.vy | 6 +++++- examples/tokens/ERC4626.vy | 7 +++++++ examples/tokens/ERC721.vy | 6 +++++- examples/wallet/wallet.vy | 7 +++++-- 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/examples/crowdfund.vy b/examples/crowdfund.vy index 56b34308f1..6d07e15bc4 100644 --- a/examples/crowdfund.vy +++ b/examples/crowdfund.vy @@ -1,4 +1,8 @@ -# Setup private variables (only callable from within the contract) +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# example of a crowd funding contract funders: HashMap[address, uint256] beneficiary: address diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index f1070b8f89..30057582e8 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -1,8 +1,13 @@ +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + # @version >=0.3.4 """ -@dev Implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) +@dev example implementation of ERC-1155 non-fungible token standard ownable, with approval, OPENSEA compatible (name, symbol) @author Dr. Pixel (github: @Doc-Pixel) """ + ############### imports ############### from vyper.interfaces import ERC165 diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 4c1d334691..c3809dbb60 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-20 token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of an ERC20 token # @author Takayuki Jimba (@yudetamago) # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-20.md diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index a9cbcc86c8..0a0a698bf0 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -1,4 +1,11 @@ # NOTE: Copied from https://github.com/fubuloubu/ERC4626/blob/1a10b051928b11eeaad15d80397ed36603c2a49b/contracts/VyperVault.vy + +# example implementation of an ERC4626 vault + +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + from vyper.interfaces import ERC20 from vyper.interfaces import ERC4626 diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 5125040399..152b94b046 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -1,4 +1,8 @@ -# @dev Implementation of ERC-721 non-fungible token standard. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# @dev example implementation of ERC-721 non-fungible token standard. # @author Ryuya Nakamura (@nrryuya) # Modified from: https://github.com/vyperlang/vyper/blob/de74722bf2d8718cca46902be165f9fe0e3641dd/examples/tokens/ERC721.vy diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index 5fd5229136..e2515d9e62 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -1,5 +1,8 @@ -# An example of how you can do a wallet in Vyper. -# Warning: NOT AUDITED. Do not use to store substantial quantities of funds. +########################################################################### +## THIS IS EXAMPLE CODE, NOT MEANT TO BE USED IN PRODUCTION! CAVEAT EMPTOR! +########################################################################### + +# An example of how you can implement a wallet in Vyper. # A list of the owners addresses (there are a maximum of 5 owners) owners: public(address[5]) From 9a982bd37a8b5a48f9a30939ec57e37ed01a72e0 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 29 Nov 2023 00:54:19 +0900 Subject: [PATCH 06/18] docs: typo in on_chain_market_maker.vy (#3677) --- examples/market_maker/on_chain_market_maker.vy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/market_maker/on_chain_market_maker.vy b/examples/market_maker/on_chain_market_maker.vy index be9c62b945..d385d2e0c6 100644 --- a/examples/market_maker/on_chain_market_maker.vy +++ b/examples/market_maker/on_chain_market_maker.vy @@ -9,7 +9,7 @@ invariant: public(uint256) token_address: ERC20 owner: public(address) -# Sets the on chain market maker with its owner, intial token quantity, +# Sets the on chain market maker with its owner, initial token quantity, # and initial ether quantity @external @payable From cbac5aba53f87b388e08f169481d6b5c29002c27 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Fri, 1 Dec 2023 21:41:57 +0200 Subject: [PATCH 07/18] feat: implement new IR for vyper (venom IR) (#3659) this commit implements a new IR for the vyper compiler. most of the implementation is self-contained in the `./vyper/venom/` directory. Venom IR is LLVM-"inspired", although we do not use LLVM on account of: 1) not wanting to introduce a large external dependency 2) no EVM backend exists for LLVM, so we would have to write one ourselves. see prior work at https://github.com/etclabscore/evm_llvm. fundamentally, LLVM is architected to target register machines; an EVM backend could conceivably be implmented, but it would always feel "bolted" on. 3) integration with LLVM would invariably be very complex 4) one advantage of using LLVM is getting multiple backends "for free", but in our case, none of the backends we are interested in (particularly EVM) have LLVM implementations. that being said, Venom is close enough to LLVM that it would seem fairly straightforward to pass "in-and-out" of LLVM, converting to LLVM to take advantage of its optimization passes and/or analysis utilities, and then converting back to Venom for final EVM emission, if that becomes desirable down the line. it could even provided as an "extra" -- if LLVM is installed on the system and enabled for the build, pass to LLVM for extra optimization, but otherwise the compiler being self-contained. for more details about the design and architecture of Venom IR, see `./vyper/venom/README.md`. note that this commit specifically focuses on the architecture, design and implementation of Venom. that is, more focus was spent on architecting the Venom compiler itself. the Vyper frontend does not emit Venom natively yet, Venom emission is implemented as a translation step from the current s-expr based IR to Venom. the translation is not feature-complete, and may have bugs. that being said, vyper compilation via Venom is experimentally available by passing the `--experimental-codegen` flag to vyper on the CLI. incrementally refactoring the codegen to use Venom instead of the earlier s-expr IR will be the next area of focus of development. --------- Co-authored-by: Charles Cooper --- .../compiler/venom/test_duplicate_operands.py | 28 + .../compiler/venom/test_multi_entry_block.py | 96 ++ .../venom/test_stack_at_external_return.py | 5 + vyper/cli/vyper_compile.py | 8 + vyper/codegen/function_definitions/common.py | 4 + .../function_definitions/internal_function.py | 4 +- vyper/codegen/ir_node.py | 16 + vyper/codegen/return_.py | 4 +- vyper/codegen/self_call.py | 2 + vyper/compiler/__init__.py | 2 + vyper/compiler/phases.py | 28 +- vyper/ir/compile_ir.py | 80 +- vyper/ir/optimizer.py | 4 + vyper/semantics/types/function.py | 2 +- vyper/utils.py | 62 +- vyper/venom/README.md | 162 +++ vyper/venom/__init__.py | 56 ++ vyper/venom/analysis.py | 191 ++++ vyper/venom/basicblock.py | 345 +++++++ vyper/venom/bb_optimizer.py | 73 ++ vyper/venom/function.py | 170 ++++ vyper/venom/ir_node_to_venom.py | 943 ++++++++++++++++++ vyper/venom/passes/base_pass.py | 21 + vyper/venom/passes/constant_propagation.py | 13 + vyper/venom/passes/dft.py | 54 + vyper/venom/passes/normalization.py | 90 ++ vyper/venom/stack_model.py | 100 ++ vyper/venom/venom_to_assembly.py | 461 +++++++++ 28 files changed, 2994 insertions(+), 30 deletions(-) create mode 100644 tests/compiler/venom/test_duplicate_operands.py create mode 100644 tests/compiler/venom/test_multi_entry_block.py create mode 100644 tests/compiler/venom/test_stack_at_external_return.py create mode 100644 vyper/venom/README.md create mode 100644 vyper/venom/__init__.py create mode 100644 vyper/venom/analysis.py create mode 100644 vyper/venom/basicblock.py create mode 100644 vyper/venom/bb_optimizer.py create mode 100644 vyper/venom/function.py create mode 100644 vyper/venom/ir_node_to_venom.py create mode 100644 vyper/venom/passes/base_pass.py create mode 100644 vyper/venom/passes/constant_propagation.py create mode 100644 vyper/venom/passes/dft.py create mode 100644 vyper/venom/passes/normalization.py create mode 100644 vyper/venom/stack_model.py create mode 100644 vyper/venom/venom_to_assembly.py diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/compiler/venom/test_duplicate_operands.py new file mode 100644 index 0000000000..505f01e31b --- /dev/null +++ b/tests/compiler/venom/test_duplicate_operands.py @@ -0,0 +1,28 @@ +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRFunction + + +def test_duplicate_operands(): + """ + Test the duplicate operands code generation. + The venom code: + + %1 = 10 + %2 = add %1, %1 + %3 = mul %1, %2 + stop + + Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] + """ + ctx = IRFunction() + + op = ctx.append_instruction("store", [IRLiteral(10)]) + sum = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("mul", [sum, op]) + ctx.append_instruction("stop", [], False) + + asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) + + assert asm == ["PUSH1", 10, "DUP1", "DUP1", "DUP1", "ADD", "MUL", "STOP", "REVERT"] diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/compiler/venom/test_multi_entry_block.py new file mode 100644 index 0000000000..bb57fa1065 --- /dev/null +++ b/tests/compiler/venom/test_multi_entry_block.py @@ -0,0 +1,96 @@ +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRLiteral +from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel +from vyper.venom.passes.normalization import NormalizationPass + + +def test_multi_entry_block_1(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +# more complicated one +def test_multi_entry_block_2(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + block_2_label = IRLabel("block_2", ctx) + + op = ctx.append_instruction("store", [IRLiteral(10)]) + acc = ctx.append_instruction("add", [op, op]) + ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + + block_2 = IRBasicBlock(block_2_label, ctx) + ctx.append_basic_block(block_2) + acc = ctx.append_instruction("add", [acc, op]) + op = ctx.append_instruction("store", [IRLiteral(10)]) + ctx.append_instruction("mstore", [acc, op], False) + # switch the order of the labels, for fun + ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + ctx.append_instruction("mul", [acc, acc]) + ctx.append_instruction("jmp", [finish_label], False) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + ctx.append_instruction("stop", [], False) + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/compiler/venom/test_stack_at_external_return.py new file mode 100644 index 0000000000..be9fa66e9a --- /dev/null +++ b/tests/compiler/venom/test_stack_at_external_return.py @@ -0,0 +1,5 @@ +def test_stack_at_external_return(): + """ + TODO: USE BOA DO GENERATE THIS TEST + """ + pass diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 82eba63f32..ca1792384e 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -141,6 +141,11 @@ def _parse_args(argv): "-p", help="Set the root path for contract imports", default=".", dest="root_folder" ) parser.add_argument("-o", help="Set the output path", dest="output_path") + parser.add_argument( + "--experimental-codegen", + help="The compiler use the new IR codegen. This is an experimental feature.", + action="store_true", + ) args = parser.parse_args(argv) @@ -188,6 +193,7 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, + args.experimental_codegen, ) if args.output_path: @@ -225,6 +231,7 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> dict: root_path = Path(root_folder).resolve() if not root_path.exists(): @@ -275,6 +282,7 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, + experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 1d24b6c6dd..c48f1256c3 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -162,5 +162,9 @@ def generate_ir_for_function( # (note: internal functions do not need to adjust gas estimate since mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore + ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore + ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index 228191e3ca..cf01dbdab4 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -68,4 +68,6 @@ def generate_ir_for_internal_function( ["seq"] + nonreentrant_post + [["exit_to", "return_pc"]], ] - return IRnode.from_list(["seq", body, cleanup_routine]) + ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + ir_node.passthrough_metadata["func_t"] = func_t + return ir_node diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index e17ef47c8f..ce26066968 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -171,6 +171,10 @@ class IRnode: valency: int args: List["IRnode"] value: Union[str, int] + is_self_call: bool + passthrough_metadata: dict[str, Any] + func_ir: Any + common_ir: Any def __init__( self, @@ -184,6 +188,8 @@ def __init__( mutable: bool = True, add_gas_estimate: int = 0, encoding: Encoding = Encoding.VYPER, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, ): if args is None: args = [] @@ -201,6 +207,10 @@ def __init__( self.add_gas_estimate = add_gas_estimate self.encoding = encoding self.as_hex = AS_HEX_DEFAULT + self.is_self_call = is_self_call + self.passthrough_metadata = passthrough_metadata or {} + self.func_ir = None + self.common_ir = None assert self.value is not None, "None is not allowed as IRnode value" @@ -585,6 +595,8 @@ def from_list( error_msg: Optional[str] = None, mutable: bool = True, add_gas_estimate: int = 0, + is_self_call: bool = False, + passthrough_metadata: dict[str, Any] = None, encoding: Encoding = Encoding.VYPER, ) -> "IRnode": if isinstance(typ, str): @@ -617,6 +629,8 @@ def from_list( source_pos=source_pos, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) else: return cls( @@ -630,4 +644,6 @@ def from_list( add_gas_estimate=add_gas_estimate, encoding=encoding, error_msg=error_msg, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 56bea2b8da..41fa11ab56 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -40,7 +40,9 @@ def finalize(fill_return_buffer): cleanup_loops = "cleanup_repeat" if context.forvars else "seq" # NOTE: because stack analysis is incomplete, cleanup_repeat must # come after fill_return_buffer otherwise the stack will break - return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit]) + jump_to_exit_ir = IRnode.from_list(jump_to_exit) + jump_to_exit_ir.passthrough_metadata["func_t"] = func_t + return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit_ir]) if context.return_type is None: if context.is_internal: diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index c320e6889c..f03f2eb9c8 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -121,4 +121,6 @@ def ir_for_self_call(stmt_expr, context): add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True + o.passthrough_metadata["func_t"] = func_t + o.passthrough_metadata["args_ir"] = args_ir return o diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 62ea05b243..61d7a7c229 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -55,6 +55,7 @@ def compile_code( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, + experimental_codegen: bool = False, ) -> dict: """ Generate consumable compiler output(s) from a single contract source code. @@ -104,6 +105,7 @@ def compile_code( storage_layout_override, show_gas_estimates, no_bytecode_metadata, + experimental_codegen, ) ret = {} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index bfbb336d54..4e32812fee 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -16,6 +16,7 @@ from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.typing import StorageLayout +from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") @@ -60,6 +61,7 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, + experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -78,14 +80,18 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False + experimental_codegen: bool, optional + Use experimental codegen. Defaults to False """ + # to force experimental codegen, uncomment: + # experimental_codegen = True self.contract_path = contract_path self.source_code = source_code self.source_id = source_id self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - + self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -160,7 +166,11 @@ def global_ctx(self) -> GlobalContext: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - return generate_ir_nodes(self.global_ctx, self.settings.optimize) + nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) + if self.experimental_codegen: + return [generate_ir(nodes[0]), generate_ir(nodes[1])] + else: + return nodes @property def ir_nodes(self) -> IRnode: @@ -183,11 +193,21 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - return generate_assembly(self.ir_nodes, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_nodes, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_nodes, self.settings.optimize) @cached_property def assembly_runtime(self) -> list: - return generate_assembly(self.ir_runtime, self.settings.optimize) + if self.experimental_codegen: + return generate_assembly_experimental( + self.ir_runtime, self.settings.optimize # type: ignore + ) + else: + return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property def bytecode(self) -> bytes: diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1c4dc1ef7c..1d3df8becb 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -9,6 +9,7 @@ from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic +from vyper.ir.optimizer import COMMUTATIVE_OPS from vyper.utils import MemoryPositions from vyper.version import version_tuple @@ -164,7 +165,7 @@ def _add_postambles(asm_ops): # insert the postambles *before* runtime code # so the data section of the runtime code can't bork the postambles. runtime = None - if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader): + if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], RuntimeHeader): runtime = asm_ops.pop() # for some reason there might not be a STOP at the end of asm_ops. @@ -229,7 +230,7 @@ def compile_to_assembly(code, optimize=OptimizationLevel.GAS): _relocate_segments(res) if optimize != OptimizationLevel.NONE: - _optimize_assembly(res) + optimize_assembly(res) return res @@ -531,7 +532,7 @@ def _height_of(witharg): # since the asm data structures are very primitive, to make sure # assembly_to_evm is able to calculate data offsets correctly, # we pass the memsize via magic opcodes to the subcode - subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode + subcode = [RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode # append the runtime code after the ctor code # `append(...)` call here is intentional. @@ -675,7 +676,7 @@ def _height_of(witharg): ) elif code.value == "data": - data_node = [_DataHeader("_sym_" + code.args[0].value)] + data_node = [DataHeader("_sym_" + code.args[0].value)] for c in code.args[1:]: if isinstance(c.value, int): @@ -837,6 +838,31 @@ def _prune_inefficient_jumps(assembly): return changed +def _optimize_inefficient_jumps(assembly): + # optimize sequences `_sym_common JUMPI _sym_x JUMP _sym_common JUMPDEST` + # to `ISZERO _sym_x JUMPI _sym_common JUMPDEST` + changed = False + i = 0 + while i < len(assembly) - 6: + if ( + is_symbol(assembly[i]) + and assembly[i + 1] == "JUMPI" + and is_symbol(assembly[i + 2]) + and assembly[i + 3] == "JUMP" + and assembly[i] == assembly[i + 4] + and assembly[i + 5] == "JUMPDEST" + ): + changed = True + assembly[i] = "ISZERO" + assembly[i + 1] = assembly[i + 2] + assembly[i + 2] = "JUMPI" + del assembly[i + 3 : i + 4] + else: + i += 1 + + return changed + + def _merge_jumpdests(assembly): # When we have multiple JUMPDESTs in a row, or when a JUMPDEST # is immediately followed by another JUMP, we can skip the @@ -938,7 +964,7 @@ def _prune_unused_jumpdests(assembly): used_jumpdests.add(assembly[i]) for item in assembly: - if isinstance(item, list) and isinstance(item[0], _DataHeader): + if isinstance(item, list) and isinstance(item[0], DataHeader): # add symbols used in data sections as they are likely # used for a jumptable. for t in item: @@ -961,6 +987,12 @@ def _stack_peephole_opts(assembly): changed = False i = 0 while i < len(assembly) - 2: + if assembly[i : i + 3] == ["DUP1", "SWAP2", "SWAP1"]: + changed = True + del assembly[i + 2] + assembly[i] = "SWAP1" + assembly[i + 1] = "DUP2" + continue # usually generated by with statements that return their input like # (with x (...x)) if assembly[i : i + 3] == ["DUP1", "SWAP1", "POP"]: @@ -975,16 +1007,22 @@ def _stack_peephole_opts(assembly): changed = True del assembly[i] continue + if assembly[i : i + 2] == ["SWAP1", "SWAP1"]: + changed = True + del assembly[i : i + 2] + if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: + changed = True + del assembly[i] i += 1 return changed # optimize assembly, in place -def _optimize_assembly(assembly): +def optimize_assembly(assembly): for x in assembly: - if isinstance(x, list) and isinstance(x[0], _RuntimeHeader): - _optimize_assembly(x) + if isinstance(x, list) and isinstance(x[0], RuntimeHeader): + optimize_assembly(x) for _ in range(1024): changed = False @@ -993,6 +1031,7 @@ def _optimize_assembly(assembly): changed |= _merge_iszero(assembly) changed |= _merge_jumpdests(assembly) changed |= _prune_inefficient_jumps(assembly) + changed |= _optimize_inefficient_jumps(assembly) changed |= _prune_unused_jumpdests(assembly) changed |= _stack_peephole_opts(assembly) @@ -1021,7 +1060,7 @@ def adjust_pc_maps(pc_maps, ofst): def _data_to_evm(assembly, symbol_map): ret = bytearray() - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") @@ -1039,7 +1078,7 @@ def _data_to_evm(assembly, symbol_map): # predict what length of an assembly [data] node will be in bytecode def _length_of_data(assembly): ret = 0 - assert isinstance(assembly[0], _DataHeader) + assert isinstance(assembly[0], DataHeader) for item in assembly[1:]: if is_symbol(item): ret += SYMBOL_SIZE @@ -1055,7 +1094,7 @@ def _length_of_data(assembly): @dataclass -class _RuntimeHeader: +class RuntimeHeader: label: str ctor_mem_size: int immutables_len: int @@ -1065,7 +1104,7 @@ def __repr__(self): @dataclass -class _DataHeader: +class DataHeader: label: str def __repr__(self): @@ -1081,11 +1120,11 @@ def _relocate_segments(assembly): code_segments = [] for t in assembly: if isinstance(t, list): - if isinstance(t[0], _DataHeader): + if isinstance(t[0], DataHeader): data_segments.append(t) else: _relocate_segments(t) # recurse - assert isinstance(t[0], _RuntimeHeader) + assert isinstance(t[0], RuntimeHeader) code_segments.append(t) else: non_data_segments.append(t) @@ -1134,7 +1173,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat mem_ofst_size, ctor_mem_size = None, None max_mem_ofst = 0 for i, item in enumerate(assembly): - if isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + if isinstance(item, list) and isinstance(item[0], RuntimeHeader): assert runtime_code is None, "Multiple subcodes" assert ctor_mem_size is None @@ -1184,6 +1223,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat if is_symbol_map_indicator(assembly[i + 1]): # Don't increment pc as the symbol itself doesn't go into code if item in symbol_map: + print(assembly) raise CompilerPanic(f"duplicate jumpdest {item}") symbol_map[item] = pc @@ -1198,7 +1238,7 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) pc -= 1 - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): # we are in initcode symbol_map[item[0].label] = pc # add source map for all items in the runtime map @@ -1209,10 +1249,10 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat pc += len(runtime_code) # grab lengths of data sections from the runtime for t in item: - if isinstance(t, list) and isinstance(t[0], _DataHeader): + if isinstance(t, list) and isinstance(t[0], DataHeader): data_section_lengths.append(_length_of_data(t)) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): symbol_map[item[0].label] = pc pc += _length_of_data(item) else: @@ -1285,9 +1325,9 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": ret.append(SWAP_OFFSET + int(item[4:])) - elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader): + elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): ret.extend(runtime_code) - elif isinstance(item, list) and isinstance(item[0], _DataHeader): + elif isinstance(item, list) and isinstance(item[0], DataHeader): ret.extend(_data_to_evm(item, symbol_map)) else: # pragma: no cover # unreachable diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py index 8df4bbac2d..79e02f041d 100644 --- a/vyper/ir/optimizer.py +++ b/vyper/ir/optimizer.py @@ -440,6 +440,8 @@ def _optimize(node: IRnode, parent: Optional[IRnode]) -> Tuple[bool, IRnode]: error_msg = node.error_msg annotation = node.annotation add_gas_estimate = node.add_gas_estimate + is_self_call = node.is_self_call + passthrough_metadata = node.passthrough_metadata changed = False @@ -462,6 +464,8 @@ def finalize(val, args): error_msg=error_msg, annotation=annotation, add_gas_estimate=add_gas_estimate, + is_self_call=is_self_call, + passthrough_metadata=passthrough_metadata, ) if should_check_symbols: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 77b9efb13d..140f73f095 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -93,7 +93,7 @@ def __init__( self.nonreentrant = nonreentrant # a list of internal functions this function calls - self.called_functions = OrderedSet() + self.called_functions = OrderedSet[ContractFunctionT]() # to be populated during codegen self._ir_info: Any = None diff --git a/vyper/utils.py b/vyper/utils.py index 3d9d9cb416..0a2e1f831f 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -6,12 +6,14 @@ import time import traceback import warnings -from typing import List, Union +from typing import Generic, List, TypeVar, Union from vyper.exceptions import DecimalOverrideException, InvalidLiteral +_T = TypeVar("_T") -class OrderedSet(dict): + +class OrderedSet(Generic[_T], dict[_T, None]): """ a minimal "ordered set" class. this is needed in some places because, while dict guarantees you can recover insertion order @@ -20,9 +22,41 @@ class OrderedSet(dict): functionality as needed. """ - def add(self, item): + def __init__(self, iterable=None): + super().__init__() + if iterable is not None: + for item in iterable: + self.add(item) + + def __repr__(self): + keys = ", ".join(repr(k) for k in self.keys()) + return f"{{{keys}}}" + + def get(self, *args, **kwargs): + raise RuntimeError("can't call get() on OrderedSet!") + + def add(self, item: _T) -> None: self[item] = None + def remove(self, item: _T) -> None: + del self[item] + + def difference(self, other): + ret = self.copy() + for k in other.keys(): + if k in ret: + ret.remove(k) + return ret + + def union(self, other): + return self | other + + def __or__(self, other): + return self.__class__(super().__or__(other)) + + def copy(self): + return self.__class__(super().copy()) + class DecimalContextOverride(decimal.Context): def __setattr__(self, name, value): @@ -436,3 +470,25 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) + + +def ir_pass(func): + """ + Decorator for IR passes. This decorator will run the pass repeatedly until + no more changes are made. + """ + + def wrapper(*args, **kwargs): + count = 0 + + while True: + changes = func(*args, **kwargs) or 0 + if isinstance(changes, list) or isinstance(changes, set): + changes = len(changes) + count += changes + if changes == 0: + break + + return count + + return wrapper diff --git a/vyper/venom/README.md b/vyper/venom/README.md new file mode 100644 index 0000000000..a81f6c0582 --- /dev/null +++ b/vyper/venom/README.md @@ -0,0 +1,162 @@ +## Venom - An Intermediate representation language for Vyper + +### Introduction + +Venom serves as the next-gen intermediate representation language specifically tailored for use with the Vyper smart contract compiler. Drawing inspiration from LLVM IR, Venom has been adapted to be simpler, and to be architected towards emitting code for stack-based virtual machines. Designed with a Single Static Assignment (SSA) form, Venom allows for sophisticated analysis and optimizations, while accommodating the idiosyncrasies of the EVM architecture. + +### Venom Form + +In Venom, values are denoted as strings commencing with the `'%'` character, referred to as variables. Variables can only be assigned to at declaration (they remain immutable post-assignment). Constants are represented as decimal numbers (hexadecimal may be added in the future). + +Reserved words include all the instruction opcodes and `'IRFunction'`, `'param'`, `'dbname'` and `'db'`. + +Any content following the `';'` character until the line end is treated as a comment. + +For instance, an example of incrementing a variable by one is as follows: + +```llvm +%sum = add %x, 1 ; Add one to x +``` + +Each instruction is identified by its opcode and a list of input operands. In cases where an instruction produces a result, it is stored in a new variable, as indicated on the left side of the assignment character. + +Code is organized into non-branching instruction blocks, known as _"Basic Blocks"_. Each basic block is defined by a label and contains its set of instructions. The final instruction of a basic block should either be a terminating instruction or a jump (conditional or unconditional) to other block(s). + +Basic blocks are grouped into _functions_ that are named and dictate the first block to execute. + +Venom employs two scopes: global and function level. + +### Example code + +```llvm +IRFunction: global + +global: + %1 = calldataload 0 + %2 = shr 224, %1 + jmp label %selector_bucket_0 + +selector_bucket_0: + %3 = xor %2, 1579456981 + %4 = iszero %3 + jnz label %1, label %2, %4 + +1: IN=[selector_bucket_0] OUT=[9] + jmp label %fallback + +2: + %5 = callvalue + %6 = calldatasize + %7 = lt %6, 164 + %8 = or %5, %7 + %9 = iszero %8 + assert %9 + stop + +fallback: + revert 0, 0 +``` + +### Grammar + +Below is a (not-so-complete) grammar to describe the text format of Venom IR: + +```llvm +program ::= function_declaration* + +function_declaration ::= "IRFunction:" identifier input_list? output_list? "=>" block + +input_list ::= "IN=" "[" (identifier ("," identifier)*)? "]" +output_list ::= "OUT=" "[" (identifier ("," identifier)*)? "]" + +block ::= label ":" input_list? output_list? "=>{" operation* "}" + +operation ::= "%" identifier "=" opcode operand ("," operand)* + | opcode operand ("," operand)* + +opcode ::= "calldataload" | "shr" | "shl" | "and" | "add" | "codecopy" | "mload" | "jmp" | "xor" | "iszero" | "jnz" | "label" | "lt" | "or" | "assert" | "callvalue" | "calldatasize" | "alloca" | "calldatacopy" | "invoke" | "gt" | ... + +operand ::= "%" identifier | label | integer | "label" "%" identifier +label ::= "%" identifier + +identifier ::= [a-zA-Z_][a-zA-Z0-9_]* +integer ::= [0-9]+ +``` + +## Implementation + +In the current implementation the compiler was extended to incorporate a new pass responsible for translating the original s-expr based IR into Venom. Subsequently, the generated Venom code undergoes processing by the actual Venom compiler, ultimately converting it to assembly code. That final assembly code is then passed to the original assembler of Vyper to produce the executable bytecode. + +Currently there is no implementation of the text format (that is, there is no front-end), although this is planned. At this time, Venom IR can only be constructed programmatically. + +## Architecture + +The Venom implementation is composed of several distinct passes that iteratively transform and optimize the Venom IR code until it reaches the assembly emitter, which produces the stack-based EVM assembly. The compiler is designed to be more-or-less pluggable, so passes can be written without too much knowledge of or dependency on other passes. + +These passes encompass generic transformations that streamline the code (such as dead code elimination and normalization), as well as those generating supplementary information about the code, like liveness analysis and control-flow graph (CFG) construction. Some passes may rely on the output of others, requiring a specific execution order. For instance, the code emitter expects the execution of a normalization pass preceding it, and this normalization pass, in turn, requires the augmentation of the Venom IR with code flow information. + +The primary categorization of pass types are: + +- Transformation passes +- Analysis/augmentation passes +- Optimization passes + +## Currently implemented passes + +The Venom compiler currently implements the following passes. + +### Control Flow Graph calculation + +The compiler generates a fundamental data structure known as the Control Flow Graph (CFG). This graph illustrates the interconnections between basic blocks, serving as a foundational data structure upon which many subsequent passes depend. + +### Data Flow Graph calculation + +To enable the compiler to analyze the movement of data through the code during execution, a specialized graph, the Dataflow Graph (DFG), is generated. The compiler inspects the code, determining where each variable is defined (in one location) and all the places where it is utilized. + +### Dataflow Transformation + +This pass depends on the DFG construction, and reorders variable declarations to try to reduce stack traffic during instruction selection. + +### Liveness analysis + +This pass conducts a dataflow analysis, utilizing information from previous passes to identify variables that are live at each instruction in the Venom IR code. A variable is deemed live at a particular instruction if it holds a value necessary for future operations. Variables only alive for their assignment instructions are identified here and then eliminated by the dead code elimination pass. + +### Dead code elimination + +This pass eliminates all basic blocks that are not reachable from any other basic block, leveraging the CFG. + +### Normalization + +A Venom program may feature basic blocks with multiple CFG inputs and outputs. This currently can occur when multiple blocks conditionally direct control to the same target basic block. We define a Venom IR as "normalized" when it contains no basic blocks that have multiple inputs and outputs. The normalization pass is responsible for converting any Venom IR program to its normalized form. EVM assembly emission operates solely on normalized Venom programs, because the stack layout is not well defined for non-normalized basic blocks. + +### Code emission + +This final pass of the compiler aims to emit EVM assembly recognized by Vyper's assembler. It calcluates the desired stack layout for every basic block, schedules items on the stack and selects instructions. It ensures that deploy code, runtime code, and data segments are arranged according to the assembler's expectations. + +## Future planned passes + +A number of passes that are planned to be implemented, or are implemented for immediately after the initial PR merge are below. + +### Constant folding + +### Instruction combination + +### Dead store elimination + +### Scalar evolution + +### Loop invariant code motion + +### Loop unrolling + +### Code sinking + +### Expression reassociation + +### Stack to mem + +### Mem to stack + +### Function inlining + +### Load-store elimination diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py new file mode 100644 index 0000000000..5a09f8378e --- /dev/null +++ b/vyper/venom/__init__.py @@ -0,0 +1,56 @@ +# maybe rename this `main.py` or `venom.py` +# (can have an `__init__.py` which exposes the API). + +from typing import Optional + +from vyper.codegen.ir_node import IRnode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.bb_optimizer import ( + ir_pass_optimize_empty_blocks, + ir_pass_optimize_unused_variables, + ir_pass_remove_unreachable_blocks, +) +from vyper.venom.function import IRFunction +from vyper.venom.ir_node_to_venom import convert_ir_basicblock +from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation +from vyper.venom.passes.dft import DFTPass +from vyper.venom.venom_to_assembly import VenomCompiler + + +def generate_assembly_experimental( + ctx: IRFunction, optimize: Optional[OptimizationLevel] = None +) -> list[str]: + compiler = VenomCompiler(ctx) + return compiler.generate_evm(optimize is OptimizationLevel.NONE) + + +def generate_ir(ir: IRnode, optimize: Optional[OptimizationLevel] = None) -> IRFunction: + # Convert "old" IR to "new" IR + ctx = convert_ir_basicblock(ir) + + # Run passes on "new" IR + # TODO: Add support for optimization levels + while True: + changes = 0 + + changes += ir_pass_optimize_empty_blocks(ctx) + changes += ir_pass_remove_unreachable_blocks(ctx) + + calculate_liveness(ctx) + + changes += ir_pass_optimize_unused_variables(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + changes += ir_pass_constant_propagation(ctx) + changes += DFTPass.run_pass(ctx) + + calculate_cfg(ctx) + calculate_liveness(ctx) + + if changes == 0: + break + + return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py new file mode 100644 index 0000000000..5980e21028 --- /dev/null +++ b/vyper/venom/analysis.py @@ -0,0 +1,191 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.basicblock import ( + BB_TERMINATORS, + CFG_ALTERING_OPS, + IRBasicBlock, + IRInstruction, + IRVariable, +) +from vyper.venom.function import IRFunction + + +def calculate_cfg(ctx: IRFunction) -> None: + """ + Calculate (cfg) inputs for each basic block. + """ + for bb in ctx.basic_blocks: + bb.cfg_in = OrderedSet() + bb.cfg_out = OrderedSet() + bb.out_vars = OrderedSet() + + # TODO: This is a hack to support the old IR format where `deploy` is + # an instruction. in the future we should have two entry points, one + # for the initcode and one for the runtime code. + deploy_bb = None + after_deploy_bb = None + for i, bb in enumerate(ctx.basic_blocks): + if bb.instructions[0].opcode == "deploy": + deploy_bb = bb + after_deploy_bb = ctx.basic_blocks[i + 1] + break + + if deploy_bb is not None: + assert after_deploy_bb is not None, "No block after deploy block" + entry_block = after_deploy_bb + has_constructor = ctx.basic_blocks[0].instructions[0].opcode != "deploy" + if has_constructor: + deploy_bb.add_cfg_in(ctx.basic_blocks[0]) + entry_block.add_cfg_in(deploy_bb) + else: + entry_block = ctx.basic_blocks[0] + + # TODO: Special case for the jump table of selector buckets and fallback. + # this will be cleaner when we introduce an "indirect jump" instruction + # for the selector table (which includes all possible targets). it will + # also clean up the code for normalization because it will not have to + # handle this case specially. + for bb in ctx.basic_blocks: + if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": + bb.add_cfg_in(entry_block) + + for bb in ctx.basic_blocks: + assert len(bb.instructions) > 0, "Basic block should not be empty" + last_inst = bb.instructions[-1] + assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" + + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_OPS: + ops = inst.get_label_operands() + for op in ops: + ctx.get_basic_block(op.value).add_cfg_in(bb) + + # Fill in the "out" set for each basic block + for bb in ctx.basic_blocks: + for in_bb in bb.cfg_in: + in_bb.add_cfg_out(bb) + + +def _reset_liveness(ctx: IRFunction) -> None: + for bb in ctx.basic_blocks: + for inst in bb.instructions: + inst.liveness = OrderedSet() + + +def _calculate_liveness_bb(bb: IRBasicBlock) -> None: + """ + Compute liveness of each instruction in the basic block. + """ + liveness = bb.out_vars.copy() + for instruction in reversed(bb.instructions): + ops = instruction.get_inputs() + + for op in ops: + if op in liveness: + instruction.dup_requirements.add(op) + + liveness = liveness.union(OrderedSet.fromkeys(ops)) + out = instruction.get_outputs()[0] if len(instruction.get_outputs()) > 0 else None + if out in liveness: + liveness.remove(out) + instruction.liveness = liveness + + +def _calculate_liveness_r(bb: IRBasicBlock, visited: dict) -> None: + assert isinstance(visited, dict) + for out_bb in bb.cfg_out: + if visited.get(bb) == out_bb: + continue + visited[bb] = out_bb + + # recurse + _calculate_liveness_r(out_bb, visited) + + target_vars = input_vars_from(bb, out_bb) + + # the output stack layout for bb. it produces a stack layout + # which works for all possible cfg_outs from the bb. + bb.out_vars = bb.out_vars.union(target_vars) + + _calculate_liveness_bb(bb) + + +def calculate_liveness(ctx: IRFunction) -> None: + _reset_liveness(ctx) + _calculate_liveness_r(ctx.basic_blocks[0], dict()) + + +# calculate the input variables into self from source +def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: + liveness = target.instructions[0].liveness.copy() + assert isinstance(liveness, OrderedSet) + + for inst in target.instructions: + if inst.opcode == "phi": + # we arbitrarily choose one of the arguments to be in the + # live variables set (dependent on how we traversed into this + # basic block). the argument will be replaced by the destination + # operand during instruction selection. + # for instance, `%56 = phi %label1 %12 %label2 %14` + # will arbitrarily choose either %12 or %14 to be in the liveness + # set, and then during instruction selection, after this instruction, + # %12 will be replaced by %56 in the liveness set + source1, source2 = inst.operands[0], inst.operands[2] + phi1, phi2 = inst.operands[1], inst.operands[3] + if source.label == source1: + liveness.add(phi1) + if phi2 in liveness: + liveness.remove(phi2) + elif source.label == source2: + liveness.add(phi2) + if phi1 in liveness: + liveness.remove(phi1) + else: + # bad path into this phi node + raise CompilerPanic(f"unreachable: {inst}") + + return liveness + + +# DataFlow Graph +# this could be refactored into its own file, but it's only used here +# for now +class DFG: + _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_outputs: dict[IRVariable, IRInstruction] + + def __init__(self): + self._dfg_inputs = dict() + self._dfg_outputs = dict() + + # return uses of a given variable + def get_uses(self, op: IRVariable) -> list[IRInstruction]: + return self._dfg_inputs.get(op, []) + + # the instruction which produces this variable. + def get_producing_instruction(self, op: IRVariable) -> IRInstruction: + return self._dfg_outputs[op] + + @classmethod + def build_dfg(cls, ctx: IRFunction) -> "DFG": + dfg = cls() + + # Build DFG + + # %15 = add %13 %14 + # %16 = iszero %15 + # dfg_outputs of %15 is (%15 = add %13 %14) + # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] + for bb in ctx.basic_blocks: + for inst in bb.instructions: + operands = inst.get_inputs() + res = inst.get_outputs() + + for op in operands: + inputs = dfg._dfg_inputs.setdefault(op, []) + inputs.append(inst) + + for op in res: # type: ignore + dfg._dfg_outputs[op] = inst + + return dfg diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py new file mode 100644 index 0000000000..b95d7416ca --- /dev/null +++ b/vyper/venom/basicblock.py @@ -0,0 +1,345 @@ +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Iterator, Optional + +from vyper.utils import OrderedSet + +# instructions which can terminate a basic block +BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) + +VOLATILE_INSTRUCTIONS = frozenset( + [ + "param", + "alloca", + "call", + "staticcall", + "invoke", + "sload", + "sstore", + "iload", + "istore", + "assert", + "mstore", + "mload", + "calldatacopy", + "codecopy", + "dloadbytes", + "dload", + "return", + "ret", + "jmp", + "jnz", + ] +) + +CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) + + +if TYPE_CHECKING: + from vyper.venom.function import IRFunction + + +class IRDebugInfo: + """ + IRDebugInfo represents debug information in IR, used to annotate IR instructions + with source code information when printing IR. + """ + + line_no: int + src: str + + def __init__(self, line_no: int, src: str) -> None: + self.line_no = line_no + self.src = src + + def __repr__(self) -> str: + src = self.src if self.src else "" + return f"\t# line {self.line_no}: {src}".expandtabs(20) + + +class IROperand: + """ + IROperand represents an operand in IR. An operand is anything that can + be an argument to an IRInstruction + """ + + value: Any + + +class IRValue(IROperand): + """ + IRValue represents a value in IR. A value is anything that can be + operated by non-control flow instructions. That is, IRValues can be + IRVariables or IRLiterals. + """ + + pass + + +class IRLiteral(IRValue): + """ + IRLiteral represents a literal in IR + """ + + value: int + + def __init__(self, value: int) -> None: + assert isinstance(value, str) or isinstance(value, int), "value must be an int" + self.value = value + + def __repr__(self) -> str: + return str(self.value) + + +class MemType(Enum): + OPERAND_STACK = auto() + MEMORY = auto() + + +class IRVariable(IRValue): + """ + IRVariable represents a variable in IR. A variable is a string that starts with a %. + """ + + value: str + offset: int = 0 + + # some variables can be in memory for conversion from legacy IR to venom + mem_type: MemType = MemType.OPERAND_STACK + mem_addr: Optional[int] = None + + def __init__( + self, value: str, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: int = None + ) -> None: + assert isinstance(value, str) + self.value = value + self.offset = 0 + self.mem_type = mem_type + self.mem_addr = mem_addr + + def __repr__(self) -> str: + return self.value + + +class IRLabel(IROperand): + """ + IRLabel represents a label in IR. A label is a string that starts with a %. + """ + + # is_symbol is used to indicate if the label came from upstream + # (like a function name, try to preserve it in optimization passes) + is_symbol: bool = False + value: str + + def __init__(self, value: str, is_symbol: bool = False) -> None: + assert isinstance(value, str), "value must be an str" + self.value = value + self.is_symbol = is_symbol + + def __repr__(self) -> str: + return self.value + + +class IRInstruction: + """ + IRInstruction represents an instruction in IR. Each instruction has an opcode, + operands, and return value. For example, the following IR instruction: + %1 = add %0, 1 + has opcode "add", operands ["%0", "1"], and return value "%1". + + Convention: the rightmost value is the top of the stack. + """ + + opcode: str + volatile: bool + operands: list[IROperand] + output: Optional[IROperand] + # set of live variables at this instruction + liveness: OrderedSet[IRVariable] + dup_requirements: OrderedSet[IRVariable] + parent: Optional["IRBasicBlock"] + fence_id: int + annotation: Optional[str] + + def __init__( + self, + opcode: str, + operands: list[IROperand] | Iterator[IROperand], + output: Optional[IROperand] = None, + ): + assert isinstance(opcode, str), "opcode must be an str" + assert isinstance(operands, list | Iterator), "operands must be a list" + self.opcode = opcode + self.volatile = opcode in VOLATILE_INSTRUCTIONS + self.operands = [op for op in operands] # in case we get an iterator + self.output = output + self.liveness = OrderedSet() + self.dup_requirements = OrderedSet() + self.parent = None + self.fence_id = -1 + self.annotation = None + + def get_label_operands(self) -> list[IRLabel]: + """ + Get all labels in instruction. + """ + return [op for op in self.operands if isinstance(op, IRLabel)] + + def get_non_label_operands(self) -> list[IROperand]: + """ + Get input operands for instruction which are not labels + """ + return [op for op in self.operands if not isinstance(op, IRLabel)] + + def get_inputs(self) -> list[IRVariable]: + """ + Get all input operands for instruction. + """ + return [op for op in self.operands if isinstance(op, IRVariable)] + + def get_outputs(self) -> list[IROperand]: + """ + Get the output item for an instruction. + (Currently all instructions output at most one item, but write + it as a list to be generic for the future) + """ + return [self.output] if self.output else [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + for i, operand in enumerate(self.operands): + if operand in replacements: + self.operands[i] = replacements[operand] + + def __repr__(self) -> str: + s = "" + if self.output: + s += f"{self.output} = " + opcode = f"{self.opcode} " if self.opcode != "store" else "" + s += opcode + operands = ", ".join( + [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in self.operands] + ) + s += operands + + if self.annotation: + s += f" <{self.annotation}>" + + # if self.liveness: + # return f"{s: <30} # {self.liveness}" + + return s + + +class IRBasicBlock: + """ + IRBasicBlock represents a basic block in IR. Each basic block has a label and + a list of instructions, while belonging to a function. + + The following IR code: + %1 = add %0, 1 + %2 = mul %1, 2 + is represented as: + bb = IRBasicBlock("bb", function) + bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) + bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + + The label of a basic block is used to refer to it from other basic blocks + in order to branch to it. + + The parent of a basic block is the function it belongs to. + + The instructions of a basic block are executed sequentially, and the last + instruction of a basic block is always a terminator instruction, which is + used to branch to other basic blocks. + """ + + label: IRLabel + parent: "IRFunction" + instructions: list[IRInstruction] + # basic blocks which can jump to this basic block + cfg_in: OrderedSet["IRBasicBlock"] + # basic blocks which this basic block can jump to + cfg_out: OrderedSet["IRBasicBlock"] + # stack items which this basic block produces + out_vars: OrderedSet[IRVariable] + + def __init__(self, label: IRLabel, parent: "IRFunction") -> None: + assert isinstance(label, IRLabel), "label must be an IRLabel" + self.label = label + self.parent = parent + self.instructions = [] + self.cfg_in = OrderedSet() + self.cfg_out = OrderedSet() + self.out_vars = OrderedSet() + + def add_cfg_in(self, bb: "IRBasicBlock") -> None: + self.cfg_in.add(bb) + + def remove_cfg_in(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_in + self.cfg_in.remove(bb) + + def add_cfg_out(self, bb: "IRBasicBlock") -> None: + # malformed: jnz condition label1 label1 + # (we could handle but it makes a lot of code easier + # if we have this assumption) + self.cfg_out.add(bb) + + def remove_cfg_out(self, bb: "IRBasicBlock") -> None: + assert bb in self.cfg_out + self.cfg_out.remove(bb) + + @property + def is_reachable(self) -> bool: + return len(self.cfg_in) > 0 + + def append_instruction(self, instruction: IRInstruction) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.append(instruction) + + def insert_instruction(self, instruction: IRInstruction, index: int) -> None: + assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" + instruction.parent = self + self.instructions.insert(index, instruction) + + def clear_instructions(self) -> None: + self.instructions = [] + + def replace_operands(self, replacements: dict) -> None: + """ + Update operands with replacements. + """ + for instruction in self.instructions: + instruction.replace_operands(replacements) + + @property + def is_terminated(self) -> bool: + """ + Check if the basic block is terminal, i.e. the last instruction is a terminator. + """ + # it's ok to return False here, since we use this to check + # if we can/need to append instructions to the basic block. + if len(self.instructions) == 0: + return False + return self.instructions[-1].opcode in BB_TERMINATORS + + def copy(self): + bb = IRBasicBlock(self.label, self.parent) + bb.instructions = self.instructions.copy() + bb.cfg_in = self.cfg_in.copy() + bb.cfg_out = self.cfg_out.copy() + bb.out_vars = self.out_vars.copy() + return bb + + def __repr__(self) -> str: + s = ( + f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" + f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars} \n" + ) + for instruction in self.instructions: + s += f" {instruction}\n" + return s diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py new file mode 100644 index 0000000000..620ee66d15 --- /dev/null +++ b/vyper/venom/bb_optimizer.py @@ -0,0 +1,73 @@ +from vyper.utils import ir_pass +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRInstruction, IRLabel +from vyper.venom.function import IRFunction + + +def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: + """ + Remove unused variables. + """ + removeList = set() + for bb in ctx.basic_blocks: + for i, inst in enumerate(bb.instructions[:-1]): + if inst.volatile: + continue + if inst.output and inst.output not in bb.instructions[i + 1].liveness: + removeList.add(inst) + + bb.instructions = [inst for inst in bb.instructions if inst not in removeList] + + return removeList + + +def _optimize_empty_basicblocks(ctx: IRFunction) -> int: + """ + Remove empty basic blocks. + """ + count = 0 + i = 0 + while i < len(ctx.basic_blocks): + bb = ctx.basic_blocks[i] + i += 1 + if len(bb.instructions) > 0: + continue + + replaced_label = bb.label + replacement_label = ctx.basic_blocks[i].label if i < len(ctx.basic_blocks) else None + if replacement_label is None: + continue + + # Try to preserve symbol labels + if replaced_label.is_symbol: + replaced_label, replacement_label = replacement_label, replaced_label + ctx.basic_blocks[i].label = replacement_label + + for bb2 in ctx.basic_blocks: + for inst in bb2.instructions: + for op in inst.operands: + if isinstance(op, IRLabel) and op.value == replaced_label.value: + op.value = replacement_label.value + + ctx.basic_blocks.remove(bb) + i -= 1 + count += 1 + + return count + + +@ir_pass +def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: + changes = _optimize_empty_basicblocks(ctx) + calculate_cfg(ctx) + return changes + + +@ir_pass +def ir_pass_remove_unreachable_blocks(ctx: IRFunction) -> int: + return ctx.remove_unreachable_blocks() + + +@ir_pass +def ir_pass_optimize_unused_variables(ctx: IRFunction) -> int: + return len(_optimize_unused_variables(ctx)) diff --git a/vyper/venom/function.py b/vyper/venom/function.py new file mode 100644 index 0000000000..c14ad77345 --- /dev/null +++ b/vyper/venom/function.py @@ -0,0 +1,170 @@ +from typing import Optional + +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IROperand, + IRVariable, + MemType, +) + +GLOBAL_LABEL = IRLabel("global") + + +class IRFunction: + """ + Function that contains basic blocks. + """ + + name: IRLabel # symbol name + args: list + basic_blocks: list[IRBasicBlock] + data_segment: list[IRInstruction] + last_label: int + last_variable: int + + def __init__(self, name: IRLabel = None) -> None: + if name is None: + name = GLOBAL_LABEL + self.name = name + self.args = [] + self.basic_blocks = [] + self.data_segment = [] + self.last_label = 0 + self.last_variable = 0 + + self.append_basic_block(IRBasicBlock(name, self)) + + def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + """ + Append basic block to function. + """ + assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" + self.basic_blocks.append(bb) + + # TODO add sanity check somewhere that basic blocks have unique labels + + return self.basic_blocks[-1] + + def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: + """ + Get basic block by label. + If label is None, return the last basic block. + """ + if label is None: + return self.basic_blocks[-1] + for bb in self.basic_blocks: + if bb.label.value == label: + return bb + raise AssertionError(f"Basic block '{label}' not found") + + def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + """ + Get basic block after label. + """ + for i, bb in enumerate(self.basic_blocks[:-1]): + if bb.label.value == label.value: + return self.basic_blocks[i + 1] + raise AssertionError(f"Basic block after '{label}' not found") + + def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: + """ + Get basic blocks that contain label. + """ + return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] + + def get_next_label(self) -> IRLabel: + self.last_label += 1 + return IRLabel(f"{self.last_label}") + + def get_next_variable( + self, mem_type: MemType = MemType.OPERAND_STACK, mem_addr: Optional[int] = None + ) -> IRVariable: + self.last_variable += 1 + return IRVariable(f"%{self.last_variable}", mem_type, mem_addr) + + def get_last_variable(self) -> str: + return f"%{self.last_variable}" + + def remove_unreachable_blocks(self) -> int: + removed = 0 + new_basic_blocks = [] + for bb in self.basic_blocks: + if not bb.is_reachable and bb.label.value != "global": + removed += 1 + else: + new_basic_blocks.append(bb) + self.basic_blocks = new_basic_blocks + return removed + + def append_instruction( + self, opcode: str, args: list[IROperand], do_ret: bool = True + ) -> Optional[IRVariable]: + """ + Append instruction to last basic block. + """ + ret = self.get_next_variable() if do_ret else None + inst = IRInstruction(opcode, args, ret) # type: ignore + self.get_basic_block().append_instruction(inst) + return ret + + def append_data(self, opcode: str, args: list[IROperand]) -> None: + """ + Append data + """ + self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + + @property + def normalized(self) -> bool: + """ + Check if function is normalized. A function is normalized if in the + CFG, no basic block simultaneously has multiple inputs and outputs. + That is, a basic block can be jumped to *from* multiple blocks, or it + can jump *to* multiple blocks, but it cannot simultaneously do both. + Having a normalized CFG makes calculation of stack layout easier when + emitting assembly. + """ + for bb in self.basic_blocks: + # Ignore if there are no multiple predecessors + if len(bb.cfg_in) <= 1: + continue + + # Check if there is a conditional jump at the end + # of one of the predecessors + # + # TODO: this check could be: + # `if len(in_bb.cfg_out) > 1: return False` + # but the cfg is currently not calculated "correctly" for + # certain special instructions (deploy instruction and + # selector table indirect jumps). + for in_bb in bb.cfg_in: + jump_inst = in_bb.instructions[-1] + if jump_inst.opcode != "jnz": + continue + if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): + continue + + # The function is not normalized + return False + + # The function is normalized + return True + + def copy(self): + new = IRFunction(self.name) + new.basic_blocks = self.basic_blocks.copy() + new.data_segment = self.data_segment.copy() + new.last_label = self.last_label + new.last_variable = self.last_variable + return new + + def __repr__(self) -> str: + str = f"IRFunction: {self.name}\n" + for bb in self.basic_blocks: + str += f"{bb}\n" + if len(self.data_segment) > 0: + str += "Data segment:\n" + for inst in self.data_segment: + str += f"{inst}\n" + return str diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py new file mode 100644 index 0000000000..19bd5c8b73 --- /dev/null +++ b/vyper/venom/ir_node_to_venom.py @@ -0,0 +1,943 @@ +from typing import Optional + +from vyper.codegen.context import VariableRecord +from vyper.codegen.ir_node import IRnode +from vyper.evm.opcodes import get_opcodes +from vyper.exceptions import CompilerPanic +from vyper.ir.compile_ir import is_mem_sym, is_symbol +from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction + +_BINARY_IR_INSTRUCTIONS = frozenset( + [ + "eq", + "gt", + "lt", + "slt", + "sgt", + "shr", + "shl", + "or", + "xor", + "and", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "sha3", + "sha3_64", + "signextend", + ] +) + +# Instuctions that are mapped to their inverse +INVERSE_MAPPED_IR_INSTRUCTIONS = {"ne": "eq", "le": "gt", "sle": "sgt", "ge": "lt", "sge": "slt"} + +# Instructions that have a direct EVM opcode equivalent and can +# be passed through to the EVM assembly without special handling +PASS_THROUGH_INSTRUCTIONS = [ + "chainid", + "basefee", + "timestamp", + "caller", + "selfbalance", + "calldatasize", + "callvalue", + "address", + "origin", + "codesize", + "gas", + "gasprice", + "gaslimit", + "returndatasize", + "coinbase", + "number", + "iszero", + "ceil32", + "calldataload", + "extcodesize", + "extcodehash", + "balance", +] + +SymbolTable = dict[str, IROperand] + + +def _get_symbols_common(a: dict, b: dict) -> dict: + ret = {} + # preserves the ordering in `a` + for k in a.keys(): + if k not in b: + continue + if a[k] == b[k]: + continue + ret[k] = a[k], b[k] + return ret + + +def convert_ir_basicblock(ir: IRnode) -> IRFunction: + global_function = IRFunction() + _convert_ir_basicblock(global_function, ir, {}, OrderedSet(), {}) + + for i, bb in enumerate(global_function.basic_blocks): + if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: + bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + + revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) + revert_bb = global_function.append_basic_block(revert_bb) + revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + + return global_function + + +def _convert_binary_op( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], + swap: bool = False, +) -> IRVariable: + ir_args = ir.args[::-1] if swap else ir.args + arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) + args = [arg_1, arg_0] + + ret = ctx.get_next_variable() + + inst = IRInstruction(ir.value, args, ret) # type: ignore + ctx.get_basic_block().append_instruction(inst) + return ret + + +def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + + label = ctx.get_next_label() + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + + +def _new_block(ctx: IRFunction) -> IRBasicBlock: + bb = IRBasicBlock(ctx.get_next_label(), ctx) + bb = ctx.append_basic_block(bb) + return bb + + +def _handle_self_call( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + func_t = ir.passthrough_metadata.get("func_t", None) + args_ir = ir.passthrough_metadata["args_ir"] + goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] + target_label = goto_ir.args[0].value # goto + return_buf = goto_ir.args[1] # return buffer + ret_args = [IRLabel(target_label)] # type: ignore + + for arg in args_ir: + if arg.is_literal: + sym = symbols.get(f"&{arg.value}", None) + if sym is None: + ret = _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ret_args.append(ret) + else: + ret_args.append(sym) # type: ignore + else: + ret = _convert_ir_basicblock( + ctx, arg._optimized, symbols, variables, allocated_variables + ) + if arg.location and arg.location.load_op == "calldataload": + ret = ctx.append_instruction(arg.location.load_op, [ret]) + ret_args.append(ret) + + if return_buf.is_literal: + ret_args.append(IRLiteral(return_buf.value)) # type: ignore + + do_ret = func_t.return_type is not None + invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + + +def _handle_internal_func( + ctx: IRFunction, ir: IRnode, func_t: ContractFunctionT, symbols: SymbolTable +) -> IRnode: + bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore + bb = ctx.append_basic_block(bb) + + old_ir_mempos = 0 + old_ir_mempos += 64 + + for arg in func_t.arguments: + new_var = ctx.get_next_variable() + + alloca_inst = IRInstruction("param", [], new_var) + alloca_inst.annotation = arg.name + bb.append_instruction(alloca_inst) + symbols[f"&{old_ir_mempos}"] = new_var + old_ir_mempos += 32 # arg.typ.memory_bytes_required + + # return buffer + if func_t.return_type is not None: + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_buffer" + symbols["return_buffer"] = new_var + + # return address + new_var = ctx.get_next_variable() + alloca_inst = IRInstruction("param", [], new_var) + bb.append_instruction(alloca_inst) + alloca_inst.annotation = "return_pc" + symbols["return_pc"] = new_var + + return ir.args[0].args[2] + + +def _convert_ir_simple_node( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> Optional[IRVariable]: + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args + ] + return ctx.append_instruction(ir.value, args) # type: ignore + + +_break_target: Optional[IRBasicBlock] = None +_continue_target: Optional[IRBasicBlock] = None + + +def _get_variable_from_address( + variables: OrderedSet[VariableRecord], addr: int +) -> Optional[VariableRecord]: + assert isinstance(addr, int), "non-int address" + for var in variables.keys(): + if var.location.name != "memory": + continue + if addr >= var.pos and addr < var.pos + var.size: # type: ignore + return var + return None + + +def _get_return_for_stack_operand( + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> IRInstruction: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + else: + sym = symbols.get(ret_ir.value, None) + if sym is None: + # FIXME: needs real allocations + new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + else: + new_var = ret_ir + return IRInstruction("return", [last_ir, new_var]) # type: ignore + + +def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): + assert isinstance(variables, OrderedSet) + global _break_target, _continue_target + + frame_info = ir.passthrough_metadata.get("frame_info", None) + if frame_info is not None: + local_vars = OrderedSet[VariableRecord](frame_info.frame_vars.values()) + variables |= local_vars + + assert isinstance(variables, OrderedSet) + + if ir.value in _BINARY_IR_INSTRUCTIONS: + return _convert_binary_op( + ctx, ir, symbols, variables, allocated_variables, ir.value in ["sha3_64"] + ) + + elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + org_value = ir.value + ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] + new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) + ir.value = org_value + return ctx.append_instruction("iszero", [new_var]) + + elif ir.value in PASS_THROUGH_INSTRUCTIONS: + return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) + + elif ir.value in ["pass", "stop", "return"]: + pass + elif ir.value == "deploy": + memsize = ir.args[0].value + ir_runtime = ir.args[1] + padding = ir.args[2].value + assert isinstance(memsize, int), "non-int memsize" + assert isinstance(padding, int), "non-int padding" + + runtimeLabel = ctx.get_next_label() + + inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) + ctx.get_basic_block().append_instruction(inst) + + bb = IRBasicBlock(runtimeLabel, ctx) + ctx.append_basic_block(bb) + + _convert_ir_basicblock(ctx, ir_runtime, symbols, variables, allocated_variables) + elif ir.value == "seq": + func_t = ir.passthrough_metadata.get("func_t", None) + if ir.is_self_call: + return _handle_self_call(ctx, ir, symbols, variables, allocated_variables) + elif func_t is not None: + symbols = {} + allocated_variables = {} + variables = OrderedSet( + {v: True for v in ir.passthrough_metadata["frame_info"].frame_vars.values()} + ) + if func_t.is_internal: + ir = _handle_internal_func(ctx, ir, func_t, symbols) + # fallthrough + + ret = None + for ir_node in ir.args: # NOTE: skip the last one + ret = _convert_ir_basicblock(ctx, ir_node, symbols, variables, allocated_variables) + + return ret + elif ir.value in ["staticcall", "call"]: # external call + idx = 0 + gas = _convert_ir_basicblock(ctx, ir.args[idx], symbols, variables, allocated_variables) + address = _convert_ir_basicblock( + ctx, ir.args[idx + 1], symbols, variables, allocated_variables + ) + + value = None + if ir.value == "call": + value = _convert_ir_basicblock( + ctx, ir.args[idx + 2], symbols, variables, allocated_variables + ) + else: + idx -= 1 + + argsOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 3], symbols, variables, allocated_variables + ) + argsSize = _convert_ir_basicblock( + ctx, ir.args[idx + 4], symbols, variables, allocated_variables + ) + retOffset = _convert_ir_basicblock( + ctx, ir.args[idx + 5], symbols, variables, allocated_variables + ) + retSize = _convert_ir_basicblock( + ctx, ir.args[idx + 6], symbols, variables, allocated_variables + ) + + if isinstance(argsOffset, IRLiteral): + offset = int(argsOffset.value) + addr = offset - 32 + 4 if offset > 0 else 0 + argsOffsetVar = symbols.get(f"&{addr}", None) + if argsOffsetVar is None: + argsOffsetVar = argsOffset + elif isinstance(argsOffsetVar, IRVariable): + argsOffsetVar.mem_type = MemType.MEMORY + argsOffsetVar.mem_addr = addr + argsOffsetVar.offset = 32 - 4 if offset > 0 else 0 + else: # pragma: nocover + raise CompilerPanic("unreachable") + else: + argsOffsetVar = argsOffset + + retOffsetValue = int(retOffset.value) if retOffset else 0 + retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) + symbols[f"&{retOffsetValue}"] = retVar + + if ir.value == "call": + args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] + return ctx.append_instruction(ir.value, args) + else: + args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] + return ctx.append_instruction(ir.value, args) + elif ir.value == "if": + cond = ir.args[0] + current_bb = ctx.get_basic_block() + + # convert the condition + cont_ret = _convert_ir_basicblock(ctx, cond, symbols, variables, allocated_variables) + + else_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(else_block) + + # convert "else" + else_ret_val = None + else_syms = symbols.copy() + if len(ir.args) == 3: + else_ret_val = _convert_ir_basicblock( + ctx, ir.args[2], else_syms, variables, allocated_variables.copy() + ) + if isinstance(else_ret_val, IRLiteral): + assert isinstance(else_ret_val.value, int) # help mypy + else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + after_else_syms = else_syms.copy() + + # convert "then" + then_block = IRBasicBlock(ctx.get_next_label(), ctx) + ctx.append_basic_block(then_block) + + then_ret_val = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + if isinstance(then_ret_val, IRLiteral): + then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + + inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) + current_bb.append_instruction(inst) + + after_then_syms = symbols.copy() + + # exit bb + exit_label = ctx.get_next_label() + bb = IRBasicBlock(exit_label, ctx) + bb = ctx.append_basic_block(bb) + + if_ret = None + if then_ret_val is not None and else_ret_val is not None: + if_ret = ctx.get_next_variable() + bb.append_instruction( + IRInstruction( + "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret + ) + ) + + common_symbols = _get_symbols_common(after_then_syms, after_else_syms) + for sym, val in common_symbols.items(): + ret = ctx.get_next_variable() + old_var = symbols.get(sym, None) + symbols[sym] = ret + if old_var is not None: + for idx, var_rec in allocated_variables.items(): # type: ignore + if var_rec.value == old_var.value: + allocated_variables[idx] = ret # type: ignore + bb.append_instruction( + IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) + ) + + if not else_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + else_block.append_instruction(exit_inst) + + if not then_block.is_terminated: + exit_inst = IRInstruction("jmp", [bb.label]) + then_block.append_instruction(exit_inst) + + return if_ret + + elif ir.value == "with": + ret = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) # initialization + + # Handle with nesting with same symbol + with_symbols = symbols.copy() + + sym = ir.args[0] + if isinstance(ret, IRLiteral): + new_var = ctx.append_instruction("store", [ret]) # type: ignore + with_symbols[sym.value] = new_var + else: + with_symbols[sym.value] = ret # type: ignore + + return _convert_ir_basicblock( + ctx, ir.args[2], with_symbols, variables, allocated_variables + ) # body + elif ir.value == "goto": + _append_jmp(ctx, IRLabel(ir.args[0].value)) + elif ir.value == "jump": + arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + inst = IRInstruction("jmp", [arg_1]) + ctx.get_basic_block().append_instruction(inst) + _new_block(ctx) + elif ir.value == "set": + sym = ir.args[0] + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + symbols[sym.value] = new_var + + elif ir.value == "calldatacopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_v = arg_0 + var = ( + _get_variable_from_address(variables, int(arg_0.value)) + if isinstance(arg_0, IRLiteral) + else None + ) + if var is not None: + if allocated_variables.get(var.name, None) is None: + new_v = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore + ) + allocated_variables[var.name] = new_v # type: ignore + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + symbols[f"&{var.pos}"] = new_v # type: ignore + else: + ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + + return new_v + elif ir.value == "codecopy": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + elif ir.value == "symbol": + return IRLabel(ir.args[0].value, True) + elif ir.value == "data": + label = IRLabel(ir.args[0].value) + ctx.append_data("dbname", [label]) + for c in ir.args[1:]: + if isinstance(c, int): + assert 0 <= c <= 255, "data with invalid size" + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, bytes): + ctx.append_data("db", [c]) # type: ignore + elif isinstance(c, IRnode): + data = _convert_ir_basicblock(ctx, c, symbols, variables, allocated_variables) + ctx.append_data("db", [data]) # type: ignore + elif ir.value == "assert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + current_bb = ctx.get_basic_block() + inst = IRInstruction("assert", [arg_0]) # type: ignore + current_bb.append_instruction(inst) + elif ir.value == "label": + label = IRLabel(ir.args[0].value, True) + if not ctx.get_basic_block().is_terminated: + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + bb = IRBasicBlock(label, ctx) + ctx.append_basic_block(bb) + _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + elif ir.value == "exit_to": + func_t = ir.passthrough_metadata.get("func_t", None) + assert func_t is not None, "exit_to without func_t" + + if func_t.is_external: + # Hardcoded contructor special case + if func_t.name == "__init__": + label = IRLabel(ir.args[0].value, True) + inst = IRInstruction("jmp", [label]) + ctx.get_basic_block().append_instruction(inst) + return None + if func_t.return_type is None: + inst = IRInstruction("stop", []) + ctx.get_basic_block().append_instruction(inst) + return None + else: + last_ir = None + ret_var = ir.args[1] + deleted = None + if ret_var.is_literal and symbols.get(f"&{ret_var.value}", None) is not None: + deleted = symbols[f"&{ret_var.value}"] + del symbols[f"&{ret_var.value}"] + for arg in ir.args[2:]: + last_ir = _convert_ir_basicblock( + ctx, arg, symbols, variables, allocated_variables + ) + if deleted is not None: + symbols[f"&{ret_var.value}"] = deleted + + ret_ir = _convert_ir_basicblock( + ctx, ret_var, symbols, variables, allocated_variables + ) + + var = ( + _get_variable_from_address(variables, int(ret_ir.value)) + if isinstance(ret_ir, IRLiteral) + else None + ) + if var is not None: + allocated_var = allocated_variables.get(var.name, None) + assert allocated_var is not None, "unallocated variable" + new_var = symbols.get(f"&{ret_ir.value}", allocated_var) # type: ignore + + if var.size and int(var.size) > 32: + offset = int(ret_ir.value) - var.pos # type: ignore + if offset > 0: + ptr_var = ctx.append_instruction( + "add", [IRLiteral(var.pos), IRLiteral(offset)] + ) + else: + ptr_var = allocated_var + inst = IRInstruction("return", [last_ir, ptr_var]) + else: + inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + else: + if isinstance(ret_ir, IRLiteral): + sym = symbols.get(f"&{ret_ir.value}", None) + if sym is None: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if func_t.return_type.memory_bytes_required > 32: + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) + ctx.append_instruction("mstore", [sym, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + else: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + if last_ir and int(last_ir.value) > 32: + inst = IRInstruction("return", [last_ir, ret_ir]) + else: + ret_buf = IRLiteral(128) # TODO: need allocator + new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) + ctx.append_instruction("mstore", [ret_ir, new_var], False) + inst = IRInstruction("return", [last_ir, new_var]) + + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + + if func_t.is_internal: + assert ir.args[1].value == "return_pc", "return_pc not found" + if func_t.return_type is None: + inst = IRInstruction("ret", [symbols["return_pc"]]) + else: + if func_t.return_type.memory_bytes_required > 32: + inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + else: + ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) + inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) + + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "revert": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction("revert", [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + + elif ir.value == "dload": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + + ctx.append_instruction( + "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False + ) + return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) + elif ir.value == "dloadbytes": + dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + src_offset = _convert_ir_basicblock( + ctx, ir.args[1], symbols, variables, allocated_variables + ) + len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) + + inst = IRInstruction("dloadbytes", [len_, src, dst]) + ctx.get_basic_block().append_instruction(inst) + return None + elif ir.value == "mload": + sym_ir = ir.args[0] + var = ( + _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None + ) + if var is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mload", [ptr_var]) + else: + if sym_ir.is_literal: + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + new_var = ctx.append_instruction("store", [sym_ir]) + symbols[f"&{sym_ir.value}"] = new_var + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + return sym + + sym = symbols.get(f"&{sym_ir.value}", None) + assert sym is not None, "unallocated variable" + return sym + else: + if sym_ir.is_literal: + new_var = symbols.get(f"&{sym_ir.value}", None) + if new_var is not None: + return ctx.append_instruction("mload", [new_var]) + else: + return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + else: + new_var = _convert_ir_basicblock( + ctx, sym_ir, symbols, variables, allocated_variables + ) + # + # Old IR gets it's return value as a reference in the stack + # New IR gets it's return value in stack in case of 32 bytes or less + # So here we detect ahead of time if this mload leads a self call and + # and we skip the mload + # + if sym_ir.is_self_call: + return new_var + return ctx.append_instruction("mload", [new_var]) + + elif ir.value == "mstore": + sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + + var = None + if isinstance(sym_ir, IRLiteral): + var = _get_variable_from_address(variables, int(sym_ir.value)) + + if var is not None and var.size is not None: + if var.size and var.size > 32: + if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = ctx.append_instruction( + "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + ) + + offset = int(sym_ir.value) - var.pos + if offset > 0: + ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + else: + ptr_var = allocated_variables[var.name] + + return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + else: + if isinstance(sym_ir, IRLiteral): + new_var = ctx.append_instruction("store", [arg_1]) + symbols[f"&{sym_ir.value}"] = new_var + # if allocated_variables.get(var.name, None) is None: + allocated_variables[var.name] = new_var + return new_var + else: + if not isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + return None + + sym = symbols.get(f"&{sym_ir.value}", None) + if sym is None: + inst = IRInstruction("mstore", [arg_1, sym_ir]) + ctx.get_basic_block().append_instruction(inst) + if arg_1 and not isinstance(sym_ir, IRLiteral): + symbols[f"&{sym_ir.value}"] = arg_1 + return None + + if isinstance(sym_ir, IRLiteral): + inst = IRInstruction("mstore", [arg_1, sym]) + ctx.get_basic_block().append_instruction(inst) + return None + else: + symbols[sym_ir.value] = arg_1 + return arg_1 + + elif ir.value in ["sload", "iload"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + return ctx.append_instruction(ir.value, [arg_0]) + elif ir.value in ["sstore", "istore"]: + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + inst = IRInstruction(ir.value, [arg_1, arg_0]) + ctx.get_basic_block().append_instruction(inst) + elif ir.value == "unique_symbol": + sym = ir.args[0] + new_var = ctx.get_next_variable() + symbols[f"&{sym.value}"] = new_var + return new_var + elif ir.value == "repeat": + # + # repeat(sym, start, end, bound, body) + # 1) entry block ] + # 2) init counter block ] -> same block + # 3) condition block (exit block, body block) + # 4) body block + # 5) increment block + # 6) exit block + # TODO: Add the extra bounds check after clarify + def emit_body_block(): + global _break_target, _continue_target + old_targets = _break_target, _continue_target + _break_target, _continue_target = exit_block, increment_block + _convert_ir_basicblock(ctx, body, symbols, variables, allocated_variables) + _break_target, _continue_target = old_targets + + sym = ir.args[0] + start = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + end = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + # "bound" is not used + _ = _convert_ir_basicblock(ctx, ir.args[3], symbols, variables, allocated_variables) + body = ir.args[4] + + entry_block = ctx.get_basic_block() + cond_block = IRBasicBlock(ctx.get_next_label(), ctx) + body_block = IRBasicBlock(ctx.get_next_label(), ctx) + jump_up_block = IRBasicBlock(ctx.get_next_label(), ctx) + increment_block = IRBasicBlock(ctx.get_next_label(), ctx) + exit_block = IRBasicBlock(ctx.get_next_label(), ctx) + + counter_var = ctx.get_next_variable() + counter_inc_var = ctx.get_next_variable() + ret = ctx.get_next_variable() + + inst = IRInstruction("store", [start], counter_var) + ctx.get_basic_block().append_instruction(inst) + symbols[sym.value] = counter_var + inst = IRInstruction("jmp", [cond_block.label]) + ctx.get_basic_block().append_instruction(inst) + + symbols[sym.value] = ret + cond_block.append_instruction( + IRInstruction( + "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret + ) + ) + + xor_ret = ctx.get_next_variable() + cont_ret = ctx.get_next_variable() + inst = IRInstruction("xor", [ret, end], xor_ret) + cond_block.append_instruction(inst) + cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + ctx.append_basic_block(cond_block) + + # Do a dry run to get the symbols needing phi nodes + start_syms = symbols.copy() + ctx.append_basic_block(body_block) + emit_body_block() + end_syms = symbols.copy() + diff_syms = _get_symbols_common(start_syms, end_syms) + + replacements = {} + for sym, val in diff_syms.items(): + new_var = ctx.get_next_variable() + symbols[sym] = new_var + replacements[val[0]] = new_var + replacements[val[1]] = new_var + cond_block.insert_instruction( + IRInstruction( + "phi", [entry_block.label, val[0], increment_block.label, val[1]], new_var + ), + 1, + ) + + body_block.replace_operands(replacements) + + body_end = ctx.get_basic_block() + if not body_end.is_terminated: + body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + + jump_cond = IRInstruction("jmp", [increment_block.label]) + jump_up_block.append_instruction(jump_cond) + ctx.append_basic_block(jump_up_block) + + increment_block.append_instruction( + IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) + ) + increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + ctx.append_basic_block(increment_block) + + ctx.append_basic_block(exit_block) + + inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) + cond_block.append_instruction(inst) + elif ir.value == "break": + assert _break_target is not None, "Break with no break target" + inst = IRInstruction("jmp", [_break_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "continue": + assert _continue_target is not None, "Continue with no contrinue target" + inst = IRInstruction("jmp", [_continue_target.label]) + ctx.get_basic_block().append_instruction(inst) + ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + elif ir.value == "gas": + return ctx.append_instruction("gas", []) + elif ir.value == "returndatasize": + return ctx.append_instruction("returndatasize", []) + elif ir.value == "returndatacopy": + assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) + + new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + + symbols[f"&{arg_0.value}"] = new_var + return new_var + elif ir.value == "selfdestruct": + arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) + ctx.append_instruction("selfdestruct", [arg_0], False) + elif isinstance(ir.value, str) and ir.value.startswith("log"): + args = [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + inst = IRInstruction(ir.value, reversed(args)) + ctx.get_basic_block().append_instruction(inst) + elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): + _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) + elif isinstance(ir.value, str) and ir.value in symbols: + return symbols[ir.value] + elif ir.is_literal: + return IRLiteral(ir.value) + else: + raise Exception(f"Unknown IR node: {ir}") + + return None + + +def _convert_ir_opcode( + ctx: IRFunction, + ir: IRnode, + symbols: SymbolTable, + variables: OrderedSet, + allocated_variables: dict[str, IRVariable], +) -> None: + opcode = ir.value.upper() # type: ignore + inst_args = [] + for arg in ir.args: + if isinstance(arg, IRnode): + inst_args.append( + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + ) + instruction = IRInstruction(opcode, inst_args) # type: ignore + ctx.get_basic_block().append_instruction(instruction) + + +def _data_ofst_of(sym, ofst, height_): + # e.g. _OFST _sym_foo 32 + assert is_symbol(sym) or is_mem_sym(sym) + if isinstance(ofst.value, int): + # resolve at compile time using magic _OFST op + return ["_OFST", sym, ofst.value] + else: + # if we can't resolve at compile time, resolve at runtime + # ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) + return ofst + [sym, "ADD"] diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py new file mode 100644 index 0000000000..11da80ac66 --- /dev/null +++ b/vyper/venom/passes/base_pass.py @@ -0,0 +1,21 @@ +class IRPass: + """ + Decorator for IR passes. This decorator will run the pass repeatedly + until no more changes are made. + """ + + @classmethod + def run_pass(cls, *args, **kwargs): + t = cls() + count = 0 + + while True: + changes_count = t._run_pass(*args, **kwargs) or 0 + count += changes_count + if changes_count == 0: + break + + return count + + def _run_pass(self, *args, **kwargs): + raise NotImplementedError(f"Not implemented! {self.__class__}.run_pass()") diff --git a/vyper/venom/passes/constant_propagation.py b/vyper/venom/passes/constant_propagation.py new file mode 100644 index 0000000000..94b556124e --- /dev/null +++ b/vyper/venom/passes/constant_propagation.py @@ -0,0 +1,13 @@ +from vyper.utils import ir_pass +from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.function import IRFunction + + +def _process_basic_block(ctx: IRFunction, bb: IRBasicBlock): + pass + + +@ir_pass +def ir_pass_constant_propagation(ctx: IRFunction): + for bb in ctx.basic_blocks: + _process_basic_block(ctx, bb) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py new file mode 100644 index 0000000000..26994bd27f --- /dev/null +++ b/vyper/venom/passes/dft.py @@ -0,0 +1,54 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import DFG +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +# DataFlow Transformation +class DFTPass(IRPass): + def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction): + if inst in self.visited_instructions: + return + self.visited_instructions.add(inst) + + if inst.opcode == "phi": + # phi instructions stay at the beginning of the basic block + # and no input processing is needed + bb.instructions.append(inst) + return + + for op in inst.get_inputs(): + target = self.dfg.get_producing_instruction(op) + if target.parent != inst.parent or target.fence_id != inst.fence_id: + # don't reorder across basic block or fence boundaries + continue + self._process_instruction_r(bb, target) + + bb.instructions.append(inst) + + def _process_basic_block(self, bb: IRBasicBlock) -> None: + self.ctx.append_basic_block(bb) + + instructions = bb.instructions + bb.instructions = [] + + for inst in instructions: + inst.fence_id = self.fence_id + if inst.volatile: + self.fence_id += 1 + + for inst in instructions: + self._process_instruction_r(bb, inst) + + def _run_pass(self, ctx: IRFunction) -> None: + self.ctx = ctx + self.dfg = DFG.build_dfg(ctx) + self.fence_id = 0 + self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() + + basic_blocks = ctx.basic_blocks + ctx.basic_blocks = [] + + for bb in basic_blocks: + self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py new file mode 100644 index 0000000000..9ee1012f91 --- /dev/null +++ b/vyper/venom/passes/normalization.py @@ -0,0 +1,90 @@ +from vyper.exceptions import CompilerPanic +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class NormalizationPass(IRPass): + """ + This pass splits basic blocks when there are multiple conditional predecessors. + The code generator expect a normalized CFG, that has the property that + each basic block has at most one conditional predecessor. + """ + + changes = 0 + + def _split_basic_block(self, bb: IRBasicBlock) -> None: + # Iterate over the predecessors of the basic block + for in_bb in list(bb.cfg_in): + jump_inst = in_bb.instructions[-1] + assert bb in in_bb.cfg_out + + # Handle static and dynamic branching + if jump_inst.opcode == "jnz": + self._split_for_static_branch(bb, in_bb) + elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): + self._split_for_dynamic_branch(bb, in_bb) + else: + continue + + self.changes += 1 + + def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + jump_inst = in_bb.instructions[-1] + for i, op in enumerate(jump_inst.operands): + if op == bb.label: + edge = i + break + else: + # none of the edges points to this bb + raise CompilerPanic("bad CFG") + + assert edge in (1, 2) # the arguments which can be labels + + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Redirect the original conditional jump to the intermediary basic block + jump_inst.operands[edge] = split_bb.label + + def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: + split_bb = self._insert_split_basicblock(bb, in_bb) + + # Update any affected labels in the data segment + # TODO: this DESTROYS the cfg! refactor so the translation of the + # selector table produces indirect jumps properly. + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + + def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: + # Create an intermediary basic block and append it + source = in_bb.label.value + target = bb.label.value + split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + self.ctx.append_basic_block(split_bb) + + # Rewire the CFG + # TODO: this is cursed code, it is necessary instead of just running + # calculate_cfg() because split_for_dynamic_branch destroys the CFG! + # ideally, remove this rewiring and just re-run calculate_cfg(). + split_bb.add_cfg_in(in_bb) + split_bb.add_cfg_out(bb) + in_bb.remove_cfg_out(bb) + in_bb.add_cfg_out(split_bb) + bb.remove_cfg_in(in_bb) + bb.add_cfg_in(split_bb) + return split_bb + + def _run_pass(self, ctx: IRFunction) -> int: + self.ctx = ctx + self.changes = 0 + + for bb in ctx.basic_blocks: + if len(bb.cfg_in) > 1: + self._split_basic_block(bb) + + # Sanity check + assert ctx.normalized, "Normalization pass failed" + + return self.changes diff --git a/vyper/venom/stack_model.py b/vyper/venom/stack_model.py new file mode 100644 index 0000000000..66c62b74d2 --- /dev/null +++ b/vyper/venom/stack_model.py @@ -0,0 +1,100 @@ +from vyper.venom.basicblock import IROperand, IRVariable + + +class StackModel: + NOT_IN_STACK = object() + _stack: list[IROperand] + + def __init__(self): + self._stack = [] + + def copy(self): + new = StackModel() + new._stack = self._stack.copy() + return new + + @property + def height(self) -> int: + """ + Returns the height of the stack map. + """ + return len(self._stack) + + def push(self, op: IROperand) -> None: + """ + Pushes an operand onto the stack map. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack.append(op) + + def pop(self, num: int = 1) -> None: + del self._stack[len(self._stack) - num :] + + def get_depth(self, op: IROperand) -> int: + """ + Returns the depth of the first matching operand in the stack map. + If the operand is not in the stack map, returns NOT_IN_STACK. + """ + assert isinstance(op, IROperand), f"{type(op)}: {op}" + + for i, stack_op in enumerate(reversed(self._stack)): + if stack_op.value == op.value: + return -i + + return StackModel.NOT_IN_STACK # type: ignore + + def get_phi_depth(self, phi1: IRVariable, phi2: IRVariable) -> int: + """ + Returns the depth of the first matching phi variable in the stack map. + If the none of the phi operands are in the stack, returns NOT_IN_STACK. + Asserts that exactly one of phi1 and phi2 is found. + """ + assert isinstance(phi1, IRVariable) + assert isinstance(phi2, IRVariable) + + ret = StackModel.NOT_IN_STACK + for i, stack_item in enumerate(reversed(self._stack)): + if stack_item in (phi1, phi2): + assert ( + ret is StackModel.NOT_IN_STACK + ), f"phi argument is not unique! {phi1}, {phi2}, {self._stack}" + ret = -i + + return ret # type: ignore + + def peek(self, depth: int) -> IROperand: + """ + Returns the top of the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot peek non-in-stack depth" + return self._stack[depth - 1] + + def poke(self, depth: int, op: IROperand) -> None: + """ + Pokes an operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot poke non-in-stack depth" + assert depth <= 0, "Bad depth" + assert isinstance(op, IROperand), f"{type(op)}: {op}" + self._stack[depth - 1] = op + + def dup(self, depth: int) -> None: + """ + Duplicates the operand at the given depth in the stack map. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot dup non-existent operand" + assert depth <= 0, "Cannot dup positive depth" + self._stack.append(self.peek(depth)) + + def swap(self, depth: int) -> None: + """ + Swaps the operand at the given depth in the stack map with the top of the stack. + """ + assert depth is not StackModel.NOT_IN_STACK, "Cannot swap non-existent operand" + assert depth < 0, "Cannot swap positive depth" + top = self._stack[-1] + self._stack[-1] = self._stack[depth - 1] + self._stack[depth - 1] = top + + def __repr__(self) -> str: + return f"" diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py new file mode 100644 index 0000000000..f6ec45440a --- /dev/null +++ b/vyper/venom/venom_to_assembly.py @@ -0,0 +1,461 @@ +from typing import Any + +from vyper.ir.compile_ir import PUSH, DataHeader, RuntimeHeader, optimize_assembly +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.analysis import calculate_cfg, calculate_liveness, input_vars_from +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, + MemType, +) +from vyper.venom.function import IRFunction +from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.stack_model import StackModel + +# instructions which map one-to-one from venom to EVM +_ONE_TO_ONE_INSTRUCTIONS = frozenset( + [ + "revert", + "coinbase", + "calldatasize", + "calldatacopy", + "calldataload", + "gas", + "gasprice", + "gaslimit", + "address", + "origin", + "number", + "extcodesize", + "extcodehash", + "returndatasize", + "returndatacopy", + "callvalue", + "selfbalance", + "sload", + "sstore", + "mload", + "mstore", + "timestamp", + "caller", + "selfdestruct", + "signextend", + "stop", + "shr", + "shl", + "and", + "xor", + "or", + "add", + "sub", + "mul", + "div", + "mod", + "exp", + "eq", + "iszero", + "lg", + "lt", + "slt", + "sgt", + "log0", + "log1", + "log2", + "log3", + "log4", + ] +) + + +# TODO: "assembly" gets into the recursion due to how the original +# IR was structured recursively in regards with the deploy instruction. +# There, recursing into the deploy instruction was by design, and +# made it easier to make the assembly generated "recursive" (i.e. +# instructions being lists of instructions). We don't have this restriction +# anymore, so we can probably refactor this to be iterative in coordination +# with the assembler. My suggestion is to let this be for now, and we can +# refactor it later when we are finished phasing out the old IR. +class VenomCompiler: + ctx: IRFunction + label_counter = 0 + visited_instructions: OrderedSet # {IRInstruction} + visited_basicblocks: OrderedSet # {IRBasicBlock} + + def __init__(self, ctx: IRFunction): + self.ctx = ctx + self.label_counter = 0 + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + + def generate_evm(self, no_optimize: bool = False) -> list[str]: + self.visited_instructions = OrderedSet() + self.visited_basicblocks = OrderedSet() + self.label_counter = 0 + + stack = StackModel() + asm: list[str] = [] + + # Before emitting the assembly, we need to make sure that the + # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) + # so it should not be called after calling NormalizationPass.run_pass(). + # Liveness is then computed for the normalized IR, and we can proceed to + # assembly generation. + # This is a side-effect of how dynamic jumps are temporarily being used + # to support the O(1) dispatcher. -> look into calculate_cfg() + calculate_cfg(self.ctx) + NormalizationPass.run_pass(self.ctx) + calculate_liveness(self.ctx) + + assert self.ctx.normalized, "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, self.ctx.basic_blocks[0], stack) + + # Append postambles + revert_postamble = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + runtime = None + if isinstance(asm[-1], list) and isinstance(asm[-1][0], RuntimeHeader): + runtime = asm.pop() + + asm.extend(revert_postamble) + if runtime: + runtime.extend(revert_postamble) + asm.append(runtime) + + # Append data segment + data_segments: dict[Any, list[Any]] = dict() + for inst in self.ctx.data_segment: + if inst.opcode == "dbname": + label = inst.operands[0].value + data_segments[label] = [DataHeader(f"_sym_{label}")] + elif inst.opcode == "db": + data_segments[label].append(f"_sym_{inst.operands[0].value}") + + extent_point = asm if not isinstance(asm[-1], list) else asm[-1] + extent_point.extend([data_segments[label] for label in data_segments]) # type: ignore + + if no_optimize is False: + optimize_assembly(asm) + + return asm + + def _stack_reorder( + self, assembly: list, stack: StackModel, _stack_ops: OrderedSet[IRVariable] + ) -> None: + # make a list so we can index it + stack_ops = [x for x in _stack_ops.keys()] + stack_ops_count = len(_stack_ops) + + for i in range(stack_ops_count): + op = stack_ops[i] + final_stack_depth = -(stack_ops_count - i - 1) + depth = stack.get_depth(op) # type: ignore + + if depth == final_stack_depth: + continue + + self.swap(assembly, stack, depth) + self.swap(assembly, stack, final_stack_depth) + + def _emit_input_operands( + self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + ) -> None: + # PRE: we already have all the items on the stack that have + # been scheduled to be killed. now it's just a matter of emitting + # SWAPs, DUPs and PUSHes until we match the `ops` argument + + # dumb heuristic: if the top of stack is not wanted here, swap + # it with something that is wanted + if ops and stack.height > 0 and stack.peek(0) not in ops: + for op in ops: + if isinstance(op, IRVariable) and op not in inst.dup_requirements: + self.swap_op(assembly, stack, op) + break + + emitted_ops = OrderedSet[IROperand]() + for op in ops: + if isinstance(op, IRLabel): + # invoke emits the actual instruction itself so we don't need to emit it here + # but we need to add it to the stack map + if inst.opcode != "invoke": + assembly.append(f"_sym_{op.value}") + stack.push(op) + continue + + if isinstance(op, IRLiteral): + assembly.extend([*PUSH(op.value)]) + stack.push(op) + continue + + if op in inst.dup_requirements: + self.dup_op(assembly, stack, op) + + if op in emitted_ops: + self.dup_op(assembly, stack, op) + + # REVIEW: this seems like it can be reordered across volatile + # boundaries (which includes memory fences). maybe just + # remove it entirely at this point + if isinstance(op, IRVariable) and op.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(op.mem_addr)]) + assembly.append("MLOAD") + + emitted_ops.add(op) + + def _generate_evm_for_basicblock_r( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if basicblock in self.visited_basicblocks: + return + self.visited_basicblocks.add(basicblock) + + # assembly entry point into the block + asm.append(f"_sym_{basicblock.label}") + asm.append("JUMPDEST") + + self.clean_stack_from_cfg_in(asm, basicblock, stack) + + for inst in basicblock.instructions: + asm = self._generate_evm_for_instruction(asm, inst, stack) + + for bb in basicblock.cfg_out: + self._generate_evm_for_basicblock_r(asm, bb, stack.copy()) + + # pop values from stack at entry to bb + # note this produces the same result(!) no matter which basic block + # we enter from in the CFG. + def clean_stack_from_cfg_in( + self, asm: list, basicblock: IRBasicBlock, stack: StackModel + ) -> None: + if len(basicblock.cfg_in) == 0: + return + + to_pop = OrderedSet[IRVariable]() + for in_bb in basicblock.cfg_in: + # inputs is the input variables we need from in_bb + inputs = input_vars_from(in_bb, basicblock) + + # layout is the output stack layout for in_bb (which works + # for all possible cfg_outs from the in_bb). + layout = in_bb.out_vars + + # pop all the stack items which in_bb produced which we don't need. + to_pop |= layout.difference(inputs) + + for var in to_pop: + depth = stack.get_depth(var) + # don't pop phantom phi inputs + if depth is StackModel.NOT_IN_STACK: + continue + + if depth != 0: + stack.swap(depth) + self.pop(asm, stack) + + def _generate_evm_for_instruction( + self, assembly: list, inst: IRInstruction, stack: StackModel + ) -> list[str]: + opcode = inst.opcode + + # + # generate EVM for op + # + + # Step 1: Apply instruction special stack manipulations + + if opcode in ["jmp", "jnz", "invoke"]: + operands = inst.get_non_label_operands() + elif opcode == "alloca": + operands = inst.operands[1:2] + elif opcode == "iload": + operands = [] + elif opcode == "istore": + operands = inst.operands[0:1] + else: + operands = inst.operands + + if opcode == "phi": + ret = inst.get_outputs()[0] + phi1, phi2 = inst.get_inputs() + depth = stack.get_phi_depth(phi1, phi2) + # collapse the arguments to the phi node in the stack. + # example, for `%56 = %label1 %13 %label2 %14`, we will + # find an instance of %13 *or* %14 in the stack and replace it with %56. + to_be_replaced = stack.peek(depth) + if to_be_replaced in inst.dup_requirements: + # %13/%14 is still live(!), so we make a copy of it + self.dup(assembly, stack, depth) + stack.poke(0, ret) + else: + stack.poke(depth, ret) + return assembly + + # Step 2: Emit instruction's input operands + self._emit_input_operands(assembly, inst, operands, stack) + + # Step 3: Reorder stack + if opcode in ["jnz", "jmp"]: + # prepare stack for jump into another basic block + assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) + b = next(iter(inst.parent.cfg_out)) + target_stack = input_vars_from(inst.parent, b) + # TODO optimize stack reordering at entry and exit from basic blocks + self._stack_reorder(assembly, stack, target_stack) + + # final step to get the inputs to this instruction ordered + # correctly on the stack + self._stack_reorder(assembly, stack, OrderedSet(operands)) + + # some instructions (i.e. invoke) need to do stack manipulations + # with the stack model containing the return value(s), so we fiddle + # with the stack model beforehand. + + # Step 4: Push instruction's return value to stack + stack.pop(len(operands)) + if inst.output is not None: + stack.push(inst.output) + + # Step 5: Emit the EVM instruction(s) + if opcode in _ONE_TO_ONE_INSTRUCTIONS: + assembly.append(opcode.upper()) + elif opcode == "alloca": + pass + elif opcode == "param": + pass + elif opcode == "store": + pass + elif opcode == "dbname": + pass + elif opcode in ["codecopy", "dloadbytes"]: + assembly.append("CODECOPY") + elif opcode == "jnz": + # jump if not zero + if_nonzero_label = inst.operands[1] + if_zero_label = inst.operands[2] + assembly.append(f"_sym_{if_nonzero_label.value}") + assembly.append("JUMPI") + + # make sure the if_zero_label will be optimized out + # assert if_zero_label == next(iter(inst.parent.cfg_out)).label + + assembly.append(f"_sym_{if_zero_label.value}") + assembly.append("JUMP") + + elif opcode == "jmp": + if isinstance(inst.operands[0], IRLabel): + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + else: + assembly.append("JUMP") + elif opcode == "gt": + assembly.append("GT") + elif opcode == "lt": + assembly.append("LT") + elif opcode == "invoke": + target = inst.operands[0] + assert isinstance(target, IRLabel), "invoke target must be a label" + assembly.extend( + [ + f"_sym_label_ret_{self.label_counter}", + f"_sym_{target.value}", + "JUMP", + f"_sym_label_ret_{self.label_counter}", + "JUMPDEST", + ] + ) + self.label_counter += 1 + if stack.height > 0 and stack.peek(0) in inst.dup_requirements: + self.pop(assembly, stack) + elif opcode == "call": + assembly.append("CALL") + elif opcode == "staticcall": + assembly.append("STATICCALL") + elif opcode == "ret": + assembly.append("JUMP") + elif opcode == "return": + assembly.append("RETURN") + elif opcode == "phi": + pass + elif opcode == "sha3": + assembly.append("SHA3") + elif opcode == "sha3_64": + assembly.extend( + [ + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "MSTORE", + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "SHA3", + ] + ) + elif opcode == "ceil32": + assembly.extend([*PUSH(31), "ADD", *PUSH(31), "NOT", "AND"]) + elif opcode == "assert": + assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + elif opcode == "deploy": + memsize = inst.operands[0].value + padding = inst.operands[2].value + # TODO: fix this by removing deploy opcode altogether me move emition to ir translation + while assembly[-1] != "JUMPDEST": + assembly.pop() + assembly.extend( + ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] + ) + assembly.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len + assembly.extend(["_mem_deploy_start"]) # stack: len mem_ofst + assembly.extend(["RETURN"]) + assembly.append([RuntimeHeader("_sym_runtime_begin", memsize, padding)]) # type: ignore + assembly = assembly[-1] + elif opcode == "iload": + loc = inst.operands[0].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MLOAD"]) + elif opcode == "istore": + loc = inst.operands[1].value + assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + else: + raise Exception(f"Unknown opcode: {opcode}") + + # Step 6: Emit instructions output operands (if any) + if inst.output is not None: + assert isinstance(inst.output, IRVariable), "Return value must be a variable" + if inst.output.mem_type == MemType.MEMORY: + assembly.extend([*PUSH(inst.output.mem_addr)]) + + return assembly + + def pop(self, assembly, stack, num=1): + stack.pop(num) + assembly.extend(["POP"] * num) + + def swap(self, assembly, stack, depth): + if depth == 0: + return + stack.swap(depth) + assembly.append(_evm_swap_for(depth)) + + def dup(self, assembly, stack, depth): + stack.dup(depth) + assembly.append(_evm_dup_for(depth)) + + def swap_op(self, assembly, stack, op): + self.swap(assembly, stack, stack.get_depth(op)) + + def dup_op(self, assembly, stack, op): + self.dup(assembly, stack, stack.get_depth(op)) + + +def _evm_swap_for(depth: int) -> str: + swap_idx = -depth + assert 1 <= swap_idx <= 16, "Unsupported swap depth" + return f"SWAP{swap_idx}" + + +def _evm_dup_for(depth: int) -> str: + dup_idx = 1 - depth + assert 1 <= dup_idx <= 16, "Unsupported dup depth" + return f"DUP{dup_idx}" From 21a47b614d1bd1e989195adedb1f5b709f5fbfee Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 7 Dec 2023 09:36:47 -0500 Subject: [PATCH 08/18] chore: move venom tests to `tests/unit/compiler` (#3684) the `tests/compiler` directory was moved to `tests/unit/` in 4dd47e302fc538c but this seems to have been missed in a merge during work on venom (cbac5aba53f87b) --- tests/functional/codegen/integration/test_crowdfund.py | 5 ++++- tests/{ => unit}/compiler/venom/test_duplicate_operands.py | 0 tests/{ => unit}/compiler/venom/test_multi_entry_block.py | 0 .../compiler/venom/test_stack_at_external_return.py | 0 4 files changed, 4 insertions(+), 1 deletion(-) rename tests/{ => unit}/compiler/venom/test_duplicate_operands.py (100%) rename tests/{ => unit}/compiler/venom/test_multi_entry_block.py (100%) rename tests/{ => unit}/compiler/venom/test_stack_at_external_return.py (100%) diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 47c63dc015..2083e62610 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -63,10 +63,13 @@ def refund(): """ a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] + c = get_contract_with_gas_estimation_for_constants(crowdfund, *[a1, 50, 60]) + start_timestamp = w3.eth.get_block(w3.eth.block_number).timestamp + c.participate(transact={"value": 5}) assert c.timelimit() == 60 - assert c.deadline() - c.block_timestamp() == 59 + assert c.deadline() - start_timestamp == 60 assert not c.expired() assert not c.reached() c.participate(transact={"value": 49}) diff --git a/tests/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py similarity index 100% rename from tests/compiler/venom/test_duplicate_operands.py rename to tests/unit/compiler/venom/test_duplicate_operands.py diff --git a/tests/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py similarity index 100% rename from tests/compiler/venom/test_multi_entry_block.py rename to tests/unit/compiler/venom/test_multi_entry_block.py diff --git a/tests/compiler/venom/test_stack_at_external_return.py b/tests/unit/compiler/venom/test_stack_at_external_return.py similarity index 100% rename from tests/compiler/venom/test_stack_at_external_return.py rename to tests/unit/compiler/venom/test_stack_at_external_return.py From 7c74aa2618c8051db88acfac3bd71a3017c524cb Mon Sep 17 00:00:00 2001 From: Franfran <51274081+iFrostizz@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:20:46 +0100 Subject: [PATCH 09/18] fix: add compile-time check for negative uint2str input (#3671) --- .../builtins/codegen/test_uint2str.py | 25 +++++++++++++++++++ vyper/builtins/functions.py | 5 +++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/functional/builtins/codegen/test_uint2str.py b/tests/functional/builtins/codegen/test_uint2str.py index 9d2b7fe3f5..d9edea154b 100644 --- a/tests/functional/builtins/codegen/test_uint2str.py +++ b/tests/functional/builtins/codegen/test_uint2str.py @@ -2,6 +2,9 @@ import pytest +from vyper.compiler import compile_code +from vyper.exceptions import InvalidType, OverflowException + VALID_BITS = list(range(8, 257, 8)) @@ -37,3 +40,25 @@ def foo(x: uint{bits}) -> uint256: """ c = get_contract(code) assert c.foo(2**bits - 1) == 0, bits + + +def test_bignum_throws(): + code = """ +@external +def test(): + a: String[78] = uint2str(2**256) + pass + """ + with pytest.raises(OverflowException): + compile_code(code) + + +def test_int_fails(): + code = """ +@external +def test(): + a: String[78] = uint2str(-1) + pass + """ + with pytest.raises(InvalidType): + compile_code(code) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index b2d817ec5c..22931508a6 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2090,7 +2090,10 @@ def evaluate(self, node): if not isinstance(node.args[0], vy_ast.Int): raise UnfoldableNode - value = str(node.args[0].value) + value = node.args[0].value + if value < 0: + raise InvalidType("Only unsigned ints allowed", node) + value = str(value) return vy_ast.Str.from_node(node, value=value) def infer_arg_types(self, node): From 10564dcc37756f3d3684b7a91fd8f4325a38c4d8 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Tue, 12 Dec 2023 19:47:44 +0200 Subject: [PATCH 10/18] refactor: improve `IRBasicBlock` builder API clean up `append_instruction` api so it does a bit of magic on its arguments and figures out whether or not to allocate a stack variable. remove `append_instruction()` from IRFunction - automatically appending to the last basic block could be a bit error prone depending on which order basic blocks are added to the CFG. --------- Co-authored-by: Charles Cooper --- .../compiler/venom/test_duplicate_operands.py | 11 +- .../compiler/venom/test_multi_entry_block.py | 53 +-- vyper/venom/analysis.py | 4 +- vyper/venom/basicblock.py | 86 ++++- vyper/venom/function.py | 11 - vyper/venom/ir_node_to_venom.py | 325 ++++++++---------- vyper/venom/passes/normalization.py | 4 +- vyper/venom/venom_to_assembly.py | 11 +- 8 files changed, 260 insertions(+), 245 deletions(-) diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 505f01e31b..a51992df67 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -1,6 +1,5 @@ from vyper.compiler.settings import OptimizationLevel from vyper.venom import generate_assembly_experimental -from vyper.venom.basicblock import IRLiteral from vyper.venom.function import IRFunction @@ -17,11 +16,11 @@ def test_duplicate_operands(): Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] """ ctx = IRFunction() - - op = ctx.append_instruction("store", [IRLiteral(10)]) - sum = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("mul", [sum, op]) - ctx.append_instruction("stop", [], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + sum = bb.append_instruction("add", op, op) + bb.append_instruction("mul", sum, op) + bb.append_instruction("stop") asm = generate_assembly_experimental(ctx, OptimizationLevel.CODESIZE) diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index bb57fa1065..6e7e6995d6 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -1,5 +1,4 @@ from vyper.venom.analysis import calculate_cfg -from vyper.venom.basicblock import IRLiteral from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel from vyper.venom.passes.normalization import NormalizationPass @@ -11,25 +10,26 @@ def test_multi_entry_block_1(): target_label = IRLabel("target") block_1_label = IRLabel("block_1", ctx) - op = ctx.append_instruction("store", [IRLiteral(10)]) - acc = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("jnz", acc, finish_label, block_1_label) block_1 = IRBasicBlock(block_1_label, ctx) ctx.append_basic_block(block_1) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, finish_label, target_label) target_bb = IRBasicBlock(target_label, ctx) ctx.append_basic_block(target_bb) - ctx.append_instruction("mul", [acc, acc]) - ctx.append_instruction("jmp", [finish_label], False) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) finish_bb = IRBasicBlock(finish_label, ctx) ctx.append_basic_block(finish_bb) - ctx.append_instruction("stop", [], False) + finish_bb.append_instruction("stop") calculate_cfg(ctx) assert not ctx.normalized, "CFG should not be normalized" @@ -54,33 +54,34 @@ def test_multi_entry_block_2(): block_1_label = IRLabel("block_1", ctx) block_2_label = IRLabel("block_2", ctx) - op = ctx.append_instruction("store", [IRLiteral(10)]) - acc = ctx.append_instruction("add", [op, op]) - ctx.append_instruction("jnz", [acc, finish_label, block_1_label], False) + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("jnz", acc, finish_label, block_1_label) block_1 = IRBasicBlock(block_1_label, ctx) ctx.append_basic_block(block_1) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - ctx.append_instruction("jnz", [acc, target_label, finish_label], False) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, target_label, finish_label) block_2 = IRBasicBlock(block_2_label, ctx) ctx.append_basic_block(block_2) - acc = ctx.append_instruction("add", [acc, op]) - op = ctx.append_instruction("store", [IRLiteral(10)]) - ctx.append_instruction("mstore", [acc, op], False) - # switch the order of the labels, for fun - ctx.append_instruction("jnz", [acc, finish_label, target_label], False) + acc = block_2.append_instruction("add", acc, op) + op = block_2.append_instruction("store", 10) + block_2.append_instruction("mstore", acc, op) + # switch the order of the labels, for fun and profit + block_2.append_instruction("jnz", acc, finish_label, target_label) target_bb = IRBasicBlock(target_label, ctx) ctx.append_basic_block(target_bb) - ctx.append_instruction("mul", [acc, acc]) - ctx.append_instruction("jmp", [finish_label], False) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) finish_bb = IRBasicBlock(finish_label, ctx) ctx.append_basic_block(finish_bb) - ctx.append_instruction("stop", [], False) + finish_bb.append_instruction("stop") calculate_cfg(ctx) assert not ctx.normalized, "CFG should not be normalized" diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 5980e21028..1a82ca85d0 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -2,7 +2,7 @@ from vyper.utils import OrderedSet from vyper.venom.basicblock import ( BB_TERMINATORS, - CFG_ALTERING_OPS, + CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRInstruction, IRVariable, @@ -55,7 +55,7 @@ def calculate_cfg(ctx: IRFunction) -> None: assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_OPS: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS: ops = inst.get_label_operands() for op in ops: ctx.get_basic_block(op.value).add_cfg_in(bb) diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index b95d7416ca..6f1c1c8ab3 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union from vyper.utils import OrderedSet @@ -31,8 +31,31 @@ ] ) -CFG_ALTERING_OPS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) +NO_OUTPUT_INSTRUCTIONS = frozenset( + [ + "deploy", + "mstore", + "sstore", + "dstore", + "istore", + "dloadbytes", + "calldatacopy", + "codecopy", + "return", + "ret", + "revert", + "assert", + "selfdestruct", + "stop", + "invalid", + "invoke", + "jmp", + "jnz", + "log", + ] +) +CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -40,8 +63,8 @@ class IRDebugInfo: """ - IRDebugInfo represents debug information in IR, used to annotate IR instructions - with source code information when printing IR. + IRDebugInfo represents debug information in IR, used to annotate IR + instructions with source code information when printing IR. """ line_no: int @@ -83,7 +106,7 @@ class IRLiteral(IRValue): value: int def __init__(self, value: int) -> None: - assert isinstance(value, str) or isinstance(value, int), "value must be an int" + assert isinstance(value, int), "value must be an int" self.value = value def __repr__(self) -> str: @@ -170,7 +193,7 @@ def __init__( assert isinstance(operands, list | Iterator), "operands must be a list" self.opcode = opcode self.volatile = opcode in VOLATILE_INSTRUCTIONS - self.operands = [op for op in operands] # in case we get an iterator + self.operands = list(operands) # in case we get an iterator self.output = output self.liveness = OrderedSet() self.dup_requirements = OrderedSet() @@ -233,6 +256,14 @@ def __repr__(self) -> str: return s +def _ir_operand_from_value(val: Any) -> IROperand: + if isinstance(val, IROperand): + return val + + assert isinstance(val, int) + return IRLiteral(val) + + class IRBasicBlock: """ IRBasicBlock represents a basic block in IR. Each basic block has a label and @@ -243,8 +274,8 @@ class IRBasicBlock: %2 = mul %1, 2 is represented as: bb = IRBasicBlock("bb", function) - bb.append_instruction(IRInstruction("add", ["%0", "1"], "%1")) - bb.append_instruction(IRInstruction("mul", ["%1", "2"], "%2")) + r1 = bb.append_instruction("add", "%0", "1") + r2 = bb.append_instruction("mul", r1, "2") The label of a basic block is used to refer to it from other basic blocks in order to branch to it. @@ -296,10 +327,41 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: def is_reachable(self) -> bool: return len(self.cfg_in) > 0 - def append_instruction(self, instruction: IRInstruction) -> None: - assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" - instruction.parent = self - self.instructions.append(instruction) + def append_instruction(self, opcode: str, *args: Union[IROperand, int]) -> Optional[IRVariable]: + """ + Append an instruction to the basic block + + Returns the output variable if the instruction supports one + """ + ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + + # Wrap raw integers in IRLiterals + inst_args = [_ir_operand_from_value(arg) for arg in args] + + inst = IRInstruction(opcode, inst_args, ret) + inst.parent = self + self.instructions.append(inst) + return ret + + def append_invoke_instruction( + self, args: list[IROperand | int], returns: bool + ) -> Optional[IRVariable]: + """ + Append an instruction to the basic block + + Returns the output variable if the instruction supports one + """ + ret = None + if returns: + ret = self.parent.get_next_variable() + + # Wrap raw integers in IRLiterals + inst_args = [_ir_operand_from_value(arg) for arg in args] + + inst = IRInstruction("invoke", inst_args, ret) + inst.parent = self + self.instructions.append(inst) + return ret def insert_instruction(self, instruction: IRInstruction, index: int) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" diff --git a/vyper/venom/function.py b/vyper/venom/function.py index c14ad77345..e16b2ad6e6 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -98,17 +98,6 @@ def remove_unreachable_blocks(self) -> int: self.basic_blocks = new_basic_blocks return removed - def append_instruction( - self, opcode: str, args: list[IROperand], do_ret: bool = True - ) -> Optional[IRVariable]: - """ - Append instruction to last basic block. - """ - ret = self.get_next_variable() if do_ret else None - inst = IRInstruction(opcode, args, ret) # type: ignore - self.get_basic_block().append_instruction(inst) - return ret - def append_data(self, opcode: str, args: list[IROperand]) -> None: """ Append data diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 19bd5c8b73..e2ce28a8f9 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -72,7 +72,7 @@ "balance", ] -SymbolTable = dict[str, IROperand] +SymbolTable = dict[str, Optional[IROperand]] def _get_symbols_common(a: dict, b: dict) -> dict: @@ -93,11 +93,11 @@ def convert_ir_basicblock(ir: IRnode) -> IRFunction: for i, bb in enumerate(global_function.basic_blocks): if not bb.is_terminated and i < len(global_function.basic_blocks) - 1: - bb.append_instruction(IRInstruction("jmp", [global_function.basic_blocks[i + 1].label])) + bb.append_instruction("jmp", global_function.basic_blocks[i + 1].label) revert_bb = IRBasicBlock(IRLabel("__revert"), global_function) revert_bb = global_function.append_basic_block(revert_bb) - revert_bb.append_instruction(IRInstruction("revert", [IRLiteral(0), IRLiteral(0)])) + revert_bb.append_instruction("revert", 0, 0) return global_function @@ -109,22 +109,16 @@ def _convert_binary_op( variables: OrderedSet, allocated_variables: dict[str, IRVariable], swap: bool = False, -) -> IRVariable: +) -> Optional[IRVariable]: ir_args = ir.args[::-1] if swap else ir.args arg_0 = _convert_ir_basicblock(ctx, ir_args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir_args[1], symbols, variables, allocated_variables) - args = [arg_1, arg_0] - - ret = ctx.get_next_variable() - inst = IRInstruction(ir.value, args, ret) # type: ignore - ctx.get_basic_block().append_instruction(inst) - return ret + return ctx.get_basic_block().append_instruction(str(ir.value), arg_1, arg_0) def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", label) label = ctx.get_next_label() bb = IRBasicBlock(label, ctx) @@ -149,7 +143,7 @@ def _handle_self_call( goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] target_label = goto_ir.args[0].value # goto return_buf = goto_ir.args[1] # return buffer - ret_args = [IRLabel(target_label)] # type: ignore + ret_args: list[IROperand] = [IRLabel(target_label)] # type: ignore for arg in args_ir: if arg.is_literal: @@ -164,16 +158,23 @@ def _handle_self_call( ctx, arg._optimized, symbols, variables, allocated_variables ) if arg.location and arg.location.load_op == "calldataload": - ret = ctx.append_instruction(arg.location.load_op, [ret]) + bb = ctx.get_basic_block() + ret = bb.append_instruction(arg.location.load_op, ret) ret_args.append(ret) if return_buf.is_literal: - ret_args.append(IRLiteral(return_buf.value)) # type: ignore + ret_args.append(return_buf.value) # type: ignore + + bb = ctx.get_basic_block() do_ret = func_t.return_type is not None - invoke_ret = ctx.append_instruction("invoke", ret_args, do_ret) # type: ignore - allocated_variables["return_buffer"] = invoke_ret # type: ignore - return invoke_ret + if do_ret: + invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore + allocated_variables["return_buffer"] = invoke_ret # type: ignore + return invoke_ret + else: + bb.append_invoke_instruction(ret_args, returns=False) # type: ignore + return None def _handle_internal_func( @@ -186,28 +187,18 @@ def _handle_internal_func( old_ir_mempos += 64 for arg in func_t.arguments: - new_var = ctx.get_next_variable() - - alloca_inst = IRInstruction("param", [], new_var) - alloca_inst.annotation = arg.name - bb.append_instruction(alloca_inst) - symbols[f"&{old_ir_mempos}"] = new_var + symbols[f"&{old_ir_mempos}"] = bb.append_instruction("param") + bb.instructions[-1].annotation = arg.name old_ir_mempos += 32 # arg.typ.memory_bytes_required # return buffer if func_t.return_type is not None: - new_var = ctx.get_next_variable() - alloca_inst = IRInstruction("param", [], new_var) - bb.append_instruction(alloca_inst) - alloca_inst.annotation = "return_buffer" - symbols["return_buffer"] = new_var + symbols["return_buffer"] = bb.append_instruction("param") + bb.instructions[-1].annotation = "return_buffer" # return address - new_var = ctx.get_next_variable() - alloca_inst = IRInstruction("param", [], new_var) - bb.append_instruction(alloca_inst) - alloca_inst.annotation = "return_pc" - symbols["return_pc"] = new_var + symbols["return_pc"] = bb.append_instruction("param") + bb.instructions[-1].annotation = "return_pc" return ir.args[0].args[2] @@ -222,7 +213,7 @@ def _convert_ir_simple_node( args = [ _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) for arg in ir.args ] - return ctx.append_instruction(ir.value, args) # type: ignore + return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore _break_target: Optional[IRBasicBlock] = None @@ -241,22 +232,22 @@ def _get_variable_from_address( return None -def _get_return_for_stack_operand( - ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable -) -> IRInstruction: +def _append_return_for_stack_operand( + bb: IRBasicBlock, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable +) -> None: if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) - ctx.append_instruction("mstore", [sym, new_var], False) # type: ignore + new_var = bb.append_instruction("alloca", 32, ret_ir) + bb.append_instruction("mstore", sym, new_var) # type: ignore else: sym = symbols.get(ret_ir.value, None) if sym is None: # FIXME: needs real allocations - new_var = ctx.append_instruction("alloca", [IRLiteral(32), IRLiteral(0)]) - ctx.append_instruction("mstore", [ret_ir, new_var], False) # type: ignore + new_var = bb.append_instruction("alloca", 32, 0) + bb.append_instruction("mstore", ret_ir, new_var) # type: ignore else: new_var = ret_ir - return IRInstruction("return", [last_ir, new_var]) # type: ignore + bb.append_instruction("return", last_ir, new_var) # type: ignore def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): @@ -280,7 +271,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] new_var = _convert_binary_op(ctx, ir, symbols, variables, allocated_variables) ir.value = org_value - return ctx.append_instruction("iszero", [new_var]) + return ctx.get_basic_block().append_instruction("iszero", new_var) elif ir.value in PASS_THROUGH_INSTRUCTIONS: return _convert_ir_simple_node(ctx, ir, symbols, variables, allocated_variables) @@ -296,8 +287,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): runtimeLabel = ctx.get_next_label() - inst = IRInstruction("deploy", [IRLiteral(memsize), runtimeLabel, IRLiteral(padding)]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("deploy", memsize, runtimeLabel, padding) bb = IRBasicBlock(runtimeLabel, ctx) ctx.append_basic_block(bb) @@ -369,12 +359,14 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): retVar = ctx.get_next_variable(MemType.MEMORY, retOffsetValue) symbols[f"&{retOffsetValue}"] = retVar + bb = ctx.get_basic_block() + if ir.value == "call": args = [retSize, retOffset, argsSize, argsOffsetVar, value, address, gas] - return ctx.append_instruction(ir.value, args) + return bb.append_instruction(ir.value, *args) else: args = [retSize, retOffset, argsSize, argsOffsetVar, address, gas] - return ctx.append_instruction(ir.value, args) + return bb.append_instruction(ir.value, *args) elif ir.value == "if": cond = ir.args[0] current_bb = ctx.get_basic_block() @@ -394,7 +386,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy - else_ret_val = ctx.append_instruction("store", [IRLiteral(else_ret_val.value)]) + else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) after_else_syms = else_syms.copy() # convert "then" @@ -405,10 +397,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ctx, ir.args[1], symbols, variables, allocated_variables ) if isinstance(then_ret_val, IRLiteral): - then_ret_val = ctx.append_instruction("store", [IRLiteral(then_ret_val.value)]) + then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) - inst = IRInstruction("jnz", [cont_ret, then_block.label, else_block.label]) - current_bb.append_instruction(inst) + current_bb.append_instruction("jnz", cont_ret, then_block.label, else_block.label) after_then_syms = symbols.copy() @@ -419,33 +410,25 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if_ret = None if then_ret_val is not None and else_ret_val is not None: - if_ret = ctx.get_next_variable() - bb.append_instruction( - IRInstruction( - "phi", [then_block.label, then_ret_val, else_block.label, else_ret_val], if_ret - ) + if_ret = bb.append_instruction( + "phi", then_block.label, then_ret_val, else_block.label, else_ret_val ) common_symbols = _get_symbols_common(after_then_syms, after_else_syms) for sym, val in common_symbols.items(): - ret = ctx.get_next_variable() + ret = bb.append_instruction("phi", then_block.label, val[0], else_block.label, val[1]) old_var = symbols.get(sym, None) symbols[sym] = ret if old_var is not None: for idx, var_rec in allocated_variables.items(): # type: ignore if var_rec.value == old_var.value: allocated_variables[idx] = ret # type: ignore - bb.append_instruction( - IRInstruction("phi", [then_block.label, val[0], else_block.label, val[1]], ret) - ) if not else_block.is_terminated: - exit_inst = IRInstruction("jmp", [bb.label]) - else_block.append_instruction(exit_inst) + else_block.append_instruction("jmp", bb.label) if not then_block.is_terminated: - exit_inst = IRInstruction("jmp", [bb.label]) - then_block.append_instruction(exit_inst) + then_block.append_instruction("jmp", bb.label) return if_ret @@ -459,7 +442,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): sym = ir.args[0] if isinstance(ret, IRLiteral): - new_var = ctx.append_instruction("store", [ret]) # type: ignore + new_var = ctx.get_basic_block().append_instruction("store", ret) # type: ignore with_symbols[sym.value] = new_var else: with_symbols[sym.value] = ret # type: ignore @@ -471,13 +454,12 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "jump": arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - inst = IRInstruction("jmp", [arg_1]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", arg_1) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - new_var = ctx.append_instruction("store", [arg_1]) # type: ignore + new_var = ctx.get_basic_block().append_instruction("store", arg_1) # type: ignore symbols[sym.value] = new_var elif ir.value == "calldatacopy": @@ -491,16 +473,15 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if isinstance(arg_0, IRLiteral) else None ) + bb = ctx.get_basic_block() if var is not None: if allocated_variables.get(var.name, None) is None: - new_v = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] # type: ignore - ) + new_v = bb.append_instruction("alloca", var.size, var.pos) # type: ignore allocated_variables[var.name] = new_v # type: ignore - ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore symbols[f"&{var.pos}"] = new_v # type: ignore else: - ctx.append_instruction("calldatacopy", [size, arg_1, new_v], False) # type: ignore + bb.append_instruction("calldatacopy", size, arg_1, new_v) # type: ignore return new_v elif ir.value == "codecopy": @@ -508,7 +489,7 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - ctx.append_instruction("codecopy", [size, arg_1, arg_0], False) # type: ignore + ctx.get_basic_block().append_instruction("codecopy", size, arg_1, arg_0) # type: ignore elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": @@ -526,13 +507,12 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value == "assert": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) current_bb = ctx.get_basic_block() - inst = IRInstruction("assert", [arg_0]) # type: ignore - current_bb.append_instruction(inst) + current_bb.append_instruction("assert", arg_0) elif ir.value == "label": label = IRLabel(ir.args[0].value, True) - if not ctx.get_basic_block().is_terminated: - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + bb = ctx.get_basic_block() + if not bb.is_terminated: + bb.append_instruction("jmp", label) bb = IRBasicBlock(label, ctx) ctx.append_basic_block(bb) _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) @@ -542,14 +522,13 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if func_t.is_external: # Hardcoded contructor special case + bb = ctx.get_basic_block() if func_t.name == "__init__": label = IRLabel(ir.args[0].value, True) - inst = IRInstruction("jmp", [label]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("jmp", label) return None if func_t.return_type is None: - inst = IRInstruction("stop", []) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("stop") return None else: last_ir = None @@ -569,6 +548,8 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ctx, ret_var, symbols, variables, allocated_variables ) + bb = ctx.get_basic_block() + var = ( _get_variable_from_address(variables, int(ret_ir.value)) if isinstance(ret_ir, IRLiteral) @@ -582,101 +563,96 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if var.size and int(var.size) > 32: offset = int(ret_ir.value) - var.pos # type: ignore if offset > 0: - ptr_var = ctx.append_instruction( - "add", [IRLiteral(var.pos), IRLiteral(offset)] - ) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_var - inst = IRInstruction("return", [last_ir, ptr_var]) + bb.append_instruction("return", last_ir, ptr_var) else: - inst = _get_return_for_stack_operand(ctx, symbols, new_var, last_ir) + _append_return_for_stack_operand(ctx, symbols, new_var, last_ir) else: if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) if sym is None: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: if func_t.return_type.memory_bytes_required > 32: - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_ir]) - ctx.append_instruction("mstore", [sym, new_var], False) - inst = IRInstruction("return", [last_ir, new_var]) + new_var = bb.append_instruction("alloca", 32, ret_ir) + bb.append_instruction("mstore", sym, new_var) + bb.append_instruction("return", last_ir, new_var) else: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: if last_ir and int(last_ir.value) > 32: - inst = IRInstruction("return", [last_ir, ret_ir]) + bb.append_instruction("return", last_ir, ret_ir) else: - ret_buf = IRLiteral(128) # TODO: need allocator - new_var = ctx.append_instruction("alloca", [IRLiteral(32), ret_buf]) - ctx.append_instruction("mstore", [ret_ir, new_var], False) - inst = IRInstruction("return", [last_ir, new_var]) + ret_buf = 128 # TODO: need allocator + new_var = bb.append_instruction("alloca", 32, ret_buf) + bb.append_instruction("mstore", ret_ir, new_var) + bb.append_instruction("return", last_ir, new_var) - ctx.get_basic_block().append_instruction(inst) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + bb = ctx.get_basic_block() if func_t.is_internal: assert ir.args[1].value == "return_pc", "return_pc not found" if func_t.return_type is None: - inst = IRInstruction("ret", [symbols["return_pc"]]) + bb.append_instruction("ret", symbols["return_pc"]) else: if func_t.return_type.memory_bytes_required > 32: - inst = IRInstruction("ret", [symbols["return_buffer"], symbols["return_pc"]]) + bb.append_instruction("ret", symbols["return_buffer"], symbols["return_pc"]) else: - ret_by_value = ctx.append_instruction("mload", [symbols["return_buffer"]]) - inst = IRInstruction("ret", [ret_by_value, symbols["return_pc"]]) - - ctx.get_basic_block().append_instruction(inst) + ret_by_value = bb.append_instruction("mload", symbols["return_buffer"]) + bb.append_instruction("ret", ret_by_value, symbols["return_pc"]) elif ir.value == "revert": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - inst = IRInstruction("revert", [arg_1, arg_0]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("revert", arg_1, arg_0) elif ir.value == "dload": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - src = ctx.append_instruction("add", [arg_0, IRLabel("code_end")]) + bb = ctx.get_basic_block() + src = bb.append_instruction("add", arg_0, IRLabel("code_end")) + + bb.append_instruction("dloadbytes", 32, src, MemoryPositions.FREE_VAR_SPACE) + return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) - ctx.append_instruction( - "dloadbytes", [IRLiteral(32), src, IRLiteral(MemoryPositions.FREE_VAR_SPACE)], False - ) - return ctx.append_instruction("mload", [IRLiteral(MemoryPositions.FREE_VAR_SPACE)]) elif ir.value == "dloadbytes": dst = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) src_offset = _convert_ir_basicblock( ctx, ir.args[1], symbols, variables, allocated_variables ) len_ = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - - src = ctx.append_instruction("add", [src_offset, IRLabel("code_end")]) - - inst = IRInstruction("dloadbytes", [len_, src, dst]) - ctx.get_basic_block().append_instruction(inst) + bb = ctx.get_basic_block() + src = bb.append_instruction("add", src_offset, IRLabel("code_end")) + bb.append_instruction("dloadbytes", len_, src, dst) return None + elif ir.value == "mload": sym_ir = ir.args[0] var = ( _get_variable_from_address(variables, int(sym_ir.value)) if sym_ir.is_literal else None ) + bb = ctx.get_basic_block() if var is not None: if var.size and var.size > 32: if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + allocated_variables[var.name] = bb.append_instruction( + "alloca", var.size, var.pos ) offset = int(sym_ir.value) - var.pos if offset > 0: - ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_variables[var.name] - return ctx.append_instruction("mload", [ptr_var]) + return bb.append_instruction("mload", ptr_var) else: if sym_ir.is_literal: sym = symbols.get(f"&{sym_ir.value}", None) if sym is None: - new_var = ctx.append_instruction("store", [sym_ir]) + new_var = bb.append_instruction("store", sym_ir) symbols[f"&{sym_ir.value}"] = new_var if allocated_variables.get(var.name, None) is None: allocated_variables[var.name] = new_var @@ -691,9 +667,9 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if sym_ir.is_literal: new_var = symbols.get(f"&{sym_ir.value}", None) if new_var is not None: - return ctx.append_instruction("mload", [new_var]) + return bb.append_instruction("mload", new_var) else: - return ctx.append_instruction("mload", [IRLiteral(sym_ir.value)]) + return bb.append_instruction("mload", sym_ir.value) else: new_var = _convert_ir_basicblock( ctx, sym_ir, symbols, variables, allocated_variables @@ -706,12 +682,14 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): # if sym_ir.is_self_call: return new_var - return ctx.append_instruction("mload", [new_var]) + return bb.append_instruction("mload", new_var) elif ir.value == "mstore": sym_ir = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) + bb = ctx.get_basic_block() + var = None if isinstance(sym_ir, IRLiteral): var = _get_variable_from_address(variables, int(sym_ir.value)) @@ -719,41 +697,38 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): if var is not None and var.size is not None: if var.size and var.size > 32: if allocated_variables.get(var.name, None) is None: - allocated_variables[var.name] = ctx.append_instruction( - "alloca", [IRLiteral(var.size), IRLiteral(var.pos)] + allocated_variables[var.name] = bb.append_instruction( + "alloca", var.size, var.pos ) offset = int(sym_ir.value) - var.pos if offset > 0: - ptr_var = ctx.append_instruction("add", [IRLiteral(var.pos), IRLiteral(offset)]) + ptr_var = bb.append_instruction("add", var.pos, offset) else: ptr_var = allocated_variables[var.name] - return ctx.append_instruction("mstore", [arg_1, ptr_var], False) + bb.append_instruction("mstore", arg_1, ptr_var) else: if isinstance(sym_ir, IRLiteral): - new_var = ctx.append_instruction("store", [arg_1]) + new_var = bb.append_instruction("store", arg_1) symbols[f"&{sym_ir.value}"] = new_var # if allocated_variables.get(var.name, None) is None: allocated_variables[var.name] = new_var return new_var else: if not isinstance(sym_ir, IRLiteral): - inst = IRInstruction("mstore", [arg_1, sym_ir]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym_ir) return None sym = symbols.get(f"&{sym_ir.value}", None) if sym is None: - inst = IRInstruction("mstore", [arg_1, sym_ir]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym_ir) if arg_1 and not isinstance(sym_ir, IRLiteral): symbols[f"&{sym_ir.value}"] = arg_1 return None if isinstance(sym_ir, IRLiteral): - inst = IRInstruction("mstore", [arg_1, sym]) - ctx.get_basic_block().append_instruction(inst) + bb.append_instruction("mstore", arg_1, sym) return None else: symbols[sym_ir.value] = arg_1 @@ -761,12 +736,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): elif ir.value in ["sload", "iload"]: arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - return ctx.append_instruction(ir.value, [arg_0]) + return ctx.get_basic_block().append_instruction(ir.value, arg_0) elif ir.value in ["sstore", "istore"]: arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) - inst = IRInstruction(ir.value, [arg_1, arg_0]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction(ir.value, arg_1, arg_0) elif ir.value == "unique_symbol": sym = ir.args[0] new_var = ctx.get_next_variable() @@ -803,28 +777,19 @@ def emit_body_block(): increment_block = IRBasicBlock(ctx.get_next_label(), ctx) exit_block = IRBasicBlock(ctx.get_next_label(), ctx) - counter_var = ctx.get_next_variable() counter_inc_var = ctx.get_next_variable() - ret = ctx.get_next_variable() - inst = IRInstruction("store", [start], counter_var) - ctx.get_basic_block().append_instruction(inst) + counter_var = ctx.get_basic_block().append_instruction("store", start) symbols[sym.value] = counter_var - inst = IRInstruction("jmp", [cond_block.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", cond_block.label) - symbols[sym.value] = ret - cond_block.append_instruction( - IRInstruction( - "phi", [entry_block.label, counter_var, increment_block.label, counter_inc_var], ret - ) + ret = cond_block.append_instruction( + "phi", entry_block.label, counter_var, increment_block.label, counter_inc_var ) + symbols[sym.value] = ret - xor_ret = ctx.get_next_variable() - cont_ret = ctx.get_next_variable() - inst = IRInstruction("xor", [ret, end], xor_ret) - cond_block.append_instruction(inst) - cond_block.append_instruction(IRInstruction("iszero", [xor_ret], cont_ret)) + xor_ret = cond_block.append_instruction("xor", ret, end) + cont_ret = cond_block.append_instruction("iszero", xor_ret) ctx.append_basic_block(cond_block) # Do a dry run to get the symbols needing phi nodes @@ -851,56 +816,55 @@ def emit_body_block(): body_end = ctx.get_basic_block() if not body_end.is_terminated: - body_end.append_instruction(IRInstruction("jmp", [jump_up_block.label])) + body_end.append_instruction("jmp", jump_up_block.label) - jump_cond = IRInstruction("jmp", [increment_block.label]) - jump_up_block.append_instruction(jump_cond) + jump_up_block.append_instruction("jmp", increment_block.label) ctx.append_basic_block(jump_up_block) - increment_block.append_instruction( - IRInstruction("add", [ret, IRLiteral(1)], counter_inc_var) - ) - increment_block.append_instruction(IRInstruction("jmp", [cond_block.label])) + increment_block.append_instruction(IRInstruction("add", ret, 1)) + increment_block.insert_instruction[-1].output = counter_inc_var + + increment_block.append_instruction("jmp", cond_block.label) ctx.append_basic_block(increment_block) ctx.append_basic_block(exit_block) - inst = IRInstruction("jnz", [cont_ret, exit_block.label, body_block.label]) - cond_block.append_instruction(inst) + cond_block.append_instruction("jnz", cont_ret, exit_block.label, body_block.label) elif ir.value == "break": assert _break_target is not None, "Break with no break target" - inst = IRInstruction("jmp", [_break_target.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", _break_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) elif ir.value == "continue": assert _continue_target is not None, "Continue with no contrinue target" - inst = IRInstruction("jmp", [_continue_target.label]) - ctx.get_basic_block().append_instruction(inst) + ctx.get_basic_block().append_instruction("jmp", _continue_target.label) ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) elif ir.value == "gas": - return ctx.append_instruction("gas", []) + return ctx.get_basic_block().append_instruction("gas") elif ir.value == "returndatasize": - return ctx.append_instruction("returndatasize", []) + return ctx.get_basic_block().append_instruction("returndatasize") elif ir.value == "returndatacopy": assert len(ir.args) == 3, "returndatacopy with wrong number of arguments" arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) arg_1 = _convert_ir_basicblock(ctx, ir.args[1], symbols, variables, allocated_variables) size = _convert_ir_basicblock(ctx, ir.args[2], symbols, variables, allocated_variables) - new_var = ctx.append_instruction("returndatacopy", [arg_1, size]) + new_var = ctx.get_basic_block().append_instruction("returndatacopy", arg_1, size) symbols[f"&{arg_0.value}"] = new_var return new_var elif ir.value == "selfdestruct": arg_0 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.append_instruction("selfdestruct", [arg_0], False) + ctx.get_basic_block().append_instruction("selfdestruct", arg_0) elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = [ - _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) - for arg in ir.args - ] - inst = IRInstruction(ir.value, reversed(args)) - ctx.get_basic_block().append_instruction(inst) + args = reversed( + [ + _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) + for arg in ir.args + ] + ) + topic_count = int(ir.value[3:]) + assert topic_count >= 0 and topic_count <= 4, "invalid topic count" + ctx.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): _convert_ir_opcode(ctx, ir, symbols, variables, allocated_variables) elif isinstance(ir.value, str) and ir.value in symbols: @@ -927,8 +891,7 @@ def _convert_ir_opcode( inst_args.append( _convert_ir_basicblock(ctx, arg, symbols, variables, allocated_variables) ) - instruction = IRInstruction(opcode, inst_args) # type: ignore - ctx.get_basic_block().append_instruction(instruction) + ctx.get_basic_block().append_instruction(opcode, *inst_args) def _data_ofst_of(sym, ofst, height_): diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 9ee1012f91..90dd60e881 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,5 +1,5 @@ from vyper.exceptions import CompilerPanic -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -61,7 +61,7 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB source = in_bb.label.value target = bb.label.value split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) - split_bb.append_instruction(IRInstruction("jmp", [bb.label])) + split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) # Rewire the CFG diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index f6ec45440a..8760e9aa63 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -62,11 +62,6 @@ "lt", "slt", "sgt", - "log0", - "log1", - "log2", - "log3", - "log4", ] ) @@ -274,6 +269,10 @@ def _generate_evm_for_instruction( operands = [] elif opcode == "istore": operands = inst.operands[0:1] + elif opcode == "log": + log_topic_count = inst.operands[0].value + assert log_topic_count in [0, 1, 2, 3, 4], "Invalid topic count" + operands = inst.operands[1:] else: operands = inst.operands @@ -417,6 +416,8 @@ def _generate_evm_for_instruction( elif opcode == "istore": loc = inst.operands[1].value assembly.extend(["_OFST", "_mem_deploy_end", loc, "MSTORE"]) + elif opcode == "log": + assembly.extend([f"LOG{log_topic_count}"]) else: raise Exception(f"Unknown opcode: {opcode}") From 0b1f3e143c4f432c469b61fbe1566cb46cfcfca1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 13 Dec 2023 18:16:34 -0500 Subject: [PATCH 11/18] fix: remove .keyword from Call AST node (#3689) for some reason, there is a slot named "keyword" on the Call AST node, which is never used (and doesn't exist in the python AST!). this commit removes it for hygienic purposes. --- vyper/ast/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 2497928035..69bd1fed53 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1254,7 +1254,7 @@ def _op(self, left, right): class Call(ExprNode): - __slots__ = ("func", "args", "keywords", "keyword") + __slots__ = ("func", "args", "keywords") class keyword(VyperNode): From 919080e0b74c908d986f5cee121a2bf2379cb2dc Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 15 Dec 2023 21:31:31 -0500 Subject: [PATCH 12/18] chore: test all output formats (#3683) right now only certain output formats are tested in the main compiler test harness, namely bytecode, abi, metadata and some natspec outputs. in the past, there have been issues where output formats get broken but don't get detected until release testing or even after release. this commit adds hooks in `get_contract()` and `deploy_blueprint_for()` to generate all output formats, which will help detect broken output formats sooner. --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 216fb32b0d..22f8544beb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -279,8 +279,8 @@ def _get_contract( settings.optimize = override_opt_level or optimize out = compiler.compile_code( source_code, - # test that metadata and natspecs get generated - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + # test that all output formats can get generated + output_formats=list(compiler.OUTPUT_FORMATS.keys()), settings=settings, input_bundle=input_bundle, show_gas_estimates=True, # Enable gas estimates for testing @@ -352,7 +352,7 @@ def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwar settings.optimize = optimize out = compiler.compile_code( source_code, - output_formats=["abi", "bytecode", "metadata", "userdoc", "devdoc"], + output_formats=list(compiler.OUTPUT_FORMATS.keys()), settings=settings, show_gas_estimates=True, # Enable gas estimates for testing ) From c6f457a73db40e4b113497883bd330e0dcec28d1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 11:16:14 -0500 Subject: [PATCH 13/18] feat: implement "stateless" modules (#3663) this commit implements support for "stateless" modules in vyper. this is the first major step in implementing vyper's module system. it redesigns the language's import system, allows calling internal functions from imported modules, allows for limited use of types from imported modules, and introduces support for `.vyi` interface files. note that the following features are left for future, follow-up work: - modules with variables (constants, immutables or storage variables) - full support for imported events (in that they do not get exported in the ABI) - a system for exporting imported functions in the external interface of a contract this commit first and foremost changes how imports are handled in vyper. previously, an imported file was assumed to be an interface file. some very limited validation was performed in `InterfaceT.from_ast`, but not fully typechecked, and no code was generated for it. now, when a file is imported, it is checked whether it is 1. a `.vy` file 2. a `.vyi` file 3. a `.json` file the `.json` pathway remains more or less unchanged. the `.vyi` pathway is new, but it is fairly straightforward and is basically a "simple" path through the `.vy` pathway which piggy-backs off the `.vy` analysis to produce an `InterfaceT` object. the `.vy` pathway now does full typechecking and analysis of the imported module. some changes were made to support this: - a new ImportGraph data structure tracks the position in the import graph and detects (and bands) import cycles - InputBundles now implement a `_normalize_path()` method. this method normalizes the path so that source IDs are stable no matter how a file is accessed in the filesystem (i.e., no matter what the search path was at the time `load_file()` was called). - CompilerInput now has a distinction between `resolved_path` and `path` (the original path that was asked for). this allows us to maintain UX considerations (showing unresolved paths etc) while still having a 1:1:1 correspondence between source id, filepath and filesystem. these changes were needed in order to stabilize notions like "which file are we looking at?" no matter the way the file was accessed or how it was imported. this is important so that types imported transitively can resolve as expected no matter how they are imported - for instance, `x.SomeType` and `a.x.SomeType` resolving to the same type. the other changes needed to support code generation and analysis for imported functions were fairly simple, and mostly involved generalizing the analysis/code generation to type-based dispatch instead of AST-based dispatch. other changes to the language and compiler API include: - import restrictions are more relaxed - `import x` is allowed now (previously, `import x as x` was required) - change function labels in IR function labels are changed to disambiguate functions of the same name (but whose parent module are different). this was done by computing a unique function_id for every function and using that function_id when constructing its IR identifier. - add compile_from_file_input which accepts a FileInput instead of a string. this is now the new preferred entry point into the compiler. its usage simplifies several internal APIs which expect to have `source_id` and `path` in addition to the raw source code. - change compile_code api from contract_name= to contract_path= additional changes to internal APIs and passes include: - remove `remove_unused_statements()` the "unused statements" are now important to keep around for imports! in general, it is planned to remove both the AST expansion and constant folding passes as copying around the AST results in both performance and correctness problems - abstract out a common exception rewriting pattern. instead of raising `exception.with_annotation(node)` -- just catch-all in the parent implementation and then don't have to worry about it at the exception site. - rename "type" metadata key on most top-level declarators to more specific names (e.g. "func_type", "getter_type", etc). - remove dead package pytest-rerunfailures use of `--reruns` was removed in c913b2db0881a6 - refactor: move `parse_*` functions, remove vyper.ast.annotation move `parse_*` functions into new module vyper.ast.parse and merge in vyper.ast.annotation - rename the old `GlobalContext` class to `ModuleT` - refactor: move InterfaceT into `vyper/semantics/types/module.py` it makes more sense here since it is closely coupled with `ModuleT`. --- setup.py | 1 - tests/conftest.py | 10 +- .../codegen/test_call_graph_stability.py | 2 +- .../{builtins => }/codegen/test_interfaces.py | 129 ++++-- .../codegen/test_selector_table_stability.py | 2 +- .../codegen/test_stateless_modules.py | 335 ++++++++++++++ tests/functional/grammar/test_grammar.py | 2 +- tests/functional/syntax/test_interfaces.py | 9 +- tests/unit/ast/nodes/test_hex.py | 4 +- .../ast/test_annotate_and_optimize_ast.py | 3 +- tests/unit/ast/test_ast_dict.py | 7 +- tests/unit/ast/test_parser.py | 2 +- .../test_storage_layout.py | 0 .../test_storage_layout_overrides.py | 2 +- .../cli/vyper_compile/test_compile_files.py | 182 ++++---- .../unit/cli/vyper_json/test_compile_json.py | 111 +++-- tests/unit/cli/vyper_json/test_get_inputs.py | 5 +- .../cli/vyper_json/test_output_selection.py | 52 ++- .../vyper_json/test_parse_args_vyperjson.py | 4 +- tests/unit/compiler/asm/test_asm_optimizer.py | 61 ++- tests/unit/compiler/test_input_bundle.py | 141 ++++-- .../semantics/analysis/test_array_index.py | 20 +- .../analysis/test_cyclic_function_calls.py | 28 +- .../unit/semantics/analysis/test_for_loop.py | 32 +- tests/unit/semantics/test_storage_slots.py | 2 +- tox.ini | 2 +- vyper/__init__.py | 2 +- vyper/ast/__init__.py | 3 +- vyper/ast/__init__.pyi | 2 +- vyper/ast/expansion.py | 51 +- vyper/ast/grammar.lark | 9 +- vyper/ast/natspec.py | 2 +- vyper/ast/nodes.py | 21 +- vyper/ast/nodes.pyi | 13 +- vyper/ast/{annotation.py => parse.py} | 128 +++++- vyper/ast/utils.py | 61 +-- vyper/builtins/_utils.py | 9 +- vyper/builtins/functions.py | 4 +- .../interfaces/{ERC165.vy => ERC165.vyi} | 2 +- .../interfaces/{ERC20.vy => ERC20.vyi} | 24 +- .../{ERC20Detailed.vy => ERC20Detailed.vyi} | 6 +- .../interfaces/{ERC4626.vy => ERC4626.vyi} | 32 +- .../interfaces/{ERC721.vy => ERC721.vyi} | 43 +- vyper/cli/vyper_compile.py | 6 +- vyper/cli/vyper_json.py | 31 +- vyper/codegen/context.py | 8 +- vyper/codegen/expr.py | 49 +- vyper/codegen/function_definitions/common.py | 19 +- vyper/codegen/global_context.py | 32 -- vyper/codegen/module.py | 116 +++-- vyper/codegen/self_call.py | 11 +- vyper/codegen/stmt.py | 57 +-- vyper/compiler/__init__.py | 45 +- vyper/compiler/input_bundle.py | 111 +++-- vyper/compiler/output.py | 22 +- vyper/compiler/phases.py | 102 ++-- vyper/exceptions.py | 20 +- vyper/semantics/analysis/__init__.py | 17 +- vyper/semantics/analysis/base.py | 39 +- vyper/semantics/analysis/common.py | 21 +- vyper/semantics/analysis/data_positions.py | 4 +- vyper/semantics/analysis/import_graph.py | 37 ++ vyper/semantics/analysis/local.py | 45 +- vyper/semantics/analysis/module.py | 411 +++++++++++------ vyper/semantics/analysis/utils.py | 21 +- vyper/semantics/namespace.py | 2 +- vyper/semantics/types/__init__.py | 3 +- vyper/semantics/types/base.py | 12 +- vyper/semantics/types/bytestrings.py | 10 +- vyper/semantics/types/function.py | 435 +++++++++++------- vyper/semantics/types/module.py | 332 +++++++++++++ vyper/semantics/types/subscriptable.py | 36 +- vyper/semantics/types/user.py | 268 ++--------- vyper/semantics/types/utils.py | 52 ++- vyper/utils.py | 9 +- 75 files changed, 2546 insertions(+), 1397 deletions(-) rename tests/functional/{builtins => }/codegen/test_interfaces.py (84%) create mode 100644 tests/functional/codegen/test_stateless_modules.py rename tests/unit/cli/{outputs => storage_layout}/test_storage_layout.py (100%) rename tests/unit/cli/{outputs => storage_layout}/test_storage_layout_overrides.py (98%) rename vyper/ast/{annotation.py => parse.py} (68%) rename vyper/builtins/interfaces/{ERC165.vy => ERC165.vyi} (88%) rename vyper/builtins/interfaces/{ERC20.vy => ERC20.vyi} (68%) rename vyper/builtins/interfaces/{ERC20Detailed.vy => ERC20Detailed.vyi} (93%) rename vyper/builtins/interfaces/{ERC4626.vy => ERC4626.vyi} (90%) rename vyper/builtins/interfaces/{ERC721.vy => ERC721.vyi} (61%) delete mode 100644 vyper/codegen/global_context.py create mode 100644 vyper/semantics/analysis/import_graph.py create mode 100644 vyper/semantics/types/module.py diff --git a/setup.py b/setup.py index 40efb436c5..431c50b74b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "pytest-instafail>=0.4,<1.0", "pytest-xdist>=2.5,<3.0", "pytest-split>=0.7.0,<1.0", - "pytest-rerunfailures>=10.2,<11", "eth-tester[py-evm]>=0.9.0b1,<0.10", "py-evm>=0.7.0a1,<0.8", "web3==6.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 22f8544beb..925a025a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, _set_debug_mode from vyper.ir import compile_ir, optimizer @@ -103,6 +103,12 @@ def fn(sources_dict): return fn +# for tests which just need an input bundle, doesn't matter what it is +@pytest.fixture +def dummy_input_bundle(): + return InputBundle([]) + + # TODO: remove me, this is just string.encode("utf-8").ljust() # only used in test_logging.py. @pytest.fixture @@ -255,9 +261,11 @@ def ir_compiler(ir, *args, **kwargs): ir = IRnode.from_list(ir) if optimize != OptimizationLevel.NONE: ir = optimizer.optimize(ir) + bytecode, _ = compile_ir.assembly_to_evm( compile_ir.compile_to_assembly(ir, optimize=optimize) ) + abi = kwargs.get("abi") or [] c = w3.eth.contract(abi=abi, bytecode=bytecode) deploy_transaction = c.constructor() diff --git a/tests/functional/codegen/test_call_graph_stability.py b/tests/functional/codegen/test_call_graph_stability.py index 4c85c330f3..2d8ad59791 100644 --- a/tests/functional/codegen/test_call_graph_stability.py +++ b/tests/functional/codegen/test_call_graph_stability.py @@ -55,7 +55,7 @@ def foo(): # check the .called_functions data structure on foo() directly foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0] - foo_t = foo._metadata["type"] + foo_t = foo._metadata["func_type"] assert [f.name for f in foo_t.called_functions] == func_names # now for sanity, ensure the order that the function definitions appear diff --git a/tests/functional/builtins/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py similarity index 84% rename from tests/functional/builtins/codegen/test_interfaces.py rename to tests/functional/codegen/test_interfaces.py index 8cb0124f29..3544f4a965 100644 --- a/tests/functional/builtins/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -6,9 +6,9 @@ from vyper.compiler import compile_code from vyper.exceptions import ( ArgumentException, + DuplicateImport, InterfaceViolation, NamespaceCollision, - StructureException, ) @@ -31,7 +31,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): out = compile_code(code, output_formats=["interface"]) out = out["interface"] - code_pass = "\n".join(code.split("\n")[:-2] + [" pass"]) # replace with a pass statement. + code_pass = "\n".join(code.split("\n")[:-2] + [" ..."]) # replace with a pass statement. assert code_pass.strip() == out.strip() @@ -60,7 +60,7 @@ def allowance(_owner: address, _spender: address) -> (uint256, uint256): view def test(_owner: address): nonpayable """ - out = compile_code(code, contract_name="One.vy", output_formats=["external_interface"])[ + out = compile_code(code, contract_path="One.vy", output_formats=["external_interface"])[ "external_interface" ] @@ -85,14 +85,14 @@ def test_external_interface_parsing(make_input_bundle, assert_compile_failed): interface_code = """ @external def foo() -> uint256: - pass + ... @external def bar() -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) code = """ import a as FooBarInterface @@ -121,9 +121,8 @@ def foo() -> uint256: """ - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) + with pytest.raises(InterfaceViolation): + compile_code(not_implemented_code, input_bundle=input_bundle) def test_missing_event(make_input_bundle, assert_compile_failed): @@ -132,7 +131,7 @@ def test_missing_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -156,7 +155,7 @@ def test_malformed_event(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -183,7 +182,7 @@ def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): a: uint256 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -211,7 +210,7 @@ def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): a: indexed(uint256) """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) not_implemented_code = """ import a as FooBarInterface @@ -234,13 +233,13 @@ def bar() -> uint256: VALID_IMPORT_CODE = [ # import statement, import path without suffix - ("import a as Foo", "a.vy"), - ("import b.a as Foo", "b/a.vy"), - ("import Foo as Foo", "Foo.vy"), - ("from a import Foo", "a/Foo.vy"), - ("from b.a import Foo", "b/a/Foo.vy"), - ("from .a import Foo", "./a/Foo.vy"), - ("from ..a import Foo", "../a/Foo.vy"), + ("import a as Foo", "a.vyi"), + ("import b.a as Foo", "b/a.vyi"), + ("import Foo as Foo", "Foo.vyi"), + ("from a import Foo", "a/Foo.vyi"), + ("from b.a import Foo", "b/a/Foo.vyi"), + ("from .a import Foo", "./a/Foo.vyi"), + ("from ..a import Foo", "../a/Foo.vyi"), ] @@ -252,11 +251,12 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): BAD_IMPORT_CODE = [ - ("import a", StructureException), # must alias absolute imports - ("import a as A\nimport a as A", NamespaceCollision), + ("import a as A\nimport a as A", DuplicateImport), + ("import a as A\nimport a as a", DuplicateImport), + ("from . import a\nimport a as a", DuplicateImport), + ("import a as a\nfrom . import a", DuplicateImport), ("from b import a\nfrom . import a", NamespaceCollision), - ("from . import a\nimport a as a", NamespaceCollision), - ("import a as a\nfrom . import a", NamespaceCollision), + ("import a\nimport c as a", NamespaceCollision), ] @@ -264,34 +264,50 @@ def test_extract_file_interface_imports(code, filename, make_input_bundle): def test_extract_file_interface_imports_raises( code, exception_type, assert_compile_failed, make_input_bundle ): - input_bundle = make_input_bundle({"a.vy": "", "b/a.vy": ""}) # dummy - assert_compile_failed(lambda: compile_code(code, input_bundle=input_bundle), exception_type) + input_bundle = make_input_bundle({"a.vyi": "", "b/a.vyi": "", "c.vyi": ""}) + with pytest.raises(exception_type): + compile_code(code, input_bundle=input_bundle) def test_external_call_to_interface(w3, get_contract, make_input_bundle): + token_interface = """ +@view +@external +def balanceOf(addr: address) -> uint256: + ... + +@external +def transfer(to: address, amount: uint256): + ... + """ + token_code = """ +import itoken as IToken + +implements: IToken + balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256): - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256): + self.balanceOf[to] += amount """ - input_bundle = make_input_bundle({"one.vy": token_code}) + input_bundle = make_input_bundle({"token.vy": token_code, "itoken.vyi": token_interface}) code = """ -import one as TokenCode +import itoken as IToken interface EPI: def test() -> uint256: view -token_address: TokenCode +token_address: IToken @external def __init__(_token_address: address): - self.token_address = TokenCode(_token_address) + self.token_address = IToken(_token_address) @external @@ -299,14 +315,15 @@ def test(): self.token_address.transfer(msg.sender, 1000) """ - erc20 = get_contract(token_code) - test_c = get_contract(code, *[erc20.address], input_bundle=input_bundle) + token = get_contract(token_code, input_bundle=input_bundle) + + test_c = get_contract(code, *[token.address], input_bundle=input_bundle) sender = w3.eth.accounts[0] - assert erc20.balanceOf(sender) == 0 + assert token.balanceOf(sender) == 0 test_c.test(transact={}) - assert erc20.balanceOf(sender) == 1000 + assert token.balanceOf(sender) == 1000 @pytest.mark.parametrize( @@ -320,26 +337,36 @@ def test(): ], ) def test_external_call_to_interface_kwarg(get_contract, kwarg, typ, expected, make_input_bundle): - code_a = f""" + interface_code = f""" +@external +@view +def foo(_max: {typ} = {kwarg}) -> {typ}: + ... + """ + code1 = f""" +import one as IContract + +implements: IContract + @external @view def foo(_max: {typ} = {kwarg}) -> {typ}: return _max """ - input_bundle = make_input_bundle({"one.vy": code_a}) + input_bundle = make_input_bundle({"one.vyi": interface_code}) - code_b = f""" -import one as ContractA + code2 = f""" +import one as IContract @external @view def bar(a_address: address) -> {typ}: - return ContractA(a_address).foo() + return IContract(a_address).foo() """ - contract_a = get_contract(code_a) - contract_b = get_contract(code_b, *[contract_a.address], input_bundle=input_bundle) + contract_a = get_contract(code1, input_bundle=input_bundle) + contract_b = get_contract(code2, *[contract_a.address], input_bundle=input_bundle) assert contract_b.bar(contract_a.address) == expected @@ -349,8 +376,8 @@ def test_external_call_to_builtin_interface(w3, get_contract): balanceOf: public(HashMap[address, uint256]) @external -def transfer(to: address, _value: uint256) -> bool: - self.balanceOf[to] += _value +def transfer(to: address, amount: uint256) -> bool: + self.balanceOf[to] += amount return True """ @@ -510,14 +537,14 @@ def returns_Bytes3() -> Bytes[3]: """ should_not_compile = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface @external def foo(x: BadJSONInterface) -> Bytes[2]: return slice(x.returns_Bytes3(), 0, 2) """ code = """ -import BadJSONInterface as BadJSONInterface +import BadJSONInterface foo: BadJSONInterface @@ -578,10 +605,10 @@ def balanceOf(owner: address) -> uint256: @external @view def balanceOf(owner: address) -> uint256: - pass + ... """ - input_bundle = make_input_bundle({"balanceof.vy": interface_code}) + input_bundle = make_input_bundle({"balanceof.vyi": interface_code}) c = get_contract(code, input_bundle=input_bundle) @@ -592,7 +619,7 @@ def test_simple_implements(make_input_bundle): interface_code = """ @external def foo() -> uint256: - pass + ... """ code = """ @@ -605,7 +632,7 @@ def foo() -> uint256: return 1 """ - input_bundle = make_input_bundle({"a.vy": interface_code}) + input_bundle = make_input_bundle({"a.vyi": interface_code}) assert compile_code(code, input_bundle=input_bundle) is not None diff --git a/tests/functional/codegen/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py index 3302ff5009..27f82416d6 100644 --- a/tests/functional/codegen/test_selector_table_stability.py +++ b/tests/functional/codegen/test_selector_table_stability.py @@ -14,7 +14,7 @@ def test_dense_jumptable_stability(): # test that the selector table data is stable across different runs # (tox should provide different PYTHONHASHSEEDs). - expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501 + expected_asm = """{ DATA _sym_BUCKET_HEADERS b\'\\x0bB\' _sym_bucket_0 b\'\\n\' b\'+\\x8d\' _sym_bucket_1 b\'\\x0c\' b\'\\x00\\x85\' _sym_bucket_2 b\'\\x08\' } { DATA _sym_bucket_1 b\'\\xd8\\xee\\xa1\\xe8\' _sym_external 6 foo6()3639517672 b\'\\x05\' b\'\\xd2\\x9e\\xe0\\xf9\' _sym_external 0 foo0()3533627641 b\'\\x05\' b\'\\x05\\xf1\\xe0_\' _sym_external 2 foo2()99737695 b\'\\x05\' b\'\\x91\\t\\xb4{\' _sym_external 23 foo23()2433332347 b\'\\x05\' b\'np3\\x7f\' _sym_external 11 foo11()1852846975 b\'\\x05\' b\'&\\xf5\\x96\\xf9\' _sym_external 13 foo13()653629177 b\'\\x05\' b\'\\x04ga\\xeb\' _sym_external 14 foo14()73884139 b\'\\x05\' b\'\\x89\\x06\\xad\\xc6\' _sym_external 17 foo17()2298916294 b\'\\x05\' b\'\\xe4%\\xac\\xd1\' _sym_external 4 foo4()3827674321 b\'\\x05\' b\'yj\\x01\\xac\' _sym_external 7 foo7()2036990380 b\'\\x05\' b\'\\xf1\\xe6K\\xe5\' _sym_external 29 foo29()4058401765 b\'\\x05\' b\'\\xd2\\x89X\\xb8\' _sym_external 3 foo3()3532216504 b\'\\x05\' } { DATA _sym_bucket_2 b\'\\x06p\\xffj\' _sym_external 25 foo25()108068714 b\'\\x05\' b\'\\x964\\x99I\' _sym_external 24 foo24()2520029513 b\'\\x05\' b\'s\\x81\\xe7\\xc1\' _sym_external 10 foo10()1937893313 b\'\\x05\' b\'\\x85\\xad\\xc11\' _sym_external 28 foo28()2242756913 b\'\\x05\' b\'\\xfa"\\xb1\\xed\' _sym_external 5 foo5()4196577773 b\'\\x05\' b\'A\\xe7[\\x05\' _sym_external 22 foo22()1105681157 b\'\\x05\' b\'\\xd3\\x89U\\xe8\' _sym_external 1 foo1()3548993000 b\'\\x05\' b\'hL\\xf8\\xf3\' _sym_external 20 foo20()1749874931 b\'\\x05\' } { DATA _sym_bucket_0 b\'\\xee\\xd9\\x1d\\xe3\' _sym_external 9 foo9()4007206371 b\'\\x05\' b\'a\\xbc\\x1ch\' _sym_external 16 foo16()1639717992 b\'\\x05\' b\'\\xd3*\\xa7\\x0c\' _sym_external 21 foo21()3542787852 b\'\\x05\' b\'\\x18iG\\xd9\' _sym_external 19 foo19()409552857 b\'\\x05\' b\'\\n\\xf1\\xf9\\x7f\' _sym_external 18 foo18()183630207 b\'\\x05\' b\')\\xda\\xd7`\' _sym_external 27 foo27()702207840 b\'\\x05\' b\'2\\xf6\\xaa\\xda\' _sym_external 12 foo12()855026394 b\'\\x05\' b\'\\xbe\\xb5\\x05\\xf5\' _sym_external 15 foo15()3199534581 b\'\\x05\' b\'\\xfc\\xa7_\\xe6\' _sym_external 8 foo8()4238827494 b\'\\x05\' b\'\\x1b\\x12C8\' _sym_external 26 foo26()454181688 b\'\\x05\' } }""" # noqa: E501 assert expected_asm in output["asm"] diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py new file mode 100644 index 0000000000..8e634e5868 --- /dev/null +++ b/tests/functional/codegen/test_stateless_modules.py @@ -0,0 +1,335 @@ +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +from vyper import compiler +from vyper.exceptions import ( + CallViolation, + DuplicateImport, + ImportCycle, + StructureException, + TypeMismatch, +) + +# test modules which have no variables - "libraries" + + +def test_simple_library(get_contract, make_input_bundle, w3): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + main = """ +import library + +@external +def bar() -> uint256: + return library.foo() - 1 + """ + input_bundle = make_input_bundle({"library.vy": library_source}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == w3.eth.block_number + + +# is this the best place for this? +def test_import_cycle(make_input_bundle): + code_a = "import b\n" + code_b = "import a\n" + + input_bundle = make_input_bundle({"a.vy": code_a, "b.vy": code_b}) + + with pytest.raises(ImportCycle): + compiler.compile_code(code_a, input_bundle=input_bundle) + + +# test we can have a function in the library with the same name as +# in the main contract +def test_library_function_same_name(get_contract, make_input_bundle): + library = """ +@internal +def foo() -> uint256: + return 10 + """ + + main = """ +import library + +@internal +def foo() -> uint256: + return 100 + +@external +def self_foo() -> uint256: + return self.foo() + +@external +def library_foo() -> uint256: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library}) + + c = get_contract(main, input_bundle=input_bundle) + + assert c.self_foo() == 100 + assert c.library_foo() == 10 + + +def test_transitive_import(get_contract, make_input_bundle): + a = """ +@internal +def foo() -> uint256: + return 1 + """ + b = """ +import a + +@internal +def bar() -> uint256: + return a.foo() + 1 + """ + c = """ +import b + +@external +def baz() -> uint256: + return b.bar() + 1 + """ + # more complicated call graph, with `a` imported twice. + d = """ +import b +import a + +@external +def qux() -> uint256: + s: uint256 = a.foo() + return s + b.bar() + 1 + """ + input_bundle = make_input_bundle({"a.vy": a, "b.vy": b, "c.vy": c, "d.vy": d}) + + contract = get_contract(c, input_bundle=input_bundle) + assert contract.baz() == 3 + contract = get_contract(d, input_bundle=input_bundle) + assert contract.qux() == 4 + + +def test_cannot_call_library_external_functions(make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(CallViolation): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_external_functions_not_in_abi(get_contract, make_input_bundle): + library_source = """ +@external +def foo(): + pass + """ + contract_source = """ +import library + +@external +def bar(): + pass + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + assert not hasattr(c, "foo") + + +def test_library_structs(get_contract, make_input_bundle): + library_source = """ +struct SomeStruct: + x: uint256 + +@internal +def foo() -> SomeStruct: + return SomeStruct({x: 1}) + """ + contract_source = """ +import library + +@external +def bar(s: library.SomeStruct): + pass + +@external +def baz() -> library.SomeStruct: + return library.SomeStruct({x: 2}) + +@external +def qux() -> library.SomeStruct: + return library.foo() + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + c = get_contract(contract_source, input_bundle=input_bundle) + + assert c.bar((1,)) == [] + + assert c.baz() == (2,) + assert c.qux() == (1,) + + +# test calls to library functions in statement position +def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): + library_source = """ +from vyper.interfaces import ERC20 +@internal +def check_adds_to_ten(x: uint256, y: uint256): + assert x + y == 10 + """ + contract_source = """ +import library + +counter: public(uint256) + +@external +def foo(x: uint256): + library.check_adds_to_ten(3, x) + self.counter = x + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + + c = get_contract(contract_source, input_bundle=input_bundle) + + c.foo(7, transact={}) + + assert c.counter() == 7 + + assert_tx_failed(lambda: c.foo(8)) + + +def test_library_is_typechecked(make_input_bundle): + library_source = """ +@internal +def foo(): + asdlkfjasdflkajsdf + """ + contract_source = """ +import library + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(StructureException): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_library_is_typechecked2(make_input_bundle): + # check that we typecheck against imported function signatures + library_source = """ +@internal +def foo() -> uint256: + return 1 + """ + contract_source = """ +import library + +@external +def foo() -> bytes32: + return library.foo() + """ + + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(TypeMismatch): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_reject_duplicate_imports(make_input_bundle): + library_source = """ + """ + + contract_source = """ +import library +import library as library2 + """ + input_bundle = make_input_bundle({"library.vy": library_source, "contract.vy": contract_source}) + with pytest.raises(DuplicateImport): + compiler.compile_code(contract_source, input_bundle=input_bundle) + + +def test_nested_module_access(get_contract, make_input_bundle): + lib1 = """ +import lib2 + +@internal +def lib2_foo() -> uint256: + return lib2.foo() + """ + lib2 = """ +@internal +def foo() -> uint256: + return 1337 + """ + + main = """ +import lib1 +import lib2 + +@external +def lib1_foo() -> uint256: + return lib1.lib2_foo() + +@external +def lib2_foo() -> uint256: + return lib1.lib2.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.lib1_foo() == c.lib2_foo() == 1337 + + +_int_127 = st.integers(min_value=0, max_value=127) +_bytes_128 = st.binary(min_size=0, max_size=128) + + +def test_slice_builtin(get_contract, make_input_bundle): + lib = """ +@internal +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + + main = """ +import lib +@external +def lib_slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return lib.slice_input(x, start, length) + +@external +def slice_input(x: Bytes[128], start: uint256, length: uint256) -> Bytes[128]: + return slice(x, start, length) + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + # use an inner test so that we can cache the result of get_contract() + @given(start=_int_127, length=_int_127, bytesdata=_bytes_128) + @settings(max_examples=100) + def _test(bytesdata, start, length): + # surjectively map start into allowable range + if start > len(bytesdata): + start = start % (len(bytesdata) or 1) + # surjectively map length into allowable range + if length > (len(bytesdata) - start): + length = length % ((len(bytesdata) - start) or 1) + main_result = c.slice_input(bytesdata, start, length) + library_result = c.lib_slice_input(bytesdata, start, length) + assert main_result == library_result == bytesdata[start : start + length] + + _test() diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index aa0286cfa5..7dd8c35929 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -92,7 +92,7 @@ def from_grammar() -> st.SearchStrategy[str]: # Avoid examples with *only* single or double quote docstrings -# because they trigger a trivial compiler bug +# because they trigger a trivial parser bug SINGLE_QUOTE_DOCSTRING = re.compile(r"^'''.*'''$") DOUBLE_QUOTE_DOCSTRING = re.compile(r'^""".*"""$') diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 9100389dbd..a672ed7b88 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -376,17 +376,12 @@ def test_interfaces_success(good_code): def test_imports_and_implements_within_interface(make_input_bundle): interface_code = """ -from vyper.interfaces import ERC20 -import foo.bar as Baz - -implements: Baz - @external def foobar(): - pass + ... """ - input_bundle = make_input_bundle({"foo.vy": interface_code}) + input_bundle = make_input_bundle({"foo.vyi": interface_code}) code = """ import foo as Foo diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 47483c493c..d413340083 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -37,9 +37,9 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code): +def test_invalid_checksum(code, dummy_input_bundle): vyper_module = vy_ast.parse_to_ast(code) with pytest.raises(InvalidLiteral): vy_ast.validation.validate_literal_nodes(vyper_module) - semantics.validate_semantics(vyper_module, {}) + semantics.validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 68a07178bb..16ce6fe631 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -1,7 +1,6 @@ import ast as python_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse +from vyper.ast.parse import annotate_python_ast, pre_parse class AssertionVisitor(python_ast.NodeVisitor): diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 1f60c9ac8b..dc49f72561 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,7 +1,8 @@ import json from vyper import compiler -from vyper.ast.utils import ast_to_dict, dict_to_ast, parse_to_ast +from vyper.ast.parse import parse_to_ast +from vyper.ast.utils import ast_to_dict, dict_to_ast def get_node_ids(ast_struct, ids=None): @@ -40,7 +41,7 @@ def test_basic_ast(): code = """ a: int128 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][0] == { "annotation": { "ast_type": "Name", @@ -89,7 +90,7 @@ def foo() -> uint256: view def foo() -> uint256: return 1 """ - dict_out = compiler.compile_code(code, output_formats=["ast_dict"]) + dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0) assert dict_out["ast_dict"]["ast"]["body"][1] == { "col_offset": 0, "annotation": { diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index c47bf40bfa..e0bfcbc2ef 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,4 +1,4 @@ -from vyper.ast.utils import parse_to_ast +from vyper.ast.parse import parse_to_ast def test_ast_equal(): diff --git a/tests/unit/cli/outputs/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py similarity index 100% rename from tests/unit/cli/outputs/test_storage_layout.py rename to tests/unit/cli/storage_layout/test_storage_layout.py diff --git a/tests/unit/cli/outputs/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py similarity index 98% rename from tests/unit/cli/outputs/test_storage_layout_overrides.py rename to tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 94e0faeb37..f4c11b7ae6 100644 --- a/tests/unit/cli/outputs/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -103,7 +103,7 @@ def test_overflow(): storage_layout_override = {"x": {"slot": 2**256 - 1, "type": "uint256[2]"}} with pytest.raises( - StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}\n" + StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}" ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a16efa777..f6e3a51a4b 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -30,93 +30,100 @@ def test_invalid_root_path(): compile_files([], [], root_folder="path/that/does/not/exist") -FOO_CODE = """ -{} - -struct FooStruct: - foo_: uint256 +CONTRACT_CODE = """ +{import_stmt} @external -def foo() -> FooStruct: - return FooStruct({{foo_: 13}}) +def foo() -> {alias}.FooStruct: + return {alias}.FooStruct({{foo_: 13}}) @external -def bar(a: address) -> FooStruct: - return {}(a).bar() +def bar(a: address) -> {alias}.FooStruct: + return {alias}(a).bar() """ -BAR_CODE = """ +INTERFACE_CODE = """ struct FooStruct: foo_: uint256 + +@external +def foo() -> FooStruct: + ... + @external def bar() -> FooStruct: - return FooStruct({foo_: 13}) + ... """ SAME_FOLDER_IMPORT_STMT = [ - ("import Bar as Bar", "Bar"), - ("import contracts.Bar as Bar", "Bar"), - ("from . import Bar", "Bar"), - ("from contracts import Bar", "Bar"), - ("from ..contracts import Bar", "Bar"), - ("from . import Bar as FooBar", "FooBar"), - ("from contracts import Bar as FooBar", "FooBar"), - ("from ..contracts import Bar as FooBar", "FooBar"), + ("import IFoo as IFoo", "IFoo"), + ("import contracts.IFoo as IFoo", "IFoo"), + ("from . import IFoo", "IFoo"), + ("from contracts import IFoo", "IFoo"), + ("from ..contracts import IFoo", "IFoo"), + ("from . import IFoo as FooBar", "FooBar"), + ("from contracts import IFoo as FooBar", "FooBar"), + ("from ..contracts import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) def test_import_same_folder(import_stmt, alias, tmp_path, make_file): foo = "contracts/foo.vy" - make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("contracts/Bar.vy", BAR_CODE) + make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) SUBFOLDER_IMPORT_STMT = [ - ("import other.Bar as Bar", "Bar"), - ("import contracts.other.Bar as Bar", "Bar"), - ("from other import Bar", "Bar"), - ("from contracts.other import Bar", "Bar"), - ("from .other import Bar", "Bar"), - ("from ..contracts.other import Bar", "Bar"), - ("from other import Bar as FooBar", "FooBar"), - ("from contracts.other import Bar as FooBar", "FooBar"), - ("from .other import Bar as FooBar", "FooBar"), - ("from ..contracts.other import Bar as FooBar", "FooBar"), + ("import other.IFoo as IFoo", "IFoo"), + ("import contracts.other.IFoo as IFoo", "IFoo"), + ("from other import IFoo", "IFoo"), + ("from contracts.other import IFoo", "IFoo"), + ("from .other import IFoo", "IFoo"), + ("from ..contracts.other import IFoo", "IFoo"), + ("from other import IFoo as FooBar", "FooBar"), + ("from contracts.other import IFoo as FooBar", "FooBar"), + ("from .other import IFoo as FooBar", "FooBar"), + ("from ..contracts.other import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) def test_import_subfolder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", (FOO_CODE.format(import_stmt, alias))) - make_file("contracts/other/Bar.vy", BAR_CODE) + foo = make_file( + "contracts/foo.vy", (CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + ) + make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) OTHER_FOLDER_IMPORT_STMT = [ - ("import interfaces.Bar as Bar", "Bar"), - ("from interfaces import Bar", "Bar"), - ("from ..interfaces import Bar", "Bar"), - ("from interfaces import Bar as FooBar", "FooBar"), - ("from ..interfaces import Bar as FooBar", "FooBar"), + ("import interfaces.IFoo as IFoo", "IFoo"), + ("from interfaces import IFoo", "IFoo"), + ("from ..interfaces import IFoo", "IFoo"), + ("from interfaces import IFoo as FooBar", "FooBar"), + ("from ..interfaces import IFoo as FooBar", "FooBar"), ] @pytest.mark.parametrize("import_stmt, alias", OTHER_FOLDER_IMPORT_STMT) def test_import_other_folder(import_stmt, alias, tmp_path, make_file): - foo = make_file("contracts/foo.vy", FOO_CODE.format(import_stmt, alias)) - make_file("interfaces/Bar.vy", BAR_CODE) + foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) + make_file("interfaces/IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) def test_import_parent_folder(tmp_path, make_file): - foo = make_file("contracts/baz/foo.vy", FOO_CODE.format("from ... import Bar", "Bar")) - make_file("Bar.vy", BAR_CODE) + foo = make_file( + "contracts/baz/foo.vy", + CONTRACT_CODE.format(import_stmt="from ... import IFoo", alias="IFoo"), + ) + make_file("IFoo.vyi", INTERFACE_CODE) assert compile_files([foo], ["combined_json"], root_folder=tmp_path) @@ -125,62 +132,60 @@ def test_import_parent_folder(tmp_path, make_file): META_IMPORT_STMT = [ - "import Meta as Meta", - "import contracts.Meta as Meta", - "from . import Meta", - "from contracts import Meta", + "import ISelf as ISelf", + "import contracts.ISelf as ISelf", + "from . import ISelf", + "from contracts import ISelf", ] @pytest.mark.parametrize("import_stmt", META_IMPORT_STMT) def test_import_self_interface(import_stmt, tmp_path, make_file): - # a contract can access its derived interface by importing itself - code = f""" -{import_stmt} - + interface_code = """ struct FooStruct: foo_: uint256 @external def know_thyself(a: address) -> FooStruct: - return Meta(a).be_known() + ... @external def be_known() -> FooStruct: - return FooStruct({{foo_: 42}}) + ... """ - meta = make_file("contracts/Meta.vy", code) - - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + code = f""" +{import_stmt} +@external +def know_thyself(a: address) -> ISelf.FooStruct: + return ISelf(a).be_known() -DERIVED_IMPORT_STMT_BAZ = ["import Foo as Foo", "from . import Foo"] +@external +def be_known() -> ISelf.FooStruct: + return ISelf.FooStruct({{foo_: 42}}) + """ + make_file("contracts/ISelf.vyi", interface_code) + meta = make_file("contracts/Self.vy", code) -DERIVED_IMPORT_STMT_FOO = ["import Bar as Bar", "from . import Bar"] + assert compile_files([meta], ["combined_json"], root_folder=tmp_path) -@pytest.mark.parametrize("import_stmt_baz", DERIVED_IMPORT_STMT_BAZ) -@pytest.mark.parametrize("import_stmt_foo", DERIVED_IMPORT_STMT_FOO) -def test_derived_interface_imports(import_stmt_baz, import_stmt_foo, tmp_path, make_file): - # contracts-as-interfaces should be able to contain import statements +# implement IFoo in another contract for fun +@pytest.mark.parametrize("import_stmt_foo,alias", SAME_FOLDER_IMPORT_STMT) +def test_another_interface_implementation(import_stmt_foo, alias, tmp_path, make_file): baz_code = f""" -{import_stmt_baz} - -struct FooStruct: - foo_: uint256 +{import_stmt_foo} @external -def foo(a: address) -> FooStruct: - return Foo(a).foo() +def foo(a: address) -> {alias}.FooStruct: + return {alias}(a).foo() @external -def bar(_foo: address, _bar: address) -> FooStruct: - return Foo(_foo).bar(_bar) +def bar(_foo: address) -> {alias}.FooStruct: + return {alias}(_foo).bar() """ - - make_file("Foo.vy", FOO_CODE.format(import_stmt_foo, "Bar")) - make_file("Bar.vy", BAR_CODE) - baz = make_file("Baz.vy", baz_code) + make_file("contracts/IFoo.vyi", INTERFACE_CODE) + baz = make_file("contracts/Baz.vy", baz_code) assert compile_files([baz], ["combined_json"], root_folder=tmp_path) @@ -207,15 +212,36 @@ def test_local_namespace(make_file, tmp_path): make_file(filename, code) paths.append(filename) - for file_name in ("foo.vy", "bar.vy"): - make_file(file_name, BAR_CODE) + for file_name in ("foo.vyi", "bar.vyi"): + make_file(file_name, INTERFACE_CODE) assert compile_files(paths, ["combined_json"], root_folder=tmp_path) def test_compile_outside_root_path(tmp_path, make_file): # absolute paths relative to "." - foo = make_file("foo.vy", FOO_CODE.format("import bar as Bar", "Bar")) - bar = make_file("bar.vy", BAR_CODE) + make_file("ifoo.vyi", INTERFACE_CODE) + foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) + + assert compile_files([foo], ["combined_json"], root_folder=".") + + +def test_import_library(tmp_path, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + + make_file("lib.vy", library_source) + contract_file = make_file("contract.vy", contract_source) - assert compile_files([foo, bar], ["combined_json"], root_folder=".") + assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 732762d72b..a50946ba21 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -1,30 +1,55 @@ import json +from pathlib import PurePath import pytest import vyper -from vyper.cli.vyper_json import compile_from_input_dict, compile_json, exc_handler_to_dict -from vyper.compiler import OUTPUT_FORMATS, compile_code +from vyper.cli.vyper_json import ( + compile_from_input_dict, + compile_json, + exc_handler_to_dict, + get_inputs, +) +from vyper.compiler import OUTPUT_FORMATS, compile_code, compile_from_file_input +from vyper.compiler.input_bundle import JSONInputBundle from vyper.exceptions import InvalidType, JSONError, SyntaxException FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar + +import contracts.library as library @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) @external def baz() -> uint256: - return self.balance + return self.balance + library.foo() """ BAR_CODE = """ +import contracts.ibar as IBar + +implements: IBar + @external def bar(a: uint256) -> bool: return True """ +BAR_VYI = """ +@external +def bar(a: uint256) -> bool: + ... +""" + +LIBRARY_CODE = """ +@internal +def foo() -> uint256: + return block.number + 1 +""" + BAD_SYNTAX_CODE = """ def bar()>: """ @@ -52,6 +77,7 @@ def input_json(): "language": "Vyper", "sources": { "contracts/foo.vy": {"content": FOO_CODE}, + "contracts/library.vy": {"content": LIBRARY_CODE}, "contracts/bar.vy": {"content": BAR_CODE}, }, "interfaces": {"contracts/ibar.json": {"abi": BAR_ABI}}, @@ -59,6 +85,14 @@ def input_json(): } +@pytest.fixture(scope="function") +def input_bundle(input_json): + # CMC 2023-12-11 maybe input_json -> JSONInputBundle should be a helper + # function in `vyper_json.py`. + sources = get_inputs(input_json) + return JSONInputBundle(sources, search_paths=[PurePath(".")]) + + # test string and dict inputs both work def test_string_input(input_json): assert compile_json(input_json) == compile_json(json.dumps(input_json)) @@ -77,29 +111,39 @@ def test_keyerror_becomes_jsonerror(input_json): compile_json(input_json) -def test_compile_json(input_json, make_input_bundle): - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) +def test_compile_json(input_json, input_bundle): + foo_input = input_bundle.load_file("contracts/foo.vy") + foo = compile_from_file_input( + foo_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle + ) - foo = compile_code( - FOO_CODE, - source_id=0, - contract_name="contracts/foo.vy", - output_formats=OUTPUT_FORMATS, - input_bundle=input_bundle, + library_input = input_bundle.load_file("contracts/library.vy") + library = compile_from_file_input( + library_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - bar = compile_code( - BAR_CODE, source_id=1, contract_name="contracts/bar.vy", output_formats=OUTPUT_FORMATS + + bar_input = input_bundle.load_file("contracts/bar.vy") + bar = compile_from_file_input( + bar_input, output_formats=OUTPUT_FORMATS, input_bundle=input_bundle ) - compile_code_results = {"contracts/bar.vy": bar, "contracts/foo.vy": foo} + compile_code_results = { + "contracts/bar.vy": bar, + "contracts/library.vy": library, + "contracts/foo.vy": foo, + } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" - for source_id, contract_name in enumerate(["foo", "bar"]): + for source_id, contract_name in [(0, "foo"), (2, "library"), (3, "bar")]: path = f"contracts/{contract_name}.vy" data = compile_code_results[path] assert output_json["sources"][path] == {"id": source_id, "ast": data["ast_dict"]["ast"]} @@ -123,13 +167,28 @@ def test_compile_json(input_json, make_input_bundle): } -def test_different_outputs(make_input_bundle, input_json): +def test_compilation_targets(input_json): + output_json = compile_json(input_json) + assert list(output_json["contracts"].keys()) == [ + "contracts/foo.vy", + "contracts/library.vy", + "contracts/bar.vy", + ] + + # omit library.vy + input_json["settings"]["outputSelection"] = {"contracts/foo.vy": "*", "contracts/bar.vy": "*"} + output_json = compile_json(input_json) + + assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + + +def test_different_outputs(input_bundle, input_json): input_json["settings"]["outputSelection"] = { "contracts/bar.vy": "*", "contracts/foo.vy": ["evm.methodIdentifiers"], } output_json = compile_json(input_json) - assert list(output_json["contracts"].keys()) == ["contracts/foo.vy", "contracts/bar.vy"] + assert list(output_json["contracts"].keys()) == ["contracts/bar.vy", "contracts/foo.vy"] assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"] assert output_json["compiler"] == f"vyper-{vyper.__version__}" @@ -143,10 +202,9 @@ def test_different_outputs(make_input_bundle, input_json): assert sorted(foo.keys()) == ["evm"] # check method_identifiers - input_bundle = make_input_bundle({"contracts/bar.vy": BAR_CODE}) method_identifiers = compile_code( FOO_CODE, - contract_name="contracts/foo.vy", + contract_path="contracts/foo.vy", output_formats=["method_identifiers"], input_bundle=input_bundle, )["method_identifiers"] @@ -204,11 +262,12 @@ def get(filename, contractname): return result["contracts"][filename][contractname]["evm"]["deployedBytecode"]["sourceMap"] assert get("contracts/foo.vy", "foo").startswith("-1:-1:0") - assert get("contracts/bar.vy", "bar").startswith("-1:-1:1") + assert get("contracts/library.vy", "library").startswith("-1:-1:2") + assert get("contracts/bar.vy", "bar").startswith("-1:-1:3") def test_relative_import_paths(input_json): - input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": """from ... import foo"""} - input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""} - input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""} + input_json["sources"]["contracts/potato/baz/baz.vy"] = {"content": "from ... import foo"} + input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": "from . import baz"} + input_json["sources"]["contracts/potato/footato.vy"] = {"content": "from baz import baz"} compile_from_input_dict(input_json) diff --git a/tests/unit/cli/vyper_json/test_get_inputs.py b/tests/unit/cli/vyper_json/test_get_inputs.py index 6e323a91bd..c91cc750f2 100644 --- a/tests/unit/cli/vyper_json/test_get_inputs.py +++ b/tests/unit/cli/vyper_json/test_get_inputs.py @@ -2,7 +2,7 @@ import pytest -from vyper.cli.vyper_json import get_compilation_targets, get_inputs +from vyper.cli.vyper_json import get_inputs from vyper.exceptions import JSONError from vyper.utils import keccak256 @@ -122,9 +122,6 @@ def test_interfaces_output(): "interface.folder/bar2.vy": {"content": BAR_CODE}, }, } - targets = get_compilation_targets(input_json) - assert targets == [PurePath("foo.vy")] - result = get_inputs(input_json) assert result == { PurePath("foo.vy"): {"content": FOO_CODE}, diff --git a/tests/unit/cli/vyper_json/test_output_selection.py b/tests/unit/cli/vyper_json/test_output_selection.py index 78ad7404f2..5383190a66 100644 --- a/tests/unit/cli/vyper_json/test_output_selection.py +++ b/tests/unit/cli/vyper_json/test_output_selection.py @@ -8,53 +8,61 @@ def test_no_outputs(): with pytest.raises(KeyError): - get_output_formats({}, {}) + get_output_formats({}) def test_invalid_output(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "foobar"]}}, + } with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) def test_unknown_contract(): - input_json = {"settings": {"outputSelection": {"bar.vy": ["abi"]}}} - targets = [PurePath("foo.vy")] + input_json = {"sources": {}, "settings": {"outputSelection": {"bar.vy": ["abi"]}}} with pytest.raises(JSONError): - get_output_formats(input_json, targets) + get_output_formats(input_json) @pytest.mark.parametrize("output", TRANSLATE_MAP.items()) def test_translate_map(output): - input_json = {"settings": {"outputSelection": {"foo.vy": [output[0]]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): [output[1]]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": [output[0]]}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): [output[1]]} def test_star(): - input_json = {"settings": {"outputSelection": {"*": ["*"]}}} - targets = [PurePath("foo.vy"), PurePath("bar.vy")] + input_json = { + "sources": {"foo.vy": "", "bar.vy": ""}, + "settings": {"outputSelection": {"*": ["*"]}}, + } expected = sorted(set(TRANSLATE_MAP.values())) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected, PurePath("bar.vy"): expected} def test_evm(): - input_json = {"settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}} - targets = [PurePath("foo.vy")] + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": ["abi", "evm"]}}, + } expected = ["abi"] + sorted(v for k, v in TRANSLATE_MAP.items() if k.startswith("evm")) - result = get_output_formats(input_json, targets) + result = get_output_formats(input_json) assert result == {PurePath("foo.vy"): expected} def test_solc_style(): - input_json = {"settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["abi", "ir_dict"]} + input_json = { + "sources": {"foo.vy": ""}, + "settings": {"outputSelection": {"foo.vy": {"": ["abi"], "foo.vy": ["ir"]}}}, + } + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["abi", "ir_dict"]} def test_metadata(): - input_json = {"settings": {"outputSelection": {"*": ["metadata"]}}} - targets = [PurePath("foo.vy")] - assert get_output_formats(input_json, targets) == {PurePath("foo.vy"): ["metadata"]} + input_json = {"sources": {"foo.vy": ""}, "settings": {"outputSelection": {"*": ["metadata"]}}} + assert get_output_formats(input_json) == {PurePath("foo.vy"): ["metadata"]} diff --git a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py index 3b0f700c7e..6b509dd3ef 100644 --- a/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py +++ b/tests/unit/cli/vyper_json/test_parse_args_vyperjson.py @@ -9,11 +9,11 @@ from vyper.exceptions import JSONError FOO_CODE = """ -import contracts.bar as Bar +import contracts.ibar as IBar @external def foo(a: address) -> bool: - return Bar(a).bar(1) + return IBar(a).bar(1) """ BAR_CODE = """ diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 47b70a8c70..44b823757c 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -1,5 +1,6 @@ import pytest +from vyper.compiler import compile_code from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings @@ -71,33 +72,61 @@ def __init__(): ] +# check dead code eliminator works on unreachable functions @pytest.mark.parametrize("code", codes) def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime - ctor_only_label = "_sym_internal_ctor_only___" - runtime_only_label = "_sym_internal_runtime_only___" + # get the labels + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] + + ctor_only = "ctor_only()" + runtime_only = "runtime_only()" # qux reachable from unoptimized initcode, foo not reachable. - assert ctor_only_label + "_deploy" in initcode_asm - assert runtime_only_label + "_deploy" not in initcode_asm + assert any(ctor_only in instr for instr in initcode_asm) + assert all(runtime_only not in instr for instr in initcode_asm) # all labels should be in unoptimized runtime asm - for s in (ctor_only_label, runtime_only_label): - assert s + "_runtime" in runtime_asm + for s in (ctor_only, runtime_only): + assert any(s in instr for instr in runtime_asm) c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS)) - initcode_asm = [i for i in c.assembly if not isinstance(i, list)] - runtime_asm = c.assembly_runtime + initcode_asm = [i for i in c.assembly if isinstance(i, str)] + runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] # ctor only label should not be in runtime code - for instr in runtime_asm: - if isinstance(instr, str): - assert not instr.startswith(ctor_only_label), instr + assert all(ctor_only not in instr for instr in runtime_asm) # runtime only label should not be in initcode asm - for instr in initcode_asm: - if isinstance(instr, str): - assert not instr.startswith(runtime_only_label), instr + assert all(runtime_only not in instr for instr in initcode_asm) + + +def test_library_code_eliminator(make_input_bundle): + library = """ +@internal +def unused1(): + pass + +@internal +def unused2(): + self.unused1() + +@internal +def some_function(): + pass + """ + code = """ +import library + +@external +def foo(): + library.some_function() + """ + input_bundle = make_input_bundle({"library.vy": library}) + res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) + asm = res["asm"] + assert "some_function()" in asm + assert "unused1()" not in asm + assert "unused2()" not in asm diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index c49c81219b..e26555b169 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,4 +1,6 @@ +import contextlib import json +import os from pathlib import Path, PurePath import pytest @@ -12,19 +14,19 @@ def input_bundle(tmp_path): return FilesystemInputBundle([tmp_path]) -def test_load_file(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") +def test_load_file(make_file, input_bundle): + filepath = make_file("foo.vy", "contents") file = input_bundle.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_context_manager(make_file, tmp_path): ib = FilesystemInputBundle([]) - make_file("foo.vy", "contents") + filepath = make_file("foo.vy", "contents") with pytest.raises(FileNotFoundError): # no search path given @@ -34,7 +36,7 @@ def test_search_path_context_manager(make_file, tmp_path): file = ib.load_file(Path("foo.vy")) assert isinstance(file, FileInput) - assert file == FileInput(0, tmp_path / Path("foo.vy"), "contents") + assert file == FileInput(0, Path("foo.vy"), filepath, "contents") def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bundle): @@ -43,59 +45,85 @@ def test_search_path_precedence(make_file, tmp_path, tmp_path_factory, input_bun tmpdir = tmp_path_factory.mktemp("some_directory") tmpdir2 = tmp_path_factory.mktemp("some_other_directory") + filepaths = [] for i, directory in enumerate([tmp_path, tmpdir, tmpdir2]): - with (directory / "foo.vy").open("w") as f: + path = directory / "foo.vy" + with path.open("w") as f: f.write(f"contents {i}") + filepaths.append(path) ib = FilesystemInputBundle([tmp_path, tmpdir, tmpdir2]) file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(0, tmpdir2 / "foo.vy", "contents 2") + assert file == FileInput(0, "foo.vy", filepaths[2], "contents 2") with ib.search_path(tmpdir): file = ib.load_file("foo.vy") assert isinstance(file, FileInput) - assert file == FileInput(1, tmpdir / "foo.vy", "contents 1") + assert file == FileInput(1, "foo.vy", filepaths[1], "contents 1") # special rules for handling json files def test_load_abi(make_file, input_bundle, tmp_path): contents = json.dumps("some string") - make_file("foo.json", contents) + path = make_file("foo.json", contents) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", path, "some string") # suffix doesn't matter - make_file("foo.txt", contents) - + path = make_file("foo.txt", contents) file = input_bundle.load_file("foo.txt") assert isinstance(file, ABIInput) - assert file == ABIInput(1, tmp_path / "foo.txt", "some string") + assert file == ABIInput(1, "foo.txt", path, "some string") + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") - make_file("bar.vy", "contents 2") + foopath = make_file("foo.vy", "contents") + barpath = make_file("bar.vy", "contents 2") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") file2 = input_bundle.load_file("bar.vy") # source id increments assert file2.source_id == 1 - assert file2 == FileInput(1, tmp_path / "bar.vy", "contents 2") + assert file2 == FileInput(1, "bar.vy", barpath, "contents 2") file3 = input_bundle.load_file("foo.vy") assert file3.source_id == 0 - assert file3 == FileInput(0, tmp_path / "foo.vy", "contents") + assert file3 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.vy") + assert file4.source_id == 0 + assert file4 == FileInput(0, "foo.vy", foopath, "contents") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.vy") + assert file5.source_id == 0 + assert file5 == FileInput(0, Path(tmp_path.stem) / "foo.vy", foopath, "contents") # check that unique paths give unique source ids @@ -103,37 +131,51 @@ def test_source_id_json_input(make_file, input_bundle, tmp_path): contents = json.dumps("some string") contents2 = json.dumps(["some list"]) - make_file("foo.json", contents) + foopath = make_file("foo.json", contents) - make_file("bar.json", contents2) + barpath = make_file("bar.json", contents2) file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, tmp_path / "foo.json", "some string") + assert file == ABIInput(0, "foo.json", foopath, "some string") file2 = input_bundle.load_file("bar.json") assert isinstance(file2, ABIInput) - assert file2 == ABIInput(1, tmp_path / "bar.json", ["some list"]) + assert file2 == ABIInput(1, "bar.json", barpath, ["some list"]) file3 = input_bundle.load_file("foo.json") - assert isinstance(file3, ABIInput) - assert file3 == ABIInput(0, tmp_path / "foo.json", "some string") + assert file3.source_id == 0 + assert file3 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable across different search paths + with working_directory(tmp_path): + with input_bundle.search_path(Path(".")): + file4 = input_bundle.load_file("foo.json") + assert file4.source_id == 0 + assert file4 == ABIInput(0, "foo.json", foopath, "some string") + + # test source id is stable even when requested filename is different + with working_directory(tmp_path.parent): + with input_bundle.search_path(Path(".")): + file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.json") + assert file5.source_id == 0 + assert file5 == ABIInput(0, Path(tmp_path.stem) / "foo.json", foopath, "some string") # test some pathological case where the file changes underneath def test_mutating_file_source_id(make_file, input_bundle, tmp_path): - make_file("foo.vy", "contents") + foopath = make_file("foo.vy", "contents") file = input_bundle.load_file("foo.vy") assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "contents") + assert file == FileInput(0, "foo.vy", foopath, "contents") - make_file("foo.vy", "new contents") + foopath = make_file("foo.vy", "new contents") file = input_bundle.load_file("foo.vy") # source id hasn't changed, even though contents have assert file.source_id == 0 - assert file == FileInput(0, tmp_path / "foo.vy", "new contents") + assert file == FileInput(0, "foo.vy", foopath, "new contents") # test the os.normpath behavior of symlink @@ -147,10 +189,12 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): dir2.mkdir() symlink.symlink_to(dir2, target_is_directory=True) - with (tmp_path / "foo.vy").open("w") as f: - f.write("contents of the upper directory") + outer_path = tmp_path / "foo.vy" + with outer_path.open("w") as f: + f.write("contents of the outer directory") - with (dir1 / "foo.vy").open("w") as f: + inner_path = dir1 / "foo.vy" + with inner_path.open("w") as f: f.write("contents of the inner directory") # symlink rules would be: @@ -159,9 +203,10 @@ def test_load_file_symlink(make_file, input_bundle, tmp_path, tmp_path_factory): # base/first/foo.vy # normpath would be base/symlink/../foo.vy => # base/foo.vy - file = input_bundle.load_file(symlink / ".." / "foo.vy") + to_load = symlink / ".." / "foo.vy" + file = input_bundle.load_file(to_load) - assert file == FileInput(0, tmp_path / "foo.vy", "contents of the upper directory") + assert file == FileInput(0, to_load, outer_path.resolve(), "contents of the outer directory") def test_json_input_bundle_basic(): @@ -169,40 +214,42 @@ def test_json_input_bundle_basic(): input_bundle = JSONInputBundle(files, [PurePath(".")]) file = input_bundle.load_file(PurePath("foo.vy")) - assert file == FileInput(0, PurePath("foo.vy"), "some text") + assert file == FileInput(0, PurePath("foo.vy"), PurePath("foo.vy"), "some text") def test_json_input_bundle_normpath(): - files = {PurePath("foo/../bar.vy"): {"content": "some text"}} + contents = "some text" + files = {PurePath("foo/../bar.vy"): {"content": contents}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - expected = FileInput(0, PurePath("bar.vy"), "some text") + barpath = PurePath("bar.vy") + + expected = FileInput(0, barpath, barpath, contents) file = input_bundle.load_file(PurePath("bar.vy")) assert file == expected file = input_bundle.load_file(PurePath("baz/../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("baz/../bar.vy"), barpath, contents) file = input_bundle.load_file(PurePath("./bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("./bar.vy"), barpath, contents) with input_bundle.search_path(PurePath("foo")): file = input_bundle.load_file(PurePath("../bar.vy")) - assert file == expected + assert file == FileInput(0, PurePath("../bar.vy"), barpath, contents) def test_json_input_abi(): some_abi = ["some abi"] some_abi_str = json.dumps(some_abi) - files = { - PurePath("foo.json"): {"abi": some_abi}, - PurePath("bar.txt"): {"content": some_abi_str}, - } + foopath = PurePath("foo.json") + barpath = PurePath("bar.txt") + files = {foopath: {"abi": some_abi}, barpath: {"content": some_abi_str}} input_bundle = JSONInputBundle(files, [PurePath(".")]) - file = input_bundle.load_file(PurePath("foo.json")) - assert file == ABIInput(0, PurePath("foo.json"), some_abi) + file = input_bundle.load_file(foopath) + assert file == ABIInput(0, foopath, foopath, some_abi) - file = input_bundle.load_file(PurePath("bar.txt")) - assert file == ABIInput(1, PurePath("bar.txt"), some_abi) + file = input_bundle.load_file(barpath) + assert file == ABIInput(1, barpath, barpath, some_abi) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index 27c0634cf8..5ea373fc19 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value): +def test_type_mismatch(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -23,11 +23,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value): +def test_invalid_literal(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -38,11 +38,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidType): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value): +def test_out_of_bounds(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -53,11 +53,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value): +def test_undeclared_definition(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -68,11 +68,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value): +def test_invalid_reference(namespace, value, dummy_input_bundle): code = f""" a: uint256[3] @@ -83,4 +83,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 2a09bd5ed5..c31146b16f 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -3,22 +3,20 @@ from vyper.ast import parse_to_ast from vyper.exceptions import CallViolation, StructureException from vyper.semantics.analysis import validate_semantics -from vyper.semantics.analysis.module import ModuleAnalyzer -def test_self_function_call(namespace): +def test_self_function_call(dummy_input_bundle): code = """ @internal def foo(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_cyclic_function_call(namespace): +def test_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -29,12 +27,11 @@ def bar(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_multi_cyclic_function_call(namespace): +def test_multi_cyclic_function_call(dummy_input_bundle): code = """ @internal def foo(): @@ -53,12 +50,11 @@ def potato(): self.foo() """ vyper_module = parse_to_ast(code) - with namespace.enter_scope(): - with pytest.raises(CallViolation): - ModuleAnalyzer(vyper_module, {}, namespace) + with pytest.raises(CallViolation): + validate_semantics(vyper_module, dummy_input_bundle) -def test_global_ann_assign_callable_no_crash(): +def test_global_ann_assign_callable_no_crash(dummy_input_bundle): code = """ balanceOf: public(HashMap[address, uint256]) @@ -68,5 +64,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - validate_semantics(vyper_module, {}) - assert excinfo.value.message == "Value is not callable" + validate_semantics(vyper_module, dummy_input_bundle) + assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index 0d61a8f8f8..e2c0f555af 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis import validate_semantics -def test_modify_iterator_function_outside_loop(namespace): +def test_modify_iterator_function_outside_loop(dummy_input_bundle): code = """ a: uint256[3] @@ -26,10 +26,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_pass_memory_var_to_other_function(namespace): +def test_pass_memory_var_to_other_function(dummy_input_bundle): code = """ @internal @@ -46,10 +46,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator(namespace): +def test_modify_iterator(dummy_input_bundle): code = """ a: uint256[3] @@ -61,10 +61,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_keywords(namespace): +def test_bad_keywords(dummy_input_bundle): code = """ @internal @@ -75,10 +75,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_bad_bound(namespace): +def test_bad_bound(dummy_input_bundle): code = """ @internal @@ -89,10 +89,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StateAccessViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_function_call(namespace): +def test_modify_iterator_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -108,10 +108,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) -def test_modify_iterator_recursive_function_call(namespace): +def test_modify_iterator_recursive_function_call(dummy_input_bundle): code = """ a: uint256[3] @@ -131,7 +131,7 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) iterator_inference_codes = [ @@ -169,7 +169,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(namespace, code): +def test_iterator_type_inference_checker(code, dummy_input_bundle): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - validate_semantics(vyper_module, {}) + validate_semantics(vyper_module, dummy_input_bundle) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index d390fe9a39..002ee38cd2 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -110,6 +110,6 @@ def test_allocator_overflow(get_contract): """ with pytest.raises( StorageLayoutException, - match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}\n", + match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}", ): get_contract(code) diff --git a/tox.ini b/tox.ini index c949354dfe..f9d4c3b60b 100644 --- a/tox.ini +++ b/tox.ini @@ -53,4 +53,4 @@ commands = basepython = python3 extras = lint commands = - mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --disallow-incomplete-defs -p vyper + mypy --install-types --non-interactive --follow-imports=silent --ignore-missing-imports --implicit-optional -p vyper diff --git a/vyper/__init__.py b/vyper/__init__.py index 482d5c3a60..5bb6469757 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -1,6 +1,6 @@ from pathlib import Path as _Path -from vyper.compiler import compile_code # noqa: F401 +from vyper.compiler import compile_code, compile_from_file_input try: from importlib.metadata import PackageNotFoundError # type: ignore diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py index e5b81f1e7f..4b46801153 100644 --- a/vyper/ast/__init__.py +++ b/vyper/ast/__init__.py @@ -6,7 +6,8 @@ from . import nodes, validation from .natspec import parse_natspec from .nodes import compare_nodes -from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings +from .utils import ast_to_dict +from .parse import parse_to_ast, parse_to_ast_with_settings # adds vyper.ast.nodes classes into the local namespace for name, obj in ( diff --git a/vyper/ast/__init__.pyi b/vyper/ast/__init__.pyi index d349e804d6..eac8ffdef5 100644 --- a/vyper/ast/__init__.pyi +++ b/vyper/ast/__init__.pyi @@ -4,5 +4,5 @@ from typing import Any, Optional, Union from . import expansion, folding, nodes, validation from .natspec import parse_natspec as parse_natspec from .nodes import * +from .parse import parse_to_ast as parse_to_ast from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 5471b971a4..1536f39165 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -5,22 +5,9 @@ from vyper.semantics.types.function import ContractFunctionT -def expand_annotated_ast(vyper_module: vy_ast.Module) -> None: - """ - Perform expansion / simplification operations on an annotated Vyper AST. - - This pass uses annotated type information to modify the AST, simplifying - logic and expanding subtrees to reduce the compexity during codegen. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node that has been type-checked and annotated. - """ - generate_public_variable_getters(vyper_module) - remove_unused_statements(vyper_module) - - +# TODO: remove this function. it causes correctness/performance problems +# because of copying and mutating the AST - getter generation should be handled +# during code generation. def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ Create getter functions for public variables. @@ -32,7 +19,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: """ for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}): - func_type = node._metadata["func_type"] + func_type = node._metadata["getter_type"] input_types, return_type = node._metadata["type"].getter_signature input_nodes = [] @@ -86,31 +73,11 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: returns=return_node, ) - with vyper_module.namespace(): - func_type = ContractFunctionT.from_FunctionDef(expanded) - - expanded._metadata["type"] = func_type - return_node.set_parent(expanded) + # update pointers vyper_module.add_to_body(expanded) + return_node.set_parent(expanded) + with vyper_module.namespace(): + func_type = ContractFunctionT.from_FunctionDef(expanded) -def remove_unused_statements(vyper_module: vy_ast.Module) -> None: - """ - Remove statement nodes that are unused after type checking. - - Once type checking is complete, we can remove now-meaningless statements to - simplify the AST prior to IR generation. - - Arguments - --------- - vyper_module : Module - Top-level Vyper AST node. - """ - - # constant declarations - values were substituted within the AST during folding - for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}): - vyper_module.remove_from_body(node) - - # `implements: interface` statements - validated during type checking - for node in vyper_module.get_children(vy_ast.ImplementsDecl): - vyper_module.remove_from_body(node) + expanded._metadata["func_type"] = func_type diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index ca9979b2a3..15367ce94a 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -89,7 +89,8 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a // NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps) _MAP: "HashMap" map_def: _MAP "[" ( NAME | array_def ) "," type "]" -type: ( NAME | array_def | tuple_def | map_def | dyn_array_def ) +imported_type: NAME "." NAME +type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def ) // Structs can be composed of 1+ basic types or other custom_types _STRUCT_DECL: "struct" @@ -291,7 +292,7 @@ special_builtins: empty | abi_decode // Adapted from: https://docs.python.org/3/reference/grammar.html // Adapted by: Erez Shinan NAME: /[a-zA-Z_]\w*/ -COMMENT: /#[^\n]*/ +COMMENT: /#[^\n\r]*/ _NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+ @@ -312,8 +313,10 @@ _number: DEC_NUMBER BOOL.2: "True" | "False" +ELLIPSIS: "..." + // TODO: Remove Docstring from here, and add to first part of body -?literal: ( _number | STRING | DOCSTRING | BOOL ) +?literal: ( _number | STRING | DOCSTRING | BOOL | ELLIPSIS) %ignore /[\t \f]+/ // WS %ignore /\\[\t \f]*\r?\n/ // LINE_CONT diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index c25fc423f8..41905b178a 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -43,7 +43,7 @@ def parse_natspec(vyper_module_folded: vy_ast.Module) -> Tuple[dict, dict]: for node in [i for i in vyper_module_folded.body if i.get("doc_string.value")]: docstring = node.doc_string.value - func_type = node._metadata["type"] + func_type = node._metadata["func_type"] if func_type.visibility != FunctionVisibility.EXTERNAL: continue diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 69bd1fed53..3bccc5f141 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -589,7 +589,8 @@ def __contains__(self, obj): class Module(TopLevel): - __slots__ = () + # metadata + __slots__ = ("path", "resolved_path", "source_id") def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: """ @@ -897,12 +898,16 @@ def validate(self): raise InvalidLiteral("Cannot have an empty tuple", self) -class Dict(ExprNode): - __slots__ = ("keys", "values") +class NameConstant(Constant): + __slots__ = () -class NameConstant(Constant): - __slots__ = ("value",) +class Ellipsis(Constant): + __slots__ = () + + +class Dict(ExprNode): + __slots__ = ("keys", "values") class Name(ExprNode): @@ -1407,7 +1412,7 @@ class Pass(Stmt): __slots__ = () -class _Import(Stmt): +class _ImportStmt(Stmt): __slots__ = ("name", "alias") def __init__(self, *args, **kwargs): @@ -1419,11 +1424,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class Import(_Import): +class Import(_ImportStmt): __slots__ = () -class ImportFrom(_Import): +class ImportFrom(_ImportStmt): __slots__ = ("level", "module") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 47c9af8526..05784aed0f 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -2,9 +2,9 @@ import ast as python_ast from typing import Any, Optional, Sequence, Type, Union from .natspec import parse_natspec as parse_natspec +from .parse import parse_to_ast as parse_to_ast +from .parse import parse_to_ast_with_settings as parse_to_ast_with_settings from .utils import ast_to_dict as ast_to_dict -from .utils import parse_to_ast as parse_to_ast -from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings NODE_BASE_ATTRIBUTES: Any NODE_SRC_ATTRIBUTES: Any @@ -59,6 +59,8 @@ class TopLevel(VyperNode): def __contains__(self, obj: Any) -> bool: ... class Module(TopLevel): + path: str = ... + resolved_path: str = ... def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ... def add_to_body(self, node: VyperNode) -> None: ... def remove_from_body(self, node: VyperNode) -> None: ... @@ -121,6 +123,9 @@ class Bytes(Constant): @property def s(self): ... +class NameConstant(Constant): ... +class Ellipsis(Constant): ... + class List(VyperNode): elements: list = ... @@ -131,8 +136,6 @@ class Dict(VyperNode): keys: list = ... values: list = ... -class NameConstant(Constant): ... - class Name(VyperNode): id: str = ... _type: str = ... @@ -188,7 +191,7 @@ class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... keywords: list = ... - func: Name = ... + func: VyperNode = ... class keyword(VyperNode): ... diff --git a/vyper/ast/annotation.py b/vyper/ast/parse.py similarity index 68% rename from vyper/ast/annotation.py rename to vyper/ast/parse.py index 9c7b1e063f..a2f2542179 100644 --- a/vyper/ast/annotation.py +++ b/vyper/ast/parse.py @@ -1,14 +1,114 @@ import ast as python_ast import tokenize from decimal import Decimal -from typing import Optional, cast +from typing import Any, Dict, List, Optional, Union, cast import asttokens -from vyper.exceptions import CompilerPanic, SyntaxException +from vyper.ast import nodes as vy_ast +from vyper.ast.pre_parser import pre_parse +from vyper.compiler.settings import Settings +from vyper.exceptions import CompilerPanic, ParserException, SyntaxException from vyper.typing import ModificationOffsets +def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: + _settings, ast = parse_to_ast_with_settings(*args, **kwargs) + return ast + + +def parse_to_ast_with_settings( + source_code: str, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, + add_fn_node: Optional[str] = None, +) -> tuple[Settings, vy_ast.Module]: + """ + Parses a Vyper source string and generates basic Vyper AST nodes. + + Parameters + ---------- + source_code : str + The Vyper source code to parse. + source_id : int, optional + Source id to use in the `src` member of each node. + contract_name: str, optional + Name of contract. + add_fn_node: str, optional + If not None, adds a dummy Python AST FunctionDef wrapper node. + source_id: int, optional + The source ID generated for this source code. + Corresponds to FileInput.source_id + module_path: str, optional + The path of the source code + Corresponds to FileInput.path + resolved_path: str, optional + The resolved path of the source code + Corresponds to FileInput.resolved_path + + Returns + ------- + list + Untyped, unoptimized Vyper AST nodes. + """ + if "\x00" in source_code: + raise ParserException("No null bytes (\\x00) allowed in the source code.") + settings, class_types, reformatted_code = pre_parse(source_code) + try: + py_ast = python_ast.parse(reformatted_code) + except SyntaxError as e: + # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors + raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e + + # Add dummy function node to ensure local variables are treated as `AnnAssign` + # instead of state variables (`VariableDecl`) + if add_fn_node: + fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) + fn_node.body = py_ast.body + fn_node.args = python_ast.arguments(defaults=[]) + py_ast.body = [fn_node] + + annotate_python_ast( + py_ast, + source_code, + class_types, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + + # Convert to Vyper AST. + module = vy_ast.get_node(py_ast) + assert isinstance(module, vy_ast.Module) # mypy hint + return settings, module + + +def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: + """ + Converts a Vyper AST node, or list of nodes, into a dictionary suitable for + output to the user. + """ + if isinstance(ast_struct, vy_ast.VyperNode): + return ast_struct.to_dict() + + if isinstance(ast_struct, list): + return [i.to_dict() for i in ast_struct] + + raise CompilerPanic(f'Unknown Vyper AST node provided: "{type(ast_struct)}".') + + +def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: + """ + Converts an AST dict, or list of dicts, into Vyper AST node objects. + """ + if isinstance(ast_struct, dict): + return vy_ast.get_node(ast_struct) + if isinstance(ast_struct, list): + return [vy_ast.get_node(i) for i in ast_struct] + raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -19,11 +119,13 @@ def __init__( modification_offsets: Optional[ModificationOffsets], tokens: asttokens.ASTTokens, source_id: int, - contract_name: Optional[str], + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ): self._tokens = tokens self._source_id = source_id - self._contract_name = contract_name + self._module_path = module_path + self._resolved_path = resolved_path self._source_code: str = source_code self.counter: int = 0 self._modification_offsets = {} @@ -83,7 +185,9 @@ def _visit_docstring(self, node): return node def visit_Module(self, node): - node.name = self._contract_name + node.path = self._module_path + node.resolved_path = self._resolved_path + node.source_id = self._source_id return self._visit_docstring(node) def visit_FunctionDef(self, node): @@ -163,6 +267,8 @@ def visit_Constant(self, node): node.ast_type = "Str" elif isinstance(node.value, bytes): node.ast_type = "Bytes" + elif isinstance(node.value, Ellipsis.__class__): + node.ast_type = "Ellipsis" else: raise SyntaxException( "Invalid syntax (unsupported Python Constant AST node).", @@ -250,7 +356,8 @@ def annotate_python_ast( source_code: str, modification_offsets: Optional[ModificationOffsets] = None, source_id: int = 0, - contract_name: Optional[str] = None, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, ) -> python_ast.AST: """ Annotate and optimize a Python AST in preparation conversion to a Vyper AST. @@ -270,7 +377,14 @@ def annotate_python_ast( """ tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor(source_code, modification_offsets, tokens, source_id, contract_name) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) visitor.visit(parsed_ast) return parsed_ast diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py index 4e669385ab..4c2e5394c9 100644 --- a/vyper/ast/utils.py +++ b/vyper/ast/utils.py @@ -1,64 +1,7 @@ -import ast as python_ast -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union from vyper.ast import nodes as vy_ast -from vyper.ast.annotation import annotate_python_ast -from vyper.ast.pre_parser import pre_parse -from vyper.compiler.settings import Settings -from vyper.exceptions import CompilerPanic, ParserException, SyntaxException - - -def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module: - return parse_to_ast_with_settings(*args, **kwargs)[1] - - -def parse_to_ast_with_settings( - source_code: str, - source_id: int = 0, - contract_name: Optional[str] = None, - add_fn_node: Optional[str] = None, -) -> tuple[Settings, vy_ast.Module]: - """ - Parses a Vyper source string and generates basic Vyper AST nodes. - - Parameters - ---------- - source_code : str - The Vyper source code to parse. - source_id : int, optional - Source id to use in the `src` member of each node. - contract_name: str, optional - Name of contract. - add_fn_node: str, optional - If not None, adds a dummy Python AST FunctionDef wrapper node. - - Returns - ------- - list - Untyped, unoptimized Vyper AST nodes. - """ - if "\x00" in source_code: - raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) - try: - py_ast = python_ast.parse(reformatted_code) - except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors - raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e - - # Add dummy function node to ensure local variables are treated as `AnnAssign` - # instead of state variables (`VariableDecl`) - if add_fn_node: - fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], []) - fn_node.body = py_ast.body - fn_node.args = python_ast.arguments(defaults=[]) - py_ast.body = [fn_node] - annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name) - - # Convert to Vyper AST. - module = vy_ast.get_node(py_ast) - assert isinstance(module, vy_ast.Module) # mypy hint - return settings, module +from vyper.exceptions import CompilerPanic def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py index afc0987b6d..72b05f15e3 100644 --- a/vyper/builtins/_utils.py +++ b/vyper/builtins/_utils.py @@ -1,10 +1,10 @@ from vyper.ast import parse_to_ast from vyper.codegen.context import Context -from vyper.codegen.global_context import GlobalContext from vyper.codegen.stmt import parse_body from vyper.semantics.analysis.local import FunctionNodeVisitor from vyper.semantics.namespace import Namespace, override_global_namespace from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability +from vyper.semantics.types.module import ModuleT def _strip_source_pos(ir_node): @@ -22,15 +22,16 @@ def generate_inline_function(code, variables, variables_2, memory_allocator): # Initialise a placeholder `FunctionDef` AST node and corresponding # `ContractFunctionT` type to rely on the annotation visitors in semantics # module. - ast_code.body[0]._metadata["type"] = ContractFunctionT( + ast_code.body[0]._metadata["func_type"] = ContractFunctionT( "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE ) # The FunctionNodeVisitor's constructor performs semantic checks # annotate the AST as side effects. - FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace) + analyzer.analyze() new_context = Context( - vars_=variables, global_ctx=GlobalContext(), memory_allocator=memory_allocator + vars_=variables, module_ctx=ModuleT(ast_code), memory_allocator=memory_allocator ) generated_ir = parse_body(ast_code.body[0].body, new_context) # strip source position info from the generated_ir since diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 22931508a6..d50a31767d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2499,9 +2499,9 @@ def infer_arg_types(self, node): validate_call_args(node, 2, ["unwrap_tuple"]) data_type = get_exact_type_from_node(node.args[0]) - output_typedef = TYPE_T(type_from_annotation(node.args[1])) + output_type = type_from_annotation(node.args[1]) - return [data_type, output_typedef] + return [data_type, TYPE_T(output_type)] @process_inputs def build_IR(self, expr, args, kwargs, context): diff --git a/vyper/builtins/interfaces/ERC165.vy b/vyper/builtins/interfaces/ERC165.vyi similarity index 88% rename from vyper/builtins/interfaces/ERC165.vy rename to vyper/builtins/interfaces/ERC165.vyi index a4ca451abd..441130f77c 100644 --- a/vyper/builtins/interfaces/ERC165.vy +++ b/vyper/builtins/interfaces/ERC165.vyi @@ -1,4 +1,4 @@ @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20.vy b/vyper/builtins/interfaces/ERC20.vyi similarity index 68% rename from vyper/builtins/interfaces/ERC20.vy rename to vyper/builtins/interfaces/ERC20.vyi index 065ca97a9b..ee533ab326 100644 --- a/vyper/builtins/interfaces/ERC20.vy +++ b/vyper/builtins/interfaces/ERC20.vyi @@ -1,38 +1,38 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _value: uint256 + sender: indexed(address) + recipient: indexed(address) + value: uint256 event Approval: - _owner: indexed(address) - _spender: indexed(address) - _value: uint256 + owner: indexed(address) + spender: indexed(address) + value: uint256 # Functions @view @external def totalSupply() -> uint256: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def allowance(_owner: address, _spender: address) -> uint256: - pass + ... @external def transfer(_to: address, _value: uint256) -> bool: - pass + ... @external def transferFrom(_from: address, _to: address, _value: uint256) -> bool: - pass + ... @external def approve(_spender: address, _value: uint256) -> bool: - pass + ... diff --git a/vyper/builtins/interfaces/ERC20Detailed.vy b/vyper/builtins/interfaces/ERC20Detailed.vyi similarity index 93% rename from vyper/builtins/interfaces/ERC20Detailed.vy rename to vyper/builtins/interfaces/ERC20Detailed.vyi index 7c4f546d45..0be1c6f153 100644 --- a/vyper/builtins/interfaces/ERC20Detailed.vy +++ b/vyper/builtins/interfaces/ERC20Detailed.vyi @@ -5,14 +5,14 @@ @view @external def name() -> String[1]: - pass + ... @view @external def symbol() -> String[1]: - pass + ... @view @external def decimals() -> uint8: - pass + ... diff --git a/vyper/builtins/interfaces/ERC4626.vy b/vyper/builtins/interfaces/ERC4626.vyi similarity index 90% rename from vyper/builtins/interfaces/ERC4626.vy rename to vyper/builtins/interfaces/ERC4626.vyi index 05865406cf..6d9e4c6ef7 100644 --- a/vyper/builtins/interfaces/ERC4626.vy +++ b/vyper/builtins/interfaces/ERC4626.vyi @@ -16,75 +16,75 @@ event Withdraw: @view @external def asset() -> address: - pass + ... @view @external def totalAssets() -> uint256: - pass + ... @view @external def convertToShares(assetAmount: uint256) -> uint256: - pass + ... @view @external def convertToAssets(shareAmount: uint256) -> uint256: - pass + ... @view @external def maxDeposit(owner: address) -> uint256: - pass + ... @view @external def previewDeposit(assets: uint256) -> uint256: - pass + ... @external def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxMint(owner: address) -> uint256: - pass + ... @view @external def previewMint(shares: uint256) -> uint256: - pass + ... @external def mint(shares: uint256, receiver: address=msg.sender) -> uint256: - pass + ... @view @external def maxWithdraw(owner: address) -> uint256: - pass + ... @view @external def previewWithdraw(assets: uint256) -> uint256: - pass + ... @external def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... @view @external def maxRedeem(owner: address) -> uint256: - pass + ... @view @external def previewRedeem(shares: uint256) -> uint256: - pass + ... @external def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: - pass + ... diff --git a/vyper/builtins/interfaces/ERC721.vy b/vyper/builtins/interfaces/ERC721.vyi similarity index 61% rename from vyper/builtins/interfaces/ERC721.vy rename to vyper/builtins/interfaces/ERC721.vyi index 464c0e255b..b8dcfd3c5f 100644 --- a/vyper/builtins/interfaces/ERC721.vy +++ b/vyper/builtins/interfaces/ERC721.vyi @@ -1,67 +1,62 @@ # Events event Transfer: - _from: indexed(address) - _to: indexed(address) - _tokenId: indexed(uint256) + sender: indexed(address) + recipient: indexed(address) + token_id: indexed(uint256) event Approval: - _owner: indexed(address) - _approved: indexed(address) - _tokenId: indexed(uint256) + owner: indexed(address) + approved: indexed(address) + token_id: indexed(uint256) event ApprovalForAll: - _owner: indexed(address) - _operator: indexed(address) - _approved: bool + owner: indexed(address) + operator: indexed(address) + approved: bool # Functions @view @external def supportsInterface(interface_id: bytes4) -> bool: - pass + ... @view @external def balanceOf(_owner: address) -> uint256: - pass + ... @view @external def ownerOf(_tokenId: uint256) -> address: - pass + ... @view @external def getApproved(_tokenId: uint256) -> address: - pass + ... @view @external def isApprovedForAll(_owner: address, _operator: address) -> bool: - pass + ... @external @payable def transferFrom(_from: address, _to: address, _tokenId: uint256): - pass + ... @external @payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256): - pass - -@external -@payable -def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024]): - pass +def safeTransferFrom(_from: address, _to: address, _tokenId: uint256, _data: Bytes[1024] = b""): + ... @external @payable def approve(_approved: address, _tokenId: uint256): - pass + ... @external def setApprovalForAll(_operator: address, _approved: bool): - pass + ... diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ca1792384e..4f88812fa0 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -271,10 +271,8 @@ def compile_files( with open(storage_file_path) as sfh: storage_layout_override = json.load(sfh) - output = vyper.compile_code( - file.source_code, - contract_name=str(file.path), - source_id=file.source_id, + output = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=final_formats, exc_handler=exc_handler, diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 2720f20d23..63da2e0643 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.opcodes import EVM_VERSIONS from vyper.exceptions import JSONError -from vyper.utils import keccak256 +from vyper.utils import OrderedSet, keccak256 TRANSLATE_MAP = { "abi": "abi", @@ -151,13 +151,6 @@ def get_evm_version(input_dict: dict) -> Optional[str]: return evm_version -def get_compilation_targets(input_dict: dict) -> list[PurePath]: - # TODO: once we have modules, add optional "compilation_targets" key - # which specifies which sources we actually want to compile. - - return [PurePath(p) for p in input_dict["sources"].keys()] - - def get_inputs(input_dict: dict) -> dict[PurePath, Any]: ret = {} seen = {} @@ -218,14 +211,14 @@ def get_inputs(input_dict: dict) -> dict[PurePath, Any]: # get unique output formats for each contract, given the input_dict # NOTE: would maybe be nice to raise on duplicated output formats -def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePath, list[str]]: +def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: output_formats: dict[PurePath, list[str]] = {} for path, outputs in input_dict["settings"]["outputSelection"].items(): if isinstance(outputs, dict): # if outputs are given in solc json format, collapse them into a single list - outputs = set(x for i in outputs.values() for x in i) + outputs = OrderedSet(x for i in outputs.values() for x in i) else: - outputs = set(outputs) + outputs = OrderedSet(outputs) for key in [i for i in ("evm", "evm.bytecode", "evm.deployedBytecode") if i in outputs]: outputs.remove(key) @@ -239,13 +232,13 @@ def get_output_formats(input_dict: dict, targets: list[PurePath]) -> dict[PurePa except KeyError as e: raise JSONError(f"Invalid outputSelection - {e}") - outputs = sorted(set(outputs)) + outputs = sorted(list(outputs)) if path == "*": - output_paths = targets + output_paths = [PurePath(path) for path in input_dict["sources"].keys()] else: output_paths = [PurePath(path)] - if output_paths[0] not in targets: + if str(output_paths[0]) not in input_dict["sources"]: raise JSONError(f"outputSelection references unknown contract '{output_paths[0]}'") for output_path in output_paths: @@ -281,9 +274,9 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) - compilation_targets = get_compilation_targets(input_dict) sources = get_inputs(input_dict) - output_formats = get_output_formats(input_dict, compilation_targets) + output_formats = get_output_formats(input_dict) + compilation_targets = list(output_formats.keys()) input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) @@ -295,12 +288,10 @@ def compile_from_input_dict( # use load_file to get a unique source_id file = input_bundle.load_file(contract_path) assert isinstance(file, FileInput) # mypy hint - data = vyper.compile_code( - file.source_code, - contract_name=str(file.path), + data = vyper.compile_from_file_input( + file, input_bundle=input_bundle, output_formats=output_formats[contract_path], - source_id=file.source_id, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5b79f293bd..dea30faabc 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -48,7 +48,7 @@ def __repr__(self): class Context: def __init__( self, - global_ctx, + module_ctx, memory_allocator, vars_=None, forvars=None, @@ -60,7 +60,7 @@ def __init__( self.vars = vars_ or {} # Global variables, in the form (name, storage location, type) - self.globals = global_ctx.variables + self.globals = module_ctx.variables # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,8 +75,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store global context - self.global_ctx = global_ctx + # store module context + self.module_ctx = module_ctx # full function type self.func_t = func_t diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index dc0e98786f..5870e64e98 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -47,8 +47,10 @@ StringT, StructT, TupleT, + is_type_t, ) from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T from vyper.utils import ( DECIMAL_DIVISOR, @@ -79,7 +81,7 @@ def __init__(self, node, context): self.ir_node = fn() if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.", node) + raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -662,39 +664,38 @@ def parse_Call(self): if function_name in DISPATCH_TABLE: return DISPATCH_TABLE[function_name].build_IR(self.expr, self.context) - # Struct constructors do not need `self` prefix. - elif isinstance(self.expr._metadata["type"], StructT): - args = self.expr.args - if len(args) == 1 and isinstance(args[0], vy_ast.Dict): - return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) + func_type = self.expr.func._metadata["type"] - # Interface assignment. Bar(
). - elif isinstance(self.expr._metadata["type"], InterfaceT): - (arg0,) = self.expr.args - arg_ir = Expr(arg0, self.context).ir_node + # Struct constructor + if is_type_t(func_type, StructT): + args = self.expr.args + if len(args) == 1 and isinstance(args[0], vy_ast.Dict): + return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"]) - assert arg_ir.typ == AddressT() - arg_ir.typ = self.expr._metadata["type"] + # Interface constructor. Bar(
). + if is_type_t(func_type, InterfaceT): + (arg0,) = self.expr.args + arg_ir = Expr(arg0, self.context).ir_node - return arg_ir + assert arg_ir.typ == AddressT() + arg_ir.typ = self.expr._metadata["type"] - elif isinstance(self.expr.func, vy_ast.Attribute) and self.expr.func.attr == "pop": + return arg_ir + + if isinstance(func_type, MemberFunctionT) and self.expr.func.attr == "pop": # TODO consider moving this to builtins darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 assert isinstance(darray.typ, DArrayT) return pop_dyn_array(darray, return_popped_item=True) - elif ( - # TODO use expr.func.type.is_internal once - # type annotations are consistently available - isinstance(self.expr.func, vy_ast.Attribute) - and isinstance(self.expr.func.value, vy_ast.Name) - and self.expr.func.value.id == "self" - ): - return self_call.ir_for_self_call(self.expr, self.context) - else: - return external_call.ir_for_external_call(self.expr, self.context) + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.expr, self.context) + else: + return external_call.ir_for_external_call(self.expr, self.context) + + raise CompilerPanic("unreachable", self.expr) def parse_List(self): typ = self.expr._metadata["type"] diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index c48f1256c3..454ba9c8cd 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -7,13 +7,13 @@ from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator from vyper.exceptions import CompilerPanic from vyper.semantics.types import VyperType from vyper.semantics.types.function import ContractFunctionT -from vyper.utils import MemoryPositions, calc_mem_gas, mkalphanum +from vyper.semantics.types.module import ModuleT +from vyper.utils import MemoryPositions, calc_mem_gas @dataclass @@ -44,7 +44,14 @@ def exit_sequence_label(self) -> str: @cached_property def ir_identifier(self) -> str: argz = ",".join([str(argtyp) for argtyp in self.func_t.argument_types]) - return mkalphanum(f"{self.visibility} {self.func_t.name} ({argz})") + + name = self.func_t.name + function_id = self.func_t._function_id + assert function_id is not None + + # include module id in the ir identifier to disambiguate functions + # with the same name but which come from different modules + return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: if self.frame_info is not None: @@ -94,7 +101,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False + code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -103,7 +110,7 @@ def generate_ir_for_function( - Clamping and copying of arguments - Function body """ - func_t = code._metadata["type"] + func_t = code._metadata["func_type"] # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -126,7 +133,7 @@ def generate_ir_for_function( context = Context( vars_=None, - global_ctx=global_ctx, + module_ctx=module_ctx, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py deleted file mode 100644 index 1f6783f6f8..0000000000 --- a/vyper/codegen/global_context.py +++ /dev/null @@ -1,32 +0,0 @@ -from functools import cached_property -from typing import Optional - -from vyper import ast as vy_ast - - -# Datatype to store all global context information. -# TODO: rename me to ModuleT -class GlobalContext: - def __init__(self, module: Optional[vy_ast.Module] = None): - self._module = module - - @cached_property - def functions(self): - return self._module.get_children(vy_ast.FunctionDef) - - @cached_property - def variables(self): - # variables that this module defines, ex. - # `x: uint256` is a private storage variable named x - if self._module is None: # TODO: make self._module never be None - return None - variable_decls = self._module.get_children(vy_ast.VariableDecl) - return {s.target.id: s.target._metadata["varinfo"] for s in variable_decls} - - @property - def immutables(self): - return [t for t in self.variables.values() if t.is_immutable] - - @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index bfdafa8ba9..ef861e3953 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -5,49 +5,67 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr from vyper.codegen.function_definitions import generate_ir_for_function -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic -from vyper.utils import method_id_int +from vyper.semantics.types.module import ModuleT +from vyper.utils import OrderedSet, method_id_int -def _topsort_helper(functions, lookup): - # single pass to get a global topological sort of functions (so that each - # function comes after each of its callees). may have duplicates, which get - # filtered out in _topsort() +def _topsort(functions): + # single pass to get a global topological sort of functions (so that each + # function comes after each of its callees). + ret = OrderedSet() + for func_ast in functions: + fn_t = func_ast._metadata["func_type"] + + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t.ast_def) + + ret.add(func_ast) + + # create globally unique IDs for each function + for idx, func in enumerate(ret): + func._metadata["func_type"]._function_id = idx + + return list(ret) + - ret = [] +# calculate globally reachable functions to see which +# ones should make it into the final bytecode. +# TODO: in the future, this should get obsolesced by IR dead code eliminator. +def _globally_reachable_functions(functions): + ret = OrderedSet() for f in functions: - # called_functions is a list of ContractFunctions, need to map - # back to FunctionDefs. - callees = [lookup[t.name] for t in f._metadata["type"].called_functions] - ret.extend(_topsort_helper(callees, lookup)) - ret.append(f) + fn_t = f._metadata["func_type"] - return ret + if not fn_t.is_external: + continue + for reachable_t in fn_t.reachable_internal_functions: + assert reachable_t.ast_def is not None + ret.add(reachable_t) -def _topsort(functions): - lookup = {f.name: f for f in functions} - # strip duplicates - return list(dict.fromkeys(_topsort_helper(functions, lookup))) + ret.add(fn_t) + + return ret def _is_constructor(func_ast): - return func_ast._metadata["type"].is_constructor + return func_ast._metadata["func_type"].is_constructor def _is_fallback(func_ast): - return func_ast._metadata["type"].is_fallback + return func_ast._metadata["func_type"].is_fallback def _is_internal(func_ast): - return func_ast._metadata["type"].is_internal + return func_ast._metadata["func_type"].is_internal def _is_payable(func_ast): - return func_ast._metadata["type"].is_payable + return func_ast._metadata["func_type"].is_payable def _annotated_method_id(abi_sig): @@ -63,7 +81,7 @@ def label_for_entry_point(abi_sig, entry_point): # adapt whatever generate_ir_for_function gives us into an IR node def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): - func_t = func_ast._metadata["type"] + func_t = func_ast._metadata["func_type"] assert func_t.is_fallback or func_t.is_constructor ret = ["seq"] @@ -86,12 +104,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, global_ctx): +def _generate_external_entry_points(external_functions, module_ctx): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, global_ctx) + func_ir = generate_ir_for_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -113,13 +131,13 @@ def _generate_external_entry_points(external_functions, global_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, global_ctx): +def _selector_section_dense(external_functions, module_ctx): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -264,13 +282,13 @@ def _selector_section_dense(external_functions, global_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, global_ctx): +def _selector_section_sparse(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -367,14 +385,14 @@ def _selector_section_sparse(external_functions, global_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, global_ctx): +def _selector_section_linear(external_functions, module_ctx): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) dispatcher = ["seq"] @@ -402,10 +420,11 @@ def _selector_section_linear(external_functions, global_ctx): return ret -# take a GlobalContext, and generate the runtime and deploy IR -def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: +# take a ModuleT, and generate the runtime and deploy IR +def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(global_ctx.functions) + function_defs = _topsort(module_ctx.function_defs) + reachable = _globally_reachable_functions(module_ctx.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) @@ -421,20 +440,26 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: # compile internal functions first so we have the function info for func_ast in internal_functions: - func_ir = _ir_for_internal_function(func_ast, global_ctx, False) - internal_functions_ir.append(IRnode.from_list(func_ir)) + # compile it so that _ir_info is populated (whether or not it makes + # it into the final IR artifact) + func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + + # only include it in the IR if it is reachable from an external + # function. + if func_ast._metadata["func_type"] in reachable: + internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, global_ctx) + selector_section = _selector_section_linear(external_functions, module_ctx) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, global_ctx) + selector_section = _selector_section_dense(external_functions, module_ctx) else: - selector_section = _selector_section_sparse(external_functions, global_ctx) + selector_section = _selector_section_sparse(external_functions, module_ctx) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -447,29 +472,30 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = global_ctx.immutable_section_bytes + immutables_len = module_ctx.immutable_section_bytes if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` + init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] internal_functions = [f for f in runtime_functions if _is_internal(f)] for f in internal_functions: - init_func_t = init_function._metadata["type"] - if f.name not in init_func_t.recursive_calls: + func_t = f._metadata["func_type"] + if func_t not in init_func_t.reachable_internal_functions: # unreachable code, delete it continue - func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True) + func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables # note: (deploy mem_ofst, code, extra_padding) - init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used + init_mem_used = init_function._metadata["func_type"]._ir_info.frame_info.mem_used # force msize to be initialized past the end of immutables section # so that builtins which use `msize` for "dynamic" memory diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f03f2eb9c8..f53e4a81b4 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -4,15 +4,6 @@ from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT -_label_counter = 0 - - -# TODO a more general way of doing this -def _generate_label(name: str) -> str: - global _label_counter - _label_counter += 1 - return f"label{_label_counter}" - def _align_kwargs(func_t, args_ir): """ @@ -63,7 +54,7 @@ def ir_for_self_call(stmt_expr, context): # note: internal_function_label asserts `func_t.is_internal`. _label = func_t._ir_info.internal_function_label(context.is_ctor_context) - return_label = _generate_label(f"{_label}_call") + return_label = _freshname(f"{_label}_call") # allocate space for the return buffer # TODO allocate in stmt and/or expr.py diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 254cad32e6..cc7a603b7c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import MEMORY, STORAGE from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure from vyper.semantics.types import DArrayT, MemberFunctionT +from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -117,44 +118,32 @@ def parse_Log(self): return events.ir_node_for_log(self.stmt, event, topic_ir, data_ir, self.context) def parse_Call(self): - # TODO use expr.func.type.is_internal once type annotations - # are consistently available. - is_self_function = ( - (isinstance(self.stmt.func, vy_ast.Attribute)) - and isinstance(self.stmt.func.value, vy_ast.Name) - and self.stmt.func.value.id == "self" - ) - if isinstance(self.stmt.func, vy_ast.Name): funcname = self.stmt.func.id return STMT_DISPATCH_TABLE[funcname].build_IR(self.stmt, self.context) - elif isinstance(self.stmt.func, vy_ast.Attribute) and self.stmt.func.attr in ( - "append", - "pop", - ): - func_type = self.stmt.func._metadata["type"] - if isinstance(func_type, MemberFunctionT): - darray = Expr(self.stmt.func.value, self.context).ir_node - args = [Expr(x, self.context).ir_node for x in self.stmt.args] - if self.stmt.func.attr == "append": - # sanity checks - assert len(args) == 1 - arg = args[0] - assert isinstance(darray.typ, DArrayT) - check_assign( - dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) - ) - - return append_dyn_array(darray, arg) - else: - assert len(args) == 0 - return pop_dyn_array(darray, return_popped_item=False) - - if is_self_function: - return self_call.ir_for_self_call(self.stmt, self.context) - else: - return external_call.ir_for_external_call(self.stmt, self.context) + func_type = self.stmt.func._metadata["type"] + + if isinstance(func_type, MemberFunctionT) and self.stmt.func.attr in ("append", "pop"): + darray = Expr(self.stmt.func.value, self.context).ir_node + args = [Expr(x, self.context).ir_node for x in self.stmt.args] + if self.stmt.func.attr == "append": + (arg,) = args + assert isinstance(darray.typ, DArrayT) + check_assign( + dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) + ) + + return append_dyn_array(darray, arg) + else: + assert len(args) == 0 + return pop_dyn_array(darray, return_popped_item=False) + + if isinstance(func_type, ContractFunctionT): + if func_type.is_internal: + return self_call.ir_for_self_call(self.stmt, self.context) + else: + return external_call.ir_for_external_call(self.stmt, self.context) def _assert_reason(self, test_expr, msg): # from parse_Raise: None passed as the assert condition diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 61d7a7c229..026c8369c5 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -5,7 +5,7 @@ import vyper.ast as vy_ast # break an import cycle import vyper.codegen.core as codegen import vyper.compiler.output as output -from vyper.compiler.input_bundle import InputBundle, PathLike +from vyper.compiler.input_bundle import FileInput, InputBundle, PathLike from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version @@ -44,10 +44,8 @@ UNKNOWN_CONTRACT_NAME = "" -def compile_code( - contract_source: str, - contract_name: str = UNKNOWN_CONTRACT_NAME, - source_id: int = 0, +def compile_from_file_input( + file_input: FileInput, input_bundle: InputBundle = None, settings: Settings = None, output_formats: Optional[OutputFormats] = None, @@ -58,6 +56,8 @@ def compile_code( experimental_codegen: bool = False, ) -> dict: """ + Main entry point into the compiler. + Generate consumable compiler output(s) from a single contract source code. Basically, a wrapper around CompilerData which munges the output data into the requested output formats. @@ -72,6 +72,8 @@ def compile_code( evm_version: str, optional The target EVM ruleset to compile for. If not given, defaults to the latest implemented ruleset. + source_id: int, optional + source_id to tag AST nodes with. -1 if not provided. settings: Settings, optional Compiler settings. show_gas_estimates: bool, optional @@ -96,11 +98,11 @@ def compile_code( # make IR output the same between runs codegen.reset_names() + # TODO: maybe at this point we might as well just pass a `FileInput` + # directly to `CompilerData`. compiler_data = CompilerData( - contract_source, + file_input, input_bundle, - Path(contract_name), - source_id, settings, storage_layout_override, show_gas_estimates, @@ -118,8 +120,33 @@ def compile_code( ret[output_format] = formatter(compiler_data) except Exception as exc: if exc_handler is not None: - exc_handler(contract_name, exc) + exc_handler(str(file_input.path), exc) else: raise exc return ret + + +def compile_code( + source_code: str, + contract_path: str | PathLike = UNKNOWN_CONTRACT_NAME, + source_id: int = -1, + resolved_path: PathLike | None = None, + *args, + **kwargs, +): + # this function could be renamed to compile_from_string + """ + Do the same thing as compile_from_file_input but takes a string for source + code. This was previously the main entry point into the compiler + # (`compile_from_file_input()` is newer) + """ + if isinstance(contract_path, str): + contract_path = Path(contract_path) + file_input = FileInput( + source_id=source_id, + source_code=source_code, + path=contract_path, + resolved_path=resolved_path or contract_path, # type: ignore + ) + return compile_from_file_input(file_input, *args, **kwargs) diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 1e41c3f137..27170f0a56 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -15,15 +15,11 @@ class CompilerInput: # an input to the compiler, basically an abstraction for file contents source_id: int - path: PathLike + path: PathLike # the path that was asked for - @staticmethod - def from_string(source_id: int, path: PathLike, file_contents: str) -> "CompilerInput": - try: - s = json.loads(file_contents) - return ABIInput(source_id, path, s) - except (ValueError, TypeError): - return FileInput(source_id, path, file_contents) + # resolved_path is the real path that was resolved to. + # mainly handy for debugging at this point + resolved_path: PathLike @dataclass @@ -40,13 +36,16 @@ class ABIInput(CompilerInput): abi: Any # something that json.load() returns -class _NotFound(Exception): - pass +def try_parse_abi(file_input: FileInput) -> CompilerInput: + try: + s = json.loads(file_input.source_code) + return ABIInput(file_input.source_id, file_input.path, file_input.resolved_path, s) + except (ValueError, TypeError): + return file_input -# wrap os.path.normpath, but return the same type as the input -def _normpath(path): - return path.__class__(os.path.normpath(path)) +class _NotFound(Exception): + pass # an "input bundle" to the compiler, representing the files which are @@ -60,20 +59,31 @@ class InputBundle: # a list of search paths search_paths: list[PathLike] + _cache: Any + def __init__(self, search_paths): self.search_paths = search_paths self._source_id_counter = 0 self._source_ids: dict[PathLike, int] = {} - def _load_from_path(self, path): + # this is a little bit cursed, but it allows consumers to cache data that + # share the same lifetime as this input bundle. + self._cache = lambda: None + + def _normalize_path(self, path): + raise NotImplementedError(f"not implemented! {self.__class__}._normalize_path()") + + def _load_from_path(self, resolved_path, path): raise NotImplementedError(f"not implemented! {self.__class__}._load_from_path()") - def _generate_source_id(self, path: PathLike) -> int: - if path not in self._source_ids: - self._source_ids[path] = self._source_id_counter + def _generate_source_id(self, resolved_path: PathLike) -> int: + # Note: it is possible for a file to get in here more than once, + # e.g. by symlink + if resolved_path not in self._source_ids: + self._source_ids[resolved_path] = self._source_id_counter self._source_id_counter += 1 - return self._source_ids[path] + return self._source_ids[resolved_path] def load_file(self, path: PathLike | str) -> CompilerInput: # search path precedence @@ -84,12 +94,9 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # Path("/a") / Path("/b") => Path("/b") to_try = sp / path - # normalize the path with os.path.normpath, to break down - # things like "foo/bar/../x.vy" => "foo/x.vy", with all - # the caveats around symlinks that os.path.normpath comes with. - to_try = _normpath(to_try) try: - res = self._load_from_path(to_try) + to_try = self._normalize_path(to_try) + res = self._load_from_path(to_try, path) break except _NotFound: tried.append(to_try) @@ -104,7 +111,7 @@ def load_file(self, path: PathLike | str) -> CompilerInput: # try to parse from json, so that return types are consistent # across FilesystemInputBundle and JSONInputBundle. if isinstance(res, FileInput): - return CompilerInput.from_string(res.source_id, res.path, res.source_code) + res = try_parse_abi(res) return res @@ -126,20 +133,45 @@ def search_path(self, path: Optional[PathLike]) -> Iterator[None]: finally: self.search_paths.pop() + # temporarily modify the top of the search path (within the + # scope of the context manager) with highest precedence to something else + @contextlib.contextmanager + def poke_search_path(self, path: PathLike) -> Iterator[None]: + tmp = self.search_paths[-1] + self.search_paths[-1] = path + try: + yield + finally: + self.search_paths[-1] = tmp + # regular input. takes a search path(s), and `load_file()` will search all # search paths for the file and read it from the filesystem class FilesystemInputBundle(InputBundle): - def _load_from_path(self, path: Path) -> CompilerInput: + def _normalize_path(self, path: Path) -> Path: + # normalize the path with os.path.normpath, to break down + # things like "foo/bar/../x.vy" => "foo/x.vy", with all + # the caveats around symlinks that os.path.normpath comes with. try: - with path.open() as f: - code = f.read() - except FileNotFoundError: + return path.resolve(strict=True) + except (FileNotFoundError, NotADirectoryError): raise _NotFound(path) - source_id = super()._generate_source_id(path) + def _load_from_path(self, resolved_path: Path, original_path: Path) -> CompilerInput: + try: + with resolved_path.open() as f: + code = f.read() + except (FileNotFoundError, NotADirectoryError): + raise _NotFound(resolved_path) + + source_id = super()._generate_source_id(resolved_path) + + return FileInput(source_id, original_path, resolved_path, code) - return FileInput(source_id, path, code) + +# wrap os.path.normpath, but return the same type as the input +def _normpath(path): + return path.__class__(os.path.normpath(path)) # fake filesystem for JSON inputs. takes a base path, and `load_file()` @@ -156,25 +188,28 @@ def __init__(self, input_json, search_paths): # should be checked by caller assert path not in self.input_json - self.input_json[_normpath(path)] = item + self.input_json[path] = item + + def _normalize_path(self, path: PurePath) -> PurePath: + return _normpath(path) - def _load_from_path(self, path: PurePath) -> CompilerInput: + def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> CompilerInput: try: - value = self.input_json[path] + value = self.input_json[resolved_path] except KeyError: - raise _NotFound(path) + raise _NotFound(resolved_path) - source_id = super()._generate_source_id(path) + source_id = super()._generate_source_id(resolved_path) if "content" in value: - return FileInput(source_id, path, value["content"]) + return FileInput(source_id, original_path, resolved_path, value["content"]) if "abi" in value: - return ABIInput(source_id, path, value["abi"]) + return ABIInput(source_id, original_path, resolved_path, value["abi"]) # TODO: ethPM support # if isinstance(contents, dict) and "contractTypes" in contents: # unreachable, based on how JSONInputBundle is constructed in # the codebase. - raise JSONError(f"Unexpected type in file: '{path}'") # pragma: nocover + raise JSONError(f"Unexpected type in file: '{resolved_path}'") # pragma: nocover diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index e47f300ba9..6d1e7ef70f 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict, deque +from pathlib import PurePath import asttokens @@ -33,8 +34,8 @@ def build_userdoc(compiler_data: CompilerData) -> dict: def build_external_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] - stem = compiler_data.contract_path.stem + interface = compiler_data.vyper_module_folded._metadata["type"].interface + stem = PurePath(compiler_data.contract_path).stem # capitalize words separated by '_' # ex: test_interface.vy -> TestInterface name = "".join([x.capitalize() for x in stem.split("_")]) @@ -52,7 +53,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str: def build_interface_output(compiler_data: CompilerData) -> str: - interface = compiler_data.vyper_module_folded._metadata["type"] + interface = compiler_data.vyper_module_folded._metadata["type"].interface out = "" if interface.events: @@ -70,7 +71,7 @@ def build_interface_output(compiler_data: CompilerData) -> str: out = f"{out}@{func.mutability.value}\n" args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments]) return_value = f" -> {func.return_type}" if func.return_type is not None else "" - out = f"{out}@external\ndef {func.name}({args}){return_value}:\n pass\n\n" + out = f"{out}@external\ndef {func.name}({args}){return_value}:\n ...\n\n" return out @@ -154,14 +155,19 @@ def _to_dict(func_t): def build_method_identifiers_output(compiler_data: CompilerData) -> dict: - interface = compiler_data.vyper_module_folded._metadata["type"] - functions = interface.functions.values() + module_t = compiler_data.vyper_module_folded._metadata["type"] + functions = module_t.function_defs - return {k: hex(v) for func in functions for k, v in func.method_ids.items()} + return { + k: hex(v) for func in functions for k, v in func._metadata["func_type"].method_ids.items() + } def build_abi_output(compiler_data: CompilerData) -> list: - abi = compiler_data.vyper_module_folded._metadata["type"].to_toplevel_abi_dict() + module_t = compiler_data.vyper_module_folded._metadata["type"] + _ = compiler_data.ir_runtime # ensure _ir_info is generated + + abi = module_t.interface.to_toplevel_abi_dict() if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4e32812fee..edffa9a85e 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -7,18 +7,18 @@ from vyper import ast as vy_ast from vyper.codegen import module from vyper.codegen.core import anchor_opt_level -from vyper.codegen.global_context import GlobalContext from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer from vyper.semantics import set_data_positions, validate_semantics from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout from vyper.venom import generate_assembly_experimental, generate_ir -DEFAULT_CONTRACT_NAME = PurePath("VyperContract.vy") +DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") class CompilerData: @@ -35,7 +35,7 @@ class CompilerData: Top-level Vyper AST node vyper_module_folded : vy_ast.Module Folded Vyper AST - global_ctx : GlobalContext + global_ctx : ModuleT Sorted, contextualized representation of the Vyper AST ir_nodes : IRnode IR used to generate deployment bytecode @@ -53,10 +53,8 @@ class CompilerData: def __init__( self, - source_code: str, + file_input: FileInput | str, input_bundle: InputBundle = None, - contract_path: Path | PurePath = DEFAULT_CONTRACT_NAME, - source_id: int = 0, settings: Settings = None, storage_layout: StorageLayout = None, show_gas_estimates: bool = False, @@ -68,12 +66,10 @@ def __init__( Arguments --------- - source_code: str - Vyper source code. - contract_path: Path, optional - The name of the contract being compiled. - source_id: int, optional - ID number used to identify this contract in the source map. + file_input: FileInput | str + A FileInput or string representing the input to the compiler. + FileInput is preferred, but `str` is accepted as a convenience + method (and also for backwards compatibility reasons) settings: Settings Set optimization mode. show_gas_estimates: bool, optional @@ -85,9 +81,15 @@ def __init__( """ # to force experimental codegen, uncomment: # experimental_codegen = True - self.contract_path = contract_path - self.source_code = source_code - self.source_id = source_id + + if isinstance(file_input, str): + file_input = FileInput( + source_code=file_input, + source_id=-1, + path=DEFAULT_CONTRACT_PATH, + resolved_path=DEFAULT_CONTRACT_PATH, + ) + self.file_input = file_input self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata @@ -97,10 +99,26 @@ def __init__( _ = self._generate_ast # force settings to be calculated + @cached_property + def source_code(self): + return self.file_input.source_code + + @cached_property + def source_id(self): + return self.file_input.source_id + + @cached_property + def contract_path(self): + return self.file_input.path + @cached_property def _generate_ast(self): - contract_name = str(self.contract_path) - settings, ast = generate_ast(self.source_code, self.source_id, contract_name) + settings, ast = vy_ast.parse_to_ast_with_settings( + self.source_code, + self.source_id, + module_path=str(self.contract_path), + resolved_path=str(self.file_input.resolved_path), + ) # validate the compiler settings # XXX: this is a bit ugly, clean up later @@ -141,12 +159,12 @@ def vyper_module_unfolded(self) -> vy_ast.Module: # This phase is intended to generate an AST for tooling use, and is not # used in the compilation process. - return generate_unfolded_ast(self.contract_path, self.vyper_module, self.input_bundle) + return generate_unfolded_ast(self.vyper_module, self.input_bundle) @cached_property def _folded_module(self): return generate_folded_ast( - self.contract_path, self.vyper_module, self.input_bundle, self.storage_layout_override + self.vyper_module, self.input_bundle, self.storage_layout_override ) @property @@ -160,8 +178,8 @@ def storage_layout(self) -> StorageLayout: return storage_layout @property - def global_ctx(self) -> GlobalContext: - return GlobalContext(self.vyper_module_folded) + def global_ctx(self) -> ModuleT: + return self.vyper_module_folded._metadata["type"] @cached_property def _ir_output(self): @@ -189,7 +207,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: _ = self._ir_output fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef) - return {f.name: f._metadata["type"] for f in fs} + return {f.name: f._metadata["func_type"] for f in fs} @cached_property def assembly(self) -> list: @@ -230,37 +248,12 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_ast( - source_code: str, source_id: int, contract_name: str -) -> tuple[Settings, vy_ast.Module]: - """ - Generate a Vyper AST from source code. - - Arguments - --------- - source_code : str - Vyper source code. - source_id : int - ID number used to identify this contract in the source map. - contract_name: str - Name of the contract. - - Returns - ------- - vy_ast.Module - Top-level Vyper AST node - """ - return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name) - - # destructive -- mutates module in place! -def generate_unfolded_ast( - contract_path: Path | PurePath, vyper_module: vy_ast.Module, input_bundle: InputBundle -) -> vy_ast.Module: +def generate_unfolded_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: vy_ast.validation.validate_literal_nodes(vyper_module) vy_ast.folding.replace_builtin_functions(vyper_module) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): # note: validate_semantics does type inference on the AST validate_semantics(vyper_module, input_bundle) @@ -268,7 +261,6 @@ def generate_unfolded_ast( def generate_folded_ast( - contract_path: Path, vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, @@ -294,7 +286,7 @@ def generate_folded_ast( vyper_module_folded = copy.deepcopy(vyper_module) vy_ast.folding.fold(vyper_module_folded) - with input_bundle.search_path(contract_path.parent): + with input_bundle.search_path(Path(vyper_module.resolved_path).parent): validate_semantics(vyper_module_folded, input_bundle) symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) @@ -302,9 +294,7 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes( - global_ctx: GlobalContext, optimize: OptimizationLevel -) -> tuple[IRnode, IRnode]: +def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. @@ -315,7 +305,7 @@ def generate_ir_nodes( Arguments --------- - global_ctx : GlobalContext + global_ctx: ModuleT Contextualized Vyper AST Returns diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 3bde20356e..993c0a85eb 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -49,6 +49,7 @@ def __init__(self, message="Error Message not found.", *items): self.message = message self.lineno = None self.col_offset = None + self.annotations = None if len(items) == 1 and isinstance(items[0], tuple) and isinstance(items[0][0], int): # support older exceptions that don't annotate - remove this in the future! @@ -79,7 +80,7 @@ def __str__(self): from vyper import ast as vy_ast from vyper.utils import annotate_source_code - if not hasattr(self, "annotations"): + if not self.annotations: if self.lineno is not None and self.col_offset is not None: return f"line {self.lineno}:{self.col_offset} {self.message}" else: @@ -105,8 +106,9 @@ def __str__(self): if isinstance(node, vy_ast.VyperNode): module_node = node.get_ancestor(vy_ast.Module) - if module_node.get("name") not in (None, ""): - node_msg = f'{node_msg}contract "{module_node.name}:{node.lineno}", ' + + if module_node.get("path") not in (None, ""): + node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' fn_node = node.get_ancestor(vy_ast.FunctionDef) if fn_node: @@ -229,6 +231,18 @@ class CallViolation(VyperException): """Illegal function call.""" +class ImportCycle(VyperException): + """An import cycle""" + + +class DuplicateImport(VyperException): + """A module was imported twice from the same module""" + + +class ModuleNotFound(VyperException): + """Module was not found""" + + class ImmutableViolation(VyperException): """Modifying an immutable variable, constant, or definition.""" diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 7db230167e..7b52a68e92 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -1,17 +1,4 @@ -import vyper.ast as vy_ast - from .. import types # break a dependency cycle. -from ..namespace import get_namespace -from .local import validate_functions -from .module import add_module_namespace -from .utils import _ExprAnalyser - - -def validate_semantics(vyper_ast, input_bundle): - # validate semantics and annotate AST with type/semantics information - namespace = get_namespace() +from .module import validate_semantics - with namespace.enter_scope(): - add_module_namespace(vyper_ast, input_bundle) - vy_ast.expansion.expand_annotated_ast(vyper_ast) - validate_functions(vyper_ast) +__all__ = ["validate_semantics"] diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 449e6ca338..4d1b1cdbab 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,8 +1,9 @@ import enum from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast +from vyper.compiler.input_bundle import InputBundle from vyper.exceptions import ( CompilerPanic, ImmutableViolation, @@ -12,6 +13,9 @@ from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +if TYPE_CHECKING: + from vyper.semantics.types.module import InterfaceT, ModuleT + class _StringEnum(enum.Enum): @staticmethod @@ -145,6 +149,35 @@ def __repr__(self): return f"" +# base class for things that are the "result" of analysis +class AnalysisResult: + pass + + +@dataclass +class ModuleInfo(AnalysisResult): + module_t: "ModuleT" + + @property + def module_node(self): + return self.module_t._module + + # duck type, conform to interface of VarInfo and ExprInfo + @property + def typ(self): + return self.module_t + + +@dataclass +class ImportInfo(AnalysisResult): + typ: Union[ModuleInfo, "InterfaceT"] + alias: str # the name in the namespace + qualified_module_name: str # for error messages + # source_id: int + input_bundle: InputBundle + node: vy_ast.VyperNode + + @dataclass class VarInfo: """ @@ -212,6 +245,10 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": is_immutable=var_info.is_immutable, ) + @classmethod + def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": + return cls(module_info.module_t) + def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 507eb0a570..9d35aef2bd 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -1,6 +1,17 @@ +import contextlib from typing import Tuple -from vyper.exceptions import StructureException +from vyper.exceptions import StructureException, VyperException + + +@contextlib.contextmanager +def tag_exceptions(node): + try: + yield + except VyperException as e: + if not e.annotations and not e.lineno: + raise e.with_annotation(node) from None + raise e from None class VyperNodeVisitorBase: @@ -16,9 +27,11 @@ def visit(self, node, *args): # node types with a shared parent for class_ in node.__class__.mro(): ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - return visitor_fn(node, *args) + + with tag_exceptions(node): + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) node_type = type(node).__name__ raise StructureException( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 87ec45c40d..88679a4b09 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -79,7 +79,7 @@ def set_storage_slots_with_overrides( # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] # Ignore functions without non-reentrant if type_.nonreentrant is None: @@ -165,7 +165,7 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: ret: Dict[str, Dict] = {} for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["type"] + type_ = node._metadata["func_type"] if type_.nonreentrant is None: continue diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py new file mode 100644 index 0000000000..e406878194 --- /dev/null +++ b/vyper/semantics/analysis/import_graph.py @@ -0,0 +1,37 @@ +import contextlib +from dataclasses import dataclass, field +from typing import Iterator + +from vyper import ast as vy_ast +from vyper.exceptions import CompilerPanic, ImportCycle + +""" +data structure for collecting import statements and validating the +import graph +""" + + +@dataclass +class ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + if expected != popped: + raise CompilerPanic("unreachable") + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 647f01c299..974c14f261 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -55,14 +55,15 @@ def validate_functions(vy_module: vy_ast.Module) -> None: - """Analyzes a vyper ast and validates the function-level namespaces.""" + """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() namespace = get_namespace() for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - FunctionNodeVisitor(vy_module, node, namespace) + analyzer = FunctionNodeVisitor(vy_module, node, namespace) + analyzer.analyze() except VyperException as e: err_list.append(e) @@ -185,26 +186,31 @@ def __init__( self.vyper_module = vyper_module self.fn_node = fn_node self.namespace = namespace - self.func = fn_node._metadata["type"] + self.func = fn_node._metadata["func_type"] self.expr_visitor = _ExprVisitor(self.func) + def analyze(self): # allow internal function params to be mutable location, is_immutable = ( (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) + self.namespace[arg.name] = VarInfo( + arg.typ, location=location, is_immutable=is_immutable + ) - for node in fn_node.body: + for node in self.fn_node.body: self.visit(node) + if self.func.return_type: - if not check_for_terminus(fn_node.body): + if not check_for_terminus(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node + f"Missing or unmatched return statements in function '{self.fn_node.name}'", + self.fn_node, ) # visit default args - assert self.func.n_keyword_args == len(fn_node.args.defaults) + assert self.func.n_keyword_args == len(self.fn_node.args.defaults) for kwarg in self.func.keyword_args: self.expr_visitor.visit(kwarg.default_value, kwarg.typ) @@ -224,10 +230,7 @@ def visit_AnnAssign(self, node): typ = type_from_annotation(node.annotation, DataLocation.MEMORY) validate_expected_type(node.value, typ) - try: - self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) self.expr_visitor.visit(node.target, typ) self.expr_visitor.visit(node.value, typ) @@ -290,6 +293,13 @@ def visit_Continue(self, node): raise StructureException("`continue` must be enclosed in a `for` loop", node) def visit_Expr(self, node): + if isinstance(node.value, vy_ast.Ellipsis): + raise StructureException( + "`...` is not allowed in `.vy` files! " + "Did you mean to import me as a `.vyi` file?", + node, + ) + if not isinstance(node.value, vy_ast.Call): raise StructureException("Expressions without assignment are disallowed", node) @@ -433,6 +443,7 @@ def visit_For(self, node): # Check if `iter` is a storage variable. get_descendants` is used to check for # nested `self` (e.g. structs) + # NOTE: this analysis will be borked once stateful modules are allowed! iter_is_storage_var = ( isinstance(node.iter, vy_ast.Attribute) and len(node.iter.get_descendants(vy_ast.Name, {"id": "self"})) > 0 @@ -453,8 +464,11 @@ def visit_For(self, node): call_node, ) - for name in self.namespace["self"].typ.members[fn_name].recursive_calls: + for reachable_t in ( + self.namespace["self"].typ.members[fn_name].reachable_internal_functions + ): # check for indirect modification + name = reachable_t.name fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] if _check_iterator_modification(node.iter, fn_node): raise ImmutableViolation( @@ -472,10 +486,7 @@ def visit_For(self, node): # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): - try: - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) try: with NodeMetadata.enter_typechecker_speculation(): diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 239438f35b..7aa661aec3 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,6 +1,6 @@ import os from pathlib import Path, PurePath -from typing import Optional +from typing import Any, Optional import vyper.builtins.interfaces from vyper import ast as vy_ast @@ -8,9 +8,11 @@ from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, + DuplicateImport, ExceptionList, InvalidLiteral, InvalidType, + ModuleNotFound, NamespaceCollision, StateAccessViolation, StructureException, @@ -18,128 +20,200 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase -from vyper.semantics.analysis.utils import check_constant, validate_expected_type +from vyper.semantics.analysis.import_graph import ImportGraph +from vyper.semantics.analysis.local import validate_functions +from vyper.semantics.analysis.utils import ( + check_constant, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import Namespace, get_namespace +from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation -def add_module_namespace(vy_module: vy_ast.Module, input_bundle: InputBundle) -> None: +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) + + +def validate_semantics_r( + module_ast: vy_ast.Module, + input_bundle: InputBundle, + import_graph: ImportGraph, + is_interface: bool, +) -> ModuleT: """ Analyze a Vyper module AST node, add all module-level objects to the - namespace and validate top-level correctness + namespace, type-check/validate semantics and annotate with type and analysis info """ - + # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - ModuleAnalyzer(vy_module, input_bundle, namespace) + with namespace.enter_scope(), import_graph.enter_path(module_ast): + analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + ret = analyzer.analyze() + + vy_ast.expansion.generate_public_variable_getters(module_ast) + + # if this is an interface, the function is already validated + # in `ContractFunction.from_vyi()` + if not is_interface: + validate_functions(module_ast) + + return ret + + +# compute reachable set and validate the call graph (detect cycles) +def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT] = None) -> None: + path = path or [] + + path.append(fn_t) + root = path[0] -def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: - if fn_names[-1] not in self_members: - return None - internal_calls = self_members[fn_names[-1]].internal_calls - for name in internal_calls: - if name in fn_names: - return fn_names + [name] - sequence = _find_cyclic_call(fn_names + [name], self_members) - if sequence: - return sequence - return None + for g in fn_t.called_functions: + if g == root: + message = " -> ".join([f.name for f in path]) + raise CallViolation(f"Contract contains cyclic function call: {message}") + + _compute_reachable_set(g, path=path) + + for h in g.reachable_internal_functions: + assert h != fn_t # sanity check + + fn_t.reachable_internal_functions.add(h) + + fn_t.reachable_internal_functions.add(g) + + path.pop() class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, module_node: vy_ast.Module, input_bundle: InputBundle, namespace: Namespace + self, + module_node: vy_ast.Module, + input_bundle: InputBundle, + namespace: Namespace, + import_graph: ImportGraph, + is_interface: bool = False, ) -> None: self.ast = module_node self.input_bundle = input_bundle self.namespace = namespace + self._import_graph = import_graph + self.is_interface = is_interface - # TODO: Move computation out of constructor - module_nodes = module_node.body.copy() - while module_nodes: - count = len(module_nodes) + # keep track of imported modules to prevent duplicate imports + self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} + + self.module_t: Optional[ModuleT] = None + + # ast cache, hitchhike onto the input_bundle object + if not hasattr(self.input_bundle._cache, "_ast_of"): + self.input_bundle._cache._ast_of: dict[int, vy_ast.Module] = {} # type: ignore + + def analyze(self) -> ModuleT: + # generate a `ModuleT` from the top-level node + # note: also validates unique method ids + if "type" in self.ast._metadata: + assert isinstance(self.ast._metadata["type"], ModuleT) + # we don't need to analyse again, skip out + self.module_t = self.ast._metadata["type"] + return self.module_t + + to_visit = self.ast.body.copy() + + # handle imports linearly + # (do this instead of handling in the next block so that + # `self._imported_modules` does not end up with garbage in it after + # exception swallowing). + import_stmts = self.ast.get_children((vy_ast.Import, vy_ast.ImportFrom)) + for node in import_stmts: + self.visit(node) + to_visit.remove(node) + + # keep trying to process all the nodes until we finish or can + # no longer progress. this makes it so we don't need to + # calculate a dependency tree between top-level items. + while len(to_visit) > 0: + count = len(to_visit) err_list = ExceptionList() - for node in list(module_nodes): + for node in to_visit.copy(): try: self.visit(node) - module_nodes.remove(node) - except (InvalidLiteral, InvalidType, VariableDeclarationException): + to_visit.remove(node) + except (InvalidLiteral, InvalidType, VariableDeclarationException) as e: # these exceptions cannot be caused by another statement not yet being # parsed, so we raise them immediately - raise + raise e from None except VyperException as e: err_list.append(e) # Only raise if no nodes were successfully processed. This allows module # level logic to parse regardless of the ordering of code elements. - if count == len(module_nodes): + if count == len(to_visit): err_list.raise_if_not_empty() - # generate an `InterfaceT` from the top-level node - used for building the ABI - # note: also validates unique method ids - interface = InterfaceT.from_ast(module_node) - module_node._metadata["type"] = interface - self.interface = interface # this is useful downstream + self.module_t = ModuleT(self.ast) + self.ast._metadata["type"] = self.module_t # attach namespace to the module for downstream use. _ns = Namespace() # note that we don't just copy the namespace because # there are constructor issues. - _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore - module_node._metadata["namespace"] = _ns + _ns.update({k: self.namespace[k] for k in self.namespace._scopes[-1]}) # type: ignore + self.ast._metadata["namespace"] = _ns + + self.analyze_call_graph() - self_members = namespace["self"].typ.members + return self.module_t + def analyze_call_graph(self): # get list of internal function calls made by each function - function_defs = self.ast.get_children(vy_ast.FunctionDef) - function_names = set(node.name for node in function_defs) - for node in function_defs: - calls_to_self = set( - i.func.attr for i in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}) - ) - # anything that is not a function call will get semantically checked later - calls_to_self = calls_to_self.intersection(function_names) - self_members[node.name].internal_calls = calls_to_self - - for fn_name in sorted(function_names): - if fn_name not in self_members: - # the referenced function does not exist - this is an issue, but we'll report - # it later when parsing the function so we can give more meaningful output - continue - - # check for circular function calls - sequence = _find_cyclic_call([fn_name], self_members) - if sequence is not None: - nodes = [] - for i in range(len(sequence) - 1): - fn_node = self.ast.get_children(vy_ast.FunctionDef, {"name": sequence[i]})[0] - call_node = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "self", "attr": sequence[i + 1]} - )[0] - nodes.append(call_node) - - raise CallViolation("Contract contains cyclic function call", *nodes) - - # get complete list of functions that are reachable from this function - function_set = set(i for i in self_members[fn_name].internal_calls if i in self_members) - while True: - expanded = set(x for i in function_set for x in self_members[i].internal_calls) - expanded |= function_set - if expanded == function_set: - break - function_set = expanded - - self_members[fn_name].recursive_calls = function_set + function_defs = self.module_t.function_defs + + for func in function_defs: + fn_t = func._metadata["func_type"] + + function_calls = func.get_descendants(vy_ast.Call) + + for call in function_calls: + try: + call_t = get_exact_type_from_node(call.func) + except VyperException: + # either there is a problem getting the call type. this is + # an issue, but it will be handled properly later. right now + # we just want to be able to construct the call graph. + continue + + if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + fn_t.called_functions.add(call_t) + + for func in function_defs: + fn_t = func._metadata["func_type"] + + # compute reachable set and validate the call graph + _compute_reachable_set(fn_t) + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self.input_bundle._cache._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_and_fold_ast(file) + + return ast_of[file.source_id] def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) + if not isinstance(type_, InterfaceT): raise StructureException("Invalid interface name", node.annotation) @@ -153,8 +227,9 @@ def visit_VariableDecl(self, node): if node.is_public: # generate function type and add to metadata # we need this when building the public getter - node._metadata["func_type"] = ContractFunctionT.getter_from_VariableDecl(node) + node._metadata["getter_type"] = ContractFunctionT.getter_from_VariableDecl(node) + # TODO: move this check to local analysis if node.is_immutable: # mutability is checked automatically preventing assignment # outside of the constructor, here we just check a value is assigned, @@ -213,22 +288,18 @@ def _finalize(): self.namespace["self"].typ.add_member(name, var_info) node.target._metadata["type"] = type_ except NamespaceCollision: + # rewrite the error message to be slightly more helpful raise NamespaceCollision( f"Value '{name}' has already been declared", node ) from None - except VyperException as exc: - raise exc.with_annotation(node) from None def _validate_self_namespace(): # block globals if storage variable already exists - try: - if name in self.namespace["self"].typ.members: - raise NamespaceCollision( - f"Value '{name}' has already been declared", node - ) from None - self.namespace[name] = var_info - except VyperException as exc: - raise exc.with_annotation(node) from None + if name in self.namespace["self"].typ.members: + raise NamespaceCollision( + f"Value '{name}' has already been declared", node + ) from None + self.namespace[name] = var_info if node.is_constant: if not node.value: @@ -251,41 +322,50 @@ def _validate_self_namespace(): _validate_self_namespace() return _finalize() - try: - self.namespace.validate_assignment(name) - except NamespaceCollision as exc: - raise exc.with_annotation(node) from None + self.namespace.validate_assignment(name) return _finalize() def visit_EnumDef(self, node): obj = EnumT.from_EnumDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace[node.name] = obj def visit_EventDef(self, node): obj = EventT.from_EventDef(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + node._metadata["event_type"] = obj + self.namespace[node.name] = obj def visit_FunctionDef(self, node): - func = ContractFunctionT.from_FunctionDef(node) + if self.is_interface: + func_t = ContractFunctionT.from_vyi(node) + if not func_t.is_external: + # TODO test me! + raise StructureException( + "Internal functions in `.vyi` files are not allowed!", node + ) + else: + func_t = ContractFunctionT.from_FunctionDef(node) - try: - self.namespace["self"].typ.add_member(func.name, func) - node._metadata["type"] = func - except VyperException as exc: - raise exc.with_annotation(node) from None + self.namespace["self"].typ.add_member(func_t.name, func_t) + node._metadata["func_type"] = func_t def visit_Import(self, node): - if not node.alias: - raise StructureException("Import requires an accompanying `as` statement", node) # import x.y[name] as y[alias] - self._add_import(node, 0, node.name, node.alias) + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + suggested_alias = node.name[node.name.rfind(".") :] + suggestion = f"hint: try `import {node.name} as {suggested_alias}`" + raise StructureException( + f"import requires an accompanying `as` statement ({suggestion})", node + ) + + self._add_import(node, 0, node.name, alias) def visit_ImportFrom(self, node): # from m.n[module] import x[name] as y[alias] @@ -299,42 +379,87 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_ast(node) - try: - self.namespace[node.name] = obj - except VyperException as exc: - raise exc.with_annotation(node) from None + obj = InterfaceT.from_InterfaceDef(node) + self.namespace[node.name] = obj def visit_StructDef(self, node): - struct_t = StructT.from_ast_def(node) - try: - self.namespace[node.name] = struct_t - except VyperException as exc: - raise exc.with_annotation(node) from None + struct_t = StructT.from_StructDef(node) + node._metadata["struct_type"] = struct_t + self.namespace[node.name] = struct_t def _add_import( self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str ) -> None: - type_ = self._load_import(level, qualified_module_name) - - try: - self.namespace[alias] = type_ - except VyperException as exc: - raise exc.with_annotation(node) from None + module_info = self._load_import(node, level, qualified_module_name, alias) + node._metadata["import_info"] = ImportInfo( + module_info, alias, qualified_module_name, self.input_bundle, node + ) + self.namespace[alias] = module_info - # load an InterfaceT from an import. + # load an InterfaceT or ModuleInfo from an import. # raises FileNotFoundError - def _load_import(self, level: int, module_str: str) -> InterfaceT: + def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: + # the directory this (currently being analyzed) module is in + self_search_path = Path(self.ast.resolved_path).parent + + with self.input_bundle.poke_search_path(self_search_path): + return self._load_import_helper(node, level, module_str, alias) + + def _load_import_helper( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> Any: if _is_builtin(module_str): return _load_builtin_import(level, module_str) path = _import_to_path(level, module_str) + # this could conceivably be in the ImportGraph but no need at this point + if path in self._imported_modules: + previous_import_stmt = self._imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self._imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self.input_bundle.load_file(path_vy) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=False, + ) + + return ModuleInfo(module_t) + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + try: - file = self.input_bundle.load_file(path.with_suffix(".vy")) + file = self.input_bundle.load_file(path.with_suffix(".vyi")) assert isinstance(file, FileInput) # mypy hint - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=str(file.path)) - return InterfaceT.from_ast(interface_ast) + module_ast = self._ast_from_file(file) + + with override_global_namespace(Namespace()): + validate_semantics_r( + module_ast, + self.input_bundle, + import_graph=self._import_graph, + is_interface=True, + ) + module_t = module_ast._metadata["type"] + + return module_t.interface + except FileNotFoundError: pass @@ -343,7 +468,24 @@ def _load_import(self, level: int, module_str: str) -> InterfaceT: assert isinstance(file, ABIInput) # mypy hint return InterfaceT.from_json_abi(str(file.path), file.abi) except FileNotFoundError: - raise ModuleNotFoundError(module_str) + pass + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, node) from err + + +def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=str(file.path), + resolved_path=str(file.resolved_path), + ) + vy_ast.validation.validate_literal_nodes(ret) + vy_ast.folding.fold(ret) + + return ret # convert an import to a path (without suffix) @@ -385,7 +527,7 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: remapped_module = remapped_module.removeprefix("vyper.interfaces") remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - path = _import_to_path(level, remapped_module).with_suffix(".vy") + path = _import_to_path(level, remapped_module).with_suffix(".vyi") try: file = input_bundle.load_file(path) @@ -394,5 +536,8 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT: raise ModuleNotFoundError(f"Not a builtin: {module_str}") from None # TODO: it might be good to cache this computation - interface_ast = vy_ast.parse_to_ast(file.source_code, contract_name=module_str) - return InterfaceT.from_ast(interface_ast) + interface_ast = _parse_and_fold_ast(file) + + with override_global_namespace(Namespace()): + module_t = validate_semantics(interface_ast, input_bundle, is_interface=True) + return module_t.interface diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index afa6b56838..1785afd92d 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -66,8 +66,15 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: # if it's a Name, we have varinfo for it if isinstance(node, vy_ast.Name): - varinfo = self.namespace[node.id] - return ExprInfo.from_varinfo(varinfo) + info = self.namespace[node.id] + + if isinstance(info, VarInfo): + return ExprInfo.from_varinfo(info) + + if isinstance(info, ModuleInfo): + return ExprInfo.from_moduleinfo(info) + + raise CompilerPanic("unreachable!", node) if isinstance(node, vy_ast.Attribute): # if it's an Attr, we check the parent exprinfo and @@ -192,16 +199,17 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, VyperType): + + if isinstance(s, (VyperType, TYPE_T)): # ex. foo.bar(). bar() is a ContractFunctionT return [s] if is_self_reference and (s.is_constant or s.is_immutable): _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] - except UnknownAttribute: + except UnknownAttribute as e: if not is_self_reference: - raise + raise e from None if name in self.namespace: _raise_invalid_reference(name, node) @@ -364,6 +372,7 @@ def types_from_Name(self, node): return [TYPE_T(t)] return [t.typ] + except VyperException as exc: raise exc.with_annotation(node) from None diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 613ac0c03b..4df2511a29 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -95,7 +95,7 @@ def validate_assignment(self, attr): def get_namespace(): """ - Get the active namespace object. + Get the global namespace object. """ global _namespace try: diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index ad470718c8..1fef6a706e 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -2,9 +2,10 @@ from .base import TYPE_T, KwargSettings, VyperType, is_type_t from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT +from .module import InterfaceT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT -from .user import EnumT, EventT, InterfaceT, StructT +from .user import EnumT, EventT, StructT def _get_primitive_types(): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index c5af5c2a39..d22d9bfff9 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -44,6 +44,13 @@ class VyperType: A tuple of invalid `DataLocation`s for this type _is_prim_word: bool, optional This is a word type like uint256, int8, bytesM or address + _supports_external_calls: bool, optional + Whether or not this type supports external calls. Currently + limited to `InterfaceT`s + _attribute_in_annotation: bool, optional + Whether or not this type can be attributed in a type + annotation, like IFoo.SomeType. Currently limited to + `InterfaceT`s. """ _id: str @@ -58,6 +65,9 @@ class VyperType: _as_array: bool = False # rename to something like can_be_array_member _as_hashmap_key: bool = False + _supports_external_calls: bool = False + _attribute_in_annotation: bool = False + size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: @@ -261,7 +271,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]: VyperType, optional Type generated as a result of the call. """ - raise StructureException("Value is not callable", node) + raise StructureException(f"{self} is not callable", node) @classmethod def get_subscripted_type(self, node: vy_ast.Index) -> None: diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 09130626aa..e3c381ac69 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -132,7 +132,15 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": raise UnexpectedValue("Node id does not match type name") length = get_index_value(node.slice) # type: ignore - # return cls._type(length, location, is_constant, is_public, is_immutable) + + if length is None: + raise StructureException( + f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node + ) + + # TODO: pass None to constructor after we redo length inference on bytestrings + length = length or 0 + return cls(length) @classmethod diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 140f73f095..ec30ac85d6 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -17,7 +17,11 @@ StructureException, ) from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot -from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type +from vyper.semantics.analysis.utils import ( + check_kwargable, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -44,6 +48,7 @@ class KeywordArg(_FunctionArg): ast_source: Optional[vy_ast.VyperNode] = None +# TODO: refactor this into FunctionT (from an ast) and ABIFunctionT (from json) class ContractFunctionT(VyperType): """ Contract function type. @@ -81,6 +86,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, nonreentrant: Optional[str] = None, + ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -92,11 +98,18 @@ def __init__( self.mutability = state_mutability self.nonreentrant = nonreentrant - # a list of internal functions this function calls - self.called_functions = OrderedSet[ContractFunctionT]() + self.ast_def = ast_def + + # a list of internal functions this function calls. + # to be populated during analysis + self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() + + # recursively reachable from this function + self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # to be populated during codegen self._ir_info: Any = None + self._function_id: Optional[int] = None @cached_property def call_site_kwargs(self): @@ -126,7 +139,7 @@ def __hash__(self): return hash(id(self)) @classmethod - def from_abi(cls, abi: Dict) -> "ContractFunctionT": + def from_abi(cls, abi: dict) -> "ContractFunctionT": """ Generate a `ContractFunctionT` object from an ABI interface. @@ -157,190 +170,174 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT": ) @classmethod - def from_FunctionDef( - cls, node: vy_ast.FunctionDef, is_interface: Optional[bool] = False - ) -> "ContractFunctionT": + def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": """ - Generate a `ContractFunctionT` object from a `FunctionDef` node. + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an `InterfaceDef` Arguments --------- - node : FunctionDef + funcdef: FunctionDef Vyper ast node to generate the function definition from. - is_interface: bool, optional - Boolean indicating if the function definition is part of an interface. Returns ------- ContractFunctionT """ - kwargs: Dict[str, Any] = {} - if is_interface: - # FunctionDef with stateMutability in body (Interface defintions) - if ( - len(node.body) == 1 - and isinstance(node.body[0], vy_ast.Expr) - and isinstance(node.body[0].value, vy_ast.Name) - and StateMutability.is_valid_value(node.body[0].value.id) - ): - # Interfaces are always public - kwargs["function_visibility"] = FunctionVisibility.EXTERNAL - kwargs["state_mutability"] = StateMutability(node.body[0].value.id) - elif len(node.body) == 1 and node.body[0].get("value.id") in ("constant", "modifying"): - if node.body[0].value.id == "constant": - expected = "view or pure" - else: - expected = "payable or nonpayable" - raise StructureException( - f"State mutability should be set to {expected}", node.body[0] - ) + # FunctionDef with stateMutability in body (Interface defintions) + body = funcdef.body + if ( + len(body) == 1 + and isinstance(body[0], vy_ast.Expr) + and isinstance(body[0].value, vy_ast.Name) + and StateMutability.is_valid_value(body[0].value.id) + ): + # Interfaces are always public + function_visibility = FunctionVisibility.EXTERNAL + state_mutability = StateMutability(body[0].value.id) + # handle errors + elif len(body) == 1 and body[0].get("value.id") in ("constant", "modifying"): + if body[0].value.id == "constant": + expected = "view or pure" else: - raise StructureException( - "Body must only contain state mutability label", node.body[0] - ) - + expected = "payable or nonpayable" + raise StructureException(f"State mutability should be set to {expected}", body[0]) else: - # FunctionDef with decorators (normal functions) - for decorator in node.decorator_list: - if isinstance(decorator, vy_ast.Call): - if "nonreentrant" in kwargs: - raise StructureException( - "nonreentrant decorator is already set with key: " - f"{kwargs['nonreentrant']}", - node, - ) + raise StructureException("Body must only contain state mutability label", body[0]) - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - if node.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" - raise FunctionDeclarationException(msg, decorator) - - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) - - kwargs["nonreentrant"] = nonreentrant_key - - elif isinstance(decorator, vy_ast.Name): - if FunctionVisibility.is_valid_value(decorator.id): - if "function_visibility" in kwargs: - raise FunctionDeclarationException( - f"Visibility is already set to: {kwargs['function_visibility']}", - node, - ) - kwargs["function_visibility"] = FunctionVisibility(decorator.id) - - elif StateMutability.is_valid_value(decorator.id): - if "state_mutability" in kwargs: - raise FunctionDeclarationException( - f"Mutability is already set to: {kwargs['state_mutability']}", node - ) - kwargs["state_mutability"] = StateMutability(decorator.id) - - else: - if decorator.id == "constant": - warnings.warn( - "'@constant' decorator has been removed (see VIP2040). " - "Use `@view` instead.", - DeprecationWarning, - ) - raise FunctionDeclarationException( - f"Unknown decorator: {decorator.id}", decorator - ) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=None, + ast_def=funcdef, + ) + + @classmethod + def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` inside + of an interface (`.vyi`) file + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + if nonreentrant_key is not None: + raise FunctionDeclarationException( + "nonreentrant key not allowed in interfaces", funcdef + ) + + if funcdef.name == "__init__": + raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) - else: - raise StructureException("Bad decorator syntax", decorator) + if funcdef.name == "__default__": + raise FunctionDeclarationException( + "Default functions cannot appear in interfaces", funcdef + ) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) - if "function_visibility" not in kwargs: + if len(funcdef.body) != 1 or not isinstance(funcdef.body[0].get("value"), vy_ast.Ellipsis): raise FunctionDeclarationException( - f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", node + "function body in an interface can only be ...!", funcdef ) - if node.name == "__default__": - if kwargs["function_visibility"] != FunctionVisibility.EXTERNAL: + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) + + @classmethod + def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": + """ + Generate a `ContractFunctionT` object from a `FunctionDef` node. + + Arguments + --------- + funcdef: FunctionDef + Vyper ast node to generate the function definition from. + + Returns + ------- + ContractFunctionT + """ + function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + + positional_args, keyword_args = _parse_args(funcdef) + + return_type = _parse_return_type(funcdef) + + # validate default and init functions + if funcdef.name == "__default__": + if function_visibility != FunctionVisibility.EXTERNAL: raise FunctionDeclarationException( - "Default function must be marked as `@external`", node + "Default function must be marked as `@external`", funcdef ) - if node.args.args: + if funcdef.args.args: raise FunctionDeclarationException( - "Default function may not receive any arguments", node.args.args[0] + "Default function may not receive any arguments", funcdef.args.args[0] ) - if "state_mutability" not in kwargs: - # Assume nonpayable if not set at all (cannot accept Ether, but can modify state) - kwargs["state_mutability"] = StateMutability.NONPAYABLE - - if kwargs["state_mutability"] == StateMutability.PURE and "nonreentrant" in kwargs: - raise StructureException("Cannot use reentrancy guard on pure functions", node) - - if node.name == "__init__": + if funcdef.name == "__init__": if ( - kwargs["state_mutability"] in (StateMutability.PURE, StateMutability.VIEW) - or kwargs["function_visibility"] == FunctionVisibility.INTERNAL + state_mutability in (StateMutability.PURE, StateMutability.VIEW) + or function_visibility == FunctionVisibility.INTERNAL ): raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", node + "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef ) - - # call arguments - if node.args.defaults: + if return_type is not None: raise FunctionDeclarationException( - "Constructor may not use default arguments", node.args.defaults[0] + "Constructor may not have a return type", funcdef.returns ) - argnames = set() # for checking uniqueness - n_total_args = len(node.args.args) - n_positional_args = n_total_args - len(node.args.defaults) - - positional_args: list[PositionalArg] = [] - keyword_args: list[KeywordArg] = [] - - for i, arg in enumerate(node.args.args): - argname = arg.arg - if argname in ("gas", "value", "skip_contract_check", "default_return_value"): - raise ArgumentException( - f"Cannot use '{argname}' as a variable name in a function input", arg + # call arguments + if funcdef.args.defaults: + raise FunctionDeclarationException( + "Constructor may not use default arguments", funcdef.args.defaults[0] ) - if argname in argnames: - raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) - - if arg.annotation is None: - raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) - - type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) - - if i < n_positional_args: - positional_args.append(PositionalArg(argname, type_, ast_source=arg)) - else: - value = node.args.defaults[i - n_positional_args] - if not check_kwargable(value): - raise StateAccessViolation( - "Value must be literal or environment variable", value - ) - validate_expected_type(value, type_) - keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) - - argnames.add(argname) - # return types - if node.returns is None: - return_type = None - elif node.name == "__init__": - raise FunctionDeclarationException( - "Constructor may not have a return type", node.returns - ) - elif isinstance(node.returns, (vy_ast.Name, vy_ast.Subscript, vy_ast.Tuple)): - # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE - return_type = type_from_annotation(node.returns, DataLocation.MEMORY) - else: - raise InvalidType("Function return value must be a type name or tuple", node.returns) - - return cls(node.name, positional_args, keyword_args, return_type, **kwargs) + return cls( + funcdef.name, + positional_args, + keyword_args, + return_type, + function_visibility, + state_mutability, + nonreentrant=nonreentrant_key, + ast_def=funcdef, + ) def set_reentrancy_key_position(self, position: StorageSlot) -> None: if hasattr(self, "reentrancy_key_position"): @@ -383,6 +380,7 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio return_type, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.VIEW, + ast_def=node, ) @property @@ -489,8 +487,12 @@ def method_ids(self) -> Dict[str, int]: return method_ids def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: - if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL: - raise CallViolation("Cannot call external functions via 'self'", node) + # mypy hint - right now, the only way a ContractFunctionT can be + # called is via `Attribute`, e.x. self.foo() or library.bar() + assert isinstance(node.func, vy_ast.Attribute) + parent_t = get_exact_type_from_node(node.func.value) + if not parent_t._supports_external_calls and self.visibility == FunctionVisibility.EXTERNAL: + raise CallViolation("Cannot call external functions via 'self' or via library", node) kwarg_keys = [] # for external calls, include gas and value as optional kwargs @@ -584,6 +586,125 @@ def abi_signature_for_kwargs(self, kwargs: list[KeywordArg]) -> str: return self.name + "(" + ",".join([arg.typ.abi_type.selector_name() for arg in args]) + ")" +def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: + # return types + if funcdef.returns is None: + return None + # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE + return type_from_annotation(funcdef.returns, DataLocation.MEMORY) + + +def _parse_decorators( + funcdef: vy_ast.FunctionDef, +) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: + function_visibility = None + state_mutability = None + nonreentrant_key = None + + for decorator in funcdef.decorator_list: + if isinstance(decorator, vy_ast.Call): + if nonreentrant_key is not None: + raise StructureException( + "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", + funcdef, + ) + + if decorator.get("func.id") != "nonreentrant": + raise StructureException("Decorator is not callable", decorator) + if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): + raise StructureException( + "@nonreentrant name must be given as a single string literal", decorator + ) + + if funcdef.name == "__init__": + msg = "Nonreentrant decorator disallowed on `__init__`" + raise FunctionDeclarationException(msg, decorator) + + nonreentrant_key = decorator.args[0].value + validate_identifier(nonreentrant_key, decorator.args[0]) + + elif isinstance(decorator, vy_ast.Name): + if FunctionVisibility.is_valid_value(decorator.id): + if function_visibility is not None: + raise FunctionDeclarationException( + f"Visibility is already set to: {function_visibility}", funcdef + ) + function_visibility = FunctionVisibility(decorator.id) + + elif StateMutability.is_valid_value(decorator.id): + if state_mutability is not None: + raise FunctionDeclarationException( + f"Mutability is already set to: {state_mutability}", funcdef + ) + state_mutability = StateMutability(decorator.id) + + else: + if decorator.id == "constant": + warnings.warn( + "'@constant' decorator has been removed (see VIP2040). " + "Use `@view` instead.", + DeprecationWarning, + ) + raise FunctionDeclarationException(f"Unknown decorator: {decorator.id}", decorator) + + else: + raise StructureException("Bad decorator syntax", decorator) + + if function_visibility is None: + raise FunctionDeclarationException( + f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", funcdef + ) + + if state_mutability is None: + # default to nonpayable + state_mutability = StateMutability.NONPAYABLE + + if state_mutability == StateMutability.PURE and nonreentrant_key is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + + # assert function_visibility is not None # mypy + # assert state_mutability is not None # mypy + return function_visibility, state_mutability, nonreentrant_key + + +def _parse_args( + funcdef: vy_ast.FunctionDef, is_interface: bool = False +) -> tuple[list[PositionalArg], list[KeywordArg]]: + argnames = set() # for checking uniqueness + n_total_args = len(funcdef.args.args) + n_positional_args = n_total_args - len(funcdef.args.defaults) + + positional_args = [] + keyword_args = [] + + for i, arg in enumerate(funcdef.args.args): + argname = arg.arg + if argname in ("gas", "value", "skip_contract_check", "default_return_value"): + raise ArgumentException( + f"Cannot use '{argname}' as a variable name in a function input", arg + ) + if argname in argnames: + raise ArgumentException(f"Function contains multiple inputs named {argname}", arg) + + if arg.annotation is None: + raise ArgumentException(f"Function argument '{argname}' is missing a type", arg) + + type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA) + + if i < n_positional_args: + positional_args.append(PositionalArg(argname, type_, ast_source=arg)) + else: + value = funcdef.args.defaults[i - n_positional_args] + if not check_kwargable(value): + raise StateAccessViolation("Value must be literal or environment variable", value) + validate_expected_type(value, type_) + keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) + + argnames.add(argname) + + return positional_args, keyword_args + + class MemberFunctionT(VyperType): """ Member function type definition. diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py new file mode 100644 index 0000000000..4622482951 --- /dev/null +++ b/vyper/semantics/types/module.py @@ -0,0 +1,332 @@ +from functools import cached_property +from typing import Optional + +from vyper import ast as vy_ast +from vyper.abi_types import ABI_Address, ABIType +from vyper.ast.validation import validate_call_args +from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.namespace import get_namespace +from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.primitives import AddressT +from vyper.semantics.types.user import EventT, StructT, _UserType + + +class InterfaceT(_UserType): + _type_members = {"address": AddressT()} + _is_prim_word = True + _as_array = True + _as_hashmap_key = True + _supports_external_calls = True + _attribute_in_annotation = True + + def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> None: + validate_unique_method_ids(list(functions.values())) + + members = functions | events | structs + + # sanity check: by construction, there should be no duplicates. + assert len(members) == len(functions) + len(events) + len(structs) + + super().__init__(functions) + + self._helper = VyperType(events | structs) + self._id = _id + self.functions = functions + self.events = events + self.structs = structs + + def get_type_member(self, attr, node): + # get an event or struct from this interface + return TYPE_T(self._helper.get_member(attr, node)) + + @property + def getter_signature(self): + return (), AddressT() + + @property + def abi_type(self) -> ABIType: + return ABI_Address() + + def __repr__(self): + return f"interface {self._id}" + + # when using the type itself (not an instance) in the call position + def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": + self._ctor_arg_types(node) + return self + + def _ctor_arg_types(self, node): + validate_call_args(node, 1) + validate_expected_type(node.args[0], AddressT()) + return [AddressT()] + + def _ctor_kwarg_types(self, node): + return {} + + # TODO x.validate_implements(other) + def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: + namespace = get_namespace() + unimplemented = [] + + def _is_function_implemented(fn_name, fn_type): + vyper_self = namespace["self"].typ + if fn_name not in vyper_self.members: + return False + s = vyper_self.members[fn_name] + if isinstance(s, ContractFunctionT): + to_compare = vyper_self.members[fn_name] + # this is kludgy, rework order of passes in ModuleNodeVisitor + elif isinstance(s, VarInfo) and s.is_public: + to_compare = s.decl_node._metadata["getter_type"] + else: + return False + + return to_compare.implements(fn_type) + + # check for missing functions + for name, type_ in self.functions.items(): + if not isinstance(type_, ContractFunctionT): + # ex. address + continue + + if not _is_function_implemented(name, type_): + unimplemented.append(name) + + # check for missing events + for name, event in self.events.items(): + if name not in namespace: + unimplemented.append(name) + continue + + if not isinstance(namespace[name], EventT): + unimplemented.append(f"{name} is not an event!") + if ( + namespace[name].event_id != event.event_id + or namespace[name].indexed != event.indexed + ): + unimplemented.append(f"{name} is not implemented! (should be {event})") + + if len(unimplemented) > 0: + # TODO: improve the error message for cases where the + # mismatch is small (like mutability, or just one argument + # is off, etc). + missing_str = ", ".join(sorted(unimplemented)) + raise InterfaceViolation( + f"Contract does not implement all interface functions or events: {missing_str}", + node, + ) + + def to_toplevel_abi_dict(self) -> list[dict]: + abi = [] + for event in self.events.values(): + abi += event.to_toplevel_abi_dict() + for func in self.functions.values(): + abi += func.to_toplevel_abi_dict() + return abi + + # helper function which performs namespace collision checking + @classmethod + def _from_lists( + cls, + name: str, + function_list: list[tuple[str, ContractFunctionT]], + event_list: list[tuple[str, EventT]], + struct_list: list[tuple[str, StructT]], + ) -> "InterfaceT": + functions = {} + events = {} + structs = {} + + seen_items: dict = {} + + for name, function in function_list: + if name in seen_items: + raise NamespaceCollision(f"multiple functions named '{name}'!", function.ast_def) + functions[name] = function + seen_items[name] = function + + for name, event in event_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + events[name] = event + seen_items[name] = event + + for name, struct in struct_list: + if name in seen_items: + raise NamespaceCollision( + f"multiple functions or events named '{name}'!", event.decl_node + ) + structs[name] = struct + seen_items[name] = struct + + return cls(name, functions, events, structs) + + @classmethod + def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": + """ + Generate an `InterfaceT` object from an ABI. + + Arguments + --------- + name : str + The name of the interface + abi : dict + Contract ABI + + Returns + ------- + InterfaceT + primitive interface type + """ + functions: list = [] + events: list = [] + + for item in [i for i in abi if i.get("type") == "function"]: + functions.append((item["name"], ContractFunctionT.from_abi(item))) + for item in [i for i in abi if i.get("type") == "event"]: + events.append((item["name"], EventT.from_abi(item))) + + structs: list = [] # no structs in json ABI (as of yet) + return cls._from_lists(name, functions, events, structs) + + @classmethod + def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": + """ + Generate an `InterfaceT` object from a Vyper ast node. + + Arguments + --------- + module_t: ModuleT + Vyper module type + Returns + ------- + InterfaceT + primitive interface type + """ + funcs = [] + + for node in module_t.function_defs: + func_t = node._metadata["func_type"] + if not func_t.is_external: + continue + funcs.append((node.name, func_t)) + + # add getters for public variables since they aren't yet in the AST + for node in module_t._module.get_children(vy_ast.VariableDecl): + if not node.is_public: + continue + getter = node._metadata["getter_type"] + funcs.append((node.target.id, getter)) + + events = [(node.name, node._metadata["event_type"]) for node in module_t.event_defs] + + structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] + + return cls._from_lists(module_t._id, funcs, events, structs) + + @classmethod + def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": + functions = [] + for node in node.body: + if not isinstance(node, vy_ast.FunctionDef): + raise StructureException("Interfaces can only contain function definitions", node) + if len(node.decorator_list) > 0: + raise StructureException( + "Function definition in interface cannot be decorated", node.decorator_list[0] + ) + functions.append((node.name, ContractFunctionT.from_InterfaceDef(node))) + + # no structs or events in InterfaceDefs + events: list = [] + structs: list = [] + + return cls._from_lists(node.name, functions, events, structs) + + +# Datatype to store all module information. +class ModuleT(VyperType): + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): + super().__init__() + + self._module = module + + self._id = name or module.path + + # compute the interface, note this has the side effect of checking + # for function collisions + self._helper = self.interface + + for f in self.function_defs: + # note: this checks for collisions + self.add_member(f.name, f._metadata["func_type"]) + + for e in self.event_defs: + # add the type of the event so it can be used in call position + self.add_member(e.name, TYPE_T(e._metadata["event_type"])) # type: ignore + + for s in self.struct_defs: + # add the type of the struct so it can be used in call position + self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for v in self.variable_decls: + self.add_member(v.target.id, v.target._metadata["varinfo"]) + + for i in self.import_stmts: + import_info = i._metadata["import_info"] + self.add_member(import_info.alias, import_info.typ) + + # __eq__ is very strict on ModuleT - object equality! this is because we + # don't want to reason about where a module came from (i.e. input bundle, + # search path, symlinked vs normalized path, etc.) + def __eq__(self, other): + return self is other + + def __hash__(self): + return hash(id(self)) + + def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": + return self._helper.get_member(key, node) + + # this is a property, because the function set changes after AST expansion + @property + def function_defs(self): + return self._module.get_children(vy_ast.FunctionDef) + + @property + def event_defs(self): + return self._module.get_children(vy_ast.EventDef) + + @property + def struct_defs(self): + return self._module.get_children(vy_ast.StructDef) + + @property + def import_stmts(self): + return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) + + @property + def variable_decls(self): + return self._module.get_children(vy_ast.VariableDecl) + + @cached_property + def variables(self): + # variables that this module defines, ex. + # `x: uint256` is a private storage variable named x + return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + + @cached_property + def immutables(self): + return [t for t in self.variables.values() if t.is_immutable] + + @cached_property + def immutable_section_bytes(self): + return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + + @cached_property + def interface(self): + return InterfaceT.from_ModuleT(self) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 6a2d3aae73..46dffbdec4 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from vyper import ast as vy_ast from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType @@ -68,7 +68,7 @@ def get_subscripted_type(self, node): return self.value_type @classmethod - def from_annotation(cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript]) -> "HashMapT": + def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT": if ( not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index) @@ -274,24 +274,32 @@ def compare_type(self, other): @classmethod def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": + # common error message, different ast locations + err_msg = "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]" + + if not isinstance(node, vy_ast.Subscript): + raise StructureException(err_msg, node) + if ( - not isinstance(node, vy_ast.Subscript) - or not isinstance(node.slice, vy_ast.Index) + not isinstance(node.slice, vy_ast.Index) or not isinstance(node.slice.value, vy_ast.Tuple) - or not isinstance(node.slice.value.elements[1], vy_ast.Int) or len(node.slice.value.elements) != 2 ): - raise StructureException( - "DynArray must be defined with base type and max length, e.g. DynArray[bool, 5]", - node, - ) + raise StructureException(err_msg, node.slice) + + length_node = node.slice.value.elements[1] + + if not isinstance(length_node, vy_ast.Int): + raise StructureException(err_msg, length_node) - value_type = type_from_annotation(node.slice.value.elements[0]) + length = length_node.value + + value_node = node.slice.value.elements[0] + value_type = type_from_annotation(value_node) if not value_type._as_darray: - raise StructureException(f"Arrays of {value_type} are not allowed", node) + raise StructureException(f"Arrays of {value_type} are not allowed", value_node) - max_length = node.slice.value.elements[1].value - return cls(value_type, max_length) + return cls(value_type, length) class TupleT(VyperType): @@ -333,7 +341,7 @@ def tuple_items(self): return list(enumerate(self.member_types)) @classmethod - def from_annotation(cls, node: vy_ast.Tuple) -> VyperType: + def from_annotation(cls, node: vy_ast.Tuple) -> "TupleT": values = node.elements types = tuple(type_from_annotation(v) for v in values) return cls(types) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ce82731c34..ef7e1d0eb4 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -1,27 +1,22 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional from vyper import ast as vy_ast -from vyper.abi_types import ABI_Address, ABI_GIntM, ABI_Tuple, ABIType +from vyper.abi_types import ABI_GIntM, ABI_Tuple, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import ( EnumDeclarationException, EventDeclarationException, - InterfaceViolation, InvalidAttribute, NamespaceCollision, StructureException, UnknownAttribute, VariableDeclarationException, ) -from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.data_locations import DataLocation -from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import VyperType -from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.subscriptable import HashMapT from vyper.semantics.types.utils import type_from_abi, type_from_annotation from vyper.utils import keccak256 @@ -29,12 +24,19 @@ # user defined type class _UserType(VyperType): + def __init__(self, members=None): + super().__init__(members=members) + def __eq__(self, other): return self is other - # TODO: revisit this once user types can be imported via modules def compare_type(self, other): - return super().compare_type(other) and self._id == other._id + # object exact comparison is a bit tricky here since we have + # to be careful to construct any given user type exactly + # only one time. however, the alternative requires reasoning + # about both the name and source (module or json abi) of + # the type. + return self is other def __hash__(self): return hash(id(self)) @@ -52,7 +54,8 @@ def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: raise EnumDeclarationException("Enums are limited to 256 members!") - super().__init__() + super().__init__(members=None) + self._id = name self._enum_members = members @@ -112,7 +115,7 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": ------- Enum """ - members: Dict = {} + members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): raise EnumDeclarationException("Enum must have members", base_node) @@ -135,7 +138,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: # TODO return None - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: # TODO return [] @@ -160,13 +163,21 @@ class EventT(_UserType): _invalid_locations = tuple(iter(DataLocation)) # not instantiable in any location - def __init__(self, name: str, arguments: dict, indexed: list) -> None: + def __init__( + self, + name: str, + arguments: dict, + indexed: list, + decl_node: Optional[vy_ast.VyperNode] = None, + ) -> None: super().__init__(members=arguments) self.name = name self.indexed = indexed assert len(self.indexed) == len(self.arguments) self.event_id = int(keccak256(self.signature.encode()).hex(), 16) + self.decl_node = decl_node + # backward compatible @property def arguments(self): @@ -187,7 +198,7 @@ def signature(self): return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments.values())})" @classmethod - def from_abi(cls, abi: Dict) -> "EventT": + def from_abi(cls, abi: dict) -> "EventT": """ Generate an `Event` object from an ABI interface. @@ -201,7 +212,7 @@ def from_abi(cls, abi: Dict) -> "EventT": Event object. """ members: dict = {} - indexed: List = [i["indexed"] for i in abi["inputs"]] + indexed: list = [i["indexed"] for i in abi["inputs"]] for item in abi["inputs"]: members[item["name"]] = type_from_abi(item) return cls(abi["name"], members, indexed) @@ -219,11 +230,11 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": ------- Event """ - members: Dict = {} - indexed: List = [] + members: dict = {} + indexed: list = [] if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - return EventT(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): @@ -252,14 +263,14 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": members[member_name] = type_from_annotation(annotation) - return cls(base_node.name, members, indexed) + return cls(base_node.name, members, indexed, base_node) def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): validate_expected_type(arg, expected) - def to_toplevel_abi_dict(self) -> List[Dict]: + def to_toplevel_abi_dict(self) -> list[dict]: return [ { "name": self.name, @@ -273,215 +284,6 @@ def to_toplevel_abi_dict(self) -> List[Dict]: ] -class InterfaceT(_UserType): - _type_members = {"address": AddressT()} - _is_prim_word = True - _as_array = True - _as_hashmap_key = True - - def __init__(self, _id: str, members: dict, events: dict) -> None: - validate_unique_method_ids(list(members.values())) # explicit list cast for mypy - super().__init__(members) - - self._id = _id - self.events = events - - @property - def getter_signature(self): - return (), AddressT() - - @property - def abi_type(self) -> ABIType: - return ABI_Address() - - def __repr__(self): - return f"{self._id}" - - # when using the type itself (not an instance) in the call position - # maybe rename to _ctor_call_return - def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": - self._ctor_arg_types(node) - - return self - - def _ctor_arg_types(self, node): - validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] - - def _ctor_kwarg_types(self, node): - return {} - - # TODO x.validate_implements(other) - def validate_implements(self, node: vy_ast.ImplementsDecl) -> None: - namespace = get_namespace() - unimplemented = [] - - def _is_function_implemented(fn_name, fn_type): - vyper_self = namespace["self"].typ - if fn_name not in vyper_self.members: - return False - s = vyper_self.members[fn_name] - if isinstance(s, ContractFunctionT): - to_compare = vyper_self.members[fn_name] - # this is kludgy, rework order of passes in ModuleNodeVisitor - elif isinstance(s, VarInfo) and s.is_public: - to_compare = s.decl_node._metadata["func_type"] - else: - return False - - return to_compare.implements(fn_type) - - # check for missing functions - for name, type_ in self.members.items(): - if not isinstance(type_, ContractFunctionT): - # ex. address - continue - - if not _is_function_implemented(name, type_): - unimplemented.append(name) - - # check for missing events - for name, event in self.events.items(): - if name not in namespace: - unimplemented.append(name) - continue - - if not isinstance(namespace[name], EventT): - unimplemented.append(f"{name} is not an event!") - if ( - namespace[name].event_id != event.event_id - or namespace[name].indexed != event.indexed - ): - unimplemented.append(f"{name} is not implemented! (should be {event})") - - if len(unimplemented) > 0: - # TODO: improve the error message for cases where the - # mismatch is small (like mutability, or just one argument - # is off, etc). - missing_str = ", ".join(sorted(unimplemented)) - raise InterfaceViolation( - f"Contract does not implement all interface functions or events: {missing_str}", - node, - ) - - def to_toplevel_abi_dict(self) -> List[Dict]: - abi = [] - for event in self.events.values(): - abi += event.to_toplevel_abi_dict() - for func in self.functions.values(): - abi += func.to_toplevel_abi_dict() - return abi - - @property - def functions(self): - return {k: v for (k, v) in self.members.items() if isinstance(v, ContractFunctionT)} - - @classmethod - def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT": - """ - Generate an `InterfaceT` object from an ABI. - - Arguments - --------- - name : str - The name of the interface - abi : dict - Contract ABI - - Returns - ------- - InterfaceT - primitive interface type - """ - members: Dict = {} - events: Dict = {} - - names = [i["name"] for i in abi if i.get("type") in ("event", "function")] - collisions = set(i for i in names if names.count(i) > 1) - if collisions: - collision_list = ", ".join(sorted(collisions)) - raise NamespaceCollision( - f"ABI '{name}' has multiple functions or events " - f"with the same name: {collision_list}" - ) - - for item in [i for i in abi if i.get("type") == "function"]: - members[item["name"]] = ContractFunctionT.from_abi(item) - for item in [i for i in abi if i.get("type") == "event"]: - events[item["name"]] = EventT.from_abi(item) - - return cls(name, members, events) - - # TODO: split me into from_InterfaceDef and from_Module - @classmethod - def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "InterfaceT": - """ - Generate an `InterfaceT` object from a Vyper ast node. - - Arguments - --------- - node : InterfaceDef | Module - Vyper ast node defining the interface - Returns - ------- - InterfaceT - primitive interface type - """ - if isinstance(node, vy_ast.Module): - members, events = _get_module_definitions(node) - elif isinstance(node, vy_ast.InterfaceDef): - members = _get_class_functions(node) - events = {} - else: - raise StructureException("Invalid syntax for interface definition", node) - - return cls(node.name, members, events) - - -def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[Dict, Dict]: - functions: Dict = {} - events: Dict = {} - for node in base_node.get_children(vy_ast.FunctionDef): - if "external" in [i.id for i in node.decorator_list if isinstance(i, vy_ast.Name)]: - func = ContractFunctionT.from_FunctionDef(node) - functions[node.name] = func - for node in base_node.get_children(vy_ast.VariableDecl, {"is_public": True}): - name = node.target.id - if name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{name}'", base_node - ) - functions[name] = ContractFunctionT.getter_from_VariableDecl(node) - for node in base_node.get_children(vy_ast.EventDef): - name = node.name - if name in functions or name in events: - raise NamespaceCollision( - f"Interface contains multiple objects named '{name}'", base_node - ) - events[name] = EventT.from_EventDef(node) - - return functions, events - - -def _get_class_functions(base_node: vy_ast.InterfaceDef) -> Dict[str, ContractFunctionT]: - functions = {} - for node in base_node.body: - if not isinstance(node, vy_ast.FunctionDef): - raise StructureException("Interfaces can only contain function definitions", node) - if node.name in functions: - raise NamespaceCollision( - f"Interface contains multiple functions named '{node.name}'", node - ) - if len(node.decorator_list) > 0: - raise StructureException( - "Function definition in interface cannot be decorated", node.decorator_list[0] - ) - functions[node.name] = ContractFunctionT.from_FunctionDef(node, is_interface=True) - - return functions - - class StructT(_UserType): _as_array = True @@ -516,7 +318,7 @@ def member_types(self): return self.members @classmethod - def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": + def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": """ Generate a `StructT` object from a Vyper ast node. @@ -531,7 +333,7 @@ def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT": """ struct_name = base_node.name - members: Dict[str, VyperType] = {} + members: dict[str, VyperType] = {} for node in base_node.body: if not isinstance(node, vy_ast.AnnAssign): raise StructureException( @@ -605,4 +407,4 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": f"Struct declaration does not define all fields: {', '.join(list(members))}", node ) - return StructT(self._id, self.member_types) + return self diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 1187080ca9..8d68a9fa01 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -6,12 +6,13 @@ InstantiationException, InvalidType, StructureException, + UndeclaredDefinition, UnknownType, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace -from vyper.semantics.types.base import VyperType +from vyper.semantics.types.base import TYPE_T, VyperType # TODO maybe this should be merged with .types/base.py @@ -75,7 +76,7 @@ def type_from_annotation( Arguments --------- - node : VyperNode + node: VyperNode Vyper ast node from the `annotation` member of a `VariableDecl` or `AnnAssign` node. Returns @@ -95,12 +96,6 @@ def type_from_annotation( def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: namespace = get_namespace() - def _failwith(type_name): - suggestions_str = get_levenshtein_error_suggestions(type_name, namespace, 0.3) - raise UnknownType( - f"No builtin or user-defined type named '{type_name}'. {suggestions_str}", node - ) from None - if isinstance(node, vy_ast.Tuple): tuple_t = namespace["$TupleT"] return tuple_t.from_annotation(node) @@ -116,11 +111,43 @@ def _failwith(type_name): return type_ctor.from_annotation(node) + # prepare a common error message + err_msg = f"'{node.node_source_code}' is not a type!" + + if isinstance(node, vy_ast.Attribute): + # ex. SomeModule.SomeStruct + + # sanity check - we only allow modules/interfaces to be + # imported as `Name`s currently. + if not isinstance(node.value, vy_ast.Name): + raise InvalidType(err_msg, node) + + try: + module_or_interface = namespace[node.value.id] # type: ignore + except UndeclaredDefinition: + raise InvalidType(err_msg, node) from None + + interface = module_or_interface + if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo + interface = module_or_interface.module_t.interface + + if not interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = interface.get_type_member(node.attr, node) + assert isinstance(type_t, TYPE_T) # sanity check + return type_t.typedef + if not isinstance(node, vy_ast.Name): # maybe handle this somewhere upstream in ast validation - raise InvalidType(f"'{node.node_source_code}' is not a type", node) - if node.id not in namespace: - _failwith(node.node_source_code) + raise InvalidType(err_msg, node) + + if node.id not in namespace: # type: ignore + suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) + raise UnknownType( + f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}", + node, + ) from None typ_ = namespace[node.id] if hasattr(typ_, "from_annotation"): @@ -138,7 +165,7 @@ def get_index_value(node: vy_ast.Index) -> int: Arguments --------- - node : vy_ast.Index + node: vy_ast.Index Vyper ast node from the `slice` member of a Subscript node. Must be an `Index` object (Vyper does not support `Slice` or `ExtSlice`). @@ -146,6 +173,7 @@ def get_index_value(node: vy_ast.Index) -> int: ------- int Literal integer value. + In the future, will return `None` if the subscript is an Ellipsis """ # this is imported to improve error messages # TODO: revisit this! diff --git a/vyper/utils.py b/vyper/utils.py index 0a2e1f831f..6816db9bae 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -51,6 +51,10 @@ def difference(self, other): def union(self, other): return self | other + def update(self, other): + for item in other: + self.add(item) + def __or__(self, other): return self.__class__(super().__or__(other)) @@ -162,11 +166,6 @@ def method_id(method_str: str) -> bytes: return keccak256(bytes(method_str, "utf-8"))[:4] -# map a string to only-alphanumeric chars -def mkalphanum(s): - return "".join([c if c.isalnum() else "_" for c in s]) - - def round_towards_zero(d: decimal.Decimal) -> int: # TODO double check if this can just be int(d) # (but either way keep this util function bc it's easier at a glance From 0cbc94d01d7be616329a9f70df15733818b4590c Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 13:08:34 -0500 Subject: [PATCH 14/18] feat: add short options `-v` and `-O` to the CLI (#3695) this commit adds `-v` and `-O` as aliases for `--verbose` and `--optimize`, respectively. --- vyper/cli/vyper_compile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 4f88812fa0..ec4681a814 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -111,6 +111,7 @@ def _parse_args(argv): ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") parser.add_argument( + "-O", "--optimize", help="Optimization flag (defaults to 'gas')", choices=["gas", "codesize", "none"], @@ -125,6 +126,7 @@ def _parse_args(argv): type=int, ) parser.add_argument( + "-v", "--verbose", help="Turn on compiler verbose output. " "Currently an alias for --traceback-limit but " From b0ea5b6f1c8cd8d09db6f37e9857f9b3837fb386 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sat, 16 Dec 2023 13:42:31 -0500 Subject: [PATCH 15/18] feat: search path resolution for cli (#3694) the current behavior is that the current directory does *not* get into the search path when `-p` is specified, which is annoying. (one would expect `vyper some/directory/some/file.vy` to compile no matter what `-p` is specified as). this commit also handles the addition of multiple search paths specified on the CLI, and adds a long `--path` option as an alternative to `-p`. --- .../cli/vyper_compile/test_compile_files.py | 36 ++++++++++++------- tests/unit/compiler/test_input_bundle.py | 13 +------ tests/utils.py | 12 +++++++ vyper/cli/vyper_compile.py | 19 ++++++---- 4 files changed, 49 insertions(+), 31 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index f6e3a51a4b..2a65d66835 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import working_directory from vyper.cli.vyper_compile import compile_files @@ -19,7 +20,7 @@ def test_combined_json_keys(tmp_path, make_file): "userdoc", "devdoc", } - compile_data = compile_files(["bar.vy"], ["combined_json"], root_folder=tmp_path) + compile_data = compile_files(["bar.vy"], ["combined_json"], paths=[tmp_path]) assert set(compile_data.keys()) == {Path("bar.vy"), "version"} assert set(compile_data[Path("bar.vy")].keys()) == combined_keys @@ -27,7 +28,7 @@ def test_combined_json_keys(tmp_path, make_file): def test_invalid_root_path(): with pytest.raises(FileNotFoundError): - compile_files([], [], root_folder="path/that/does/not/exist") + compile_files([], [], paths=["path/that/does/not/exist"]) CONTRACT_CODE = """ @@ -74,7 +75,7 @@ def test_import_same_folder(import_stmt, alias, tmp_path, make_file): make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("contracts/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) SUBFOLDER_IMPORT_STMT = [ @@ -98,7 +99,7 @@ def test_import_subfolder(import_stmt, alias, tmp_path, make_file): ) make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) OTHER_FOLDER_IMPORT_STMT = [ @@ -115,7 +116,7 @@ def test_import_other_folder(import_stmt, alias, tmp_path, make_file): foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("interfaces/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) def test_import_parent_folder(tmp_path, make_file): @@ -125,10 +126,21 @@ def test_import_parent_folder(tmp_path, make_file): ) make_file("IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], root_folder=tmp_path) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) # perform relative import outside of base folder - compile_files([foo], ["combined_json"], root_folder=tmp_path / "contracts") + compile_files([foo], ["combined_json"], paths=[tmp_path / "contracts"]) + + +def test_import_search_paths(tmp_path, make_file): + with working_directory(tmp_path): + contract_code = CONTRACT_CODE.format(import_stmt="from utils import IFoo", alias="IFoo") + contract_filename = "dir1/baz/foo.vy" + interface_filename = "dir2/utils/IFoo.vyi" + make_file(interface_filename, INTERFACE_CODE) + make_file(contract_filename, contract_code) + + assert compile_files([contract_filename], ["combined_json"], paths=["dir2"]) META_IMPORT_STMT = [ @@ -167,7 +179,7 @@ def be_known() -> ISelf.FooStruct: make_file("contracts/ISelf.vyi", interface_code) meta = make_file("contracts/Self.vy", code) - assert compile_files([meta], ["combined_json"], root_folder=tmp_path) + assert compile_files([meta], ["combined_json"], paths=[tmp_path]) # implement IFoo in another contract for fun @@ -187,7 +199,7 @@ def bar(_foo: address) -> {alias}.FooStruct: make_file("contracts/IFoo.vyi", INTERFACE_CODE) baz = make_file("contracts/Baz.vy", baz_code) - assert compile_files([baz], ["combined_json"], root_folder=tmp_path) + assert compile_files([baz], ["combined_json"], paths=[tmp_path]) def test_local_namespace(make_file, tmp_path): @@ -215,7 +227,7 @@ def test_local_namespace(make_file, tmp_path): for file_name in ("foo.vyi", "bar.vyi"): make_file(file_name, INTERFACE_CODE) - assert compile_files(paths, ["combined_json"], root_folder=tmp_path) + assert compile_files(paths, ["combined_json"], paths=[tmp_path]) def test_compile_outside_root_path(tmp_path, make_file): @@ -223,7 +235,7 @@ def test_compile_outside_root_path(tmp_path, make_file): make_file("ifoo.vyi", INTERFACE_CODE) foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) - assert compile_files([foo], ["combined_json"], root_folder=".") + assert compile_files([foo], ["combined_json"], paths=None) def test_import_library(tmp_path, make_file): @@ -244,4 +256,4 @@ def foo() -> uint256: make_file("lib.vy", library_source) contract_file = make_file("contract.vy", contract_source) - assert compile_files([contract_file], ["combined_json"], root_folder=tmp_path) is not None + assert compile_files([contract_file], ["combined_json"], paths=[tmp_path]) is not None diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index e26555b169..621b529722 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -1,10 +1,9 @@ -import contextlib import json -import os from pathlib import Path, PurePath import pytest +from tests.utils import working_directory from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, JSONInputBundle @@ -83,16 +82,6 @@ def test_load_abi(make_file, input_bundle, tmp_path): assert file == ABIInput(1, "foo.txt", path, "some string") -@contextlib.contextmanager -def working_directory(directory): - tmp = os.getcwd() - try: - os.chdir(directory) - yield - finally: - os.chdir(tmp) - - # check that unique paths give unique source ids def test_source_id_file_input(make_file, input_bundle, tmp_path): foopath = make_file("foo.vy", "contents") diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..0c89c39ff3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,12 @@ +import contextlib +import os + + +@contextlib.contextmanager +def working_directory(directory): + tmp = os.getcwd() + try: + os.chdir(directory) + yield + finally: + os.chdir(tmp) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index ec4681a814..25f1180098 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -140,7 +140,7 @@ def _parse_args(argv): ) parser.add_argument("--hex-ir", action="store_true") parser.add_argument( - "-p", help="Set the root path for contract imports", default=".", dest="root_folder" + "--path", "-p", help="Set the root path for contract imports", action="append", dest="paths" ) parser.add_argument("-o", help="Set the output path", dest="output_path") parser.add_argument( @@ -190,7 +190,7 @@ def _parse_args(argv): compiled = compile_files( args.input_files, output_formats, - args.root_folder, + args.paths, args.show_gas_estimates, settings, args.storage_layout, @@ -228,18 +228,23 @@ def exc_handler(contract_path: ContractPath, exception: Exception) -> None: def compile_files( input_files: list[str], output_formats: OutputFormats, - root_folder: str = ".", + paths: list[str] = None, show_gas_estimates: bool = False, settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, experimental_codegen: bool = False, ) -> dict: - root_path = Path(root_folder).resolve() - if not root_path.exists(): - raise FileNotFoundError(f"Invalid root path - '{root_path.as_posix()}' does not exist") + paths = paths or [] - input_bundle = FilesystemInputBundle([root_path]) + # lowest precedence search path is always `.` + search_paths = [Path(".")] + + for p in paths: + path = Path(p).resolve(strict=True) + search_paths.append(path) + + input_bundle = FilesystemInputBundle(search_paths) show_version = False if "combined_json" in output_formats: From 5a67b68b4ba20d050e9a4af913823cbbf0007539 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Wed, 20 Dec 2023 16:12:56 +0200 Subject: [PATCH 16/18] fix: type annotation of helper function (#3702) Fixed the signature of _append_return_for_stack_operand() to take the context not the basic block --- vyper/venom/ir_node_to_venom.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index e2ce28a8f9..0aaf6aba03 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -233,8 +233,9 @@ def _get_variable_from_address( def _append_return_for_stack_operand( - bb: IRBasicBlock, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable + ctx: IRFunction, symbols: SymbolTable, ret_ir: IRVariable, last_ir: IRVariable ) -> None: + bb = ctx.get_basic_block() if isinstance(ret_ir, IRLiteral): sym = symbols.get(f"&{ret_ir.value}", None) new_var = bb.append_instruction("alloca", 32, ret_ir) From 91659266c55ac564d1ed7784a189f5b59b868ced Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 14:41:53 -0500 Subject: [PATCH 17/18] chore: improve exception handling in IR generation (#3705) QOL improvement - improve unannotated exceptions that happen during IR generation to include source info. --- vyper/codegen/expr.py | 6 +++++- vyper/codegen/stmt.py | 11 +++++++++-- vyper/exceptions.py | 20 ++++++++++++++++++-- vyper/semantics/analysis/common.py | 13 +------------ 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 5870e64e98..d5ca5aceee 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -26,6 +26,7 @@ from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import ( + CodegenPanic, CompilerPanic, EvmVersionException, StructureException, @@ -33,6 +34,7 @@ TypeMismatch, UnimplementedException, VyperException, + tag_exceptions, ) from vyper.semantics.types import ( AddressT, @@ -79,7 +81,9 @@ def __init__(self, node, context): if fn is None: raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}", node) - self.ir_node = fn() + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + self.ir_node = fn() + if self.ir_node is None: raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index cc7a603b7c..601597771c 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -24,7 +24,13 @@ from vyper.codegen.expr import Expr from vyper.codegen.return_ import make_return_stmt from vyper.evm.address_space import MEMORY, STORAGE -from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure +from vyper.exceptions import ( + CodegenPanic, + CompilerPanic, + StructureException, + TypeCheckFailure, + tag_exceptions, +) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -39,7 +45,8 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}") with context.internal_memory_scope(): - self.ir_node = fn() + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + self.ir_node = fn() if self.ir_node is None: raise TypeCheckFailure("Statement node did not produce IR") diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 993c0a85eb..4846b1c3b1 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -1,3 +1,4 @@ +import contextlib import copy import textwrap import types @@ -322,8 +323,9 @@ class VyperInternalException(_BaseVyperException): def __str__(self): return ( - f"{self.message}\n\nThis is an unhandled internal compiler error. " - "Please create an issue on Github to notify the developers.\n" + f"{super().__str__()}\n\n" + "This is an unhandled internal compiler error. " + "Please create an issue on Github to notify the developers!\n" "https://github.com/vyperlang/vyper/issues/new?template=bug.md" ) @@ -354,3 +356,17 @@ class TypeCheckFailure(VyperInternalException): class InvalidABIType(VyperInternalException): """An internal routine constructed an invalid ABI type""" + + +@contextlib.contextmanager +def tag_exceptions( + node, fallback_exception_type=CompilerPanic, fallback_message="unhandled exception" +): + try: + yield + except _BaseVyperException as e: + if not e.annotations and not e.lineno: + raise e.with_annotation(node) from None + raise e from None + except Exception as e: + raise fallback_exception_type(fallback_message, node) from e diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 9d35aef2bd..198cffca5d 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -1,17 +1,6 @@ -import contextlib from typing import Tuple -from vyper.exceptions import StructureException, VyperException - - -@contextlib.contextmanager -def tag_exceptions(node): - try: - yield - except VyperException as e: - if not e.annotations and not e.lineno: - raise e.with_annotation(node) from None - raise e from None +from vyper.exceptions import StructureException, tag_exceptions class VyperNodeVisitorBase: From 3116e88c886efaf0ea4157852c6c90485357cee7 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Thu, 21 Dec 2023 02:41:42 +0200 Subject: [PATCH 18/18] feat: add new target-constrained jump instruction (#3687) this commit adds a new "djmp" instruction which allows jumping to one of multiple jump targets. it has been added in both the s-expr IR and venom IR. this removes the workarounds that we had to implement in the normalization pass and the cfg calculations. --------- Co-authored-by: Charles Cooper --- tests/unit/ast/test_pre_parser.py | 3 + .../compiler/venom/test_multi_entry_block.py | 41 +++++++++++ vyper/cli/vyper_compile.py | 7 +- vyper/codegen/core.py | 2 + vyper/codegen/module.py | 18 ++--- vyper/compiler/__init__.py | 2 - vyper/compiler/output.py | 3 + vyper/compiler/phases.py | 67 ++++++++--------- vyper/compiler/settings.py | 1 + vyper/ir/compile_ir.py | 7 ++ vyper/utils.py | 1 + vyper/venom/analysis.py | 9 --- vyper/venom/basicblock.py | 17 ++++- vyper/venom/function.py | 12 +--- vyper/venom/ir_node_to_venom.py | 9 +-- vyper/venom/passes/normalization.py | 71 ++++++------------- vyper/venom/venom_to_assembly.py | 15 ++-- 17 files changed, 158 insertions(+), 127 deletions(-) diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 3d072674f6..682c13ca84 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -184,6 +184,9 @@ def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_ve # None is sentinel here meaning that nothing changed compiler_data_settings = pre_parse_settings + # cannot be set via pragma, don't check + compiler_data_settings.experimental_codegen = False + assert compiler_data.settings == compiler_data_settings diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 6e7e6995d6..104697432b 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -95,3 +95,44 @@ def test_multi_entry_block_2(): assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" + + +def test_multi_entry_block_with_dynamic_jump(): + ctx = IRFunction() + + finish_label = IRLabel("finish") + target_label = IRLabel("target") + block_1_label = IRLabel("block_1", ctx) + + bb = ctx.get_basic_block() + op = bb.append_instruction("store", 10) + acc = bb.append_instruction("add", op, op) + bb.append_instruction("djmp", acc, finish_label, block_1_label) + + block_1 = IRBasicBlock(block_1_label, ctx) + ctx.append_basic_block(block_1) + acc = block_1.append_instruction("add", acc, op) + op = block_1.append_instruction("store", 10) + block_1.append_instruction("mstore", acc, op) + block_1.append_instruction("jnz", acc, finish_label, target_label) + + target_bb = IRBasicBlock(target_label, ctx) + ctx.append_basic_block(target_bb) + target_bb.append_instruction("mul", acc, acc) + target_bb.append_instruction("jmp", finish_label) + + finish_bb = IRBasicBlock(finish_label, ctx) + ctx.append_basic_block(finish_bb) + finish_bb.append_instruction("stop") + + calculate_cfg(ctx) + assert not ctx.normalized, "CFG should not be normalized" + + NormalizationPass.run_pass(ctx) + assert ctx.normalized, "CFG should be normalized" + + finish_bb = ctx.get_basic_block(finish_label.value) + cfg_in = list(finish_bb.cfg_in.keys()) + assert cfg_in[0].label.value == "target", "Should contain target" + assert cfg_in[1].label.value == "finish_split_global", "Should contain finish_split_global" + assert cfg_in[2].label.value == "finish_split_block_1", "Should contain finish_split_block_1" diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 25f1180098..3063a289ab 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -147,6 +147,7 @@ def _parse_args(argv): "--experimental-codegen", help="The compiler use the new IR codegen. This is an experimental feature.", action="store_true", + dest="experimental_codegen", ) args = parser.parse_args(argv) @@ -184,6 +185,9 @@ def _parse_args(argv): if args.evm_version: settings.evm_version = args.evm_version + if args.experimental_codegen: + settings.experimental_codegen = args.experimental_codegen + if args.verbose: print(f"cli specified: `{settings}`", file=sys.stderr) @@ -195,7 +199,6 @@ def _parse_args(argv): settings, args.storage_layout, args.no_bytecode_metadata, - args.experimental_codegen, ) if args.output_path: @@ -233,7 +236,6 @@ def compile_files( settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, - experimental_codegen: bool = False, ) -> dict: paths = paths or [] @@ -287,7 +289,6 @@ def compile_files( storage_layout_override=storage_layout_override, show_gas_estimates=show_gas_estimates, no_bytecode_metadata=no_bytecode_metadata, - experimental_codegen=experimental_codegen, ) ret[file_path] = output diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index e1d3ea12b4..503e0e2f3b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -892,6 +892,8 @@ def make_setter(left, right): _opt_level = OptimizationLevel.GAS +# FIXME: this is to get around the fact that we don't have a +# proper context object in the IR generation phase. @contextlib.contextmanager def anchor_opt_level(new_level: OptimizationLevel) -> Generator: """ diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index ef861e3953..98395a6a0c 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -311,21 +311,23 @@ def _selector_section_sparse(external_functions, module_ctx): ret.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER]) - jumpdest = IRnode.from_list(["mload", 0]) - # don't particularly like using `jump` here since it can cause - # issues for other backends, consider changing `goto` to allow - # dynamic jumps, or adding some kind of jumptable instruction - ret.append(["jump", jumpdest]) + jump_targets = [] - jumptable_data = ["data", "selector_buckets"] for i in range(n_buckets): if i in buckets: bucket_label = f"selector_bucket_{i}" - jumptable_data.append(["symbol", bucket_label]) + jump_targets.append(bucket_label) else: # empty bucket - jumptable_data.append(["symbol", "fallback"]) + jump_targets.append("fallback") + + jumptable_data = ["data", "selector_buckets"] + jumptable_data.extend(["symbol", label] for label in jump_targets) + + jumpdest = IRnode.from_list(["mload", 0]) + jump_instr = IRnode.from_list(["djump", jumpdest, *jump_targets]) + ret.append(jump_instr) ret.append(jumptable_data) for bucket_id, bucket in buckets.items(): diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 026c8369c5..c87814ba15 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -53,7 +53,6 @@ def compile_from_file_input( no_bytecode_metadata: bool = False, show_gas_estimates: bool = False, exc_handler: Optional[Callable] = None, - experimental_codegen: bool = False, ) -> dict: """ Main entry point into the compiler. @@ -107,7 +106,6 @@ def compile_from_file_input( storage_layout_override, show_gas_estimates, no_bytecode_metadata, - experimental_codegen, ) ret = {} diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 6d1e7ef70f..dc2a43720e 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -89,6 +89,9 @@ def build_ir_runtime_output(compiler_data: CompilerData) -> IRnode: def _ir_to_dict(ir_node): + # Currently only supported with IRnode and not VenomIR + if not isinstance(ir_node, IRnode): + return args = ir_node.args if len(args) > 0 or ir_node.value == "seq": return {ir_node.value: [_ir_to_dict(x) for x in args]} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index edffa9a85e..199bbbc3e5 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -21,6 +21,26 @@ DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") +def _merge_one(lhs, rhs, helpstr): + if lhs is not None and rhs is not None and lhs != rhs: + raise StructureException( + f"compiler settings indicate {helpstr} {lhs}, " f"but source pragma indicates {rhs}." + ) + return lhs if rhs is None else rhs + + +# TODO: does this belong as a method under Settings? +def _merge_settings(cli: Settings, pragma: Settings): + ret = Settings() + ret.evm_version = _merge_one(cli.evm_version, pragma.evm_version, "evm version") + ret.optimize = _merge_one(cli.optimize, pragma.optimize, "optimize") + ret.experimental_codegen = _merge_one( + cli.experimental_codegen, pragma.experimental_codegen, "experimental codegen" + ) + + return ret + + class CompilerData: """ Object for fetching and storing compiler data for a Vyper contract. @@ -59,7 +79,6 @@ def __init__( storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, - experimental_codegen: bool = False, ) -> None: """ Initialization method. @@ -76,11 +95,9 @@ def __init__( Show gas estimates for abi and ir output modes no_bytecode_metadata: bool, optional Do not add metadata to bytecode. Defaults to False - experimental_codegen: bool, optional - Use experimental codegen. Defaults to False """ # to force experimental codegen, uncomment: - # experimental_codegen = True + # settings.experimental_codegen = True if isinstance(file_input, str): file_input = FileInput( @@ -93,7 +110,6 @@ def __init__( self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - self.experimental_codegen = experimental_codegen self.settings = settings or Settings() self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) @@ -120,32 +136,13 @@ def _generate_ast(self): resolved_path=str(self.file_input.resolved_path), ) - # validate the compiler settings - # XXX: this is a bit ugly, clean up later - if settings.evm_version is not None: - if ( - self.settings.evm_version is not None - and self.settings.evm_version != settings.evm_version - ): - raise StructureException( - f"compiler settings indicate evm version {self.settings.evm_version}, " - f"but source pragma indicates {settings.evm_version}." - ) - - self.settings.evm_version = settings.evm_version - - if settings.optimize is not None: - if self.settings.optimize is not None and self.settings.optimize != settings.optimize: - raise StructureException( - f"compiler options indicate optimization mode {self.settings.optimize}, " - f"but source pragma indicates {settings.optimize}." - ) - self.settings.optimize = settings.optimize - - # ensure defaults + self.settings = _merge_settings(self.settings, settings) if self.settings.optimize is None: self.settings.optimize = OptimizationLevel.default() + if self.settings.experimental_codegen is None: + self.settings.experimental_codegen = False + # note self.settings.compiler_version is erased here as it is # not used after pre-parsing return ast @@ -184,8 +181,10 @@ def global_ctx(self) -> ModuleT: @cached_property def _ir_output(self): # fetch both deployment and runtime IR - nodes = generate_ir_nodes(self.global_ctx, self.settings.optimize) - if self.experimental_codegen: + nodes = generate_ir_nodes( + self.global_ctx, self.settings.optimize, self.settings.experimental_codegen + ) + if self.settings.experimental_codegen: return [generate_ir(nodes[0]), generate_ir(nodes[1])] else: return nodes @@ -211,7 +210,7 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: @cached_property def assembly(self) -> list: - if self.experimental_codegen: + if self.settings.experimental_codegen: return generate_assembly_experimental( self.ir_nodes, self.settings.optimize # type: ignore ) @@ -220,7 +219,7 @@ def assembly(self) -> list: @cached_property def assembly_runtime(self) -> list: - if self.experimental_codegen: + if self.settings.experimental_codegen: return generate_assembly_experimental( self.ir_runtime, self.settings.optimize # type: ignore ) @@ -294,7 +293,9 @@ def generate_folded_ast( return vyper_module_folded, symbol_tables -def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: +def generate_ir_nodes( + global_ctx: ModuleT, optimize: OptimizationLevel, experimental_codegen: bool +) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index d2c88a8592..51c8d64e41 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -42,6 +42,7 @@ class Settings: compiler_version: Optional[str] = None optimize: Optional[OptimizationLevel] = None evm_version: Optional[str] = None + experimental_codegen: Optional[bool] = None _DEBUG = False diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 1d3df8becb..8ce8c887f1 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -702,6 +702,13 @@ def _height_of(witharg): o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) o.extend(["_sym_" + code.args[0].value, "JUMP"]) return o + elif code.value == "djump": + o = [] + # "djump" compiles to a raw EVM jump instruction + jump_target = code.args[0] + o.extend(_compile_to_assembly(jump_target, withargs, existing_labels, break_dest, height)) + o.append("JUMP") + return o # push a literal symbol elif code.value == "symbol": return ["_sym_" + code.args[0].value] diff --git a/vyper/utils.py b/vyper/utils.py index 6816db9bae..a778a4e31b 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -331,6 +331,7 @@ class SizeLimits: "with", "label", "goto", + "djump", # "dynamic jump", i.e. constrained, multi-destination jump "~extcode", "~selfcode", "~calldata", diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 1a82ca85d0..6dfc3c3d7c 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -40,15 +40,6 @@ def calculate_cfg(ctx: IRFunction) -> None: else: entry_block = ctx.basic_blocks[0] - # TODO: Special case for the jump table of selector buckets and fallback. - # this will be cleaner when we introduce an "indirect jump" instruction - # for the selector table (which includes all possible targets). it will - # also clean up the code for normalization because it will not have to - # handle this case specially. - for bb in ctx.basic_blocks: - if "selector_bucket_" in bb.label.value or bb.label.value == "fallback": - bb.add_cfg_in(entry_block) - for bb in ctx.basic_blocks: assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 6f1c1c8ab3..9afaa5e6fd 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -4,7 +4,7 @@ from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "deploy", "stop"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -50,12 +50,15 @@ "invalid", "invoke", "jmp", + "djmp", "jnz", "log", ] ) -CFG_ALTERING_INSTRUCTIONS = frozenset(["jmp", "jnz", "call", "staticcall", "invoke", "deploy"]) +CFG_ALTERING_INSTRUCTIONS = frozenset( + ["jmp", "djmp", "jnz", "call", "staticcall", "invoke", "deploy"] +) if TYPE_CHECKING: from vyper.venom.function import IRFunction @@ -236,6 +239,16 @@ def replace_operands(self, replacements: dict) -> None: if operand in replacements: self.operands[i] = replacements[operand] + def replace_label_operands(self, replacements: dict) -> None: + """ + Update label operands with replacements. + replacements are represented using a dict: "key" is replaced by "value". + """ + replacements = {k.value: v for k, v in replacements.items()} + for i, operand in enumerate(self.operands): + if isinstance(operand, IRLabel) and operand.value in replacements: + self.operands[i] = replacements[operand.value] + def __repr__(self) -> str: s = "" if self.output: diff --git a/vyper/venom/function.py b/vyper/venom/function.py index e16b2ad6e6..665fa0c6c2 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -125,17 +125,11 @@ def normalized(self) -> bool: # TODO: this check could be: # `if len(in_bb.cfg_out) > 1: return False` # but the cfg is currently not calculated "correctly" for - # certain special instructions (deploy instruction and - # selector table indirect jumps). + # the special deploy instruction. for in_bb in bb.cfg_in: jump_inst = in_bb.instructions[-1] - if jump_inst.opcode != "jnz": - continue - if jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRLabel): - continue - - # The function is not normalized - return False + if jump_inst.opcode in ("jnz", "djmp"): + return False # The function is normalized return True diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 0aaf6aba03..9f5c23df0b 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -166,7 +166,6 @@ def _handle_self_call( ret_args.append(return_buf.value) # type: ignore bb = ctx.get_basic_block() - do_ret = func_t.return_type is not None if do_ret: invoke_ret = bb.append_invoke_instruction(ret_args, returns=True) # type: ignore @@ -453,9 +452,11 @@ def _convert_ir_basicblock(ctx, ir, symbols, variables, allocated_variables): ) # body elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) - elif ir.value == "jump": - arg_1 = _convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables) - ctx.get_basic_block().append_instruction("jmp", arg_1) + elif ir.value == "djump": + args = [_convert_ir_basicblock(ctx, ir.args[0], symbols, variables, allocated_variables)] + for target in ir.args[1:]: + args.append(IRLabel(target.value)) + ctx.get_basic_block().append_instruction("djmp", *args) _new_block(ctx) elif ir.value == "set": sym = ir.args[0] diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 90dd60e881..43e8d47235 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,5 +1,5 @@ -from vyper.exceptions import CompilerPanic -from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable +from vyper.venom.analysis import calculate_cfg +from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -19,72 +19,43 @@ def _split_basic_block(self, bb: IRBasicBlock) -> None: jump_inst = in_bb.instructions[-1] assert bb in in_bb.cfg_out - # Handle static and dynamic branching - if jump_inst.opcode == "jnz": - self._split_for_static_branch(bb, in_bb) - elif jump_inst.opcode == "jmp" and isinstance(jump_inst.operands[0], IRVariable): - self._split_for_dynamic_branch(bb, in_bb) - else: - continue - - self.changes += 1 - - def _split_for_static_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: - jump_inst = in_bb.instructions[-1] - for i, op in enumerate(jump_inst.operands): - if op == bb.label: - edge = i + # Handle branching + if jump_inst.opcode in ("jnz", "djmp"): + self._insert_split_basicblock(bb, in_bb) + self.changes += 1 break - else: - # none of the edges points to this bb - raise CompilerPanic("bad CFG") - - assert edge in (1, 2) # the arguments which can be labels - - split_bb = self._insert_split_basicblock(bb, in_bb) - - # Redirect the original conditional jump to the intermediary basic block - jump_inst.operands[edge] = split_bb.label - - def _split_for_dynamic_branch(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> None: - split_bb = self._insert_split_basicblock(bb, in_bb) - - # Update any affected labels in the data segment - # TODO: this DESTROYS the cfg! refactor so the translation of the - # selector table produces indirect jumps properly. - for inst in self.ctx.data_segment: - if inst.opcode == "db" and inst.operands[0] == bb.label: - inst.operands[0] = split_bb.label def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRBasicBlock: # Create an intermediary basic block and append it source = in_bb.label.value target = bb.label.value - split_bb = IRBasicBlock(IRLabel(f"{target}_split_{source}"), self.ctx) + + split_label = IRLabel(f"{target}_split_{source}") + in_terminal = in_bb.instructions[-1] + in_terminal.replace_label_operands({bb.label: split_label}) + + split_bb = IRBasicBlock(split_label, self.ctx) split_bb.append_instruction("jmp", bb.label) self.ctx.append_basic_block(split_bb) - # Rewire the CFG - # TODO: this is cursed code, it is necessary instead of just running - # calculate_cfg() because split_for_dynamic_branch destroys the CFG! - # ideally, remove this rewiring and just re-run calculate_cfg(). - split_bb.add_cfg_in(in_bb) - split_bb.add_cfg_out(bb) - in_bb.remove_cfg_out(bb) - in_bb.add_cfg_out(split_bb) - bb.remove_cfg_in(in_bb) - bb.add_cfg_in(split_bb) + # Update the labels in the data segment + for inst in self.ctx.data_segment: + if inst.opcode == "db" and inst.operands[0] == bb.label: + inst.operands[0] = split_bb.label + return split_bb def _run_pass(self, ctx: IRFunction) -> int: self.ctx = ctx self.changes = 0 + # Split blocks that need splitting for bb in ctx.basic_blocks: if len(bb.cfg_in) > 1: self._split_basic_block(bb) - # Sanity check - assert ctx.normalized, "Normalization pass failed" + # If we made changes, recalculate the cfg + if self.changes > 0: + calculate_cfg(ctx) return self.changes diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 8760e9aa63..0c32c3b816 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -261,7 +261,7 @@ def _generate_evm_for_instruction( # Step 1: Apply instruction special stack manipulations - if opcode in ["jmp", "jnz", "invoke"]: + if opcode in ["jmp", "djmp", "jnz", "invoke"]: operands = inst.get_non_label_operands() elif opcode == "alloca": operands = inst.operands[1:2] @@ -296,7 +296,7 @@ def _generate_evm_for_instruction( self._emit_input_operands(assembly, inst, operands, stack) # Step 3: Reorder stack - if opcode in ["jnz", "jmp"]: + if opcode in ["jnz", "djmp", "jmp"]: # prepare stack for jump into another basic block assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) b = next(iter(inst.parent.cfg_out)) @@ -344,11 +344,12 @@ def _generate_evm_for_instruction( assembly.append("JUMP") elif opcode == "jmp": - if isinstance(inst.operands[0], IRLabel): - assembly.append(f"_sym_{inst.operands[0].value}") - assembly.append("JUMP") - else: - assembly.append("JUMP") + assert isinstance(inst.operands[0], IRLabel) + assembly.append(f"_sym_{inst.operands[0].value}") + assembly.append("JUMP") + elif opcode == "djmp": + assert isinstance(inst.operands[0], IRVariable) + assembly.append("JUMP") elif opcode == "gt": assembly.append("GT") elif opcode == "lt":