diff --git a/src/networking/client.rs b/src/networking/client.rs index 35254640..30add922 100644 --- a/src/networking/client.rs +++ b/src/networking/client.rs @@ -53,6 +53,10 @@ impl Client { self.local.acknowledge_packets(ack, ack_bitfield); self.remote.acknowledge_packet(sequence); } + + pub fn address(&self) -> std::net::SocketAddr { + self.address + } } #[cfg(test)] mod tests { diff --git a/src/networking/local.rs b/src/networking/local.rs index 3b772bfe..ef0abec8 100644 --- a/src/networking/local.rs +++ b/src/networking/local.rs @@ -39,7 +39,9 @@ impl Local { /// Acknowledges a packet with the given sequence number. This means that the remote has received the packet. pub fn acknowledge_packet(&mut self, sequence: u16) { let index = (sequence % 1024) as usize; - self.packet_data.set(index, true); + if self.sequence_buffer[index] == sequence { + self.packet_data.set(index, true); + } } pub fn acknowledge_packets(&mut self, ack: u16, ack_bitfield: u32) { @@ -118,32 +120,96 @@ mod tests { #[test] fn test_get_packet_data() { - let mut remote = Local::new(); - let packet_header = remote.get_packet_data(0); + let mut local = Local::new(); + let packet_header = local.get_packet_data(0); assert_eq!(packet_header, None); - let packet_header = remote.get_packet_data(1023); + let packet_header = local.get_packet_data(1023); assert_eq!(packet_header, None); - remote.get_sequence_number(); - let packet_header = remote.get_packet_data(0); + local.get_sequence_number(); + let packet_header = local.get_packet_data(0); assert_eq!(packet_header, Some(PacketInfo { acked: false })); - let packet_header = remote.get_packet_data(1023); + let packet_header = local.get_packet_data(1023); assert_eq!(packet_header, None); - remote.get_sequence_number(); - let packet_header = remote.get_packet_data(0); + local.get_sequence_number(); + let packet_header = local.get_packet_data(0); assert_eq!(packet_header, Some(PacketInfo { acked: false })); - let packet_header = remote.get_packet_data(1023); + let packet_header = local.get_packet_data(1023); assert_eq!(packet_header, None); - let packet_header = remote.get_packet_data(1); + let packet_header = local.get_packet_data(1); assert_eq!(packet_header, Some(PacketInfo { acked: false })); - remote.acknowledge_packet(0); - let packet_header = remote.get_packet_data(0); + local.acknowledge_packet(0); + let packet_header = local.get_packet_data(0); assert_eq!(packet_header, Some(PacketInfo { acked: true })); - let packet_header = remote.get_packet_data(1023); + let packet_header = local.get_packet_data(1023); assert_eq!(packet_header, None); - let packet_header = remote.get_packet_data(1); + let packet_header = local.get_packet_data(1); assert_eq!(packet_header, Some(PacketInfo { acked: false })); } + + #[test] + fn test_packet_acknowledgement() { + let mut local = Local::new(); + + for i in 0..32 { + local.get_sequence_number(); + } + + assert_eq!(local.unacknowledged_packets(), (0u16..32u16).collect::>()); + + for i in 0..32 { + local.acknowledge_packet(i); + } + + assert_eq!(local.unacknowledged_packets(), Vec::::new()); + + for i in 0..32 { + local.get_sequence_number(); + } + + assert_eq!(local.unacknowledged_packets(), (32u16..64u16).collect::>()); + + for i in 0..32 { + local.acknowledge_packet(i); + } + + assert_eq!(local.unacknowledged_packets(), (32u16..64u16).collect::>()); + + for i in 32..64 { + local.acknowledge_packet(i); + } + + assert_eq!(local.unacknowledged_packets(), Vec::::new()); + } + + #[test] + fn test_sparse_packet_acknowledgement() { + let mut local = Local::new(); + + for i in 0..32 { + local.get_sequence_number(); + } + + local.acknowledge_packet(0); + + assert_eq!(local.unacknowledged_packets(), (1u16..32u16).collect::>()); + + local.acknowledge_packet(2); + + assert_eq!(local.unacknowledged_packets(), (1u16..32u16).filter(|&i| i != 2).collect::>()); + + local.acknowledge_packet(4); + + assert_eq!(local.unacknowledged_packets(), (1u16..32u16).filter(|&i| i != 2 && i != 4).collect::>()); + + local.acknowledge_packet(1); + + assert_eq!(local.unacknowledged_packets(), (3u16..32u16).filter(|&i| i != 4).collect::>()); + + local.acknowledge_packet(3); + + assert_eq!(local.unacknowledged_packets(), (5u16..32u16).collect::>()); + } } \ No newline at end of file diff --git a/src/networking/remote.rs b/src/networking/remote.rs index 7d17e005..67126d63 100644 --- a/src/networking/remote.rs +++ b/src/networking/remote.rs @@ -34,6 +34,7 @@ impl Remote { /// Acknowledges a packet with the given sequence number. This means that the remote has received the packet. pub fn acknowledge_packet(&mut self, sequence: u16) { let index = (sequence % 1024) as usize; + let window_shift = sequence.max(self.ack) - self.ack; // If the packet sequence is more recent, we update the remote sequence number. if sequence_greater_than(sequence, self.ack) { @@ -44,18 +45,14 @@ impl Remote { self.receive_sequence_buffer[index] = u16::MAX; } - let index = (sequence % 1024) as usize; self.receive_sequence_buffer[index] = sequence; self.ack = sequence; } self.packet_data.set(index, true); - self.ack_bitfield = (self.ack_bitfield << 1) | 1; - } - pub fn unacknowledged_packets(&self) -> Vec { - self.receive_sequence_buffer.iter().enumerate().filter(|(i, &sequence)| sequence != u16::MAX && !self.packet_data.get(*i)).map(|(_, &e)| e).collect() + self.ack_bitfield = (self.ack_bitfield << window_shift) | 1 << ((self.ack - sequence) % 32); } pub fn get_ack(&self) -> u16 { @@ -101,4 +98,75 @@ impl BitArray where [u8; N / 8]: { #[cfg(test)] mod tests { use super::*; + + #[test] + fn test_packet_acknowledgement() { + let mut remote = Remote::new(); + + assert_eq!(remote.get_ack(), 0); + assert_eq!(remote.get_ack_bitfield(), 0); + + for i in 0..32 { + remote.acknowledge_packet(i); + } + + assert_eq!(remote.get_ack(), 31); + assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF); + + for i in 0..32 { + remote.acknowledge_packet(i); + } + + assert_eq!(remote.get_ack(), 31); + assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF); + + for i in 32..48 { + remote.acknowledge_packet(i); + } + + assert_eq!(remote.get_ack(), 47); + assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF); + + for i in 48..64 { + remote.acknowledge_packet(i); + } + + assert_eq!(remote.get_ack(), 63); + assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF); + } + + #[test] + fn test_sparse_packet_acknowledgement() { + let mut remote = Remote::new(); + + assert_eq!(remote.get_ack(), 0); + assert_eq!(remote.get_ack_bitfield(), 0); + + remote.acknowledge_packet(0); + + assert_eq!(remote.get_ack(), 0); + assert_eq!(remote.get_ack_bitfield(), 1 << 0); + + remote.acknowledge_packet(2); + + assert_eq!(remote.get_ack(), 2); + assert_eq!(remote.get_ack_bitfield(), 1 << 2 | 1 << 0); + + remote.acknowledge_packet(4); + + assert_eq!(remote.get_ack(), 4); + assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 2 | 1 << 0); + + remote.acknowledge_packet(1); + + assert_eq!(remote.get_ack(), 4); + assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 3 | 1 << 2 | 1 << 0); + + remote.acknowledge_packet(3); + + assert_eq!(remote.get_ack(), 4); + assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 3 | 1 << 2 | 1 << 1 | 1 << 0); + + // TODO: Test when ack and sequence are more than 32 apart. + } } \ No newline at end of file diff --git a/src/networking/server.rs b/src/networking/server.rs index 7dea0e86..8cb0c2c5 100644 --- a/src/networking/server.rs +++ b/src/networking/server.rs @@ -14,6 +14,10 @@ impl Server { } fn connect(&mut self, address: std::net::SocketAddr) -> Result { + if let Some(i) = self.clients.iter().enumerate().find(|(i, client)| if let Some(client) = client { client.address() == address } else { false }).map(|(i, _)| i) { + return Ok(i); + } + for (i, client) in self.clients.iter_mut().enumerate() { if client.is_none() { *client = Some(Client::new(address)); @@ -24,6 +28,10 @@ impl Server { Err(ConnectionResults::ServerFull) } + fn disconnect(&mut self, client_index: usize) { + self.clients[client_index] = None; + } + fn send(&mut self, client_index: usize,) { if let Some(client) = self.clients[client_index].as_mut() { client.send(); @@ -51,6 +59,25 @@ mod tests { server.receive(client_index); } + #[test] + fn test_server_reconnect() { + let mut server = Server::new(); + + let client_index_0 = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap(); + let client_index_1 = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap(); + + assert_eq!(client_index_0, client_index_1); + } + + #[test] + fn test_server_disconnect() { + let mut server = Server::new(); + + let client_index = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap(); + + server.disconnect(client_index); + } + #[test] fn test_exhaust_connections() { let mut server = Server::new();