Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(blockifier, starknet_api): add explcit typing and arithmetic to gas amount/price, and fee #1173

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions crates/batcher/src/block_builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::num::NonZeroU128;

use async_trait::async_trait;
use blockifier::blockifier::block::{BlockInfo, BlockNumberHashPair, GasPrices};
use blockifier::blockifier::config::TransactionExecutorConfig;
Expand All @@ -21,7 +19,7 @@ use indexmap::IndexMap;
#[cfg(test)]
use mockall::automock;
use papyrus_storage::StorageReader;
use starknet_api::block::{BlockNumber, BlockTimestamp};
use starknet_api::block::{BlockNumber, BlockTimestamp, NonzeroGasPrice};
use starknet_api::core::ContractAddress;
use starknet_api::executable_transaction::Transaction;
use starknet_api::transaction::TransactionHash;
Expand Down Expand Up @@ -121,7 +119,7 @@ impl BlockBuilderFactory {
sequencer_address: execution_config.sequencer_address,
// TODO (yael 7/10/2024): add logic to compute gas prices
gas_prices: {
let tmp_val = NonZeroU128::new(1).unwrap();
let tmp_val = NonzeroGasPrice::MIN;
GasPrices::new(tmp_val, tmp_val, tmp_val, tmp_val, tmp_val, tmp_val)
},
use_kzg_da: execution_config.use_kzg_da,
Expand Down
44 changes: 21 additions & 23 deletions crates/blockifier/src/blockifier/block.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::num::NonZeroU128;

use log::warn;
use serde::{Deserialize, Serialize};
use starknet_api::block::{BlockHash, BlockNumber, BlockTimestamp};
use starknet_api::block::{BlockHash, BlockNumber, BlockTimestamp, GasPrice, NonzeroGasPrice};
use starknet_api::core::ContractAddress;
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;
Expand Down Expand Up @@ -30,34 +28,34 @@ pub struct BlockInfo {

#[derive(Clone, Debug)]
pub struct GasPrices {
eth_l1_gas_price: NonZeroU128, // In wei.
strk_l1_gas_price: NonZeroU128, // In fri.
eth_l1_data_gas_price: NonZeroU128, // In wei.
strk_l1_data_gas_price: NonZeroU128, // In fri.
eth_l2_gas_price: NonZeroU128, // In wei.
strk_l2_gas_price: NonZeroU128, // In fri.
eth_l1_gas_price: NonzeroGasPrice, // In wei.
strk_l1_gas_price: NonzeroGasPrice, // In fri.
eth_l1_data_gas_price: NonzeroGasPrice, // In wei.
strk_l1_data_gas_price: NonzeroGasPrice, // In fri.
eth_l2_gas_price: NonzeroGasPrice, // In wei.
strk_l2_gas_price: NonzeroGasPrice, // In fri.
}

#[derive(Debug)]
pub struct GasPricesForFeeType {
pub l1_gas_price: NonZeroU128,
pub l1_data_gas_price: NonZeroU128,
pub l2_gas_price: NonZeroU128,
pub l1_gas_price: NonzeroGasPrice,
pub l1_data_gas_price: NonzeroGasPrice,
pub l2_gas_price: NonzeroGasPrice,
}

impl GasPrices {
pub fn new(
eth_l1_gas_price: NonZeroU128,
strk_l1_gas_price: NonZeroU128,
eth_l1_data_gas_price: NonZeroU128,
strk_l1_data_gas_price: NonZeroU128,
eth_l2_gas_price: NonZeroU128,
strk_l2_gas_price: NonZeroU128,
eth_l1_gas_price: NonzeroGasPrice,
strk_l1_gas_price: NonzeroGasPrice,
eth_l1_data_gas_price: NonzeroGasPrice,
strk_l1_data_gas_price: NonzeroGasPrice,
eth_l2_gas_price: NonzeroGasPrice,
strk_l2_gas_price: NonzeroGasPrice,
) -> Self {
// TODO(Aner): fix backwards compatibility.
let expected_eth_l2_gas_price = VersionedConstants::latest_constants()
.convert_l1_to_l2_gas_price_round_up(eth_l1_gas_price.into());
if u128::from(eth_l2_gas_price) != expected_eth_l2_gas_price {
if GasPrice::from(eth_l2_gas_price) != expected_eth_l2_gas_price {
// TODO!(Aner): change to panic! Requires fixing several tests.
warn!(
"eth_l2_gas_price does not match expected! eth_l2_gas_price:{eth_l2_gas_price}, \
Expand All @@ -66,7 +64,7 @@ impl GasPrices {
}
let expected_strk_l2_gas_price = VersionedConstants::latest_constants()
.convert_l1_to_l2_gas_price_round_up(strk_l1_gas_price.into());
if u128::from(strk_l2_gas_price) != expected_strk_l2_gas_price {
if GasPrice::from(strk_l2_gas_price) != expected_strk_l2_gas_price {
// TODO!(Aner): change to panic! Requires fixing test_discounted_gas_overdraft
warn!(
"strk_l2_gas_price does not match expected! \
Expand All @@ -84,21 +82,21 @@ impl GasPrices {
}
}

pub fn get_l1_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonZeroU128 {
pub fn get_l1_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonzeroGasPrice {
match fee_type {
FeeType::Strk => self.strk_l1_gas_price,
FeeType::Eth => self.eth_l1_gas_price,
}
}

pub fn get_l1_data_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonZeroU128 {
pub fn get_l1_data_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonzeroGasPrice {
match fee_type {
FeeType::Strk => self.strk_l1_data_gas_price,
FeeType::Eth => self.eth_l1_data_gas_price,
}
}

pub fn get_l2_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonZeroU128 {
pub fn get_l2_gas_price_by_fee_type(&self, fee_type: &FeeType) -> NonzeroGasPrice {
match fee_type {
FeeType::Strk => self.strk_l2_gas_price,
FeeType::Eth => self.eth_l2_gas_price,
Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/blockifier/stateful_validator_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use assert_matches::assert_matches;
use rstest::rstest;
use starknet_api::transaction::{Fee, TransactionVersion, ValidResourceBounds};
use starknet_api::transaction::{TransactionVersion, ValidResourceBounds};

use crate::blockifier::stateful_validator::StatefulValidator;
use crate::context::BlockContext;
Expand Down Expand Up @@ -53,7 +53,7 @@ fn test_tx_validator(
class_hash,
validate_constructor,
// TODO(Arni, 1/5/2024): Add test for insufficient maximal resources.
max_fee: Fee(BALANCE),
max_fee: BALANCE,
resource_bounds: max_l1_resource_bounds,
..Default::default()
};
Expand Down
5 changes: 2 additions & 3 deletions crates/blockifier/src/blockifier/transaction_executor_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use rstest::rstest;
use starknet_api::execution_resources::GasAmount;
use starknet_api::test_utils::NonceManager;
use starknet_api::transaction::{Fee, TransactionVersion};
use starknet_api::{declare_tx_args, deploy_account_tx_args, felt, invoke_tx_args, nonce};
Expand Down Expand Up @@ -121,7 +120,7 @@ fn test_declare(
class_hash: declared_contract.get_class_hash(),
compiled_class_hash: declared_contract.get_compiled_class_hash(),
version: tx_version,
resource_bounds: l1_resource_bounds(GasAmount(0), DEFAULT_STRK_L1_GAS_PRICE),
resource_bounds: l1_resource_bounds(0_u8.into(), DEFAULT_STRK_L1_GAS_PRICE.into()),
},
calculate_class_info_for_testing(declared_contract.get_class()),
)
Expand All @@ -141,7 +140,7 @@ fn test_deploy_account(
let tx = deploy_account_tx(
deploy_account_tx_args! {
class_hash: account_contract.get_class_hash(),
resource_bounds: l1_resource_bounds(GasAmount(0), DEFAULT_STRK_L1_GAS_PRICE),
resource_bounds: l1_resource_bounds(0_u8.into(), DEFAULT_STRK_L1_GAS_PRICE.into()),
version,
},
&mut NonceManager::default(),
Expand Down
4 changes: 3 additions & 1 deletion crates/blockifier/src/bouncer_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::{HashMap, HashSet};
use cairo_vm::types::builtin_name::BuiltinName;
use rstest::rstest;
use starknet_api::core::{ClassHash, ContractAddress, PatriciaKey};
use starknet_api::transaction::Fee;
use starknet_api::{class_hash, contract_address, felt, patricia_key, storage_key};

use super::BouncerConfig;
Expand Down Expand Up @@ -186,7 +187,8 @@ fn test_bouncer_try_update(

use crate::fee::resources::{ComputationResources, TransactionResources};

let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]);
let state =
&mut test_state(&BlockContext::create_for_account_testing().chain_info, Fee(0), &[]);
let mut transactional_state = TransactionalState::create_transactional(state);

// Setup the bouncer.
Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/concurrency/fee_utils_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub fn test_fill_sequencer_balance_reads(
let chain_info = &block_context.chain_info;
let state = &mut test_state_inner(chain_info, BALANCE, &[(account, 1)], erc20_version);

let sequencer_balance = 100;
let sequencer_balance = Fee(100);
let sequencer_address = block_context.block_info.sequencer_address;
fund_account(chain_info, sequencer_address, sequencer_balance, &mut state.state);

Expand Down Expand Up @@ -58,7 +58,7 @@ pub fn test_add_fee_to_sequencer_balance(
) {
let block_context = BlockContext::create_for_account_testing();
let account = FeatureContract::Empty(CairoVersion::Cairo1);
let mut state = test_state(&block_context.chain_info, 0, &[(account, 1)]);
let mut state = test_state(&block_context.chain_info, Fee(0), &[(account, 1)]);
let (sequencer_balance_key_low, sequencer_balance_key_high) =
get_sequencer_balance_keys(&block_context);

Expand Down
7 changes: 5 additions & 2 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ fn test_run_parallel_txs(max_l1_resource_bounds: ValidResourceBounds) {
let deploy_account_tx_1 = deploy_account_tx(
deploy_account_tx_args! {
class_hash: account_without_validation.get_class_hash(),
resource_bounds: l1_resource_bounds(u128::from(!zero_bounds).into(), DEFAULT_STRK_L1_GAS_PRICE),
resource_bounds: l1_resource_bounds(
u8::from(!zero_bounds).into(),
DEFAULT_STRK_L1_GAS_PRICE.into()
),
},
&mut NonceManager::default(),
);
Expand All @@ -264,7 +267,7 @@ fn test_run_parallel_txs(max_l1_resource_bounds: ValidResourceBounds) {
let deployed_account_balance_key = get_fee_token_var_address(account_address);
let fee_token_address = chain_info.fee_token_address(&fee_type);
state_2
.set_storage_at(fee_token_address, deployed_account_balance_key, felt!(BALANCE))
.set_storage_at(fee_token_address, deployed_account_balance_key, felt!(BALANCE.0))
.unwrap();

let block_context_1 = block_context.clone();
Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/concurrency/worker_logic_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ fn test_worker_execute(max_l1_resource_bounds: ValidResourceBounds) {
let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap();
let execution_output = execution_output.as_ref().unwrap();
let result = execution_output.result.as_ref().unwrap();
let account_balance = BALANCE - result.receipt.fee.0;
let account_balance = BALANCE.0 - result.receipt.fee.0;
assert!(!result.is_reverted());

let erc20 = FeatureContract::ERC20(CairoVersion::Cairo0);
Expand Down Expand Up @@ -381,7 +381,7 @@ fn test_worker_execute(max_l1_resource_bounds: ValidResourceBounds) {
]),
storage: HashMap::from([
((test_contract_address, storage_key), felt!(0_u8)),
((erc_contract_address, account_balance_key_low), felt!(BALANCE)),
((erc_contract_address, account_balance_key_low), felt!(BALANCE.0)),
((erc_contract_address, account_balance_key_high), felt!(0_u8)),
]),
// When running an entry point, we load its contract class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use crate::{check_entry_point_execution_error_for_custom_hint, retdata};
#[test]
fn test_storage_read_write() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);

let key = felt!(1234_u16);
let value = felt!(18_u8);
Expand All @@ -79,7 +79,7 @@ fn test_storage_read_write() {
#[test]
fn test_library_call() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);
let inner_entry_point_selector = selector_from_name("test_storage_read_write");
let calldata = calldata![
test_contract.get_class_hash().0, // Class hash.
Expand All @@ -103,7 +103,7 @@ fn test_library_call() {
#[test]
fn test_nested_library_call() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);
let (key, value) = (255_u64, 44_u64);
let outer_entry_point_selector = selector_from_name("test_library_call");
let inner_entry_point_selector = selector_from_name("test_storage_read_write");
Expand Down Expand Up @@ -208,7 +208,7 @@ fn test_nested_library_call() {
fn test_call_contract() {
let chain_info = &ChainInfo::create_for_testing();
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(chain_info, 0, &[(test_contract, 1)]);
let mut state = test_state(chain_info, Fee(0), &[(test_contract, 1)]);
let test_address = test_contract.get_instance_address(0);

let trivial_external_entry_point = trivial_external_entry_point_new(test_contract);
Expand Down Expand Up @@ -277,7 +277,7 @@ fn test_replace_class() {
let chain_info = &ChainInfo::create_for_testing();
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let empty_contract = FeatureContract::Empty(CairoVersion::Cairo0);
let mut state = test_state(chain_info, 0, &[(test_contract, 1), (empty_contract, 1)]);
let mut state = test_state(chain_info, Fee(0), &[(test_contract, 1), (empty_contract, 1)]);
let test_address = test_contract.get_instance_address(0);
// Replace with undeclared class hash.
let calldata = calldata![felt!(1234_u16)];
Expand Down Expand Up @@ -342,8 +342,11 @@ fn test_deploy(
) {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let empty_contract = FeatureContract::Empty(CairoVersion::Cairo0);
let mut state =
test_state(&ChainInfo::create_for_testing(), 0, &[(empty_contract, 0), (test_contract, 1)]);
let mut state = test_state(
&ChainInfo::create_for_testing(),
Fee(0),
&[(empty_contract, 0), (test_contract, 1)],
);

let class_hash = if constructor_exists {
test_contract.get_class_hash()
Expand Down Expand Up @@ -418,7 +421,7 @@ fn test_block_info_syscalls(
calldata: Calldata,
) {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);
let entry_point_selector = selector_from_name(&format!("test_get_{}", block_info_member_name));
let entry_point_call = CallEntryPoint {
entry_point_selector,
Expand Down Expand Up @@ -453,7 +456,7 @@ fn test_block_info_syscalls(
#[rstest]
fn test_tx_info(#[values(false, true)] only_query: bool) {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);
let mut version = felt!(1_u8);
if only_query {
let simulate_version_base = Pow::pow(felt!(2_u8), QUERY_VERSION_BASE_BIT);
Expand Down Expand Up @@ -556,7 +559,7 @@ fn emit_events(
data: &[Felt],
) -> Result<CallInfo, EntryPointExecutionError> {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0);
let mut state = test_state(&ChainInfo::create_for_testing(), 0, &[(test_contract, 1)]);
let mut state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 1)]);
let calldata = Calldata(
[
n_emitted_events.to_owned(),
Expand Down
26 changes: 5 additions & 21 deletions crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::execution::execution_utils::execute_entry_point_call_wrapper;
use crate::state::state_api::{State, StateResult};
use crate::transaction::objects::{HasRelatedFeeType, TransactionInfo};
use crate::transaction::transaction_types::TransactionType;
use crate::utils::{u128_from_usize, usize_from_u128};
use crate::utils::usize_from_u128;
use crate::versioned_constants::{GasCosts, VersionedConstants};

#[cfg(test)]
Expand Down Expand Up @@ -269,26 +269,10 @@ impl EntryPointExecutionContext {
// transactions derive this value from the `max_fee`.
let tx_gas_upper_bound = match tx_info {
TransactionInfo::Deprecated(context) => {
let max_cairo_steps = context.max_fee.0
/ block_info.gas_prices.get_l1_gas_price_by_fee_type(&tx_info.fee_type());
// FIXME: This is saturating in the python bootstrapping test. Fix the value so
// that it'll fit in a usize and remove the `as`.
usize::try_from(max_cairo_steps).unwrap_or_else(|_| {
log::error!(
"Performed a saturating cast from u128 to usize: {max_cairo_steps:?}"
);
usize::MAX
})
}
TransactionInfo::Current(context) => {
// TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the
// convertion works.
context
.l1_resource_bounds()
.max_amount
.try_into()
.expect("Failed to convert u64 to usize.")
context.max_fee
/ block_info.gas_prices.get_l1_gas_price_by_fee_type(&tx_info.fee_type())
}
TransactionInfo::Current(context) => context.l1_resource_bounds().max_amount.into(),
};

// Use saturating upper bound to avoid overflow. This is safe because the upper bound is
Expand All @@ -297,7 +281,7 @@ impl EntryPointExecutionContext {
let upper_bound_u128 = if gas_per_step.is_zero() {
u128::MAX
} else {
(gas_per_step.inv() * u128_from_usize(tx_gas_upper_bound)).to_integer()
(gas_per_step.inv() * tx_gas_upper_bound.0).to_integer()
};
let tx_upper_bound = usize_from_u128(upper_bound_u128).unwrap_or_else(|_| {
log::warn!(
Expand Down
Loading
Loading