diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto index d087955d5..d9d52e8b2 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/events.proto @@ -16,6 +16,7 @@ syntax = "proto3"; package com.github.trace_machina.nativelink.events; +import "com/github/trace_machina/nativelink/remote_execution/worker_api.proto"; import "build/bazel/remote/execution/v2/remote_execution.proto"; import "google/bytestream/bytestream.proto"; import "google/devtools/build/v1/publish_build_event.proto"; @@ -81,7 +82,10 @@ message RequestEvent { google.bytestream.QueryWriteStatusRequest query_write_status_request = 10; build.bazel.remote.execution.v2.ExecuteRequest execute_request = 11; build.bazel.remote.execution.v2.WaitExecutionRequest wait_execution_request = 12; + + com.github.trace_machina.nativelink.remote_execution.StartExecute scheduler_start_execute = 13; } + reserved 14; // NextId. } message ResponseEvent { diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto index f3adbd3e9..d62ed69e0 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto @@ -169,7 +169,14 @@ message StartExecute { /// of the ActionResult. google.protobuf.Timestamp queued_timestamp = 3; - reserved 5; // NextId. + /// The post-computed platform properties that the scheduler has reserved for + /// the action. + build.bazel.remote.execution.v2.Platform platform = 5; + + /// The ID of the worker that is executing the action. + string worker_id = 6; + + reserved 7; // NextId. } /// This is a special message used to save actions into the CAS that can be used diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs index 22a8662f1..1532bb633 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.events.pb.rs @@ -100,7 +100,7 @@ pub struct WriteRequestOverride { pub struct RequestEvent { #[prost( oneof = "request_event::Event", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13" )] pub event: ::core::option::Option, } @@ -152,6 +152,8 @@ pub mod request_event { WaitExecutionRequest( super::super::super::super::super::super::build::bazel::remote::execution::v2::WaitExecutionRequest, ), + #[prost(message, tag = "13")] + SchedulerStartExecute(super::super::remote_execution::StartExecute), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs index 5c6b1ba3d..bc4622560 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs @@ -143,6 +143,15 @@ pub struct StartExecute { /// / of the ActionResult. #[prost(message, optional, tag = "3")] pub queued_timestamp: ::core::option::Option<::prost_types::Timestamp>, + /// / The post-computed platform properties that the scheduler has reserved for + /// / the action. + #[prost(message, optional, tag = "5")] + pub platform: ::core::option::Option< + super::super::super::super::super::build::bazel::remote::execution::v2::Platform, + >, + /// / The ID of the worker that is executing the action. + #[prost(string, tag = "6")] + pub worker_id: ::prost::alloc::string::String, } /// / This is a special message used to save actions into the CAS that can be used /// / by programs like bb_browswer to inspect the history of a build. diff --git a/nativelink-scheduler/src/api_worker_scheduler.rs b/nativelink-scheduler/src/api_worker_scheduler.rs index 60b8dfdac..e689a44dc 100644 --- a/nativelink-scheduler/src/api_worker_scheduler.rs +++ b/nativelink-scheduler/src/api_worker_scheduler.rs @@ -250,7 +250,7 @@ impl ApiWorkerSchedulerImpl { let was_paused = !worker.can_accept_work(); // Note: We need to run this before dealing with backpressure logic. - let complete_action_res = worker.complete_action(operation_id); + let complete_action_res = worker.complete_action(operation_id).await; // Only pause if there's an action still waiting that will unpause. if (was_paused || due_to_backpressure) && worker.has_actions() { @@ -273,8 +273,9 @@ impl ApiWorkerSchedulerImpl { action_info: ActionInfoWithProps, ) -> Result<(), Error> { if let Some(worker) = self.workers.get_mut(&worker_id) { - let notify_worker_result = - worker.notify_update(WorkerUpdate::RunAction((operation_id, action_info.clone()))); + let notify_worker_result = worker + .notify_update(WorkerUpdate::RunAction((operation_id, action_info.clone()))) + .await; if notify_worker_result.is_err() { event!( @@ -314,7 +315,7 @@ impl ApiWorkerSchedulerImpl { let mut result = Ok(()); if let Some(mut worker) = self.remove_worker(worker_id) { // We don't care if we fail to send message to worker, this is only a best attempt. - let _ = worker.notify_update(WorkerUpdate::Disconnect); + let _ = worker.notify_update(WorkerUpdate::Disconnect).await; for (operation_id, _) in worker.running_action_infos.drain() { result = result.merge( self.worker_state_manager diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs index 9267c96fb..c654df85f 100644 --- a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -22,6 +22,8 @@ use nativelink_metric::{ use nativelink_util::action_messages::{ ActionInfo, ActionStage, ActionState, OperationId, WorkerId, }; +use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::origin_event::{OriginMetadata, ORIGIN_EVENT_COLLECTOR}; use serde::{Deserialize, Serialize}; use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; @@ -78,6 +80,9 @@ pub struct AwaitedAction { #[metric(help = "The state of the AwaitedAction")] state: Arc, + /// The origin metadata of the action. + maybe_origin_metadata: Option, + /// Number of attempts the job has been tried. #[metric(help = "The number of attempts the AwaitedAction has been tried")] pub attempts: usize, @@ -100,6 +105,11 @@ impl AwaitedAction { client_operation_id: operation_id.clone(), action_digest: action_info.unique_qualifier.digest(), }); + let maybe_origin_metadata = ActiveOriginContext::get_value(&ORIGIN_EVENT_COLLECTOR) + .ok() + .flatten() + .map(|v| v.metadata.clone()); + Self { version: AwaitedActionVersion(0), action_info, @@ -108,6 +118,7 @@ impl AwaitedAction { attempts: 0, last_worker_updated_timestamp: now, last_client_keepalive_timestamp: now, + maybe_origin_metadata, worker_id: None, state, } @@ -141,6 +152,10 @@ impl AwaitedAction { &self.state } + pub(crate) fn maybe_origin_metadata(&self) -> Option<&OriginMetadata> { + self.maybe_origin_metadata.as_ref() + } + pub(crate) fn worker_id(&self) -> Option { self.worker_id } diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 7aaa6822a..7f6cca9be 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -33,6 +33,8 @@ use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProv use nativelink_util::operation_state_manager::{ ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, }; +use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::origin_event::{OriginMetadata, ORIGIN_EVENT_COLLECTOR}; use nativelink_util::store_trait::Store; use parking_lot::{Mutex, MutexGuard}; use scopeguard::guard; @@ -109,16 +111,20 @@ fn subscribe_to_existing_action( struct CacheLookupActionStateResult { action_state: Arc, + maybe_origin_metadata: Option, change_called: bool, } #[async_trait] impl ActionStateResult for CacheLookupActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.action_state.clone()) + async fn as_state(&self) -> Result<(Arc, Option), Error> { + Ok(( + self.action_state.clone(), + self.maybe_origin_metadata.clone(), + )) } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { if self.change_called { return Err(make_err!( Code::Internal, @@ -126,10 +132,13 @@ impl ActionStateResult for CacheLookupActionStateResult { )); } self.change_called = true; - Ok(self.action_state.clone()) + Ok(( + self.action_state.clone(), + self.maybe_origin_metadata.clone(), + )) } - async fn as_action_info(&self) -> Result, Error> { + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { // TODO(allada) We should probably remove as_action_info() // or implement it properly. return Err(make_err!( @@ -251,11 +260,17 @@ impl CacheLookupScheduler { action_digest: action_info.unique_qualifier.digest(), }; + let maybe_origin_metadata = + ActiveOriginContext::get_value(&ORIGIN_EVENT_COLLECTOR) + .ok() + .flatten() + .map(|v| v.metadata.clone()); for (client_operation_id, pending_tx) in pending_txs { action_state.client_operation_id = client_operation_id; // Ignore errors here, as the other end may have hung up. let _ = pending_tx.send(Ok(Box::new(CacheLookupActionStateResult { action_state: Arc::new(action_state.clone()), + maybe_origin_metadata: maybe_origin_metadata.clone(), change_called: false, }))); } diff --git a/nativelink-scheduler/src/default_scheduler_factory.rs b/nativelink-scheduler/src/default_scheduler_factory.rs index b66ca84fc..69fd47c7c 100644 --- a/nativelink-scheduler/src/default_scheduler_factory.rs +++ b/nativelink-scheduler/src/default_scheduler_factory.rs @@ -20,11 +20,12 @@ use nativelink_config::schedulers::{ }; use nativelink_config::stores::EvictionPolicy; use nativelink_error::{make_input_err, Error, ResultExt}; +use nativelink_proto::com::github::trace_machina::nativelink::events::OriginEvent; use nativelink_store::redis_store::RedisStore; use nativelink_store::store_manager::StoreManager; use nativelink_util::instant_wrapper::InstantWrapper; use nativelink_util::operation_state_manager::ClientStateManager; -use tokio::sync::Notify; +use tokio::sync::{mpsc, Notify}; use crate::cache_lookup_scheduler::CacheLookupScheduler; use crate::grpc_scheduler::GrpcScheduler; @@ -46,17 +47,19 @@ pub type SchedulerFactoryResults = ( pub fn scheduler_factory( spec: &SchedulerSpec, store_manager: &StoreManager, + maybe_origin_event_tx: Option<&mpsc::Sender>, ) -> Result { - inner_scheduler_factory(spec, store_manager) + inner_scheduler_factory(spec, store_manager, maybe_origin_event_tx) } fn inner_scheduler_factory( spec: &SchedulerSpec, store_manager: &StoreManager, + maybe_origin_event_tx: Option<&mpsc::Sender>, ) -> Result { let scheduler: SchedulerFactoryResults = match spec { SchedulerSpec::simple(spec) => { - simple_scheduler_factory(spec, store_manager, SystemTime::now)? + simple_scheduler_factory(spec, store_manager, SystemTime::now, maybe_origin_event_tx)? } SchedulerSpec::grpc(spec) => (Some(Arc::new(GrpcScheduler::new(spec)?)), None), SchedulerSpec::cache_lookup(spec) => { @@ -64,7 +67,7 @@ fn inner_scheduler_factory( .get_store(&spec.ac_store) .err_tip(|| format!("'ac_store': '{}' does not exist", spec.ac_store))?; let (action_scheduler, worker_scheduler) = - inner_scheduler_factory(&spec.scheduler, store_manager) + inner_scheduler_factory(&spec.scheduler, store_manager, maybe_origin_event_tx) .err_tip(|| "In nested CacheLookupScheduler construction")?; let cache_lookup_scheduler = Arc::new(CacheLookupScheduler::new( ac_store, @@ -74,7 +77,7 @@ fn inner_scheduler_factory( } SchedulerSpec::property_modifier(spec) => { let (action_scheduler, worker_scheduler) = - inner_scheduler_factory(&spec.scheduler, store_manager) + inner_scheduler_factory(&spec.scheduler, store_manager, maybe_origin_event_tx) .err_tip(|| "In nested PropertyModifierScheduler construction")?; let property_modifier_scheduler = Arc::new(PropertyModifierScheduler::new( spec, @@ -91,6 +94,7 @@ fn simple_scheduler_factory( spec: &SimpleSpec, store_manager: &StoreManager, now_fn: fn() -> SystemTime, + maybe_origin_event_tx: Option<&mpsc::Sender>, ) -> Result { match spec .experimental_backend @@ -104,8 +108,12 @@ fn simple_scheduler_factory( &task_change_notify.clone(), SystemTime::now, ); - let (action_scheduler, worker_scheduler) = - SimpleScheduler::new(spec, awaited_action_db, task_change_notify); + let (action_scheduler, worker_scheduler) = SimpleScheduler::new( + spec, + awaited_action_db, + task_change_notify, + maybe_origin_event_tx.cloned(), + ); Ok((Some(action_scheduler), Some(worker_scheduler))) } ExperimentalSimpleSchedulerBackend::redis(redis_config) => { @@ -134,8 +142,12 @@ fn simple_scheduler_factory( Default::default, ) .err_tip(|| "In state_manager_factory::redis_state_manager")?; - let (action_scheduler, worker_scheduler) = - SimpleScheduler::new(spec, awaited_action_db, task_change_notify); + let (action_scheduler, worker_scheduler) = SimpleScheduler::new( + spec, + awaited_action_db, + task_change_notify, + maybe_origin_event_tx.cloned(), + ); Ok((Some(action_scheduler), Some(worker_scheduler))) } } diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 9d308d739..63d3455d8 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -37,6 +37,7 @@ use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProv use nativelink_util::operation_state_manager::{ ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, }; +use nativelink_util::origin_event::OriginMetadata; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::{background_spawn, tls_utils}; use parking_lot::Mutex; @@ -55,13 +56,15 @@ struct GrpcActionStateResult { #[async_trait] impl ActionStateResult for GrpcActionStateResult { - async fn as_state(&self) -> Result, Error> { + async fn as_state(&self) -> Result<(Arc, Option), Error> { let mut action_state = self.rx.borrow().clone(); Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + // TODO(allada) We currently don't support OriginMetadata in this implementation, but + // we should. + Ok((action_state, None)) } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { self.rx.changed().await.map_err(|_| { make_err!( Code::Internal, @@ -70,10 +73,12 @@ impl ActionStateResult for GrpcActionStateResult { })?; let mut action_state = self.rx.borrow().clone(); Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + // TODO(allada) We currently don't support OriginMetadata in this implementation, but + // we should. + Ok((action_state, None)) } - async fn as_action_info(&self) -> Result, Error> { + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { // TODO(allada) We should probably remove as_action_info() // or implement it properly. return Err(make_err!( diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 744b27903..c91db82fc 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -16,10 +16,11 @@ use std::sync::Arc; use std::time::SystemTime; use async_trait::async_trait; -use futures::Future; +use futures::{Future, FutureExt}; use nativelink_config::schedulers::SimpleSpec; use nativelink_error::{Code, Error, ResultExt}; use nativelink_metric::{MetricsComponent, RootMetricsComponent}; +use nativelink_proto::com::github::trace_machina::nativelink::events::OriginEvent; use nativelink_util::action_messages::{ActionInfo, ActionState, OperationId, WorkerId}; use nativelink_util::instant_wrapper::InstantWrapper; use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider; @@ -27,12 +28,14 @@ use nativelink_util::operation_state_manager::{ ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, OperationFilter, OperationStageFlags, OrderDirection, UpdateOperationType, }; +use nativelink_util::origin_context::ActiveOriginContext; +use nativelink_util::origin_event::{OriginEventCollector, OriginMetadata, ORIGIN_EVENT_COLLECTOR}; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::Notify; +use tokio::sync::{mpsc, Notify}; use tokio::time::Duration; use tokio_stream::StreamExt; -use tracing::{event, Level}; +use tracing::{event, info_span, Level}; use crate::api_worker_scheduler::ApiWorkerScheduler; use crate::awaited_action_db::AwaitedActionDb; @@ -73,8 +76,8 @@ impl SimpleSchedulerActionStateResult { #[async_trait] impl ActionStateResult for SimpleSchedulerActionStateResult { - async fn as_state(&self) -> Result, Error> { - let mut action_state = self + async fn as_state(&self) -> Result<(Arc, Option), Error> { + let (mut action_state, origin_metadata) = self .action_state_result .as_state() .await @@ -82,11 +85,11 @@ impl ActionStateResult for SimpleSchedulerActionStateResult { // We need to ensure the client is not aware of the downstream // operation id, so override it before it goes out. Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + Ok((action_state, origin_metadata)) } - async fn changed(&mut self) -> Result, Error> { - let mut action_state = self + async fn changed(&mut self) -> Result<(Arc, Option), Error> { + let (mut action_state, origin_metadata) = self .action_state_result .changed() .await @@ -94,10 +97,10 @@ impl ActionStateResult for SimpleSchedulerActionStateResult { // We need to ensure the client is not aware of the downstream // operation id, so override it before it goes out. Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + Ok((action_state, origin_metadata)) } - async fn as_action_info(&self) -> Result, Error> { + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { self.action_state_result .as_action_info() .await @@ -127,6 +130,9 @@ pub struct SimpleScheduler { #[metric(group = "worker_scheduler")] worker_scheduler: Arc, + /// The sender to send origin events to the origin events. + maybe_origin_event_tx: Option>, + /// Background task that tries to match actions to workers. If this struct /// is dropped the spawn will be cancelled as well. _task_worker_matching_spawn: JoinHandleDropGuard<()>, @@ -191,11 +197,13 @@ impl SimpleScheduler { workers: &ApiWorkerScheduler, matching_engine_state_manager: &dyn MatchingEngineStateManager, platform_property_manager: &PlatformPropertyManager, + maybe_origin_event_tx: Option<&mpsc::Sender>, ) -> Result<(), Error> { - let action_info = action_state_result - .as_action_info() - .await - .err_tip(|| "Failed to get action_info from as_action_info_result stream")?; + let (action_info, maybe_origin_metadata) = + action_state_result + .as_action_info() + .await + .err_tip(|| "Failed to get action_info from as_action_info_result stream")?; // TODO(allada) We should not compute this every time and instead store // it with the ActionInfo when we receive it. @@ -223,39 +231,62 @@ impl SimpleScheduler { } }; - // Extract the operation_id from the action_state. - let operation_id = { - let action_state = action_state_result - .as_state() + let attach_operation_fut = async move { + // Extract the operation_id from the action_state. + let operation_id = { + let (action_state, _origin_metadata) = action_state_result + .as_state() + .await + .err_tip(|| "Failed to get action_info from as_state_result stream")?; + action_state.client_operation_id.clone() + }; + + // Tell the matching engine that the operation is being assigned to a worker. + let assign_result = matching_engine_state_manager + .assign_operation(&operation_id, Ok(&worker_id)) .await - .err_tip(|| "Failed to get action_info from as_state_result stream")?; - action_state.client_operation_id.clone() - }; - - // Tell the matching engine that the operation is being assigned to a worker. - let assign_result = matching_engine_state_manager - .assign_operation(&operation_id, Ok(&worker_id)) - .await - .err_tip(|| "Failed to assign operation in do_try_match"); - if let Err(err) = assign_result { - if err.code == Code::Aborted { - // If the operation was aborted, it means that the operation was - // cancelled due to another operation being assigned to the worker. - return Ok(()); + .err_tip(|| "Failed to assign operation in do_try_match"); + if let Err(err) = assign_result { + if err.code == Code::Aborted { + // If the operation was aborted, it means that the operation was + // cancelled due to another operation being assigned to the worker. + return Ok(()); + } + // Any other error is a real error. + return Err(err); } - // Any other error is a real error. - return Err(err); - } - // Notify the worker to run the action. - { workers .worker_notify_run_action(worker_id, operation_id, action_info) .await .err_tip(|| { "Failed to run worker_notify_run_action in SimpleScheduler::do_try_match" }) - } + }; + tokio::pin!(attach_operation_fut); + + let attach_operation_fut = if let Some(origin_event_tx) = maybe_origin_event_tx { + let mut ctx = ActiveOriginContext::fork().unwrap_or_default(); + + // Populate our origin event collector with the origin metadata + // associated with the action. + ctx.replace_value(&ORIGIN_EVENT_COLLECTOR, move |maybe_old_collector| { + let origin_metadata = maybe_origin_metadata.unwrap_or_default(); + let Some(old_collector) = maybe_old_collector else { + return Some(Arc::new(OriginEventCollector::new( + origin_event_tx.clone(), + origin_metadata, + ))); + }; + Some(Arc::new(old_collector.clone_with_metadata(origin_metadata))) + }); + Arc::new(ctx) + .wrap_async(info_span!("do_try_match"), attach_operation_fut) + .left_future() + } else { + attach_operation_fut.right_future() + }; + attach_operation_fut.await } let mut result = Ok(()); @@ -272,6 +303,7 @@ impl SimpleScheduler { self.worker_scheduler.as_ref(), self.matching_engine_state_manager.as_ref(), self.platform_property_manager.as_ref(), + self.maybe_origin_event_tx.as_ref(), ) .await, ); @@ -285,6 +317,7 @@ impl SimpleScheduler { spec: &SimpleSpec, awaited_action_db: A, task_change_notify: Arc, + maybe_origin_event_tx: Option>, ) -> (Arc, Arc) { Self::new_with_callback( spec, @@ -302,6 +335,7 @@ impl SimpleScheduler { }, task_change_notify, SystemTime::now, + maybe_origin_event_tx, ) } @@ -317,6 +351,7 @@ impl SimpleScheduler { on_matching_engine_run: F, task_change_notify: Arc, now_fn: NowFn, + maybe_origin_event_tx: Option>, ) -> (Arc, Arc) { let platform_property_manager = Arc::new(PlatformPropertyManager::new( spec.supported_platform_properties @@ -389,6 +424,7 @@ impl SimpleScheduler { client_state_manager: state_manager.clone(), worker_scheduler, platform_property_manager, + maybe_origin_event_tx, _task_worker_matching_spawn: task_worker_matching_spawn, } }); diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs index d07adfe09..c319901be 100644 --- a/nativelink-scheduler/src/simple_scheduler_state_manager.rs +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -32,6 +32,7 @@ use nativelink_util::operation_state_manager::{ ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, OperationFilter, OperationStageFlags, OrderDirection, UpdateOperationType, WorkerStateManager, }; +use nativelink_util::origin_event::OriginMetadata; use tracing::{event, Level}; use super::awaited_action_db::{ @@ -47,15 +48,15 @@ struct ErrorActionStateResult(Error); #[async_trait] impl ActionStateResult for ErrorActionStateResult { - async fn as_state(&self) -> Result, Error> { + async fn as_state(&self) -> Result<(Arc, Option), Error> { Err(self.0.clone()) } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { Err(self.0.clone()) } - async fn as_action_info(&self) -> Result, Error> { + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { Err(self.0.clone()) } } @@ -102,15 +103,15 @@ where I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Unpin + Sync + 'static, { - async fn as_state(&self) -> Result, Error> { + async fn as_state(&self) -> Result<(Arc, Option), Error> { self.inner.as_state().await } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { self.inner.changed().await } - async fn as_action_info(&self) -> Result, Error> { + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { self.inner.as_action_info().await } } @@ -157,24 +158,26 @@ where I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Unpin + Sync + 'static, { - async fn as_state(&self) -> Result, Error> { - Ok(self + async fn as_state(&self) -> Result<(Arc, Option), Error> { + let awaited_action = self .awaited_action_sub .borrow() .await - .err_tip(|| "In MatchingEngineActionStateResult::as_state")? - .state() - .clone()) + .err_tip(|| "In MatchingEngineActionStateResult::as_state")?; + Ok(( + awaited_action.state().clone(), + awaited_action.maybe_origin_metadata().cloned(), + )) } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { let mut timeout_attempts = 0; loop { tokio::select! { awaited_action_result = self.awaited_action_sub.changed() => { return awaited_action_result .err_tip(|| "In MatchingEngineActionStateResult::changed") - .map(|v| v.state().clone()); + .map(|v| (v.state().clone(), v.maybe_origin_metadata().cloned())); } () = (self.now_fn)().sleep(self.no_event_action_timeout) => { // Timeout happened, do additional checks below. @@ -225,14 +228,16 @@ where } } - async fn as_action_info(&self) -> Result, Error> { - Ok(self + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { + let awaited_action = self .awaited_action_sub .borrow() .await - .err_tip(|| "In MatchingEngineActionStateResult::as_action_info")? - .action_info() - .clone()) + .err_tip(|| "In MatchingEngineActionStateResult::as_action_info")?; + Ok(( + awaited_action.action_info().clone(), + awaited_action.maybe_origin_metadata().cloned(), + )) } } @@ -777,7 +782,7 @@ where action_info: Arc, ) -> Result, Error> { let sub = self - .inner_add_operation(client_operation_id.clone(), action_info.clone()) + .inner_add_operation(client_operation_id, action_info.clone()) .await?; Ok(Box::new(ClientActionStateResult::new( diff --git a/nativelink-scheduler/src/store_awaited_action_db.rs b/nativelink-scheduler/src/store_awaited_action_db.rs index 823589d6c..f326456c4 100644 --- a/nativelink-scheduler/src/store_awaited_action_db.rs +++ b/nativelink-scheduler/src/store_awaited_action_db.rs @@ -339,7 +339,7 @@ impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction { ActionUniqueQualifier::Cachable(_) => Some(unique_qualifier), ActionUniqueQualifier::Uncachable(_) => None, }; - let mut output = Vec::with_capacity(1 + maybe_unique_qualifier.map_or(0, |_| 1)); + let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1)); if maybe_unique_qualifier.is_some() { output.push(( "unique_qualifier", diff --git a/nativelink-scheduler/src/worker.rs b/nativelink-scheduler/src/worker.rs index 985c287e8..2a96f72ef 100644 --- a/nativelink-scheduler/src/worker.rs +++ b/nativelink-scheduler/src/worker.rs @@ -23,7 +23,8 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; use nativelink_util::action_messages::{ActionInfo, OperationId, WorkerId}; -use nativelink_util::metrics_utils::{CounterWithTime, FuncCounterWrapper}; +use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime, FuncCounterWrapper}; +use nativelink_util::origin_event::OriginEventContext; use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; use tokio::sync::mpsc::UnboundedSender; @@ -52,6 +53,13 @@ pub enum WorkerUpdate { Disconnect, } +#[derive(MetricsComponent)] +pub struct PendingActionInfoData { + #[metric] + pub action_info: ActionInfoWithProps, + ctx: OriginEventContext, +} + /// Represents a connection to a worker and used as the medium to /// interact with the worker from the client/scheduler. #[derive(MetricsComponent)] @@ -69,7 +77,7 @@ pub struct Worker { /// The action info of the running actions on the worker. #[metric(group = "running_action_infos")] - pub running_action_infos: HashMap, + pub running_action_infos: HashMap, /// Timestamp of last time this worker had been communicated with. // Warning: Do not update this timestamp without updating the placement of the worker in @@ -139,7 +147,7 @@ impl Worker { .unwrap() .as_secs(), actions_completed: CounterWithTime::default(), - run_action: FuncCounterWrapper::default(), + run_action: AsyncCounterWrapper::default(), keep_alive: FuncCounterWrapper::default(), notify_disconnect: CounterWithTime::default(), }), @@ -159,10 +167,10 @@ impl Worker { } /// Notifies the worker of a requested state change. - pub fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> { + pub async fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> { match worker_update { WorkerUpdate::RunAction((operation_id, action_info)) => { - self.run_action(operation_id, action_info) + self.run_action(operation_id, action_info).await } WorkerUpdate::Disconnect => { self.metrics.notify_disconnect.inc(); @@ -180,7 +188,7 @@ impl Worker { }) } - fn run_action( + async fn run_action( &mut self, operation_id: OperationId, action_info: ActionInfoWithProps, @@ -188,33 +196,45 @@ impl Worker { let tx = &mut self.tx; let worker_platform_properties = &mut self.platform_properties; let running_action_infos = &mut self.running_action_infos; - self.metrics.run_action.wrap(move || { - let action_info_clone = action_info.clone(); - let operation_id_string = operation_id.to_string(); - running_action_infos.insert(operation_id, action_info.clone()); - reduce_platform_properties( - worker_platform_properties, - &action_info.platform_properties, - ); - send_msg_to_worker( - tx, - update_for_worker::Update::StartAction(StartExecute { + let worker_id = self.id.to_string(); + self.metrics + .run_action + .wrap(async move { + let action_info_clone = action_info.clone(); + let operation_id_string = operation_id.to_string(); + let start_execute = StartExecute { execute_request: Some(action_info_clone.inner.as_ref().into()), operation_id: operation_id_string, queued_timestamp: Some(action_info.inner.insert_timestamp.into()), - }), - ) - }) + platform: Some((&action_info.platform_properties).into()), + worker_id, + }; + reduce_platform_properties( + worker_platform_properties, + &action_info.platform_properties, + ); + + let ctx = OriginEventContext::new(|| &start_execute).await; + running_action_infos + .insert(operation_id, PendingActionInfoData { action_info, ctx }); + + send_msg_to_worker(tx, update_for_worker::Update::StartAction(start_execute)) + }) + .await } - pub(crate) fn complete_action(&mut self, operation_id: &OperationId) -> Result<(), Error> { - let action_info = self.running_action_infos.remove(operation_id).err_tip(|| { + pub(crate) async fn complete_action( + &mut self, + operation_id: &OperationId, + ) -> Result<(), Error> { + let pending_action_info = self.running_action_infos.remove(operation_id).err_tip(|| { format!( "Worker {} tried to complete operation {} that was not running", self.id, operation_id ) })?; - self.restore_platform_properties(&action_info.platform_properties); + pending_action_info.ctx.emit(|| &()).await; + self.restore_platform_properties(&pending_action_info.action_info.platform_properties); self.is_paused = false; self.metrics.actions_completed.inc(); Ok(()) @@ -263,7 +283,7 @@ struct Metrics { #[metric(help = "The number of actions completed for this worker.")] actions_completed: CounterWithTime, #[metric(help = "The number of actions started for this worker.")] - run_action: FuncCounterWrapper, + run_action: AsyncCounterWrapper, #[metric(help = "The number of keep_alive sent to this worker.")] keep_alive: FuncCounterWrapper, #[metric(help = "The number of notify_disconnect sent to this worker.")] diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index 0bf4405bc..ef0e55ab8 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -193,10 +193,6 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { stage: ActionStage::Queued, action_digest: action_info.unique_qualifier.digest(), })); - // let platform_property_manager = Arc::new(PlatformPropertyManager::new(HashMap::from([( - // name, - // PropertyType::exact, - // )]))); let client_operation_id = OperationId::default(); let (_, (passed_client_operation_id, action_info)) = join!( context diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index c3fea27ad..b11e713e9 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -28,7 +28,9 @@ use nativelink_config::schedulers::{PropertyType, SimpleSpec}; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_metric::MetricsComponent; -use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest}; +use nativelink_proto::build::bazel::remote::execution::v2::{ + digest_function, ExecuteRequest, Platform, +}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; @@ -169,6 +171,7 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -192,6 +195,8 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { }), operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), + platform: Some(Platform::default()), + worker_id: worker_id.to_string(), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); @@ -200,7 +205,7 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -240,6 +245,7 @@ async fn client_does_not_receive_update_timeout() -> Result<(), Error> { || async move {}, task_change_notify.clone(), MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -257,7 +263,7 @@ async fn client_does_not_receive_update_timeout() -> Result<(), Error> { // Trigger a do_try_match to ensure we get a state change. scheduler.do_try_match_for_test().await.unwrap(); assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); @@ -279,7 +285,7 @@ async fn client_does_not_receive_update_timeout() -> Result<(), Error> { { // Now we should have received a timeout and the action should have been // put back in the queue. - assert_eq!(changed_fut.await.unwrap().stage, ActionStage::Queued); + assert_eq!(changed_fut.await.unwrap().0.stage, ActionStage::Queued); } Ok(()) @@ -300,6 +306,7 @@ async fn find_executing_action() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -314,6 +321,7 @@ async fn find_executing_action() -> Result<(), Error> { .as_state() .await .unwrap() + .0 .client_operation_id .clone(); // Drop our receiver and look up a new one. @@ -341,6 +349,8 @@ async fn find_executing_action() -> Result<(), Error> { }), operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), + platform: Some(Platform::default()), + worker_id: worker_id.to_string(), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); @@ -349,7 +359,7 @@ async fn find_executing_action() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -380,6 +390,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest1 = DigestInfo::new([99u8; 32], 512); let action_digest2 = DigestInfo::new([88u8; 32], 512); @@ -412,6 +423,8 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err }), operation_id: "WILL BE SET BELOW".to_string(), queued_timestamp: Some(insert_timestamp1.into()), + platform: Some(Platform::default()), + worker_id: worker_id1.to_string(), }; let mut expected_start_execute_for_worker2 = StartExecute { @@ -423,6 +436,8 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err }), operation_id: "WILL BE SET BELOW".to_string(), queued_timestamp: Some(insert_timestamp2.into()), + platform: Some(Platform::default()), + worker_id: worker_id1.to_string(), }; let operation_id1 = { // Worker1 should now see first execution request. @@ -470,14 +485,16 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client1_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client1_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client2_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client2_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } @@ -499,14 +516,16 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client1_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client1_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client2_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client2_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } @@ -514,6 +533,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); expected_start_execute_for_worker1.operation_id = operation_id1.to_string(); + expected_start_execute_for_worker1.worker_id = worker_id2.to_string(); assert_eq!( msg_for_worker, UpdateForWorker { @@ -527,6 +547,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); expected_start_execute_for_worker2.operation_id = operation_id2.to_string(); + expected_start_execute_for_worker2.worker_id = worker_id2.to_string(); assert_eq!( msg_for_worker, UpdateForWorker { @@ -555,6 +576,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -574,7 +596,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> }; // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); operation_id @@ -591,7 +613,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's been queued. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -607,7 +629,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -642,6 +664,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); let mut platform_properties = HashMap::new(); @@ -665,7 +688,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E { // Client should get notification saying it's been queued. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -679,7 +702,8 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E "prop".to_string(), PlatformPropertyValue::Exact("1".to_string()), ); - let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, worker2_properties).await?; + let mut rx_from_worker2 = + setup_new_worker(&scheduler, worker_id2, worker2_properties.clone()).await?; { // Worker should have been sent an execute command. let expected_msg_for_worker = UpdateForWorker { @@ -692,6 +716,8 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E }), operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), + platform: Some((&worker2_properties).into()), + worker_id: worker_id2.to_string(), })), }; let msg_for_worker = rx_from_worker2.recv().await.unwrap(); @@ -699,7 +725,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E } { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -733,6 +759,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -752,8 +779,10 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let (operation_id1, operation_id2) = { // Clients should get notification saying it's been queued. - let action_state1 = client1_action_listener.changed().await.unwrap(); - let action_state2 = client2_action_listener.changed().await.unwrap(); + let (action_state1, _maybe_origin_metadata) = + client1_action_listener.changed().await.unwrap(); + let (action_state2, _maybe_origin_metadata) = + client2_action_listener.changed().await.unwrap(); let operation_id1 = action_state1.client_operation_id.clone(); let operation_id2 = action_state2.client_operation_id.clone(); // Name is random so we set force it to be the same. @@ -784,6 +813,8 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { }), operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp1.into()), + platform: Some(Platform::default()), + worker_id: worker_id.to_string(), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); @@ -798,12 +829,12 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { // Most importantly the `name` (which is random) will be the same. expected_action_state.client_operation_id = operation_id1.clone(); assert_eq!( - client1_action_listener.changed().await.unwrap().as_ref(), + client1_action_listener.changed().await.unwrap().0.as_ref(), &expected_action_state ); expected_action_state.client_operation_id = operation_id2.clone(); assert_eq!( - client2_action_listener.changed().await.unwrap().as_ref(), + client2_action_listener.changed().await.unwrap().0.as_ref(), &expected_action_state ); } @@ -813,7 +844,8 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let insert_timestamp3 = make_system_time(2); let mut client3_action_listener = setup_action(&scheduler, action_digest, HashMap::new(), insert_timestamp3).await?; - let action_state = client3_action_listener.changed().await.unwrap().clone(); + let (action_state, _maybe_origin_metadata) = + client3_action_listener.changed().await.unwrap().clone(); expected_action_state.client_operation_id = action_state.client_operation_id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); } @@ -834,6 +866,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let worker_id: WorkerId = WorkerId(Uuid::new_v4()); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -849,7 +882,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), setup_action(&scheduler, action_digest, HashMap::new(), insert_timestamp).await?; { // Client should get notification saying it's being queued not executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -989,6 +1022,7 @@ async fn matching_engine_fails_sends_abort() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); // Initial worker calls do_try_match, so send it no items. senders.get_range_of_actions.send(vec![]).unwrap(); @@ -1034,6 +1068,7 @@ async fn matching_engine_fails_sends_abort() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); // senders.tx_get_awaited_action_by_id.send(Ok(None)).unwrap(); senders.get_range_of_actions.send(vec![]).unwrap(); @@ -1092,6 +1127,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1115,6 +1151,8 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { }), operation_id: "UNKNOWN HERE, WE WILL SET IT LATER".to_string(), queued_timestamp: Some(insert_timestamp.into()), + platform: Some(Platform::default()), + worker_id: worker_id1.to_string(), }; { @@ -1140,7 +1178,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); assert_eq!( action_state.as_ref(), &ActionState { @@ -1173,7 +1211,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); assert_eq!( action_state.as_ref(), &ActionState { @@ -1184,14 +1222,13 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { ); } { + start_execute.worker_id = worker_id2.to_string(); // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); assert_eq!( msg_for_worker, UpdateForWorker { - update: Some(update_for_worker::Update::StartAction( - start_execute.clone() - )), + update: Some(update_for_worker::Update::StartAction(start_execute)), } ); } @@ -1214,6 +1251,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1229,7 +1267,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err Some(update_for_worker::Update::StartAction(start_execute)) => { // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); start_execute.operation_id @@ -1287,7 +1325,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err { // Client should get notification saying it has been completed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -1315,6 +1353,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1328,6 +1367,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E .as_state() .await .unwrap() + .0 .client_operation_id .clone(); @@ -1404,7 +1444,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E .expect("Action not found"); { // Client should get notification saying it has been completed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -1433,6 +1473,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1450,7 +1491,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { } // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); } @@ -1531,6 +1572,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1562,6 +1604,8 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro }), operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), + platform: Some(Platform::default()), + worker_id: worker_id.to_string(), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); @@ -1581,7 +1625,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro { // Client should get notification saying it's being executed. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. expected_action_state.client_operation_id = action_state.client_operation_id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1627,7 +1671,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro // Action should now be executing. expected_action_state.stage = ActionStage::Completed(action_result.clone()); assert_eq!( - action_listener.changed().await.unwrap().as_ref(), + action_listener.changed().await.unwrap().0.as_ref(), &expected_action_state ); } @@ -1643,7 +1687,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro .unwrap(); // We didn't disconnect our worker, so it will have scheduled it to the worker. expected_action_state.stage = ActionStage::Executing; - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); // The name of the action changed (since it's a new action), so update it. expected_action_state.client_operation_id = action_state.client_operation_id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1674,6 +1718,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); let action_digest2 = DigestInfo::new([99u8; 32], 512); @@ -1712,8 +1757,8 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> v => panic!("Expected StartAction, got : {v:?}"), }; { - let state_1 = client1_action_listener.changed().await.unwrap(); - let state_2 = client2_action_listener.changed().await.unwrap(); + let (state_1, _maybe_origin_metadata) = client1_action_listener.changed().await.unwrap(); + let (state_2, _maybe_origin_metadata) = client2_action_listener.changed().await.unwrap(); // First client should be in an Executing state. assert_eq!(state_1.stage, ActionStage::Executing); // Second client should be in a queued state. @@ -1759,7 +1804,8 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> { // First action should now be completed. - let action_state = client1_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client1_action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -1782,7 +1828,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> }; // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - client2_action_listener.changed().await.unwrap().stage, + client2_action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); operation_id @@ -1802,7 +1848,8 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> { // Our second client should be notified it completed. - let action_state = client2_action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = + client2_action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -1836,6 +1883,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest1 = DigestInfo::new([11u8; 32], 512); let action_digest2 = DigestInfo::new([99u8; 32], 512); @@ -1872,12 +1920,12 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { { // First client should be in an Executing state. assert_eq!( - client1_action_listener.changed().await.unwrap().stage, + client1_action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); // Second client should be in a queued state. assert_eq!( - client2_action_listener.changed().await.unwrap().stage, + client2_action_listener.changed().await.unwrap().0.stage, ActionStage::Queued ); } @@ -1903,6 +1951,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -1920,7 +1969,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> }; // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); OperationId::from(operation_id.as_str()) @@ -1936,7 +1985,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { // Client should get notification saying it has been queued again. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -1957,7 +2006,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> } // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); } @@ -1974,7 +2023,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { // Client should get notification saying it has been queued again. - let action_state = action_listener.changed().await.unwrap(); + let (action_state, _maybe_origin_metadata) = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. client_operation_id: action_state.client_operation_id.clone(), @@ -2056,6 +2105,7 @@ async fn ensure_scheduler_drops_inner_spawn() -> Result<(), Error> { }, task_change_notify, MockInstantWrapped::default, + None, ); assert_eq!(dropped.load(Ordering::Relaxed), false); @@ -2085,6 +2135,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -2109,7 +2160,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), }; // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); OperationId::from(operation_id.as_str()) @@ -2134,7 +2185,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), .err_tip(|| "worker went away")?; // Other tests check full data. We only care if client thinks we are Executing. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Executing ); } @@ -2160,6 +2211,7 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -2172,6 +2224,7 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { .as_state() .await .unwrap() + .0 .client_operation_id .clone(); @@ -2191,7 +2244,7 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { // We should get one notification saying it's queued. assert_eq!( - new_action_listener.changed().await.unwrap().stage, + new_action_listener.changed().await.unwrap().0.stage, ActionStage::Queued ); @@ -2240,6 +2293,7 @@ async fn client_timesout_job_then_same_action_requested() -> Result<(), Error> { || async move {}, task_change_notify, MockInstantWrapped::default, + None, ); let action_digest = DigestInfo::new([99u8; 32], 512); @@ -2252,7 +2306,7 @@ async fn client_timesout_job_then_same_action_requested() -> Result<(), Error> { // We should get one notification saying it's queued. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Queued ); @@ -2275,7 +2329,7 @@ async fn client_timesout_job_then_same_action_requested() -> Result<(), Error> { // We should get one notification saying it's queued. assert_eq!( - action_listener.changed().await.unwrap().stage, + action_listener.changed().await.unwrap().0.stage, ActionStage::Queued ); diff --git a/nativelink-scheduler/tests/utils/scheduler_utils.rs b/nativelink-scheduler/tests/utils/scheduler_utils.rs index 67bd7b493..d72442bff 100644 --- a/nativelink-scheduler/tests/utils/scheduler_utils.rs +++ b/nativelink-scheduler/tests/utils/scheduler_utils.rs @@ -24,6 +24,7 @@ use nativelink_util::action_messages::{ use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::origin_event::OriginMetadata; use tokio::sync::watch; pub const INSTANCE_NAME: &str = "foobar_instance_name"; @@ -73,13 +74,13 @@ impl TokioWatchActionStateResult { #[async_trait] impl ActionStateResult for TokioWatchActionStateResult { - async fn as_state(&self) -> Result, Error> { + async fn as_state(&self) -> Result<(Arc, Option), Error> { let mut action_state = self.rx.borrow().clone(); Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + Ok((action_state, None)) } - async fn changed(&mut self) -> Result, Error> { + async fn changed(&mut self) -> Result<(Arc, Option), Error> { self.rx.changed().await.map_err(|_| { make_err!( Code::Internal, @@ -88,10 +89,10 @@ impl ActionStateResult for TokioWatchActionStateResult { })?; let mut action_state = self.rx.borrow().clone(); Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); - Ok(action_state) + Ok((action_state, None)) } - async fn as_action_info(&self) -> Result, Error> { - Ok(self.action_info.clone()) + async fn as_action_info(&self) -> Result<(Arc, Option), Error> { + Ok((self.action_info.clone(), None)) } } diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index b8faf1284..897d5c635 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -206,7 +206,7 @@ impl ExecutionServer { async move { let mut action_listener = maybe_action_listener?; match action_listener.changed().await { - Ok(action_update) => { + Ok((action_update, _maybe_origin_metadata)) => { event!(Level::INFO, ?action_update, "Execute Resp Stream"); // If the action is finished we won't be sending any more updates. let maybe_action_listener = if action_update.stage.is_finished() { @@ -279,6 +279,7 @@ impl ExecutionServer { .as_state() .await .err_tip(|| "In ExecutionServer::inner_execute")? + .0 .client_operation_id .clone(), ), diff --git a/nativelink-store/src/compression_store.rs b/nativelink-store/src/compression_store.rs index 83d2899a2..0b98ba022 100644 --- a/nativelink-store/src/compression_store.rs +++ b/nativelink-store/src/compression_store.rs @@ -422,7 +422,7 @@ impl StoreDriver for CompressionStore { let read_fut = async move { let header = { // Read header. - static EMPTY_HEADER: Header = Header { + const EMPTY_HEADER: Header = Header { version: CURRENT_STREAM_FORMAT_VERSION, config: Lz4Config { block_size: 0 }, upload_size: UploadSizeInfo::ExactSize(0), diff --git a/nativelink-util/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs index 0a3835f08..55b6deaeb 100644 --- a/nativelink-util/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -27,6 +27,7 @@ use crate::action_messages::{ }; use crate::common::DigestInfo; use crate::known_platform_property_provider::KnownPlatformPropertyProvider; +use crate::origin_event::OriginMetadata; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -47,12 +48,12 @@ impl Default for OperationStageFlags { #[async_trait] pub trait ActionStateResult: Send + Sync + 'static { - // Provides the current state of the action. - async fn as_state(&self) -> Result, Error>; - // Waits for the state of the action to change. - async fn changed(&mut self) -> Result, Error>; - // Provide result as action info. This behavior will not be supported by all implementations. - async fn as_action_info(&self) -> Result, Error>; + /// Provides the current state of the action. + async fn as_state(&self) -> Result<(Arc, Option), Error>; + /// Waits for the state of the action to change. + async fn changed(&mut self) -> Result<(Arc, Option), Error>; + /// Provide result as action info. This behavior will not be supported by all implementations. + async fn as_action_info(&self) -> Result<(Arc, Option), Error>; } /// The direction in which the results are ordered. diff --git a/nativelink-util/src/origin_context.rs b/nativelink-util/src/origin_context.rs index 081e4b197..c6d4728bc 100644 --- a/nativelink-util/src/origin_context.rs +++ b/nativelink-util/src/origin_context.rs @@ -16,6 +16,7 @@ use core::panic; use std::any::Any; use std::cell::RefCell; use std::clone::Clone; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::mem::ManuallyDrop; use std::pin::Pin; @@ -105,6 +106,29 @@ impl OriginContext { Self::default() } + /// Replaces the value for a given symbol on the context. + pub fn replace_value( + &mut self, + symbol: &'static impl Symbol, + cb: impl FnOnce(Option>) -> Option>, + ) { + let entry = self.data.entry(RawSymbolWrapper(symbol.as_ptr())); + let old_value = match &entry { + Entry::Occupied(data) => Arc::downcast(data.get().clone()).ok(), + Entry::Vacant(_) => None, + }; + match cb(old_value) { + Some(new_value) => { + entry.insert_entry(new_value); + } + None => { + if let Entry::Occupied(entry) = entry { + entry.remove(); + } + } + } + } + /// Sets the value for a given symbol on the context. pub fn set_value( &mut self, diff --git a/nativelink-util/src/origin_event.rs b/nativelink-util/src/origin_event.rs index fd8350073..df0c1419f 100644 --- a/nativelink-util/src/origin_event.rs +++ b/nativelink-util/src/origin_event.rs @@ -16,6 +16,8 @@ use std::marker::PhantomData; use std::pin::Pin; use std::sync::{Arc, OnceLock}; +use base64::prelude::BASE64_STANDARD_NO_PAD; +use base64::Engine; use futures::future::ready; use futures::task::{Context, Poll}; use futures::{Future, FutureExt, Stream, StreamExt}; @@ -30,6 +32,7 @@ use nativelink_proto::com::github::trace_machina::nativelink::events::{ response_event, stream_event, BatchReadBlobsResponseOverride, BatchUpdateBlobsRequestOverride, Event, OriginEvent, RequestEvent, ResponseEvent, StreamEvent, WriteRequestOverride, }; +use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::StartExecute; use nativelink_proto::google::bytestream::{ QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, WriteResponse, @@ -37,7 +40,9 @@ use nativelink_proto::google::bytestream::{ use nativelink_proto::google::longrunning::Operation; use nativelink_proto::google::rpc::Status; use pin_project_lite::pin_project; +use prost::Message; use rand::RngCore; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; use tonic::{Response, Status as TonicStatus, Streaming}; @@ -73,6 +78,7 @@ pub fn get_id_for_event(event: &Event) -> [u8; 2] { Some(request_event::Event::QueryWriteStatusRequest(_)) => [0x01, 0x0A], Some(request_event::Event::ExecuteRequest(_)) => [0x01, 0x0B], Some(request_event::Event::WaitExecutionRequest(_)) => [0x01, 0x0C], + Some(request_event::Event::SchedulerStartExecute(_)) => [0x01, 0x0D], }, Some(event::Event::Response(res)) => match res.event { None => [0x02, 0x00], @@ -115,22 +121,64 @@ pub fn get_node_id(event: Option<&Event>) -> [u8; 6] { node_id } +fn serialize_request_metadata( + value: &Option, + serializer: S, +) -> Result +where + S: Serializer, +{ + match value { + Some(msg) => serializer.serialize_some(&BASE64_STANDARD_NO_PAD.encode(msg.encode_to_vec())), + None => serializer.serialize_none(), + } +} + +fn deserialize_request_metadata<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + match opt { + Some(s) => { + let decoded = BASE64_STANDARD_NO_PAD + .decode(s.as_bytes()) + .map_err(serde::de::Error::custom)?; + RequestMetadata::decode(&*decoded) + .map_err(serde::de::Error::custom) + .map(Some) + } + None => Ok(None), + } +} + +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] +pub struct OriginMetadata { + pub identity: String, + #[serde( + serialize_with = "serialize_request_metadata", + deserialize_with = "deserialize_request_metadata" + )] + pub bazel_metadata: Option, +} + pub struct OriginEventCollector { sender: mpsc::Sender, - identity: String, - bazel_metadata: Option, + pub metadata: OriginMetadata, } impl OriginEventCollector { - pub fn new( - sender: mpsc::Sender, - identity: String, - bazel_metadata: Option, - ) -> Self { + pub fn new(sender: mpsc::Sender, metadata: OriginMetadata) -> Self { + Self { sender, metadata } + } + + #[must_use] + pub fn clone_with_metadata(&self, metadata: OriginMetadata) -> Self { Self { - sender, - identity, - bazel_metadata, + sender: self.sender.clone(), + metadata, } } @@ -151,8 +199,8 @@ impl OriginEventCollector { version: ORIGIN_EVENT_VERSION, event_id: event_id.as_hyphenated().to_string(), parent_event_id, - bazel_request_metadata: self.bazel_metadata.clone(), - identity: self.identity.clone(), + bazel_request_metadata: self.metadata.bazel_metadata.clone(), + identity: self.metadata.identity.clone(), event: Some(event), }) .await; @@ -172,8 +220,8 @@ impl OriginEventCollector { version: ORIGIN_EVENT_VERSION, event_id: event_id.as_hyphenated().to_string(), parent_event_id, - bazel_request_metadata: self.bazel_metadata.clone(), - identity: self.identity.clone(), + bazel_request_metadata: self.metadata.bazel_metadata.clone(), + identity: self.metadata.identity.clone(), event: Some(event), }) .map_or_else( @@ -508,6 +556,7 @@ impl_as_event! {Request, (), Streaming, WriteRequest, to_empty_wri impl_as_event! {Request, (), QueryWriteStatusRequest} impl_as_event! {Request, (), ExecuteRequest} impl_as_event! {Request, (), WaitExecutionRequest} +impl_as_event! {Request, (), StartExecute, SchedulerStartExecute} // -- Responses -- @@ -522,6 +571,7 @@ impl_as_event! {Response, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse} impl_as_event! {Response, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchReadBlobsResponseOverride, to_batch_read_blobs_response_override} impl_as_event! {Response, GetTreeRequest, Pin> + Send + '_>>, Empty, to_empty_response} impl_as_event! {Response, ExecuteRequest, Pin> + Send + '_>>, Empty, to_empty_response} +impl_as_event! {Response, StartExecute, (), Empty} // -- Streams -- diff --git a/nativelink-util/src/origin_event_middleware.rs b/nativelink-util/src/origin_event_middleware.rs index ae73d9e5a..f870add19 100644 --- a/nativelink-util/src/origin_event_middleware.rs +++ b/nativelink-util/src/origin_event_middleware.rs @@ -29,7 +29,7 @@ use tower::Service; use tracing::trace_span; use crate::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY}; -use crate::origin_event::{OriginEventCollector, ORIGIN_EVENT_COLLECTOR}; +use crate::origin_event::{OriginEventCollector, OriginMetadata, ORIGIN_EVENT_COLLECTOR}; /// Default identity header name. /// Note: If this is changed, the default value in the [`IdentityHeaderSpec`] @@ -138,8 +138,10 @@ where &ORIGIN_EVENT_COLLECTOR, Arc::new(OriginEventCollector::new( origin_event_tx.clone(), - identity, - bazel_metadata, + OriginMetadata { + identity, + bazel_metadata, + }, )), ); } diff --git a/nativelink-util/src/platform_properties.rs b/nativelink-util/src/platform_properties.rs index 0a223974b..9c3a4fab5 100644 --- a/nativelink-util/src/platform_properties.rs +++ b/nativelink-util/src/platform_properties.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; use nativelink_metric::{ publish, MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent, }; +use nativelink_proto::build::bazel::remote::execution::v2::platform::Property as ProtoProperty; use nativelink_proto::build::bazel::remote::execution::v2::Platform as ProtoPlatform; use serde::{Deserialize, Serialize}; @@ -69,6 +70,21 @@ impl From for PlatformProperties { } } +impl From<&PlatformProperties> for ProtoPlatform { + fn from(val: &PlatformProperties) -> Self { + ProtoPlatform { + properties: val + .properties + .iter() + .map(|(name, value)| ProtoProperty { + name: name.clone(), + value: value.as_str().to_string(), + }) + .collect(), + } + } +} + /// Holds the associated value of the key and type. /// /// Exact - Means the worker must have this exact value. diff --git a/nativelink-util/tests/origin_event_test.rs b/nativelink-util/tests/origin_event_test.rs index d08c99576..cf334b651 100644 --- a/nativelink-util/tests/origin_event_test.rs +++ b/nativelink-util/tests/origin_event_test.rs @@ -95,6 +95,7 @@ fn get_id_for_event_test() { Some(request_event::Event::QueryWriteStatusRequest(_)) => [0x01, 0x0A], Some(request_event::Event::ExecuteRequest(_)) => [0x01, 0x0B], Some(request_event::Event::WaitExecutionRequest(_)) => [0x01, 0x0C], + Some(request_event::Event::SchedulerStartExecute(_)) => [0x01, 0x0D], // Don't forget to add new entries to test cases. } } @@ -144,6 +145,7 @@ fn get_id_for_event_test() { test_event!(Request, QueryWriteStatusRequest); test_event!(Request, ExecuteRequest); test_event!(Request, WaitExecutionRequest); + test_event!(Request, SchedulerStartExecute); test_event!(Response, None); test_event!(Response, Error); diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index 39534e075..7dcc6d44a 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -35,6 +35,7 @@ use nativelink_config::stores::{FastSlowSpec, FilesystemSpec, MemorySpec, StoreS use nativelink_error::{make_err, make_input_err, Code, Error}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::platform::Property; +use nativelink_proto::build::bazel::remote::execution::v2::Platform; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ execute_result, ConnectionResult, ExecuteResult, KillOperationRequest, StartExecute, @@ -248,6 +249,8 @@ async fn blake3_digest_function_registerd_properly() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { execute_request: Some(execute_request), operation_id, queued_timestamp: None, + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -924,6 +930,8 @@ async fn upload_files_from_above_cwd_test() -> Result<(), Box Result<(), Box> execute_request: Some(execute_request), operation_id, queued_timestamp: Some(queued_timestamp.into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -1280,6 +1290,8 @@ async fn cleanup_happens_on_job_failure() -> Result<(), Box Result<(), Box> { execute_request: Some(execute_request), operation_id, queued_timestamp: Some(make_system_time(1000).into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -1563,6 +1577,8 @@ exit 0 execute_request: Some(execute_request), operation_id, queued_timestamp: Some(make_system_time(1000).into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -1735,6 +1751,8 @@ exit 0 execute_request: Some(execute_request), operation_id, queued_timestamp: Some(make_system_time(1000).into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -1876,6 +1894,8 @@ exit 1 execute_request: Some(execute_request), operation_id, queued_timestamp: Some(make_system_time(1000).into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -2391,6 +2411,8 @@ async fn ensure_worker_timeout_chooses_correct_values() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { execute_request: Some(execute_request), operation_id, queued_timestamp: Some(make_system_time(1000).into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .and_then(|action| { @@ -2800,6 +2828,8 @@ async fn kill_all_waits_for_all_tasks_to_finish() -> Result<(), Box Result<(), Box Result<(), Box> { execute_request: Some(execute_request), operation_id, queued_timestamp: None, + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), }, ) .await?; @@ -3361,6 +3395,8 @@ async fn running_actions_manager_respects_action_timeout() -> Result<(), Box>> = Vec::new(); + + let maybe_origin_event_tx = cfg + .experimental_origin_events + .as_ref() + .map(|origin_events_cfg| { + let mut max_queued_events = origin_events_cfg.max_event_queue_size; + if max_queued_events == 0 { + max_queued_events = DEFAULT_MAX_QUEUE_EVENTS; + } + let (tx, rx) = mpsc::channel(max_queued_events); + let store_name = origin_events_cfg.publisher.store.as_str(); + let store = store_manager.get_store(store_name).err_tip(|| { + format!("Could not get store {store_name} for origin event publisher") + })?; + + root_futures.push(Box::pin( + OriginEventPublisher::new(store, rx, shutdown_tx.clone()) + .run() + .map(Ok), + )); + + Ok::<_, Error>(tx) + }) + .transpose()?; + let mut action_schedulers = HashMap::new(); let mut worker_schedulers = HashMap::new(); for SchedulerConfig { name, spec } in cfg.schedulers.iter().flatten() { let (maybe_action_scheduler, maybe_worker_scheduler) = - scheduler_factory(spec, &store_manager) + scheduler_factory(spec, &store_manager, maybe_origin_event_tx.as_ref()) .err_tip(|| format!("Failed to create scheduler '{name}'"))?; if let Some(action_scheduler) = maybe_action_scheduler { action_schedulers.insert(name.clone(), action_scheduler.clone()); @@ -241,8 +267,6 @@ async fn inner_main( }) .collect(); - let mut root_futures: Vec>> = Vec::new(); - let root_metrics = Arc::new(RwLock::new(RootMetrics { stores: store_manager.clone(), servers: server_metrics, @@ -250,30 +274,6 @@ async fn inner_main( schedulers: action_schedulers.clone(), })); - let maybe_origin_event_tx = cfg - .experimental_origin_events - .as_ref() - .map(|origin_events_cfg| { - let mut max_queued_events = origin_events_cfg.max_event_queue_size; - if max_queued_events == 0 { - max_queued_events = DEFAULT_MAX_QUEUE_EVENTS; - } - let (tx, rx) = mpsc::channel(max_queued_events); - let store_name = origin_events_cfg.publisher.store.as_str(); - let store = store_manager.get_store(store_name).err_tip(|| { - format!("Could not get store {store_name} for origin event publisher") - })?; - - root_futures.push(Box::pin( - OriginEventPublisher::new(store, rx, shutdown_tx.clone()) - .run() - .map(Ok), - )); - - Ok::<_, Error>(tx) - }) - .transpose()?; - for (server_cfg, connected_clients_mux) in servers_and_clients { let services = server_cfg .services