Skip to content

Commit 8dfaac9

Browse files
authored
Mark containing type for generated stub functions as unsafe (#60979)
1 parent d7c96e8 commit 8dfaac9

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
145145
(data, ct) =>
146146
{
147147
IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.CalculateStubInformation, data));
148-
return (data.Syntax, StubContext: CalculateStubInformation(data.Syntax, data.Symbol, data.Environment, ct));
148+
return (data.Syntax, StubContext: CalculateStubInformation(data.Symbol, data.Environment, ct));
149149
}
150150
)
151151
.WithComparer(Comparers.CalculatedContextWithSyntax)
@@ -244,6 +244,17 @@ private static SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenLis
244244
return new SyntaxTokenList(strippedTokens);
245245
}
246246

247+
private static SyntaxTokenList AddToModifiers(SyntaxTokenList modifiers, SyntaxKind modifierToAdd)
248+
{
249+
if (modifiers.IndexOf(modifierToAdd) >= 0)
250+
return modifiers;
251+
252+
int idx = modifiers.IndexOf(SyntaxKind.PartialKeyword);
253+
return idx >= 0
254+
? modifiers.Insert(idx, Token(modifierToAdd))
255+
: modifiers.Add(Token(modifierToAdd));
256+
}
257+
247258
private static TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration)
248259
{
249260
return TypeDeclaration(
@@ -279,6 +290,9 @@ private static MemberDeclarationSyntax WrapMethodInContainingScopes(DllImportStu
279290
MemberDeclarationSyntax containingType = CreateTypeDeclarationWithoutTrivia(stub.StubContainingTypes.First())
280291
.AddMembers(stubMethod);
281292

293+
// Mark containing type as unsafe such that all the generated functions will be in an unsafe context.
294+
containingType = containingType.WithModifiers(AddToModifiers(containingType.Modifiers, SyntaxKind.UnsafeKeyword));
295+
282296
// Add type to the remaining containing types (skipping the first which was handled above)
283297
foreach (TypeDeclarationSyntax typeDecl in stub.StubContainingTypes.Skip(1))
284298
{
@@ -339,8 +353,6 @@ private static GeneratedDllImportData ProcessGeneratedDllImportAttribute(Attribu
339353
bool setLastError = false;
340354
bool throwOnUnmappableChar = false;
341355

342-
var stubDllImportData = new GeneratedDllImportData(attrData.ConstructorArguments[0].Value!.ToString());
343-
344356
// All other data on attribute is defined as NamedArguments.
345357
foreach (KeyValuePair<string, TypedConstant> namedArg in attrData.NamedArguments)
346358
{
@@ -398,7 +410,7 @@ private static GeneratedDllImportData ProcessGeneratedDllImportAttribute(Attribu
398410
};
399411
}
400412

401-
private static IncrementalStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
413+
private static IncrementalStubGenerationContext CalculateStubInformation(IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
402414
{
403415
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
404416
INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
@@ -506,6 +518,10 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
506518
dllImport = dllImport.AddAttributeLists(AttributeList(SeparatedList(forwardedAttributes)));
507519
}
508520

521+
dllImport = dllImport.WithLeadingTrivia(
522+
Comment("//"),
523+
Comment("// Local P/Invoke"),
524+
Comment("//"));
509525
code = code.AddStatements(dllImport);
510526

511527
return (PrintGeneratedSource(originalSyntax, dllImportStub.StubContext, code), dllImportStub.Diagnostics.AddRange(diagnostics.Diagnostics));
@@ -514,7 +530,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
514530
private MemberDeclarationSyntax PrintForwarderStub(MethodDeclarationSyntax userDeclaredMethod, IncrementalStubGenerationContext stub)
515531
{
516532
SyntaxTokenList modifiers = StripTriviaFromModifiers(userDeclaredMethod.Modifiers);
517-
modifiers = modifiers.Insert(modifiers.IndexOf(SyntaxKind.PartialKeyword), Token(SyntaxKind.ExternKeyword));
533+
modifiers = AddToModifiers(modifiers, SyntaxKind.ExternKeyword);
518534
// Create stub function
519535
MethodDeclarationSyntax stubMethod = MethodDeclaration(stub.StubContext.StubReturnType, userDeclaredMethod.Identifier)
520536
.WithModifiers(modifiers)

src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/PInvokeStubCodeGenerator.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
215215
}
216216
}
217217

218+
/// <summary>
219+
/// Generate the method body of the p/invoke stub.
220+
/// </summary>
221+
/// <param name="dllImportName">Name of the target DllImport function to invoke</param>
222+
/// <returns>Method body of the p/invoke stub</returns>
223+
/// <remarks>
224+
/// The generated code assumes it will be in an unsafe context.
225+
/// </remarks>
218226
public BlockSyntax GeneratePInvokeBody(string dllImportName)
219227
{
220228
bool invokeReturnsVoid = _retMarshaller.TypeInfo.ManagedType == SpecialTypeInfo.Void;
@@ -330,8 +338,7 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName)
330338
if (!_stubReturnsVoid)
331339
allStatements.Add(ReturnStatement(IdentifierName(ReturnIdentifier)));
332340

333-
// Wrap all statements in an unsafe block
334-
return Block(UnsafeStatement(Block(allStatements)));
341+
return Block(allStatements);
335342

336343
void GenerateStatementsForStage(Stage stage, List<StatementSyntax> statementsToUpdate)
337344
{

src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/CodeSnippets.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ partial class InnerClass
141141
}
142142
";
143143

144+
/// <summary>
145+
/// Containing type with and without unsafe
146+
/// </summary>
147+
public static readonly string UnsafeContext = @"
148+
using System.Runtime.InteropServices;
149+
partial class Test
150+
{
151+
[GeneratedDllImport(""DoesNotExist"")]
152+
public static partial void Method1();
153+
}
154+
unsafe partial class Test
155+
{
156+
[GeneratedDllImport(""DoesNotExist"")]
157+
public static partial int* Method2();
158+
}
159+
";
144160
/// <summary>
145161
/// Declaration with user defined EntryPoint.
146162
/// </summary>

src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/Compiles.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()
2222
yield return new[] { CodeSnippets.MultipleAttributes };
2323
yield return new[] { CodeSnippets.NestedNamespace };
2424
yield return new[] { CodeSnippets.NestedTypes };
25+
yield return new[] { CodeSnippets.UnsafeContext };
2526
yield return new[] { CodeSnippets.UserDefinedEntryPoint };
2627
yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
2728
yield return new[] { CodeSnippets.DefaultParameters };

0 commit comments

Comments
 (0)