Skip to content

Commit

Permalink
Added network tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
facundo-villa committed Mar 6, 2024
1 parent 16fc5a5 commit 1406aaa
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 20 deletions.
4 changes: 4 additions & 0 deletions src/networking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl Client {
self.local.acknowledge_packets(ack, ack_bitfield);
self.remote.acknowledge_packet(sequence);
}

pub fn address(&self) -> std::net::SocketAddr {
self.address
}
}
#[cfg(test)]
mod tests {
Expand Down
96 changes: 81 additions & 15 deletions src/networking/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ impl Local {
/// Acknowledges a packet with the given sequence number. This means that the remote has received the packet.
pub fn acknowledge_packet(&mut self, sequence: u16) {
let index = (sequence % 1024) as usize;
self.packet_data.set(index, true);
if self.sequence_buffer[index] == sequence {
self.packet_data.set(index, true);
}
}

pub fn acknowledge_packets(&mut self, ack: u16, ack_bitfield: u32) {
Expand Down Expand Up @@ -118,32 +120,96 @@ mod tests {

#[test]
fn test_get_packet_data() {
let mut remote = Local::new();
let packet_header = remote.get_packet_data(0);
let mut local = Local::new();
let packet_header = local.get_packet_data(0);
assert_eq!(packet_header, None);
let packet_header = remote.get_packet_data(1023);
let packet_header = local.get_packet_data(1023);
assert_eq!(packet_header, None);

remote.get_sequence_number();
let packet_header = remote.get_packet_data(0);
local.get_sequence_number();
let packet_header = local.get_packet_data(0);
assert_eq!(packet_header, Some(PacketInfo { acked: false }));
let packet_header = remote.get_packet_data(1023);
let packet_header = local.get_packet_data(1023);
assert_eq!(packet_header, None);

remote.get_sequence_number();
let packet_header = remote.get_packet_data(0);
local.get_sequence_number();
let packet_header = local.get_packet_data(0);
assert_eq!(packet_header, Some(PacketInfo { acked: false }));
let packet_header = remote.get_packet_data(1023);
let packet_header = local.get_packet_data(1023);
assert_eq!(packet_header, None);
let packet_header = remote.get_packet_data(1);
let packet_header = local.get_packet_data(1);
assert_eq!(packet_header, Some(PacketInfo { acked: false }));

remote.acknowledge_packet(0);
let packet_header = remote.get_packet_data(0);
local.acknowledge_packet(0);
let packet_header = local.get_packet_data(0);
assert_eq!(packet_header, Some(PacketInfo { acked: true }));
let packet_header = remote.get_packet_data(1023);
let packet_header = local.get_packet_data(1023);
assert_eq!(packet_header, None);
let packet_header = remote.get_packet_data(1);
let packet_header = local.get_packet_data(1);
assert_eq!(packet_header, Some(PacketInfo { acked: false }));
}

#[test]
fn test_packet_acknowledgement() {
let mut local = Local::new();

for i in 0..32 {
local.get_sequence_number();
}

assert_eq!(local.unacknowledged_packets(), (0u16..32u16).collect::<Vec<_>>());

for i in 0..32 {
local.acknowledge_packet(i);
}

assert_eq!(local.unacknowledged_packets(), Vec::<u16>::new());

for i in 0..32 {
local.get_sequence_number();
}

assert_eq!(local.unacknowledged_packets(), (32u16..64u16).collect::<Vec<_>>());

for i in 0..32 {
local.acknowledge_packet(i);
}

assert_eq!(local.unacknowledged_packets(), (32u16..64u16).collect::<Vec<_>>());

for i in 32..64 {
local.acknowledge_packet(i);
}

assert_eq!(local.unacknowledged_packets(), Vec::<u16>::new());
}

#[test]
fn test_sparse_packet_acknowledgement() {
let mut local = Local::new();

for i in 0..32 {
local.get_sequence_number();
}

local.acknowledge_packet(0);

assert_eq!(local.unacknowledged_packets(), (1u16..32u16).collect::<Vec<_>>());

local.acknowledge_packet(2);

assert_eq!(local.unacknowledged_packets(), (1u16..32u16).filter(|&i| i != 2).collect::<Vec<_>>());

local.acknowledge_packet(4);

assert_eq!(local.unacknowledged_packets(), (1u16..32u16).filter(|&i| i != 2 && i != 4).collect::<Vec<_>>());

local.acknowledge_packet(1);

assert_eq!(local.unacknowledged_packets(), (3u16..32u16).filter(|&i| i != 4).collect::<Vec<_>>());

local.acknowledge_packet(3);

assert_eq!(local.unacknowledged_packets(), (5u16..32u16).collect::<Vec<_>>());
}
}
78 changes: 73 additions & 5 deletions src/networking/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl Remote {
/// Acknowledges a packet with the given sequence number. This means that the remote has received the packet.
pub fn acknowledge_packet(&mut self, sequence: u16) {
let index = (sequence % 1024) as usize;
let window_shift = sequence.max(self.ack) - self.ack;

// If the packet sequence is more recent, we update the remote sequence number.
if sequence_greater_than(sequence, self.ack) {
Expand All @@ -44,18 +45,14 @@ impl Remote {
self.receive_sequence_buffer[index] = u16::MAX;
}

let index = (sequence % 1024) as usize;
self.receive_sequence_buffer[index] = sequence;

self.ack = sequence;
}

self.packet_data.set(index, true);
self.ack_bitfield = (self.ack_bitfield << 1) | 1;
}

pub fn unacknowledged_packets(&self) -> Vec<u16> {
self.receive_sequence_buffer.iter().enumerate().filter(|(i, &sequence)| sequence != u16::MAX && !self.packet_data.get(*i)).map(|(_, &e)| e).collect()
self.ack_bitfield = (self.ack_bitfield << window_shift) | 1 << ((self.ack - sequence) % 32);
}

pub fn get_ack(&self) -> u16 {
Expand Down Expand Up @@ -101,4 +98,75 @@ impl<const N: usize> BitArray<N> where [u8; N / 8]: {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_packet_acknowledgement() {
let mut remote = Remote::new();

assert_eq!(remote.get_ack(), 0);
assert_eq!(remote.get_ack_bitfield(), 0);

for i in 0..32 {
remote.acknowledge_packet(i);
}

assert_eq!(remote.get_ack(), 31);
assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF);

for i in 0..32 {
remote.acknowledge_packet(i);
}

assert_eq!(remote.get_ack(), 31);
assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF);

for i in 32..48 {
remote.acknowledge_packet(i);
}

assert_eq!(remote.get_ack(), 47);
assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF);

for i in 48..64 {
remote.acknowledge_packet(i);
}

assert_eq!(remote.get_ack(), 63);
assert_eq!(remote.get_ack_bitfield(), 0xFFFF_FFFF);
}

#[test]
fn test_sparse_packet_acknowledgement() {
let mut remote = Remote::new();

assert_eq!(remote.get_ack(), 0);
assert_eq!(remote.get_ack_bitfield(), 0);

remote.acknowledge_packet(0);

assert_eq!(remote.get_ack(), 0);
assert_eq!(remote.get_ack_bitfield(), 1 << 0);

remote.acknowledge_packet(2);

assert_eq!(remote.get_ack(), 2);
assert_eq!(remote.get_ack_bitfield(), 1 << 2 | 1 << 0);

remote.acknowledge_packet(4);

assert_eq!(remote.get_ack(), 4);
assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 2 | 1 << 0);

remote.acknowledge_packet(1);

assert_eq!(remote.get_ack(), 4);
assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 3 | 1 << 2 | 1 << 0);

remote.acknowledge_packet(3);

assert_eq!(remote.get_ack(), 4);
assert_eq!(remote.get_ack_bitfield(), 1 << 4 | 1 << 3 | 1 << 2 | 1 << 1 | 1 << 0);

// TODO: Test when ack and sequence are more than 32 apart.
}
}
27 changes: 27 additions & 0 deletions src/networking/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ impl Server {
}

