Skip to content

Commit

Permalink
Make compile_test_prepared_db API more configurable (#6449)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciektr authored Oct 4, 2024
1 parent d7813fb commit df92761
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/bin/cairo-run/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ cairo-lang-diagnostics = { path = "../../cairo-lang-diagnostics", version = "~2.
cairo-lang-runner = { path = "../../cairo-lang-runner", version = "~2.8.2" }
cairo-lang-sierra-generator = { path = "../../cairo-lang-sierra-generator", version = "~2.8.2" }
cairo-lang-starknet = { path = "../../cairo-lang-starknet", version = "~2.8.2" }
cairo-lang-utils = { path = "../../cairo-lang-utils", version = "~2.8.2" }
6 changes: 4 additions & 2 deletions crates/bin/cairo-run/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use cairo_lang_runner::{ProfilingInfoCollectionConfig, SierraCasmRunner, Starkne
use cairo_lang_sierra_generator::db::SierraGenGroup;
use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
use cairo_lang_sierra_generator::replace_ids::{DebugReplacer, SierraIdReplacer};
use cairo_lang_starknet::contract::get_contracts_info;
use cairo_lang_starknet::contract::{find_contracts, get_contracts_info};
use cairo_lang_utils::Upcast;
use clap::Parser;

/// Compiles a Cairo project and runs the function `main`.
Expand Down Expand Up @@ -75,7 +76,8 @@ fn main() -> anyhow::Result<()> {
anyhow::bail!("Program requires gas counter, please provide `--available-gas` argument.");
}

let contracts_info = get_contracts_info(db, main_crate_ids, &replacer)?;
let contracts = find_contracts((*db).upcast(), &main_crate_ids);
let contracts_info = get_contracts_info(db, contracts, &replacer)?;
let sierra_program = replacer.apply(&sierra_program);

let runner = SierraCasmRunner::new(
Expand Down
4 changes: 2 additions & 2 deletions crates/cairo-lang-starknet/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use crate::plugin::consts::{ABI_ATTR, ABI_ATTR_EMBED_V0_ARG};
mod test;

/// Represents a declaration of a contract.
#[derive(Clone)]
pub struct ContractDeclaration {
/// The id of the module that defines the contract.
pub submodule_id: SubmoduleId,
Expand Down Expand Up @@ -293,10 +294,9 @@ pub struct ContractInfo {
/// Returns the list of functions in a given module.
pub fn get_contracts_info<T: SierraIdReplacer>(
db: &dyn SierraGenGroup,
main_crate_ids: Vec<CrateId>,
contracts: Vec<ContractDeclaration>,
replacer: &T,
) -> Result<OrderedHashMap<Felt252, ContractInfo>, anyhow::Error> {
let contracts = find_contracts(db.upcast(), &main_crate_ids);
let mut contracts_info = OrderedHashMap::default();
for contract in contracts {
let (class_hash, contract_info) = analyze_contract(db, &contract, replacer)?;
Expand Down
45 changes: 39 additions & 6 deletions crates/cairo-lang-test-plugin/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::default::Default;
use std::sync::Arc;

use anyhow::Result;
use anyhow::{ensure, Result};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
use cairo_lang_compiler::get_sierra_program_for_functions;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
use cairo_lang_filesystem::db::FilesGroup;
use cairo_lang_filesystem::ids::CrateId;
use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
use cairo_lang_semantic::db::SemanticGroup;
Expand All @@ -23,7 +24,8 @@ use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
use cairo_lang_sierra_generator::replace_ids::DebugReplacer;
use cairo_lang_sierra_generator::statements_locations::StatementsLocations;
use cairo_lang_starknet::contract::{
find_contracts, get_contract_abi_functions, get_contracts_info, ContractInfo,
find_contracts, get_contract_abi_functions, get_contracts_info, ContractDeclaration,
ContractInfo,
};
use cairo_lang_starknet::plugin::consts::{CONSTRUCTOR_MODULE, EXTERNAL_MODULE, L1_HANDLER_MODULE};
use cairo_lang_starknet_classes::casm_contract_class::ENTRY_POINT_COST;
Expand Down Expand Up @@ -52,6 +54,19 @@ pub struct TestsCompilationConfig {
/// Adds the starknet contracts to the compiled tests.
pub starknet: bool,

/// Contracts to compile.
/// If defined, only this contacts will be available in tests.
/// If not, all contracts from `contract_crate_ids` will be compiled.
pub contract_declarations: Option<Vec<ContractDeclaration>>,

/// Crates to be searched for contracts.
/// If not defined, all crates will be searched.
pub contract_crate_ids: Option<Vec<CrateId>>,

/// Crates to be searched for executable attributes.
/// If not defined, test crates will be searched.
pub executable_crate_ids: Option<Vec<CrateId>>,

/// Adds mapping used by [cairo-profiler](https://github.com/software-mansion/cairo-profiler) to
/// [Annotations] in [DebugInfo] in the compiled tests.
pub add_statements_functions: bool,
Expand All @@ -75,12 +90,27 @@ pub struct TestsCompilationConfig {
pub fn compile_test_prepared_db(
db: &RootDatabase,
tests_compilation_config: TestsCompilationConfig,
main_crate_ids: Vec<CrateId>,
test_crate_ids: Vec<CrateId>,
diagnostics_reporter: DiagnosticsReporter<'_>,
) -> Result<TestCompilation> {
ensure!(
tests_compilation_config.starknet
|| tests_compilation_config.contract_declarations.is_none(),
"Contract declarations can be provided only when starknet is enabled."
);
ensure!(
tests_compilation_config.starknet || tests_compilation_config.contract_crate_ids.is_none(),
"Contract crate ids can be provided only when starknet is enabled."
);

let contracts = tests_compilation_config.contract_declarations.unwrap_or_else(|| {
find_contracts(
db,
&tests_compilation_config.contract_crate_ids.unwrap_or_else(|| db.crates()),
)
});
let all_entry_points = if tests_compilation_config.starknet {
find_contracts(db, &main_crate_ids)
contracts
.iter()
.flat_map(|contract| {
chain!(
Expand All @@ -96,7 +126,10 @@ pub fn compile_test_prepared_db(
vec![]
};

let executable_functions = find_executable_function_ids(db, main_crate_ids.clone());
let executable_functions = find_executable_function_ids(
db,
tests_compilation_config.executable_crate_ids.unwrap_or_else(|| test_crate_ids.clone()),
);
let all_tests = find_all_tests(db, test_crate_ids.clone());

let func_ids = chain!(
Expand Down Expand Up @@ -157,7 +190,7 @@ pub fn compile_test_prepared_db(
)
})
.collect_vec();
let contracts_info = get_contracts_info(db, main_crate_ids.clone(), &replacer)?;
let contracts_info = get_contracts_info(db, contracts, &replacer)?;
let sierra_program = ProgramArtifact::stripped(sierra_program).with_debug_info(DebugInfo {
executables,
annotations,
Expand Down
4 changes: 3 additions & 1 deletion crates/cairo-lang-test-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ impl TestRunner {
starknet,
add_statements_functions: config.run_profiler == RunProfilerConfig::Cairo,
add_statements_code_locations: false,
contract_declarations: None,
contract_crate_ids: None,
executable_crate_ids: None,
},
)?;
Ok(Self { compiler, config })
Expand Down Expand Up @@ -262,7 +265,6 @@ impl TestCompiler {
compile_test_prepared_db(
&self.db,
self.config.clone(),
self.main_crate_ids.clone(),
self.test_crate_ids.clone(),
diag_reporter,
)
Expand Down
3 changes: 3 additions & 0 deletions crates/cairo-lang-test-runner/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ fn test_compiled_serialization() {
starknet: true,
add_statements_functions: false,
add_statements_code_locations: false,
contract_declarations: None,
contract_crate_ids: None,
executable_crate_ids: None,
},
)
.unwrap();
Expand Down

0 comments on commit df92761

Please sign in to comment.