Skip to content

Commit

Permalink
Instantiate specialized classes nested in templates
Browse files Browse the repository at this point in the history
Signed-off-by: Dimitar Dobrev <[email protected]>
  • Loading branch information
ddobrev committed Oct 22, 2021
1 parent eca0db1 commit 2da9e49
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 20 deletions.
14 changes: 14 additions & 0 deletions src/AST/ClassExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ public static Class GetInterface(this Class @class)
return @interface;
}

public static ClassTemplateSpecialization GetParentSpecialization(this Class @class)
{
Class currentClass = @class;
do
{
if (currentClass is ClassTemplateSpecialization specialization)
{
return specialization;
}
currentClass = currentClass.Namespace as Class;
} while (currentClass != null);
return null;
}

public static bool HasDependentValueFieldInLayout(this Class @class)
{
if (@class.Fields.Any(f => IsValueDependent(f.Type)))
Expand Down
33 changes: 29 additions & 4 deletions src/CppParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <clang/Parse/ParseAST.h>
#include <clang/Sema/Sema.h>
#include <clang/Sema/SemaConsumer.h>
#include <clang/Sema/Template.h>
#include <clang/Frontend/Utils.h>
#include <clang/Driver/Driver.h>
#include <clang/Driver/ToolChain.h>
Expand Down Expand Up @@ -3053,8 +3054,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
RD = const_cast<CXXRecordDecl*>(Type->getPointeeCXXRecordDecl());
ClassTemplateSpecializationDecl* CTS;
if (!RD ||
!(CTS = llvm::dyn_cast<ClassTemplateSpecializationDecl>(RD)) ||
CTS->isCompleteDefinition())
!(CTS = llvm::dyn_cast<ClassTemplateSpecializationDecl>(RD)))
return;

auto existingClient = c->getSema().getDiagnostics().getClient();
Expand All @@ -3065,8 +3065,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
Scope Scope(nullptr, Scope::ScopeFlags::ClassScope, c->getSema().getDiagnostics());
c->getSema().TUScope = &Scope;

c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(),
CTS, TSK_ImplicitInstantiation, false);
InstantiateSpecialization(CTS);

c->getSema().getDiagnostics().setClient(existingClient, false);
c->getSema().TUScope = nullptr;
Expand All @@ -3082,6 +3081,32 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
}
}

void Parser::InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS)
{
using namespace clang;

if (!CTS->isCompleteDefinition())
{
c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(),
CTS, TSK_ImplicitInstantiation, false);
}

for (auto Decl : CTS->decls())
{
if (Decl->getKind() == Decl::Kind::CXXRecord)
{
CXXRecordDecl* Nested = cast<CXXRecordDecl>(Decl);
CXXRecordDecl* Template = Nested->getInstantiatedFromMemberClass();
if (Template && !Nested->isCompleteDefinition() && !Nested->hasDefinition())
{
c->getSema().InstantiateClass(Nested->getBeginLoc(), Nested, Template,
MultiLevelTemplateArgumentList(CTS->getTemplateArgs()),
TSK_ImplicitInstantiation, false);
}
}
}
}

Parameter* Parser::WalkParameter(const clang::ParmVarDecl* PVD,
const clang::SourceLocation& ParamStartLoc)
{
Expand Down
1 change: 1 addition & 0 deletions src/CppParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class Parser
std::string GetTypeName(const clang::Type* Type);
bool CanCheckCodeGenInfo(clang::Sema & S, const clang::Type * Ty);
void CompleteIfSpecializationType(const clang::QualType& QualType);
void InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS);
Parameter* WalkParameter(const clang::ParmVarDecl* PVD,
const clang::SourceLocation& ParamStartLoc);
void SetBody(const clang::FunctionDecl* FD, Function* F);
Expand Down
8 changes: 3 additions & 5 deletions src/Generator/Generators/CSharp/CSharpSources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,11 @@ private void GenerateNestedInternals(string name, IEnumerable<Class> nestedClass
private IEnumerable<Class> GetGeneratedClasses(
Class dependentClass, IEnumerable<Class> specializedClasses)
{
var specialization = specializedClasses.FirstOrDefault(s => s.IsGenerated) ??
specializedClasses.First();

if (dependentClass.HasDependentValueFieldInLayout())
return specializedClasses;
return specializedClasses.KeepSingleAllPointersSpecialization();

return new[] { specialization };
return new[] { specializedClasses.FirstOrDefault(s => s.IsGenerated) ??
specializedClasses.First()};
}

public override void GenerateDeclarationCommon(Declaration decl)
Expand Down
16 changes: 7 additions & 9 deletions src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ public static void GenerateNativeConstructorsByValue(
var printedClass = @class.Visit(gen.TypePrinter);
if (@class.IsDependent)
{
IEnumerable<Class> specializations =
@class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
if (@class.IsTemplate)
specializations = specializations.KeepSingleAllPointersSpecialization();
foreach (var specialization in specializations)
foreach (var specialization in (from s in @class.GetSpecializedClassesToGenerate()
where s.IsGenerated
select s).KeepSingleAllPointersSpecialization())
gen.GenerateNativeConstructorByValue(specialization, printedClass);
}
else
Expand All @@ -40,10 +38,10 @@ public static void GenerateNativeConstructorsByValue(
public static IEnumerable<Class> KeepSingleAllPointersSpecialization(
this IEnumerable<Class> specializations)
{
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
a.Type.Type?.Desugar().IsAddress() == true;
var groups = (from ClassTemplateSpecialization spec in specializations
group spec by spec.Arguments.All(allPointers)
static bool allPointers(TemplateArgument a) => a.Type.Type?.Desugar().IsAddress() == true;
var groups = (from @class in specializations
let spec = @class.GetParentSpecialization()
group @class by spec.Arguments.All(allPointers)
into @group
select @group).ToList();
foreach (var group in groups)
Expand Down
8 changes: 6 additions & 2 deletions tests/CSharp/CSharpTemplates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,12 @@ void forceUseSpecializations(IndependentFields<int> _1, IndependentFields<bool>
VirtualTemplate<int> _6, VirtualTemplate<bool> _7,
HasDefaultTemplateArgument<int, int> _8, DerivedChangesTypeName<T1> _9,
TemplateWithIndexer<int> _10, TemplateWithIndexer<T1> _11,
TemplateWithIndexer<T2*> _12, TemplateDerivedFromRegularDynamic<RegularDynamic> _13,
IndependentFields<OnlySpecialisedInTypeArg<double> > _14, std::string s)
TemplateWithIndexer<void*> _12, TemplateWithIndexer<UsedInTemplatedIndexer> _13,
TemplateDerivedFromRegularDynamic<RegularDynamic> _14,
IndependentFields<OnlySpecialisedInTypeArg<double>> _15,
DependentPointerFields<float> _16, IndependentFields<const T1&> _17,
TemplateWithIndexer<T2*> _18, IndependentFields<int(*)(int)> _19,
TemplateWithIndexer<const char*> _20, std::string s)
{
}

Expand Down

0 comments on commit 2da9e49

Please sign in to comment.