diff --git a/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift b/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift index 14f30c50d86..93eb135f30b 100644 --- a/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift +++ b/Sources/SwiftCompilerPluginMessageHandling/CompilerPluginMessageHandler.swift @@ -30,6 +30,9 @@ import SwiftSyntaxMacros @_spi(PluginMessage) public enum PluginFeature: String { case loadPluginLibrary = "load-plugin-library" + + /// Whether the plugin knows how to infer nonisolated conformances. + case inferNonisolatedConformances = "infer-nonisolated-conformances" } /// A type that provides the actual plugin functions. diff --git a/Sources/SwiftLibraryPluginProvider/LibraryPluginProvider.swift b/Sources/SwiftLibraryPluginProvider/LibraryPluginProvider.swift index 7cb686a0622..7e771081bc1 100644 --- a/Sources/SwiftLibraryPluginProvider/LibraryPluginProvider.swift +++ b/Sources/SwiftLibraryPluginProvider/LibraryPluginProvider.swift @@ -84,7 +84,7 @@ public class LibraryPluginProvider: PluginProvider { public static let shared: LibraryPluginProvider = LibraryPluginProvider() public var features: [PluginFeature] { - [.loadPluginLibrary] + [.loadPluginLibrary, .inferNonisolatedConformances] } public func loadPluginLibrary(libraryPath: String, moduleName: String) throws { diff --git a/Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt b/Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt index 892a2e30fe6..db4ad44325a 100644 --- a/Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt +++ b/Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt @@ -8,6 +8,7 @@ add_swift_syntax_library(SwiftSyntaxMacroExpansion MacroReplacement.swift MacroSpec.swift MacroSystem.swift + SyntaxProtocol+NonisolatedConformances.swift ) target_link_swift_syntax_libraries(SwiftSyntaxMacroExpansion PUBLIC diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift index ffd31a16ef4..357aba23cf5 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift @@ -156,7 +156,7 @@ public func expandFreestandingMacro( (.codeItem, _), (.preamble, _), (.body, _): throw MacroExpansionError.unmatchedMacroRole(definition, macroRole) } - return expandedSyntax.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + return expandedSyntax.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } catch { context.addDiagnostics(from: error, node: node) return nil @@ -273,7 +273,7 @@ public func expandAttachedMacroWithoutCollapsing in: context ) return accessors.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as MemberAttributeMacro.Type, .memberAttribute): @@ -294,7 +294,7 @@ public func expandAttachedMacroWithoutCollapsing // Form a buffer containing an attribute list to return to the caller. return attributes.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as MemberMacro.Type, .member): @@ -313,7 +313,7 @@ public func expandAttachedMacroWithoutCollapsing // Form a buffer of member declarations to return to the caller. return members.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as PeerMacro.Type, .peer): @@ -326,7 +326,7 @@ public func expandAttachedMacroWithoutCollapsing // Form a buffer of peer declarations to return to the caller. return peers.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as ExtensionMacro.Type, .extension): @@ -357,7 +357,7 @@ public func expandAttachedMacroWithoutCollapsing // Form a buffer of peer declarations to return to the caller. return extensions.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as PreambleMacro.Type, .preamble): @@ -375,7 +375,7 @@ public func expandAttachedMacroWithoutCollapsing in: context ) return preamble.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } case (let attachedMacro as BodyMacro.Type, .body): @@ -400,7 +400,7 @@ public func expandAttachedMacroWithoutCollapsing } return body.map { - $0.formattedExpansion(definition.formatMode, indentationWidth: indentationWidth) + $0.adjustedMacroExpansion(for: definition, indentationWidth: indentationWidth) } default: @@ -511,15 +511,29 @@ public func expandAttachedMacro( } fileprivate extension SyntaxProtocol { - /// Perform a format if required and then trim any leading/trailing - /// whitespace. - func formattedExpansion(_ mode: FormatMode, indentationWidth: Trivia?) -> String { - switch mode { + /// Perform post-expansion adjustments to the result of a macro expansion. + /// + /// This applies adjustments to the result of a macro expansion to normalize + /// it for use in later tools. Each of the adjustments here should have a + /// corresponding configuration option in the `Macro` protocol. + func adjustedMacroExpansion( + for macro: Macro.Type, + indentationWidth: Trivia? + ) -> String { + var syntax = Syntax(self) + + // Infer nonisolated conformances. + if macro.inferNonisolatedConformances { + syntax = syntax.inferNonisolatedConformances() + } + + // Formatting. + switch macro.formatMode { case .auto: - return self.formatted(using: BasicFormat(indentationWidth: indentationWidth)) + return syntax.formatted(using: BasicFormat(indentationWidth: indentationWidth)) .trimmedDescription(matching: \.isWhitespace) case .disabled: - return Syntax(self).description + return syntax.description #if RESILIENT_LIBRARIES @unknown default: fatalError() diff --git a/Sources/SwiftSyntaxMacroExpansion/SyntaxProtocol+NonisolatedConformances.swift b/Sources/SwiftSyntaxMacroExpansion/SyntaxProtocol+NonisolatedConformances.swift new file mode 100644 index 00000000000..7bb22bc732a --- /dev/null +++ b/Sources/SwiftSyntaxMacroExpansion/SyntaxProtocol+NonisolatedConformances.swift @@ -0,0 +1,147 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file implements inference of "nonisolated" on the conformances that +// occur within macro-expanded code. It's meant to provide source compatibility +// + +import SwiftSyntax + +extension SyntaxProtocol { + /// Given some Swift syntax that may contain type definitions and extensions, + /// add "nonisolated" to protocol conformances when there are nonisolated + /// members. For example, given: + /// + /// extension X: P { + /// nonisolated func f() { } + /// } + /// + /// this operation will produce: + /// + /// extension X: nonisolated P { + /// nonisolated func f() { } + /// } + @_spi(Testing) @_spi(Compiler) + public func inferNonisolatedConformances() -> Syntax { + let rewriter = NonisolatedConformanceRewriter() + return rewriter.rewrite(self) + } +} + +fileprivate class NonisolatedConformanceRewriter: SyntaxRewriter { + override func visitAny(_ node: Syntax) -> Syntax? { + // We only care about decl groups (non-protocol nominal types + extensions) + // that have nonisolated members and an inheritance clause. + guard let declGroup = node.asProtocol(DeclGroupSyntax.self), + !declGroup.is(ProtocolDeclSyntax.self), + declGroup.containsNonisolatedMembers, + let inheritanceClause = declGroup.inheritanceClause + else { + return nil + } + + var skipFirst = + declGroup.is(ClassDeclSyntax.self) + || (declGroup.is(EnumDeclSyntax.self) && inheritanceClause.inheritedTypes.first?.looksLikeEnumRawType ?? false) + let inheritedTypes = inheritanceClause.inheritedTypes.map { inheritedType in + // If there's already a 'nonisolated' or some kind of custom attribute + if inheritedType.type.hasNonisolatedOrCustomAttribute { + return inheritedType + } + + if skipFirst { + skipFirst = false + return inheritedType + } + + return inheritedType.with(\.type, "nonisolated \(inheritedType.type)") + } + + return Syntax( + fromProtocol: declGroup.with( + \.inheritanceClause, + inheritanceClause.with( + \.inheritedTypes, + InheritedTypeListSyntax(inheritedTypes) + ) + ) + ) + } +} + +extension TypeSyntax { + /// Determine whether the given type has a 'nonisolated' specifier or a + /// custom attribute (that could be a global actor). + fileprivate var hasNonisolatedOrCustomAttribute: Bool { + var type = self + while let attributedType = type.as(AttributedTypeSyntax.self) { + // nonisolated + let hasNonisolated = attributedType.specifiers.contains { specifier in + if case .nonisolatedTypeSpecifier = specifier { + return true + } + + return false + } + if hasNonisolated { + return true + } + + // Any attribute will do. + if !attributedType.attributes.isEmpty { + return true + } + + type = attributedType.baseType + } + + return false + } +} + +extension InheritedTypeSyntax { + /// Determine whether this inherited type "looks like" a raw type, e.g., + /// if it's one of the integer types or String. This can only be an heuristic, + /// because it does not + fileprivate var looksLikeEnumRawType: Bool { + // TODO: We could probably use a utility to syntactically recognize types + // from the + var text = type.trimmed.description[...] + if text.starts(with: "Swift.") { + text = text.dropFirst(6) + } + + switch text { + case "Int", "Int8", "Int16", "Int32", "Int64", + "UInt", "UInt8", "UInt16", "UInt32", "UInt64", + "String": + return true + + default: return false + } + } +} +extension DeclModifierListSyntax { + /// Whether the modifier list contains "nonisolated". + fileprivate var hasNonisolated: Bool { + contains { $0.name.tokenKind == .keyword(.nonisolated) } + } +} + +extension DeclGroupSyntax { + /// Determine whether any of members is marked "nonisolated. + fileprivate var containsNonisolatedMembers: Bool { + memberBlock.members.lazy.map(\.decl).contains { + $0.asProtocol(WithModifiersSyntax.self)?.modifiers.hasNonisolated ?? false + } + } +} diff --git a/Sources/SwiftSyntaxMacros/MacroProtocols/Macro.swift b/Sources/SwiftSyntaxMacros/MacroProtocols/Macro.swift index 6203a6f9127..a5dba4fbac6 100644 --- a/Sources/SwiftSyntaxMacros/MacroProtocols/Macro.swift +++ b/Sources/SwiftSyntaxMacros/MacroProtocols/Macro.swift @@ -15,4 +15,29 @@ public protocol Macro { /// How the resulting expansion should be formatted, `.auto` by default. /// Use `.disabled` for the expansion to be used as is. static var formatMode: FormatMode { get } + + /// Whether to infer "nonisolated" on protocol conformances introduced in + /// the macro expansion when there are some nonisolated members in the + /// corresponding declaration group. When true, macro expansion will adjust + /// expanded code such as + /// + /// extension C: P { + /// nonisolated func f() { } + /// } + /// + /// to + /// + /// extension C: nonisolated P { + /// nonisolated func f() { } + /// } + /// + /// This operation defaults to `true`. Macros can implement it to return + /// `false` to prevent this adjustment to the macro-expanded code. + static var inferNonisolatedConformances: Bool { get } +} + +extension Macro { + /// Default implementation of the Macro protocol's + /// `inferNonisolatedConformances` that returns `true`. + public static var inferNonisolatedConformances: Bool { true } } diff --git a/Tests/SwiftSyntaxMacroExpansionTest/ExtensionMacroTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/ExtensionMacroTests.swift index eb6dae3773f..57cab4cb0ba 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/ExtensionMacroTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/ExtensionMacroTests.swift @@ -251,6 +251,41 @@ final class ExtensionMacroTests: XCTestCase { indentationWidth: indentationWidth ) } + + func testNonisolatedConformances() { + struct NonisolatedConformanceMacro: ExtensionMacro { + static func expansion( + of node: AttributeSyntax, + attachedTo declaration: some DeclGroupSyntax, + providingExtensionsOf type: some TypeSyntaxProtocol, + conformingTo protocols: [TypeSyntax], + in context: some MacroExpansionContext + ) throws -> [ExtensionDeclSyntax] { + return [ + (""" + extension \(type): P { + nonisolated func f() { } + } + """ as DeclSyntax).cast(ExtensionDeclSyntax.self) + ] + } + } + + assertMacroExpansion( + "@NonisolatedConformance struct Foo {}", + expandedSource: """ + struct Foo {} + + extension Foo: nonisolated P { + nonisolated func f() { + } + } + """, + macros: [ + "NonisolatedConformance": NonisolatedConformanceMacro.self + ] + ) + } } fileprivate struct SendableExtensionMacro: ExtensionMacro { diff --git a/Tests/SwiftSyntaxMacroExpansionTest/InferNonisolatedConformancesTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/InferNonisolatedConformancesTests.swift new file mode 100644 index 00000000000..d416bcfdb34 --- /dev/null +++ b/Tests/SwiftSyntaxMacroExpansionTest/InferNonisolatedConformancesTests.swift @@ -0,0 +1,161 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import SwiftSyntax +@_spi(Testing) import SwiftSyntaxMacroExpansion +import XCTest +import _SwiftSyntaxTestSupport + +final class InferNonisolatedConformancesTests: XCTestCase { + func testAddNonisolatedSimple() { + assertInferNonisolatedConformances( + """ + struct MyStruct: P, Q { + nonisolated func f() { } + } + """, + """ + struct MyStruct: nonisolated P, nonisolated Q { + nonisolated func f() { } + } + """ + ) + } + + func testAddNonisolatedNested() { + assertInferNonisolatedConformances( + """ + extension MyStruct: P, Q { + nonisolated func f() { } + + actor Inner: nonisolated R { + nonisolated var value: Int { 0 } + } + } + """, + """ + extension MyStruct: nonisolated P, nonisolated Q { + nonisolated func f() { } + + actor Inner: nonisolated R { + nonisolated var value: Int { 0 } + } + } + """ + ) + } + + func testNoAddWhenNoNonIsolated() { + assertInferNonisolatedConformances( + """ + struct MyStruct: P, Q { + func f() { } + } + """, + """ + struct MyStruct: P, Q { + func f() { } + } + """ + ) + } + + func testNoAddWhenExplicit() { + assertInferNonisolatedConformances( + """ + struct MyStruct: P, nonisolated Q, @MainActor R, S { + nonisolated func f() { } + } + """, + """ + struct MyStruct: nonisolated P, nonisolated Q, @MainActor R, nonisolated S { + nonisolated func f() { } + } + """ + ) + } + + func testNoAddHeuristics() { + assertInferNonisolatedConformances( + """ + class MyClass: P, Q { + nonisolated func f() { } + } + """, + """ + class MyClass: P, nonisolated Q { + nonisolated func f() { } + } + """ + ) + } + + func testNoAddRawType() { + assertInferNonisolatedConformances( + """ + enum MyEnum: Int, Q { + nonisolated func f() { } + } + """, + """ + enum MyEnum: Int, nonisolated Q { + nonisolated func f() { } + } + """ + ) + + assertInferNonisolatedConformances( + """ + enum MyEnum: P, Q { + nonisolated func f() { } + } + """, + """ + enum MyEnum: nonisolated P, nonisolated Q { + nonisolated func f() { } + } + """ + ) + } + + func testNoAddProtocol() { + assertInferNonisolatedConformances( + """ + protocol MyProtocol: P, Q { + nonisolated func f() { } + } + """, + """ + protocol MyProtocol: P, Q { + nonisolated func f() { } + } + """ + ) + } +} + +public func assertInferNonisolatedConformances( + _ original: DeclSyntax, + _ expected: DeclSyntax, + additionalInfo: @autoclosure () -> String? = nil, + file: StaticString = #filePath, + line: UInt = #line +) { + let result = original.inferNonisolatedConformances() + + assertStringsEqualWithDiff( + result.description, + expected.description, + file: file, + line: line + ) +}