Skip to content

Commit f7e97e6

Browse files
committed
finish implement owned HUGR building
for Module and DFG
1 parent 01781ec commit f7e97e6

File tree

8 files changed

+87
-113
lines changed

8 files changed

+87
-113
lines changed

src/algorithm/nest_cfgs.rs

+9-13
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
400400
#[cfg(test)]
401401
pub(crate) mod test {
402402
use super::*;
403-
use crate::builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, HugrMutRef};
403+
use crate::builder::{
404+
BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, HugrMutRef, ModuleBuilder,
405+
};
404406
use crate::ops::{
405407
handle::{BasicBlockID, CfgID, ConstID, NodeHandle},
406408
ConstValue,
@@ -428,8 +430,7 @@ pub(crate) mod test {
428430
// /-> left --\
429431
// entry -> split > merge -> head -> tail -> exit
430432
// \-> right -/ \-<--<-/
431-
let mut builder = HugrBuilder::new();
432-
let mut module_builder = builder.module_hugr_builder();
433+
let mut module_builder = ModuleBuilder::new();
433434
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
434435
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
435436
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -452,8 +453,7 @@ pub(crate) mod test {
452453
let cfg_id = cfg_builder.finish_container()?;
453454

454455
func_builder.finish_with_outputs(cfg_id.outputs())?;
455-
module_builder.finish_container()?;
456-
let h = builder.finish()?;
456+
let h = module_builder.finish_hugr()?;
457457

458458
let (entry, exit) = (entry.node(), exit.node());
459459
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
@@ -482,8 +482,7 @@ pub(crate) mod test {
482482
// \-> right -/ \-<--<-/
483483
// Here we would like two consecutive regions, but there is no *edge* between
484484
// the conditional and the loop to indicate the boundary, so we cannot separate them.
485-
let mut builder = HugrBuilder::new();
486-
let mut module_builder = builder.module_hugr_builder();
485+
let mut module_builder = ModuleBuilder::new();
487486
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
488487
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
489488
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -504,9 +503,8 @@ pub(crate) mod test {
504503
let cfg_id = cfg_builder.finish_container()?;
505504

506505
func_builder.finish_with_outputs(cfg_id.outputs())?;
507-
module_builder.finish_container()?;
508506

509-
let h = builder.finish()?;
507+
let h = module_builder.finish_hugr()?;
510508

511509
let (entry, exit) = (entry.node(), exit.node());
512510
let (merge, tail) = (merge.node(), tail.node());
@@ -685,8 +683,7 @@ pub(crate) mod test {
685683
) -> Result<(Hugr, CfgID, BasicBlockID, BasicBlockID), BuildError> {
686684
//let sum2_type = SimpleType::new_predicate(2);
687685

688-
let mut builder = HugrBuilder::new();
689-
let mut module_builder = builder.module_hugr_builder();
686+
let mut module_builder = ModuleBuilder::new();
690687
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
691688
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
692689
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -720,9 +717,8 @@ pub(crate) mod test {
720717
let cfg_id = cfg_builder.finish_container()?;
721718

722719
func_builder.finish_with_outputs(cfg_id.outputs())?;
723-
module_builder.finish_container()?;
724720

725-
let h = builder.finish()?;
721+
let h = module_builder.finish_hugr()?;
726722

727723
Ok((h, *cfg_id.handle(), head, tail))
728724
}

src/builder.rs

+8-56
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
//!
33
use thiserror::Error;
44

5-
use crate::hugr::{HugrError, HugrMut, HugrView, Node, ValidationError, Wire};
5+
use crate::hugr::{HugrError, HugrMut, Node, ValidationError, Wire};
66
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
7-
use crate::ops::DataflowOp;
8-
use crate::types::{LinearType, Signature, TypeRow};
9-
use crate::Hugr;
7+
8+
use crate::types::LinearType;
109

1110
pub mod handle;
1211
pub use handle::BuildHandle;
1312

1413
mod build_traits;
15-
pub use build_traits::{Container, Dataflow};
14+
pub use build_traits::{Container, Dataflow, HugrBuilder};
1615

1716
mod dataflow;
1817
pub use dataflow::{DFGBuilder, DFGWrapper, FunctionBuilder};
@@ -69,51 +68,6 @@ pub enum BuildError {
6968
CircuitError(#[from] circuit_builder::CircuitBuildError),
7069
}
7170

72-
#[derive(Default)]
73-
/// Base builder, can generate builders for containers
74-
pub struct HugrBuilder {
75-
base: HugrMut,
76-
}
77-
78-
impl HugrBuilder {
79-
/// Initialize a new builder
80-
pub fn new() -> Self {
81-
// initially assume to be a module root, will be replaced if not.
82-
Self {
83-
base: HugrMut::new_module(),
84-
}
85-
}
86-
87-
/// Use this builder to build a module HUGR
88-
pub fn module_hugr_builder(&mut self) -> ModuleBuilder<&mut HugrMut> {
89-
ModuleBuilder(&mut self.base)
90-
}
91-
92-
/// Use this builder to build a DFG HUGR
93-
pub fn dfg_hugr_builder(
94-
&mut self,
95-
input: impl Into<TypeRow>,
96-
output: impl Into<TypeRow>,
97-
) -> Result<DFGBuilder<&mut HugrMut>, BuildError> {
98-
let input = input.into();
99-
let output = output.into();
100-
let root = self.base.hugr().root();
101-
let dfg_op = DataflowOp::DFG {
102-
signature: Signature::new_df(input.clone(), output.clone()),
103-
};
104-
self.base.replace_op(root, dfg_op);
105-
106-
DFGBuilder::create_with_io(&mut self.base, root, input, output)
107-
}
108-
109-
// TODO: CFG, BasicBlock, Def, Conditional, TailLoop, Case
110-
111-
/// Complete building and return HUGR, performing validation.
112-
pub fn finish(self) -> Result<Hugr, BuildError> {
113-
Ok(self.base.finish()?)
114-
}
115-
}
116-
11771
impl AsMut<HugrMut> for HugrMut {
11872
fn as_mut(&mut self) -> &mut HugrMut {
11973
self
@@ -138,8 +92,8 @@ mod test {
13892
use crate::Hugr;
13993

14094
use super::handle::BuildHandle;
141-
use super::{BuildError, Dataflow, FuncID, FunctionBuilder};
142-
use super::{Container, HugrBuilder};
95+
use super::HugrBuilder;
96+
use super::{BuildError, Dataflow, FuncID, FunctionBuilder, ModuleBuilder};
14397

14498
pub(super) const NAT: SimpleType = SimpleType::Classic(ClassicType::i64());
14599
pub(super) const F64: SimpleType = SimpleType::Classic(ClassicType::F64);
@@ -160,12 +114,10 @@ mod test {
160114
FunctionBuilder<&mut HugrMut, true>,
161115
) -> Result<BuildHandle<FuncID<true>>, BuildError>,
162116
) -> Result<Hugr, BuildError> {
163-
let mut builder = HugrBuilder::new();
164-
let mut module_builder = builder.module_hugr_builder();
117+
let mut module_builder = ModuleBuilder::new();
165118
let f_builder = module_builder.declare_and_def("main", signature)?;
166119

167120
f(f_builder)?;
168-
module_builder.finish_container()?;
169-
builder.finish()
121+
Ok(module_builder.finish_hugr()?)
170122
}
171123
}

src/builder/build_traits.rs

+22
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ pub trait Container {
6363
fn finish_container(self) -> Result<Self::ContainerHandle, BuildError>;
6464
}
6565

66+
/// Types implementing this trait can be used to build complete HUGRs
67+
/// (with varying root node types)
68+
pub trait HugrBuilder: Container {
69+
/// Finish building the HUGR, perform any validation checks and return it.
70+
fn finish_hugr(self) -> Result<Hugr, ValidationError>;
71+
}
72+
6673
/// Trait for building dataflow regions of a HUGR.
6774
pub trait Dataflow: Container {
6875
/// Return indices of input and output nodes.
@@ -655,3 +662,18 @@ fn if_copy_add_port(base: &mut HugrMut, src: Node) -> Option<usize> {
655662
None
656663
}
657664
}
665+
666+
pub trait DataflowHugrBuilder: HugrBuilder + Dataflow {
667+
fn finish_hugr_with_outputs(
668+
mut self,
669+
outputs: impl IntoIterator<Item = Wire>,
670+
) -> Result<Hugr, BuildError>
671+
where
672+
Self: Sized,
673+
{
674+
self.set_outputs(outputs)?;
675+
Ok(self.finish_hugr()?)
676+
}
677+
}
678+
679+
impl<T: HugrBuilder + Dataflow> DataflowHugrBuilder for T {}

src/builder/cfg.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ impl<B: HugrMutRef> BlockBuilder<B> {
184184

185185
#[cfg(test)]
186186
mod test {
187-
use crate::builder::HugrBuilder;
187+
use crate::builder::build_traits::HugrBuilder;
188+
use crate::builder::ModuleBuilder;
188189
use crate::{builder::test::NAT, ops::ConstValue, type_row, types::Signature};
189190

190191
use super::*;
@@ -193,8 +194,7 @@ mod test {
193194
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
194195

195196
let build_result = {
196-
let mut builder = HugrBuilder::new();
197-
let mut module_builder = builder.module_hugr_builder();
197+
let mut module_builder = ModuleBuilder::new();
198198
let main =
199199
module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
200200
let s1 = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -234,8 +234,7 @@ mod test {
234234

235235
func_builder.finish_with_outputs(cfg_id.outputs())?
236236
};
237-
module_builder.finish_container()?;
238-
builder.finish()
237+
module_builder.finish_hugr()
239238
};
240239

241240
assert_eq!(build_result.err(), None);

src/builder/conditional.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ impl<B: HugrMutRef> ConditionalBuilder<B> {
137137
mod test {
138138
use cool_asserts::assert_matches;
139139

140-
use crate::builder::HugrBuilder;
140+
use crate::builder::{HugrBuilder, ModuleBuilder};
141141
use crate::{
142142
builder::{
143143
test::{n_identity, NAT},
@@ -151,9 +151,8 @@ mod test {
151151

152152
#[test]
153153
fn basic_conditional() -> Result<(), BuildError> {
154-
let build_result = {
155-
let mut builder = HugrBuilder::new();
156-
let mut module_builder = builder.module_hugr_builder();
154+
let build_result: Result<Hugr, BuildError> = {
155+
let mut module_builder = ModuleBuilder::new();
157156
let main = module_builder
158157
.declare("main", Signature::new_df(type_row![NAT], type_row![NAT]))?;
159158
let tru_const = module_builder.constant(ConstValue::true_val())?;
@@ -180,8 +179,7 @@ mod test {
180179
let [int] = conditional_id.outputs_arr();
181180
fbuild.finish_with_outputs([int])?
182181
};
183-
module_builder.finish_container()?;
184-
builder.finish()
182+
Ok(module_builder.finish_hugr()?)
185183
};
186184

187185
assert_matches!(build_result, Ok(_));

src/builder/dataflow.rs

+19-21
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::build_traits::HugrBuilder;
12
use super::handle::BuildHandle;
23
use super::{BuildError, Container, Dataflow, DfgID, FuncID, HugrMutRef};
34

@@ -48,6 +49,7 @@ impl<T: HugrMutRef> DFGBuilder<T> {
4849
}
4950
}
5051
impl DFGBuilder<HugrMut> {
52+
/// Begin building a new DFG rooted HUGR.
5153
pub fn new(
5254
input: impl Into<TypeRow>,
5355
output: impl Into<TypeRow>,
@@ -61,6 +63,9 @@ impl DFGBuilder<HugrMut> {
6163
let root = base.hugr().root();
6264
DFGBuilder::create_with_io(base, root, input, output)
6365
}
66+
}
67+
68+
impl HugrBuilder for DFGBuilder<HugrMut> {
6469
fn finish_hugr(self) -> Result<Hugr, ValidationError> {
6570
self.base.finish()
6671
}
@@ -156,7 +161,8 @@ impl<B: HugrMutRef, T: From<BuildHandle<DfgID>>> Dataflow for DFGWrapper<B, T> {
156161
mod test {
157162
use cool_asserts::assert_matches;
158163

159-
use crate::builder::HugrBuilder;
164+
use crate::builder::build_traits::DataflowHugrBuilder;
165+
use crate::builder::ModuleBuilder;
160166
use crate::hugr::HugrView;
161167
use crate::{
162168
builder::{
@@ -172,8 +178,7 @@ mod test {
172178
#[test]
173179
fn nested_identity() -> Result<(), BuildError> {
174180
let build_result = {
175-
let mut builder = HugrBuilder::new();
176-
let mut module_builder = builder.module_hugr_builder();
181+
let mut module_builder = ModuleBuilder::new();
177182

178183
let _f_id = {
179184
let mut func_builder = module_builder.declare_and_def(
@@ -193,8 +198,7 @@ mod test {
193198

194199
func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?
195200
};
196-
module_builder.finish_container()?;
197-
builder.finish()
201+
module_builder.finish_hugr()
198202
};
199203

200204
assert_eq!(build_result.err(), None);
@@ -210,8 +214,7 @@ mod test {
210214
) -> Result<BuildHandle<FuncID<true>>, BuildError>,
211215
{
212216
let build_result = {
213-
let mut builder = HugrBuilder::new();
214-
let mut module_builder = builder.module_hugr_builder();
217+
let mut module_builder = ModuleBuilder::new();
215218

216219
let f_build = module_builder.declare_and_def(
217220
"main",
@@ -220,8 +223,7 @@ mod test {
220223

221224
f(f_build)?;
222225

223-
module_builder.finish_container()?;
224-
builder.finish()
226+
module_builder.finish_hugr()
225227
};
226228
assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
227229

@@ -262,27 +264,24 @@ mod test {
262264
#[test]
263265
fn copy_insertion_qubit() {
264266
let builder = || {
265-
let mut builder = HugrBuilder::new();
266-
let mut module_builder = builder.module_hugr_builder();
267+
let mut module_builder = ModuleBuilder::new();
267268

268269
let f_build = module_builder
269270
.declare_and_def("main", Signature::new_df(type_row![QB], type_row![QB, QB]))?;
270271

271272
let [q1] = f_build.input_wires_arr();
272273
f_build.finish_with_outputs([q1, q1])?;
273274

274-
module_builder.finish_container()?;
275-
builder.finish()
275+
Ok(module_builder.finish_hugr()?)
276276
};
277277

278278
assert_eq!(builder(), Err(BuildError::NoCopyLinear(LinearType::Qubit)));
279279
}
280280

281281
#[test]
282282
fn simple_inter_graph_edge() {
283-
let builder = || {
284-
let mut builder = HugrBuilder::new();
285-
let mut module_builder = builder.module_hugr_builder();
283+
let builder = || -> Result<Hugr, BuildError> {
284+
let mut module_builder = ModuleBuilder::new();
286285

287286
let mut f_build = module_builder
288287
.declare_and_def("main", Signature::new_df(type_row![BIT], type_row![BIT]))?;
@@ -299,20 +298,19 @@ mod test {
299298

300299
f_build.finish_with_outputs([nested.out_wire(0)])?;
301300

302-
module_builder.finish_container()?;
303-
builder.finish()
301+
Ok(module_builder.finish_hugr()?)
304302
};
305303

306304
assert_matches!(builder(), Ok(_));
307305
}
308306

309307
#[test]
310308
fn dfg_hugr() -> Result<(), BuildError> {
311-
let mut dfg_builder = DFGBuilder::new(type_row![BIT], type_row![BIT])?;
309+
let dfg_builder = DFGBuilder::new(type_row![BIT], type_row![BIT])?;
312310

313-
dfg_builder.set_outputs(dfg_builder.input_wires())?;
311+
let [i1] = dfg_builder.input_wires_arr();
312+
let hugr = dfg_builder.finish_hugr_with_outputs([i1])?;
314313

315-
let hugr = dfg_builder.finish_hugr()?;
316314
assert_eq!(hugr.node_count(), 3);
317315
assert_matches!(hugr.root_type(), OpType::Dataflow(DataflowOp::DFG { .. }));
318316

0 commit comments

Comments
 (0)