Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added source generator for actor type registration #1401

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
141 changes: 141 additions & 0 deletions src/Dapr.Actors.Generators/ActorRegistrationGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// ------------------------------------------------------------------------
// Copyright 2023 The Dapr Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ------------------------------------------------------------------------

using System.Collections.Immutable;
using System.Text;
using Dapr.Actors.Generators.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;



namespace Dapr.Actors.Generators;

/// <summary>
/// Generates an extension method that can be used during dependency injection to register all actor types.
/// </summary>
[Generator]
public sealed class ActorRegistrationGenerator : IIncrementalGenerator
{
private const string DaprActorType = "Dapr.Actors.Runtime.Actor";

/// <summary>
/// Initializes the generator and registers the syntax receiver.
/// </summary>
/// <param name="context">The <see cref="T:Microsoft.CodeAnalysis.IncrementalGeneratorInitializationContext" /> to register callbacks on</param>
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var classDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => IsClassDeclaration(s),
transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx))
.Where(static m => m is not null);

var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());
context.RegisterSourceOutput(compilationAndClasses, static (spc, source) => Execute(source.Left, source.Right, spc));
}

private static bool IsClassDeclaration(SyntaxNode node) => node is ClassDeclarationSyntax;

private static INamedTypeSymbol? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
{
var classDeclaration = (ClassDeclarationSyntax)context.Node;
var model = context.SemanticModel;

if (model.GetDeclaredSymbol(classDeclaration) is not INamedTypeSymbol classSymbol)
{
return null;
}

var actorClass = context.SemanticModel.Compilation.GetTypeByMetadataName(DaprActorType);
return classSymbol.BaseType != null && classSymbol.BaseType.Equals(actorClass, SymbolEqualityComparer.Default) ? classSymbol : null;
}

private static void Execute(Compilation compilation, ImmutableArray<INamedTypeSymbol?> actorTypes,
SourceProductionContext context)
{
var validActorTypes = actorTypes.Where(static t => t is not null).Cast<INamedTypeSymbol>().ToList();
var source = GenerateActorRegistrationSource(compilation, validActorTypes);
context.AddSource("ActorRegistrationExtensions.g.cs", SourceText.From(source, Encoding.UTF8));
}

/// <summary>
/// Generates the source code for the actor registration method.
/// </summary>
/// <param name="compilation">The current compilation context.</param>
/// <param name="actorTypes">The list of actor types to register.</param>
/// <returns>The generated source code as a string.</returns>
private static string GenerateActorRegistrationSource(Compilation compilation, IReadOnlyList<INamedTypeSymbol> actorTypes)
{
#pragma warning disable RS1035
var registrations = string.Join(Environment.NewLine,
#pragma warning restore RS1035
actorTypes.Select(t => $"options.Actors.RegisterActor<{t.ToDisplayString()}>();"));

return $@"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.AspNetCore;
using Dapr.Actors.Runtime;
using Dapr.Actors;
using Dapr.Actors.AspNetCore;

/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{{
services.AddActors(options =>
{{
{registrations}
if (includeTransientReferences)
{{
{GenerateTransientActorRegistrations(compilation)}
}}
}});
}}
}}";
}

/// <summary>
/// Generates the registration code for actor types in referenced assemblies.
/// </summary>
/// <param name="compilation">The current compilation context.</param>
/// <returns>The generated registration code as a string.</returns>
private static string GenerateTransientActorRegistrations(Compilation compilation)
{
var actorRegistrations = new List<string>();

foreach (var reference in compilation.References)
{
if (compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol referencedCompilation)
{
actorRegistrations.AddRange(from type in referencedCompilation.GlobalNamespace.GetNamespaceTypes()
where type.BaseType?.ToDisplayString() == DaprActorType
select $"options.Actors.RegisterActor<{type.ToDisplayString()}>();");
}
}

#pragma warning disable RS1035
return string.Join(Environment.NewLine, actorRegistrations);
#pragma warning restore RS1035
}
}

45 changes: 22 additions & 23 deletions src/Dapr.Actors.Generators/Extensions/IEnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
namespace Dapr.Actors.Generators.Extensions
namespace Dapr.Actors.Generators.Extensions;

