From 2da9e49fe88e3665183ffe4d610c042af708e23b Mon Sep 17 00:00:00 2001 From: Dimitar Dobrev Date: Fri, 22 Oct 2021 17:29:13 +0300 Subject: [PATCH] Instantiate specialized classes nested in templates Signed-off-by: Dimitar Dobrev --- src/AST/ClassExtensions.cs | 14 ++++++++ src/CppParser/Parser.cpp | 33 ++++++++++++++++--- src/CppParser/Parser.h | 1 + .../Generators/CSharp/CSharpSources.cs | 8 ++--- .../CSharp/CSharpSourcesExtensions.cs | 16 ++++----- tests/CSharp/CSharpTemplates.cpp | 8 +++-- 6 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/AST/ClassExtensions.cs b/src/AST/ClassExtensions.cs index 5b1fabc371..514b52490b 100644 --- a/src/AST/ClassExtensions.cs +++ b/src/AST/ClassExtensions.cs @@ -233,6 +233,20 @@ public static Class GetInterface(this Class @class) return @interface; } + public static ClassTemplateSpecialization GetParentSpecialization(this Class @class) + { + Class currentClass = @class; + do + { + if (currentClass is ClassTemplateSpecialization specialization) + { + return specialization; + } + currentClass = currentClass.Namespace as Class; + } while (currentClass != null); + return null; + } + public static bool HasDependentValueFieldInLayout(this Class @class) { if (@class.Fields.Any(f => IsValueDependent(f.Type))) diff --git a/src/CppParser/Parser.cpp b/src/CppParser/Parser.cpp index 16d39596c8..f95b9c543a 100644 --- a/src/CppParser/Parser.cpp +++ b/src/CppParser/Parser.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -3053,8 +3054,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType) RD = const_cast(Type->getPointeeCXXRecordDecl()); ClassTemplateSpecializationDecl* CTS; if (!RD || - !(CTS = llvm::dyn_cast(RD)) || - CTS->isCompleteDefinition()) + !(CTS = llvm::dyn_cast(RD))) return; auto existingClient = c->getSema().getDiagnostics().getClient(); @@ -3065,8 +3065,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType) Scope Scope(nullptr, Scope::ScopeFlags::ClassScope, c->getSema().getDiagnostics()); c->getSema().TUScope = &Scope; - c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(), - CTS, TSK_ImplicitInstantiation, false); + InstantiateSpecialization(CTS); c->getSema().getDiagnostics().setClient(existingClient, false); c->getSema().TUScope = nullptr; @@ -3082,6 +3081,32 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType) } } +void Parser::InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS) +{ + using namespace clang; + + if (!CTS->isCompleteDefinition()) + { + c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(), + CTS, TSK_ImplicitInstantiation, false); + } + + for (auto Decl : CTS->decls()) + { + if (Decl->getKind() == Decl::Kind::CXXRecord) + { + CXXRecordDecl* Nested = cast(Decl); + CXXRecordDecl* Template = Nested->getInstantiatedFromMemberClass(); + if (Template && !Nested->isCompleteDefinition() && !Nested->hasDefinition()) + { + c->getSema().InstantiateClass(Nested->getBeginLoc(), Nested, Template, + MultiLevelTemplateArgumentList(CTS->getTemplateArgs()), + TSK_ImplicitInstantiation, false); + } + } + } +} + Parameter* Parser::WalkParameter(const clang::ParmVarDecl* PVD, const clang::SourceLocation& ParamStartLoc) { diff --git a/src/CppParser/Parser.h b/src/CppParser/Parser.h index 8a25f02a29..aae5deb490 100644 --- a/src/CppParser/Parser.h +++ b/src/CppParser/Parser.h @@ -137,6 +137,7 @@ class Parser std::string GetTypeName(const clang::Type* Type); bool CanCheckCodeGenInfo(clang::Sema & S, const clang::Type * Ty); void CompleteIfSpecializationType(const clang::QualType& QualType); + void InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS); Parameter* WalkParameter(const clang::ParmVarDecl* PVD, const clang::SourceLocation& ParamStartLoc); void SetBody(const clang::FunctionDecl* FD, Function* F); diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 45d540f5bf..6c21df1f6b 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -361,13 +361,11 @@ private void GenerateNestedInternals(string name, IEnumerable nestedClass private IEnumerable GetGeneratedClasses( Class dependentClass, IEnumerable specializedClasses) { - var specialization = specializedClasses.FirstOrDefault(s => s.IsGenerated) ?? - specializedClasses.First(); - if (dependentClass.HasDependentValueFieldInLayout()) - return specializedClasses; + return specializedClasses.KeepSingleAllPointersSpecialization(); - return new[] { specialization }; + return new[] { specializedClasses.FirstOrDefault(s => s.IsGenerated) ?? + specializedClasses.First()}; } public override void GenerateDeclarationCommon(Declaration decl) diff --git a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs index da08ec4945..fb01dbde06 100644 --- a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs +++ b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs @@ -24,11 +24,9 @@ public static void GenerateNativeConstructorsByValue( var printedClass = @class.Visit(gen.TypePrinter); if (@class.IsDependent) { - IEnumerable specializations = - @class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated); - if (@class.IsTemplate) - specializations = specializations.KeepSingleAllPointersSpecialization(); - foreach (var specialization in specializations) + foreach (var specialization in (from s in @class.GetSpecializedClassesToGenerate() + where s.IsGenerated + select s).KeepSingleAllPointersSpecialization()) gen.GenerateNativeConstructorByValue(specialization, printedClass); } else @@ -40,10 +38,10 @@ public static void GenerateNativeConstructorsByValue( public static IEnumerable KeepSingleAllPointersSpecialization( this IEnumerable specializations) { - Func allPointers = (TemplateArgument a) => - a.Type.Type?.Desugar().IsAddress() == true; - var groups = (from ClassTemplateSpecialization spec in specializations - group spec by spec.Arguments.All(allPointers) + static bool allPointers(TemplateArgument a) => a.Type.Type?.Desugar().IsAddress() == true; + var groups = (from @class in specializations + let spec = @class.GetParentSpecialization() + group @class by spec.Arguments.All(allPointers) into @group select @group).ToList(); foreach (var group in groups) diff --git a/tests/CSharp/CSharpTemplates.cpp b/tests/CSharp/CSharpTemplates.cpp index 2ae88924ae..898454e852 100644 --- a/tests/CSharp/CSharpTemplates.cpp +++ b/tests/CSharp/CSharpTemplates.cpp @@ -109,8 +109,12 @@ void forceUseSpecializations(IndependentFields _1, IndependentFields VirtualTemplate _6, VirtualTemplate _7, HasDefaultTemplateArgument _8, DerivedChangesTypeName _9, TemplateWithIndexer _10, TemplateWithIndexer _11, - TemplateWithIndexer _12, TemplateDerivedFromRegularDynamic _13, - IndependentFields > _14, std::string s) + TemplateWithIndexer _12, TemplateWithIndexer _13, + TemplateDerivedFromRegularDynamic _14, + IndependentFields> _15, + DependentPointerFields _16, IndependentFields _17, + TemplateWithIndexer _18, IndependentFields _19, + TemplateWithIndexer _20, std::string s) { }