Skip to content

Commit 810a7f9

Browse files
Add BStrStringMarshaller to source generator (#69213)
* Add BStrStringMarshaller to source generator * Convert to use void* for BStr and Utf16 marshallers' native types. Co-authored-by: Jan Kotas <[email protected]>
1 parent de27aa1 commit 810a7f9

File tree

9 files changed

+307
-26
lines changed

9 files changed

+307
-26
lines changed

src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems

+1
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@
866866
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\MarshalDirectiveException.cs" />
867867
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\AnsiStringMarshaller.cs" />
868868
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\ArrayMarshaller.cs" />
869+
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\BStrStringMarshaller.cs" />
869870
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerAttribute.cs" />
870871
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerDirection.cs" />
871872
<Compile Include="$(MSBuildThisFileDirectory)System\Runtime\InteropServices\Marshalling\CustomTypeMarshallerFeatures.cs" />
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Diagnostics;
5+
using System.Runtime.CompilerServices;
6+
using System.Text;
7+
8+
namespace System.Runtime.InteropServices.Marshalling
9+
{
10+
/// <summary>
11+
/// Marshaller for BSTR strings
12+
/// </summary>
13+
[CLSCompliant(false)]
14+
[CustomTypeMarshaller(typeof(string), BufferSize = 0x100,
15+
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer)]
16+
public unsafe ref struct BStrStringMarshaller
17+
{
18+
private void* _ptrToFirstChar;
19+
private bool _allocated;
20+
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
23+
/// </summary>
24+
/// <param name="str">The string to marshal.</param>
25+
public BStrStringMarshaller(string? str)
26+
: this(str, default)
27+
{ }
28+
29+
/// <summary>
30+
/// Initializes a new instance of the <see cref="BStrStringMarshaller"/>.
31+
/// </summary>
32+
/// <param name="str">The string to marshal.</param>
33+
/// <param name="buffer">Buffer that may be used for marshalling.</param>
34+
/// <remarks>
35+
/// The <paramref name="buffer"/> must not be movable - that is, it should not be
36+
/// on the managed heap or it should be pinned.
37+
/// <seealso cref="CustomTypeMarshallerFeatures.CallerAllocatedBuffer"/>
38+
/// </remarks>
39+
public BStrStringMarshaller(string? str, Span<ushort> buffer)
40+
{
41+
_allocated = false;
42+
43+
if (str is null)
44+
{
45+
_ptrToFirstChar = null;
46+
return;
47+
}
48+
49+
ushort* ptrToFirstChar;
50+
int lengthInBytes = checked(sizeof(char) * str.Length);
51+
52+
// A caller provided buffer must be at least (lengthInBytes + 6) bytes
53+
// in order to be constructed manually. The 6 extra bytes are 4 for byte length and 2 for wide null.
54+
int manualBstrNeeds = checked(lengthInBytes + 6);
55+
if (manualBstrNeeds > buffer.Length)
56+
{
57+
// Use precise byte count when the provided stack-allocated buffer is not sufficient
58+
ptrToFirstChar = (ushort*)Marshal.AllocBSTRByteLen((uint)lengthInBytes);
59+
_allocated = true;
60+
}
61+
else
62+
{
63+
// Set length and update buffer target
64+
byte* pBuffer = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer));
65+
*((uint*)pBuffer) = (uint)lengthInBytes;
66+
ptrToFirstChar = (ushort*)(pBuffer + sizeof(uint));
67+
}
68+
69+
// Confirm the size is properly set for the allocated BSTR.
70+
Debug.Assert(lengthInBytes == Marshal.SysStringByteLen((IntPtr)ptrToFirstChar));
71+
72+
// Copy characters from the managed string
73+
str.CopyTo(new Span<char>(ptrToFirstChar, str.Length));
74+
ptrToFirstChar[str.Length] = '\0'; // null-terminate
75+
_ptrToFirstChar = ptrToFirstChar;
76+
}
77+
78+
/// <summary>
79+
/// Returns the native value representing the string.
80+
/// </summary>
81+
/// <remarks>
82+
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
83+
/// </remarks>
84+
public void* ToNativeValue() => _ptrToFirstChar;
85+
86+
/// <summary>
87+
/// Sets the native value representing the string.
88+
/// </summary>
89+
/// <param name="value">The native value.</param>
90+
/// <remarks>
91+
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
92+
/// </remarks>
93+
public void FromNativeValue(void* value)
94+
{
95+
_ptrToFirstChar = value;
96+
_allocated = true;
97+
}
98+
99+
/// <summary>
100+
/// Returns the managed string.
101+
/// </summary>
102+
/// <remarks>
103+
/// <seealso cref="CustomTypeMarshallerDirection.Out"/>
104+
/// </remarks>
105+
public string? ToManaged()
106+
{
107+
if (_ptrToFirstChar is null)
108+
return null;
109+
110+
return Marshal.PtrToStringBSTR((IntPtr)_ptrToFirstChar);
111+
}
112+
113+
/// <summary>
114+
/// Frees native resources.
115+
/// </summary>
116+
/// <remarks>
117+
/// <seealso cref="CustomTypeMarshallerFeatures.UnmanagedResources"/>
118+
/// </remarks>
119+
public void FreeNative()
120+
{
121+
if (_allocated)
122+
Marshal.FreeBSTR((IntPtr)_ptrToFirstChar);
123+
}
124+
}
125+
}

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace System.Runtime.InteropServices.Marshalling
1313
Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling)]
1414
public unsafe ref struct Utf16StringMarshaller
1515
{
16-
private ushort* _nativeValue;
16+
private void* _nativeValue;
1717

1818
/// <summary>
1919
/// Initializes a new instance of the <see cref="Utf16StringMarshaller"/>.
@@ -25,7 +25,7 @@ public unsafe ref struct Utf16StringMarshaller
2525
/// <param name="str">The string to marshal.</param>
2626
public Utf16StringMarshaller(string? str)
2727
{
28-
_nativeValue = (ushort*)Marshal.StringToCoTaskMemUni(str);
28+
_nativeValue = (void*)Marshal.StringToCoTaskMemUni(str);
2929
}
3030

3131
/// <summary>
@@ -34,7 +34,7 @@ public Utf16StringMarshaller(string? str)
3434
/// <remarks>
3535
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
3636
/// </remarks>
37-
public ushort* ToNativeValue() => _nativeValue;
37+
public void* ToNativeValue() => _nativeValue;
3838

3939
/// <summary>
4040
/// Sets the native value representing the string.
@@ -43,7 +43,7 @@ public Utf16StringMarshaller(string? str)
4343
/// <remarks>
4444
/// <seealso cref="CustomTypeMarshallerFeatures.TwoStageMarshalling"/>
4545
/// </remarks>
46-
public void FromNativeValue(ushort* value) => _nativeValue = value;
46+
public void FromNativeValue(void* value) => _nativeValue = value;
4747

4848
/// <summary>
4949
/// Returns the managed string.

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs

+19-20
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
namespace Microsoft.Interop
1313
{
14-
1514
/// <summary>
1615
/// Type used to pass on default marshalling details.
1716
/// </summary>
@@ -72,7 +71,6 @@ public enum CharEncoding
7271
Undefined,
7372
Utf8,
7473
Utf16,
75-
Ansi,
7674
Custom
7775
}
7876

@@ -761,7 +759,13 @@ private bool TryCreateTypeBasedMarshallingInfo(
761759
}
762760
else
763761
{
764-
marshallingInfo = CreateStringMarshallingInfo(type, _defaultInfo.CharEncoding);
762+
marshallingInfo = _defaultInfo.CharEncoding switch
763+
{
764+
CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller),
765+
CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller),
766+
_ => throw new InvalidOperationException()
767+
};
768+
765769
return true;
766770
}
767771

