Skip to content

Commit 0b608a5

Browse files
committed
WIP try to parametrize by bool
1 parent d6d9a4e commit 0b608a5

File tree

7 files changed

+262
-151
lines changed

7 files changed

+262
-151
lines changed

hugr/src/extension/op_def.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ impl OpDef {
408408
// The type scheme may contain row variables so be of variable length;
409409
// these will have to be substituted to fixed-length concrete types when
410410
// the OpDef is instantiated into an actual OpType.
411-
ts.poly_func.validate_var_len(exts)?;
411+
ts.poly_func.validate(exts)?;
412412
}
413413
Ok(())
414414
}

hugr/src/types.rs

Lines changed: 109 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub type TypeName = SmolStr;
3838
pub type TypeNameRef = str;
3939

4040
/// The kinds of edges in a HUGR, excluding Hierarchy.
41-
#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
41+
#[derive(Clone, PartialEq, Debug, serde::Serialize, serde::Deserialize)]
4242
#[non_exhaustive]
4343
pub enum EdgeKind {
4444
/// Control edges of a CFG region.
@@ -130,7 +130,7 @@ pub enum SumType {
130130
Unit { size: u8 },
131131
/// General case of a Sum type.
132132
#[allow(missing_docs)]
133-
General { rows: Vec<TypeRow> },
133+
General { rows: Vec<TypeRow<true>> },
134134
}
135135

136136
impl std::fmt::Display for SumType {
@@ -152,7 +152,7 @@ impl SumType {
152152
/// Initialize a new sum type.
153153
pub fn new<V>(variants: impl IntoIterator<Item = V>) -> Self
154154
where
155-
V: Into<TypeRow>,
155+
V: Into<TypeRow<true>>,
156156
{
157157
let rows = variants.into_iter().map(Into::into).collect_vec();
158158

@@ -170,7 +170,7 @@ impl SumType {
170170
}
171171

172172
/// Report the tag'th variant, if it exists.
173-
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
173+
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow<true>> {
174174
match self {
175175
SumType::Unit { size } if tag < (*size as usize) => Some(Type::EMPTY_TYPEROW_REF),
176176
SumType::General { rows } => rows.get(tag),
@@ -187,8 +187,8 @@ impl SumType {
187187
}
188188
}
189189

190-
impl From<SumType> for Type {
191-
fn from(sum: SumType) -> Type {
190+
impl <const RV:bool> From<SumType> for Type<RV> {
191+
fn from(sum: SumType) -> Self {
192192
match sum {
193193
SumType::Unit { size } => Type::new_unit_sum(size),
194194
SumType::General { rows } => Type::new_sum(rows),
@@ -199,7 +199,7 @@ impl From<SumType> for Type {
199199
#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)]
200200
#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))]
201201
/// Core types
202-
pub enum TypeEnum {
202+
pub enum TypeEnum<const ROWVARS:bool=false> {
203203
// TODO optimise with Box<CustomType> ?
204204
// or some static version of this?
205205
#[allow(missing_docs)]
@@ -223,14 +223,29 @@ pub enum TypeEnum {
223223
#[display(fmt = "Variable({})", _0)]
224224
Variable(usize, TypeBound),
225225
/// Variable index, and cache of inner TypeBound - matches a [TypeParam::List] of [TypeParam::Type]
226-
/// of this bound (checked in validation)
226+
/// of this bound (checked in validation). Should only exist for `Type<true>` and `TypeEnum<true>`.
227227
#[display(fmt = "RowVar({})", _0)]
228228
RowVariable(usize, TypeBound),
229229
#[allow(missing_docs)]
230230
#[display(fmt = "{}", "_0")]
231231
Sum(#[cfg_attr(test, proptest(strategy = "any_with::<SumType>(params)"))] SumType),
232232
}
233-
impl TypeEnum {
233+
234+
/*impl <const RV:bool> PartialEq<TypeEnum> for TypeEnum<RV> {
235+
fn eq(&self, other: &TypeEnum) -> bool {
236+
match (self, other) {
237+
(TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2,
238+
(TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2,
239+
(TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1==f2,
240+
(TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1==i2 && b1==b2,
241+
(TypeEnum::RowVariable(i1, b1), TypeEnum::RowVariable(i2, b2)) => i1==i2 && b1==b2,
242+
(TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2,
243+
_ => false
244+
}
245+
}
246+
}*/
247+
248+
impl <const RV:bool> TypeEnum<RV> {
234249
/// The smallest type bound that covers the whole type.
235250
fn least_upper_bound(&self) -> TypeBound {
236251
match self {
@@ -249,10 +264,10 @@ impl TypeEnum {
249264
}
250265

251266
#[derive(
252-
Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
267+
Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
253268
)]
254269
#[display(fmt = "{}", "_0")]
255-
#[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")]
270+
#[serde(into = "serialize::SerSimpleType", try_from = "serialize::SerSimpleType")]
256271
/// A HUGR type - the valid types of [EdgeKind::Value] and [EdgeKind::Const] edges.
257272
/// Such an edge is valid if the ports on either end agree on the [Type].
258273
/// Types have an optional [TypeBound] which places limits on the valid
@@ -273,24 +288,30 @@ impl TypeEnum {
273288
/// let func_type = Type::new_function(FunctionType::new_endo(vec![]));
274289
/// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable);
275290
/// ```
276-
pub struct Type(TypeEnum, TypeBound);
291+
pub struct Type<const ROWVARS:bool=false>(TypeEnum<ROWVARS>, TypeBound);
277292

278-
impl Type {
293+
/*impl<const RV:bool> PartialEq<Type> for Type<RV> {
294+
fn eq(&self, other: &Type) -> bool {
295+
self.0 == other.0 && self.1 == other.1
296+
}
297+
}*/
298+
299+
impl<const RV:bool> Type<RV> {
279300
/// An empty `TypeRow`. Provided here for convenience
280-
pub const EMPTY_TYPEROW: TypeRow = type_row![];
301+
pub const EMPTY_TYPEROW: TypeRow<RV> = type_row![];
281302
/// Unit type (empty tuple).
282303
pub const UNIT: Self = Self(TypeEnum::Sum(SumType::Unit { size: 1 }), TypeBound::Eq);
283304

284-
const EMPTY_TYPEROW_REF: &'static TypeRow = &Self::EMPTY_TYPEROW;
305+
const EMPTY_TYPEROW_REF: &'static TypeRow<RV> = &Self::EMPTY_TYPEROW;
285306

286307
/// Initialize a new function type.
287-
pub fn new_function(fun_ty: impl Into<FunctionType>) -> Self {
308+
pub fn new_function(fun_ty: impl Into<FunctionType<true>>) -> Self {
288309
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
289310
}
290311

291312
/// Initialize a new tuple type by providing the elements.
292313
#[inline(always)]
293-
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
314+
pub fn new_tuple(types: impl Into<TypeRow<true>>) -> Self {
294315
let row = types.into();
295316
match row.len() {
296317
0 => Self::UNIT,
@@ -300,7 +321,7 @@ impl Type {
300321

301322
/// Initialize a new sum type by providing the possible variant types.
302323
#[inline(always)]
303-
pub fn new_sum(variants: impl IntoIterator<Item = TypeRow>) -> Self where {
324+
pub fn new_sum(variants: impl IntoIterator<Item = TypeRow<true>>) -> Self where {
304325
Self::new(TypeEnum::Sum(SumType::new(variants)))
305326
}
306327

@@ -316,7 +337,7 @@ impl Type {
316337
Self::new(TypeEnum::Alias(alias))
317338
}
318339

319-
fn new(type_e: TypeEnum) -> Self {
340+
fn new(type_e: TypeEnum<RV>) -> Self {
320341
let bound = type_e.least_upper_bound();
321342
Self(type_e, bound)
322343
}
@@ -335,19 +356,6 @@ impl Type {
335356
Self(TypeEnum::Variable(idx, bound), bound)
336357
}
337358

338-
/// New use (occurrence) of the row variable with specified index.
339-
/// `bound` must be exactly that with which the variable was declared
340-
/// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound),
341-
/// which may be narrower than required for the use.
342-
/// For use in [OpDef] type schemes, or function types, only,
343-
/// not [FuncDefn] type schemes or as a Hugr port type.
344-
///
345-
/// [OpDef]: crate::extension::OpDef
346-
/// [FuncDefn]: crate::ops::FuncDefn
347-
pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self {
348-
Self(TypeEnum::RowVariable(idx, bound), bound)
349-
}
350-
351359
/// Report the least upper [TypeBound]
352360
#[inline(always)]
353361
pub const fn least_upper_bound(&self) -> TypeBound {
@@ -356,7 +364,7 @@ impl Type {
356364

357365
/// Report the component TypeEnum.
358366
#[inline(always)]
359-
pub const fn as_type_enum(&self) -> &TypeEnum {
367+
pub const fn as_type_enum(&self) -> &TypeEnum<RV> {
360368
&self.0
361369
}
362370

@@ -382,7 +390,6 @@ impl Type {
382390
/// [TypeDef]: crate::extension::TypeDef
383391
pub(crate) fn validate(
384392
&self,
385-
allow_row_vars: bool,
386393
extension_registry: &ExtensionRegistry,
387394
var_decls: &[TypeParam],
388395
) -> Result<(), SignatureError> {
@@ -391,16 +398,16 @@ impl Type {
391398
match &self.0 {
392399
TypeEnum::Sum(SumType::General { rows }) => rows
393400
.iter()
394-
.try_for_each(|row| row.validate_var_len(extension_registry, var_decls)),
401+
.try_for_each(|row| row.validate(extension_registry, var_decls)),
395402
TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there
396403
TypeEnum::Alias(_) => Ok(()),
397404
TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls),
398405
// Function values may be passed around without knowing their arity
399406
// (i.e. with row vars) as long as they are not called:
400-
TypeEnum::Function(ft) => ft.validate_var_len(extension_registry, var_decls),
407+
TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls),
401408
TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()),
402409
TypeEnum::RowVariable(idx, bound) => {
403-
if allow_row_vars {
410+
if RV {
404411
check_typevar_decl(var_decls, *idx, &TypeParam::new_list(*bound))
405412
} else {
406413
Err(SignatureError::RowVarWhereTypeExpected { idx: *idx })
@@ -411,29 +418,68 @@ impl Type {
411418

412419
/// Applies a substitution to a type.
413420
/// This may result in a row of types, if this [Type] is not really a single type but actually a row variable
414-
/// Invariants may be confirmed by validation:
415-
/// * If [Type::validate]`(false)` returns successfully, this method will return a Vec containing exactly one type
416-
/// * If [Type::validate]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self)
417-
/// return a Vec containing any number of [Type]s. These may (or not) pass [Type::validate]
418-
fn substitute(&self, t: &Substitution) -> Vec<Self> {
421+
/// (of course this can only occur for a `Type<true``). For a `Type<false>`, will always return exactly one element.
422+
fn subst_vec(&self, s: &Substitution) -> Vec<Self> {
419423
match &self.0 {
420-
TypeEnum::RowVariable(idx, bound) => t.apply_rowvar(*idx, *bound),
421-
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()],
424+
TypeEnum::RowVariable(idx, bound) =>
425+
if RV {s.apply_rowvar(idx, bound)} // ALAN Argh, type error here as Type<true> != Type<RV> even inside "if RV"
426+
else {panic!("Row Variable outside Row - should not have validated?")},
427+
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone().into()],
422428
TypeEnum::Variable(idx, bound) => {
423-
let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else {
429+
let TypeArg::Type { ty } = s.apply_var(*idx, &((*bound).into())) else {
424430
panic!("Variable was not a type - try validate() first")
425431
};
426-
vec![ty]
432+
vec![ty] // ALAN argh, can't convert parametrically from Type<false> to Type<RV>
427433
}
428-
TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(t))],
429-
TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(t))],
434+
TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(s))],
435+
TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(s))],
430436
TypeEnum::Sum(SumType::General { rows }) => {
431-
vec![Type::new_sum(rows.iter().map(|r| r.substitute(t)))]
437+
vec![Type::new_sum(rows.iter().map(|r| r.substitute(s)))]
432438
}
433439
}
434440
}
435441
}
436442

443+
impl TryFrom<Type<true>> for Type<false> {
444+
type Error = ConvertError;
445+
fn try_from(value: Type<true>) -> Result<Self, Self::Error> {
446+
Ok(Self(match value.0 {
447+
TypeEnum::Extension(e) => TypeEnum::Extension(e),
448+
TypeEnum::Alias(a) => TypeEnum::Alias(a),
449+
TypeEnum::Function(ft) => TypeEnum::Function(ft),
450+
TypeEnum::Variable(i, b) => TypeEnum::Variable(i,b),
451+
TypeEnum::RowVariable(_, _) => return Err(ConvertError),
452+
TypeEnum::Sum(st) => TypeEnum::Sum(st)
453+
}, value.1))
454+
}
455+
}
456+
457+
impl Type<false> {
458+
fn substitute(&self, s: &Substitution) -> Self {
459+
let v = self.subst_vec(s);
460+
let [r] = v.try_into().unwrap(); // No row vars, so every Type<false> produces exactly one
461+
r
462+
}
463+
}
464+
465+
impl Type<true> {
466+
/// New use (occurrence) of the row variable with specified index.
467+
/// `bound` must match that with which the variable was declared
468+
/// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound).
469+
/// For use in [OpDef], not [FuncDefn], type schemes only.
470+
///
471+
/// [OpDef]: crate::extension::OpDef
472+
/// [FuncDefn]: crate::ops::FuncDefn
473+
pub const fn new_row_var(idx: usize, bound: TypeBound) -> Self {
474+
Self(TypeEnum::RowVariable(idx, bound), bound)
475+
}
476+
477+
fn substitute(&self, s: &Substitution) -> Vec<Self> {
478+
self.subst_vec(s)
479+
}
480+
481+
}
482+
437483
/// Details a replacement of type variables with a finite list of known values.
438484
/// (Variables out of the range of the list will result in a panic)
439485
pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry);
@@ -448,7 +494,7 @@ impl<'a> Substitution<'a> {
448494
arg.clone()
449495
}
450496

451-
fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec<Type> {
497+
fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec<Type<true>> {
452498
let arg = self
453499
.0
454500
.get(idx)
@@ -476,6 +522,19 @@ impl<'a> Substitution<'a> {
476522
}
477523
}
478524

525+
impl From<Type<false>> for Type<true> {
526+
fn from(value: Type<false>) -> Self {
527+
Self(match value.0 {
528+
TypeEnum::Alias(a) => TypeEnum::Alias(a),
529+
TypeEnum::Extension(e) => TypeEnum::Extension(e),
530+
TypeEnum::Function(ft) => TypeEnum::Function(ft),
531+
TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound),
532+
TypeEnum::RowVariable(_, _) => panic!("Type<false> should not contain row variables"),
533+
TypeEnum::Sum(st) => TypeEnum::Sum(st),
534+
}, value.1)
535+
}
536+
}
537+
479538
pub(crate) fn check_typevar_decl(
480539
decls: &[TypeParam],
481540
idx: usize,

0 commit comments

Comments
 (0)