Skip to content

Commit 5905dd3

Browse files
committed
new top level HugrBuilder
to allow for building non-module root HUGR for now only adds DFG builder at root level, other containers to come
1 parent 9019f43 commit 5905dd3

File tree

7 files changed

+134
-59
lines changed

7 files changed

+134
-59
lines changed

src/algorithm/nest_cfgs.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
400400
#[cfg(test)]
401401
pub(crate) mod test {
402402
use super::*;
403-
use crate::builder::{BuildError, CFGBuilder, Container, Dataflow, ModuleBuilder};
403+
use crate::builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder};
404404
use crate::ops::{
405405
handle::{BasicBlockID, CfgID, ConstID, NodeHandle},
406406
ConstValue,
@@ -428,7 +428,8 @@ pub(crate) mod test {
428428
// /-> left --\
429429
// entry -> split > merge -> head -> tail -> exit
430430
// \-> right -/ \-<--<-/
431-
let mut module_builder = ModuleBuilder::new();
431+
let mut builder = HugrBuilder::new();
432+
let mut module_builder = builder.module_builder();
432433
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
433434
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
434435
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -452,7 +453,7 @@ pub(crate) mod test {
452453

453454
func_builder.finish_with_outputs(cfg_id.outputs())?;
454455

455-
let h = module_builder.finish()?;
456+
let h = builder.finish()?;
456457

457458
let (entry, exit) = (entry.node(), exit.node());
458459
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
@@ -481,7 +482,8 @@ pub(crate) mod test {
481482
// \-> right -/ \-<--<-/
482483
// Here we would like two consecutive regions, but there is no *edge* between
483484
// the conditional and the loop to indicate the boundary, so we cannot separate them.
484-
let mut module_builder = ModuleBuilder::new();
485+
let mut builder = HugrBuilder::new();
486+
let mut module_builder = builder.module_builder();
485487
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
486488
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
487489
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -503,7 +505,7 @@ pub(crate) mod test {
503505

504506
func_builder.finish_with_outputs(cfg_id.outputs())?;
505507

506-
let h = module_builder.finish()?;
508+
let h = builder.finish()?;
507509

508510
let (entry, exit) = (entry.node(), exit.node());
509511
let (merge, tail) = (merge.node(), tail.node());
@@ -682,7 +684,8 @@ pub(crate) mod test {
682684
) -> Result<(Hugr, CfgID, BasicBlockID, BasicBlockID), BuildError> {
683685
//let sum2_type = SimpleType::new_predicate(2);
684686

685-
let mut module_builder = ModuleBuilder::new();
687+
let mut builder = HugrBuilder::new();
688+
let mut module_builder = builder.module_builder();
686689
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
687690
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
688691
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -717,7 +720,7 @@ pub(crate) mod test {
717720

718721
func_builder.finish_with_outputs(cfg_id.outputs())?;
719722

720-
let h = module_builder.finish()?;
723+
let h = builder.finish()?;
721724

722725
Ok((h, *cfg_id.handle(), head, tail))
723726
}

src/builder.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
//!
33
use thiserror::Error;
44

5-
use crate::hugr::{HugrError, Node, ValidationError, Wire};
5+
use crate::hugr::{HugrError, HugrMut, HugrView, Node, ValidationError, Wire};
66
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
7-
use crate::types::LinearType;
7+
use crate::ops::DataflowOp;
8+
use crate::types::{LinearType, Signature, TypeRow};
9+
use crate::Hugr;
810

911
pub mod handle;
1012
pub use handle::BuildHandle;
@@ -67,14 +69,60 @@ pub enum BuildError {
6769
CircuitError(#[from] circuit_builder::CircuitBuildError),
6870
}
6971

72+
#[derive(Default)]
73+
/// Base builder, can generate builders for containers
74+
pub struct HugrBuilder {
75+
base: HugrMut,
76+
}
77+
78+
impl HugrBuilder {
79+
/// Initialize a new builder
80+
pub fn new() -> Self {
81+
// initially assume to be a module root, will be replaced if not.
82+
Self {
83+
base: HugrMut::new_module(),
84+
}
85+
}
86+
87+
/// Use this builder to build a module HUGR
88+
pub fn module_builder(&mut self) -> ModuleBuilder {
89+
ModuleBuilder(&mut self.base)
90+
}
91+
92+
/// Use this builder to build a DFG HUGR
93+
pub fn root_dfg_builder(
94+
&mut self,
95+
input: impl Into<TypeRow>,
96+
output: impl Into<TypeRow>,
97+
) -> Result<DFGBuilder, BuildError> {
98+
let input = input.into();
99+
let output = output.into();
100+
let root = self.base.hugr().root();
101+
let dfg_op = DataflowOp::DFG {
102+
signature: Signature::new_df(input.clone(), output.clone()),
103+
};
104+
self.base.replace_op(root, dfg_op);
105+
106+
DFGBuilder::create_with_io(&mut self.base, root, input, output)
107+
}
108+
109+
// TODO: CFG, BasicBlock, Def, Conditional, TailLoop, Case
110+
111+
/// Complete building and return HUGR, performing validation.
112+
pub fn finish(self) -> Result<Hugr, BuildError> {
113+
Ok(self.base.finish()?)
114+
}
115+
}
116+
70117
#[cfg(test)]
71118
mod test {
72119

73120
use crate::types::{ClassicType, LinearType, Signature, SimpleType};
74-
use crate::{builder::ModuleBuilder, Hugr};
121+
use crate::Hugr;
75122

76123
use super::handle::BuildHandle;
77-
use super::{BuildError, Container, Dataflow, FuncID, FunctionBuilder};
124+
use super::HugrBuilder;
125+
use super::{BuildError, Dataflow, FuncID, FunctionBuilder};
78126

79127
pub(super) const NAT: SimpleType = SimpleType::Classic(ClassicType::i64());
80128
pub(super) const F64: SimpleType = SimpleType::Classic(ClassicType::F64);
@@ -93,11 +141,12 @@ mod test {
93141
signature: Signature,
94142
f: impl FnOnce(FunctionBuilder<true>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
95143
) -> Result<Hugr, BuildError> {
96-
let mut module_builder = ModuleBuilder::new();
144+
let mut builder = HugrBuilder::new();
145+
let mut module_builder = builder.module_builder();
97146
let f_builder = module_builder.declare_and_def("main", signature)?;
98147

99148
f(f_builder)?;
100149

101-
module_builder.finish()
150+
builder.finish()
102151
}
103152
}

src/builder/cfg.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,20 +184,17 @@ impl<'b> BlockBuilder<'b> {
184184

185185
#[cfg(test)]
186186
mod test {
187-
use crate::{
188-
builder::{module_builder::ModuleBuilder, test::NAT},
189-
ops::ConstValue,
190-
type_row,
191-
types::Signature,
192-
};
187+
use crate::builder::HugrBuilder;
188+
use crate::{builder::test::NAT, ops::ConstValue, type_row, types::Signature};
193189

194190
use super::*;
195191
#[test]
196192
fn basic_cfg() -> Result<(), BuildError> {
197193
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
198194

199195
let build_result = {
200-
let mut module_builder = ModuleBuilder::new();
196+
let mut builder = HugrBuilder::new();
197+
let mut module_builder = builder.module_builder();
201198
let main =
202199
module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
203200
let s1 = module_builder.constant(ConstValue::simple_unary_predicate())?;
@@ -237,7 +234,8 @@ mod test {
237234

238235
func_builder.finish_with_outputs(cfg_id.outputs())?
239236
};
240-
module_builder.finish()
237+
module_builder.finish()?;
238+
builder.finish()
241239
};
242240

243241
assert_eq!(build_result.err(), None);

src/builder/conditional.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ impl<'f> ConditionalBuilder<'f> {
135135
mod test {
136136
use cool_asserts::assert_matches;
137137

138+
use crate::builder::HugrBuilder;
138139
use crate::{
139140
builder::{
140-
module_builder::ModuleBuilder,
141141
test::{n_identity, NAT},
142142
Dataflow,
143143
},
@@ -150,7 +150,8 @@ mod test {
150150
#[test]
151151
fn basic_conditional() -> Result<(), BuildError> {
152152
let build_result = {
153-
let mut module_builder = ModuleBuilder::new();
153+
let mut builder = HugrBuilder::new();
154+
let mut module_builder = builder.module_builder();
154155
let main = module_builder
155156
.declare("main", Signature::new_df(type_row![NAT], type_row![NAT]))?;
156157
let tru_const = module_builder.constant(ConstValue::true_val())?;
@@ -177,7 +178,8 @@ mod test {
177178
let [int] = conditional_id.outputs_arr();
178179
fbuild.finish_with_outputs([int])?
179180
};
180-
module_builder.finish()
181+
module_builder.finish()?;
182+
builder.finish()
181183
};
182184

183185
assert_matches!(build_result, Ok(_));

src/builder/dataflow.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ impl<'b, T: From<BuildHandle<DfgID>>> Dataflow for DFGWrapper<'b, T> {
135135
mod test {
136136
use cool_asserts::assert_matches;
137137

138+
use crate::builder::HugrBuilder;
139+
use crate::hugr::HugrView;
138140
use crate::{
139141
builder::{
140-
module_builder::ModuleBuilder,
141142
test::{n_identity, BIT, NAT, QB},
142143
BuildError,
143144
},
@@ -150,7 +151,8 @@ mod test {
150151
#[test]
151152
fn nested_identity() -> Result<(), BuildError> {
152153
let build_result = {
153-
let mut module_builder = ModuleBuilder::new();
154+
let mut builder = HugrBuilder::new();
155+
let mut module_builder = builder.module_builder();
154156

155157
let _f_id = {
156158
let mut func_builder = module_builder.declare_and_def(
@@ -170,7 +172,8 @@ mod test {
170172

171173
func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?
172174
};
173-
module_builder.finish()
175+
module_builder.finish()?;
176+
builder.finish()
174177
};
175178

176179
assert_eq!(build_result.err(), None);
@@ -184,7 +187,8 @@ mod test {
184187
F: FnOnce(FunctionBuilder<true>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
185188
{
186189
let build_result = {
187-
let mut module_builder = ModuleBuilder::new();
190+
let mut builder = HugrBuilder::new();
191+
let mut module_builder = builder.module_builder();
188192

189193
let f_build = module_builder.declare_and_def(
190194
"main",
@@ -193,7 +197,8 @@ mod test {
193197

194198
f(f_build)?;
195199

196-
module_builder.finish()
200+
module_builder.finish()?;
201+
builder.finish()
197202
};
198203
assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);
199204

@@ -234,15 +239,17 @@ mod test {
234239
#[test]
235240
fn copy_insertion_qubit() {
236241
let builder = || {
237-
let mut module_builder = ModuleBuilder::new();
242+
let mut builder = HugrBuilder::new();
243+
let mut module_builder = builder.module_builder();
238244

239245
let f_build = module_builder
240246
.declare_and_def("main", Signature::new_df(type_row![QB], type_row![QB, QB]))?;
241247

242248
let [q1] = f_build.input_wires_arr();
243249
f_build.finish_with_outputs([q1, q1])?;
244250

245-
module_builder.finish()
251+
module_builder.finish()?;
252+
builder.finish()
246253
};
247254

248255
assert_eq!(builder(), Err(BuildError::NoCopyLinear(LinearType::Qubit)));
@@ -251,7 +258,8 @@ mod test {
251258
#[test]
252259
fn simple_inter_graph_edge() {
253260
let builder = || {
254-
let mut module_builder = ModuleBuilder::new();
261+
let mut builder = HugrBuilder::new();
262+
let mut module_builder = builder.module_builder();
255263

256264
let mut f_build = module_builder
257265
.declare_and_def("main", Signature::new_df(type_row![BIT], type_row![BIT]))?;
@@ -268,9 +276,25 @@ mod test {
268276

269277
f_build.finish_with_outputs([nested.out_wire(0)])?;
270278

271-
module_builder.finish()
279+
module_builder.finish()?;
280+
builder.finish()
272281
};
273282

274283
assert_matches!(builder(), Ok(_));
275284
}
285+
286+
#[test]
287+
fn dfg_hugr() -> Result<(), BuildError> {
288+
let mut builder = HugrBuilder::new();
289+
290+
let dfg_builder = builder.root_dfg_builder(type_row![BIT], type_row![BIT])?;
291+
292+
n_identity(dfg_builder)?;
293+
294+
let hugr = builder.finish()?;
295+
assert_eq!(hugr.node_count(), 3);
296+
assert_matches!(hugr.root_type(), OpType::Dataflow(DataflowOp::DFG { .. }));
297+
298+
Ok(())
299+
}
276300
}

0 commit comments

Comments
 (0)