Skip to content

Commit

Permalink
Merge pull request #33 from orlandos-nl/jo/sendable-support
Browse files Browse the repository at this point in the history
Add Sendable support
  • Loading branch information
Joannis authored Oct 24, 2024
2 parents 9c7cc44 + bc4091d commit 3cd2795
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/swift.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
swift-version: ["5.7", "5.8"]
swift-version: ["5.7", "5.8", "5.9", "5.10", "6.0"]
steps:
- uses: actions/checkout@v3

Expand Down
12 changes: 9 additions & 3 deletions Sources/DNSClient/Client.swift
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import NIO
import NIOConcurrencyHelpers

/// A DNS client that can be used to send queries to a DNS server.
/// The client is thread-safe and can be used from multiple threads. Supports both UDP and TCP, and multicast DNS. This client is not a full implementation of the DNS protocol, but only supports the most common queries. If you need more advanced features, you should use the `sendQuery` method to send a custom query.
/// This client is not a full resolver, and does not support caching, recursion, or other advanced features. If you need a full resolver, use the `Resolver` class.
public final class DNSClient: Resolver {
public final class DNSClient: Resolver, Sendable {
let dnsDecoder: DNSDecoder
let channel: Channel
let primaryAddress: SocketAddress
internal var isMulticast = false
private let isMulticastBox = NIOLockedValueBox(false)
internal var isMulticast: Bool {
get { isMulticastBox.withLockedValue { $0 } }
set { isMulticastBox.withLockedValue { $0 = newValue } }
}

var loop: EventLoop {
return channel.eventLoop
}
// Each query has an ID to keep track of which response belongs to which query
var messageID: UInt16 = 0
let messageID: Atomic<UInt16> = Atomic(value: 0)

internal init(channel: Channel, address: SocketAddress, decoder: DNSDecoder) {
self.channel = channel
Expand Down
5 changes: 3 additions & 2 deletions Sources/DNSClient/DNSClient+Query.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import NIO
import NIOConcurrencyHelpers

extension DNSClient {
/// Request A records
Expand Down Expand Up @@ -79,7 +80,7 @@ extension DNSClient {
/// - returns: A future with the response message
public func sendQuery(forHost address: String, type: DNSResourceType, additionalOptions: MessageOptions? = nil) -> EventLoopFuture<Message> {
channel.eventLoop.flatSubmit {
self.messageID = self.messageID &+ 1
let messageID = self.messageID.add(1)

var options: MessageOptions = [.standardQuery]

Expand All @@ -91,7 +92,7 @@ extension DNSClient {
options.insert(additionalOptions)
}

let header = DNSMessageHeader(id: self.messageID, options: options, questionCount: 1, answerCount: 0, authorityCount: 0, additionalRecordCount: 0)
let header = DNSMessageHeader(id: messageID, options: options, questionCount: 1, answerCount: 0, authorityCount: 0, additionalRecordCount: 0)
let labels = address.split(separator: ".").map(String.init).map(DNSLabel.init)
let question = QuestionSection(labels: labels, type: type, questionClass: .internet)
let message = Message(header: header, questions: [question], answers: [], authorities: [], additionalData: [])
Expand Down
76 changes: 39 additions & 37 deletions Sources/DNSClient/DNSDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,53 @@ final class EnvelopeInboundChannel: ChannelInboundHandler {
}
}

final class DNSDecoder: ChannelInboundHandler {
public final class DNSDecoder: ChannelInboundHandler {
let group: EventLoopGroup
var messageCache = [UInt16: SentQuery]()
var clients = [ObjectIdentifier: DNSClient]()
weak var mainClient: DNSClient?

init(group: EventLoopGroup) {
public init(group: EventLoopGroup) {
self.group = group
}

public typealias InboundIn = ByteBuffer
public typealias OutboundOut = Never

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = self.unwrapInboundIn(data)
var buffer = envelope
let message: Message

guard let header = buffer.readHeader() else {
context.fireErrorCaught(ProtocolError())
do {
message = try Self.parse(unwrapInboundIn(data))
} catch {
context.fireErrorCaught(error)
return
}

if !message.header.options.contains(.answer) {
return
}

guard let query = messageCache[message.header.id] else {
return
}

query.promise.succeed(message)
messageCache[message.header.id] = nil
}

public static func parse(_ buffer: ByteBuffer) throws -> Message {
var buffer = buffer

guard let header = buffer.readHeader() else {
throw ProtocolError()
}

var questions = [QuestionSection]()

for _ in 0..<header.questionCount {
guard let question = buffer.readQuestion() else {
context.fireErrorCaught(ProtocolError())
return
throw ProtocolError()
}

questions.append(question)
Expand All @@ -59,37 +78,20 @@ final class DNSDecoder: ChannelInboundHandler {
return records
}

do {
let answers = try resourceRecords(count: header.answerCount)
let authorities = try resourceRecords(count: header.authorityCount)
let additionalData = try resourceRecords(count: header.additionalRecordCount)

let message = Message(
header: header,
questions: questions,
answers: answers,
authorities: authorities,
additionalData: additionalData
)

if !header.options.contains(.answer) {
return
}

guard let query = messageCache[header.id] else {
return
}

query.promise.succeed(message)
messageCache[header.id] = nil
} catch {
messageCache[header.id]?.promise.fail(error)
messageCache[header.id] = nil
context.fireErrorCaught(error)
}
let answers = try resourceRecords(count: header.answerCount)
let authorities = try resourceRecords(count: header.authorityCount)
let additionalData = try resourceRecords(count: header.additionalRecordCount)

return Message(
header: header,
questions: questions,
answers: answers,
authorities: authorities,
additionalData: additionalData
)
}

func errorCaught(context ctx: ChannelHandlerContext, error: Error) {
public func errorCaught(context ctx: ChannelHandlerContext, error: Error) {
for query in self.messageCache.values {
query.promise.fail(error)
}
Expand Down
42 changes: 33 additions & 9 deletions Sources/DNSClient/DNSEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,36 @@ final class UInt16FrameEncoder: MessageToByteEncoder {
}
}

final class DNSEncoder: ChannelOutboundHandler {
typealias OutboundIn = Message
typealias OutboundOut = ByteBuffer
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
public final class DNSEncoder: ChannelOutboundHandler {
public typealias OutboundIn = Message
public typealias OutboundOut = ByteBuffer

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let message = unwrapOutboundIn(data)
let data = DNSEncoder.encodeMessage(message, allocator: context.channel.allocator)
do {
var labelIndices = [String: UInt16]()
let data = try DNSEncoder.encodeMessage(
message,
allocator: context.channel.allocator,
labelIndices: &labelIndices
)

context.write(wrapOutboundOut(data), promise: promise)
context.write(wrapOutboundOut(data), promise: promise)
} catch {
context.fireErrorCaught(error)
}
}

static func encodeMessage(_ message: Message, allocator: ByteBufferAllocator) -> ByteBuffer {
public static func encodeMessage(
_ message: Message,
allocator: ByteBufferAllocator,
labelIndices: inout [String: UInt16]
) throws -> ByteBuffer {
var out = allocator.buffer(capacity: 512)

let header = message.header

out.write(header)
var labelIndices = [String : UInt16]()

for question in message.questions {
out.writeCompressedLabels(question.labels, labelIndices: &labelIndices)
Expand All @@ -73,6 +85,18 @@ final class DNSEncoder: ChannelOutboundHandler {
out.writeInteger(question.questionClass.rawValue, endianness: .big)
}

for answer in message.answers {
try out.writeAnyRecord(answer, labelIndices: &labelIndices)
}

for authority in message.authorities {
try out.writeAnyRecord(authority, labelIndices: &labelIndices)
}

for additionalData in message.additionalData {
try out.writeAnyRecord(additionalData, labelIndices: &labelIndices)
}

return out
}
}
38 changes: 38 additions & 0 deletions Sources/DNSClient/Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,44 @@ extension ByteBuffer {
return QuestionSection(labels: labels, type: type, questionClass: dataClass)
}

mutating func writeRecord<RecordType: DNSResource>(
_ record: ResourceRecord<RecordType>,
labelIndices: inout [String: UInt16]
) throws {
writeCompressedLabels(record.domainName, labelIndices: &labelIndices)
writeInteger(record.dataType)
writeInteger(record.dataClass)
writeInteger(record.ttl)

try writeLengthPrefixed(as: UInt16.self) { buffer in
record.resource.write(into: &buffer, labelIndices: &labelIndices)
}
}

mutating func writeAnyRecord(
_ record: Record,
labelIndices: inout [String: UInt16]
) throws {
switch record {
case .aaaa(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .a(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .txt(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .cname(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .srv(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .mx(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .ptr(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
case .other(let resourceRecord):
try writeRecord(resourceRecord, labelIndices: &labelIndices)
}
}

mutating func readRecord() -> Record? {
guard
let labels = readLabels(),
Expand Down
Loading

0 comments on commit 3cd2795

Please sign in to comment.