Skip to content

Commit

Permalink
Merge pull request #298 from zarns/feature/zarns/catanatron-rust
Browse files Browse the repository at this point in the history
Rust Rewrite
  • Loading branch information
bcollazo authored Jan 9, 2025
2 parents 6ae93a0 + d954062 commit 6cc44c1
Show file tree
Hide file tree
Showing 8 changed files with 891 additions and 91 deletions.
4 changes: 2 additions & 2 deletions catanatron_rust/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub enum ActionPrompt {
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Action {
// The first value in all these is the color of the player.
Roll(u8), // None. Log instead sets it to (int, int) rolled.
Roll(u8, Option<(u8, u8)>), // None. Log instead sets it to (int, int) rolled.
MoveRobber(u8, Coordinate, Option<u8>), // Log has extra element of card stolen.
Discard(u8), // value is None|Resource[].
BuildRoad(u8, EdgeId),
Expand Down Expand Up @@ -106,7 +106,7 @@ pub enum MapType {
// TODO: Make immutable and read-only
#[derive(Debug)]
pub struct GameConfiguration {
pub dicard_limit: u8,
pub discard_limit: u8,
pub vps_to_win: u8,
pub map_type: MapType,
pub num_players: u8,
Expand Down
2 changes: 1 addition & 1 deletion catanatron_rust/src/game.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod tests {
) -> (GlobalState, GameConfiguration, HashMap<u8, Box<dyn Player>>) {
let global_state = GlobalState::new();
let config = GameConfiguration {
dicard_limit: 7,
discard_limit: 7,
vps_to_win: 10,
map_type: MapType::Base,
num_players,
Expand Down
2 changes: 1 addition & 1 deletion catanatron_rust/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fn main() {
println!("Colors {:?}", COLORS);

let config = GameConfiguration {
dicard_limit: 7,
discard_limit: 7,
vps_to_win: 10,
map_type: MapType::Base,
num_players: 2,
Expand Down
31 changes: 28 additions & 3 deletions catanatron_rust/src/map_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,21 @@ fn get_unit_vector(direction: Direction) -> (i8, i8, i8) {

#[derive(Debug, Clone, PartialEq)]
pub struct Hexagon {
// TODO: id?
pub(crate) nodes: HashMap<NodeRef, NodeId>,
pub(crate) edges: HashMap<EdgeRef, EdgeId>,
}

#[derive(Debug, Clone, PartialEq)]
pub struct LandTile {
pub(crate) id: u8,
pub(crate) hexagon: Hexagon,
pub(crate) resource: Option<Resource>,
pub(crate) number: Option<u8>,
}

#[derive(Debug, Clone, PartialEq)]
pub struct PortTile {
pub(crate) id: u8,
pub(crate) hexagon: Hexagon,
pub(crate) resource: Option<Resource>,
pub(crate) direction: Direction,
Expand Down Expand Up @@ -162,6 +163,10 @@ impl MapInstance {
pub fn get_neighbor_edges(&self, node_id: NodeId) -> Vec<EdgeId> {
self.edge_neighbors.get(&node_id).unwrap().clone()
}

pub fn get_adjacent_tiles(&self, node_id: NodeId) -> Option<&Vec<LandTile>> {
self.adjacent_land_tiles.get(&node_id)
}
}

impl MapInstance {
Expand All @@ -185,6 +190,9 @@ impl MapInstance {
let mut hexagons: HashMap<Coordinate, Hexagon> = HashMap::new();
let mut tiles: HashMap<Coordinate, Tile> = HashMap::new();
let mut autoinc = 0;
let mut tile_autoinc = 0;
let mut port_autoinc = 0;

for (&coordinate, &tile_slot) in map_template.topology.iter() {
let (nodes, edges, new_autoinc) = get_nodes_edges(&hexagons, coordinate, autoinc);
autoinc = new_autoinc;
Expand All @@ -194,6 +202,7 @@ impl MapInstance {
let resource = shuffled_tiles.pop().unwrap();
if resource.is_none() {
let land_tile = LandTile {
id: tile_autoinc,
hexagon: hexagon.clone(),
resource,
number: None,
Expand All @@ -202,12 +211,14 @@ impl MapInstance {
} else {
let number = shuffled_numbers.pop().unwrap();
let land_tile = LandTile {
id: tile_autoinc,
hexagon: hexagon.clone(),
resource,
number: Some(number),
};
tiles.insert(coordinate, Tile::Land(land_tile));
}
tile_autoinc += 1;
} else if tile_slot == TileSlot::Water {
let water_tile = WaterTile {
hexagon: hexagon.clone(),
Expand All @@ -225,11 +236,13 @@ impl MapInstance {
};
let resource = shuffled_ports.pop().unwrap();
let port_tile = PortTile {
id: port_autoinc,
hexagon: hexagon.clone(),
resource,
direction,
};
tiles.insert(coordinate, Tile::Port(port_tile));
port_autoinc += 1;
}

hexagons.insert(coordinate, hexagon);
Expand Down Expand Up @@ -275,8 +288,18 @@ impl MapInstance {
land_edges.insert(edge_id);
node_neighbors.entry(edge_id.0).or_default().push(edge_id.1);
node_neighbors.entry(edge_id.1).or_default().push(edge_id.0);
edge_neighbors.entry(edge_id.0).or_default().push(edge_id);
edge_neighbors.entry(edge_id.1).or_default().push(edge_id);

// Only insert edge into edge_neighbors if not already present
{
let edges_for_node_0 = edge_neighbors.entry(edge_id.0).or_default();
if !edges_for_node_0.contains(&edge_id) {
edges_for_node_0.push(edge_id);
}
let edges_for_node_1 = edge_neighbors.entry(edge_id.1).or_default();
if !edges_for_node_1.contains(&edge_id) {
edges_for_node_1.push(edge_id);
}
}
});
} else if let Tile::Port(port_tile) = tile {
let (a_noderef, b_noderef) = get_noderefs_from_port_direction(port_tile.direction);
Expand Down Expand Up @@ -588,6 +611,7 @@ mod tests {
assert_eq!(
map_instance.tiles.get(&(0, 0, 0)),
Some(&Tile::Land(LandTile {
id: 0,
hexagon: Hexagon {
nodes: HashMap::from([
(NodeRef::North, 0),
Expand All @@ -613,6 +637,7 @@ mod tests {
assert_eq!(
map_instance.land_tiles.get(&(1, -1, 0)),
Some(&LandTile {
id: 1,
hexagon: Hexagon {
nodes: HashMap::from([
(NodeRef::North, 6),
Expand Down
114 changes: 110 additions & 4 deletions catanatron_rust/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl State {
pub fn new_base() -> Self {
let global_state = GlobalState::new();
let config = GameConfiguration {
dicard_limit: 7,
discard_limit: 7,
vps_to_win: 10,
map_type: MapType::Base,
num_players: 4,
Expand Down Expand Up @@ -134,7 +134,7 @@ impl State {
pub fn can_play_dev(&self, dev_card: u8) -> bool {
let color = self.get_current_color();
let dev_card_index = dev_card as usize;
let has_one = self.vector[player_devhand_slice(color)][dev_card_index] > 0;
let has_one = self.vector[player_devhand_slice(self.config.num_players, color)][dev_card_index] > 0;
let has_played_in_turn = self.vector[HAS_PLAYED_DEV_CARD] == 1;
has_one && !has_played_in_turn
}
Expand All @@ -159,11 +159,19 @@ impl State {

// TODO: Maybe move to mutations(?)
pub fn get_mut_player_hand(&mut self, color: u8) -> &mut [u8] {
&mut self.vector[player_hand_slice(color)]
&mut self.vector[player_hand_slice(self.config.num_players, color)]
}

pub fn get_player_hand(&self, color: u8) -> &[u8] {
&self.vector[player_hand_slice(color)]
&self.vector[player_hand_slice(self.config.num_players, color)]
}

pub fn get_mut_player_devhand(&mut self, color: u8) -> &mut [u8] {
&mut self.vector[player_devhand_slice(self.config.num_players, color)]
}

pub fn get_player_devhand(&self, color: u8) -> &[u8] {
&self.vector[player_devhand_slice(self.config.num_players, color)]
}

pub fn winner(&self) -> Option<u8> {
Expand Down Expand Up @@ -226,6 +234,22 @@ impl State {
buildable.into_iter().collect()
}

pub fn buildable_node_ids(&self, color: u8,) -> Vec<u8> {
let road_subgraphs = match self.connected_components.get(&color) {
Some(components) => components,
None => &vec![],
};

let mut road_connected_nodes: HashSet<u8> = HashSet::new();
for component in road_subgraphs {
road_connected_nodes.extend(component);
}

road_connected_nodes.intersection(&self.board_buildable_ids)
.copied()
.collect()
}

fn get_connected_component_index(&self, color: u8, a: u8) -> Option<usize> {
let components = self.connected_components.get(&color).unwrap();
for (i, component) in components.iter().enumerate() {
Expand Down Expand Up @@ -256,6 +280,66 @@ impl State {
let (node1, node2) = edge;
node1 == a || node2 == a
}

fn dfs_longest_path(
&self,
node: NodeId,
parent: Option<NodeId>,
connected_set: &HashSet<NodeId>,
color: u8,
current_path: &mut Vec<EdgeId>,
best_path: &mut Vec<EdgeId>,
) {
// If current_path is longer than what we have, store it
if current_path.len() > best_path.len() {
*best_path = current_path.clone();
}

for &neighbor in &self.map_instance.get_neighbor_nodes(node) {
// Must be in the connected component
if !connected_set.contains(&neighbor) {
continue;
}
let edge = (node.min(neighbor), node.max(neighbor));

// Avoid going back to parent
if parent == Some(neighbor) {
continue;
}
// Skip roads not owned by us
if self.roads.get(&edge) != Some(&color) {
continue;
}
// Acyclic check
if current_path.contains(&edge) {
continue;
}

// Move forward
current_path.push(edge);
self.dfs_longest_path(neighbor, Some(node), connected_set, color, current_path, best_path);
current_path.pop();
}
}

pub fn longest_acyclic_path(&self, connected_node_set: &HashSet<NodeId>, color: u8) -> Vec<EdgeId> {
if connected_node_set.is_empty() {
return vec![];
}

let mut overall_best_path = Vec::new();

for &start_node in connected_node_set {
let mut current_path = Vec::new();
let mut best_path = Vec::new();

self.dfs_longest_path(start_node, None, connected_node_set, color, &mut current_path, &mut best_path);
if best_path.len() > overall_best_path.len() {
overall_best_path = best_path;
}
}
overall_best_path
}
}

#[cfg(test)]
Expand All @@ -277,4 +361,26 @@ mod tests {
assert!(!state.is_moving_robber());
assert!(!state.is_discarding());
}

#[test]
fn test_longest_acyclic_path() {
let mut state = State::new_base();
let color = 0;

state.roads.insert((0, 1), color);
state.roads.insert((1, 2), color);
state.roads.insert((2, 3), color);
state.roads.insert((3, 4), color);
state.roads.insert((4, 5), color);
state.roads.insert((0, 5), color);
state.roads.insert((0, 20), color);
state.roads.insert((20, 19), color);
state.roads.insert((20, 22), color);
state.roads.insert((22, 23), color);
state.roads.insert((6, 23), color);

let all_nodes = HashSet::from([0, 1, 2, 3, 4, 5, 19, 20, 22, 23, 6]);
let path = state.longest_acyclic_path(&all_nodes, color);
assert_eq!(path.len(), 10);
}
}
Loading

1 comment on commit 6cc44c1

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance Alert

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.30.

Benchmark suite Current: 6cc44c1 Previous: f2b016d Ratio
tests/integration_tests/test_speed.py::test_same_turn_alphabeta_speed 3.0558931910079745 iter/sec (stddev: 0.41233688221203635) 3.9830607182922346 iter/sec (stddev: 0.28893228611446836) 1.30

This comment was automatically generated by workflow.

Please sign in to comment.