Skip to content

Commit

Permalink
refactor(barrier): explicitly maintain database barrier state separat…
Browse files Browse the repository at this point in the history
…ely in local barrier manager (#19556)
  • Loading branch information
wenym1 authored Nov 29, 2024
1 parent ea2f775 commit 3b3a1c5
Show file tree
Hide file tree
Showing 23 changed files with 597 additions and 330 deletions.
21 changes: 15 additions & 6 deletions proto/stream_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ option optimize_for = SPEED;
message InjectBarrierRequest {
string request_id = 1;
stream_plan.Barrier barrier = 2;
uint32 database_id = 3;
repeated uint32 actor_ids_to_collect = 4;
repeated uint32 table_ids_to_sync = 5;
uint64 partial_graph_id = 6;
uint32 partial_graph_id = 6;

repeated common.ActorInfo broadcast_info = 8;
repeated stream_plan.StreamActor actors_to_build = 9;
Expand Down Expand Up @@ -48,9 +49,10 @@ message BarrierCompleteResponse {
uint32 worker_id = 5;
map<uint32, hummock.TableWatermarks> table_watermarks = 6;
repeated hummock.SstableInfo old_value_sstables = 7;
uint64 partial_graph_id = 8;
uint32 partial_graph_id = 8;
// prev_epoch of barrier
uint64 epoch = 9;
uint32 database_id = 10;
}

message WaitEpochCommitRequest {
Expand All @@ -64,20 +66,27 @@ message WaitEpochCommitResponse {

message StreamingControlStreamRequest {
message InitialPartialGraph {
uint64 partial_graph_id = 1;
uint32 partial_graph_id = 1;
repeated stream_plan.SubscriptionUpstreamInfo subscriptions = 2;
}

message DatabaseInitialPartialGraph {
uint32 database_id = 1;
repeated InitialPartialGraph graphs = 2;
}

message InitRequest {
repeated InitialPartialGraph graphs = 1;
repeated DatabaseInitialPartialGraph databases = 1;
}

message CreatePartialGraphRequest {
uint64 partial_graph_id = 1;
uint32 partial_graph_id = 1;
uint32 database_id = 2;
}

message RemovePartialGraphRequest {
repeated uint64 partial_graph_ids = 1;
repeated uint32 partial_graph_ids = 1;
uint32 database_id = 2;
}

oneof request {
Expand Down
1 change: 1 addition & 0 deletions proto/task_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ message GetStreamRequest {
uint32 down_actor_id = 2;
uint32 up_fragment_id = 3;
uint32 down_fragment_id = 4;
uint32 database_id = 5;
}

oneof value {
Expand Down
2 changes: 1 addition & 1 deletion src/common/src/catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub struct DatabaseId {
}

impl DatabaseId {
pub fn new(database_id: u32) -> Self {
pub const fn new(database_id: u32) -> Self {
DatabaseId { database_id }
}

Expand Down
4 changes: 3 additions & 1 deletion src/compute/src/rpc/service/exchange_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use either::Either;
use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
use futures_async_stream::try_stream;
use risingwave_batch::task::BatchManager;
use risingwave_common::catalog::DatabaseId;
use risingwave_pb::task_service::exchange_service_server::ExchangeService;
use risingwave_pb::task_service::{
permits, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
Expand Down Expand Up @@ -93,6 +94,7 @@ impl ExchangeService for ExchangeServiceImpl {
down_actor_id,
up_fragment_id,
down_fragment_id,
database_id,
} = {
let req = request_stream
.next()
Expand All @@ -106,7 +108,7 @@ impl ExchangeService for ExchangeServiceImpl {

let receiver = self
.stream_mgr
.take_receiver((up_actor_id, down_actor_id))
.take_receiver(DatabaseId::new(database_id), (up_actor_id, down_actor_id))
.await?;

// Map the remaining stream to add-permits.
Expand Down
5 changes: 2 additions & 3 deletions src/meta/src/barrier/checkpoint/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl CheckpointControl {
resp: BarrierCompleteResponse,
control_stream_manager: &mut ControlStreamManager,
) -> MetaResult<()> {
let database_id = from_partial_graph_id(resp.partial_graph_id).0;
let database_id = DatabaseId::new(resp.database_id);
self.databases
.get_mut(&database_id)
.expect("should exist")
Expand Down Expand Up @@ -435,8 +435,7 @@ impl DatabaseCheckpointControl {
partial_graph_id = resp.partial_graph_id,
"barrier collected"
);
let (database_id, creating_job_id) = from_partial_graph_id(resp.partial_graph_id);
assert_eq!(database_id, self.database_id);
let creating_job_id = from_partial_graph_id(resp.partial_graph_id);
match creating_job_id {
None => {
if let Some(node) = self.command_ctx_queue.get_mut(&prev_epoch) {
Expand Down
84 changes: 29 additions & 55 deletions src/meta/src/barrier/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ use risingwave_pb::meta::PausedReason;
use risingwave_pb::stream_plan::barrier_mutation::Mutation;
use risingwave_pb::stream_plan::{Barrier, BarrierMutation, StreamActor, SubscriptionUpstreamInfo};
use risingwave_pb::stream_service::streaming_control_stream_request::{
CreatePartialGraphRequest, PbInitRequest, PbInitialPartialGraph, RemovePartialGraphRequest,
CreatePartialGraphRequest, PbDatabaseInitialPartialGraph, PbInitRequest, PbInitialPartialGraph,
RemovePartialGraphRequest,
};
use risingwave_pb::stream_service::{
streaming_control_stream_request, streaming_control_stream_response, BarrierCompleteResponse,
Expand All @@ -54,25 +55,21 @@ use crate::{MetaError, MetaResult};

const COLLECT_ERROR_TIMEOUT: Duration = Duration::from_secs(3);

fn to_partial_graph_id(database_id: DatabaseId, job_id: Option<TableId>) -> u64 {
((database_id.database_id as u64) << u32::BITS)
| (job_id
.map(|table| {
assert_ne!(table.table_id, u32::MAX);
table.table_id
})
.unwrap_or(u32::MAX) as u64)
fn to_partial_graph_id(job_id: Option<TableId>) -> u32 {
job_id
.map(|table| {
assert_ne!(table.table_id, u32::MAX);
table.table_id
})
.unwrap_or(u32::MAX)
}

pub(super) fn from_partial_graph_id(partial_graph_id: u64) -> (DatabaseId, Option<TableId>) {
let database_id = DatabaseId::new((partial_graph_id >> u32::BITS) as u32);
let job_id = (partial_graph_id & (u32::MAX as u64)) as u32;
let job_id = if job_id == u32::MAX {
pub(super) fn from_partial_graph_id(partial_graph_id: u32) -> Option<TableId> {
if partial_graph_id == u32::MAX {
None
} else {
Some(TableId::new(job_id))
};
(database_id, job_id)
Some(TableId::new(partial_graph_id))
}
}

struct ControlStreamNode {
Expand Down Expand Up @@ -272,10 +269,13 @@ impl ControlStreamManager {
initial_subscriptions: impl Iterator<Item = (DatabaseId, &InflightSubscriptionInfo)>,
) -> PbInitRequest {
PbInitRequest {
graphs: initial_subscriptions
.map(|(database_id, info)| PbInitialPartialGraph {
partial_graph_id: to_partial_graph_id(database_id, None),
subscriptions: info.into_iter().collect_vec(),
databases: initial_subscriptions
.map(|(database_id, info)| PbDatabaseInitialPartialGraph {
database_id: database_id.database_id,
graphs: vec![PbInitialPartialGraph {
partial_graph_id: to_partial_graph_id(None),
subscriptions: info.into_iter().collect_vec(),
}],
})
.collect(),
}
Expand Down Expand Up @@ -335,7 +335,7 @@ impl ControlStreamManager {
"inject_barrier_err"
));

let partial_graph_id = to_partial_graph_id(database_id, creating_table_id);
let partial_graph_id = to_partial_graph_id(creating_table_id);

let node_actors = InflightFragmentInfo::actor_ids_to_collect(pre_applied_graph_info);

Expand Down Expand Up @@ -399,6 +399,7 @@ impl ControlStreamManager {
InjectBarrierRequest {
request_id: Uuid::new_v4().to_string(),
barrier: Some(barrier),
database_id: database_id.database_id,
actor_ids_to_collect,
table_ids_to_sync: table_ids_to_sync
.iter()
Expand Down Expand Up @@ -451,14 +452,17 @@ impl ControlStreamManager {
database_id: DatabaseId,
creating_job_id: Option<TableId>,
) -> MetaResult<()> {
let partial_graph_id = to_partial_graph_id(database_id, creating_job_id);
let partial_graph_id = to_partial_graph_id(creating_job_id);
self.nodes.iter().try_for_each(|(_, node)| {
node.handle
.request_sender
.send(StreamingControlStreamRequest {
request: Some(
streaming_control_stream_request::Request::CreatePartialGraph(
CreatePartialGraphRequest { partial_graph_id },
CreatePartialGraphRequest {
database_id: database_id.database_id,
partial_graph_id,
},
),
),
})
Expand All @@ -477,7 +481,7 @@ impl ControlStreamManager {
}
let partial_graph_ids = creating_job_ids
.into_iter()
.map(|job_id| to_partial_graph_id(database_id, Some(job_id)))
.map(|job_id| to_partial_graph_id(Some(job_id)))
.collect_vec();
self.nodes.iter().for_each(|(_, node)| {
if node.handle
Expand All @@ -487,6 +491,7 @@ impl ControlStreamManager {
streaming_control_stream_request::Request::RemovePartialGraph(
RemovePartialGraphRequest {
partial_graph_ids: partial_graph_ids.clone(),
database_id: database_id.database_id,
},
),
),
Expand Down Expand Up @@ -567,34 +572,3 @@ pub(super) fn merge_node_rpc_errors<E: Error + Send + Sync + 'static>(
});
anyhow!(concat).into()
}

#[cfg(test)]
mod tests {
use risingwave_common::catalog::{DatabaseId, TableId};

use crate::barrier::rpc::{from_partial_graph_id, to_partial_graph_id};

#[test]
fn test_partial_graph_id_convert() {
fn test_convert(database_id: u32, job_id: Option<u32>) {
let database_id = DatabaseId::new(database_id);
let job_id = job_id.map(TableId::new);
assert_eq!(
(database_id, job_id),
from_partial_graph_id(to_partial_graph_id(database_id, job_id))
);
}
for database_id in [0, 1, 2, u32::MAX - 1, u32::MAX >> 1] {
for job_id in [
Some(0),
Some(1),
Some(2),
None,
Some(u32::MAX >> 1),
Some(u32::MAX - 1),
] {
test_convert(database_id, job_id);
}
}
}
}
3 changes: 3 additions & 0 deletions src/rpc_client/src/compute_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::time::Duration;

use async_trait::async_trait;
use futures::StreamExt;
use risingwave_common::catalog::DatabaseId;
use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, STREAM_WINDOW_SIZE};
use risingwave_common::monitor::{EndpointExt, TcpConfig};
use risingwave_common::util::addr::HostAddr;
Expand Down Expand Up @@ -115,6 +116,7 @@ impl ComputeClient {
down_actor_id: u32,
up_fragment_id: u32,
down_fragment_id: u32,
database_id: DatabaseId,
) -> Result<(
Streaming<GetStreamResponse>,
mpsc::UnboundedSender<permits::Value>,
Expand All @@ -132,6 +134,7 @@ impl ComputeClient {
down_actor_id,
up_fragment_id,
down_fragment_id,
database_id: database_id.database_id,
})),
},
))
Expand Down
15 changes: 8 additions & 7 deletions src/stream/src/executor/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1254,16 +1254,17 @@ mod tests {
},
));
barrier_test_env.inject_barrier(&b1, [actor_id]);
barrier_test_env
.shared_context
.local_barrier_manager
.flush_all_events()
.await;
barrier_test_env.flush_all_events().await;

let input = Executor::new(
Default::default(),
ReceiverExecutor::for_test(actor_id, rx, barrier_test_env.shared_context.clone())
.boxed(),
ReceiverExecutor::for_test(
actor_id,
rx,
barrier_test_env.shared_context.clone(),
barrier_test_env.local_barrier_manager.clone(),
)
.boxed(),
);
let executor = Box::new(DispatchExecutor::new(
input,
Expand Down
16 changes: 15 additions & 1 deletion src/stream/src/executor/exchange/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ pub struct RemoteInput {
}

use remote_input::RemoteInputStreamInner;
use risingwave_common::catalog::DatabaseId;

impl RemoteInput {
/// Create a remote input from compute client and related info. Should provide the corresponding
Expand All @@ -170,6 +171,7 @@ impl RemoteInput {
upstream_addr: HostAddr,
up_down_ids: UpDownActorIds,
up_down_frag: UpDownFragmentIds,
database_id: DatabaseId,
metrics: Arc<StreamingMetrics>,
batched_permits: usize,
) -> Self {
Expand All @@ -182,6 +184,7 @@ impl RemoteInput {
upstream_addr,
up_down_ids,
up_down_frag,
database_id,
metrics,
batched_permits,
),
Expand All @@ -194,6 +197,7 @@ mod remote_input {

use anyhow::Context;
use await_tree::InstrumentAwait;
use risingwave_common::catalog::DatabaseId;
use risingwave_common::util::addr::HostAddr;
use risingwave_pb::task_service::{permits, GetStreamResponse};
use risingwave_rpc_client::ComputeClientPool;
Expand All @@ -211,6 +215,7 @@ mod remote_input {
upstream_addr: HostAddr,
up_down_ids: UpDownActorIds,
up_down_frag: UpDownFragmentIds,
database_id: DatabaseId,
metrics: Arc<StreamingMetrics>,
batched_permits_limit: usize,
) -> RemoteInputStreamInner {
Expand All @@ -219,6 +224,7 @@ mod remote_input {
upstream_addr,
up_down_ids,
up_down_frag,
database_id,
metrics,
batched_permits_limit,
)
Expand All @@ -230,12 +236,19 @@ mod remote_input {
upstream_addr: HostAddr,
up_down_ids: UpDownActorIds,
up_down_frag: UpDownFragmentIds,
database_id: DatabaseId,
metrics: Arc<StreamingMetrics>,
batched_permits_limit: usize,
) {
let client = client_pool.get_by_addr(upstream_addr).await?;
let (stream, permits_tx) = client
.get_stream(up_down_ids.0, up_down_ids.1, up_down_frag.0, up_down_frag.1)
.get_stream(
up_down_ids.0,
up_down_ids.1,
up_down_frag.0,
up_down_frag.1,
database_id,
)
.await?;

let up_actor_id = up_down_ids.0.to_string();
Expand Down Expand Up @@ -336,6 +349,7 @@ pub(crate) fn new_input(
upstream_addr,
(upstream_actor_id, actor_id),
(upstream_fragment_id, fragment_id),
context.database_id,
metrics,
context.config.developer.exchange_batched_permits,
)
Expand Down
Loading

0 comments on commit 3b3a1c5

Please sign in to comment.