diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index dc5093f4c..0d3d27bdc 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -16,8 +16,9 @@ use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, Port, }; use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; -use fxhash::FxHashMap; +use fxhash::{FxBuildHasher, FxHashMap}; use hugr_model::v0::{self as model}; +use petgraph::unionfind::UnionFind; use std::fmt::Write; pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; @@ -58,13 +59,8 @@ struct Context<'a> { /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, - /// Table that is used to track which ports are connected. - /// - /// Each group of ports that is connected together is represented by a - /// single link. When traversing the [`Hugr`] graph we assign a link to each - /// port by finding the smallest node/port pair among all the linked ports - /// and looking up the link for that pair in this table. - links: model::scope::LinkTable<(Node, Port)>, + /// Auxiliary structure for tracking the links between ports. + links: Links, /// The symbol table tracking symbols that are currently in scope. symbols: model::scope::SymbolTable<'a>, @@ -85,11 +81,13 @@ impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = model::Module::default(); module.nodes.reserve(hugr.node_count()); + let links = Links::new(hugr); Self { hugr, module, bump, + links, term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), @@ -98,7 +96,6 @@ impl<'a> Context<'a> { implicit_imports: FxHashMap::default(), node_to_id: FxHashMap::default(), id_to_node: FxHashMap::default(), - links: model::scope::LinkTable::default(), } } @@ -144,19 +141,6 @@ impl<'a> Context<'a> { }; } - /// Returns the edge id for a given port, creating a new edge if necessary. - /// - /// Any two ports that are linked will be represented by the same link. - fn get_link_index(&mut self, node: Node, port: impl Into) -> model::LinkIndex { - // To ensure that linked ports are represented by the same edge, we take the minimum port - // among all the linked ports, including the one we started with. - let port = port.into(); - let linked_ports = self.hugr.linked_ports(node, port); - let all_ports = std::iter::once((node, port)).chain(linked_ports); - let repr = all_ports.min().unwrap(); - self.links.use_link(repr) - } - pub fn make_ports( &mut self, node: Node, @@ -167,7 +151,7 @@ impl<'a> Context<'a> { let mut links = BumpVec::with_capacity_in(ports.size_hint().0, self.bump); for port in ports.take(num_ports) { - links.push(self.get_link_index(node, port)); + links.push(self.links.use_link(node, port)); } links.into_bump_slice() @@ -715,7 +699,7 @@ impl<'a> Context<'a> { } if source.is_none() { - source = Some(self.get_link_index(child, IncomingPort::from(0))); + source = Some(self.links.use_link(child, IncomingPort::from(0))); } } } @@ -1151,6 +1135,81 @@ impl<'a> Context<'a> { } } +type FxIndexSet = indexmap::IndexSet; + +/// Data structure for translating the edges between ports in the `Hugr` graph +/// into the hypergraph representation used by `hugr_model`. +struct Links { + /// Scoping helper that keeps track of the current nesting of regions + /// and translates the group of connected ports into a link index. + scope: model::scope::LinkTable, + + /// A mapping from each port to the group of connected ports it belongs to. + groups: FxHashMap<(Node, Port), u32>, +} + +impl Links { + /// Create the `Links` data structure from a `Hugr` graph by recording the + /// connectivity of the ports. + pub fn new(hugr: &Hugr) -> Self { + let scope = model::scope::LinkTable::new(); + + // We collect all ports that are in the hugr into an index set so that + // we have an association between the port and a numeric index. + let node_ports: FxIndexSet<(Node, Port)> = hugr + .nodes() + .flat_map(|node| hugr.all_node_ports(node).map(move |port| (node, port))) + .collect(); + + // We then use a union-find data structure to group together all ports that are connected. + let mut uf = UnionFind::::new(node_ports.len()); + + for (i, (node, port)) in node_ports.iter().enumerate() { + if let Ok(port) = port.as_incoming() { + for (other_node, other_port) in hugr.linked_outputs(*node, port) { + let other_port = Port::from(other_port); + let j = node_ports.get_index_of(&(other_node, other_port)).unwrap(); + uf.union(i as u32, j as u32); + } + } + } + + // We then collect the association between the port and the group of connected ports it belongs to. + let groups = node_ports + .into_iter() + .enumerate() + .map(|(i, node_port)| (node_port, uf.find(i as u32))) + .collect(); + + Self { scope, groups } + } + + /// Enter an isolated region. + pub fn enter(&mut self, region: model::RegionId) { + self.scope.enter(region); + } + + /// Leave an isolated region, returning the number of links and ports in the region. + /// + /// # Panics + /// + /// Panics if there is no remaining open scope to exit. + pub fn exit(&mut self) -> (u32, u32) { + self.scope.exit() + } + + /// Obtain the link index for a node and port. + /// + /// # Panics + /// + /// Panics if the port does not exist in the [`Hugr`] that was passed to `[Self::new]`. + pub fn use_link(&mut self, node: Node, port: impl Into) -> model::LinkIndex { + let port = port.into(); + let group = self.groups[&(node, port)]; + self.scope.use_link(group) + } +} + #[cfg(test)] mod test { use rstest::{fixture, rstest}; diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index cb6666ba9..0e9d11b4e 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -21,7 +21,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext))) (tag 0 [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext)))))) - (block [%6] [%3 %9] + (block [%6] [%3 %6] (signature (-> [(ctrl [?0])] [(ctrl [?0]) (ctrl [?0])] (ext))) (dfg [%7] [%8]