From 3af5390001e59ba767378047add0df5e26193d9f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 7 May 2024 16:22:20 -0400 Subject: [PATCH 1/3] refactor[test]: change fixture scope in examples (#3995) roughly 5x performance increase per CPU in the `tests/functional/examples/` directory (testing locally: 27s -> 7s) --- tests/evm_backends/pyevm_env.py | 5 +++- tests/evm_backends/revm_env.py | 3 +++ .../examples/auctions/test_blind_auction.py | 2 +- .../auctions/test_simple_open_auction.py | 4 ++-- .../examples/company/test_company.py | 2 +- .../crowdfund/test_crowdfund_example.py | 2 +- .../examples/factory/test_factory.py | 6 ++--- .../test_on_chain_market_maker.py | 17 +++++++------- .../test_safe_remote_purchase.py | 4 ++-- .../examples/storage/test_advanced_storage.py | 2 +- .../examples/storage/test_storage.py | 2 +- .../examples/tokens/test_erc1155.py | 2 +- .../functional/examples/tokens/test_erc20.py | 4 ++-- .../examples/tokens/test_erc4626.py | 4 ++-- .../functional/examples/tokens/test_erc721.py | 2 +- .../functional/examples/voting/test_ballot.py | 2 +- .../functional/examples/wallet/test_wallet.py | 23 ++++++++----------- 17 files changed, 44 insertions(+), 42 deletions(-) diff --git a/tests/evm_backends/pyevm_env.py b/tests/evm_backends/pyevm_env.py index 6638308ff9..6c510278a7 100644 --- a/tests/evm_backends/pyevm_env.py +++ b/tests/evm_backends/pyevm_env.py @@ -1,3 +1,4 @@ +import copy import logging from contextlib import contextmanager from typing import Optional @@ -65,7 +66,7 @@ def _state(self) -> StateAPI: def _vm(self) -> VirtualMachineAPI: return self._chain.get_vm() - @cached_property + @property def _context(self) -> ExecutionContext: context = self._state.execution_context assert isinstance(context, ExecutionContext) # help mypy @@ -74,10 +75,12 @@ def _context(self) -> ExecutionContext: @contextmanager def anchor(self): snapshot_id = self._state.snapshot() + ctx = copy.copy(self._state.execution_context) try: yield finally: self._state.revert(snapshot_id) + self._state.execution_context = ctx def get_balance(self, address: str) -> int: return self._state.get_balance(_addr(address)) diff --git a/tests/evm_backends/revm_env.py b/tests/evm_backends/revm_env.py index c23a74e158..5c8b8aba08 100644 --- a/tests/evm_backends/revm_env.py +++ b/tests/evm_backends/revm_env.py @@ -31,6 +31,7 @@ def __init__( @contextmanager def anchor(self): snapshot_id = self._evm.snapshot() + block = BlockEnv(number=self._evm.env.block.number, timestamp=self._evm.env.block.timestamp) try: yield finally: @@ -40,6 +41,8 @@ def anchor(self): # snapshot_id is reverted by the transaction already. # revm updates are needed to make the journal more robust. pass + self._evm.set_block_env(block) + # self._evm.set_tx_env(tx) def get_balance(self, address: str) -> int: return self._evm.get_balance(address) diff --git a/tests/functional/examples/auctions/test_blind_auction.py b/tests/functional/examples/auctions/test_blind_auction.py index 06f0656f1d..eda84e1217 100644 --- a/tests/functional/examples/auctions/test_blind_auction.py +++ b/tests/functional/examples/auctions/test_blind_auction.py @@ -9,7 +9,7 @@ TEST_INCREMENT = 1 -@pytest.fixture +@pytest.fixture(scope="module") def auction_contract(env, get_contract): with open("examples/auctions/blind_auction.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py index 430294fa79..68b208a9b8 100644 --- a/tests/functional/examples/auctions/test_simple_open_auction.py +++ b/tests/functional/examples/auctions/test_simple_open_auction.py @@ -5,12 +5,12 @@ EXPIRY = 16 -@pytest.fixture +@pytest.fixture(scope="module") def auction_start(env): return env.timestamp + 1 -@pytest.fixture +@pytest.fixture(scope="module") def auction_contract(env, get_contract, auction_start): with open("examples/auctions/simple_open_auction.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/company/test_company.py b/tests/functional/examples/company/test_company.py index 35b4951471..e302735d7c 100644 --- a/tests/functional/examples/company/test_company.py +++ b/tests/functional/examples/company/test_company.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): with open("examples/stock/company.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/crowdfund/test_crowdfund_example.py b/tests/functional/examples/crowdfund/test_crowdfund_example.py index ff0d85d61e..510dd80c82 100644 --- a/tests/functional/examples/crowdfund/test_crowdfund_example.py +++ b/tests/functional/examples/crowdfund/test_crowdfund_example.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): with open("examples/crowdfund.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py index ecfc0bf557..5964d70478 100644 --- a/tests/functional/examples/factory/test_factory.py +++ b/tests/functional/examples/factory/test_factory.py @@ -4,7 +4,7 @@ import vyper -@pytest.fixture +@pytest.fixture(scope="module") def create_token(get_contract): with open("examples/tokens/ERC20.vy") as f: code = f.read() @@ -15,7 +15,7 @@ def create_token(): return create_token -@pytest.fixture +@pytest.fixture(scope="module") def create_exchange(env, get_contract): with open("examples/factory/Exchange.vy") as f: code = f.read() @@ -29,7 +29,7 @@ def create_exchange(token, factory): return create_exchange -@pytest.fixture +@pytest.fixture(scope="module") def factory(get_contract): with open("examples/factory/Exchange.vy") as f: code = f.read() diff --git a/tests/functional/examples/market_maker/test_on_chain_market_maker.py b/tests/functional/examples/market_maker/test_on_chain_market_maker.py index 071afce5d6..9dddc37ceb 100644 --- a/tests/functional/examples/market_maker/test_on_chain_market_maker.py +++ b/tests/functional/examples/market_maker/test_on_chain_market_maker.py @@ -3,14 +3,6 @@ from tests.utils import ZERO_ADDRESS - -@pytest.fixture -def market_maker(get_contract): - with open("examples/market_maker/on_chain_market_maker.vy") as f: - contract_code = f.read() - return get_contract(contract_code) - - TOKEN_NAME = "Vypercoin" TOKEN_SYMBOL = "FANG" TOKEN_DECIMALS = 18 @@ -18,7 +10,7 @@ def market_maker(get_contract): TOKEN_TOTAL_SUPPLY = TOKEN_INITIAL_SUPPLY * (10**TOKEN_DECIMALS) -@pytest.fixture +@pytest.fixture(scope="module") def erc20(get_contract): with open("examples/tokens/ERC20.vy") as f: contract_code = f.read() @@ -27,6 +19,13 @@ def erc20(get_contract): ) +@pytest.fixture(scope="module") +def market_maker(get_contract, erc20): + with open("examples/market_maker/on_chain_market_maker.vy") as f: + contract_code = f.read() + return get_contract(contract_code) + + def test_initial_state(market_maker): assert market_maker.totalEthQty() == 0 assert market_maker.totalTokenQty() == 0 diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index bb89375530..c4cfdc29eb 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -15,14 +15,14 @@ from eth_utils import to_wei -@pytest.fixture +@pytest.fixture(scope="module") def contract_code(get_contract): with open("examples/safe_remote_purchase/safe_remote_purchase.vy") as f: contract_code = f.read() return contract_code -@pytest.fixture +@pytest.fixture(scope="module") def get_balance(env): def get_balance(): a0, a1 = env.accounts[:2] diff --git a/tests/functional/examples/storage/test_advanced_storage.py b/tests/functional/examples/storage/test_advanced_storage.py index 51e5a1729e..4a41cb415c 100644 --- a/tests/functional/examples/storage/test_advanced_storage.py +++ b/tests/functional/examples/storage/test_advanced_storage.py @@ -4,7 +4,7 @@ INITIAL_VALUE = 4 -@pytest.fixture +@pytest.fixture(scope="module") def adv_storage_contract(get_contract): with open("examples/storage/advanced_storage.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/storage/test_storage.py b/tests/functional/examples/storage/test_storage.py index cdb71c5810..631bdc4dbe 100644 --- a/tests/functional/examples/storage/test_storage.py +++ b/tests/functional/examples/storage/test_storage.py @@ -3,7 +3,7 @@ INITIAL_VALUE = 4 -@pytest.fixture +@pytest.fixture(scope="module") def storage_contract(get_contract): with open("examples/storage/storage.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/tokens/test_erc1155.py b/tests/functional/examples/tokens/test_erc1155.py index afbfa8d56d..0a51c115bb 100644 --- a/tests/functional/examples/tokens/test_erc1155.py +++ b/tests/functional/examples/tokens/test_erc1155.py @@ -29,7 +29,7 @@ mintConflictBatch = [1, 2, 3] -@pytest.fixture +@pytest.fixture(scope="module") def erc1155(get_contract, env, tx_failed): owner, a1, a2, a3, a4, a5 = env.accounts[0:6] with open("examples/tokens/ERC1155ownable.vy") as f: diff --git a/tests/functional/examples/tokens/test_erc20.py b/tests/functional/examples/tokens/test_erc20.py index aef43768cb..b3dc2fe238 100644 --- a/tests/functional/examples/tokens/test_erc20.py +++ b/tests/functional/examples/tokens/test_erc20.py @@ -13,14 +13,14 @@ TOKEN_INITIAL_SUPPLY = 0 -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract): with open("examples/tokens/ERC20.vy") as f: code = f.read() return get_contract(code, *[TOKEN_NAME, TOKEN_SYMBOL, TOKEN_DECIMALS, TOKEN_INITIAL_SUPPLY]) -@pytest.fixture +@pytest.fixture(scope="module") def c_bad(get_contract): # Bad contract is used for overflow checks on totalSupply corrupted with open("examples/tokens/ERC20.vy") as f: diff --git a/tests/functional/examples/tokens/test_erc4626.py b/tests/functional/examples/tokens/test_erc4626.py index f0fb79efae..f6ff71f51a 100644 --- a/tests/functional/examples/tokens/test_erc4626.py +++ b/tests/functional/examples/tokens/test_erc4626.py @@ -7,7 +7,7 @@ TOKEN_INITIAL_SUPPLY = 0 -@pytest.fixture +@pytest.fixture(scope="module") def token(get_contract): with open("examples/tokens/ERC20.vy") as f: return get_contract( @@ -15,7 +15,7 @@ def token(get_contract): ) -@pytest.fixture +@pytest.fixture(scope="module") def vault(get_contract, token): with open("examples/tokens/ERC4626.vy") as f: return get_contract(f.read(), token.address) diff --git a/tests/functional/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py index 1ed26f64dc..3c1c5e71f9 100644 --- a/tests/functional/examples/tokens/test_erc721.py +++ b/tests/functional/examples/tokens/test_erc721.py @@ -11,7 +11,7 @@ ERC721_SIG = "0x80ac58cd" -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract, env): with open("examples/tokens/ERC721.vy") as f: code = f.read() diff --git a/tests/functional/examples/voting/test_ballot.py b/tests/functional/examples/voting/test_ballot.py index 2135feff72..9c82c5156b 100644 --- a/tests/functional/examples/voting/test_ballot.py +++ b/tests/functional/examples/voting/test_ballot.py @@ -6,7 +6,7 @@ PROPOSAL_2_NAME = b"Trump" + b"\x00" * 27 -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract): with open("examples/voting/ballot.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/wallet/test_wallet.py b/tests/functional/examples/wallet/test_wallet.py index c639974a31..6dfb838d8a 100644 --- a/tests/functional/examples/wallet/test_wallet.py +++ b/tests/functional/examples/wallet/test_wallet.py @@ -5,9 +5,10 @@ from eth_utils import is_same_address, to_bytes, to_checksum_address, to_int from tests.utils import ZERO_ADDRESS +from vyper.utils import keccak256 -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): a0, a1, a2, a3, a4, a5, a6 = env.accounts[:7] with open("examples/wallet/wallet.vy") as f: @@ -19,20 +20,16 @@ def c(env, get_contract): return c -@pytest.fixture -def sign(keccak): - def _sign(seq, to, value, data, key): - keys = KeyAPI() - comb = seq.to_bytes(32, "big") + b"\x00" * 12 + to + value.to_bytes(32, "big") + data - h1 = keccak(comb) - h2 = keccak(b"\x19Ethereum Signed Message:\n32" + h1) - sig = keys.ecdsa_sign(h2, key) - return [28 if sig.v == 1 else 27, sig.r, sig.s] +def sign(seq, to, value, data, key): + keys = KeyAPI() + comb = seq.to_bytes(32, "big") + b"\x00" * 12 + to + value.to_bytes(32, "big") + data + h1 = keccak256(comb) + h2 = keccak256(b"\x19Ethereum Signed Message:\n32" + h1) + sig = keys.ecdsa_sign(h2, key) + return [28 if sig.v == 1 else 27, sig.r, sig.s] - return _sign - -def test_approve(env, c, tx_failed, sign): +def test_approve(env, c, tx_failed): a0, a1, a2, a3, a4, a5, a6 = env.accounts[:7] k0, k1, k2, k3, k4, k5, k6, k7 = env._keys[:8] env.set_balance(a1, 10**18) From 75c75c5631222dd1b98c23f8cfeedc080e47a9e3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 8 May 2024 08:37:15 -0400 Subject: [PATCH 2/3] feat[tool]: archive format (#3891) this commit adds several output formats to aid with build reproducibility and source code verification: - `-f archive` - `-f solc_json` - `-f integrity` - `--base64` `-f archive` creates a "vyper archive" using the zipfile format. it emits the metadata associated with the build (settings, search path, compiler version, integrity hash) in the `MANIFEST/` folder inside the archive. `--base64` is only usable with `-f archive` and produces a base64-encoded archive (which is easier to copy-paste). both the base64 and binary versions of the archive round-trip. that is, if you provide an archive directly to the vyper compiler (e.g. `vyper contract.zip` or `vyper contract.zip.b64`), it should produce exactly the same output as running `vyper contract.vy` on the local machine with the same settings+environment that produced the archive. `-f solc_json` is for compatibility with standard json that a lot of verifiers and tooling currently uses. it uses the same "output bundle" machinery as `-f archive`, but spits out in "standard-json" format (consumable by `--standard-json`). both of these use an `OutputBundle` abstraction, which abstracts collecting the inputs to the build. these include - settings (whatever is on the Settings object) - search path - compiler version - integrity hash importantly, `OutputBundle` recovers and anonymizes search paths used during compilation. this is done to minimize leaking of user information in archives. however, it comes with a tradeoff -- because of how the anonymization works, it is possible to have a build where search paths are not recoverable (specifically, if a module "escapes" its package with too many `..`, the resulting anonymized path will be bogus). several methods were tried to prevent this, but in the end this method was chosen, which prioritizes minimizing leakage over handling edge cases. `-f integrity` produces an "integrity hash", which is basically a checksum over the source file inputs. it is intended to let consumers of the compiler know when any input in the dependency tree has changed and recompilation is necessary. it is conservative by design; it works by recursively hashing source text as opposed to AST or any other semantic representation of source code. it can also be used by tooling as a check to determine if the source tree in an archive is the same as expected. this would likely be an additional check in addition to bytecode comparison, since there could differences in source code (e.g. comments) which affect the integrity hash but not the bytecode. the integrity hash computation currently depends on all frontend analysis to complete. in theory, since it only depends on source code, it could be refactored into another preliminary pass in the compiler, whose sole job is to resolve (and hash) imports. however, it would be additional maintenance work. we could revisit if the performance of this method becomes reported as an issue (note: current numbers are that this method operates at roughly 2500 lloc per second). currently, there are two places where build reproducibility might fail - in checking the integrity hash of an archive or during archive construction itself (if there is a compile-time failure, this could happen for example if the user is trying to send a reproduction of an error). it was decided that the most user-friendly thing to do is to emit a warning in these cases, rather than adding additional compilation flags that control whether to bail out or not. the tentative canonical suffix for vyper archive (the zipfile version) is `.vyz`, although this is subject to change (several alternatives were also considered, including `.den` - as in "a den of vypers"!). --- tests/conftest.py | 9 +- .../cli/vyper_compile/test_compile_files.py | 182 ++++++++++-- .../unit/cli/vyper_json/test_compile_json.py | 5 - tests/unit/compiler/test_input_bundle.py | 21 +- tests/unit/compiler/test_pre_parser.py | 10 +- vyper/__init__.py | 3 + vyper/cli/compile_archive.py | 70 +++++ vyper/cli/vyper_compile.py | 81 +++++- vyper/cli/vyper_json.py | 27 +- vyper/compiler/__init__.py | 17 +- vyper/compiler/input_bundle.py | 69 ++++- vyper/compiler/output.py | 42 +++ vyper/compiler/output_bundle.py | 260 ++++++++++++++++++ vyper/compiler/phases.py | 76 ++--- vyper/compiler/settings.py | 73 +++++ vyper/exceptions.py | 4 + vyper/semantics/analysis/base.py | 8 +- vyper/semantics/analysis/module.py | 6 +- vyper/semantics/types/module.py | 37 ++- vyper/utils.py | 6 + 20 files changed, 869 insertions(+), 137 deletions(-) create mode 100644 vyper/cli/compile_archive.py create mode 100644 vyper/compiler/output_bundle.py diff --git a/tests/conftest.py b/tests/conftest.py index f330ca2911..59bb76c493 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,10 +56,11 @@ def pytest_addoption(parser): @pytest.fixture(scope="module") def output_formats(): output_formats = compiler.OUTPUT_FORMATS.copy() - del output_formats["bb"] - del output_formats["bb_runtime"] - del output_formats["cfg"] - del output_formats["cfg_runtime"] + + to_drop = ("bb", "bb_runtime", "cfg", "cfg_runtime", "archive", "archive_b64", "solc_json") + for s in to_drop: + del output_formats[s] + return output_formats diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index c697f2bc98..6467ff6dc9 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -1,14 +1,19 @@ import contextlib import sys +import zipfile from pathlib import Path import pytest -from tests.utils import working_directory from vyper.cli.vyper_compile import compile_files +from vyper.cli.vyper_json import compile_json +from vyper.compiler.input_bundle import FilesystemInputBundle +from vyper.compiler.output_bundle import OutputBundle +from vyper.compiler.phases import CompilerData +from vyper.utils import sha256sum -def test_combined_json_keys(tmp_path, make_file): +def test_combined_json_keys(chdir_tmp_path, make_file): make_file("bar.vy", "") combined_keys = { @@ -22,7 +27,7 @@ def test_combined_json_keys(tmp_path, make_file): "userdoc", "devdoc", } - compile_data = compile_files(["bar.vy"], ["combined_json"], paths=[tmp_path]) + compile_data = compile_files(["bar.vy"], ["combined_json"]) assert set(compile_data.keys()) == {Path("bar.vy"), "version"} assert set(compile_data[Path("bar.vy")].keys()) == combined_keys @@ -72,12 +77,12 @@ def bar() -> FooStruct: @pytest.mark.parametrize("import_stmt,alias", SAME_FOLDER_IMPORT_STMT) -def test_import_same_folder(import_stmt, alias, tmp_path, make_file): +def test_import_same_folder(import_stmt, alias, chdir_tmp_path, make_file): foo = "contracts/foo.vy" make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("contracts/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], paths=[tmp_path]) + assert compile_files([foo], ["combined_json"]) is not None SUBFOLDER_IMPORT_STMT = [ @@ -95,13 +100,13 @@ def test_import_same_folder(import_stmt, alias, tmp_path, make_file): @pytest.mark.parametrize("import_stmt, alias", SUBFOLDER_IMPORT_STMT) -def test_import_subfolder(import_stmt, alias, tmp_path, make_file): +def test_import_subfolder(import_stmt, alias, chdir_tmp_path, make_file): foo = make_file( "contracts/foo.vy", (CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) ) make_file("contracts/other/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], paths=[tmp_path]) + assert compile_files([foo], ["combined_json"]) is not None OTHER_FOLDER_IMPORT_STMT = [ @@ -118,7 +123,7 @@ def test_import_other_folder(import_stmt, alias, tmp_path, make_file): foo = make_file("contracts/foo.vy", CONTRACT_CODE.format(import_stmt=import_stmt, alias=alias)) make_file("interfaces/IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], paths=[tmp_path]) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) is not None def test_import_parent_folder(tmp_path, make_file): @@ -128,21 +133,20 @@ def test_import_parent_folder(tmp_path, make_file): ) make_file("IFoo.vyi", INTERFACE_CODE) - assert compile_files([foo], ["combined_json"], paths=[tmp_path]) + assert compile_files([foo], ["combined_json"], paths=[tmp_path]) is not None # perform relative import outside of base folder compile_files([foo], ["combined_json"], paths=[tmp_path / "contracts"]) -def test_import_search_paths(tmp_path, make_file): - with working_directory(tmp_path): - contract_code = CONTRACT_CODE.format(import_stmt="from utils import IFoo", alias="IFoo") - contract_filename = "dir1/baz/foo.vy" - interface_filename = "dir2/utils/IFoo.vyi" - make_file(interface_filename, INTERFACE_CODE) - make_file(contract_filename, contract_code) +def test_import_search_paths(chdir_tmp_path, make_file): + contract_code = CONTRACT_CODE.format(import_stmt="from utils import IFoo", alias="IFoo") + contract_filename = "dir1/baz/foo.vy" + interface_filename = "dir2/utils/IFoo.vyi" + make_file(interface_filename, INTERFACE_CODE) + make_file(contract_filename, contract_code) - assert compile_files([contract_filename], ["combined_json"], paths=["dir2"]) + assert compile_files([contract_filename], ["combined_json"], paths=["dir2"]) is not None META_IMPORT_STMT = [ @@ -181,7 +185,7 @@ def be_known() -> ISelf.FooStruct: make_file("contracts/ISelf.vyi", interface_code) meta = make_file("contracts/Self.vy", code) - assert compile_files([meta], ["combined_json"], paths=[tmp_path]) + assert compile_files([meta], ["combined_json"], paths=[tmp_path]) is not None # implement IFoo in another contract for fun @@ -201,10 +205,10 @@ def bar(_foo: address) -> {alias}.FooStruct: make_file("contracts/IFoo.vyi", INTERFACE_CODE) baz = make_file("contracts/Baz.vy", baz_code) - assert compile_files([baz], ["combined_json"], paths=[tmp_path]) + assert compile_files([baz], ["combined_json"], paths=[tmp_path]) is not None -def test_local_namespace(make_file, tmp_path): +def test_local_namespace(make_file, chdir_tmp_path): # interface code namespaces should be isolated # all of these contract should be able to compile together codes = [ @@ -229,7 +233,7 @@ def test_local_namespace(make_file, tmp_path): for file_name in ("foo.vyi", "bar.vyi"): make_file(file_name, INTERFACE_CODE) - assert compile_files(paths, ["combined_json"], paths=[tmp_path]) + assert compile_files(paths, ["combined_json"]) is not None def test_compile_outside_root_path(tmp_path, make_file): @@ -237,7 +241,7 @@ def test_compile_outside_root_path(tmp_path, make_file): make_file("ifoo.vyi", INTERFACE_CODE) foo = make_file("foo.vy", CONTRACT_CODE.format(import_stmt="import ifoo as IFoo", alias="IFoo")) - assert compile_files([foo], ["combined_json"], paths=None) + assert compile_files([foo], ["combined_json"], paths=None) is not None def test_import_library(tmp_path, make_file): @@ -270,23 +274,153 @@ def mock_sys_path(path): sys.path.pop() -def test_import_sys_path(tmp_path_factory, make_file): +@pytest.fixture +def input_files(tmp_path_factory, make_file, chdir_tmp_path): library_source = """ @internal def foo() -> uint256: return block.number + 1 + """ + json_source = """ +[ + { + "stateMutability": "nonpayable", + "type": "function", + "name": "test_json", + "inputs": [ { "name": "", "type": "uint256" } ], + "outputs": [ { "name": "", "type": "uint256" } ] + } +] """ contract_source = """ import lib +import jsonabi @external def foo() -> uint256: return lib.foo() + +@external +def bar(x: uint256) -> uint256: + return extcall jsonabi(msg.sender).test_json(x) """ - tmpdir = tmp_path_factory.mktemp("test-sys-path") + tmpdir = tmp_path_factory.mktemp("fake-package") with open(tmpdir / "lib.vy", "w") as f: f.write(library_source) + with open(tmpdir / "jsonabi.json", "w") as f: + f.write(json_source) contract_file = make_file("contract.vy", contract_source) + + return (tmpdir, tmpdir / "lib.vy", tmpdir / "jsonabi.json", contract_file) + + +def test_import_sys_path(input_files): + tmpdir, _, _, contract_file = input_files with mock_sys_path(tmpdir): assert compile_files([contract_file], ["combined_json"]) is not None + + +def test_archive_output(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + s = compile_files([contract_file], ["archive"], paths=search_paths) + archive_bytes = s[contract_file]["archive"] + + archive_path = Path("foo.zip") + with archive_path.open("wb") as f: + f.write(archive_bytes) + + assert zipfile.is_zipfile(archive_path) + + # compare compiling the two input bundles + out = compile_files([contract_file], ["integrity", "bytecode"], paths=search_paths) + out2 = compile_files([archive_path], ["integrity", "bytecode"]) + assert out[contract_file] == out2[archive_path] + + +def test_archive_b64_output(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + out = compile_files( + [contract_file], ["archive_b64", "integrity", "bytecode"], paths=search_paths + ) + + archive_b64 = out[contract_file].pop("archive_b64") + + archive_path = Path("foo.zip.b64") + with archive_path.open("w") as f: + f.write(archive_b64) + + # compare compiling the two input bundles + out2 = compile_files([archive_path], ["integrity", "bytecode"]) + assert out[contract_file] == out2[archive_path] + + +def test_solc_json_output(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + out = compile_files([contract_file], ["solc_json"], paths=search_paths) + + json_input = out[contract_file]["solc_json"] + + # check that round-tripping solc_json thru standard json produces + # the same as compiling directly + json_out = compile_json(json_input)["contracts"]["contract.vy"] + json_out_bytecode = json_out["contract"]["evm"]["bytecode"]["object"] + + out2 = compile_files([contract_file], ["integrity", "bytecode"], paths=search_paths) + + assert out2[contract_file]["bytecode"] == json_out_bytecode + + +# maybe this belongs in tests/unit/compiler? +def test_integrity_sum(input_files): + tmpdir, library_file, jsonabi_file, contract_file = input_files + search_paths = [".", tmpdir] + + out = compile_files([contract_file], ["integrity"], paths=search_paths) + + with library_file.open() as f, contract_file.open() as g, jsonabi_file.open() as h: + library_contents = f.read() + contract_contents = g.read() + jsonabi_contents = h.read() + + contract_hash = sha256sum(contract_contents) + library_hash = sha256sum(library_contents) + jsonabi_hash = sha256sum(jsonabi_contents) + expected = sha256sum(contract_hash + sha256sum(library_hash) + jsonabi_hash) + assert out[contract_file]["integrity"] == expected + + +# does this belong in tests/unit/compiler? +def test_archive_search_path(tmp_path_factory, make_file, chdir_tmp_path): + lib1 = """ +x: uint256 + """ + lib2 = """ +y: uint256 + """ + dir1 = tmp_path_factory.mktemp("dir1") + dir2 = tmp_path_factory.mktemp("dir2") + make_file(dir1 / "lib.vy", lib1) + make_file(dir2 / "lib.vy", lib2) + + main = """ +import lib + """ + pwd = Path(".") + make_file(pwd / "main.vy", main) + for search_paths in ([pwd, dir1, dir2], [pwd, dir2, dir1]): + input_bundle = FilesystemInputBundle(search_paths) + file_input = input_bundle.load_file("main.vy") + + # construct CompilerData manually + compiler_data = CompilerData(file_input, input_bundle) + output_bundle = OutputBundle(compiler_data) + + used_dir = search_paths[-1].stem # either dir1 or dir2 + assert output_bundle.used_search_paths == [".", "0/" + used_dir] diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 82c332d185..f4c93c08bf 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -227,11 +227,6 @@ def test_different_outputs(input_bundle, input_json): assert foo["evm"]["methodIdentifiers"] == method_identifiers -def test_root_folder_not_exists(input_json): - with pytest.raises(FileNotFoundError): - compile_json(input_json, root_folder="/path/that/does/not/exist") - - def test_wrong_language(): with pytest.raises(JSONError): compile_json({"language": "Solidity"}) diff --git a/tests/unit/compiler/test_input_bundle.py b/tests/unit/compiler/test_input_bundle.py index 621b529722..74fd04f16e 100644 --- a/tests/unit/compiler/test_input_bundle.py +++ b/tests/unit/compiler/test_input_bundle.py @@ -73,13 +73,13 @@ def test_load_abi(make_file, input_bundle, tmp_path): file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, "foo.json", path, "some string") + assert file == ABIInput(0, "foo.json", path, contents, "some string") # suffix doesn't matter path = make_file("foo.txt", contents) file = input_bundle.load_file("foo.txt") assert isinstance(file, ABIInput) - assert file == ABIInput(1, "foo.txt", path, "some string") + assert file == ABIInput(1, "foo.txt", path, contents, "some string") # check that unique paths give unique source ids @@ -126,29 +126,31 @@ def test_source_id_json_input(make_file, input_bundle, tmp_path): file = input_bundle.load_file("foo.json") assert isinstance(file, ABIInput) - assert file == ABIInput(0, "foo.json", foopath, "some string") + assert file == ABIInput(0, "foo.json", foopath, contents, "some string") file2 = input_bundle.load_file("bar.json") assert isinstance(file2, ABIInput) - assert file2 == ABIInput(1, "bar.json", barpath, ["some list"]) + assert file2 == ABIInput(1, "bar.json", barpath, contents2, ["some list"]) file3 = input_bundle.load_file("foo.json") assert file3.source_id == 0 - assert file3 == ABIInput(0, "foo.json", foopath, "some string") + assert file3 == ABIInput(0, "foo.json", foopath, contents, "some string") # test source id is stable across different search paths with working_directory(tmp_path): with input_bundle.search_path(Path(".")): file4 = input_bundle.load_file("foo.json") assert file4.source_id == 0 - assert file4 == ABIInput(0, "foo.json", foopath, "some string") + assert file4 == ABIInput(0, "foo.json", foopath, contents, "some string") # test source id is stable even when requested filename is different with working_directory(tmp_path.parent): with input_bundle.search_path(Path(".")): file5 = input_bundle.load_file(Path(tmp_path.stem) / "foo.json") assert file5.source_id == 0 - assert file5 == ABIInput(0, Path(tmp_path.stem) / "foo.json", foopath, "some string") + assert file5 == ABIInput( + 0, Path(tmp_path.stem) / "foo.json", foopath, contents, "some string" + ) # test some pathological case where the file changes underneath @@ -238,7 +240,8 @@ def test_json_input_abi(): input_bundle = JSONInputBundle(files, [PurePath(".")]) file = input_bundle.load_file(foopath) - assert file == ABIInput(0, foopath, foopath, some_abi) + abi_contents = json.dumps({"abi": some_abi}) + assert file == ABIInput(0, foopath, foopath, abi_contents, some_abi) file = input_bundle.load_file(barpath) - assert file == ABIInput(1, barpath, barpath, some_abi) + assert file == ABIInput(1, barpath, barpath, some_abi_str, some_abi) diff --git a/tests/unit/compiler/test_pre_parser.py b/tests/unit/compiler/test_pre_parser.py index 128b6b16eb..f867937046 100644 --- a/tests/unit/compiler/test_pre_parser.py +++ b/tests/unit/compiler/test_pre_parser.py @@ -2,7 +2,7 @@ from vyper.compiler import compile_code from vyper.compiler.settings import OptimizationLevel, Settings -from vyper.exceptions import StructureException, SyntaxException +from vyper.exceptions import SyntaxException def test_semicolon_prohibited(get_contract): @@ -96,7 +96,7 @@ def test_evm_version_check(assert_compile_failed): assert compile_code(code, settings=Settings(evm_version="london")) is not None # should fail if compile options indicate different evm version # from source pragma - with pytest.raises(StructureException): + with pytest.raises(ValueError): compile_code(code, settings=Settings(evm_version="shanghai")) @@ -107,9 +107,9 @@ def test_optimization_mode_check(): assert compile_code(code, settings=Settings(optimize=None)) # should fail if compile options indicate different optimization mode # from source pragma - with pytest.raises(StructureException): + with pytest.raises(ValueError): compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS)) - with pytest.raises(StructureException): + with pytest.raises(ValueError): compile_code(code, settings=Settings(optimize=OptimizationLevel.NONE)) @@ -119,7 +119,7 @@ def test_optimization_mode_check_none(): """ assert compile_code(code, settings=Settings(optimize=None)) # "none" conflicts with "gas" - with pytest.raises(StructureException): + with pytest.raises(ValueError): compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS)) diff --git a/vyper/__init__.py b/vyper/__init__.py index 5bb6469757..5e36cbb69d 100644 --- a/vyper/__init__.py +++ b/vyper/__init__.py @@ -21,3 +21,6 @@ __version__ = _version(__name__) except PackageNotFoundError: from vyper.version import version as __version__ + +# pep440 version with commit hash +__long_version__ = f"{__version__}+commit.{__commit__}" diff --git a/vyper/cli/compile_archive.py b/vyper/cli/compile_archive.py new file mode 100644 index 0000000000..1b52343c1c --- /dev/null +++ b/vyper/cli/compile_archive.py @@ -0,0 +1,70 @@ +# not an entry point! +# utility functions to handle compiling from a "vyper archive" + +import base64 +import binascii +import io +import json +import zipfile +from pathlib import PurePath + +from vyper.compiler import compile_from_file_input +from vyper.compiler.input_bundle import FileInput, ZipInputBundle +from vyper.compiler.settings import Settings, merge_settings +from vyper.exceptions import BadArchive + + +class NotZipInput(Exception): + pass + + +def compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata): + with open(file_name, "rb") as f: + bcontents = f.read() + + try: + buf = io.BytesIO(bcontents) + archive = zipfile.ZipFile(buf, mode="r") + except zipfile.BadZipFile as e1: + try: + # `validate=False` - tools like base64 can generate newlines + # for readability. validate=False does the "correct" thing and + # simply ignores these + bcontents = base64.b64decode(bcontents, validate=False) + buf = io.BytesIO(bcontents) + archive = zipfile.ZipFile(buf, mode="r") + except (zipfile.BadZipFile, binascii.Error): + raise NotZipInput() from e1 + + fcontents = archive.read("MANIFEST/compilation_targets").decode("utf-8") + compilation_targets = fcontents.splitlines() + + if len(compilation_targets) != 1: + raise BadArchive("Multiple compilation targets not supported!") + + input_bundle = ZipInputBundle(archive) + + mainpath = PurePath(compilation_targets[0]) + file = input_bundle.load_file(mainpath) + assert isinstance(file, FileInput) # mypy hint + + settings = settings or Settings() + + archive_settings_txt = archive.read("MANIFEST/settings.json").decode("utf-8") + archive_settings = Settings.from_dict(json.loads(archive_settings_txt)) + + integrity = archive.read("MANIFEST/integrity").decode("utf-8").strip() + + settings = merge_settings( + settings, archive_settings, lhs_source="command line", rhs_source="archive settings" + ) + + # TODO: validate integrity sum (probably in CompilerData) + return compile_from_file_input( + file, + input_bundle=input_bundle, + output_formats=output_formats, + integrity_sum=integrity, + settings=settings, + no_bytecode_metadata=no_bytecode_metadata, + ) diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index d390e3bb8a..bb2cfa34b8 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse import json +import os import sys import warnings from pathlib import Path @@ -10,6 +11,7 @@ import vyper.codegen.ir_node as ir_node import vyper.evm.opcodes as evm from vyper.cli import vyper_json +from vyper.cli.compile_archive import NotZipInput, compile_from_zip from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT, OptimizationLevel, Settings from vyper.typing import ContractPath, OutputFormats @@ -39,6 +41,8 @@ ir_json - Intermediate representation in JSON format ir_runtime - Intermediate representation of runtime bytecode in list format asm - Output the EVM assembly of the deployable bytecode +archive - Output the build as an archive file +solc_json - Output the build in solc json format """ combined_json_outputs = [ @@ -64,6 +68,20 @@ def _cli_helper(f, output_formats, compiled): print(json.dumps(compiled), file=f) return + if output_formats == ("archive",): + for contract_data in compiled.values(): + assert list(contract_data.keys()) == ["archive"] + out = contract_data["archive"] + if f.isatty() and isinstance(out, bytes): + raise RuntimeError( + "won't write raw bytes to a tty! (if you want to base64" + " encode the archive, you can try `-f archive` in" + " conjunction with `--base64`)" + ) + else: + f.write(out) + return + for contract_data in compiled.values(): for data in contract_data.values(): if isinstance(data, (list, dict)): @@ -85,9 +103,7 @@ def _parse_args(argv): formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument("input_files", help="Vyper sourcecode to compile", nargs="+") - parser.add_argument( - "--version", action="version", version=f"{vyper.__version__}+commit.{vyper.__commit__}" - ) + parser.add_argument("--version", action="version", version=vyper.__long_version__) parser.add_argument( "--show-gas-estimates", help="Show gas estimates in abi and ir output mode.", @@ -108,6 +124,11 @@ def _parse_args(argv): dest="evm_version", ) parser.add_argument("--no-optimize", help="Do not optimize", action="store_true") + parser.add_argument( + "--base64", + help="Base64 encode the output (only valid in conjunction with `-f archive`", + action="store_true", + ) parser.add_argument( "-O", "--optimize", @@ -150,6 +171,9 @@ def _parse_args(argv): dest="experimental_codegen", ) parser.add_argument("--enable-decimals", help="Enable decimals", action="store_true") + parser.add_argument( + "--disable-sys-path", help="Disable the use of sys.path", action="store_true" + ) args = parser.parse_args(argv) @@ -170,6 +194,12 @@ def _parse_args(argv): output_formats = tuple(uniq(args.format.split(","))) + if args.base64 and output_formats != ("archive",): + raise ValueError("Cannot use `--base64` except with `-f archive`") + + if args.base64: + output_formats = ("archive_b64",) + if args.no_optimize and args.optimize: raise ValueError("Cannot use `--no-optimize` and `--optimize` at the same time!") @@ -195,22 +225,30 @@ def _parse_args(argv): if args.verbose: print(f"cli specified: `{settings}`", file=sys.stderr) + include_sys_path = not args.disable_sys_path + compiled = compile_files( args.input_files, output_formats, args.paths, + include_sys_path, args.show_gas_estimates, settings, args.storage_layout, args.no_bytecode_metadata, ) + mode = "w" + if output_formats == ("archive",): + mode = "wb" + if args.output_path: - with open(args.output_path, "w") as f: + with open(args.output_path, mode) as f: _cli_helper(f, output_formats, compiled) else: - f = sys.stdout - _cli_helper(f, output_formats, compiled) + # https://stackoverflow.com/a/54073813 + with os.fdopen(sys.stdout.fileno(), mode, closefd=False) as f: + _cli_helper(f, output_formats, compiled) def uniq(seq: Iterable[T]) -> Iterator[T]: @@ -232,7 +270,7 @@ def exc_handler(contract_path: ContractPath, exception: Exception) -> None: raise exception -def get_search_paths(paths: list[str] = None) -> list[Path]: +def get_search_paths(paths: list[str] = None, include_sys_path=True) -> list[Path]: # given `paths` input, get the full search path, including # the system search path. paths = paths or [] @@ -241,7 +279,9 @@ def get_search_paths(paths: list[str] = None) -> list[Path]: # note python sys path uses opposite resolution order from us # (first in list is highest precedence; we give highest precedence # to the last in the list) - search_paths = [Path(p) for p in reversed(sys.path)] + search_paths = [] + if include_sys_path: + search_paths = [Path(p) for p in reversed(sys.path)] if Path(".") not in search_paths: search_paths.append(Path(".")) @@ -257,12 +297,13 @@ def compile_files( input_files: list[str], output_formats: OutputFormats, paths: list[str] = None, + include_sys_path: bool = True, show_gas_estimates: bool = False, settings: Optional[Settings] = None, storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, ) -> dict: - search_paths = get_search_paths(paths) + search_paths = get_search_paths(paths, include_sys_path) input_bundle = FilesystemInputBundle(search_paths) show_version = False @@ -272,6 +313,11 @@ def compile_files( output_formats = combined_json_outputs show_version = True + # formats which can only be requested as a single output format + for c in ("solc_json", "archive"): + if c in output_formats and len(output_formats) > 1: + raise ValueError(f"If using {c} it must be the only output format requested") + translate_map = { "abi_python": "abi", "json": "abi", @@ -294,6 +340,23 @@ def compile_files( for file_name in input_files: file_path = Path(file_name) + + try: + # try to compile in zipfile mode if it's a zip file, falling back + # to regular mode if it's not. + # we allow this instead of requiring a different mode (like + # `--zip`) so that verifier pipelines do not need a different + # workflow for archive files and single-file contracts. + output = compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata) + ret[file_path] = output + continue + except NotZipInput: + pass + + # note compile_from_zip also reads the file contents, so this + # is slightly inefficient (and also maybe allows for some very + # rare, strange race conditions if the file changes in between + # the two reads). file = input_bundle.load_file(file_path) assert isinstance(file, FileInput) # mypy hint diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index 71b0c6a1b3..42b017fb94 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -58,12 +58,6 @@ def _parse_args(argv): default=None, dest="output_file", ) - parser.add_argument( - "-p", - help="Set a base import path. Vyper searches here if a file is not found in the JSON.", - default=None, - dest="root_folder", - ) parser.add_argument("--pretty-json", help="Output JSON in pretty format.", action="store_true") parser.add_argument( "--traceback", @@ -82,7 +76,7 @@ def _parse_args(argv): exc_handler = exc_handler_raises if args.traceback else exc_handler_to_dict output_json = json.dumps( - compile_json(input_json, exc_handler, args.root_folder, json_path), + compile_json(input_json, exc_handler, json_path), indent=2 if args.pretty_json else None, sort_keys=True, default=str, @@ -250,12 +244,14 @@ def get_output_formats(input_dict: dict) -> dict[PurePath, list[str]]: return output_formats +def get_search_paths(input_dict: dict) -> list[PurePath]: + ret = input_dict["settings"].get("search_paths", ".") + return [PurePath(p) for p in ret] + + def compile_from_input_dict( - input_dict: dict, exc_handler: Callable = exc_handler_raises, root_folder: Optional[str] = None + input_dict: dict, exc_handler: Callable = exc_handler_raises ) -> tuple[dict, dict]: - if root_folder is None: - root_folder = "." - if input_dict["language"] != "Vyper": raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.") @@ -281,11 +277,14 @@ def compile_from_input_dict( no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True) + integrity = input_dict.get("integrity") + sources = get_inputs(input_dict) output_formats = get_output_formats(input_dict) compilation_targets = list(output_formats.keys()) + search_paths = get_search_paths(input_dict) - input_bundle = JSONInputBundle(sources, search_paths=[Path(root_folder)]) + input_bundle = JSONInputBundle(sources, search_paths=search_paths) res, warnings_dict = {}, {} warnings.simplefilter("always") @@ -299,6 +298,7 @@ def compile_from_input_dict( file, input_bundle=input_bundle, output_formats=output_formats[contract_path], + integrity_sum=integrity, settings=settings, no_bytecode_metadata=no_bytecode_metadata, ) @@ -381,7 +381,6 @@ def _raise_on_duplicate_keys(ordered_pairs: list[tuple[Hashable, Any]]) -> dict: def compile_json( input_json: dict | str, exc_handler: Callable = exc_handler_raises, - root_folder: Optional[str] = None, json_path: Optional[str] = None, ) -> dict: try: @@ -395,7 +394,7 @@ def compile_json( input_dict = input_json try: - compiler_data, warn_data = compile_from_input_dict(input_dict, exc_handler, root_folder) + compiler_data, warn_data = compile_from_input_dict(input_dict, exc_handler) if "errors" in compiler_data: return compiler_data except KeyError as exc: diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index 47e2054bd8..26439d2918 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -14,9 +14,12 @@ # requires annotated_vyper_module "annotated_ast_dict": output.build_annotated_ast_dict, "layout": output.build_layout_output, - # requires global_ctx "devdoc": output.build_devdoc, "userdoc": output.build_userdoc, + "archive": output.build_archive, + "archive_b64": output.build_archive_b64, + "integrity": output.build_integrity, + "solc_json": output.build_solc_json, # requires ir_node "external_interface": output.build_external_interface_output, "interface": output.build_interface_output, @@ -51,6 +54,7 @@ def compile_from_file_input( file_input: FileInput, input_bundle: InputBundle = None, settings: Settings = None, + integrity_sum: str = None, output_formats: Optional[OutputFormats] = None, storage_layout_override: Optional[StorageLayout] = None, no_bytecode_metadata: bool = False, @@ -106,10 +110,11 @@ def compile_from_file_input( compiler_data = CompilerData( file_input, input_bundle, - settings, - storage_layout_override, - show_gas_estimates, - no_bytecode_metadata, + settings=settings, + integrity_sum=integrity_sum, + storage_layout=storage_layout_override, + show_gas_estimates=show_gas_estimates, + no_bytecode_metadata=no_bytecode_metadata, ) ret = {} @@ -147,7 +152,7 @@ def compile_code( contract_path = Path(contract_path) file_input = FileInput( source_id=source_id, - source_code=source_code, + contents=source_code, path=contract_path, resolved_path=resolved_path or contract_path, # type: ignore ) diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 4fe16a4bf1..51f1779119 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -1,10 +1,10 @@ import contextlib import json import os -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field from functools import cached_property from pathlib import Path, PurePath -from typing import Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator, Optional from vyper.exceptions import JSONError from vyper.utils import sha256sum @@ -12,40 +12,47 @@ # a type to make mypy happy PathLike = Path | PurePath +if TYPE_CHECKING: + from zipfile import ZipFile -@dataclass + +@dataclass(frozen=True) class CompilerInput: # an input to the compiler, basically an abstraction for file contents + source_id: int path: PathLike # the path that was asked for # resolved_path is the real path that was resolved to. # mainly handy for debugging at this point resolved_path: PathLike + contents: str + @cached_property + def sha256sum(self): + return sha256sum(self.contents) -@dataclass -class FileInput(CompilerInput): - source_code: str +@dataclass(frozen=True) +class FileInput(CompilerInput): @cached_property - def sha256sum(self): - return sha256sum(self.source_code) + def source_code(self): + return self.contents -@dataclass +@dataclass(frozen=True, unsafe_hash=True) class ABIInput(CompilerInput): # some json input, which has already been parsed into a dict or list # this is needed because json inputs present json interfaces as json # objects, not as strings. this class helps us avoid round-tripping # back to a string to pretend it's a file. - abi: Any # something that json.load() returns + abi: Any = field(hash=False) # something that json.load() returns def try_parse_abi(file_input: FileInput) -> CompilerInput: try: s = json.loads(file_input.source_code) - return ABIInput(file_input.source_id, file_input.path, file_input.resolved_path, s) + return ABIInput(**asdict(file_input), abi=s) except (ValueError, TypeError): return file_input @@ -185,9 +192,10 @@ def _normpath(path): return path.__class__(os.path.normpath(path)) -# fake filesystem for JSON inputs. takes a base path, and `load_file()` -# "reads" the file from the JSON input. Note that this input bundle type -# never actually interacts with the filesystem -- it is guaranteed to be pure! +# fake filesystem for "standard JSON" (aka solc-style) inputs. takes search +# paths, and `load_file()` "reads" the file from the JSON input. Note that this +# input bundle type never actually interacts with the filesystem -- it is +# guaranteed to be pure! class JSONInputBundle(InputBundle): input_json: dict[PurePath, Any] @@ -216,7 +224,9 @@ def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> C return FileInput(source_id, original_path, resolved_path, value["content"]) if "abi" in value: - return ABIInput(source_id, original_path, resolved_path, value["abi"]) + return ABIInput( + source_id, original_path, resolved_path, json.dumps(value), value["abi"] + ) # TODO: ethPM support # if isinstance(contents, dict) and "contractTypes" in contents: @@ -224,3 +234,32 @@ def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> C # unreachable, based on how JSONInputBundle is constructed in # the codebase. raise JSONError(f"Unexpected type in file: '{resolved_path}'") # pragma: nocover + + +# input bundle for vyper archives. similar to JSONInputBundle, but takes +# a zipfile as input. +class ZipInputBundle(InputBundle): + def __init__(self, archive: "ZipFile"): + assert archive.testzip() is None + self.archive = archive + + sp_str = archive.read("MANIFEST/searchpaths").decode("utf-8") + search_paths = [PurePath(p) for p in sp_str.splitlines()] + + super().__init__(search_paths) + + def _normalize_path(self, path: PurePath) -> PurePath: + return _normpath(path) + + def _load_from_path(self, resolved_path: PurePath, original_path: PurePath) -> CompilerInput: + # zipfile.BadZipFile: File is not a zip file + + try: + value = self.archive.read(str(resolved_path)).decode("utf-8") + except KeyError: + # zipfile literally raises KeyError if the file is not there + raise _NotFound(resolved_path) + + source_id = super()._generate_source_id(resolved_path) + + return FileInput(source_id, original_path, resolved_path, value) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 9b3bd147ef..c9b138ba64 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -1,15 +1,19 @@ +import base64 import warnings from collections import deque from pathlib import PurePath from vyper.ast import ast_to_dict from vyper.codegen.ir_node import IRnode +from vyper.compiler.output_bundle import SolcJSONWriter, VyperArchiveWriter from vyper.compiler.phases import CompilerData from vyper.compiler.utils import build_gas_estimates from vyper.evm import opcodes +from vyper.exceptions import VyperException from vyper.ir import compile_ir from vyper.semantics.types.function import FunctionVisibility, StateMutability from vyper.typing import StorageLayout +from vyper.utils import vyper_warn from vyper.warnings import ContractSizeLimitWarning @@ -37,6 +41,44 @@ def build_userdoc(compiler_data: CompilerData) -> dict: return compiler_data.natspec.userdoc +def build_solc_json(compiler_data: CompilerData) -> str: + # request bytecode to ensure the input compiles through all the + # compilation passes, emit warnings if there are any issues + # (this allows use cases like sending a bug reproduction while + # still alerting the user in the common case that they didn't + # mean to have a bug) + try: + _ = compiler_data.bytecode + except VyperException as e: + vyper_warn( + f"Exceptions encountered during code generation (but producing output anyway): {e}" + ) + writer = SolcJSONWriter(compiler_data) + writer.write() + return writer.output() + + +def build_archive(compiler_data: CompilerData) -> bytes: + # ditto + try: + _ = compiler_data.bytecode + except VyperException as e: + vyper_warn( + f"Exceptions encountered during code generation (but producing archive anyway): {e}" + ) + writer = VyperArchiveWriter(compiler_data) + writer.write() + return writer.output() + + +def build_archive_b64(compiler_data: CompilerData) -> str: + return base64.b64encode(build_archive(compiler_data)).decode("ascii") + + +def build_integrity(compiler_data: CompilerData) -> str: + return compiler_data.compilation_target._metadata["type"].integrity_sum + + def build_external_interface_output(compiler_data: CompilerData) -> str: interface = compiler_data.annotated_vyper_module._metadata["type"].interface stem = PurePath(compiler_data.contract_path).stem diff --git a/vyper/compiler/output_bundle.py b/vyper/compiler/output_bundle.py new file mode 100644 index 0000000000..13e74922a8 --- /dev/null +++ b/vyper/compiler/output_bundle.py @@ -0,0 +1,260 @@ +import importlib +import io +import json +import os +import zipfile +from dataclasses import dataclass +from functools import cached_property +from pathlib import PurePath +from typing import Optional + +from vyper.compiler.input_bundle import CompilerInput, _NotFound +from vyper.compiler.phases import CompilerData +from vyper.compiler.settings import Settings +from vyper.exceptions import CompilerPanic +from vyper.semantics.analysis.module import _is_builtin +from vyper.utils import get_long_version + +# data structures and routines for constructing "output bundles", +# basically reproducible builds of a vyper contract, with varying +# formats. note this is similar but not exactly analogous to +# `input_bundle.py` -- the output bundle defined here contains more +# information. + + +def _anonymize(p: str): + segments = [] + # replace ../../../a/b with 0/1/2/a/b + for i, s in enumerate(PurePath(p).parts): + if s == "..": + segments.append(str(i)) + else: + segments.append(s) + return str(PurePath(*segments)) + + +# data structure containing things that should be in an output bundle, +# which is some container containing the information required to +# reproduce a build +@dataclass +class OutputBundle: + def __init__(self, compiler_data: CompilerData): + self.compiler_data = compiler_data + + @cached_property + def compilation_target(self): + return self.compiler_data.compilation_target._metadata["type"] + + @cached_property + def _imports(self): + return self.compilation_target.reachable_imports + + @cached_property + def compiler_inputs(self) -> dict[str, CompilerInput]: + inputs: list[CompilerInput] = [ + t.compiler_input for t in self._imports if not _is_builtin(t.qualified_module_name) + ] + inputs.append(self.compiler_data.file_input) + + sources = {} + for c in inputs: + path = os.path.relpath(str(c.resolved_path)) + # note: there should be a 1:1 correspondence between + # resolved_path and source_id, but for clarity use resolved_path + # since it corresponds more directly to search path semantics. + sources[_anonymize(path)] = c + + return sources + + @cached_property + def compilation_target_path(self): + p = self.compiler_data.file_input.resolved_path + p = os.path.relpath(str(p)) + return _anonymize(p) + + @cached_property + def used_search_paths(self) -> list[str]: + # report back which search paths were "actually used" in this + # compilation run. this is useful mainly for aesthetic purposes, + # because we don't need to see `/usr/lib/python` in the search path + # if it is not used. + # that being said, we are overly conservative. that is, we might + # put search paths which are not actually used in the output. + + input_bundle = self.compiler_data.input_bundle + + search_paths = [] + for sp in input_bundle.search_paths: + try: + search_paths.append(input_bundle._normalize_path(sp)) + except _NotFound: + # invalid / nonexistent path + pass + + # preserve order of original search paths + tmp = {sp: 0 for sp in search_paths} + + for c in self.compiler_inputs.values(): + ok = False + # recover the search path that was used for this CompilerInput. + # note that it is not sufficient to thread the "search path that + # was used" into CompilerInput because search_paths are modified + # during compilation (so a search path which does not exist in + # the original search_paths set could be used for a given file). + for sp in reversed(search_paths): + if c.resolved_path.is_relative_to(sp): + # don't break here. if there are more than 1 search path + # which could possibly match, we add all them to the + # output. + tmp[sp] += 1 + ok = True + + # this shouldn't happen unless a file escapes its package, + # *or* if we have a bug + if not ok: + raise CompilerPanic(f"Invalid path: {c.resolved_path}") + + sps = [sp for sp, count in tmp.items() if count > 0] + assert len(sps) > 0 + + return [_anonymize(os.path.relpath(sp)) for sp in sps] + + +class OutputBundleWriter: + def __init__(self, compiler_data: CompilerData): + self.compiler_data = compiler_data + + @cached_property + def bundle(self): + return OutputBundle(self.compiler_data) + + def write_sources(self, sources: dict[str, CompilerInput]): + raise NotImplementedError(f"write_sources: {self.__class__}") + + def write_search_paths(self, search_paths: list[str]): + raise NotImplementedError(f"write_search_paths: {self.__class__}") + + def write_settings(self, settings: Optional[Settings]): + raise NotImplementedError(f"write_settings: {self.__class__}") + + def write_integrity(self, integrity_sum: str): + raise NotImplementedError(f"write_integrity: {self.__class__}") + + def write_compilation_target(self, targets: list[str]): + raise NotImplementedError(f"write_compilation_target: {self.__class__}") + + def write_compiler_version(self, version: str): + raise NotImplementedError(f"write_compiler_version: {self.__class__}") + + def output(self): + raise NotImplementedError(f"output: {self.__class__}") + + def write(self): + long_version = get_long_version() + self.write_version(f"v{long_version}") + self.write_compilation_target([self.bundle.compilation_target_path]) + self.write_search_paths(self.bundle.used_search_paths) + self.write_settings(self.compiler_data.original_settings) + self.write_integrity(self.bundle.compilation_target.integrity_sum) + self.write_sources(self.bundle.compiler_inputs) + + +class SolcJSONWriter(OutputBundleWriter): + def __init__(self, compiler_data): + super().__init__(compiler_data) + + self._output = {"language": "Vyper", "sources": {}, "settings": {"outputSelection": {}}} + + def write_sources(self, sources: dict[str, CompilerInput]): + out = {} + for path, c in sources.items(): + out[path] = {"content": c.contents, "sha256sum": c.sha256sum} + + self._output["sources"].update(out) + + def write_search_paths(self, search_paths: list[str]): + self._output["settings"]["search_paths"] = search_paths + + def write_settings(self, settings: Optional[Settings]): + if settings is not None: + s = settings.as_dict() + if "evm_version" in s: + s["evmVersion"] = s.pop("evm_version") + if "experimental_codegen" in s: + s["experimentalCodegen"] = s.pop("experimental_codegen") + + self._output["settings"].update(s) + + def write_integrity(self, integrity_sum: str): + self._output["integrity"] = integrity_sum + + def write_compilation_target(self, targets: list[str]): + for target in targets: + self._output["settings"]["outputSelection"][target] = "*" + + def write_version(self, version): + self._output["compiler_version"] = version + + def output(self): + return self._output + + +def _get_compression_method(): + # try to find a compression library, if none are available then + # fall back to ZIP_STORED + # (note: these should all be on all modern systems and in particular + # they should be in the build environment for our build artifacts, + # but write the graceful fallback anyway because hygiene). + try: + importlib.import_module("zlib") + return zipfile.ZIP_DEFLATED + except ImportError: + pass + + # fallback + return zipfile.ZIP_STORED + + +class VyperArchiveWriter(OutputBundleWriter): + def __init__(self, compiler_data: CompilerData): + super().__init__(compiler_data) + + self._buf = io.BytesIO() + method = _get_compression_method() + self.archive = zipfile.ZipFile(self._buf, mode="w", compression=method, compresslevel=9) + + def __del__(self): + # manually order the destruction of child objects. + # cf. https://bugs.python.org/issue37773 + # https://github.com/python/cpython/issues/81954 + del self.archive + del self._buf + + def write_sources(self, sources: dict[str, CompilerInput]): + for path, c in sources.items(): + self.archive.writestr(_anonymize(path), c.contents) + + def write_search_paths(self, search_paths: list[str]): + self.archive.writestr("MANIFEST/searchpaths", "\n".join(search_paths)) + + def write_settings(self, settings: Optional[Settings]): + if settings is not None: + self.archive.writestr("MANIFEST/settings.json", json.dumps(settings.as_dict())) + self.archive.writestr("MANIFEST/cli_settings.txt", settings.as_cli()) + else: + self.archive.writestr("MANIFEST/settings.json", json.dumps(None)) + self.archive.writestr("MANIFEST/cli_settings.txt", "") + + def write_integrity(self, integrity_sum: str): + self.archive.writestr("MANIFEST/integrity", integrity_sum) + + def write_compilation_target(self, targets: list[str]): + self.archive.writestr("MANIFEST/compilation_targets", "\n".join(targets)) + + def write_version(self, version: str): + self.archive.writestr("MANIFEST/compiler_version", version) + + def output(self): + assert self.archive.testzip() is None + self.archive.close() + return self._buf.getvalue() diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index e1ee91df72..0de8e87c1a 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -9,40 +9,18 @@ from vyper.codegen import module from vyper.codegen.ir_node import IRnode from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle -from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings -from vyper.exceptions import StructureException +from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings, merge_settings from vyper.ir import compile_ir, optimizer from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout -from vyper.utils import ERC5202_PREFIX +from vyper.utils import ERC5202_PREFIX, vyper_warn from vyper.venom import generate_assembly_experimental, generate_ir DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") -def _merge_one(lhs, rhs, helpstr): - if lhs is not None and rhs is not None and lhs != rhs: - raise StructureException( - f"compiler settings indicate {helpstr} {lhs}, " f"but source pragma indicates {rhs}." - ) - return lhs if rhs is None else rhs - - -# TODO: does this belong as a method under Settings? -def _merge_settings(cli: Settings, pragma: Settings): - ret = Settings() - ret.evm_version = _merge_one(cli.evm_version, pragma.evm_version, "evm version") - ret.optimize = _merge_one(cli.optimize, pragma.optimize, "optimize") - ret.experimental_codegen = _merge_one( - cli.experimental_codegen, pragma.experimental_codegen, "experimental codegen" - ) - ret.enable_decimals = _merge_one(cli.enable_decimals, pragma.enable_decimals, "enable-decimals") - - return ret - - class CompilerData: """ Object for fetching and storing compiler data for a Vyper contract. @@ -78,6 +56,7 @@ def __init__( file_input: FileInput | str, input_bundle: InputBundle = None, settings: Settings = None, + integrity_sum: str = None, storage_layout: StorageLayout = None, show_gas_estimates: bool = False, no_bytecode_metadata: bool = False, @@ -101,7 +80,7 @@ def __init__( if isinstance(file_input, str): file_input = FileInput( - source_code=file_input, + contents=file_input, source_id=-1, path=DEFAULT_CONTRACT_PATH, resolved_path=DEFAULT_CONTRACT_PATH, @@ -110,10 +89,9 @@ def __init__( self.storage_layout_override = storage_layout self.show_gas_estimates = show_gas_estimates self.no_bytecode_metadata = no_bytecode_metadata - self.settings = settings or Settings() + self.original_settings = settings self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")]) - - _ = self._generate_ast # force settings to be calculated + self.expected_integrity_sum = integrity_sum @cached_property def source_code(self): @@ -136,20 +114,32 @@ def _generate_ast(self): resolved_path=str(self.file_input.resolved_path), ) - self.settings = _merge_settings(self.settings, settings) - if self.settings.optimize is None: - self.settings.optimize = OptimizationLevel.default() + if self.original_settings: + og_settings = self.original_settings + settings = merge_settings(og_settings, settings) + assert self.original_settings == og_settings # be paranoid + else: + # merge with empty Settings(), doesn't do much but it does + # remove the compiler version + settings = merge_settings(Settings(), settings) - if self.settings.experimental_codegen is None: - self.settings.experimental_codegen = False + if settings.optimize is None: + settings.optimize = OptimizationLevel.default() - # note self.settings.compiler_version is erased here as it is - # not used after pre-parsing - return ast + if settings.experimental_codegen is None: + settings.experimental_codegen = False + + return settings, ast + + @cached_property + def settings(self): + settings, _ = self._generate_ast + return settings @cached_property def vyper_module(self): - return self._generate_ast + _, ast = self._generate_ast + return ast @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: @@ -172,6 +162,18 @@ def compilation_target(self): required for a compilation target. """ module_t = self.annotated_vyper_module._metadata["type"] + + expected = self.expected_integrity_sum + + if expected is not None and module_t.integrity_sum != expected: + # warn for now. strict/relaxed mode was considered but it costs + # interface and testing complexity to add another feature flag. + vyper_warn( + f"Mismatched integrity sum! Expected {expected}" + f" but got {module_t.integrity_sum}." + " (This likely indicates a corrupted archive)" + ) + validate_compilation_target(module_t) return self.annotated_vyper_module diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index 0e232472ea..7c20e03906 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import os from dataclasses import dataclass from enum import Enum @@ -17,6 +18,7 @@ VYPER_TRACEBACK_LIMIT = None +# TODO: use StringEnum (requires refactoring vyper.utils to avoid import cycle) class OptimizationLevel(Enum): NONE = 1 GAS = 2 @@ -37,6 +39,9 @@ def from_string(cls, val): def default(cls): return cls.GAS + def __str__(self): + return self._name_.lower() + DEFAULT_ENABLE_DECIMALS = False @@ -50,12 +55,80 @@ class Settings: debug: Optional[bool] = None enable_decimals: Optional[bool] = None + def __post_init__(self): + # sanity check inputs + if self.optimize is not None: + assert isinstance(self.optimize, OptimizationLevel) + if self.experimental_codegen is not None: + assert isinstance(self.experimental_codegen, bool) + if self.debug is not None: + assert isinstance(self.debug, bool) + if self.enable_decimals is not None: + assert isinstance(self.enable_decimals, bool) + # CMC 2024-04-10 consider hiding the `enable_decimals` member altogether def get_enable_decimals(self) -> bool: if self.enable_decimals is None: return DEFAULT_ENABLE_DECIMALS return self.enable_decimals + def as_cli(self): + ret = [] + if self.optimize is not None: + ret.append(" --optimize " + str(self.optimize)) + if self.experimental_codegen is True: + ret.append(" --experimental-codegen") + if self.evm_version is not None: + ret.append(" --evm-version " + self.evm_version) + if self.debug is True: + ret.append(" --debug") + if self.enable_decimals is True: + ret.append(" --enable-decimals") + + return "".join(ret) + + def as_dict(self): + ret = dataclasses.asdict(self) + # compiler_version is not a compiler input, it can only come from + # source code pragma. + ret.pop("compiler_version", None) + ret = {k: v for (k, v) in ret.items() if v is not None} + if "optimize" in ret: + ret["optimize"] = str(ret["optimize"]) + return ret + + @classmethod + def from_dict(cls, data): + data = data.copy() + if "optimize" in data: + data["optimize"] = OptimizationLevel.from_string(data["optimize"]) + return cls(**data) + + +def merge_settings( + one: Settings, two: Settings, lhs_source="compiler settings", rhs_source="source pragma" +) -> Settings: + def _merge_one(lhs, rhs, helpstr): + if lhs is not None and rhs is not None and lhs != rhs: + # aesthetics, conjugate the verbs per english rules + s1 = "" if lhs_source.endswith("s") else "s" + s2 = "" if rhs_source.endswith("s") else "s" + raise ValueError( + f"settings conflict!\n\n {lhs_source}: {one}\n {rhs_source}: {two}\n\n" + f"({lhs_source} indicate{s1} {helpstr} {lhs}, but {rhs_source} indicate{s2} {rhs}.)" + ) + return lhs if rhs is None else rhs + + ret = Settings() + ret.evm_version = _merge_one(one.evm_version, two.evm_version, "evm version") + ret.optimize = _merge_one(one.optimize, two.optimize, "optimize") + ret.experimental_codegen = _merge_one( + one.experimental_codegen, two.experimental_codegen, "experimental codegen" + ) + ret.enable_decimals = _merge_one(one.enable_decimals, two.enable_decimals, "enable-decimals") + + return ret + # CMC 2024-04-10 do we need it to be Optional? _settings = None diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 183dd63b76..3c0444b1ca 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -350,6 +350,10 @@ class ParserException(Exception): """Contract source cannot be parsed.""" +class BadArchive(Exception): + """Bad archive""" + + class UnimplementedException(VyperException): """Some feature is known to be not implemented""" diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 3a1c912392..718581c20c 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union from vyper import ast as vy_ast -from vyper.compiler.input_bundle import CompilerInput, FileInput +from vyper.compiler.input_bundle import CompilerInput from vyper.exceptions import CompilerPanic, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType @@ -119,7 +119,7 @@ def __hash__(self): return hash(id(self.module_t)) -@dataclass +@dataclass(frozen=True) class ImportInfo(AnalysisResult): typ: Union[ModuleInfo, "InterfaceT"] alias: str # the name in the namespace @@ -133,9 +133,7 @@ def to_dict(self): ret["source_id"] = self.compiler_input.source_id ret["path"] = str(self.compiler_input.path) ret["resolved_path"] = str(self.compiler_input.resolved_path) - - if isinstance(self.compiler_input, FileInput): - ret["file_sha256sum"] = self.compiler_input.sha256sum + ret["file_sha256sum"] = self.compiler_input.sha256sum return ret diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index dd7546732a..06469ccef2 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -16,6 +16,7 @@ from vyper.exceptions import ( BorrowException, CallViolation, + CompilerPanic, DuplicateImport, EvmVersionException, ExceptionList, @@ -907,6 +908,7 @@ def _import_to_path(level: int, module_str: str) -> PurePath: BUILTIN_PREFIXES = ["ethereum.ercs"] +# TODO: could move this to analysis/common.py or something def _is_builtin(module_str): return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) @@ -915,8 +917,8 @@ def _is_builtin(module_str): def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, InterfaceT]: - if not _is_builtin(module_str): - raise ModuleNotFound(module_str) + if not _is_builtin(module_str): # pragma: nocover + raise CompilerPanic("unreachable!") builtins_path = vyper.builtins.interfaces.__path__[0] # hygiene: convert to relpath to avoid leaking user directory info diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index cac9b63be5..b3e3f2ef2b 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -22,10 +22,10 @@ from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType -from vyper.utils import OrderedSet +from vyper.utils import OrderedSet, sha256sum if TYPE_CHECKING: - from vyper.semantics.analysis.base import ModuleInfo + from vyper.semantics.analysis.base import ImportInfo, ModuleInfo class InterfaceT(_UserType): @@ -412,6 +412,39 @@ def imported_modules(self) -> dict[str, "ModuleInfo"]: ret[info.alias] = module_info return ret + @cached_property + def reachable_imports(self) -> list["ImportInfo"]: + """ + Return (recursively) reachable imports from this module as a list in + depth-first (descendants-first) order. + """ + ret = [] + for s in self.import_stmts: + info = s._metadata["import_info"] + + # NOTE: this needs to be redone if interfaces can import other interfaces + if not isinstance(info.typ, InterfaceT): + ret.extend(info.typ.typ.reachable_imports) + + ret.append(info) + + return ret + + @cached_property + def integrity_sum(self) -> str: + acc = [sha256sum(self._module.full_source_code)] + for s in self.import_stmts: + info = s._metadata["import_info"] + + if isinstance(info.typ, InterfaceT): + # NOTE: this needs to be redone if interfaces can import other interfaces + acc.append(info.compiler_input.sha256sum) + else: + assert isinstance(info.typ.typ, ModuleT) + acc.append(info.typ.typ.integrity_sum) + + return sha256sum("".join(acc)) + def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: for s in self.imported_modules.values(): if s.module_t == needle: diff --git a/vyper/utils.py b/vyper/utils.py index 600f5552ab..a1fed4087c 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -191,6 +191,12 @@ def sha256sum(s: str) -> str: return hashlib.sha256(s.encode("utf-8")).digest().hex() +def get_long_version(): + from vyper import __long_version__ + + return __long_version__ + + # Converts four bytes to an integer def fourbytes_to_int(inp): return (inp[0] << 24) + (inp[1] << 16) + (inp[2] << 8) + inp[3] From 93147be821189a42a4157c6aca7b54810fbee519 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 8 May 2024 10:42:20 -0400 Subject: [PATCH 3/3] feat[venom]: optimize `get_basic_block()` (#4002) `get_basic_block()` is a hotspot in venom (up to 35% of total compilation time!). this optimizes `get_basic_block()`, on a large contract near the 24kb limit this reduces time spent in venom from 3s to 1s (total time from 6s to 4s). note on the same contract, time spent in the IRnode optimizer pipeline is 2s - so time in venom is now smaller than time in legacy optimizer(!) notes: - refactor to use dict for basic_blocks - clean up basic blocks API hide basic blocks behind `get_basic_blocks()` iterator and `num_basic_blocks`. --- vyper/venom/analysis/cfg.py | 6 +- vyper/venom/analysis/dfg.py | 2 +- vyper/venom/analysis/dominators.py | 2 +- vyper/venom/analysis/dup_requirements.py | 2 +- vyper/venom/analysis/liveness.py | 4 +- vyper/venom/function.py | 97 ++++++++----------- vyper/venom/ir_node_to_venom.py | 7 +- vyper/venom/passes/dft.py | 4 +- vyper/venom/passes/normalization.py | 4 +- vyper/venom/passes/remove_unused_variables.py | 2 +- vyper/venom/passes/sccp/sccp.py | 2 +- vyper/venom/passes/simplify_cfg.py | 29 +++--- 12 files changed, 72 insertions(+), 89 deletions(-) diff --git a/vyper/venom/analysis/cfg.py b/vyper/venom/analysis/cfg.py index 2a521ab131..6bd7e538e9 100644 --- a/vyper/venom/analysis/cfg.py +++ b/vyper/venom/analysis/cfg.py @@ -10,12 +10,12 @@ class CFGAnalysis(IRAnalysis): def analyze(self) -> None: fn = self.function - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): bb.cfg_in = OrderedSet() bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] assert ( @@ -29,7 +29,7 @@ def analyze(self) -> None: fn.get_basic_block(op.value).add_cfg_in(bb) # Fill in the "out" set for each basic block - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): for in_bb in bb.cfg_in: in_bb.add_cfg_out(bb) diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 8b113e74bc..dc7076d5de 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -33,7 +33,7 @@ def analyze(self): # %16 = iszero %15 # dfg_outputs of %15 is (%15 = add %13 %14) # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for inst in bb.instructions: operands = inst.get_inputs() res = inst.get_outputs() diff --git a/vyper/venom/analysis/dominators.py b/vyper/venom/analysis/dominators.py index c0b149d880..129d1d0f22 100644 --- a/vyper/venom/analysis/dominators.py +++ b/vyper/venom/analysis/dominators.py @@ -153,7 +153,7 @@ def as_graph(self) -> str: Generate a graphviz representation of the dominator tree. """ lines = ["digraph dominator_tree {"] - for bb in self.fn.basic_blocks: + for bb in self.fn.get_basic_blocks(): if bb == self.entry_block: continue idom = self.immediate_dominator(bb) diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py index 015c7c5871..3452bc2e0f 100644 --- a/vyper/venom/analysis/dup_requirements.py +++ b/vyper/venom/analysis/dup_requirements.py @@ -4,7 +4,7 @@ class DupRequirementsAnalysis(IRAnalysis): def analyze(self): - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): last_liveness = bb.out_vars for inst in reversed(bb.instructions): inst.dup_requirements = OrderedSet() diff --git a/vyper/venom/analysis/liveness.py b/vyper/venom/analysis/liveness.py index 95853e57aa..5e78aa4ff3 100644 --- a/vyper/venom/analysis/liveness.py +++ b/vyper/venom/analysis/liveness.py @@ -15,7 +15,7 @@ def analyze(self): self._reset_liveness() while True: changed = False - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): changed |= self._calculate_out_vars(bb) changed |= self._calculate_liveness(bb) @@ -23,7 +23,7 @@ def analyze(self): break def _reset_liveness(self) -> None: - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 556be28246..eace17af0d 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -13,57 +13,42 @@ class IRFunction: name: IRLabel # symbol name ctx: "IRContext" # type: ignore # noqa: F821 args: list - basic_blocks: list[IRBasicBlock] last_label: int last_variable: int + _basic_block_dict: dict[str, IRBasicBlock] # Used during code generation _ast_source_stack: list[IRnode] _error_msg_stack: list[str] - _bb_index: dict[str, int] def __init__(self, name: IRLabel, ctx: "IRContext" = None) -> None: # type: ignore # noqa: F821 self.ctx = ctx self.name = name self.args = [] - self.basic_blocks = [] + self._basic_block_dict = {} self.last_variable = 0 self._ast_source_stack = [] self._error_msg_stack = [] - self._bb_index = {} self.append_basic_block(IRBasicBlock(name, self)) @property def entry(self) -> IRBasicBlock: - return self.basic_blocks[0] + return next(self.get_basic_blocks()) - def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + def append_basic_block(self, bb: IRBasicBlock): """ Append basic block to function. """ - assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" - self.basic_blocks.append(bb) - - return self.basic_blocks[-1] - - def _get_basicblock_index(self, label: str): - # perf: keep an "index" of labels to block indices to - # perform fast lookup. - # TODO: maybe better just to throw basic blocks in an ordered - # dict of some kind. - ix = self._bb_index.get(label, -1) - if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: - return ix - # do a reindex - self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) - # sanity check - no duplicate labels - assert len(self._bb_index) == len( - self.basic_blocks - ), f"Duplicate labels in function '{self.name}' {self._bb_index} {self.basic_blocks}" - return self._bb_index[label] + assert isinstance(bb, IRBasicBlock), bb + assert bb.label.name not in self._basic_block_dict + self._basic_block_dict[bb.label.name] = bb + + def remove_basic_block(self, bb: IRBasicBlock): + assert isinstance(bb, IRBasicBlock), bb + del self._basic_block_dict[bb.label.name] def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ @@ -71,33 +56,31 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: If label is None, return the last basic block. """ if label is None: - return self.basic_blocks[-1] - ix = self._get_basicblock_index(label) - return self.basic_blocks[ix] + return next(reversed(self._basic_block_dict.values())) + + return self._basic_block_dict[label] + + def clear_basic_blocks(self): + self._basic_block_dict.clear() - def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + def get_basic_blocks(self) -> Iterator[IRBasicBlock]: """ - Get basic block after label. + Get an iterator over this function's basic blocks """ - ix = self._get_basicblock_index(label.value) - if 0 <= ix < len(self.basic_blocks) - 1: - return self.basic_blocks[ix + 1] - raise AssertionError(f"Basic block after '{label}' not found") + return iter(self._basic_block_dict.values()) + + @property + def num_basic_blocks(self) -> int: + return len(self._basic_block_dict) def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: """ Get basic blocks that are terminal. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if bb.is_terminal: yield bb - def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: - """ - Get basic blocks that point to the given basic block - """ - return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_variable(self) -> IRVariable: self.last_variable += 1 return IRVariable(f"%{self.last_variable}") @@ -109,15 +92,14 @@ def remove_unreachable_blocks(self) -> int: self._compute_reachability() removed = [] - new_basic_blocks = [] # Remove unreachable basic blocks - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if not bb.is_reachable: removed.append(bb) - else: - new_basic_blocks.append(bb) - self.basic_blocks = new_basic_blocks + + for bb in removed: + self.remove_basic_block(bb) # Remove phi instructions that reference removed basic blocks for bb in removed: @@ -142,7 +124,7 @@ def _compute_reachability(self) -> None: """ Compute reachability of basic blocks. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): bb.reachable = OrderedSet() bb.is_reachable = False @@ -172,7 +154,7 @@ def normalized(self) -> bool: Having a normalized CFG makes calculation of stack layout easier when emitting assembly. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): # Ignore if there are no multiple predecessors if len(bb.cfg_in) <= 1: continue @@ -211,22 +193,23 @@ def chain_basic_blocks(self) -> None: Otherwise, append a stop instruction. This is necessary for the IR to be valid, and is done after the IR is generated. """ - for i, bb in enumerate(self.basic_blocks): + bbs = list(self.get_basic_blocks()) + for i, bb in enumerate(bbs): if not bb.is_terminated: - if len(self.basic_blocks) - 1 > i: + if i < len(bbs) - 1: # TODO: revisit this. When contructor calls internal functions they # are linked to the last ctor block. Should separate them before this # so we don't have to handle this here - if self.basic_blocks[i + 1].label.value.startswith("internal"): + if bbs[i + 1].label.value.startswith("internal"): bb.append_instruction("stop") else: - bb.append_instruction("jmp", self.basic_blocks[i + 1].label) + bb.append_instruction("jmp", bbs[i + 1].label) else: bb.append_instruction("exit") def copy(self): new = IRFunction(self.name) - new.basic_blocks = self.basic_blocks.copy() + new._basic_block_dict = self._basic_block_dict.copy() new.last_label = self.last_label new.last_variable = self.last_variable return new @@ -246,11 +229,11 @@ def _make_label(bb): ret = "digraph G {\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): for out_bb in bb.cfg_out: ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n' - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): ret += f' "{bb.label.value}" [shape=plaintext, ' ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n' @@ -259,6 +242,6 @@ def _make_label(bb): def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): str += f"{bb}\n" return str.strip() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index b4465e9f7b..61b3c081ff 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -135,10 +135,9 @@ def _append_jmp(fn: IRFunction, label: IRLabel) -> None: bb.append_instruction("jmp", label) -def _new_block(fn: IRFunction) -> IRBasicBlock: +def _new_block(fn: IRFunction) -> None: bb = IRBasicBlock(fn.ctx.get_next_label(), fn) - bb = fn.append_basic_block(bb) - return bb + fn.append_basic_block(bb) def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): @@ -328,7 +327,7 @@ def _convert_ir_bb(fn, ir, symbols): # exit bb exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) - exit_bb = fn.append_basic_block(exit_bb) + fn.append_basic_block(exit_bb) if_ret = fn.get_next_variable() if then_ret_val is not None and else_ret_val is not None: diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index e4e27ed813..06366e4336 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -74,8 +74,8 @@ def run_pass(self) -> None: self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() - basic_blocks = self.function.basic_blocks + basic_blocks = list(self.function.get_basic_blocks()) - self.function.basic_blocks = [] + self.function.clear_basic_blocks() for bb in basic_blocks: self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 83c565b1be..cf44c3cf89 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -58,7 +58,7 @@ def _run_pass(self) -> int: self.analyses_cache.request_analysis(CFGAnalysis) # Split blocks that need splitting - for bb in fn.basic_blocks: + for bb in list(fn.get_basic_blocks()): if len(bb.cfg_in) > 1: self._split_basic_block(bb) @@ -71,7 +71,7 @@ def _run_pass(self) -> int: def run_pass(self): fn = self.function - for _ in range(len(fn.basic_blocks) * 2): + for _ in range(fn.num_basic_blocks * 2): if self._run_pass() == 0: break else: diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py index b7fb3abbf0..a4cd737e98 100644 --- a/vyper/venom/passes/remove_unused_variables.py +++ b/vyper/venom/passes/remove_unused_variables.py @@ -9,7 +9,7 @@ def run_pass(self): self.analyses_cache.request_analysis(LivenessAnalysis) - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for i, inst in enumerate(bb.instructions[:-1]): if inst.volatile: continue diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 7f3fc7e03e..577030dea6 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -87,7 +87,7 @@ def _calculate_sccp(self, entry: IRBasicBlock): and the work list. The `_propagate_constants()` method is responsible for updating the IR with the constant values. """ - self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.basic_blocks} + self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.get_basic_blocks()} dummy = IRBasicBlock(IRLabel("__dummy_start"), self.fn) self.work_list.append(FlowWorkItem(dummy, entry)) diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index bb5233eba0..08582fee96 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -30,7 +30,7 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): break inst.operands[inst.operands.index(b.label)] = a.label - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb = b.cfg_out.first() @@ -44,7 +44,7 @@ def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _collapse_chained_blocks_r(self, bb: IRBasicBlock): """ @@ -87,31 +87,32 @@ def _optimize_empty_basicblocks(self) -> int: Remove empty basic blocks. """ fn = self.function - count = 0 - i = 0 - while i < len(fn.basic_blocks): - bb = fn.basic_blocks[i] + worklist = list(fn.get_basic_blocks()) + i = count = 0 + while i < len(worklist): + bb = worklist[i] i += 1 + if len(bb.instructions) > 0: continue + next_bb = worklist[i] + replaced_label = bb.label - replacement_label = fn.basic_blocks[i].label if i < len(fn.basic_blocks) else None - if replacement_label is None: - continue + replacement_label = next_bb.label # Try to preserve symbol labels if replaced_label.is_symbol: replaced_label, replacement_label = replacement_label, replaced_label - fn.basic_blocks[i].label = replacement_label + next_bb.label = replacement_label - for bb2 in fn.basic_blocks: + for bb2 in fn.get_basic_blocks(): for inst in bb2.instructions: for op in inst.operands: if isinstance(op, IRLabel) and op.value == replaced_label.value: op.value = replacement_label.value - fn.basic_blocks.remove(bb) + fn.remove_basic_block(bb) i -= 1 count += 1 @@ -121,7 +122,7 @@ def run_pass(self): fn = self.function entry = fn.entry - for _ in range(len(fn.basic_blocks)): + for _ in range(fn.num_basic_blocks): changes = self._optimize_empty_basicblocks() changes += fn.remove_unreachable_blocks() if changes == 0: @@ -131,7 +132,7 @@ def run_pass(self): self.analyses_cache.force_analysis(CFGAnalysis) - for _ in range(len(fn.basic_blocks)): # essentially `while True` + for _ in range(fn.num_basic_blocks): # essentially `while True` self._collapse_chained_blocks(entry) if fn.remove_unreachable_blocks() == 0: break