@@ -842,30 +846,25 @@ private MarshallingInfo CreateStringMarshallingInfo(
842846
ITypeSymbol type,
843847
UnmanagedType unmanagedType)
844848
{
845-
CharEncoding charEncoding = unmanagedType switch
849+
string? marshallerName = unmanagedType switch
846850
{
847-
UnmanagedType.LPStr => CharEncoding.Ansi,
848-
UnmanagedType.LPTStr or UnmanagedType.LPWStr => CharEncoding.Utf16,
849-
MarshalAsInfo.UnmanagedType_LPUTF8Str => CharEncoding.Utf8,
850-
_ => CharEncoding.Undefined
851+
UnmanagedType.BStr => TypeNames.BStrStringMarshaller,
852+
UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller,
853+
UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller,
854+
MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller,
855+
_ => null
851856
};
852-
if (charEncoding == CharEncoding.Undefined)
857+
858+
if (marshallerName is null)
853859
return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding);
854860

855-
return CreateStringMarshallingInfo(type, charEncoding);
861+
return CreateStringMarshallingInfo(type, marshallerName);
856862
}
857863

858864
private MarshallingInfo CreateStringMarshallingInfo(
859865
ITypeSymbol type,
860-
CharEncoding charEncoding)
866+
string marshallerName)
861867
{
862-
string? marshallerName = charEncoding switch
863-
{
864-
CharEncoding.Ansi => TypeNames.AnsiStringMarshaller,
865-
CharEncoding.Utf16 => TypeNames.Utf16StringMarshaller,
866-
CharEncoding.Utf8 => TypeNames.Utf8StringMarshaller,
867-
_ => throw new InvalidOperationException()
868-
};
869868
INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName);
870869
if (stringMarshaller is null)
871870
return new MissingSupportMarshallingInfo();
@@ -876,9 +875,9 @@ private MarshallingInfo CreateStringMarshallingInfo(
876875
return CreateNativeMarshallingInfoForValue(
877876
type,
878877
stringMarshaller,
879-
default,
878+
null,
880879
customTypeMarshallerData.Value,
881-
allowPinningManagedType: charEncoding == CharEncoding.Utf16,
880+
allowPinningManagedType: marshallerName is TypeNames.Utf16StringMarshaller,
882881
useDefaultMarshalling: false);
883882
}
884883

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public static class TypeNames
1818
public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder";
1919

