diff --git a/corelib/src/gas.cairo b/corelib/src/gas.cairo index f720204e08c..b47430d814a 100644 --- a/corelib/src/gas.cairo +++ b/corelib/src/gas.cairo @@ -35,5 +35,12 @@ pub extern fn withdraw_gas_all( costs: BuiltinCosts ) -> Option<()> implicits(RangeCheck, GasBuiltin) nopanic; + +/// Returns unused gas into the gas builtin. +/// +/// Useful for cases where different branches take different amounts of gas, but gas withdrawal is +/// the same for both. +pub extern fn redeposit_gas() implicits(GasBuiltin) nopanic; + /// Returns the `BuiltinCosts` table to be used in `withdraw_gas_all`. pub extern fn get_builtin_costs() -> BuiltinCosts nopanic; diff --git a/corelib/src/test/coupon_test.cairo b/corelib/src/test/coupon_test.cairo index 3cc03d57d64..304807ba868 100644 --- a/corelib/src/test/coupon_test.cairo +++ b/corelib/src/test/coupon_test.cairo @@ -1,5 +1,3 @@ -use crate::test::test_utils::assert_eq; - extern fn coupon_buy<T>() -> T nopanic; #[feature("corelib-internal-use")] @@ -22,6 +20,6 @@ fn test_arr_sum() { let available_gas = crate::testing::get_available_gas(); let res = arr_sum(arr); // Check that arr_sum did not consume any gas. - assert_eq(@crate::testing::get_available_gas(), @available_gas, 'Gas was consumed by arr_sum'); - assert_eq(@res, @12, 'Wrong array sum.'); + assert_ge!(core::testing::get_available_gas(), available_gas, "Gas was consumed by arr_sum."); + assert_eq!(res, 12); } diff --git a/crates/cairo-lang-lowering/src/optimizations/gas_redeposit.rs b/crates/cairo-lang-lowering/src/optimizations/gas_redeposit.rs new file mode 100644 index 00000000000..b94c220370c --- /dev/null +++ b/crates/cairo-lang-lowering/src/optimizations/gas_redeposit.rs @@ -0,0 +1,94 @@ +#[cfg(test)] +#[path = "gas_redeposit_test.rs"] +mod test; + +use cairo_lang_filesystem::flag::Flag; +use cairo_lang_filesystem::ids::FlagId; +use cairo_lang_semantic::corelib; +use itertools::Itertools; + +use crate::db::LoweringGroup; +use crate::ids::{ConcreteFunctionWithBodyId, SemanticFunctionIdEx}; +use crate::implicits::FunctionImplicitsTrait; +use crate::{BlockId, FlatBlockEnd, FlatLowered, Statement, StatementCall}; + +/// Adds redeposit gas actions. +/// +/// The algorithm is as follows: +/// Check if the function will have the `GasBuiltin` implicit after the lower_implicits stage. +/// If so, add a `redeposit_gas` call at the beginning of every branch in the code. +/// Otherwise, do nothing. +/// +/// Note that for implementation simplicity this stage must be applied before `LowerImplicits` +/// stage. +pub fn gas_redeposit( + db: &dyn LoweringGroup, + function_id: ConcreteFunctionWithBodyId, + lowered: &mut FlatLowered, +) { + if lowered.blocks.is_empty() { + return; + } + if !matches!( + db.get_flag(FlagId::new(db.upcast(), "add_redeposit_gas")), + Some(flag) if matches!(*flag, Flag::AddRedepositGas(true)) + ) { + return; + } + let gb_ty = corelib::get_core_ty_by_name(db.upcast(), "GasBuiltin".into(), vec![]); + // Checking if the implicits of this function past lowering includes `GasBuiltin`. + if let Ok(implicits) = db.function_with_body_implicits(function_id) { + if !implicits.into_iter().contains(&gb_ty) { + return; + } + } + assert!( + lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty), + "`GasRedeposit` stage must be called before `LowerImplicits` stage" + ); + + let redeposit_gas = corelib::get_function_id( + db.upcast(), + corelib::core_submodule(db.upcast(), "gas"), + "redeposit_gas".into(), + vec![], + ) + .lowered(db); + let mut stack = vec![BlockId::root()]; + let mut visited = vec![false; lowered.blocks.len()]; + let mut redeposit_commands = vec![]; + while let Some(block_id) = stack.pop() { + if visited[block_id.0] { + continue; + } + visited[block_id.0] = true; + let block = &lowered.blocks[block_id]; + match &block.end { + FlatBlockEnd::Goto(block_id, _) => { + stack.push(*block_id); + } + FlatBlockEnd::Match { info } => { + let location = info.location().with_auto_generation_note(db, "withdraw_gas"); + for arm in info.arms() { + stack.push(arm.block_id); + redeposit_commands.push((arm.block_id, location)); + } + } + &FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(_) => {} + FlatBlockEnd::NotSet => unreachable!("Block end not set"), + } + } + for (block_id, location) in redeposit_commands { + let block = &mut lowered.blocks[block_id]; + block.statements.insert( + 0, + Statement::Call(StatementCall { + function: redeposit_gas, + inputs: vec![], + with_coupon: false, + outputs: vec![], + location, + }), + ); + } +} diff --git a/crates/cairo-lang-lowering/src/optimizations/gas_redeposit_test.rs b/crates/cairo-lang-lowering/src/optimizations/gas_redeposit_test.rs new file mode 100644 index 00000000000..e6e07d9bf83 --- /dev/null +++ b/crates/cairo-lang-lowering/src/optimizations/gas_redeposit_test.rs @@ -0,0 +1,77 @@ +use std::ops::Deref; +use std::sync::Arc; + +use cairo_lang_debug::DebugWithDb; +use cairo_lang_filesystem::db::FilesGroupEx; +use cairo_lang_filesystem::flag::Flag; +use cairo_lang_filesystem::ids::FlagId; +use cairo_lang_semantic::test_utils::setup_test_function; +use cairo_lang_test_utils::parse_test_file::{TestFileRunner, TestRunnerResult}; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; + +use super::gas_redeposit; +use crate::db::LoweringGroup; +use crate::fmt::LoweredFormatter; +use crate::ids::ConcreteFunctionWithBodyId; +use crate::test_utils::LoweringDatabaseForTesting; + +cairo_lang_test_utils::test_file_test_with_runner!( + gas_redeposit, + "src/optimizations/test_data", + { + gas_redeposit: "gas_redeposit", + }, + GetRedepositTestRunner +); + +struct GetRedepositTestRunner { + db: LoweringDatabaseForTesting, +} +impl Default for GetRedepositTestRunner { + fn default() -> Self { + let mut db = LoweringDatabaseForTesting::new(); + let flag = FlagId::new(&db, "add_redeposit_gas"); + db.set_flag(flag, Some(Arc::new(Flag::AddRedepositGas(true)))); + Self { db } + } +} + +impl TestFileRunner for GetRedepositTestRunner { + fn run( + &mut self, + inputs: &OrderedHashMap<String, String>, + _args: &OrderedHashMap<String, String>, + ) -> TestRunnerResult { + let db = &self.db; + let (test_function, semantic_diagnostics) = setup_test_function( + db, + inputs["function"].as_str(), + inputs["function_name"].as_str(), + inputs["module_code"].as_str(), + ) + .split(); + let function_id = + ConcreteFunctionWithBodyId::from_semantic(db, test_function.concrete_function_id); + + let before = + db.concrete_function_with_body_postpanic_lowered(function_id).unwrap().deref().clone(); + + let lowering_diagnostics = db.module_lowering_diagnostics(test_function.module_id).unwrap(); + + let mut after = before.clone(); + gas_redeposit(db, function_id, &mut after); + + TestRunnerResult::success(OrderedHashMap::from([ + ("semantic_diagnostics".into(), semantic_diagnostics), + ( + "before".into(), + format!("{:?}", before.debug(&LoweredFormatter::new(db, &before.variables))), + ), + ( + "after".into(), + format!("{:?}", after.debug(&LoweredFormatter::new(db, &after.variables))), + ), + ("lowering_diagnostics".into(), lowering_diagnostics.format(db)), + ])) + } +} diff --git a/crates/cairo-lang-lowering/src/optimizations/mod.rs b/crates/cairo-lang-lowering/src/optimizations/mod.rs index 59de4d2cc53..607a2ad2510 100644 --- a/crates/cairo-lang-lowering/src/optimizations/mod.rs +++ b/crates/cairo-lang-lowering/src/optimizations/mod.rs @@ -2,6 +2,7 @@ pub mod branch_inversion; pub mod cancel_ops; pub mod config; pub mod const_folding; +pub mod gas_redeposit; pub mod match_optimizer; pub mod remappings; pub mod reorder_statements; diff --git a/crates/cairo-lang-lowering/src/optimizations/strategy.rs b/crates/cairo-lang-lowering/src/optimizations/strategy.rs index 78452ce5f17..67e71703b31 100644 --- a/crates/cairo-lang-lowering/src/optimizations/strategy.rs +++ b/crates/cairo-lang-lowering/src/optimizations/strategy.rs @@ -1,6 +1,7 @@ use cairo_lang_diagnostics::Maybe; use cairo_lang_utils::{define_short_id, Intern, LookupIntern}; +use super::gas_redeposit::gas_redeposit; use crate::db::LoweringGroup; use crate::ids::ConcreteFunctionWithBodyId; use crate::implicits::lower_implicits; @@ -29,6 +30,7 @@ pub enum OptimizationPhase { ReorganizeBlocks, ReturnOptimization, SplitStructs, + GasRedeposit, /// The following is not really an optimization but we want to apply optimizations before and /// after it, so it is convenient to treat it as an optimization. LowerImplicits, @@ -56,6 +58,7 @@ impl OptimizationPhase { OptimizationPhase::ReturnOptimization => return_optimization(db, lowered), OptimizationPhase::SplitStructs => split_structs(lowered), OptimizationPhase::LowerImplicits => lower_implicits(db, function, lowered), + OptimizationPhase::GasRedeposit => gas_redeposit(db, function, lowered), } Ok(()) } @@ -120,6 +123,7 @@ pub fn baseline_optimization_strategy(db: &dyn LoweringGroup) -> OptimizationStr /// Query implementation of [crate::db::LoweringGroup::final_optimization_strategy]. pub fn final_optimization_strategy(db: &dyn LoweringGroup) -> OptimizationStrategyId { OptimizationStrategy(vec![ + OptimizationPhase::GasRedeposit, OptimizationPhase::LowerImplicits, OptimizationPhase::ReorganizeBlocks, OptimizationPhase::CancelOps, diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/gas_redeposit b/crates/cairo-lang-lowering/src/optimizations/test_data/gas_redeposit new file mode 100644 index 00000000000..b1d8231ce48 --- /dev/null +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/gas_redeposit @@ -0,0 +1,191 @@ +//! > Test gas redeposit skip. + +//! > test_runner_name +GetRedepositTestRunner + +//! > function +fn foo(x: felt252) -> felt252 { + if x == 0 { + heavy_op1() + } else { + heavy_op2() + } +} + +//! > function_name +foo + +//! > module_code +#[inline(never)] +fn heavy_op1() -> felt252 { + 0 +} + +#[inline(never)] +fn heavy_op2() -> felt252 { + 1 +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > before +Parameters: v0: core::felt252 +blk0 (root): +Statements: + (v1: core::felt252, v2: @core::felt252) <- snapshot(v0) + (v3: core::felt252) <- 0 + (v4: core::felt252, v5: @core::felt252) <- snapshot(v3) + (v6: core::bool) <- core::Felt252PartialEq::eq(v2, v5) +End: + Match(match_enum(v6) { + bool::False(v9) => blk2, + bool::True(v7) => blk1, + }) + +blk1: +Statements: + (v8: core::felt252) <- test::heavy_op1() +End: + Goto(blk3, {v8 -> v11}) + +blk2: +Statements: + (v10: core::felt252) <- test::heavy_op2() +End: + Goto(blk3, {v10 -> v11}) + +blk3: +Statements: +End: + Return(v11) + +//! > after +Parameters: v0: core::felt252 +blk0 (root): +Statements: + (v1: core::felt252, v2: @core::felt252) <- snapshot(v0) + (v3: core::felt252) <- 0 + (v4: core::felt252, v5: @core::felt252) <- snapshot(v3) + (v6: core::bool) <- core::Felt252PartialEq::eq(v2, v5) +End: + Match(match_enum(v6) { + bool::False(v9) => blk2, + bool::True(v7) => blk1, + }) + +blk1: +Statements: + (v8: core::felt252) <- test::heavy_op1() +End: + Goto(blk3, {v8 -> v11}) + +blk2: +Statements: + (v10: core::felt252) <- test::heavy_op2() +End: + Goto(blk3, {v10 -> v11}) + +blk3: +Statements: +End: + Return(v11) + +//! > ========================================================================== + +//! > Test gas redeposit. + +//! > test_runner_name +GetRedepositTestRunner + +//! > function +fn foo(x: felt252) -> felt252 { + if x == 0 { + heavy_op1() + } else { + heavy_op2() + } +} + +//! > function_name +foo + +//! > module_code +#[inline(never)] +fn heavy_op1() -> felt252 implicits(GasBuiltin) { + 0 +} + +#[inline(never)] +fn heavy_op2() -> felt252 { + 1 +} + +//! > semantic_diagnostics + +//! > before +Parameters: v0: core::felt252 +blk0 (root): +Statements: + (v1: core::felt252, v2: @core::felt252) <- snapshot(v0) + (v3: core::felt252) <- 0 + (v4: core::felt252, v5: @core::felt252) <- snapshot(v3) + (v6: core::bool) <- core::Felt252PartialEq::eq(v2, v5) +End: + Match(match_enum(v6) { + bool::False(v9) => blk2, + bool::True(v7) => blk1, + }) + +blk1: +Statements: + (v8: core::felt252) <- test::heavy_op1() +End: + Goto(blk3, {v8 -> v11}) + +blk2: +Statements: + (v10: core::felt252) <- test::heavy_op2() +End: + Goto(blk3, {v10 -> v11}) + +blk3: +Statements: +End: + Return(v11) + +//! > after +Parameters: v0: core::felt252 +blk0 (root): +Statements: + (v1: core::felt252, v2: @core::felt252) <- snapshot(v0) + (v3: core::felt252) <- 0 + (v4: core::felt252, v5: @core::felt252) <- snapshot(v3) + (v6: core::bool) <- core::Felt252PartialEq::eq(v2, v5) +End: + Match(match_enum(v6) { + bool::False(v9) => blk2, + bool::True(v7) => blk1, + }) + +blk1: +Statements: + () <- core::gas::redeposit_gas() + (v8: core::felt252) <- test::heavy_op1() +End: + Goto(blk3, {v8 -> v11}) + +blk2: +Statements: + () <- core::gas::redeposit_gas() + (v10: core::felt252) <- test::heavy_op2() +End: + Goto(blk3, {v10 -> v11}) + +blk3: +Statements: +End: + Return(v11) + +//! > lowering_diagnostics diff --git a/crates/cairo-lang-lowering/src/test_utils.rs b/crates/cairo-lang-lowering/src/test_utils.rs index 147966f4465..70e0d719364 100644 --- a/crates/cairo-lang-lowering/src/test_utils.rs +++ b/crates/cairo-lang-lowering/src/test_utils.rs @@ -38,24 +38,28 @@ impl salsa::ParallelDatabase for LoweringDatabaseForTesting { } } impl LoweringDatabaseForTesting { + pub fn new() -> Self { + let mut res = LoweringDatabaseForTesting { storage: Default::default() }; + init_files_group(&mut res); + let suite = get_default_plugin_suite(); + res.set_macro_plugins(suite.plugins); + res.set_inline_macro_plugins(suite.inline_macro_plugins.into()); + res.set_analyzer_plugins(suite.analyzer_plugins); + + let corelib_path = detect_corelib().expect("Corelib not found in default location."); + init_dev_corelib(&mut res, corelib_path); + init_lowering_group(&mut res, InliningStrategy::Default); + res + } + /// Snapshots the db for read only. pub fn snapshot(&self) -> LoweringDatabaseForTesting { LoweringDatabaseForTesting { storage: self.storage.snapshot() } } } -pub static SHARED_DB: LazyLock<Mutex<LoweringDatabaseForTesting>> = LazyLock::new(|| { - let mut res = LoweringDatabaseForTesting { storage: Default::default() }; - init_files_group(&mut res); - let suite = get_default_plugin_suite(); - res.set_macro_plugins(suite.plugins); - res.set_inline_macro_plugins(suite.inline_macro_plugins.into()); - res.set_analyzer_plugins(suite.analyzer_plugins); - let corelib_path = detect_corelib().expect("Corelib not found in default location."); - init_dev_corelib(&mut res, corelib_path); - init_lowering_group(&mut res, InliningStrategy::Default); - Mutex::new(res) -}); +pub static SHARED_DB: LazyLock<Mutex<LoweringDatabaseForTesting>> = + LazyLock::new(|| Mutex::new(LoweringDatabaseForTesting::new())); impl Default for LoweringDatabaseForTesting { fn default() -> Self { SHARED_DB.lock().unwrap().snapshot() diff --git a/crates/cairo-lang-sierra-generator/src/store_variables/mod.rs b/crates/cairo-lang-sierra-generator/src/store_variables/mod.rs index 9237859b8c1..bfac359fb0c 100644 --- a/crates/cairo-lang-sierra-generator/src/store_variables/mod.rs +++ b/crates/cairo-lang-sierra-generator/src/store_variables/mod.rs @@ -6,21 +6,15 @@ mod state; #[cfg(test)] mod test; -use cairo_lang_filesystem::flag::Flag; -use cairo_lang_filesystem::ids::FlagId; -use cairo_lang_semantic::corelib; use cairo_lang_sierra as sierra; use cairo_lang_sierra::extensions::lib_func::{LibfuncSignature, ParamSignature, SierraApChange}; use cairo_lang_sierra::ids::ConcreteLibfuncId; use cairo_lang_sierra::program::{GenBranchInfo, GenBranchTarget, GenStatement}; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use cairo_lang_utils::{extract_matches, Intern, LookupIntern}; +use cairo_lang_utils::{extract_matches, LookupIntern}; use itertools::zip_eq; use sierra::extensions::function_call::{CouponCallLibfunc, FunctionCallLibfunc}; -use sierra::extensions::gas::RedepositGasLibfunc; use sierra::extensions::NamedLibfunc; -use sierra::ids::GenericLibfuncId; -use sierra::program::ConcreteLibfuncLongId; use state::{ merge_optional_states, DeferredVariableInfo, DeferredVariableKind, VarState, VariablesState, }; @@ -98,8 +92,6 @@ struct AddStoreVariableStatements<'a> { /// state is added to the map. When the label is visited, it is merged with the known variables /// state, and removed from the map. future_states: OrderedHashMap<pre_sierra::LabelId, VariablesState>, - /// The type of `GasBuiltin`. - gb_ty: sierra::ids::ConcreteTypeId, } impl<'a> AddStoreVariableStatements<'a> { /// Constructs a new [AddStoreVariableStatements] object. @@ -109,13 +101,6 @@ impl<'a> AddStoreVariableStatements<'a> { local_variables, result: Vec::new(), future_states: OrderedHashMap::default(), - gb_ty: db - .get_concrete_type_id(corelib::get_core_ty_by_name( - db.upcast(), - "GasBuiltin".into(), - vec![], - )) - .unwrap(), } } @@ -360,12 +345,6 @@ impl<'a> AddStoreVariableStatements<'a> { // Optimization: check if there is a prefix of `push_values` that is already on the stack. let prefix_size = state.known_stack.compute_on_stack_prefix_size(push_values); - if prefix_size == 0 { - // Calling piggyback libfuncs only if not possibly breaking stack. - for pv in push_values { - self.extra_store_temp_piggyback(state, &pv.ty, &pv.var); - } - } for (i, pre_sierra::PushValue { var, var_on_stack, ty, dup }) in push_values.iter().enumerate() { @@ -588,48 +567,4 @@ impl<'a> AddStoreVariableStatements<'a> { } } } - - /// Adds lightweight calls to libfuncs when performing a store_temp on a variable already stored - /// as temp. This is currently used only for adding the `redeposit_gas` libfunc. - fn extra_store_temp_piggyback( - &mut self, - state: &mut VariablesState, - ty: &sierra::ids::ConcreteTypeId, - var: &sierra::ids::VarId, - ) { - if *ty == self.gb_ty { - self.extra_store_temp_gas_builtin_piggyback(state, var); - } - } - - /// Adds a call to the re-deposit gas libfunc for the given variable if relevant. - fn extra_store_temp_gas_builtin_piggyback( - &mut self, - state: &mut VariablesState, - var: &sierra::ids::VarId, - ) { - if !matches!( - self.db.get_flag(FlagId::new(self.db.upcast(), "add_redeposit_gas")), - Some(flag) if matches!(*flag, Flag::AddRedepositGas(true)) - ) { - return; - } - let new_var_state = match state.pop_var_state(var) { - VarState::TempVar { ty } => { - let redeposit_gas = ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::new_inline(RedepositGasLibfunc::STR_ID), - generic_args: vec![], - } - .intern(self.db); - let vars = [var.clone()]; - self.result.push(simple_statement(redeposit_gas, &vars, &vars)); - VarState::Deferred { - info: DeferredVariableInfo { ty, kind: DeferredVariableKind::Generic }, - } - } - other => other, - }; - - state.variables.insert(var.clone(), new_var_state); - } } diff --git a/crates/cairo-lang-starknet/cairo_level_tests/abi_dispatchers_tests.cairo b/crates/cairo-lang-starknet/cairo_level_tests/abi_dispatchers_tests.cairo index bb38c11813f..551004db917 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/abi_dispatchers_tests.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/abi_dispatchers_tests.cairo @@ -112,9 +112,9 @@ fn test_validate_gas_cost() { let serialization_gas_usage = post_call_building_gas - post_serialization_gas; let entry_point_gas_usage = post_serialization_gas - post_call_gas; assert!( - call_building_gas_usage == 3150 - && serialization_gas_usage == 50950 - && entry_point_gas_usage == 141400, + call_building_gas_usage == 3250 + && serialization_gas_usage == 54450 + && entry_point_gas_usage == 144700, "Unexpected gas_usage: call_building: `{call_building_gas_usage}`. serialization: `{serialization_gas_usage}`. diff --git a/crates/cairo-lang-starknet/cairo_level_tests/interoperability.cairo b/crates/cairo-lang-starknet/cairo_level_tests/interoperability.cairo index e753eecb75d..055bb18b0ad 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/interoperability.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/interoperability.cairo @@ -92,7 +92,7 @@ fn test_flow_safe_dispatcher() { // If the test is failing do to gas usage changes, update the gas limit by taking `test_flow` test // gas usage and add about 110000. #[test] -#[available_gas(820000)] +#[available_gas(826400)] #[should_panic(expected: ('Out of gas', 'ENTRYPOINT_FAILED',))] fn test_flow_out_of_gas() { // Calling the `test_flow` test but a low gas limit.