diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 90e3644dd5..2a36cebd34 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -65,6 +65,7 @@ pub mod scalar; pub mod shape; mod sort; mod storage; +pub mod streaming; mod strided_index; mod tensor; mod tensor_cat; @@ -84,6 +85,7 @@ pub use indexer::IndexOp; pub use layout::Layout; pub use shape::{Shape, D}; pub use storage::Storage; +pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule}; pub use strided_index::{StridedBlocks, StridedIndex}; pub use tensor::{Tensor, TensorId}; pub use variable::Var; diff --git a/candle-core/src/streaming.rs b/candle-core/src/streaming.rs new file mode 100644 index 0000000000..f70ec51e6c --- /dev/null +++ b/candle-core/src/streaming.rs @@ -0,0 +1,206 @@ +use crate::{Result, Shape, Tensor}; + +pub trait Dim: crate::shape::Dim + Copy {} +impl Dim for T {} + +/// A stream tensor is used in streaming module. It can either contain an actual tensor or be +/// empty. +#[derive(Clone)] +pub struct StreamTensor(Option); + +impl std::fmt::Debug for StreamTensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + Some(t) => write!(f, "{:?}", t.shape()), + None => write!(f, "Empty"), + } + } +} + +impl std::convert::From> for StreamTensor { + fn from(value: Option) -> Self { + Self(value) + } +} + +impl std::convert::From for StreamTensor { + fn from(value: Tensor) -> Self { + Self(Some(value)) + } +} + +impl std::convert::From<()> for StreamTensor { + fn from(_value: ()) -> Self { + Self(None) + } +} + +impl StreamTensor { + pub fn empty() -> Self { + Self(None) + } + + pub fn from_tensor(tensor: Tensor) -> Self { + Self(Some(tensor)) + } + + pub fn shape(&self) -> Option<&Shape> { + self.0.as_ref().map(|t| t.shape()) + } + + pub fn cat2(&self, rhs: &Self, dim: D) -> Result { + let xs = match (&self.0, &rhs.0) { + (Some(lhs), Some(rhs)) => { + let xs = Tensor::cat(&[lhs, rhs], dim)?; + Some(xs) + } + (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()), + (None, None) => None, + }; + Ok(Self(xs)) + } + + pub fn seq_len(&self, dim: D) -> Result { + match &self.0 { + None => Ok(0), + Some(v) => v.dim(dim), + } + } + + pub fn reset(&mut self) { + self.0 = None + } + + pub fn narrow(&self, dim: D, offset: usize, len: usize) -> Result { + let t = match &self.0 { + None => None, + Some(t) => { + let seq_len = t.dim(dim)?; + if seq_len <= offset { + None + } else { + let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?; + Some(t) + } + } + }; + Ok(Self(t)) + } + + /// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements + /// returned in the first output and the remaining in the second output. + pub fn split(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> { + match &self.0 { + None => Ok((Self::empty(), Self::empty())), + Some(t) => { + let seq_len = t.dim(dim)?; + let lhs_len = usize::min(seq_len, lhs_len); + if lhs_len == 0 { + Ok((Self::empty(), t.clone().into())) + } else { + let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?); + let rhs_len = seq_len - lhs_len; + let rhs = if rhs_len == 0 { + Self::empty() + } else { + Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?) + }; + Ok((lhs, rhs)) + } + } + } + } + + pub fn as_option(&self) -> Option<&Tensor> { + self.0.as_ref() + } + + pub fn apply(&self, m: &M) -> Result { + match &self.0 { + None => Ok(Self::empty()), + Some(t) => Ok(Self::from_tensor(t.apply(m)?)), + } + } +} + +/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform +/// some internal buffering so that enough data has been received for the module to be able to +/// perform some operations. +pub trait StreamingModule { + // TODO: Should we also have a flush method? + fn step(&mut self, xs: &StreamTensor) -> Result; + fn reset_state(&mut self); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinOp { + Add, + Mul, + Sub, + Div, +} + +#[derive(Debug, Clone)] +pub struct StreamingBinOp { + prev_lhs: StreamTensor, + prev_rhs: StreamTensor, + pub op: BinOp, + pub dim: crate::D, +} + +impl StreamingBinOp { + pub fn new(op: BinOp, dim: crate::D) -> Self { + Self { + prev_lhs: StreamTensor::empty(), + prev_rhs: StreamTensor::empty(), + op, + dim, + } + } + + pub fn reset_state(&mut self) { + self.prev_lhs.reset(); + self.prev_rhs.reset(); + } + + pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result { + match self.op { + BinOp::Add => Tensor::add(lhs, rhs), + BinOp::Mul => Tensor::mul(lhs, rhs), + BinOp::Sub => Tensor::sub(lhs, rhs), + BinOp::Div => Tensor::div(lhs, rhs), + } + } + + pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result { + let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?; + let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?; + let lhs_len = lhs.seq_len(self.dim)?; + let rhs_len = rhs.seq_len(self.dim)?; + let common_len = usize::min(lhs_len, rhs_len); + let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?; + let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?; + let ys = match (lhs.0, rhs.0) { + (Some(lhs), Some(rhs)) => { + let ys = self.forward(&lhs, &rhs)?; + StreamTensor::from_tensor(ys) + } + (None, None) => StreamTensor::empty(), + (lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"), + }; + self.prev_lhs = prev_lhs; + self.prev_rhs = prev_rhs; + Ok(ys) + } +} + +/// Simple wrapper that doesn't do any buffering. +pub struct Map(T); + +impl StreamingModule for Map { + fn reset_state(&mut self) {} + + fn step(&mut self, xs: &StreamTensor) -> Result { + xs.apply(&self.0) + } +}