2020
public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller";
21+
public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller";
2122
public const string Utf16StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller";
2223
public const string Utf8StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf8StringMarshaller";
2324

src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs

+16-2
Original file line numberDiff line numberDiff line change
@@ -2103,6 +2103,20 @@ public void FromNativeValue(byte* value) { }
21032103
public T[]? ToManaged() { throw null; }
21042104
public void FreeNative() { }
21052105
}
2106+
[System.CLSCompliant(false)]
2107+
[System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(string), BufferSize = 0x100,
2108+
Features = System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.UnmanagedResources
2109+
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.CallerAllocatedBuffer
2110+
| System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.TwoStageMarshalling )]
2111+
public unsafe ref struct BStrStringMarshaller
2112+
{
2113+
public BStrStringMarshaller(string? str) { }
2114+
public BStrStringMarshaller(string? str, System.Span<ushort> buffer) { }
2115+
public void* ToNativeValue() { throw null; }
2116+
public void FromNativeValue(void* value) { }
2117+
public string? ToManaged() { throw null; }
2118+
public void FreeNative() { }
2119+
}
21062120
[System.AttributeUsageAttribute(System.AttributeTargets.Struct)]
21072121
public sealed partial class CustomTypeMarshallerAttribute : System.Attribute
21082122
{
@@ -2197,8 +2211,8 @@ public void FreeNative() { }
21972211
public unsafe ref struct Utf16StringMarshaller
21982212
{
21992213
public Utf16StringMarshaller(string? str) { }
2200-
public ushort* ToNativeValue() { throw null; }
2201-
public void FromNativeValue(ushort* value) { }
2214+
public void* ToNativeValue() { throw null; }
2215+
public void FromNativeValue(void* value) { }
22022216
public string? ToManaged() { throw null; }
22032217
public void FreeNative() { }
22042218
}

src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs

+77
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ private class EntryPoints
2323

2424
private const string UShortSuffix = "_ushort";
2525
private const string ByteSuffix = "_byte";
26+
private const string BStrSuffix = "_bstr";
2627

2728
public class Byte
2829
{
@@ -41,6 +42,15 @@ public class UShort
4142
public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix;
4243
public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix;
4344
}
45+
46+
public class BStr
47+
{
48+
public const string ReturnLength = EntryPoints.ReturnLength + BStrSuffix;
49+
public const string ReverseReturn = EntryPoints.ReverseReturn + BStrSuffix;
50+
public const string ReverseOut = EntryPoints.ReverseOut + BStrSuffix;
51+
public const string ReverseInplace = EntryPoints.ReverseInplace + BStrSuffix;
52+
public const string ReverseReplace = EntryPoints.ReverseReplace + BStrSuffix;
53+
}
4454
}
4555

