Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs(gateway): cleanup reshard example #2224

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 85 additions & 71 deletions examples/gateway-reshard.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use futures_util::StreamExt;
use std::{env, sync::Arc, time::Duration};
use std::{
env,
future::{poll_fn, Future},
task::Poll,
time::Duration,
};
use tokio::time;
use twilight_gateway::{
stream::{self, ShardEventStream, ShardMessageStream},
Config, ConfigBuilder, Event, Intents, Shard, ShardId,
Config, ConfigBuilder, Intents, Shard, ShardId,
};
use twilight_http::Client;

Expand All @@ -13,31 +18,22 @@ async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();

let token = env::var("DISCORD_TOKEN")?;
let client = Arc::new(Client::new(token.clone()));
let config = Config::new(
token.clone(),
Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT,
);
let client = Client::new(token.clone());
let config = Config::new(token, Intents::GUILDS);
let config_callback = |_, builder: ConfigBuilder| builder.build();

let mut shards = stream::create_recommended(&client, config.clone(), &config_callback)
let mut shards = stream::create_recommended(&client, config.clone(), config_callback)
.await?
.collect::<Vec<_>>();

loop {
// Run `gateway_runner` and `reshard` concurrently until the first one
// finishes.
// Run the two futures concurrently, returning when the first branch
// completes and cancels the other one.
tokio::select! {
// Gateway_runner only finishes on errors, so break the loop and exit
// the program.
_ = gateway_runner(Arc::clone(&client), shards) => break,
// Resharding complete! Time to run `gateway_runner` with the new
// list of shards.
Ok(Some(new_shards)) = reshard(&client, config.clone(), config_callback) => {
// Assign the new list of shards to `shards`, dropping the
// old list.
shards = new_shards;
},
_ = runner(shards) => break,
new_shards = reshard(&client, config.clone(), config_callback) => {
shards = new_shards?;
}
}
}

Expand All @@ -46,13 +42,13 @@ async fn main() -> anyhow::Result<()> {

// Instrument to differentiate between the logs produced here and in `reshard`.
#[tracing::instrument(skip_all)]
async fn gateway_runner(client: Arc<Client>, mut shards: Vec<Shard>) {
async fn runner(mut shards: Vec<Shard>) {
let mut stream = ShardEventStream::new(shards.iter_mut());

loop {
let event = match stream.next().await {
Some((_, Ok(event))) => event,
Some((_, Err(source))) => {
while let Some((shard, event)) = stream.next().await {
let event = match event {
Ok(event) => event,
Err(source) => {
tracing::warn!(?source, "error receiving event");

if source.is_fatal() {
Expand All @@ -61,73 +57,91 @@ async fn gateway_runner(client: Arc<Client>, mut shards: Vec<Shard>) {

continue;
}
None => break,
};

tokio::spawn(event_handler(Arc::clone(&client), event));
tracing::debug!(?event, shard = ?shard.id(), "received event");
}
}

async fn event_handler(client: Arc<Client>, event: Event) -> anyhow::Result<()> {
match event {
Event::MessageCreate(message) if message.content == "!ping" => {
client
.create_message(message.channel_id)
.content("Pong!")?
.await?;
}
_ => {}
}

Ok(())
}

// Instrument to differentiate between the logs produced here and
// in `gateway_runner`.
// Instrument to differentiate between the logs produced here and in `runner`.
#[tracing::instrument(skip_all)]
async fn reshard(
client: &Client,
config: Config,
config_callback: impl Fn(ShardId, ConfigBuilder) -> Config,
) -> anyhow::Result<Option<Vec<Shard>>> {
) -> anyhow::Result<Vec<Shard>> {
// Reshard every eight hours. This is an arbitrary number.
const RESHARD_DURATION: Duration = Duration::from_secs(60 * 60 * 8);

// Reshard every eight hours.
time::sleep(RESHARD_DURATION).await;

let mut shards = stream::create_recommended(client, config, config_callback)
.await?
.collect::<Vec<_>>();
let info = client.gateway().authed().await?.model().await?;

let mut shards =
stream::create_range(.., info.shards, config, config_callback).collect::<Vec<_>>();

let expected_duration = estimate_identifed(
info.shards,
info.session_start_limit.max_concurrency,
info.session_start_limit.remaining,
Duration::from_millis(info.session_start_limit.reset_after),
info.session_start_limit.total,
);
tokio::pin! {
let timeout = time::sleep(expected_duration);
}
// Register timer.
poll_fn(|cx| {
_ = timeout.as_mut().poll(cx);
Poll::Ready(())
})
.await;

// Before swapping the old and new list of shards, try to identify them.
// Don't try too hard, however, as large bots may never have all shards
// identified at the same time.
let mut identified = vec![false; shards.len()];
// Don't deserialize any events (with `ShardEventStream`) as the already
// running shards will handle them (the events are duplicated).
let mut stream = ShardMessageStream::new(shards.iter_mut());

// Drive the new list of shards until they are all identified.
while !identified.iter().all(|&shard| shard) {
match stream.next().await {
Some((_, Err(source))) => {
tracing::warn!(?source, "error receiving event");

if source.is_fatal() {
// When returning `None` `reshard` will be called again,
// retrying after `RESHARD_DURATION`.
// A fatal error will however most likely also be
// encountered for the currently running list of shards at
// the same time, exciting the application.
return Ok(None);
}

continue;
loop {
let identified_count = identified.iter().map(|&i| i as usize).sum::<usize>();
tokio::select! {
_ = &mut timeout, if identified_count >= (identified.len() * 3) / 4 => {
drop(stream);
break;
}
Some((shard, _)) => {
identified[shard.id().number() as usize] = shard.status().is_identified();
Some(res) = stream.next() => {
match res {
(_, Err(source)) => {
tracing::warn!(?source, "error receiving message");

if source.is_fatal() {
anyhow::bail!(source);
}
}
(shard, _) => {
identified[shard.id().number() as usize] = shard.status().is_identified();
}
}
}
None => return Ok(None),
}
}

drop(stream);
Ok(Some(shards))
Ok(shards)
}

fn estimate_identifed(
shards: u64,
max_concurrency: u64,
remaining: u64,
reset_after: Duration,
total: u64,
) -> Duration {
const DAY: Duration = Duration::from_secs(60 * 60 * 24);

let refills = shards / remaining;
let buckets = (shards as f32 / max_concurrency as f32).round() as u64;
reset_after * (refills > 0) as u32
+ (1..refills).map(|_| DAY).sum::<Duration>()
+ Duration::from_secs(5 * buckets % total)
}