From df927615252c445ea55caf72db7f7215f1e02936 Mon Sep 17 00:00:00 2001 From: maciektr Date: Fri, 4 Oct 2024 09:21:20 -0700 Subject: [PATCH] Make `compile_test_prepared_db` API more configurable (#6449) --- Cargo.lock | 1 + crates/bin/cairo-run/Cargo.toml | 1 + crates/bin/cairo-run/src/main.rs | 6 ++- crates/cairo-lang-starknet/src/contract.rs | 4 +- crates/cairo-lang-test-plugin/src/lib.rs | 45 +++++++++++++++++++--- crates/cairo-lang-test-runner/src/lib.rs | 4 +- crates/cairo-lang-test-runner/src/test.rs | 3 ++ 7 files changed, 53 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index acc32481df0..c9545c19d88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1160,6 +1160,7 @@ dependencies = [ "cairo-lang-runner", "cairo-lang-sierra-generator", "cairo-lang-starknet", + "cairo-lang-utils", "clap", ] diff --git a/crates/bin/cairo-run/Cargo.toml b/crates/bin/cairo-run/Cargo.toml index 141776634e6..56da9e96316 100644 --- a/crates/bin/cairo-run/Cargo.toml +++ b/crates/bin/cairo-run/Cargo.toml @@ -15,3 +15,4 @@ cairo-lang-diagnostics = { path = "../../cairo-lang-diagnostics", version = "~2. cairo-lang-runner = { path = "../../cairo-lang-runner", version = "~2.8.2" } cairo-lang-sierra-generator = { path = "../../cairo-lang-sierra-generator", version = "~2.8.2" } cairo-lang-starknet = { path = "../../cairo-lang-starknet", version = "~2.8.2" } +cairo-lang-utils = { path = "../../cairo-lang-utils", version = "~2.8.2" } diff --git a/crates/bin/cairo-run/src/main.rs b/crates/bin/cairo-run/src/main.rs index ea9034449a4..82535555d0f 100644 --- a/crates/bin/cairo-run/src/main.rs +++ b/crates/bin/cairo-run/src/main.rs @@ -14,7 +14,8 @@ use cairo_lang_runner::{ProfilingInfoCollectionConfig, SierraCasmRunner, Starkne use cairo_lang_sierra_generator::db::SierraGenGroup; use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug; use cairo_lang_sierra_generator::replace_ids::{DebugReplacer, SierraIdReplacer}; -use cairo_lang_starknet::contract::get_contracts_info; +use cairo_lang_starknet::contract::{find_contracts, get_contracts_info}; +use cairo_lang_utils::Upcast; use clap::Parser; /// Compiles a Cairo project and runs the function `main`. @@ -75,7 +76,8 @@ fn main() -> anyhow::Result<()> { anyhow::bail!("Program requires gas counter, please provide `--available-gas` argument."); } - let contracts_info = get_contracts_info(db, main_crate_ids, &replacer)?; + let contracts = find_contracts((*db).upcast(), &main_crate_ids); + let contracts_info = get_contracts_info(db, contracts, &replacer)?; let sierra_program = replacer.apply(&sierra_program); let runner = SierraCasmRunner::new( diff --git a/crates/cairo-lang-starknet/src/contract.rs b/crates/cairo-lang-starknet/src/contract.rs index 1c9827c66bf..3117a311b25 100644 --- a/crates/cairo-lang-starknet/src/contract.rs +++ b/crates/cairo-lang-starknet/src/contract.rs @@ -41,6 +41,7 @@ use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG}; mod test; /// Represents a declaration of a contract. +#[derive(Clone)] pub struct ContractDeclaration { /// The id of the module that defines the contract. pub submodule_id: SubmoduleId, @@ -293,10 +294,9 @@ pub struct ContractInfo { /// Returns the list of functions in a given module. pub fn get_contracts_info( db: &dyn SierraGenGroup, - main_crate_ids: Vec, + contracts: Vec, replacer: &T, ) -> Result, anyhow::Error> { - let contracts = find_contracts(db.upcast(), &main_crate_ids); let mut contracts_info = OrderedHashMap::default(); for contract in contracts { let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?; diff --git a/crates/cairo-lang-test-plugin/src/lib.rs b/crates/cairo-lang-test-plugin/src/lib.rs index e2af0128af0..877bded22f1 100644 --- a/crates/cairo-lang-test-plugin/src/lib.rs +++ b/crates/cairo-lang-test-plugin/src/lib.rs @@ -1,12 +1,13 @@ use std::default::Default; use std::sync::Arc; -use anyhow::Result; +use anyhow::{ensure, Result}; use cairo_lang_compiler::db::RootDatabase; use cairo_lang_compiler::diagnostics::DiagnosticsReporter; use cairo_lang_compiler::get_sierra_program_for_functions; use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId}; +use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::CrateId; use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId; use cairo_lang_semantic::db::SemanticGroup; @@ -23,7 +24,8 @@ use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug; use cairo_lang_sierra_generator::replace_ids::DebugReplacer; use cairo_lang_sierra_generator::statements_locations::StatementsLocations; use cairo_lang_starknet::contract::{ - find_contracts, get_contract_abi_functions, get_contracts_info, ContractInfo, + find_contracts, get_contract_abi_functions, get_contracts_info, ContractDeclaration, + ContractInfo, }; use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE}; use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST; @@ -52,6 +54,19 @@ pub struct TestsCompilationConfig { /// Adds the starknet contracts to the compiled tests. pub starknet: bool, + /// Contracts to compile. + /// If defined, only this contacts will be available in tests. + /// If not, all contracts from `contract_crate_ids` will be compiled. + pub contract_declarations: Option>, + + /// Crates to be searched for contracts. + /// If not defined, all crates will be searched. + pub contract_crate_ids: Option>, + + /// Crates to be searched for executable attributes. + /// If not defined, test crates will be searched. + pub executable_crate_ids: Option>, + /// Adds mapping used by [cairo-profiler](https://github.com/software-mansion/cairo-profiler) to /// [Annotations] in [DebugInfo] in the compiled tests. pub add_statements_functions: bool, @@ -75,12 +90,27 @@ pub struct TestsCompilationConfig { pub fn compile_test_prepared_db( db: &RootDatabase, tests_compilation_config: TestsCompilationConfig, - main_crate_ids: Vec, test_crate_ids: Vec, diagnostics_reporter: DiagnosticsReporter<'_>, ) -> Result { + ensure!( + tests_compilation_config.starknet + || tests_compilation_config.contract_declarations.is_none(), + "Contract declarations can be provided only when starknet is enabled." + ); + ensure!( + tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(), + "Contract crate ids can be provided only when starknet is enabled." + ); + + let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| { + find_contracts( + db, + &tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()), + ) + }); let all_entry_points = if tests_compilation_config.starknet { - find_contracts(db, &main_crate_ids) + contracts .iter() .flat_map(|contract| { chain!( @@ -96,7 +126,10 @@ pub fn compile_test_prepared_db( vec![] }; - let executable_functions = find_executable_function_ids(db, main_crate_ids.clone()); + let executable_functions = find_executable_function_ids( + db, + tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()), + ); let all_tests = find_all_tests(db, test_crate_ids.clone()); let func_ids = chain!( @@ -157,7 +190,7 @@ pub fn compile_test_prepared_db( ) }) .collect_vec(); - let contracts_info = get_contracts_info(db, main_crate_ids.clone(), &replacer)?; + let contracts_info = get_contracts_info(db, contracts, &replacer)?; let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo { executables, annotations, diff --git a/crates/cairo-lang-test-runner/src/lib.rs b/crates/cairo-lang-test-runner/src/lib.rs index 4ac720d1309..4e93155c706 100644 --- a/crates/cairo-lang-test-runner/src/lib.rs +++ b/crates/cairo-lang-test-runner/src/lib.rs @@ -71,6 +71,9 @@ impl TestRunner { starknet, add_statements_functions: config.run_profiler == RunProfilerConfig::Cairo, add_statements_code_locations: false, + contract_declarations: None, + contract_crate_ids: None, + executable_crate_ids: None, }, )?; Ok(Self { compiler, config }) @@ -262,7 +265,6 @@ impl TestCompiler { compile_test_prepared_db( &self.db, self.config.clone(), - self.main_crate_ids.clone(), self.test_crate_ids.clone(), diag_reporter, ) diff --git a/crates/cairo-lang-test-runner/src/test.rs b/crates/cairo-lang-test-runner/src/test.rs index 9c38228b8d7..c2a2de11da7 100644 --- a/crates/cairo-lang-test-runner/src/test.rs +++ b/crates/cairo-lang-test-runner/src/test.rs @@ -20,6 +20,9 @@ fn test_compiled_serialization() { starknet: true, add_statements_functions: false, add_statements_code_locations: false, + contract_declarations: None, + contract_crate_ids: None, + executable_crate_ids: None, }, ) .unwrap();