4656
public partial class Utf16
@@ -185,6 +195,31 @@ public partial class LPStr
185195
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s);
186196
}
187197

198+
public partial class BStr
199+
{
200+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength)]
201+
public static partial int ReturnLength([MarshalAs(UnmanagedType.BStr)] string s);
202+
203+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength, StringMarshalling = StringMarshalling.Utf16)]
204+
public static partial int ReturnLength_IgnoreStringMarshalling([MarshalAs(UnmanagedType.BStr)] string s);
205+
206+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReturn)]
207+
[return: MarshalAs(UnmanagedType.BStr)]
208+
public static partial string Reverse_Return([MarshalAs(UnmanagedType.BStr)] string s);
209+
210+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseOut)]
211+
public static partial void Reverse_Out([MarshalAs(UnmanagedType.BStr)] string s, [MarshalAs(UnmanagedType.BStr)] out string ret);
212+
213+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
214+
public static partial void Reverse_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
215+
216+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)]
217+
public static partial void Reverse_In([MarshalAs(UnmanagedType.BStr)] in string s);
218+
219+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReplace)]
220+
public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.BStr)] ref string s);
221+
}
222+
188223
public partial class StringMarshallingCustomType
189224
{
190225
public partial class Utf16
@@ -418,6 +453,48 @@ public void AnsiStringByRef(string value)
418453
Assert.Equal(expected, refValue);
419454
}
420455

456+
[Theory]
457+
[MemberData(nameof(UnicodeStrings))]
458+
public void BStrStringMarshalledAsExpected(string value)
459+
{
460+
int expectedLen = value != null ? value.Length : -1;
461+
462+
Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength(value));
463+
Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength_IgnoreStringMarshalling(value));
464+
}
465+
466+
[Theory]
467+
[MemberData(nameof(UnicodeStrings))]
468+
public void BStrStringReturn(string value)
469+
{
470+
string expected = ReverseChars(value);
471+
472+
Assert.Equal(expected, NativeExportsNE.BStr.Reverse_Return(value));
473+
474+
string ret;
475+
NativeExportsNE.BStr.Reverse_Out(value, out ret);
476+
Assert.Equal(expected, ret);
477+
}
478+
479+
[Theory]
480+
[MemberData(nameof(UnicodeStrings))]
481+
public void BStrStringByRef(string value)
482+
{
483+
string refValue = value;
484+
string expected = ReverseChars(value);
485+
486+
NativeExportsNE.BStr.Reverse_In(in refValue);
487+
Assert.Equal(value, refValue); // Should not be updated when using 'in'
488+
489+
refValue = value;
490+
NativeExportsNE.BStr.Reverse_Ref(ref refValue);
491+
Assert.Equal(expected, refValue);
492+
493+
refValue = value;
494+
NativeExportsNE.BStr.Reverse_Replace_Ref(ref refValue);
495+
Assert.Equal(expected, refValue);
496+
}
497+
421498
[Theory]
422499
[MemberData(nameof(UnicodeStrings))]
423500
public void StringMarshallingCustomType_MarshalledAsExpected(string value)

0 commit comments

Comments
 (0)