From 559ca6c5e00cb263defbee759cc170d6fce1adc0 Mon Sep 17 00:00:00 2001 From: LucasLvy Date: Mon, 23 Oct 2023 11:44:39 +0200 Subject: [PATCH 1/5] feat(hints): load next transaction --- src/hints/hints_raw.rs | 4 ++++ src/hints/mod.rs | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/hints/hints_raw.rs b/src/hints/hints_raw.rs index c0c36527..486975c0 100644 --- a/src/hints/hints_raw.rs +++ b/src/hints/hints_raw.rs @@ -59,3 +59,7 @@ pub const ENTER_SYSCALL_SCOPES: &str = "vm_enter_scope({\n '__deprecated_class_hashes': __deprecated_class_hashes,\n 'transactions': \ iter(os_input.transactions),\n 'execution_helper': execution_helper,\n 'deprecated_syscall_handler': \ deprecated_syscall_handler,\n 'syscall_handler': syscall_handler,\n '__dict_manager': __dict_manager,\n})"; + +pub const LOAD_NEXT_TX: &str = "tx = next(transactions)\ntx_type_bytes = \ + tx.tx_type.name.encode(\"ascii\")\nids.tx_type = int.from_bytes(tx_type_bytes, \ + \"big\")"; diff --git a/src/hints/mod.rs b/src/hints/mod.rs index 3d2e01a1..42d86ea9 100644 --- a/src/hints/mod.rs +++ b/src/hints/mod.rs @@ -3,6 +3,7 @@ pub mod hints_raw; use std::collections::HashMap; use std::rc::Rc; +use std::slice::Iter; use cairo_vm::felt::Felt252; use cairo_vm::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::{ @@ -17,7 +18,7 @@ use cairo_vm::vm::errors::hint_errors::HintError; use cairo_vm::vm::vm_core::VirtualMachine; use crate::config::DEFAULT_INPUT_PATH; -use crate::io::StarknetOsInput; +use crate::io::{InternalTransaction, StarknetOsInput}; pub fn sn_hint_processor() -> BuiltinHintProcessor { let mut hint_processor = BuiltinHintProcessor::new_empty(); @@ -240,3 +241,22 @@ pub fn enter_syscall_scopes( ) -> Result<(), HintError> { Ok(()) } + +/// Implements hint: +/// +/// tx = next(transactions) +/// tx_type_bytes = tx.tx_type.name.encode("ascii") +/// ids.tx_type = int.from_bytes(tx_type_bytes, "big") +pub fn load_next_tx( + vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap, +) -> Result<(), HintError> { + let mut transactions = exec_scopes.get::>("transactions")?; + // Safe to unwrap because the remaining number of txs is checked in the cairo code. + let tx = transactions.next().unwrap(); + exec_scopes.insert_value(transactions, "transactions"); + insert_value_from_var_name("tx_type", Felt252::from_bytes_be(tx.r#type.as_bytes()), vm, ids_data, ap_tracking) +} From eae48853cf77f82e1f9bfb6c086f0b5048f45bda Mon Sep 17 00:00:00 2001 From: LucasLvy Date: Mon, 23 Oct 2023 15:03:51 +0200 Subject: [PATCH 2/5] test(load next tx): ignored test --- scripts/setup-tests.sh | 1 + src/hints/mod.rs | 8 ++++-- src/io/mod.rs | 2 +- tests/hints.rs | 27 ++++++++++++++++++- tests/programs/load_next_tx.cairo | 44 +++++++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 tests/programs/load_next_tx.cairo diff --git a/scripts/setup-tests.sh b/scripts/setup-tests.sh index 840fa25c..e78c4ee0 100755 --- a/scripts/setup-tests.sh +++ b/scripts/setup-tests.sh @@ -48,6 +48,7 @@ cairo-compile tests/programs/fact.cairo --output build/programs/fact.json cairo-compile tests/programs/load_deprecated_class.cairo --output build/programs/load_deprecated_class.json --cairo_path cairo-lang/src cairo-compile tests/programs/initialize_state_changes.cairo --output build/programs/initialize_state_changes.json --cairo_path cairo-lang/src cairo-compile tests/programs/get_block_mapping.cairo --output build/programs/get_block_mapping.json --cairo_path cairo-lang/src +cairo-compile tests/programs/load_next_tx.cairo --output build/programs/load_next_tx.json --cairo_path cairo-lang/src # compile os with debug info cairo-compile cairo-lang/src/starkware/starknet/core/os/os.cairo --output build/os_debug.json --cairo_path cairo-lang/src diff --git a/src/hints/mod.rs b/src/hints/mod.rs index 42d86ea9..98b6e75d 100644 --- a/src/hints/mod.rs +++ b/src/hints/mod.rs @@ -1,6 +1,7 @@ pub mod block_context; pub mod hints_raw; +use std::any::Any; use std::collections::HashMap; use std::rc::Rc; use std::slice::Iter; @@ -234,11 +235,14 @@ pub fn transactions_len( /// }) pub fn enter_syscall_scopes( _vm: &mut VirtualMachine, - _exec_scopes: &mut ExecutionScopes, + exec_scopes: &mut ExecutionScopes, _ids_data: &HashMap, _ap_tracking: &ApTracking, _constants: &HashMap, ) -> Result<(), HintError> { + let os_input = exec_scopes.get::("os_input").unwrap(); + let transactions: Box = Box::new([os_input.transactions.into_iter()].into_iter()); + exec_scopes.enter_scope(HashMap::from_iter([(String::from("transactions"), transactions)])); Ok(()) } @@ -257,6 +261,6 @@ pub fn load_next_tx( let mut transactions = exec_scopes.get::>("transactions")?; // Safe to unwrap because the remaining number of txs is checked in the cairo code. let tx = transactions.next().unwrap(); - exec_scopes.insert_value(transactions, "transactions"); + exec_scopes.insert_value("transactions", transactions); insert_value_from_var_name("tx_type", Felt252::from_bytes_be(tx.r#type.as_bytes()), vm, ids_data, ap_tracking) } diff --git a/src/io/mod.rs b/src/io/mod.rs index b7bbbeff..cb719a08 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -62,7 +62,7 @@ pub struct StorageCommitment { } #[serde_as] -#[derive(Deserialize, Clone, Debug, Serialize)] +#[derive(Deserialize, Clone, Debug, Serialize, Default)] pub struct InternalTransaction { #[serde_as(as = "Felt252Str")] pub hash_value: Felt252, diff --git a/tests/hints.rs b/tests/hints.rs index 06215291..0f0ceeb4 100644 --- a/tests/hints.rs +++ b/tests/hints.rs @@ -15,7 +15,10 @@ use snos::hints::block_context::{ get_block_mapping, load_deprecated_class_facts, load_deprecated_inner, sequencer_address, }; use snos::hints::hints_raw::*; -use snos::hints::{check_deprecated_class_hash, initialize_class_hashes, initialize_state_changes, starknet_os_input}; +use snos::hints::{ + check_deprecated_class_hash, enter_syscall_scopes, initialize_class_hashes, initialize_state_changes, load_next_tx, + starknet_os_input, +}; use snos::io::StarknetOsInput; #[fixture] @@ -112,3 +115,25 @@ fn get_block_mapping_test(mut os_input_hint_processor: BuiltinHintProcessor) { ); check_output_vs_python(run_output, program, true); } + +#[rstest] +#[ignore] +fn load_next_tx_test(mut os_input_hint_processor: BuiltinHintProcessor) { + let program = "build/programs/load_next_tx.json"; + + let load_os_input = HintFunc(Box::new(starknet_os_input)); + os_input_hint_processor.add_hint(String::from(STARKNET_OS_INPUT), Rc::new(load_os_input)); + + let load_scopes = HintFunc(Box::new(enter_syscall_scopes)); + os_input_hint_processor.add_hint(String::from(ENTER_SYSCALL_SCOPES), Rc::new(load_scopes)); + + let load_transaction = HintFunc(Box::new(load_next_tx)); + os_input_hint_processor.add_hint(String::from(LOAD_NEXT_TX), Rc::new(load_transaction)); + + let run_output = cairo_run( + &fs::read(program).unwrap(), + &CairoRunConfig { layout: "starknet", relocate_mem: true, trace_enabled: true, ..Default::default() }, + &mut os_input_hint_processor, + ); + check_output_vs_python(run_output, program, true); +} diff --git a/tests/programs/load_next_tx.cairo b/tests/programs/load_next_tx.cairo new file mode 100644 index 00000000..18cde8c8 --- /dev/null +++ b/tests/programs/load_next_tx.cairo @@ -0,0 +1,44 @@ +%builtins output pedersen range_check + +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.segments import relocate_segment +from starkware.cairo.common.cairo_builtins import HashBuiltin +from starkware.starknet.core.os.output import OsCarriedOutputs + +from starkware.starknet.core.os.contract_class.deprecated_compiled_class import ( + DeprecatedCompiledClassFact, + deprecated_validate_entry_points, + deprecated_compiled_class_hash, + DeprecatedContractEntryPoint, +) + +const DEPRECATED_COMPILED_CLASS_VERSION = 0; + +func main{output_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { + alloc_locals; + + let (initial_carried_outputs: OsCarriedOutputs*) = alloc(); + %{ + from starkware.starknet.core.os.os_input import StarknetOsInput + + os_input = StarknetOsInput.load(data=program_input) + + ids.initial_carried_outputs.messages_to_l1 = segments.add_temp_segment() + ids.initial_carried_outputs.messages_to_l2 = segments.add_temp_segment() + %} + + %{ + vm_enter_scope({ + '__deprecated_class_hashes': __deprecated_class_hashes, + 'transactions': iter(os_input.transactions), + 'execution_helper': execution_helper, + 'deprecated_syscall_handler': deprecated_syscall_handler, + 'syscall_handler': syscall_handler, + '__dict_manager': __dict_manager, + }) + %} + + %{ vm_exit_scope() %} + + return (); +} From 3f6d2e4063fdbda0fdb07efff8a91f65f07a7916 Mon Sep 17 00:00:00 2001 From: LucasLvy Date: Mon, 23 Oct 2023 18:47:08 +0200 Subject: [PATCH 3/5] WIP --- src/hints/hints_raw.rs | 4 ++ src/hints/mod.rs | 1 + src/hints/transaction_context.rs | 21 +++++++ src/io/mod.rs | 1 + src/lib.rs | 95 +++++++++++++++++++++++++++++--- 5 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 src/hints/transaction_context.rs diff --git a/src/hints/hints_raw.rs b/src/hints/hints_raw.rs index 486975c0..60bb4cde 100644 --- a/src/hints/hints_raw.rs +++ b/src/hints/hints_raw.rs @@ -63,3 +63,7 @@ pub const ENTER_SYSCALL_SCOPES: &str = pub const LOAD_NEXT_TX: &str = "tx = next(transactions)\ntx_type_bytes = \ tx.tx_type.name.encode(\"ascii\")\nids.tx_type = int.from_bytes(tx_type_bytes, \ \"big\")"; + +pub const LOAD_CONTRACT_ADDRESS: &str = "from starkware.starknet.business_logic.transaction.objects import \ + InternalL1Handler\nids.contract_address = (\ntx.contract_address if \ + isinstance(tx, InternalL1Handler) else tx.sender_address\n)"; diff --git a/src/hints/mod.rs b/src/hints/mod.rs index 98b6e75d..155eda94 100644 --- a/src/hints/mod.rs +++ b/src/hints/mod.rs @@ -1,5 +1,6 @@ pub mod block_context; pub mod hints_raw; +// pub mod transaction_context; use std::any::Any; use std::collections::HashMap; diff --git a/src/hints/transaction_context.rs b/src/hints/transaction_context.rs new file mode 100644 index 00000000..8980e73c --- /dev/null +++ b/src/hints/transaction_context.rs @@ -0,0 +1,21 @@ +use crate::io::InternalTransaction; + +/// Implements hint: +/// +/// from starkware.starknet.business_logic.transaction.objects import InternalL1Handler +/// ids.contract_address = ( +/// tx.contract_address if isinstance(tx, InternalL1Handler) else tx.sender_address +/// ) +pub fn load_transaction_context( + vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap, +) -> Result<(), HintError> { + let mut transactions = exec_scopes.get::>("transactions")?; + // Safe to unwrap because the remaining number of txs is checked in the cairo code. + let tx = transactions.next().unwrap(); + exec_scopes.insert_value("transactions", transactions); + insert_value_from_var_name("tx_type", Felt252::from_bytes_be(tx.r#type.as_bytes()), vm, ids_data, ap_tracking) +} diff --git a/src/io/mod.rs b/src/io/mod.rs index cb719a08..a9a6a25c 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -106,6 +106,7 @@ pub struct InternalTransaction { pub r#type: String, } +#[derive(Debug)] pub struct StarknetOsOutput { /// The state commitment before this block. pub prev_state_root: Felt252, diff --git a/src/lib.rs b/src/lib.rs index a5bd22ba..3a89c508 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ use cairo_vm::vm::runners::cairo_runner::CairoRunner; use cairo_vm::vm::vm_core::VirtualMachine; use config::StarknetGeneralConfig; use error::SnOsError; +use num_traits::Num; use state::SharedState; use crate::io::StarknetOsOutput; @@ -79,10 +80,10 @@ impl SnOsRunner { cairo_runner .run_until_pc(end, &mut vm, &mut sn_hint_processor) .map_err(|err| VmException::from_vm_error(&cairo_runner, &vm, err)) - .map_err(|e| SnOsError::Runner(e.into()))?; + .map_err(|e| SnOsError::Runner(e.into())); cairo_runner .end_run(cairo_run_config.disable_trace_padding, false, &mut vm, &mut sn_hint_processor) - .map_err(|e| SnOsError::Runner(e.into()))?; + .map_err(|e| SnOsError::Runner(e.into())); if cairo_run_config.proof_mode { cairo_runner.finalize_segments(&mut vm).map_err(|e| SnOsError::Runner(e.into()))?; @@ -118,6 +119,85 @@ impl SnOsRunner { } }) .collect(); + let os_output = vec![ + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("200681068043714771978294967736222413892373265451181245269365696587346998380", 10) + .unwrap(), + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("2", 10).unwrap(), + Felt252::from_str_radix("2", 10).unwrap(), + Felt252::from_str_radix("4", 10).unwrap(), + Felt252::from_str_radix("6", 10).unwrap(), + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("100516143779279430775707828199600578312537898796928552917232883557759234322", 10) + .unwrap(), + Felt252::from_str_radix("0", 10).unwrap(), + Felt252::from_str_radix("35204018158445673560851558076088854146605956506855338357946372855484348775", 10) + .unwrap(), + Felt252::from_str_radix("1", 10).unwrap(), + Felt252::from_str_radix("2", 10).unwrap(), + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("100516143779279430775707828199600578312537898796928552917232883557759234322", 10) + .unwrap(), + Felt252::from_str_radix("34028236692093846346337460743176821142", 10).unwrap(), + Felt252::from_str_radix("69269496341425719426402089224874584819743134075306502400687571826086987209", 10) + .unwrap(), + Felt252::from_str_radix("13", 10).unwrap(), + Felt252::from_str_radix("46", 10).unwrap(), + Felt252::from_str_radix("30", 10).unwrap(), + Felt252::from_str_radix("221543030371090279154099648482303080997145207855149800960303587058346405278", 10) + .unwrap(), + Felt252::from_str_radix("31", 10).unwrap(), + Felt252::from_str_radix("153672706898142968531", 10).unwrap(), + Felt252::from_str_radix("32", 10).unwrap(), + Felt252::from_str_radix("9", 10).unwrap(), + Felt252::from_str_radix("81567992657121201822719584870756232234855806740606093104123927385410749460", 10) + .unwrap(), + Felt252::from_str_radix("2", 10).unwrap(), + Felt252::from_str_radix("131641924399560670288987069486918367964567650624282359632691221293624835245", 10) + .unwrap(), + Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) + .unwrap(), + Felt252::from_str_radix("200681068043714771978294967736222413892373265451181245269365696587346998380", 10) + .unwrap(), + Felt252::from_str_radix("34028236692093846346337460743176821141", 10).unwrap(), + Felt252::from_str_radix("208452472809998532760646017254057231099592366165685967544992326891398085023", 10) + .unwrap(), + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("7", 10).unwrap(), + Felt252::from_str_radix("31", 10).unwrap(), + Felt252::from_str_radix("53", 10).unwrap(), + Felt252::from_str_radix("44", 10).unwrap(), + Felt252::from_str_radix("66", 10).unwrap(), + Felt252::from_str_radix("171542524625682182385553640995899254084645198956708755167645702779965225616", 10) + .unwrap(), + Felt252::from_str_radix("10", 10).unwrap(), + Felt252::from_str_radix("171542524625682182385553640995899254084645198956708755167645702779965225617", 10) + .unwrap(), + Felt252::from_str_radix("20", 10).unwrap(), + Felt252::from_str_radix("222163306951389421296717391987130197751942633868181938423174889893366401376", 10) + .unwrap(), + Felt252::from_str_radix("34028236692093846346337460743176821140", 10).unwrap(), + Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) + .unwrap(), + Felt252::from_str_radix("5", 10).unwrap(), + Felt252::from_str_radix("1", 10).unwrap(), + Felt252::from_str_radix("11", 10).unwrap(), + Felt252::from_str_radix("97", 10).unwrap(), + Felt252::from_str_radix("55", 10).unwrap(), + Felt252::from_str_radix("88", 10).unwrap(), + Felt252::from_str_radix("66", 10).unwrap(), + Felt252::from_str_radix("99", 10).unwrap(), + Felt252::from_str_radix("261876760381503837851236634655062773110976680464358301683405235391247340282", 10) + .unwrap(), + Felt252::from_str_radix("44272185776902923874", 10).unwrap(), + Felt252::from_str_radix("330209860549393888721793468867835607193970854666866966631900875791400281196", 10) + .unwrap(), + Felt252::from_str_radix("34028236692093846346337460743176821146", 10).unwrap(), + Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) + .unwrap(), + Felt252::from(0), + ]; let prev_state_root = os_output[0].clone(); let new_state_root = os_output[1].clone(); let block_number = os_output[2].clone(); @@ -125,20 +205,20 @@ impl SnOsRunner { let config_hash = os_output[4].clone(); let os_output = &os_output[5..]; let messages_to_l1_size = ::from_be_bytes(os_output[0].to_be_bytes()[..8].try_into().unwrap()); - let messages_to_l1 = os_output[1..messages_to_l1_size].to_vec(); + let messages_to_l1 = os_output[1..1 + messages_to_l1_size].to_vec(); let os_output = &os_output[messages_to_l1_size + 1..]; let messages_to_l2_size = ::from_be_bytes(os_output[0].to_be_bytes()[..8].try_into().unwrap()); - let messages_to_l2 = os_output[1..messages_to_l2_size].to_vec(); + let messages_to_l2 = os_output[1..1 + messages_to_l2_size].to_vec(); let os_output = &os_output[messages_to_l2_size + 1..]; let state_updates_size = ::from_be_bytes(os_output[0].to_be_bytes()[..8].try_into().unwrap()); - let state_updates = os_output[1..state_updates_size].to_vec(); + let state_updates = os_output[1..1 + state_updates_size].to_vec(); let os_output = &os_output[state_updates_size + 1..]; let contract_class_diff_size = ::from_be_bytes(os_output[0].to_be_bytes()[..8].try_into().unwrap()); - let contract_class_diff = os_output[1..contract_class_diff_size].to_vec(); - StarknetOsOutput::new( + let contract_class_diff = os_output[1..1 + contract_class_diff_size].to_vec(); + let real_output = StarknetOsOutput::new( prev_state_root, new_state_root, block_number, @@ -149,6 +229,7 @@ impl SnOsRunner { state_updates, contract_class_diff, ); + println!("{:?}", real_output); vm.verify_auto_deductions().map_err(|e| SnOsError::Runner(e.into()))?; cairo_runner.read_return_values(&mut vm).map_err(|e| SnOsError::Runner(e.into()))?; From 6470c5c7c3c28a3de1cb423ac2522eb80b637b98 Mon Sep 17 00:00:00 2001 From: Ben Goebel Date: Mon, 23 Oct 2023 11:51:00 -0600 Subject: [PATCH 4/5] tx fmt, remove os output to new branch --- src/lib.rs | 87 +++---------------------------------------- tests/common/utils.rs | 6 +-- 2 files changed, 6 insertions(+), 87 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3a89c508..415ad40a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,6 @@ use cairo_vm::vm::runners::cairo_runner::CairoRunner; use cairo_vm::vm::vm_core::VirtualMachine; use config::StarknetGeneralConfig; use error::SnOsError; -use num_traits::Num; use state::SharedState; use crate::io::StarknetOsOutput; @@ -80,10 +79,12 @@ impl SnOsRunner { cairo_runner .run_until_pc(end, &mut vm, &mut sn_hint_processor) .map_err(|err| VmException::from_vm_error(&cairo_runner, &vm, err)) - .map_err(|e| SnOsError::Runner(e.into())); + .map_err(|e| SnOsError::Runner(e.into()))?; + + // End the Cairo VM run cairo_runner .end_run(cairo_run_config.disable_trace_padding, false, &mut vm, &mut sn_hint_processor) - .map_err(|e| SnOsError::Runner(e.into())); + .map_err(|e| SnOsError::Runner(e.into()))?; if cairo_run_config.proof_mode { cairo_runner.finalize_segments(&mut vm).map_err(|e| SnOsError::Runner(e.into()))?; @@ -119,85 +120,7 @@ impl SnOsRunner { } }) .collect(); - let os_output = vec![ - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("200681068043714771978294967736222413892373265451181245269365696587346998380", 10) - .unwrap(), - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("2", 10).unwrap(), - Felt252::from_str_radix("2", 10).unwrap(), - Felt252::from_str_radix("4", 10).unwrap(), - Felt252::from_str_radix("6", 10).unwrap(), - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("100516143779279430775707828199600578312537898796928552917232883557759234322", 10) - .unwrap(), - Felt252::from_str_radix("0", 10).unwrap(), - Felt252::from_str_radix("35204018158445673560851558076088854146605956506855338357946372855484348775", 10) - .unwrap(), - Felt252::from_str_radix("1", 10).unwrap(), - Felt252::from_str_radix("2", 10).unwrap(), - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("100516143779279430775707828199600578312537898796928552917232883557759234322", 10) - .unwrap(), - Felt252::from_str_radix("34028236692093846346337460743176821142", 10).unwrap(), - Felt252::from_str_radix("69269496341425719426402089224874584819743134075306502400687571826086987209", 10) - .unwrap(), - Felt252::from_str_radix("13", 10).unwrap(), - Felt252::from_str_radix("46", 10).unwrap(), - Felt252::from_str_radix("30", 10).unwrap(), - Felt252::from_str_radix("221543030371090279154099648482303080997145207855149800960303587058346405278", 10) - .unwrap(), - Felt252::from_str_radix("31", 10).unwrap(), - Felt252::from_str_radix("153672706898142968531", 10).unwrap(), - Felt252::from_str_radix("32", 10).unwrap(), - Felt252::from_str_radix("9", 10).unwrap(), - Felt252::from_str_radix("81567992657121201822719584870756232234855806740606093104123927385410749460", 10) - .unwrap(), - Felt252::from_str_radix("2", 10).unwrap(), - Felt252::from_str_radix("131641924399560670288987069486918367964567650624282359632691221293624835245", 10) - .unwrap(), - Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) - .unwrap(), - Felt252::from_str_radix("200681068043714771978294967736222413892373265451181245269365696587346998380", 10) - .unwrap(), - Felt252::from_str_radix("34028236692093846346337460743176821141", 10).unwrap(), - Felt252::from_str_radix("208452472809998532760646017254057231099592366165685967544992326891398085023", 10) - .unwrap(), - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("7", 10).unwrap(), - Felt252::from_str_radix("31", 10).unwrap(), - Felt252::from_str_radix("53", 10).unwrap(), - Felt252::from_str_radix("44", 10).unwrap(), - Felt252::from_str_radix("66", 10).unwrap(), - Felt252::from_str_radix("171542524625682182385553640995899254084645198956708755167645702779965225616", 10) - .unwrap(), - Felt252::from_str_radix("10", 10).unwrap(), - Felt252::from_str_radix("171542524625682182385553640995899254084645198956708755167645702779965225617", 10) - .unwrap(), - Felt252::from_str_radix("20", 10).unwrap(), - Felt252::from_str_radix("222163306951389421296717391987130197751942633868181938423174889893366401376", 10) - .unwrap(), - Felt252::from_str_radix("34028236692093846346337460743176821140", 10).unwrap(), - Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) - .unwrap(), - Felt252::from_str_radix("5", 10).unwrap(), - Felt252::from_str_radix("1", 10).unwrap(), - Felt252::from_str_radix("11", 10).unwrap(), - Felt252::from_str_radix("97", 10).unwrap(), - Felt252::from_str_radix("55", 10).unwrap(), - Felt252::from_str_radix("88", 10).unwrap(), - Felt252::from_str_radix("66", 10).unwrap(), - Felt252::from_str_radix("99", 10).unwrap(), - Felt252::from_str_radix("261876760381503837851236634655062773110976680464358301683405235391247340282", 10) - .unwrap(), - Felt252::from_str_radix("44272185776902923874", 10).unwrap(), - Felt252::from_str_radix("330209860549393888721793468867835607193970854666866966631900875791400281196", 10) - .unwrap(), - Felt252::from_str_radix("34028236692093846346337460743176821146", 10).unwrap(), - Felt252::from_str_radix("326212205117017662403990886779887590398051155242173007037667265340317986446", 10) - .unwrap(), - Felt252::from(0), - ]; + let prev_state_root = os_output[0].clone(); let new_state_root = os_output[1].clone(); let block_number = os_output[2].clone(); diff --git a/tests/common/utils.rs b/tests/common/utils.rs index af4dff8c..8fe2847c 100644 --- a/tests/common/utils.rs +++ b/tests/common/utils.rs @@ -62,11 +62,7 @@ pub fn deprecated_cairo_python_run(program: &str, with_input: bool) -> String { let mut raw = String::from_utf8(cmd_out.stdout).unwrap(); raw.push_str(&String::from_utf8(cmd_out.stderr).unwrap()); - raw.trim_start_matches("Program output:") - .trim_start_matches("\n ") - .trim_end_matches("\n\n") - .replace(' ', "") - .to_string() + raw.trim_start_matches("Program output:").trim_start_matches("\n ").trim_end_matches("\n\n").replace(' ', "") } pub fn raw_deploy( From 9bdf191763dc945df3ecf19e42201e27c7c77772 Mon Sep 17 00:00:00 2001 From: Ben Goebel Date: Mon, 23 Oct 2023 11:57:19 -0600 Subject: [PATCH 5/5] remove syscall test file --- scripts/debug-hint.sh | 2 +- tests/common/syscall_handler_test.py | 657 --------------------------- 2 files changed, 1 insertion(+), 658 deletions(-) delete mode 100644 tests/common/syscall_handler_test.py diff --git a/scripts/debug-hint.sh b/scripts/debug-hint.sh index 03103c41..7bc844db 100755 --- a/scripts/debug-hint.sh +++ b/scripts/debug-hint.sh @@ -17,4 +17,4 @@ else echo -e "$1 sucessfully recompiled...\n" fi -cargo test -q "$1_test" -- --nocapture +cargo test "$1_test" -- --nocapture diff --git a/tests/common/syscall_handler_test.py b/tests/common/syscall_handler_test.py deleted file mode 100644 index fccaca4b..00000000 --- a/tests/common/syscall_handler_test.py +++ /dev/null @@ -1,657 +0,0 @@ -from typing import NamedTuple, Optional, cast - -import pytest - -from starkware.cairo.common.cairo_secp import secp_utils -from starkware.cairo.common.structs import CairoStructProxy -from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME -from starkware.cairo.lang.compiler.test_utils import short_string_to_felt -from starkware.cairo.lang.vm.memory_dict import MemoryDict -from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager -from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.python.math_utils import EC_INFINITY -from starkware.python.utils import snake_to_camel_case -from starkware.starknet.business_logic.execution.execute_entry_point import ExecuteEntryPoint -from starkware.starknet.business_logic.execution.objects import ( - ExecutionResourcesManager, - OrderedEvent, - OrderedL2ToL1Message, - TransactionExecutionContext, -) -from starkware.starknet.business_logic.state.state import CachedSyncState -from starkware.starknet.business_logic.state.state_api import SyncState -from starkware.starknet.business_logic.state.state_api_objects import BlockInfo -from starkware.starknet.business_logic.state.storage_domain import StorageDomain -from starkware.starknet.business_logic.state.test_utils import EmptySyncStateReader -from starkware.starknet.core.os.syscall_handler import ( - KECCAK_FULL_RATE_IN_U64S, - BusinessLogicSyscallHandler, - to_uint256, -) -from starkware.starknet.definitions import constants -from starkware.starknet.definitions.constants import GasCost -from starkware.starknet.definitions.error_codes import CairoErrorCode -from starkware.starknet.definitions.general_config import StarknetGeneralConfig -from starkware.starknet.services.api.contract_class.contract_class import CompiledClass -from starkware.starknet.services.api.contract_class.contract_class_test_utils import ( - get_compiled_class_by_name, -) - -CURRENT_BLOCK_NUMBER = 40 # Some number bigger then STORED_BLOCK_HASH_BUFFER. -CONTRACT_ADDRESS = 1991 - - -# Fixtures. - - -@pytest.fixture(scope="module") -def general_config() -> StarknetGeneralConfig: - return StarknetGeneralConfig() - - -@pytest.fixture(scope="module") -def compiled_class() -> CompiledClass: - return get_compiled_class_by_name("test_contract") - - -@pytest.fixture -def state(compiled_class: CompiledClass) -> CachedSyncState: - """ - Returns a state with a deployed contract. - """ - block_timestamp = 1 - state = CachedSyncState( - state_reader=EmptySyncStateReader(), - block_info=BlockInfo.create_for_testing( - block_number=CURRENT_BLOCK_NUMBER, block_timestamp=block_timestamp - ), - compiled_class_cache={}, - ) - class_hash = 28 - compiled_class_hash = 6 - - # Declare new version class. - state.compiled_classes[compiled_class_hash] = compiled_class - state.set_compiled_class_hash(class_hash=class_hash, compiled_class_hash=compiled_class_hash) - # Deploy. - state.deploy_contract(contract_address=CONTRACT_ADDRESS, class_hash=class_hash) - - return state - - -@pytest.fixture -def tx_execution_context() -> TransactionExecutionContext: - return TransactionExecutionContext.create_for_testing( - account_contract_address=11, max_fee=22, nonce=33 - ) - - -@pytest.fixture -def entry_point() -> ExecuteEntryPoint: - return ExecuteEntryPoint.create_for_testing( - contract_address=CONTRACT_ADDRESS, calldata=[1], entry_point_selector=2, caller_address=3 - ) - - -@pytest.fixture -def syscall_handler( - state: CachedSyncState, - tx_execution_context: TransactionExecutionContext, - general_config: StarknetGeneralConfig, - entry_point: ExecuteEntryPoint, -) -> BusinessLogicSyscallHandler: - segments = MemorySegmentManager(memory=MemoryDict({}), prime=DEFAULT_PRIME) - return BusinessLogicSyscallHandler( - state=state, - resources_manager=ExecutionResourcesManager.empty(), - segments=segments, - tx_execution_context=tx_execution_context, - initial_syscall_ptr=segments.add(), - entry_point=entry_point, - general_config=general_config, - support_reverted=True, - ) - - -def test_storage_write(state: SyncState, syscall_handler: BusinessLogicSyscallHandler): - """ - Tests the SyscallHandler's storage_write syscall. - """ - # Positive flow. - key, value = 1970, 555 - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="storage_write", - request=syscall_handler.structs.StorageWriteRequest(reserved=0, key=key, value=value), - ) - assert ( - state.get_storage_at( - storage_domain=StorageDomain.ON_CHAIN, contract_address=CONTRACT_ADDRESS, key=key - ) - == value - ) - - # Negative flow - out of gas. - new_value = 777 - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="storage_write", - request=syscall_handler.structs.StorageWriteRequest(reserved=0, key=key, value=new_value), - out_of_gas=True, - ) - # Storage should not be changed. - assert ( - state.get_storage_at( - storage_domain=StorageDomain.ON_CHAIN, contract_address=CONTRACT_ADDRESS, key=key - ) - == value - ) - - -def test_storage_read(state: SyncState, syscall_handler: BusinessLogicSyscallHandler): - """ - Tests the SyscallHandler's storage_read syscall. - """ - # Set a non-trivial value to storage. - key, value = 2023, 777 - state.set_storage_at( - storage_domain=StorageDomain.ON_CHAIN, - contract_address=CONTRACT_ADDRESS, - key=key, - value=value, - ) - - structs = syscall_handler.structs - - # Positive flow. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="storage_read", - request=structs.StorageReadRequest(reserved=0, key=key), - response_struct=structs.StorageReadResponse, - expected_response=structs.StorageReadResponse(value=value), - ) - - # Negative flow - out of gas. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="storage_read", - request=structs.StorageReadRequest(reserved=0, key=key), - out_of_gas=True, - ) - - -def test_emit_event(syscall_handler: BusinessLogicSyscallHandler): - """ - Tests the SyscallHandler's emit_event syscall. - """ - structs = syscall_handler.structs - keys_start = syscall_handler.segments.add() - keys = [5] - keys_end = syscall_handler.segments.load_data(ptr=keys_start, data=keys) - data_start = syscall_handler.segments.add() - data = [6] - data_end = syscall_handler.segments.load_data(ptr=data_start, data=data) - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="emit_event", - request=structs.EmitEventRequest( - keys_start=keys_start, keys_end=keys_end, data_start=data_start, data_end=data_end - ), - response_struct=None, - expected_response=None, - ) - - assert len(syscall_handler.events) == 1 - assert syscall_handler.events == [OrderedEvent(order=0, keys=keys, data=data)] - - -def test_send_message_to_l1(syscall_handler: BusinessLogicSyscallHandler): - """ - Tests the SyscallHandler's send_message_to_l1 syscall. - """ - structs = syscall_handler.structs - to_address = 0 - payload_start = syscall_handler.segments.add() - payload = [5] - payload_end = syscall_handler.segments.load_data(ptr=payload_start, data=payload) - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="send_message_to_l1", - request=structs.SendMessageToL1Request( - to_address=to_address, payload_start=payload_start, payload_end=payload_end - ), - response_struct=None, - expected_response=None, - ) - - assert len(syscall_handler.l2_to_l1_messages) == 1 - assert syscall_handler.l2_to_l1_messages == [ - OrderedL2ToL1Message(order=0, to_address=to_address, payload=payload) - ] - - -def test_get_block_hash(syscall_handler: BusinessLogicSyscallHandler): - """ - Tests the SyscallHandler's get_block_hash syscall. - """ - structs = syscall_handler.structs - - # Positive flow. - - # Initialize block number -> block hash entry. - block_number = CURRENT_BLOCK_NUMBER - constants.STORED_BLOCK_HASH_BUFFER - block_hash = 1995 - syscall_handler.state.set_storage_at( - StorageDomain.ON_CHAIN, - contract_address=constants.BLOCK_HASH_CONTRACT_ADDRESS, - key=block_number, - value=block_hash, - ) - - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="get_block_hash", - request=structs.GetBlockHashRequest(block_number=block_number), - response_struct=structs.GetBlockHashResponse, - expected_response=structs.GetBlockHashResponse(block_hash=block_hash), - ) - - # Negative flow - requested block hash is out of range. - - block_number = CURRENT_BLOCK_NUMBER - constants.STORED_BLOCK_HASH_BUFFER + 1 - syscall_handler_failure_test( - syscall_handler=syscall_handler, - syscall_name="get_block_hash", - request=structs.GetBlockHashRequest(block_number=block_number), - initial_gas=GasCost.GET_BLOCK_HASH.value, - expected_error_code=CairoErrorCode.BLOCK_NUMBER_OUT_OF_RANGE, - ) - - -def test_get_execution_info( - syscall_handler: BusinessLogicSyscallHandler, - state: SyncState, - tx_execution_context: TransactionExecutionContext, - general_config: StarknetGeneralConfig, - entry_point: ExecuteEntryPoint, -): - """ - Tests the SyscallHandler's get_execution_info syscall. - """ - structs = syscall_handler.structs - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="get_execution_info", - request=structs.EmptyRequest(), - response_struct=structs.GetExecutionInfoResponse, - ) - - # Read and check response. - memory = syscall_handler.segments.memory - response = structs.GetExecutionInfoResponse.from_ptr( - memory=memory, addr=syscall_handler.syscall_ptr - 1 - ) - execution_info = structs.ExecutionInfo.from_ptr(memory=memory, addr=response.execution_info) - assert execution_info == structs.ExecutionInfo( - block_info=execution_info.block_info, - tx_info=execution_info.tx_info, - caller_address=entry_point.caller_address, - contract_address=entry_point.contract_address, - selector=entry_point.entry_point_selector, - ) - block_info = structs.BlockInfo.from_ptr(memory=memory, addr=execution_info.block_info) - assert block_info == structs.BlockInfo( - block_number=state.block_info.block_number, - block_timestamp=state.block_info.block_timestamp, - sequencer_address=state.block_info.sequencer_address, - ) - tx_info = structs.TxInfo.from_ptr(memory=memory, addr=execution_info.tx_info) - assert tx_info == structs.TxInfo( - version=tx_execution_context.version, - account_contract_address=tx_execution_context.account_contract_address, - max_fee=tx_execution_context.max_fee, - signature_start=tx_info.signature_start, - signature_end=tx_info.signature_end, - transaction_hash=tx_execution_context.transaction_hash, - chain_id=general_config.chain_id.value, - nonce=tx_execution_context.nonce, - ) - signature_len = tx_info.signature_end - tx_info.signature_start - signature = memory.get_range_as_ints(addr=tx_info.signature_start, size=signature_len) - assert signature == tx_execution_context.signature - - -def test_secp256k1_syscalls( - syscall_handler: BusinessLogicSyscallHandler, -): - """ - Tests the SyscallHandler's secp256k1 syscalls. - """ - - structs = syscall_handler.structs - segments = syscall_handler.segments - memory = segments.memory - - # Negative flow - invalid argument. - syscall_handler_failure_test( - syscall_handler=syscall_handler, - syscall_name="secp256k1_new", - request=structs.Secp256k1NewRequest( - x=to_uint256(structs, secp_utils.SECP_P), y=to_uint256(structs, 0) - ), - initial_gas=GasCost.SECP256K1_NEW.value, - expected_error_code=CairoErrorCode.INVALID_ARGUMENT, - ) - - # Positive flow - (0, 0) is the point at infinity. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_new", - request=structs.Secp256k1NewRequest(x=to_uint256(structs, 0), y=to_uint256(structs, 0)), - response_struct=structs.Secp256k1NewResponse, - ) - # We cannot use expected_response since the response is a newly allocated segment, - # so we test it manually. - response = structs.Secp256k1NewResponse.from_ptr( - memory, syscall_handler.syscall_ptr - structs.Secp256k1NewResponse.size - ) - assert response.not_on_curve == 0 - p0 = response.ec_point - assert syscall_handler.ec_points[p0] == EC_INFINITY - - x = 0xF728B4FA42485E3A0A5D2F346BAA9455E3E70682C2094CAC629F6FBED82C07CD - y = 0x8E182CA967F38E1BD6A49583F43F187608E031AB54FC0C4A8F0DC94FAD0D0611 - - # Positive flow - a point on the curve. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_new", - request=structs.Secp256k1NewRequest(x=to_uint256(structs, x), y=to_uint256(structs, y)), - response_struct=structs.Secp256k1NewResponse, - ) - - # Check the expected response. - response = structs.Secp256k1NewResponse.from_ptr( - memory, syscall_handler.syscall_ptr - structs.Secp256k1NewResponse.size - ) - - assert response.not_on_curve == 0 - p1 = response.ec_point - assert syscall_handler.ec_points[p1] == (x, y) - - # Positive flow - a point on the curve. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_get_point_from_x", - request=structs.Secp256k1GetPointFromXRequest(x=to_uint256(structs, x), y_parity=1), - response_struct=structs.Secp256k1NewResponse, - ) - - # Check the expected response. - response = structs.Secp256k1NewResponse.from_ptr( - memory, syscall_handler.syscall_ptr - structs.Secp256k1NewResponse.size - ) - assert response.not_on_curve == 0 - assert syscall_handler.ec_points[response.ec_point] == (x, y) - - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_get_xy", - request=structs.Secp256k1GetXyRequest(ec_point=response.ec_point), - response_struct=structs.Secp256k1GetXyResponse, - expected_response=structs.Secp256k1GetXyResponse( - x=to_uint256(structs, x), y=to_uint256(structs, y) - ), - ) - - # Positive flow - Add two points. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_add", - request=structs.Secp256k1AddRequest(p0=p0, p1=p1), - response_struct=structs.Secp256k1OpResponse, - ) - - # Positive flow - 17 * p0. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="secp256k1_mul", - request=structs.Secp256k1MulRequest(p=p0, scalar=to_uint256(structs, 17)), - response_struct=structs.Secp256k1OpResponse, - ) - - -def test_keccak_good_case( - syscall_handler: BusinessLogicSyscallHandler, -): - """ - Tests the SyscallHandler's keccak syscall. - """ - - structs = syscall_handler.structs - - data = list(range(1, 3 * KECCAK_FULL_RATE_IN_U64S + 1)) - start = syscall_handler.segments.gen_arg(data) - - # Positive flow. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="keccak", - request=structs.KeccakRequest( - input_start=start, input_end=start + KECCAK_FULL_RATE_IN_U64S - ), - response_struct=structs.KeccakResponse, - expected_response=structs.KeccakResponse( - result_low=0xEC687BE9C50D2218388DA73622E8FDD5, - result_high=0xD2EB808DFBA4703C528D145DFE6571AF, - ), - additional_gas=GasCost.KECCAK_ROUND_COST.value, - ) - - assert syscall_handler.resources_manager.syscall_counter["keccak"] == 1 - - # Positive flow. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="keccak", - request=structs.KeccakRequest( - input_start=start, input_end=start + 3 * KECCAK_FULL_RATE_IN_U64S - ), - response_struct=structs.KeccakResponse, - expected_response=structs.KeccakResponse( - result_low=0xEB56A947B570E88C145BD535C9831146, - result_high=0xF7BA51D4400150464F414250B163C1CB, - ), - additional_gas=3 * GasCost.KECCAK_ROUND_COST.value, - ) - - # We expected the syscall above to count as 3 as it does 3 keccak rounds. - assert syscall_handler.resources_manager.syscall_counter["keccak"] == 4 - - -@pytest.mark.parametrize("input_len", [1, KECCAK_FULL_RATE_IN_U64S - 1]) -def test_keccak_invalid_input_lengh( - syscall_handler: BusinessLogicSyscallHandler, - input_len, -): - structs = syscall_handler.structs - - data = list(range(input_len)) - start = syscall_handler.segments.gen_arg(data) - initial_gas = GasCost.KECCAK_ROUND_COST.value - - syscall_handler_failure_test( - syscall_handler=syscall_handler, - syscall_name="keccak", - request=structs.KeccakRequest(input_start=start, input_end=start + len(data)), - initial_gas=initial_gas, - expected_error_code=CairoErrorCode.INVALID_INPUT_LEN, - ) - - -def test_keccak_out_of_gas( - syscall_handler: BusinessLogicSyscallHandler, -): - structs = syscall_handler.structs - n_blocks = 2 - data = list(range(n_blocks * KECCAK_FULL_RATE_IN_U64S)) - start = syscall_handler.segments.gen_arg(data) - - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="keccak", - request=structs.KeccakRequest(input_start=start, input_end=start + len(data)), - out_of_gas=True, - additional_gas=n_blocks * GasCost.KECCAK_ROUND_COST.value, - ) - - -def test_replace_class( - state: CachedSyncState, - compiled_class: CompiledClass, - syscall_handler: BusinessLogicSyscallHandler, -): - """ - Tests the SyscallHandler's replace_class syscall. - """ - # Declare new version class. - class_hash = 10028 - compiled_class_hash = 10006 - - state.compiled_classes[compiled_class_hash] = compiled_class - state.set_compiled_class_hash(class_hash=class_hash, compiled_class_hash=compiled_class_hash) - - # Check that the contract's class does not match the class hash before the replacement. - assert state.get_class_hash_at(contract_address=CONTRACT_ADDRESS) != class_hash - - # Positive flow. - syscall_handler_test_body( - syscall_handler=syscall_handler, - syscall_name="replace_class", - request=syscall_handler.structs.ReplaceClassRequest(class_hash=class_hash), - ) - - assert state.get_class_hash_at(contract_address=CONTRACT_ADDRESS) == class_hash - - -# Utilities. - - -def execute_syscall( - syscall_handler: BusinessLogicSyscallHandler, - syscall_name: str, - request: tuple, - initial_gas: int, -) -> RelocatableValue: - syscall_ptr = syscall_handler.syscall_ptr - segments = syscall_handler.segments - structs = syscall_handler.structs - - # Prepare request. - selector = short_string_to_felt(snake_to_camel_case(syscall_name)) - request_header = structs.RequestHeader(selector=selector, gas=initial_gas) - - # Write request. - segments.write_arg(ptr=syscall_ptr, arg=request_header) - updated_syscall_ptr = syscall_ptr + len(request_header) - - flat_request = segments.gen_typed_args(cast(NamedTuple, request)) - segments.write_arg(ptr=updated_syscall_ptr, arg=flat_request) - updated_syscall_ptr += len(flat_request) - - # Execute. - syscall_handler.syscall(syscall_ptr=syscall_ptr) - - return updated_syscall_ptr - - -def syscall_handler_test_body( - syscall_handler: BusinessLogicSyscallHandler, - syscall_name: str, - request: tuple, - out_of_gas: bool = False, - response_struct: Optional[CairoStructProxy] = None, - expected_response: Optional[tuple] = None, - additional_gas: Optional[int] = None, -): - required_gas = GasCost[syscall_name.upper()].int_value - GasCost.SYSCALL_BASE.value - if additional_gas is not None: - required_gas += additional_gas - if out_of_gas: - initial_gas = required_gas - 1 - final_gas = initial_gas - failure_flag = 1 - else: - initial_gas = required_gas - final_gas = 0 - failure_flag = 0 - - updated_syscall_ptr = execute_syscall( - syscall_handler=syscall_handler, - syscall_name=syscall_name, - request=request, - initial_gas=initial_gas, - ) - - structs = syscall_handler.structs - segments = syscall_handler.segments - - # Read and validate response header. - response_header = structs.ResponseHeader.from_ptr( - memory=segments.memory, addr=updated_syscall_ptr - ) - updated_syscall_ptr += len(response_header) - assert response_header == structs.ResponseHeader(gas=final_gas, failure_flag=failure_flag) - - # Read and validate response body. - if out_of_gas: - assert response_struct is None and expected_response is None - response = structs.FailureReason.from_ptr(memory=segments.memory, addr=updated_syscall_ptr) - updated_syscall_ptr += len(response) - array = segments.memory.get_range(response.start, response.end - response.start) - assert array == [CairoErrorCode.OUT_OF_GAS.to_felt()] - else: - if response_struct is not None: - response = response_struct.from_ptr(memory=segments.memory, addr=updated_syscall_ptr) - updated_syscall_ptr += response_struct.size - if expected_response is not None: - assert response == expected_response - - # Validate that the handler advanced the syscall pointer correctly. - assert syscall_handler.syscall_ptr == updated_syscall_ptr - - -def syscall_handler_failure_test( - syscall_handler: BusinessLogicSyscallHandler, - syscall_name: str, - request: tuple, - initial_gas: int, - expected_error_code: CairoErrorCode, -): - updated_syscall_ptr = execute_syscall( - syscall_handler=syscall_handler, - syscall_name=syscall_name, - request=request, - initial_gas=initial_gas, - ) - - required_gas = GasCost[syscall_name.upper()].int_value - GasCost.SYSCALL_BASE.value - - structs = syscall_handler.structs - segments = syscall_handler.segments - - # Read and validate response header. - response_header = structs.ResponseHeader.from_ptr( - memory=segments.memory, addr=updated_syscall_ptr - ) - updated_syscall_ptr += len(response_header) - assert response_header == structs.ResponseHeader(gas=initial_gas - required_gas, failure_flag=1) - - response = structs.FailureReason.from_ptr(memory=segments.memory, addr=updated_syscall_ptr) - updated_syscall_ptr += len(response) - array = segments.memory.get_range(response.start, response.end - response.start) - assert array == [expected_error_code.to_felt()] - - # Validate that the handler advanced the syscall pointer correctly. - assert syscall_handler.syscall_ptr == updated_syscall_ptr