Skip to content

Commit

Permalink
Merge pull request #196 from marcelropos/command-logger-pipe
Browse files Browse the repository at this point in the history
✨ Add logger pipe command
  • Loading branch information
maxwai authored Feb 11, 2024
2 parents d1aaf67 + 34bb4bb commit f25789b
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 57 deletions.
47 changes: 40 additions & 7 deletions src/bot/checks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use poise::serenity_prelude::{ChannelId, Permissions, RoleId};
use poise::serenity_prelude::{ChannelId, GuildId, Permissions, RoleId};
use sqlx::{MySql, Pool};
use tracing::error;

use crate::bot::Context;
use crate::{env, mysql_lib};
Expand Down Expand Up @@ -55,6 +56,44 @@ pub async fn is_admin(ctx: Context<'_>) -> bool {
false
}

/// Returns false in case of an error
pub async fn is_bot_admin(ctx: Context<'_>) -> bool {
let author_id = ctx.author().id.0;
let main_guild_id = env::MAIN_GUILD_ID.get().unwrap();

let main_guild_member = ctx.http().get_member(*main_guild_id, author_id).await
.map_err(|err| error!(
error = err.to_string(),
member_id = author_id,
"Could not get main guild member"
)).ok();

let main_guild_member = if let Some(main_guild_member) = main_guild_member {
main_guild_member
} else {
return false;
};


let main_guild_roles = GuildId(*main_guild_id).roles(ctx.http()).await
.map_err(|err| error!(error = err.to_string(), "Could not get main guild roles"))
.ok();

let main_guild_roles = if let Some(main_guild_roles) = main_guild_roles {
main_guild_roles
} else {
return false;
};

let is_admin = main_guild_member.roles.iter().any(|role_id| {
main_guild_roles.get(role_id).map_or(false, |role| {
role.has_permission(Permissions::ADMINISTRATOR)
})
});

is_admin
}

/// Checks if the author is the owner of the guild where the message was sent
///
/// Will return false when:
Expand Down Expand Up @@ -110,9 +149,3 @@ pub async fn sent_in_setup_guild(ctx: Context<'_>, pool: &Pool<MySql>) -> bool {
}
false
}

/// Checks if the message was sent in a guild
#[allow(dead_code)]
pub async fn sent_in_guild(ctx: Context<'_>) -> bool {
ctx.guild_id().is_some()
}
64 changes: 62 additions & 2 deletions src/bot/commands.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use tracing::info;

use super::checks;
use crate::{
bot::{Context, Error},
logging, mysql_lib,
};
use std::time::SystemTime;
use crate::bot::{Context, Error};

/// ping command
#[poise::command(slash_command, prefix_command)]
Expand All @@ -18,4 +24,58 @@ pub async fn ping(ctx: Context<'_>) -> Result<(), Error> {
})
.await?;
Ok(())
}
}

