From 556502775e2158244187843c69ee572f85eeab2b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 31 Oct 2023 15:30:52 +0000 Subject: [PATCH] feat: Builder and HugrMut add_op_xxx default to open extensions (#622) A lot of this is refactoring existing code that explicitly uses `add_node_xxx(NodeType::open_extensions(...))` to `add_op_xxx(...)`, and updating tests - in most cases just `validate` -> `update_validate`. Algorithmically, there was only one change - in `infer.rs` - where Functions, FuncDefs and Aliases are given the empty ExtensionSet if they are open (otherwise these end up unsolved). Also there are some gymnastics in a couple of the tests where we want the correct error out of validation but can't do inference because the Hugr is, well, invalid....so we use a couple of different techniques to carry solutions from earlier, valid, Hugrs over to the invalid one. closes #424 --- src/algorithm/nest_cfgs.rs | 4 +- src/builder.rs | 12 +-- src/builder/conditional.rs | 5 +- src/extension/infer.rs | 156 +++++++++++++++----------------- src/hugr.rs | 11 +-- src/hugr/hugrmut.rs | 35 ++++--- src/hugr/rewrite/outline_cfg.rs | 4 +- src/hugr/validate.rs | 29 +++--- 8 files changed, 126 insertions(+), 130 deletions(-) diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index b7a9a1ec2..15154d4f2 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -646,7 +646,7 @@ pub(crate) mod test { ]) ); transform_cfg_to_nested(&mut IdentityCfgMap::new(rc)); - h.validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); assert_eq!(1, depth(&h, entry)); assert_eq!(1, depth(&h, exit)); for n in [split, left, right, merge, head, tail] { @@ -753,7 +753,7 @@ pub(crate) mod test { let root = h.root(); let m = SiblingMut::::try_new(&mut h, root).unwrap(); transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); - h.validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); assert_eq!(1, depth(&h, entry)); assert_eq!(3, depth(&h, head)); for n in [split, left, right, merge] { diff --git a/src/builder.rs b/src/builder.rs index c0260530b..52d2b7ca4 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -149,18 +149,18 @@ pub(crate) mod test { let mut hugr = Hugr::new(NodeType::pure(ops::DFG { signature: signature.clone(), })); - hugr.add_node_with_parent( + hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::Input { + ops::Input { types: signature.input, - }), + }, ) .unwrap(); - hugr.add_node_with_parent( + hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::Output { + ops::Output { types: signature.output, - }), + }, ) .unwrap(); hugr diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index f1af1ad0d..da8808eea 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -126,10 +126,9 @@ impl + AsRef> ConditionalBuilder { let case_node = // add case before any existing subsequent cases if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() { - // TODO: Allow this to be non-pure - self.hugr_mut().add_node_before(sibling_node, NodeType::open_extensions(case_op))? + self.hugr_mut().add_op_before(sibling_node, case_op)? } else { - self.add_child_node(NodeType::open_extensions(case_op))? + self.add_child_op(case_op)? }; self.case_nodes[case] = Some(case_node); diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 22344492b..db0b66694 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -316,6 +316,12 @@ impl UnificationContext { m_output, node_type.op_signature().extension_reqs, ); + if matches!( + node_type.tag(), + OpTag::Alias | OpTag::Function | OpTag::FuncDefn + ) { + self.add_solution(m_input, ExtensionSet::new()); + } } // We have a solution for everything! Some(sig) => { @@ -337,16 +343,16 @@ impl UnificationContext { | Some(EdgeKind::ControlFlow) ) }) { + let m_tgt = *self + .extensions + .get(&(tgt_node, Direction::Incoming)) + .unwrap(); for (src_node, _) in hugr.linked_ports(tgt_node, port) { let m_src = self .extensions .get(&(src_node, Direction::Outgoing)) .unwrap(); - let m_tgt = self - .extensions - .get(&(tgt_node, Direction::Incoming)) - .unwrap(); - self.add_constraint(*m_src, Constraint::Equal(*m_tgt)); + self.add_constraint(*m_src, Constraint::Equal(m_tgt)); } } } @@ -720,11 +726,11 @@ mod test { let root_node = NodeType::open_extensions(op); let mut hugr = Hugr::new(root_node); - let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT])); - let output = NodeType::open_extensions(ops::Output::new(type_row![NAT])); + let input = ops::Input::new(type_row![NAT, NAT]); + let output = ops::Output::new(type_row![NAT]); - let input = hugr.add_node_with_parent(hugr.root(), input)?; - let output = hugr.add_node_with_parent(hugr.root(), output)?; + let input = hugr.add_op_with_parent(hugr.root(), input)?; + let output = hugr.add_op_with_parent(hugr.root(), output)?; assert_matches!(hugr.get_io(hugr.root()), Some(_)); @@ -740,29 +746,29 @@ mod test { let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&C)); - let add_a = hugr.add_node_with_parent( + let add_a = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_a_sig, - }), + }, )?; - let add_b = hugr.add_node_with_parent( + let add_b = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_b_sig, - }), + }, )?; - let add_ab = hugr.add_node_with_parent( + let add_ab = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_ab_sig, - }), + }, )?; - let mult_c = hugr.add_node_with_parent( + let mult_c = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: mult_c_sig, - }), + }, )?; hugr.connect(input, 0, add_a, 0)?; @@ -896,29 +902,26 @@ mod test { let [input, output] = hugr.get_io(hugr.root()).unwrap(); let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let add_r = hugr.add_node_with_parent( + let add_r = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: add_r_sig, - }), + }, )?; // Dangling thingy let src_sig = FunctionType::new(type_row![], type_row![NAT]) .with_extension_delta(&ExtensionSet::new()); - let src = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::DFG { signature: src_sig }), - )?; + let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?; let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]); // Mult has open extension requirements, which we should solve to be "R" - let mult = hugr.add_node_with_parent( + let mult = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::DFG { + ops::DFG { signature: mult_sig, - }), + }, )?; hugr.connect(input, 0, add_r, 0)?; @@ -978,18 +981,18 @@ mod test { ) -> Result<[Node; 3], Box> { let op: OpType = op.into(); - let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?; - let input = hugr.add_node_with_parent( + let node = hugr.add_op_with_parent(parent, op)?; + let input = hugr.add_op_with_parent( node, - NodeType::open_extensions(ops::Input { + ops::Input { types: op_sig.input, - }), + }, )?; - let output = hugr.add_node_with_parent( + let output = hugr.add_op_with_parent( node, - NodeType::open_extensions(ops::Output { + ops::Output { types: op_sig.output, - }), + }, )?; Ok([node, input, output]) } @@ -1010,20 +1013,20 @@ mod test { Into::::into(op).signature(), )?; - let lift1 = hugr.add_node_with_parent( + let lift1 = hugr.add_op_with_parent( case, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: first_ext, - }), + }, )?; - let lift2 = hugr.add_node_with_parent( + let lift2 = hugr.add_op_with_parent( case, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: second_ext, - }), + }, )?; hugr.connect(case_in, 0, lift1, 0)?; @@ -1088,17 +1091,17 @@ mod test { })); let root = hugr.root(); - let input = hugr.add_node_with_parent( + let input = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::Input { + ops::Input { types: type_row![NAT], - }), + }, )?; - let output = hugr.add_node_with_parent( + let output = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::Output { + ops::Output { types: type_row![NAT], - }), + }, )?; // Make identical dataflow nodes which add extension requirement "A" or "B" @@ -1119,12 +1122,12 @@ mod test { .unwrap(); let lift = hugr - .add_node_with_parent( + .add_op_with_parent( node, - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![NAT], new_extension: ext, - }), + }, ) .unwrap(); @@ -1171,7 +1174,7 @@ mod test { let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; - let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; + let dfg = hugr.add_op_with_parent(bb, op)?; hugr.connect(bb_in, 0, dfg, 0)?; hugr.connect(dfg, 0, bb_out, 0)?; @@ -1203,23 +1206,20 @@ mod test { extension_delta: entry_extensions, }; - let exit = hugr.add_node_with_parent( + let exit = hugr.add_op_with_parent( root, - NodeType::open_extensions(ops::BasicBlock::Exit { + ops::BasicBlock::Exit { cfg_outputs: exit_types.into(), - }), + }, )?; - let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; - let entry_in = hugr.add_node_with_parent( + let entry = hugr.add_op_before(exit, dfb)?; + let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?; + let entry_out = hugr.add_op_with_parent( entry, - NodeType::open_extensions(ops::Input { types: inputs }), - )?; - let entry_out = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Output { + ops::Output { types: vec![entry_tuple_sum].into(), - }), + }, )?; Ok(([entry, entry_in, entry_out], exit)) @@ -1270,12 +1270,12 @@ mod test { type_row![NAT], )?; - let mkpred = hugr.add_node_with_parent( + let mkpred = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( + make_opaque( A, FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a), - )), + ), )?; // Internal wiring for DFGs @@ -1366,12 +1366,9 @@ mod test { type_row![NAT], )?; - let entry_mid = hugr.add_node_with_parent( + let entry_mid = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( - UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], twoway(NAT)), - )), + make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))), )?; hugr.connect(entry_in, 0, entry_mid, 0)?; @@ -1455,12 +1452,12 @@ mod test { type_row![NAT], )?; - let entry_dfg = hugr.add_node_with_parent( + let entry_dfg = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( + make_opaque( UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), - )), + ), )?; hugr.connect(entry_in, 0, entry_dfg, 0)?; @@ -1536,12 +1533,9 @@ mod test { type_row![NAT], )?; - let entry_mid = hugr.add_node_with_parent( + let entry_mid = hugr.add_op_with_parent( entry, - NodeType::open_extensions(make_opaque( - UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], oneway(NAT)), - )), + make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))), )?; hugr.connect(entry_in, 0, entry_mid, 0)?; diff --git a/src/hugr.rs b/src/hugr.rs index a11099f55..d6dcd5ec6 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -239,8 +239,7 @@ impl Hugr { /// Add a node to the graph, with the default conversion from OpType to NodeType pub(crate) fn add_op(&mut self, op: impl Into) -> Node { - // TODO: Default to `NodeType::open_extensions` once we can infer extensions - self.add_node(NodeType::pure(op)) + self.add_node(NodeType::open_extensions(op)) } /// Add a node to the graph. @@ -356,7 +355,7 @@ impl From for PyErr { #[cfg(test)] mod test { - use super::{Hugr, HugrView, NodeType}; + use super::{Hugr, HugrView}; use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; @@ -392,12 +391,12 @@ mod test { FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), ); let [input, output] = hugr.get_io(hugr.root()).unwrap(); - let lift = hugr.add_node_with_parent( + let lift = hugr.add_op_with_parent( hugr.root(), - NodeType::open_extensions(ops::LeafOp::Lift { + ops::LeafOp::Lift { type_row: type_row![BIT], new_extension: "R".try_into().unwrap(), - }), + }, )?; hugr.connect(input, 0, lift, 0)?; hugr.connect(lift, 0, output, 0)?; diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index ac8bb53ce..fca006b5d 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -37,8 +37,7 @@ pub trait HugrMut: HugrMutInternals { parent: Node, op: impl Into, ) -> Result { - // TODO: Default to `NodeType::open_extensions` once we can infer extensions - self.add_node_with_parent(parent, NodeType::pure(op)) + self.add_node_with_parent(parent, NodeType::open_extensions(op)) } /// Add a node to the graph with a parent in the hierarchy. @@ -64,9 +63,9 @@ pub trait HugrMut: HugrMutInternals { self.hugr_mut().add_op_before(sibling, op) } - /// A generalisation of [`HugrMut::add_op_before`], needed temporarily until - /// add_op type methods all default to creating nodes with open extensions. - /// See issue #424 + /// Add a node to the graph as the previous sibling of another node. + /// + /// The sibling node's parent becomes the new node's parent. #[inline] fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result { self.valid_non_root(sibling)?; @@ -218,7 +217,7 @@ impl + AsMut> HugrMut for T { } fn add_op_before(&mut self, sibling: Node, op: impl Into) -> Result { - self.add_node_before(sibling, NodeType::pure(op)) + self.add_node_before(sibling, NodeType::open_extensions(op)) } fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result { @@ -601,16 +600,15 @@ mod test { #[test] fn simple_function() { - // Starts an empty builder - let mut builder = Hugr::default(); + let mut hugr = Hugr::default(); // Create the root module definition - let module: Node = builder.root(); + let module: Node = hugr.root(); // Start a main function with two nat inputs. // // `add_op` is equivalent to `add_root_op` followed by `set_parent` - let f: Node = builder + let f: Node = hugr .add_op_with_parent( module, ops::FuncDefn { @@ -621,22 +619,21 @@ mod test { .expect("Failed to add function definition node"); { - let f_in = builder - .add_op_with_parent(f, ops::Input::new(type_row![NAT])) + let f_in = hugr + .add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT]))) .unwrap(); - let f_out = builder + let f_out = hugr .add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT])) .unwrap(); - let noop = builder + let noop = hugr .add_op_with_parent(f, LeafOp::Noop { ty: NAT }) .unwrap(); - assert!(builder.connect(f_in, 0, noop, 0).is_ok()); - assert!(builder.connect(noop, 0, f_out, 0).is_ok()); - assert!(builder.connect(noop, 0, f_out, 1).is_ok()); + hugr.connect(f_in, 0, noop, 0).unwrap(); + hugr.connect(noop, 0, f_out, 0).unwrap(); + hugr.connect(noop, 0, f_out, 1).unwrap(); } - // Finish the construction and create the HUGR - builder.validate(&PRELUDE_REGISTRY).unwrap(); + hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); } } diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 589a85643..179ea486d 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -304,7 +304,7 @@ mod test { let (mut h, head, tail) = build_conditional_in_loop_cfg(false).unwrap(); h.update_validate(&PRELUDE_REGISTRY).unwrap(); do_outline_cfg_test(&mut h, head, tail, 1); - h.validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); } fn do_outline_cfg_test( @@ -406,7 +406,7 @@ mod test { let (new_block, new_cfg) = h .apply_rewrite(OutlineCfg::new(blocks_to_move.iter().copied())) .unwrap(); - h.validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); assert_eq!(new_block, h.children(h.root()).next().unwrap()); assert_matches!( h.get_optype(new_block), diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 1b56f874c..be8d5062d 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -886,13 +886,13 @@ mod test { let mut b = Hugr::new(NodeType::pure(dfg_op)); let root = b.root(); add_df_children(&mut b, root, 1); - assert_eq!(b.validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); } #[test] fn simple_hugr() { - let b = make_simple_hugr(2).0; - assert_eq!(b.validate(&EMPTY_REG), Ok(())); + let mut b = make_simple_hugr(2).0; + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); } #[test] @@ -919,7 +919,7 @@ mod test { ) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.update_validate(&EMPTY_REG), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) ); @@ -927,9 +927,10 @@ mod test { add_df_children(&mut b, new_def, 2); b.set_parent(new_def, copy).unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.update_validate(&EMPTY_REG), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); + let closure = b.infer_extensions().unwrap(); b.set_parent(new_def, root).unwrap(); // After moving the previous definition to a valid place, @@ -938,7 +939,7 @@ mod test { .add_op_with_parent(root, ops::Input::new(type_row![])) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate_with_extension_closure(closure, &EMPTY_REG), Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} ); } @@ -999,7 +1000,11 @@ mod test { .map_into() .collect_tuple() .unwrap(); - + // Write Extension annotations into the Hugr while it's still well-formed + // enough for us to compute them + let closure = b.infer_extensions().unwrap(); + b.instantiate_extensions(closure); + b.validate(&EMPTY_REG).unwrap(); b.replace_op( copy, NodeType::pure(ops::CFG { @@ -1035,7 +1040,7 @@ mod test { ) .unwrap(); b.add_other_edge(block, exit).unwrap(); - assert_eq!(b.validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); // Test malformed errors @@ -1377,7 +1382,7 @@ mod test { } #[test] fn unregistered_extension() { - let (h, def) = identity_hugr_with_type(USIZE_T); + let (mut h, def) = identity_hugr_with_type(USIZE_T); assert_eq!( h.validate(&EMPTY_REG), Err(ValidationError::SignatureError { @@ -1385,7 +1390,7 @@ mod test { cause: SignatureError::ExtensionNotFound(PRELUDE.name.clone()) }) ); - h.validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); } #[test] @@ -1416,7 +1421,9 @@ mod test { TypeBound::Any, )); assert_eq!( - identity_hugr_with_type(valid.clone()).0.validate(®), + identity_hugr_with_type(valid.clone()) + .0 + .update_validate(®), Ok(()) );