Skip to content

Commit 9ca69dd

Browse files
committed
Use static-abstracts + DIMs to remove the extra shape validation.
1 parent 7182f4b commit 9ca69dd

File tree

8 files changed

+84
-62
lines changed

8 files changed

+84
-62
lines changed

docs/design/libraries/ComInterfaceGenerator/VTableStubs.md

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public class VirtualMethodIndexAttribute : Attribute
5757

5858
```
5959

60-
A new interface will be defined and used by the source generator to fetch the native `this` pointer and the vtable that the function pointer is stored in. This interface is designed to provide an API that various native platforms, like COM, WinRT, or Swift, could use to provide support for multiple managed interface wrappers from a single native object. In particular, this interface was designed to ensure it is possible support a managed gesture to do an unmanaged "type cast" (i.e., `QueryInterface` in the COM and WinRT worlds).
60+
New interfaces will be defined and used by the source generator to fetch the native `this` pointer and the vtable that the function pointer is stored in. These interfaces are designed to provide an API that various native platforms, like COM, WinRT, or Swift, could use to provide support for multiple managed interface wrappers from a single native object. In particular, these interfaces are designed to ensure it is possible support a managed gesture to do an unmanaged "type cast" (i.e., `QueryInterface` in the COM and WinRT worlds).
6161

6262
```csharp
6363
namespace System.Runtime.InteropServices;
@@ -82,13 +82,24 @@ public readonly ref struct VirtualMethodTableInfo
8282

8383
public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
8484
{
85-
VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
85+
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
86+
87+
public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
88+
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
89+
{
90+
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
91+
}
92+
}
93+
94+
public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
95+
{
96+
public abstract static T TypeKey { get; }
8697
}
8798
```
8899

89100
## Required API Shapes
90101

91-
In addition to the provided APIs above, users will be required to add a `readonly static` field or `get`-able property to their user-defined interface type named `TypeKey`. The type of this member will be used as the `T` in `IUnmanagedVirtualMethodTableProvider<T>` and the value will be passed to `GetVirtualMethodTableInfoForKey`. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.
102+
The user will be required to implement `IUnmanagedVirtualMethodTableProvider<T>` on the type that provides the method tables, and `IUnmanagedInterfaceType<T>` on the type that defines the unmanaged interface. The `T` types must match between the two interfaces. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.
92103

93104
## Example Usage
94105

@@ -149,11 +160,11 @@ using System.Runtime.InteropServices;
149160
[assembly:DisableRuntimeMarshalling]
150161
151162
// Define the interface of the native API
152-
partial interface INativeAPI
163+
partial interface INativeAPI : IUnmanagedInterfaceType<NoCasting>
153164
{
154165
// There is no concept of casting for this API, but providing a type key is still required by the generator.
155166
// Use an empty readonly record struct to provide a type that implements IEquatable<T> but contains no data.
156-
readonly static NoCasting TypeKey = default;
167+
static NoCasting IUnmanagedInterfaceType.TypeKey => default;
157168
158169
[VirtualMethodIndex(0, ImplicitThisParameter = false, Direction = CustomTypeMarshallerDirection.In)]
159170
int GetVersion();
@@ -218,7 +229,7 @@ partial interface INativeAPI
218229
{
219230
int INativeAPI.GetVersion()
220231
{
221-
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
232+
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
222233
int retVal;
223234
retVal = ((delegate* unmanaged<int>)vtable[0])();
224235
return retVal;
@@ -231,7 +242,7 @@ partial interface INativeAPI
231242
{
232243
int INativeAPI.Add(int x, int y)
233244
{
234-
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
245+
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
235246
int retVal;
236247
retVal = ((delegate* unmanaged<int, int, int>)vtable[1])(x, y);
237248
return retVal;
@@ -244,7 +255,7 @@ partial interface INativeAPI
244255
{
245256
int INativeAPI.Multiply(int x, int y)
246257
{
247-
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
258+
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
248259
int retVal;
249260
retVal = ((delegate* unmanaged<int, int, int>)vtable[2])(x, y);
250261
return retVal;
@@ -279,9 +290,9 @@ struct IUnknown
279290
using System;
280291
using System.Runtime.InteropServices;
281292

282-
interface IUnknown
293+
interface IUnknown: IUnmanagedInterfaceType<Guid>
283294
{
284-
public static readonly Guid TypeKey = Guid.Parse("00000000-0000-0000-C000-000000000046");
295+
static Guid IUnmanagedTypeInterfaceType<Guid>.TypeKey => Guid.Parse("00000000-0000-0000-C000-000000000046");
285296

286297
[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall), typeof(CallConvMemberFunction) })]
287298
[VirtualMethodIndex(0)]
@@ -347,7 +358,7 @@ partial interface IUnknown
347358
{
348359
int IUnknown.QueryInterface(in Guid riid, out IntPtr ppvObject)
349360
{
350-
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
361+
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
351362
int retVal;
352363
fixed (Guid* riid__gen_native = &riid)
353364
fixed (IntPtr* ppvObject__gen_native = &ppvObject)
@@ -364,7 +375,7 @@ partial interface IUnknown
364375
{
365376
uint IUnknown.AddRef()
366377
{
367-
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
378+
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
368379
uint retVal;
369380
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[1])(thisPtr);
370381
return retVal;
@@ -377,7 +388,7 @@ partial interface IUnknown
377388
{
378389
uint IUnknown.Release()
379390
{
380-
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
391+
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
381392
uint retVal;
382393
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[2])(thisPtr);
383394
return retVal;

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
117117
{
118118
var setupStatements = new List<StatementSyntax>
119119
{
120-
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey(<containingTypeName>.TypeKey)
120+
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
121121
ExpressionStatement(
122122
AssignmentExpression(
123123
SyntaxKind.SimpleAssignmentExpression,
@@ -141,15 +141,12 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
141141
TypeArgumentList(
142142
SingletonSeparatedList(typeKeyType.Syntax))),
143143
ThisExpression())),
144-
IdentifierName("GetVirtualMethodTableInfoForKey")))
144+
GenericName(
145+
Identifier("GetVirtualMethodTableInfoForKey"),
146+
TypeArgumentList(
147+
SingletonSeparatedList(containingTypeName)))))
145148
.WithArgumentList(
146-
ArgumentList(
147-
SingletonSeparatedList(
148-
Argument(
149-
MemberAccessExpression(
150-
SyntaxKind.SimpleMemberAccessExpression,
151-
containingTypeName,
152-
IdentifierName("TypeKey"))))))))
149+
ArgumentList())))
153150
};
154151

155152
GeneratedStatements statements = GeneratedStatements.Create(

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
228228
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
229229
INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
230230
INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
231+
INamedTypeSymbol iUnmanagedInterfaceTypeType = environment.Compilation.GetTypeByMetadataName(TypeNames.IUnmanagedInterfaceType_Metadata)!;
231232
// Get any attributes of interest on the method
232233
AttributeData? virtualMethodIndexAttr = null;
233234
AttributeData? lcidConversionAttr = null;
@@ -310,14 +311,14 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
310311
var typeKeyOwner = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
311312
ManagedTypeInfo typeKeyType = SpecialTypeInfo.Byte;
312313

313-
IFieldSymbol? typeKeyField = symbol.ContainingType.GetMembers("TypeKey").OfType<IFieldSymbol>().FirstOrDefault(f => f.IsStatic);
314-
if (typeKeyField is null)
314+
INamedTypeSymbol? iUnmanagedInterfaceTypeInstantiation = symbol.ContainingType.AllInterfaces.FirstOrDefault(iface => SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, iUnmanagedInterfaceTypeType));
315+
if (iUnmanagedInterfaceTypeInstantiation is null)
315316
{
316317
// Report invalid configuration
317318
}
318319
else
319320
{
320-
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(typeKeyField.Type);
321+
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iUnmanagedInterfaceTypeInstantiation.TypeArguments[0]);
321322
}
322323

323324
return new IncrementalStubGenerationContext(

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public static class TypeNames
3232

3333
public const string IUnmanagedVirtualMethodTableProvider = "System.Runtime.InteropServices.IUnmanagedVirtualMethodTableProvider";
3434

35+
public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`1";
36+
3537
public const string System_Span_Metadata = "System.Span`1";
3638
public const string System_Span = "System.Span";
3739
public const string System_ReadOnlySpan_Metadata = "System.ReadOnlySpan`1";

