diff --git a/src/thirtytwo/NativeMethods.txt b/src/thirtytwo/NativeMethods.txt index dd5ba7c..9d0d814 100644 --- a/src/thirtytwo/NativeMethods.txt +++ b/src/thirtytwo/NativeMethods.txt @@ -22,6 +22,7 @@ CoGetClassObject CombineRgn COMBOBOXINFO_BUTTON_STATE CopyImage +CoTaskMemAlloc CoTaskMemFree CountClipboardFormats CreateActCtx @@ -224,6 +225,7 @@ IMarshal IModalWindow InitCommonControlsEx InitVariantFromDoubleArray +INoMarshal INTERFACEDATA INVALID_HANDLE_VALUE InvalidateRect diff --git a/src/thirtytwo/Win32/System/Com/Lifetime.cs b/src/thirtytwo/Win32/System/Com/Lifetime.cs new file mode 100644 index 0000000..cf132c7 --- /dev/null +++ b/src/thirtytwo/Win32/System/Com/Lifetime.cs @@ -0,0 +1,80 @@ +// Copyright (c) Jeremy W. Kuhne. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.InteropServices; + +namespace Windows.Win32.System.Com; + +/// +/// Lifetime management helper for a COM callable wrapper. It holds the created +/// wrapper with he given . +/// +/// +/// +/// This should not be created directly. Instead use . +/// +/// +/// A COM object's memory layout is a virtual function table (vtable) pointer followed by instance data. We're +/// effectively manually creating a COM object here that contains instance data of a GCHandle to the related +/// managed object and a ref count. +/// +/// +public unsafe struct Lifetime where TVTable : unmanaged +{ + public TVTable* VTable; + public IUnknown* Handle; + public uint RefCount; + + public static unsafe uint AddRef(IUnknown* @this) + => Interlocked.Increment(ref ((Lifetime*)@this)->RefCount); + + public static unsafe uint Release(IUnknown* @this) + { + var lifetime = (Lifetime*)@this; + Debug.Assert(lifetime->RefCount > 0); + uint count = Interlocked.Decrement(ref lifetime->RefCount); + if (count == 0) + { + GCHandle.FromIntPtr((nint)lifetime->Handle).Free(); + Interop.CoTaskMemFree(lifetime); + } + + return count; + } + + /// + /// Allocate a lifetime wrapper for the given with the given + /// . + /// + /// + /// + /// This creates a to root the until ref + /// counting has gone to zero. + /// + /// + /// The should be fixed, typically as a static. Com calls always + /// include the "this" pointer as the first argument. + /// + /// + public static unsafe Lifetime* Allocate(TObject @object, TVTable* vtable) + { + // Manually allocate a native instance of this struct. + var wrapper = (Lifetime*)Interop.CoTaskMemAlloc((nuint)sizeof(Lifetime)); + + // Assign a pointer to the vtable, allocate a GCHandle for the related object, and set the initial ref count. + wrapper->VTable = vtable; + wrapper->Handle = (IUnknown*)GCHandle.ToIntPtr(GCHandle.Alloc(@object)); + wrapper->RefCount = 1; + + return wrapper; + } + + /// + /// Gets the object wrapped by a lifetime wrapper. + /// + public static TObject? GetObject(IUnknown* @this) + { + var lifetime = (Lifetime*)@this; + return (TObject?)GCHandle.FromIntPtr((nint)lifetime->Handle).Target; + } +} \ No newline at end of file diff --git a/src/thirtytwo_tests/Win32/System/Com/ComTests.cs b/src/thirtytwo_tests/Win32/System/Com/ComTests.cs index b2ff04a..7d5ab72 100644 --- a/src/thirtytwo_tests/Win32/System/Com/ComTests.cs +++ b/src/thirtytwo_tests/Win32/System/Com/ComTests.cs @@ -1,8 +1,13 @@ // Copyright (c) Jeremy W. Kuhne. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using Windows.Dialogs; +using Windows.Win32.Foundation; +using Windows.Win32.System.Com.Marshal; using Windows.Win32.UI.Shell; +using InteropMarshal = global::System.Runtime.InteropServices.Marshal; namespace Windows.Win32.System.Com; @@ -27,4 +32,144 @@ public void Com_GetComPointer_SameInterfaceInstance() Assert.True(iEvents1.Pointer == iEvents2.Pointer); } + + [Fact] + public void Com_BuiltInCom_RCW_Behavior() + { + UnknownTest unknown = new(); + using ComScope iUnknown = new(UnknownCCW.CreateInstance(unknown)); + + object rcw = InteropMarshal.GetObjectForIUnknown((IntPtr)iUnknown.Pointer); + + unknown.AddRefCount.Should().Be(1); + unknown.ReleaseCount.Should().Be(1); + unknown.LastRefCount.Should().Be(2); + unknown.QueryInterfaceGuids.Should().BeEquivalentTo([ + IUnknown.IID_Guid, + INoMarshal.IID_Guid, + IAgileObject.IID_Guid, + IMarshal.IID_Guid]); + + // Release and FinalRelease look the same from our IUnknown's perspective + InteropMarshal.FinalReleaseComObject(rcw); + + unknown.AddRefCount.Should().Be(1); + unknown.ReleaseCount.Should().Be(2); + unknown.LastRefCount.Should().Be(1); + unknown.QueryInterfaceGuids.Should().BeEquivalentTo([ + IUnknown.IID_Guid, + INoMarshal.IID_Guid, + IAgileObject.IID_Guid, + IMarshal.IID_Guid]); + } + + public interface IUnkownTest + { + public void QueryInterface(Guid riid); + public void AddRef(uint current); + public void Release(uint current); + } + + public class UnknownTest : IUnkownTest + { + public int AddRefCount { get; private set; } + public int ReleaseCount { get; private set; } + public List QueryInterfaceGuids { get; } = []; + public int LastRefCount { get; private set; } + + void IUnkownTest.AddRef(uint current) + { + AddRefCount++; + LastRefCount = (int)current; + } + + void IUnkownTest.QueryInterface(Guid riid) + { + QueryInterfaceGuids.Add(riid); + } + + void IUnkownTest.Release(uint current) + { + ReleaseCount++; + LastRefCount = (int)current; + } + } + + public static class UnknownCCW + { + public static unsafe IUnknown* CreateInstance(IUnkownTest @object) + => (IUnknown*)Lifetime.Allocate(@object, CCWVTable); + + private static readonly IUnknown.Vtbl* CCWVTable = AllocateVTable(); + + private static unsafe IUnknown.Vtbl* AllocateVTable() + { + // Allocate and create a static VTable for this type projection. + var vtable = (IUnknown.Vtbl*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(UnknownCCW), sizeof(IUnknown.Vtbl)); + + // IUnknown + vtable->QueryInterface_1 = &QueryInterface; + vtable->AddRef_2 = &AddRef; + vtable->Release_3 = &Release; + return vtable; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])] + private static HRESULT QueryInterface(IUnknown* @this, Guid* riid, void** ppvObject) + { + if (ppvObject is null) + { + return HRESULT.E_POINTER; + } + + var unknown = Lifetime.GetObject(@this); + if (unknown is null) + { + return HRESULT.COR_E_OBJECTDISPOSED; + } + + unknown.QueryInterface(*riid); + + if (*riid == typeof(IUnknown).GUID) + { + *ppvObject = @this; + } + else + { + *ppvObject = null; + return HRESULT.E_NOINTERFACE; + } + + Lifetime.AddRef(@this); + return HRESULT.S_OK; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])] + private static uint AddRef(IUnknown* @this) + { + var unknown = Lifetime.GetObject(@this); + if (unknown is null) + { + return HRESULT.COR_E_OBJECTDISPOSED; + } + + uint current = Lifetime.AddRef(@this); + unknown.AddRef(current); + return current; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])] + private static uint Release(IUnknown* @this) + { + var unknown = Lifetime.GetObject(@this); + if (unknown is null) + { + return HRESULT.COR_E_OBJECTDISPOSED; + } + + uint current = Lifetime.Release(@this); + unknown.Release(current); + return current; + } + } }