diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs index 856a7b4d1c2142..66b7c5f2f85442 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs @@ -146,18 +146,20 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName) var tryStatements = new List(); tryStatements.AddRange(statements.Marshal); - var invokeStatement = statements.InvokeStatement; + BlockSyntax fixedBlock = Block(statements.PinnedMarshal); if (_setLastError) { StatementSyntax clearLastError = MarshallerHelpers.CreateClearLastSystemErrorStatement(SuccessErrorCode); StatementSyntax getLastError = MarshallerHelpers.CreateGetLastSystemErrorStatement(LastErrorIdentifier); - invokeStatement = Block(clearLastError, invokeStatement, getLastError); + fixedBlock = fixedBlock.AddStatements(clearLastError, statements.InvokeStatement, getLastError); } - invokeStatement = statements.Pin.NestFixedStatements(invokeStatement); - - tryStatements.Add(invokeStatement); + else + { + fixedBlock = fixedBlock.AddStatements(statements.InvokeStatement); + } + tryStatements.Add(statements.Pin.NestFixedStatements(fixedBlock)); // = true; if (!statements.GuaranteedUnmarshal.IsEmpty) { @@ -166,7 +168,7 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName) LiteralExpression(SyntaxKind.TrueLiteralExpression)))); } - tryStatements.AddRange(statements.KeepAlive); + tryStatements.AddRange(statements.NotifyForSuccessfulInvoke); tryStatements.AddRange(statements.Unmarshal); List allStatements = setupStatements; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs index 50572f3c6da4df..5f46ca36737818 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs @@ -18,9 +18,10 @@ public struct GeneratedStatements public ImmutableArray Setup { get; init; } public ImmutableArray Marshal { get; init; } public ImmutableArray Pin { get; init; } + public ImmutableArray PinnedMarshal { get; init; } public StatementSyntax InvokeStatement { get; init; } public ImmutableArray Unmarshal { get; init; } - public ImmutableArray KeepAlive { get; init; } + public ImmutableArray NotifyForSuccessfulInvoke { get; init; } public ImmutableArray GuaranteedUnmarshal { get; init; } public ImmutableArray Cleanup { get; init; } @@ -31,9 +32,11 @@ public static GeneratedStatements Create(BoundGenerators marshallers, StubCodeCo Setup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Setup }), Marshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Marshal }), Pin = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Pin }).Cast().ToImmutableArray(), + PinnedMarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.PinnedMarshal }), InvokeStatement = GenerateStatementForNativeInvoke(marshallers, context with { CurrentStage = StubCodeContext.Stage.Invoke }, expressionToInvoke), - Unmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal }), - KeepAlive = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.KeepAlive }), + Unmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.UnmarshalCapture }) + .AddRange(GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal })), + NotifyForSuccessfulInvoke = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.NotifyForSuccessfulInvoke }), GuaranteedUnmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.GuaranteedUnmarshal }), Cleanup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Cleanup }), }; @@ -48,7 +51,7 @@ private static ImmutableArray GenerateStatementsForStubContext( statementsToUpdate.AddRange(retStatements); } - if (context.CurrentStage is StubCodeContext.Stage.Unmarshal or StubCodeContext.Stage.GuaranteedUnmarshal) + if (context.CurrentStage is StubCodeContext.Stage.UnmarshalCapture or StubCodeContext.Stage.Unmarshal or StubCodeContext.Stage.GuaranteedUnmarshal) { // For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted // marshaller list to generate the marshalling statements @@ -113,10 +116,12 @@ private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage) StubCodeContext.Stage.Setup => "Perform required setup.", StubCodeContext.Stage.Marshal => "Convert managed data to native data.", StubCodeContext.Stage.Pin => "Pin data in preparation for calling the P/Invoke.", + StubCodeContext.Stage.PinnedMarshal => "Convert managed data to native data that requires the managed data to be pinned.", StubCodeContext.Stage.Invoke => "Call the P/Invoke.", + StubCodeContext.Stage.UnmarshalCapture => "Capture the native data into marshaller instances in case conversion to managed data throws an exception.", StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.", StubCodeContext.Stage.Cleanup => "Perform required cleanup.", - StubCodeContext.Stage.KeepAlive => "Keep alive any managed objects that need to stay alive across the call.", + StubCodeContext.Stage.NotifyForSuccessfulInvoke => "Keep alive any managed objects that need to stay alive across the call.", StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.", _ => throw new ArgumentOutOfRangeException(nameof(stage)) }; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index e2d5a0513846af..80afc1731f0c94 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -212,10 +212,13 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo { return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, marshallingStrategy); } - - if (marshalInfo.NativeValueType is not null) + else if (marshalInfo.NativeValueType is not null) { - marshallingStrategy = DecorateWithTwoStageMarshallingStrategy(marshalInfo, marshallingStrategy); + marshallingStrategy = new CustomNativeTypeWithToFromNativeValueMarshalling(marshallingStrategy, marshalInfo.NativeValueType.Syntax); + if (marshalInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && marshalInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling)) + { + marshallingStrategy = new PinnableMarshallerTypeMarshalling(marshallingStrategy); + } } IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); @@ -283,18 +286,6 @@ private static void ValidateCustomNativeTypeMarshallingSupported(TypePositionInf } } - private static ICustomNativeTypeMarshallingStrategy DecorateWithTwoStageMarshallingStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) - { - TypeSyntax nativeValueTypeSyntax = marshalInfo.NativeValueType!.Syntax; - - if (marshalInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && marshalInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling)) - { - return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, nativeValueTypeSyntax); - } - - return new CustomNativeTypeWithToFromNativeValueMarshalling(nativeTypeMarshaller, nativeValueTypeSyntax); - } - private IMarshallingGenerator CreateNativeCollectionMarshaller( TypePositionInfo info, StubCodeContext context, @@ -324,10 +315,10 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new LinearCollectionWithNonBlittableElementsMarshalling(marshallingStrategy, elementMarshaller, elementInfo, numElementsExpression); } - // Explicitly insert the Value property handling here (before numElements handling) so that the numElements handling will be emitted before the Value property handling in unmarshalling. - if (collectionInfo.NativeValueType is not null) + marshallingStrategy = new CustomNativeTypeWithToFromNativeValueMarshalling(marshallingStrategy, collectionInfo.NativeValueType.Syntax); + if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && collectionInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling)) { - marshallingStrategy = DecorateWithTwoStageMarshallingStrategy(collectionInfo, marshallingStrategy); + marshallingStrategy = new PinnableMarshallerTypeMarshalling(marshallingStrategy); } TypeSyntax nativeElementType = elementMarshaller.AsNativeType(elementInfo); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs index c770312495087c..de758a34d23e0b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomNativeTypeMarshallingGenerator.cs @@ -63,6 +63,18 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont return _nativeTypeMarshaller.GeneratePinStatements(info, context); } break; + case StubCodeContext.Stage.PinnedMarshal: + if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + { + return _nativeTypeMarshaller.GeneratePinnedMarshalStatements(info, context); + } + break; + case StubCodeContext.Stage.UnmarshalCapture: + if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + { + return _nativeTypeMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + break; case StubCodeContext.Stage.Unmarshal: if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In) || (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs index ff9873c93b3b9e..00c7f6e2b941b0 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/DelegateMarshaller.cs @@ -87,7 +87,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont LiteralExpression(SyntaxKind.NullLiteralExpression)))); } break; - case StubCodeContext.Stage.KeepAlive: + case StubCodeContext.Stage.NotifyForSuccessfulInvoke: if (info.RefKind != RefKind.Out) { yield return ExpressionStatement( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomNativeTypeMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomNativeTypeMarshallingStrategy.cs index 5ba85d6806dc15..0d89208515d78f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomNativeTypeMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomNativeTypeMarshallingStrategy.cs @@ -22,12 +22,16 @@ internal interface ICustomNativeTypeMarshallingStrategy IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments); + IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context); + IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context); IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context); IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context); + IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context); + IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context); bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context); @@ -71,6 +75,11 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i .WithArgumentList(ArgumentList(SeparatedList(nativeTypeConstructorArguments))))); } + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { // If the current element is being marshalled by-value [Out], then don't call the ToManaged method and do the assignment. @@ -92,6 +101,11 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo IdentifierName(ShapeMemberNames.Value.ToManaged))))); } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { yield return Argument(IdentifierName(context.GetIdentifiers(info).managed)); @@ -181,6 +195,16 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i { yield return statement; } + } + + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); + + foreach (StatementSyntax statement in _innerMarshaller.GeneratePinnedMarshalStatements(info, subContext)) + { + yield return statement; + } // = .ToNativeValue(); yield return ExpressionStatement( @@ -205,7 +229,7 @@ private static StatementSyntax GenerateFromNativeValueInvocation(TypePositionInf ArgumentList(SingletonSeparatedList(Argument(IdentifierName(context.GetIdentifiers(info).native)))))); } - public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) { var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); @@ -213,6 +237,11 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo { yield return GenerateFromNativeValueInvocation(info, context, subContext); } + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, subContext)) { @@ -300,6 +329,11 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i } } + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + } + private static bool StackAllocOptimizationValid(TypePositionInfo info, StubCodeContext context) { return context.SingleFrameSpansNativeContext && (!info.IsByRef || info.RefKind == RefKind.In); @@ -320,6 +354,11 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf return _innerMarshaller.GenerateSetupStatements(info, context); } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GenerateUnmarshalStatements(info, context); @@ -390,6 +429,11 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i return _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); } + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + } + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GeneratePinStatements(info, context); @@ -400,6 +444,11 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf return _innerMarshaller.GenerateSetupStatements(info, context); } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GenerateUnmarshalStatements(info, context); @@ -422,71 +471,36 @@ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) internal sealed class PinnableMarshallerTypeMarshalling : ICustomNativeTypeMarshallingStrategy { private readonly ICustomNativeTypeMarshallingStrategy _innerMarshaller; - private readonly TypeSyntax _nativeValueType; - public PinnableMarshallerTypeMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, TypeSyntax nativeValueType) + public PinnableMarshallerTypeMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller) { _innerMarshaller = innerMarshaller; - _nativeValueType = nativeValueType; - } - - private static bool CanPinMarshaller(TypePositionInfo info, StubCodeContext context) - { - return context.SingleFrameSpansNativeContext && !info.IsManagedReturnPosition && !info.IsByRef || info.RefKind == RefKind.In; } public TypeSyntax AsNativeType(TypePositionInfo info) { - return _nativeValueType; + return _innerMarshaller.AsNativeType(info); } public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { - var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); - - if (!context.AdditionalTemporaryStateLivesAcrossStages) - { - // .Value = ; - yield return GenerateFromNativeValueInvocation(info, context, subContext); - } - - foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, subContext)) - { - yield return statement; - } + return _innerMarshaller.GenerateCleanupStatements(info, context); } public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) { - var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); - foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, subContext, nativeTypeConstructorArguments)) - { - yield return statement; - } - - if (!CanPinMarshaller(info, context)) - yield return GenerateToNativeValueInvocation(info, context, subContext); + return _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); } - private static StatementSyntax GenerateToNativeValueInvocation(TypePositionInfo info, StubCodeContext context, CustomNativeTypeWithToFromNativeValueContext subContext) + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) { - // = .ToNativeValue(); - return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(context.GetIdentifiers(info).native), - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(subContext.GetIdentifiers(info).native), - IdentifierName(ShapeMemberNames.Value.ToNativeValue)), - ArgumentList()))); + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); } public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) { // The type of the ignored identifier isn't relevant, so we use void* for all. - // fixed (void* = &) - // + // fixed (void* = &); var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); yield return FixedStatement( VariableDeclaration( @@ -495,56 +509,27 @@ public IEnumerable GeneratePinStatements(TypePositionInfo info, VariableDeclarator(Identifier(context.GetAdditionalIdentifier(info, "ignored"))) .WithInitializer(EqualsValueClause( IdentifierName(subContext.GetIdentifiers(info).native))))), - GenerateToNativeValueInvocation(info, context, subContext)); + EmptyStatement()); } public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) { - var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); - yield return LocalDeclarationStatement( - VariableDeclaration( - _innerMarshaller.AsNativeType(info), - SingletonSeparatedList( - VariableDeclarator(subContext.GetIdentifiers(info).native) - .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression)))))); - - foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, subContext)) - { - yield return statement; - } + return _innerMarshaller.GenerateSetupStatements(info, context); } - private static StatementSyntax GenerateFromNativeValueInvocation(TypePositionInfo info, StubCodeContext context, CustomNativeTypeWithToFromNativeValueContext subContext) + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) { - // .FromNativeValue(); - return ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(subContext.GetIdentifiers(info).native), - IdentifierName(ShapeMemberNames.Value.FromNativeValue)), - ArgumentList(SingletonSeparatedList(Argument(IdentifierName(context.GetIdentifiers(info).native)))))); + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); } public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { - var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); - - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) - { - // .Value = ; - yield return GenerateFromNativeValueInvocation(info, context, subContext); - } - - foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, subContext)) - { - yield return statement; - } + return _innerMarshaller.GenerateUnmarshalStatements(info, context); } public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { - var subContext = new CustomNativeTypeWithToFromNativeValueContext(context); - return _innerMarshaller.GetNativeTypeConstructorArguments(info, subContext); + return _innerMarshaller.GetNativeTypeConstructorArguments(info, context); } public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) @@ -596,6 +581,12 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i return _innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); } + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GeneratePinStatements(info, context); @@ -606,18 +597,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf return _innerMarshaller.GenerateSetupStatements(info, context); } - private IEnumerable GenerateUnmarshallerCollectionInitialization(TypePositionInfo info, StubCodeContext context) - { - string marshalerIdentifier = MarshallerHelpers.GetMarshallerIdentifier(info, context); - if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) - { - yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(marshalerIdentifier), - ImplicitObjectCreationExpression().AddArgumentListArguments(Argument(_sizeOfElementExpression)))); - } - } - - public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) { // To fulfill the generic contiguous collection marshaller design, // we need to emit code to initialize the collection marshaller with the size of native elements @@ -628,13 +608,28 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo { yield return statement; } - - foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, context)) + foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context)) { yield return statement; } } + private IEnumerable GenerateUnmarshallerCollectionInitialization(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = MarshallerHelpers.GetMarshallerIdentifier(info, context); + if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) + { + yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(marshalerIdentifier), + ImplicitObjectCreationExpression().AddArgumentListArguments(Argument(_sizeOfElementExpression)))); + } + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalStatements(info, context); + } + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { foreach (ArgumentSyntax arg in _innerMarshaller.GetNativeTypeConstructorArguments(info, context)) @@ -741,6 +736,11 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i ArgumentList())))))); } + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + } + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GeneratePinStatements(info, context); @@ -751,6 +751,11 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf return _innerMarshaller.GenerateSetupStatements(info, context); } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { string nativeIdentifier = context.GetIdentifiers(info).native; @@ -999,7 +1004,7 @@ private LocalDeclarationStatementSyntax GeneratedManagedValuesDestinationDeclara ArgumentList(SingletonSeparatedList(Argument(IdentifierName(numElementsIdentifier)))))))))); } - private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, ExpressionSyntax lengthExpression) + private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, ExpressionSyntax lengthExpression, params StubCodeContext.Stage[] stagesToGeneratePerElement) { string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(info, context); string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context); @@ -1008,11 +1013,6 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in managedSpanIdentifier, nativeSpanIdentifier, context); - var elementSubContext = new LinearCollectionElementMarshallingCodeContext( - context.CurrentStage, - managedSpanIdentifier, - nativeSpanIdentifier, - context); TypePositionInfo localElementInfo = _elementInfo with { @@ -1022,7 +1022,13 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in NativeIndex = info.NativeIndex }; - List elementStatements = _elementMarshaller.Generate(localElementInfo, elementSubContext).ToList(); + List elementStatements = new(); + + foreach (var stage in stagesToGeneratePerElement) + { + var elementSubContext = elementSetupSubContext with { CurrentStage = stage }; + elementStatements.AddRange(_elementMarshaller.Generate(localElementInfo, elementSubContext)); + } if (elementStatements.Any()) { @@ -1032,12 +1038,12 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in if (_elementMarshaller.AsNativeType(_elementInfo) is PointerTypeSyntax elementNativeType) { - PointerNativeTypeAssignmentRewriter rewriter = new(elementSubContext.GetIdentifiers(localElementInfo).native, elementNativeType); + PointerNativeTypeAssignmentRewriter rewriter = new(elementSetupSubContext.GetIdentifiers(localElementInfo).native, elementNativeType); marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement); } // Iterate through the elements of the native collection to unmarshal them - return MarshallerHelpers.GetForLoop(lengthExpression, elementSubContext.IndexerIdentifier) + return MarshallerHelpers.GetForLoop(lengthExpression, elementSetupSubContext.IndexerIdentifier) .WithStatement(marshallingStatement); } return EmptyStatement(); @@ -1053,7 +1059,8 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)), - IdentifierName("Length"))); + IdentifierName("Length")), + StubCodeContext.Stage.Cleanup); if (!contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement)) { @@ -1101,9 +1108,21 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i GenerateContentsMarshallingStatement(info, context, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(MarshallerHelpers.GetManagedSpanIdentifier(info, context)), - IdentifierName("Length")))); + IdentifierName("Length")), + StubCodeContext.Stage.Marshal, + // Using the PinnedMarshal stage here isn't strictly valid as we don't guarantee that GetPinnableReference + // is pinned, but for our existing marshallers this is not an issue and we'll be removing support for stateful element marshallers soon + // (at which point we can remove this) + // and address this problem better when we bring them back in the future. + StubCodeContext.Stage.PinnedMarshal)); + } + + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GeneratePinnedMarshalStatements(info, context); } + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) { return _innerMarshaller.GeneratePinStatements(info, context); @@ -1114,6 +1133,11 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf return _innerMarshaller.GenerateSetupStatements(info, context); } + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + } + private StatementSyntax GenerateByValueUnmarshalStatement(TypePositionInfo info, StubCodeContext context) { // Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents, @@ -1186,7 +1210,9 @@ private StatementSyntax GenerateByValueUnmarshalStatement(TypePositionInfo info, managedValuesDeclaration, nativeValuesDeclaration, GenerateContentsMarshallingStatement(info, context, - IdentifierName(numElementsIdentifier))); + IdentifierName(numElementsIdentifier), + StubCodeContext.Stage.UnmarshalCapture, + StubCodeContext.Stage.Unmarshal)); } public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -1206,7 +1232,9 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo GeneratedManagedValuesDestinationDeclaration(info, context, numElementsIdentifier), GenerateNativeValuesSourceDeclaration(info, context, numElementsIdentifier), GenerateContentsMarshallingStatement(info, context, - IdentifierName(numElementsIdentifier))); + IdentifierName(numElementsIdentifier), + StubCodeContext.Stage.UnmarshalCapture, + StubCodeContext.Stage.Unmarshal)); } foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, context)) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs index 84d17e8c88ed91..9822d1cf70ba9a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs @@ -33,6 +33,11 @@ public enum Stage /// Pin, + /// + /// Convert managed data to native data, assuming that any values pinned in the stage are pinned. + /// + PinnedMarshal, + /// /// Call the generated P/Invoke /// @@ -42,20 +47,26 @@ public enum Stage /// Invoke, + /// + /// Capture native values to ensure that we do not leak if an exception is thrown during unmarshalling + /// + UnmarshalCapture, + /// /// Convert native data to managed data /// Unmarshal, /// - /// Perform any cleanup required + /// Notify a marshaller object that the Invoke stage and all stages preceeding the Invoke stage + /// successfully completed without any exceptions. /// - Cleanup, + NotifyForSuccessfulInvoke, /// - /// Keep alive any managed objects that need to stay alive across the call. + /// Perform any cleanup required /// - KeepAlive, + Cleanup, /// /// Convert native data to managed data even in the case of an exception during diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SyntaxExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SyntaxExtensions.cs index 02a0d80743d59e..2457328d2e6c5b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SyntaxExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SyntaxExtensions.cs @@ -52,7 +52,7 @@ public static StatementSyntax NestFixedStatements(this ImmutableArray= 0; i--) { @@ -60,6 +60,15 @@ public static StatementSyntax NestFixedStatements(this ImmutableArray