Skip to content

Commit fa39c37

Browse files
ss2165aborgna-q
andauthored
Allow any op as root (#112)
* allow any op as module root * fix name of DFG * new top level HugrBuilder to allow for building non-module root HUGR for now only adds DFG builder at root level, other containers to come * rename to module_hugr_builder and dfg_hugr_builder * implement empty drop for builders to force call of finish * make finish return result uniformly * fix comment * check root has no edges * don't require copy is not child of root * use AsMut and AsRef to make builders generic over base * clean up drop macro * WIP owned builder specializations * remove Drop implmentations for builders * rename `finish()` to `finish_container()` * finish implement owned HUGR building for Module and DFG * CFG HugrBuilder * allow constants in any Container * implement HugrBuilder for CFGBuilder * separate further owned and borrowed builders * owned TailLoop builder * add owned conditional builder * owned FunctionBuilder * owned BlockBuilder * owned CaseBuilder * fix doc link * update FuncBuilder docstring * Apply suggestions from code review Co-authored-by: Agustín Borgna <[email protected]> * remove explicit lifetimes * take some clippy pedantic suggestions mostly docs * clean up const + dag validation --------- Co-authored-by: Agustín Borgna <[email protected]>
1 parent 8fefbad commit fa39c37

13 files changed

+772
-414
lines changed

src/algorithm/nest_cfgs.rs

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,10 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
400400
#[cfg(test)]
401401
pub(crate) mod test {
402402
use super::*;
403-
use crate::builder::{BuildError, CFGBuilder, Container, Dataflow, ModuleBuilder};
403+
use crate::builder::{
404+
BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, HugrMutRef,
405+
ModuleBuilder, SubContainer,
406+
};
404407
use crate::ops::{
405408
handle::{BasicBlockID, CfgID, ConstID, NodeHandle},
406409
ConstValue,
@@ -430,8 +433,8 @@ pub(crate) mod test {
430433
// \-> right -/ \-<--<-/
431434
let mut module_builder = ModuleBuilder::new();
432435
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
433-
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
434-
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
436+
let pred_const = module_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
437+
let const_unit = module_builder.add_constant(ConstValue::simple_unary_predicate())?;
435438

436439
let mut func_builder = module_builder.define_function(&main)?;
437440
let [int] = func_builder.input_wires_arr();
@@ -448,11 +451,10 @@ pub(crate) mod test {
448451
cfg_builder.branch(&merge, 0, &head)?;
449452
let exit = cfg_builder.exit_block();
450453
cfg_builder.branch(&tail, 0, &exit)?;
451-
let cfg_id = cfg_builder.finish();
454+
let cfg_id = cfg_builder.finish_sub_container()?;
452455

453456
func_builder.finish_with_outputs(cfg_id.outputs())?;
454-
455-
let h = module_builder.finish()?;
457+
let h = module_builder.finish_hugr()?;
456458

457459
let (entry, exit) = (entry.node(), exit.node());
458460
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
@@ -483,8 +485,8 @@ pub(crate) mod test {
483485
// the conditional and the loop to indicate the boundary, so we cannot separate them.
484486
let mut module_builder = ModuleBuilder::new();
485487
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
486-
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
487-
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
488+
let pred_const = module_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
489+
let const_unit = module_builder.add_constant(ConstValue::simple_unary_predicate())?;
488490

489491
let mut func_builder = module_builder.define_function(&main)?;
490492
let [int] = func_builder.input_wires_arr();
@@ -499,11 +501,11 @@ pub(crate) mod test {
499501
cfg_builder.branch(&merge, 0, &tail)?; // trivial "loop body"
500502
let exit = cfg_builder.exit_block();
501503
cfg_builder.branch(&tail, 0, &exit)?;
502-
let cfg_id = cfg_builder.finish();
504+
let cfg_id = cfg_builder.finish_sub_container()?;
503505

504506
func_builder.finish_with_outputs(cfg_id.outputs())?;
505507

506-
let h = module_builder.finish()?;
508+
let h = module_builder.finish_hugr()?;
507509

508510
let (entry, exit) = (entry.node(), exit.node());
509511
let (merge, tail) = (merge.node(), tail.node());
@@ -602,7 +604,7 @@ pub(crate) mod test {
602604
Ok(())
603605
}
604606

605-
fn n_identity<T: Dataflow>(
607+
fn n_identity<T: DataflowSubContainer>(
606608
mut dataflow_builder: T,
607609
pred_const: &ConstID,
608610
) -> Result<T::ContainerHandle, BuildError> {
@@ -611,8 +613,8 @@ pub(crate) mod test {
611613
dataflow_builder.finish_with_outputs([u].into_iter().chain(w))
612614
}
613615

614-
fn build_if_then_else_merge(
615-
cfg: &mut CFGBuilder,
616+
fn build_if_then_else_merge<T: HugrMutRef>(
617+
cfg: &mut CFGBuilder<T>,
616618
const_pred: &ConstID,
617619
unit_const: &ConstID,
618620
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
@@ -624,8 +626,8 @@ pub(crate) mod test {
624626
Ok((split, merge))
625627
}
626628

627-
fn build_then_else_merge_from_if(
628-
cfg: &mut CFGBuilder,
629+
fn build_then_else_merge_from_if<T: HugrMutRef>(
630+
cfg: &mut CFGBuilder<T>,
629631
unit_const: &ConstID,
630632
split: BasicBlockID,
631633
) -> Result<BasicBlockID, BuildError> {
@@ -649,8 +651,8 @@ pub(crate) mod test {
649651
}
650652

651653
// Returns loop tail - caller must link header to tail, and provide 0th successor of tail
652-
fn build_loop_from_header(
653-
cfg: &mut CFGBuilder,
654+
fn build_loop_from_header<T: HugrMutRef>(
655+
cfg: &mut CFGBuilder<T>,
654656
const_pred: &ConstID,
655657
header: BasicBlockID,
656658
) -> Result<BasicBlockID, BuildError> {
@@ -663,8 +665,8 @@ pub(crate) mod test {
663665
}
664666

665667
// Result is header and tail. Caller must provide 0th successor of header (linking to tail), and 0th successor of tail.
666-
fn build_loop(
667-
cfg: &mut CFGBuilder,
668+
fn build_loop<T: HugrMutRef>(
669+
cfg: &mut CFGBuilder<T>,
668670
const_pred: &ConstID,
669671
unit_const: &ConstID,
670672
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
@@ -684,8 +686,8 @@ pub(crate) mod test {
684686

685687
let mut module_builder = ModuleBuilder::new();
686688
let main = module_builder.declare("main", Signature::new_df(vec![NAT], type_row![NAT]))?;
687-
let pred_const = module_builder.constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
688-
let const_unit = module_builder.constant(ConstValue::simple_unary_predicate())?;
689+
let pred_const = module_builder.add_constant(ConstValue::simple_predicate(0, 2))?; // Nothing here cares which
690+
let const_unit = module_builder.add_constant(ConstValue::simple_unary_predicate())?;
689691

690692
let mut func_builder = module_builder.define_function(&main)?;
691693
let [int] = func_builder.input_wires_arr();
@@ -713,11 +715,11 @@ pub(crate) mod test {
713715
cfg_builder.branch(&entry, 0, &head)?;
714716
cfg_builder.branch(&tail, 0, &exit)?;
715717

716-
let cfg_id = cfg_builder.finish();
718+
let cfg_id = cfg_builder.finish_sub_container()?;
717719

718720
func_builder.finish_with_outputs(cfg_id.outputs())?;
719721

720-
let h = module_builder.finish()?;
722+
let h = module_builder.finish_hugr()?;
721723

722724
Ok((h, *cfg_id.handle(), head, tail))
723725
}

src/builder.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
//!
33
use thiserror::Error;
44

5-
use crate::hugr::{HugrError, Node, ValidationError, Wire};
5+
use crate::hugr::{HugrError, HugrMut, Node, ValidationError, Wire};
66
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
7+
78
use crate::types::LinearType;
89

910
pub mod handle;
1011
pub use handle::BuildHandle;
1112

1213
mod build_traits;
13-
pub use build_traits::{Container, Dataflow};
14+
pub use build_traits::{
15+
Container, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer,
16+
};
1417

1518
mod dataflow;
1619
pub use dataflow::{DFGBuilder, DFGWrapper, FunctionBuilder};
@@ -67,22 +70,40 @@ pub enum BuildError {
6770
CircuitError(#[from] circuit_builder::CircuitBuildError),
6871
}
6972

73+
impl AsMut<HugrMut> for HugrMut {
74+
fn as_mut(&mut self) -> &mut HugrMut {
75+
self
76+
}
77+
}
78+
impl AsRef<HugrMut> for HugrMut {
79+
fn as_ref(&self) -> &HugrMut {
80+
self
81+
}
82+
}
83+
84+
/// Trait allowing treating type as (im)mutable reference to [`HugrMut`]
85+
pub trait HugrMutRef: AsMut<HugrMut> + AsRef<HugrMut> {}
86+
impl HugrMutRef for HugrMut {}
87+
impl HugrMutRef for &mut HugrMut {}
88+
7089
#[cfg(test)]
7190
mod test {
7291

92+
use crate::hugr::HugrMut;
7393
use crate::types::{ClassicType, LinearType, Signature, SimpleType};
74-
use crate::{builder::ModuleBuilder, Hugr};
94+
use crate::Hugr;
7595

7696
use super::handle::BuildHandle;
77-
use super::{BuildError, Container, Dataflow, FuncID, FunctionBuilder};
97+
use super::{BuildError, FuncID, FunctionBuilder, ModuleBuilder};
98+
use super::{DataflowSubContainer, HugrBuilder};
7899

79100
pub(super) const NAT: SimpleType = SimpleType::Classic(ClassicType::i64());
80101
pub(super) const F64: SimpleType = SimpleType::Classic(ClassicType::F64);
81102
pub(super) const BIT: SimpleType = SimpleType::Classic(ClassicType::bit());
82103
pub(super) const QB: SimpleType = SimpleType::Linear(LinearType::Qubit);
83104

84105
/// Wire up inputs of a Dataflow container to the outputs.
85-
pub(super) fn n_identity<T: Dataflow>(
106+
pub(super) fn n_identity<T: DataflowSubContainer>(
86107
dataflow_builder: T,
87108
) -> Result<T::ContainerHandle, BuildError> {
88109
let w = dataflow_builder.input_wires();
@@ -91,13 +112,12 @@ mod test {
91112

92113
pub(super) fn build_main(
93114
signature: Signature,
94-
f: impl FnOnce(FunctionBuilder<true>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
115+
f: impl FnOnce(FunctionBuilder<&mut HugrMut>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
95116
) -> Result<Hugr, BuildError> {
96117
let mut module_builder = ModuleBuilder::new();
97118
let f_builder = module_builder.declare_and_def("main", signature)?;
98119

99120
f(f_builder)?;
100-
101-
module_builder.finish()
121+
Ok(module_builder.finish_hugr()?)
102122
}
103123
}

0 commit comments

Comments
 (0)