/// admin-only. Sets the configured logger pipe channel to the current channel.
/// If the current channel is the current logger pipe channel, it will be deactivated.
#[poise::command(prefix_command, guild_only)]
pub async fn logger_pipe(ctx: Context<'_>) -> Result<(), Error> {
// Check permissions
if !checks::is_owner(ctx).await
&& !checks::is_admin(ctx).await
&& !checks::is_bot_admin(ctx).await
{
ctx.say("Missing permissions, requires admin permissions")
.await?;
return Ok(());
}

let guild_id = ctx.guild_id().unwrap();

let db = &ctx.data().database_pool;
let db_guild = if let Some(db_guild) = mysql_lib::get_guild(db, guild_id).await {
db_guild
} else {
ctx.say("Needs to be executed in an already setup guild")
.await?;
return Ok(());
};

let current_logger_pipe = db_guild.logger_pipe_channel;

if current_logger_pipe.is_some_and(|logger_pipe| logger_pipe == ctx.channel_id()) {
// Current channel is logger pipe
// => deactivate logger pipe
info!(?guild_id, "Deactivating logger pipe");

ctx.say("Deactivating logger pipe in current channel")
.await?;

mysql_lib::update_logger_pipe_channel(db, guild_id, None).await;

logging::remove_per_server_logging(guild_id);
} else {
// Logger pipe either not setup, or not the current channel
// => set current channel as logger pipe
info!(?guild_id, "Setting logger pipe");

ctx.say("Setting logger pipe to the current channel")
.await?;

mysql_lib::update_logger_pipe_channel(db, guild_id, Some(ctx.channel_id())).await;

logging::add_per_server_logging(guild_id, ctx.channel_id());
}

Ok(())
}
23 changes: 17 additions & 6 deletions src/bot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ pub async fn entrypoint(database_pool: Pool<MySql>, redis_client: Client) {
info!("Starting the bot");

let db_clone = database_pool.clone();
let framework = poise::Framework::builder()
let framework = Framework::builder()
.options(poise::FrameworkOptions {
commands: vec![commands::ping()],
commands: vec![
commands::ping(),
commands::logger_pipe()
],
allowed_mentions: Some({
let mut f = serenity::CreateAllowedMentions::default();
f.empty_parse()
Expand All @@ -46,13 +49,17 @@ pub async fn entrypoint(database_pool: Pool<MySql>, redis_client: Client) {
pre_command: |ctx| {
Box::pin(async move {
info!(
"Received Command from {} in channel {}: {}",
guild_id = ctx.guild_id().map(|id| id.0).unwrap_or(0),
"Received Command from @{}, in guild {}, in channel #{}: `{}`",
ctx.author().name,
ctx.guild()
.map(|guild| guild.name)
.unwrap_or("no-guild".to_string()),
ctx.channel_id()
.name(ctx.cache())
.await
.unwrap_or_else(|| { "Unknown".to_string() }),
ctx.invocation_string()
.unwrap_or("Unknown".to_string()),
ctx.invocation_string(),
);
})
},
Expand All @@ -74,7 +81,11 @@ pub async fn entrypoint(database_pool: Pool<MySql>, redis_client: Client) {

let built_framework = framework.build().await.expect("Err building poise client");

logging::setup_discord_logging(built_framework.clone(), db_clone).await;
logging::setup_discord_logging(
built_framework.client().cache_and_http.http.clone(),
db_clone,
)
.await;

built_framework
.start()
Expand Down
9 changes: 9 additions & 0 deletions src/env.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::env;
use std::str::FromStr;
use std::sync::OnceLock;

pub static MAIN_GUILD_ID: OnceLock<u64> = OnceLock::new();
Expand All @@ -10,8 +11,16 @@ pub static MYSQL_DATABASE: OnceLock<String> = OnceLock::new();
pub static MYSQL_USER: OnceLock<String> = OnceLock::new();
pub static MYSQL_PASSWORD: OnceLock<String> = OnceLock::new();
pub static BOT_TOKEN: OnceLock<String> = OnceLock::new();
pub static LOG_LEVEL: OnceLock<tracing::Level> = OnceLock::new();

pub fn init() {
LOG_LEVEL.get_or_init(|| {
env::var("RUST_LOG")
.map(|level| {
tracing::Level::from_str(level.as_str()).expect("Could not parse Logger level")
})
.unwrap_or(tracing::Level::INFO)
});
MAIN_GUILD_ID.get_or_init(|| {
env::var("MAIN_GUILD_ID")
.expect("No Main Guild ID given")
Expand Down
97 changes: 55 additions & 42 deletions src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ use std::fs;
use std::num::NonZeroU64;
use std::sync::{Arc, OnceLock};

use poise::serenity_prelude::{ChannelId, GuildId, Http};
use poise::serenity_prelude::futures::executor::block_on;
use poise::serenity_prelude::{ChannelId, GuildId};
use rolling_file::RollingConditionBasic;
use sqlx::{MySql, Pool};
use tokio::task::spawn_blocking;
use tracing::{error, info, Level, Subscriber};
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{filter, fmt, Registry, reload};
use tracing_subscriber::prelude::*;
use tracing_subscriber::reload::Handle;
use tracing_subscriber::{filter, fmt, reload, Registry};

use crate::{bot, env, mysql_lib};
use crate::{env, mysql_lib};

const LOG_FILE_MAX_SIZE_MB: u64 = 10;
const MAX_AMOUNT_LOG_FILES: usize = 10;
Expand All @@ -29,7 +30,7 @@ pub async fn setup_logging() -> WorkerGuard {

let (rolling_file_writer, worker_guard) = tracing_appender::non_blocking(
rolling_file::BasicRollingFileAppender::new(
"./appdata/logs/hm-discord-bot",
"./appdata/logs/hm-discord-bot.log",
RollingConditionBasic::new().max_size(LOG_FILE_MAX_SIZE_MB * 1024 * 1024),
MAX_AMOUNT_LOG_FILES,
)
Expand All @@ -41,12 +42,21 @@ pub async fn setup_logging() -> WorkerGuard {
let (discord_layer_reloadable, log_reload_handle) = reload::Layer::new(discord_layer);

let discord_layer_filtered = discord_layer_reloadable
.with_filter(filter::Targets::new().with_target("discord", Level::INFO));
.with_filter(filter::Targets::new().with_target("hm_discord_bot", Level::INFO));

tracing_subscriber::registry()
.with(discord_layer_filtered)
.with(fmt::layer().compact())
.with(fmt::layer().compact().with_writer(rolling_file_writer))
.with(
fmt::layer()
.with_writer(std::io::stdout.with_max_level(*env::LOG_LEVEL.get().unwrap()))
.compact(),
)
.with(
fmt::layer()
.compact()
.with_writer(rolling_file_writer.with_max_level(*env::LOG_LEVEL.get().unwrap()))
.with_ansi(false),
)
.init();

info!("Setup logging");
Expand All @@ -57,20 +67,18 @@ pub async fn setup_logging() -> WorkerGuard {
}

/// Panics if called twice
pub async fn setup_discord_logging(framework: Arc<bot::Framework>, db: Pool<MySql>) {
pub async fn setup_discord_logging(discord_http: Arc<Http>, db: Pool<MySql>) {
modify_discord_layer(|discord_layer| {
discord_layer.poise_framework = Some(framework.clone());
discord_layer.discord_http = Some(discord_http.clone());
});

let http = &framework.client().cache_and_http.http;

// Setup main logging guild/channel
let main_guild = http
let main_guild_channels = discord_http
.get_channels(*env::MAIN_GUILD_ID.get().unwrap())
.await
.expect("Could not get main guild");

let main_logging_channel = main_guild[0].id;
let main_logging_channel = main_guild_channels[0].id;

modify_discord_layer(|discord_layer| {
discord_layer.main_log_channel = NonZeroU64::new(main_logging_channel.0);
Expand All @@ -93,6 +101,20 @@ pub async fn setup_discord_logging(framework: Arc<bot::Framework>, db: Pool<MySq
});
}

/// Panics if called before [`setup_discord_logging`]
pub fn add_per_server_logging(guild_id: GuildId, log_channel_id: ChannelId) {
modify_discord_layer(|layer| {
layer.guild_to_log_channel.insert(guild_id, log_channel_id);
});
}

/// Panics if called before [`setup_discord_logging`]
pub fn remove_per_server_logging(guild_id: GuildId) {
modify_discord_layer(|layer| {
layer.guild_to_log_channel.remove(&guild_id);
});
}

fn modify_discord_layer(f: impl FnOnce(&mut DiscordTracingLayer)) {
let result = DISCORD_LAYER_CHANGE_HANDLE.get().unwrap().modify(f);

Expand All @@ -104,24 +126,9 @@ fn modify_discord_layer(f: impl FnOnce(&mut DiscordTracingLayer)) {
}
}

#[allow(dead_code)]
/// Panics if called before [`install_framework`]
pub fn add_per_server_logging(guild_id: GuildId, log_channel_id: ChannelId) {
let layer_change_handle = DISCORD_LAYER_CHANGE_HANDLE.get().unwrap();
let result = layer_change_handle.modify(|layer| {
layer.guild_to_log_channel.insert(guild_id, log_channel_id);
});
if let Err(err) = result {
error!(
error = err.to_string(),
"Failed to install poise framework into discord tracing layer"
);
}
}

struct DiscordTracingLayer {
main_log_channel: Option<NonZeroU64>,
poise_framework: Option<Arc<bot::Framework>>,
discord_http: Option<Arc<Http>>,
/// HashMap of GuilId's -> ChannelId's
guild_to_log_channel: HashMap<GuildId, ChannelId>,
}
Expand All @@ -130,7 +137,7 @@ impl DiscordTracingLayer {
pub fn new() -> DiscordTracingLayer {
DiscordTracingLayer {
main_log_channel: None,
poise_framework: None,
discord_http: None,
guild_to_log_channel: HashMap::new(),
}
}
Expand All @@ -145,8 +152,8 @@ where
event: &tracing::Event<'_>,
_ctx: tracing_subscriber::layer::Context<'_, S>,
) {
let poise_framework = if let Some(poise_framework) = &self.poise_framework {
poise_framework
let discord_http = if let Some(discord_http) = &self.discord_http {
discord_http
} else {
return;
};
Expand All @@ -161,21 +168,27 @@ where
event.record(&mut guild_id_visitor);
let guild_id = guild_id_visitor.guild_id;

let http = &poise_framework.client().cache_and_http.http;

if let Some(channel_id) = self.main_log_channel {
let _ = block_on(
ChannelId(channel_id.get())
.send_message(http, |m| m.content(format!("{event_level} {message}"))),
);
let local_discord_http = discord_http.clone();
let message_copy = message.clone();
spawn_blocking(move || {
let _ = block_on(
ChannelId(channel_id.get()).send_message(local_discord_http, |m| {
m.content(format!("{event_level} {}", message_copy))
}),
);
});
}

if let Some(guild_id) = guild_id {
if let Some(channel_id) = self.guild_to_log_channel.get(&guild_id.get().into()) {
let _ = block_on(
channel_id
.send_message(http, |m| m.content(format!("{event_level} {message}"))),
);
let channel_id = *channel_id;
let local_discord_http = discord_http.clone();
spawn_blocking(move || {
let _ = block_on(channel_id.send_message(local_discord_http.clone(), |m| {
m.content(format!("{event_level} {message}"))
}));
});
}
}
}
Expand Down

0 comments on commit f25789b

Please sign in to comment.