diff --git a/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs b/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs
new file mode 100644
index 000000000..f4320352c
--- /dev/null
+++ b/src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs
@@ -0,0 +1,141 @@
+// ------------------------------------------------------------------------
+// Copyright 2023 The Dapr Authors
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ------------------------------------------------------------------------
+
+using System.Collections.Immutable;
+using System.Text;
+using Dapr.Actors.Generators.Extensions;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Text;
+
+
+
+namespace Dapr.Actors.Generators;
+
+///
+/// Generates an extension method that can be used during dependency injection to register all actor types.
+///
+[Generator]
+public sealed class ActorRegistrationGenerator : IIncrementalGenerator
+{
+ private const string DaprActorType = "Dapr.Actors.Runtime.Actor";
+
+ ///
+ /// Initializes the generator and registers the syntax receiver.
+ ///
+ /// The to register callbacks on
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ var classDeclarations = context.SyntaxProvider
+ .CreateSyntaxProvider(
+ predicate: static (s, _) => IsClassDeclaration(s),
+ transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
+ .Where(static m => m is not null);
+
+ var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());
+ context.RegisterSourceOutput(compilationAndClasses, static (spc, source) => Execute(source.Left, source.Right, spc));
+ }
+
+ private static bool IsClassDeclaration(SyntaxNode node) => node is ClassDeclarationSyntax;
+
+ private static INamedTypeSymbol? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
+ {
+ var classDeclaration = (ClassDeclarationSyntax)context.Node;
+ var model = context.SemanticModel;
+
+ if (model.GetDeclaredSymbol(classDeclaration) is not INamedTypeSymbol classSymbol)
+ {
+ return null;
+ }
+
+ var actorClass = context.SemanticModel.Compilation.GetTypeByMetadataName(DaprActorType);
+ return classSymbol.BaseType != null && classSymbol.BaseType.Equals(actorClass, SymbolEqualityComparer.Default) ? classSymbol : null;
+ }
+
+ private static void Execute(Compilation compilation, ImmutableArray actorTypes,
+ SourceProductionContext context)
+ {
+ var validActorTypes = actorTypes.Where(static t => t is not null).Cast().ToList();
+ var source = GenerateActorRegistrationSource(compilation, validActorTypes);
+ context.AddSource("ActorRegistrationExtensions.g.cs", SourceText.From(source, Encoding.UTF8));
+ }
+
+ ///
+ /// Generates the source code for the actor registration method.
+ ///
+ /// The current compilation context.
+ /// The list of actor types to register.
+ /// The generated source code as a string.
+ private static string GenerateActorRegistrationSource(Compilation compilation, IReadOnlyList actorTypes)
+ {
+#pragma warning disable RS1035
+ var registrations = string.Join(Environment.NewLine,
+#pragma warning restore RS1035
+ actorTypes.Select(t => $"options.Actors.RegisterActor<{t.ToDisplayString()}>();"));
+
+ return $@"
+using Microsoft.Extensions.DependencyInjection;
+using Dapr.Actors.AspNetCore;
+using Dapr.Actors.Runtime;
+using Dapr.Actors;
+using Dapr.Actors.AspNetCore;
+
+///
+/// Extension methods for registering Dapr actors.
+///
+public static class ActorRegistrationExtensions
+{{
+ ///
+ /// Registers all discovered actor types with the Dapr actor runtime.
+ ///
+ /// The service collection to add the actors to.
+ /// Whether to include actor types from referenced assemblies.
+ public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
+ {{
+ services.AddActors(options =>
+ {{
+ {registrations}
+ if (includeTransientReferences)
+ {{
+ {GenerateTransientActorRegistrations(compilation)}
+ }}
+ }});
+ }}
+}}";
+ }
+
+ ///
+ /// Generates the registration code for actor types in referenced assemblies.
+ ///
+ /// The current compilation context.
+ /// The generated registration code as a string.
+ private static string GenerateTransientActorRegistrations(Compilation compilation)
+ {
+ var actorRegistrations = new List();
+
+ foreach (var reference in compilation.References)
+ {
+ if (compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol referencedCompilation)
+ {
+ actorRegistrations.AddRange(from type in referencedCompilation.GlobalNamespace.GetNamespaceTypes()
+ where type.BaseType?.ToDisplayString() == DaprActorType
+ select $"options.Actors.RegisterActor<{type.ToDisplayString()}>();");
+ }
+ }
+
+#pragma warning disable RS1035
+ return string.Join(Environment.NewLine, actorRegistrations);
+#pragma warning restore RS1035
+ }
+}
+
diff --git a/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs b/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs
index 6b45e86f3..7ea5c65f9 100644
--- a/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs
+++ b/src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs
@@ -1,34 +1,33 @@
-namespace Dapr.Actors.Generators.Extensions
+namespace Dapr.Actors.Generators.Extensions;
+
+internal static class IEnumerableExtensions
{
- internal static class IEnumerableExtensions
+ ///
+ /// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned.
+ ///
+ /// The type of objects in the .
+ /// in which to search.
+ /// Function performed to check whether an item satisfies the condition.
+ /// Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1.
+ internal static int IndexOf(this IEnumerable source, Func predicate)
{
- ///
- /// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned.
- ///
- /// The type of objects in the .
- /// in which to search.
- /// Function performed to check whether an item satisfies the condition.
- /// Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1.
- internal static int IndexOf(this IEnumerable source, Func predicate)
+ if (predicate is null)
{
- if (predicate is null)
- {
- throw new ArgumentNullException(nameof(predicate));
- }
+ throw new ArgumentNullException(nameof(predicate));
+ }
- int index = 0;
+ int index = 0;
- foreach (var item in source)
+ foreach (var item in source)
+ {
+ if (predicate(item))
{
- if (predicate(item))
- {
- return index;
- }
-
- index++;
+ return index;
}
- return -1;
+ index++;
}
+
+ return -1;
}
}
diff --git a/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs b/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs
new file mode 100644
index 000000000..b094ff714
--- /dev/null
+++ b/src/Dapr.Actors.Generators/Extensions/INamespaceSymbolExtensions.cs
@@ -0,0 +1,33 @@
+using Microsoft.CodeAnalysis;
+
+namespace Dapr.Actors.Generators.Extensions;
+
+internal static class INamespaceSymbolExtensions
+{
+ ///
+ /// Recursively gets all the types in a namespace.
+ ///
+ /// The namespace symbol to search.
+ /// A collection of the named type symbols.
+ public static IEnumerable GetNamespaceTypes(this INamespaceSymbol namespaceSymbol)
+ {
+ foreach (var member in namespaceSymbol.GetMembers())
+ {
+ switch (member)
+ {
+ case INamespaceSymbol nestedNamespace:
+ {
+ foreach (var nestedType in nestedNamespace.GetNamespaceTypes())
+ {
+ yield return nestedType;
+ }
+
+ break;
+ }
+ case INamedTypeSymbol namedType:
+ yield return namedType;
+ break;
+ }
+ }
+ }
+}
diff --git a/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs b/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs
new file mode 100644
index 000000000..4ee4778e5
--- /dev/null
+++ b/test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs
@@ -0,0 +1,156 @@
+using System.Text;
+using Dapr.Actors.Runtime;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.Text;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Dapr.Actors.Generators.Test;
+
+public class ActorRegistrationGeneratorTests
+{
+ [Fact]
+ public void TestActorRegistrationGenerator_WithoutTransientReference()
+ {
+ const string source = @"
+using Dapr.Actors.Runtime;
+
+public class MyActor : Actor, IMyActor
+{
+ public MyActor(ActorHost host) : base(host) { }
+}
+
+public interface IMyActor : IActor
+{
+}
+";
+
+ const string expectedGeneratedCode = @"
+using Microsoft.Extensions.DependencyInjection;
+using Dapr.Actors.Runtime;
+
+///
+/// Extension methods for registering Dapr actors.
+///
+public static class ActorRegistrationExtensions
+{
+ ///
+ /// Registers all discovered actor types with the Dapr actor runtime.
+ ///
+ /// The service collection to add the actors to.
+ /// Whether to include actor types from referenced assemblies.
+ public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
+ {
+ services.AddActors(options =>
+ {
+ options.Actors.RegisterActor();
+ if (includeTransientReferences)
+ {
+
+ }
+ });
+ }
+}";
+
+ var generatedCode = GetGeneratedCode(source);
+ Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
+ }
+
+ [Fact]
+ public void TestActorRegistrationGenerator_WithTransientReference()
+ {
+ const string source = @"
+using Dapr.Actors.Runtime;
+
+public class MyActor : Actor, IMyActor
+{
+ public MyActor(ActorHost host) : base(host) { }
+}
+
+public interface IMyActor : IActor
+{
+}
+";
+
+ const string referencedSource = @"
+using Dapr.Actors.Runtime;
+
+public class TransientActor : Actor, ITransientActor
+{
+ public TransientActor(ActorHost host) : base(host) { }
+}
+
+public interface ITransientActor : IActor
+{
+}
+";
+
+ const string expectedGeneratedCode = @"
+using Microsoft.Extensions.DependencyInjection;
+using Dapr.Actors.Runtime;
+
+///
+/// Extension methods for registering Dapr actors.
+///
+public static class ActorRegistrationExtensions
+{
+ ///
+ /// Registers all discovered actor types with the Dapr actor runtime.
+ ///
+ /// The service collection to add the actors to.
+ /// Whether to include actor types from referenced assemblies.
+ public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
+ {
+ services.AddActors(options =>
+ {
+ options.Actors.RegisterActor();
+ if (includeTransientReferences)
+ {
+ options.Actors.RegisterActor();
+ }
+ });
+ }
+}";
+
+ var generatedCode = GetGeneratedCode(source, referencedSource);
+ Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
+ }
+
+ private static string GetGeneratedCode(string source, string? referencedSource = null)
+ {
+ var syntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(source, Encoding.UTF8));
+ var references = new List
+ {
+ MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(Actor).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location)
+ };
+
+ var compilation = CSharpCompilation.Create("TestCompilation",
+ new[] { syntaxTree },
+ references,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+
+ if (referencedSource != null)
+ {
+ var referencedSyntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(referencedSource, Encoding.UTF8));
+ var referencedCompilation = CSharpCompilation.Create("ReferencedCompilation",
+ new[] { referencedSyntaxTree },
+ references,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+
+ compilation = compilation.AddReferences(referencedCompilation.ToMetadataReference());
+ }
+
+ var generator = new ActorRegistrationGenerator();
+ var driver = CSharpGeneratorDriver.Create(generator);
+ driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics);
+
+ var generatedTrees = outputCompilation.SyntaxTrees.Skip(1).ToList();
+ Assert.Single(generatedTrees);
+
+ var generatedCode = generatedTrees[0].ToString();
+ return generatedCode;
+ }
+}
diff --git a/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs b/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs
new file mode 100644
index 000000000..7b1b5b7b3
--- /dev/null
+++ b/test/Dapr.Actors.Generators.Test/Extensions/INamespaceExtensionsTests.cs
@@ -0,0 +1,40 @@
+using Dapr.Actors.Generators.Extensions;
+using Microsoft.CodeAnalysis.CSharp;
+
+namespace Dapr.Actors.Generators.Test.Extensions;
+
+public class INamespaceExtensionsTests
+{
+ [Fact]
+ public void GetNamespaceTypes_ReturnsAllTypesInNamespace()
+ {
+ // Arrange
+ const string source = @"
+namespace TestNamespace
+{
+ public class ClassA { }
+ public class ClassB { }
+
+ namespace NestedNamespace
+ {
+ public class ClassC { }
+ }
+}";
+ var syntaxTree = CSharpSyntaxTree.ParseText(source);
+ var compilation = CSharpCompilation.Create("TestCompilation", new[] { syntaxTree });
+ var namespaceSymbol = compilation.GlobalNamespace.GetNamespaceMembers().FirstOrDefault(n => n.Name == "TestNamespace");
+
+ // Act
+ if (namespaceSymbol != null)
+ {
+ var types = namespaceSymbol.GetNamespaceTypes().ToList();
+
+ // Assert
+ Assert.NotNull(namespaceSymbol);
+ Assert.Equal(3, types.Count);
+ Assert.Contains(types, t => t.Name == "ClassA");
+ Assert.Contains(types, t => t.Name == "ClassB");
+ Assert.Contains(types, t => t.Name == "ClassC");
+ }
+ }
+}