diff --git a/Sources/FormattingHelpers.swift b/Sources/FormattingHelpers.swift index 8517e6bb9..9dce078a7 100644 --- a/Sources/FormattingHelpers.swift +++ b/Sources/FormattingHelpers.swift @@ -1196,18 +1196,34 @@ extension Formatter { /// Whether or not the code block starting at the given `.startOfScope` token /// has a single statement. This makes it eligible to be used with implicit return. - func blockBodyHasSingleStatement(atStartOfScope startOfScopeIndex: Int) -> Bool { + func blockBodyHasSingleStatement( + atStartOfScope startOfScopeIndex: Int, + includingConditionalStatements: Bool = true + ) -> Bool { guard let endOfScopeIndex = endOfScope(at: startOfScopeIndex) else { return false } let startOfBody = self.startOfBody(atStartOfScope: startOfScopeIndex) + // The body should contain exactly one expression. + // We can confirm this by parsing the body with `parseExpressionRange`, + // and checking that the token after that expression is just the end of the scope. + guard var firstTokenInBody = index(of: .nonSpaceOrCommentOrLinebreak, after: startOfBody) else { + return false + } + + // Skip over any optional `return` keyword + if tokens[firstTokenInBody] == .keyword("return") { + guard let tokenAfterReturnKeyword = index(of: .nonSpaceOrCommentOrLinebreak, after: firstTokenInBody) else { return false } + firstTokenInBody = tokenAfterReturnKeyword + } + // In Swift 5.9+, if and switch statements where each branch is a single statement // are also considered single statements if options.swiftVersion >= "5.9", - let firstTokenInBody = index(of: .nonSpaceOrCommentOrLinebreak, after: startOfBody), + includingConditionalStatements, let conditionalBranches = conditionalBranches(at: firstTokenInBody) { let isSingleStatement = conditionalBranches.allSatisfy { branch in - blockBodyHasSingleStatement(atStartOfScope: branch.startOfBranch) + blockBodyHasSingleStatement(atStartOfScope: branch.startOfBranch, includingConditionalStatements: true) } let endOfStatement = conditionalBranches.last?.endOfBranch ?? firstTokenInBody @@ -1216,18 +1232,6 @@ extension Formatter { return isSingleStatement && isOnlyStatement } - // The body should contain exactly one expression. - // We can confirm this by parsing the body with `parseExpressionRange`, - // and checking that the token after that expression is just the end of the scope. - guard var firstTokenInBody = index(of: .nonSpaceOrCommentOrLinebreak, after: startOfBody) else { - return false - } - - // Skip over any optional `return` keyword - if tokens[firstTokenInBody] == .keyword("return") { - guard let tokenAfterReturnKeyword = index(of: .nonSpaceOrCommentOrLinebreak, after: firstTokenInBody) else { return false } - firstTokenInBody = tokenAfterReturnKeyword - } guard let expressionRange = parseExpressionRange(startingAt: firstTokenInBody), let nextIndexAfterExpression = index(of: .nonSpaceOrCommentOrLinebreak, after: expressionRange.upperBound) else { @@ -1257,6 +1261,26 @@ extension Formatter { /// If `index` is the start of an `if` or `switch` statement, /// finds and returns all of the statement branches. func conditionalBranches(at index: Int) -> [ConditionalBranch]? { + // Skip over any `try`, `try?`, `try!`, or `await` token, + // which are valid before an if/switch expression. + if tokens[index] == .keyword("await"), + let nextToken = self.index(of: .nonSpaceOrCommentOrLinebreak, after: index) + { + return conditionalBranches(at: nextToken) + } + + if tokens[index] == .keyword("try"), + let tokenAfterTry = self.index(of: .nonSpaceOrCommentOrLinebreak, after: index) + { + if tokens[tokenAfterTry] == .operator("!", .postfix) || tokens[tokenAfterTry] == .operator("?", .postfix), + let tokenAfterOperator = self.index(of: .nonSpaceOrCommentOrLinebreak, after: tokenAfterTry) + { + return conditionalBranches(at: tokenAfterOperator) + } else { + return conditionalBranches(at: tokenAfterTry) + } + } + if tokens[index] == .keyword("if") { return ifStatementBranches(at: index) } else if tokens[index] == .keyword("switch") { diff --git a/Sources/ParsingHelpers.swift b/Sources/ParsingHelpers.swift index 20fe73126..ac134aa77 100644 --- a/Sources/ParsingHelpers.swift +++ b/Sources/ParsingHelpers.swift @@ -1314,9 +1314,7 @@ extension Formatter { nextTokenAfterTry = nextTokenAfterTryOperator } - if let startOfFollowingExpressionIndex = index(of: .nonSpaceOrCommentOrLinebreak, after: nextTokenAfterTry), - let followingExpression = parseExpressionRange(startingAt: startOfFollowingExpressionIndex) - { + if let followingExpression = parseExpressionRange(startingAt: nextTokenAfterTry) { return startIndex ... followingExpression.upperBound } } diff --git a/Sources/Rules.swift b/Sources/Rules.swift index a21141fb8..1d82c867e 100644 --- a/Sources/Rules.swift +++ b/Sources/Rules.swift @@ -6200,7 +6200,17 @@ public struct _FormatRules { // because removing them could break the build. formatter.index(of: .nonSpaceOrCommentOrLinebreak, after: closureStartIndex) != closureEndIndex { - guard formatter.blockBodyHasSingleStatement(atStartOfScope: closureStartIndex) else { + /// Whether or not this closure has a single, simple expression in its body. + /// These closures can always be simplified / removed regardless of the context. + let hasSingleSimpleExpression = formatter.blockBodyHasSingleStatement(atStartOfScope: closureStartIndex, includingConditionalStatements: false) + + /// Whether or not this closure has a single if/switch expression in its body. + /// Since if/switch expressions are only valid in the `return` position or as an `=` assignment, + /// these closures can only sometimes be simplified / removed. + let hasSingleConditionalExpression = !hasSingleSimpleExpression && + formatter.blockBodyHasSingleStatement(atStartOfScope: closureStartIndex, includingConditionalStatements: true) + + guard hasSingleSimpleExpression || hasSingleConditionalExpression else { return } @@ -6244,6 +6254,43 @@ public struct _FormatRules { startIndex = prevIndex } + // Since if/switch expressions are only valid in the `return` position or as an `=` assignment, + // these closures can only sometimes be simplified / removed. + if hasSingleConditionalExpression { + // Find the `{` start of scope or `=` and verify that the entire following expression consists of just this closure. + var startOfScopeContainingClosure = formatter.startOfScope(at: startIndex) + var assignmentBeforeClosure = formatter.index(of: .operator("=", .infix), before: startIndex) + + let potentialStartOfExpressionContainingClosure: Int? + switch (startOfScopeContainingClosure, assignmentBeforeClosure) { + case (nil, nil): + potentialStartOfExpressionContainingClosure = nil + case (.some(let startOfScope), nil): + potentialStartOfExpressionContainingClosure = startOfScope + case (nil, let .some(assignmentBeforeClosure)): + potentialStartOfExpressionContainingClosure = assignmentBeforeClosure + case let (.some(startOfScope), .some(assignmentBeforeClosure)): + potentialStartOfExpressionContainingClosure = max(startOfScope, assignmentBeforeClosure) + } + + if let potentialStartOfExpressionContainingClosure = potentialStartOfExpressionContainingClosure { + guard var startOfExpressionIndex = formatter.index(of: .nonSpaceOrCommentOrLinebreak, after: potentialStartOfExpressionContainingClosure) + else { return } + + // Skip over any return token that may be present + if formatter.tokens[startOfExpressionIndex] == .keyword("return"), + let nextTokenIndex = formatter.index(of: .nonSpaceOrCommentOrLinebreak, after: startOfExpressionIndex) + { + startOfExpressionIndex = nextTokenIndex + } + + // Parse the expression and require that entire expression is simply just this closure. + guard let expressionRange = formatter.parseExpressionRange(startingAt: startOfExpressionIndex), + expressionRange == startIndex ... closureCallCloseParenIndex + else { return } + } + } + // If the closure is a property with an explicit `Void` type, // we can't remove the closure since the build would break // if the method is `@discardableResult` @@ -6276,14 +6323,13 @@ public struct _FormatRules { closureEndIndex -= 1 } - // remove the { }() tokens + // remove the trailing }() tokens, working backwards to not invalidate any indices formatter.removeToken(at: closureCallCloseParenIndex) formatter.removeToken(at: closureCallOpenParenIndex) formatter.removeToken(at: closureEndIndex) - formatter.removeTokens(in: startIndex ... closureStartIndex) - // Remove the initial return token, and any trailing space, if present - if let returnIndex = formatter.index(of: .nonSpaceOrCommentOrLinebreak, after: closureStartIndex - 1), + // Remove the initial return token, and any trailing space, if present. + if let returnIndex = formatter.index(of: .nonSpaceOrCommentOrLinebreak, after: closureStartIndex), formatter.token(at: returnIndex)?.string == "return" { while formatter.token(at: returnIndex + 1)?.isSpaceOrLinebreak == true { @@ -6292,6 +6338,9 @@ public struct _FormatRules { formatter.removeToken(at: returnIndex) } + + // Finally, remove then open `{` token + formatter.removeTokens(in: startIndex ... closureStartIndex) } } } diff --git a/Tests/ParsingHelpersTests.swift b/Tests/ParsingHelpersTests.swift index 21e89af14..387a05da8 100644 --- a/Tests/ParsingHelpersTests.swift +++ b/Tests/ParsingHelpersTests.swift @@ -1822,6 +1822,11 @@ class ParsingHelpersTests: XCTestCase { XCTAssert(isSingleExpression(#"#selector(Foo.bar)"#)) XCTAssert(isSingleExpression(#"#macro()"#)) XCTAssert(isSingleExpression(#"#outerMacro(12, #innerMacro(34), "some text")"#)) + XCTAssert(isSingleExpression(#"try { try printThrows(foo) }()"#)) + XCTAssert(isSingleExpression(#"try! { try printThrows(foo) }()"#)) + XCTAssert(isSingleExpression(#"try? { try printThrows(foo) }()"#)) + XCTAssert(isSingleExpression(#"await { await printAsync(foo) }()"#)) + XCTAssert(isSingleExpression(#"try await { try await printAsyncThrows(foo) }()"#)) XCTAssert(isSingleExpression(""" foo diff --git a/Tests/RulesTests+Redundancy.swift b/Tests/RulesTests+Redundancy.swift index 0cbb64506..8a8b71a4d 100644 --- a/Tests/RulesTests+Redundancy.swift +++ b/Tests/RulesTests+Redundancy.swift @@ -7724,6 +7724,69 @@ class RedundancyTests: RulesTests { testFormatting(for: input, output, rule: FormatRules.redundantClosure) } + func testRedundantClosureWithExplicitReturn() { + let input = """ + let foo = { return "Foo" }() + + let bar = { + return if Bool.random() { + "Bar" + } else { + "Baaz" + } + }() + """ + + let output = """ + let foo = "Foo" + + let bar = if Bool.random() { + "Bar" + } else { + "Baaz" + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["indent"]) + } + + func testRedundantClosureWithExplicitReturn2() { + let input = """ + func foo() -> String { + methodCall() + return { return "Foo" }() + } + + func bar() -> String { + methodCall() + return { "Bar" }() + } + + func baaz() -> String { + { return "Baaz" }() + } + """ + + let output = """ + func foo() -> String { + methodCall() + return "Foo" + } + + func bar() -> String { + methodCall() + return "Bar" + } + + func baaz() -> String { + "Baaz" + } + """ + + testFormatting(for: input, output, rule: FormatRules.redundantClosure, exclude: ["redundantReturn"]) + } + func testKeepsClosureThatIsNotCalled() { let input = """ let foo = { "Foo" } @@ -8106,6 +8169,307 @@ class RedundancyTests: RulesTests { exclude: ["redundantReturn", "indent"]) } + func testRedundantClosureDoesntLeaveInvalidSwitchExpressionInOperatorChain() { + let input = """ + private enum Format { + case uint8 + case uint16 + + var bytes: Int { + { + switch self { + case .uint8: UInt8.bitWidth + case .uint16: UInt16.bitWidth + } + }() / 8 + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, rule: FormatRules.redundantClosure, options: options) + } + + func testRedundantClosureDoesntLeaveInvalidIfExpressionInOperatorChain() { + let input = """ + private enum Format { + case uint8 + case uint16 + + var bytes: Int { + { + if self == .uint8 { + UInt8.bitWidth + } else { + UInt16.bitWidth + } + }() / 8 + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, rule: FormatRules.redundantClosure, options: options) + } + + func testRedundantClosureDoesntLeaveInvalidIfExpressionInOperatorChain2() { + let input = """ + private enum Format { + case uint8 + case uint16 + + var bytes: Int { + 8 / { + if self == .uint8 { + UInt8.bitWidth + } else { + UInt16.bitWidth + } + }() + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, rule: FormatRules.redundantClosure, options: options) + } + + func testRedundantClosureDoesntLeaveInvalidIfExpressionInOperatorChain3() { + let input = """ + private enum Format { + case uint8 + case uint16 + + var bytes = 8 / { + if self == .uint8 { + UInt8.bitWidth + } else { + UInt16.bitWidth + } + }() + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, rule: FormatRules.redundantClosure, options: options) + } + + func testRedundantClosureDoesRemoveRedundantIfStatementClosureInAssignmentPosition() { + let input = """ + private enum Format { + case uint8 + case uint16 + + var bytes = { + if self == .uint8 { + UInt8.bitWidth + } else { + UInt16.bitWidth + } + }() + } + """ + + let output = """ + private enum Format { + case uint8 + case uint16 + + var bytes = if self == .uint8 { + UInt8.bitWidth + } else { + UInt16.bitWidth + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["indent"]) + } + + func testRedundantClosureDoesntLeaveInvalidSwitchExpressionInArray() { + let input = """ + private func constraint() -> [Int] { + [ + 1, + 2, + { + if Bool.random() { + 3 + } else { + 4 + } + }(), + ] + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, rule: FormatRules.redundantClosure, options: options) + } + + func testRedundantClosureRemovesClosureAsReturnTryStatement() { + let input = """ + func method() -> Int { + return { + return try! if Bool.random() { + randomThrows() + } else { + randomThrows() + } + }() + } + """ + + let output = """ + func method() -> Int { + return try! if Bool.random() { + randomThrows() + } else { + randomThrows() + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["redundantReturn", "indent"]) + } + + func testRedundantClosureRemovesClosureAsReturnTryStatement2() { + let input = """ + func method() throws -> Int { + return try { + return try if Bool.random() { + randomThrows() + } else { + randomThrows() + } + }() + } + """ + + let output = """ + func method() throws -> Int { + return try if Bool.random() { + randomThrows() + } else { + randomThrows() + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["redundantReturn", "indent"]) + } + + func testRedundantClosureRemovesClosureAsReturnTryStatement3() { + let input = """ + func method() async throws -> Int { + return try await { + return try await if Bool.random() { + randomAsyncThrows() + } else { + randomAsyncThrows() + } + }() + } + """ + + let output = """ + func method() async throws -> Int { + return try await if Bool.random() { + randomAsyncThrows() + } else { + randomAsyncThrows() + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["redundantReturn", "indent"]) + } + + func testRedundantClosureRemovesClosureAsReturnTryStatement4() { + let input = """ + func method() -> Int { + return { + return try! if Bool.random() { + randomThrows() + } else { + randomThrows() + } + }() + } + """ + + let output = """ + func method() -> Int { + return try! if Bool.random() { + randomThrows() + } else { + randomThrows() + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["redundantReturn", "indent"]) + } + + func testRedundantClosureRemovesClosureAsReturnStatement() { + let input = """ + func method() -> Int { + return { + return if Bool.random() { + 42 + } else { + 43 + } + }() + } + """ + + let output = """ + func method() -> Int { + return if Bool.random() { + 42 + } else { + 43 + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["redundantReturn", "indent"]) + } + + func testRedundantClosureRemovesClosureAsImplicitReturnStatement() { + let input = """ + func method() -> Int { + { + if Bool.random() { + 42 + } else { + 43 + } + }() + } + """ + + let output = """ + func method() -> Int { + if Bool.random() { + 42 + } else { + 43 + } + } + """ + + let options = FormatOptions(swiftVersion: "5.9") + testFormatting(for: input, output, rule: FormatRules.redundantClosure, options: options, exclude: ["indent"]) + } + // MARK: Redundant optional binding func testRemovesRedundantOptionalBindingsInSwift5_7() {