From a01c3bf8630f38c35812cc5e98a2b924eeaa4f61 Mon Sep 17 00:00:00 2001 From: Jesse Levier Date: Tue, 28 Mar 2023 09:47:15 -0400 Subject: [PATCH] Refactoring existing source generator to use new incremental api --- ...OneOf.SourceGenerator.AnalyzerTests.csproj | 2 +- .../OneOf.SourceGenerator.csproj | 4 +- OneOf.SourceGenerator/OneOfGenerator.cs | 224 ++++++++++-------- 3 files changed, 126 insertions(+), 104 deletions(-) diff --git a/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj b/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj index 754dcb6..c57bfe8 100644 --- a/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj +++ b/OneOf.SourceGenerator.AnalyzerTests/OneOf.SourceGenerator.AnalyzerTests.csproj @@ -9,7 +9,7 @@ - + diff --git a/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj b/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj index 8f79365..a303304 100644 --- a/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj +++ b/OneOf.SourceGenerator/OneOf.SourceGenerator.csproj @@ -32,11 +32,11 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all diff --git a/OneOf.SourceGenerator/OneOfGenerator.cs b/OneOf.SourceGenerator/OneOfGenerator.cs index a7c225b..0d803fb 100644 --- a/OneOf.SourceGenerator/OneOfGenerator.cs +++ b/OneOf.SourceGenerator/OneOfGenerator.cs @@ -1,21 +1,22 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Threading; namespace OneOf.SourceGenerator { [Generator] - public class OneOfGenerator : ISourceGenerator + public class OneOfGenerator : IIncrementalGenerator { private const string AttributeName = "GenerateOneOfAttribute"; private const string AttributeNamespace = "OneOf"; - private readonly string _attributeText = $@"// + private static readonly string _attributeText = $@"// using System; #pragma warning disable 1591 @@ -29,158 +30,179 @@ internal sealed class {AttributeName} : Attribute }} "; - public void Execute(GeneratorExecutionContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver) - { - return; - } - - Compilation compilation = context.Compilation; + context.RegisterPostInitializationOutput(ctx => ctx.AddSource( + $"{AttributeName}.g.cs", + SourceText.From(_attributeText, Encoding.UTF8))); + + IncrementalValuesProvider classesToGenerate = context.SyntaxProvider + .ForAttributeWithMetadataName( + $"{AttributeNamespace}.{AttributeName}", + predicate: (node, _) => IsCandidateForGeneration(node), + transform: GetClassToGenerate) + .Where(static m => m is not null); + + context.RegisterSourceOutput(classesToGenerate, + static (spc, classToGenerate) => Execute(in classToGenerate, spc)); + } - INamedTypeSymbol? attributeSymbol = - compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}"); + private static bool IsCandidateForGeneration(SyntaxNode node) + => node is ClassDeclarationSyntax classSyntax && classSyntax.AttributeLists.Count > 0; - if (attributeSymbol is null) + private static void Execute(in ClassToGenerate? classToGenerate, SourceProductionContext context) + { + if (classToGenerate is { } validClass) { - return; + if (validClass.Error is not null) + { + context.ReportDiagnostic(validClass.Error); + return; + } + + var result = GetClassCode(in validClass); + context.AddSource($"{validClass.Namespace}_{validClass.Name}.g.cs", SourceText.From(result, Encoding.UTF8)); } + } - List<(INamedTypeSymbol, Location?)> namedTypeSymbols = new(); - foreach (ClassDeclarationSyntax classDeclaration in receiver.CandidateClasses) + private static ClassToGenerate? GetClassToGenerate(GeneratorAttributeSyntaxContext context, CancellationToken ct) + { + if (context.TargetSymbol is not INamedTypeSymbol classSymbol) { - SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree); - INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration); - - AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad => - ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false); - - if (attributeData is not null) - { - namedTypeSymbols.Add((namedTypeSymbol!, - attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation())); - } + return null; } - foreach ((INamedTypeSymbol namedSymbol, Location? attributeLocation) in namedTypeSymbols) + // Check to see if the class has the attribute we're looking for, otherwise return null and do nothing + if (!classSymbol.GetAttributes().Any(a => a.AttributeClass?.Name == AttributeName + && a.AttributeClass?.ContainingNamespace.Name == AttributeNamespace)) { - string? classSource = ProcessClass(namedSymbol, context, attributeLocation); + return null; + } - if (classSource is null) - { - continue; - } + ct.ThrowIfCancellationRequested(); - context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.g.cs", classSource); + if (ClassHasErrors(classSymbol, context, out ClassToGenerate? classWithError)) + { + return classWithError; } + + IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs = + classSymbol.BaseType!.TypeParameters.Zip(classSymbol.BaseType.TypeArguments, (param, arg) => (param, arg)); + + return new ClassToGenerate( + classSymbol.Name, + classSymbol.ContainingNamespace.ToDisplayString(), + classSymbol.TypeArguments, + classSymbol.BaseType!.TypeArguments, + paramArgPairs); } - private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation) + private static bool ClassHasErrors(INamedTypeSymbol classSymbol, GeneratorAttributeSyntaxContext context, out ClassToGenerate? classWithError) { - attributeLocation ??= Location.None; + classWithError = null; if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default)) { - CreateDiagnosticError(GeneratorDiagnosticDescriptors.TopLevelError); - return null; + classWithError = CreateError(GeneratorDiagnosticDescriptors.TopLevelError, context.TargetNode.GetLocation(), classSymbol.Name); + return true; } - if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf") + if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != AttributeNamespace) { - CreateDiagnosticError(GeneratorDiagnosticDescriptors.WrongBaseType); - return null; + classWithError = CreateError(GeneratorDiagnosticDescriptors.WrongBaseType, context.TargetNode.GetLocation(), classSymbol.Name); + return true; } - ImmutableArray typeArguments = classSymbol.BaseType.TypeArguments; - - foreach (ITypeSymbol typeSymbol in typeArguments) + foreach (ITypeSymbol typeSymbol in classSymbol.BaseType!.TypeArguments) { if (typeSymbol.Name == nameof(Object)) { - CreateDiagnosticError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType); - return null; + classWithError = CreateError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType, context.TargetNode.GetLocation(), classSymbol.Name); + return true; } if (typeSymbol.TypeKind == TypeKind.Interface) { - CreateDiagnosticError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed); - return null; + classWithError = CreateError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed, + context.TargetNode.GetLocation(), classSymbol.Name); + return true; } } - return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments); - - void CreateDiagnosticError(DiagnosticDescriptor descriptor) - => context.ReportDiagnostic(Diagnostic.Create(descriptor, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); + return false; } - private static string GenerateClassSource(INamedTypeSymbol classSymbol, - ImmutableArray typeParameters, ImmutableArray typeArguments) - { - IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs = - typeParameters.Zip(typeArguments, (param, arg) => (param, arg)); + private static ClassToGenerate CreateError(DiagnosticDescriptor descriptor, Location location, string name) + => new(Diagnostic.Create(descriptor, location, name, DiagnosticSeverity.Error)); - string oneOfGenericPart = GetGenericPart(typeArguments); + private static string GetGenericPart(ImmutableArray typeArguments) => + string.Join(", ", typeArguments.Select(x => x.ToDisplayString())); - string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}"; + private static string? GetOpenGenericPart(ImmutableArray typeArguments) + { + if (!typeArguments.Any()) + { + return null; + } - StringBuilder source = new($@"// -#pragma warning disable 1591 + return $"<{GetGenericPart(typeArguments)}>"; + } -namespace {classSymbol.ContainingNamespace.ToDisplayString()} -{{ - partial class {classNameWithGenericTypes}"); + private static string GetClassCode(in ClassToGenerate classToGenerate) + { + string constructor = $"public {classToGenerate.Name}(OneOf.OneOf<{GetGenericPart(classToGenerate.OneOfBaseTypeArguments)}> _) : base(_) {{ }}"; - source.Append($@" - {{ - public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }} -"); + string classNameWithGenericTypes = $"{classToGenerate.Name}{GetOpenGenericPart(classToGenerate.TypeArguments)}"; - foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs) + StringBuilder sbParamArgPairs = new(); + foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in classToGenerate.ParamArgPairs) { - source.Append($@" + sbParamArgPairs.Append($@" public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_); public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name}; "); } - source.Append(@" } -}"); - return source.ToString(); - } + return $@"// +#pragma warning disable 1591 - private static string GetGenericPart(ImmutableArray typeArguments) => - string.Join(", ", typeArguments.Select(x => x.ToDisplayString())); +namespace {classToGenerate.Namespace} +{{ + partial class {classNameWithGenericTypes} + {{ + {constructor} + {sbParamArgPairs} + }} +}} +"; + } - private static string? GetOpenGenericPart(INamedTypeSymbol classSymbol) + internal sealed class ClassToGenerate { - if (!classSymbol.TypeArguments.Any()) + public ClassToGenerate( + string name, + string @namespace, + ImmutableArray typeArguments, + ImmutableArray oneOfBaseTypeArguments, + IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs) { - return null; + Name = name; + Namespace = @namespace; + TypeArguments = typeArguments; + OneOfBaseTypeArguments = oneOfBaseTypeArguments; + ParamArgPairs = paramArgPairs; } - return $"<{GetGenericPart(classSymbol.TypeArguments)}>"; - } + public ClassToGenerate(Diagnostic error) + : this("", "", default, default, new List<(ITypeParameterSymbol param, ITypeSymbol arg)>()) + => Error = error; - public void Initialize(GeneratorInitializationContext context) - { - context.RegisterForPostInitialization(ctx => - ctx.AddSource($"{AttributeName}.g.cs", _attributeText)); - context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver()); - } - - internal class OneOfSyntaxReceiver : ISyntaxReceiver - { - public List CandidateClasses { get; } = new(); - - public void OnVisitSyntaxNode(SyntaxNode syntaxNode) - { - if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax - && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - CandidateClasses.Add(classDeclarationSyntax); - } - } + public string Name { get; } + public string Namespace { get; } + public ImmutableArray TypeArguments { get; } + public ImmutableArray OneOfBaseTypeArguments { get; } + public IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> ParamArgPairs { get; } + public Diagnostic? Error { get; } } } -} +} \ No newline at end of file