diff --git a/pkg/apis/numaflow/v1alpha1/udf.go b/pkg/apis/numaflow/v1alpha1/udf.go index 573ddcbca..7a1a44c70 100644 --- a/pkg/apis/numaflow/v1alpha1/udf.go +++ b/pkg/apis/numaflow/v1alpha1/udf.go @@ -51,6 +51,9 @@ func (in UDF) getContainers(req getContainerReq) ([]corev1.Container, []corev1.C func (in UDF) getMainContainer(req getContainerReq) corev1.Container { if in.GroupBy == nil { + if req.executeRustBinary { + return containerBuilder{}.init(req).command(NumaflowRustBinary).args("processor", "--type="+string(VertexTypeMapUDF), "--isbsvc-type="+string(req.isbSvcType), "--rust").build() + } args := []string{"processor", "--type=" + string(VertexTypeMapUDF), "--isbsvc-type=" + string(req.isbSvcType)} return containerBuilder{}. init(req).args(args...).build() diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a210284fc..beec59aa4 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1722,6 +1722,28 @@ dependencies = [ "uuid", ] +[[package]] +name = "numaflow" +version = "0.2.1" +source = "git+https://github.com/numaproj/numaflow-rs.git?rev=9ca9362ad511084501520e5a37d40cdcd0cdc9d9#9ca9362ad511084501520e5a37d40cdcd0cdc9d9" +dependencies = [ + "chrono", + "futures-util", + "hyper-util", + "prost 0.13.3", + "prost-types 0.13.3", + "serde", + "serde_json", + "thiserror 1.0.69", + "tokio", + "tokio-stream", + "tokio-util", + "tonic", + "tonic-build", + "tracing", + "uuid", +] + [[package]] name = "numaflow-core" version = "0.1.0" @@ -1736,7 +1758,7 @@ dependencies = [ "futures", "hyper-util", "kube", - "numaflow 0.1.1", + "numaflow 0.2.1", "numaflow-models", "numaflow-pb", "numaflow-pulsar", diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index b4688a135..38cabb704 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -49,7 +49,7 @@ async-nats = "0.38.0" [dev-dependencies] tempfile = "3.11.0" -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "ddd879588e11455921f1ca958ea2b3c076689293" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "9ca9362ad511084501520e5a37d40cdcd0cdc9d9" } pulsar = { version = "6.3.0", default-features = false, features = ["tokio-rustls-runtime"] } [build-dependencies] diff --git a/rust/numaflow-core/src/config/pipeline.rs b/rust/numaflow-core/src/config/pipeline.rs index 9509e8f4a..1368b0b32 100644 --- a/rust/numaflow-core/src/config/pipeline.rs +++ b/rust/numaflow-core/src/config/pipeline.rs @@ -14,6 +14,8 @@ use crate::config::components::source::SourceConfig; use crate::config::components::transformer::{TransformerConfig, TransformerType}; use crate::config::get_vertex_replica; use crate::config::pipeline::isb::{BufferReaderConfig, BufferWriterConfig}; +use crate::config::pipeline::map::MapMode; +use crate::config::pipeline::map::MapVtxConfig; use crate::error::Error; use crate::Result; @@ -23,6 +25,11 @@ const DEFAULT_LOOKBACK_WINDOW_IN_SECS: u16 = 120; const ENV_NUMAFLOW_SERVING_JETSTREAM_URL: &str = "NUMAFLOW_ISBSVC_JETSTREAM_URL"; const ENV_NUMAFLOW_SERVING_JETSTREAM_USER: &str = "NUMAFLOW_ISBSVC_JETSTREAM_USER"; const ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD: &str = "NUMAFLOW_ISBSVC_JETSTREAM_PASSWORD"; +const DEFAULT_GRPC_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; // 64 MB +const DEFAULT_MAP_SOCKET: &str = "/var/run/numaflow/map.sock"; +pub(crate) const DEFAULT_BATCH_MAP_SOCKET: &str = "/var/run/numaflow/batchmap.sock"; +pub(crate) const DEFAULT_STREAM_MAP_SOCKET: &str = "/var/run/numaflow/mapstream.sock"; +const DEFAULT_MAP_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info"; pub(crate) mod isb; @@ -69,6 +76,84 @@ pub(crate) struct SourceVtxConfig { pub(crate) transformer_config: Option, } +pub(crate) mod map { + use std::collections::HashMap; + + use numaflow_models::models::Udf; + + use crate::config::pipeline::{ + DEFAULT_GRPC_MAX_MESSAGE_SIZE, DEFAULT_MAP_SERVER_INFO_FILE, DEFAULT_MAP_SOCKET, + }; + use crate::error::Error; + + /// A map can be run in different modes. + #[derive(Debug, Clone, PartialEq)] + pub enum MapMode { + Unary, + Batch, + Stream, + } + + impl MapMode { + pub(crate) fn from_str(s: &str) -> Option { + match s { + "unary-map" => Some(MapMode::Unary), + "stream-map" => Some(MapMode::Stream), + "batch-map" => Some(MapMode::Batch), + _ => None, + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) struct MapVtxConfig { + pub(crate) concurrency: usize, + pub(crate) map_type: MapType, + pub(crate) map_mode: MapMode, + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) enum MapType { + UserDefined(UserDefinedConfig), + Builtin(BuiltinConfig), + } + + impl TryFrom> for MapType { + type Error = Error; + fn try_from(udf: Box) -> std::result::Result { + if let Some(builtin) = udf.builtin { + Ok(MapType::Builtin(BuiltinConfig { + name: builtin.name, + kwargs: builtin.kwargs, + args: builtin.args, + })) + } else if let Some(_container) = udf.container { + Ok(MapType::UserDefined(UserDefinedConfig { + grpc_max_message_size: DEFAULT_GRPC_MAX_MESSAGE_SIZE, + socket_path: DEFAULT_MAP_SOCKET.to_string(), + server_info_path: DEFAULT_MAP_SERVER_INFO_FILE.to_string(), + })) + } else { + Err(Error::Config("Invalid UDF".to_string())) + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) struct UserDefinedConfig { + pub grpc_max_message_size: usize, + pub socket_path: String, + pub server_info_path: String, + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) struct BuiltinConfig { + pub(crate) name: String, + pub(crate) kwargs: Option>, + pub(crate) args: Option>, + } +} + #[derive(Debug, Clone, PartialEq)] pub(crate) struct SinkVtxConfig { pub(crate) sink_config: SinkConfig, @@ -79,6 +164,7 @@ pub(crate) struct SinkVtxConfig { pub(crate) enum VertexType { Source(SourceVtxConfig), Sink(SinkVtxConfig), + Map(MapVtxConfig), } impl std::fmt::Display for VertexType { @@ -86,6 +172,7 @@ impl std::fmt::Display for VertexType { match self { VertexType::Source(_) => write!(f, "Source"), VertexType::Sink(_) => write!(f, "Sink"), + VertexType::Map(_) => write!(f, "Map"), } } } @@ -182,6 +269,12 @@ impl PipelineConfig { }, fb_sink_config, }) + } else if let Some(map) = vertex_obj.spec.udf { + VertexType::Map(MapVtxConfig { + concurrency: batch_size as usize, + map_type: map.try_into()?, + map_mode: MapMode::Unary, + }) } else { return Err(Error::Config( "Only source and sink are supported ATM".to_string(), @@ -283,7 +376,7 @@ impl PipelineConfig { Ok(PipelineConfig { batch_size: batch_size as usize, paf_concurrency: env::var("PAF_BATCH_SIZE") - .unwrap_or("30000".to_string()) + .unwrap_or((DEFAULT_BATCH_SIZE * 2).to_string()) .parse() .unwrap(), read_timeout: Duration::from_millis(timeout_in_ms as u64), @@ -301,11 +394,13 @@ impl PipelineConfig { #[cfg(test)] mod tests { + use numaflow_models::models::{Container, Function, Udf}; use numaflow_pulsar::source::PulsarSourceConfig; use super::*; use crate::config::components::sink::{BlackholeConfig, LogConfig, SinkType}; use crate::config::components::source::{GeneratorConfig, SourceType}; + use crate::config::pipeline::map::{MapType, UserDefinedConfig}; #[test] fn test_default_pipeline_config() { @@ -360,7 +455,7 @@ mod tests { vertex_name: "out".to_string(), replica: 0, batch_size: 500, - paf_concurrency: 30000, + paf_concurrency: 1000, read_timeout: Duration::from_secs(1), js_client_config: isb::jetstream::ClientConfig { url: "localhost:4222".to_string(), @@ -371,7 +466,7 @@ mod tests { name: "in".to_string(), reader_config: BufferReaderConfig { partitions: 1, - streams: vec![("default-simple-pipeline-out-0".into(), 0)], + streams: vec![("default-simple-pipeline-out-0", 0)], wip_ack_interval: Duration::from_secs(1), }, partitions: 0, @@ -407,7 +502,7 @@ mod tests { vertex_name: "in".to_string(), replica: 0, batch_size: 1000, - paf_concurrency: 30000, + paf_concurrency: 1000, read_timeout: Duration::from_secs(1), js_client_config: isb::jetstream::ClientConfig { url: "localhost:4222".to_string(), @@ -460,7 +555,7 @@ mod tests { vertex_name: "in".to_string(), replica: 0, batch_size: 50, - paf_concurrency: 30000, + paf_concurrency: 1000, read_timeout: Duration::from_secs(1), js_client_config: isb::jetstream::ClientConfig { url: "localhost:4222".to_string(), @@ -498,4 +593,120 @@ mod tests { assert_eq!(pipeline_config, expected); } + + #[test] + fn test_map_vertex_config_user_defined() { + let udf = Udf { + builtin: None, + container: Some(Box::from(Container { + args: None, + command: None, + env: None, + env_from: None, + image: None, + image_pull_policy: None, + liveness_probe: None, + ports: None, + readiness_probe: None, + resources: None, + security_context: None, + volume_mounts: None, + })), + group_by: None, + }; + + let map_type = MapType::try_from(Box::new(udf)).unwrap(); + assert!(matches!(map_type, MapType::UserDefined(_))); + + let map_vtx_config = MapVtxConfig { + concurrency: 10, + map_type, + map_mode: MapMode::Unary, + }; + + assert_eq!(map_vtx_config.concurrency, 10); + if let MapType::UserDefined(config) = map_vtx_config.map_type { + assert_eq!(config.grpc_max_message_size, DEFAULT_GRPC_MAX_MESSAGE_SIZE); + assert_eq!(config.socket_path, DEFAULT_MAP_SOCKET); + assert_eq!(config.server_info_path, DEFAULT_MAP_SERVER_INFO_FILE); + } else { + panic!("Expected UserDefined map type"); + } + } + + #[test] + fn test_map_vertex_config_builtin() { + let udf = Udf { + builtin: Some(Box::from(Function { + args: None, + kwargs: None, + name: "cat".to_string(), + })), + container: None, + group_by: None, + }; + + let map_type = MapType::try_from(Box::new(udf)).unwrap(); + assert!(matches!(map_type, MapType::Builtin(_))); + + let map_vtx_config = MapVtxConfig { + concurrency: 5, + map_type, + map_mode: MapMode::Unary, + }; + + assert_eq!(map_vtx_config.concurrency, 5); + if let MapType::Builtin(config) = map_vtx_config.map_type { + assert_eq!(config.name, "cat"); + assert!(config.kwargs.is_none()); + assert!(config.args.is_none()); + } else { + panic!("Expected Builtin map type"); + } + } + + #[test] + fn test_pipeline_config_load_map_vertex() { + let pipeline_cfg_base64 = "eyJtZXRhZGF0YSI6eyJuYW1lIjoic2ltcGxlLXBpcGVsaW5lLW1hcCIsIm5hbWVzcGFjZSI6ImRlZmF1bHQiLCJjcmVhdGlvblRpbWVzdGFtcCI6bnVsbH0sInNwZWMiOnsibmFtZSI6Im1hcCIsInVkZiI6eyJjb250YWluZXIiOnsidGVtcGxhdGUiOiJkZWZhdWx0In19LCJsaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9LCJzY2FsZSI6eyJtaW4iOjF9LCJwaXBlbGluZU5hbWUiOiJzaW1wbGUtcGlwZWxpbmUiLCJpbnRlclN0ZXBCdWZmZXJTZXJ2aWNlTmFtZSI6IiIsInJlcGxpY2FzIjowLCJmcm9tRWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoibWFwIiwiY29uZGl0aW9ucyI6bnVsbCwiZnJvbVZlcnRleFR5cGUiOiJTb3VyY2UiLCJmcm9tVmVydGV4UGFydGl0aW9uQ291bnQiOjEsImZyb21WZXJ0ZXhMaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9LCJ0b1ZlcnRleFR5cGUiOiJNYXAiLCJ0b1ZlcnRleFBhcnRpdGlvbkNvdW50IjoxLCJ0b1ZlcnRleExpbWl0cyI6eyJyZWFkQmF0Y2hTaXplIjo1MDAsInJlYWRUaW1lb3V0IjoiMXMiLCJidWZmZXJNYXhMZW5ndGgiOjMwMDAwLCJidWZmZXJVc2FnZUxpbWl0Ijo4MH19XSwid2F0ZXJtYXJrIjp7Im1heERlbGF5IjoiMHMifX0sInN0YXR1cyI6eyJwaGFzZSI6IiIsInJlcGxpY2FzIjowLCJkZXNpcmVkUmVwbGljYXMiOjAsImxhc3RTY2FsZWRBdCI6bnVsbH19"; + + let env_vars = [("NUMAFLOW_ISBSVC_JETSTREAM_URL", "localhost:4222")]; + let pipeline_config = + PipelineConfig::load(pipeline_cfg_base64.to_string(), env_vars).unwrap(); + + let expected = PipelineConfig { + pipeline_name: "simple-pipeline".to_string(), + vertex_name: "map".to_string(), + replica: 0, + batch_size: 500, + paf_concurrency: 1000, + read_timeout: Duration::from_secs(1), + js_client_config: isb::jetstream::ClientConfig { + url: "localhost:4222".to_string(), + user: None, + password: None, + }, + from_vertex_config: vec![FromVertexConfig { + name: "in".to_string(), + reader_config: BufferReaderConfig { + partitions: 1, + streams: vec![("default-simple-pipeline-map-0", 0)], + wip_ack_interval: Duration::from_secs(1), + }, + partitions: 0, + }], + to_vertex_config: vec![], + vertex_config: VertexType::Map(MapVtxConfig { + concurrency: 500, + map_type: MapType::UserDefined(UserDefinedConfig { + grpc_max_message_size: DEFAULT_GRPC_MAX_MESSAGE_SIZE, + socket_path: DEFAULT_MAP_SOCKET.to_string(), + server_info_path: DEFAULT_MAP_SERVER_INFO_FILE.to_string(), + }), + map_mode: MapMode::Unary, + }), + metrics_config: MetricsConfig::default(), + }; + + assert_eq!(pipeline_config, expected); + } } diff --git a/rust/numaflow-core/src/error.rs b/rust/numaflow-core/src/error.rs index e82a93e2d..0e499d068 100644 --- a/rust/numaflow-core/src/error.rs +++ b/rust/numaflow-core/src/error.rs @@ -16,6 +16,9 @@ pub enum Error { #[error("Transformer Error - {0}")] Transformer(String), + #[error("Mapper Error - {0}")] + Mapper(String), + #[error("Forwarder Error - {0}")] Forwarder(String), diff --git a/rust/numaflow-core/src/lib.rs b/rust/numaflow-core/src/lib.rs index 727a119f1..d65380f8d 100644 --- a/rust/numaflow-core/src/lib.rs +++ b/rust/numaflow-core/src/lib.rs @@ -51,6 +51,9 @@ mod pipeline; /// Tracker to track the completeness of message processing. mod tracker; +/// Map is a feature that allows users to execute custom code to transform their data. +mod mapper; + pub async fn run() -> Result<()> { let cln_token = CancellationToken::new(); let shutdown_cln_token = cln_token.clone(); diff --git a/rust/numaflow-core/src/mapper.rs b/rust/numaflow-core/src/mapper.rs new file mode 100644 index 000000000..56d0f51f3 --- /dev/null +++ b/rust/numaflow-core/src/mapper.rs @@ -0,0 +1,31 @@ +//! Numaflow supports flatmap operation through [map::MapHandle] an actor interface. +//! +//! The [map::MapHandle] orchestrates reading messages from the input stream, invoking the map operation, +//! and sending the mapped messages to the output stream. +//! +//! The [map::MapHandle] reads messages from the input stream and invokes the map operation based on the +//! mode: +//! - Unary: Concurrent operations controlled using permits and `tokio::spawn`. +//! - Batch: Synchronous operations, one batch at a time, followed by an invoke. +//! - Stream: Concurrent operations controlled using permits and `tokio::spawn`, followed by an +//! invoke. +//! +//! Error handling in unary and stream operations with concurrency N: +//! ```text +//! (Read) <----- (error_tx) <-------- + +//! | | +//! + -->-- (tokio map task 1) -->--- + +//! | | +//! + -->-- (tokio map task 2) -->--- + +//! | | +//! : : +//! | | +//! + -->-- (tokio map task N) -->--- + +//! ``` +//! In case of errors in unary/stream, tasks will write to the error channel (`error_tx`), and the `MapHandle` +//! will stop reading new requests and return an error. +//! +//! Error handling in batch operation is easier because it is synchronous and one batch at a time. If there +//! is an error, the [map::MapHandle] will stop reading new requests and return an error. + +pub(crate) mod map; diff --git a/rust/numaflow-core/src/mapper/map.rs b/rust/numaflow-core/src/mapper/map.rs new file mode 100644 index 000000000..8c279376a --- /dev/null +++ b/rust/numaflow-core/src/mapper/map.rs @@ -0,0 +1,1176 @@ +use crate::config::pipeline::map::MapMode; +use crate::error; +use crate::error::Error; +use crate::mapper::map::user_defined::{ + UserDefinedBatchMap, UserDefinedStreamMap, UserDefinedUnaryMap, +}; +use crate::message::Message; +use crate::tracker::TrackerHandle; +use numaflow_pb::clients::map::map_client::MapClient; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use tonic::transport::Channel; +pub(super) mod user_defined; + +/// UnaryActorMessage is a message that is sent to the UnaryMapperActor. +struct UnaryActorMessage { + message: Message, + respond_to: oneshot::Sender>>, +} + +/// BatchActorMessage is a message that is sent to the BatchMapperActor. +struct BatchActorMessage { + messages: Vec, + respond_to: Vec>>>, +} + +/// StreamActorMessage is a message that is sent to the StreamMapperActor. +struct StreamActorMessage { + message: Message, + respond_to: mpsc::Sender>, +} + +/// UnaryMapperActor is responsible for handling the unary map operation. +struct UnaryMapperActor { + receiver: mpsc::Receiver, + mapper: UserDefinedUnaryMap, +} + +impl UnaryMapperActor { + fn new(receiver: mpsc::Receiver, mapper: UserDefinedUnaryMap) -> Self { + Self { receiver, mapper } + } + + async fn handle_message(&mut self, msg: UnaryActorMessage) { + self.mapper.unary_map(msg.message, msg.respond_to).await; + } + + async fn run(mut self) { + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg).await; + } + } +} + +/// BatchMapActor is responsible for handling the batch map operation. +struct BatchMapActor { + receiver: mpsc::Receiver, + mapper: UserDefinedBatchMap, +} + +impl BatchMapActor { + fn new(receiver: mpsc::Receiver, mapper: UserDefinedBatchMap) -> Self { + Self { receiver, mapper } + } + + async fn handle_message(&mut self, msg: BatchActorMessage) { + self.mapper.batch_map(msg.messages, msg.respond_to).await; + } + + async fn run(mut self) { + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg).await; + } + } +} + +/// StreamMapActor is responsible for handling the stream map operation. +struct StreamMapActor { + receiver: mpsc::Receiver, + mapper: UserDefinedStreamMap, +} + +impl StreamMapActor { + fn new(receiver: mpsc::Receiver, mapper: UserDefinedStreamMap) -> Self { + Self { receiver, mapper } + } + + async fn handle_message(&mut self, msg: StreamActorMessage) { + self.mapper.stream_map(msg.message, msg.respond_to).await; + } + + async fn run(mut self) { + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg).await; + } + } +} + +/// ActorSender is an enum to store the handles to different types of actors. +#[derive(Clone)] +enum ActorSender { + Unary(mpsc::Sender), + Batch(mpsc::Sender), + Stream(mpsc::Sender), +} + +/// MapHandle is responsible for reading messages from the stream and invoke the map operation +/// on those messages and send the mapped messages to the output stream. +pub(crate) struct MapHandle { + batch_size: usize, + read_timeout: Duration, + concurrency: usize, + tracker: TrackerHandle, + actor_sender: ActorSender, + task_handle: JoinHandle<()>, +} + +/// Abort all the background tasks when the mapper is dropped. +impl Drop for MapHandle { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +/// Response channel size for streaming map. +const STREAMING_MAP_RESP_CHANNEL_SIZE: usize = 10; + +impl MapHandle { + /// Creates a new mapper with the given batch size, concurrency, client, and tracker handle. + /// It spawns the appropriate actor based on the map mode. + pub(crate) async fn new( + map_mode: MapMode, + batch_size: usize, + read_timeout: Duration, + concurrency: usize, + client: MapClient, + tracker_handle: TrackerHandle, + ) -> error::Result { + let task_handle; + + // Based on the map mode, spawn the appropriate map actor + // and store the sender handle in the actor_sender. + let actor_sender = match map_mode { + MapMode::Unary => { + let (sender, receiver) = mpsc::channel(batch_size); + let mapper_actor = UnaryMapperActor::new( + receiver, + UserDefinedUnaryMap::new(batch_size, client).await?, + ); + + let handle = tokio::spawn(async move { + mapper_actor.run().await; + }); + task_handle = handle; + ActorSender::Unary(sender) + } + MapMode::Batch => { + let (batch_sender, batch_receiver) = mpsc::channel(batch_size); + let batch_mapper_actor = BatchMapActor::new( + batch_receiver, + UserDefinedBatchMap::new(batch_size, client).await?, + ); + + let handle = tokio::spawn(async move { + batch_mapper_actor.run().await; + }); + task_handle = handle; + ActorSender::Batch(batch_sender) + } + MapMode::Stream => { + let (stream_sender, stream_receiver) = mpsc::channel(batch_size); + let stream_mapper_actor = StreamMapActor::new( + stream_receiver, + UserDefinedStreamMap::new(batch_size, client).await?, + ); + + let handle = tokio::spawn(async move { + stream_mapper_actor.run().await; + }); + task_handle = handle; + ActorSender::Stream(stream_sender) + } + }; + + Ok(Self { + actor_sender, + batch_size, + read_timeout, + concurrency, + tracker: tracker_handle, + task_handle, + }) + } + + /// Maps the input stream of messages and returns the output stream and the handle to the + /// background task. In case of critical errors it stops reading from the input stream and + /// returns the error using the join handle. + pub(crate) async fn streaming_map( + &self, + input_stream: ReceiverStream, + ) -> error::Result<(ReceiverStream, JoinHandle>)> { + let (output_tx, output_rx) = mpsc::channel(self.batch_size); + let (error_tx, mut error_rx) = mpsc::channel(1); + + let actor_handle = self.actor_sender.clone(); + let tracker = self.tracker.clone(); + let semaphore = Arc::new(Semaphore::new(self.concurrency)); + let batch_size = self.batch_size; + let read_timeout = self.read_timeout; + + let handle = tokio::spawn(async move { + let mut input_stream = input_stream; + + // based on the map mode, send the message to the appropriate actor handle. + match actor_handle { + ActorSender::Unary(map_handle) => loop { + // we need tokio select here because we have to listen to both the input stream + // and the error channel. If there is an error, we need to discard all the messages + // in the tracker and stop processing the input stream. + tokio::select! { + read_msg = input_stream.next() => { + if let Some(read_msg) = read_msg { + let permit = Arc::clone(&semaphore).acquire_owned().await.map_err(|e| Error::Mapper(format!("failed to acquire semaphore: {}", e)))?; + let error_tx = error_tx.clone(); + Self::unary( + map_handle.clone(), + permit, + read_msg, + output_tx.clone(), + tracker.clone(), + error_tx, + ).await; + } else { + break; + } + }, + Some(error) = error_rx.recv() => { + // if there is an error, discard all the messages in the tracker and return the error. + tracker.discard_all().await?; + return Err(error); + }, + } + }, + + ActorSender::Batch(map_handle) => { + let timeout_duration = read_timeout; + let chunked_stream = input_stream.chunks_timeout(batch_size, timeout_duration); + tokio::pin!(chunked_stream); + // we don't need to tokio spawn here because, unlike unary and stream, batch is a blocking operation, + // and we process one batch at a time. + while let Some(batch) = chunked_stream.next().await { + if !batch.is_empty() { + if let Err(e) = Self::batch( + map_handle.clone(), + batch, + output_tx.clone(), + tracker.clone(), + ) + .await + { + // if there is an error, discard all the messages in the tracker and return the error. + tracker.discard_all().await?; + return Err(e); + } + } + } + } + + ActorSender::Stream(map_handle) => loop { + // we need tokio select here because we have to listen to both the input stream + // and the error channel. If there is an error, we need to discard all the messages + // in the tracker and stop processing the input stream. + tokio::select! { + read_msg = input_stream.next() => { + if let Some(read_msg) = read_msg { + let permit = Arc::clone(&semaphore).acquire_owned().await.map_err(|e| Error::Mapper(format!("failed to acquire semaphore: {}", e)))?; + let error_tx = error_tx.clone(); + Self::stream( + map_handle.clone(), + permit, + read_msg, + output_tx.clone(), + tracker.clone(), + error_tx, + ).await; + } else { + break; + } + }, + Some(error) = error_rx.recv() => { + // if there is an error, discard all the messages in the tracker and return the error. + tracker.discard_all().await?; + return Err(error); + }, + } + }, + } + Ok(()) + }); + + Ok((ReceiverStream::new(output_rx), handle)) + } + + /// performs unary map operation on the given message and sends the mapped messages to the output + /// stream. It updates the tracker with the number of messages sent. If there are any errors, it + /// sends the error to the error channel. + /// + /// We use permit to limit the number of concurrent map unary operations, so that at any point in time + /// we don't have more than `concurrency` number of map operations running. + async fn unary( + map_handle: mpsc::Sender, + permit: OwnedSemaphorePermit, + read_msg: Message, + output_tx: mpsc::Sender, + tracker_handle: TrackerHandle, + error_tx: mpsc::Sender, + ) { + let output_tx = output_tx.clone(); + + // short-lived tokio spawns we don't need structured concurrency here + tokio::spawn(async move { + let _permit = permit; + + let (sender, receiver) = oneshot::channel(); + let msg = UnaryActorMessage { + message: read_msg.clone(), + respond_to: sender, + }; + + if let Err(e) = map_handle.send(msg).await { + let _ = error_tx + .send(Error::Mapper(format!("failed to send message: {}", e))) + .await; + return; + } + + match receiver.await { + Ok(Ok(mut mapped_messages)) => { + // update the tracker with the number of messages sent and send the mapped messages + if let Err(e) = tracker_handle + .update( + read_msg.id.offset.clone(), + mapped_messages.len() as u32, + true, + ) + .await + { + error_tx.send(e).await.expect("failed to send error"); + return; + } + for mapped_message in mapped_messages.drain(..) { + output_tx + .send(mapped_message) + .await + .expect("failed to send response"); + } + } + Ok(Err(e)) => { + error_tx.send(e).await.expect("failed to send error"); + } + Err(e) => { + error_tx + .send(Error::Mapper(format!("failed to receive message: {}", e))) + .await + .expect("failed to send error"); + } + } + }); + } + + /// performs batch map operation on the given batch of messages and sends the mapped messages to + /// the output stream. It updates the tracker with the number of messages sent. + async fn batch( + map_handle: mpsc::Sender, + batch: Vec, + output_tx: mpsc::Sender, + tracker_handle: TrackerHandle, + ) -> error::Result<()> { + let (senders, receivers): (Vec<_>, Vec<_>) = + batch.iter().map(|_| oneshot::channel()).unzip(); + let msg = BatchActorMessage { + messages: batch, + respond_to: senders, + }; + + map_handle + .send(msg) + .await + .map_err(|e| Error::Mapper(format!("failed to send message: {}", e)))?; + + for receiver in receivers { + match receiver.await { + Ok(Ok(mut mapped_messages)) => { + let offset = mapped_messages.first().unwrap().id.offset.clone(); + tracker_handle + .update(offset.clone(), mapped_messages.len() as u32, true) + .await?; + for mapped_message in mapped_messages.drain(..) { + output_tx + .send(mapped_message) + .await + .expect("failed to send response"); + } + } + Ok(Err(e)) => { + return Err(e); + } + Err(e) => { + return Err(Error::Mapper(format!("failed to receive message: {}", e))); + } + } + } + Ok(()) + } + + /// performs stream map operation on the given message and sends the mapped messages to the output + /// stream. It updates the tracker with the number of messages sent. If there are any errors, + /// it sends the error to the error channel. + /// + /// We use permit to limit the number of concurrent map unary operations, so that at any point in time + /// we don't have more than `concurrency` number of map operations running. + async fn stream( + map_handle: mpsc::Sender, + permit: OwnedSemaphorePermit, + read_msg: Message, + output_tx: mpsc::Sender, + tracker_handle: TrackerHandle, + error_tx: mpsc::Sender, + ) { + let output_tx = output_tx.clone(); + + tokio::spawn(async move { + let _permit = permit; + + let (sender, mut receiver) = mpsc::channel(STREAMING_MAP_RESP_CHANNEL_SIZE); + let msg = StreamActorMessage { + message: read_msg.clone(), + respond_to: sender, + }; + + if let Err(e) = map_handle.send(msg).await { + let _ = error_tx + .send(Error::Mapper(format!("failed to send message: {}", e))) + .await; + return; + } + + while let Some(result) = receiver.recv().await { + match result { + Ok(mapped_message) => { + let offset = mapped_message.id.offset.clone(); + if let Err(e) = tracker_handle.update(offset.clone(), 1, false).await { + error_tx.send(e).await.expect("failed to send error"); + return; + } + if let Err(e) = output_tx.send(mapped_message).await { + error_tx + .send(Error::Mapper(format!("failed to send message: {}", e))) + .await + .expect("failed to send error"); + return; + } + } + Err(e) => { + error_tx.send(e).await.expect("failed to send error"); + return; + } + } + } + + if let Err(e) = tracker_handle.update(read_msg.id.offset, 0, true).await { + error_tx.send(e).await.expect("failed to send error"); + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Result; + use std::time::Duration; + + use crate::message::{MessageID, Offset, StringOffset}; + use crate::shared::grpc::create_rpc_channel; + use numaflow::mapstream; + use numaflow::{batchmap, map}; + use numaflow_pb::clients::map::map_client::MapClient; + use tempfile::TempDir; + use tokio::sync::mpsc::Sender; + use tokio::sync::oneshot; + + struct SimpleMapper; + + #[tonic::async_trait] + impl map::Mapper for SimpleMapper { + async fn map(&self, input: map::MapRequest) -> Vec { + let message = map::Message::new(input.value) + .keys(input.keys) + .tags(vec!["test".to_string()]); + vec![message] + } + } + + #[tokio::test] + async fn mapper_operations() -> Result<()> { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + map::Server::new(SimpleMapper) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + let tracker_handle = TrackerHandle::new(); + + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Unary, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let message = Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(Offset::String(crate::message::StringOffset::new( + "0".to_string(), + 0, + ))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + let (output_tx, mut output_rx) = mpsc::channel(10); + + let semaphore = Arc::new(Semaphore::new(10)); + let permit = semaphore.acquire_owned().await.unwrap(); + let (error_tx, mut error_rx) = mpsc::channel(1); + + let ActorSender::Unary(input_tx) = mapper.actor_sender.clone() else { + panic!("Expected Unary actor sender"); + }; + + MapHandle::unary( + input_tx, + permit, + message, + output_tx, + tracker_handle, + error_tx, + ) + .await; + + // check for errors + assert!(error_rx.recv().await.is_none()); + + let mapped_message = output_rx.recv().await.unwrap(); + assert_eq!(mapped_message.value, "hello"); + + // we need to drop the mapper, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } + + #[tokio::test] + async fn test_map_stream() -> Result<()> { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + map::Server::new(SimpleMapper) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let tracker_handle = TrackerHandle::new(); + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Unary, + 10, + Duration::from_millis(10), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + for i in 0..5 { + let message = Message { + keys: Arc::from(vec![format!("key_{}", i)]), + tags: None, + value: format!("value_{}", i).into(), + offset: Some(Offset::String(StringOffset::new(i.to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: i.to_string().into(), + index: i, + }, + headers: Default::default(), + }; + input_tx.send(message).await.unwrap(); + } + drop(input_tx); + + let (output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + + let mut output_rx = output_stream.into_inner(); + + for i in 0..5 { + let mapped_message = output_rx.recv().await.unwrap(); + assert_eq!(mapped_message.value, format!("value_{}", i)); + } + + // we need to drop the mapper, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + assert!( + map_handle.is_finished(), + "Expected mapper to have shut down" + ); + Ok(()) + } + + struct PanicCat; + + #[tonic::async_trait] + impl map::Mapper for PanicCat { + async fn map(&self, _input: map::MapRequest) -> Vec { + panic!("PanicCat panicked!"); + } + } + + #[tokio::test] + async fn test_map_stream_with_panic() -> Result<()> { + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + map::Server::new(PanicCat) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start() + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let tracker_handle = TrackerHandle::new(); + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Unary, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + let message = Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + input_tx.send(message).await.unwrap(); + + let (_output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + + // Await the join handle and expect an error due to the panic + let result = map_handle.await.unwrap(); + assert!(result.is_err(), "Expected an error due to panic"); + assert!(result + .unwrap_err() + .to_string() + .contains("PanicCat panicked!")); + + // we need to drop the mapper, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } + + struct SimpleBatchMap; + + #[tonic::async_trait] + impl batchmap::BatchMapper for SimpleBatchMap { + async fn batchmap( + &self, + mut input: tokio::sync::mpsc::Receiver, + ) -> Vec { + let mut responses: Vec = Vec::new(); + while let Some(datum) = input.recv().await { + let mut response = batchmap::BatchResponse::from_id(datum.id); + response.append(batchmap::Message { + keys: Option::from(datum.keys), + value: datum.value, + tags: None, + }); + responses.push(response); + } + responses + } + } + + #[tokio::test] + async fn batch_mapper_operations() -> Result<()> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("batch_map.sock"); + let server_info_file = tmp_dir.path().join("batch_map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + batchmap::Server::new(SimpleBatchMap) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + let tracker_handle = TrackerHandle::new(); + + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Batch, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let messages = vec![ + Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }, + Message { + keys: Arc::from(vec!["second".into()]), + tags: None, + value: "world".into(), + offset: Some(Offset::String(StringOffset::new("1".to_string(), 1))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "1".to_string().into(), + index: 1, + }, + headers: Default::default(), + }, + ]; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + for message in messages { + input_tx.send(message).await.unwrap(); + } + drop(input_tx); + + let (output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + let mut output_rx = output_stream.into_inner(); + + let mapped_message1 = output_rx.recv().await.unwrap(); + assert_eq!(mapped_message1.value, "hello"); + + let mapped_message2 = output_rx.recv().await.unwrap(); + assert_eq!(mapped_message2.value, "world"); + + // we need to drop the mapper, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + assert!( + map_handle.is_finished(), + "Expected mapper to have shut down" + ); + Ok(()) + } + + struct PanicBatchMap; + + #[tonic::async_trait] + impl batchmap::BatchMapper for PanicBatchMap { + async fn batchmap( + &self, + _input: mpsc::Receiver, + ) -> Vec { + panic!("PanicBatchMap panicked!"); + } + } + + #[tokio::test] + async fn test_batch_map_with_panic() -> Result<()> { + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("batch_map_panic.sock"); + let server_info_file = tmp_dir.path().join("batch_map_panic-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + batchmap::Server::new(PanicBatchMap) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let tracker_handle = TrackerHandle::new(); + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Batch, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let messages = vec![ + Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }, + Message { + keys: Arc::from(vec!["second".into()]), + tags: None, + value: "world".into(), + offset: Some(Offset::String(StringOffset::new("1".to_string(), 1))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "1".to_string().into(), + index: 1, + }, + headers: Default::default(), + }, + ]; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + for message in messages { + input_tx.send(message).await.unwrap(); + } + drop(input_tx); + + let (_output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + + // Await the join handle and expect an error due to the panic + let result = map_handle.await.unwrap(); + assert!(result.is_err(), "Expected an error due to panic"); + + // we need to drop the mapper, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } + + struct FlatmapStream; + + #[tonic::async_trait] + impl mapstream::MapStreamer for FlatmapStream { + async fn map_stream( + &self, + input: mapstream::MapStreamRequest, + tx: Sender, + ) { + let payload_str = String::from_utf8(input.value).unwrap_or_default(); + let splits: Vec<&str> = payload_str.split(',').collect(); + + for split in splits { + let message = mapstream::Message::new(split.as_bytes().to_vec()) + .keys(input.keys.clone()) + .tags(vec![]); + if tx.send(message).await.is_err() { + break; + } + } + } + } + + #[tokio::test] + async fn map_stream_operations() -> Result<()> { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map_stream.sock"); + let server_info_file = tmp_dir.path().join("map_stream-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let _handle = tokio::spawn(async move { + mapstream::Server::new(FlatmapStream) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + let tracker_handle = TrackerHandle::new(); + + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let mapper = MapHandle::new( + MapMode::Stream, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle.clone(), + ) + .await?; + + let message = Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "test,map,stream".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + input_tx.send(message).await.unwrap(); + drop(input_tx); + + let (mut output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + + let mut responses = vec![]; + while let Some(response) = output_stream.next().await { + responses.push(response); + } + + assert_eq!(responses.len(), 3); + // convert the bytes value to string and compare + let values: Vec = responses + .iter() + .map(|r| String::from_utf8(Vec::from(r.value.clone())).unwrap()) + .collect(); + assert_eq!(values, vec!["test", "map", "stream"]); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + map_handle.is_finished(), + "Expected mapper to have shut down" + ); + Ok(()) + } + + struct PanicFlatmapStream; + + #[tonic::async_trait] + impl mapstream::MapStreamer for PanicFlatmapStream { + async fn map_stream( + &self, + _input: mapstream::MapStreamRequest, + _tx: Sender, + ) { + panic!("PanicFlatmapStream panicked!"); + } + } + + #[tokio::test] + async fn map_stream_panic_case() -> Result<()> { + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map_stream_panic.sock"); + let server_info_file = tmp_dir.path().join("map_stream_panic-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + mapstream::Server::new(PanicFlatmapStream) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let client = MapClient::new(create_rpc_channel(sock_file).await?); + let tracker_handle = TrackerHandle::new(); + let mapper = MapHandle::new( + MapMode::Stream, + 500, + Duration::from_millis(1000), + 10, + client, + tracker_handle, + ) + .await?; + + let message = Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "panic".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + input_tx.send(message).await.unwrap(); + + let (_output_stream, map_handle) = mapper.streaming_map(input_stream).await?; + + // Await the join handle and expect an error due to the panic + let result = map_handle.await.unwrap(); + assert!(result.is_err(), "Expected an error due to panic"); + assert!(result + .unwrap_err() + .to_string() + .contains("PanicFlatmapStream panicked!")); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(mapper); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } +} diff --git a/rust/numaflow-core/src/mapper/map/user_defined.rs b/rust/numaflow-core/src/mapper/map/user_defined.rs new file mode 100644 index 000000000..6bc816c40 --- /dev/null +++ b/rust/numaflow-core/src/mapper/map/user_defined.rs @@ -0,0 +1,721 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use numaflow_pb::clients::map::{self, map_client::MapClient, MapRequest, MapResponse}; +use tokio::sync::Mutex; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::Channel; +use tonic::{Request, Streaming}; +use tracing::error; + +use crate::config::get_vertex_name; +use crate::error::{Error, Result}; +use crate::message::{Message, MessageID, Offset}; + +type ResponseSenderMap = + Arc>>)>>>; + +type StreamResponseSenderMap = + Arc>)>>>; + +struct ParentMessageInfo { + offset: Offset, + event_time: DateTime, + headers: HashMap, +} + +/// UserDefinedUnaryMap is a grpc client that sends unary requests to the map server +/// and forwards the responses. +pub(in crate::mapper) struct UserDefinedUnaryMap { + read_tx: mpsc::Sender, + senders: ResponseSenderMap, + task_handle: tokio::task::JoinHandle<()>, +} + +/// Abort the background task that receives responses when the UserDefinedBatchMap is dropped. +impl Drop for UserDefinedUnaryMap { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +impl UserDefinedUnaryMap { + /// Performs handshake with the server and creates a new UserDefinedMap. + pub(in crate::mapper) async fn new( + batch_size: usize, + mut client: MapClient, + ) -> Result { + let (read_tx, read_rx) = mpsc::channel(batch_size); + let resp_stream = create_response_stream(read_tx.clone(), read_rx, &mut client).await?; + + // map to track the oneshot sender for each request along with the message info + let sender_map = Arc::new(Mutex::new(HashMap::new())); + + // background task to receive responses from the server and send them to the appropriate + // oneshot sender based on the message id + let task_handle = tokio::spawn(Self::receive_unary_responses( + Arc::clone(&sender_map), + resp_stream, + )); + + let mapper = Self { + read_tx, + senders: sender_map, + task_handle, + }; + + Ok(mapper) + } + + /// receive responses from the server and gets the corresponding oneshot response sender from the map + /// and sends the response. + async fn receive_unary_responses( + sender_map: ResponseSenderMap, + mut resp_stream: Streaming, + ) { + while let Some(resp) = match resp_stream.message().await { + Ok(message) => message, + Err(e) => { + let error = Error::Mapper(format!("failed to receive map response: {}", e)); + let mut senders = sender_map.lock().await; + for (_, (_, sender)) in senders.drain() { + let _ = sender.send(Err(error.clone())); + } + None + } + } { + process_response(&sender_map, resp).await + } + } + + /// Handles the incoming message and sends it to the server for mapping. + pub(in crate::mapper) async fn unary_map( + &mut self, + message: Message, + respond_to: oneshot::Sender>>, + ) { + let key = message.offset.clone().unwrap().to_string(); + let msg_info = ParentMessageInfo { + offset: message.offset.clone().expect("offset can never be none"), + event_time: message.event_time, + headers: message.headers.clone(), + }; + + self.senders + .lock() + .await + .insert(key, (msg_info, respond_to)); + + self.read_tx + .send(message.into()) + .await + .expect("failed to send message"); + } +} + +/// UserDefinedBatchMap is a grpc client that sends batch requests to the map server +/// and forwards the responses. +pub(in crate::mapper) struct UserDefinedBatchMap { + read_tx: mpsc::Sender, + senders: ResponseSenderMap, + task_handle: tokio::task::JoinHandle<()>, +} + +/// Abort the background task that receives responses when the UserDefinedBatchMap is dropped. +impl Drop for UserDefinedBatchMap { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +impl UserDefinedBatchMap { + /// Performs handshake with the server and creates a new UserDefinedMap. + pub(in crate::mapper) async fn new( + batch_size: usize, + mut client: MapClient, + ) -> Result { + let (read_tx, read_rx) = mpsc::channel(batch_size); + let resp_stream = create_response_stream(read_tx.clone(), read_rx, &mut client).await?; + + // map to track the oneshot response sender for each request along with the message info + let sender_map = Arc::new(Mutex::new(HashMap::new())); + + // background task to receive responses from the server and send them to the appropriate + // oneshot response sender based on the id + let task_handle = tokio::spawn(Self::receive_batch_responses( + Arc::clone(&sender_map), + resp_stream, + )); + + let mapper = Self { + read_tx, + senders: sender_map, + task_handle, + }; + Ok(mapper) + } + + /// receive responses from the server and gets the corresponding oneshot response sender from the map + /// and sends the response. + async fn receive_batch_responses( + sender_map: ResponseSenderMap, + mut resp_stream: Streaming, + ) { + while let Some(resp) = match resp_stream.message().await { + Ok(message) => message, + Err(e) => { + let error = Error::Mapper(format!("failed to receive map response: {}", e)); + let mut senders = sender_map.lock().await; + for (_, (_, sender)) in senders.drain() { + sender + .send(Err(error.clone())) + .expect("failed to send error response"); + } + None + } + } { + if let Some(map::TransmissionStatus { eot: true }) = resp.status { + if !sender_map.lock().await.is_empty() { + error!("received EOT but not all responses have been received"); + } + continue; + } + + process_response(&sender_map, resp).await + } + } + + /// Handles the incoming message and sends it to the server for mapping. + pub(in crate::mapper) async fn batch_map( + &mut self, + messages: Vec, + respond_to: Vec>>>, + ) { + for (message, respond_to) in messages.into_iter().zip(respond_to) { + let key = message.offset.clone().unwrap().to_string(); + let msg_info = ParentMessageInfo { + offset: message.offset.clone().expect("offset can never be none"), + event_time: message.event_time, + headers: message.headers.clone(), + }; + + self.senders + .lock() + .await + .insert(key, (msg_info, respond_to)); + self.read_tx + .send(message.into()) + .await + .expect("failed to send message"); + } + + // send eot request + self.read_tx + .send(MapRequest { + request: None, + id: "".to_string(), + handshake: None, + status: Some(map::TransmissionStatus { eot: true }), + }) + .await + .expect("failed to send eot request"); + } +} + +/// Processes the response from the server and sends it to the appropriate oneshot sender +/// based on the message id entry in the map. +async fn process_response(sender_map: &ResponseSenderMap, resp: MapResponse) { + let msg_id = resp.id; + if let Some((msg_info, sender)) = sender_map.lock().await.remove(&msg_id) { + let mut response_messages = vec![]; + for (i, result) in resp.results.into_iter().enumerate() { + let message = Message { + id: MessageID { + vertex_name: get_vertex_name().to_string().into(), + index: i as i32, + offset: msg_info.offset.to_string().into(), + }, + keys: Arc::from(result.keys), + tags: Some(Arc::from(result.tags)), + value: result.value.into(), + offset: Some(msg_info.offset.clone()), + event_time: msg_info.event_time, + headers: msg_info.headers.clone(), + }; + response_messages.push(message); + } + sender + .send(Ok(response_messages)) + .expect("failed to send response"); + } +} + +/// Performs handshake with the server and returns the response stream to receive responses. +async fn create_response_stream( + read_tx: mpsc::Sender, + read_rx: mpsc::Receiver, + client: &mut MapClient, +) -> Result> { + let handshake_request = MapRequest { + request: None, + id: "".to_string(), + handshake: Some(map::Handshake { sot: true }), + status: None, + }; + + read_tx + .send(handshake_request) + .await + .map_err(|e| Error::Mapper(format!("failed to send handshake request: {}", e)))?; + + let mut resp_stream = client + .map_fn(Request::new(ReceiverStream::new(read_rx))) + .await? + .into_inner(); + + let handshake_response = resp_stream.message().await?.ok_or(Error::Mapper( + "failed to receive handshake response".to_string(), + ))?; + + if handshake_response.handshake.map_or(true, |h| !h.sot) { + return Err(Error::Mapper("invalid handshake response".to_string())); + } + + Ok(resp_stream) +} + +/// UserDefinedStreamMap is a grpc client that sends stream requests to the map server +pub(in crate::mapper) struct UserDefinedStreamMap { + read_tx: mpsc::Sender, + senders: StreamResponseSenderMap, + task_handle: tokio::task::JoinHandle<()>, +} + +/// Abort the background task that receives responses when the UserDefinedBatchMap is dropped. +impl Drop for UserDefinedStreamMap { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +impl UserDefinedStreamMap { + /// Performs handshake with the server and creates a new UserDefinedMap. + pub(in crate::mapper) async fn new( + batch_size: usize, + mut client: MapClient, + ) -> Result { + let (read_tx, read_rx) = mpsc::channel(batch_size); + let resp_stream = create_response_stream(read_tx.clone(), read_rx, &mut client).await?; + + // map to track the oneshot response sender for each request along with the message info + let sender_map = Arc::new(Mutex::new(HashMap::new())); + + // background task to receive responses from the server and send them to the appropriate + // mpsc sender based on the id + let task_handle = tokio::spawn(Self::receive_stream_responses( + Arc::clone(&sender_map), + resp_stream, + )); + + let mapper = Self { + read_tx, + senders: sender_map, + task_handle, + }; + Ok(mapper) + } + + /// receive responses from the server and gets the corresponding oneshot sender from the map + /// and sends the response. + async fn receive_stream_responses( + sender_map: StreamResponseSenderMap, + mut resp_stream: Streaming, + ) { + while let Some(resp) = match resp_stream.message().await { + Ok(message) => message, + Err(e) => { + let error = Error::Mapper(format!("failed to receive map response: {}", e)); + let mut senders = sender_map.lock().await; + for (_, (_, sender)) in senders.drain() { + let _ = sender.send(Err(error.clone())).await; + } + None + } + } { + let (message_info, response_sender) = sender_map + .lock() + .await + .remove(&resp.id) + .expect("map entry should always be present"); + + // once we get eot, we can drop the sender to let the callee + // know that we are done sending responses + if let Some(map::TransmissionStatus { eot: true }) = resp.status { + continue; + } + + for (i, result) in resp.results.into_iter().enumerate() { + let message = Message { + id: MessageID { + vertex_name: get_vertex_name().to_string().into(), + index: i as i32, + offset: message_info.offset.to_string().into(), + }, + keys: Arc::from(result.keys), + tags: Some(Arc::from(result.tags)), + value: result.value.into(), + offset: None, + event_time: message_info.event_time, + headers: message_info.headers.clone(), + }; + response_sender + .send(Ok(message)) + .await + .expect("failed to send response"); + } + + // Write the sender back to the map, because we need to send + // more responses for the same request + sender_map + .lock() + .await + .insert(resp.id, (message_info, response_sender)); + } + } + + /// Handles the incoming message and sends it to the server for mapping. + pub(in crate::mapper) async fn stream_map( + &mut self, + message: Message, + respond_to: mpsc::Sender>, + ) { + let key = message.offset.clone().unwrap().to_string(); + let msg_info = ParentMessageInfo { + offset: message.offset.clone().expect("offset can never be none"), + event_time: message.event_time, + headers: message.headers.clone(), + }; + + self.senders + .lock() + .await + .insert(key, (msg_info, respond_to)); + + self.read_tx + .send(message.into()) + .await + .expect("failed to send message"); + } +} + +#[cfg(test)] +mod tests { + use numaflow::mapstream; + use std::error::Error; + use std::sync::Arc; + use std::time::Duration; + + use numaflow::batchmap::Server; + use numaflow::{batchmap, map}; + use numaflow_pb::clients::map::map_client::MapClient; + use tempfile::TempDir; + + use crate::mapper::map::user_defined::{ + UserDefinedBatchMap, UserDefinedStreamMap, UserDefinedUnaryMap, + }; + use crate::message::{MessageID, StringOffset}; + use crate::shared::grpc::create_rpc_channel; + + struct Cat; + + #[tonic::async_trait] + impl map::Mapper for Cat { + async fn map(&self, input: map::MapRequest) -> Vec { + let message = map::Message::new(input.value).keys(input.keys).tags(vec![]); + vec![message] + } + } + + #[tokio::test] + async fn map_operations() -> Result<(), Box> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + map::Server::new(Cat) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = + UserDefinedUnaryMap::new(500, MapClient::new(create_rpc_channel(sock_file).await?)) + .await?; + + let message = crate::message::Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(crate::message::Offset::String(StringOffset::new( + "0".to_string(), + 0, + ))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + let (tx, rx) = tokio::sync::oneshot::channel(); + + tokio::time::timeout(Duration::from_secs(2), client.unary_map(message, tx)) + .await + .unwrap(); + + let messages = rx.await.unwrap(); + assert!(messages.is_ok()); + assert_eq!(messages?.len(), 1); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } + + struct SimpleBatchMap; + + #[tonic::async_trait] + impl batchmap::BatchMapper for SimpleBatchMap { + async fn batchmap( + &self, + mut input: tokio::sync::mpsc::Receiver, + ) -> Vec { + let mut responses: Vec = Vec::new(); + while let Some(datum) = input.recv().await { + let mut response = batchmap::BatchResponse::from_id(datum.id); + response.append(batchmap::Message { + keys: Option::from(datum.keys), + value: datum.value, + tags: None, + }); + responses.push(response); + } + responses + } + } + + #[tokio::test] + async fn batch_map_operations() -> Result<(), Box> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("batch_map.sock"); + let server_info_file = tmp_dir.path().join("batch_map-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + Server::new(SimpleBatchMap) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = + UserDefinedBatchMap::new(500, MapClient::new(create_rpc_channel(sock_file).await?)) + .await?; + + let messages = vec![ + crate::message::Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(crate::message::Offset::String(StringOffset::new( + "0".to_string(), + 0, + ))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }, + crate::message::Message { + keys: Arc::from(vec!["second".into()]), + tags: None, + value: "world".into(), + offset: Some(crate::message::Offset::String(StringOffset::new( + "1".to_string(), + 1, + ))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "1".to_string().into(), + index: 1, + }, + headers: Default::default(), + }, + ]; + + let (tx1, rx1) = tokio::sync::oneshot::channel(); + let (tx2, rx2) = tokio::sync::oneshot::channel(); + + tokio::time::timeout( + Duration::from_secs(2), + client.batch_map(messages, vec![tx1, tx2]), + ) + .await + .unwrap(); + + let messages1 = rx1.await.unwrap(); + let messages2 = rx2.await.unwrap(); + + assert!(messages1.is_ok()); + assert!(messages2.is_ok()); + assert_eq!(messages1?.len(), 1); + assert_eq!(messages2?.len(), 1); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } + + struct FlatmapStream; + + #[tonic::async_trait] + impl mapstream::MapStreamer for FlatmapStream { + async fn map_stream( + &self, + input: mapstream::MapStreamRequest, + tx: tokio::sync::mpsc::Sender, + ) { + let payload_str = String::from_utf8(input.value).unwrap_or_default(); + let splits: Vec<&str> = payload_str.split(',').collect(); + + for split in splits { + let message = mapstream::Message::new(split.as_bytes().to_vec()) + .keys(input.keys.clone()) + .tags(vec![]); + if tx.send(message).await.is_err() { + break; + } + } + } + } + + #[tokio::test] + async fn map_stream_operations() -> Result<(), Box> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("map_stream.sock"); + let server_info_file = tmp_dir.path().join("map_stream-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + mapstream::Server::new(FlatmapStream) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start_with_shutdown(shutdown_rx) + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = + UserDefinedStreamMap::new(500, MapClient::new(create_rpc_channel(sock_file).await?)) + .await?; + + let message = crate::message::Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "test,map,stream".into(), + offset: Some(crate::message::Offset::String(StringOffset::new( + "0".to_string(), + 0, + ))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + let (tx, mut rx) = tokio::sync::mpsc::channel(3); + + tokio::time::timeout(Duration::from_secs(2), client.stream_map(message, tx)) + .await + .unwrap(); + + let mut responses = vec![]; + while let Some(response) = rx.recv().await { + responses.push(response.unwrap()); + } + + assert_eq!(responses.len(), 3); + // convert the bytes value to string and compare + let values: Vec = responses + .iter() + .map(|r| String::from_utf8(Vec::from(r.value.clone())).unwrap()) + .collect(); + assert_eq!(values, vec!["test", "map", "stream"]); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); + + shutdown_tx + .send(()) + .expect("failed to send shutdown signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } +} diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index 2b3ca0b5f..a33b4a704 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -8,6 +8,7 @@ use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::Engine; use bytes::{Bytes, BytesMut}; use chrono::{DateTime, Utc}; +use numaflow_pb::clients::map::MapRequest; use numaflow_pb::clients::sink::sink_request::Request; use numaflow_pb::clients::sink::Status::{Failure, Fallback, Success}; use numaflow_pb::clients::sink::{sink_response, SinkRequest}; @@ -285,7 +286,10 @@ impl From for SourceTransformRequest { Self { request: Some( numaflow_pb::clients::sourcetransformer::source_transform_request::Request { - id: message.id.to_string(), + id: message + .offset + .expect("offset should be present") + .to_string(), keys: message.keys.to_vec(), value: message.value.to_vec(), event_time: prost_timestamp_from_utc(message.event_time), @@ -298,6 +302,23 @@ impl From for SourceTransformRequest { } } +impl From for MapRequest { + fn from(message: Message) -> Self { + Self { + request: Some(numaflow_pb::clients::map::map_request::Request { + keys: message.keys.to_vec(), + value: message.value.to_vec(), + event_time: prost_timestamp_from_utc(message.event_time), + watermark: None, + headers: message.headers, + }), + id: message.offset.unwrap().to_string(), + handshake: None, + status: None, + } + } +} + /// Convert [`read_response::Result`] to [`Message`] impl TryFrom for Message { type Error = Error; diff --git a/rust/numaflow-core/src/metrics.rs b/rust/numaflow-core/src/metrics.rs index 866e58f2c..fa79e457b 100644 --- a/rust/numaflow-core/src/metrics.rs +++ b/rust/numaflow-core/src/metrics.rs @@ -10,6 +10,7 @@ use axum::http::{Response, StatusCode}; use axum::response::IntoResponse; use axum::{routing::get, Router}; use axum_server::tls_rustls::RustlsConfig; +use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; @@ -116,6 +117,7 @@ pub(crate) enum PipelineContainerState { ), ), Sink((Option>, Option>)), + Map(Option>), } /// The global register of all metrics. @@ -689,6 +691,14 @@ async fn sidecar_livez(State(state): State) -> impl I } } } + PipelineContainerState::Map(map_client) => { + if let Some(mut map_client) = map_client { + if map_client.is_ready(Request::new(())).await.is_err() { + error!("Pipeline map client is not ready"); + return StatusCode::INTERNAL_SERVER_ERROR; + } + } + } }, } StatusCode::NO_CONTENT @@ -943,8 +953,8 @@ mod tests { async fn ack(&self, _: Vec) {} - async fn pending(&self) -> usize { - 0 + async fn pending(&self) -> Option { + Some(0) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/monovertex.rs b/rust/numaflow-core/src/monovertex.rs index ba488cc8f..1518a3c9f 100644 --- a/rust/numaflow-core/src/monovertex.rs +++ b/rust/numaflow-core/src/monovertex.rs @@ -127,8 +127,8 @@ mod tests { async fn ack(&self, _: Vec) {} - async fn pending(&self) -> usize { - 0 + async fn pending(&self) -> Option { + Some(0) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/monovertex/forwarder.rs b/rust/numaflow-core/src/monovertex/forwarder.rs index b04868048..51851e4ee 100644 --- a/rust/numaflow-core/src/monovertex/forwarder.rs +++ b/rust/numaflow-core/src/monovertex/forwarder.rs @@ -111,9 +111,9 @@ impl Forwarder { sink_writer_handle, ) { Ok((reader_result, transformer_result, sink_writer_result)) => { - reader_result?; - transformer_result?; sink_writer_result?; + transformer_result?; + reader_result?; Ok(()) } Err(e) => Err(Error::Forwarder(format!( @@ -206,9 +206,11 @@ mod tests { } } - async fn pending(&self) -> usize { - self.num - self.sent_count.load(Ordering::SeqCst) - + self.yet_to_ack.read().unwrap().len() + async fn pending(&self) -> Option { + Some( + self.num - self.sent_count.load(Ordering::SeqCst) + + self.yet_to_ack.read().unwrap().len(), + ) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/pipeline.rs b/rust/numaflow-core/src/pipeline.rs index 434b9aa6d..d2cb77091 100644 --- a/rust/numaflow-core/src/pipeline.rs +++ b/rust/numaflow-core/src/pipeline.rs @@ -7,6 +7,7 @@ use tokio_util::sync::CancellationToken; use tracing::info; use crate::config::pipeline; +use crate::config::pipeline::map::MapVtxConfig; use crate::config::pipeline::{PipelineConfig, SinkVtxConfig, SourceVtxConfig}; use crate::metrics::{PipelineContainerState, UserDefinedContainerState}; use crate::pipeline::forwarder::source_forwarder; @@ -36,6 +37,10 @@ pub(crate) async fn start_forwarder( info!("Starting sink forwarder"); start_sink_forwarder(cln_token, config.clone(), sink.clone()).await?; } + pipeline::VertexType::Map(map) => { + info!("Starting map forwarder"); + start_map_forwarder(cln_token, config.clone(), map.clone()).await?; + } } Ok(()) } @@ -75,8 +80,8 @@ async fn start_source_forwarder( start_metrics_server( config.metrics_config.clone(), UserDefinedContainerState::Pipeline(PipelineContainerState::Source(( - source_grpc_client.clone(), - transformer_grpc_client.clone(), + source_grpc_client, + transformer_grpc_client, ))), ) .await; @@ -94,6 +99,92 @@ async fn start_source_forwarder( Ok(()) } +async fn start_map_forwarder( + cln_token: CancellationToken, + config: PipelineConfig, + map_vtx_config: MapVtxConfig, +) -> Result<()> { + let js_context = create_js_context(config.js_client_config.clone()).await?; + + // Only the reader config of the first "from" vertex is needed, as all "from" vertices currently write + // to a common buffer, in the case of a join. + let reader_config = &config + .from_vertex_config + .first() + .ok_or_else(|| error::Error::Config("No from vertex config found".to_string()))? + .reader_config; + + // Create buffer writers and buffer readers + let mut forwarder_components = vec![]; + let mut mapper_grpc_client = None; + for stream in reader_config.streams.clone() { + let tracker_handle = TrackerHandle::new(); + + let buffer_reader = create_buffer_reader( + stream, + reader_config.clone(), + js_context.clone(), + tracker_handle.clone(), + config.batch_size, + ) + .await?; + + let (mapper, mapper_rpc_client) = create_components::create_mapper( + config.batch_size, + config.read_timeout, + map_vtx_config.clone(), + tracker_handle.clone(), + cln_token.clone(), + ) + .await?; + + if let Some(mapper_rpc_client) = mapper_rpc_client { + mapper_grpc_client = Some(mapper_rpc_client); + } + + let buffer_writer = create_buffer_writer( + &config, + js_context.clone(), + tracker_handle.clone(), + cln_token.clone(), + ) + .await; + forwarder_components.push((buffer_reader, buffer_writer, mapper)); + } + + start_metrics_server( + config.metrics_config.clone(), + UserDefinedContainerState::Pipeline(PipelineContainerState::Map(mapper_grpc_client)), + ) + .await; + + let mut forwarder_tasks = vec![]; + for (buffer_reader, buffer_writer, mapper) in forwarder_components { + info!(%buffer_reader, "Starting forwarder for buffer reader"); + let forwarder = forwarder::map_forwarder::MapForwarder::new( + buffer_reader, + mapper, + buffer_writer, + cln_token.clone(), + ) + .await; + let task = tokio::spawn(async move { forwarder.start().await }); + forwarder_tasks.push(task); + } + + let results = try_join_all(forwarder_tasks) + .await + .map_err(|e| error::Error::Forwarder(e.to_string()))?; + + for result in results { + error!(?result, "Forwarder task failed"); + result?; + } + + info!("All forwarders have stopped successfully"); + Ok(()) +} + async fn start_sink_forwarder( cln_token: CancellationToken, config: PipelineConfig, @@ -120,6 +211,7 @@ async fn start_sink_forwarder( reader_config.clone(), js_context.clone(), tracker_handle.clone(), + config.batch_size, ) .await?; buffer_readers.push(buffer_reader); @@ -159,17 +251,19 @@ async fn start_sink_forwarder( ) .await; - let task = tokio::spawn({ - let config = config.clone(); - async move { forwarder.start(config.clone()).await } - }); - + let task = tokio::spawn(async move { forwarder.start().await }); forwarder_tasks.push(task); } - try_join_all(forwarder_tasks) + let results = try_join_all(forwarder_tasks) .await .map_err(|e| error::Error::Forwarder(e.to_string()))?; + + for result in results { + error!(?result, "Forwarder task failed"); + result?; + } + info!("All forwarders have stopped successfully"); Ok(()) } @@ -194,6 +288,7 @@ async fn create_buffer_reader( reader_config: BufferReaderConfig, js_context: Context, tracker_handle: TrackerHandle, + batch_size: usize, ) -> Result { JetstreamReader::new( stream.0, @@ -201,6 +296,7 @@ async fn create_buffer_reader( js_context, reader_config, tracker_handle, + batch_size, ) .await } @@ -228,12 +324,15 @@ async fn create_js_context(config: pipeline::isb::jetstream::ClientConfig) -> Re #[cfg(test)] mod tests { + use crate::pipeline::pipeline::map::MapMode; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use async_nats::jetstream; use async_nats::jetstream::{consumer, stream}; + use numaflow::map; + use tempfile::TempDir; use tokio_stream::StreamExt; use super::*; @@ -242,6 +341,7 @@ mod tests { use crate::config::components::source::GeneratorConfig; use crate::config::components::source::SourceConfig; use crate::config::components::source::SourceType; + use crate::config::pipeline::map::{MapType, UserDefinedConfig}; use crate::config::pipeline::PipelineConfig; use crate::pipeline::pipeline::isb; use crate::pipeline::pipeline::isb::{BufferReaderConfig, BufferWriterConfig}; @@ -250,6 +350,8 @@ mod tests { use crate::pipeline::pipeline::{SinkVtxConfig, SourceVtxConfig}; use crate::pipeline::tests::isb::BufferFullStrategy::RetryUntilSuccess; + // e2e test for source forwarder, reads from generator and writes to + // multi-partitioned buffer. #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_forwarder_for_source_vertex() { @@ -389,6 +491,8 @@ mod tests { } } + // e2e test for sink forwarder, reads from multi-partitioned buffer and + // writes to sink. #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_forwarder_for_sink_vertex() { @@ -407,9 +511,6 @@ mod tests { const MESSAGE_COUNT: usize = 10; let mut consumers = vec![]; - // Create streams to which the generator source vertex we create later will forward - // messages to. The consumers created for the corresponding streams will be used to ensure - // that messages were actually written to the streams. for stream_name in &streams { let stream_name = *stream_name; // Delete stream if it exists @@ -546,4 +647,247 @@ mod tests { context.delete_stream(stream_name).await.unwrap(); } } + + struct SimpleCat; + + #[tonic::async_trait] + impl map::Mapper for SimpleCat { + async fn map(&self, input: map::MapRequest) -> Vec { + let message = map::Message::new(input.value) + .keys(input.keys) + .tags(vec!["test-forwarder".to_string()]); + vec![message] + } + } + + // e2e test for map forwarder, reads from multi-partitioned buffer, invokes map + // and writes to multi-partitioned buffer. + #[cfg(feature = "nats-tests")] + #[tokio::test] + async fn test_forwarder_for_map_vertex() { + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("mapper-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let _handle = tokio::spawn(async move { + map::Server::new(SimpleCat) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start() + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + // Unique names for the streams we use in this test + let input_streams = vec![ + "default-test-forwarder-for-map-vertex-in-0", + "default-test-forwarder-for-map-vertex-in-1", + "default-test-forwarder-for-map-vertex-in-2", + "default-test-forwarder-for-map-vertex-in-3", + "default-test-forwarder-for-map-vertex-in-4", + ]; + + let output_streams = vec![ + "default-test-forwarder-for-map-vertex-out-0", + "default-test-forwarder-for-map-vertex-out-1", + "default-test-forwarder-for-map-vertex-out-2", + "default-test-forwarder-for-map-vertex-out-3", + "default-test-forwarder-for-map-vertex-out-4", + ]; + + let js_url = "localhost:4222"; + let client = async_nats::connect(js_url).await.unwrap(); + let context = jetstream::new(client); + + const MESSAGE_COUNT: usize = 10; + let mut input_consumers = vec![]; + let mut output_consumers = vec![]; + for stream_name in &input_streams { + let stream_name = *stream_name; + // Delete stream if it exists + let _ = context.delete_stream(stream_name).await; + let _stream = context + .get_or_create_stream(stream::Config { + name: stream_name.into(), + subjects: vec![stream_name.into()], + max_message_size: 64 * 1024, + max_messages: 10000, + ..Default::default() + }) + .await + .unwrap(); + + // Publish some messages into the stream + use chrono::{TimeZone, Utc}; + + use crate::message::{Message, MessageID, Offset, StringOffset}; + let message = Message { + keys: Arc::from(vec!["key1".to_string()]), + tags: None, + value: vec![1, 2, 3].into(), + offset: Some(Offset::String(StringOffset::new("123".to_string(), 0))), + event_time: Utc.timestamp_opt(1627846261, 0).unwrap(), + id: MessageID { + vertex_name: "vertex".to_string().into(), + offset: "123".to_string().into(), + index: 0, + }, + headers: HashMap::new(), + }; + let message: bytes::BytesMut = message.try_into().unwrap(); + + for _ in 0..MESSAGE_COUNT { + context + .publish(stream_name.to_string(), message.clone().into()) + .await + .unwrap() + .await + .unwrap(); + } + + let c: consumer::PullConsumer = context + .create_consumer_on_stream( + consumer::pull::Config { + name: Some(stream_name.to_string()), + ack_policy: consumer::AckPolicy::Explicit, + ..Default::default() + }, + stream_name, + ) + .await + .unwrap(); + + input_consumers.push((stream_name.to_string(), c)); + } + + // Create output streams and consumers + for stream_name in &output_streams { + let stream_name = *stream_name; + // Delete stream if it exists + let _ = context.delete_stream(stream_name).await; + let _stream = context + .get_or_create_stream(stream::Config { + name: stream_name.into(), + subjects: vec![stream_name.into()], + max_message_size: 64 * 1024, + max_messages: 1000, + ..Default::default() + }) + .await + .unwrap(); + + let c: consumer::PullConsumer = context + .create_consumer_on_stream( + consumer::pull::Config { + name: Some(stream_name.to_string()), + ack_policy: consumer::AckPolicy::Explicit, + ..Default::default() + }, + stream_name, + ) + .await + .unwrap(); + output_consumers.push((stream_name.to_string(), c)); + } + + let pipeline_config = PipelineConfig { + pipeline_name: "simple-map-pipeline".to_string(), + vertex_name: "in".to_string(), + replica: 0, + batch_size: 1000, + paf_concurrency: 1000, + read_timeout: Duration::from_secs(1), + js_client_config: isb::jetstream::ClientConfig { + url: "localhost:4222".to_string(), + user: None, + password: None, + }, + to_vertex_config: vec![ToVertexConfig { + name: "map-out".to_string(), + writer_config: BufferWriterConfig { + streams: output_streams + .iter() + .enumerate() + .map(|(i, stream_name)| ((*stream_name).to_string(), i as u16)) + .collect(), + partitions: 5, + max_length: 30000, + usage_limit: 0.8, + buffer_full_strategy: RetryUntilSuccess, + }, + conditions: None, + }], + from_vertex_config: vec![FromVertexConfig { + name: "map-in".to_string(), + reader_config: BufferReaderConfig { + partitions: 5, + streams: input_streams + .iter() + .enumerate() + .map(|(i, key)| (*key, i as u16)) + .collect(), + wip_ack_interval: Duration::from_secs(1), + }, + partitions: 0, + }], + vertex_config: VertexType::Map(MapVtxConfig { + concurrency: 10, + map_type: MapType::UserDefined(UserDefinedConfig { + grpc_max_message_size: 4 * 1024 * 1024, + socket_path: sock_file.to_str().unwrap().to_string(), + server_info_path: server_info_file.to_str().unwrap().to_string(), + }), + map_mode: MapMode::Unary, + }), + metrics_config: MetricsConfig { + metrics_server_listen_port: 2469, + lag_check_interval_in_secs: 5, + lag_refresh_interval_in_secs: 3, + lookback_window_in_secs: 120, + }, + }; + + let cancellation_token = CancellationToken::new(); + let forwarder_task = tokio::spawn({ + let cancellation_token = cancellation_token.clone(); + async move { + start_forwarder(cancellation_token, pipeline_config) + .await + .unwrap(); + } + }); + + // Wait for a few messages to be forwarded + tokio::time::sleep(Duration::from_secs(3)).await; + cancellation_token.cancel(); + // token cancellation is not aborting the forwarder since we fetch messages from jetstream + // as a stream of messages (not using `consumer.batch()`). + // See `JetstreamReader::start` method in src/pipeline/isb/jetstream/reader.rs + //forwarder_task.await.unwrap(); + forwarder_task.abort(); + + // make sure we have mapped and written all messages to downstream + let mut written_count = 0; + for (_, mut stream_consumer) in output_consumers { + written_count += stream_consumer.info().await.unwrap().num_pending; + } + assert_eq!(written_count, (MESSAGE_COUNT * input_streams.len()) as u64); + + // make sure all the upstream messages are read and acked + for (_, mut stream_consumer) in input_consumers { + let con_info = stream_consumer.info().await.unwrap(); + assert_eq!(con_info.num_pending, 0); + assert_eq!(con_info.num_ack_pending, 0); + } + + // Delete all streams created in this test + for stream_name in input_streams.iter().chain(output_streams.iter()) { + context.delete_stream(stream_name).await.unwrap(); + } + } } diff --git a/rust/numaflow-core/src/pipeline/forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder.rs index e87a15ef4..3fb39e5a7 100644 --- a/rust/numaflow-core/src/pipeline/forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder.rs @@ -35,6 +35,10 @@ /// the Write is User-defined Sink or builtin. pub(crate) mod sink_forwarder; +/// Forwarder specific to Mapper where Reader is ISB, UDF is User-defined Mapper, +/// Write is ISB. +pub(crate) mod map_forwarder; + /// Source where the Reader is builtin or User-defined Source, Write is ISB, /// with an optional Transformer. pub(crate) mod source_forwarder; diff --git a/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs new file mode 100644 index 000000000..afc08a667 --- /dev/null +++ b/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs @@ -0,0 +1,63 @@ +use tokio_util::sync::CancellationToken; + +use crate::error::Error; +use crate::mapper::map::MapHandle; +use crate::pipeline::isb::jetstream::reader::JetstreamReader; +use crate::pipeline::isb::jetstream::writer::JetstreamWriter; +use crate::Result; + +/// Map forwarder is a component which starts a streaming reader, a mapper, and a writer +/// and manages the lifecycle of these components. +pub(crate) struct MapForwarder { + jetstream_reader: JetstreamReader, + mapper: MapHandle, + jetstream_writer: JetstreamWriter, + cln_token: CancellationToken, +} + +impl MapForwarder { + pub(crate) async fn new( + jetstream_reader: JetstreamReader, + mapper: MapHandle, + jetstream_writer: JetstreamWriter, + cln_token: CancellationToken, + ) -> Self { + Self { + jetstream_reader, + mapper, + jetstream_writer, + cln_token, + } + } + + pub(crate) async fn start(&self) -> Result<()> { + // Create a child cancellation token only for the reader so that we can stop the reader first + let reader_cancellation_token = self.cln_token.child_token(); + let (read_messages_stream, reader_handle) = self + .jetstream_reader + .streaming_read(reader_cancellation_token.clone()) + .await?; + + let (mapped_messages_stream, mapper_handle) = + self.mapper.streaming_map(read_messages_stream).await?; + + let writer_handle = self + .jetstream_writer + .streaming_write(mapped_messages_stream) + .await?; + + // Join the reader, mapper, and writer + match tokio::try_join!(reader_handle, mapper_handle, writer_handle) { + Ok((reader_result, mapper_result, writer_result)) => { + writer_result?; + mapper_result?; + reader_result?; + Ok(()) + } + Err(e) => Err(Error::Forwarder(format!( + "Error while joining reader, mapper, and writer: {:?}", + e + ))), + } + } +} diff --git a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs index 7153a4ff1..1d560e94e 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs @@ -1,6 +1,5 @@ use tokio_util::sync::CancellationToken; -use crate::config::pipeline::PipelineConfig; use crate::error::Error; use crate::pipeline::isb::jetstream::reader::JetstreamReader; use crate::sink::SinkWriter; @@ -27,12 +26,12 @@ impl SinkForwarder { } } - pub(crate) async fn start(&self, pipeline_config: PipelineConfig) -> Result<()> { + pub(crate) async fn start(&self) -> Result<()> { // Create a child cancellation token only for the reader so that we can stop the reader first let reader_cancellation_token = self.cln_token.child_token(); let (read_messages_stream, reader_handle) = self .jetstream_reader - .streaming_read(reader_cancellation_token.clone(), &pipeline_config) + .streaming_read(reader_cancellation_token.clone()) .await?; let sink_writer_handle = self @@ -43,8 +42,8 @@ impl SinkForwarder { // Join the reader and sink writer match tokio::try_join!(reader_handle, sink_writer_handle) { Ok((reader_result, sink_writer_result)) => { - reader_result?; sink_writer_result?; + reader_result?; Ok(()) } Err(e) => Err(Error::Forwarder(format!( diff --git a/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs index d494cbbd9..b81ddaf80 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs @@ -81,9 +81,9 @@ impl SourceForwarder { writer_handle, ) { Ok((reader_result, transformer_result, sink_writer_result)) => { - reader_result?; - transformer_result?; sink_writer_result?; + transformer_result?; + reader_result?; Ok(()) } Err(e) => Err(Error::Forwarder(format!( @@ -180,9 +180,11 @@ mod tests { } } - async fn pending(&self) -> usize { - self.num - self.sent_count.load(Ordering::SeqCst) - + self.yet_to_ack.read().unwrap().len() + async fn pending(&self) -> Option { + Some( + self.num - self.sent_count.load(Ordering::SeqCst) + + self.yet_to_ack.read().unwrap().len(), + ) } async fn partitions(&self) -> Option> { @@ -212,7 +214,7 @@ mod tests { let cln_token = CancellationToken::new(); let (src_shutdown_tx, src_shutdown_rx) = oneshot::channel(); - let tmp_dir = tempfile::TempDir::new().unwrap(); + let tmp_dir = TempDir::new().unwrap(); let sock_file = tmp_dir.path().join("source.sock"); let server_info_file = tmp_dir.path().join("source-server-info"); diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs index 4513cb918..79b8572ef 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs @@ -12,8 +12,8 @@ use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; use tracing::{error, info}; +use crate::config::get_vertex_name; use crate::config::pipeline::isb::BufferReaderConfig; -use crate::config::pipeline::PipelineConfig; use crate::error::Error; use crate::message::{IntOffset, Message, MessageID, Offset, ReadAck}; use crate::metrics::{ @@ -33,6 +33,7 @@ pub(crate) struct JetstreamReader { config: BufferReaderConfig, consumer: PullConsumer, tracker_handle: TrackerHandle, + batch_size: usize, } impl JetstreamReader { @@ -42,6 +43,7 @@ impl JetstreamReader { js_ctx: Context, config: BufferReaderConfig, tracker_handle: TrackerHandle, + batch_size: usize, ) -> Result { let mut config = config; @@ -69,6 +71,7 @@ impl JetstreamReader { config: config.clone(), consumer, tracker_handle, + batch_size, }) } @@ -81,10 +84,8 @@ impl JetstreamReader { pub(crate) async fn streaming_read( &self, cancel_token: CancellationToken, - pipeline_config: &PipelineConfig, ) -> Result<(ReceiverStream, JoinHandle>)> { - let (messages_tx, messages_rx) = mpsc::channel(2 * pipeline_config.batch_size); - let pipeline_config = pipeline_config.clone(); + let (messages_tx, messages_rx) = mpsc::channel(2 * self.batch_size); let handle: JoinHandle> = tokio::spawn({ let consumer = self.consumer.clone(); @@ -143,20 +144,23 @@ impl JetstreamReader { } }; - message.offset = Some(Offset::Int(IntOffset::new( + let offset = Offset::Int(IntOffset::new( msg_info.stream_sequence, partition_idx, - ))); + )); - message.id = MessageID { - vertex_name: pipeline_config.vertex_name.clone().into(), - offset: msg_info.stream_sequence.to_string().into(), + let message_id = MessageID { + vertex_name: get_vertex_name().to_string().into(), + offset: offset.to_string().into(), index: 0, }; + message.offset = Some(offset.clone()); + message.id = message_id.clone(); + // Insert the message into the tracker and wait for the ack to be sent back. let (ack_tx, ack_rx) = oneshot::channel(); - tracker_handle.insert(message.id.offset.clone(), ack_tx).await?; + tracker_handle.insert(message_id.offset.clone(), ack_tx).await?; tokio::spawn(Self::start_work_in_progress( jetstream_message, @@ -164,9 +168,14 @@ impl JetstreamReader { config.wip_ack_interval, )); - messages_tx.send(message).await.map_err(|e| { - Error::ISB(format!("Error while sending message to channel: {:?}", e)) - })?; + if let Err(e) = messages_tx.send(message).await { + // nak the read message and return + tracker_handle.discard(message_id.offset.clone()).await?; + return Err(Error::ISB(format!( + "Failed to send message to receiver: {:?}", + e + ))); + } pipeline_metrics() .forwarder @@ -313,17 +322,14 @@ mod tests { context.clone(), buf_reader_config, TrackerHandle::new(), + 500, ) .await .unwrap(); - let pipeline_cfg_base64 = "eyJtZXRhZGF0YSI6eyJuYW1lIjoic2ltcGxlLXBpcGVsaW5lLW91dCIsIm5hbWVzcGFjZSI6ImRlZmF1bHQiLCJjcmVhdGlvblRpbWVzdGFtcCI6bnVsbH0sInNwZWMiOnsibmFtZSI6Im91dCIsInNpbmsiOnsiYmxhY2tob2xlIjp7fSwicmV0cnlTdHJhdGVneSI6eyJvbkZhaWx1cmUiOiJyZXRyeSJ9fSwibGltaXRzIjp7InJlYWRCYXRjaFNpemUiOjUwMCwicmVhZFRpbWVvdXQiOiIxcyIsImJ1ZmZlck1heExlbmd0aCI6MzAwMDAsImJ1ZmZlclVzYWdlTGltaXQiOjgwfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19LCJwaXBlbGluZU5hbWUiOiJzaW1wbGUtcGlwZWxpbmUiLCJpbnRlclN0ZXBCdWZmZXJTZXJ2aWNlTmFtZSI6IiIsInJlcGxpY2FzIjowLCJmcm9tRWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoib3V0IiwiY29uZGl0aW9ucyI6bnVsbCwiZnJvbVZlcnRleFR5cGUiOiJTb3VyY2UiLCJmcm9tVmVydGV4UGFydGl0aW9uQ291bnQiOjEsImZyb21WZXJ0ZXhMaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9LCJ0b1ZlcnRleFR5cGUiOiJTaW5rIiwidG9WZXJ0ZXhQYXJ0aXRpb25Db3VudCI6MSwidG9WZXJ0ZXhMaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9fV0sIndhdGVybWFyayI6eyJtYXhEZWxheSI6IjBzIn19LCJzdGF0dXMiOnsicGhhc2UiOiIiLCJyZXBsaWNhcyI6MCwiZGVzaXJlZFJlcGxpY2FzIjowLCJsYXN0U2NhbGVkQXQiOm51bGx9fQ==".to_string(); - - let env_vars = [("NUMAFLOW_ISBSVC_JETSTREAM_URL", "localhost:4222")]; - let pipeline_config = PipelineConfig::load(pipeline_cfg_base64, env_vars).unwrap(); let reader_cancel_token = CancellationToken::new(); let (mut js_reader_rx, js_reader_task) = js_reader - .streaming_read(reader_cancel_token.clone(), &pipeline_config) + .streaming_read(reader_cancel_token.clone()) .await .unwrap(); @@ -413,17 +419,14 @@ mod tests { context.clone(), buf_reader_config, tracker_handle.clone(), + 1, ) .await .unwrap(); - let pipeline_cfg_base64 = "eyJtZXRhZGF0YSI6eyJuYW1lIjoic2ltcGxlLXBpcGVsaW5lLW91dCIsIm5hbWVzcGFjZSI6ImRlZmF1bHQiLCJjcmVhdGlvblRpbWVzdGFtcCI6bnVsbH0sInNwZWMiOnsibmFtZSI6Im91dCIsInNpbmsiOnsiYmxhY2tob2xlIjp7fSwicmV0cnlTdHJhdGVneSI6eyJvbkZhaWx1cmUiOiJyZXRyeSJ9fSwibGltaXRzIjp7InJlYWRCYXRjaFNpemUiOjUwMCwicmVhZFRpbWVvdXQiOiIxcyIsImJ1ZmZlck1heExlbmd0aCI6MzAwMDAsImJ1ZmZlclVzYWdlTGltaXQiOjgwfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19LCJwaXBlbGluZU5hbWUiOiJzaW1wbGUtcGlwZWxpbmUiLCJpbnRlclN0ZXBCdWZmZXJTZXJ2aWNlTmFtZSI6IiIsInJlcGxpY2FzIjowLCJmcm9tRWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoib3V0IiwiY29uZGl0aW9ucyI6bnVsbCwiZnJvbVZlcnRleFR5cGUiOiJTb3VyY2UiLCJmcm9tVmVydGV4UGFydGl0aW9uQ291bnQiOjEsImZyb21WZXJ0ZXhMaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9LCJ0b1ZlcnRleFR5cGUiOiJTaW5rIiwidG9WZXJ0ZXhQYXJ0aXRpb25Db3VudCI6MSwidG9WZXJ0ZXhMaW1pdHMiOnsicmVhZEJhdGNoU2l6ZSI6NTAwLCJyZWFkVGltZW91dCI6IjFzIiwiYnVmZmVyTWF4TGVuZ3RoIjozMDAwMCwiYnVmZmVyVXNhZ2VMaW1pdCI6ODB9fV0sIndhdGVybWFyayI6eyJtYXhEZWxheSI6IjBzIn19LCJzdGF0dXMiOnsicGhhc2UiOiIiLCJyZXBsaWNhcyI6MCwiZGVzaXJlZFJlcGxpY2FzIjowLCJsYXN0U2NhbGVkQXQiOm51bGx9fQ==".to_string(); - - let env_vars = [("NUMAFLOW_ISBSVC_JETSTREAM_URL", "localhost:4222")]; - let pipeline_config = PipelineConfig::load(pipeline_cfg_base64, env_vars).unwrap(); let reader_cancel_token = CancellationToken::new(); let (mut js_reader_rx, js_reader_task) = js_reader - .streaming_read(reader_cancel_token.clone(), &pipeline_config) + .streaming_read(reader_cancel_token.clone()) .await .unwrap(); @@ -438,7 +441,7 @@ mod tests { event_time: Utc::now(), id: MessageID { vertex_name: "vertex".to_string().into(), - offset: format!("{}", i + 1).into(), + offset: format!("{}-0", i + 1).into(), index: i, }, headers: HashMap::new(), diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs index a99d43856..e71335a57 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs @@ -12,6 +12,7 @@ use async_nats::jetstream::Context; use bytes::{Bytes, BytesMut}; use tokio::sync::Semaphore; use tokio::task::JoinHandle; +use tokio::time; use tokio::time::{sleep, Instant}; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; @@ -31,11 +32,11 @@ use crate::Result; const DEFAULT_RETRY_INTERVAL_MILLIS: u64 = 10; const DEFAULT_REFRESH_INTERVAL_SECS: u64 = 1; -#[derive(Clone)] /// Writes to JetStream ISB. Exposes both write and blocking methods to write messages. /// It accepts a cancellation token to stop infinite retries during shutdown. /// JetstreamWriter is one to many mapping of streams to write messages to. It also /// maintains the buffer usage metrics for each stream. +#[derive(Clone)] pub(crate) struct JetstreamWriter { config: Arc>, js_ctx: Context, @@ -183,6 +184,9 @@ impl JetstreamWriter { let mut messages_stream = messages_stream; let mut hash = DefaultHasher::new(); + let mut processed_msgs_count: usize = 0; + let mut last_logged_at = time::Instant::now(); + while let Some(message) = messages_stream.next().await { // if message needs to be dropped, ack and continue // TODO: add metric for dropped count @@ -241,6 +245,17 @@ impl JetstreamWriter { offset: message.id.offset, }) .await?; + + processed_msgs_count += 1; + if last_logged_at.elapsed().as_secs() >= 1 { + info!( + "Processed {} messages in {:?}", + processed_msgs_count, + std::time::Instant::now() + ); + processed_msgs_count = 0; + last_logged_at = Instant::now(); + } } Ok(()) }); diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index 9dd0f3959..bde1f6059 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; @@ -9,6 +10,10 @@ use tonic::transport::Channel; use crate::config::components::sink::{SinkConfig, SinkType}; use crate::config::components::source::{SourceConfig, SourceType}; use crate::config::components::transformer::TransformerConfig; +use crate::config::pipeline::map::{MapMode, MapType, MapVtxConfig}; +use crate::config::pipeline::{DEFAULT_BATCH_MAP_SOCKET, DEFAULT_STREAM_MAP_SOCKET}; +use crate::error::Error; +use crate::mapper::map::MapHandle; use crate::shared::grpc; use crate::shared::server_info::{sdk_server_info, ContainerType}; use crate::sink::{SinkClientType, SinkWriter, SinkWriterBuilder}; @@ -147,7 +152,7 @@ pub(crate) async fn create_sink_writer( } /// Creates a transformer if it is configured -pub async fn create_transformer( +pub(crate) async fn create_transformer( batch_size: usize, transformer_config: Option, tracker_handle: TrackerHandle, @@ -197,6 +202,66 @@ pub async fn create_transformer( Ok((None, None)) } +pub(crate) async fn create_mapper( + batch_size: usize, + read_timeout: Duration, + map_config: MapVtxConfig, + tracker_handle: TrackerHandle, + cln_token: CancellationToken, +) -> error::Result<(MapHandle, Option>)> { + match map_config.map_type { + MapType::UserDefined(mut config) => { + let server_info = + sdk_server_info(config.server_info_path.clone().into(), cln_token.clone()).await?; + + // based on the map mode that is set in the server info, we will override the socket path + // so that the clients can connect to the appropriate socket. + let config = match server_info.get_map_mode().unwrap_or(MapMode::Unary) { + MapMode::Unary => config, + MapMode::Batch => { + config.socket_path = DEFAULT_BATCH_MAP_SOCKET.into(); + config + } + MapMode::Stream => { + config.socket_path = DEFAULT_STREAM_MAP_SOCKET.into(); + config + } + }; + + let metric_labels = metrics::sdk_info_labels( + config::get_component_type().to_string(), + config::get_vertex_name().to_string(), + server_info.language.clone(), + server_info.version.clone(), + ContainerType::Sourcer.to_string(), + ); + metrics::global_metrics() + .sdk_info + .get_or_create(&metric_labels) + .set(1); + + let mut map_grpc_client = + MapClient::new(grpc::create_rpc_channel(config.socket_path.clone().into()).await?) + .max_encoding_message_size(config.grpc_max_message_size) + .max_decoding_message_size(config.grpc_max_message_size); + grpc::wait_until_mapper_ready(&cln_token, &mut map_grpc_client).await?; + Ok(( + MapHandle::new( + server_info.get_map_mode().unwrap_or(MapMode::Unary), + batch_size, + read_timeout, + map_config.concurrency, + map_grpc_client.clone(), + tracker_handle, + ) + .await?, + Some(map_grpc_client), + )) + } + MapType::Builtin(_) => Err(Error::Mapper("Builtin mapper is not supported".to_string())), + } +} + /// Creates a source type based on the configuration pub async fn create_source( batch_size: usize, @@ -311,8 +376,8 @@ mod tests { async fn ack(&self, _offset: Vec) {} - async fn pending(&self) -> usize { - 0 + async fn pending(&self) -> Option { + Some(0) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/shared/grpc.rs b/rust/numaflow-core/src/shared/grpc.rs index 3500524f0..bedfd2e13 100644 --- a/rust/numaflow-core/src/shared/grpc.rs +++ b/rust/numaflow-core/src/shared/grpc.rs @@ -5,6 +5,7 @@ use axum::http::Uri; use backoff::retry::Retry; use backoff::strategy::fixed; use chrono::{DateTime, TimeZone, Timelike, Utc}; +use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; @@ -81,6 +82,26 @@ pub(crate) async fn wait_until_transformer_ready( Ok(()) } +/// Waits until the mapper server is ready, by doing health checks +pub(crate) async fn wait_until_mapper_ready( + cln_token: &CancellationToken, + client: &mut MapClient, +) -> error::Result<()> { + loop { + if cln_token.is_cancelled() { + return Err(Error::Forwarder( + "Cancellation token is cancelled".to_string(), + )); + } + match client.is_ready(Request::new(())).await { + Ok(_) => break, + Err(_) => sleep(Duration::from_secs(1)).await, + } + info!("Waiting for mapper client to be ready..."); + } + Ok(()) +} + pub(crate) fn prost_timestamp_from_utc(t: DateTime) -> Option { Some(Timestamp { seconds: t.timestamp(), diff --git a/rust/numaflow-core/src/shared/server_info.rs b/rust/numaflow-core/src/shared/server_info.rs index ee3b1c8d6..757636841 100644 --- a/rust/numaflow-core/src/shared/server_info.rs +++ b/rust/numaflow-core/src/shared/server_info.rs @@ -12,12 +12,14 @@ use tokio::time::sleep; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use crate::config::pipeline::map::MapMode; use crate::error::{self, Error}; use crate::shared::server_info::version::SdkConstraints; // Constant to represent the end of the server info. // Equivalent to U+005C__END__. const END: &str = "U+005C__END__"; +const MAP_MODE_KEY: &str = "MAP_MODE"; #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub enum ContainerType { @@ -88,6 +90,17 @@ pub(crate) struct ServerInfo { pub(crate) metadata: Option>, // Metadata is optional } +impl ServerInfo { + pub(crate) fn get_map_mode(&self) -> Option { + if let Some(metadata) = &self.metadata { + if let Some(map_mode) = metadata.get(MAP_MODE_KEY) { + return MapMode::from_str(map_mode); + } + } + None + } +} + /// sdk_server_info waits until the server info file is ready and check whether the /// server is compatible with Numaflow. pub(crate) async fn sdk_server_info( @@ -415,21 +428,25 @@ mod version { go_version_map.insert(ContainerType::SourceTransformer, "0.9.0-z".to_string()); go_version_map.insert(ContainerType::Sinker, "0.9.0-z".to_string()); go_version_map.insert(ContainerType::FbSinker, "0.9.0-z".to_string()); + go_version_map.insert(ContainerType::Mapper, "0.9.0-z".to_string()); let mut python_version_map = HashMap::new(); python_version_map.insert(ContainerType::Sourcer, "0.9.0rc100".to_string()); python_version_map.insert(ContainerType::SourceTransformer, "0.9.0rc100".to_string()); python_version_map.insert(ContainerType::Sinker, "0.9.0rc100".to_string()); python_version_map.insert(ContainerType::FbSinker, "0.9.0rc100".to_string()); + python_version_map.insert(ContainerType::Mapper, "0.9.0rc100".to_string()); let mut java_version_map = HashMap::new(); java_version_map.insert(ContainerType::Sourcer, "0.9.0-z".to_string()); java_version_map.insert(ContainerType::SourceTransformer, "0.9.0-z".to_string()); java_version_map.insert(ContainerType::Sinker, "0.9.0-z".to_string()); java_version_map.insert(ContainerType::FbSinker, "0.9.0-z".to_string()); + java_version_map.insert(ContainerType::Mapper, "0.9.0-z".to_string()); let mut rust_version_map = HashMap::new(); rust_version_map.insert(ContainerType::Sourcer, "0.1.0-z".to_string()); rust_version_map.insert(ContainerType::SourceTransformer, "0.1.0-z".to_string()); rust_version_map.insert(ContainerType::Sinker, "0.1.0-z".to_string()); rust_version_map.insert(ContainerType::FbSinker, "0.1.0-z".to_string()); + rust_version_map.insert(ContainerType::Mapper, "0.1.0-z".to_string()); let mut m = HashMap::new(); m.insert("go".to_string(), go_version_map); diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 8be9d8549..4d280d372 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -247,8 +247,6 @@ impl Source { info!("Started streaming source with batch size: {}", batch_size); let handle = tokio::spawn(async move { - let mut processed_msgs_count: usize = 0; - let mut last_logged_at = time::Instant::now(); // this semaphore is used only if read-ahead is disabled. we hold this semaphore to // make sure we can read only if the current inflight ones are ack'ed. let semaphore = Arc::new(Semaphore::new(1)); @@ -312,7 +310,7 @@ impl Source { // insert the offset and the ack one shot in the tracker. tracker_handle - .insert(offset.to_string().into(), resp_ack_tx) + .insert(message.id.offset.clone(), resp_ack_tx) .await?; // store the ack one shot in the batch to invoke ack later. @@ -343,17 +341,6 @@ impl Source { None }, )); - - processed_msgs_count += n; - if last_logged_at.elapsed().as_secs() >= 1 { - info!( - "Processed {} messages in {:?}", - processed_msgs_count, - std::time::Instant::now() - ); - processed_msgs_count = 0; - last_logged_at = time::Instant::now(); - } } }); Ok((ReceiverStream::new(messages_rx), handle)) @@ -504,8 +491,8 @@ mod tests { } } - async fn pending(&self) -> usize { - self.yet_to_ack.read().unwrap().len() + async fn pending(&self) -> Option { + Some(self.yet_to_ack.read().unwrap().len()) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/source/user_defined.rs b/rust/numaflow-core/src/source/user_defined.rs index 758f8a6fc..e5717c12a 100644 --- a/rust/numaflow-core/src/source/user_defined.rs +++ b/rust/numaflow-core/src/source/user_defined.rs @@ -292,8 +292,8 @@ mod tests { } } - async fn pending(&self) -> usize { - self.yet_to_ack.read().unwrap().len() + async fn pending(&self) -> Option { + Some(self.yet_to_ack.read().unwrap().len()) } async fn partitions(&self) -> Option> { diff --git a/rust/numaflow-core/src/tracker.rs b/rust/numaflow-core/src/tracker.rs index a4ef30e24..a8ccaca54 100644 --- a/rust/numaflow-core/src/tracker.rs +++ b/rust/numaflow-core/src/tracker.rs @@ -12,7 +12,6 @@ use std::collections::HashMap; use bytes::Bytes; use tokio::sync::{mpsc, oneshot}; -use tracing::warn; use crate::error::Error; use crate::message::ReadAck; @@ -43,6 +42,7 @@ enum ActorMessage { Discard { offset: String, }, + DiscardAll, // New variant for discarding all messages #[cfg(test)] IsEmpty { respond_to: oneshot::Sender, @@ -56,11 +56,10 @@ struct Tracker { receiver: mpsc::Receiver, } -/// Implementation of Drop for Tracker to send Nak for unacknowledged messages. impl Drop for Tracker { fn drop(&mut self) { - for (offset, entry) in self.entries.drain() { - warn!(?offset, "Sending Nak for unacknowledged message"); + // clear the entries from the map and send nak + for (_, entry) in self.entries.drain() { entry .ack_send .send(ReadAck::Nak) @@ -103,6 +102,9 @@ impl Tracker { ActorMessage::Discard { offset } => { self.handle_discard(offset); } + ActorMessage::DiscardAll => { + self.handle_discard_all().await; + } #[cfg(test)] ActorMessage::IsEmpty { respond_to } => { let is_empty = self.entries.is_empty(); @@ -118,7 +120,7 @@ impl Tracker { TrackerEntry { ack_send: respond_to, count: 0, - eof: false, + eof: true, }, ); } @@ -126,8 +128,18 @@ impl Tracker { /// Updates an existing entry in the tracker with the number of expected messages and EOF status. fn handle_update(&mut self, offset: String, count: u32, eof: bool) { if let Some(entry) = self.entries.get_mut(&offset) { - entry.count = count; + entry.count += count; entry.eof = eof; + // if the count is zero, we can send an ack immediately + // this is case where map stream will send eof true after + // receiving all the messages. + if entry.count == 0 { + let entry = self.entries.remove(&offset).unwrap(); + entry + .ack_send + .send(ReadAck::Ack) + .expect("Failed to send ack"); + } } } @@ -138,7 +150,7 @@ impl Tracker { if entry.count > 0 { entry.count -= 1; } - if entry.count == 0 || entry.eof { + if entry.count == 0 && entry.eof { entry .ack_send .send(ReadAck::Ack) @@ -158,6 +170,16 @@ impl Tracker { .expect("Failed to send nak"); } } + + /// Discards all entries from the tracker and sends a nak for each. + async fn handle_discard_all(&mut self) { + for (_, entry) in self.entries.drain() { + entry + .ack_send + .send(ReadAck::Nak) + .expect("Failed to send nak"); + } + } } /// TrackerHandle provides an interface to interact with the Tracker. @@ -231,6 +253,15 @@ impl TrackerHandle { Ok(()) } + /// Discards all messages from the Tracker and sends a nak for each. + pub(crate) async fn discard_all(&self) -> Result<()> { + let message = ActorMessage::DiscardAll; + self.sender + .send(message) + .await + .map_err(|e| Error::Tracker(format!("{:?}", e)))?; + Ok(()) + } /// Checks if the Tracker is empty. Used for testing to make sure all messages are acknowledged. #[cfg(test)] pub(crate) async fn is_empty(&self) -> Result { @@ -293,7 +324,7 @@ mod tests { // Update the message with a count of 3 handle - .update("offset1".to_string().into(), 3, false) + .update("offset1".to_string().into(), 3, true) .await .unwrap(); diff --git a/rust/numaflow-core/src/transformer.rs b/rust/numaflow-core/src/transformer.rs index 0b26a7e76..6f9298b7c 100644 --- a/rust/numaflow-core/src/transformer.rs +++ b/rust/numaflow-core/src/transformer.rs @@ -6,7 +6,6 @@ use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use tonic::transport::Channel; -use tracing::error; use crate::error::Error; use crate::message::Message; @@ -15,7 +14,7 @@ use crate::tracker::TrackerHandle; use crate::transformer::user_defined::UserDefinedTransformer; use crate::Result; -/// User-Defined Transformer extends Numaflow to add custom sources supported outside the builtins. +/// User-Defined Transformer is a custom transformer that can be built by the user. /// /// [User-Defined Transformer]: https://numaflow.numaproj.io/user-guide/sources/transformer/overview/#build-your-own-transformer pub(crate) mod user_defined; @@ -60,13 +59,22 @@ impl TransformerActor { } } -/// StreamingTransformer, transforms messages in a streaming fashion. +/// Transformer, transforms messages in a streaming fashion. pub(crate) struct Transformer { batch_size: usize, sender: mpsc::Sender, concurrency: usize, tracker_handle: TrackerHandle, + task_handle: JoinHandle<()>, } + +/// Aborts the actor task when the transformer is dropped. +impl Drop for Transformer { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + impl Transformer { pub(crate) async fn new( batch_size: usize, @@ -80,7 +88,7 @@ impl Transformer { UserDefinedTransformer::new(batch_size, client).await?, ); - tokio::spawn(async move { + let task_handle = tokio::spawn(async move { transformer_actor.run().await; }); @@ -89,23 +97,25 @@ impl Transformer { concurrency, sender, tracker_handle, + task_handle, }) } /// Applies the transformation on the message and sends it to the next stage, it blocks if the /// concurrency limit is reached. - pub(crate) async fn transform( + async fn transform( transform_handle: mpsc::Sender, permit: OwnedSemaphorePermit, read_msg: Message, output_tx: mpsc::Sender, tracker_handle: TrackerHandle, - ) -> Result<()> { + error_tx: mpsc::Sender, + ) { // only if we have tasks < max_concurrency - let output_tx = output_tx.clone(); // invoke transformer and then wait for the one-shot + // short-lived tokio spawns we don't need structured concurrency here tokio::spawn(async move { let start_time = tokio::time::Instant::now(); let _permit = permit; @@ -117,32 +127,41 @@ impl Transformer { }; // invoke trf - transform_handle - .send(msg) - .await - .expect("failed to send message"); + if let Err(e) = transform_handle.send(msg).await { + let _ = error_tx + .send(Error::Transformer(format!("failed to send message: {}", e))) + .await; + return; + } // wait for one-shot match receiver.await { Ok(Ok(mut transformed_messages)) => { - tracker_handle + if let Err(e) = tracker_handle .update( read_msg.id.offset.clone(), transformed_messages.len() as u32, - false, + true, ) .await - .expect("failed to update tracker"); + { + let _ = error_tx.send(e).await; + return; + } for transformed_message in transformed_messages.drain(..) { let _ = output_tx.send(transformed_message).await; } } - Err(_) | Ok(Err(_)) => { - error!("Failed to transform message"); - tracker_handle - .discard(read_msg.id.offset.clone()) - .await - .expect("failed to discard tracker"); + Ok(Err(e)) => { + let _ = error_tx.send(e).await; + } + Err(e) => { + let _ = error_tx + .send(Error::Transformer(format!( + "failed to receive message: {}", + e + ))) + .await; } } monovertex_metrics() @@ -151,40 +170,59 @@ impl Transformer { .get_or_create(mvtx_forward_metric_labels()) .observe(start_time.elapsed().as_micros() as f64); }); - - Ok(()) } - /// Starts reading messages in the form of chunks and transforms them and - /// sends them to the next stage. + /// Starts the transformation of the stream of messages and returns the transformed stream. pub(crate) fn transform_stream( &self, input_stream: ReceiverStream, ) -> Result<(ReceiverStream, JoinHandle>)> { let (output_tx, output_rx) = mpsc::channel(self.batch_size); + // channel to transmit errors from the transformer tasks to the main task + let (error_tx, mut error_rx) = mpsc::channel(1); + let transform_handle = self.sender.clone(); let tracker_handle = self.tracker_handle.clone(); - // FIXME: batch_size should not be used, introduce a new config called udf concurrency let semaphore = Arc::new(Semaphore::new(self.concurrency)); let handle = tokio::spawn(async move { let mut input_stream = input_stream; - while let Some(read_msg) = input_stream.next().await { - let permit = Arc::clone(&semaphore).acquire_owned().await.map_err(|e| { - Error::Transformer(format!("failed to acquire semaphore: {}", e)) - })?; - - Self::transform( - transform_handle.clone(), - permit, - read_msg, - output_tx.clone(), - tracker_handle.clone(), - ) - .await?; + // we do a tokio::select! loop to handle the input stream and the error channel + // in case of any errors in the transformer tasks we need to shut down the mapper + // and discard all the messages in the tracker. + loop { + tokio::select! { + x = input_stream.next() => { + if let Some(read_msg) = x { + let permit = Arc::clone(&semaphore) + .acquire_owned() + .await + .map_err(|e| Error::Transformer(format!("failed to acquire semaphore: {}", e)))?; + + let error_tx = error_tx.clone(); + Self::transform( + transform_handle.clone(), + permit, + read_msg, + output_tx.clone(), + tracker_handle.clone(), + error_tx, + ).await; + } else { + break; + } + }, + Some(error) = error_rx.recv() => { + // discard all the messages in the tracker since it's a critical error, and + // we are shutting down + tracker_handle.discard_all().await?; + return Err(error); + }, + } } + Ok(()) }); @@ -202,6 +240,7 @@ mod tests { use tokio::sync::oneshot; use super::*; + use crate::message::StringOffset; use crate::message::{Message, MessageID, Offset}; use crate::shared::grpc::create_rpc_channel; @@ -248,10 +287,7 @@ mod tests { keys: Arc::from(vec!["first".into()]), tags: None, value: "hello".into(), - offset: Some(Offset::String(crate::message::StringOffset::new( - "0".to_string(), - 0, - ))), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), event_time: chrono::Utc::now(), id: MessageID { vertex_name: "vertex_name".to_string().into(), @@ -265,14 +301,19 @@ mod tests { let semaphore = Arc::new(Semaphore::new(10)); let permit = semaphore.acquire_owned().await.unwrap(); + let (error_tx, mut error_rx) = mpsc::channel(1); Transformer::transform( transformer.sender.clone(), permit, message, output_tx, tracker_handle, + error_tx, ) - .await?; + .await; + + // check for errors + assert!(error_rx.recv().await.is_none()); let transformed_message = output_rx.recv().await.unwrap(); assert_eq!(transformed_message.value, "hello"); @@ -325,10 +366,7 @@ mod tests { keys: Arc::from(vec![format!("key_{}", i)]), tags: None, value: format!("value_{}", i).into(), - offset: Some(Offset::String(crate::message::StringOffset::new( - i.to_string(), - 0, - ))), + offset: Some(Offset::String(StringOffset::new(i.to_string(), 0))), event_time: chrono::Utc::now(), id: MessageID { vertex_name: "vertex_name".to_string().into(), @@ -368,4 +406,78 @@ mod tests { ); Ok(()) } + + struct SimpleTransformerPanic; + + #[tonic::async_trait] + impl sourcetransform::SourceTransformer for SimpleTransformerPanic { + async fn transform( + &self, + _input: sourcetransform::SourceTransformRequest, + ) -> Vec { + panic!("SimpleTransformerPanic panicked!"); + } + } + + #[tokio::test] + async fn test_transform_stream_with_panic() -> Result<()> { + let tmp_dir = TempDir::new().unwrap(); + let sock_file = tmp_dir.path().join("sourcetransform.sock"); + let server_info_file = tmp_dir.path().join("sourcetransformer-server-info"); + + let server_info = server_info_file.clone(); + let server_socket = sock_file.clone(); + let handle = tokio::spawn(async move { + sourcetransform::Server::new(SimpleTransformerPanic) + .with_socket_file(server_socket) + .with_server_info_file(server_info) + .start() + .await + .expect("server failed"); + }); + + // wait for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + let tracker_handle = TrackerHandle::new(); + let client = SourceTransformClient::new(create_rpc_channel(sock_file).await?); + let transformer = Transformer::new(500, 10, client, tracker_handle.clone()).await?; + + let (input_tx, input_rx) = mpsc::channel(10); + let input_stream = ReceiverStream::new(input_rx); + + let message = Message { + keys: Arc::from(vec!["first".into()]), + tags: None, + value: "hello".into(), + offset: Some(Offset::String(StringOffset::new("0".to_string(), 0))), + event_time: chrono::Utc::now(), + id: MessageID { + vertex_name: "vertex_name".to_string().into(), + offset: "0".to_string().into(), + index: 0, + }, + headers: Default::default(), + }; + + input_tx.send(message).await.unwrap(); + + let (_output_stream, transform_handle) = transformer.transform_stream(input_stream)?; + + // Await the join handle and expect an error due to the panic + let result = transform_handle.await.unwrap(); + assert!(result.is_err(), "Expected an error due to panic"); + assert!(result.unwrap_err().to_string().contains("panic")); + + // we need to drop the transformer, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(transformer); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); + Ok(()) + } } diff --git a/rust/numaflow-core/src/transformer/user_defined.rs b/rust/numaflow-core/src/transformer/user_defined.rs index 9a82275ac..398d5a4bc 100644 --- a/rust/numaflow-core/src/transformer/user_defined.rs +++ b/rust/numaflow-core/src/transformer/user_defined.rs @@ -1,11 +1,11 @@ use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use numaflow_pb::clients::sourcetransformer::{ self, source_transform_client::SourceTransformClient, SourceTransformRequest, SourceTransformResponse, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Channel; use tonic::{Request, Streaming}; @@ -28,6 +28,14 @@ struct ParentMessageInfo { pub(super) struct UserDefinedTransformer { read_tx: mpsc::Sender, senders: ResponseSenderMap, + task_handle: tokio::task::JoinHandle<()>, +} + +/// Aborts the background task when the UserDefinedTransformer is dropped. +impl Drop for UserDefinedTransformer { + fn drop(&mut self) { + self.task_handle.abort(); + } } impl UserDefinedTransformer { @@ -65,15 +73,19 @@ impl UserDefinedTransformer { // map to track the oneshot sender for each request along with the message info let sender_map = Arc::new(Mutex::new(HashMap::new())); + // background task to receive responses from the server and send them to the appropriate + // oneshot sender based on the message id + let task_handle = tokio::spawn(Self::receive_responses( + Arc::clone(&sender_map), + resp_stream, + )); + let transformer = Self { read_tx, - senders: Arc::clone(&sender_map), + senders: sender_map, + task_handle, }; - // background task to receive responses from the server and send them to the appropriate - // oneshot sender based on the message id - tokio::spawn(Self::receive_responses(sender_map, resp_stream)); - Ok(transformer) } @@ -83,29 +95,32 @@ impl UserDefinedTransformer { sender_map: ResponseSenderMap, mut resp_stream: Streaming, ) { - while let Some(resp) = resp_stream - .message() - .await - .expect("failed to receive response") - { + while let Some(resp) = match resp_stream.message().await { + Ok(message) => message, + Err(e) => { + let error = + Error::Transformer(format!("failed to receive transformer response: {}", e)); + let mut senders = sender_map.lock().await; + for (_, (_, sender)) in senders.drain() { + let _ = sender.send(Err(error.clone())); + } + None + } + } { let msg_id = resp.id; - if let Some((msg_info, sender)) = sender_map - .lock() - .expect("map entry should always be present") - .remove(&msg_id) - { + if let Some((msg_info, sender)) = sender_map.lock().await.remove(&msg_id) { let mut response_messages = vec![]; for (i, result) in resp.results.into_iter().enumerate() { let message = Message { id: MessageID { vertex_name: get_vertex_name().to_string().into(), index: i as i32, - offset: msg_info.offset.to_string().into(), + offset: msg_info.offset.clone().to_string().into(), }, keys: Arc::from(result.keys), tags: Some(Arc::from(result.tags)), value: result.value.into(), - offset: None, + offset: Some(msg_info.offset.clone()), event_time: utc_from_timestamp(result.event_time), headers: msg_info.headers.clone(), }; @@ -124,7 +139,12 @@ impl UserDefinedTransformer { message: Message, respond_to: oneshot::Sender>>, ) { - let msg_id = message.id.to_string(); + let key = message + .offset + .clone() + .expect("offset should be present") + .to_string(); + let msg_info = ParentMessageInfo { offset: message.offset.clone().expect("offset can never be none"), headers: message.headers.clone(), @@ -132,10 +152,13 @@ impl UserDefinedTransformer { self.senders .lock() - .unwrap() - .insert(msg_id, (msg_info, respond_to)); + .await + .insert(key, (msg_info, respond_to)); - self.read_tx.send(message.into()).await.unwrap(); + self.read_tx + .send(message.into()) + .await + .expect("failed to send message"); } }