diff --git a/Package.swift b/Package.swift index db38820..63028e8 100644 --- a/Package.swift +++ b/Package.swift @@ -47,7 +47,7 @@ let dependencies: [Package.Dependency] = [ ), .package( url: "https://github.com/grpc/grpc-swift-protobuf.git", - exact: "1.0.0-beta.3" + branch: "main" ), .package( url: "https://github.com/apple/swift-protobuf.git", diff --git a/Sources/GRPCInterceptors/ClientTracingInterceptor.swift b/Sources/GRPCInterceptors/ClientTracingInterceptor.swift deleted file mode 100644 index a4c85c9..0000000 --- a/Sources/GRPCInterceptors/ClientTracingInterceptor.swift +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright 2024, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -public import GRPCCore -internal import Tracing - -/// A client interceptor that injects tracing information into the request. -/// -/// The tracing information is taken from the current `ServiceContext`, and injected into the request's -/// metadata. It will then be picked up by the server-side ``ServerTracingInterceptor``. -/// -/// For more information, refer to the documentation for `swift-distributed-tracing`. -public struct ClientTracingInterceptor: ClientInterceptor { - private let injector: ClientRequestInjector - private let emitEventOnEachWrite: Bool - - /// Create a new instance of a ``ClientTracingInterceptor``. - /// - /// - Parameter emitEventOnEachWrite: If `true`, each request part sent and response part - /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response - /// start and end will be recorded as events. - public init(emitEventOnEachWrite: Bool = false) { - self.injector = ClientRequestInjector() - self.emitEventOnEachWrite = emitEventOnEachWrite - } - - /// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs - /// have been made available by the tracing implementation bootstrapped in your application. - /// - /// Which key-value pairs are injected will depend on the specific tracing implementation - /// that has been configured when bootstrapping `swift-distributed-tracing` in your application. - public func intercept( - request: StreamingClientRequest, - context: ClientContext, - next: ( - StreamingClientRequest, - ClientContext - ) async throws -> StreamingClientResponse - ) async throws -> StreamingClientResponse where Input: Sendable, Output: Sendable { - var request = request - let tracer = InstrumentationSystem.tracer - let serviceContext = ServiceContext.current ?? .topLevel - - tracer.inject( - serviceContext, - into: &request.metadata, - using: self.injector - ) - - return try await tracer.withSpan( - context.descriptor.fullyQualifiedMethod, - context: serviceContext, - ofKind: .client - ) { span in - span.addEvent("Request started") - - if self.emitEventOnEachWrite { - let wrappedProducer = request.producer - request.producer = { writer in - let eventEmittingWriter = HookedWriter( - wrapping: writer, - beforeEachWrite: { - span.addEvent("Sending request part") - }, - afterEachWrite: { - span.addEvent("Sent request part") - } - ) - - do { - try await wrappedProducer(RPCWriter(wrapping: eventEmittingWriter)) - } catch { - span.addEvent("Error encountered") - throw error - } - - span.addEvent("Request end") - } - } - - var response: StreamingClientResponse - do { - response = try await next(request, context) - } catch { - span.addEvent("Error encountered") - throw error - } - - switch response.accepted { - case .success(var success): - if self.emitEventOnEachWrite { - let onEachPartRecordingSequence = success.bodyParts.map { element in - span.addEvent("Received response part") - return element - } - let onFinishRecordingSequence = OnFinishAsyncSequence( - wrapping: onEachPartRecordingSequence - ) { - span.addEvent("Received response end") - } - success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence) - response.accepted = .success(success) - } else { - let onFinishRecordingSequence = OnFinishAsyncSequence(wrapping: success.bodyParts) { - span.addEvent("Received response end") - } - success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence) - response.accepted = .success(success) - } - case .failure: - span.addEvent("Received error response") - } - - return response - } - } -} - -/// An injector responsible for injecting the required instrumentation keys from the `ServiceContext` into -/// the request metadata. -struct ClientRequestInjector: Instrumentation.Injector { - typealias Carrier = Metadata - - func inject(_ value: String, forKey key: String, into carrier: inout Carrier) { - carrier.addString(value, forKey: key) - } -} diff --git a/Sources/GRPCInterceptors/HookedAsyncSequence.swift b/Sources/GRPCInterceptors/HookedAsyncSequence.swift new file mode 100644 index 0000000..fdbe100 --- /dev/null +++ b/Sources/GRPCInterceptors/HookedAsyncSequence.swift @@ -0,0 +1,80 @@ +/* + * Copyright 2025, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +internal struct HookedRPCAsyncSequence: AsyncSequence, Sendable +where Wrapped.Element: Sendable { + private let wrapped: Wrapped + + private let forEachElement: @Sendable (Wrapped.Element) -> Void + private let onFinish: @Sendable ((any Error)?) -> Void + + init( + wrapping sequence: Wrapped, + forEachElement: @escaping @Sendable (Wrapped.Element) -> Void, + onFinish: @escaping @Sendable ((any Error)?) -> Void + ) { + self.wrapped = sequence + self.forEachElement = forEachElement + self.onFinish = onFinish + } + + func makeAsyncIterator() -> HookedAsyncIterator { + HookedAsyncIterator( + self.wrapped, + forEachElement: self.forEachElement, + onFinish: self.onFinish + ) + } + + struct HookedAsyncIterator: AsyncIteratorProtocol { + typealias Element = Wrapped.Element + + private var wrapped: Wrapped.AsyncIterator + private let forEachElement: @Sendable (Wrapped.Element) -> Void + private let onFinish: @Sendable ((any Error)?) -> Void + + init( + _ sequence: Wrapped, + forEachElement: @escaping @Sendable (Wrapped.Element) -> Void, + onFinish: @escaping @Sendable ((any Error)?) -> Void + ) { + self.wrapped = sequence.makeAsyncIterator() + self.forEachElement = forEachElement + self.onFinish = onFinish + } + + mutating func next( + isolation actor: isolated (any Actor)? + ) async throws(Wrapped.Failure) -> Wrapped.Element? { + do { + if let element = try await self.wrapped.next(isolation: actor) { + self.forEachElement(element) + return element + } else { + self.onFinish(nil) + return nil + } + } catch { + self.onFinish(error) + throw error + } + } + + mutating func next() async throws -> Wrapped.Element? { + try await self.next(isolation: nil) + } + } +} diff --git a/Sources/GRPCInterceptors/HookedWriter.swift b/Sources/GRPCInterceptors/HookedWriter.swift index b93c491..1baec5b 100644 --- a/Sources/GRPCInterceptors/HookedWriter.swift +++ b/Sources/GRPCInterceptors/HookedWriter.swift @@ -18,27 +18,22 @@ internal import Tracing struct HookedWriter: RPCWriterProtocol { private let writer: any RPCWriterProtocol - private let beforeEachWrite: @Sendable () -> Void private let afterEachWrite: @Sendable () -> Void init( wrapping other: some RPCWriterProtocol, - beforeEachWrite: @Sendable @escaping () -> Void, afterEachWrite: @Sendable @escaping () -> Void ) { self.writer = other - self.beforeEachWrite = beforeEachWrite self.afterEachWrite = afterEachWrite } func write(_ element: Element) async throws { - self.beforeEachWrite() try await self.writer.write(element) self.afterEachWrite() } func write(contentsOf elements: some Sequence) async throws { - self.beforeEachWrite() try await self.writer.write(contentsOf: elements) self.afterEachWrite() } diff --git a/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift b/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift deleted file mode 100644 index f7a8f64..0000000 --- a/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2024, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -struct OnFinishAsyncSequence: AsyncSequence, Sendable { - private let _makeAsyncIterator: @Sendable () -> AsyncIterator - - init( - wrapping other: S, - onFinish: @escaping @Sendable () -> Void - ) where S.Element == Element, S: Sendable { - self._makeAsyncIterator = { - AsyncIterator(wrapping: other.makeAsyncIterator(), onFinish: onFinish) - } - } - - func makeAsyncIterator() -> AsyncIterator { - self._makeAsyncIterator() - } - - struct AsyncIterator: AsyncIteratorProtocol { - private var iterator: any AsyncIteratorProtocol - private var onFinish: (@Sendable () -> Void)? - - fileprivate init( - wrapping other: Iterator, - onFinish: @escaping @Sendable () -> Void - ) where Iterator: AsyncIteratorProtocol, Iterator.Element == Element { - self.iterator = other - self.onFinish = onFinish - } - - mutating func next() async throws -> Element? { - let elem = try await self.iterator.next() - - if elem == nil { - self.onFinish?() - self.onFinish = nil - } - - return elem as? Element - } - } -} diff --git a/Sources/GRPCInterceptors/Tracing/ClientOTelTracingInterceptor.swift b/Sources/GRPCInterceptors/Tracing/ClientOTelTracingInterceptor.swift new file mode 100644 index 0000000..b51834c --- /dev/null +++ b/Sources/GRPCInterceptors/Tracing/ClientOTelTracingInterceptor.swift @@ -0,0 +1,206 @@ +/* + * Copyright 2024-2025, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +public import GRPCCore +internal import Synchronization +package import Tracing + +/// A client interceptor that injects tracing information into the request. +/// +/// The tracing information is taken from the current `ServiceContext`, and injected into the request's +/// metadata. It will then be picked up by the server-side ``ServerTracingInterceptor``. +/// +/// For more information, refer to the documentation for `swift-distributed-tracing`. +public struct ClientOTelTracingInterceptor: ClientInterceptor { + private let injector: ClientRequestInjector + private let traceEachMessage: Bool + private var serverHostname: String + private var networkTransportMethod: String + + /// Create a new instance of a ``ClientOTelTracingInterceptor``. + /// + /// - Parameters: + /// - severHostname: The hostname of the RPC server. This will be the value for the `server.address` attribute in spans. + /// - networkTransportMethod: The transport in use (e.g. "tcp", "unix"). This will be the value for the + /// `network.transport` attribute in spans. + /// - traceEachMessage: If `true`, each request part sent and response part received will be recorded as a separate + /// event in a tracing span. Otherwise, only the request/response start and end will be recorded as events. + public init( + serverHostname: String, + networkTransportMethod: String, + traceEachMessage: Bool = true + ) { + self.injector = ClientRequestInjector() + self.serverHostname = serverHostname + self.networkTransportMethod = networkTransportMethod + self.traceEachMessage = traceEachMessage + } + + /// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs + /// have been made available by the tracing implementation bootstrapped in your application. + /// + /// Which key-value pairs are injected will depend on the specific tracing implementation + /// that has been configured when bootstrapping `swift-distributed-tracing` in your application. + /// + /// It will also inject all required and recommended span and event attributes, and set span status, as defined by OpenTelemetry's + /// documentation on: + /// - https://opentelemetry.io/docs/specs/semconv/rpc/rpc-spans + /// - https://opentelemetry.io/docs/specs/semconv/rpc/grpc/ + public func intercept( + request: StreamingClientRequest, + context: ClientContext, + next: ( + StreamingClientRequest, + ClientContext + ) async throws -> StreamingClientResponse + ) async throws -> StreamingClientResponse where Input: Sendable, Output: Sendable { + try await self.intercept( + tracer: InstrumentationSystem.tracer, + request: request, + context: context, + next: next + ) + } + + /// Same as ``intercept(request:context:next:)``, but allows specifying a `Tracer` for testing purposes. + package func intercept( + tracer: any Tracer, + request: StreamingClientRequest, + context: ClientContext, + next: ( + StreamingClientRequest, + ClientContext + ) async throws -> StreamingClientResponse + ) async throws -> StreamingClientResponse where Input: Sendable, Output: Sendable { + var request = request + let serviceContext = ServiceContext.current ?? .topLevel + + tracer.inject( + serviceContext, + into: &request.metadata, + using: self.injector + ) + + return try await tracer.withSpan( + context.descriptor.fullyQualifiedMethod, + context: serviceContext, + ofKind: .client + ) { span in + span.setOTelClientSpanGRPCAttributes( + context: context, + serverHostname: self.serverHostname, + networkTransportMethod: self.networkTransportMethod + ) + + if self.traceEachMessage { + let wrappedProducer = request.producer + request.producer = { writer in + let messageSentCounter = Atomic(1) + let eventEmittingWriter = HookedWriter( + wrapping: writer, + afterEachWrite: { + var event = SpanEvent(name: "rpc.message") + event.attributes[GRPCTracingKeys.rpcMessageType] = "SENT" + event.attributes[GRPCTracingKeys.rpcMessageID] = + messageSentCounter + .wrappingAdd(1, ordering: .sequentiallyConsistent) + .oldValue + span.addEvent(event) + } + ) + try await wrappedProducer(RPCWriter(wrapping: eventEmittingWriter)) + } + } + + var response = try await next(request, context) + switch response.accepted { + case .success(var success): + let hookedSequence: + HookedRPCAsyncSequence< + RPCAsyncSequence.Contents.BodyPart, any Error> + > + if self.traceEachMessage { + let messageReceivedCounter = Atomic(1) + hookedSequence = HookedRPCAsyncSequence(wrapping: success.bodyParts) { _ in + var event = SpanEvent(name: "rpc.message") + event.attributes[GRPCTracingKeys.rpcMessageType] = "RECEIVED" + event.attributes[GRPCTracingKeys.rpcMessageID] = + messageReceivedCounter + .wrappingAdd(1, ordering: .sequentiallyConsistent) + .oldValue + span.addEvent(event) + } onFinish: { error in + if let error { + if let errorCode = error.grpcErrorCode { + span.attributes[GRPCTracingKeys.grpcStatusCode] = errorCode.rawValue + } + span.setStatus(SpanStatus(code: .error)) + span.recordError(error) + } else { + span.attributes[GRPCTracingKeys.grpcStatusCode] = 0 + } + } + } else { + hookedSequence = HookedRPCAsyncSequence(wrapping: success.bodyParts) { _ in + // Nothing to do if traceEachMessage is false + } onFinish: { error in + if let error { + if let errorCode = error.grpcErrorCode { + span.attributes[GRPCTracingKeys.grpcStatusCode] = errorCode.rawValue + } + span.setStatus(SpanStatus(code: .error)) + span.recordError(error) + } else { + span.attributes[GRPCTracingKeys.grpcStatusCode] = 0 + } + } + } + + success.bodyParts = RPCAsyncSequence(wrapping: hookedSequence) + response.accepted = .success(success) + + case .failure(let error): + span.attributes[GRPCTracingKeys.grpcStatusCode] = error.code.rawValue + span.setStatus(SpanStatus(code: .error)) + span.recordError(error) + } + + return response + } + } +} + +/// An injector responsible for injecting the required instrumentation keys from the `ServiceContext` into +/// the request metadata. +struct ClientRequestInjector: Instrumentation.Injector { + typealias Carrier = Metadata + + func inject(_ value: String, forKey key: String, into carrier: inout Carrier) { + carrier.addString(value, forKey: key) + } +} + +extension Error { + var grpcErrorCode: RPCError.Code? { + if let rpcError = self as? RPCError { + return rpcError.code + } else if let rpcError = self as? any RPCErrorConvertible { + return rpcError.rpcErrorCode + } else { + return nil + } + } +} diff --git a/Sources/GRPCInterceptors/ServerTracingInterceptor.swift b/Sources/GRPCInterceptors/Tracing/ServerTracingInterceptor.swift similarity index 87% rename from Sources/GRPCInterceptors/ServerTracingInterceptor.swift rename to Sources/GRPCInterceptors/Tracing/ServerTracingInterceptor.swift index 413752d..eeaea16 100644 --- a/Sources/GRPCInterceptors/ServerTracingInterceptor.swift +++ b/Sources/GRPCInterceptors/Tracing/ServerTracingInterceptor.swift @@ -90,43 +90,28 @@ public struct ServerTracingInterceptor: ServerInterceptor { success.producer = { writer in let eventEmittingWriter = HookedWriter( wrapping: writer, - beforeEachWrite: { - span.addEvent("Sending response part") - }, afterEachWrite: { span.addEvent("Sent response part") } ) - let wrappedResult: Metadata - do { - wrappedResult = try await wrappedProducer( - RPCWriter(wrapping: eventEmittingWriter) - ) - } catch { - span.addEvent("Error encountered") - throw error - } + let wrappedResult = try await wrappedProducer( + RPCWriter(wrapping: eventEmittingWriter) + ) span.addEvent("Sent response end") return wrappedResult } } else { success.producer = { writer in - let wrappedResult: Metadata - do { - wrappedResult = try await wrappedProducer(writer) - } catch { - span.addEvent("Error encountered") - throw error - } - + let wrappedResult = try await wrappedProducer(writer) span.addEvent("Sent response end") return wrappedResult } } response = .init(accepted: .success(success)) + case .failure: span.addEvent("Sent error response") } diff --git a/Sources/GRPCInterceptors/Tracing/SpanAttributes+GRPCTracingKeys.swift b/Sources/GRPCInterceptors/Tracing/SpanAttributes+GRPCTracingKeys.swift new file mode 100644 index 0000000..186195a --- /dev/null +++ b/Sources/GRPCInterceptors/Tracing/SpanAttributes+GRPCTracingKeys.swift @@ -0,0 +1,130 @@ +/* + * Copyright 2025, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +internal import GRPCCore +internal import Tracing + +enum GRPCTracingKeys { + static let rpcSystem = "rpc.system" + static let rpcMethod = "rpc.method" + static let rpcService = "rpc.service" + static let rpcMessageID = "rpc.message.id" + static let rpcMessageType = "rpc.message.type" + static let grpcStatusCode = "rpc.grpc.status_code" + + static let serverAddress = "server.address" + static let serverPort = "server.port" + + static let clientAddress = "client.address" + static let clientPort = "client.port" + + static let networkTransport = "network.transport" + static let networkType = "network.type" + static let networkPeerAddress = "network.peer.address" + static let networkPeerPort = "network.peer.port" +} + +extension Span { + // See: https://opentelemetry.io/docs/specs/semconv/rpc/rpc-spans/ + func setOTelClientSpanGRPCAttributes( + context: ClientContext, + serverHostname: String, + networkTransportMethod: String + ) { + self.attributes[GRPCTracingKeys.rpcSystem] = "grpc" + self.attributes[GRPCTracingKeys.serverAddress] = serverHostname + self.attributes[GRPCTracingKeys.networkTransport] = networkTransportMethod + self.attributes[GRPCTracingKeys.rpcService] = context.descriptor.service.fullyQualifiedService + self.attributes[GRPCTracingKeys.rpcMethod] = context.descriptor.method + + // Set server address information + switch PeerAddress(context.remotePeer) { + case .ipv4(let address, let port): + self.attributes[GRPCTracingKeys.networkType] = "ipv4" + self.attributes[GRPCTracingKeys.networkPeerAddress] = address + self.attributes[GRPCTracingKeys.networkPeerPort] = port + self.attributes[GRPCTracingKeys.serverPort] = port + + case .ipv6(let address, let port): + self.attributes[GRPCTracingKeys.networkType] = "ipv6" + self.attributes[GRPCTracingKeys.networkPeerAddress] = address + self.attributes[GRPCTracingKeys.networkPeerPort] = port + self.attributes[GRPCTracingKeys.serverPort] = port + + case .unixDomainSocket(let path): + self.attributes[GRPCTracingKeys.networkPeerAddress] = path + + case .none: + // We don't recognise this address format, so don't populate any fields. + () + } + } +} + +package enum PeerAddress: Equatable { + case ipv4(address: String, port: Int?) + case ipv6(address: String, port: Int?) + case unixDomainSocket(path: String) + + package init?(_ address: String) { + // We expect this address to be of one of these formats: + // - ipv4:: for ipv4 addresses + // - ipv6:[]: for ipv6 addresses + // - unix: for UNIX domain sockets + + // First get the first component so that we know what type of address we're dealing with + let addressComponents = address.split(separator: ":", maxSplits: 1) + + guard addressComponents.count > 1 else { + // This is some unexpected/unknown format + return nil + } + + // Check what type the transport is... + switch addressComponents[0] { + case "ipv4": + let ipv4AddressComponents = addressComponents[1].split(separator: ":") + if ipv4AddressComponents.count == 2, let port = Int(ipv4AddressComponents[1]) { + self = .ipv4(address: String(ipv4AddressComponents[0]), port: port) + } else { + return nil + } + + case "ipv6": + if addressComponents[1].first == "[" { + // At this point, we are looking at an address with format: [
]: + // We drop the first character ('[') and split by ']:' to keep two components: the address + // and the port. + let ipv6AddressComponents = addressComponents[1].dropFirst().split(separator: "]:") + if ipv6AddressComponents.count == 2, let port = Int(ipv6AddressComponents[1]) { + self = .ipv6(address: String(ipv6AddressComponents[0]), port: port) + } else { + return nil + } + } else { + return nil + } + + case "unix": + // Whatever comes after "unix:" is the + self = .unixDomainSocket(path: String(addressComponents[1])) + + default: + // This is some unexpected/unknown format + return nil + } + } +} diff --git a/Tests/GRPCInterceptorsTests/PeerAddressTests.swift b/Tests/GRPCInterceptorsTests/PeerAddressTests.swift new file mode 100644 index 0000000..dc4249e --- /dev/null +++ b/Tests/GRPCInterceptorsTests/PeerAddressTests.swift @@ -0,0 +1,59 @@ +/* + * Copyright 2025, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCInterceptors +import Testing + +@Suite("PeerAddress tests") +struct PeerAddressTests { + @Test("IPv4 addresses are correctly parsed") + func testIPv4() { + let address = PeerAddress("ipv4:10.1.2.80:567") + #expect(address == .ipv4(address: "10.1.2.80", port: 567)) + } + + @Test("IPv6 addresses are correctly parsed") + func testIPv6() { + let address = PeerAddress("ipv6:[2001::130F:::09C0:876A:130B]:1234") + #expect(address == .ipv6(address: "2001::130F:::09C0:876A:130B", port: 1234)) + } + + @Test("Unix domain sockets are correctly parsed") + func testUDS() { + let address = PeerAddress("unix:some-path") + #expect(address == .unixDomainSocket(path: "some-path")) + } + + @Test( + "Unrecognised addresses return nil", + arguments: [ + "", + "unknown", + "in-process:1234", + "ipv4:", + "ipv4:1234", + "ipv6:", + "ipv6:123:456:789:123", + "ipv6:123:456:789]:123", + "ipv6:123:456:789]", + "unix", + ] + ) + func testOther(address: String) { + let address = PeerAddress(address) + #expect(address == nil) + } +} diff --git a/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift index cf847ca..cead2ab 100644 --- a/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift +++ b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift @@ -1,5 +1,5 @@ /* - * Copyright 2024, gRPC Authors All rights reserved. + * Copyright 2024-2025, gRPC Authors All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,45 +15,67 @@ */ import GRPCCore +import GRPCInterceptors +import Testing import Tracing import XCTest -@testable import GRPCInterceptors +@Suite("OTel Tracing Client Interceptor Tests") +struct OTelTracingClientInterceptorTests { + private let tracer: TestTracer -final class TracingInterceptorTests: XCTestCase { - override class func setUp() { - InstrumentationSystem.bootstrap(TestTracer()) + init() { + self.tracer = TestTracer() } - func testClientInterceptor() async throws { + // - MARK: Client Interceptor Tests + + @Test( + "Successful RPC is recorded correctly", + arguments: OTelTracingInterceptorTestAddressType.allCases + ) + func testSuccessfulRPC(addressType: OTelTracingInterceptorTestAddressType) async throws { var serviceContext = ServiceContext.topLevel let traceIDString = UUID().uuidString - let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false) - let (stream, continuation) = AsyncStream.makeStream() + let (requestStream, requestStreamContinuation) = AsyncStream.makeStream() serviceContext.traceID = traceIDString // FIXME: use 'ServiceContext.withValue(serviceContext)' // // This is blocked on: https://github.com/apple/swift-service-context/pull/46 try await ServiceContext.$current.withValue(serviceContext) { + let interceptor = ClientOTelTracingInterceptor( + serverHostname: "someserver.com", + networkTransportMethod: "tcp", + traceEachMessage: false + ) let methodDescriptor = MethodDescriptor( - fullyQualifiedService: "TracingInterceptorTests", - method: "testClientInterceptor" + fullyQualifiedService: "OTelTracingClientInterceptorTests", + method: "testSuccessfulRPC" + ) + let testValues = self.getTestValues( + addressType: addressType, + methodDescriptor: methodDescriptor ) let response = try await interceptor.intercept( + tracer: self.tracer, request: .init(producer: { writer in try await writer.write(contentsOf: ["request1"]) try await writer.write(contentsOf: ["request2"]) }), - context: ClientContext(descriptor: methodDescriptor, remotePeer: "", localPeer: "") + context: ClientContext( + descriptor: methodDescriptor, + remotePeer: testValues.remotePeerAddress, + localPeer: testValues.localPeerAddress + ) ) { stream, _ in // Assert the metadata contains the injected context key-value. - XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"]) + #expect(stream.metadata == ["trace-id": "\(traceIDString)"]) - // Write into the response stream to make sure the `producer` closure's called. - let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + // Write into the request stream to make sure the `producer` closure's called. + let writer = RPCWriter(wrapping: TestWriter(streamContinuation: requestStreamContinuation)) try await stream.producer(writer) - continuation.finish() + requestStreamContinuation.finish() return .init( metadata: [], @@ -66,62 +88,63 @@ final class TracingInterceptorTests: XCTestCase { ) } - var streamIterator = stream.makeAsyncIterator() - var element = await streamIterator.next() - XCTAssertEqual(element, "request1") - element = await streamIterator.next() - XCTAssertEqual(element, "request2") - element = await streamIterator.next() - XCTAssertNil(element) - - var messages = response.messages.makeAsyncIterator() - var message = try await messages.next() - XCTAssertEqual(message, ["response"]) - message = try await messages.next() - XCTAssertNil(message) - - let tracer = InstrumentationSystem.tracer as! TestTracer - XCTAssertEqual( - tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { - $0.name - }, - [ - "Request started", - "Received response end", - ] - ) + await assertStreamContentsEqual(["request1", "request2"], requestStream) + try await assertStreamContentsEqual([["response"]], response.messages) + + assertTestSpanComponents(forMethod: methodDescriptor) { events in + // No events are recorded + #expect(events.isEmpty) + } assertAttributes: { attributes in + #expect(attributes == testValues.expectedSpanAttributes) + } assertStatus: { status in + #expect(status == nil) + } assertErrors: { errors in + #expect(errors == []) + } } } - func testClientInterceptorAllEventsRecorded() async throws { - let methodDescriptor = MethodDescriptor( - fullyQualifiedService: "TracingInterceptorTests", - method: "testClientInterceptorAllEventsRecorded" - ) + @Test("All events are recorded when traceEachMessage is true") + func testAllEventsRecorded() async throws { var serviceContext = ServiceContext.topLevel let traceIDString = UUID().uuidString - let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true) - let (stream, continuation) = AsyncStream.makeStream() + + let (requestStream, requestStreamContinuation) = AsyncStream.makeStream() serviceContext.traceID = traceIDString // FIXME: use 'ServiceContext.withValue(serviceContext)' // // This is blocked on: https://github.com/apple/swift-service-context/pull/46 try await ServiceContext.$current.withValue(serviceContext) { + let interceptor = ClientOTelTracingInterceptor( + serverHostname: "someserver.com", + networkTransportMethod: "tcp", + traceEachMessage: true + ) + let methodDescriptor = MethodDescriptor( + fullyQualifiedService: "OTelTracingClientInterceptorTests", + method: "testAllEventsRecorded" + ) + let testValues = self.getTestValues(addressType: .ipv4, methodDescriptor: methodDescriptor) let response = try await interceptor.intercept( + tracer: self.tracer, request: .init(producer: { writer in try await writer.write(contentsOf: ["request1"]) try await writer.write(contentsOf: ["request2"]) }), - context: ClientContext(descriptor: methodDescriptor, remotePeer: "", localPeer: "") + context: ClientContext( + descriptor: methodDescriptor, + remotePeer: testValues.remotePeerAddress, + localPeer: testValues.localPeerAddress + ) ) { stream, _ in // Assert the metadata contains the injected context key-value. - XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"]) + #expect(stream.metadata == ["trace-id": "\(traceIDString)"]) - // Write into the response stream to make sure the `producer` closure's called. - let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + // Write into the request stream to make sure the `producer` closure's called. + let writer = RPCWriter(wrapping: TestWriter(streamContinuation: requestStreamContinuation)) try await stream.producer(writer) - continuation.finish() + requestStreamContinuation.finish() return .init( metadata: [], @@ -134,44 +157,370 @@ final class TracingInterceptorTests: XCTestCase { ) } - var streamIterator = stream.makeAsyncIterator() - var element = await streamIterator.next() - XCTAssertEqual(element, "request1") - element = await streamIterator.next() - XCTAssertEqual(element, "request2") - element = await streamIterator.next() - XCTAssertNil(element) - - var messages = response.messages.makeAsyncIterator() - var message = try await messages.next() - XCTAssertEqual(message, ["response"]) - message = try await messages.next() - XCTAssertNil(message) - - let tracer = InstrumentationSystem.tracer as! TestTracer - XCTAssertEqual( - tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { - $0.name - }, - [ - "Request started", - // Recorded when `request1` is sent - "Sending request part", - "Sent request part", - // Recorded when `request2` is sent - "Sending request part", - "Sent request part", - // Recorded after all request parts have been sent - "Request end", - // Recorded when receiving response part - "Received response part", - // Recorded at end of response - "Received response end", + await assertStreamContentsEqual(["request1", "request2"], requestStream) + try await assertStreamContentsEqual([["response"]], response.messages) + + assertTestSpanComponents(forMethod: methodDescriptor) { events in + #expect( + events == [ + // Recorded when `request1` is sent + TestSpanEvent("rpc.message", ["rpc.message.type": "SENT", "rpc.message.id": 1]), + // Recorded when `request2` is sent + TestSpanEvent("rpc.message", ["rpc.message.type": "SENT", "rpc.message.id": 2]), + // Recorded when receiving response part + TestSpanEvent("rpc.message", ["rpc.message.type": "RECEIVED", "rpc.message.id": 1]), + ] + ) + } assertAttributes: { attributes in + #expect(attributes == testValues.expectedSpanAttributes) + } assertStatus: { status in + #expect(status == nil) + } assertErrors: { errors in + #expect(errors == []) + } + } + } + + @Test("RPC that throws is correctly recorded") + func testThrowingRPC() async throws { + var serviceContext = ServiceContext.topLevel + let traceIDString = UUID().uuidString + serviceContext.traceID = traceIDString + + // FIXME: use 'ServiceContext.withValue(serviceContext)' + // + // This is blocked on: https://github.com/apple/swift-service-context/pull/46 + await ServiceContext.$current.withValue(serviceContext) { + let interceptor = ClientOTelTracingInterceptor( + serverHostname: "someserver.com", + networkTransportMethod: "tcp", + traceEachMessage: false + ) + let methodDescriptor = MethodDescriptor( + fullyQualifiedService: "OTelTracingClientInterceptorTests", + method: "testThrowingRPC" + ) + do { + let _: StreamingClientResponse = try await interceptor.intercept( + tracer: self.tracer, + request: StreamingClientRequest(of: Void.self, producer: { writer in }), + context: ClientContext( + descriptor: methodDescriptor, + remotePeer: "ipv4:10.1.2.80:567", + localPeer: "ipv4:10.1.2.80:123" + ) + ) { stream, _ in + // Assert the metadata contains the injected context key-value. + #expect(stream.metadata == ["trace-id": "\(traceIDString)"]) + // Now throw + throw TracingInterceptorTestError.testError + } + Issue.record("Should have thrown") + } catch { + assertTestSpanComponents(forMethod: methodDescriptor) { events in + // No events are recorded + #expect(events.isEmpty) + } assertAttributes: { attributes in + // The attributes should not contain a grpc status code, as the request was never even sent. + #expect( + attributes == [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "server.address": "someserver.com", + "server.port": 567, + "network.peer.address": "10.1.2.80", + "network.peer.port": 567, + "network.transport": "tcp", + "network.type": "ipv4", + ] + ) + } assertStatus: { status in + #expect(status == nil) + } assertErrors: { errors in + #expect(errors == [.testError]) + } + } + } + } + + @Test("RPC with a failure response is correctly recorded") + func testFailedRPC() async throws { + var serviceContext = ServiceContext.topLevel + let traceIDString = UUID().uuidString + let (requestStream, requestStreamContinuation) = AsyncStream.makeStream() + serviceContext.traceID = traceIDString + + // FIXME: use 'ServiceContext.withValue(serviceContext)' + // + // This is blocked on: https://github.com/apple/swift-service-context/pull/46 + try await ServiceContext.$current.withValue(serviceContext) { + let interceptor = ClientOTelTracingInterceptor( + serverHostname: "someserver.com", + networkTransportMethod: "tcp", + traceEachMessage: false + ) + let methodDescriptor = MethodDescriptor( + fullyQualifiedService: "OTelTracingClientInterceptorTests", + method: "testFailedRPC" + ) + let response: StreamingClientResponse = try await interceptor.intercept( + tracer: self.tracer, + request: .init(producer: { writer in + try await writer.write(contentsOf: ["request"]) + }), + context: ClientContext( + descriptor: methodDescriptor, + remotePeer: "ipv4:10.1.2.80:567", + localPeer: "ipv4:10.1.2.80:123" + ) + ) { stream, _ in + // Assert the metadata contains the injected context key-value. + #expect(stream.metadata == ["trace-id": "\(traceIDString)"]) + + // Write into the request stream to make sure the `producer` closure's called. + let writer = RPCWriter(wrapping: TestWriter(streamContinuation: requestStreamContinuation)) + try await stream.producer(writer) + requestStreamContinuation.finish() + + return .init(error: RPCError(code: .unavailable, message: "This should not work")) + } + + await assertStreamContentsEqual(["request"], requestStream) + + switch response.accepted { + case .success: + Issue.record("Response should have failed") + return + + case .failure(let failure): + #expect(failure == RPCError(code: .unavailable, message: "This should not work")) + } + + assertTestSpanComponents(forMethod: methodDescriptor) { events in + // No events are recorded + #expect(events.isEmpty) + } assertAttributes: { attributes in + #expect( + attributes == [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "rpc.grpc.status_code": 14, // this is unavailable's raw code + "server.address": "someserver.com", + "server.port": 567, + "network.peer.address": "10.1.2.80", + "network.peer.port": 567, + "network.transport": "tcp", + "network.type": "ipv4", + ] + ) + } assertStatus: { status in + #expect(status == .some(.init(code: .error))) + } assertErrors: { errors in + #expect(errors.count == 1) + } + } + } + + @Test("Accepted server-streaming RPC that throws error during response is correctly recorded") + func testAcceptedRPCWithError() async throws { + var serviceContext = ServiceContext.topLevel + let traceIDString = UUID().uuidString + serviceContext.traceID = traceIDString + + // FIXME: use 'ServiceContext.withValue(serviceContext)' + // + // This is blocked on: https://github.com/apple/swift-service-context/pull/46 + try await ServiceContext.$current.withValue(serviceContext) { + let interceptor = ClientOTelTracingInterceptor( + serverHostname: "someserver.com", + networkTransportMethod: "tcp", + traceEachMessage: false + ) + let methodDescriptor = MethodDescriptor( + fullyQualifiedService: "OTelTracingClientInterceptorTests", + method: "testAcceptedRPCWithError" + ) + let response: StreamingClientResponse = try await interceptor.intercept( + tracer: self.tracer, + request: .init(producer: { writer in + try await writer.write(contentsOf: ["request"]) + }), + context: ClientContext( + descriptor: methodDescriptor, + remotePeer: "ipv4:10.1.2.80:567", + localPeer: "ipv4:10.1.2.80:123" + ) + ) { stream, _ in + // Assert the metadata contains the injected context key-value. + #expect(stream.metadata == ["trace-id": "\(traceIDString)"]) + + return .init( + metadata: [], + bodyParts: RPCAsyncSequence( + wrapping: AsyncThrowingStream { + $0.finish(throwing: RPCError(code: .unavailable, message: "This should be thrown")) + } + ) + ) + } + + switch response.accepted { + case .success(let success): + do { + for try await _ in success.bodyParts { + // We don't care about any received messages here - we're not even writing any. + } + } catch { + #expect( + error as? RPCError + == RPCError( + code: .unavailable, + message: "This should be thrown" + ) + ) + } + + case .failure: + Issue.record("Response should have been successful") + return + } + + assertTestSpanComponents(forMethod: methodDescriptor) { events in + // No events are recorded + #expect(events.isEmpty) + } assertAttributes: { attributes in + #expect( + attributes == [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "rpc.grpc.status_code": 14, // this is unavailable's raw code + "server.address": "someserver.com", + "server.port": 567, + "network.peer.address": "10.1.2.80", + "network.peer.port": 567, + "network.transport": "tcp", + "network.type": "ipv4", + ] + ) + } assertStatus: { status in + #expect(status == .some(.init(code: .error))) + } assertErrors: { errors in + #expect(errors.count == 1) + } + } + } + + // - MARK: Utilities + + private func getTestValues( + addressType: OTelTracingInterceptorTestAddressType, + methodDescriptor: MethodDescriptor + ) -> OTelTracingInterceptorTestCaseValues { + switch addressType { + case .ipv4: + return OTelTracingInterceptorTestCaseValues( + remotePeerAddress: "ipv4:10.1.2.80:567", + localPeerAddress: "ipv4:10.1.2.80:123", + expectedSpanAttributes: [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "rpc.grpc.status_code": 0, + "server.address": "someserver.com", + "server.port": 567, + "network.peer.address": "10.1.2.80", + "network.peer.port": 567, + "network.transport": "tcp", + "network.type": "ipv4", + ] + ) + + case .ipv6: + return OTelTracingInterceptorTestCaseValues( + remotePeerAddress: "ipv6:[2001::130F:::09C0:876A:130B]:1234", + localPeerAddress: "ipv6:[ff06:0:0:0:0:0:0:c3]:5678", + expectedSpanAttributes: [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "rpc.grpc.status_code": 0, + "server.address": "someserver.com", + "server.port": 1234, + "network.peer.address": "2001::130F:::09C0:876A:130B", + "network.peer.port": 1234, + "network.transport": "tcp", + "network.type": "ipv6", ] ) + + case .uds: + return OTelTracingInterceptorTestCaseValues( + remotePeerAddress: "unix:some-path", + localPeerAddress: "unix:some-path", + expectedSpanAttributes: [ + "rpc.system": "grpc", + "rpc.method": .string(methodDescriptor.method), + "rpc.service": .string(methodDescriptor.service.fullyQualifiedService), + "rpc.grpc.status_code": 0, + "server.address": "someserver.com", + "network.peer.address": "some-path", + "network.transport": "tcp", + ] + ) + } + } + + private func getTestSpanForMethod(_ methodDescriptor: MethodDescriptor) -> TestSpan { + return self.tracer.getSpan(ofOperation: methodDescriptor.fullyQualifiedMethod)! + } + + private func assertTestSpanComponents( + forMethod method: MethodDescriptor, + assertEvents: ([TestSpanEvent]) -> Void, + assertAttributes: (SpanAttributes) -> Void, + assertStatus: (SpanStatus?) -> Void, + assertErrors: ([TracingInterceptorTestError]) -> Void + ) { + let span = self.getTestSpanForMethod(method) + assertEvents(span.events.map({ TestSpanEvent($0) })) + assertAttributes(span.attributes) + assertStatus(span.status) + assertErrors(span.errors) + } + + private func assertStreamContentsEqual( + _ array: [T], + _ stream: any AsyncSequence + ) async throws { + var streamElements = [T]() + for try await element in stream { + streamElements.append(element) } + #expect(streamElements == array) + } + + private func assertStreamContentsEqual( + _ array: [T], + _ stream: any AsyncSequence + ) async { + var streamElements = [T]() + for await element in stream { + streamElements.append(element) + } + #expect(streamElements == array) + } +} + +final class TracingInterceptorTests: XCTestCase { + override class func setUp() { + InstrumentationSystem.bootstrap(TestTracer()) } + // - MARK: Server Interceptor Tests + func testServerInterceptorErrorResponse() async throws { let methodDescriptor = MethodDescriptor( fullyQualifiedService: "TracingInterceptorTests", @@ -194,7 +543,7 @@ final class TracingInterceptorTests: XCTestCase { let tracer = InstrumentationSystem.tracer as! TestTracer XCTAssertEqual( - tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + tracer.getEventsForTestSpan(ofOperation: methodDescriptor.fullyQualifiedMethod).map { $0.name }, [ @@ -264,7 +613,7 @@ final class TracingInterceptorTests: XCTestCase { let tracer = InstrumentationSystem.tracer as! TestTracer XCTAssertEqual( - tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + tracer.getEventsForTestSpan(ofOperation: methodDescriptor.fullyQualifiedMethod).map { $0.name }, [ @@ -334,21 +683,74 @@ final class TracingInterceptorTests: XCTestCase { let tracer = InstrumentationSystem.tracer as! TestTracer XCTAssertEqual( - tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + tracer.getEventsForTestSpan(ofOperation: methodDescriptor.fullyQualifiedMethod).map { $0.name }, [ "Received request start", "Received request end", // Recorded when `response1` is sent - "Sending response part", "Sent response part", // Recorded when `response2` is sent - "Sending response part", "Sent response part", // Recorded when we're done sending response "Sent response end", ] ) } + + private func getTestSpanForMethod(_ methodDescriptor: MethodDescriptor) -> TestSpan { + let tracer = InstrumentationSystem.tracer as! TestTracer + return tracer.getSpan(ofOperation: methodDescriptor.fullyQualifiedMethod)! + } + + private func assertTestSpanComponents( + forMethod method: MethodDescriptor, + assertEvents: ([TestSpanEvent]) -> Void, + assertAttributes: (SpanAttributes) -> Void, + assertStatus: (SpanStatus?) -> Void, + assertErrors: ([TracingInterceptorTestError]) -> Void + ) { + let span = self.getTestSpanForMethod(method) + assertEvents(span.events.map({ TestSpanEvent($0) })) + assertAttributes(span.attributes) + assertStatus(span.status) + assertErrors(span.errors) + } + + private func assertStreamContentsEqual( + _ array: [T], + _ stream: any AsyncSequence + ) async throws { + var streamElements = [T]() + for try await element in stream { + streamElements.append(element) + } + XCTAssertEqual(streamElements, array) + } + + private func assertStreamContentsEqual( + _ array: [T], + _ stream: any AsyncSequence + ) async { + var streamElements = [T]() + for await element in stream { + streamElements.append(element) + } + XCTAssertEqual(streamElements, array) + } +} + +enum OTelTracingInterceptorTestAddressType { + case ipv4 + case ipv6 + case uds + + static let allCases: [Self] = [.ipv4, .ipv6, .uds] +} + +struct OTelTracingInterceptorTestCaseValues { + let remotePeerAddress: String + let localPeerAddress: String + let expectedSpanAttributes: SpanAttributes } diff --git a/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift index a5af7df..a6b64b3 100644 --- a/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift +++ b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift @@ -23,9 +23,12 @@ final class TestTracer: Tracer { private let testSpans: Mutex<[String: TestSpan]> = .init([:]) - func getEventsForTestSpan(ofOperationName operationName: String) -> [SpanEvent] { - let span = self.testSpans.withLock({ $0[operationName] }) - return span?.events ?? [] + func getSpan(ofOperation operationName: String) -> TestSpan? { + self.testSpans.withLock { $0[operationName] } + } + + func getEventsForTestSpan(ofOperation operationName: String) -> [SpanEvent] { + self.getSpan(ofOperation: operationName)?.events ?? [] } func extract( @@ -75,6 +78,7 @@ final class TestSpan: Span, Sendable { var attributes: Tracing.SpanAttributes var status: Tracing.SpanStatus? var events: [Tracing.SpanEvent] = [] + var errors: [TracingInterceptorTestError] } private let state: Mutex @@ -98,13 +102,26 @@ final class TestSpan: Span, Sendable { self.state.withLock { $0.events } } + var status: SpanStatus? { + self.state.withLock { $0.status } + } + + var errors: [TracingInterceptorTestError] { + self.state.withLock { $0.errors } + } + init( context: ServiceContextModule.ServiceContext, operationName: String, attributes: Tracing.SpanAttributes = [:], isRecording: Bool = true ) { - let state = State(context: context, operationName: operationName, attributes: attributes) + let state = State( + context: context, + operationName: operationName, + attributes: attributes, + errors: [] + ) self.state = Mutex(state) self.isRecording = isRecording } @@ -122,12 +139,8 @@ final class TestSpan: Span, Sendable { attributes: Tracing.SpanAttributes, at instant: @autoclosure () -> Instant ) where Instant: Tracing.TracerInstant { - self.setStatus( - .init( - code: .error, - message: "Error: \(error), attributes: \(attributes), at instant: \(instant())" - ) - ) + // For the purposes of these tests, we don't really care about the error being thrown + self.state.withLock { $0.errors.append(TracingInterceptorTestError.testError) } } func addLink(_ link: Tracing.SpanLink) { @@ -137,7 +150,7 @@ final class TestSpan: Span, Sendable { } func end(at instant: @autoclosure () -> Instant) where Instant: Tracing.TracerInstant { - self.setStatus(.init(code: .ok, message: "Ended at instant: \(instant())")) + // no-op } } @@ -192,3 +205,33 @@ struct TestWriter: RPCWriterProtocol { } } } + +struct TestSpanEvent: Equatable, CustomDebugStringConvertible { + var name: String + var attributes: SpanAttributes + + var debugDescription: String { + var attributesDescription = "" + self.attributes.forEach { key, value in + attributesDescription += " \(key): \(value)," + } + + return """ + (name: \(self.name), attributes: [\(attributesDescription)]) + """ + } + + init(_ name: String, _ attributes: SpanAttributes) { + self.name = name + self.attributes = attributes + } + + init(_ spanEvent: SpanEvent) { + self.name = spanEvent.name + self.attributes = spanEvent.attributes + } +} + +enum TracingInterceptorTestError: Error, Equatable { + case testError +} diff --git a/dev/license-check.sh b/dev/license-check.sh index ce643ac..a0ffd42 100755 --- a/dev/license-check.sh +++ b/dev/license-check.sh @@ -88,7 +88,7 @@ check_copyright_headers() { actual_sha=$(head -n "$((drop_first + expected_lines))" "$filename" \ | tail -n "$expected_lines" \ - | sed -e 's/201[56789]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' \ + | sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' \ | shasum \ | awk '{print $1}')