internal static class IEnumerableExtensions
{
internal static class IEnumerableExtensions
/// <summary>
/// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned.
/// </summary>
/// <typeparam name="T">The type of objects in the <see cref="IEnumerable{T}"/>.</typeparam>
/// <param name="source"><see cref="IEnumerable{T}"/> in which to search.</param>
/// <param name="predicate">Function performed to check whether an item satisfies the condition.</param>
/// <returns>Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1.</returns>
internal static int IndexOf<T>(this IEnumerable<T> source, Func<T, bool> predicate)
{
/// <summary>
/// Returns the index of the first item in the sequence that satisfies the predicate. If no item satisfies the predicate, -1 is returned.
/// </summary>
/// <typeparam name="T">The type of objects in the <see cref="IEnumerable{T}"/>.</typeparam>
/// <param name="source"><see cref="IEnumerable{T}"/> in which to search.</param>
/// <param name="predicate">Function performed to check whether an item satisfies the condition.</param>
/// <returns>Return the zero-based index of the first occurrence of an element that satisfies the condition, if found; otherwise, -1.</returns>
internal static int IndexOf<T>(this IEnumerable<T> source, Func<T, bool> predicate)
if (predicate is null)
{
if (predicate is null)
{
throw new ArgumentNullException(nameof(predicate));
}
throw new ArgumentNullException(nameof(predicate));
}

int index = 0;
int index = 0;

foreach (var item in source)
foreach (var item in source)
{
if (predicate(item))
{
if (predicate(item))
{
return index;
}

index++;
return index;
}

return -1;
index++;
}

return -1;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Microsoft.CodeAnalysis;

namespace Dapr.Actors.Generators.Extensions;

internal static class INamespaceSymbolExtensions
{
/// <summary>
/// Recursively gets all the types in a namespace.
/// </summary>
/// <param name="namespaceSymbol">The namespace symbol to search.</param>
/// <returns>A collection of the named type symbols.</returns>
public static IEnumerable<INamedTypeSymbol> GetNamespaceTypes(this INamespaceSymbol namespaceSymbol)
{
foreach (var member in namespaceSymbol.GetMembers())
{
switch (member)
{
case INamespaceSymbol nestedNamespace:
{
foreach (var nestedType in nestedNamespace.GetNamespaceTypes())
{
yield return nestedType;
}

break;
}
case INamedTypeSymbol namedType:
yield return namedType;
break;
}
}
}
}
156 changes: 156 additions & 0 deletions test/Dapr.Actors.Generators.Test/ActorRegistrationGeneratorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using System.Text;
using Dapr.Actors.Runtime;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Extensions.DependencyInjection;

namespace Dapr.Actors.Generators.Test;

public class ActorRegistrationGeneratorTests
{
[Fact]
public void TestActorRegistrationGenerator_WithoutTransientReference()
{
const string source = @"
using Dapr.Actors.Runtime;

public class MyActor : Actor, IMyActor
{
public MyActor(ActorHost host) : base(host) { }
}

public interface IMyActor : IActor
{
}
";

const string expectedGeneratedCode = @"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.Runtime;

/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{
services.AddActors(options =>
{
options.Actors.RegisterActor<MyActor>();
if (includeTransientReferences)
{

}
});
}
}";

var generatedCode = GetGeneratedCode(source);
Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
}

[Fact]
public void TestActorRegistrationGenerator_WithTransientReference()
{
const string source = @"
using Dapr.Actors.Runtime;

public class MyActor : Actor, IMyActor
{
public MyActor(ActorHost host) : base(host) { }
}

public interface IMyActor : IActor
{
}
";

const string referencedSource = @"
using Dapr.Actors.Runtime;

public class TransientActor : Actor, ITransientActor
{
public TransientActor(ActorHost host) : base(host) { }
}

public interface ITransientActor : IActor
{
}
";

const string expectedGeneratedCode = @"
using Microsoft.Extensions.DependencyInjection;
using Dapr.Actors.Runtime;

/// <summary>
/// Extension methods for registering Dapr actors.
/// </summary>
public static class ActorRegistrationExtensions
{
/// <summary>
/// Registers all discovered actor types with the Dapr actor runtime.
/// </summary>
/// <param name=""services"">The service collection to add the actors to.</param>
/// <param name=""includeTransientReferences"">Whether to include actor types from referenced assemblies.</param>
public static void RegisterAllActors(this IServiceCollection services, bool includeTransientReferences = false)
{
services.AddActors(options =>
{
options.Actors.RegisterActor<MyActor>();
if (includeTransientReferences)
{
options.Actors.RegisterActor<TransientActor>();
}
});
}
}";

var generatedCode = GetGeneratedCode(source, referencedSource);
Assert.Equal(expectedGeneratedCode.Trim(), generatedCode.Trim());
}

private static string GetGeneratedCode(string source, string? referencedSource = null)
{
var syntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(source, Encoding.UTF8));
var references = new List<MetadataReference>
{
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Actor).Assembly.Location),
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location)
};

var compilation = CSharpCompilation.Create("TestCompilation",
new[] { syntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

if (referencedSource != null)
{
var referencedSyntaxTree = CSharpSyntaxTree.ParseText(SourceText.From(referencedSource, Encoding.UTF8));
var referencedCompilation = CSharpCompilation.Create("ReferencedCompilation",
new[] { referencedSyntaxTree },
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

compilation = compilation.AddReferences(referencedCompilation.ToMetadataReference());
}

var generator = new ActorRegistrationGenerator();
var driver = CSharpGeneratorDriver.Create(generator);
driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var diagnostics);

var generatedTrees = outputCompilation.SyntaxTrees.Skip(1).ToList();
Assert.Single(generatedTrees);

var generatedCode = generatedTrees[0].ToString();
return generatedCode;
}
}
Loading
Loading