fn connect(&mut self, address: std::net::SocketAddr) -> Result<usize, ConnectionResults> {
if let Some(i) = self.clients.iter().enumerate().find(|(i, client)| if let Some(client) = client { client.address() == address } else { false }).map(|(i, _)| i) {
return Ok(i);
}

for (i, client) in self.clients.iter_mut().enumerate() {
if client.is_none() {
*client = Some(Client::new(address));
Expand All @@ -24,6 +28,10 @@ impl Server {
Err(ConnectionResults::ServerFull)
}

fn disconnect(&mut self, client_index: usize) {
self.clients[client_index] = None;
}

fn send(&mut self, client_index: usize,) {
if let Some(client) = self.clients[client_index].as_mut() {
client.send();
Expand Down Expand Up @@ -51,6 +59,25 @@ mod tests {
server.receive(client_index);
}

#[test]
fn test_server_reconnect() {
let mut server = Server::new();

let client_index_0 = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap();
let client_index_1 = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap();

assert_eq!(client_index_0, client_index_1);
}

#[test]
fn test_server_disconnect() {
let mut server = Server::new();

let client_index = server.connect(std::net::SocketAddr::new(std::net::Ipv4Addr::new(127, 0, 0, 1).into(), 6669)).unwrap();

server.disconnect(client_index);
}

#[test]
fn test_exhaust_connections() {
let mut server = Server::new();
Expand Down

0 comments on commit 1406aaa

Please sign in to comment.