diff --git a/README.md b/README.md index 2d1f458..8abee81 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![GitHub Workflow Status (CI)](https://img.shields.io/github/actions/workflow/status/lizardbyte/support-bot/ci.yml.svg?branch=master&label=CI%20build&logo=github&style=for-the-badge)](https://github.com/LizardByte/support-bot/actions/workflows/ci.yml?query=branch%3Amaster) [![Codecov](https://img.shields.io/codecov/c/gh/LizardByte/support-bot.svg?token=900Q93P1DE&style=for-the-badge&logo=codecov&label=codecov)](https://app.codecov.io/gh/LizardByte/support-bot) -Support bot written in python to help manage LizardByte communities. The current focus is discord and reddit, but other -platforms such as GitHub discussions/issues could be added. +Support bot written in python to help manage LizardByte communities. The current focus is Discord and Reddit, but other +platforms such as GitHub discussions/issues might be added in the future. ## Overview @@ -31,6 +31,9 @@ platforms such as GitHub discussions/issues could be added. | variable | required | default | description | |-------------------------|----------|------------------------------------------------------|---------------------------------------------------------------| | DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | +| DISCORD_CLIENT_ID | True | `None` | Discord OAuth2 client id. | +| DISCORD_CLIENT_SECRET | True | `None` | Discord OAuth2 client secret. | +| DISCORD_REDIRECT_URI | False | `https://localhost:8080/discord/callback` | The redirect uri for OAuth2. Must be publicly accessible. | | DAILY_TASKS | False | `true` | Daily tasks on or off. | | DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | | DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | @@ -41,11 +44,6 @@ platforms such as GitHub discussions/issues could be added. | SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | | SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | -* Running bot: - * `python -m src` -* Invite bot to server: - * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` - ### Reddit @@ -62,7 +60,13 @@ platforms such as GitHub discussions/issues could be added. | DISCORD_WEBHOOK | False | None | URL of webhook to send discord notifications to | | GRAVATAR_EMAIL | False | None | Gravatar email address to get avatar from | | REDDIT_USERNAME | True | None | Reddit username | -* | REDDIT_PASSWORD | True | None | Reddit password | + | REDDIT_PASSWORD | True | None | Reddit password | + +### Start -* Running bot: - * `python -m src` +```bash +python -m src +``` + +* Invite bot to server: + * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` diff --git a/requirements.txt b/requirements.txt index 2f48774..7581d55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ praw==7.8.1 py-cord==2.6.1 python-dotenv==1.0.1 requests==2.32.3 +requests-oauthlib==2.0.0 diff --git a/src/__main__.py b/src/__main__.py index 7968744..c971466 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,5 +1,4 @@ # standard imports -import os import time # development imports @@ -8,33 +7,28 @@ # local imports if True: # hack for flake8 + from src.common import globals from src.discord import bot as d_bot - from src import keep_alive + from src.common import webapp from src.reddit import bot as r_bot def main(): - # to run in replit - try: - os.environ['REPL_SLUG'] - except KeyError: - pass # not running in replit - else: - keep_alive.keep_alive() # Start the web server + webapp.start() # Start the web server - discord_bot = d_bot.Bot() - discord_bot.start_threaded() # Start the discord bot + globals.DISCORD_BOT = d_bot.Bot() + globals.DISCORD_BOT.start_threaded() # Start the discord bot - reddit_bot = r_bot.Bot() - reddit_bot.start_threaded() # Start the reddit bot + globals.REDDIT_BOT = r_bot.Bot() + globals.REDDIT_BOT.start_threaded() # Start the reddit bot try: - while discord_bot.bot_thread.is_alive() or reddit_bot.bot_thread.is_alive(): + while globals.DISCORD_BOT.bot_thread.is_alive() or globals.REDDIT_BOT.bot_thread.is_alive(): time.sleep(0.5) except KeyboardInterrupt: print("Keyboard Interrupt Detected") - discord_bot.stop() - reddit_bot.stop() + globals.DISCORD_BOT.stop() + globals.REDDIT_BOT.stop() if __name__ == '__main__': # pragma: no cover diff --git a/src/common/__init__.py b/src/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/common.py b/src/common/common.py similarity index 100% rename from src/common.py rename to src/common/common.py diff --git a/src/common/crypto.py b/src/common/crypto.py new file mode 100644 index 0000000..a59cf77 --- /dev/null +++ b/src/common/crypto.py @@ -0,0 +1,69 @@ +# standard imports +import os + +# lib imports +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption +from datetime import datetime, timedelta, UTC + +# local imports +from src.common import common + +CERT_FILE = os.path.join(common.data_dir, "cert.pem") +KEY_FILE = os.path.join(common.data_dir, "key.pem") + + +def check_expiration(cert_path: str) -> int: + with open(cert_path, "rb") as cert_file: + cert_data = cert_file.read() + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + expiry_date = cert.not_valid_after_utc + return (expiry_date - datetime.now(UTC)).days + + +def generate_certificate(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, + ) + subject = issuer = x509.Name([ + x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"), + ]) + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.now(UTC) + ).not_valid_after( + datetime.now(UTC) + timedelta(days=365) + ).sign(private_key, hashes.SHA256()) + + with open(KEY_FILE, "wb") as f: + f.write(private_key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=NoEncryption(), + )) + + with open(CERT_FILE, "wb") as f: + f.write(cert.public_bytes(Encoding.PEM)) + + +def initialize_certificate() -> tuple[str, str]: + print("Initializing SSL certificate") + if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE): + cert_expires_in = check_expiration(CERT_FILE) + print(f"Certificate expires in {cert_expires_in} days.") + if cert_expires_in >= 90: + return CERT_FILE, KEY_FILE + print("Generating new certificate") + generate_certificate() + return CERT_FILE, KEY_FILE diff --git a/src/common/database.py b/src/common/database.py new file mode 100644 index 0000000..b27fdd4 --- /dev/null +++ b/src/common/database.py @@ -0,0 +1,22 @@ +# standard imports +import shelve +import threading + + +class Database: + def __init__(self, db_path): + self.db_path = db_path + self.lock = threading.Lock() + + def __enter__(self): + self.lock.acquire() + self.db = shelve.open(self.db_path, writeback=True) + return self.db + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync() + self.db.close() + self.lock.release() + + def sync(self): + self.db.sync() diff --git a/src/common/globals.py b/src/common/globals.py new file mode 100644 index 0000000..f185cab --- /dev/null +++ b/src/common/globals.py @@ -0,0 +1,2 @@ +DISCORD_BOT = None +REDDIT_BOT = None diff --git a/src/common/webapp.py b/src/common/webapp.py new file mode 100644 index 0000000..46bd50f --- /dev/null +++ b/src/common/webapp.py @@ -0,0 +1,150 @@ +# standard imports +import asyncio +import os +from threading import Thread +from typing import Tuple + +# lib imports +import discord +from flask import Flask, jsonify, redirect, request, Response +from requests_oauthlib import OAuth2Session + +# local imports +from src.common import crypto +from src.common import globals + + +DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") +DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") +DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", "https://localhost:8080/discord/callback") + +app = Flask('LizardByte-bot') + + +@app.route('/') +def main(): + return "LizardByte-bot is live!" + + +@app.route("/discord/callback") +def discord_callback(): + # get all active states from the global state manager + with globals.DISCORD_BOT.db as db: + active_states = db['oauth_states'] + + discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) + token = discord_oauth.fetch_token("https://discord.com/api/oauth2/token", + client_secret=DISCORD_CLIENT_SECRET, + authorization_response=request.url) + + # Fetch the user's Discord profile + response = discord_oauth.get("https://discord.com/api/users/@me") + discord_user = response.json() + + # if the user is not in the active states, return an error + if discord_user['id'] not in active_states: + return "Invalid state" + + # remove the user from the active states + del active_states[discord_user['id']] + + # Fetch the user's connected accounts + connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") + connections = connections_response.json() + + with globals.DISCORD_BOT.db as db: + db['discord_users'] = db.get('discord_users', {}) + db['discord_users'][discord_user['id']] = { + 'discord_username': discord_user['username'], + 'discord_global_name': discord_user['global_name'], + 'github_id': None, + 'github_username': None, + 'token': token, # TODO: should we store the token at all? + } + + for connection in connections: + if connection['type'] == 'github': + db['discord_users'][discord_user['id']]['github_id'] = connection['id'] + db['discord_users'][discord_user['id']]['github_username'] = connection['name'] + + # Redirect to our main website + return redirect("https://app.lizardbyte.dev") + + +@app.route("/webhook/", methods=["POST"]) +def webhook(source: str) -> Tuple[Response, int]: + """ + Process webhooks from various sources. + + * GitHub sponsors: https://github.com/sponsors/LizardByte/dashboard/webhooks + * GitHub status: https://www.githubstatus.com + + Parameters + ---------- + source : str + The source of the webhook (e.g., 'github_sponsors', 'github_status'). + + Returns + ------- + flask.Response + Response to the webhook request + """ + valid_sources = ["github_sponsors", "github_status"] + + if source not in valid_sources: + return jsonify({"status": "error", "message": "Invalid source"}), 400 + + print(f"received webhook from {source}") + data = request.json + print(f"received webhook data: \n{data}") + + if source == "github_sponsors": + # ensure the secret matches + # if data['secret'] != os.getenv("GITHUB_SPONSORS_WEBHOOK_SECRET_KEY"): + # return jsonify({"status": "error", "message": "Invalid secret"}), 400 + + # process the webhook data + if data['action'] == "created": + message = f'New GitHub sponsor: {data["sponsorship"]["sponsor"]["login"]}' + + # create a discord embed + embed = discord.Embed( + author=discord.EmbedAuthor( + name=data["sponsorship"]["sponsor"]["login"], + url=data["sponsorship"]["sponsor"]["url"], + icon_url=data["sponsorship"]["sponsor"]["avatar_url"], + ), + color=0x00ff00, + description=message, + footer=discord.EmbedFooter( + text=f"Sponsored at {data['sponsorship']['created_at']}", + ), + title="New GitHub Sponsor", + ) + message = asyncio.run_coroutine_threadsafe( + globals.DISCORD_BOT.send_message_to_channel( + channel_id=os.getenv("DISCORD_SPONSORS_CHANNEL_ID"), + embeds=[embed], + ), globals.DISCORD_BOT.loop) + message.result() # wait for the message to be sent + + return jsonify({"status": "success"}), 200 + + +def run(): + cert_file, key_file = crypto.initialize_certificate() + + app.run( + host="0.0.0.0", + port=8080, + ssl_context=(cert_file, key_file) + ) + + +def start(): + server = Thread( + name="Flask", + daemon=True, + target=run, + ) + server.start() diff --git a/src/discord/bot.py b/src/discord/bot.py index a9baf6c..46faed9 100644 --- a/src/discord/bot.py +++ b/src/discord/bot.py @@ -7,7 +7,8 @@ import discord # local imports -from src.common import bot_name, get_avatar_bytes, org_name +from src.common.common import bot_name, data_dir, get_avatar_bytes, org_name +from src.common.database import Database from src.discord.tasks import daily_task from src.discord.views import DonateCommandView @@ -30,6 +31,7 @@ def __init__(self, *args, **kwargs): self.bot_thread = threading.Thread(target=lambda: None) self.token = os.environ['DISCORD_BOT_TOKEN'] + self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database')) self.load_extension( name='src.discord.cogs', @@ -37,6 +39,9 @@ def __init__(self, *args, **kwargs): store=False, ) + with self.db as db: + db['oauth_states'] = {} # clear any oauth states from previous sessions + async def on_ready(self): """ Bot on ready event. @@ -71,6 +76,32 @@ async def on_ready(self): else: print("'DAILY_TASKS' environment variable is disabled") + async def send_message_to_channel( + self, + channel_id: int, + message: str = None, + embeds: list[discord.Embed] = None, + ) -> discord.Message: + """ + Send a message to a specific channel. + + Parameters + ---------- + channel_id : int + The ID of the channel to send the message to. + message : str, optional + The message to send. + embeds : list[discord.Embed], optional + A list of embeds to send. + + Returns + ------- + discord.Message + The message that was sent. + """ + channel = await self.fetch_channel(channel_id) + return await channel.send(content=message, embeds=embeds) + def start_threaded(self): try: # Login the bot in a separate thread diff --git a/src/discord/cogs/base_commands.py b/src/discord/cogs/base_commands.py index 99734d8..94b42f1 100644 --- a/src/discord/cogs/base_commands.py +++ b/src/discord/cogs/base_commands.py @@ -3,7 +3,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name, org_name, version +from src.common.common import avatar, bot_name, org_name, version from src.discord.views import DonateCommandView from src.discord import cogs_common diff --git a/src/discord/cogs/fun_commands.py b/src/discord/cogs/fun_commands.py index 98e53f2..820fd1d 100644 --- a/src/discord/cogs/fun_commands.py +++ b/src/discord/cogs/fun_commands.py @@ -7,7 +7,7 @@ import requests # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name from src.discord.views import RefundCommandView from src.discord import cogs_common diff --git a/src/discord/cogs/github_commands.py b/src/discord/cogs/github_commands.py new file mode 100644 index 0000000..2054ab4 --- /dev/null +++ b/src/discord/cogs/github_commands.py @@ -0,0 +1,123 @@ +# standard imports +import os + +# lib imports +import discord +import requests +from requests_oauthlib import OAuth2Session + + +class GitHubCommandsCog(discord.Cog): + def __init__(self, bot): + self.bot = bot + self.token = os.getenv("GITHUB_TOKEN") + self.org_name = os.getenv("GITHUB_ORG_NAME", "LizardByte") + self.graphql_url = "https://api.github.com/graphql" + self.headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + @discord.slash_command( + name="get_sponsors", + description="Get list of GitHub sponsors", + default_member_permissions=discord.Permissions(manage_guild=True), + ) + async def get_sponsors( + self, + ctx: discord.ApplicationContext, + ): + """ + Get list of GitHub sponsors. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + edges { + node { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + name + monthlyPriceInDollars + } + } + } + } + } + } + """ % self.org_name + + response = requests.post(self.graphql_url, json={'query': query}, headers=self.headers) + data = response.json() + + if 'errors' in data: + print(data['errors']) + await ctx.respond("An error occurred while fetching sponsors.", ephemeral=True) + return + + message = "List of GitHub sponsors" + for edge in data['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + tier = edge['node'].get('tier', {}) + tier_info = f" - Tier: {tier.get('name', 'N/A')} (${tier.get('monthlyPriceInDollars', 'N/A')}/month)" + message += f"\n* [{sponsor['login']}]({sponsor['url']}){tier_info}" + + embed = discord.Embed(title="GitHub Sponsors", color=0x00ff00, description=message) + + await ctx.respond(embed=embed, ephemeral=True) + + @discord.slash_command( + name="link_github", + description="Validate GitHub sponsor status" + ) + async def link_github(self, ctx: discord.ApplicationContext): + """ + Link Discord account with GitHub account, by validating Discord user's "GitHub" connected account status. + + User to login to Discord via OAuth2, and check if their connected GitHub account is a sponsor of the project. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + discord_oauth = OAuth2Session( + os.environ['DISCORD_CLIENT_ID'], + redirect_uri=os.environ['DISCORD_REDIRECT_URI'], + scope=[ + "identify", + "connections", + ], + ) + authorization_url, state = discord_oauth.authorization_url("https://discord.com/oauth2/authorize") + + with self.bot.db as db: + db['oauth_states'] = db.get('oauth_states', {}) + db['oauth_states'][str(ctx.author.id)] = state + db.sync() + + # Store the state in the user's session or database + await ctx.respond(f"Please authorize the application by clicking [here]({authorization_url}).", ephemeral=True) + + +def setup(bot: discord.Bot): + bot.add_cog(GitHubCommandsCog(bot=bot)) diff --git a/src/discord/cogs/moderator_commands.py b/src/discord/cogs/moderator_commands.py index 2464b7d..2f88697 100644 --- a/src/discord/cogs/moderator_commands.py +++ b/src/discord/cogs/moderator_commands.py @@ -7,7 +7,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name # constants recommended_channel_desc = 'Select the recommended channel' # hack for flake8 F722 diff --git a/src/discord/cogs/support_commands.py b/src/discord/cogs/support_commands.py index edb1502..8995d76 100644 --- a/src/discord/cogs/support_commands.py +++ b/src/discord/cogs/support_commands.py @@ -11,7 +11,7 @@ from mistletoe.markdown_renderer import MarkdownRenderer # local imports -from src.common import avatar, bot_name, data_dir +from src.common.common import avatar, bot_name, data_dir from src.discord.views import DocsCommandView from src.discord import cogs_common diff --git a/src/discord/tasks.py b/src/discord/tasks.py index d4249dd..15ef652 100644 --- a/src/discord/tasks.py +++ b/src/discord/tasks.py @@ -9,7 +9,7 @@ from igdb.wrapper import IGDBWrapper # local imports -from src.common import avatar, bot_name, bot_url +from src.common.common import avatar, bot_name, bot_url from src.discord.helpers import igdb_authorization, month_dictionary diff --git a/src/discord/views.py b/src/discord/views.py index 4435d8e..f5649d5 100644 --- a/src/discord/views.py +++ b/src/discord/views.py @@ -7,7 +7,7 @@ from discord.ui.button import Button # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name from src.discord.helpers import get_json from src.discord.modals import RefundModal diff --git a/src/keep_alive.py b/src/keep_alive.py deleted file mode 100644 index 74ab1c9..0000000 --- a/src/keep_alive.py +++ /dev/null @@ -1,20 +0,0 @@ -from flask import Flask -from threading import Thread -import os - -app = Flask('') - - -@app.route('/') -def main(): - return f"{os.environ['REPL_SLUG']} is live!" - - -def run(): - app.run(host="0.0.0.0", port=8080) - - -def keep_alive(): - server = Thread(name="Flask", target=run) - server.setDaemon(daemonic=True) - server.start() diff --git a/src/reddit/bot.py b/src/reddit/bot.py index 7520b9e..b1c3b82 100644 --- a/src/reddit/bot.py +++ b/src/reddit/bot.py @@ -1,18 +1,20 @@ # standard imports +import asyncio from datetime import datetime import os -import requests import shelve import sys import threading import time # lib imports +import discord import praw from praw import models # local imports -from src import common +from src.common import common +from src.common import globals class Bot: @@ -31,14 +33,7 @@ def __init__(self, **kwargs): self.user_agent = kwargs.get('user_agent', f'{common.bot_name} {self.version}') self.avatar = kwargs.get('avatar', common.get_bot_avatar(gravatar=os.environ['GRAVATAR_EMAIL'])) self.subreddit_name = kwargs.get('subreddit', os.getenv('PRAW_SUBREDDIT', 'LizardByte')) - - if not kwargs.get('redirect_uri', None): - try: # for running in replit - self.redirect_uri = f'https://{os.environ["REPL_SLUG"]}.{os.environ["REPL_OWNER"].lower()}.repl.co' - except KeyError: - self.redirect_uri = os.getenv('REDIRECT_URI', 'http://localhost:8080') - else: - self.redirect_uri = kwargs['redirect_uri'] + self.redirect_uri = kwargs.get('redirect_uri', os.getenv('REDIRECT_URI', 'http://localhost:8080')) # directories self.data_dir = common.data_dir @@ -66,7 +61,7 @@ def __init__(self, **kwargs): @staticmethod def validate_env() -> bool: required_env = [ - 'DISCORD_WEBHOOK', + 'DISCORD_REDDIT_CHANNEL_ID', 'PRAW_CLIENT_ID', 'PRAW_CLIENT_SECRET', 'REDDIT_PASSWORD', @@ -141,7 +136,7 @@ def process_submission(self, submission: models.Submission): print(f'submission id: {submission.id}') print(f'submission title: {submission.title}') print('---------') - if os.getenv('DISCORD_WEBHOOK'): + if os.getenv('DISCORD_REDDIT_CHANNEL_ID'): self.discord(submission=submission) self.flair(submission=submission) self.karma(submission=submission) @@ -175,37 +170,33 @@ def discord(self, submission: models.Submission): submission_time = datetime.fromtimestamp(submission.created_utc) - # create the discord message - # todo: use the running discord bot, directly instead of using a webhook - discord_webhook = { - 'username': 'LizardByte-Bot', - 'avatar_url': self.avatar, - 'embeds': [ - { - 'author': { - 'name': str(submission.author), - 'url': f'https://www.reddit.com/user/{submission.author}', - 'icon_url': str(redditor.icon_img) - }, - 'title': str(submission.title), - 'url': str(submission.url), - 'description': str(submission.selftext), - 'color': color, - 'thumbnail': { - 'url': 'https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png' - }, - 'footer': { - 'text': f'Posted on r/{self.subreddit_name} at {submission_time}', - 'icon_url': 'https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' - } - } - ] - } - - # actually send the message - r = requests.post(os.environ['DISCORD_WEBHOOK'], json=discord_webhook) - - if r.status_code == 204: # successful completion of request, no additional content + # create the discord embed + embed = discord.Embed( + author=discord.EmbedAuthor( + name=str(submission.author), + url=f'https://www.reddit.com/user/{submission.author}', + icon_url=str(redditor.icon_img), + ), + title=submission.title, + url=submission.url, + description=submission.selftext, + color=color, + thumbnail='https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png', + footer=discord.EmbedFooter( + text=f'Posted on r/{self.subreddit_name} at {submission_time}', + icon_url='https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' + ) + ) + + # actually send the embed + message = asyncio.run_coroutine_threadsafe( + globals.DISCORD_BOT.send_message_to_channel( + channel_id=os.getenv("DISCORD_REDDIT_CHANNEL_ID"), + embeds=[embed], + ), globals.DISCORD_BOT.loop) + message = message.result() # wait for the message to be sent + + if message: with self.lock, shelve.open(self.db) as db: # the shelve doesn't update unless we recreate the main key submissions = db['submissions'] diff --git a/tests/unit/discord/test_discord_bot.py b/tests/unit/discord/test_discord_bot.py index 500722c..b4aeb1b 100644 --- a/tests/unit/discord/test_discord_bot.py +++ b/tests/unit/discord/test_discord_bot.py @@ -6,7 +6,7 @@ import pytest_asyncio # local imports -from src import common +from src.common import common from src.discord import bot as discord_bot diff --git a/tests/unit/reddit/test_reddit_bot.py b/tests/unit/reddit/test_reddit_bot.py index 8ff1a84..8f3196b 100644 --- a/tests/unit/reddit/test_reddit_bot.py +++ b/tests/unit/reddit/test_reddit_bot.py @@ -161,7 +161,7 @@ def _submission(self, bot, recorder): def test_validate_env(self, bot): with patch.dict( os.environ, { - "DISCORD_WEBHOOK": "test", + "DISCORD_REDDIT_CHANNEL_ID": "test", "PRAW_CLIENT_ID": "test", "PRAW_CLIENT_SECRET": "test", "REDDIT_PASSWORD": "test",