src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnmanagedVirtualMethodTableProvider.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,18 @@ public void Deconstruct(out IntPtr thisPointer, out ReadOnlySpan<IntPtr> virtual
2929

3030
public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
3131
{
32-
VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
32+
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
33+
34+
public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
35+
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
36+
{
37+
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
38+
}
3339
}
3440

41+
42+
public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
43+
{
44+
public abstract static T TypeKey { get; }
45+
}
3546
}

src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ internal partial class ImplicitThis
1818
{
1919
public readonly record struct NoCasting;
2020

21-
internal partial interface INativeObject
21+
internal partial interface INativeObject : IUnmanagedInterfaceType<NoCasting>
2222
{
23-
public static readonly NoCasting TypeKey = default;
23+
static NoCasting IUnmanagedInterfaceType<NoCasting>.TypeKey => default;
2424

2525
[VirtualMethodIndex(0, ImplicitThisParameter = true)]
2626
int GetData();

src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/NoImplicitThisTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ internal partial class NoImplicitThis
1818
{
1919
public readonly record struct NoCasting;
2020

21-
internal partial interface IStaticMethodTable
21+
internal partial interface IStaticMethodTable : IUnmanagedInterfaceType<NoCasting>
2222
{
23-
public static readonly NoCasting TypeKey = default;
23+
static NoCasting IUnmanagedInterfaceType<NoCasting>.TypeKey => default;
2424

2525
[VirtualMethodIndex(0, ImplicitThisParameter = false)]
2626
int Add(int x, int y);

0 commit comments

Comments
 (0)