diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 472e096..58be6ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,10 +41,16 @@ jobs: - name: Test with pytest id: test env: + CI_EVENT_ID: ${{ github.event.number || github.sha }} GITHUB_PYTEST: "true" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} DISCORD_BOT_TOKEN: ${{ secrets.DISCORD_TEST_BOT_TOKEN }} - DISCORD_WEBHOOK: ${{ secrets.DISCORD_TEST_BOT_WEBHOOK }} + DISCORD_GITHUB_STATUS_CHANNEL_ID: ${{ vars.DISCORD_GITHUB_STATUS_CHANNEL_ID }} + DISCORD_REDDIT_CHANNEL_ID: ${{ vars.DISCORD_REDDIT_CHANNEL_ID }} + DISCORD_SPONSORS_CHANNEL_ID: ${{ vars.DISCORD_SPONSORS_CHANNEL_ID }} GRAVATAR_EMAIL: ${{ secrets.GRAVATAR_EMAIL }} + IGDB_CLIENT_ID: ${{ secrets.TWITCH_CLIENT_ID }} + IGDB_CLIENT_SECRET: ${{ secrets.TWITCH_CLIENT_SECRET }} PRAW_CLIENT_ID: ${{ secrets.REDDIT_CLIENT_ID }} PRAW_CLIENT_SECRET: ${{ secrets.REDDIT_CLIENT_SECRET }} REDDIT_USERNAME: ${{ secrets.REDDIT_USERNAME }} diff --git a/Dockerfile b/Dockerfile index 8cf0ff5..9f14448 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,19 +17,26 @@ ENV COMMIT=${COMMIT} ARG DAILY_TASKS=true ARG DAILY_RELEASES=true ARG DAILY_TASKS_UTC_HOUR=12 +ARG DISCORD_GITHUB_STATUS_CHANNEL_ID +ARG DISCORD_REDDIT_CHANNEL_ID +ARG DISCORD_SPONSORS_CHANNEL_ID # Secret config -ARG DISCORD_BOT_TOKEN ARG DAILY_CHANNEL_ID +ARG DISCORD_BOT_TOKEN +ARG DISCORD_CLIENT_ID +ARG DISCORD_CLIENT_SECRET +ARG DISCORD_REDIRECT_URI +ARG GITHUB_CLIENT_ID +ARG GITHUB_CLIENT_SECRET +ARG GITHUB_REDIRECT_URI +ARG GITHUB_WEBHOOK_SECRET_KEY ARG GRAVATAR_EMAIL ARG IGDB_CLIENT_ID ARG IGDB_CLIENT_SECRET ARG PRAW_CLIENT_ID ARG PRAW_CLIENT_SECRET ARG PRAW_SUBREDDIT -ARG DISCORD_WEBHOOK -ARG GRAVATAR_EMAIL -ARG REDIRECT_URI # Environment variables ENV DAILY_TASKS=$DAILY_TASKS @@ -37,6 +44,16 @@ ENV DAILY_RELEASES=$DAILY_RELEASES ENV DAILY_CHANNEL_ID=$DAILY_CHANNEL_ID ENV DAILY_TASKS_UTC_HOUR=$DAILY_TASKS_UTC_HOUR ENV DISCORD_BOT_TOKEN=$DISCORD_BOT_TOKEN +ENV DISCORD_CLIENT_ID=$DISCORD_CLIENT_ID +ENV DISCORD_CLIENT_SECRET=$DISCORD_CLIENT_SECRET +ENV DISCORD_GITHUB_STATUS_CHANNEL_ID=$DISCORD_GITHUB_STATUS_CHANNEL_ID +ENV DISCORD_REDDIT_CHANNEL_ID=$DISCORD_REDDIT_CHANNEL_ID +ENV DISCORD_REDIRECT_URI=$DISCORD_REDIRECT_URI +ENV DISCORD_SPONSORS_CHANNEL_ID=$DISCORD_SPONSORS_CHANNEL_ID +ENV GITHUB_CLIENT_ID=$GITHUB_CLIENT_ID +ENV GITHUB_CLIENT_SECRET=$GITHUB_CLIENT_SECRET +ENV GITHUB_REDIRECT_URI=$GITHUB_REDIRECT_URI +ENV GITHUB_WEBHOOK_SECRET_KEY=$GITHUB_WEBHOOK_SECRET_KEY ENV GRAVATAR_EMAIL=$GRAVATAR_EMAIL ENV IGDB_CLIENT_ID=$IGDB_CLIENT_ID ENV IGDB_CLIENT_SECRET=$IGDB_CLIENT_SECRET @@ -44,8 +61,6 @@ ENV PRAW_CLIENT_ID=$PRAW_CLIENT_ID ENV PRAW_CLIENT_SECRET=$PRAW_CLIENT_SECRET ENV PRAW_SUBREDDIT=$PRAW_SUBREDDIT ENV DISCORD_WEBHOOK=$DISCORD_WEBHOOK -ENV GRAVATAR_EMAIL=$GRAVATAR_EMAIL -ENV REDIRECT_URI=$REDIRECT_URI SHELL ["/bin/bash", "-o", "pipefail", "-c"] # install dependencies @@ -69,7 +84,7 @@ RUN <<_SETUP set -e # replace the version in the code -sed -i "s/version = '0.0.0'/version = '${BUILD_VERSION}'/g" src/common.py +sed -i "s/version = '0.0.0'/version = '${BUILD_VERSION}'/g" src/common/common.py # install dependencies python -m pip install --no-cache-dir -r requirements.txt diff --git a/README.md b/README.md index 2d1f458..bf63759 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![GitHub Workflow Status (CI)](https://img.shields.io/github/actions/workflow/status/lizardbyte/support-bot/ci.yml.svg?branch=master&label=CI%20build&logo=github&style=for-the-badge)](https://github.com/LizardByte/support-bot/actions/workflows/ci.yml?query=branch%3Amaster) [![Codecov](https://img.shields.io/codecov/c/gh/LizardByte/support-bot.svg?token=900Q93P1DE&style=for-the-badge&logo=codecov&label=codecov)](https://app.codecov.io/gh/LizardByte/support-bot) -Support bot written in python to help manage LizardByte communities. The current focus is discord and reddit, but other -platforms such as GitHub discussions/issues could be added. +Support bot written in python to help manage LizardByte communities. The current focus is Discord and Reddit, but other +platforms such as GitHub discussions/issues might be added in the future. ## Overview @@ -28,23 +28,28 @@ platforms such as GitHub discussions/issues could be added. :exclamation: if using Docker these can be arguments. :warning: Never publicly expose your tokens, secrets, or ids. -| variable | required | default | description | -|-------------------------|----------|------------------------------------------------------|---------------------------------------------------------------| -| DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | -| DAILY_TASKS | False | `true` | Daily tasks on or off. | -| DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | -| DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | -| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | -| GRAVATAR_EMAIL | False | `None` | Gravatar email address for bot avatar. | -| IGDB_CLIENT_ID | False | `None` | Required if daily_releases is enabled. | -| IGDB_CLIENT_SECRET | False | `None` | Required if daily_releases is enabled. | -| SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | -| SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | - -* Running bot: - * `python -m src` -* Invite bot to server: - * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` +| variable | required | default | description | +|----------------------------------|----------|------------------------------------------------------|---------------------------------------------------------------| +| DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | +| DISCORD_CLIENT_ID | True | `None` | Discord OAuth2 client id. | +| DISCORD_CLIENT_SECRET | True | `None` | Discord OAuth2 client secret. | +| DISCORD_GITHUB_STATUS_CHANNEL_ID | True | `None` | Channel ID to send GitHub status updates to. | +| DISCORD_REDDIT_CHANNEL_ID | True | `None` | Channel ID to send Reddit post updates to. | +| DISCORD_REDIRECT_URI | False | `https://localhost:8080/discord/callback` | The redirect uri for OAuth2. Must be publicly accessible. | +| DISCORD_SPONSORS_CHANNEL_ID | True | `None` | Channel ID to send sponsorship updates to. | +| GITHUB_CLIENT_ID | True | `None` | GitHub OAuth2 client id. | +| GITHUB_CLIENT_SECRET | True | `None` | GitHub OAuth2 client secret. | +| GITHUB_REDIRECT_URI | False | `https://localhost:8080/github/callback` | The redirect uri for OAuth2. Must be publicly accessible. | +| GITHUB_WEBHOOK_SECRET_KEY | True | `None` | A secret value to ensure webhooks are from trusted sources. | +| DAILY_TASKS | False | `true` | Daily tasks on or off. | +| DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | +| DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | +| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | +| GRAVATAR_EMAIL | False | `None` | Gravatar email address for bot avatar. | +| IGDB_CLIENT_ID | False | `None` | Required if daily_releases is enabled. | +| IGDB_CLIENT_SECRET | False | `None` | Required if daily_releases is enabled. | +| SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | +| SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | ### Reddit @@ -62,7 +67,13 @@ platforms such as GitHub discussions/issues could be added. | DISCORD_WEBHOOK | False | None | URL of webhook to send discord notifications to | | GRAVATAR_EMAIL | False | None | Gravatar email address to get avatar from | | REDDIT_USERNAME | True | None | Reddit username | -* | REDDIT_PASSWORD | True | None | Reddit password | + | REDDIT_PASSWORD | True | None | Reddit password | + +### Start -* Running bot: - * `python -m src` +```bash +python -m src +``` + +* Invite bot to server: + * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` diff --git a/assets/favicon.ico b/assets/favicon.ico new file mode 100644 index 0000000..79620bf Binary files /dev/null and b/assets/favicon.ico differ diff --git a/requirements-dev.txt b/requirements-dev.txt index 8efeb9b..32ab25f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,4 @@ betamax-serializers==0.2.1 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==6.0.0 +pytest-mock==3.14.0 diff --git a/requirements.txt b/requirements.txt index d3f6ce3..e28e05a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +cryptography==43.0.3 Flask==3.1.0 GitPython==3.1.43 igdb-api-v4==0.3.3 @@ -7,3 +8,4 @@ praw==7.8.1 py-cord==2.6.1 python-dotenv==1.0.1 requests==2.32.3 +requests-oauthlib==2.0.0 diff --git a/src/__main__.py b/src/__main__.py index 7968744..5411ceb 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,40 +1,33 @@ # standard imports -import os import time # development imports from dotenv import load_dotenv load_dotenv(override=False) # environment secrets take priority over .env file -# local imports -if True: # hack for flake8 - from src.discord import bot as d_bot - from src import keep_alive - from src.reddit import bot as r_bot +# local imports, import after env loaded +from src.common import globals # noqa: E402 +from src.discord import bot as d_bot # noqa: E402 +from src.common import webapp # noqa: E402 +from src.reddit import bot as r_bot # noqa: E402 def main(): - # to run in replit - try: - os.environ['REPL_SLUG'] - except KeyError: - pass # not running in replit - else: - keep_alive.keep_alive() # Start the web server + webapp.start() # Start the web server - discord_bot = d_bot.Bot() - discord_bot.start_threaded() # Start the discord bot + globals.DISCORD_BOT = d_bot.Bot() + globals.DISCORD_BOT.start_threaded() # Start the discord bot - reddit_bot = r_bot.Bot() - reddit_bot.start_threaded() # Start the reddit bot + globals.REDDIT_BOT = r_bot.Bot() + globals.REDDIT_BOT.start_threaded() # Start the reddit bot try: - while discord_bot.bot_thread.is_alive() or reddit_bot.bot_thread.is_alive(): + while globals.DISCORD_BOT.bot_thread.is_alive() or globals.REDDIT_BOT.bot_thread.is_alive(): time.sleep(0.5) except KeyboardInterrupt: print("Keyboard Interrupt Detected") - discord_bot.stop() - reddit_bot.stop() + globals.DISCORD_BOT.stop() + globals.REDDIT_BOT.stop() if __name__ == '__main__': # pragma: no cover diff --git a/src/common/__init__.py b/src/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/common.py b/src/common/common.py similarity index 79% rename from src/common.py rename to src/common/common.py index ef33e84..7997a13 100644 --- a/src/common.py +++ b/src/common/common.py @@ -7,6 +7,17 @@ import requests +colors = { + 'black': 0x000000, + 'green': 0x00ff00, + 'orange': 0xffa500, + 'purple': 0x9147ff, + 'red': 0xff0000, + 'white': 0xffffff, + 'yellow': 0xffff00, +} + + def get_bot_avatar(gravatar: str) -> str: """ Get Gravatar image url. @@ -36,15 +47,17 @@ def get_avatar_bytes(): return avatar_img -def get_data_dir(): +def get_app_dirs(): # parent directory name of this file, not full path - parent_dir = os.path.dirname(os.path.abspath(__file__)).split(os.sep)[-2] + parent_dir = os.path.dirname(os.path.abspath(__file__)).split(os.sep)[-3] if parent_dir == 'app': # running in Docker container + a = '/app' d = '/data' else: # running locally + a = os.getcwd() d = os.path.join(os.getcwd(), 'data') os.makedirs(d, exist_ok=True) - return d + return a, d # constants @@ -52,5 +65,5 @@ def get_data_dir(): org_name = 'LizardByte' bot_name = f'{org_name}-Bot' bot_url = 'https://app.lizardbyte.dev' -data_dir = get_data_dir() +app_dir, data_dir = get_app_dirs() version = '0.0.0' diff --git a/src/common/crypto.py b/src/common/crypto.py new file mode 100644 index 0000000..a59cf77 --- /dev/null +++ b/src/common/crypto.py @@ -0,0 +1,69 @@ +# standard imports +import os + +# lib imports +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption +from datetime import datetime, timedelta, UTC + +# local imports +from src.common import common + +CERT_FILE = os.path.join(common.data_dir, "cert.pem") +KEY_FILE = os.path.join(common.data_dir, "key.pem") + + +def check_expiration(cert_path: str) -> int: + with open(cert_path, "rb") as cert_file: + cert_data = cert_file.read() + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + expiry_date = cert.not_valid_after_utc + return (expiry_date - datetime.now(UTC)).days + + +def generate_certificate(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, + ) + subject = issuer = x509.Name([ + x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"), + ]) + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.now(UTC) + ).not_valid_after( + datetime.now(UTC) + timedelta(days=365) + ).sign(private_key, hashes.SHA256()) + + with open(KEY_FILE, "wb") as f: + f.write(private_key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=NoEncryption(), + )) + + with open(CERT_FILE, "wb") as f: + f.write(cert.public_bytes(Encoding.PEM)) + + +def initialize_certificate() -> tuple[str, str]: + print("Initializing SSL certificate") + if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE): + cert_expires_in = check_expiration(CERT_FILE) + print(f"Certificate expires in {cert_expires_in} days.") + if cert_expires_in >= 90: + return CERT_FILE, KEY_FILE + print("Generating new certificate") + generate_certificate() + return CERT_FILE, KEY_FILE diff --git a/src/common/database.py b/src/common/database.py new file mode 100644 index 0000000..b27fdd4 --- /dev/null +++ b/src/common/database.py @@ -0,0 +1,22 @@ +# standard imports +import shelve +import threading + + +class Database: + def __init__(self, db_path): + self.db_path = db_path + self.lock = threading.Lock() + + def __enter__(self): + self.lock.acquire() + self.db = shelve.open(self.db_path, writeback=True) + return self.db + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync() + self.db.close() + self.lock.release() + + def sync(self): + self.db.sync() diff --git a/src/common/globals.py b/src/common/globals.py new file mode 100644 index 0000000..f185cab --- /dev/null +++ b/src/common/globals.py @@ -0,0 +1,2 @@ +DISCORD_BOT = None +REDDIT_BOT = None diff --git a/src/common/sponsors.py b/src/common/sponsors.py new file mode 100644 index 0000000..56b342a --- /dev/null +++ b/src/common/sponsors.py @@ -0,0 +1,73 @@ +# standard imports +import os +from typing import Union + +# lib imports +import requests + + +tier_map = { + 't4-sponsors': 15, + 't3-sponsors': 10, + 't2-sponsors': 5, + 't1-sponsors': 3, +} + + +def get_github_sponsors() -> Union[dict, False]: + """ + Get list of GitHub sponsors. + + Returns + ------- + Union[dict, False] + JSON response containing the list of sponsors. False if an error occurred. + """ + token = os.getenv("GITHUB_TOKEN") + org_name = os.getenv("GITHUB_ORG_NAME", "LizardByte") + + graphql_url = "https://api.github.com/graphql" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + edges { + node { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + name + monthlyPriceInDollars + } + } + } + } + } + } + """ % org_name + + response = requests.post(graphql_url, json={'query': query}, headers=headers) + data = response.json() + + if 'errors' in data or 'message' in data: + print(data) + print('::error::An error occurred while fetching sponsors.') + return False + + return data diff --git a/src/common/time.py b/src/common/time.py new file mode 100644 index 0000000..d0f0f86 --- /dev/null +++ b/src/common/time.py @@ -0,0 +1,19 @@ +# standard imports +import datetime + + +def iso_to_datetime(iso_str): + """ + Convert an ISO 8601 string to a datetime object. + + Parameters + ---------- + iso_str : str + The ISO 8601 string to convert. + + Returns + ------- + datetime.datetime + The datetime object. + """ + return datetime.datetime.fromisoformat(iso_str) diff --git a/src/common/webapp.py b/src/common/webapp.py new file mode 100644 index 0000000..921fcbd --- /dev/null +++ b/src/common/webapp.py @@ -0,0 +1,338 @@ +# standard imports +import asyncio +import html +import os +from threading import Thread +from typing import Tuple + +# lib imports +import discord +from flask import Flask, jsonify, redirect, request, Response, send_from_directory +from requests_oauthlib import OAuth2Session +from werkzeug.middleware.proxy_fix import ProxyFix + +# local imports +from src.common.common import app_dir, colors +from src.common import crypto +from src.common import globals +from src.common import time + + +DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") +DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") +DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", "https://localhost:8080/discord/callback") + +GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") +GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") +GITHUB_REDIRECT_URI = os.getenv("GITHUB_REDIRECT_URI", "https://localhost:8080/github/callback") + +app = Flask( + import_name='LizardByte-bot', + static_folder=os.path.join(app_dir, 'assets'), +) + +# this allows us to log the real IP address of the client, instead of the IP address of the proxy host +app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) + + +status_colors_map = { + 'investigating': colors['red'], + 'identified': colors['orange'], + 'monitoring': colors['yellow'], + 'resolved': colors['green'], + 'operational': colors['green'], + 'major_outage': colors['red'], + 'partial_outage': colors['orange'], + 'degraded_performance': colors['yellow'], +} + + +def html_to_md(html: str) -> str: + """ + Convert HTML to markdown. + + Parameters + ---------- + html : str + The HTML string to convert to markdown. + + Returns + ------- + str + The markdown string. + """ + replacements = { + '
': '\n', + '
': '\n', + '
': '\n', + '': '**', + '': '**', + } + + for old, new in replacements.items(): + html = html.replace(old, new) + + return html + + +@app.route('/status') +def status(): + return "LizardByte-bot is live!" + + +@app.route("/favicon.ico") +def favicon(): + return send_from_directory( + directory=app.static_folder, + path="favicon.ico", + mimetype="image/vnd.microsoft.icon", + ) + + +@app.route("/discord/callback") +def discord_callback(): + # errors will be in the query parameters + if 'error' in request.args: + return Response(html.escape(request.args['error_description']), status=400) + + # get all active states from the global state manager + with globals.DISCORD_BOT.db as db: + active_states = db['oauth_states'] + + discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) + token = discord_oauth.fetch_token( + token_url="https://discord.com/api/oauth2/token", + client_secret=DISCORD_CLIENT_SECRET, + authorization_response=request.url + ) + + # Fetch the user's Discord profile + response = discord_oauth.get( + url="https://discord.com/api/users/@me", + headers={ + "Accept": "application/json", + "Authorization": f"Bearer {token['access_token']}", + }, + ) + discord_user = response.json() + + # if the user is not in the active states, return an error + if discord_user['id'] not in active_states: + globals.DISCORD_BOT.update_cached_message( + author_id=discord_user['id'], + reason='failure', + ) + return Response("Invalid state", status=400) + + # remove the user from the active states + del active_states[discord_user['id']] + + # Fetch the user's connected accounts + connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") + connections = connections_response.json() + + with globals.DISCORD_BOT.db as db: + db['discord_users'] = db.get('discord_users', {}) + db['discord_users'][discord_user['id']] = { + 'discord_username': discord_user['username'], + 'discord_global_name': discord_user['global_name'], + 'github_id': None, + 'github_username': None, + } + + for connection in connections: + if connection['type'] == 'github': + db['discord_users'][discord_user['id']]['github_id'] = connection['id'] + db['discord_users'][discord_user['id']]['github_username'] = connection['name'] + + globals.DISCORD_BOT.update_cached_message( + author_id=discord_user['id'], + reason='success', + ) + + # Redirect to our main website + return redirect("https://app.lizardbyte.dev") + + +@app.route("/github/callback") +def github_callback(): + # errors will be in the query parameters + if 'error' in request.args: + return Response(html.escape(request.args['error_description']), status=400) + + # the state is sent as a query parameter in the redirect URL + state = request.args.get('state') + + # get all active states from the global state manager + with globals.DISCORD_BOT.db as db: + active_states = db['oauth_states'] + + github_oauth = OAuth2Session(GITHUB_CLIENT_ID, redirect_uri=GITHUB_REDIRECT_URI) + token = github_oauth.fetch_token( + token_url="https://github.com/login/oauth/access_token", + client_secret=GITHUB_CLIENT_SECRET, + authorization_response=request.url + ) + + # Fetch the user's GitHub profile + response = github_oauth.get( + url="https://api.github.com/user", + headers={ + "Accept": "application/vnd.github.v3+json", + "Authorization": f"token {token['access_token']}", + }, + ) + github_user = response.json() + + # if the user is not in the active states, return an error + for discord_user_id, _state in active_states.items(): + if state == _state: + break + else: + return Response("Invalid state", status=400) + + # remove the user from the active states + del active_states[discord_user_id] + + # get discord user data + discord_user_future = asyncio.run_coroutine_threadsafe( + globals.DISCORD_BOT.fetch_user(int(discord_user_id)), + globals.DISCORD_BOT.loop + ) + discord_user = discord_user_future.result() + + with globals.DISCORD_BOT.db as db: + db['discord_users'] = db.get('discord_users', {}) + db['discord_users'][discord_user_id] = { + 'discord_username': discord_user.name, + 'discord_global_name': discord_user.global_name, + 'github_id': github_user['id'], + 'github_username': github_user['login'], + } + + globals.DISCORD_BOT.update_cached_message( + author_id=discord_user_id, + reason='success', + ) + + # Redirect to our main website + return redirect("https://app.lizardbyte.dev") + + +@app.route("/webhook//", methods=["POST"]) +def webhook(source: str, key: str) -> Tuple[Response, int]: + """ + Process webhooks from various sources. + + * GitHub sponsors: https://github.com/sponsors/LizardByte/dashboard/webhooks + * GitHub status: https://www.githubstatus.com + + Parameters + ---------- + source : str + The source of the webhook (e.g., 'github_sponsors', 'github_status'). + key : str + The secret key for the webhook. This must match an environment variable. + + Returns + ------- + flask.Response + Response to the webhook request + """ + valid_sources = [ + "github_sponsors", + "github_status", + ] + + if source not in valid_sources: + return jsonify({"status": "error", "message": "Invalid source"}), 400 + + if key != os.getenv("GITHUB_WEBHOOK_SECRET_KEY"): + return jsonify({"status": "error", "message": "Invalid key"}), 400 + + print(f"received webhook from {source}") + data = request.json + print(f"received webhook data: \n{data}") + + # process the webhook data + if source == "github_sponsors": + if data['action'] == "created": + embed = discord.Embed( + author=discord.EmbedAuthor( + name=data["sponsorship"]["sponsor"]["login"], + url=data["sponsorship"]["sponsor"]["url"], + icon_url=data["sponsorship"]["sponsor"]["avatar_url"], + ), + color=colors['green'], + timestamp=time.iso_to_datetime(data['sponsorship']['created_at']), + title="New GitHub Sponsor", + ) + globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_SPONSORS_CHANNEL_ID"), + embed=embed, + ) + + elif source == "github_status": + # https://support.atlassian.com/statuspage/docs/enable-webhook-notifications + + embed = discord.Embed( + title="GitHub Status Update", + description=data['page']['status_description'], + color=colors['green'], + ) + + # handle component updates + if 'component_update' in data: + component_update = data['component_update'] + component = data['component'] + embed = discord.Embed( + color=status_colors_map.get(component_update['new_status'], colors['orange']), + description=f"Status changed from {component_update['old_status']} to {component_update['new_status']}", + timestamp=time.iso_to_datetime(component_update['created_at']), + title=f"Component Update: {component['name']}", + ) + embed.add_field(name="Component ID", value=component['id']) + embed.add_field(name="Component Status", value=component['status']) + + # handle incident updates + if 'incident' in data: + incident = data['incident'] + try: + update = incident['incident_updates'][0] + except (IndexError, KeyError): + return jsonify({"status": "error", "message": "No incident updates"}), 400 + + embed = discord.Embed( + color=status_colors_map.get(update['status'], colors['orange']), + timestamp=time.iso_to_datetime(incident['created_at']), + title=f"Incident: {incident['name']}", + url=incident.get('shortlink', 'https://www.githubstatus.com'), + ) + embed.add_field(name="Level", value=incident['impact'], inline=False) + embed.add_field(name=update['status'], value=html_to_md(update['body']), inline=False) + + globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_GITHUB_STATUS_CHANNEL_ID"), + embed=embed, + ) + + return jsonify({"status": "success"}), 200 + + +def run(): + cert_file, key_file = crypto.initialize_certificate() + + app.run( + host="0.0.0.0", + port=8080, + ssl_context=(cert_file, key_file) + ) + + +def start(): + server = Thread( + name="Flask", + daemon=True, + target=run, + ) + server.start() diff --git a/src/discord/bot.py b/src/discord/bot.py index a9baf6c..390558d 100644 --- a/src/discord/bot.py +++ b/src/discord/bot.py @@ -2,13 +2,14 @@ import asyncio import os import threading +from typing import Literal, Optional # lib imports import discord # local imports -from src.common import bot_name, get_avatar_bytes, org_name -from src.discord.tasks import daily_task +from src.common.common import bot_name, data_dir, get_avatar_bytes, org_name +from src.common.database import Database from src.discord.views import DonateCommandView @@ -21,6 +22,9 @@ class Bot(discord.Bot): when the bot is ready. """ def __init__(self, *args, **kwargs): + # tasks need to be imported here to avoid circular imports + from src.discord import tasks + if 'intents' not in kwargs: intents = discord.Intents.all() kwargs['intents'] = intents @@ -30,6 +34,11 @@ def __init__(self, *args, **kwargs): self.bot_thread = threading.Thread(target=lambda: None) self.token = os.environ['DISCORD_BOT_TOKEN'] + self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database')) + self.ephemeral_db = {} + self.clean_ephemeral_cache = tasks.clean_ephemeral_cache + self.daily_task = tasks.daily_task + self.role_update_task = tasks.role_update_task self.load_extension( name='src.discord.cogs', @@ -37,12 +46,15 @@ def __init__(self, *args, **kwargs): store=False, ) + with self.db as db: + db['oauth_states'] = {} # clear any oauth states from previous sessions + async def on_ready(self): """ Bot on ready event. This function runs when the discord bot is ready. The function will update the bot presence, update the username - and avatar, and start daily tasks. + and avatar, and start tasks. """ print(f'py-cord version: {discord.__version__}') print(f'Logged in as {self.user.name} (ID: {self.user.id})') @@ -59,18 +71,197 @@ async def on_ready(self): self.add_view(DonateCommandView()) # register view for persistent listening - await self.sync_commands() + self.clean_ephemeral_cache.start(bot=self) + self.role_update_task.start(bot=self) try: os.environ['DAILY_TASKS'] except KeyError: - daily_task.start(bot=self) + self.daily_task.start(bot=self) else: if os.environ['DAILY_TASKS'].lower() == 'true': - daily_task.start(bot=self) + self.daily_task.start(bot=self) else: print("'DAILY_TASKS' environment variable is disabled") + await self.sync_commands() + + async def async_send_message( + self, + channel_id: int, + message: str = None, + embed: discord.Embed = None, + ) -> Optional[discord.Message]: + """ + Send a message to a specific channel asynchronously. If the embeds are too large, they will be shortened. + Additionally, if the total size of the embeds is too large, they will be sent in separate messages. + + Parameters + ---------- + channel_id : int + The ID of the channel to send the message to. + message : str, optional + The message to send. + embed : discord.Embed, optional + The embed to send. + + Returns + ------- + discord.Message + The message that was sent. + """ + # ensure we have a message or embeds to send + if not message and not embed: + return + + if embed and len(embed) > 6000: + cut_length = len(embed) - 6000 + 3 + embed.description = embed.description[:-cut_length] + "..." + if embed and embed.description and len(embed.description) > 4096: + cut_length = len(embed.description) - 4096 + 3 + embed.description = embed.description[:-cut_length] + "..." + + channel = await self.fetch_channel(channel_id) + return await channel.send(content=message, embed=embed) + + def send_message( + self, + channel_id: int, + message: str = None, + embed: discord.Embed = None, + ) -> discord.Message: + """ + Send a message to a specific channel synchronously. + + Parameters + ---------- + channel_id : int + The ID of the channel to send the message to. + message : str, optional + The message to send. + embed : discord.Embed, optional + The embed to send. + + Returns + ------- + discord.Message + The message that was sent. + """ + future = asyncio.run_coroutine_threadsafe( + self.async_send_message( + channel_id=channel_id, + message=message, + embed=embed, + ), self.loop) + return future.result() + + async def async_update_cached_message( + self, + author_id: int, + reason: str, + ) -> bool: + """ + Update the original message with the reason asynchronously. + + After the message is updated, it will be removed from the cache. + + Parameters + ---------- + author_id : int + Author ID to update the cache. + reason : str + Reason to update the cache. Must be one of the following: 'duplicate', 'failure', 'success', 'timeout'. + + Returns + ------- + bool + True if the message was updated, False otherwise. + """ + reasons = { + 'duplicate': "This request was invalidated due to a new request.", + 'failure': "An error occurred while linking your GitHub account.", + 'success': "Your GitHub account is now linked.", + 'timeout': "The request has timed out.", + } + + db = self.ephemeral_db + db['github_cache_context'] = db.get('github_cache_context', {}) + + if str(author_id) not in db['github_cache_context']: + return False + + await db['github_cache_context'][str(author_id)]['response'].edit( + content=reasons[reason], + ) + + # remove the context from the cache + del db['github_cache_context'][str(author_id)] + + return True + + def update_cached_message( + self, + author_id: int, + reason: str, + ) -> bool: + """ + Update the original message with the reason synchronously. + + After the message is updated, it will be removed from the cache. + + Parameters + ---------- + author_id : int + Author ID to update the cache. + reason : str + Reason to update the cache. Must be one of the following: 'duplicate', 'failure', 'success', 'timeout'. + + Returns + ------- + bool + True if the message was updated, False otherwise. + """ + future = asyncio.run_coroutine_threadsafe( + self.async_update_cached_message( + author_id=author_id, + reason=reason, + ), self.loop) + return future.result() + + def create_thread( + self, + message: discord.Message, + name: str, + auto_archive_duration: Literal[60, 1440, 4320, 10080] = discord.MISSING, + slowmode_delay: int = discord.MISSING, + ) -> discord.Thread: + """ + Create a thread from a message. + + Parameters + ---------- + message : discord.Message + The message to create the thread from. + name : str + The name of the thread. + auto_archive_duration : Literal[60, 1440, 4320, 10080], optional + The duration in minutes before the thread is automatically archived. + slowmode_delay : int, optional + The slowmode delay for the thread. + + Returns + ------- + discord.Thread + The thread that was created. + """ + future = asyncio.run_coroutine_threadsafe( + message.create_thread( + name=name, + auto_archive_duration=auto_archive_duration, + slowmode_delay=slowmode_delay, + ), self.loop) + return future.result() + def start_threaded(self): try: # Login the bot in a separate thread @@ -85,14 +276,12 @@ def start_threaded(self): self.stop() def stop(self, future: asyncio.Future = None): - print("Attempting to stop daily tasks") - daily_task.stop() + print("Attempting to stop tasks") + self.daily_task.stop() + self.role_update_task.stop() + self.clean_ephemeral_cache.stop() print("Attempting to close bot connection") if self.bot_thread is not None and self.bot_thread.is_alive(): asyncio.run_coroutine_threadsafe(self.close(), self.loop) self.bot_thread.join() print("Closed bot") - - # Set a result for the future to mark it as done (unit testing) - if future and not future.done(): - future.set_result(None) diff --git a/src/discord/cogs/base_commands.py b/src/discord/cogs/base_commands.py index 99734d8..fd87cf7 100644 --- a/src/discord/cogs/base_commands.py +++ b/src/discord/cogs/base_commands.py @@ -3,7 +3,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name, org_name, version +from src.common.common import avatar, bot_name, colors, org_name, version from src.discord.views import DonateCommandView from src.discord import cogs_common @@ -39,7 +39,7 @@ async def help_command( description += f"\n\nVersion: {version}\n" - embed = discord.Embed(description=description, color=0xE5A00D) + embed = discord.Embed(description=description, color=colors['orange']) embed.set_footer(text=bot_name, icon_url=avatar) await ctx.respond(embed=embed, ephemeral=True) diff --git a/src/discord/cogs/fun_commands.py b/src/discord/cogs/fun_commands.py index 98e53f2..395595d 100644 --- a/src/discord/cogs/fun_commands.py +++ b/src/discord/cogs/fun_commands.py @@ -7,7 +7,7 @@ import requests # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name, colors from src.discord.views import RefundCommandView from src.discord import cogs_common @@ -56,7 +56,7 @@ async def random_command( else: description = None - embed = discord.Embed(title=quote, description=description, color=0x00ff00) + embed = discord.Embed(title=quote, description=description, color=colors['green']) embed.set_footer(text=bot_name, icon_url=avatar) if user: @@ -91,7 +91,7 @@ async def refund_command( embed = discord.Embed(title="Refund request", description="Original purchase price: $0.00\n\n" "Select the button below to request a full refund!", - color=0xDC143C) + color=colors['red']) embed.set_footer(text=bot_name, icon_url=avatar) if user: diff --git a/src/discord/cogs/github_commands.py b/src/discord/cogs/github_commands.py new file mode 100644 index 0000000..0e85dad --- /dev/null +++ b/src/discord/cogs/github_commands.py @@ -0,0 +1,145 @@ +# standard imports +from datetime import datetime, timedelta, UTC +import os + +# lib imports +import discord +from requests_oauthlib import OAuth2Session + +# local imports +from src.common.common import colors +from src.common import sponsors + + +link_github_platform_description = 'Platform to link' # hack for flake8 F722 +link_github_platform_choices = [ # hack for flake8 F821 + "discord", + "github", +] + + +class GitHubCommandsCog(discord.Cog): + def __init__(self, bot): + self.bot = bot + + @discord.slash_command( + name="get_sponsors", + description="Get list of GitHub sponsors", + default_member_permissions=discord.Permissions(manage_guild=True), + ) + async def get_sponsors( + self, + ctx: discord.ApplicationContext, + ): + """ + Get list of GitHub sponsors. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + data = sponsors.get_github_sponsors() + + if not data: + await ctx.respond("An error occurred while fetching sponsors.", ephemeral=True) + return + + message = "List of GitHub sponsors" + for edge in data['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + tier = edge['node'].get('tier', {}) + tier_info = f" - Tier: {tier.get('name', 'N/A')} (${tier.get('monthlyPriceInDollars', 'N/A')}/month)" + message += f"\n* [{sponsor['login']}]({sponsor['url']}){tier_info}" + + embed = discord.Embed(title="GitHub Sponsors", color=colors['green'], description=message) + + await ctx.respond(embed=embed, ephemeral=True) + + @discord.slash_command( + name="link_github", + description="Validate GitHub sponsor status" + ) + async def link_github( + self, + ctx: discord.ApplicationContext, + platform: discord.Option( + str, + description=link_github_platform_description, + choices=link_github_platform_choices, + required=True, + ), + ): + """ + Link Discord account with GitHub account. + + This works by authenticating to GitHub or to Discord and checking the user's "GitHub" connected account status. + + User to login via OAuth2. + If the Discord option is selected, then check if their connected GitHub account is a sponsor of the project. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + platform : str + Platform to link. + """ + platform_map = { + 'discord': { + 'auth_url': "https://discord.com/api/oauth2/authorize", + 'client_id': os.environ['DISCORD_CLIENT_ID'], + 'redirect_uri': os.environ['DISCORD_REDIRECT_URI'], + 'scope': [ + "identify", + "connections", + ], + }, + 'github': { + 'auth_url': "https://github.com/login/oauth/authorize", + 'client_id': os.environ['GITHUB_CLIENT_ID'], + 'redirect_uri': os.environ['GITHUB_REDIRECT_URI'], + 'scope': [ + "read:user", + ], + }, + } + + auth = OAuth2Session( + client_id=platform_map[platform]['client_id'], + redirect_uri=platform_map[platform]['redirect_uri'], + scope=platform_map[platform]['scope'], + ) + authorization_url, state = auth.authorization_url(platform_map[platform]['auth_url']) + + # Store the state in the user's session or database + with self.bot.db as db: + db['oauth_states'] = db.get('oauth_states', {}) + db['oauth_states'][str(ctx.author.id)] = state + db.sync() + + response = await ctx.respond( + f"Please authorize the application by clicking [here]({authorization_url}).", + ephemeral=True, + ) + + now = datetime.now(UTC) + db = self.bot.ephemeral_db + db['github_cache_context'] = db.get('github_cache_context', {}) + + # if there is a current context, update the original response on discord + if str(ctx.author.id) in db['github_cache_context']: + await self.bot.async_update_cached_message( + author_id=ctx.author.id, + reason='duplicate', + ) + + db['github_cache_context'][str(ctx.author.id)] = { + 'created_at': now, + 'expires_at': now + timedelta(seconds=300), + 'response': response, + } + + +def setup(bot: discord.Bot): + bot.add_cog(GitHubCommandsCog(bot=bot)) diff --git a/src/discord/cogs/moderator_commands.py b/src/discord/cogs/moderator_commands.py index 2464b7d..22ef8db 100644 --- a/src/discord/cogs/moderator_commands.py +++ b/src/discord/cogs/moderator_commands.py @@ -7,7 +7,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name, colors # constants recommended_channel_desc = 'Select the recommended channel' # hack for flake8 F722 @@ -56,7 +56,7 @@ async def channel_command( embed = discord.Embed( title="Incorrect channel", description=f"Please move discussion to {recommended_channel.mention}", - color=0x00ff00, + color=colors['orange'], ) permission_ch_id = '' @@ -150,52 +150,63 @@ async def user_info_command( embed.set_author(name=user.name) embed.set_thumbnail(url=user.display_avatar.url) - if user.colour.value: # If user has a role with a color - embed.colour = user.colour + embed.colour = user.color if user.color.value else colors['white'] + + with self.bot.db as db: + user_data = db.get('discord_users', {}).get(str(user.id)) + if user_data and user_data.get('github_username'): + embed.add_field( + name="GitHub", + value=f"[{user_data['github_username']}](https://github.com/{user_data['github_username']})", + inline=False, + ) if isinstance(user, discord.User): # Checks if the user in the server embed.set_footer(text="This user is not in this server.") - else: # We end up here if the user is a discord.Member object - embed.add_field( - name="Joined Server at", - value=f'{discord.utils.format_dt(user.joined_at, "R")}\n' - f'{discord.utils.format_dt(user.joined_at, "F")}', - inline=False, - ) # When the user joined the server - - # get User Roles - roles = [role.name for role in user.roles] - roles.pop(0) # remove @everyone role - embed.add_field( - name="Server Roles", - value='\n'.join(roles) if roles else "No roles", - inline=False, - ) - - # get User Status, such as Server Owner, Server Moderator, Server Admin, etc. - user_status = [] - if user.guild.owner_id == user.id: - user_status.append("Server Owner") - if user.guild_permissions.administrator: - user_status.append("Server Admin") - if user.guild_permissions.manage_guild: - user_status.append("Server Moderator") - embed.add_field( - name="User Status", - value='\n'.join(user_status), - inline=False, - ) - - if user.premium_since: # If the user is boosting the server - boosting_value = (f'{discord.utils.format_dt(user.premium_since, "R")}\n' - f'{discord.utils.format_dt(user.premium_since, "F")}') - else: - boosting_value = "Not boosting" - embed.add_field( - name="Boosting Since", - value=boosting_value, - inline=False, - ) + await ctx.respond(embeds=[embed]) + return + + # We end up here if the user is a discord.Member object + embed.add_field( + name="Joined Server at", + value=f'{discord.utils.format_dt(user.joined_at, "R")}\n' + f'{discord.utils.format_dt(user.joined_at, "F")}', + inline=False, + ) # When the user joined the server + + # get User Roles + roles = [role.name for role in user.roles] + roles.pop(0) # remove @everyone role + embed.add_field( + name="Server Roles", + value='\n'.join(roles) if roles else "No roles", + inline=False, + ) + + # get User Status, such as Server Owner, Server Moderator, Server Admin, etc. + user_status = [] + if user.guild.owner_id == user.id: + user_status.append("Server Owner") + if user.guild_permissions.administrator: + user_status.append("Server Admin") + if user.guild_permissions.manage_guild: + user_status.append("Server Moderator") + embed.add_field( + name="User Status", + value='\n'.join(user_status), + inline=False, + ) + + if user.premium_since: # If the user is boosting the server + boosting_value = (f'{discord.utils.format_dt(user.premium_since, "R")}\n' + f'{discord.utils.format_dt(user.premium_since, "F")}') + else: + boosting_value = "Not boosting" + embed.add_field( + name="Boosting Since", + value=boosting_value, + inline=False, + ) await ctx.respond(embeds=[embed]) # Sends the embed diff --git a/src/discord/cogs/support_commands.py b/src/discord/cogs/support_commands.py index edb1502..ace82f8 100644 --- a/src/discord/cogs/support_commands.py +++ b/src/discord/cogs/support_commands.py @@ -11,7 +11,7 @@ from mistletoe.markdown_renderer import MarkdownRenderer # local imports -from src.common import avatar, bot_name, data_dir +from src.common.common import avatar, bot_name, colors, data_dir from src.discord.views import DocsCommandView from src.discord import cogs_common @@ -130,7 +130,7 @@ async def project_command(ctx: discord.ApplicationContext, command: str): f"{project}/{command}.md") embed = discord.Embed( - color=0xF1C232, + color=colors['yellow'], description=description, timestamp=datetime.datetime.now(tz=datetime.timezone.utc), title="See on GitHub", @@ -165,7 +165,7 @@ async def docs_command( user : discord.Member Username to mention in response. """ - embed = discord.Embed(title="Select a project", color=0xF1C232) + embed = discord.Embed(title="Select a project", color=colors['yellow']) embed.set_footer(text=bot_name, icon_url=avatar) if user: diff --git a/src/discord/tasks.py b/src/discord/tasks.py index d4249dd..8e63a1d 100644 --- a/src/discord/tasks.py +++ b/src/discord/tasks.py @@ -1,5 +1,7 @@ # standard imports -from datetime import datetime +import asyncio +import copy +from datetime import datetime, UTC import json import os @@ -9,165 +11,263 @@ from igdb.wrapper import IGDBWrapper # local imports -from src.common import avatar, bot_name, bot_url +from src.common.common import avatar, bot_name, bot_url, colors +from src.common import sponsors +from src.discord.bot import Bot from src.discord.helpers import igdb_authorization, month_dictionary +@tasks.loop(seconds=30) +async def clean_ephemeral_cache(bot: Bot) -> bool: + """ + Clean ephemeral messages in cache. + + This function runs on a schedule, every 30 seconds. + Check the ephemeral database for expired messages and delete them. + """ + for key, value in copy.deepcopy(bot.ephemeral_db.get('github_cache_context', {})).items(): + if value['expires_at'] < datetime.now(UTC): + bot.update_cached_message(author_id=int(key), reason='timeout') + + return True + + @tasks.loop(minutes=60.0) -async def daily_task(bot: discord.Bot): +async def daily_task(bot: Bot) -> bool: """ Run daily task loop. This function runs on a schedule, every 60 minutes. Create an embed and thread for each game released on this day in history (according to IGDB), if enabled. + + Returns + ------- + bool + True if the task ran successfully, False otherwise. + """ + date = datetime.now(UTC) + if date.hour != int(os.getenv(key='DAILY_TASKS_UTC_HOUR', default=12)): + return False + + daily_releases = True if os.getenv(key='DAILY_RELEASES', default='true').lower() == 'true' else False + if not daily_releases: + print("'DAILY_RELEASES' environment variable is disabled") + return False + + try: + channel_id = int(os.environ['DAILY_CHANNEL_ID']) + except KeyError: + print("'DAILY_CHANNEL_ID' not defined in environment variables.") + return False + + igdb_auth = igdb_authorization(client_id=os.environ['IGDB_CLIENT_ID'], + client_secret=os.environ['IGDB_CLIENT_SECRET']) + wrapper = IGDBWrapper(client_id=os.environ['IGDB_CLIENT_ID'], auth_token=igdb_auth['access_token']) + + end_point = 'release_dates' + fields = [ + 'human', + 'game.name', + 'game.summary', + 'game.url', + 'game.genres.name', + 'game.rating', + 'game.cover.url', + 'game.artworks.url', + 'game.platforms.name', + 'game.platforms.url' + ] + + where = f'human="{month_dictionary[date.month]} {date.day:02d}"*' + limit = 500 + query = f'fields {", ".join(fields)}; where {where}; limit {limit};' + + byte_array = bytes(wrapper.api_request(endpoint=end_point, query=query)) + json_result = json.loads(byte_array) + + game_ids = [] + + for game in json_result: + try: + game_id = game['game']['id'] + except KeyError: + continue + + if game_id in game_ids: + continue # do not repeat the same game... even though it could be a different platform + game_ids.append(game_id) + + try: + embed = discord.Embed( + title=game['game']['name'], + url=game['game']['url'], + description=game['game']['summary'][0:2000 - 1], + color=colors['purple'] + ) + except KeyError: + continue + + try: + rating = round(game['game']['rating'] / 20, 1) + embed.add_field( + name='Average Rating', + value=f'⭐{rating}', + inline=True + ) + except KeyError: + continue + if rating < 4.0: # reduce the number of messages per day + continue + + try: + embed.add_field( + name='Release Date', + value=game['human'], + inline=True + ) + except KeyError: + pass + + try: + embed.set_thumbnail(url=f"https:{game['game']['cover']['url'].replace('_thumb', '_original')}") + except KeyError: + pass + + try: + embed.set_image(url=f"https:{game['game']['artworks'][0]['url'].replace('_thumb', '_original')}") + except KeyError: + pass + + try: + platforms = ', '.join(platform['name'] for platform in game['game']['platforms']) + name = 'Platforms' if len(game['game']['platforms']) > 1 else 'Platform' + + embed.add_field( + name=name, + value=platforms, + inline=False + ) + except KeyError: + pass + + try: + genres = ', '.join(genre['name'] for genre in game['game']['genres']) + name = 'Genres' if len(game['game']['genres']) > 1 else 'Genre' + + embed.add_field( + name=name, + value=genres, + inline=False + ) + except KeyError: + pass + + embed.set_author( + name=bot_name, + url=bot_url, + icon_url=avatar + ) + + embed.set_footer( + text='Data provided by IGDB', + icon_url='https://www.igdb.com/favicon-196x196.png' + ) + + message = bot.send_message(channel_id=channel_id, embed=embed) + thread = bot.create_thread(message=message, name=embed.title) + + print(f'thread created: {thread.name}') + + return True + + +@tasks.loop(minutes=1.0) +async def role_update_task(bot: Bot) -> bool: """ - if datetime.utcnow().hour == int(os.getenv(key='DAILY_TASKS_UTC_HOUR', default=12)): - daily_releases = True if os.getenv(key='DAILY_RELEASES', default='true').lower() == 'true' else False - if not daily_releases: - print("'DAILY_RELEASES' environment variable is disabled") + Run the role update task. + + This function runs on a schedule, every 1 minute. + If the current time is not divisible by 10, return False. e.g. Run every 10 minutes. + + Returns + ------- + bool + True if the task ran successfully, False otherwise. + """ + if datetime.now(UTC).minute not in list(range(0, 60, 10)): + return False + + # check each user in the database for their GitHub sponsor status + with bot.db as db: + discord_users = db.get('discord_users', {}) + + if not discord_users: + return False + + github_sponsors = sponsors.get_github_sponsors() + + for user_id, user_data in discord_users.items(): + # get the currently revocable roles, to ensure we don't remove roles that were added by another integration + # i.e.; any role that was added by our bot is safe to remove + revocable_roles = user_data.get('roles', []).copy() + + # check if the user is a GitHub sponsor + for edge in github_sponsors['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + if sponsor['login'] == user_data['github_username']: + # user is a sponsor + user_data['github_sponsor'] = True + + monthly_amount = edge['node'].get('tier', {}).get('monthlyPriceInDollars', 0) + + for tier, amount in sponsors.tier_map.items(): + if monthly_amount >= amount: + user_data['roles'] = [tier, 'supporters'] + break + else: + user_data['roles'] = [] + + break else: - try: - channel = bot.get_channel(int(os.environ['DAILY_CHANNEL_ID'])) - except KeyError: - print("'DAILY_CHANNEL_ID' not defined in environment variables.") - else: - igdb_auth = igdb_authorization(client_id=os.environ['IGDB_CLIENT_ID'], - client_secret=os.environ['IGDB_CLIENT_SECRET']) - wrapper = IGDBWrapper(client_id=os.environ['IGDB_CLIENT_ID'], auth_token=igdb_auth['access_token']) - - end_point = 'release_dates' - fields = [ - 'human', - 'game.name', - 'game.summary', - 'game.url', - 'game.genres.name', - 'game.rating', - 'game.cover.url', - 'game.artworks.url', - 'game.platforms.name', - 'game.platforms.url' - ] - - where = f'human="{month_dictionary[datetime.utcnow().month]} {datetime.utcnow().day:02d}"*' - limit = 500 - query = f'fields {", ".join(fields)}; where {where}; limit {limit};' - - byte_array = bytes(wrapper.api_request(endpoint=end_point, query=query)) - json_result = json.loads(byte_array) - - game_ids = [] - - for game in json_result: - color = 0x9147FF - - try: - game_id = game['game']['id'] - except KeyError: - continue - else: - if game_id not in game_ids: - game_ids.append(game_id) - else: # do not repeat the same game... even though it could be a different platform - continue - - try: - embed = discord.Embed( - title=game['game']['name'], - url=game['game']['url'], - description=game['game']['summary'][0:2000 - 1], - color=color - ) - except KeyError: - continue - - try: - embed.add_field( - name='Release Date', - value=game['human'], - inline=True - ) - except KeyError: - pass - - try: - rating = round(game['game']['rating'] / 20, 1) - embed.add_field( - name='Average Rating', - value=f'⭐{rating}', - inline=True - ) - - if rating < 4.0: # reduce number of messages per day - continue - except KeyError: - continue - - try: - embed.set_thumbnail( - url=f"https:{game['game']['cover']['url'].replace('_thumb', '_original')}" - ) - except KeyError: - pass - - try: - embed.set_image( - url=f"https:{game['game']['artworks'][0]['url'].replace('_thumb', '_original')}" - ) - except KeyError: - pass - - try: - platforms = '' - name = 'Platform' - - for platform in game['game']['platforms']: - if platforms: - platforms += ", " - name = 'Platforms' - platforms += platform['name'] - - embed.add_field( - name=name, - value=platforms, - inline=False - ) - except KeyError: - pass - - try: - genres = '' - name = 'Genre' - - for genre in game['game']['genres']: - if genres: - genres += ", " - name = 'Genres' - genres += genre['name'] - - embed.add_field( - name=name, - value=genres, - inline=False - ) - except KeyError: - pass - - try: - embed.set_author( - name=bot_name, - url=bot_url, - icon_url=avatar - ) - except KeyError: - pass - - embed.set_footer( - text='Data provided by IGDB', - icon_url='https://www.igdb.com/favicon-196x196.png' - ) - - message = await channel.send(embed=embed) - thread = await message.create_thread(name=embed.title) - - print(f'thread created: {thread.name}') + # user is not a sponsor + user_data['github_sponsor'] = False + user_data['roles'] = [] + + if user_data.get('github_username'): + user_data['roles'].append('github-users') + + # update the discord user roles + for g in bot.guilds: + roles = g.roles + + role_map = { + 'github-users': discord.utils.get(roles, name='github-users'), + 'supporters': discord.utils.get(roles, name='supporters'), + 't1-sponsors': discord.utils.get(roles, name='t1-sponsors'), + 't2-sponsors': discord.utils.get(roles, name='t2-sponsors'), + 't3-sponsors': discord.utils.get(roles, name='t3-sponsors'), + 't4-sponsors': discord.utils.get(roles, name='t4-sponsors'), + } + + user_roles = user_data['roles'] + + for user_role, role in role_map.items(): + member = g.get_member(int(user_id)) + role = role_map.get(user_role, None) + if not member or not role: + continue + + if user_role in user_roles: + # await member.add_roles(role) + add_future = asyncio.run_coroutine_threadsafe(member.add_roles(role), bot.loop) + add_future.result() + elif user_role in revocable_roles: + # await member.remove_roles(role) + remove_future = asyncio.run_coroutine_threadsafe(member.remove_roles(role), bot.loop) + remove_future.result() + + with bot.db as db: + db['discord_users'] = discord_users + db.sync() + + return True diff --git a/src/discord/views.py b/src/discord/views.py index 4435d8e..c756e1c 100644 --- a/src/discord/views.py +++ b/src/discord/views.py @@ -7,7 +7,7 @@ from discord.ui.button import Button # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name, colors from src.discord.helpers import get_json from src.discord.modals import RefundModal @@ -89,13 +89,13 @@ def check_completion_status(self) -> Tuple[bool, discord.Embed]: if complete: embed.title = self.docs_project embed.description = f'The selected docs are available at {url}' - embed.color = 0x39FF14 + embed.color = colors['green'] embed.url = url else: # info is not complete embed.title = "Select the remaining values" embed.description = None - embed.color = 0xF1C232 + embed.color = colors['orange'] embed.url = None return complete, embed @@ -113,7 +113,7 @@ async def on_timeout(self): if not complete: embed.title = "Command timed out..." - embed.color = 0xDC143C + embed.color = colors['red'] delete_after = 30 # delete after 30 seconds else: delete_after = None # do not delete diff --git a/src/keep_alive.py b/src/keep_alive.py deleted file mode 100644 index 74ab1c9..0000000 --- a/src/keep_alive.py +++ /dev/null @@ -1,20 +0,0 @@ -from flask import Flask -from threading import Thread -import os - -app = Flask('') - - -@app.route('/') -def main(): - return f"{os.environ['REPL_SLUG']} is live!" - - -def run(): - app.run(host="0.0.0.0", port=8080) - - -def keep_alive(): - server = Thread(name="Flask", target=run) - server.setDaemon(daemonic=True) - server.start() diff --git a/src/reddit/bot.py b/src/reddit/bot.py index 7520b9e..e0c5755 100644 --- a/src/reddit/bot.py +++ b/src/reddit/bot.py @@ -1,18 +1,19 @@ # standard imports from datetime import datetime import os -import requests import shelve import sys import threading import time # lib imports +import discord import praw from praw import models # local imports -from src import common +from src.common import common +from src.common import globals class Bot: @@ -31,14 +32,7 @@ def __init__(self, **kwargs): self.user_agent = kwargs.get('user_agent', f'{common.bot_name} {self.version}') self.avatar = kwargs.get('avatar', common.get_bot_avatar(gravatar=os.environ['GRAVATAR_EMAIL'])) self.subreddit_name = kwargs.get('subreddit', os.getenv('PRAW_SUBREDDIT', 'LizardByte')) - - if not kwargs.get('redirect_uri', None): - try: # for running in replit - self.redirect_uri = f'https://{os.environ["REPL_SLUG"]}.{os.environ["REPL_OWNER"].lower()}.repl.co' - except KeyError: - self.redirect_uri = os.getenv('REDIRECT_URI', 'http://localhost:8080') - else: - self.redirect_uri = kwargs['redirect_uri'] + self.redirect_uri = kwargs.get('redirect_uri', os.getenv('REDIRECT_URI', 'http://localhost:8080')) # directories self.data_dir = common.data_dir @@ -66,7 +60,7 @@ def __init__(self, **kwargs): @staticmethod def validate_env() -> bool: required_env = [ - 'DISCORD_WEBHOOK', + 'DISCORD_REDDIT_CHANNEL_ID', 'PRAW_CLIENT_ID', 'PRAW_CLIENT_SECRET', 'REDDIT_PASSWORD', @@ -141,7 +135,7 @@ def process_submission(self, submission: models.Submission): print(f'submission id: {submission.id}') print(f'submission title: {submission.title}') print('---------') - if os.getenv('DISCORD_WEBHOOK'): + if os.getenv('DISCORD_REDDIT_CHANNEL_ID'): self.discord(submission=submission) self.flair(submission=submission) self.karma(submission=submission) @@ -166,46 +160,39 @@ def discord(self, submission: models.Submission): try: color = int(submission.link_flair_background_color, 16) except Exception: - color = int('ffffff', 16) + color = common.colors['white'] try: redditor = self.reddit.redditor(name=submission.author) except Exception: return - submission_time = datetime.fromtimestamp(submission.created_utc) - - # create the discord message - # todo: use the running discord bot, directly instead of using a webhook - discord_webhook = { - 'username': 'LizardByte-Bot', - 'avatar_url': self.avatar, - 'embeds': [ - { - 'author': { - 'name': str(submission.author), - 'url': f'https://www.reddit.com/user/{submission.author}', - 'icon_url': str(redditor.icon_img) - }, - 'title': str(submission.title), - 'url': str(submission.url), - 'description': str(submission.selftext), - 'color': color, - 'thumbnail': { - 'url': 'https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png' - }, - 'footer': { - 'text': f'Posted on r/{self.subreddit_name} at {submission_time}', - 'icon_url': 'https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' - } - } - ] - } - - # actually send the message - r = requests.post(os.environ['DISCORD_WEBHOOK'], json=discord_webhook) - - if r.status_code == 204: # successful completion of request, no additional content + # create the discord embed + embed = discord.Embed( + author=discord.EmbedAuthor( + name=str(submission.author), + url=f'https://www.reddit.com/user/{submission.author}', + icon_url=str(redditor.icon_img), + ), + color=color, + description=submission.selftext, + footer=discord.EmbedFooter( + text=f'Posted on r/{self.subreddit_name}', + icon_url='https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' + ), + title=submission.title, + url=f"https://www.reddit.com{submission.permalink}", + timestamp=datetime.fromtimestamp(submission.created_utc), + thumbnail='https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png', + ) + + # actually send the embed + message = globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_REDDIT_CHANNEL_ID"), + embed=embed, + ) + + if message: with self.lock, shelve.open(self.db) as db: # the shelve doesn't update unless we recreate the main key submissions = db['submissions'] diff --git a/tests/conftest.py b/tests/conftest.py index a9455c6..6a8d14a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,68 @@ +# standard imports +import os +import time + # lib imports import dotenv +import pytest + +# local imports +from src.common import globals dotenv.load_dotenv(override=False) # environment secrets take priority over .env file + +# import after env loaded +from src.discord import bot as d_bot # noqa: E402 + + +@pytest.fixture(scope='session') +def discord_bot(): + bot = d_bot.Bot() + bot.start_threaded() + globals.DISCORD_BOT = bot + + while not bot.is_ready(): # Wait until the bot is ready + time.sleep(1) + + bot.role_update_task.stop() + bot.daily_task.stop() + bot.clean_ephemeral_cache.stop() + + yield bot + + bot.stop() + globals.DISCORD_BOT = None + + +@pytest.fixture(scope='function') +def discord_db_users(discord_bot): + with discord_bot.db as db: + db['discord_users'] = { + '939171917578002502': { + 'discord_username': 'test_user', + 'discord_global_name': 'Test User', + 'github_id': 'test_user', + 'github_username': 'test_user', + 'roles': [ + 'supporters', + ] + } + } + db['oauth_states'] = {'939171917578002502': 'valid_state'} + db.sync() # Ensure the data is written to the shelve + + yield + + with discord_bot.db as db: + db['discord_users'] = {} + db['oauth_states'] = {} + db.sync() # Ensure the data is written to the shelve + + +@pytest.fixture(scope='function') +def no_github_token(): + og_token = os.getenv('GITHUB_TOKEN') + del os.environ['GITHUB_TOKEN'] + yield + + os.environ['GITHUB_TOKEN'] = og_token diff --git a/tests/fixtures/certs/expired/cert.pem b/tests/fixtures/certs/expired/cert.pem new file mode 100644 index 0000000..4084b6f --- /dev/null +++ b/tests/fixtures/certs/expired/cert.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEtDCCApygAwIBAgIUO9/D0xVF8jI0w7lJQEbuOAkpeXIwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI0MTEyMzAwNDUwM1oXDTI0MTEy +MzAwNDUwM1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAx2ytodLmx/I7DRe6JTGn98I/DEcRdow+f+6UjjIQczPB +jD97JsfV45eVaIWmRMjqn+A8zAKnsBdRpGlFwbAdG174cu/BLdNb/OoVxCSkiZpH +wtmuRVofOgo2VnFTjgG7gu/4GV7SIOsngz/uB+W7xw/GVQfEsDld3cjiLObn0aFv +D7oKE6wAP0gmbGrNDmvFikVUQIU0tuWfc6DLK2QEh00KOsOAjX9bwk7cOO6+W4Xg +EAqYx0XsyNPkWOG8D9FzttnC1UESSoOyr41ne6sn5knkxdk00dxAVpabt7Z1Do4u +NlAR324+u62GkMP9Pv0tXU8YZxnHFT5njXDJMKXwj7vWiHPzR7Ykw6/fCUOYOcop +rOweSUmevmUIGZKQVyKDrcLumGC4IPfv+UCPQyA3eUEyTPnkDawDRepha0FsZ7Pt +R/0Ftm7XW5u4HFMhRyrnDrHBGNywUg+bYGl7MWIqr0p3o9CVENDIRgLyuDrCnnMB +UFrqpbPp7Q6Z64ohdpvb+eJRYBCUJkbbFawUa/SXe6c5/cFAFwoNgN0UHcBvWKpV +INc5WJPgiaHsauADeUuiU4+n3ZdOu8YMCpei+lM+eRR3KadZ2/UE9lzWQ7PfKDtM +iIeon6oudZIlaTJPsV/AFIwJadJKpxYhgJ6JtlcORBUgzaWFEXRL+/rhVRZInNUC +AwEAATANBgkqhkiG9w0BAQsFAAOCAgEAOzJmOXAcERo/kuB11AnrNqk5bpxqF1Gl +ORNxUQflB0f3qooHkuPH6CdrrZ32yUIN+54fcVCCnQfx04PCC4bPFRreTCqyPCtb +Oinfk5BgEmIvE4x9PibPcmQG6zQfHHqOQzsxio6Fjhfk+iL9Fy30W2K3RBvIicOM +BQ+kGysltV+9tMX4wI/VnLCN5LORbBX7fiMnFtmVKeLZalnOWcMqZuc6opQFjWzg +r4vqu6//STkrCvze4tLUMipS8uKXQ9hvrdiXgQGHOZDhRaQCC+TAXYxPn2pxvYYK +l7dlQS1mWY8pPB7X9FMsACmZR2myBIqbHzFsde+Mqyf5fWHihtWwNYPYreCXKZdr +A7LtgQG9KhTUQO9HjFkbG/VYiH5rPUlewd+qLVvdZ8vFS6ZMvMH7eJPdL0ubuM4s +vTDgPXxqE4GqfzuT0d+vmJujllkiOYdbkDNRYv0rekojNbJcNyyDCs1056ke5JPr +//XfgeW1Lwz1yL9xB5U1lqVUaGIifzihO69yNESUSh/niuwDeWYkz/bgo9oM3L9+ +f1WznzC/tcibq+d9V6PE7KRiGfS5ZbRxAm95wrnRurZYkM+eeZHDmPs3InfYe0Zj +WarJjoO+x/+/ErjgsVUHt9JqB8GdXO3Xg7c5bkrt6LqgYxZ2GUDZZSbe/MTktYsp +E/Y7rCRq6LQ= +-----END CERTIFICATE----- diff --git a/tests/fixtures/certs/expired/key.pem b/tests/fixtures/certs/expired/key.pem new file mode 100644 index 0000000..a016fbb --- /dev/null +++ b/tests/fixtures/certs/expired/key.pem @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJJwIBAAKCAgEAx2ytodLmx/I7DRe6JTGn98I/DEcRdow+f+6UjjIQczPBjD97 +JsfV45eVaIWmRMjqn+A8zAKnsBdRpGlFwbAdG174cu/BLdNb/OoVxCSkiZpHwtmu +RVofOgo2VnFTjgG7gu/4GV7SIOsngz/uB+W7xw/GVQfEsDld3cjiLObn0aFvD7oK +E6wAP0gmbGrNDmvFikVUQIU0tuWfc6DLK2QEh00KOsOAjX9bwk7cOO6+W4XgEAqY +x0XsyNPkWOG8D9FzttnC1UESSoOyr41ne6sn5knkxdk00dxAVpabt7Z1Do4uNlAR +324+u62GkMP9Pv0tXU8YZxnHFT5njXDJMKXwj7vWiHPzR7Ykw6/fCUOYOcoprOwe +SUmevmUIGZKQVyKDrcLumGC4IPfv+UCPQyA3eUEyTPnkDawDRepha0FsZ7PtR/0F +tm7XW5u4HFMhRyrnDrHBGNywUg+bYGl7MWIqr0p3o9CVENDIRgLyuDrCnnMBUFrq +pbPp7Q6Z64ohdpvb+eJRYBCUJkbbFawUa/SXe6c5/cFAFwoNgN0UHcBvWKpVINc5 +WJPgiaHsauADeUuiU4+n3ZdOu8YMCpei+lM+eRR3KadZ2/UE9lzWQ7PfKDtMiIeo +n6oudZIlaTJPsV/AFIwJadJKpxYhgJ6JtlcORBUgzaWFEXRL+/rhVRZInNUCAwEA +AQKCAgADzJG4OfzUhUxTsQaGS95fzW8HDFmMURqltEVXOiPvFebThagSco8kEVCy +14z11YAGwK5X0psgMymGgMzn5jN/wHzqL6AV/+dKN6lnfa02w94nG5+Cybc7k1M6 +rVkCpQzN70ViMli9cM1lZjPiKaG8ppPILeg01TrxDTEl2tZCu5kSiyBDBK1Sh0zY +FubGJg5y1mRHAGKjM1eoy8DjGDov26tcuDm8OFdmqbrvSLkOpEvC8ni7nxzmLIc2 +nEJJaNuT+a0JA/7VtZGTX5W/mOCfNfwqOruTXedJ3v+jbdHoD5RYy4izoXWHfMRK +ALnT193j36xe1nJg+LnfS21BxH+DLNk4OkdiZJzOo3+7uiklZNO0vo6+8RVAckSB +Cvot8x84kUgPPzNyxVilzL5lvfaGrRpAWkPgOFmxyYYkjoenFh0FkWF5foF3sm/S +oSqf6/ojOzxbazI1oZ0yx3YMhfxnTdxQOngy9QeNhEF37ZuU4tutWInyuiIE2u5s +9XNXO7hYqdPqrSV+JOIMuYPTLTxdsKkSbdS4tHUuLt7mO0E2opo1ti8/lvIy10qA +eF2bm/Vcrpt5Zs406uTOZD1a3JaFwVKC1rzVNEDrVomRxsie6Hubkhd8X3SaBx1z +0bicV4yzhPFYu3iqNS2ZU9cJa4H8qJoQ8n1H8Yhi1O7Sc9CBkQKCAQEA7qlyMfPl +N0Vk45BSIWnD+vCaxQDYgqqDXlluGR8NgL0Yhst9Q0XpcP12eIKcTZ6pzdfUe/vK +NsCJXQnDLtxlIzWznXyiY1n/hNRW2dwwYEzPi8IZzE2RM1PltoHn7noVeAmZiayb +mf9M1pRUhbfVsg2gpbGnIR/XLC/1Ll2xNChkSiD5lbRNccRQZqkkIjYl0qCpvnoQ +LQCqdFmVgz0YAkooOf1n2vxVL4g1Ad3761FRSXqsPCNfvPY5+8LhFfBauuMoeS4t +DFjss7yyA6zA9+hxke3dqHVVnXWRTZkuuCzrcT5zOGv31K3l97BmCcKW4lXq6Gja +o0786rg3dpFOOQKCAQEA1emCNwEx5nt2HE070e6TwHBrMPWPKLionodxPsD2hsb0 +RfQKwqxabVVQKEcx+ZpsFFRUF79cHqVwIK6umTwfUdkrgxosjUqr+XYpINfICGln +KIp/nfGac1Sfx528rj7X5ZPaVzldEwCiReZmW29cS/XMKj5fpWXsRTBDr6tvenJn +UiY/6xo5vftFuCHVs4MUntK5BBNSYt3AJq/ZTZx9Lso2+EhkVA1cOlf05TRcWaiI +5KUklf7Rm9owxB3L8uWuBDlxmNymYbn2SuZywaWZlLBA2dTokUK200S4Yhzq/Xrz +dKFn0uTMnATgyj9FUqVZn7p40i7jF93bbBebmMDDfQKCAQBGHMdsf18mRp+l7r8C +C+VEMiz1lRMGB/vB2vnqLWI1INg0uVEaU06KIBwOuSgb8XGnBDHrHoRAY323NGf/ +u0WG+37B1FyMXWMgbZT6OaKIl+gdAa+8gkkW0B3a6Pzu5TSrZ/6QIIIx0nuLSlYu +VlxUC4bXRoJ3y7fVxlz7+xBU50zXLirEXQynUGniTuxLlKa14vca+xcHcXuh5LN0 +s5z7BzgcGSLKhXitFxGjc8hPUDtWH9C7dhTpGVjdalnfrRWqc5NvTi5zwyf+gX+2 +bqjd6455tWx50caOFHzUVB0ShDfCs/r7Z1SOSWwWwN6pHV5gLaduEWextEG+3tGE +ZpmZAoIBAA1FeYC0IEZubnt/BzEVHjGYR+43rfQW0M9VE9+S1Tiza0BTzb8aNloG +Kvz0vdMAk6gHO1hl1O9J0FUWwVpccoz/bkWqAA2cDmNhw1d4S77J206WmShRbwWs +wGUAEk61M2vY6njy5CVjqq2vh7YwiIdl7o7IY+K9GhWI0wo5FqeAJYzhNqH9dIum +5UJxRvLmNQdNh5ELKddcbql3y4GXLeUTQqnQw/i7A3fTMSxvPTOK00NsQ4LS1mpW +9SOVvauKOGumrLeRKPlzMiafeYsuHQMulDdvkCZC/1jIMLBVnvavBB++S9S3wUIE +w3WIy2I/Q/o29XwE0K4QY6anKE4n13kCggEAL0pUNXZc9Sz3auY6PmWCJX+XEggJ +CX8DbUADPKVWGMRekhbSkcgeTZOzqDg3Tcbd5PBQ+PdP9/MG316S0KCvMeBFn4ze +GtKPFfFruS3aLfocQLHP/9p+dEL1g5LjaGvuWQDD++X4EfmNTFe+1yOCb6GTyFBg +AU1QlkSSqQ1TVlH5URUQsLizIwlhTSGy/9J1ylFcLwCnl4uW8VR+exul1tx8zX6P +YWFci1ppCwbwIpSY8zjE8MT5MA3KvOg22gdJUhBhHw36xsYs4MBbzP6YIwm0d+7M +vEcenOCHCMyOdA/Y4/036KnOBaRMyjZ/pABCK1KKaCPZQZR/FaRyyfJ36g== +-----END RSA PRIVATE KEY----- diff --git a/tests/unit/common/test_crypto.py b/tests/unit/common/test_crypto.py new file mode 100644 index 0000000..293d6bf --- /dev/null +++ b/tests/unit/common/test_crypto.py @@ -0,0 +1,65 @@ +# standard imports +import os +from datetime import datetime, UTC + +# lib imports +from cryptography import x509 +import pytest + +# local imports +from src.common.crypto import check_expiration, generate_certificate, initialize_certificate, CERT_FILE, KEY_FILE + + +@pytest.fixture(scope='module') +def setup_certificates(): + # Ensure the certificates are generated for testing + if not os.path.exists(CERT_FILE) or not os.path.exists(KEY_FILE): + generate_certificate() + yield + # Cleanup after tests + if os.path.exists(CERT_FILE): + os.remove(CERT_FILE) + if os.path.exists(KEY_FILE): + os.remove(KEY_FILE) + + +@pytest.fixture(scope='function') +def clear_certificates(): + os.remove(CERT_FILE) + os.remove(KEY_FILE) + yield + + +def test_check_expiration(setup_certificates): + days_left = check_expiration(CERT_FILE) + assert days_left <= 365 + assert days_left >= 364 + + +def test_check_expiration_expired(): + cert_file = os.path.join("tests", "fixtures", "certs", "expired", "cert.pem") + days_left = check_expiration(cert_file) + assert days_left < 0 + + +def test_generate_certificate(setup_certificates): + assert os.path.exists(CERT_FILE) + assert os.path.exists(KEY_FILE) + + with open(CERT_FILE, "rb") as cert_file: + cert_data = cert_file.read() + + cert = x509.load_pem_x509_certificate(cert_data) + assert cert.not_valid_after_utc > datetime.now(UTC) + + +@pytest.mark.parametrize("fixture", ["setup_certificates", "clear_certificates"]) +def test_initialize_certificate(request, fixture): + request.getfixturevalue(fixture) + cert_file, key_file = initialize_certificate() + assert os.path.exists(cert_file) + assert os.path.exists(key_file) + + cert_expires_in = check_expiration(cert_file) + assert cert_expires_in <= 365 + assert cert_expires_in >= 364 diff --git a/tests/unit/common/test_sponsors.py b/tests/unit/common/test_sponsors.py new file mode 100644 index 0000000..b0b0aa8 --- /dev/null +++ b/tests/unit/common/test_sponsors.py @@ -0,0 +1,17 @@ +# local imports +from src.common import sponsors + + +def test_get_github_sponsors(): + data = sponsors.get_github_sponsors() + assert data + assert 'errors' not in data + assert 'data' in data + assert 'organization' in data['data'] + assert 'sponsorshipsAsMaintainer' in data['data']['organization'] + assert 'edges' in data['data']['organization']['sponsorshipsAsMaintainer'] + + +def test_get_github_sponsors_error(no_github_token): + data = sponsors.get_github_sponsors() + assert not data diff --git a/tests/unit/common/test_time.py b/tests/unit/common/test_time.py new file mode 100644 index 0000000..ff41e9e --- /dev/null +++ b/tests/unit/common/test_time.py @@ -0,0 +1,17 @@ +# standard imports +import datetime + +# lib imports +import pytest + +# local imports +from src.common import time + + +@pytest.mark.parametrize("iso_str, expected", [ + ("2024-11-23T20:29:48", datetime.datetime(2024, 11, 23, 20, 29, 48)), + ("2023-01-01T00:00:00", datetime.datetime(2023, 1, 1, 0, 0, 0)), + ("2022-12-31T23:59:59", datetime.datetime(2022, 12, 31, 23, 59, 59)), +]) +def test_iso_to_datetime(iso_str, expected): + assert time.iso_to_datetime(iso_str) == expected diff --git a/tests/unit/common/test_webapp.py b/tests/unit/common/test_webapp.py new file mode 100644 index 0000000..ee2fce9 --- /dev/null +++ b/tests/unit/common/test_webapp.py @@ -0,0 +1,349 @@ +# standard imports +import os +from unittest.mock import Mock + +# lib imports +import pytest + +# local imports +from src.common import webapp + + +@pytest.fixture(scope='function') +def test_client(): + """Create a test client for testing webapp endpoints""" + app = webapp.app + app.testing = True + + client = app.test_client() + + # Create a test client using the Flask application configured for testing + with client as test_client: + # Establish an application context + with app.app_context(): + yield test_client # this is where the testing happens! + + +def test_status(test_client): + """ + WHEN the '/status' page is requested (GET) + THEN check that the response is valid + """ + response = test_client.get('/status') + assert response.status_code == 200 + + +def test_favicon(test_client): + """ + WHEN the '/favicon.ico' file is requested (GET) + THEN check that the response is valid + THEN check the content type is 'image/vnd.microsoft.icon' + """ + response = test_client.get('/favicon.ico') + assert response.status_code == 200 + assert response.content_type == 'image/vnd.microsoft.icon' + + +def test_discord_callback_success(test_client, mocker, discord_db_users): + """ + WHEN the '/discord/callback' endpoint is requested (GET) with valid data + THEN check that the response is a redirect to the main website + """ + mocker.patch.dict(os.environ, { + "DISCORD_CLIENT_ID": "test_client_id", + "DISCORD_CLIENT_SECRET": "test_client_secret", + "DISCORD_REDIRECT_URI": "https://localhost:8080/discord/callback" + }) + + mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) + mocker.patch('src.common.webapp.OAuth2Session.get', side_effect=[ + Mock(json=lambda: { + 'id': '939171917578002502', + 'username': 'discord_user', + 'global_name': 'discord_global_name', + }), + Mock(json=lambda: [ + { + 'type': 'github', + 'id': 'github_user_id', + 'name': 'github_user_login', + } + ]) + ]) + + response = test_client.get('/discord/callback?state=valid_state') + + assert response.status_code == 302 + assert response.location == "https://app.lizardbyte.dev" + + +def test_discord_callback_invalid_state(test_client, mocker, discord_db_users): + """ + WHEN the '/discord/callback' endpoint is requested (GET) with an invalid state + THEN check that the response is 'Invalid state' + """ + mocker.patch.dict(os.environ, { + "DISCORD_CLIENT_ID": "test_client_id", + "DISCORD_CLIENT_SECRET": "test_client_secret", + "DISCORD_REDIRECT_URI": "https://localhost:8080/discord/callback" + }) + + mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) + mocker.patch('src.common.webapp.OAuth2Session.get', return_value=Mock(json=lambda: { + 'id': '1234567890', + 'username': 'discord_user', + 'global_name': 'discord_global_name', + })) + + response = test_client.get('/discord/callback?state=invalid_state') + + assert response.data == b'Invalid state' + assert response.status_code == 400 + + +def test_discord_callback_error_in_request(test_client): + """ + WHEN the '/discord/callback' endpoint is requested (GET) with an error in the request + THEN check that the response is the error description + """ + response = test_client.get('/discord/callback?error=access_denied&error_description=The+user+denied+access') + + assert response.data == b'The user denied access' + assert response.status_code == 400 + + +def test_github_callback_success(test_client, mocker, discord_db_users): + """ + WHEN the '/github/callback' endpoint is requested (GET) with valid data + THEN check that the response is a redirect to the main website + """ + mocker.patch.dict(os.environ, { + "GITHUB_CLIENT_ID": "test_client_id", + "GITHUB_CLIENT_SECRET": "test_client_secret", + "GITHUB_REDIRECT_URI": "https://localhost:8080/github/callback" + }) + + mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) + mocker.patch('src.common.webapp.OAuth2Session.get', side_effect=[ + Mock(json=lambda: { + 'id': 'github_user_id', + 'login': 'github_user_login', + }), + Mock(json=lambda: { + 'id': 'github_user_id', + 'login': 'github_user_login', + }) + ]) + + response = test_client.get('/github/callback?state=valid_state') + + assert response.status_code == 302 + assert response.location == "https://app.lizardbyte.dev" + + +def test_github_callback_invalid_state(test_client, mocker, discord_db_users): + """ + WHEN the '/github/callback' endpoint is requested (GET) with an invalid state + THEN check that the response is 'Invalid state' + """ + mocker.patch.dict(os.environ, { + "GITHUB_CLIENT_ID": "test_client_id", + "GITHUB_CLIENT_SECRET": "test_client_secret", + "GITHUB_REDIRECT_URI": "https://localhost:8080/github/callback" + }) + + mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) + mocker.patch('src.common.webapp.OAuth2Session.get', return_value=Mock(json=lambda: { + 'id': 'github_user_id', + 'login': 'github_user_login', + })) + + response = test_client.get('/github/callback?state=invalid_state') + + assert response.data == b'Invalid state' + assert response.status_code == 400 + + +def test_github_callback_error_in_request(test_client): + """ + WHEN the '/github/callback' endpoint is requested (GET) with an error in the request + THEN check that the response is the error description + """ + response = test_client.get('/github/callback?error=access_denied&error_description=The+user+denied+access') + + assert response.data == b'The user denied access' + assert response.status_code == 400 + + +def test_webhook_invalid_source(test_client): + """ + WHEN the '/webhook//' endpoint is requested (POST) with an invalid source + THEN check that the response is 'Invalid source' + """ + response = test_client.post('/webhook/invalid_source/invalid_key') + assert response.json == {"status": "error", "message": "Invalid source"} + assert response.status_code == 400 + + +def test_webhook_invalid_key(test_client, mocker): + """ + WHEN the '/webhook//' endpoint is requested (POST) with an invalid key + THEN check that the response is 'Invalid key' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + response = test_client.post('/webhook/github_sponsors/invalid_key') + assert response.json == {"status": "error", "message": "Invalid key"} + assert response.status_code == 400 + + +def test_webhook_github_sponsors(discord_bot, test_client, mocker): + """ + WHEN the '/webhook/github_sponsors/' endpoint is requested (POST) with valid data + THEN check that the response is 'success' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + data = { + 'action': 'created', + 'sponsorship': { + 'sponsor': { + 'login': 'octocat', + 'url': 'https://github.com/octocat', + 'avatar_url': 'https://avatars.githubusercontent.com/u/583231', + }, + 'created_at': '1970-01-01T00:00:00Z', + }, + } + response = test_client.post('/webhook/github_sponsors/valid_key', json=data) + assert response.json == {"status": "success"} + assert response.status_code == 200 + + +@pytest.mark.parametrize("data, expected_status", [ + # https://support.atlassian.com/statuspage/docs/enable-webhook-notifications/ + ({ + "meta": { + "unsubscribe": "https://statustest.flyingkleinbrothers.com:5000/?unsubscribe=j0vqr9kl3513", + "documentation": "https://doers.statuspage.io/customer-notifications/webhooks/", + }, + "page": { + "id": "j2mfxwj97wnj", + "status_indicator": "major", + "status_description": "Partial System Outage", + }, + "component_update": { + "created_at": "2013-05-29T21:32:28Z", + "new_status": "operational", + "old_status": "major_outage", + "id": "k7730b5v92bv", + "component_id": "rb5wq1dczvbm", + }, + "component": { + "created_at": "2013-05-29T21:32:28Z", + "id": "rb5wq1dczvbm", + "name": "Some Component", + "status": "operational", + }, + }, 200), + ({ + "meta": { + "unsubscribe": "https://statustest.flyingkleinbrothers.com:5000/?unsubscribe=j0vqr9kl3513", + "documentation": "https://doers.statuspage.io/customer-notifications/webhooks/", + }, + "page": { + "id": "j2mfxwj97wnj", + "status_indicator": "critical", + "status_description": "Major System Outage", + }, + "incident": { + "backfilled": False, + "created_at": "2013-05-29T15:08:51-06:00", + "impact": "critical", + "impact_override": None, + "monitoring_at": "2013-05-29T16:07:53-06:00", + "postmortem_body": None, + "postmortem_body_last_updated_at": None, + "postmortem_ignored": False, + "postmortem_notified_subscribers": False, + "postmortem_notified_twitter": False, + "postmortem_published_at": None, + "resolved_at": None, + "scheduled_auto_transition": False, + "scheduled_for": None, + "scheduled_remind_prior": False, + "scheduled_reminded_at": None, + "scheduled_until": None, + "shortlink": "https://j.mp/18zyDQx", + "status": "monitoring", + "updated_at": "2013-05-29T16:30:35-06:00", + "id": "lbkhbwn21v5q", + "organization_id": "j2mfxwj97wnj", + "incident_updates": [ + { + "body": "A fix has been implemented and we are monitoring the results.", + "created_at": "2013-05-29T16:07:53-06:00", + "display_at": "2013-05-29T16:07:53-06:00", + "status": "monitoring", + "twitter_updated_at": None, + "updated_at": "2013-05-29T16:09:09-06:00", + "wants_twitter_update": False, + "id": "drfcwbnpxnr6", + "incident_id": "lbkhbwn21v5q", + }, + { + "body": "We are waiting for the cloud to come back online " + "and will update when we have further information", + "created_at": "2013-05-29T15:18:51-06:00", + "display_at": "2013-05-29T15:18:51-06:00", + "status": "identified", + "twitter_updated_at": None, + "updated_at": "2013-05-29T15:28:51-06:00", + "wants_twitter_update": False, + "id": "2rryghr4qgrh", + "incident_id": "lbkhbwn21v5q", + }, + { + "body": "The cloud, located in Norther Virginia, has once again gone the way of the dodo.", + "created_at": "2013-05-29T15:08:51-06:00", + "display_at": "2013-05-29T15:08:51-06:00", + "status": "investigating", + "twitter_updated_at": None, + "updated_at": "2013-05-29T15:28:51-06:00", + "wants_twitter_update": False, + "id": "qbbsfhy5s9kk", + "incident_id": "lbkhbwn21v5q", + }, + ], + "name": "Virginia Is Down", + }, + }, 200), + ({ + "meta": { + "unsubscribe": "https://statustest.flyingkleinbrothers.com:5000/?unsubscribe=j0vqr9kl3513", + "documentation": "https://doers.statuspage.io/customer-notifications/webhooks/", + }, + "page": { + "id": "j2mfxwj97wnj", + "status_indicator": "critical", + "status_description": "Major System Outage", + }, + "incident": { + "incident_updates": [], + "name": "Virginia Is Down", + }, + }, 400), +]) +def test_webhook_github_status(discord_bot, test_client, mocker, data, expected_status): + """ + WHEN the '/webhook/github_status/' endpoint is requested (POST) with valid data + THEN check that the response is 'success' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + response = test_client.post('/webhook/github_status/valid_key', json=data) + assert response.status_code == expected_status + + if expected_status == 200: + assert response.json == {"status": "success"} + + if expected_status == 400: + assert response.json["status"] == "error" + assert response.json["message"] diff --git a/tests/unit/discord/test_discord_bot.py b/tests/unit/discord/test_discord_bot.py index 500722c..cb53388 100644 --- a/tests/unit/discord/test_discord_bot.py +++ b/tests/unit/discord/test_discord_bot.py @@ -1,42 +1,79 @@ # standard imports import asyncio +import os # lib imports +import discord import pytest -import pytest_asyncio # local imports -from src import common -from src.discord import bot as discord_bot - - -@pytest_asyncio.fixture -async def bot(): - # event_loop fixture is deprecated - _loop = asyncio.get_event_loop() - - bot = discord_bot.Bot(loop=_loop) - future = asyncio.run_coroutine_threadsafe(bot.start(token=bot.token), _loop) - await bot.wait_until_ready() # Wait until the bot is ready - yield bot - bot.stop(future=future) - - # wait for the bot to finish - counter = 0 - while not future.done() and counter < 30: - await asyncio.sleep(1) - counter += 1 - future.cancel() # Cancel the bot when the tests are done - - -@pytest.mark.asyncio -async def test_bot_on_ready(bot): - assert bot is not None - assert bot.guilds - assert bot.guilds[0].name == "ReenigneArcher's test server" - assert bot.user.id == 939171917578002502 - assert bot.user.name == common.bot_name - assert bot.user.avatar +from src.common import common + + +def test_bot_on_ready(discord_bot): + assert discord_bot is not None + assert discord_bot.guilds + assert discord_bot.guilds[0].name == "ReenigneArcher's test server" + assert discord_bot.user.id == 939171917578002502 + assert discord_bot.user.name == common.bot_name + assert discord_bot.user.avatar # compare the bot avatar to our intended avatar - assert await bot.user.avatar.read() == common.get_avatar_bytes() + future = asyncio.run_coroutine_threadsafe(discord_bot.user.avatar.read(), discord_bot.loop) + assert future.result() == common.get_avatar_bytes() + + +@pytest.mark.parametrize("message, embed", [ + (None, None), + (f"This is a test message from {os.getenv('CI_EVENT_ID', 'local')}.", None), + (None, discord.Embed( + title="Test Embed 1", + description="This is a test embed from the unit tests.", + color=0x00ff00, + )), + (None, discord.Embed( + title="Test Embed 2", + description=f"{'a' * 4097}", # ensure embed description is larger than 4096 characters + color=0xff0000, + )), + (None, discord.Embed( + title="Test Embed 3", + description=f"{'a' * 4096}", + color=0xff0000, + footer=discord.EmbedFooter( + text=f"{'b' * 2000}" # ensure embed total size is larger than 6000 characters + ), + )), +]) +def test_send_message(discord_bot, message, embed): + channel_id = int(os.environ['DISCORD_GITHUB_STATUS_CHANNEL_ID']) + msg = discord_bot.send_message(channel_id=channel_id, message=message, embed=embed) + + if not message and not embed: + assert msg is None + return + + if message: + assert msg.content == message + else: + assert msg.content == '' + + assert msg.channel.id == channel_id + assert msg.author.id == 939171917578002502 + assert msg.author.name == common.bot_name + + avatar_future = asyncio.run_coroutine_threadsafe(msg.author.avatar.read(), discord_bot.loop) + assert avatar_future.result() == common.get_avatar_bytes() + + assert msg.author.display_name == common.bot_name + assert msg.author.discriminator == "7085" + assert msg.author.bot is True + assert msg.author.system is False + + if embed: + assert msg.embeds[0].title == embed.title + assert msg.embeds[0].description == embed.description[:4093] + "..." if len( + embed.description) > 4096 or len(embed) > 6000 else embed.description + assert msg.embeds[0].color == embed.color + if embed.footer: + assert msg.embeds[0].footer.text == embed.footer.text diff --git a/tests/unit/discord/test_tasks.py b/tests/unit/discord/test_tasks.py new file mode 100644 index 0000000..49f16f0 --- /dev/null +++ b/tests/unit/discord/test_tasks.py @@ -0,0 +1,130 @@ +# standard imports +from datetime import datetime, timedelta, timezone, UTC +import os + +# lib imports +import pytest + +# local imports +from src.discord import tasks + + +def set_env_variable(env_var_name, request): + og_value = os.environ.get(env_var_name) + new_value = request.param + if new_value is not None: + os.environ[env_var_name] = new_value + yield + if og_value is not None: + os.environ[env_var_name] = og_value + elif env_var_name in os.environ: + del os.environ[env_var_name] + + +@pytest.fixture(scope='function') +def set_daily_channel_id(request): + yield from set_env_variable('DAILY_CHANNEL_ID', request) + + +@pytest.fixture(scope='function') +def set_daily_releases(request): + yield from set_env_variable('DAILY_RELEASES', request) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("db_start, expected_keys", [ + ( + { + '1': { + 'expires_at': datetime.now(UTC), + }, + '2': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=1) + }, + '3': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=2) + }, + '4': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=3) + }, + '5': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=4) + }, + '6': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=5) + }, + '7': { + 'expires_at': datetime.now(UTC) - timedelta(minutes=10) + }, + }, + ['1', '2', '3', '4', '5'] + ) +]) +async def test_clean_ephemeral_cache(discord_bot, mocker, db_start, expected_keys): + """ + GIVEN a database with ephemeral cache entries + WHEN the clean_ephemeral_cache task is called + THEN expired entries are removed from the database + """ + # Mock the edit method of the response objects + for entry in db_start.values(): + entry['response'] = mocker.Mock() + entry['response'].edit = mocker.AsyncMock() + + # Mock the bot's ephemeral_db + discord_bot.ephemeral_db = { + 'github_cache_context': db_start + } + + # Run the clean_ephemeral_cache task + await tasks.clean_ephemeral_cache(bot=discord_bot) + + # Assert the ephemeral_db is as expected + for k, v in discord_bot.ephemeral_db['github_cache_context'].items(): + assert k in expected_keys, f"Key {k} should not be in the database" + assert v['expires_at'] >= datetime.now(UTC) - timedelta(minutes=5), f"Key {k} should not have expired" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("skip, set_daily_releases, set_daily_channel_id, expected", [ + (True, 'false', None, False), + (False, 'false', None, False), + (False, 'true', None, False), + (False, 'true', os.environ['DISCORD_GITHUB_STATUS_CHANNEL_ID'], True), +], indirect=["set_daily_releases", "set_daily_channel_id"]) +async def test_daily_task(discord_bot, mocker, skip, set_daily_releases, set_daily_channel_id, expected): + """ + WHEN the daily task is called + THEN check that the task runs without error + """ + # Patch datetime.datetime at the location where it's imported in `tasks` + mock_datetime = mocker.patch('src.discord.tasks.datetime', autospec=True) + mock_datetime.now.return_value = datetime(2023, 1, 1, 1 if skip else 12, 0, 0, tzinfo=timezone.utc) + + # Run the daily task + result = await tasks.daily_task(bot=discord_bot) + + assert result is expected + + # Verify that datetime.now() was called + mock_datetime.now.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("skip", [True, False]) +async def test_role_update_task(discord_bot, discord_db_users, mocker, skip): + """ + WHEN the role update task is called + THEN check that the task runs without error + """ + # Patch datetime.datetime at the location where it's imported in `tasks` + mock_datetime = mocker.patch('src.discord.tasks.datetime', autospec=True) + mock_datetime.now.return_value = datetime(2023, 1, 1, 0, 1 if skip else 0, 0, tzinfo=timezone.utc) + + # Run the task + result = await tasks.role_update_task(bot=discord_bot) + + assert result is not skip + + # Verify that datetime.now() was called + mock_datetime.now.assert_called_once() diff --git a/tests/unit/reddit/test_reddit_bot.py b/tests/unit/reddit/test_reddit_bot.py index 8ff1a84..07a38bc 100644 --- a/tests/unit/reddit/test_reddit_bot.py +++ b/tests/unit/reddit/test_reddit_bot.py @@ -161,7 +161,7 @@ def _submission(self, bot, recorder): def test_validate_env(self, bot): with patch.dict( os.environ, { - "DISCORD_WEBHOOK": "test", + "DISCORD_REDDIT_CHANNEL_ID": "test", "PRAW_CLIENT_ID": "test", "PRAW_CLIENT_SECRET": "test", "REDDIT_PASSWORD": "test", @@ -198,7 +198,7 @@ def test_process_comment(self, bot, recorder, request, slash_command_comment): assert db['comments'][slash_command_comment.id]['slash_command']['project'] == 'sunshine' assert db['comments'][slash_command_comment.id]['slash_command']['command'] == 'vban' - def test_process_submission(self, bot, recorder, request, _submission): + def test_process_submission(self, bot, discord_bot, recorder, request, _submission): with recorder.use_cassette(request.node.name): bot.process_submission(submission=_submission) with bot.lock, shelve.open(bot.db) as db: @@ -213,7 +213,7 @@ def test_comment_loop(self, bot, recorder, request): comment = bot._comment_loop(test=True) assert comment.author - def test_submission_loop(self, bot, recorder, request): + def test_submission_loop(self, bot, discord_bot, recorder, request): with recorder.use_cassette(request.node.name): submission = bot._submission_loop(test=True) assert submission.author