diff --git a/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs b/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs new file mode 100644 index 000000000..f4320352c --- /dev/null +++ b/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs @@ -0,0 +1,141 @@ +// ------------------------------------------------------------------------ +// Copyright 2023 The Dapr Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + +using System.Collections.Immutable; +using System.Text; +using Dapr.Actors.Generators.Extensions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + + + +namespace Dapr.Actors.Generators; + +/// +/// Generates an extension method that can be used during dependency injection to register all actor types. +/// +[Generator] +public sealed class ActorRegistrationGenerator : IIncrementalGenerator +{ + private const string DaprActorType = "Dapr.Actors.Runtime.Actor"; + + /// + /// Initializes the generator and registers the syntax receiver. + /// + /// The to register callbacks on + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var classDeclarations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (s, _) => IsClassDeclaration(s), + transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)) + .Where(static m => m is not null); + + var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect()); + context.RegisterSourceOutput(compilationAndClasses, static (spc, source) => Execute(source.Left, source.Right, spc)); + } + + private static bool IsClassDeclaration(SyntaxNode node) => node is ClassDeclarationSyntax; + + private static INamedTypeSymbol? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) + { + var classDeclaration = (ClassDeclarationSyntax)context.Node; + var model = context.SemanticModel; + + if (model.GetDeclaredSymbol(classDeclaration) is not INamedTypeSymbol classSymbol) + { + return null; + } + + var actorClass = context.SemanticModel.Compilation.GetTypeByMetadataName(DaprActorType); + return classSymbol.BaseType != null && classSymbol.BaseType.Equals(actorClass, SymbolEqualityComparer.Default) ? classSymbol : null; + } + + private static void Execute(Compilation compilation, ImmutableArray actorTypes, + SourceProductionContext context) + { + var validActorTypes = actorTypes.Where(static t => t is not null).Cast().ToList(); + var source = GenerateActorRegistrationSource(compilation, validActorTypes); + context.AddSource("ActorRegistrationExtensions.g.cs", SourceText.From(source, Encoding.UTF8)); + } + + /// + /// Generates the source code for the actor registration method. + /// + /// The current compilation context. + /// The list of actor types to register. + /// The generated source code as a string. + private static string GenerateActorRegistrationSource(Compilation compilation, IReadOnlyList actorTypes) + { +#pragma warning disable RS1035 + var registrations = string.Join(Environment.NewLine, +#pragma warning restore RS1035 + actorTypes.Select(t => $"options.Actors.RegisterActor<{t.ToDisplayString()}>();")); + + return $@" +using Microsoft.Extensions.DependencyInjection; +using Dapr.Actors.AspNetCore; +using Dapr.Actors.Runtime; +using Dapr.Actors; +using Dapr.Actors.AspNetCore; + +/// +/// Extension methods for registering Dapr actors. +/// +public static class ActorRegistrationExtensions +{{ + /// + /// Registers all discovered actor types with the Dapr actor runtime. + /// + /// The service collection to add the actors to. + /// Whether to include actor types from referenced assemblies. + public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false) + {{ + services.AddActors(options => + {{ + {registrations} + if (includeTransientReferences) + {{ + {GenerateTransientActorRegistrations(compilation)} + }} + }}); + }} +}}"; + } + + /// + /// Generates the registration code for actor types in referenced assemblies. + /// + /// The current compilation context. + /// The generated registration code as a string. + private static string GenerateTransientActorRegistrations(Compilation compilation) + { + var actorRegistrations = new List(); + + foreach (var reference in compilation.References) + { + if (compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol referencedCompilation) + { + actorRegistrations.AddRange(from type in referencedCompilation.GlobalNamespace.GetNamespaceTypes() + where type.BaseType?.ToDisplayString() == DaprActorType + select $"options.Actors.RegisterActor<{type.ToDisplayString()}>();"); + } + } + +#pragma warning disable RS1035 + return string.Join(Environment.NewLine, actorRegistrations); +#pragma warning restore RS1035 + } +} + diff --git a/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs b/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs index 6b45e86f3..7ea5c65f9 100644 --- a/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs +++ b/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs @@ -1,34 +1,33 @@ -namespace Dapr.Actors.Generators.Extensions +namespace Dapr.Actors.Generators.Extensions; + +internal static class IEnumerableExtensions { - internal static class IEnumerableExtensions + /// + /// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned. + /// + /// The type of objects in the . + /// in which to search. + /// Function performed to check whether an item satisfies the condition. + /// Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1. + internal static int IndexOf(this IEnumerable source, Func predicate) { - /// - /// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned. - /// - /// The type of objects in the . - /// in which to search. - /// Function performed to check whether an item satisfies the condition. - /// Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1. - internal static int IndexOf(this IEnumerable source, Func predicate) + if (predicate is null) { - if (predicate is null) - { - throw new ArgumentNullException(nameof(predicate)); - } + throw new ArgumentNullException(nameof(predicate)); + } - int index = 0; + int index = 0; - foreach (var item in source) + foreach (var item in source) + { + if (predicate(item)) { - if (predicate(item)) - { - return index; - } - - index++; + return index; } - return -1; + index++; } + + return -1; } } diff --git a/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs b/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs new file mode 100644 index 000000000..b094ff714 --- /dev/null +++ b/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs @@ -0,0 +1,33 @@ +using Microsoft.CodeAnalysis; + +namespace Dapr.Actors.Generators.Extensions; + +internal static class INamespaceSymbolExtensions +{ + /// + /// Recursively gets all the types in a namespace. + /// + /// The namespace symbol to search. + /// A collection of the named type symbols. + public static IEnumerable GetNamespaceTypes(this INamespaceSymbol namespaceSymbol) + { + foreach (var member in namespaceSymbol.GetMembers()) + { + switch (member) + { + case INamespaceSymbol nestedNamespace: + { + foreach (var nestedType in nestedNamespace.GetNamespaceTypes()) + { + yield return nestedType; + } + + break; + } + case INamedTypeSymbol namedType: + yield return namedType; + break; + } + } + } +} diff --git a/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs b/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs new file mode 100644 index 000000000..4ee4778e5 --- /dev/null +++ b/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs @@ -0,0 +1,156 @@ +using System.Text; +using Dapr.Actors.Runtime; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Extensions.DependencyInjection; + +namespace Dapr.Actors.Generators.Test; + +public class ActorRegistrationGeneratorTests +{ + [Fact] + public void TestActorRegistrationGenerator_WithoutTransientReference() + { + const string source = @" +using Dapr.Actors.Runtime; + +public class MyActor : Actor, IMyActor +{ + public MyActor(ActorHost host) : base(host) { } +} + +public interface IMyActor : IActor +{ +} +"; + + const string expectedGeneratedCode = @" +using Microsoft.Extensions.DependencyInjection; +using Dapr.Actors.Runtime; + +/// +/// Extension methods for registering Dapr actors. +/// +public static class ActorRegistrationExtensions +{ + /// + /// Registers all discovered actor types with the Dapr actor runtime. + /// + /// The service collection to add the actors to. + /// Whether to include actor types from referenced assemblies. + public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false) + { + services.AddActors(options => + { + options.Actors.RegisterActor(); + if (includeTransientReferences) + { + + } + }); + } +}"; + + var generatedCode = GetGeneratedCode(source); + Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim()); + } + + [Fact] + public void TestActorRegistrationGenerator_WithTransientReference() + { + const string source = @" +using Dapr.Actors.Runtime; + +public class MyActor : Actor, IMyActor +{ + public MyActor(ActorHost host) : base(host) { } +} + +public interface IMyActor : IActor +{ +} +"; + + const string referencedSource = @" +using Dapr.Actors.Runtime; + +public class TransientActor : Actor, ITransientActor +{ + public TransientActor(ActorHost host) : base(host) { } +} + +public interface ITransientActor : IActor +{ +} +"; + + const string expectedGeneratedCode = @" +using Microsoft.Extensions.DependencyInjection; +using Dapr.Actors.Runtime; + +/// +/// Extension methods for registering Dapr actors. +/// +public static class ActorRegistrationExtensions +{ + /// + /// Registers all discovered actor types with the Dapr actor runtime. + /// + /// The service collection to add the actors to. + /// Whether to include actor types from referenced assemblies. + public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false) + { + services.AddActors(options => + { + options.Actors.RegisterActor(); + if (includeTransientReferences) + { + options.Actors.RegisterActor(); + } + }); + } +}"; + + var generatedCode = GetGeneratedCode(source, referencedSource); + Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim()); + } + + private static string GetGeneratedCode(string source, string? referencedSource = null) + { + var syntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(source, Encoding.UTF8)); + var references = new List + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Actor).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location) + }; + + var compilation = CSharpCompilation.Create("TestCompilation", + new[] { syntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + if (referencedSource != null) + { + var referencedSyntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(referencedSource, Encoding.UTF8)); + var referencedCompilation = CSharpCompilation.Create("ReferencedCompilation", + new[] { referencedSyntaxTree }, + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + compilation = compilation.AddReferences(referencedCompilation.ToMetadataReference()); + } + + var generator = new ActorRegistrationGenerator(); + var driver = CSharpGeneratorDriver.Create(generator); + driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics); + + var generatedTrees = outputCompilation.SyntaxTrees.Skip(1).ToList(); + Assert.Single(generatedTrees); + + var generatedCode = generatedTrees[0].ToString(); + return generatedCode; + } +} diff --git a/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs b/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs new file mode 100644 index 000000000..7b1b5b7b3 --- /dev/null +++ b/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs @@ -0,0 +1,40 @@ +using Dapr.Actors.Generators.Extensions; +using Microsoft.CodeAnalysis.CSharp; + +namespace Dapr.Actors.Generators.Test.Extensions; + +public class INamespaceExtensionsTests +{ + [Fact] + public void GetNamespaceTypes_ReturnsAllTypesInNamespace() + { + // Arrange + const string source = @" +namespace TestNamespace +{ + public class ClassA { } + public class ClassB { } + + namespace NestedNamespace + { + public class ClassC { } + } +}"; + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var compilation = CSharpCompilation.Create("TestCompilation", new[] { syntaxTree }); + var namespaceSymbol = compilation.GlobalNamespace.GetNamespaceMembers().FirstOrDefault(n => n.Name == "TestNamespace"); + + // Act + if (namespaceSymbol != null) + { + var types = namespaceSymbol.GetNamespaceTypes().ToList(); + + // Assert + Assert.NotNull(namespaceSymbol); + Assert.Equal(3, types.Count); + Assert.Contains(types, t => t.Name == "ClassA"); + Assert.Contains(types, t => t.Name == "ClassB"); + Assert.Contains(types, t => t.Name == "ClassC"); + } + } +}