From 2d2986b81307b827dcd375a99258d8a6922de363 Mon Sep 17 00:00:00 2001 From: "Nathan (Blaise) Bruer" Date: Mon, 27 Jan 2025 15:07:30 -0600 Subject: [PATCH] Add `Closed` stream event to OriginEvents (#1570) Adds an event for when a stream is closed. --- .../nativelink/remote_execution/events.proto | 3 +- ...thub.trace_machina.nativelink.events.pb.rs | 4 +- nativelink-util/src/origin_event.rs | 130 +++++++++++++++--- nativelink-util/tests/origin_event_test.rs | 2 + 4 files changed, 117 insertions(+), 22 deletions(-) diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto index bb52a38c9..d087955d5 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto @@ -108,9 +108,10 @@ message StreamEvent { uint64 data_length = 3; WriteRequestOverride write_request = 4; google.longrunning.Operation operation = 5; + google.protobuf.Empty closed = 6; } - reserved 6; // NextId. + reserved 7; // NextId. } message Event { diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs index f14b46676..22a8662f1 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs @@ -197,7 +197,7 @@ pub mod response_event { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct StreamEvent { - #[prost(oneof = "stream_event::Event", tags = "1, 2, 3, 4, 5")] + #[prost(oneof = "stream_event::Event", tags = "1, 2, 3, 4, 5, 6")] pub event: ::core::option::Option, } /// Nested message and enum types in `StreamEvent`. @@ -218,6 +218,8 @@ pub mod stream_event { Operation( super::super::super::super::super::super::google::longrunning::Operation, ), + #[prost(message, tag = "6")] + Closed(()), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/nativelink-util/src/origin_event.rs b/nativelink-util/src/origin_event.rs index 6878af23a..fd8350073 100644 --- a/nativelink-util/src/origin_event.rs +++ b/nativelink-util/src/origin_event.rs @@ -17,6 +17,7 @@ use std::pin::Pin; use std::sync::{Arc, OnceLock}; use futures::future::ready; +use futures::task::{Context, Poll}; use futures::{Future, FutureExt, Stream, StreamExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ ActionResult, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchUpdateBlobsRequest, @@ -35,13 +36,15 @@ use nativelink_proto::google::bytestream::{ }; use nativelink_proto::google::longrunning::Operation; use nativelink_proto::google::rpc::Status; +use pin_project_lite::pin_project; use rand::RngCore; use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; use tonic::{Response, Status as TonicStatus, Streaming}; use uuid::Uuid; -use crate::make_symbol; use crate::origin_context::ActiveOriginContext; +use crate::{background_spawn, make_symbol}; const ORIGIN_EVENT_VERSION: u32 = 0; @@ -90,6 +93,7 @@ pub fn get_id_for_event(event: &Event) -> [u8; 2] { Some(stream_event::Event::DataLength(_)) => [0x03, 0x03], Some(stream_event::Event::WriteRequest(_)) => [0x03, 0x04], Some(stream_event::Event::Operation(_)) => [0x03, 0x05], + Some(stream_event::Event::Closed(())) => [0x03, 0x06], // Special case when stream has terminated. }, } } @@ -130,11 +134,17 @@ impl OriginEventCollector { } } - async fn publish_origin_event(&self, event: Event, parent_event_id: Option) -> Uuid { - let event_id = Uuid::now_v6(&get_node_id(Some(&event))); + /// Publishes an event to the origin event collector. + async fn publish_origin_event( + &self, + event: Event, + parent_event_id: Option, + maybe_event_id: Option, + ) -> Uuid { + let event_id = maybe_event_id.unwrap_or_else(|| Uuid::now_v6(&get_node_id(Some(&event)))); let parent_event_id = parent_event_id.map_or_else(String::new, |id| id.as_hyphenated().to_string()); - // Failing to send this event means that the receiver has been dropped. + // Ignore cases when channel is dropped. let _ = self .sender .send(OriginEvent { @@ -148,6 +158,78 @@ impl OriginEventCollector { .await; event_id } + + /// Publishes an event to the origin event collector. + /// If the buffer is full, the event will be sent in a background spawn. + /// This is useful for cases where the event is critical and must be sent, + /// but cannot await the send operation. + fn publish_origin_event_now_or_in_spawn(&self, event: Event, parent_event_id: Option) { + let event_id = Uuid::now_v6(&get_node_id(Some(&event))); + let parent_event_id = + parent_event_id.map_or_else(String::new, |id| id.as_hyphenated().to_string()); + self.sender + .try_send(OriginEvent { + version: ORIGIN_EVENT_VERSION, + event_id: event_id.as_hyphenated().to_string(), + parent_event_id, + bazel_request_metadata: self.bazel_metadata.clone(), + identity: self.identity.clone(), + event: Some(event), + }) + .map_or_else( + |e| match e { + TrySendError::Full(event) => { + let sender = self.sender.clone(); + background_spawn!("send_end_stream_origin_event", async move { + // Ignore cases when channel is dropped. + let _ = sender.send(event).await; + }); + } + // Ignore cases when channel is dropped. + TrySendError::Closed(_) => {} + }, + |()| {}, + ); + } +} + +pin_project! { + struct CloseEventStream { + #[pin] + inner: S, + ctx_impl: OriginEventContextImpl, + } + + impl PinnedDrop for CloseEventStream { + #[inline] + fn drop(this: Pin<&mut Self>) { + let event = Event { + event: Some(event::Event::Stream(StreamEvent { + event: Some(stream_event::Event::Closed(())) + })), + }; + // Try to send the event immediately, if we cannot because + // the buffer is full, do it in a background spawn. + this.ctx_impl.origin_event_collector + .publish_origin_event_now_or_in_spawn(event, Some(this.ctx_impl.parent_event_id)); + } + } +} + +impl Stream for CloseEventStream +where + S: Stream + Send, +{ + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let project = self.project(); + project.inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } } make_symbol!(ORIGIN_EVENT_COLLECTOR, OriginEventCollector); @@ -185,7 +267,7 @@ impl OriginEventContext<()> { let event = source_cb().as_event(); async move { let parent_event_id = origin_event_collector - .publish_origin_event(event, None) + .publish_origin_event(event, None, None) .await; OriginEventContext { inner: Some(OriginEventContextImpl { @@ -220,17 +302,20 @@ impl OriginEventContext { O: OriginEventSource + Send + 'a, S: Stream + Send + 'a, { - if self.inner.is_none() { + let Some(ctx_impl) = self.inner.clone() else { return Box::pin(stream); - } + }; let ctx = self.clone(); - Box::pin(stream.then(move |item| { - let ctx = ctx.clone(); - async move { - ctx.emit(|| &item).await; - item - } - })) + Box::pin(CloseEventStream { + inner: stream.then(move |item| { + let ctx = ctx.clone(); + async move { + ctx.emit(|| &item).await; + item + } + }), + ctx_impl, + }) } } @@ -246,7 +331,7 @@ pub trait OriginEventSource: Sized { fn publish<'a>(&self, ctx: &'a OriginEventContextImpl) -> impl Future + Send + 'a { let event = self.as_event(); ctx.origin_event_collector - .publish_origin_event(event, Some(ctx.parent_event_id)) + .publish_origin_event(event, Some(ctx.parent_event_id), None) // We don't need the Uuid here. .map(|_| ()) } @@ -400,10 +485,15 @@ fn to_batch_read_blobs_response_override(val: &BatchReadBlobsResponse) -> event: } #[inline] -fn to_empty(_: T) -> event::Event { +fn to_empty_response(_: T) -> event::Event { get_event_type!(Response, Empty, ()) } +#[inline] +fn to_empty_write_request(_: T) -> event::Event { + get_event_type!(Request, WriteRequest, ()) +} + // -- Requests -- impl_as_event! {Request, (), GetCapabilitiesRequest} @@ -414,7 +504,7 @@ impl_as_event! {Request, (), BatchReadBlobsRequest} impl_as_event! {Request, (), BatchUpdateBlobsRequest, BatchUpdateBlobsRequest, to_batch_update_blobs_request_override} impl_as_event! {Request, (), GetTreeRequest} impl_as_event! {Request, (), ReadRequest} -impl_as_event! {Request, (), Streaming, WriteRequest, to_empty} +impl_as_event! {Request, (), Streaming, WriteRequest, to_empty_write_request} impl_as_event! {Request, (), QueryWriteStatusRequest} impl_as_event! {Request, (), ExecuteRequest} impl_as_event! {Request, (), WaitExecutionRequest} @@ -425,13 +515,13 @@ impl_as_event! {Response, GetCapabilitiesRequest, ServerCapabilities} impl_as_event! {Response, GetActionResultRequest, ActionResult} impl_as_event! {Response, UpdateActionResultRequest, ActionResult} impl_as_event! {Response, Streaming, WriteResponse} -impl_as_event! {Response, ReadRequest, Pin> + Send + '_>>, Empty, to_empty} +impl_as_event! {Response, ReadRequest, Pin> + Send + '_>>, Empty, to_empty_response} impl_as_event! {Response, QueryWriteStatusRequest, QueryWriteStatusResponse} impl_as_event! {Response, FindMissingBlobsRequest, FindMissingBlobsResponse} impl_as_event! {Response, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse} impl_as_event! {Response, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchReadBlobsResponseOverride, to_batch_read_blobs_response_override} -impl_as_event! {Response, GetTreeRequest, Pin> + Send + '_>>, Empty, to_empty} -impl_as_event! {Response, ExecuteRequest, Pin> + Send + '_>>, Empty, to_empty} +impl_as_event! {Response, GetTreeRequest, Pin> + Send + '_>>, Empty, to_empty_response} +impl_as_event! {Response, ExecuteRequest, Pin> + Send + '_>>, Empty, to_empty_response} // -- Streams -- diff --git a/nativelink-util/tests/origin_event_test.rs b/nativelink-util/tests/origin_event_test.rs index 7e845ebda..d08c99576 100644 --- a/nativelink-util/tests/origin_event_test.rs +++ b/nativelink-util/tests/origin_event_test.rs @@ -121,6 +121,7 @@ fn get_id_for_event_test() { Some(stream_event::Event::DataLength(_)) => [0x03, 0x03], Some(stream_event::Event::WriteRequest(_)) => [0x03, 0x04], Some(stream_event::Event::Operation(_)) => [0x03, 0x05], + Some(stream_event::Event::Closed(())) => [0x03, 0x06], // Don't forget to add new entries to test cases. } } @@ -161,4 +162,5 @@ fn get_id_for_event_test() { test_event!(Stream, DataLength); test_event!(Stream, WriteRequest); test_event!(Stream, Operation); + test_event!(Stream, Closed); }