From e9e721a06c4aaceafeb67b3c2272372a2c6f30a3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 18 Apr 2025 11:07:07 +0200 Subject: [PATCH 1/2] Fixing the CI (grpc path). --- router/src/http/server.rs | 4 +- router/src/logging.rs | 102 ++++++++++++++++++++------------------ router/tests/common.rs | 1 + 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 506c9fa5..f805744a 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1833,7 +1833,9 @@ pub async fn run( .layer(Extension(info)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) - .layer(axum::middleware::from_fn(logging::trace_context_middleware)) + .layer(axum::middleware::from_fn( + logging::http::trace_context_middleware, + )) .layer(DefaultBodyLimit::max(payload_limit)) .layer(cors_layer); diff --git a/router/src/logging.rs b/router/src/logging.rs index 4ac4db1e..43306499 100644 --- a/router/src/logging.rs +++ b/router/src/logging.rs @@ -1,6 +1,3 @@ -use axum::{extract::Request, middleware::Next, response::Response}; -use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId}; -use opentelemetry::Context; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use opentelemetry_sdk::propagation::TraceContextPropagator; @@ -10,56 +7,63 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; -struct TraceParent { - #[allow(dead_code)] - version: u8, - trace_id: TraceId, - parent_id: SpanId, - trace_flags: TraceFlags, -} - -fn parse_traceparent(header_value: &str) -> Option { - let parts: Vec<&str> = header_value.split('-').collect(); - if parts.len() != 4 { - return None; +#[cfg(feature = "http")] +pub mod http { + use axum::{extract::Request, middleware::Next, response::Response}; + use opentelemetry::trace::{SpanContext, TraceContextExt}; + use opentelemetry::trace::{SpanId, TraceFlags, TraceId}; + use opentelemetry::Context; + struct TraceParent { + #[allow(dead_code)] + version: u8, + trace_id: TraceId, + parent_id: SpanId, + trace_flags: TraceFlags, } - let version = u8::from_str_radix(parts[0], 16).ok()?; - if version == 0xff { - return None; + fn parse_traceparent(header_value: &str) -> Option { + let parts: Vec<&str> = header_value.split('-').collect(); + if parts.len() != 4 { + return None; + } + + let version = u8::from_str_radix(parts[0], 16).ok()?; + if version == 0xff { + return None; + } + + let trace_id = TraceId::from_hex(parts[1]).ok()?; + let parent_id = SpanId::from_hex(parts[2]).ok()?; + let trace_flags = u8::from_str_radix(parts[3], 16).ok()?; + + Some(TraceParent { + version, + trace_id, + parent_id, + trace_flags: TraceFlags::new(trace_flags), + }) } - let trace_id = TraceId::from_hex(parts[1]).ok()?; - let parent_id = SpanId::from_hex(parts[2]).ok()?; - let trace_flags = u8::from_str_radix(parts[3], 16).ok()?; - - Some(TraceParent { - version, - trace_id, - parent_id, - trace_flags: TraceFlags::new(trace_flags), - }) -} - -pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response { - let context = request - .headers() - .get("traceparent") - .and_then(|v| v.to_str().ok()) - .and_then(parse_traceparent) - .map(|traceparent| { - Context::new().with_remote_span_context(SpanContext::new( - traceparent.trace_id, - traceparent.parent_id, - traceparent.trace_flags, - true, - Default::default(), - )) - }); - - request.extensions_mut().insert(context); - - next.run(request).await + pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response { + let context = request + .headers() + .get("traceparent") + .and_then(|v| v.to_str().ok()) + .and_then(parse_traceparent) + .map(|traceparent| { + Context::new().with_remote_span_context(SpanContext::new( + traceparent.trace_id, + traceparent.parent_id, + traceparent.trace_flags, + true, + Default::default(), + )) + }); + + request.extensions_mut().insert(context); + + next.run(request).await + } } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: diff --git a/router/tests/common.rs b/router/tests/common.rs index 55fdf5f5..1ae4333b 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -67,6 +67,7 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy None, None, "text-embeddings-inference.server".to_owned(), + 9000, None, ) }); From f851d77426940d394383e9c8d7a188a485447cc8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 18 Apr 2025 10:47:38 +0200 Subject: [PATCH 2/2] Warmup padded models too. --- router/src/lib.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 66b1e240..f1b8ba26 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -249,13 +249,11 @@ pub async fn run( .await .context("Model backend is not healthy")?; - if !backend.padded_model { - tracing::info!("Warming up model"); - backend - .warmup(max_input_length, max_batch_tokens, max_batch_requests) - .await - .context("Model backend is not healthy")?; - } + tracing::info!("Warming up model"); + backend + .warmup(max_input_length, max_batch_tokens, max_batch_requests) + .await + .context("Model backend is not healthy")?; let max_batch_requests = backend .max_batch_size