Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
JeremyKuhne committed Nov 15, 2024
2 parents 3c2549c + 9023e72 commit c54532b
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/thirtytwo/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ CoGetClassObject
CombineRgn
COMBOBOXINFO_BUTTON_STATE
CopyImage
CoTaskMemAlloc
CoTaskMemFree
CountClipboardFormats
CreateActCtx
Expand Down Expand Up @@ -224,6 +225,7 @@ IMarshal
IModalWindow
InitCommonControlsEx
InitVariantFromDoubleArray
INoMarshal
INTERFACEDATA
INVALID_HANDLE_VALUE
InvalidateRect
Expand Down
80 changes: 80 additions & 0 deletions src/thirtytwo/Win32/System/Com/Lifetime.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Jeremy W. Kuhne. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Runtime.InteropServices;

namespace Windows.Win32.System.Com;

/// <summary>
/// Lifetime management helper for a COM callable wrapper. It holds the created <typeparamref name="TObject"/>
/// wrapper with he given <typeparamref name="TVTable"/>.
/// </summary>
/// <remarks>
/// <para>
/// This should not be created directly. Instead use <see cref="Lifetime{TVTable, TObject}.Allocate"/>.
/// </para>
/// <para>
/// A COM object's memory layout is a virtual function table (vtable) pointer followed by instance data. We're
/// effectively manually creating a COM object here that contains instance data of a GCHandle to the related
/// managed object and a ref count.
/// </para>
/// </remarks>
public unsafe struct Lifetime<TVTable, TObject> where TVTable : unmanaged
{
public TVTable* VTable;
public IUnknown* Handle;
public uint RefCount;

public static unsafe uint AddRef(IUnknown* @this)
=> Interlocked.Increment(ref ((Lifetime<TVTable, TObject>*)@this)->RefCount);

public static unsafe uint Release(IUnknown* @this)
{
var lifetime = (Lifetime<TVTable, TObject>*)@this;
Debug.Assert(lifetime->RefCount > 0);
uint count = Interlocked.Decrement(ref lifetime->RefCount);
if (count == 0)
{
GCHandle.FromIntPtr((nint)lifetime->Handle).Free();
Interop.CoTaskMemFree(lifetime);
}

return count;
}

/// <summary>
/// Allocate a lifetime wrapper for the given <paramref name="object"/> with the given
/// <paramref name="vtable"/>.
/// </summary>
/// <remarks>
/// <para>
/// This creates a <see cref="GCHandle"/> to root the <paramref name="object"/> until ref
/// counting has gone to zero.
/// </para>
/// <para>
/// The <paramref name="vtable"/> should be fixed, typically as a static. Com calls always
/// include the "this" pointer as the first argument.
/// </para>
/// </remarks>
public static unsafe Lifetime<TVTable, TObject>* Allocate(TObject @object, TVTable* vtable)
{
// Manually allocate a native instance of this struct.
var wrapper = (Lifetime<TVTable, TObject>*)Interop.CoTaskMemAlloc((nuint)sizeof(Lifetime<TVTable, TObject>));

// Assign a pointer to the vtable, allocate a GCHandle for the related object, and set the initial ref count.
wrapper->VTable = vtable;
wrapper->Handle = (IUnknown*)GCHandle.ToIntPtr(GCHandle.Alloc(@object));
wrapper->RefCount = 1;

return wrapper;
}

/// <summary>
/// Gets the object wrapped by a lifetime wrapper.
/// </summary>
public static TObject? GetObject(IUnknown* @this)
{
var lifetime = (Lifetime<TVTable, TObject>*)@this;
return (TObject?)GCHandle.FromIntPtr((nint)lifetime->Handle).Target;
}
}
145 changes: 145 additions & 0 deletions src/thirtytwo_tests/Win32/System/Com/ComTests.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// Copyright (c) Jeremy W. Kuhne. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Windows.Dialogs;
using Windows.Win32.Foundation;
using Windows.Win32.System.Com.Marshal;
using Windows.Win32.UI.Shell;
using InteropMarshal = global::System.Runtime.InteropServices.Marshal;

namespace Windows.Win32.System.Com;

