Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(blockifier): unify get entry point for V1 and Native #1243

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use starknet_api::contract_class::{ContractClass as RawContractClass, EntryPoint
use starknet_api::core::EntryPointSelector;
use starknet_api::deprecated_contract_class::{
ContractClass as DeprecatedContractClass,
EntryPoint,
EntryPointOffset,
EntryPointV0,
Program as DeprecatedProgram,
};
use starknet_types_core::felt::Felt;
Expand Down Expand Up @@ -63,6 +63,21 @@ pub enum TrackedResource {
SierraGas, // AKA Sierra mode.
}

#[derive(Clone)]
pub enum Cairo1EntryPoint {
Casm(EntryPointV1),
Native(NativeEntryPoint),
}

impl Cairo1EntryPoint {
pub fn selector(&self) -> &EntryPointSelector {
match self {
Cairo1EntryPoint::Casm(ep) => &ep.selector,
Cairo1EntryPoint::Native(ep) => &ep.selector,
}
}
}

/// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM).
#[derive(Clone, Debug, Eq, PartialEq, derive_more::From)]
pub enum ContractClass {
Expand All @@ -71,6 +86,28 @@ pub enum ContractClass {
V1Native(NativeContractClassV1),
}

pub fn get_entry_point(
contract_class: &ContractClass,
call: &CallEntryPoint,
) -> Result<Cairo1EntryPoint, PreExecutionError> {
call.verify_constructor()?;

let entry_points_of_same_type = contract_class.entry_points_of_same_type(call.entry_point_type);
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| *ep.selector() == call.entry_point_selector)
.collect();

match &filtered_entry_points[..] {
[] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)),
[entry_point] => Ok((**entry_point).clone()),
_ => Err(PreExecutionError::DuplicatedEntryPointSelector {
selector: call.entry_point_selector,
typ: call.entry_point_type,
}),
}
}

impl TryFrom<RawContractClass> for ContractClass {
type Error = ProgramError;

Expand Down Expand Up @@ -107,6 +144,23 @@ impl ContractClass {
}
}

pub fn entry_points_of_same_type(
&self,
entry_point_type: EntryPointType,
) -> Vec<Cairo1EntryPoint> {
match self {
ContractClass::V0(_) => panic!("V0 contracts do not support entry points."),
ContractClass::V1(class) => class.entry_points_by_type[&entry_point_type]
.iter()
.map(|ep| Cairo1EntryPoint::Casm(ep.clone()))
.collect(),
ContractClass::V1Native(class) => class.entry_points_by_type[entry_point_type]
.iter()
.map(|ep| Cairo1EntryPoint::Native(ep.clone()))
.collect(),
}
}

pub fn get_visited_segments(
&self,
visited_pcs: &HashSet<usize>,
Expand Down Expand Up @@ -207,7 +261,7 @@ impl ContractClassV0 {
pub struct ContractClassV0Inner {
#[serde(deserialize_with = "deserialize_program")]
pub program: Program,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPoint>>,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPointV0>>,
}

impl TryFrom<DeprecatedContractClass> for ContractClassV0 {
Expand Down Expand Up @@ -253,21 +307,9 @@ impl ContractClassV1 {
&self,
call: &CallEntryPoint,
) -> Result<EntryPointV1, PreExecutionError> {
call.verify_constructor()?;

let entry_points_of_same_type = &self.0.entry_points_by_type[&call.entry_point_type];
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| ep.selector == call.entry_point_selector)
.collect();

match &filtered_entry_points[..] {
[] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)),
[entry_point] => Ok((*entry_point).clone()),
_ => Err(PreExecutionError::DuplicatedEntryPointSelector {
selector: call.entry_point_selector,
typ: call.entry_point_type,
}),
match get_entry_point(&ContractClass::V1(self.clone()), call)? {
Cairo1EntryPoint::Casm(entry_point) => Ok(entry_point),
Cairo1EntryPoint::Native(_) => panic!("Unexpected entry point type."),
}
}

Expand Down Expand Up @@ -639,22 +681,10 @@ impl NativeContractClassV1 {
}

/// Returns an entry point into the natively compiled contract.
pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result<&FunctionId, PreExecutionError> {
call.verify_constructor()?;

let entry_points_of_same_type = &self.0.entry_points_by_type[call.entry_point_type];
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| ep.selector == call.entry_point_selector)
.collect();

match &filtered_entry_points[..] {
[] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)),
[entry_point] => Ok(&entry_point.function_id),
_ => Err(PreExecutionError::DuplicatedEntryPointSelector {
selector: call.entry_point_selector,
typ: call.entry_point_type,
}),
pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result<FunctionId, PreExecutionError> {
match get_entry_point(&ContractClass::V1Native(self.clone()), call)? {
Cairo1EntryPoint::Native(entry_point) => Ok(entry_point.function_id),
Cairo1EntryPoint::Casm(_) => panic!("Unexpected entry point type."),
}
}
}
Expand Down Expand Up @@ -733,9 +763,9 @@ fn sierra_eps_to_native_eps(
sierra_eps.iter().map(|sierra_ep| NativeEntryPoint::from(func_ids, sierra_ep)).collect()
}

