diff --git a/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj b/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj
index 754dcb6..c57bfe8 100644
--- a/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj
+++ b/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj
@@ -9,7 +9,7 @@
-
+
diff --git a/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj b/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj
index 8f79365..a303304 100644
--- a/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj
+++ b/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj
@@ -32,11 +32,11 @@
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
-
+
all
diff --git a/OneOf.SourceGenerator/OneOfGenerator.cs b/OneOf.SourceGenerator/OneOfGenerator.cs
index a7c225b..0d803fb 100644
--- a/OneOf.SourceGenerator/OneOfGenerator.cs
+++ b/OneOf.SourceGenerator/OneOfGenerator.cs
@@ -1,21 +1,22 @@
using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
+using System.Threading;
namespace OneOf.SourceGenerator
{
[Generator]
- public class OneOfGenerator : ISourceGenerator
+ public class OneOfGenerator : IIncrementalGenerator
{
private const string AttributeName = "GenerateOneOfAttribute";
private const string AttributeNamespace = "OneOf";
- private readonly string _attributeText = $@"//
+ private static readonly string _attributeText = $@"//
using System;
#pragma warning disable 1591
@@ -29,158 +30,179 @@ internal sealed class {AttributeName} : Attribute
}}
";
- public void Execute(GeneratorExecutionContext context)
+ public void Initialize(IncrementalGeneratorInitializationContext context)
{
- if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver)
- {
- return;
- }
-
- Compilation compilation = context.Compilation;
+ context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
+ $"{AttributeName}.g.cs",
+ SourceText.From(_attributeText, Encoding.UTF8)));
+
+ IncrementalValuesProvider classesToGenerate = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ $"{AttributeNamespace}.{AttributeName}",
+ predicate: (node, _) => IsCandidateForGeneration(node),
+ transform: GetClassToGenerate)
+ .Where(static m => m is not null);
+
+ context.RegisterSourceOutput(classesToGenerate,
+ static (spc, classToGenerate) => Execute(in classToGenerate, spc));
+ }
- INamedTypeSymbol? attributeSymbol =
- compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}");
+ private static bool IsCandidateForGeneration(SyntaxNode node)
+ => node is ClassDeclarationSyntax classSyntax && classSyntax.AttributeLists.Count > 0;
- if (attributeSymbol is null)
+ private static void Execute(in ClassToGenerate? classToGenerate, SourceProductionContext context)
+ {
+ if (classToGenerate is { } validClass)
{
- return;
+ if (validClass.Error is not null)
+ {
+ context.ReportDiagnostic(validClass.Error);
+ return;
+ }
+
+ var result = GetClassCode(in validClass);
+ context.AddSource($"{validClass.Namespace}_{validClass.Name}.g.cs", SourceText.From(result, Encoding.UTF8));
}
+ }
- List<(INamedTypeSymbol, Location?)> namedTypeSymbols = new();
- foreach (ClassDeclarationSyntax classDeclaration in receiver.CandidateClasses)
+ private static ClassToGenerate? GetClassToGenerate(GeneratorAttributeSyntaxContext context, CancellationToken ct)
+ {
+ if (context.TargetSymbol is not INamedTypeSymbol classSymbol)
{
- SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
- INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration);
-
- AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad =>
- ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false);
-
- if (attributeData is not null)
- {
- namedTypeSymbols.Add((namedTypeSymbol!,
- attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
- }
+ return null;
}
- foreach ((INamedTypeSymbol namedSymbol, Location? attributeLocation) in namedTypeSymbols)
+ // Check to see if the class has the attribute we're looking for, otherwise return null and do nothing
+ if (!classSymbol.GetAttributes().Any(a => a.AttributeClass?.Name == AttributeName
+ && a.AttributeClass?.ContainingNamespace.Name == AttributeNamespace))
{
- string? classSource = ProcessClass(namedSymbol, context, attributeLocation);
+ return null;
+ }
- if (classSource is null)
- {
- continue;
- }
+ ct.ThrowIfCancellationRequested();
- context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.g.cs", classSource);
+ if (ClassHasErrors(classSymbol, context, out ClassToGenerate? classWithError))
+ {
+ return classWithError;
}
+
+ IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
+ classSymbol.BaseType!.TypeParameters.Zip(classSymbol.BaseType.TypeArguments, (param, arg) => (param, arg));
+
+ return new ClassToGenerate(
+ classSymbol.Name,
+ classSymbol.ContainingNamespace.ToDisplayString(),
+ classSymbol.TypeArguments,
+ classSymbol.BaseType!.TypeArguments,
+ paramArgPairs);
}
- private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation)
+ private static bool ClassHasErrors(INamedTypeSymbol classSymbol, GeneratorAttributeSyntaxContext context, out ClassToGenerate? classWithError)
{
- attributeLocation ??= Location.None;
+ classWithError = null;
if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default))
{
- CreateDiagnosticError(GeneratorDiagnosticDescriptors.TopLevelError);
- return null;
+ classWithError = CreateError(GeneratorDiagnosticDescriptors.TopLevelError, context.TargetNode.GetLocation(), classSymbol.Name);
+ return true;
}
- if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf")
+ if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != AttributeNamespace)
{
- CreateDiagnosticError(GeneratorDiagnosticDescriptors.WrongBaseType);
- return null;
+ classWithError = CreateError(GeneratorDiagnosticDescriptors.WrongBaseType, context.TargetNode.GetLocation(), classSymbol.Name);
+ return true;
}
- ImmutableArray typeArguments = classSymbol.BaseType.TypeArguments;
-
- foreach (ITypeSymbol typeSymbol in typeArguments)
+ foreach (ITypeSymbol typeSymbol in classSymbol.BaseType!.TypeArguments)
{
if (typeSymbol.Name == nameof(Object))
{
- CreateDiagnosticError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType);
- return null;
+ classWithError = CreateError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType, context.TargetNode.GetLocation(), classSymbol.Name);
+ return true;
}
if (typeSymbol.TypeKind == TypeKind.Interface)
{
- CreateDiagnosticError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed);
- return null;
+ classWithError = CreateError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed,
+ context.TargetNode.GetLocation(), classSymbol.Name);
+ return true;
}
}
- return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments);
-
- void CreateDiagnosticError(DiagnosticDescriptor descriptor)
- => context.ReportDiagnostic(Diagnostic.Create(descriptor, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
+ return false;
}
- private static string GenerateClassSource(INamedTypeSymbol classSymbol,
- ImmutableArray typeParameters, ImmutableArray typeArguments)
- {
- IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
- typeParameters.Zip(typeArguments, (param, arg) => (param, arg));
+ private static ClassToGenerate CreateError(DiagnosticDescriptor descriptor, Location location, string name)
+ => new(Diagnostic.Create(descriptor, location, name, DiagnosticSeverity.Error));
- string oneOfGenericPart = GetGenericPart(typeArguments);
+ private static string GetGenericPart(ImmutableArray typeArguments) =>
+ string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));
- string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}";
+ private static string? GetOpenGenericPart(ImmutableArray typeArguments)
+ {
+ if (!typeArguments.Any())
+ {
+ return null;
+ }
- StringBuilder source = new($@"//
-#pragma warning disable 1591
+ return $"<{GetGenericPart(typeArguments)}>";
+ }
-namespace {classSymbol.ContainingNamespace.ToDisplayString()}
-{{
- partial class {classNameWithGenericTypes}");
+ private static string GetClassCode(in ClassToGenerate classToGenerate)
+ {
+ string constructor = $"public {classToGenerate.Name}(OneOf.OneOf<{GetGenericPart(classToGenerate.OneOfBaseTypeArguments)}> _) : base(_) {{ }}";
- source.Append($@"
- {{
- public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }}
-");
+ string classNameWithGenericTypes = $"{classToGenerate.Name}{GetOpenGenericPart(classToGenerate.TypeArguments)}";
- foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs)
+ StringBuilder sbParamArgPairs = new();
+ foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in classToGenerate.ParamArgPairs)
{
- source.Append($@"
+ sbParamArgPairs.Append($@"
public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_);
public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name};
");
}
- source.Append(@" }
-}");
- return source.ToString();
- }
+ return $@"//
+#pragma warning disable 1591
- private static string GetGenericPart(ImmutableArray typeArguments) =>
- string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));
+namespace {classToGenerate.Namespace}
+{{
+ partial class {classNameWithGenericTypes}
+ {{
+ {constructor}
+ {sbParamArgPairs}
+ }}
+}}
+";
+ }
- private static string? GetOpenGenericPart(INamedTypeSymbol classSymbol)
+ internal sealed class ClassToGenerate
{
- if (!classSymbol.TypeArguments.Any())
+ public ClassToGenerate(
+ string name,
+ string @namespace,
+ ImmutableArray typeArguments,
+ ImmutableArray oneOfBaseTypeArguments,
+ IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs)
{
- return null;
+ Name = name;
+ Namespace = @namespace;
+ TypeArguments = typeArguments;
+ OneOfBaseTypeArguments = oneOfBaseTypeArguments;
+ ParamArgPairs = paramArgPairs;
}
- return $"<{GetGenericPart(classSymbol.TypeArguments)}>";
- }
+ public ClassToGenerate(Diagnostic error)
+ : this("", "", default, default, new List<(ITypeParameterSymbol param, ITypeSymbol arg)>())
+ => Error = error;
- public void Initialize(GeneratorInitializationContext context)
- {
- context.RegisterForPostInitialization(ctx =>
- ctx.AddSource($"{AttributeName}.g.cs", _attributeText));
- context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver());
- }
-
- internal class OneOfSyntaxReceiver : ISyntaxReceiver
- {
- public List CandidateClasses { get; } = new();
-
- public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
- {
- if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax
- && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
- {
- CandidateClasses.Add(classDeclarationSyntax);
- }
- }
+ public string Name { get; }
+ public string Namespace { get; }
+ public ImmutableArray TypeArguments { get; }
+ public ImmutableArray OneOfBaseTypeArguments { get; }
+ public IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> ParamArgPairs { get; }
+ public Diagnostic? Error { get; }
}
}
-}
+}
\ No newline at end of file