Skip to content

Commit

Permalink
Fix refund
Browse files Browse the repository at this point in the history
  • Loading branch information
Nenad committed Jul 31, 2024
1 parent 468ff98 commit 3ad1b6f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 46 deletions.
65 changes: 46 additions & 19 deletions listings/applications/coin_flip/src/contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,24 @@ pub mod CoinFlip {
use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};

#[derive(Drop, starknet::Store)]
struct FlipData {
struct RefundData {
flipper: ContractAddress,
deposit: u256,
refunded: bool
amount: u256,
}

#[derive(Drop, starknet::Store)]
struct LastRequestData {
flip_id: u64,
flipper: ContractAddress,
last_balance: u256,
}

#[storage]
struct Storage {
eth_dispatcher: IERC20Dispatcher,
flips: LegacyMap<u64, FlipData>,
flips: LegacyMap<u64, ContractAddress>,
last_received_request_id: Option<LastRequestData>,
refunds: LegacyMap<u64, RefundData>,
nonce: u64,
randomness_contract_address: ContractAddress,
}
Expand Down Expand Up @@ -83,6 +91,7 @@ pub mod CoinFlip {
pub const CALLER_NOT_RANDOMNESS: felt252 = 'Caller not randomness contract';
pub const INVALID_ADDRESS: felt252 = 'Invalid address';
pub const INVALID_FLIP_ID: felt252 = 'No flip with the given ID';
pub const NOTHING_TO_REFUND: felt252 = 'Nothing to refund';
pub const ONLY_FLIPPER_CAN_REFUND: felt252 = 'Only the flipper can refund';
pub const REQUESTOR_NOT_SELF: felt252 = 'Requestor is not self';
pub const TRANSFER_FAILED: felt252 = 'Transfer failed';
Expand Down Expand Up @@ -120,7 +129,7 @@ pub mod CoinFlip {

let flip_id = self._request_my_randomness();

self.flips.write(flip_id, FlipData { flipper, deposit, refunded: false });
self.flips.write(flip_id, flipper);

self.emit(Event::Flipped(Flipped { flip_id, flipper }));
}
Expand All @@ -131,26 +140,27 @@ pub mod CoinFlip {

fn refund(ref self: ContractState, flip_id: u64) {
let caller = get_caller_address();
let FlipData { flipper, deposit, refunded } = self.flips.read(flip_id);
assert(flipper.is_non_zero(), Errors::INVALID_FLIP_ID);
let flipper = self.flips.read(flip_id);
assert(flipper == caller, Errors::ONLY_FLIPPER_CAN_REFUND);
assert(!refunded, Errors::ALREADY_REFUNDED);

let randomness_dispatcher = IRandomnessDispatcher {
contract_address: self.randomness_contract_address.read()
};
let eth_dispatcher = self.eth_dispatcher.read();

let total_paid = randomness_dispatcher.get_total_fees(get_contract_address(), flip_id);
if let Option::Some(data) = self.last_received_request_id.read() {
let to_refund = eth_dispatcher.balance_of(get_contract_address())
- data.last_balance;
self.refunds.write(data.flip_id, RefundData { flipper, amount: to_refund });
self.last_received_request_id.write(Option::None);
}

let to_refund: u256 = deposit - total_paid;
let RefundData { flipper, amount } = self.refunds.read(flip_id);
assert(flipper.is_non_zero(), Errors::NOTHING_TO_REFUND);

self.flips.write(flip_id, FlipData { flipper, deposit, refunded: true });
self.refunds.write(flip_id, RefundData { flipper: Zero::zero(), amount: 0 });

let eth_dispatcher = self.eth_dispatcher.read();
let success = eth_dispatcher.transfer(flipper, to_refund);
let success = eth_dispatcher.transfer(flipper, amount);
assert(success, Errors::TRANSFER_FAILED);

self.emit(Event::Refunded(Refunded { flip_id, flipper, amount: to_refund }));
self.emit(Event::Refunded(Refunded { flip_id, flipper, amount }));
}
}

Expand Down Expand Up @@ -206,10 +216,27 @@ pub mod CoinFlip {
}

fn _process_coin_flip(ref self: ContractState, flip_id: u64, random_value: @felt252) {
let flipData = self.flips.read(flip_id);
let flipper = flipData.flipper;
let flipper = self.flips.read(flip_id);
assert(flipper.is_non_zero(), Errors::INVALID_FLIP_ID);

let eth_dispatcher = self.eth_dispatcher.read();
let current_balance = eth_dispatcher.balance_of(get_contract_address());
if let Option::Some(data) = self.last_received_request_id.read() {
self
.refunds
.write(
data.flip_id,
RefundData { flipper, amount: current_balance - data.last_balance }
);
}
self
.last_received_request_id
.write(
Option::Some(
LastRequestData { flip_id, flipper, last_balance: current_balance }
)
);

// The chance of a flipped coin landing sideways is approximately 1 in 6000.
// https://journals.aps.org/pre/abstract/10.1103/PhysRevE.48.2547
//
Expand Down
20 changes: 9 additions & 11 deletions listings/applications/coin_flip/src/mock_randomness.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use starknet::ContractAddress;

#[starknet::contract]
pub mod MockRandomness {
use pragma_lib::abi::IRandomness;
Expand All @@ -15,7 +13,7 @@ pub mod MockRandomness {
struct Storage {
eth_dispatcher: IERC20Dispatcher,
next_request_id: u64,
total_fees: LegacyMap::<(ContractAddress, u64), u256>,
total_fees: LegacyMap<(ContractAddress, u64), u256>,
}

#[event]
Expand All @@ -27,8 +25,6 @@ pub mod MockRandomness {
pub const TRANSFER_FAILED: felt252 = 'Transfer failed';
}

pub const PREMIUM_FEE: u128 = 100_000_000;

#[constructor]
fn constructor(ref self: ContractState, eth_address: ContractAddress) {
assert(eth_address.is_non_zero(), Errors::INVALID_ADDRESS);
Expand All @@ -49,8 +45,7 @@ pub mod MockRandomness {
let caller = get_caller_address();
let this = get_contract_address();

let total_fee = (callback_fee_limit / 2 + self.compute_premium_fee(callback_address))
.into();
let total_fee: u256 = callback_fee_limit.into() * 5;
let eth_dispatcher = self.eth_dispatcher.read();
let success = eth_dispatcher.transfer_from(caller, this, total_fee);
assert(success, Errors::TRANSFER_FAILED);
Expand Down Expand Up @@ -78,10 +73,10 @@ pub mod MockRandomness {
) {
let requestor = IPragmaVRFDispatcher { contract_address: callback_address };
requestor.receive_random_words(requestor_address, request_id, random_words, calldata);
}

fn compute_premium_fee(self: @ContractState, caller_address: ContractAddress) -> u128 {
PREMIUM_FEE
let eth_dispatcher = self.eth_dispatcher.read();
let success = eth_dispatcher
.transfer(requestor_address, (callback_fee_limit - callback_fee).into());
assert(success, Errors::TRANSFER_FAILED);
}

fn get_total_fees(
Expand All @@ -91,6 +86,9 @@ pub mod MockRandomness {
}


fn compute_premium_fee(self: @ContractState, caller_address: ContractAddress) -> u128 {
panic!("unimplemented 'compute_premium_fee'")
}
fn update_status(
ref self: ContractState,
requestor_address: ContractAddress,
Expand Down
30 changes: 14 additions & 16 deletions listings/applications/coin_flip/src/tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn deploy() -> (ICoinFlipDispatcher, IRandomnessDispatcher, IERC20Dispatcher, Co
let eth_contract = declare("ERC20Upgradeable").unwrap();
let eth_name: ByteArray = "Ethereum";
let eth_symbol: ByteArray = "ETH";
let eth_supply: u256 = CoinFlip::CALLBACK_FEE_LIMIT.into() * 10;
let eth_supply: u256 = CoinFlip::CALLBACK_FEE_LIMIT.into() * 20;
let mut eth_ctor_calldata = array![];
let deployer = contract_address_const::<'deployer'>();
((eth_name, eth_symbol, eth_supply, deployer), deployer).serialize(ref eth_ctor_calldata);
Expand Down Expand Up @@ -59,7 +59,10 @@ fn deploy() -> (ICoinFlipDispatcher, IRandomnessDispatcher, IERC20Dispatcher, Co
fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
let (coin_flip, randomness, eth, deployer) = deploy();

let callback_fee_limit = coin_flip.get_expected_deposit();
let expected_deposit = coin_flip.get_expected_deposit();
let expected_callback_fee = CoinFlip::CALLBACK_FEE_LIMIT / 2;
let expected_total_fee: u256 = expected_deposit
- (CoinFlip::CALLBACK_FEE_LIMIT - expected_callback_fee).into();

let mut spy = spy_events(SpyOn::One(coin_flip.contract_address));

Expand All @@ -82,7 +85,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
]
);

assert_eq!(eth.balance_of(deployer), original_balance - callback_fee_limit);
assert_eq!(eth.balance_of(deployer), original_balance - expected_deposit);

randomness
.submit_random(
Expand All @@ -92,7 +95,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
0,
coin_flip.contract_address,
CoinFlip::CALLBACK_FEE_LIMIT,
CoinFlip::CALLBACK_FEE_LIMIT,
expected_callback_fee,
array![random_word_1].span(),
array![].span(),
array![]
Expand Down Expand Up @@ -125,13 +128,12 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
coin_flip.refund(expected_request_id);
stop_cheat_caller_address(coin_flip.contract_address);

assert_eq!(
eth.balance_of(deployer),
original_balance
- randomness.get_total_fees(coin_flip.contract_address, expected_request_id)
);
assert_eq!(eth.balance_of(deployer), original_balance - expected_total_fee);

let original_balance = eth.balance_of(deployer);
let expected_callback_fee = CoinFlip::CALLBACK_FEE_LIMIT / 2 + 1000;
let expected_total_fee: u256 = expected_deposit
- (CoinFlip::CALLBACK_FEE_LIMIT - expected_callback_fee).into();

start_cheat_caller_address(coin_flip.contract_address, deployer);
coin_flip.flip();
Expand All @@ -151,7 +153,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
]
);

assert_eq!(eth.balance_of(deployer), original_balance - callback_fee_limit);
assert_eq!(eth.balance_of(deployer), original_balance - expected_deposit);

randomness
.submit_random(
Expand All @@ -161,7 +163,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
0,
coin_flip.contract_address,
CoinFlip::CALLBACK_FEE_LIMIT,
CoinFlip::CALLBACK_FEE_LIMIT,
expected_callback_fee,
array![random_word_2].span(),
array![].span(),
array![]
Expand Down Expand Up @@ -194,11 +196,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) {
coin_flip.refund(expected_request_id);
stop_cheat_caller_address(coin_flip.contract_address);

assert_eq!(
eth.balance_of(deployer),
original_balance
- randomness.get_total_fees(coin_flip.contract_address, expected_request_id)
);
assert_eq!(eth.balance_of(deployer), original_balance - expected_total_fee);
}

#[test]
Expand Down

0 comments on commit 3ad1b6f

Please sign in to comment.