Skip to content

Add two new stages to prepare for our new custom type marshalling design. #70598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,20 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName)
var tryStatements = new List<StatementSyntax>();
tryStatements.AddRange(statements.Marshal);

var invokeStatement = statements.InvokeStatement;
BlockSyntax fixedBlock = Block(statements.PinnedMarshal);
if (_setLastError)
{
StatementSyntax clearLastError = MarshallerHelpers.CreateClearLastSystemErrorStatement(SuccessErrorCode);

StatementSyntax getLastError = MarshallerHelpers.CreateGetLastSystemErrorStatement(LastErrorIdentifier);

invokeStatement = Block(clearLastError, invokeStatement, getLastError);
fixedBlock = fixedBlock.AddStatements(clearLastError, statements.InvokeStatement, getLastError);
}
invokeStatement = statements.Pin.NestFixedStatements(invokeStatement);

tryStatements.Add(invokeStatement);
else
{
fixedBlock = fixedBlock.AddStatements(statements.InvokeStatement);
}
tryStatements.Add(statements.Pin.NestFixedStatements(fixedBlock));
// <invokeSucceeded> = true;
if (!statements.GuaranteedUnmarshal.IsEmpty)
{
Expand All @@ -166,7 +168,7 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName)
LiteralExpression(SyntaxKind.TrueLiteralExpression))));
}

tryStatements.AddRange(statements.KeepAlive);
tryStatements.AddRange(statements.NotifyForSuccessfulInvoke);
tryStatements.AddRange(statements.Unmarshal);

List<StatementSyntax> allStatements = setupStatements;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ public struct GeneratedStatements
public ImmutableArray<StatementSyntax> Setup { get; init; }
public ImmutableArray<StatementSyntax> Marshal { get; init; }
public ImmutableArray<FixedStatementSyntax> Pin { get; init; }
public ImmutableArray<StatementSyntax> PinnedMarshal { get; init; }
public StatementSyntax InvokeStatement { get; init; }
public ImmutableArray<StatementSyntax> Unmarshal { get; init; }
public ImmutableArray<StatementSyntax> KeepAlive { get; init; }
public ImmutableArray<StatementSyntax> NotifyForSuccessfulInvoke { get; init; }
public ImmutableArray<StatementSyntax> GuaranteedUnmarshal { get; init; }
public ImmutableArray<StatementSyntax> Cleanup { get; init; }

Expand All @@ -31,9 +32,11 @@ public static GeneratedStatements Create(BoundGenerators marshallers, StubCodeCo
Setup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Setup }),
Marshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Marshal }),
Pin = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Pin }).Cast<FixedStatementSyntax>().ToImmutableArray(),
PinnedMarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.PinnedMarshal }),
InvokeStatement = GenerateStatementForNativeInvoke(marshallers, context with { CurrentStage = StubCodeContext.Stage.Invoke }, expressionToInvoke),
Unmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal }),
KeepAlive = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.KeepAlive }),
Unmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.UnmarshalCapture })
.AddRange(GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal })),
NotifyForSuccessfulInvoke = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.NotifyForSuccessfulInvoke }),
GuaranteedUnmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.GuaranteedUnmarshal }),
Cleanup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Cleanup }),
};
Expand All @@ -48,7 +51,7 @@ private static ImmutableArray<StatementSyntax> GenerateStatementsForStubContext(
statementsToUpdate.AddRange(retStatements);
}

