diff --git a/interp/src/debugger/cidr.rs b/interp/src/debugger/cidr.rs index 64f99e5a2a..c6619cba50 100644 --- a/interp/src/debugger/cidr.rs +++ b/interp/src/debugger/cidr.rs @@ -11,7 +11,7 @@ use crate::interpreter::{ComponentInterpreter, ConstCell, Interpreter}; use crate::structures::names::{CompGroupName, ComponentQualifiedInstanceName}; use crate::structures::state_views::StateView; use crate::utils::AsRaw; -use crate::{interpreter_ir as iir, primitives::Serializable}; +use crate::{interpreter_ir as iir, serialization::Serializable}; use calyx_ir::{self as ir, Id, RRC}; diff --git a/interp/src/flatten/flat_ir/base.rs b/interp/src/flatten/flat_ir/base.rs index e91cd16fa2..a1a9bcc583 100644 --- a/interp/src/flatten/flat_ir/base.rs +++ b/interp/src/flatten/flat_ir/base.rs @@ -274,12 +274,21 @@ impl From for AssignmentWinner { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct AssignedValue { val: Value, winner: AssignmentWinner, } +impl std::fmt::Debug for AssignedValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AssignedValue") + .field("val", &format!("{}", &self.val)) + .field("winner", &self.winner) + .finish() + } +} + impl std::fmt::Display for AssignedValue { // TODO: replace with something more reasonable fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -342,7 +351,7 @@ impl AssignedValue { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] /// A wrapper struct around an option of an [AssignedValue] pub struct PortValue(Option); @@ -398,6 +407,12 @@ impl PortValue { pub fn new_implicit(val: Value) -> Self { Self(Some(AssignedValue::implicit_value(val))) } + + /// Sets the value to undefined and returns the former value if present. + /// This is equivalent to [Option::take] + pub fn set_undef(&mut self) -> Option { + self.0.take() + } } impl From> for PortValue { @@ -412,6 +427,12 @@ impl From for PortValue { } } +impl From for Option { + fn from(value: PortValue) -> Self { + value.0 + } +} + /// A global index for standard groups in the IR #[derive(Debug, Eq, Copy, Clone, PartialEq, Hash, PartialOrd, Ord)] pub struct GroupIdx(u32); diff --git a/interp/src/flatten/primitives/builder.rs b/interp/src/flatten/primitives/builder.rs index f2bbb8e9de..037a3cc867 100644 --- a/interp/src/flatten/primitives/builder.rs +++ b/interp/src/flatten/primitives/builder.rs @@ -106,7 +106,7 @@ pub fn build_primitive( idx_size: _, } => match mem_type { MemType::Seq => todo!("SeqMem primitives are not currently defined in the flat interpreter"), - MemType::Std => Box::new(StdMemD1::new(base_port, *width, false, *size as usize)) + MemType::Std => Box::new(CombMemD1::new(base_port, *width, false, *size as usize)) }, CellPrototype::MemD2 { mem_type, @@ -117,7 +117,7 @@ pub fn build_primitive( d1_idx_size: _, } => match mem_type { MemType::Seq => todo!("SeqMem primitives are not currently defined in the flat interpreter"), - MemType::Std => Box::new(StdMemD2::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize))), + MemType::Std => Box::new(CombMemD2::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize))), }, CellPrototype::MemD3 { mem_type, @@ -130,7 +130,7 @@ pub fn build_primitive( d2_idx_size: _, } => match mem_type { MemType::Seq => todo!("SeqMem primitives are not currently defined in the flat interpreter"), - MemType::Std => Box::new(StdMemD3::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize, *d2_size as usize))), + MemType::Std => Box::new(CombMemD3::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize, *d2_size as usize))), }, CellPrototype::MemD4 { mem_type, @@ -145,7 +145,7 @@ pub fn build_primitive( d3_idx_size: _, }=> match mem_type { MemType::Seq => todo!("SeqMem primitives are not currently defined in the flat interpreter"), - MemType::Std => Box::new(StdMemD4::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize, *d2_size as usize, *d3_size as usize))), + MemType::Std => Box::new(CombMemD4::new(base_port, *width, false, (*d0_size as usize, *d1_size as usize, *d2_size as usize, *d3_size as usize))), }, CellPrototype::Unknown(_, _) => todo!(), } diff --git a/interp/src/flatten/primitives/combinational.rs b/interp/src/flatten/primitives/combinational.rs index a1b3899490..7d1dfee6c4 100644 --- a/interp/src/flatten/primitives/combinational.rs +++ b/interp/src/flatten/primitives/combinational.rs @@ -3,7 +3,6 @@ use std::ops::Not; use bitvec::vec::BitVec; use crate::{ - errors::InterpreterResult, flatten::{ flat_ir::prelude::{AssignedValue, GlobalPortIdx, PortValue}, primitives::{ @@ -84,13 +83,6 @@ impl Primitive for StdMux { } } - fn reset(&mut self, port_map: &mut PortMap) -> InterpreterResult<()> { - ports![&self.base; out: Self::OUT]; - port_map.write_undef_unchecked(out); - - Ok(()) - } - fn has_stateful(&self) -> bool { false } diff --git a/interp/src/flatten/primitives/macros.rs b/interp/src/flatten/primitives/macros.rs index 79162465f8..9571e656fd 100644 --- a/interp/src/flatten/primitives/macros.rs +++ b/interp/src/flatten/primitives/macros.rs @@ -102,10 +102,6 @@ macro_rules! comb_primitive { false } - fn reset(&mut self, map:&mut $crate::flatten::structures::environment::PortMap) -> $crate::errors::InterpreterResult<()> { - self.exec_comb(map)?; - Ok(()) - } } }; diff --git a/interp/src/flatten/primitives/prim_trait.rs b/interp/src/flatten/primitives/prim_trait.rs index d5d22b5bc9..c6f712a6ae 100644 --- a/interp/src/flatten/primitives/prim_trait.rs +++ b/interp/src/flatten/primitives/prim_trait.rs @@ -2,7 +2,7 @@ use crate::{ debugger::PrintCode, errors::InterpreterResult, flatten::{flat_ir::base::GlobalPortIdx, structures::environment::PortMap}, - primitives::Serializable, + serialization::Serializable, values::Value, }; @@ -30,6 +30,7 @@ impl From<(Value, GlobalPortIdx)> for AssignResult { } /// An enum used to denote whether or not committed updates changed the state +#[derive(Debug)] pub enum UpdateStatus { Unchanged, Changed, @@ -50,7 +51,7 @@ impl UpdateStatus { /// If the status is unchanged and other is changed, updates the status of /// self to changed, otherwise does nothing pub fn update(&mut self, other: Self) { - if !self.is_changed() && other.is_changed() { + if !self.as_bool() && other.as_bool() { *self = UpdateStatus::Changed; } } @@ -60,7 +61,7 @@ impl UpdateStatus { /// /// [`Changed`]: UpdateStatus::Changed #[must_use] - pub fn is_changed(&self) -> bool { + pub fn as_bool(&self) -> bool { matches!(self, Self::Changed) } } @@ -69,7 +70,7 @@ impl std::ops::BitOr for UpdateStatus { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - if self.is_changed() || rhs.is_changed() { + if self.as_bool() || rhs.as_bool() { UpdateStatus::Changed } else { UpdateStatus::Unchanged @@ -81,7 +82,7 @@ impl std::ops::BitOr for &UpdateStatus { type Output = UpdateStatus; fn bitor(self, rhs: Self) -> Self::Output { - if self.is_changed() || rhs.is_changed() { + if self.as_bool() || rhs.as_bool() { UpdateStatus::Changed } else { UpdateStatus::Unchanged @@ -106,10 +107,6 @@ pub trait Primitive { Ok(UpdateStatus::Unchanged) } - fn reset(&mut self, _port_map: &mut PortMap) -> InterpreterResult<()> { - Ok(()) - } - fn has_comb(&self) -> bool { true } diff --git a/interp/src/flatten/primitives/stateful/memories.rs b/interp/src/flatten/primitives/stateful/memories.rs index 9c5c9a98bc..ddf86cc37e 100644 --- a/interp/src/flatten/primitives/stateful/memories.rs +++ b/interp/src/flatten/primitives/stateful/memories.rs @@ -1,7 +1,7 @@ use crate::{ - errors::{InterpreterError, InterpreterResult}, + errors::InterpreterError, flatten::{ - flat_ir::prelude::{AssignedValue, GlobalPortIdx, PortValue}, + flat_ir::prelude::{AssignedValue, GlobalPortIdx}, primitives::{ declare_ports, make_getters, ports, prim_trait::{UpdateResult, UpdateStatus}, @@ -9,13 +9,14 @@ use crate::{ }, structures::environment::PortMap, }, - primitives::{Entry, Serializable}, + serialization::{Entry, Serializable, Shape}, values::Value, }; pub struct StdReg { base_port: GlobalPortIdx, internal_state: Value, + done_is_high: bool, } impl StdReg { @@ -26,6 +27,7 @@ impl StdReg { Self { base_port, internal_state, + done_is_high: false, } } } @@ -43,7 +45,7 @@ impl Primitive for StdReg { let done_port = if port_map[reset].as_bool().unwrap_or_default() { self.internal_state = Value::zeroes(self.internal_state.width()); port_map - .insert_val(done, AssignedValue::cell_value(Value::bit_low())) + .insert_val(done, AssignedValue::cell_value(Value::bit_low()))? } else if port_map[write_en].as_bool().unwrap_or_default() { self.internal_state = port_map[input] .as_option() @@ -51,12 +53,20 @@ impl Primitive for StdReg { .val() .clone(); - port_map - .insert_val(done, AssignedValue::cell_value(Value::bit_high())) + self.done_is_high = true; + + port_map.insert_val( + done, + AssignedValue::cell_value(Value::bit_high()), + )? | port_map.insert_val( + out_idx, + AssignedValue::cell_value(self.internal_state.clone()), + )? } else { + self.done_is_high = false; port_map - .insert_val(done, AssignedValue::cell_value(Value::bit_low())) - }?; + .insert_val(done, AssignedValue::cell_value(Value::bit_low()))? + }; Ok(done_port | port_map.insert_val( @@ -65,18 +75,24 @@ impl Primitive for StdReg { )?) } - fn reset(&mut self, port_map: &mut PortMap) -> InterpreterResult<()> { - ports![&self.base_port; done: Self::DONE]; - port_map[done] = PortValue::new_cell(Value::bit_low()); - Ok(()) - } - fn exec_comb(&self, port_map: &mut PortMap) -> UpdateResult { - ports![&self.base_port; out_idx: Self::OUT]; - port_map.insert_val( + ports![&self.base_port; + done: Self::DONE, + out_idx: Self::OUT]; + let out_signal = port_map.insert_val( out_idx, AssignedValue::cell_value(self.internal_state.clone()), - ) + )?; + let done_signal = port_map.insert_val( + done, + AssignedValue::cell_value(if self.done_is_high { + Value::bit_high() + } else { + Value::bit_low() + }), + )?; + + Ok(out_signal | done_signal) } fn serialize( @@ -102,9 +118,13 @@ pub trait MemAddresser { port_map: &PortMap, base_port: GlobalPortIdx, ) -> Option; + + fn get_dimensions(&self) -> Shape; } -pub struct MemD1; +pub struct MemD1 { + d0_size: usize, +} impl MemAddresser for MemD1 { fn calculate_addr( @@ -128,6 +148,10 @@ impl MemAddresser for MemD1 { } else { Self::COMB_ADDR0 + 1 }; + + fn get_dimensions(&self) -> Shape { + Shape::D1((self.d0_size,)) + } } impl MemD1 { @@ -135,6 +159,7 @@ impl MemD1 { } pub struct MemD2 { + d0_size: usize, d1_size: usize, } @@ -171,9 +196,14 @@ impl MemAddresser for MemD2 { } else { Self::COMB_ADDR1 + 1 }; + + fn get_dimensions(&self) -> Shape { + Shape::D2((self.d0_size, self.d1_size)) + } } pub struct MemD3 { + d0_size: usize, d1_size: usize, d2_size: usize, } @@ -219,9 +249,14 @@ impl MemAddresser for MemD3 { } else { Self::COMB_ADDR2 + 1 }; + + fn get_dimensions(&self) -> Shape { + Shape::D3((self.d0_size, self.d1_size, self.d2_size)) + } } pub struct MemD4 { + d0_size: usize, d1_size: usize, d2_size: usize, d3_size: usize, @@ -279,24 +314,29 @@ impl MemAddresser for MemD4 { } else { Self::COMB_ADDR3 + 1 }; + + fn get_dimensions(&self) -> Shape { + Shape::D4((self.d0_size, self.d1_size, self.d2_size, self.d3_size)) + } } -pub struct StdMem { +pub struct CombMem { base_port: GlobalPortIdx, internal_state: Vec, allow_invalid_access: bool, width: u32, addresser: M, + done_is_high: bool, } -impl StdMem { +impl CombMem { declare_ports![ - WRITE_DATA: M::NON_ADDRESS_BASE + 1, - WRITE_EN: M::NON_ADDRESS_BASE + 2, - CLK: M::NON_ADDRESS_BASE + 3, - RESET: M::NON_ADDRESS_BASE + 4, - READ_DATA: M::NON_ADDRESS_BASE + 5, - DONE: M::NON_ADDRESS_BASE + 6 + WRITE_DATA: M::NON_ADDRESS_BASE, + WRITE_EN: M::NON_ADDRESS_BASE + 1, + CLK: M::NON_ADDRESS_BASE + 2, + RESET: M::NON_ADDRESS_BASE + 3, + READ_DATA: M::NON_ADDRESS_BASE + 4, + DONE: M::NON_ADDRESS_BASE + 5 ]; make_getters![base_port; @@ -308,25 +348,36 @@ impl StdMem { ]; } -impl Primitive for StdMem { +impl Primitive for CombMem { fn exec_comb(&self, port_map: &mut PortMap) -> UpdateResult { let addr = self.addresser.calculate_addr(port_map, self.base_port); let read_data = self.read_data(); - if addr.is_some() && addr.unwrap() < self.internal_state.len() { - Ok(port_map.insert_val( - read_data, - AssignedValue::cell_value( - self.internal_state[addr.unwrap()].clone(), - ), - )?) - } - // either the address is undefined or it is outside the range of valid addresses - else { - // throw error on cycle boundary rather than here - port_map.write_undef(read_data)?; - Ok(UpdateStatus::Unchanged) - } + let read = + if addr.is_some() && addr.unwrap() < self.internal_state.len() { + port_map.insert_val( + read_data, + AssignedValue::cell_value( + self.internal_state[addr.unwrap()].clone(), + ), + )? + } + // either the address is undefined or it is outside the range of valid addresses + else { + // throw error on cycle boundary rather than here + port_map.write_undef(read_data)?; + UpdateStatus::Unchanged + }; + + let done_signal = port_map.insert_val( + self.done(), + AssignedValue::cell_value(if self.done_is_high { + Value::bit_high() + } else { + Value::bit_low() + }), + )?; + Ok(done_signal | read) } fn exec_cycle(&mut self, port_map: &mut PortMap) -> UpdateResult { @@ -344,8 +395,10 @@ impl Primitive for StdMem { .as_option() .ok_or(InterpreterError::UndefinedWrite)?; self.internal_state[addr] = write_data.val().clone(); + self.done_is_high = true; port_map.insert_val(done, AssignedValue::cell_b_high())? } else { + self.done_is_high = false; port_map.insert_val(done, AssignedValue::cell_b_low())? }; @@ -355,24 +408,24 @@ impl Primitive for StdMem { AssignedValue::cell_value(self.internal_state[addr].clone()), )? | done) } else { + port_map.write_undef(read_data)?; Ok(done) } } - fn reset(&mut self, port_map: &mut PortMap) -> InterpreterResult<()> { - let (read_data, done) = (self.read_data(), self.done()); - - port_map.write_undef_unchecked(read_data); - port_map[done] = PortValue::new_cell(Value::bit_low()); - - Ok(()) - } - fn serialize( &self, - _code: Option, + code: Option, ) -> Serializable { - todo!("StdMemD1::serialize") + let code = code.unwrap_or_default(); + + Serializable::Array( + self.internal_state + .iter() + .map(|x| Entry::from_val_code(x, &code)) + .collect(), + self.addresser.get_dimensions(), + ) } fn has_serializable_state(&self) -> bool { @@ -381,12 +434,12 @@ impl Primitive for StdMem { } // type aliases -pub type StdMemD1 = StdMem>; -pub type StdMemD2 = StdMem>; -pub type StdMemD3 = StdMem>; -pub type StdMemD4 = StdMem>; +pub type CombMemD1 = CombMem>; +pub type CombMemD2 = CombMem>; +pub type CombMemD3 = CombMem>; +pub type CombMemD4 = CombMem>; -impl StdMemD1 { +impl CombMemD1 { pub fn new( base: GlobalPortIdx, width: u32, @@ -400,12 +453,13 @@ impl StdMemD1 { internal_state, allow_invalid_access: allow_invalid, width, - addresser: MemD1::, + addresser: MemD1:: { d0_size: size }, + done_is_high: false, } } } -impl StdMemD2 { +impl CombMemD2 { pub fn new( base: GlobalPortIdx, width: u32, @@ -419,12 +473,16 @@ impl StdMemD2 { internal_state, allow_invalid_access: allow_invalid, width, - addresser: MemD2:: { d1_size: size.1 }, + addresser: MemD2:: { + d0_size: size.0, + d1_size: size.1, + }, + done_is_high: false, } } } -impl StdMemD3 { +impl CombMemD3 { pub fn new( base: GlobalPortIdx, width: u32, @@ -440,14 +498,16 @@ impl StdMemD3 { allow_invalid_access: allow_invalid, width, addresser: MemD3:: { + d0_size: size.0, d1_size: size.1, d2_size: size.2, }, + done_is_high: false, } } } -impl StdMemD4 { +impl CombMemD4 { pub fn new( base: GlobalPortIdx, width: u32, @@ -463,10 +523,12 @@ impl StdMemD4 { allow_invalid_access: allow_invalid, width, addresser: MemD4:: { + d0_size: size.0, d1_size: size.1, d2_size: size.2, d3_size: size.3, }, + done_is_high: false, } } } diff --git a/interp/src/flatten/structures/environment/env.rs b/interp/src/flatten/structures/environment/env.rs index 78ae6132e5..5b17761405 100644 --- a/interp/src/flatten/structures/environment/env.rs +++ b/interp/src/flatten/structures/environment/env.rs @@ -1,4 +1,3 @@ -use ahash::HashSet; use itertools::Itertools; use super::{assignments::AssignmentBundle, program_counter::ProgramCounter}; @@ -20,10 +19,10 @@ use crate::{ }, primitives::{self, prim_trait::UpdateStatus, Primitive}, structures::{ - environment::program_counter::{ControlPoint, SearchPath}, - index_trait::IndexRef, + environment::program_counter::ControlPoint, index_trait::IndexRef, }, }, + values::Value, }; use std::{collections::VecDeque, fmt::Debug}; @@ -134,6 +133,15 @@ impl CellLedger { self.as_comp() .expect("Unwrapped cell ledger as component but received primitive") } + + #[must_use] + pub(crate) fn as_primitive(&self) -> Option<&dyn Primitive> { + if let Self::Primitive { cell_dyn } = self { + Some(&**cell_dyn) + } else { + None + } + } } impl Debug for CellLedger { @@ -348,15 +356,32 @@ impl<'a> Environment<'a> { let definition = &self.ctx.secondary[comp.port_offset_map[port]]; println!( - " {}: {}", + " {}: {} ({:?})", self.ctx.secondary[definition.name], - self.ports[&info.index_bases + port] + self.ports[&info.index_bases + port], + &info.index_bases + port ); } + let cell_idx = &info.index_bases + cell_off; + if definition.prototype.is_component() { - let child_target = &info.index_bases + cell_off; - self.print_component(child_target, hierarchy); + self.print_component(cell_idx, hierarchy); + } else if self.cells[cell_idx] + .as_primitive() + .unwrap() + .has_serializable_state() + { + println!( + " INTERNAL_DATA: {}", + serde_json::to_string_pretty( + &self.cells[cell_idx] + .as_primitive() + .unwrap() + .serialize(None) + ) + .unwrap() + ) } } @@ -377,7 +402,8 @@ impl<'a> Environment<'a> { } /// A wrapper struct for the environment that provides the functions used to -/// simulate the actual program +/// simulate the actual program. This is just to keep the simulation logic under +/// a different namespace than the environment to avoid confusion pub struct Simulator<'a> { env: Environment<'a>, } @@ -482,7 +508,7 @@ impl<'a> Simulator<'a> { control_points .iter() .map(|node| { - match &self.ctx().primary[node.control_node] { + match &self.ctx().primary[node.control_node_idx] { ControlNode::Enable(e) => { (node.comp, self.ctx().primary[e.group()].assignments) } @@ -511,36 +537,25 @@ impl<'a> Simulator<'a> { } pub fn step(&mut self) -> InterpreterResult<()> { - /// attempts to get the next node for the given control point, if found - /// it replaces the given node. Returns true if the node was found and - /// replaced, returns false otherwise - fn get_next(node: &mut ControlPoint, ctx: &Context) -> bool { - let path = SearchPath::find_path_from_root(node.control_node, ctx); - let next = path.next_node(&ctx.primary.control); - if let Some(next) = next { - *node = node.new_w_comp(next); - true - } else { - //need to remove the node from the list now - false - } - } - // place to keep track of what groups we need to conclude at the end of // this step. These are indices into the program counter + // In the future it may be worthwhile to preallocate some space to these + // buffers. Can pick anything from zero to the number of nodes in the + // program counter as the size let mut leaf_nodes = vec![]; + let mut done_groups = vec![]; self.env.pc.vec_mut().retain_mut(|node| { // just considering a single node case for the moment - match &self.env.ctx.primary[node.control_node] { + match &self.env.ctx.primary[node.control_node_idx] { ControlNode::Seq(seq) => { if !seq.is_empty() { let next = seq.stms()[0]; - *node = node.new_w_comp(next); + *node = node.new_retain_comp(next); true } else { - get_next(node, self.env.ctx) + node.mutate_into_next(self.env.ctx) } } ControlNode::Par(_par) => todo!("not ready for par yet"), @@ -567,7 +582,7 @@ impl<'a> Simulator<'a> { }; let target = if result { i.tbranch() } else { i.fbranch() }; - *node = node.new_w_comp(target); + *node = node.new_retain_comp(target); true } ControlNode::While(w) => { @@ -594,50 +609,127 @@ impl<'a> Simulator<'a> { if result { // enter the body - *node = node.new_w_comp(w.body()); + *node = node.new_retain_comp(w.body()); true } else { // ascend the tree - get_next(node, self.env.ctx) + node.mutate_into_next(self.env.ctx) } } // ===== leaf nodes ===== - ControlNode::Empty(_) => get_next(node, self.env.ctx), - ControlNode::Enable(_) => { - leaf_nodes.push(node.clone()); - true + ControlNode::Empty(_) => node.mutate_into_next(self.env.ctx), + ControlNode::Enable(e) => { + let done_local = self.env.ctx.primary[e.group()].done; + let done_idx = &self.env.cells[node.comp] + .as_comp() + .unwrap() + .index_bases + + done_local; + + if !self.env.ports[done_idx].as_bool().unwrap_or_default() { + leaf_nodes.push(node.clone()); + true + } else { + done_groups.push(( + node.clone(), + self.env.ports[done_idx].clone(), + )); + // remove from the list now + false + } } ControlNode::Invoke(_) => todo!("invokes not implemented yet"), } }); - self.simulate_combinational(&leaf_nodes)?; + self.undef_all_ports(); + for (node, val) in &done_groups { + match &self.env.ctx.primary[node.control_node_idx] { + ControlNode::Enable(e) => { + let go_local = self.env.ctx.primary[e.group()].go; + let done_local = self.env.ctx.primary[e.group()].done; + let index_bases = &self.env.cells[node.comp] + .as_comp() + .unwrap() + .index_bases; + let done_idx = index_bases + done_local; + let go_idx = index_bases + go_local; + + // retain done condition from before + self.env.ports[done_idx] = val.clone(); + self.env.ports[go_idx] = + PortValue::new_implicit(Value::bit_high()); + } + ControlNode::Invoke(_) => todo!(), + _ => { + unreachable!("non-leaf node included in list of done nodes. This should never happen, please report it.") + } + } + } - let parent_cells: HashSet = self - .get_assignments(&leaf_nodes) - .iter() - .flat_map(|(cell, assigns)| { - assigns.iter().map(|x| { - let assign = &self.env.ctx.primary[x]; - self.get_parent_cell(assign.dst, *cell) - }) - }) - .flatten() - .collect(); + for node in &leaf_nodes { + match &self.env.ctx.primary[node.control_node_idx] { + ControlNode::Enable(e) => { + let go_local = self.env.ctx.primary[e.group()].go; + let index_bases = &self.env.cells[node.comp] + .as_comp() + .unwrap() + .index_bases; + + // set go high + let go_idx = index_bases + go_local; + self.env.ports[go_idx] = + PortValue::new_implicit(Value::bit_high()); + } + ControlNode::Invoke(_) => todo!(), + non_leaf => { + unreachable!("non-leaf node {:?} included in list of leaf nodes. This should never happen, please report it.", non_leaf) + } + } + } + + self.simulate_combinational(&leaf_nodes)?; - for cell in parent_cells { - match &mut self.env.cells[cell] { + for cell in self.env.cells.values_mut() { + match cell { CellLedger::Primitive { cell_dyn } => { cell_dyn.exec_cycle(&mut self.env.ports)?; } - CellLedger::Component(_) => todo!(), + CellLedger::Component(_) => {} + } + } + + // need to cleanup the finished groups + for (node, _) in done_groups { + if let Some(next) = ControlPoint::get_next(&node, self.env.ctx) { + self.env.pc.insert_node(next) } } Ok(()) } + fn is_done(&self) -> bool { + assert!( + self.ctx().primary[self.ctx().entry_point].control.is_some(), + "flat interpreter doesn't handle a fully structural entrypoint program yet" + ); + // TODO griffin: need to handle structural components + self.env.pc.is_done() + } + + /// Evaluate the entire program + pub fn run_program(&mut self) -> InterpreterResult<()> { + while !self.is_done() { + dbg!("calling step"); + // self.env.print_pc(); + self.print_env(); + self.step()? + } + Ok(()) + } + fn evaluate_guard( &self, guard: GuardIdx, @@ -684,6 +776,12 @@ impl<'a> Simulator<'a> { } } + fn undef_all_ports(&mut self) { + for (_idx, port_val) in self.env.ports.iter_mut() { + port_val.set_undef(); + } + } + fn simulate_combinational( &mut self, control_points: &[ControlPoint], @@ -691,17 +789,6 @@ impl<'a> Simulator<'a> { let assigns_bundle = self.get_assignments(control_points); let mut has_changed = true; - let parent_cells: HashSet = assigns_bundle - .iter() - .flat_map(|(cell, assigns)| { - assigns.iter().map(|x| { - let assign = &self.env.ctx.primary[x]; - self.get_parent_cell(assign.dst, *cell) - }) - }) - .flatten() - .collect(); - while has_changed { has_changed = false; @@ -730,17 +817,21 @@ impl<'a> Simulator<'a> { } // Run all the primitives - let changed: bool = parent_cells + let changed: bool = self + .env + .cells + .range() .iter() - .map(|x| match &mut self.env.cells[*x] { + .filter_map(|x| match &mut self.env.cells[x] { CellLedger::Primitive { cell_dyn } => { - cell_dyn.exec_comb(&mut self.env.ports) + Some(cell_dyn.exec_comb(&mut self.env.ports)) } - CellLedger::Component(_) => todo!(), + CellLedger::Component(_) => None, }) - .fold_ok(false, |has_changed, update| { - has_changed | update.is_changed() - })?; + .fold_ok(UpdateStatus::Unchanged, |has_changed, update| { + has_changed | update + })? + .as_bool(); has_changed |= changed; } @@ -750,11 +841,8 @@ impl<'a> Simulator<'a> { pub fn _main_test(&mut self) { self.env.print_pc(); - for _x in self.env.pc.iter() { - // println!("{:?} next {:?}", x, self.find_next_control_point(x)) - } + let _ = self.run_program(); self.env.print_pc(); self.print_env(); - // println!("{:?}", self.get_assignments()) } } diff --git a/interp/src/flatten/structures/environment/program_counter.rs b/interp/src/flatten/structures/environment/program_counter.rs index 90ed42195a..d40446f690 100644 --- a/interp/src/flatten/structures/environment/program_counter.rs +++ b/interp/src/flatten/structures/environment/program_counter.rs @@ -11,27 +11,46 @@ use crate::flatten::{ use itertools::{FoldWhile, Itertools}; /// Simple struct containing both the component instance and the active leaf -/// node in the component +/// node in the component. This is used to represent an active execution of some +/// portion of the control tree #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub struct ControlPoint { pub comp: GlobalCellIdx, - pub control_node: ControlIdx, + pub control_node_idx: ControlIdx, } impl ControlPoint { pub fn new(comp: GlobalCellIdx, control_leaf: ControlIdx) -> Self { Self { comp, - control_node: control_leaf, + control_node_idx: control_leaf, } } /// Constructs a new [ControlPoint] from an existing one by copying over the /// component identifier but changing the leaf node - pub fn new_w_comp(&self, target: ControlIdx) -> Self { + pub fn new_retain_comp(&self, target: ControlIdx) -> Self { Self { comp: self.comp, - control_node: target, + control_node_idx: target, + } + } + + pub fn get_next(node: &Self, ctx: &Context) -> Option { + let path = SearchPath::find_path_from_root(node.control_node_idx, ctx); + let next = path.next_node(&ctx.primary.control); + next.map(|x| node.new_retain_comp(x)) + } + + /// Attempts to get the next node for the given control point, if found + /// it replaces the given node. Returns true if the node was found and + /// replaced, returns false otherwise + pub fn mutate_into_next(&mut self, ctx: &Context) -> bool { + if let Some(next) = Self::get_next(self, ctx) { + *self = next; + true + } else { + false } } } @@ -79,7 +98,7 @@ impl SearchPath { } pub fn source_node(&self) -> Option<&SearchNode> { - self.path.get(0) + self.path.first() } pub fn len(&self) -> usize { @@ -232,8 +251,37 @@ impl SearchPath { search_index: None, }) } - ControlNode::If(_) => todo!(), - ControlNode::While(_) => todo!(), + ControlNode::If(i) => { + if let Some(idx) = &mut node.search_index { + if idx.is_true_branch() { + *idx = SearchIndex::new(SearchIndex::FALSE_BRANCH); + current_path.path.push(SearchNode { + node: i.fbranch(), + search_index: None, + }) + } else { + current_path.path.pop(); + } + } else { + node.search_index = + Some(SearchIndex::new(SearchIndex::TRUE_BRANCH)); + current_path.path.push(SearchNode { + node: i.tbranch(), + search_index: None, + }) + } + } + ControlNode::While(w) => { + if node.search_index.is_some() { + current_path.path.pop(); + } else { + node.search_index = Some(SearchIndex::new(0)); + current_path.path.push(SearchNode { + node: w.body(), + search_index: None, + }) + } + } } } @@ -294,7 +342,7 @@ impl ProgramCounter { if let Some(current) = ctx.primary[root].control { vec.push(ControlPoint { comp: root_cell, - control_node: current, + control_node_idx: current, }) } else { todo!( @@ -323,6 +371,10 @@ impl ProgramCounter { pub(crate) fn vec_mut(&mut self) -> &mut Vec { &mut self.vec } + + pub(crate) fn insert_node(&mut self, node: ControlPoint) { + self.vec.push(node) + } } impl<'a> IntoIterator for &'a ProgramCounter { diff --git a/interp/src/flatten/structures/indexed_map.rs b/interp/src/flatten/structures/indexed_map.rs index 5b201d9695..78abde87a3 100644 --- a/interp/src/flatten/structures/indexed_map.rs +++ b/interp/src/flatten/structures/indexed_map.rs @@ -1,4 +1,4 @@ -use super::index_trait::{IndexRangeIterator, IndexRef}; +use super::index_trait::{IndexRange, IndexRangeIterator, IndexRef}; use std::{ marker::PhantomData, ops::{self, Index}, @@ -13,6 +13,18 @@ where phantom: PhantomData, } +impl IndexedMap +where + K: IndexRef + PartialOrd, +{ + /// Produces a range containing all the keys in the input map. This is + /// similar to [IndexedMap::keys] but has an independent lifetime from the + /// map + pub fn range(&self) -> IndexRange { + IndexRange::new(K::new(0), K::new(self.len())) + } +} + impl IndexedMap where K: IndexRef, @@ -98,6 +110,17 @@ where self.data.iter().enumerate().map(|(i, v)| (K::new(i), v)) } + pub fn iter_mut(&mut self) -> impl Iterator { + self.data + .iter_mut() + .enumerate() + .map(|(i, v)| (K::new(i), v)) + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.data.iter_mut() + } + pub fn keys(&self) -> impl Iterator + '_ { // TODO (griffin): Make this an actual struct instead self.data.iter().enumerate().map(|(i, _)| K::new(i)) diff --git a/interp/src/flatten/structures/printer.rs b/interp/src/flatten/structures/printer.rs index 2aa7a1d524..8a3f5ba5c2 100644 --- a/interp/src/flatten/structures/printer.rs +++ b/interp/src/flatten/structures/printer.rs @@ -173,14 +173,21 @@ impl<'a> Printer<'a> { match &self.ctx.primary[control] { ControlNode::Empty(_) => String::new(), ControlNode::Enable(e) => text_utils::indent( - self.ctx.secondary[self.ctx.primary[e.group()].name()].clone() - + ";", + format!( + "{}; ({:?})", + self.ctx.secondary[self.ctx.primary[e.group()].name()] + .clone(), + control + ), indent, ), // TODO Griffin: refactor into shared function rather than copy-paste? ControlNode::Seq(s) => { - let mut seq = text_utils::indent("seq {\n", indent); + let mut seq = text_utils::indent( + format!("seq {{ ({:?})\n", control), + indent, + ); for stmt in s.stms() { let child = self.format_control(parent, *stmt, indent + 1); seq += &child; diff --git a/interp/src/interpreter/component_interpreter.rs b/interp/src/interpreter/component_interpreter.rs index 778500df51..737b941ffe 100644 --- a/interp/src/interpreter/component_interpreter.rs +++ b/interp/src/interpreter/component_interpreter.rs @@ -471,8 +471,8 @@ impl Primitive for ComponentInterpreter { fn serialize( &self, _signed: Option, - ) -> crate::primitives::Serializable { - crate::primitives::Serializable::Full( + ) -> crate::serialization::Serializable { + crate::serialization::Serializable::Full( self.get_env() .gen_serializer(matches!(_signed, Some(PrintCode::Binary))), ) diff --git a/interp/src/lib.rs b/interp/src/lib.rs index 414d2f8964..10739d5875 100644 --- a/interp/src/lib.rs +++ b/interp/src/lib.rs @@ -1,5 +1,6 @@ pub mod interpreter; pub mod primitives; +mod serialization; pub use utils::MemoryMap; pub mod configuration; pub mod debugger; diff --git a/interp/src/primitives/combinational.rs b/interp/src/primitives/combinational.rs index fbe9986918..ce246fef0f 100644 --- a/interp/src/primitives/combinational.rs +++ b/interp/src/primitives/combinational.rs @@ -5,9 +5,12 @@ use super::{ primitive_traits::Named, Primitive, }; -use crate::logging::warn; use crate::values::Value; use crate::{comb_primitive, errors::InterpreterError}; +use crate::{ + logging::warn, + serialization::{Entry, Serializable}, +}; use bitvec::vec::BitVec; use calyx_ir as ir; use std::ops::Not; @@ -77,12 +80,9 @@ impl Primitive for StdConst { fn serialize( &self, code: Option, - ) -> super::Serializable { + ) -> Serializable { let code = code.unwrap_or(crate::debugger::PrintCode::Unsigned); - super::Serializable::Val(super::Entry::from_val_code( - &self.value, - &code, - )) + Serializable::Val(Entry::from_val_code(&self.value, &code)) } } diff --git a/interp/src/primitives/mod.rs b/interp/src/primitives/mod.rs index 6523d7ffa3..0d95116e94 100644 --- a/interp/src/primitives/mod.rs +++ b/interp/src/primitives/mod.rs @@ -1,8 +1,7 @@ mod primitive_traits; -pub use primitive_traits::Entry; + pub use primitive_traits::Named; pub use primitive_traits::Primitive; -pub use primitive_traits::Serializable; pub mod combinational; pub(super) mod prim_utils; diff --git a/interp/src/primitives/primitive_traits.rs b/interp/src/primitives/primitive_traits.rs index 276012d2de..df888bf1c4 100644 --- a/interp/src/primitives/primitive_traits.rs +++ b/interp/src/primitives/primitive_traits.rs @@ -1,18 +1,10 @@ use crate::{ - errors::InterpreterResult, - interpreter::ComponentInterpreter, - structures::state_views::{FullySerialize, StateView}, - utils::PrintCode, - values::Value, + errors::InterpreterResult, interpreter::ComponentInterpreter, + serialization::Serializable, structures::state_views::StateView, + utils::PrintCode, values::Value, }; use calyx_ir as ir; -use fraction::Fraction; - -use itertools::Itertools; -use serde::Serialize; -use std::fmt::Debug; -use std::fmt::Display; /// A trait indicating that the thing has a name pub trait Named { @@ -71,276 +63,3 @@ pub trait Primitive: Named { None } } - -/// An enum wrapping over a tuple representing the shape of a multi-dimensional -/// array -#[derive(Clone)] -pub enum Shape { - D1((usize,)), - D2((usize, usize)), - D3((usize, usize, usize)), - D4((usize, usize, usize, usize)), -} - -impl Shape { - fn is_1d(&self) -> bool { - matches!(self, Shape::D1(_)) - } - - pub(crate) fn dim_str(&self) -> String { - match self { - Shape::D1(_) => String::from("1D"), - Shape::D2(_) => String::from("2D"), - Shape::D3(_) => String::from("3D"), - Shape::D4(_) => String::from("4D"), - } - } -} -impl From for Shape { - fn from(u: usize) -> Self { - Shape::D1((u,)) - } -} -impl From<(usize,)> for Shape { - fn from(u: (usize,)) -> Self { - Shape::D1(u) - } -} -impl From<(usize, usize)> for Shape { - fn from(u: (usize, usize)) -> Self { - Shape::D2(u) - } -} - -impl From<(usize, usize, usize)> for Shape { - fn from(u: (usize, usize, usize)) -> Self { - Shape::D3(u) - } -} - -impl From<(usize, usize, usize, usize)> for Shape { - fn from(u: (usize, usize, usize, usize)) -> Self { - Shape::D4(u) - } -} - -/// A wrapper enum used during serialization. It represents either an unsigned integer, -/// or a signed integer and is serialized as the underlying integer. This also allows -/// mixed serialization of signed and unsigned values -#[derive(Serialize, Clone)] -#[serde(untagged)] -pub enum Entry { - U(u64), - I(i64), - Frac(Fraction), - Value(Value), -} - -impl From for Entry { - fn from(u: u64) -> Self { - Self::U(u) - } -} - -impl From for Entry { - fn from(i: i64) -> Self { - Self::I(i) - } -} - -impl From for Entry { - fn from(f: Fraction) -> Self { - Self::Frac(f) - } -} - -impl Entry { - pub fn from_val_code(val: &Value, code: &PrintCode) -> Self { - match code { - PrintCode::Unsigned => val.as_u64().into(), - PrintCode::Signed => val.as_i64().into(), - PrintCode::UFixed(f) => val.as_ufp(*f).into(), - PrintCode::SFixed(f) => val.as_sfp(*f).into(), - PrintCode::Binary => Entry::Value(val.clone()), - } - } -} - -impl Display for Entry { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Entry::U(v) => write!(f, "{}", v), - Entry::I(v) => write!(f, "{}", v), - Entry::Frac(v) => write!(f, "{}", v), - Entry::Value(v) => write!(f, "{}", v), - } - } -} - -impl Debug for Entry { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) - } -} - -#[derive(Clone)] -pub enum Serializable { - Empty, - Val(Entry), - Array(Vec, Shape), - Full(FullySerialize), -} - -impl Serializable { - pub fn has_state(&self) -> bool { - !matches!(self, Serializable::Empty) - } -} - -impl Display for Serializable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Serializable::Empty => write!(f, ""), - Serializable::Val(v) => write!(f, "{}", v), - Serializable::Array(arr, shape) => { - write!(f, "{}", format_array(arr, shape)) - } - full @ Serializable::Full(_) => { - write!(f, "{}", serde_json::to_string(full).unwrap()) - } - } - } -} - -impl Serialize for Serializable { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self { - Serializable::Empty => serializer.serialize_unit(), - Serializable::Val(u) => u.serialize(serializer), - Serializable::Array(arr, shape) => { - let arr: Vec<&Entry> = arr.iter().collect(); - if shape.is_1d() { - return arr.serialize(serializer); - } - // there's probably a better way to write this - match shape { - Shape::D2(shape) => { - let mem = arr - .iter() - .chunks(shape.1) - .into_iter() - .map(|x| x.into_iter().collect::>()) - .collect::>(); - mem.serialize(serializer) - } - Shape::D3(shape) => { - let mem = arr - .iter() - .chunks(shape.1 * shape.2) - .into_iter() - .map(|x| { - x.into_iter() - .chunks(shape.2) - .into_iter() - .map(|y| y.into_iter().collect::>()) - .collect::>() - }) - .collect::>(); - mem.serialize(serializer) - } - Shape::D4(shape) => { - let mem = arr - .iter() - .chunks(shape.2 * shape.1 * shape.3) - .into_iter() - .map(|x| { - x.into_iter() - .chunks(shape.2 * shape.3) - .into_iter() - .map(|y| { - y.into_iter() - .chunks(shape.3) - .into_iter() - .map(|z| { - z.into_iter() - .collect::>() - }) - .collect::>() - }) - .collect::>() - }) - .collect::>(); - mem.serialize(serializer) - } - Shape::D1(_) => unreachable!(), - } - } - Serializable::Full(s) => s.serialize(serializer), - } - } -} - -impl Serialize for dyn Primitive { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.serialize(None).serialize(serializer) - } -} - -fn format_array(arr: &[Entry], shape: &Shape) -> String { - match shape { - Shape::D2(shape) => { - let mem = arr - .iter() - .chunks(shape.1) - .into_iter() - .map(|x| x.into_iter().collect::>()) - .collect::>(); - format!("{:?}", mem) - } - Shape::D3(shape) => { - let mem = arr - .iter() - .chunks(shape.1 * shape.0) - .into_iter() - .map(|x| { - x.into_iter() - .chunks(shape.2) - .into_iter() - .map(|y| y.into_iter().collect::>()) - .collect::>() - }) - .collect::>(); - format!("{:?}", mem) - } - Shape::D4(shape) => { - let mem = arr - .iter() - .chunks(shape.2 * shape.1 * shape.3) - .into_iter() - .map(|x| { - x.into_iter() - .chunks(shape.2 * shape.3) - .into_iter() - .map(|y| { - y.into_iter() - .chunks(shape.3) - .into_iter() - .map(|z| z.into_iter().collect::>()) - .collect::>() - }) - .collect::>() - }) - .collect::>(); - format!("{:?}", mem) - } - Shape::D1(_) => { - format!("{:?}", arr) - } - } -} diff --git a/interp/src/primitives/stateful/math.rs b/interp/src/primitives/stateful/math.rs index c9bf0bc130..02563bf7dd 100644 --- a/interp/src/primitives/stateful/math.rs +++ b/interp/src/primitives/stateful/math.rs @@ -1,8 +1,9 @@ use super::super::prim_utils::{get_inputs, get_param, ShiftBuffer}; use super::super::primitive_traits::Named; -use super::super::{Entry, Primitive, Serializable}; +use super::super::Primitive; use crate::errors::{InterpreterError, InterpreterResult}; use crate::logging::{self, warn}; +use crate::serialization::{Entry, Serializable}; use crate::utils::PrintCode; use crate::validate; use crate::values::Value; diff --git a/interp/src/primitives/stateful/mem_utils.rs b/interp/src/primitives/stateful/mem_utils.rs index 7ec5093c53..1803e7d4c2 100644 --- a/interp/src/primitives/stateful/mem_utils.rs +++ b/interp/src/primitives/stateful/mem_utils.rs @@ -2,10 +2,8 @@ use calyx_ir as ir; use crate::{ errors::{InterpreterError, InterpreterResult}, - primitives::{ - prim_utils::{get_inputs, get_params}, - primitive_traits::Shape, - }, + primitives::prim_utils::{get_inputs, get_params}, + serialization::Shape, validate_friendly, values::Value, }; diff --git a/interp/src/primitives/stateful/memories.rs b/interp/src/primitives/stateful/memories.rs index 15ffa7a246..01b47e9573 100644 --- a/interp/src/primitives/stateful/memories.rs +++ b/interp/src/primitives/stateful/memories.rs @@ -4,8 +4,9 @@ use crate::{ errors::{InterpreterError, InterpreterResult}, primitives::{ prim_utils::{get_inputs, get_param, output}, - Entry, Named, Primitive, Serializable, + Named, Primitive, }, + serialization::{Entry, Serializable}, utils::construct_bindings, validate, validate_friendly, values::Value, diff --git a/interp/src/serialization.rs b/interp/src/serialization.rs new file mode 100644 index 0000000000..2d23042a9b --- /dev/null +++ b/interp/src/serialization.rs @@ -0,0 +1,282 @@ +use fraction::Fraction; +use itertools::Itertools; +use serde::Serialize; +use std::fmt::{Debug, Display}; + +use crate::{ + primitives::Primitive, structures::state_views::FullySerialize, + utils::PrintCode, values::Value, +}; + +/// An enum wrapping over a tuple representing the shape of a multi-dimensional +/// array +#[derive(Clone)] +pub enum Shape { + D1((usize,)), + D2((usize, usize)), + D3((usize, usize, usize)), + D4((usize, usize, usize, usize)), +} + +impl Shape { + fn is_1d(&self) -> bool { + matches!(self, Shape::D1(_)) + } + + pub(crate) fn dim_str(&self) -> String { + match self { + Shape::D1(_) => String::from("1D"), + Shape::D2(_) => String::from("2D"), + Shape::D3(_) => String::from("3D"), + Shape::D4(_) => String::from("4D"), + } + } +} +impl From for Shape { + fn from(u: usize) -> Self { + Shape::D1((u,)) + } +} +impl From<(usize,)> for Shape { + fn from(u: (usize,)) -> Self { + Shape::D1(u) + } +} +impl From<(usize, usize)> for Shape { + fn from(u: (usize, usize)) -> Self { + Shape::D2(u) + } +} + +impl From<(usize, usize, usize)> for Shape { + fn from(u: (usize, usize, usize)) -> Self { + Shape::D3(u) + } +} + +impl From<(usize, usize, usize, usize)> for Shape { + fn from(u: (usize, usize, usize, usize)) -> Self { + Shape::D4(u) + } +} + +/// A wrapper enum used during serialization. It represents either an unsigned integer, +/// or a signed integer and is serialized as the underlying integer. This also allows +/// mixed serialization of signed and unsigned values +#[derive(Serialize, Clone)] +#[serde(untagged)] +pub enum Entry { + U(u64), + I(i64), + Frac(Fraction), + Value(Value), +} + +impl From for Entry { + fn from(u: u64) -> Self { + Self::U(u) + } +} + +impl From for Entry { + fn from(i: i64) -> Self { + Self::I(i) + } +} + +impl From for Entry { + fn from(f: Fraction) -> Self { + Self::Frac(f) + } +} + +impl Entry { + pub fn from_val_code(val: &Value, code: &PrintCode) -> Self { + match code { + PrintCode::Unsigned => val.as_u64().into(), + PrintCode::Signed => val.as_i64().into(), + PrintCode::UFixed(f) => val.as_ufp(*f).into(), + PrintCode::SFixed(f) => val.as_sfp(*f).into(), + PrintCode::Binary => Entry::Value(val.clone()), + } + } +} + +impl Display for Entry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Entry::U(v) => write!(f, "{}", v), + Entry::I(v) => write!(f, "{}", v), + Entry::Frac(v) => write!(f, "{}", v), + Entry::Value(v) => write!(f, "{}", v), + } + } +} + +impl Debug for Entry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self) + } +} + +#[derive(Clone)] +pub enum Serializable { + Empty, + Val(Entry), + Array(Vec, Shape), + Full(FullySerialize), +} + +impl Serializable { + pub fn has_state(&self) -> bool { + !matches!(self, Serializable::Empty) + } +} + +impl Display for Serializable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Serializable::Empty => write!(f, ""), + Serializable::Val(v) => write!(f, "{}", v), + Serializable::Array(arr, shape) => { + write!(f, "{}", format_array(arr, shape)) + } + full @ Serializable::Full(_) => { + write!(f, "{}", serde_json::to_string(full).unwrap()) + } + } + } +} + +impl Serialize for Serializable { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Serializable::Empty => serializer.serialize_unit(), + Serializable::Val(u) => u.serialize(serializer), + Serializable::Array(arr, shape) => { + let arr: Vec<&Entry> = arr.iter().collect(); + if shape.is_1d() { + return arr.serialize(serializer); + } + // there's probably a better way to write this + match shape { + Shape::D2(shape) => { + let mem = arr + .iter() + .chunks(shape.1) + .into_iter() + .map(|x| x.into_iter().collect::>()) + .collect::>(); + mem.serialize(serializer) + } + Shape::D3(shape) => { + let mem = arr + .iter() + .chunks(shape.1 * shape.2) + .into_iter() + .map(|x| { + x.into_iter() + .chunks(shape.2) + .into_iter() + .map(|y| y.into_iter().collect::>()) + .collect::>() + }) + .collect::>(); + mem.serialize(serializer) + } + Shape::D4(shape) => { + let mem = arr + .iter() + .chunks(shape.2 * shape.1 * shape.3) + .into_iter() + .map(|x| { + x.into_iter() + .chunks(shape.2 * shape.3) + .into_iter() + .map(|y| { + y.into_iter() + .chunks(shape.3) + .into_iter() + .map(|z| { + z.into_iter() + .collect::>() + }) + .collect::>() + }) + .collect::>() + }) + .collect::>(); + mem.serialize(serializer) + } + Shape::D1(_) => unreachable!(), + } + } + Serializable::Full(s) => s.serialize(serializer), + } + } +} + +impl Serialize for dyn Primitive { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.serialize(None).serialize(serializer) + } +} + +fn format_array(arr: &[Entry], shape: &Shape) -> String { + match shape { + Shape::D2(shape) => { + let mem = arr + .iter() + .chunks(shape.1) + .into_iter() + .map(|x| x.into_iter().collect::>()) + .collect::>(); + format!("{:?}", mem) + } + Shape::D3(shape) => { + let mem = arr + .iter() + .chunks(shape.1 * shape.0) + .into_iter() + .map(|x| { + x.into_iter() + .chunks(shape.2) + .into_iter() + .map(|y| y.into_iter().collect::>()) + .collect::>() + }) + .collect::>(); + format!("{:?}", mem) + } + Shape::D4(shape) => { + let mem = arr + .iter() + .chunks(shape.2 * shape.1 * shape.3) + .into_iter() + .map(|x| { + x.into_iter() + .chunks(shape.2 * shape.3) + .into_iter() + .map(|y| { + y.into_iter() + .chunks(shape.3) + .into_iter() + .map(|z| z.into_iter().collect::>()) + .collect::>() + }) + .collect::>() + }) + .collect::>(); + format!("{:?}", mem) + } + Shape::D1(_) => { + format!("{:?}", arr) + } + } +} diff --git a/interp/src/structures/state_views.rs b/interp/src/structures/state_views.rs index f44adcfb6d..8f09a7bb78 100644 --- a/interp/src/structures/state_views.rs +++ b/interp/src/structures/state_views.rs @@ -13,7 +13,8 @@ use crate::{ environment::{InterpreterState, PrimitiveMap}, interpreter::ConstCell, interpreter_ir as iir, - primitives::{Entry, Primitive, Serializable}, + primitives::Primitive, + serialization::{Entry, Serializable}, utils::AsRaw, values::Value, };