From 3ad1b6fd12a6890fda34687633c1469f9a63516d Mon Sep 17 00:00:00 2001 From: Nenad Date: Wed, 31 Jul 2024 14:58:50 +0200 Subject: [PATCH] Fix refund --- .../applications/coin_flip/src/contract.cairo | 65 +++++++++++++------ .../coin_flip/src/mock_randomness.cairo | 20 +++--- .../applications/coin_flip/src/tests.cairo | 30 ++++----- 3 files changed, 69 insertions(+), 46 deletions(-) diff --git a/listings/applications/coin_flip/src/contract.cairo b/listings/applications/coin_flip/src/contract.cairo index a59f63cd..c2e356fc 100644 --- a/listings/applications/coin_flip/src/contract.cairo +++ b/listings/applications/coin_flip/src/contract.cairo @@ -29,16 +29,24 @@ pub mod CoinFlip { use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; #[derive(Drop, starknet::Store)] - struct FlipData { + struct RefundData { flipper: ContractAddress, - deposit: u256, - refunded: bool + amount: u256, + } + + #[derive(Drop, starknet::Store)] + struct LastRequestData { + flip_id: u64, + flipper: ContractAddress, + last_balance: u256, } #[storage] struct Storage { eth_dispatcher: IERC20Dispatcher, - flips: LegacyMap, + flips: LegacyMap, + last_received_request_id: Option, + refunds: LegacyMap, nonce: u64, randomness_contract_address: ContractAddress, } @@ -83,6 +91,7 @@ pub mod CoinFlip { pub const CALLER_NOT_RANDOMNESS: felt252 = 'Caller not randomness contract'; pub const INVALID_ADDRESS: felt252 = 'Invalid address'; pub const INVALID_FLIP_ID: felt252 = 'No flip with the given ID'; + pub const NOTHING_TO_REFUND: felt252 = 'Nothing to refund'; pub const ONLY_FLIPPER_CAN_REFUND: felt252 = 'Only the flipper can refund'; pub const REQUESTOR_NOT_SELF: felt252 = 'Requestor is not self'; pub const TRANSFER_FAILED: felt252 = 'Transfer failed'; @@ -120,7 +129,7 @@ pub mod CoinFlip { let flip_id = self._request_my_randomness(); - self.flips.write(flip_id, FlipData { flipper, deposit, refunded: false }); + self.flips.write(flip_id, flipper); self.emit(Event::Flipped(Flipped { flip_id, flipper })); } @@ -131,26 +140,27 @@ pub mod CoinFlip { fn refund(ref self: ContractState, flip_id: u64) { let caller = get_caller_address(); - let FlipData { flipper, deposit, refunded } = self.flips.read(flip_id); - assert(flipper.is_non_zero(), Errors::INVALID_FLIP_ID); + let flipper = self.flips.read(flip_id); assert(flipper == caller, Errors::ONLY_FLIPPER_CAN_REFUND); - assert(!refunded, Errors::ALREADY_REFUNDED); - let randomness_dispatcher = IRandomnessDispatcher { - contract_address: self.randomness_contract_address.read() - }; + let eth_dispatcher = self.eth_dispatcher.read(); - let total_paid = randomness_dispatcher.get_total_fees(get_contract_address(), flip_id); + if let Option::Some(data) = self.last_received_request_id.read() { + let to_refund = eth_dispatcher.balance_of(get_contract_address()) + - data.last_balance; + self.refunds.write(data.flip_id, RefundData { flipper, amount: to_refund }); + self.last_received_request_id.write(Option::None); + } - let to_refund: u256 = deposit - total_paid; + let RefundData { flipper, amount } = self.refunds.read(flip_id); + assert(flipper.is_non_zero(), Errors::NOTHING_TO_REFUND); - self.flips.write(flip_id, FlipData { flipper, deposit, refunded: true }); + self.refunds.write(flip_id, RefundData { flipper: Zero::zero(), amount: 0 }); - let eth_dispatcher = self.eth_dispatcher.read(); - let success = eth_dispatcher.transfer(flipper, to_refund); + let success = eth_dispatcher.transfer(flipper, amount); assert(success, Errors::TRANSFER_FAILED); - self.emit(Event::Refunded(Refunded { flip_id, flipper, amount: to_refund })); + self.emit(Event::Refunded(Refunded { flip_id, flipper, amount })); } } @@ -206,10 +216,27 @@ pub mod CoinFlip { } fn _process_coin_flip(ref self: ContractState, flip_id: u64, random_value: @felt252) { - let flipData = self.flips.read(flip_id); - let flipper = flipData.flipper; + let flipper = self.flips.read(flip_id); assert(flipper.is_non_zero(), Errors::INVALID_FLIP_ID); + let eth_dispatcher = self.eth_dispatcher.read(); + let current_balance = eth_dispatcher.balance_of(get_contract_address()); + if let Option::Some(data) = self.last_received_request_id.read() { + self + .refunds + .write( + data.flip_id, + RefundData { flipper, amount: current_balance - data.last_balance } + ); + } + self + .last_received_request_id + .write( + Option::Some( + LastRequestData { flip_id, flipper, last_balance: current_balance } + ) + ); + // The chance of a flipped coin landing sideways is approximately 1 in 6000. // https://journals.aps.org/pre/abstract/10.1103/PhysRevE.48.2547 // diff --git a/listings/applications/coin_flip/src/mock_randomness.cairo b/listings/applications/coin_flip/src/mock_randomness.cairo index 683f5abb..79750aa3 100644 --- a/listings/applications/coin_flip/src/mock_randomness.cairo +++ b/listings/applications/coin_flip/src/mock_randomness.cairo @@ -1,5 +1,3 @@ -use starknet::ContractAddress; - #[starknet::contract] pub mod MockRandomness { use pragma_lib::abi::IRandomness; @@ -15,7 +13,7 @@ pub mod MockRandomness { struct Storage { eth_dispatcher: IERC20Dispatcher, next_request_id: u64, - total_fees: LegacyMap::<(ContractAddress, u64), u256>, + total_fees: LegacyMap<(ContractAddress, u64), u256>, } #[event] @@ -27,8 +25,6 @@ pub mod MockRandomness { pub const TRANSFER_FAILED: felt252 = 'Transfer failed'; } - pub const PREMIUM_FEE: u128 = 100_000_000; - #[constructor] fn constructor(ref self: ContractState, eth_address: ContractAddress) { assert(eth_address.is_non_zero(), Errors::INVALID_ADDRESS); @@ -49,8 +45,7 @@ pub mod MockRandomness { let caller = get_caller_address(); let this = get_contract_address(); - let total_fee = (callback_fee_limit / 2 + self.compute_premium_fee(callback_address)) - .into(); + let total_fee: u256 = callback_fee_limit.into() * 5; let eth_dispatcher = self.eth_dispatcher.read(); let success = eth_dispatcher.transfer_from(caller, this, total_fee); assert(success, Errors::TRANSFER_FAILED); @@ -78,10 +73,10 @@ pub mod MockRandomness { ) { let requestor = IPragmaVRFDispatcher { contract_address: callback_address }; requestor.receive_random_words(requestor_address, request_id, random_words, calldata); - } - - fn compute_premium_fee(self: @ContractState, caller_address: ContractAddress) -> u128 { - PREMIUM_FEE + let eth_dispatcher = self.eth_dispatcher.read(); + let success = eth_dispatcher + .transfer(requestor_address, (callback_fee_limit - callback_fee).into()); + assert(success, Errors::TRANSFER_FAILED); } fn get_total_fees( @@ -91,6 +86,9 @@ pub mod MockRandomness { } + fn compute_premium_fee(self: @ContractState, caller_address: ContractAddress) -> u128 { + panic!("unimplemented 'compute_premium_fee'") + } fn update_status( ref self: ContractState, requestor_address: ContractAddress, diff --git a/listings/applications/coin_flip/src/tests.cairo b/listings/applications/coin_flip/src/tests.cairo index 1d3b534d..68b61691 100644 --- a/listings/applications/coin_flip/src/tests.cairo +++ b/listings/applications/coin_flip/src/tests.cairo @@ -18,7 +18,7 @@ fn deploy() -> (ICoinFlipDispatcher, IRandomnessDispatcher, IERC20Dispatcher, Co let eth_contract = declare("ERC20Upgradeable").unwrap(); let eth_name: ByteArray = "Ethereum"; let eth_symbol: ByteArray = "ETH"; - let eth_supply: u256 = CoinFlip::CALLBACK_FEE_LIMIT.into() * 10; + let eth_supply: u256 = CoinFlip::CALLBACK_FEE_LIMIT.into() * 20; let mut eth_ctor_calldata = array![]; let deployer = contract_address_const::<'deployer'>(); ((eth_name, eth_symbol, eth_supply, deployer), deployer).serialize(ref eth_ctor_calldata); @@ -59,7 +59,10 @@ fn deploy() -> (ICoinFlipDispatcher, IRandomnessDispatcher, IERC20Dispatcher, Co fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { let (coin_flip, randomness, eth, deployer) = deploy(); - let callback_fee_limit = coin_flip.get_expected_deposit(); + let expected_deposit = coin_flip.get_expected_deposit(); + let expected_callback_fee = CoinFlip::CALLBACK_FEE_LIMIT / 2; + let expected_total_fee: u256 = expected_deposit + - (CoinFlip::CALLBACK_FEE_LIMIT - expected_callback_fee).into(); let mut spy = spy_events(SpyOn::One(coin_flip.contract_address)); @@ -82,7 +85,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { ] ); - assert_eq!(eth.balance_of(deployer), original_balance - callback_fee_limit); + assert_eq!(eth.balance_of(deployer), original_balance - expected_deposit); randomness .submit_random( @@ -92,7 +95,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { 0, coin_flip.contract_address, CoinFlip::CALLBACK_FEE_LIMIT, - CoinFlip::CALLBACK_FEE_LIMIT, + expected_callback_fee, array![random_word_1].span(), array![].span(), array![] @@ -125,13 +128,12 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { coin_flip.refund(expected_request_id); stop_cheat_caller_address(coin_flip.contract_address); - assert_eq!( - eth.balance_of(deployer), - original_balance - - randomness.get_total_fees(coin_flip.contract_address, expected_request_id) - ); + assert_eq!(eth.balance_of(deployer), original_balance - expected_total_fee); let original_balance = eth.balance_of(deployer); + let expected_callback_fee = CoinFlip::CALLBACK_FEE_LIMIT / 2 + 1000; + let expected_total_fee: u256 = expected_deposit + - (CoinFlip::CALLBACK_FEE_LIMIT - expected_callback_fee).into(); start_cheat_caller_address(coin_flip.contract_address, deployer); coin_flip.flip(); @@ -151,7 +153,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { ] ); - assert_eq!(eth.balance_of(deployer), original_balance - callback_fee_limit); + assert_eq!(eth.balance_of(deployer), original_balance - expected_deposit); randomness .submit_random( @@ -161,7 +163,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { 0, coin_flip.contract_address, CoinFlip::CALLBACK_FEE_LIMIT, - CoinFlip::CALLBACK_FEE_LIMIT, + expected_callback_fee, array![random_word_2].span(), array![].span(), array![] @@ -194,11 +196,7 @@ fn test_two_flips(random_word_1: felt252, random_word_2: felt252) { coin_flip.refund(expected_request_id); stop_cheat_caller_address(coin_flip.contract_address); - assert_eq!( - eth.balance_of(deployer), - original_balance - - randomness.get_total_fees(coin_flip.contract_address, expected_request_id) - ); + assert_eq!(eth.balance_of(deployer), original_balance - expected_total_fee); } #[test]