diff --git a/examples/gateway-parallel.rs b/examples/gateway-parallel.rs index 6c4b845efe2..cbd1a5d932e 100644 --- a/examples/gateway-parallel.rs +++ b/examples/gateway-parallel.rs @@ -1,16 +1,8 @@ -//! Run the recommended number of shards in `available_parallelism` tasks. Note -//! that neither [`ShardEventStream`] nor [`ShardMessageStream`] is required for -//! parallelism; each shard can run on an independent task. -//! -//! [`ShardMessageStream`]: twilight_gateway::stream::ShardMessageStream - -use futures_util::{future::join_all, StreamExt}; -use std::{env, iter, thread}; +//! Run the recommended number of shards, each in their own task. + +use std::env; use tokio::{signal, sync::watch, task::JoinSet}; -use twilight_gateway::{ - stream::{self, ShardEventStream}, - CloseFrame, Config, Intents, Shard, -}; +use twilight_gateway::{stream, CloseFrame, Config, Intents, Shard}; use twilight_http::Client; #[tokio::main] @@ -19,37 +11,25 @@ async fn main() -> anyhow::Result<()> { let token = env::var("DISCORD_TOKEN")?; let client = Client::new(token.clone()); - let config = Config::new(token.clone(), Intents::GUILDS); - - let tasks = thread::available_parallelism()?.get(); + let config = Config::new(token, Intents::GUILDS); - // Split shards into a vec of `tasks` vecs of shards. - let init = iter::repeat_with(Vec::new) - .take(tasks) - .collect::>>(); let shards = stream::create_recommended(&client, config, |_, builder| builder.build()) .await? - .enumerate() - .fold(init, |mut fold, (idx, shard)| { - fold[idx % tasks].push(shard); - fold - }); + .collect::>(); let (tx, rx) = watch::channel(false); let mut set = JoinSet::new(); - for mut shards in shards { + for mut shard in shards { let mut rx = rx.clone(); set.spawn(async move { - // Run `process` and `rx.changed()` concurrently, returning when - // the first branch completes, cancelling the other one. + // Run the two futures concurrently, returning when the first branch + // completes, cancelling the other one. tokio::select! { - _ = process(shards.iter_mut()) => {}, + _ = runner(&mut shard) => {}, _ = rx.changed() => { - join_all(shards.iter_mut().map(|shard| async move { - shard.close(CloseFrame::NORMAL).await - })).await; + _ = shard.close(CloseFrame::NORMAL).await; } } }); @@ -68,12 +48,11 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn process(shards: impl Iterator) { - let mut stream = ShardEventStream::new(shards); +async fn runner(shard: &mut Shard) { loop { - let (shard, event) = match stream.next().await { - Some((shard, Ok(event))) => (shard, event), - Some((_, Err(source))) => { + let event = match shard.next_event().await { + Ok(event) => event, + Err(source) => { tracing::warn!(?source, "error receiving event"); if source.is_fatal() { @@ -82,7 +61,6 @@ async fn process(shards: impl Iterator) { continue; } - None => break, }; tracing::debug!(?event, shard = ?shard.id(), "received event");