Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly call cctor and destructor (on non-itanium ABIs) when calling functions with copy-by-value parameters #1699 #1766

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/Generator/Generators/CSharp/CSharpMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,43 @@ private void MarshalRefClass(Class @class)
{
if (Context.Parameter.IsIndirect)
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
Context.Return.Write(paramInstance);
Method cctor = @class.HasNonTrivialCopyConstructor ? @class.Methods.First(c => c.IsCopyConstructor) : null;
if (cctor != null && cctor.IsGenerated)
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");

var nativeClass = typePrinter.PrintNative(@class);

var cctorName = CSharpSources.GetFunctionNativeIdentifier(Context.Context, cctor);

var defaultValue = "";
var TypePrinter = new CSharpTypePrinter(Context.Context);
var ExpressionPrinter = new CSharpExpressionPrinter(TypePrinter);
if (cctor.Parameters.Count > 1)
defaultValue = $", {ExpressionPrinter.VisitParameter(cctor.Parameters.Last())}";

Context.Before.WriteLine($"byte* __{Context.Parameter.Name}Memory = stackalloc byte[sizeof({nativeClass})];");
Context.Before.WriteLine($"__IntPtr __{Context.Parameter.Name}Ptr = (__IntPtr)__{Context.Parameter.Name}Memory;");
Context.Before.WriteLine($"{nativeClass}.{cctorName}(__{Context.Parameter.Name}Ptr, {Context.Parameter.Name}.__Instance{defaultValue});");
Context.Return.Write($"__{Context.Parameter.Name}Ptr");

if (Context.Context.ParserOptions.IsItaniumLikeAbi && @class.HasNonTrivialDestructor)
{
Method dtor = @class.Destructors.FirstOrDefault();
if (dtor != null)
{
// todo: virtual destructors?
Context.Cleanup.WriteLine($"{nativeClass}.dtor(__{Context.Parameter.Name}Ptr);");
}
}
}
else
{
Context.Return.Write(paramInstance);
}
}
else
{
Expand Down
10 changes: 8 additions & 2 deletions src/Generator/Generators/CSharp/CSharpSources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3464,6 +3464,12 @@ public static string GetFunctionIdentifier(Function function)

public string GetFunctionNativeIdentifier(Function function,
bool isForDelegate = false)
{
return GetFunctionNativeIdentifier(Context, function, isForDelegate);
}

public static string GetFunctionNativeIdentifier(BindingContext context, Function function,
bool isForDelegate = false)
{
var identifier = new StringBuilder();

Expand Down Expand Up @@ -3494,12 +3500,12 @@ public string GetFunctionNativeIdentifier(Function function,
identifier.Append(Helpers.GetSuffixFor(specialization));

var internalParams = function.GatherInternalParams(
Context.ParserOptions.IsItaniumLikeAbi);
context.ParserOptions.IsItaniumLikeAbi);
var overloads = function.Namespace.GetOverloads(function)
.Where(f => (!f.Ignore ||
(f.OriginalFunction != null && !f.OriginalFunction.Ignore)) &&
(isForDelegate || internalParams.SequenceEqual(
f.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi),
f.GatherInternalParams(context.ParserOptions.IsItaniumLikeAbi),
new MarshallingParamComparer()))).ToList();
var index = -1;
if (overloads.Count > 1)
Expand Down
15 changes: 15 additions & 0 deletions tests/dotnet/CSharp/CSharp.Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1995,4 +1995,19 @@ public void TestPointerToClass()
Assert.IsTrue(CSharp.CSharp.PointerToClass.IsDefaultInstance);
Assert.IsTrue(CSharp.CSharp.PointerToClass.IsValid);
}

[Test]
public void TestCallByValueCopyConstructor()
{
using (var s = new CallByValueCopyConstructor())
{
s.A = 500;
CSharp.CSharp.CallByValueCopyConstructorFunction(s);
Assert.That(s.A, Is.EqualTo(500));
}

Assert.That(CallByValueCopyConstructor.ConstructorCalls, Is.EqualTo(1));
Assert.That(CallByValueCopyConstructor.CopyConstructorCalls, Is.EqualTo(1));
Assert.That(CallByValueCopyConstructor.DestructorCalls, Is.EqualTo(2));
}
}
26 changes: 26 additions & 0 deletions tests/dotnet/CSharp/CSharp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1791,3 +1791,29 @@ bool PointerTester::IsValid()
}

PointerTester* PointerToClass = &internalPointerTesterInstance;

int CallByValueCopyConstructor::constructorCalls = 0;
int CallByValueCopyConstructor::destructorCalls = 0;
int CallByValueCopyConstructor::copyConstructorCalls = 0;

CallByValueCopyConstructor::CallByValueCopyConstructor()
{
a = 0;
constructorCalls++;
}

CallByValueCopyConstructor::CallByValueCopyConstructor(const CallByValueCopyConstructor& other)
{
a = other.a;
copyConstructorCalls++;
}

CallByValueCopyConstructor::~CallByValueCopyConstructor()
{
destructorCalls++;
}

void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s)
{
s.a = 99999;
}
13 changes: 13 additions & 0 deletions tests/dotnet/CSharp/CSharp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1603,3 +1603,16 @@ class DLL_API PointerTester
};

DLL_API extern PointerTester* PointerToClass;

struct DLL_API CallByValueCopyConstructor {
int a;
static int constructorCalls;
static int destructorCalls;
static int copyConstructorCalls;

CallByValueCopyConstructor();
~CallByValueCopyConstructor();
CallByValueCopyConstructor(const CallByValueCopyConstructor& other);
};

DLL_API void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s);