Skip to content

Commit

Permalink
refactor(mempool_test_utils): rename contract (#1115)
Browse files Browse the repository at this point in the history
Rename:

- FeatureAccount -> Contract, this was an incorrect name, since we use
this abstraction for both non-accounts and accounts. This is basically
an enhancement of FeatureContract, allowing to construct it from a
deploy account tx, as well as through the literal constructor.

- `account: Contract` -> `contract: Contract`.

Also added docs to clarify this issue.

Note: we still unfortunately expose `ThinStateDiffBuilder#fund()` for
non-account contracts, which is only protected via an assert, but
AFAICS mending this requires touchups inside FeatureContract itself,
which is out of scope here.

Co-Authored-By: Gilad Chase <gilad@starkware.com>
  • Loading branch information
giladchase and Gilad Chase authored Oct 1, 2024
1 parent c09e262 commit 618fafe
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 48 deletions.
27 changes: 14 additions & 13 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl MultiAccountTransactionGenerator {
self.register_account(account_contract);
}

pub fn accounts(&self) -> Vec<FeatureAccount> {
pub fn accounts(&self) -> Vec<Contract> {
self.account_tx_generators.iter().map(|tx_gen| &tx_gen.account).copied().collect()
}
}
Expand All @@ -296,7 +296,7 @@ impl MultiAccountTransactionGenerator {
/// TODO: add more transaction generation methods as needed.
#[derive(Debug)]
pub struct AccountTransactionGenerator {
account: FeatureAccount,
account: Contract,
nonce_manager: SharedNonceManager,
}

Expand Down Expand Up @@ -377,10 +377,8 @@ impl AccountTransactionGenerator {
let default_deploy_account_tx =
generate_deploy_account_with_salt(&account, contract_address_salt);

let mut account_tx_generator = Self {
account: FeatureAccount::new(account, &default_deploy_account_tx),
nonce_manager,
};
let mut account_tx_generator =
Self { account: Contract::new(account, &default_deploy_account_tx), nonce_manager };
// Bump the account nonce after transaction creation.
account_tx_generator.next_nonce();

Expand All @@ -394,22 +392,22 @@ impl AccountTransactionGenerator {
// not related to an actual deploy account transaction, which is the way real account addresses are
// calculated.
#[derive(Clone, Copy, Debug)]
pub struct FeatureAccount {
pub account: FeatureContract,
pub struct Contract {
pub contract: FeatureContract,
pub sender_address: ContractAddress,
}

impl FeatureAccount {
impl Contract {
pub fn class_hash(&self) -> ClassHash {
self.account.get_class_hash()
self.contract.get_class_hash()
}

pub fn cairo_version(&self) -> CairoVersion {
self.account.cairo_version()
self.contract.cairo_version()
}

pub fn raw_class(&self) -> String {
self.account.get_raw_class()
self.contract.get_raw_class()
}

fn new(account: FeatureContract, deploy_account_tx: &RpcTransaction) -> Self {
Expand All @@ -426,7 +424,10 @@ impl FeatureAccount {
"{account:?} is not an account"
);

Self { account, sender_address: deploy_account_tx.calculate_sender_address().unwrap() }
Self {
contract: account,
sender_address: deploy_account_tx.calculate_sender_address().unwrap(),
}
}
}

Expand Down
75 changes: 40 additions & 35 deletions crates/tests-integration/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use blockifier::test_utils::{
use blockifier::transaction::objects::FeeType;
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use indexmap::IndexMap;
use mempool_test_utils::starknet_api_test_utils::FeatureAccount;
use mempool_test_utils::starknet_api_test_utils::Contract;
use papyrus_common::pending_classes::PendingClasses;
use papyrus_rpc::{run_server, RpcConfig};
use papyrus_storage::body::BodyStorageWriter;
Expand Down Expand Up @@ -61,23 +61,23 @@ type ContractClassesMap =
/// Creates a papyrus storage reader and spawns a papyrus rpc server for it.
/// Returns the address of the rpc server.
/// A variable number of identical accounts and test contracts are initialized and funded.
pub async fn spawn_test_rpc_state_reader(test_defined_accounts: Vec<FeatureAccount>) -> SocketAddr {
pub async fn spawn_test_rpc_state_reader(test_defined_accounts: Vec<Contract>) -> SocketAddr {
let block_context = BlockContext::create_for_testing();

let into_dummy_feature_account = |account: FeatureContract| FeatureAccount {
account,
sender_address: account.get_instance_address(0),
let into_contract = |contract: FeatureContract| Contract {
contract,
sender_address: contract.get_instance_address(0),
};
let default_test_contracts = [
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::TestContract(CairoVersion::Cairo1),
]
.into_iter()
.map(into_dummy_feature_account)
.map(into_contract)
.collect();

let erc20_contract = FeatureContract::ERC20(CairoVersion::Cairo0);
let erc20_contract = into_dummy_feature_account(erc20_contract);
let erc20_contract = into_contract(erc20_contract);

let storage_reader = initialize_papyrus_test_state(
block_context.chain_info(),
Expand All @@ -90,9 +90,9 @@ pub async fn spawn_test_rpc_state_reader(test_defined_accounts: Vec<FeatureAccou

fn initialize_papyrus_test_state(
chain_info: &ChainInfo,
test_defined_accounts: Vec<FeatureAccount>,
default_test_contracts: Vec<FeatureAccount>,
erc20_contract: FeatureAccount,
test_defined_accounts: Vec<Contract>,
default_test_contracts: Vec<Contract>,
erc20_contract: Contract,
) -> StorageReader {
let state_diff = prepare_state_diff(
chain_info,
Expand All @@ -111,18 +111,18 @@ fn initialize_papyrus_test_state(

fn prepare_state_diff(
chain_info: &ChainInfo,
test_defined_accounts: &[FeatureAccount],
default_test_contracts: &[FeatureAccount],
erc20_contract: &FeatureAccount,
test_defined_accounts: &[Contract],
default_test_contracts: &[Contract],
erc20_contract: &Contract,
) -> ThinStateDiff {
let mut state_diff_builder = ThinStateDiffBuilder::new(chain_info);

// Setup the common test contracts that are used by default in all test invokes.
// TODO(batcher): this does nothing until we actually start excuting stuff in the batcher.
state_diff_builder.set_accounts(default_test_contracts).declare().deploy();
state_diff_builder.set_contracts(default_test_contracts).declare().deploy();

// Declare and deploy and the ERC20 contract, so that transfers from it can be made.
state_diff_builder.set_accounts(std::slice::from_ref(erc20_contract)).declare().deploy();
state_diff_builder.set_contracts(std::slice::from_ref(erc20_contract)).declare().deploy();

// TODO(deploy_account_support): once we have batcher with execution, replace with:
// ```
Expand All @@ -135,22 +135,22 @@ fn prepare_state_diff(
}

fn prepare_compiled_contract_classes(
contract_classes_to_retrieve: impl Iterator<Item = FeatureAccount>,
contract_classes_to_retrieve: impl Iterator<Item = Contract>,
) -> ContractClassesMap {
let mut cairo0_contract_classes = Vec::new();
let mut cairo1_contract_classes = Vec::new();
for account in contract_classes_to_retrieve {
match account.cairo_version() {
for contract in contract_classes_to_retrieve {
match contract.cairo_version() {
CairoVersion::Cairo0 => {
cairo0_contract_classes.push((
account.class_hash(),
serde_json::from_str(&account.raw_class()).unwrap(),
contract.class_hash(),
serde_json::from_str(&contract.raw_class()).unwrap(),
));
}
CairoVersion::Cairo1 => {
cairo1_contract_classes.push((
account.class_hash(),
serde_json::from_str(&account.raw_class()).unwrap(),
contract.class_hash(),
serde_json::from_str(&contract.raw_class()).unwrap(),
));
}
}
Expand Down Expand Up @@ -235,9 +235,11 @@ async fn run_papyrus_rpc_server(storage_reader: StorageReader) -> SocketAddr {
addr
}

/// Constructs a thin state diff from lists of contracts, where each contract can be declared,
/// deployed, and in case it is an account, funded.
#[derive(Default)]
struct ThinStateDiffBuilder<'a> {
accounts: &'a [FeatureAccount],
contracts: &'a [Contract],
deprecated_declared_classes: Vec<ClassHash>,
declared_classes: IndexMap<ClassHash, starknet_api::core::CompiledClassHash>,
deployed_contracts: IndexMap<ContractAddress, ClassHash>,
Expand Down Expand Up @@ -266,34 +268,37 @@ impl<'a> ThinStateDiffBuilder<'a> {
}
}

fn set_accounts(&mut self, accounts: &'a [FeatureAccount]) -> &mut Self {
self.accounts = accounts;
fn set_contracts(&mut self, contracts: &'a [Contract]) -> &mut Self {
self.contracts = contracts;
self
}

fn declare(&mut self) -> &mut Self {
for account in self.accounts {
match account.cairo_version() {
CairoVersion::Cairo0 => self.deprecated_declared_classes.push(account.class_hash()),
for contract in self.contracts {
match contract.cairo_version() {
CairoVersion::Cairo0 => {
self.deprecated_declared_classes.push(contract.class_hash())
}
CairoVersion::Cairo1 => {
self.declared_classes.insert(account.class_hash(), Default::default());
self.declared_classes.insert(contract.class_hash(), Default::default());
}
}
}
self
}

fn deploy(&mut self) -> &mut Self {
for account in self.accounts {
self.deployed_contracts.insert(account.sender_address, account.class_hash());
for contract in self.contracts {
self.deployed_contracts.insert(contract.sender_address, contract.class_hash());
}
self
}

/// Only applies for contracts that are accounts, for non-accounts only declare and deploy work.
fn fund(&mut self) -> &mut Self {
for account in self.accounts {
for account in self.contracts {
assert_matches!(
account.account,
account.contract,
FeatureContract::AccountWithLongValidate(_)
| FeatureContract::AccountWithoutValidations(_)
| FeatureContract::FaultyAccount(_),
Expand All @@ -313,8 +318,8 @@ impl<'a> ThinStateDiffBuilder<'a> {
}

// TODO(deploy_account_support): delete method once we have batcher with execution.
fn inject_accounts_into_state(&mut self, accounts_defined_in_the_test: &'a [FeatureAccount]) {
self.set_accounts(accounts_defined_in_the_test).declare().deploy().fund();
fn inject_accounts_into_state(&mut self, accounts_defined_in_the_test: &'a [Contract]) {
self.set_contracts(accounts_defined_in_the_test).declare().deploy().fund();

// Set nonces as 1 in the state so that subsequent invokes can pass validation.
self.nonces = self
Expand Down

0 comments on commit 618fafe

Please sign in to comment.