Skip to content

Commit

Permalink
Merge pull request #24 from orlandos-nl/jo-minimal-mdns-support
Browse files Browse the repository at this point in the history
Minimal mDNS support
  • Loading branch information
Joannis authored Dec 10, 2022
2 parents 1d1eb5f + 677f035 commit c4a2290
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 9 deletions.
1 change: 1 addition & 0 deletions Sources/DNSClient/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public final class DNSClient: Resolver {
let dnsDecoder: DNSDecoder
let channel: Channel
let primaryAddress: SocketAddress
internal var isMulticast = false
var loop: EventLoop {
return channel.eventLoop
}
Expand Down
14 changes: 14 additions & 0 deletions Sources/DNSClient/DNSClient+Connect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ extension DNSClient {
}
}

public static func connectMulticast(on group: EventLoopGroup) -> EventLoopFuture<DNSClient> {
do {
let address = try SocketAddress(ipAddress: "224.0.0.251", port: 5353)

return connect(on: group, config: [address]).flatMap { client in
let channel = client.channel as! MulticastChannel
client.isMulticast = true
return channel.joinGroup(address).map { client }
}
} catch {
return group.next().makeFailedFuture(UnableToParseConfig())
}
}

/// Connect to the dns server using TCP
///
/// - parameters:
Expand Down
7 changes: 6 additions & 1 deletion Sources/DNSClient/DNSClient+Query.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ extension DNSClient {
channel.eventLoop.flatSubmit {
self.messageID = self.messageID &+ 1

var options: MessageOptions = [.standardQuery, .recursionDesired]
var options: MessageOptions = [.standardQuery]

if !self.isMulticast {
options.insert(.recursionDesired)
}

if let additionalOptions = additionalOptions {
options.insert(additionalOptions)
}
Expand Down
6 changes: 5 additions & 1 deletion Sources/DNSClient/DNSDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ final class DNSDecoder: ChannelInboundHandler {
authorities: authorities,
additionalData: additionalData
)

if !header.options.contains(.answer) {
return
}

guard let query = messageCache[header.id] else {
throw UnknownQuery()
return
}

query.promise.succeed(message)
Expand Down
7 changes: 2 additions & 5 deletions Sources/DNSClient/DNSEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,11 @@ final class DNSEncoder: ChannelOutboundHandler {
let header = message.header

out.write(header)
var labelIndices = [String : UInt16]()

for question in message.questions {
for label in question.labels {
out.writeInteger(label.length, endianness: .big)
out.writeBytes(label.label)
}
out.writeCompressedLabels(question.labels, labelIndices: &labelIndices)

out.writeInteger(0, endianness: .big, as: UInt8.self)
out.writeInteger(question.type.rawValue, endianness: .big)
out.writeInteger(question.questionClass.rawValue, endianness: .big)
}
Expand Down
47 changes: 45 additions & 2 deletions Sources/DNSClient/Messages/Message.swift
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,58 @@ struct ZoneAuthority {
let minimumExpireTimeout: UInt32
}

extension Array where Element == DNSLabel {
extension Sequence where Element == DNSLabel {
public var string: String {
return self.compactMap { label in
if let string = String(bytes: label.label, encoding: .utf8), string.count > 0 {
return string
}

return nil
}.joined(separator: ".")
}.joined(separator: ".")
}
}

extension ByteBuffer {
/// Either write label index or list of labels
@discardableResult
mutating func writeCompressedLabels(_ labels: [DNSLabel], labelIndices: inout [String: UInt16]) -> Int {
var written = 0
var labels = labels
while !labels.isEmpty {
let label = labels.removeFirst()
// use combined labels as a key for a position in the packet
let key = labels.string
// if position exists output position or'ed with 0xc000 and return
if let labelIndex = labelIndices[key] {
written += writeInteger(labelIndex | 0xc000)
return written
} else {
// if no position exists for this combination of labels output the first label
labelIndices[key] = numericCast(writerIndex)
written += writeInteger(UInt8(label.label.count))
written += writeBytes(label.label)
}
}
// write end of labels
written += writeInteger(UInt8(0))
return written
}

/// write labels into DNS packet
@discardableResult
mutating func writeLabels(_ labels: [DNSLabel]) -> Int {
var written = 0
for label in labels {
written += writeInteger(UInt8(label.label.count))
written += writeBytes(label.label)
}

return written
}

func labelsSize(_ labels: [DNSLabel]) -> Int {
return labels.reduce(0, { $0 + 2 + $1.label.count })
}
}

Expand Down
3 changes: 3 additions & 0 deletions Sources/DNSClient/Messages/MessageOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public struct MessageOptions: OptionSet, ExpressibleByIntegerLiteral {
public static let resultCodeNameError: MessageOptions = 0b00000000_00000011
public static let resultCodeNotImplemented: MessageOptions = 0b00000000_00000100
public static let resultCodeNotRefused: MessageOptions = 0b00000000_00000101

public static let multicastResponse: MessageOptions = 0b10000000_00000000
public static let mutlicastUnauthenticatedDataAcceptable: MessageOptions = 0b00000000_00010000

public var isAnswer: Bool {
return self.contains(.answer)
Expand Down
10 changes: 10 additions & 0 deletions Tests/DNSClientTests/DNSTCPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ final class DNSTCPClientTests: XCTestCase {
self.waitForExpectations(timeout: 5, handler: nil)
}

// func testMulticastDNS() async throws {
// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
// let client = try await DNSClient.connectMulticast(on: eventLoopGroup).get()
// let addresses = try await client.sendQuery(
// forHost: "my-host.local",
// type: .any
// ).get()
// print(addresses)
// }

func testThreadSafety() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let client = try await DNSClient.connectTCP(
Expand Down

0 comments on commit c4a2290

Please sign in to comment.