From 8d8340c725e39031ef56e04e603c3b9241f15399 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 6 Jan 2025 11:27:53 +0530 Subject: [PATCH] feat: Implement `Sourcer` traits for serving source (#2301) Signed-off-by: Sreekanth Signed-off-by: Vigith Maurice Co-authored-by: Vigith Maurice --- rust/Cargo.lock | 212 +++++++++--- rust/Cargo.toml | 6 +- rust/numaflow-core/Cargo.toml | 11 +- rust/numaflow-core/src/config/components.rs | 16 +- rust/numaflow-core/src/message.rs | 8 +- rust/numaflow-core/src/metrics.rs | 2 - .../src/shared/create_components.rs | 22 +- rust/numaflow-core/src/source.rs | 11 + rust/numaflow-core/src/source/serving.rs | 206 ++++++++++++ rust/numaflow/src/main.rs | 13 +- rust/serving/Cargo.toml | 15 +- rust/serving/src/app.rs | 304 +++++------------- rust/serving/src/app/jetstream_proxy.rs | 284 ++++++++-------- rust/serving/src/config.rs | 144 +++------ rust/serving/src/error.rs | 7 + rust/serving/src/lib.rs | 47 +-- rust/serving/src/metrics.rs | 2 + rust/serving/src/pipeline.rs | 43 ++- rust/serving/src/source.rs | 292 +++++++++++++++++ 19 files changed, 1048 insertions(+), 597 deletions(-) create mode 100644 rust/numaflow-core/src/source/serving.rs create mode 100644 rust/serving/src/source.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index beec59aa4b..e3d90e2f05 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -53,39 +53,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "async-nats" -version = "0.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab8df97cb8fc4a884af29ab383e9292ea0939cfcdd7d2a17179086dc6c427e7f" -dependencies = [ - "base64 0.22.1", - "bytes", - "futures", - "memchr", - "nkeys", - "nuid", - "once_cell", - "portable-atomic", - "rand", - "regex", - "ring", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.2.0", - "rustls-webpki 0.102.8", - "serde", - "serde_json", - "serde_nanos", - "serde_repr", - "thiserror 1.0.69", - "time", - "tokio", - "tokio-rustls 0.26.0", - "tracing", - "tryhard", - "url", -] - [[package]] name = "async-nats" version = "0.38.0" @@ -221,7 +188,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tower 0.5.1", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -705,6 +672,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1125,6 +1107,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -1608,6 +1606,23 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nkeys" version = "0.4.4" @@ -1748,7 +1763,7 @@ dependencies = [ name = "numaflow-core" version = "0.1.0" dependencies = [ - "async-nats 0.38.0", + "async-nats", "axum", "axum-server", "backoff", @@ -1771,13 +1786,14 @@ dependencies = [ "pulsar", "rand", "rcgen", + "reqwest 0.12.12", "rustls 0.23.19", "semver", "serde", "serde_json", "serving", "tempfile", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tokio-stream", "tokio-util", @@ -1821,7 +1837,7 @@ dependencies = [ "prost 0.11.9", "pulsar", "serde", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tonic", "tracing", @@ -1842,12 +1858,50 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "openssl" +version = "0.10.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ordered-float" version = "2.10.1" @@ -2025,6 +2079,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "portable-atomic" version = "1.10.0" @@ -2249,7 +2309,7 @@ dependencies = [ "rustc-hash 2.1.0", "rustls 0.23.19", "socket2", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tracing", ] @@ -2268,7 +2328,7 @@ dependencies = [ "rustls 0.23.19", "rustls-pki-types", "slab", - "thiserror 2.0.3", + "thiserror 2.0.8", "tinyvec", "tracing", "web-time", @@ -2448,7 +2508,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -2462,24 +2522,28 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.9" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", "hyper 1.5.1", "hyper-rustls 0.27.3", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -2491,8 +2555,11 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.0", + "tower 0.5.2", "tower-service", "url", "wasm-bindgen", @@ -2850,7 +2917,7 @@ name = "servesink" version = "0.1.0" dependencies = [ "numaflow 0.1.1", - "reqwest 0.12.9", + "reqwest 0.12.12", "tokio", "tonic", "tracing", @@ -2860,12 +2927,12 @@ dependencies = [ name = "serving" version = "0.1.0" dependencies = [ - "async-nats 0.35.1", "axum", "axum-macros", "axum-server", "backoff", "base64 0.22.1", + "bytes", "chrono", "hyper-util", "numaflow-models", @@ -2873,6 +2940,8 @@ dependencies = [ "prometheus-client", "rcgen", "redis", + "reqwest 0.12.12", + "rustls 0.23.19", "serde", "serde_json", "thiserror 1.0.69", @@ -3067,7 +3136,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation 0.9.4", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.9.4", + "system-configuration-sys 0.6.0", ] [[package]] @@ -3080,6 +3160,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.14.0" @@ -3104,11 +3194,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.8", ] [[package]] @@ -3124,9 +3214,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" dependencies = [ "proc-macro2", "quote", @@ -3227,6 +3317,16 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-retry" version = "0.3.0" @@ -3370,14 +3470,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -3594,6 +3694,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8a6b41a1a4..75fd036128 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -58,8 +58,12 @@ numaflow-core = { path = "numaflow-core" } numaflow-models = { path = "numaflow-models" } backoff = { path = "backoff" } numaflow-pb = { path = "numaflow-pb" } -numaflow-pulsar = {path = "extns/numaflow-pulsar"} +numaflow-pulsar = { path = "extns/numaflow-pulsar" } tokio = "1.41.1" +bytes = "1.7.1" tracing = "0.1.40" axum = "0.7.5" axum-server = { version = "0.7.1", features = ["tls-rustls"] } +serde = { version = "1.0.204", features = ["derive"] } +rustls = { version = "0.23.12", features = ["aws_lc_rs"] } +reqwest = "0.12.12" diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index 38cabb704f..4a98303a1e 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -21,8 +21,10 @@ serving.workspace = true backoff.workspace = true axum.workspace = true axum-server.workspace = true +bytes.workspace = true +serde.workspace = true +rustls.workspace = true tonic = "0.12.3" -bytes = "1.7.1" thiserror = "2.0.3" tokio-util = "0.7.11" tokio-stream = "0.1.15" @@ -35,8 +37,6 @@ tower = "0.4.13" serde_json = "1.0.122" trait-variant = "0.1.2" rcgen = "0.13.1" -rustls = { version = "0.23.12", features = ["aws_lc_rs"] } -serde = { version = "1.0.204", features = ["derive"] } semver = "1.0" pep440_rs = "0.6.6" parking_lot = "0.12.3" @@ -50,6 +50,9 @@ async-nats = "0.38.0" [dev-dependencies] tempfile = "3.11.0" numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "9ca9362ad511084501520e5a37d40cdcd0cdc9d9" } -pulsar = { version = "6.3.0", default-features = false, features = ["tokio-rustls-runtime"] } +pulsar = { version = "6.3.0", default-features = false, features = [ + "tokio-rustls-runtime", +] } +reqwest = { workspace = true, features = ["json"] } [build-dependencies] diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index a49692060f..3dc0bf2a66 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -5,6 +5,7 @@ pub(crate) mod source { use std::collections::HashMap; use std::env; + use std::sync::Arc; use std::{fmt::Debug, time::Duration}; use bytes::Bytes; @@ -37,7 +38,9 @@ pub(crate) mod source { Generator(GeneratorConfig), UserDefined(UserDefinedConfig), Pulsar(PulsarSourceConfig), - Serving(serving::Settings), + // Serving source starts an Axum HTTP server in the background. + // The settings will be used as application state which gets cloned in each handler on each request. + Serving(Arc), } impl From> for SourceType { @@ -110,10 +113,7 @@ pub(crate) mod source { // There should be only one option (user-defined) to define the settings. fn try_from(cfg: Box) -> Result { let env_vars = env::vars().collect::>(); - - let mut settings: serving::Settings = env_vars - .try_into() - .map_err(|e: serving::Error| Error::Config(e.to_string()))?; + let mut settings: serving::Settings = env_vars.try_into()?; settings.tid_header = cfg.msg_id_header_key; @@ -148,7 +148,7 @@ pub(crate) mod source { } settings.redis.addr = cfg.store.url; - Ok(SourceType::Serving(settings)) + Ok(SourceType::Serving(Arc::new(settings))) } } @@ -168,6 +168,10 @@ pub(crate) mod source { return pulsar.try_into(); } + if let Some(serving) = source.serving.take() { + return serving.try_into(); + } + Err(Error::Config(format!("Invalid source type: {source:?}"))) } } diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index 00f5cca663..fe20613dad 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -37,7 +37,7 @@ pub(crate) struct Message { } /// Offset of the message which will be used to acknowledge the message. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) enum Offset { Int(IntOffset), String(StringOffset), @@ -62,7 +62,7 @@ impl Message { } /// IntOffset is integer based offset enum type. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct IntOffset { pub(crate) offset: u64, pub(crate) partition_idx: u16, @@ -84,7 +84,7 @@ impl fmt::Display for IntOffset { } /// StringOffset is string based offset enum type. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) struct StringOffset { /// offset could be a complex base64 string. pub(crate) offset: Bytes, @@ -120,7 +120,7 @@ pub(crate) enum ReadAck { } /// Message ID which is used to uniquely identify a message. It cheap to clone this. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) struct MessageID { pub(crate) vertex_name: Bytes, pub(crate) offset: Bytes, diff --git a/rust/numaflow-core/src/metrics.rs b/rust/numaflow-core/src/metrics.rs index fa79e457b8..2a672ec31d 100644 --- a/rust/numaflow-core/src/metrics.rs +++ b/rust/numaflow-core/src/metrics.rs @@ -600,8 +600,6 @@ pub(crate) async fn start_metrics_https_server( addr: SocketAddr, metrics_state: UserDefinedContainerState, ) -> crate::Result<()> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - // Generate a self-signed certificate let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| Error::Metrics(format!("Generating self-signed certificate: {}", e)))?; diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index bde1f6059e..b28f4caeee 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -1,15 +1,18 @@ +use std::sync::Arc; use std::time::Duration; use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; +use serving::ServingSource; use tokio_util::sync::CancellationToken; use tonic::transport::Channel; use crate::config::components::sink::{SinkConfig, SinkType}; use crate::config::components::source::{SourceConfig, SourceType}; use crate::config::components::transformer::TransformerConfig; +use crate::config::get_vertex_replica; use crate::config::pipeline::map::{MapMode, MapType, MapVtxConfig}; use crate::config::pipeline::{DEFAULT_BATCH_MAP_SOCKET, DEFAULT_STREAM_MAP_SOCKET}; use crate::error::Error; @@ -334,8 +337,23 @@ pub async fn create_source( None, )) } - SourceType::Serving(_) => { - unimplemented!("Serving as built-in source is not yet implemented") + SourceType::Serving(config) => { + let serving = ServingSource::new( + Arc::clone(config), + batch_size, + read_timeout, + *get_vertex_replica(), + ) + .await?; + Ok(( + Source::new( + batch_size, + source::SourceType::Serving(serving), + tracker_handle, + source_config.read_ahead, + ), + None, + )) } } } diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 4d280d3725..a30fc9777e 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -37,6 +37,9 @@ pub(crate) mod generator; /// [Pulsar]: https://numaflow.numaproj.io/user-guide/sources/pulsar/ pub(crate) mod pulsar; +pub(crate) mod serving; +use serving::ServingSource; + /// Set of Read related items that has to be implemented to become a Source. pub(crate) trait SourceReader { #[allow(dead_code)] @@ -68,6 +71,7 @@ pub(crate) enum SourceType { generator::GeneratorLagReader, ), Pulsar(PulsarSource), + Serving(ServingSource), } enum ActorMessage { @@ -182,6 +186,13 @@ impl Source { actor.run().await; }); } + SourceType::Serving(serving) => { + tokio::spawn(async move { + let actor = + SourceActor::new(receiver, serving.clone(), serving.clone(), serving); + actor.run().await; + }); + } }; Self { read_batch_size: batch_size, diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs new file mode 100644 index 0000000000..b9fb6c72ed --- /dev/null +++ b/rust/numaflow-core/src/source/serving.rs @@ -0,0 +1,206 @@ +use std::sync::Arc; + +pub(crate) use serving::ServingSource; + +use crate::config::get_vertex_replica; +use crate::message::{MessageID, StringOffset}; +use crate::Error; +use crate::Result; + +use super::{get_vertex_name, Message, Offset}; + +impl TryFrom for Message { + type Error = Error; + + fn try_from(message: serving::Message) -> Result { + let offset = Offset::String(StringOffset::new(message.id.clone(), *get_vertex_replica())); + + Ok(Message { + // we do not support keys from HTTP client + keys: Arc::from(vec![]), + tags: None, + value: message.value, + offset: Some(offset.clone()), + event_time: Default::default(), + id: MessageID { + vertex_name: get_vertex_name().to_string().into(), + offset: offset.to_string().into(), + index: 0, + }, + headers: message.headers, + }) + } +} + +impl From for Error { + fn from(value: serving::Error) -> Self { + Error::Source(value.to_string()) + } +} + +impl super::SourceReader for ServingSource { + fn name(&self) -> &'static str { + "serving" + } + + async fn read(&mut self) -> Result> { + self.read_messages() + .await? + .into_iter() + .map(|msg| msg.try_into()) + .collect() + } + + fn partitions(&self) -> Vec { + vec![*get_vertex_replica()] + } +} + +impl super::SourceAcker for ServingSource { + /// HTTP response is sent only once we have confirmation that the message has been written to the ISB. + // TODO: Current implementation only works for `/v1/process/async` endpoint. + // For `/v1/process/{sync,sync_serve}` endpoints: https://github.com/numaproj/numaflow/issues/2308 + async fn ack(&mut self, offsets: Vec) -> Result<()> { + let mut serving_offsets = vec![]; + for offset in offsets { + let Offset::String(offset) = offset else { + return Err(Error::Source(format!( + "Expected string offset for Serving source. Got {offset:?}" + ))); + }; + serving_offsets.push(offset.to_string()); + } + self.ack_messages(serving_offsets).await?; + Ok(()) + } +} + +impl super::LagReader for ServingSource { + async fn pending(&mut self) -> Result> { + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + message::{Message, MessageID, Offset, StringOffset}, + source::{SourceAcker, SourceReader}, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; + + use bytes::Bytes; + use serving::{ServingSource, Settings}; + + use super::get_vertex_replica; + + type Result = std::result::Result>; + + #[test] + fn test_message_conversion() -> Result<()> { + const MSG_ID: &str = "b149ad7a-5690-4f0a"; + + let mut headers = HashMap::new(); + headers.insert("header-key".to_owned(), "header-value".to_owned()); + + let serving_message = serving::Message { + value: Bytes::from_static(b"test"), + id: MSG_ID.into(), + headers: headers.clone(), + }; + let message: Message = serving_message.try_into()?; + assert_eq!(message.value, Bytes::from_static(b"test")); + assert_eq!( + message.offset, + Some(Offset::String(StringOffset::new(MSG_ID.into(), 0))) + ); + assert_eq!( + message.id, + MessageID { + vertex_name: Bytes::new(), + offset: format!("{MSG_ID}-0").into(), + index: 0 + } + ); + + assert_eq!(message.headers, headers); + + Ok(()) + } + + #[test] + fn test_error_conversion() { + use crate::error::Error; + let error: Error = serving::Error::ParseConfig("Invalid config".to_owned()).into(); + if let Error::Source(val) = error { + assert_eq!(val, "ParseConfig Error - Invalid config".to_owned()); + } else { + panic!("Expected Error::Source() variant"); + } + } + + #[tokio::test] + async fn test_serving_source_reader_acker() -> Result<()> { + let settings = Settings { + app_listen_port: 2000, + ..Default::default() + }; + let settings = Arc::new(settings); + let mut serving_source = ServingSource::new( + Arc::clone(&settings), + 10, + Duration::from_millis(1), + *get_vertex_replica(), + ) + .await?; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + // Wait for the server + for _ in 0..10 { + let resp = client + .get(format!( + "https://localhost:{}/livez", + settings.app_listen_port + )) + .send() + .await; + if resp.is_ok() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let task_handle = tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut messages = serving_source.read().await.unwrap(); + if messages.is_empty() { + // Server has not received any requests yet + continue; + } + assert_eq!(messages.len(), 1); + let msg = messages.remove(0); + serving_source.ack(vec![msg.offset.unwrap()]).await.unwrap(); + break; + } + }); + + let resp = client + .post(format!( + "https://localhost:{}/v1/process/async", + settings.app_listen_port + )) + .json("test-payload") + .send() + .await?; + + assert!(resp.status().is_success()); + assert!(task_handle.await.is_ok()); + Ok(()) + } +} diff --git a/rust/numaflow/src/main.rs b/rust/numaflow/src/main.rs index 60e26ef850..9a5ab6fe82 100644 --- a/rust/numaflow/src/main.rs +++ b/rust/numaflow/src/main.rs @@ -1,7 +1,5 @@ -use std::collections::HashMap; use std::env; use std::error::Error; -use std::sync::Arc; use tracing::error; use tracing_subscriber::layer::SubscriberExt; @@ -31,14 +29,7 @@ async fn main() -> Result<(), Box> { async fn run() -> Result<(), Box> { let args: Vec = env::args().collect(); // Based on the argument, run the appropriate component. - if args.contains(&"--serving".to_string()) { - let env_vars: HashMap = env::vars().collect(); - let settings: serving::Settings = env_vars.try_into()?; - let settings = Arc::new(settings); - serving::serve(settings) - .await - .map_err(|e| format!("Error running serving: {e:?}"))?; - } else if args.contains(&"--servesink".to_string()) { + if args.contains(&"--servesink".to_string()) { servesink::servesink() .await .map_err(|e| format!("Error running servesink: {e:?}"))?; @@ -47,5 +38,5 @@ async fn run() -> Result<(), Box> { .await .map_err(|e| format!("Error running rust binary: {e:?}"))? } - Err("Invalid argument. Use --serving, --servesink, or --rust".into()) + Err("Invalid argument. Use --servesink, or --rust".into()) } diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index de2f8bb820..857d69db77 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -5,8 +5,7 @@ edition = "2021" [features] redis-tests = [] -nats-tests = [] -all-tests = ["redis-tests", "nats-tests"] +all-tests = ["redis-tests"] [lints] workspace = true @@ -18,7 +17,8 @@ numaflow-models.workspace = true backoff.workspace = true axum.workspace = true axum-server.workspace = true -async-nats = "0.35.1" +bytes.workspace = true +rustls.workspace = true axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } serde = { version = "1.0.204", features = ["derive"] } @@ -26,7 +26,11 @@ serde_json = "1.0.120" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["trace", "timeout"] } uuid = { version = "1.10.0", features = ["v4"] } -redis = { version = "0.26.0", features = ["tokio-comp", "aio", "connection-manager"] } +redis = { version = "0.26.0", features = [ + "tokio-comp", + "aio", + "connection-manager", +] } trait-variant = "0.1.2" chrono = { version = "0.4", features = ["serde"] } base64 = "0.22.1" @@ -35,3 +39,6 @@ parking_lot = "0.12.3" prometheus-client = "0.22.3" thiserror = "1.0.63" +[dev-dependencies] +reqwest = { workspace = true, features = ["json"] } +rustls.workspace = true diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index 56d4a33cb3..82ef1ef62e 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -1,9 +1,6 @@ use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; -use async_nats::jetstream; -use async_nats::jetstream::Context; use axum::extract::{MatchedPath, State}; use axum::http::StatusCode; use axum::middleware::Next; @@ -25,12 +22,9 @@ use self::{ message_path::get_message_path, }; use crate::app::callback::store::Store; -use crate::app::tracker::MessageGraph; -use crate::config::JetStreamConfig; -use crate::pipeline::PipelineDCG; +use crate::metrics::capture_metrics; +use crate::AppState; use crate::Error::InitError; -use crate::Settings; -use crate::{app::callback::state::State as CallbackState, metrics::capture_metrics}; /// manage callbacks pub(crate) mod callback; @@ -41,7 +35,7 @@ mod jetstream_proxy; /// Return message path in response to UI requests mod message_path; // TODO: merge message_path and tracker mod response; -mod tracker; +pub(crate) mod tracker; /// Everything for numaserve starts here. The routing, middlewares, proxying, etc. // TODO @@ -49,16 +43,39 @@ mod tracker; // - [ ] outer fallback for /v1/direct /// Start the main application Router and the axum server. -pub(crate) async fn start_main_server( - settings: Arc, +pub(crate) async fn start_main_server( + app: AppState, tls_config: RustlsConfig, - pipeline_spec: PipelineDCG, -) -> crate::Result<()> { - let app_addr: SocketAddr = format!("0.0.0.0:{}", &settings.app_listen_port) +) -> crate::Result<()> +where + T: Clone + Send + Sync + Store + 'static, +{ + let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) .parse() .map_err(|e| InitError(format!("{e:?}")))?; - let tid_header = settings.tid_header.clone(); + let handle = Handle::new(); + // Spawn a task to gracefully shutdown server. + tokio::spawn(graceful_shutdown(handle.clone())); + + info!(?app_addr, "Starting application server"); + + let router = router_with_auth(app).await?; + + axum_server::bind_rustls(app_addr, tls_config) + .handle(handle) + .serve(router.into_make_service()) + .await + .map_err(|e| InitError(format!("Starting web server for metrics: {}", e)))?; + + Ok(()) +} + +pub(crate) async fn router_with_auth(app: AppState) -> crate::Result +where + T: Clone + Send + Sync + Store + 'static, +{ + let tid_header = app.settings.tid_header.clone(); let layers = ServiceBuilder::new() // Add tracing to all requests .layer( @@ -85,45 +102,14 @@ pub(crate) async fn start_main_server( .layer( // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so // requests don't hang forever. - TimeoutLayer::new(Duration::from_secs(settings.drain_timeout_secs)), + TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)), ) // Add auth middleware to all user facing routes .layer(middleware::from_fn_with_state( - settings.api_auth_token.clone(), + app.settings.api_auth_token.clone(), auth_middleware, )); - - // Create the message graph from the pipeline spec and the redis store - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).map_err(|e| { - InitError(format!( - "Creating message graph from pipeline spec: {:?}", - e - )) - })?; - - // Create a redis store to store the callbacks and the custom responses - let redis_store = - callback::store::redisstore::RedisConnection::new(settings.redis.clone()).await?; - let state = CallbackState::new(msg_graph, redis_store).await?; - - let handle = Handle::new(); - // Spawn a task to gracefully shutdown server. - tokio::spawn(graceful_shutdown(handle.clone())); - - // Create a Jetstream context - let js_context = create_js_context(&settings.jetstream).await?; - - let router = setup_app(settings, js_context, state).await?.layer(layers); - - info!(?app_addr, "Starting application server"); - - axum_server::bind_rustls(app_addr, tls_config) - .handle(handle) - .serve(router.into_make_service()) - .await - .map_err(|e| InitError(format!("Starting web server for metrics: {}", e)))?; - - Ok(()) + Ok(setup_app(app).await?.layer(layers)) } // Gracefully shutdown the server on receiving SIGINT or SIGTERM @@ -154,30 +140,6 @@ async fn graceful_shutdown(handle: Handle) { handle.graceful_shutdown(Some(Duration::from_secs(30))); } -async fn create_js_context(js_config: &JetStreamConfig) -> crate::Result { - // Connect to Jetstream with user and password if they are set - let js_client = match js_config.auth.as_ref() { - Some(auth) => { - async_nats::connect_with_options( - &js_config.url, - async_nats::ConnectOptions::with_user_and_password( - auth.username.clone(), - auth.password.clone(), - ), - ) - .await - } - _ => async_nats::connect(&js_config.url).await, - } - .map_err(|e| { - InitError(format!( - "Connecting to jetstream server {}: {}", - &js_config.url, e - )) - })?; - Ok(jetstream::new(js_client)) -} - const PUBLISH_ENDPOINTS: [&str; 3] = [ "/v1/process/sync", "/v1/process/sync_serve", @@ -228,28 +190,14 @@ async fn auth_middleware( } } -#[derive(Clone)] -pub(crate) struct AppState { - pub(crate) settings: Arc, - pub(crate) callback_state: CallbackState, - pub(crate) context: Context, -} - async fn setup_app( - settings: Arc, - context: Context, - callback_state: CallbackState, + app: AppState, ) -> crate::Result { - let app_state = AppState { - settings, - callback_state: callback_state.clone(), - context: context.clone(), - }; let parent = Router::new() .route("/health", get(health_check)) .route("/livez", get(livez)) // Liveliness check .route("/readyz", get(readyz)) - .with_state(app_state.clone()); // Readiness check + .with_state(app.clone()); // Readiness check // a pool based client implementation for direct proxy, this client is cloneable. let client: direct_proxy::Client = @@ -260,9 +208,9 @@ async fn setup_app( let app = parent .nest( "/v1/direct", - direct_proxy(client, app_state.settings.upstream_addr.clone()), + direct_proxy(client, app.settings.upstream_addr.clone()), ) - .nest("/v1/process", routes(app_state).await?); + .nest("/v1/process", routes(app).await?); Ok(app) } @@ -278,13 +226,7 @@ async fn livez() -> impl IntoResponse { async fn readyz( State(app): State>, ) -> impl IntoResponse { - if app.callback_state.clone().ready().await - && app - .context - .get_stream(&app.settings.jetstream.stream) - .await - .is_ok() - { + if app.callback_state.clone().ready().await { StatusCode::NO_CONTENT } else { StatusCode::INTERNAL_SERVER_ERROR @@ -308,188 +250,100 @@ async fn routes( #[cfg(test)] mod tests { - use async_nats::jetstream::stream; + use std::sync::Arc; + use axum::http::StatusCode; - use tokio::time::{sleep, Duration}; use tower::ServiceExt; use super::*; use crate::app::callback::store::memstore::InMemoryStore; - use crate::config::generate_certs; + use crate::Settings; + use callback::state::State as CallbackState; + use tokio::sync::mpsc; + use tracker::MessageGraph; const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; type Result = core::result::Result; type Error = Box; - #[tokio::test] - async fn test_start_main_server() -> Result<()> { - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .unwrap(); - - let settings = Arc::new(Settings { - app_listen_port: 0, - ..Settings::default() - }); - - let server = tokio::spawn(async move { - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let result = start_main_server(settings, tls_config, pipeline_spec).await; - assert!(result.is_ok()) - }); - - // Give the server a little bit of time to start - sleep(Duration::from_millis(50)).await; - - // Stop the server - server.abort(); - Ok(()) - } - - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_setup_app() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (tx, _) = mpsc::channel(10); + let app = AppState { + message: tx, + settings, + callback_state, + }; - let result = setup_app(settings, context, callback_state).await; + let result = setup_app(app).await; assert!(result.is_ok()); Ok(()) } - #[cfg(feature = "all-tests")] #[tokio::test] - async fn test_livez() -> Result<()> { + async fn test_health_check_endpoints() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; - let result = setup_app(settings, context, callback_state).await; + let (messages_tx, _messages_rx) = mpsc::channel(10); + let app = AppState { + message: messages_tx, + settings, + callback_state, + }; + + let router = setup_app(app).await.unwrap(); let request = Request::builder().uri("/livez").body(Body::empty())?; - - let response = result?.oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - - #[cfg(feature = "all-tests")] - #[tokio::test] - async fn test_readyz() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); - - let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - - let result = setup_app(settings, context, callback_state).await; let request = Request::builder().uri("/readyz").body(Body::empty())?; - - let response = result.unwrap().oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - #[tokio::test] - async fn test_health_check() { - let response = health_check().await; - let response = response.into_response(); + let request = Request::builder().uri("/health").body(Body::empty())?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::OK); + Ok(()) } - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_auth_middleware() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); + let settings = Settings { + api_auth_token: Some("test-token".into()), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (messages_tx, _messages_rx) = mpsc::channel(10); + let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, }; - let app = Router::new() - .nest("/v1/process", routes(app_state).await.unwrap()) - .layer(middleware::from_fn_with_state( - Some("test_token".to_owned()), - auth_middleware, - )); - - let res = app + let router = router_with_auth(app_state).await.unwrap(); + let res = router .oneshot( axum::extract::Request::builder() + .method("POST") .uri("/v1/process/sync") .body(Body::empty()) .unwrap(), diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index af7d3917ff..6f61a0530f 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -1,6 +1,5 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; -use async_nats::{jetstream::Context, HeaderMap as JSHeaderMap}; use axum::{ body::Bytes, extract::State, @@ -9,12 +8,13 @@ use axum::{ routing::post, Json, Router, }; +use tokio::sync::{mpsc, oneshot}; use tracing::error; use uuid::Uuid; use super::{callback::store::Store, AppState}; -use crate::app::callback::state; use crate::app::response::{ApiError, ServeResponse}; +use crate::{app::callback::state, Message, MessageWrapper}; // TODO: // - [ ] better health check @@ -37,10 +37,9 @@ const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; struct ProxyState { + message: mpsc::Sender, tid_header: String, - context: Context, callback: state::State, - stream: String, callback_url: String, } @@ -48,10 +47,9 @@ pub(crate) async fn jetstream_proxy( state: AppState, ) -> crate::Result { let proxy_state = Arc::new(ProxyState { + message: state.message.clone(), tid_header: state.settings.tid_header.clone(), - context: state.context.clone(), callback: state.callback_state.clone(), - stream: state.settings.jetstream.stream.clone(), callback_url: format!( "https://{}:{}/v1/process/callback", state.settings.host_ip, state.settings.app_listen_port @@ -76,20 +74,34 @@ async fn sync_publish_serve( // Register the ID in the callback proxy state let notify = proxy_state.callback.clone().register(id.clone()); - if let Err(e) = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - proxy_state.tid_header.as_str(), - id.as_str(), - ) - .await - { + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } + + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + + proxy_state + .message + .send(message) + .await + .expect("Failed to send request payload to Serving channel"); + + if let Err(e) = rx.await { // Deregister the ID in the callback proxy state if writing to Jetstream fails let _ = proxy_state.callback.clone().deregister(&id).await; - error!(error = ?e, "Publishing message to Jetstream for sync serve request"); + error!(error = ?e, "Waiting for acknowledgement for message"); return Err(ApiError::BadGateway( "Failed to write message to Jetstream".to_string(), )); @@ -143,21 +155,30 @@ async fn sync_publish( ) -> Result, ApiError> { let id = extract_id_from_headers(&proxy_state.tid_header, &headers); + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } + + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + // Register the ID in the callback proxy state let notify = proxy_state.callback.clone().register(id.clone()); + proxy_state.message.send(message).await.unwrap(); // FIXME: - if let Err(e) = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - &proxy_state.tid_header, - id.as_str(), - ) - .await - { - // Deregister the ID in the callback proxy state if writing to Jetstream fails + if let Err(e) = rx.await { + // Deregister the ID in the callback proxy state if waiting for ack fails let _ = proxy_state.callback.clone().deregister(&id).await; error!(error = ?e, "Publishing message to Jetstream for sync request"); return Err(ApiError::BadGateway( @@ -192,62 +213,40 @@ async fn async_publish( body: Bytes, ) -> Result, ApiError> { let id = extract_id_from_headers(&proxy_state.tid_header, &headers); - let result = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - &proxy_state.tid_header, - id.as_str(), - ) - .await; + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } - match result { + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + + proxy_state.message.send(message).await.unwrap(); // FIXME: + match rx.await { Ok(_) => Ok(Json(ServeResponse::new( "Successfully published message".to_string(), id, StatusCode::OK, ))), Err(e) => { - error!(error = ?e, "Publishing message to Jetstream"); + error!(error = ?e, "Waiting for message save confirmation"); Err(ApiError::InternalServerError( - "Failed to publish message to Jetstream".to_string(), + "Failed to save message".to_string(), )) } } } -/// Write to JetStream and return the metadata. It is responsible for getting the ID from the header. -async fn publish_to_jetstream( - stream: String, - callback_url: &str, - headers: HeaderMap, - body: Bytes, - js_context: Context, - id_header: &str, - id_header_value: &str, -) -> Result<(), async_nats::Error> { - let mut js_headers = JSHeaderMap::new(); - - // pass in the HTTP headers as jetstream headers - for (k, v) in headers.iter() { - js_headers.append(k.as_ref(), String::from_utf8_lossy(v.as_bytes()).borrow()) - } - - js_headers.append(id_header, id_header_value); // Use the passed ID - js_headers.append(CALLBACK_URL_KEY, callback_url); - - js_context - .publish_with_headers(stream, js_headers, body) - .await - .map_err(|e| format!("Publishing message to stream: {e:?}"))? - .await - .map_err(|e| format!("Waiting for acknowledgement of published message: {e:?}"))?; - - Ok(()) -} - // extracts the ID from the headers, if not found, generates a new UUID fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { headers.get(tid_header).map_or_else( @@ -256,13 +255,10 @@ fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { ) } -#[cfg(feature = "nats-tests")] #[cfg(test)] mod tests { use std::sync::Arc; - use async_nats::jetstream; - use async_nats::jetstream::stream; use axum::body::{to_bytes, Body}; use axum::extract::Request; use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; @@ -303,46 +299,47 @@ mod tests { #[tokio::test] async fn test_async_publish() -> Result<(), Box> { - let settings = Settings::default(); - let settings = Arc::new(settings); - let client = async_nats::connect(&settings.jetstream.url) - .await - .map_err(|e| format!("Connecting to Jetstream: {:?}", e))?; - - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "default"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e))?; + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mock_store = MockStore {}; - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec) - .map_err(|e| format!("Failed to create message graph from pipeline spec: {:?}", e))?; - + let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mock_store).await?; + + let (messages_tx, mut messages_rx) = mpsc::channel::(10); + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, - settings, }; + let app = jetstream_proxy(app_state).await?; let res = Request::builder() .method("POST") .uri("/async") .header(CONTENT_TYPE, "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -350,7 +347,7 @@ mod tests { result, json!({ "message": "Successfully published message", - "id": id, + "id": ID_VALUE, "code": 200 }) ); @@ -392,20 +389,12 @@ mod tests { #[tokio::test] async fn test_sync_publish() { - let settings = Settings::default(); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -413,16 +402,28 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); - let settings = Arc::new(settings); + let (messages_tx, mut messages_rx) = mpsc::channel(10); + + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state: callback_state.clone(), - context, }; + let app = jetstream_proxy(app_state).await.unwrap(); tokio::spawn(async move { - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); let mut retries = 0; loop { match callback_state.insert_callback_requests(cbs.clone()).await { @@ -442,11 +443,13 @@ mod tests { .method("POST") .uri("/sync") .header("Content-Type", "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.clone().oneshot(res).await.unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -454,7 +457,7 @@ mod tests { result, json!({ "message": "Successfully processed the message", - "id": id, + "id": ID_VALUE, "code": 200 }) ); @@ -462,20 +465,8 @@ mod tests { #[tokio::test] async fn test_sync_publish_serve() { + const ID_VALUE: &str = "foobar"; let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_serve_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -483,16 +474,28 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); + let (messages_tx, mut messages_rx) = mpsc::channel(10); + + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { + message: messages_tx, settings, callback_state: callback_state.clone(), - context, }; let app = jetstream_proxy(app_state).await.unwrap(); // pipeline is in -> cat -> out, so we will have 3 callback requests - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); // spawn a tokio task which will insert the callback requests to the callback state // if it fails, sleep for 10ms and retry @@ -531,11 +534,14 @@ mod tests { .method("POST") .uri("/sync_serve") .header("Content-Type", "text/plain") - .header("id", id) + .header("ID", ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); + assert_eq!(response.status(), StatusCode::OK); let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 7ba3778d00..16c2ee125c 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -1,71 +1,29 @@ use std::collections::HashMap; use std::fmt::Debug; -use async_nats::rustls; use base64::prelude::BASE64_STANDARD; use base64::Engine; use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; use serde::{Deserialize, Serialize}; -use crate::Error::ParseConfig; +use crate::{ + pipeline::PipelineDCG, + Error::{self, ParseConfig}, +}; const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_URL: &str = "NUMAFLOW_ISBSVC_JETSTREAM_URL"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM: &str = "NUMAFLOW_SERVING_JETSTREAM_STREAM"; const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_USER: &str = "NUMAFLOW_ISBSVC_JETSTREAM_USER"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD: &str = "NUMAFLOW_ISBSVC_JETSTREAM_PASSWORD"; const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; +const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| format!("Failed to generate cert {:?}", e))?; Ok((cert, key_pair)) } -#[derive(Deserialize, Clone, PartialEq)] -pub struct BasicAuth { - pub username: String, - pub password: String, -} - -impl Debug for BasicAuth { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let passwd_printable = if self.password.len() > 4 { - let passwd: String = self - .password - .chars() - .skip(self.password.len() - 2) - .take(2) - .collect(); - format!("***{}", passwd) - } else { - "*****".to_owned() - }; - write!(f, "{}:{}", self.username, passwd_printable) - } -} - -#[derive(Debug, Deserialize, Clone, PartialEq)] -pub struct JetStreamConfig { - pub stream: String, - pub url: String, - pub auth: Option, -} - -impl Default for JetStreamConfig { - fn default() -> Self { - Self { - stream: "default".to_owned(), - url: "localhost:4222".to_owned(), - auth: None, - } - } -} - #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RedisConfig { pub addr: String, @@ -95,11 +53,11 @@ pub struct Settings { pub metrics_server_listen_port: u16, pub upstream_addr: String, pub drain_timeout_secs: u64, - pub jetstream: JetStreamConfig, pub redis: RedisConfig, /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header pub host_ip: String, pub api_auth_token: Option, + pub pipeline_spec: PipelineDCG, } impl Default for Settings { @@ -110,10 +68,10 @@ impl Default for Settings { metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".to_owned(), drain_timeout_secs: 10, - jetstream: JetStreamConfig::default(), redis: RedisConfig::default(), host_ip: "127.0.0.1".to_owned(), api_auth_token: None, + pipeline_spec: Default::default(), } } } @@ -133,7 +91,7 @@ pub struct CallbackStorageConfig { /// This implementation is to load settings from env variables impl TryFrom> for Settings { - type Error = crate::Error; + type Error = Error; fn try_from(env_vars: HashMap) -> std::result::Result { let host_ip = env_vars .get(ENV_NUMAFLOW_SERVING_HOST_IP) @@ -144,19 +102,27 @@ impl TryFrom> for Settings { })? .to_owned(); + let pipeline_spec: PipelineDCG = env_vars + .get(ENV_MIN_PIPELINE_SPEC) + .ok_or_else(|| { + Error::ParseConfig(format!( + "Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}" + )) + })? + .parse() + .map_err(|e| { + Error::ParseConfig(format!( + "Parsing pipeline spec: {}: error={e:?}", + env_vars.get(ENV_MIN_PIPELINE_SPEC).unwrap() + )) + })?; + let mut settings = Settings { host_ip, + pipeline_spec, ..Default::default() }; - if let Some(jetstream_url) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_URL) { - settings.jetstream.url = jetstream_url.to_owned(); - } - - if let Some(jetstream_stream) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM) { - settings.jetstream.stream = jetstream_stream.to_owned(); - } - if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { settings.api_auth_token = Some(api_auth_token.to_owned()); } @@ -169,17 +135,6 @@ impl TryFrom> for Settings { })?; } - // If username is set, the password also must be set - if let Some(username) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_USER) { - let Some(password) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD) else { - return Err(ParseConfig(format!("Env variable {ENV_NUMAFLOW_SERVING_JETSTREAM_USER} is set, but {ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD} is not set"))); - }; - settings.jetstream.auth = Some(BasicAuth { - username: username.to_owned(), - password: password.to_owned(), - }); - } - // Update redis.ttl_secs from environment variable if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { @@ -213,17 +168,9 @@ impl TryFrom> for Settings { #[cfg(test)] mod tests { - use super::*; + use crate::pipeline::{Edge, Vertex}; - #[test] - fn test_basic_auth_debug_print() { - let auth = BasicAuth { - username: "js-auth-user".into(), - password: "js-auth-password".into(), - }; - let auth_debug = format!("{auth:?}"); - assert_eq!(auth_debug, "js-auth-user:***rd"); - } + use super::*; #[test] fn test_default_config() { @@ -234,8 +181,6 @@ mod tests { assert_eq!(settings.metrics_server_listen_port, 3001); assert_eq!(settings.upstream_addr, "localhost:8888"); assert_eq!(settings.drain_timeout_secs, 10); - assert_eq!(settings.jetstream.stream, "default"); - assert_eq!(settings.jetstream.url, "localhost:4222"); assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); assert_eq!(settings.redis.max_tasks, 50); assert_eq!(settings.redis.retries, 5); @@ -246,21 +191,12 @@ mod tests { fn test_config_parse() { // Set up the environment variables let env_vars = [ - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_URL, - "nats://isbsvc-default-js-svc.default.svc:4222", - ), - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM, - "ascii-art-pipeline-in-serving-source", - ), - (ENV_NUMAFLOW_SERVING_JETSTREAM_USER, "js-auth-user"), - (ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD, "js-user-password"), (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), (ENV_NUMAFLOW_SERVING_STORE_TTL, "86400"), - (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0=") + (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0="), + (ENV_MIN_PIPELINE_SPEC, "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6InNlcnZpbmctaW4iLCJzb3VyY2UiOnsic2VydmluZyI6eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX19LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciIsImVudiI6W3sibmFtZSI6IlJVU1RfTE9HIiwidmFsdWUiOiJpbmZvIn1dfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2aW5nLXNpbmsiLCJzaW5rIjp7InVkc2luayI6eyJjb250YWluZXIiOnsiaW1hZ2UiOiJxdWF5LmlvL251bWFpby9udW1hZmxvdy1ycy9zaW5rLWxvZzpzdGFibGUiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9fX0sInJldHJ5U3RyYXRlZ3kiOnt9fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fV0sImVkZ2VzIjpbeyJmcm9tIjoic2VydmluZy1pbiIsInRvIjoic2VydmluZy1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH1dLCJsaWZlY3ljbGUiOnt9LCJ3YXRlcm1hcmsiOnt9fQ==") ]; // Call the config method @@ -277,14 +213,6 @@ mod tests { metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".into(), drain_timeout_secs: 10, - jetstream: JetStreamConfig { - stream: "ascii-art-pipeline-in-serving-source".into(), - url: "nats://isbsvc-default-js-svc.default.svc:4222".into(), - auth: Some(BasicAuth { - username: "js-auth-user".into(), - password: "js-user-password".into(), - }), - }, redis: RedisConfig { addr: "redis://redis:6379".into(), max_tasks: 50, @@ -294,8 +222,22 @@ mod tests { }, host_ip: "10.2.3.5".into(), api_auth_token: Some("api-auth-token".into()), + pipeline_spec: PipelineDCG { + vertices: vec![ + Vertex { + name: "serving-in".into(), + }, + Vertex { + name: "serving-sink".into(), + }, + ], + edges: vec![Edge { + from: "serving-in".into(), + to: "serving-sink".into(), + conditions: None, + }], + }, }; - assert_eq!(settings, expected_config); } } diff --git a/rust/serving/src/error.rs b/rust/serving/src/error.rs index d53509c939..8d03c48234 100644 --- a/rust/serving/src/error.rs +++ b/rust/serving/src/error.rs @@ -1,4 +1,5 @@ use thiserror::Error; +use tokio::sync::oneshot; // TODO: introduce module level error handling @@ -44,6 +45,12 @@ pub enum Error { #[error("Init Error - {0}")] InitError(String), + #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] + ActorTaskTerminated(oneshot::error::RecvError), + + #[error("Serving source error - {0}")] + Source(String), + #[error("Other Error - {0}")] // catch-all variant for now Other(String), diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 796313bdb2..001065ddfe 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -1,12 +1,13 @@ -use std::env; use std::net::SocketAddr; use std::sync::Arc; +use crate::app::callback::state::State as CallbackState; +use app::callback::store::Store; use axum_server::tls_rustls::RustlsConfig; +use tokio::sync::mpsc; use tracing::info; pub use self::error::{Error, Result}; -use self::pipeline::PipelineDCG; use crate::app::start_main_server; use crate::config::generate_certs; use crate::metrics::start_https_metrics_server; @@ -21,41 +22,43 @@ mod error; mod metrics; mod pipeline; -const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; +pub mod source; +use crate::source::MessageWrapper; +pub use source::{Message, ServingSource}; + +#[derive(Clone)] +pub(crate) struct AppState { + pub(crate) message: mpsc::Sender, + pub(crate) settings: Arc, + pub(crate) callback_state: CallbackState, +} + +pub(crate) async fn serve( + app: AppState, +) -> std::result::Result<(), Box> +where + T: Clone + Send + Sync + Store + 'static, +{ + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); -pub async fn serve( - settings: Arc, -) -> std::result::Result<(), Box> { let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) .await .map_err(|e| format!("Failed to create tls config {:?}", e))?; - // TODO: Move all env variables into one place. Some env variables are loaded when Settings is initialized - let pipeline_spec: PipelineDCG = env::var(ENV_MIN_PIPELINE_SPEC) - .map_err(|_| { - format!("Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}") - })? - .parse() - .map_err(|e| { - format!( - "Parsing pipeline spec: {}: error={e:?}", - env::var(ENV_MIN_PIPELINE_SPEC).unwrap() - ) - })?; - - info!(config = ?settings, ?pipeline_spec, "Starting server with config and pipeline spec"); + info!(config = ?app.settings, "Starting server with config and pipeline spec"); // Start the metrics server, which serves the prometheus metrics. let metrics_addr: SocketAddr = - format!("0.0.0.0:{}", &settings.metrics_server_listen_port).parse()?; + format!("0.0.0.0:{}", &app.settings.metrics_server_listen_port).parse()?; let metrics_server_handle = tokio::spawn(start_https_metrics_server(metrics_addr, tls_config.clone())); // Start the main server, which serves the application. - let app_server_handle = tokio::spawn(start_main_server(settings, tls_config, pipeline_spec)); + let app_server_handle = tokio::spawn(start_main_server(app, tls_config)); // TODO: is try_join the best? we need to short-circuit at the first failure tokio::try_join!(flatten(app_server_handle), flatten(metrics_server_handle))?; diff --git a/rust/serving/src/metrics.rs b/rust/serving/src/metrics.rs index 4c64760d4d..a605cc9988 100644 --- a/rust/serving/src/metrics.rs +++ b/rust/serving/src/metrics.rs @@ -175,6 +175,8 @@ mod tests { #[tokio::test] async fn test_start_metrics_server() -> Result<()> { + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) diff --git a/rust/serving/src/pipeline.rs b/rust/serving/src/pipeline.rs index d782e3d73a..cb491d7d88 100644 --- a/rust/serving/src/pipeline.rs +++ b/rust/serving/src/pipeline.rs @@ -10,7 +10,7 @@ use crate::Error::ParseConfig; // OperatorType is an enum that contains the types of operators // that can be used in the conditions for the edge. #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub enum OperatorType { +pub(crate) enum OperatorType { #[serde(rename = "and")] And, #[serde(rename = "or")] @@ -42,40 +42,37 @@ impl From for OperatorType { } // Tag is a struct that contains the information about the tags for the edge -#[cfg_attr(test, derive(PartialEq))] -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Tag { - pub operator: Option, - pub values: Vec, +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub(crate) struct Tag { + pub(crate) operator: Option, + pub(crate) values: Vec, } // Conditions is a struct that contains the information about the conditions for the edge -#[cfg_attr(test, derive(PartialEq))] -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Conditions { - pub tags: Option, +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub(crate) struct Conditions { + pub(crate) tags: Option, } // Edge is a struct that contains the information about the edge in the pipeline. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Edge { - pub from: String, - pub to: String, - pub conditions: Option, +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub(crate) struct Edge { + pub(crate) from: String, + pub(crate) to: String, + pub(crate) conditions: Option, } /// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges /// from the pipeline spec -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde()] -pub struct PipelineDCG { - pub vertices: Vec, - pub edges: Vec, +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] +pub(crate) struct PipelineDCG { + pub(crate) vertices: Vec, + pub(crate) edges: Vec, } -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Vertex { - pub name: String, +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub(crate) struct Vertex { + pub(crate) name: String, } impl FromStr for PipelineDCG { diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs new file mode 100644 index 0000000000..d038179672 --- /dev/null +++ b/rust/serving/src/source.rs @@ -0,0 +1,292 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::Instant; + +use crate::app::callback::state::State as CallbackState; +use crate::app::callback::store::redisstore::RedisConnection; +use crate::app::tracker::MessageGraph; +use crate::Settings; +use crate::{Error, Result}; + +/// [Message] with a oneshot for notifying when the message has been completed processed. +pub(crate) struct MessageWrapper { + // TODO: this might be more that saving to ISB. + pub(crate) confirm_save: oneshot::Sender<()>, + pub(crate) message: Message, +} + +/// Serving payload passed on to Numaflow. +#[derive(Debug)] +pub struct Message { + pub value: Bytes, + pub id: String, + pub headers: HashMap, +} + +enum ActorMessage { + Read { + batch_size: usize, + timeout_at: Instant, + reply_to: oneshot::Sender>>, + }, + Ack { + offsets: Vec, + reply_to: oneshot::Sender>, + }, +} + +/// Background actor that starts Axum server for accepting HTTP requests. +struct ServingSourceActor { + /// The HTTP handlers will put the message received from the payload to this channel + messages: mpsc::Receiver, + /// Channel for the actor handle to communicate with this actor + handler_rx: mpsc::Receiver, + /// Mapping from request's ID header (usually `X-Numaflow-Id` header) to a channel. + /// This sending a message on this channel notifies the HTTP handler function that the message + /// has been successfully processed. + tracker: HashMap>, + vertex_replica_id: u16, +} + +impl ServingSourceActor { + async fn start( + settings: Arc, + handler_rx: mpsc::Receiver, + request_channel_buffer_size: usize, + vertex_replica_id: u16, + ) -> Result<()> { + // Channel to which HTTP handlers will send request payload + let (messages_tx, messages_rx) = mpsc::channel(request_channel_buffer_size); + // Create a redis store to store the callbacks and the custom responses + let redis_store = RedisConnection::new(settings.redis.clone()).await?; + // Create the message graph from the pipeline spec and the redis store + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { + Error::InitError(format!( + "Creating message graph from pipeline spec: {:?}", + e + )) + })?; + let callback_state = CallbackState::new(msg_graph, redis_store).await?; + + tokio::spawn(async move { + let mut serving_actor = ServingSourceActor { + messages: messages_rx, + handler_rx, + tracker: HashMap::new(), + vertex_replica_id, + }; + serving_actor.run().await; + }); + let app = crate::AppState { + message: messages_tx, + settings, + callback_state, + }; + tokio::spawn(async move { + crate::serve(app).await.unwrap(); + }); + Ok(()) + } + + async fn run(&mut self) { + while let Some(msg) = self.handler_rx.recv().await { + self.handle_message(msg).await; + } + } + + async fn handle_message(&mut self, actor_msg: ActorMessage) { + match actor_msg { + ActorMessage::Read { + batch_size, + timeout_at, + reply_to, + } => { + let messages = self.read(batch_size, timeout_at).await; + let _ = reply_to.send(messages); + } + ActorMessage::Ack { offsets, reply_to } => { + let status = self.ack(offsets).await; + let _ = reply_to.send(status); + } + } + } + + async fn read(&mut self, count: usize, timeout_at: Instant) -> Result> { + let mut messages = vec![]; + loop { + // Stop if the read timeout has reached or if we have collected the requested number of messages + if messages.len() >= count || Instant::now() >= timeout_at { + break; + } + let next_msg = self.messages.recv(); + let message = match tokio::time::timeout_at(timeout_at, next_msg).await { + Ok(Some(msg)) => msg, + Ok(None) => { + // If we have collected at-least one message, we return those messages. + // The error will happen on all the subsequent read attempts too. + if messages.is_empty() { + return Err(Error::Other( + "Sending half of the Serving channel has disconnected".into(), + )); + } + tracing::error!("Sending half of the Serving channel has disconnected"); + return Ok(messages); + } + Err(_) => return Ok(messages), + }; + let MessageWrapper { + confirm_save, + message, + } = message; + + self.tracker.insert(message.id.clone(), confirm_save); + messages.push(message); + } + Ok(messages) + } + + async fn ack(&mut self, offsets: Vec) -> Result<()> { + let offset_suffix = format!("-{}", self.vertex_replica_id); + for offset in offsets { + let offset = offset.strip_suffix(&offset_suffix).ok_or_else(|| { + Error::Source(format!("offset does not end with '{}'", &offset_suffix)) + })?; + let confirm_save_tx = self + .tracker + .remove(offset) + .ok_or_else(|| Error::Source("offset was not found in the tracker".into()))?; + confirm_save_tx + .send(()) + .map_err(|e| Error::Source(format!("Sending on confirm_save channel: {e:?}")))?; + } + Ok(()) + } +} + +#[derive(Clone)] +pub struct ServingSource { + batch_size: usize, + // timeout for each batch read request + timeout: Duration, + actor_tx: mpsc::Sender, +} + +impl ServingSource { + pub async fn new( + settings: Arc, + batch_size: usize, + timeout: Duration, + vertex_replica_id: u16, + ) -> Result { + let (actor_tx, actor_rx) = mpsc::channel(2 * batch_size); + ServingSourceActor::start(settings, actor_rx, 2 * batch_size, vertex_replica_id).await?; + Ok(Self { + batch_size, + timeout, + actor_tx, + }) + } + + pub async fn read_messages(&self) -> Result> { + let start = Instant::now(); + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Read { + reply_to: tx, + batch_size: self.batch_size, + timeout_at: Instant::now() + self.timeout, + }; + let _ = self.actor_tx.send(actor_msg).await; + let messages = rx.await.map_err(Error::ActorTaskTerminated)??; + tracing::debug!( + count = messages.len(), + requested_count = self.batch_size, + time_taken_ms = start.elapsed().as_millis(), + "Got messages from Serving source" + ); + Ok(messages) + } + + pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Ack { + offsets, + reply_to: tx, + }; + let _ = self.actor_tx.send(actor_msg).await; + rx.await.map_err(Error::ActorTaskTerminated)??; + Ok(()) + } +} + +#[cfg(feature = "redis-tests")] +#[cfg(test)] +mod tests { + use std::{sync::Arc, time::Duration}; + + use crate::Settings; + + use super::ServingSource; + + type Result = std::result::Result>; + #[tokio::test] + async fn test_serving_source() -> Result<()> { + let settings = Arc::new(Settings::default()); + let serving_source = + ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1), 0).await?; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + // Wait for the server + for _ in 0..10 { + let resp = client + .get(format!( + "https://localhost:{}/livez", + settings.app_listen_port + )) + .send() + .await; + if resp.is_ok() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut messages = serving_source.read_messages().await.unwrap(); + if messages.is_empty() { + // Server has not received any requests yet + continue; + } + assert_eq!(messages.len(), 1); + let msg = messages.remove(0); + serving_source + .ack_messages(vec![format!("{}-0", msg.id)]) + .await + .unwrap(); + break; + } + }); + + let resp = client + .post(format!( + "https://localhost:{}/v1/process/async", + settings.app_listen_port + )) + .json("test-payload") + .send() + .await?; + + assert!(resp.status().is_success()); + Ok(()) + } +}