Skip to content

Commit

Permalink
feat: Builder and HugrMut add_op_xxx default to open extensions (#622)
Browse files Browse the repository at this point in the history
A lot of this is refactoring existing code that explicitly uses
`add_node_xxx(NodeType::open_extensions(...))` to `add_op_xxx(...)`, and
updating tests - in most cases just `validate` -> `update_validate`.

Algorithmically, there was only one change - in `infer.rs` - where
Functions, FuncDefs and Aliases are given the empty ExtensionSet if they
are open (otherwise these end up unsolved).

Also there are some gymnastics in a couple of the tests where we want
the correct error out of validation but can't do inference because the
Hugr is, well, invalid....so we use a couple of different techniques to
carry solutions from earlier, valid, Hugrs over to the invalid one.
 
closes #424
  • Loading branch information
acl-cqc authored Oct 31, 2023
1 parent 92b936e commit 5565027
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 130 deletions.
4 changes: 2 additions & 2 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ pub(crate) mod test {
])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(1, depth(&h, exit));
for n in [split, left, right, merge, head, tail] {
Expand Down Expand Up @@ -753,7 +753,7 @@ pub(crate) mod test {
let root = h.root();
let m = SiblingMut::<CfgID>::try_new(&mut h, root).unwrap();
transform_cfg_to_nested(&mut IdentityCfgMap::new(m));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(3, depth(&h, head));
for n in [split, left, right, merge] {
Expand Down
12 changes: 6 additions & 6 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,18 @@ pub(crate) mod test {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Input {
ops::Input {
types: signature.input,
}),
},
)
.unwrap();
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
ops::Output {
types: signature.output,
}),
},
)
.unwrap();
hugr
Expand Down
5 changes: 2 additions & 3 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let case_node =
// add case before any existing subsequent cases
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
// TODO: Allow this to be non-pure
self.hugr_mut().add_node_before(sibling_node, NodeType::open_extensions(case_op))?
self.hugr_mut().add_op_before(sibling_node, case_op)?
} else {
self.add_child_node(NodeType::open_extensions(case_op))?
self.add_child_op(case_op)?
};

self.case_nodes[case] = Some(case_node);
Expand Down
156 changes: 75 additions & 81 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ impl UnificationContext {
m_output,
node_type.op_signature().extension_reqs,
);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
) {
self.add_solution(m_input, ExtensionSet::new());
}
}
// We have a solution for everything!
Some(sig) => {
Expand All @@ -337,16 +343,16 @@ impl UnificationContext {
| Some(EdgeKind::ControlFlow)
)
}) {
let m_tgt = *self
.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
for (src_node, _) in hugr.linked_ports(tgt_node, port) {
let m_src = self
.extensions
.get(&(src_node, Direction::Outgoing))
.unwrap();
let m_tgt = self
.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
self.add_constraint(*m_src, Constraint::Equal(*m_tgt));
self.add_constraint(*m_src, Constraint::Equal(m_tgt));
}
}
}
Expand Down Expand Up @@ -720,11 +726,11 @@ mod test {
let root_node = NodeType::open_extensions(op);
let mut hugr = Hugr::new(root_node);

let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));
let input = ops::Input::new(type_row![NAT, NAT]);
let output = ops::Output::new(type_row![NAT]);

let input = hugr.add_node_with_parent(hugr.root(), input)?;
let output = hugr.add_node_with_parent(hugr.root(), output)?;
let input = hugr.add_op_with_parent(hugr.root(), input)?;
let output = hugr.add_op_with_parent(hugr.root(), output)?;

assert_matches!(hugr.get_io(hugr.root()), Some(_));

Expand All @@ -740,29 +746,29 @@ mod test {
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&C));

let add_a = hugr.add_node_with_parent(
let add_a = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_a_sig,
}),
},
)?;
let add_b = hugr.add_node_with_parent(
let add_b = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_b_sig,
}),
},
)?;
let add_ab = hugr.add_node_with_parent(
let add_ab = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_ab_sig,
}),
},
)?;
let mult_c = hugr.add_node_with_parent(
let mult_c = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_c_sig,
}),
},
)?;

hugr.connect(input, 0, add_a, 0)?;
Expand Down Expand Up @@ -896,29 +902,26 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let add_r = hugr.add_node_with_parent(
let add_r = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_r_sig,
}),
},
)?;

// Dangling thingy
let src_sig = FunctionType::new(type_row![], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());

let src = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG { signature: src_sig }),
)?;
let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;

let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]);
// Mult has open extension requirements, which we should solve to be "R"
let mult = hugr.add_node_with_parent(
let mult = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_sig,
}),
},
)?;

hugr.connect(input, 0, add_r, 0)?;
Expand Down Expand Up @@ -978,18 +981,18 @@ mod test {
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
let node = hugr.add_op_with_parent(parent, op)?;
let input = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Input {
ops::Input {
types: op_sig.input,
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Output {
ops::Output {
types: op_sig.output,
}),
},
)?;
Ok([node, input, output])
}
Expand All @@ -1010,20 +1013,20 @@ mod test {
Into::<OpType>::into(op).signature(),
)?;

let lift1 = hugr.add_node_with_parent(
let lift1 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
},
)?;

let lift2 = hugr.add_node_with_parent(
let lift2 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
},
)?;

hugr.connect(case_in, 0, lift1, 0)?;
Expand Down Expand Up @@ -1088,17 +1091,17 @@ mod test {
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
let input = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Input {
ops::Input {
types: type_row![NAT],
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Output {
ops::Output {
types: type_row![NAT],
}),
},
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
Expand All @@ -1119,12 +1122,12 @@ mod test {
.unwrap();

let lift = hugr
.add_node_with_parent(
.add_op_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
},
)
.unwrap();

Expand Down Expand Up @@ -1171,7 +1174,7 @@ mod test {

let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?;

let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?;
let dfg = hugr.add_op_with_parent(bb, op)?;

hugr.connect(bb_in, 0, dfg, 0)?;
hugr.connect(dfg, 0, bb_out, 0)?;
Expand Down Expand Up @@ -1203,23 +1206,20 @@ mod test {
extension_delta: entry_extensions,
};

let exit = hugr.add_node_with_parent(
let exit = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::BasicBlock::Exit {
ops::BasicBlock::Exit {
cfg_outputs: exit_types.into(),
}),
},
)?;

let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?;
let entry_in = hugr.add_node_with_parent(
let entry = hugr.add_op_before(exit, dfb)?;
let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?;
let entry_out = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(ops::Input { types: inputs }),
)?;
let entry_out = hugr.add_node_with_parent(
entry,
NodeType::open_extensions(ops::Output {
ops::Output {
types: vec![entry_tuple_sum].into(),
}),
},
)?;

Ok(([entry, entry_in, entry_out], exit))
Expand Down Expand Up @@ -1270,12 +1270,12 @@ mod test {
type_row![NAT],
)?;

let mkpred = hugr.add_node_with_parent(
let mkpred = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
A,
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
)),
),
)?;

// Internal wiring for DFGs
Expand Down Expand Up @@ -1366,12 +1366,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], twoway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down Expand Up @@ -1455,12 +1452,12 @@ mod test {
type_row![NAT],
)?;

let entry_dfg = hugr.add_node_with_parent(
let entry_dfg = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
)),
),
)?;

hugr.connect(entry_in, 0, entry_dfg, 0)?;
Expand Down Expand Up @@ -1536,12 +1533,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down
Loading

0 comments on commit 5565027

Please sign in to comment.