diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs index 08e64845d..ef7033296 100644 --- a/nativelink-scheduler/src/memory_awaited_action_db.rs +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -783,6 +783,15 @@ impl I + Clone + Send + Sync> AwaitedActionDbI }; *connected_clients += 1; + // Immediately mark the keep alive, we don't need to wake anyone + // so we always fake that it was not actually changed. + // Failing update the client could lead to the client connecting + // then not updating the keep alive in time, resulting in the + // operation timing out due to async behavior. + tx.send_if_modified(|awaited_action| { + awaited_action.update_client_keep_alive((self.now_fn)().now()); + false + }); let subscription = tx.subscribe(); self.client_operation_to_awaited_action diff --git a/nativelink-scheduler/src/store_awaited_action_db.rs b/nativelink-scheduler/src/store_awaited_action_db.rs index f482748e5..823589d6c 100644 --- a/nativelink-scheduler/src/store_awaited_action_db.rs +++ b/nativelink-scheduler/src/store_awaited_action_db.rs @@ -490,7 +490,7 @@ where .await .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")?; match maybe_awaited_action { - Some(awaited_action) => { + Some(mut awaited_action) => { // TODO(allada) We don't support joining completed jobs because we // need to also check that all the data is still in the cache. if awaited_action.state().stage.is_finished() { @@ -498,10 +498,23 @@ where } // TODO(allada) We only care about the operation_id here, we should // have a way to tell the decoder we only care about specific fields. - let operation_id = awaited_action.operation_id(); + let operation_id = awaited_action.operation_id().clone(); + + awaited_action.update_client_keep_alive((self.now_fn)().now()); + let update_res = inner_update_awaited_action(self.store.as_ref(), awaited_action) + .await + .err_tip(|| "In OperationSubscriber::changed"); + if let Err(err) = update_res { + event!( + Level::WARN, + "Error updating client keep alive in RedisAwaitedActionDb::try_subscribe - {err:?} - This is not a critical error, but we did decide to create a new action instead of joining an existing one." + ); + return Ok(None); + } + Ok(Some(OperationSubscriber::new( Some(client_operation_id.clone()), - OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())), + OperationIdToAwaitedAction(Cow::Owned(operation_id)), Arc::downgrade(&self.store), self.now_fn.clone(), ))) diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 4ce28a243..c3fea27ad 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -2221,3 +2221,71 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { Ok(()) } + +#[nativelink_test] +async fn client_timesout_job_then_same_action_requested() -> Result<(), Error> { + const CLIENT_ACTION_TIMEOUT_S: u64 = 60; + let task_change_notify = Arc::new(Notify::new()); + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( + &SimpleSpec { + worker_timeout_s: WORKER_TIMEOUT_S, + client_action_timeout_s: CLIENT_ACTION_TIMEOUT_S, + ..Default::default() + }, + memory_awaited_action_db_factory( + 0, + &task_change_notify.clone(), + MockInstantWrapped::default, + ), + || async move {}, + task_change_notify, + MockInstantWrapped::default, + ); + let action_digest = DigestInfo::new([99u8; 32], 512); + + { + let insert_timestamp = make_system_time(1); + let mut action_listener = + setup_action(&scheduler, action_digest, HashMap::new(), insert_timestamp) + .await + .unwrap(); + + // We should get one notification saying it's queued. + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Queued + ); + + let changed_fut = action_listener.changed(); + tokio::pin!(changed_fut); + + MockClock::advance(Duration::from_secs(2)); + scheduler.do_try_match_for_test().await.unwrap(); + assert_eq!(poll!(&mut changed_fut), Poll::Pending); + } + + MockClock::advance(Duration::from_secs(CLIENT_ACTION_TIMEOUT_S + 1)); + + { + let insert_timestamp = make_system_time(1); + let mut action_listener = + setup_action(&scheduler, action_digest, HashMap::new(), insert_timestamp) + .await + .unwrap(); + + // We should get one notification saying it's queued. + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Queued + ); + + let changed_fut = action_listener.changed(); + tokio::pin!(changed_fut); + + MockClock::advance(Duration::from_secs(2)); + tokio::task::yield_now().await; + assert_eq!(poll!(&mut changed_fut), Poll::Pending); + } + + Ok(()) +}