if (context.CurrentStage is StubCodeContext.Stage.Unmarshal or StubCodeContext.Stage.GuaranteedUnmarshal)
if (context.CurrentStage is StubCodeContext.Stage.UnmarshalCapture or StubCodeContext.Stage.Unmarshal or StubCodeContext.Stage.GuaranteedUnmarshal)
{
// For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted
// marshaller list to generate the marshalling statements
Expand Down Expand Up @@ -113,10 +116,12 @@ private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage)
StubCodeContext.Stage.Setup => "Perform required setup.",
StubCodeContext.Stage.Marshal => "Convert managed data to native data.",
StubCodeContext.Stage.Pin => "Pin data in preparation for calling the P/Invoke.",
StubCodeContext.Stage.PinnedMarshal => "Convert managed data to native data that requires the managed data to be pinned.",
StubCodeContext.Stage.Invoke => "Call the P/Invoke.",
StubCodeContext.Stage.UnmarshalCapture => "Capture the native data into marshaller instances in case conversion to managed data throws an exception.",
StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.",
StubCodeContext.Stage.Cleanup => "Perform required cleanup.",
StubCodeContext.Stage.KeepAlive => "Keep alive any managed objects that need to stay alive across the call.",
StubCodeContext.Stage.NotifyForSuccessfulInvoke => "Keep alive any managed objects that need to stay alive across the call.",
StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.",
_ => throw new ArgumentOutOfRangeException(nameof(stage))
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,13 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
{
return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, marshallingStrategy);
}

if (marshalInfo.NativeValueType is not null)
else if (marshalInfo.NativeValueType is not null)
{
marshallingStrategy = DecorateWithTwoStageMarshallingStrategy(marshalInfo, marshallingStrategy);
marshallingStrategy = new CustomNativeTypeWithToFromNativeValueMarshalling(marshallingStrategy, marshalInfo.NativeValueType.Syntax);
if (marshalInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && marshalInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling))
{
marshallingStrategy = new PinnableMarshallerTypeMarshalling(marshallingStrategy);
}
}

IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
Expand Down Expand Up @@ -283,18 +286,6 @@ private static void ValidateCustomNativeTypeMarshallingSupported(TypePositionInf
}
}

private static ICustomNativeTypeMarshallingStrategy DecorateWithTwoStageMarshallingStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller)
{
TypeSyntax nativeValueTypeSyntax = marshalInfo.NativeValueType!.Syntax;

if (marshalInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && marshalInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling))
{
return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, nativeValueTypeSyntax);
}

return new CustomNativeTypeWithToFromNativeValueMarshalling(nativeTypeMarshaller, nativeValueTypeSyntax);
}

private IMarshallingGenerator CreateNativeCollectionMarshaller(
TypePositionInfo info,
StubCodeContext context,
Expand Down Expand Up @@ -324,10 +315,10 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new LinearCollectionWithNonBlittableElementsMarshalling(marshallingStrategy, elementMarshaller, elementInfo, numElementsExpression);
}

// Explicitly insert the Value property handling here (before numElements handling) so that the numElements handling will be emitted before the Value property handling in unmarshalling.
if (collectionInfo.NativeValueType is not null)
marshallingStrategy = new CustomNativeTypeWithToFromNativeValueMarshalling(marshallingStrategy, collectionInfo.NativeValueType.Syntax);
if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.NativeType) && collectionInfo.MarshallingFeatures.HasFlag(CustomTypeMarshallerFeatures.TwoStageMarshalling))
{
marshallingStrategy = DecorateWithTwoStageMarshallingStrategy(collectionInfo, marshallingStrategy);
marshallingStrategy = new PinnableMarshallerTypeMarshalling(marshallingStrategy);
}

TypeSyntax nativeElementType = elementMarshaller.AsNativeType(elementInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
return _nativeTypeMarshaller.GeneratePinStatements(info, context);
}
break;
case StubCodeContext.Stage.PinnedMarshal:
if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out)
{
return _nativeTypeMarshaller.GeneratePinnedMarshalStatements(info, context);
}
break;
case StubCodeContext.Stage.UnmarshalCapture:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
{
return _nativeTypeMarshaller.GenerateUnmarshalCaptureStatements(info, context);
}
break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)
|| (_enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
LiteralExpression(SyntaxKind.NullLiteralExpression))));
}
break;
case StubCodeContext.Stage.KeepAlive:
case StubCodeContext.Stage.NotifyForSuccessfulInvoke:
if (info.RefKind != RefKind.Out)
{
yield return ExpressionStatement(
Expand Down
Loading