Skip to content

Commit

Permalink
feat(github): add github integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ReenigneArcher committed Nov 10, 2024
1 parent 362aef2 commit 69a160c
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 46 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ praw==7.8.1
py-cord==2.6.1
python-dotenv==1.0.1
requests==2.32.3
requests-oauthlib==2.0.0
26 changes: 10 additions & 16 deletions src/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# standard imports
import os
import time

# development imports
Expand All @@ -8,33 +7,28 @@

# local imports
if True: # hack for flake8
from src import globals
from src.discord import bot as d_bot
from src import keep_alive
from src import webapp
from src.reddit import bot as r_bot


def main():
# to run in replit
try:
os.environ['REPL_SLUG']
except KeyError:
pass # not running in replit
else:
keep_alive.keep_alive() # Start the web server
webapp.start() # Start the web server

discord_bot = d_bot.Bot()
discord_bot.start_threaded() # Start the discord bot
globals.DISCORD_BOT = d_bot.Bot()
globals.DISCORD_BOT.start_threaded() # Start the discord bot

reddit_bot = r_bot.Bot()
reddit_bot.start_threaded() # Start the reddit bot
globals.REDDIT_BOT = r_bot.Bot()
globals.REDDIT_BOT.start_threaded() # Start the reddit bot

try:
while discord_bot.bot_thread.is_alive() or reddit_bot.bot_thread.is_alive():
while globals.DISCORD_BOT.bot_thread.is_alive() or globals.REDDIT_BOT.bot_thread.is_alive():
time.sleep(0.5)
except KeyboardInterrupt:
print("Keyboard Interrupt Detected")
discord_bot.stop()
reddit_bot.stop()
globals.DISCORD_BOT.stop()
globals.REDDIT_BOT.stop()


if __name__ == '__main__': # pragma: no cover
Expand Down
69 changes: 69 additions & 0 deletions src/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# standard imports
import os

# lib imports
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
from datetime import datetime, timedelta, UTC

# local imports
from src import common

CERT_FILE = os.path.join(common.data_dir, "cert.pem")
KEY_FILE = os.path.join(common.data_dir, "key.pem")


def check_expiration(cert_path: str) -> int:
with open(cert_path, "rb") as cert_file:
cert_data = cert_file.read()
cert = x509.load_pem_x509_certificate(cert_data, default_backend())
expiry_date = cert.not_valid_after_utc
return (expiry_date - datetime.now(UTC)).days


def generate_certificate():
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
subject = issuer = x509.Name([
x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.now(UTC)
).not_valid_after(
datetime.now(UTC) + timedelta(days=365)
).sign(private_key, hashes.SHA256())

with open(KEY_FILE, "wb") as f:
f.write(private_key.private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=NoEncryption(),
))

with open(CERT_FILE, "wb") as f:
f.write(cert.public_bytes(Encoding.PEM))


def initialize_certificate() -> tuple[str, str]:
print("Initializing SSL certificate")
if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE):
cert_expires_in = check_expiration(CERT_FILE)
print(f"Certificate expires in {cert_expires_in} days.")
if cert_expires_in >= 90:
return CERT_FILE, KEY_FILE
print("Generating new certificate")
generate_certificate()
return CERT_FILE, KEY_FILE
22 changes: 22 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# standard imports
import shelve
import threading


class Database:
def __init__(self, db_path):
self.db_path = db_path
self.lock = threading.Lock()

def __enter__(self):
self.lock.acquire()
self.db = shelve.open(self.db_path, writeback=True)
return self.db

def __exit__(self, exc_type, exc_val, exc_tb):
self.sync()
self.db.close()
self.lock.release()

def sync(self):
self.db.sync()
7 changes: 6 additions & 1 deletion src/discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import discord

# local imports
from src.common import bot_name, get_avatar_bytes, org_name
from src.common import bot_name, data_dir, get_avatar_bytes, org_name
from src.database import Database
from src.discord.tasks import daily_task
from src.discord.views import DonateCommandView

Expand All @@ -30,13 +31,17 @@ def __init__(self, *args, **kwargs):

self.bot_thread = threading.Thread(target=lambda: None)
self.token = os.environ['DISCORD_BOT_TOKEN']
self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database'))

self.load_extension(
name='src.discord.cogs',
recursive=True,
store=False,
)

with self.db as db:
db['oauth_states'] = {} # clear any oauth states from previous sessions

