diff --git a/Cargo.toml b/Cargo.toml index b10befa..7c6059f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "mdns-sd" version = "0.11.5" authors = ["keepsimple "] edition = "2018" -rust-version = "1.63.0" +rust-version = "1.65.0" license = "Apache-2.0 OR MIT" repository = "https://github.com/keepsimple1/mdns-sd" documentation = "https://docs.rs/mdns-sd" @@ -17,6 +17,7 @@ logging = ["log"] default = ["async", "logging"] [dependencies] +fastrand = "2.1" flume = { version = "0.11", default-features = false } # channel between threads if-addrs = { version = "0.13", features = ["link-local"] } # get local IP addresses log = { version = "0.4", optional = true } # logging @@ -24,7 +25,7 @@ polling = "2.1" # select/poll sockets socket2 = { version = "0.5.5", features = ["all"] } # socket APIs [dev-dependencies] -env_logger = { version = "= 0.10.2", default-features = false, features= ["humantime"] } +env_logger = "0.11" fastrand = "2.1" humantime = "2.1" test-log = "= 0.2.14" diff --git a/examples/register.rs b/examples/register.rs index cc1281f..c2e4f1b 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -2,11 +2,11 @@ //! //! Run with: //! -//! cargo run --example register +//! cargo run --example register [options] //! //! Example: //! -//! cargo run --example register _my-hello._udp test1 +//! cargo run --example register _my-hello._udp instance1 host1 //! //! Options: //! "--unregister": automatically unregister after 2 seconds. @@ -16,7 +16,9 @@ use mdns_sd::{DaemonEvent, IfKind, ServiceDaemon, ServiceInfo}; use std::{env, thread, time::Duration}; fn main() { - env_logger::init(); + // setup env_logger with more precise timestamp. + let mut builder = env_logger::Builder::from_default_env(); + builder.format_timestamp_millis().init(); // Simple command line options. let args: Vec = env::args().collect(); @@ -52,11 +54,18 @@ fn main() { return; } }; + let hostname = match args.get(3) { + Some(arg) => arg, + None => { + print_usage(); + return; + } + }; // With `enable_addr_auto()`, we can give empty addrs and let the lib find them. // If the caller knows specific addrs to use, then assign the addrs here. let my_addrs = ""; - let service_hostname = format!("{}{}", instance_name, &service_type); + let service_hostname = format!("{}.local.", hostname); let port = 3456; // The key string in TXT properties is case insensitive. Only the first @@ -106,10 +115,11 @@ fn main() { fn print_usage() { println!("Usage:"); - println!("cargo run --example register [--unregister]"); + println!("cargo run --example register [options]"); println!("Options:"); println!("--unregister: automatically unregister after 2 seconds"); + println!("--disable-ipv6: not to use IPv6 interfaces."); println!(); println!("For example:"); - println!("cargo run --example register _my-hello._udp test1"); + println!("cargo run --example register _my-hello._udp instance1 host1"); } diff --git a/src/dns_cache.rs b/src/dns_cache.rs index e2d56fd..8ea30ce 100644 --- a/src/dns_cache.rs +++ b/src/dns_cache.rs @@ -277,7 +277,7 @@ impl DnsCache { srv_records.retain(|srv| { let expired = srv.get_record().is_expired(now); if expired { - debug!("expired SRV: {}: {}", ty_domain, srv.get_name()); + debug!("expired SRV: {}: {:?}", ty_domain, srv); expired_instances .entry(ty_domain.to_string()) .or_insert_with(HashSet::new) @@ -299,7 +299,7 @@ impl DnsCache { let expired = x.get_record().is_expired(now); if expired { if let Some(dns_ptr) = x.any().downcast_ref::() { - debug!("expired PTR: {:?}", dns_ptr); + debug!("expired PTR: domain:{ty_domain} record: {:?}", dns_ptr); expired_instances .entry(ty_domain.to_string()) .or_insert_with(HashSet::new) diff --git a/src/dns_parser.rs b/src/dns_parser.rs index 9b5810e..fbbdbcc 100644 --- a/src/dns_parser.rs +++ b/src/dns_parser.rs @@ -29,6 +29,21 @@ pub const TYPE_SRV: u16 = 33; pub const TYPE_NSEC: u16 = 47; // Negative responses pub const TYPE_ANY: u16 = 255; +pub(crate) const fn rr_type_name(rr_type: u16) -> &'static str { + match rr_type { + TYPE_A => "TYPE_A", + TYPE_CNAME => "TYPE_CNAME", + TYPE_PTR => "TYPE_PTR", + TYPE_HINFO => "TYPE_HINFO", + TYPE_TXT => "TYPE_TXT", + TYPE_AAAA => "TYPE_AAAA", + TYPE_SRV => "TYPE_SRV", + TYPE_NSEC => "TYPE_NSEC", + TYPE_ANY => "TYPE_ANY", + _ => "type_others", + } +} + pub const CLASS_IN: u16 = 1; pub const CLASS_MASK: u16 = 0x7FFF; pub const CLASS_CACHE_FLUSH: u16 = 0x8000; @@ -116,6 +131,9 @@ pub struct DnsRecord { /// Support re-query an instance before its PTR record expires. /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 refresh: u64, // UNIX time in millis + + /// If conflict resolution decides to change the name, this is the new one. + new_name: Option, } impl DnsRecord { @@ -135,6 +153,7 @@ impl DnsRecord { created, expires, refresh, + new_name: None, } } @@ -230,6 +249,27 @@ impl DnsRecord { self.ttl -= (elapsed / 1000) as u32; } } + + pub(crate) fn set_new_name(&mut self, new_name: String) { + if new_name == self.entry.name { + self.new_name = None; + } else { + self.new_name = Some(new_name); + } + } + + pub(crate) fn get_new_name(&self) -> Option<&str> { + self.new_name.as_deref() + } + + /// Return the new name if exists, otherwise the regular name in DnsEntry. + pub(crate) fn get_name(&self) -> &str { + self.new_name.as_deref().unwrap_or(&self.entry.name) + } + + pub(crate) fn get_original_name(&self) -> &str { + &self.entry.name + } } impl PartialEq for DnsRecord { @@ -247,6 +287,41 @@ pub(crate) trait DnsRecordExt: fmt::Debug { /// Returns whether `other` record is considered the same except TTL. fn matches(&self, other: &dyn DnsRecordExt) -> bool; + /// Returns whether `other` record has the same rdata. + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool; + + /// Returns the result based on a byte-level comparison of `rdata`. + /// If `other` is not valid, returns `Greater`. + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering; + + /// Returns the result based on "lexicographically later" defined below. + fn compare(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + /* + RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2 + + ... The determination of "lexicographically later" is performed by first + comparing the record class (excluding the cache-flush bit described + in Section 10.2), then the record type, then raw comparison of the + binary content of the rdata without regard for meaning or structure. + If the record classes differ, then the numerically greater class is + considered "lexicographically later". Otherwise, if the record types + differ, then the numerically greater type is considered + "lexicographically later". If the rrtype and rrclass both match, + then the rdata is compared. ... + */ + match self.get_class().cmp(&other.get_class()) { + cmp::Ordering::Equal => match self.get_type().cmp(&other.get_type()) { + cmp::Ordering::Equal => self.compare_rdata(other), + not_equal => not_equal, + }, + not_equal => not_equal, + } + } + + /// Returns a human-readable string of rdata. + fn rdata_print(&self) -> String; + + /// Returns the class only, excluding class_flush / unique bit. fn get_class(&self) -> u16 { self.get_record().entry.class } @@ -255,9 +330,15 @@ pub(crate) trait DnsRecordExt: fmt::Debug { self.get_record().entry.cache_flush } + /// Return the new name if exists, otherwise the regular name in DnsEntry. fn get_name(&self) -> &str { - self.get_record().entry.name.as_str() + self.get_record().get_name() } + + fn get_original_name(&self) -> &str { + self.get_record().get_original_name() + } + fn get_type(&self) -> u16 { self.get_record().entry.ty } @@ -355,6 +436,25 @@ impl DnsRecordExt for DnsAddress { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_a) = other.any().downcast_ref::() { + return self.address == other_a.address; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + if let Some(other_a) = other.any().downcast_ref::() { + self.address.cmp(&other_a.address) + } else { + cmp::Ordering::Greater + } + } + + fn rdata_print(&self) -> String { + format!("{}", self.address) + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -398,6 +498,25 @@ impl DnsRecordExt for DnsPointer { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_ptr) = other.any().downcast_ref::() { + return self.alias == other_ptr.alias; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + if let Some(other_ptr) = other.any().downcast_ref::() { + self.alias.cmp(&other_ptr.alias) + } else { + cmp::Ordering::Greater + } + } + + fn rdata_print(&self) -> String { + self.alias.clone() + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -467,6 +586,55 @@ impl DnsRecordExt for DnsSrv { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_srv) = other.any().downcast_ref::() { + return self.host == other_srv.host + && self.port == other_srv.port + && self.weight == other_srv.weight + && self.priority == other_srv.priority; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + let Some(other_srv) = other.any().downcast_ref::() else { + return cmp::Ordering::Greater; + }; + + // 1. compare `priority` + match self + .priority + .to_be_bytes() + .cmp(&other_srv.priority.to_be_bytes()) + { + cmp::Ordering::Equal => { + // 2. compare `weight` + match self + .weight + .to_be_bytes() + .cmp(&other_srv.weight.to_be_bytes()) + { + cmp::Ordering::Equal => { + // 3. compare `port`. + match self.port.to_be_bytes().cmp(&other_srv.port.to_be_bytes()) { + cmp::Ordering::Equal => self.host.cmp(&other_srv.host), + not_equal => not_equal, + } + } + not_equal => not_equal, + } + } + not_equal => not_equal, + } + } + + fn rdata_print(&self) -> String { + format!( + "priority: {}, weight: {}, port: {}, host: {}", + self.priority, self.weight, self.port, self.host + ) + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -522,6 +690,25 @@ impl DnsRecordExt for DnsTxt { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_txt) = other.any().downcast_ref::() { + return self.text == other_txt.text; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + if let Some(other_txt) = other.any().downcast_ref::() { + self.text.cmp(&other_txt.text) + } else { + cmp::Ordering::Greater + } + } + + fn rdata_print(&self) -> String { + format!("{:?}", decode_txt(&self.text)) + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -581,6 +768,28 @@ impl DnsRecordExt for DnsHostInfo { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_hinfo) = other.any().downcast_ref::() { + return self.cpu == other_hinfo.cpu && self.os == other_hinfo.os; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + if let Some(other_hinfo) = other.any().downcast_ref::() { + match self.cpu.cmp(&other_hinfo.cpu) { + cmp::Ordering::Equal => self.os.cmp(&other_hinfo.os), + ordering => ordering, + } + } else { + cmp::Ordering::Greater + } + } + + fn rdata_print(&self) -> String { + format!("cpu: {}, os: {}", self.cpu, self.os) + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -664,6 +873,33 @@ impl DnsRecordExt for DnsNSec { false } + fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool { + if let Some(other_record) = other.any().downcast_ref::() { + return self.next_domain == other_record.next_domain + && self.type_bitmap == other_record.type_bitmap; + } + false + } + + fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering { + if let Some(other_nsec) = other.any().downcast_ref::() { + match self.next_domain.cmp(&other_nsec.next_domain) { + cmp::Ordering::Equal => self.type_bitmap.cmp(&other_nsec.type_bitmap), + ordering => ordering, + } + } else { + cmp::Ordering::Greater + } + } + + fn rdata_print(&self) -> String { + format!( + "next_domain: {}, type_bitmap len: {}", + self.next_domain, + self.type_bitmap.len() + ) + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -714,7 +950,7 @@ impl DnsOutPacket { let start_size = self.size; let record = record_ext.get_record(); - self.write_name(&record.entry.name); + self.write_name(record.get_name()); self.write_short(record.entry.ty); if record.entry.cache_flush { // check "multicast" @@ -901,9 +1137,9 @@ pub(crate) struct DnsOutgoing { multicast: bool, pub(crate) questions: Vec, pub(crate) answers: Vec<(DnsRecordBox, u64)>, - pub(crate) authorities: Vec, + pub(crate) authorities: Vec, pub(crate) additionals: Vec, - pub(crate) known_answer_count: i64, + pub(crate) known_answer_count: i64, // for internal maintenance only } impl DnsOutgoing { @@ -967,8 +1203,12 @@ impl DnsOutgoing { } /// A workaround as Rust doesn't allow us to pass DnsRecordBox in as `impl DnsRecordExt` - pub(crate) fn add_additional_answer_box(&mut self, answer_box: DnsRecordBox) { - self.additionals.push(answer_box); + pub(crate) fn add_answer_box(&mut self, answer_box: DnsRecordBox) { + self.answers.push((answer_box, 0)); + } + + pub(crate) fn add_authority(&mut self, record: DnsRecordBox) { + self.authorities.push(record); } /// Returns true if `answer` is added to the outgoing msg. @@ -996,7 +1236,6 @@ impl DnsOutgoing { answer: impl DnsRecordExt + Send + 'static, now: u64, ) -> bool { - debug!("Check for add_answer_at_time"); if now == 0 || !answer.get_record().is_expired(now) { debug!("add_answer push: {:?}", &answer); self.answers.push((Box::new(answer), now)); @@ -1115,7 +1354,7 @@ impl DnsOutgoing { } for auth in self.authorities.iter() { - auth_count += u16::from(packet.write_record(auth, 0)); + auth_count += u16::from(packet.write_record(auth.as_ref(), 0)); } for addi in self.additionals.iter() { @@ -1173,9 +1412,9 @@ pub struct DnsIncoming { offset: usize, data: Vec, pub(crate) questions: Vec, - /// This field includes records in the `answers` section - /// and in the `additionals` section. pub(crate) answers: Vec, + pub(crate) authorities: Vec, + pub(crate) additional: Vec, pub(crate) id: u16, flags: u16, pub(crate) num_questions: u16, @@ -1191,6 +1430,8 @@ impl DnsIncoming { data, questions: Vec::new(), answers: Vec::new(), + authorities: Vec::new(), + additional: Vec::new(), id: 0, flags: 0, num_questions: 0, @@ -1199,9 +1440,31 @@ impl DnsIncoming { num_additionals: 0, }; + /* + RFC 1035 section 4.1: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1 + ... + All communications inside of the domain protocol are carried in a single + format called a message. The top level format of message is divided + into 5 sections (some of which are empty in certain cases) shown below: + + +---------------------+ + | Header | + +---------------------+ + | Question | the question for the name server + +---------------------+ + | Answer | RRs answering the question + +---------------------+ + | Authority | RRs pointing toward an authority + +---------------------+ + | Additional | RRs holding additional information + +---------------------+ + */ incoming.read_header()?; incoming.read_questions()?; - incoming.read_others()?; + incoming.read_answers()?; + incoming.read_authorities()?; + incoming.read_additional()?; + Ok(incoming) } @@ -1266,14 +1529,25 @@ impl DnsIncoming { Ok(()) } - /// Decodes all answers, authorities and additionals. - fn read_others(&mut self) -> Result<()> { - let n = self - .num_answers - .checked_add(self.num_authorities) - .and_then(|x| x.checked_add(self.num_additionals)) - .ok_or_else(|| Error::Msg("read_others: overflow".to_string()))?; - debug!("read_others: {}", n); + fn read_answers(&mut self) -> Result<()> { + self.answers = self.read_rr_records(self.num_answers)?; + Ok(()) + } + + fn read_authorities(&mut self) -> Result<()> { + self.authorities = self.read_rr_records(self.num_authorities)?; + Ok(()) + } + + fn read_additional(&mut self) -> Result<()> { + self.additional = self.read_rr_records(self.num_additionals)?; + Ok(()) + } + + /// Decodes a sequence of RR records (in answers, authorities and additionals). + fn read_rr_records(&mut self, count: u16) -> Result> { + debug!("read_rr_records: {}", count); + let mut rr_records = Vec::new(); // RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.1 // @@ -1302,7 +1576,7 @@ impl DnsIncoming { // Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes. const RR_HEADER_REMAIN: usize = 10; - for _ in 0..n { + for _ in 0..count { let name = self.read_name()?; let slice = &self.data[self.offset..]; @@ -1408,11 +1682,11 @@ impl DnsIncoming { if let Some(record) = rec { debug!("read_others: {:?}", &record); - self.answers.push(record); + rr_records.push(record); } } - Ok(()) + Ok(rr_records) } fn read_char_string(&mut self) -> String { diff --git a/src/service_daemon.rs b/src/service_daemon.rs index dc6fae8..a8a8c41 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -29,17 +29,17 @@ // in Service Discovery, the basic data structure is "Service Info". One Service Info // corresponds to a set of DNS Resource Records. #[cfg(feature = "logging")] -use crate::log::{debug, error, warn}; +use crate::log::{debug, error, info, warn}; use crate::{ dns_cache::DnsCache, dns_parser::{ - current_time_millis, ip_address_to_type, split_sub_domain, DnsAddress, DnsIncoming, - DnsOutgoing, DnsPointer, DnsRecordExt, DnsSrv, DnsTxt, CLASS_CACHE_FLUSH, CLASS_IN, - FLAGS_AA, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, MAX_MSG_ABSOLUTE, TYPE_A, TYPE_AAAA, TYPE_ANY, - TYPE_PTR, TYPE_SRV, TYPE_TXT, + current_time_millis, ip_address_to_type, rr_type_name, split_sub_domain, DnsAddress, + DnsIncoming, DnsOutgoing, DnsPointer, DnsRecordBox, DnsRecordExt, DnsSrv, DnsTxt, + CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_AA, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, MAX_MSG_ABSOLUTE, + TYPE_A, TYPE_AAAA, TYPE_ANY, TYPE_PTR, TYPE_SRV, TYPE_TXT, }, error::{Error, Result}, - service_info::ServiceInfo, + service_info::{DnsRegistry, Probe, ServiceInfo, ServiceStatus}, Receiver, }; use flume::{bounded, Sender, TrySendError}; @@ -593,6 +593,9 @@ impl ServiceDaemon { zc.resolve_updated_instances(instance_set); } + // Send out probing queries. + zc.probing_handler(); + // check IP changes. if now > next_ip_check { next_ip_check = now + IP_CHECK_INTERVAL_MILLIS; @@ -622,9 +625,9 @@ impl ServiceDaemon { zc.increase_counter(Counter::Register, 1); } - Command::RegisterResend(fullname) => { - debug!("announce service: {}", &fullname); - zc.exec_command_register_resend(fullname); + Command::RegisterResend(fullname, intf) => { + debug!("register-resend service: {fullname} on {:?}", &intf.addr); + zc.exec_command_register_resend(fullname, intf); } Command::Unregister(fullname, resp_s) => { @@ -889,8 +892,12 @@ struct Zeroconf { /// Local registered services, keyed by service full names. my_services: HashMap, + /// Received DNS records. cache: DnsCache, + /// Registered service records. + dns_registry_map: HashMap, + /// Active "Browse" commands. service_queriers: HashMap>, // @@ -967,6 +974,7 @@ impl Zeroconf { poll_id_count: 0, my_services: HashMap::new(), cache: DnsCache::new(), + dns_registry_map: HashMap::new(), hostname_resolvers: HashMap::new(), service_queriers: HashMap::new(), retransmissions: Vec::new(), @@ -1025,15 +1033,6 @@ impl Zeroconf { }); } - /// Add `addr` in my services that enabled `addr_auto`. - fn add_addr_in_my_services(&mut self, addr: IpAddr) { - for (_, service_info) in self.my_services.iter_mut() { - if service_info.is_addr_auto() { - service_info.insert_ipaddr(addr); - } - } - } - /// Remove `addr` in my services that enabled `addr_auto`. fn del_addr_in_my_services(&mut self, addr: &IpAddr) { for (_, service_info) in self.my_services.iter_mut() { @@ -1194,9 +1193,36 @@ impl Zeroconf { return; } - self.intf_socks.insert(intf, sock); + info!("add new interface {}: {new_ip}", intf.name); + let dns_registry = match self.dns_registry_map.get_mut(&intf) { + Some(registry) => registry, + None => self + .dns_registry_map + .entry(intf.clone()) + .or_insert(DnsRegistry::new()), + }; - self.add_addr_in_my_services(new_ip); + for (_, service_info) in self.my_services.iter_mut() { + if service_info.is_addr_auto() { + service_info.insert_ipaddr(new_ip); + + if announce_service_on_intf(dns_registry, service_info, &intf, &sock) { + info!( + "Announce service {} on {}", + service_info.get_fullname(), + intf.ip() + ); + service_info.set_status(&intf, ServiceStatus::Announced); + } else { + for timer in dns_registry.new_timers.drain(..) { + self.timers.push(Reverse(timer)); + } + service_info.set_status(&intf, ServiceStatus::Probing); + } + } + } + + self.intf_socks.insert(intf, sock); // Notify the monitors. self.notify_monitors(DaemonEvent::IpAdd(new_ip)); @@ -1225,9 +1251,9 @@ impl Zeroconf { } } - debug!("register service {:?}", &info); + info!("register service {:?}", &info); - let outgoing_addrs = self.send_unsolicited_response(&info); + let outgoing_addrs = self.send_unsolicited_response(&mut info); if !outgoing_addrs.is_empty() { self.notify_monitors(DaemonEvent::Announce( info.get_fullname().to_string(), @@ -1235,120 +1261,186 @@ impl Zeroconf { )); } - // RFC 6762 section 8.3. - // ..The Multicast DNS responder MUST send at least two unsolicited - // responses, one second apart. - let next_time = current_time_millis() + 1000; - // The key has to be lower case letter as DNS record name is case insensitive. // The info will have the original name. let service_fullname = info.get_fullname().to_lowercase(); - self.add_retransmission(next_time, Command::RegisterResend(service_fullname.clone())); self.my_services.insert(service_fullname, info); } /// Sends out announcement of `info` on every valid interface. /// Returns the list of interface IPs that sent out the announcement. - fn send_unsolicited_response(&self, info: &ServiceInfo) -> Vec { + fn send_unsolicited_response(&mut self, info: &mut ServiceInfo) -> Vec { let mut outgoing_addrs = Vec::new(); // Send the announcement on one interface per ip version. let mut multicast_sent_trackers = HashSet::new(); + let mut outgoing_intfs = Vec::new(); + for (intf, sock) in self.intf_socks.iter() { if let Some(tracker) = multicast_send_tracker(intf) { if multicast_sent_trackers.contains(&tracker) { continue; // No need to send again on the same interface with same ip version. } } - if self.broadcast_service_on_intf(info, intf, sock) { + + let dns_registry = match self.dns_registry_map.get_mut(intf) { + Some(registry) => registry, + None => self + .dns_registry_map + .entry(intf.clone()) + .or_insert(DnsRegistry::new()), + }; + + if announce_service_on_intf(dns_registry, info, intf, sock) { if let Some(tracker) = multicast_send_tracker(intf) { multicast_sent_trackers.insert(tracker); } outgoing_addrs.push(intf.ip()); + outgoing_intfs.push(intf.clone()); + + info!("Announce service {} on {}", info.get_fullname(), intf.ip()); + + info.set_status(intf, ServiceStatus::Announced); + } else { + for timer in dns_registry.new_timers.drain(..) { + self.timers.push(Reverse(timer)); + } + info.set_status(intf, ServiceStatus::Probing); } } + // RFC 6762 section 8.3. + // ..The Multicast DNS responder MUST send at least two unsolicited + // responses, one second apart. + let next_time = current_time_millis() + 1000; + for intf in outgoing_intfs { + self.add_retransmission( + next_time, + Command::RegisterResend(info.get_fullname().to_string(), intf), + ); + } + outgoing_addrs } - /// Send an unsolicited response for owned service via `intf_sock`. - /// Returns true if sent out successfully. - fn broadcast_service_on_intf( - &self, - info: &ServiceInfo, - intf: &Interface, - sock: &Socket, - ) -> bool { - let service_fullname = info.get_fullname(); - debug!("broadcast service {}", service_fullname); - let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA); - out.add_answer_at_time( - DnsPointer::new( - info.get_type(), - TYPE_PTR, - CLASS_IN, - info.get_other_ttl(), - info.get_fullname().to_string(), - ), - 0, - ); + /// Send probings or finish them if expired. Notify waiting services. + fn probing_handler(&mut self) { + let now = current_time_millis(); - if let Some(sub) = info.get_subtype() { - debug!("Adding subdomain {}", sub); - out.add_answer_at_time( - DnsPointer::new( - sub, - TYPE_PTR, - CLASS_IN, - info.get_other_ttl(), - info.get_fullname().to_string(), - ), - 0, - ); - } + for (intf, sock) in self.intf_socks.iter() { + let Some(dns_registry) = self.dns_registry_map.get_mut(intf) else { + continue; + }; - out.add_answer_at_time( - DnsSrv::new( - info.get_fullname(), - CLASS_IN | CLASS_CACHE_FLUSH, - info.get_host_ttl(), - info.get_priority(), - info.get_weight(), - info.get_port(), - info.get_hostname().to_string(), - ), - 0, - ); - out.add_answer_at_time( - DnsTxt::new( - info.get_fullname(), - CLASS_IN | CLASS_CACHE_FLUSH, - info.get_other_ttl(), - info.generate_txt(), - ), - 0, - ); + let mut expired_names = Vec::new(); + let mut out = DnsOutgoing::new(FLAGS_QR_QUERY); - let intf_addrs = info.get_addrs_on_intf(intf); - if intf_addrs.is_empty() { - debug!("No valid addrs to add on intf {:?}", &intf); - return false; - } - for address in intf_addrs { - out.add_answer_at_time( - DnsAddress::new( - info.get_hostname(), - ip_address_to_type(&address), - CLASS_IN | CLASS_CACHE_FLUSH, - info.get_host_ttl(), - address, - ), - 0, - ); - } + for (name, probe) in dns_registry.probing.iter_mut() { + if now >= probe.next_send { + if probe.expired(now) { + // move the record to active + expired_names.push(name.clone()); + } else { + out.add_question(name, TYPE_ANY); + + /* + RFC 6762 section 8.2: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2 + ... + for tiebreaking to work correctly in all + cases, the Authority Section must contain *all* the records and + proposed rdata being probed for uniqueness. + */ + for record in probe.records.iter() { + out.add_authority(record.clone()); + } - send_dns_outgoing(&out, intf, sock); - true + probe.update_next_send(now); + + // add timer + self.timers.push(Reverse(probe.next_send)); + } + } + } + + // send probing. + if !out.questions.is_empty() { + send_dns_outgoing(&out, intf, sock); + } + + let mut waiting_services = HashSet::new(); + + for name in expired_names { + let Some(probe) = dns_registry.probing.remove(&name) else { + continue; + }; + + // send notifications about name changes + for record in probe.records.iter() { + if let Some(new_name) = record.get_record().get_new_name() { + let event = DnsNameChange { + original: record.get_record().get_original_name().to_string(), + new_name: new_name.to_string(), + rr_type: record.get_type(), + intf_name: intf.name.to_string(), + }; + notify_monitors(&mut self.monitors, DaemonEvent::NameChange(event)); + } + } + + // move RR from probe to active. + info!( + "probe of '{name}' finished: move {} records to active. ({} waiting services)", + probe.records.len(), + probe.waiting_services.len(), + ); + + match dns_registry.active.get_mut(&name) { + Some(records) => { + records.extend(probe.records); + } + None => { + dns_registry.active.insert(name, probe.records); + } + } + + waiting_services.extend(probe.waiting_services); + } + + // wake up services waiting. + for service_name in waiting_services { + info!("try to announce service {service_name}"); + if let Some(info) = self.my_services.get_mut(&service_name) { + if announce_service_on_intf(dns_registry, info, intf, sock) { + let next_time = now + 1000; + let command = + Command::RegisterResend(info.get_fullname().to_string(), intf.clone()); + self.retransmissions.push(ReRun { next_time, command }); + self.timers.push(Reverse(next_time)); + + let fullname = match dns_registry.name_changes.get(&service_name) { + Some(new_name) => new_name.to_string(), + None => service_name.to_string(), + }; + + let mut hostname = info.get_hostname(); + if let Some(new_name) = dns_registry.name_changes.get(hostname) { + hostname = new_name; + } + + info!("wake up: announce service {} on {}", fullname, intf.ip()); + notify_monitors( + &mut self.monitors, + DaemonEvent::Announce( + service_name, + format!("{}:{}", hostname, &intf.ip()), + ), + ); + + info.set_status(intf, ServiceStatus::Announced); + } + } + } + } } fn unregister_service(&self, info: &ServiceInfo, intf: &Interface, sock: &Socket) -> Vec { @@ -1443,10 +1535,17 @@ impl Zeroconf { out.add_question(name, *qtype); for record in self.cache.get_known_answers(name, *qtype, now) { + /* + RFC 6762 section 7.1: https://datatracker.ietf.org/doc/html/rfc6762#section-7.1 + ... + When a Multicast DNS querier sends a query to which it already knows + some answers, it populates the Answer Section of the DNS query + message with those answers. + */ debug!("add known answer: {:?}", record); let mut new_record = record.clone(); new_record.get_record_mut().update_ttl(now); - out.add_additional_answer_box(new_record); + out.add_answer_box(new_record); } } @@ -1694,18 +1793,18 @@ impl Zeroconf { debug!( "handle_response: {} answers {} authorities {} additionals", &msg.answers.len(), - &msg.num_authorities, - &msg.num_additionals + &msg.authorities.len(), + &msg.additional.len() ); let now = current_time_millis(); // remove records that are expired. - msg.answers.retain(|record| { + let mut record_predicate = |record: &DnsRecordBox| { if !record.get_record().is_expired(now) { return true; } - debug!("record is expired, removing it from cache."); + info!("record is expired, removing it from cache."); if self.cache.remove(record) { // for PTR records, send event to listeners if let Some(dns_ptr) = record.any().downcast_ref::() { @@ -1720,7 +1819,13 @@ impl Zeroconf { } } false - }); + }; + msg.answers.retain(&mut record_predicate); + msg.authorities.retain(&mut record_predicate); + msg.additional.retain(&mut record_predicate); + + // check possible conflicts and handle them. + self.conflict_handler(&msg, intf); /// Represents a DNS record change that involves one service instance. struct InstanceChange { @@ -1737,7 +1842,12 @@ impl Zeroconf { // other. let mut changes = Vec::new(); let mut timers = Vec::new(); - for record in msg.answers { + for record in msg + .answers + .into_iter() + .chain(msg.authorities.into_iter()) + .chain(msg.additional.into_iter()) + { match self.cache.add_or_update(intf, record, &mut timers) { Some((dns_record, true)) => { timers.push(dns_record.get_record().get_expire_time()); @@ -1816,6 +1926,83 @@ impl Zeroconf { self.resolve_updated_instances(updated_instances); } + fn conflict_handler(&mut self, msg: &DnsIncoming, intf: &Interface) { + let Some(dns_registry) = self.dns_registry_map.get_mut(intf) else { + return; + }; + + let mut new_records = Vec::new(); + + for answer in msg.answers.iter() { + let name = answer.get_name(); + let Some(probe) = dns_registry.probing.get_mut(name) else { + continue; + }; + + probe.records.retain(|record| { + if record.get_type() == answer.get_type() + && record.get_class() == answer.get_class() + && !record.rrdata_match(answer.as_ref()) + { + info!( + "found conflict name: '{name}' record: {}: {} PEER: {}", + rr_type_name(record.get_type()), + record.rdata_print(), + answer.rdata_print() + ); + + // create a new name for this record + // then remove the old record in probing. + let mut new_record = record.clone(); + let new_name = name_change(name); + new_record.get_record_mut().set_new_name(new_name); + new_records.push(new_record); + return false; // old record is dropped from the probe. + } + + true + }); + } + + // Probing again with the new names. + let create_time = current_time_millis() + fastrand::u64(0..250); + + for record in new_records { + if dns_registry.update_hostname( + record.get_original_name(), + record.get_name(), + create_time, + ) { + self.timers.push(Reverse(create_time)); + } + + // remember the name changes + dns_registry.name_changes.insert( + record.get_record().entry.name.to_string(), + record.get_name().to_string(), + ); + + let probe = match dns_registry.probing.get_mut(record.get_name()) { + Some(p) => p, + None => { + let new_probe = dns_registry + .probing + .entry(record.get_name().to_string()) + .or_insert(Probe::new(create_time)); + self.timers.push(Reverse(new_probe.next_send)); + new_probe + } + }; + + info!( + "insert record with new name '{}' {} into probe", + record.get_name(), + rr_type_name(record.get_type()) + ); + probe.insert_record(record); + } + } + /// Resolve the updated (including new) instances. /// /// Note: it is possible that more than 1 PTR pointing to the same @@ -1881,6 +2068,10 @@ impl Zeroconf { if qtype == TYPE_PTR { for service in self.my_services.values() { + if service.get_status(intf) != ServiceStatus::Announced { + continue; + } + if question.entry.name == service.get_type() || service .get_subtype() @@ -1905,8 +2096,57 @@ impl Zeroconf { } } } else { + // Simultaneous Probe Tiebreaking (RFC 6762 section 8.2) + if qtype == TYPE_ANY && msg.num_authorities > 0 { + if let Some(dns_registry) = self.dns_registry_map.get_mut(intf) { + let probe_name = &question.entry.name; + + if let Some(probe) = dns_registry.probing.get_mut(probe_name) { + let now = current_time_millis(); + + // Only do tiebreaking if probe already started. + // This check also helps avoid redo tiebreaking if start time + // was postponed. + if probe.start_time < now { + let incoming_records: Vec<_> = msg + .authorities + .iter() + .filter(|r| r.get_name() == probe_name) + .collect(); + + /* + RFC 6762 section 8.2: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2 + ... + if the host finds that its own data is lexicographically later, it + simply ignores the other host's probe. If the host finds that its + own data is lexicographically earlier, then it defers to the winning + host by waiting one second, and then begins probing for this record + again. + */ + match probe.tiebreaking(&incoming_records) { + cmp::Ordering::Less => { + info!( + "tiebreaking '{}': LOST, will wait for one second", + probe_name + ); + probe.start_time = now + 1000; // wait and restart. + probe.next_send = now + 1000; + } + ordering => { + info!("tiebreaking '{}': {:?}", probe_name, ordering); + } + } + } + } + } + } + if qtype == TYPE_A || qtype == TYPE_AAAA || qtype == TYPE_ANY { for service in self.my_services.values() { + if service.get_status(intf) != ServiceStatus::Announced { + continue; + } + if service.get_hostname().to_lowercase() == question.entry.name.to_lowercase() { @@ -1945,6 +2185,10 @@ impl Zeroconf { None => continue, }; + if service.get_status(intf) != ServiceStatus::Announced { + continue; + } + if qtype == TYPE_SRV || qtype == TYPE_ANY { out.add_answer( &msg, @@ -2042,7 +2286,7 @@ impl Zeroconf { instance_name.to_string(), ); match sender.send(event) { - Ok(()) => debug!("Sent ServiceRemoved to listener successfully"), + Ok(()) => info!("notify_service_removal: sent ServiceRemoved to listener of {ty_domain}: {instance_name}"), Err(e) => error!("Failed to send event: {}", e), } } @@ -2259,20 +2503,42 @@ impl Zeroconf { } } - fn exec_command_register_resend(&mut self, fullname: String) { - match self.my_services.get(&fullname) { - Some(info) => { - let outgoing_addrs = self.send_unsolicited_response(info); - if !outgoing_addrs.is_empty() { - self.notify_monitors(DaemonEvent::Announce( - fullname, - format!("{:?}", &outgoing_addrs), - )); - } - self.increase_counter(Counter::RegisterResend, 1); + fn exec_command_register_resend(&mut self, fullname: String, intf: Interface) { + let Some(info) = self.my_services.get_mut(&fullname) else { + debug!("announce: cannot find such service {}", &fullname); + return; + }; + + let Some(dns_registry) = self.dns_registry_map.get_mut(&intf) else { + return; + }; + + let Some(sock) = self.intf_socks.get(&intf) else { + return; + }; + + if announce_service_on_intf(dns_registry, info, &intf, sock) { + let mut hostname = info.get_hostname(); + if let Some(new_name) = dns_registry.name_changes.get(hostname) { + hostname = new_name; } - None => debug!("announce: cannot find such service {}", &fullname), + let service_name = match dns_registry.name_changes.get(&fullname) { + Some(new_name) => new_name.to_string(), + None => fullname, + }; + + info!("resend: announce service {} on {}", service_name, intf.ip()); + + notify_monitors( + &mut self.monitors, + DaemonEvent::Announce(service_name, format!("{}:{}", hostname, &intf.ip())), + ); + info.set_status(&intf, ServiceStatus::Announced); + } else { + error!("register-resend should not fail"); } + + self.increase_counter(Counter::RegisterResend, 1); } /// Refresh cached service records with active queriers @@ -2366,6 +2632,16 @@ pub enum DaemonEvent { /// Daemon detected a IP address removed from the host. IpDel(IpAddr), + + NameChange(DnsNameChange), +} + +#[derive(Clone, Debug)] +pub struct DnsNameChange { + pub original: String, + pub new_name: String, + pub rr_type: u16, + pub intf_name: String, } /// Commands supported by the daemon @@ -2384,7 +2660,7 @@ enum Command { Unregister(String, Sender), // (fullname) /// Announce again a service to local network - RegisterResend(String), // (fullname) + RegisterResend(String, Interface), // (fullname) /// Resend unregister packet. UnregisterResend(Vec, Interface), // (packet content) @@ -2423,7 +2699,7 @@ impl fmt::Display for Command { Self::GetMetrics(_) => write!(f, "Command GetMetrics"), Self::Monitor(_) => write!(f, "Command Monitor"), Self::Register(_) => write!(f, "Command Register"), - Self::RegisterResend(_) => write!(f, "Command RegisterResend"), + Self::RegisterResend(_, _) => write!(f, "Command RegisterResend"), Self::SetOption(_) => write!(f, "Command SetOption"), Self::StopBrowse(_) => write!(f, "Command StopBrowse"), Self::StopResolveHostname(_) => write!(f, "Command StopResolveHostname"), @@ -2566,7 +2842,7 @@ fn my_ip_interfaces() -> Vec { fn send_dns_outgoing(out: &DnsOutgoing, intf: &Interface, sock: &Socket) -> Vec> { let qtype = if out.is_query() { "query" } else { "response" }; debug!( - "Multicasting {}: {} questions {} answers {} authorities {} additional", + "send outgoing {}: {} questions {} answers {} authorities {} additional", qtype, out.questions.len(), out.answers.len(), @@ -2615,6 +2891,159 @@ fn valid_instance_name(name: &str) -> bool { name.split('.').count() >= 5 } +fn notify_monitors(monitors: &mut Vec>, event: DaemonEvent) { + monitors.retain(|sender| { + if let Err(e) = sender.try_send(event.clone()) { + error!("notify_monitors: try_send: {}", &e); + if matches!(e, TrySendError::Disconnected(_)) { + return false; // This monitor is dropped. + } + } + true + }); +} + +/// Check if all unique records passed "probing", and if yes, create a packet +/// to announce the service. +fn prepare_announce( + info: &ServiceInfo, + intf: &Interface, + dns_registry: &mut DnsRegistry, +) -> Option { + let service_fullname = match dns_registry.name_changes.get(info.get_fullname()) { + Some(new_name) => new_name, + None => info.get_fullname(), + }; + + info!( + "prepare to announce service {service_fullname} on {}: {}", + &intf.name, + &intf.ip() + ); + let mut probing_count = 0; + let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA); + let create_time = current_time_millis() + fastrand::u64(0..250); + + out.add_answer_at_time( + DnsPointer::new( + info.get_type(), + TYPE_PTR, + CLASS_IN, + info.get_other_ttl(), + service_fullname.to_string(), + ), + 0, + ); + + if let Some(sub) = info.get_subtype() { + debug!("Adding subdomain {}", sub); + out.add_answer_at_time( + DnsPointer::new( + sub, + TYPE_PTR, + CLASS_IN, + info.get_other_ttl(), + service_fullname.to_string(), + ), + 0, + ); + } + + let intf_addrs = info.get_addrs_on_intf(intf); + if intf_addrs.is_empty() { + debug!("No valid addrs to add on intf {:?}", &intf); + return None; + } + + // SRV records. + let mut hostname = info.get_hostname().to_string(); + if let Some(new_name) = dns_registry.name_changes.get(&hostname) { + hostname = new_name.to_string(); + } + + let mut srv = DnsSrv::new( + info.get_fullname(), + CLASS_IN | CLASS_CACHE_FLUSH, + info.get_host_ttl(), + info.get_priority(), + info.get_weight(), + info.get_port(), + hostname, + ); + + if let Some(new_name) = dns_registry.name_changes.get(info.get_fullname()) { + srv.get_record_mut().set_new_name(new_name.to_string()); + } + + if dns_registry.is_probing_done(&srv, info.get_fullname(), create_time) { + out.add_answer_at_time(srv, 0); + } else { + probing_count += 1; + } + + // TXT records. + + let mut txt = DnsTxt::new( + info.get_fullname(), + CLASS_IN | CLASS_CACHE_FLUSH, + info.get_other_ttl(), + info.generate_txt(), + ); + + if let Some(new_name) = dns_registry.name_changes.get(info.get_fullname()) { + txt.get_record_mut().set_new_name(new_name.to_string()); + } + + if dns_registry.is_probing_done(&txt, info.get_fullname(), create_time) { + out.add_answer_at_time(txt, 0); + } else { + probing_count += 1; + } + + // Address records. + let hostname = info.get_hostname(); + for address in intf_addrs { + let mut dns_addr = DnsAddress::new( + hostname, + ip_address_to_type(&address), + CLASS_IN | CLASS_CACHE_FLUSH, + info.get_host_ttl(), + address, + ); + + if let Some(new_name) = dns_registry.name_changes.get(hostname) { + dns_addr.get_record_mut().set_new_name(new_name.to_string()); + } + + if dns_registry.is_probing_done(&dns_addr, info.get_fullname(), create_time) { + out.add_answer_at_time(dns_addr, 0); + } else { + probing_count += 1; + } + } + + if probing_count > 0 { + return None; + } + + Some(out) +} + +/// Send an unsolicited response for owned service via `intf` and `sock`. +/// Returns true if sent out successfully. +fn announce_service_on_intf( + dns_registry: &mut DnsRegistry, + info: &ServiceInfo, + intf: &Interface, + sock: &Socket, +) -> bool { + if let Some(out) = prepare_announce(info, intf, dns_registry) { + send_dns_outgoing(&out, intf, sock); + return true; + } + false +} + #[cfg(test)] mod tests { use super::{ @@ -2691,7 +3120,7 @@ mod tests { } #[test] - fn service_with_temporarily_invalidated_ptr() { + fn test_service_with_temporarily_invalidated_ptr() { // Create a daemon let d = ServiceDaemon::new().expect("Failed to create daemon"); @@ -2810,7 +3239,7 @@ mod tests { // let fullname = my_service.get_fullname().to_string(); // set SRV to expire soon. - let new_ttl = 2; // for testing only. + let new_ttl = 3; // for testing only. my_service._set_host_ttl(new_ttl); // register my service @@ -2820,7 +3249,7 @@ mod tests { let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client"); let browse_chan = mdns_client.browse(service_type).unwrap(); - let timeout = Duration::from_secs(1); + let timeout = Duration::from_secs(2); let mut resolved = false; while let Ok(event) = browse_chan.recv_timeout(timeout) { @@ -2949,7 +3378,7 @@ mod tests { ) .unwrap(); - let new_ttl = 2; // for testing only. + let new_ttl = 3; // for testing only. my_service._set_other_ttl(new_ttl); // register my service @@ -2977,7 +3406,7 @@ mod tests { assert!(resolved); // wait over 80% of TTL, and refresh PTR should be sent out. - let timeout = Duration::from_millis(1800); + let timeout = Duration::from_millis(new_ttl as u64 * 1000 * 90 / 100); while let Ok(event) = browse_chan.recv_timeout(timeout) { println!("event: {:?}", &event); } @@ -2993,3 +3422,17 @@ mod tests { mdns_client.shutdown().unwrap(); } } + +/// Returns a new name based on the `original` to avoid conflicts. +/// +/// For example: +/// `foo.local.` becomes `foo (2).local.` +fn name_change(original: &str) -> String { + let mut parts: Vec<_> = original.split('.').collect(); + let Some(first_part) = parts.get_mut(0) else { + return format!("{original} (2)"); + }; + let new_name = format!("{} (2)", first_part); + *first_part = &new_name; + parts.join(".") +} diff --git a/src/service_info.rs b/src/service_info.rs index aacdd46..c656a4b 100644 --- a/src/service_info.rs +++ b/src/service_info.rs @@ -1,10 +1,14 @@ //! Define `ServiceInfo` to represent a service and its operations. #[cfg(feature = "logging")] -use crate::log::error; -use crate::{dns_parser::split_sub_domain, Error, Result}; +use crate::log::{error, info}; +use crate::{ + dns_parser::{rr_type_name, split_sub_domain, DnsRecordBox, DnsRecordExt, DnsSrv, TYPE_SRV}, + Error, Result, +}; use if_addrs::{IfAddr, Interface}; use std::{ + cmp, collections::{HashMap, HashSet}, convert::TryInto, fmt, @@ -38,6 +42,15 @@ pub struct ServiceInfo { weight: u16, txt_properties: TxtProperties, addr_auto: bool, // Let the system update addresses automatically. + + status: HashMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ServiceStatus { + Probing, + Announced, + Unknown, } impl ServiceInfo { @@ -121,6 +134,7 @@ impl ServiceInfo { weight: 0, txt_properties, addr_auto: false, + status: HashMap::new(), }; Ok(this) @@ -315,6 +329,24 @@ impl ServiceInfo { pub(crate) fn _set_other_ttl(&mut self, ttl: u32) { self.other_ttl = ttl; } + + pub(crate) fn set_status(&mut self, intf: &Interface, status: ServiceStatus) { + match self.status.get_mut(intf) { + Some(service_status) => { + *service_status = status; + } + None => { + self.status.entry(intf.clone()).or_insert(status); + } + } + } + + pub(crate) fn get_status(&self, intf: &Interface) -> ServiceStatus { + self.status + .get(intf) + .cloned() + .unwrap_or(ServiceStatus::Unknown) + } } /// Removes potentially duplicated ".local." at the end of "hostname". @@ -747,6 +779,259 @@ pub fn valid_two_addrs_on_intf(addr_a: &IpAddr, addr_b: &IpAddr, intf: &Interfac } } +/// A probing for a particular name. +#[derive(Debug)] +pub(crate) struct Probe { + /// All records probing for the same name. + pub(crate) records: Vec, + + /// The fullnames of services that are probing these records. + /// These are the original service names, will not change per conflicts. + pub(crate) waiting_services: Vec, + + /// The time (T) to send the first query . + pub(crate) start_time: u64, + + /// The time to send the next (including the first) query. + pub(crate) next_send: u64, +} + +impl Probe { + pub(crate) fn new(start_time: u64) -> Self { + // RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.1: + // + // "250 ms after the first query, the host should send a second; then, + // 250 ms after that, a third. If, by 250 ms after the third probe, no + // conflicting Multicast DNS responses have been received, the host may + // move to the next step, announcing. " + let next_send = start_time; + + Self { + records: Vec::new(), + waiting_services: Vec::new(), + start_time, + next_send, + } + } + + /// Add a new record with the same probing name in a sorted order. + pub(crate) fn insert_record(&mut self, record: DnsRecordBox) { + /* + RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2.1 + + " The records are sorted using the same lexicographical order as + described above, that is, if the record classes differ, the record + with the lower class number comes first. If the classes are the same + but the rrtypes differ, the record with the lower rrtype number comes + first." + */ + let insert_position = self + .records + .binary_search_by( + |existing| match existing.get_class().cmp(&record.get_class()) { + std::cmp::Ordering::Equal => existing.get_type().cmp(&record.get_type()), + other => other, + }, + ) + .unwrap_or_else(|pos| pos); + + self.records.insert(insert_position, record); + } + + /// Compares with `incoming` records. Returns `Less` if we yield. + pub(crate) fn tiebreaking(&self, incoming: &[&DnsRecordBox]) -> cmp::Ordering { + /* + RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2 + + " If the host finds that its + own data is lexicographically earlier, then it defers to the winning + host by waiting one second, and then begins probing for this record + again." + */ + let min_len = self.records.len().min(incoming.len()); + + // Compare elements up to the length of the shorter vector + for (i, incoming_record) in incoming.iter().enumerate().take(min_len) { + match self.records[i].compare(incoming_record.as_ref()) { + cmp::Ordering::Equal => continue, + other => return other, + } + } + + self.records.len().cmp(&incoming.len()) + } + + pub(crate) fn update_next_send(&mut self, now: u64) { + self.next_send = now + 250; + } + + /// Returns whether this probe is finished. + pub(crate) fn expired(&self, now: u64) -> bool { + // The 2nd query is T + 250ms, the 3rd query is T + 500ms, + // The expire time is T + 750ms + now >= self.start_time + 750 + } +} + +/// DNS records of all the registered services. +pub(crate) struct DnsRegistry { + /// keyed by the name of all related records. + /* + When a host is probing for a group of related records with the same + name (e.g., the SRV and TXT record describing a DNS-SD service), only + a single question need be placed in the Question Section, since query + type "ANY" (255) is used, which will elicit answers for all records + with that name. However, for tiebreaking to work correctly in all + cases, the Authority Section must contain *all* the records and + proposed rdata being probed for uniqueness. + */ + pub(crate) probing: HashMap, + + /// Already done probing, or no need to probe. + pub(crate) active: HashMap>, + + /// timers of the newly added probes. + pub(crate) new_timers: Vec, + + /// Mapping from original names to new names. + pub(crate) name_changes: HashMap, +} + +impl DnsRegistry { + pub(crate) fn new() -> Self { + Self { + probing: HashMap::new(), + active: HashMap::new(), + new_timers: Vec::new(), + name_changes: HashMap::new(), + } + } + + pub(crate) fn is_probing_done( + &mut self, + answer: &T, + service_name: &str, + start_time: u64, + ) -> bool + where + T: DnsRecordExt + Send + 'static, + { + if let Some(active_records) = self.active.get(answer.get_name()) { + for record in active_records.iter() { + if answer.matches(record.as_ref()) { + info!( + "found active record {} type {}", + answer.get_name(), + answer.get_type() + ); + return true; + } + } + } + + let probe = self + .probing + .entry(answer.get_name().to_string()) + .or_insert(Probe::new(start_time)); + + self.new_timers.push(probe.next_send); + + for record in probe.records.iter() { + if answer.matches(record.as_ref()) { + info!( + "found existing record {} in probe of '{}'", + rr_type_name(answer.get_type()), + answer.get_name(), + ); + probe.waiting_services.push(service_name.to_string()); + return false; // Found existing probe for the same record. + } + } + + info!( + "insert record {} into probe of {}", + rr_type_name(answer.get_type()), + answer.get_name(), + ); + probe.insert_record(answer.clone_box()); + probe.waiting_services.push(service_name.to_string()); + + false + } + + pub(crate) fn update_hostname( + &mut self, + original: &str, + new_name: &str, + probe_time: u64, + ) -> bool { + // check all records in "probing" and "active": + // if the record is SRV, and hostname is set to original, remove it. + // and add it to "probing" with "new_name". + + let mut found_records = Vec::new(); + let mut new_timer_added = false; + + for (_name, probe) in self.probing.iter_mut() { + probe.records.retain(|record| { + if record.get_type() == TYPE_SRV { + if let Some(srv) = record.any().downcast_ref::() { + if srv.host == original { + let mut new_record = srv.clone(); + new_record.host = new_name.to_string(); + found_records.push(new_record); + return false; + } + } + } + true + }); + } + + for (_name, records) in self.active.iter_mut() { + records.retain(|record| { + if record.get_type() == TYPE_SRV { + if let Some(srv) = record.any().downcast_ref::() { + if srv.host == original { + let mut new_record = srv.clone(); + new_record.host = new_name.to_string(); + found_records.push(new_record); + return false; + } + } + } + true + }); + } + + for record in found_records { + let probe = match self.probing.get_mut(record.get_name()) { + Some(p) => { + p.start_time = probe_time; // restart this probe. + p + } + None => { + let new_probe = self + .probing + .entry(record.get_name().to_string()) + .or_insert(Probe::new(probe_time)); + new_timer_added = true; + new_probe + } + }; + + info!( + "insert record {} with new hostname into probe: {}", + rr_type_name(record.get_type()), + record.get_name() + ); + probe.insert_record(Box::new(record)); + } + + new_timer_added + } +} + #[cfg(test)] mod tests { use super::{ diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs index 620f53b..3352bc4 100644 --- a/tests/mdns_test.rs +++ b/tests/mdns_test.rs @@ -7,7 +7,7 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::thread::sleep; use std::time::{Duration, SystemTime}; -// use test_log::test; // commented out for debugging a flaky test in CI. +use test_log::test; /// This test covers: /// register(announce), browse(query), response, unregister, shutdown. @@ -186,7 +186,7 @@ fn integration_success() { println!("metrics: {:?}", &metrics); assert_eq!(metrics["register"], 1); assert_eq!(metrics["unregister"], 1); - assert_eq!(metrics["register-resend"], 1); + assert!(metrics["register-resend"] >= 1); println!("unique interface set: {:?}", unique_intf_idx_ip_ver_set); assert_eq!( @@ -540,7 +540,7 @@ fn service_with_named_interface_only() { // Browse again. let browse_chan = d.browse(my_ty_domain).unwrap(); - let timeout = Duration::from_secs(2); + let timeout = Duration::from_secs(3); let mut resolved = false; while let Ok(event) = browse_chan.recv_timeout(timeout) { @@ -745,14 +745,14 @@ fn subtype() { let d = ServiceDaemon::new().expect("Failed to create daemon"); // Register a service with a subdomain - let subtype_domain = "_directory._sub._test._tcp.local."; - let ty_domain = "_test._tcp.local."; + let subtype_domain = "_directory._sub._test-subtype._tcp.local."; + let ty_domain = "_test-subtype._tcp.local."; let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); let instance_name = now.as_micros().to_string(); // Create a unique name. let host_ipv4 = my_ip_interfaces()[0].ip().to_string(); - let host_name = "my_host.local."; + let host_name = "subtype_host.local."; let port = 5201; let my_service = ServiceInfo::new( subtype_domain, @@ -835,7 +835,10 @@ fn service_name_check() { assert!(result.is_ok()); // Verify that the service was published successfully. - let event = monitor.recv_timeout(Duration::from_millis(500)).unwrap(); + let publish_timeout = 1000; + let event = monitor + .recv_timeout(Duration::from_millis(publish_timeout)) + .unwrap(); assert!(matches!(event, DaemonEvent::Announce(_, _))); // Check for the internal upper limit of service name length max. @@ -930,10 +933,13 @@ fn instance_name_two_dots() { assert!(result.is_ok()); // Verify that the service was published successfully. - let event = monitor.recv_timeout(Duration::from_millis(500)).unwrap(); + let publish_timeout = 1000; + let event = monitor + .recv_timeout(Duration::from_millis(publish_timeout)) + .unwrap(); assert!(matches!(event, DaemonEvent::Announce(_, _))); - // Browseing the service. + // Browse the service. let receiver = server_daemon.browse(service_type).unwrap(); let mut resolved = false; let timeout = Duration::from_secs(2); @@ -1163,7 +1169,7 @@ fn hostname_resolution_timeout() { d.shutdown().unwrap(); } -#[test_log::test] +#[test] fn test_cache_flush_record() { // Create a daemon let server = ServiceDaemon::new().expect("Failed to create server"); @@ -1456,6 +1462,100 @@ fn test_domain_suffix_in_browse() { mdns_client.shutdown().unwrap(); } +#[test] +fn test_conflict_resolution() { + // This test registers two services using the same names, but different IP addresses. + let ty_domain = "_conflict-test._udp.local."; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(); + let instance_name = now.as_micros().to_string(); // Create a unique name. + let host_name = "conflict_host.local."; + let port = 5200; + + // Register the first service. + let server1 = ServiceDaemon::new().expect("failed to start server1"); + + // Get a single IPv4 address + let ip_addr1 = my_ip_interfaces() + .iter() + .find(|iface| iface.ip().is_ipv4()) + .map(|iface| iface.ip()) + .unwrap(); + + // Publish the service on server1 + let service1 = ServiceInfo::new(ty_domain, &instance_name, host_name, &ip_addr1, port, None) + .expect("valid service info"); + server1 + .register(service1) + .expect("Failed to register service1"); + + // wait for the service announced. + sleep(Duration::from_secs(1)); + + // Register the second service. + let server2 = ServiceDaemon::new().expect("failed to start server2"); + + // Modify the IPv4 address for the service. + let IpAddr::V4(ipv4) = ip_addr1 else { + assert!(false); + return; + }; + let bytes = ipv4.octets(); + let ip_addr2 = IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3] + 1)); + + let service2 = ServiceInfo::new(ty_domain, &instance_name, host_name, &ip_addr2, port, None) + .expect("failed to create ServiceInfo for service2"); + server2 + .register(service2) + .expect("failed to register service2"); + + // Verify name change event for the second service, due to the name conflict. + let server2_monitor = server2.monitor().unwrap(); + let timeout = Duration::from_secs(2); + let mut name_changed = false; + while let Ok(event) = server2_monitor.recv_timeout(timeout) { + match event { + DaemonEvent::NameChange(change) => { + println!("server2 daemon event: {:?}", change); + name_changed = true; + break; + } + other => println!("server2 other event: {:?}", other), + } + } + assert!(name_changed); + + // Verify both services are resolved. + let client = ServiceDaemon::new().expect("failed to create mdns client"); + let receiver = client.browse(ty_domain).unwrap(); + + let timeout = Duration::from_secs(3); + let mut service_names = HashSet::new(); + + while let Ok(event) = receiver.recv_timeout(timeout) { + match event { + ServiceEvent::ServiceResolved(info) => { + println!( + "Resolved a service: {} host {} IP {:?}", + info.get_fullname(), + info.get_hostname(), + info.get_addresses_v4() + ); + + service_names.insert(info.get_fullname().to_string()); + if info.get_fullname().contains("(2)") { + break; + } + } + _ => {} + } + } + + // Verify that we have resolve two services instead of one. + assert_eq!(service_names.len(), 2); +} + /// A helper function to include a timestamp for println. fn timed_println(msg: String) { let now = SystemTime::now();