#[derive(Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
/// Provides a relation between a function in a contract and a compiled contract.
struct NativeEntryPoint {
pub struct NativeEntryPoint {
/// The selector is the key to find the function in the contract.
selector: EntryPointSelector,
/// And the function_id is the key to find the function in the compiled contract.
Expand Down
8 changes: 4 additions & 4 deletions crates/blockifier_reexecution/src/state_reader/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use flate2::bufread;
use serde::Deserialize;
use starknet_api::contract_class::EntryPointType;
use starknet_api::core::EntryPointSelector;
use starknet_api::deprecated_contract_class::{EntryPoint, EntryPointOffset};
use starknet_api::deprecated_contract_class::{EntryPointOffset, EntryPointV0};
use starknet_api::hash::StarkHash;
use starknet_core::types::{
CompressedLegacyContractClass,
Expand All @@ -37,16 +37,16 @@ pub struct MiddleSierraContractClass {
/// from legacy format to new `EntryPoint` struct.
pub fn map_entry_points_by_type_legacy(
entry_points_by_type: LegacyEntryPointsByType,
) -> HashMap<EntryPointType, Vec<EntryPoint>> {
) -> HashMap<EntryPointType, Vec<EntryPointV0>> {
let entry_types_to_points = HashMap::from([
(EntryPointType::Constructor, entry_points_by_type.constructor),
(EntryPointType::External, entry_points_by_type.external),
(EntryPointType::L1Handler, entry_points_by_type.l1_handler),
]);

let to_contract_entry_point = |entrypoint: &LegacyContractEntryPoint| -> EntryPoint {
let to_contract_entry_point = |entrypoint: &LegacyContractEntryPoint| -> EntryPointV0 {
let felt: StarkHash = StarkHash::from_bytes_be(&entrypoint.selector.to_bytes_be());
EntryPoint {
EntryPointV0 {
offset: EntryPointOffset(usize::try_from(entrypoint.offset).unwrap()),
selector: EntryPointSelector(felt),
}
Expand Down
8 changes: 4 additions & 4 deletions crates/papyrus_protobuf/src/converters/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ impl From<state::ContractClass> for protobuf::Cairo1Class {
}
}

impl TryFrom<protobuf::EntryPoint> for deprecated_contract_class::EntryPoint {
impl TryFrom<protobuf::EntryPoint> for deprecated_contract_class::EntryPointV0 {
type Error = ProtobufConversionError;
fn try_from(value: protobuf::EntryPoint) -> Result<Self, Self::Error> {
let selector_felt =
Expand All @@ -282,12 +282,12 @@ impl TryFrom<protobuf::EntryPoint> for deprecated_contract_class::EntryPoint {
value.offset.try_into().expect("Failed converting u64 to usize"),
);

Ok(deprecated_contract_class::EntryPoint { selector, offset })
Ok(deprecated_contract_class::EntryPointV0 { selector, offset })
}
}

impl From<deprecated_contract_class::EntryPoint> for protobuf::EntryPoint {
fn from(value: deprecated_contract_class::EntryPoint) -> Self {
impl From<deprecated_contract_class::EntryPointV0> for protobuf::EntryPoint {
fn from(value: deprecated_contract_class::EntryPointV0) -> Self {
protobuf::EntryPoint {
selector: Some(value.selector.0.into()),
offset: u64::try_from(value.offset.0).expect("Failed converting usize to u64"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_api::core::{CompiledClassHash, ContractAddress, Nonce};
use starknet_api::data_availability::DataAvailabilityMode;
use starknet_api::deprecated_contract_class::{
ContractClassAbiEntry as DeprecatedContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointV0 as DeprecatedEntryPoint,
EventAbiEntry,
FunctionAbiEntry,
StructAbiEntry,
Expand Down
4 changes: 2 additions & 2 deletions crates/papyrus_rpc/src/v0_8/deprecated_contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashMap;
use papyrus_storage::db::serialization::StorageSerdeError;
use serde::{Deserialize, Serialize};
use starknet_api::contract_class::EntryPointType;
use starknet_api::deprecated_contract_class::{ContractClassAbiEntry, EntryPoint};
use starknet_api::deprecated_contract_class::{ContractClassAbiEntry, EntryPointV0};

use crate::compression_utils::compress_and_encode;

Expand All @@ -13,7 +13,7 @@ pub struct ContractClass {
/// A base64 encoding of the gzip-compressed JSON representation of program.
pub program: String,
/// The selector of each entry point is a unique identifier in the program.
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPoint>>,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPointV0>>,
}

impl TryFrom<starknet_api::deprecated_contract_class::ContractClass> for ContractClass {
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_storage/src/serialization/serializers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ use starknet_api::deprecated_contract_class::{
ConstructorType,
ContractClass as DeprecatedContractClass,
ContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointOffset,
EntryPointV0 as DeprecatedEntryPoint,
EventAbiEntry,
EventType,
FunctionAbiEntry,
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_test_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ use starknet_api::deprecated_contract_class::{
ConstructorType,
ContractClass as DeprecatedContractClass,
ContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointOffset,
EntryPointV0 as DeprecatedEntryPoint,
EventAbiEntry,
EventType,
FunctionAbiEntry,
Expand Down
8 changes: 4 additions & 4 deletions crates/starknet_api/src/deprecated_contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct ContractClass {
/// The selector of each entry point is a unique identifier in the program.
// TODO: Consider changing to IndexMap, since this is used for computing the
// class hash.
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPoint>>,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPointV0>>,
}

/// A [ContractClass](`crate::deprecated_contract_class::ContractClass`) abi entry.
Expand Down Expand Up @@ -167,16 +167,16 @@ where

/// An entry point of a [ContractClass](`crate::deprecated_contract_class::ContractClass`).
#[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Deserialize, Serialize, PartialOrd, Ord)]
pub struct EntryPoint {
pub struct EntryPointV0 {
pub selector: EntryPointSelector,
pub offset: EntryPointOffset,
}

impl TryFrom<CasmContractEntryPoint> for EntryPoint {
impl TryFrom<CasmContractEntryPoint> for EntryPointV0 {
type Error = StarknetApiError;

fn try_from(value: CasmContractEntryPoint) -> Result<Self, Self::Error> {
Ok(EntryPoint {
Ok(EntryPointV0 {
selector: EntryPointSelector(StarkHash::from(value.selector)),
offset: EntryPointOffset(value.offset),
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use starknet_api::deprecated_contract_class::{
ConstructorType,
ContractClass as DeprecatedContractClass,
ContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointOffset,
EntryPointV0 as DeprecatedEntryPoint,
FunctionAbiEntry,
Program,
TypedParameter,
Expand Down
2 changes: 1 addition & 1 deletion crates/starknet_client/src/writer/objects/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use starknet_api::contract_class::EntryPointType;
use starknet_api::core::{ClassHash, ContractAddress};
use starknet_api::deprecated_contract_class::{
ContractClassAbiEntry as DeprecatedContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointV0 as DeprecatedEntryPoint,
};
use starknet_api::transaction::TransactionHash;

Expand Down
2 changes: 1 addition & 1 deletion crates/starknet_client/src/writer/objects/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use starknet_api::core::{
};
use starknet_api::deprecated_contract_class::{
ContractClassAbiEntry as DeprecatedContractClassAbiEntry,
EntryPoint as DeprecatedEntryPoint,
EntryPointV0 as DeprecatedEntryPoint,
};
use starknet_api::state::EntryPoint;
use starknet_api::transaction::{
Expand Down
Loading