async def on_ready(self):
"""
Bot on ready event.
Expand Down
123 changes: 123 additions & 0 deletions src/discord/cogs/github_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# standard imports
import os

# lib imports
import discord
import requests
from requests_oauthlib import OAuth2Session


class GitHubCommandsCog(discord.Cog):
def __init__(self, bot):
self.bot = bot
self.token = os.getenv("GITHUB_TOKEN")
self.org_name = os.getenv("GITHUB_ORG_NAME", "LizardByte")
self.graphql_url = "https://api.github.com/graphql"
self.headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
}

@discord.slash_command(
name="get_sponsors",
description="Get list of GitHub sponsors",
default_member_permissions=discord.Permissions(manage_guild=True),
)
async def get_sponsors(
self,
ctx: discord.ApplicationContext,
):
"""
Get list of GitHub sponsors.
Parameters
----------
ctx : discord.ApplicationContext
Request message context.
"""
query = """
query {
organization(login: "%s") {
sponsorshipsAsMaintainer(first: 100) {
edges {
node {
sponsorEntity {
... on User {
login
name
avatarUrl
url
}
... on Organization {
login
name
avatarUrl
url
}
}
tier {
name
monthlyPriceInDollars
}
}
}
}
}
}
""" % self.org_name

response = requests.post(self.graphql_url, json={'query': query}, headers=self.headers)
data = response.json()

if 'errors' in data:
print(data['errors'])
await ctx.respond("An error occurred while fetching sponsors.", ephemeral=True)
return

message = "List of GitHub sponsors"
for edge in data['data']['organization']['sponsorshipsAsMaintainer']['edges']:
sponsor = edge['node']['sponsorEntity']
tier = edge['node'].get('tier', {})
tier_info = f" - Tier: {tier.get('name', 'N/A')} (${tier.get('monthlyPriceInDollars', 'N/A')}/month)"
message += f"\n* [{sponsor['login']}]({sponsor['url']}){tier_info}"

embed = discord.Embed(title="GitHub Sponsors", color=0x00ff00, description=message)

await ctx.respond(embed=embed, ephemeral=True)

@discord.slash_command(
name="link_github",
description="Validate GitHub sponsor status"
)
async def link_github(self, ctx: discord.ApplicationContext):
"""
Link Discord account with GitHub account, by validating Discord user's "GitHub" connected account status.
User to login to Discord via OAuth2, and check if their connected GitHub account is a sponsor of the project.
Parameters
----------
ctx : discord.ApplicationContext
Request message context.
"""
discord_oauth = OAuth2Session(
os.environ['DISCORD_CLIENT_ID'],
redirect_uri=os.environ['DISCORD_REDIRECT_URI'],
scope=[
"identify",
"connections",
],
)
authorization_url, state = discord_oauth.authorization_url("https://discord.com/oauth2/authorize")

with self.bot.db as db:
db['oauth_states'] = db.get('oauth_states', {})
db['oauth_states'][str(ctx.author.id)] = state
db.sync()

# Store the state in the user's session or database
await ctx.respond(f"Please authorize the application by clicking [here]({authorization_url}).", ephemeral=True)


def setup(bot: discord.Bot):
bot.add_cog(GitHubCommandsCog(bot=bot))
2 changes: 2 additions & 0 deletions src/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DISCORD_BOT = None
REDDIT_BOT = None
20 changes: 0 additions & 20 deletions src/keep_alive.py

This file was deleted.

11 changes: 2 additions & 9 deletions src/reddit/bot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# standard imports
from datetime import datetime
import os
import requests
import shelve
import sys
import threading
Expand All @@ -10,6 +9,7 @@
# lib imports
import praw
from praw import models
import requests

# local imports
from src import common
Expand All @@ -31,14 +31,7 @@ def __init__(self, **kwargs):
self.user_agent = kwargs.get('user_agent', f'{common.bot_name} {self.version}')
self.avatar = kwargs.get('avatar', common.get_bot_avatar(gravatar=os.environ['GRAVATAR_EMAIL']))
self.subreddit_name = kwargs.get('subreddit', os.getenv('PRAW_SUBREDDIT', 'LizardByte'))

if not kwargs.get('redirect_uri', None):
try: # for running in replit
self.redirect_uri = f'https://{os.environ["REPL_SLUG"]}.{os.environ["REPL_OWNER"].lower()}.repl.co'
except KeyError:
self.redirect_uri = os.getenv('REDIRECT_URI', 'http://localhost:8080')
else:
self.redirect_uri = kwargs['redirect_uri']
self.redirect_uri = kwargs.get('redirect_uri', os.getenv('REDIRECT_URI', 'http://localhost:8080'))

# directories
self.data_dir = common.data_dir
Expand Down
Loading

0 comments on commit 69a160c

Please sign in to comment.