Expand All @@ -27,4 +32,144 @@ public void Com_GetComPointer_SameInterfaceInstance()

Assert.True(iEvents1.Pointer == iEvents2.Pointer);
}

[Fact]
public void Com_BuiltInCom_RCW_Behavior()
{
UnknownTest unknown = new();
using ComScope<IUnknown> iUnknown = new(UnknownCCW.CreateInstance(unknown));

object rcw = InteropMarshal.GetObjectForIUnknown((IntPtr)iUnknown.Pointer);

unknown.AddRefCount.Should().Be(1);
unknown.ReleaseCount.Should().Be(1);
unknown.LastRefCount.Should().Be(2);
unknown.QueryInterfaceGuids.Should().BeEquivalentTo([
IUnknown.IID_Guid,
INoMarshal.IID_Guid,
IAgileObject.IID_Guid,
IMarshal.IID_Guid]);

// Release and FinalRelease look the same from our IUnknown's perspective
InteropMarshal.FinalReleaseComObject(rcw);

unknown.AddRefCount.Should().Be(1);
unknown.ReleaseCount.Should().Be(2);
unknown.LastRefCount.Should().Be(1);
unknown.QueryInterfaceGuids.Should().BeEquivalentTo([
IUnknown.IID_Guid,
INoMarshal.IID_Guid,
IAgileObject.IID_Guid,
IMarshal.IID_Guid]);
}

public interface IUnkownTest
{
public void QueryInterface(Guid riid);
public void AddRef(uint current);
public void Release(uint current);
}

public class UnknownTest : IUnkownTest
{
public int AddRefCount { get; private set; }
public int ReleaseCount { get; private set; }
public List<Guid> QueryInterfaceGuids { get; } = [];
public int LastRefCount { get; private set; }

void IUnkownTest.AddRef(uint current)
{
AddRefCount++;
LastRefCount = (int)current;
}

void IUnkownTest.QueryInterface(Guid riid)
{
QueryInterfaceGuids.Add(riid);
}

void IUnkownTest.Release(uint current)
{
ReleaseCount++;
LastRefCount = (int)current;
}
}

public static class UnknownCCW
{
public static unsafe IUnknown* CreateInstance(IUnkownTest @object)
=> (IUnknown*)Lifetime<IUnknown.Vtbl, IUnkownTest>.Allocate(@object, CCWVTable);

private static readonly IUnknown.Vtbl* CCWVTable = AllocateVTable();

private static unsafe IUnknown.Vtbl* AllocateVTable()
{
// Allocate and create a static VTable for this type projection.
var vtable = (IUnknown.Vtbl*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(UnknownCCW), sizeof(IUnknown.Vtbl));

// IUnknown
vtable->QueryInterface_1 = &QueryInterface;
vtable->AddRef_2 = &AddRef;
vtable->Release_3 = &Release;
return vtable;
}

[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
private static HRESULT QueryInterface(IUnknown* @this, Guid* riid, void** ppvObject)
{
if (ppvObject is null)
{
return HRESULT.E_POINTER;
}

var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
if (unknown is null)
{
return HRESULT.COR_E_OBJECTDISPOSED;
}

unknown.QueryInterface(*riid);

if (*riid == typeof(IUnknown).GUID)
{
*ppvObject = @this;
}
else
{
*ppvObject = null;
return HRESULT.E_NOINTERFACE;
}

Lifetime<IUnknown.Vtbl, IUnkownTest>.AddRef(@this);
return HRESULT.S_OK;
}

[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
private static uint AddRef(IUnknown* @this)
{
var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
if (unknown is null)
{
return HRESULT.COR_E_OBJECTDISPOSED;
}

uint current = Lifetime<IUnknown.Vtbl, IUnkownTest>.AddRef(@this);
unknown.AddRef(current);
return current;
}

[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
private static uint Release(IUnknown* @this)
{
var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
if (unknown is null)
{
return HRESULT.COR_E_OBJECTDISPOSED;
}

uint current = Lifetime<IUnknown.Vtbl, IUnkownTest>.Release(@this);
unknown.Release(current);
return current;
}
}
}

0 comments on commit c54532b

Please sign in to comment.