diff --git a/ArchUnitNET/Loader/ArchLoader.cs b/ArchUnitNET/Loader/ArchLoader.cs index 93afeb378..1ba93d1df 100644 --- a/ArchUnitNET/Loader/ArchLoader.cs +++ b/ArchUnitNET/Loader/ArchLoader.cs @@ -98,7 +98,7 @@ private ArchLoader LoadAssembly(string fileName, bool includeDependencies, bool return this; } - private void LoadModule(string fileName, string nameSpace, bool includeDependencies, bool recursive) + private void LoadModule(string fileName, string nameSpace, bool includeDependencies, bool recursive, FilterFunc filterFunc = null) { try { @@ -112,7 +112,7 @@ private void LoadModule(string fileName, string nameSpace, bool includeDependenc { if (includeDependencies && recursive) { - AddReferencedAssembliesRecursively(assemblyReference, processedAssemblies, resolvedModules); + AddReferencedAssembliesRecursively(assemblyReference, processedAssemblies, resolvedModules, filterFunc); } else { @@ -148,7 +148,8 @@ private void LoadModule(string fileName, string nameSpace, bool includeDependenc } private void AddReferencedAssembliesRecursively(AssemblyNameReference currentAssemblyReference, - ICollection processedAssemblies, List resolvedModules) + ICollection processedAssemblies, List resolvedModules, + FilterFunc filterFunc) { if (processedAssemblies.Contains(currentAssemblyReference)) { @@ -161,11 +162,18 @@ private void AddReferencedAssembliesRecursively(AssemblyNameReference currentAss _assemblyResolver.AddLib(currentAssemblyReference); var assemblyDefinition = _assemblyResolver.Resolve(currentAssemblyReference) ?? throw new AssemblyResolutionException(currentAssemblyReference); - _archBuilder.AddAssembly(assemblyDefinition, false); - resolvedModules.AddRange(assemblyDefinition.Modules); + + var filterResult = filterFunc?.Invoke(assemblyDefinition); + if (filterResult?.LoadThisAssembly != false) + { + _archBuilder.AddAssembly(assemblyDefinition, false); + resolvedModules.AddRange(assemblyDefinition.Modules); + } + foreach (var reference in assemblyDefinition.Modules.SelectMany(m => m.AssemblyReferences)) { - AddReferencedAssembliesRecursively(reference, processedAssemblies, resolvedModules); + if (filterResult?.TraverseDependencies != false) + AddReferencedAssembliesRecursively(reference, processedAssemblies, resolvedModules, filterFunc); } } catch (AssemblyResolutionException) @@ -173,5 +181,20 @@ private void AddReferencedAssembliesRecursively(AssemblyNameReference currentAss //Failed to resolve assembly, skip it } } + + /// + /// Loads assemblies from dependency tree with user-defined filtration logic + /// + /// Assemblies to start traversal from + /// Delegate to control loading and traversal logic + /// + public ArchLoader LoadAssembliesRecursively(IEnumerable assemblies, FilterFunc filterFunc) + { + foreach (var assembly in assemblies) + { + LoadModule(assembly.Location, null, true, true, filterFunc); + } + return this; + } } } \ No newline at end of file diff --git a/ArchUnitNET/Loader/FilterResult.cs b/ArchUnitNET/Loader/FilterResult.cs new file mode 100644 index 000000000..7ab008652 --- /dev/null +++ b/ArchUnitNET/Loader/FilterResult.cs @@ -0,0 +1,53 @@ +// Copyright 2019 Florian Gather +// Copyright 2019 Fritz Brandhuber +// Copyright 2020 Pavel Fischer +// +// SPDX-License-Identifier: Apache-2.0 +// + +using Mono.Cecil; + +namespace ArchUnitNET.Loader +{ + /// + /// Type of delegate to control assemblies loading + /// + /// Current assembly definition + public delegate FilterResult FilterFunc(AssemblyDefinition assemblyDefinition); + + /// + /// Filter function result options + /// + public struct FilterResult + { + /// + /// Load this assembly and traverse its dependencies + /// + public static FilterResult LoadAndContinue = new FilterResult(true, true); + + /// + /// Do not load this assembly, but traverse its dependencies + /// + public static FilterResult SkipAndContinue = new FilterResult(true, false); + + /// + /// Load this assembly and do not traverse its dependencies + /// + public static FilterResult LoadAndStop = new FilterResult(false, true); + + /// + /// Do not load this assembly and do not traverse its dependencies + /// + public static FilterResult DontLoadAndStop = new FilterResult(false, false); + + private FilterResult(bool traverseDependencies, bool loadThisAssembly) + { + TraverseDependencies = traverseDependencies; + LoadThisAssembly = loadThisAssembly; + } + + internal bool TraverseDependencies { get; } + + internal bool LoadThisAssembly { get; } + } +} \ No newline at end of file diff --git a/ArchUnitNETTests/Loader/ArchLoaderTests.cs b/ArchUnitNETTests/Loader/ArchLoaderTests.cs index e67b8b11f..f28196d92 100644 --- a/ArchUnitNETTests/Loader/ArchLoaderTests.cs +++ b/ArchUnitNETTests/Loader/ArchLoaderTests.cs @@ -34,5 +34,31 @@ public void LoadAssembliesIncludingRecursiveDependencies() Assert.True(archUnitNetTestArchitectureWithRecursiveDependencies.Assemblies.Count() > 100); } + + [Fact] + public void LoadAssembliesRecursivelyWithCustomFilter() + { + FilterFunc filterFunc = assembly => assembly.Name.Name.StartsWith("ArchUnit") ? FilterResult.LoadAndContinue : FilterResult.DontLoadAndStop; + var loader = new ArchLoader(); + var architecture = loader.LoadAssembliesRecursively(new[] { typeof(BaseClass).Assembly }, filterFunc).Build(); + + Assert.Equal(3, architecture.Assemblies.Count()); + } + + [Fact] + public void LoadAssembliesRecursively_NestedDependencyOnly() + { + FilterFunc filterFunc = assembly => + { + if (assembly.Name.Name == "ArchUnitNet") + return FilterResult.LoadAndStop; + + return FilterResult.SkipAndContinue; + }; + var loader = new ArchLoader(); + var architecture = loader.LoadAssembliesRecursively(new[] { typeof(BaseClass).Assembly }, filterFunc).Build(); + + Assert.Equal(1, architecture.Assemblies.Count()); + } } } \ No newline at end of file