diff --git a/scarb/src/bin/scarb/args.rs b/scarb/src/bin/scarb/args.rs index 4b007d799..40dbe944e 100644 --- a/scarb/src/bin/scarb/args.rs +++ b/scarb/src/bin/scarb/args.rs @@ -228,6 +228,14 @@ pub struct ExpandArgs { #[command(flatten)] pub features: FeaturesSpec, + /// Specify the target to expand by target kind. + #[arg(long)] + pub target_kind: Option, + + /// Specify the target to expand by target name. + #[arg(long)] + pub target_name: Option, + /// Do not attempt formatting. #[arg(long, default_value_t = false)] pub ugly: bool, diff --git a/scarb/src/bin/scarb/commands/expand.rs b/scarb/src/bin/scarb/commands/expand.rs index 67939ab71..e6c018962 100644 --- a/scarb/src/bin/scarb/commands/expand.rs +++ b/scarb/src/bin/scarb/commands/expand.rs @@ -1,7 +1,8 @@ use anyhow::Result; +use smol_str::ToSmolStr; use crate::args::ExpandArgs; -use scarb::core::Config; +use scarb::core::{Config, TargetKind}; use scarb::ops; use scarb::ops::ExpandOpts; @@ -12,6 +13,8 @@ pub fn run(args: ExpandArgs, config: &Config) -> Result<()> { let opts = ExpandOpts { features: args.features.try_into()?, ugly: args.ugly, + target_name: args.target_name.map(|n| n.to_smolstr()), + target_kind: args.target_kind.map(TargetKind::try_new).transpose()?, }; ops::expand(package, opts, &ws) } diff --git a/scarb/src/ops/compile.rs b/scarb/src/ops/compile.rs index 63252fbfe..e894f2326 100644 --- a/scarb/src/ops/compile.rs +++ b/scarb/src/ops/compile.rs @@ -17,6 +17,7 @@ use crate::core::{ FeatureName, PackageId, PackageName, TargetKind, Utf8PathWorkspaceExt, Workspace, }; use crate::ops; +use crate::ops::get_test_package_ids; #[derive(Debug, Clone)] pub enum FeaturesSelector { @@ -81,26 +82,7 @@ where let resolve = ops::resolve_workspace(ws)?; // Add test compilation units to build - let packages = packages - .into_iter() - .flat_map(|package_id| { - let package = ws.members().find(|p| p.id == package_id).unwrap(); - let mut result: Vec = package - .manifest - .targets - .iter() - .filter(|t| t.is_test()) - .map(|t| { - package - .id - .for_test_target(t.group_id.clone().unwrap_or(t.name.clone())) - }) - .collect(); - result.push(package_id); - result - }) - .collect::>(); - + let packages = get_test_package_ids(packages, ws); let compilation_units = ops::generate_compilation_units(&resolve, &opts.features, ws)? .into_iter() .filter(|cu| { diff --git a/scarb/src/ops/expand.rs b/scarb/src/ops/expand.rs index a15226f51..38e895816 100644 --- a/scarb/src/ops/expand.rs +++ b/scarb/src/ops/expand.rs @@ -3,9 +3,8 @@ use crate::compiler::helpers::{build_compiler_config, write_string}; use crate::compiler::{CairoCompilationUnit, CompilationUnit, CompilationUnitAttributes}; use crate::core::{Package, TargetKind, Workspace}; use crate::ops; -use crate::ops::FeaturesOpts; +use crate::ops::{get_test_package_ids, FeaturesOpts}; use anyhow::{anyhow, bail, Context, Result}; -use cairo_lang_compiler::db::RootDatabase; use cairo_lang_compiler::diagnostics::DiagnosticsError; use cairo_lang_defs::db::DefsGroup; use cairo_lang_defs::ids::{LanguageElementId, ModuleId, ModuleItemId}; @@ -18,11 +17,14 @@ use cairo_lang_parser::db::ParserGroup; use cairo_lang_syntax::node::helpers::UsePathEx; use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::Upcast; +use smol_str::SmolStr; use std::collections::HashSet; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ExpandOpts { pub features: FeaturesOpts, + pub target_kind: Option, + pub target_name: Option, pub ugly: bool, } @@ -38,30 +40,49 @@ pub fn expand(package: Package, opts: ExpandOpts, ws: &Workspace<'_>) -> Result< .map(|unit| ops::compile::compile_unit(unit.clone(), ws)) .collect::>>()?; - let Some(compilation_unit) = compilation_units.into_iter().find(|unit| { - unit.main_package_id() == package.id - && unit.main_component().target_kind() == TargetKind::LIB - }) else { - bail!("compilation unit not found for `{package_name}`") - }; - let CompilationUnit::Cairo(compilation_unit) = compilation_unit else { - bail!("only cairo compilation units can be expanded") - }; - let ScarbDatabase { db, .. } = build_scarb_root_database(&compilation_unit, ws)?; - let mut compiler_config = build_compiler_config(&compilation_unit, ws); - compiler_config - .diagnostics_reporter - .ensure(&db) - .map_err(|err| err.into()) - .map_err(|err| { - if !suppress_error(&err) { - ws.config().ui().anyhow(&err); - } + let compilation_units = compilation_units + .into_iter() + // We rewrite group compilation units to single source paths ones. We value simplicity over + // performance here, as expand output will be read by people rather than tooling. + .flat_map(|unit| match unit { + CompilationUnit::Cairo(unit) => unit + .rewrite_to_single_source_paths() + .into_iter() + .map(CompilationUnit::Cairo) + .collect::>(), + // We include non-cairo compilation units here, so we can show better error msg later. + _ => vec![unit], + }) + .filter(|unit| { + let target_kind = if opts.target_name.is_none() && opts.target_kind.is_none() { + // If no target specifier is used - default to lib. + Some(TargetKind::LIB) + } else { + opts.target_kind.clone() + }; + // Includes test package ids. + get_test_package_ids(vec![package.id], ws).contains(&unit.main_package_id()) + // We can use main_component below, as targets are not grouped. + && target_kind.as_ref() + .map_or(true, |kind| unit.main_component().target_kind() == *kind) + && opts + .target_name + .as_ref() + .map_or(true, |name| unit.main_component().first_target().name == *name) + }) + .map(|unit| match unit { + CompilationUnit::Cairo(unit) => Ok(unit), + _ => bail!("only cairo compilation units can be expanded"), + }) + .collect::>>()?; - anyhow!("could not check `{package_name}` due to previous error") - })?; + if compilation_units.is_empty() { + bail!("no compilation units found for `{package_name}`") + } - do_expand(&db, &compilation_unit, opts, ws)?; + for compilation_unit in compilation_units { + do_expand(&compilation_unit, opts.clone(), ws)?; + } Ok(()) } @@ -123,11 +144,23 @@ impl ModuleStack { } fn do_expand( - db: &RootDatabase, compilation_unit: &CairoCompilationUnit, opts: ExpandOpts, ws: &Workspace<'_>, ) -> Result<()> { + let ScarbDatabase { db, .. } = build_scarb_root_database(compilation_unit, ws)?; + let mut compiler_config = build_compiler_config(compilation_unit, ws); + compiler_config + .diagnostics_reporter + .ensure(&db) + .map_err(|err| err.into()) + .map_err(|err| { + if !suppress_error(&err) { + ws.config().ui().anyhow(&err); + } + + anyhow!("could not check due to previous error") + })?; let main_crate_id = db.intern_crate(CrateLongId::Real( compilation_unit.main_component().cairo_package_name(), )); @@ -142,13 +175,13 @@ fn do_expand( .context("failed to retrieve module main file syntax")?; let crate_modules = db.crate_modules(main_crate_id); - let item_asts = file_syntax.items(db); + let item_asts = file_syntax.items(&db); - let mut builder = PatchBuilder::new(db, &item_asts); + let mut builder = PatchBuilder::new(&db, &item_asts); let mut module_stack = ModuleStack::new(); for module_id in crate_modules.iter() { - builder.add_str(module_stack.register(module_id.full_path(db)).as_str()); + builder.add_str(module_stack.register(module_id.full_path(&db)).as_str()); let Some(module_items) = db.module_items(*module_id).to_option() else { continue; }; @@ -156,7 +189,7 @@ fn do_expand( for item_id in module_items.iter() { // We need to handle uses manually, as module data only includes use leaf instead of path. if let ModuleItemId::Use(use_id) = item_id { - let use_item = use_id.stable_ptr(db).lookup(db.upcast()); + let use_item = use_id.stable_ptr(&db).lookup(db.upcast()); let item = ast::UsePath::Leaf(use_item.clone()).get_item(db.upcast()); let item = item.use_path(db.upcast()); // We need to deduplicate multi-uses (`a::{b, c}`), which are split into multiple leaves. @@ -172,7 +205,7 @@ fn do_expand( if let ModuleItemId::Submodule(_) = item_id { continue; } - let node = item_id.stable_location(db).syntax_node(db); + let node = item_id.stable_location(&db).syntax_node(&db); builder.add_node(node); } } diff --git a/scarb/src/ops/resolve.rs b/scarb/src/ops/resolve.rs index cb64f8dcc..e780a3174 100644 --- a/scarb/src/ops/resolve.rs +++ b/scarb/src/ops/resolve.rs @@ -571,3 +571,29 @@ fn generate_cairo_plugin_compilation_units(member: &Package) -> Result, ws: &Workspace<'_>) -> Vec { + packages + .into_iter() + .flat_map(|package_id| { + let Some(package) = ws.members().find(|p| p.id == package_id) else { + return Vec::new(); + }; + let mut result: Vec = package + .manifest + .targets + .iter() + .filter(|t| t.is_test()) + .map(|t| { + package + .id + .for_test_target(t.group_id.clone().unwrap_or(t.name.clone())) + }) + .collect(); + result.push(package_id); + result + }) + .collect::>() +}