diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index 316a3ad..2def5ea 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -38,7 +38,7 @@ jobs: pip install pylint - name: Analysing the code with pylint run: | - pylint $(git ls-files '*.py') + pylint main.py test: runs-on: ubuntu-latest steps: @@ -50,7 +50,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-asyncio + pip install pytest pytest-asyncio pytest-mock if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest run: | diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..b44d61e --- /dev/null +++ b/.pylintrc @@ -0,0 +1,6 @@ +[MASTER] +disable= + C0114, # docstrings + C0115, + C0116, + C0301 #line too long diff --git a/main.py b/main.py index a84a9c3..c7e8e3b 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,33 @@ +import base64 +import csv import glob +import hashlib +import json import os from datetime import datetime -import nextcord -import hashlib import requests -import base64 -import json -from gnupg import GPG -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ed25519 -import csv import boto3 +import nextcord from botocore.exceptions import ClientError +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from gnupg import GPG -global S3_BUCKET -S3_ENABLED = False key_fingerprints = [] gpg = GPG() +S3_CLIENT = None +SIGNING_KEY = None + +class ConfigurationError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + +class CryptographyError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message def get_ephemeral_path(): @@ -26,37 +37,42 @@ def get_ephemeral_path(): return ephemeral_path +def is_s3_enabled(): + return os.getenv('S3_ENABLED') == '1' + + +def s3_bucket(): + return os.getenv('S3_BUCKET') + + def init_s3(): # Read Env - S3_ENABLED = os.getenv('S3_ENABLED') == '1' - S3_BUCKET = os.getenv('S3_BUCKET') - S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY') - S3_SECRET_KEY = os.getenv('S3_SECRET_KEY') - S3_ENDPOINT = os.getenv('S3_ENDPOINT') + s3_access_key = os.getenv('S3_ACCESS_KEY') + s3_secret_key = os.getenv('S3_SECRET_KEY') + s3_endpoint = os.getenv('S3_ENDPOINT') # Setup - if S3_ENABLED: + if is_s3_enabled(): # Prepare S3 access - if S3_BUCKET is None: - raise Exception('S3 enabled but S3_BUCKET not set.') + if s3_bucket() is None: + raise ConfigurationError('S3 enabled but S3_BUCKET not set.') - if S3_ACCESS_KEY is None: - raise Exception('S3 enabled but S3_ACCESS_KEY not set.') + if s3_access_key is None: + raise ConfigurationError('S3 enabled but S3_ACCESS_KEY not set.') - if S3_SECRET_KEY is None: - raise Exception('S3 enabled but S3_SECRET_KEY not set.') + if s3_secret_key is None: + raise ConfigurationError('S3 enabled but S3_SECRET_KEY not set.') - if S3_ENDPOINT is None: - raise Exception('S3 enabled but S3_ENDPOINT not set.') + if s3_endpoint is None: + raise ConfigurationError('S3 enabled but S3_ENDPOINT not set.') return boto3.client( service_name='s3', - aws_access_key_id=S3_ACCESS_KEY, - aws_secret_access_key=S3_SECRET_KEY, - endpoint_url=S3_ENDPOINT, + aws_access_key_id=s3_access_key, + aws_secret_access_key=s3_secret_key, + endpoint_url=s3_endpoint, ) - else: - return None + return None def hash_string(msg): @@ -85,7 +101,7 @@ def extract_message(message): # process any attachments/files for attach in message.attachments: # Download file - resp = requests.get(attach.url) + resp = requests.get(attach.url, timeout=30) resp.raise_for_status() attachments.append({ @@ -134,25 +150,26 @@ def write_to_storage(backup_msg): if not enc_msg.ok: print(f'Encryption failed: {enc_msg.status}') - raise Exception('Unable to encrypt') + raise CryptographyError('Unable to encrypt') # hash and sign msg enc_hash_str, enc_hash_b = hash_string(str(enc_msg)) - signature = signing_key.sign(enc_hash_b).hex() + signature = SIGNING_KEY.sign(enc_hash_b).hex() manifest_path, _ = get_manifest_path(backup_msg["server"]["id"], backup_msg["channel"]["id"]) # Write to manifest - with open(manifest_path, 'a') as manifest_file: + with open(manifest_path, 'a', encoding='utf-8') as manifest_file: writer = csv.writer(manifest_file) # Date/Time of message, msg hash, signature writer.writerow([backup_msg['created_at'], enc_hash_str, signature]) # write to msg file - if S3_ENABLED: - s3.put_object(Bucket=S3_BUCKET, Key=f'messages/{enc_hash_str}', Body=str(enc_msg)) + if is_s3_enabled(): + S3_CLIENT.put_object(Bucket=s3_bucket(), Key=f'messages/{enc_hash_str}', Body=str(enc_msg)) else: - with open(os.path.join(get_ephemeral_path(), f'{enc_hash_str}.msg'),'w') as msg_file: + msg_path = os.path.join(get_ephemeral_path(), f'{enc_hash_str}.msg') + with open(msg_path,'w', encoding='utf-8') as msg_file: msg_file.write(str(enc_msg)) print(f'Message written: {enc_hash_str}') @@ -166,11 +183,11 @@ def get_signing_key(): keyfile = os.getenv('SIGN_KEY_PEM') if keyfile is None: - raise Exception('Signing key not configured. Please set SIGN_KEY_PEM.') + raise ConfigurationError('Signing key not configured. Please set SIGN_KEY_PEM.') # Verify file extension if os.path.splitext(keyfile)[1] != '.pem': - raise Exception('Signing key file not a pem file. make sure the extension is pem.') + raise ConfigurationError('Signing key file not a pem file. make sure the extension is pem.') if os.path.exists(keyfile): print(f'Loading signing key from {keyfile}...') @@ -194,7 +211,7 @@ def get_signing_key(): private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() # Change this if you want the key to be encrypted + encryption_algorithm=serialization.NoEncryption() ) ) @@ -212,7 +229,8 @@ def get_signing_key(): def seal_manifest(guild_id, channel_id): """ - This method seals the manifest file for a guild and channel by hashing the IDs, loading the manifest contents, + This method seals the manifest file for a guild and channel by hashing the IDs, + loading the manifest contents, generating a signature for the manifest, and writing the signature to a seal file. :param guild_id: the ID of the guild @@ -223,24 +241,26 @@ def seal_manifest(guild_id, channel_id): seal_path = manifest_path.replace('.manifest','.seal') # load the manifest contents - with open(manifest_path, 'r') as manifest: + with open(manifest_path, 'r', encoding='utf-8') as manifest: man_str = manifest.read() _, manifest_hash = hash_string(man_str) - man_signature = signing_key.sign(manifest_hash).hex() - with open(seal_path, 'w') as seal_file: + man_signature = SIGNING_KEY.sign(manifest_hash).hex() + with open(seal_path, 'w', encoding='utf-8') as seal_file: seal_file.write(man_signature) def get_manifest_path(guild_id, channel_id): channel_hash, _ = hash_string(str(channel_id)) guild_hash, _ = hash_string(str(guild_id)) - return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.manifest'), f'manifests/{guild_hash}-{channel_hash}' + return (os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.manifest'), + f'manifests/{guild_hash}-{channel_hash}') def get_manifest_seal_path(guild_id, channel_id): channel_hash, _ = hash_string(str(channel_id)) guild_hash, _ = hash_string(str(guild_id)) - return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.seal'), f'manifests/seals/{guild_hash}-{channel_hash}' + return (os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.seal'), + f'manifests/seals/{guild_hash}-{channel_hash}') async def backup_channel(channel, last_message_id): @@ -259,10 +279,12 @@ async def backup_channel(channel, last_message_id): # download manifest from S3 manifest_path, s3_manifest_path = get_manifest_path(channel.guild.id, channel.id) - if S3_ENABLED: + if is_s3_enabled(): try: - s3.head_object(Bucket=S3_BUCKET, Key=s3_manifest_path) - s3.download_file(Bucket=S3_BUCKET, Filename=manifest_path, Key=s3_manifest_path) + S3_CLIENT.head_object(Bucket=s3_bucket(), Key=s3_manifest_path) + S3_CLIENT.download_file(Bucket=s3_bucket(), + Filename=manifest_path, + Key=s3_manifest_path) except ClientError as e: if e.response['Error']['Code'] == '404': # The object does not exist. @@ -288,11 +310,11 @@ async def backup_channel(channel, last_message_id): seal_manifest(channel.guild.id, channel.id) # upload to S3 - if S3_ENABLED: + if is_s3_enabled(): # upload manifest and seal manifest_seal_path, s3_manifest_seal_path = get_manifest_seal_path(channel.guild.id, channel.id) - s3.upload_file(manifest_path, S3_BUCKET, s3_manifest_path) - s3.upload_file(manifest_seal_path, S3_BUCKET, s3_manifest_seal_path) + S3_CLIENT.upload_file(manifest_path, s3_bucket(), s3_manifest_path) + S3_CLIENT.upload_file(manifest_seal_path, s3_bucket(), s3_manifest_seal_path) return after.id @@ -306,7 +328,8 @@ def get_loc_path(channel): """ guild_hash, _ = hash_string(str(channel.guild.id)) channel_hash, _ = hash_string(str(channel.id)) - return os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.loc'), f'locations/{guild_hash}-{channel_hash}' + return (os.path.join(get_ephemeral_path(), f'{guild_hash}-{channel_hash}.loc'), + f'locations/{guild_hash}-{channel_hash}') async def get_last_message_id(channel): @@ -322,18 +345,17 @@ async def get_last_message_id(channel): loc_path, loc_s3_path = get_loc_path(channel) # If S3 is enabled go directly to S3 - if S3_ENABLED: + if is_s3_enabled(): try: - loc_obj = s3.get_object(Bucket=S3_BUCKET, Key=loc_s3_path) + loc_obj = S3_CLIENT.get_object(Bucket=s3_bucket(), Key=loc_s3_path) initial_content = loc_obj['Body'].read().decode() last_msg_id = int(initial_content) except ClientError: print("The location is non existent on S3.") - pass else: if os.path.exists(loc_path): - with open(loc_path, 'r') as f: + with open(loc_path, 'r', encoding='utf-8') as f: f_content = f.read() last_msg_id = int(f_content) return last_msg_id @@ -348,10 +370,10 @@ async def set_last_message_id(channel, new_last_msg_id): :return: None """ loc_path, loc_s3_path = get_loc_path(channel) - if S3_ENABLED: - s3.put_object(Bucket=S3_BUCKET, Key=loc_s3_path, Body=str(new_last_msg_id)) + if is_s3_enabled(): + S3_CLIENT.put_object(Bucket=s3_bucket(), Key=loc_s3_path, Body=str(new_last_msg_id)) else: - with open(loc_path, "w") as file: + with open(loc_path, "w", encoding='utf-8') as file: file.write(str(new_last_msg_id)) @@ -409,33 +431,33 @@ def generate_directory_file(target_channels, current_datetime): if not enc_msg.ok: print(f'Encryption failed: {enc_msg.status}') - raise Exception('Unable to encrypt') + raise CryptographyError('Unable to encrypt') # Write directory to storage - if S3_ENABLED: - s3.put_object(Bucket=S3_BUCKET, Key=f'directories/{iso8601_format}', Body=str(enc_msg)) + if is_s3_enabled(): + S3_CLIENT.put_object(Bucket=s3_bucket(), Key=f'directories/{iso8601_format}', Body=str(enc_msg)) else: - with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dir'), 'w') as file: + with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dir'), 'w', encoding='utf-8') as file: file.write(str(enc_msg)) # generate seal _, manifest_hash = hash_string(str(enc_msg)) - man_signature = signing_key.sign(manifest_hash).hex() + man_signature = SIGNING_KEY.sign(manifest_hash).hex() # write the directory to storage - if S3_ENABLED: - s3.put_object(Bucket=S3_BUCKET, Key=f'directories/seals/{iso8601_format}', Body=man_signature) + if is_s3_enabled(): + S3_CLIENT.put_object(Bucket=s3_bucket(), Key=f'directories/seals/{iso8601_format}', Body=man_signature) else: - with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dirseal'), 'w') as seal_file: + with open(os.path.join(get_ephemeral_path(), f'{iso8601_format}.dirseal'), 'w', encoding='utf-8') as seal_file: seal_file.write(man_signature) def load_gpg_keys(): print('Loading GPG keys...') - GPG_KEY_DIR = os.getenv('GPG_KEY_DIR') - if GPG_KEY_DIR is None: - raise Exception('No GPG key directory set. Please set GPG_KEY_DIR.') - key_files = glob.glob(os.path.join(GPG_KEY_DIR, '*.asc')) + gpg_key_dir = os.getenv('GPG_KEY_DIR') + if gpg_key_dir is None: + raise ConfigurationError('No GPG key directory set. Please set GPG_KEY_DIR.') + key_files = glob.glob(os.path.join(gpg_key_dir, '*.asc')) imported_keys = [gpg.import_keys_file(key_file) for key_file in key_files] # Trust keys @@ -449,33 +471,43 @@ def load_gpg_keys(): return [result.fingerprints[0] for result in imported_keys] -if __name__ == '__main__': - # Main starting - print('EF Backup Bot starting...') +def send_heartbeat(start=False): + url = os.getenv('HEARTBEAT_URL') - # Read config from env - TOKEN = os.getenv('DISCORD_TOKEN') - if TOKEN is None: - raise Exception('No discord token set. Please set DISCORD_TOKEN.') - - HEARTBEAT_URL = os.getenv('HEARTBEAT_URL') + if url is None: + print('No HEARTBEAT_URL set. Monitoring disabled.') + return + if start: + url = url + '/start' # Send a start signal to heartbeat try: - requests.get(HEARTBEAT_URL + "/start", timeout=5) + requests.get(url, timeout=5) except requests.exceptions.RequestException: # If the network request fails for any reason, we don't want # it to prevent the main job from running pass + +if __name__ == '__main__': + # Main starting + print('EF Backup Bot starting...') + + # Read config from env + TOKEN = os.getenv('DISCORD_TOKEN') + if TOKEN is None: + raise ConfigurationError('No discord token set. Please set DISCORD_TOKEN.') + + send_heartbeat(start=True) + # init S3 - s3 = init_s3() + S3_CLIENT = init_s3() # prepare gpg key_fingerprints = load_gpg_keys() # load the signing key - signing_key = get_signing_key() + SIGNING_KEY = get_signing_key() # Prepare discord connection intents = nextcord.Intents.default() @@ -507,37 +539,24 @@ async def on_ready(): print(f'Backing up Channel {channel.id} on {channel.guild.id}') # Backup channels - try: - last_msg_id = await get_last_message_id(channel) - new_last_msg_id = await backup_channel(channel, last_msg_id) - await set_last_message_id(channel, new_last_msg_id) - - except Exception as e: - print(f'Unable to backup: {e}') + last_msg_id = await get_last_message_id(channel) + new_last_msg_id = await backup_channel(channel, last_msg_id) + await set_last_message_id(channel, new_last_msg_id) # Backup threads in channel for thread in channel.threads: print(f'Backing up Thread {thread.id} in Channel {channel.id} on {channel.guild.id}') - try: - last_msg_id = await get_last_message_id(thread) - new_last_msg_id = await backup_channel(thread, last_msg_id) - await set_last_message_id(thread, new_last_msg_id) - - except Exception as e: - print(f'Unable to backup: {e}') + last_msg_id = await get_last_message_id(thread) + new_last_msg_id = await backup_channel(thread, last_msg_id) + await set_last_message_id(thread, new_last_msg_id) # Quit when done print('Notifying the heartbeat check...') - try: - requests.get(HEARTBEAT_URL, timeout=10) - except requests.exceptions.RequestException: - # If the network request fails for any reason, we don't want - # it to prevent the main job from running - pass + send_heartbeat() print('Done. exiting.') await client.close() # run the bot - still in main - client.run(TOKEN) \ No newline at end of file + client.run(TOKEN) diff --git a/requirements.txt b/requirements.txt index 8605a5e..87f61b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp==3.9.1 aiosignal==1.3.1 +astroid==3.0.2 async-timeout==4.0.3 attrs==23.2.0 boto3==1.34.15 @@ -7,19 +8,33 @@ botocore==1.34.15 certifi==2023.11.17 cffi==1.16.0 charset-normalizer==3.3.2 +coverage==7.4.0 cryptography==41.0.7 +dill==0.3.7 +exceptiongroup==1.2.0 frozenlist==1.4.1 idna==3.6 +iniconfig==2.0.0 +isort==5.13.2 jmespath==1.0.1 +mccabe==0.7.0 multidict==6.0.4 nextcord==2.6.0 +packaging==23.2 +platformdirs==4.1.0 +pluggy==1.3.0 pycparser==2.21 +pylint==3.0.3 +pytest==7.4.4 +pytest-asyncio==0.23.3 +pytest-mock==3.12.0 python-dateutil==2.8.2 python-gnupg==0.5.2 requests==2.31.0 s3transfer==0.10.0 six==1.16.0 +tomli==2.0.1 +tomlkit==0.12.3 typing_extensions==4.9.0 urllib3==1.26.18 -yarl==1.9.4 -pytest~=7.4.4 \ No newline at end of file +yarl==1.9.4 \ No newline at end of file diff --git a/test_main.py b/test_main.py index bee870a..9f35959 100644 --- a/test_main.py +++ b/test_main.py @@ -73,7 +73,7 @@ def test_get_loc_s3(self,mocker,channel_id, guild_id, expected_path_loc, monkeyp './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f.manifest' ] ]) - def test_get_manifest_path(self,mocker,channel_id, guild_id, expected_path_loc): + def test_get_manifest_path(self,channel_id, guild_id, expected_path_loc): main.EPHEMERAL_PATH = "./" path_loc, _ = main.get_manifest_path(guild_id, channel_id) assert path_loc == expected_path_loc @@ -90,7 +90,7 @@ def test_get_manifest_path(self,mocker,channel_id, guild_id, expected_path_loc): 'manifests/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f' ] ]) - def test_get_manifest_s3(self,mocker,channel_id, guild_id, expected_path_loc): + def test_get_manifest_s3(self,channel_id, guild_id, expected_path_loc): main.EPHEMERAL_PATH = "./" _, s3_loc = main.get_manifest_path(guild_id, channel_id) assert s3_loc == expected_path_loc @@ -107,7 +107,7 @@ def test_get_manifest_s3(self,mocker,channel_id, guild_id, expected_path_loc): './38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f.seal' ] ]) - def test_get_manifest_seal_path(self, mocker, channel_id, guild_id, expected_path_loc): + def test_get_manifest_seal_path(self, channel_id, guild_id, expected_path_loc): main.EPHEMERAL_PATH = "./" path_loc, _ = main.get_manifest_seal_path(guild_id, channel_id) assert path_loc == expected_path_loc @@ -124,7 +124,7 @@ def test_get_manifest_seal_path(self, mocker, channel_id, guild_id, expected_pat 'manifests/seals/38db8582a710659583b9e1878e5c4fa96983292b13908824dd64dba7737de6e4-e88257514b2e0049279d50793b0ff55d6da824d793093ee47df1925205b6385f' ] ]) - def test_get_manifest_seal_s3(self, mocker, channel_id, guild_id, expected_path_loc): + def test_get_manifest_seal_s3(self, channel_id, guild_id, expected_path_loc): main.EPHEMERAL_PATH = "./" _, s3_loc = main.get_manifest_seal_path(guild_id, channel_id) assert s3_loc == expected_path_loc @@ -136,33 +136,34 @@ class TestConfiguration: ["1", "the-bucket", "12345-access", "12345-secret", "http://localhost:9000"], ]) def test_successful_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint): - # prep env - monkeypatch.setenv("S3_ENABLED", s3enabled) - monkeypatch.setenv("S3_BUCKET", s3bucket) - monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) - monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) - monkeypatch.setenv("S3_ENDPOINT", s3endpoint) + with monkeypatch.context() as m: + # prep env + m.setenv("S3_ENABLED", s3enabled) + m.setenv("S3_BUCKET", s3bucket) + m.setenv("S3_ACCESS_KEY", s3accesskey) + m.setenv("S3_SECRET_KEY", s3secretkey) + m.setenv("S3_ENDPOINT", s3endpoint) - s3_obj = main.init_s3() + s3_obj = main.init_s3() - assert s3_obj is not None - assert s3_obj.meta.endpoint_url == s3endpoint - assert s3_obj.meta.region_name == "us-east-1" + assert s3_obj is not None + assert s3_obj.meta.endpoint_url == s3endpoint + assert s3_obj.meta.region_name == "us-east-1" @pytest.mark.parametrize("s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint", [ ["0", "the-bucket", "12345-access", "12345-secret", "http://localhost:9000"], ]) def test_disabled_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint): - # prep env - monkeypatch.setenv("S3_ENABLED", s3enabled) - monkeypatch.setenv("S3_BUCKET", s3bucket) - monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) - monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) - monkeypatch.setenv("S3_ENDPOINT", s3endpoint) + with monkeypatch.context() as m: + # prep env + m.setenv("S3_ENABLED", s3enabled) + m.setenv("S3_BUCKET", s3bucket) + m.setenv("S3_ACCESS_KEY", s3accesskey) + m.setenv("S3_SECRET_KEY", s3secretkey) + m.setenv("S3_ENDPOINT", s3endpoint) - s3_obj = main.init_s3() - - assert s3_obj is None + s3_obj = main.init_s3() + assert s3_obj is None @pytest.mark.parametrize("s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint, expException", [ ["1", None, "12345-access", "12345-secret", "http://localhost:9000", "S3 enabled but S3_BUCKET not set."], @@ -171,19 +172,20 @@ def test_disabled_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s ["1", "bucket", "12345-access", "12345-secret", None, "S3 enabled but S3_ENDPOINT not set."], ]) def test_failed_s3_init(self, monkeypatch, s3enabled, s3bucket, s3accesskey, s3secretkey, s3endpoint, expException): - # prep env - monkeypatch.setenv("S3_ENABLED", s3enabled) - if s3bucket is not None: - monkeypatch.setenv("S3_BUCKET", s3bucket) - if s3accesskey is not None: - monkeypatch.setenv("S3_ACCESS_KEY", s3accesskey) - if s3secretkey is not None: - monkeypatch.setenv("S3_SECRET_KEY", s3secretkey) - if s3endpoint is not None: - monkeypatch.setenv("S3_ENDPOINT", s3endpoint) - - with pytest.raises(Exception, match=expException): - main.init_s3() + with monkeypatch.context() as m: + # prep env + m.setenv("S3_ENABLED", s3enabled) + if s3bucket is not None: + m.setenv("S3_BUCKET", s3bucket) + if s3accesskey is not None: + m.setenv("S3_ACCESS_KEY", s3accesskey) + if s3secretkey is not None: + m.setenv("S3_SECRET_KEY", s3secretkey) + if s3endpoint is not None: + m.setenv("S3_ENDPOINT", s3endpoint) + + with pytest.raises(Exception, match=expException): + main.init_s3() class TestMisc: @@ -196,22 +198,23 @@ def test_hash_string(self, hash_input, expected): @pytest.mark.asyncio async def test_last_message_id(self,monkeypatch,tmp_path,mocker): - msg_id = 123456543453 - monkeypatch.setenv("EPHEMERAL_PATH", str(tmp_path.absolute())) - channel_mock = mocker.Mock() - channel_mock.id = 1234 - channel_mock.guild.id = 5678 + with monkeypatch.context() as m: + msg_id = 123456543453 + m.setenv("EPHEMERAL_PATH", str(tmp_path.absolute())) + channel_mock = mocker.Mock() + channel_mock.id = 1234 + channel_mock.guild.id = 5678 - # Get empty id - empty_channel = await main.get_last_message_id(channel_mock) - assert empty_channel == -1 + # Get empty id + empty_channel = await main.get_last_message_id(channel_mock) + assert empty_channel == -1 - # store an id - await main.set_last_message_id(channel_mock, msg_id) + # store an id + await main.set_last_message_id(channel_mock, msg_id) - # recall id - non_empty_channel = await main.get_last_message_id(channel_mock) - assert non_empty_channel == msg_id + # recall id + non_empty_channel = await main.get_last_message_id(channel_mock) + assert non_empty_channel == msg_id class TestMessage: def test_extract_basic_message(self, mocker): @@ -331,226 +334,231 @@ def test_extract_message_with_attachment(self, mocker): class TestSigning: def test_load_sign_key(self, tmp_path, monkeypatch): - key_file = tmp_path / 'priv_key.pem' - monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + with monkeypatch.context() as m: + key_file = tmp_path / 'priv_key.pem' + m.setenv('SIGN_KEY_PEM', str(key_file.absolute())) - # Should generate a new signing key - priv_key = main.get_signing_key() + # Should generate a new signing key + priv_key = main.get_signing_key() - # check ifs an ED25519 private key - assert isinstance(priv_key, ed25519.Ed25519PrivateKey) + # check ifs an ED25519 private key + assert isinstance(priv_key, ed25519.Ed25519PrivateKey) - # check if the pubkey was generated - assert os.path.exists(str(key_file.absolute()).replace('.pem','.pub')) + # check if the pubkey was generated + assert os.path.exists(str(key_file.absolute()).replace('.pem','.pub')) - # check if the perms are correct - key_file_stat = os.stat(str(key_file.absolute())) - assert key_file_stat.st_mode == 0o100600 + # check if the perms are correct + key_file_stat = os.stat(str(key_file.absolute())) + assert key_file_stat.st_mode == 0o100600 - # Should load the key from disk - priv_key_second = main.get_signing_key() + # Should load the key from disk + priv_key_second = main.get_signing_key() - # should be the same key - compare pubkeys - assert priv_key.public_key() == priv_key_second.public_key() + # should be the same key - compare pubkeys + assert priv_key.public_key() == priv_key_second.public_key() @pytest.mark.parametrize('filename, expException',[ ['key.foo', 'Signing key file not a pem file. make sure the extension is pem.'], [None, 'Signing key not configured. Please set SIGN_KEY_PEM.'] ]) def test_load_sign_key_fail(self, tmp_path, monkeypatch, filename, expException): - if filename is not None: - key_file = tmp_path / filename - monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + with monkeypatch.context() as m: + if filename is not None: + key_file = tmp_path / filename + m.setenv('SIGN_KEY_PEM', str(key_file.absolute())) - # Should generate a new signing key - with pytest.raises(Exception, match=expException): - main.get_signing_key() + # Should generate a new signing key + with pytest.raises(Exception, match=expException): + main.get_signing_key() class TestStorage: def test_write_to_storage_fail_encrypt_no_keys(self, tmp_path, monkeypatch): - # message doesn't need to conform to format completely - dt_iso = datetime.datetime.utcnow().isoformat() - msg = { - 'content': 'hello world', - 'author': 'Testy McTestface', - 'server': { - 'id': 1234 - }, - 'channel': { - 'id': 1234 - }, - 'created_at': dt_iso - } - - # set the ephemeral path - monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) - - # No GPG keys loaded so it should fail here - with pytest.raises(Exception, match="No recipients specified with asymmetric encryption"): - main.write_to_storage(msg) + with monkeypatch.context() as m: + # message doesn't need to conform to format completely + dt_iso = datetime.datetime.utcnow().isoformat() + msg = { + 'content': 'hello world', + 'author': 'Testy McTestface', + 'server': { + 'id': 1234 + }, + 'channel': { + 'id': 1234 + }, + 'created_at': dt_iso + } + + # set the ephemeral path + m.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # No GPG keys loaded so it should fail here + with pytest.raises(Exception, match="No recipients specified with asymmetric encryption"): + main.write_to_storage(msg) def test_write_to_storage(self, tmp_path, monkeypatch): - id = 1234 - hash_for_id = "03ac674216f3e15c761ee1a5e255f067953623c8b388b4459e13f978d7c846f4" - - # message doesn't need to conform to format completely - dt_iso = datetime.datetime.utcnow().isoformat() - msg = { - 'content': 'hello world', - 'author': 'Testy McTestface', - 'server': { - 'id': id - }, - 'channel': { - 'id': id - }, - 'created_at': dt_iso - } - - # set the ephemeral path - monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) - - # generate a signing key - key_file = tmp_path / 'priv_key.pem' - monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) - main.signing_key = main.get_signing_key() - - # load a GPG key - pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- - -mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS -pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 -ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD -BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed -yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW -9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z -ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX -o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO -UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l -Ag== -=IhBg ------END PGP PUBLIC KEY BLOCK----- -''' - - monkeypatch.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) - gpg_pub_key = tmp_path / 'gpg_pub_key.asc' - gpg_pub_key.write_text(pub_key) - - main.key_fingerprints = main.load_gpg_keys() - - main.write_to_storage(msg) - - # a manifest should be in the tmp path now - result_manifest_path = tmp_path / f'{hash_for_id}-{hash_for_id}.manifest' - assert os.path.exists(result_manifest_path) - - # read manifest - with open(result_manifest_path, 'r') as f: - manifest_content = f.read() - manifest_fields = manifest_content.split(',') - - # we expect 3 fields in the manifest - assert len(manifest_fields) == 3 - - # we expect the first to be the iso date of the message - assert manifest_fields[0] == dt_iso - - # check if the msg file exists - msg_hash = manifest_fields[1] - - # Check if the filename is correct - msg_path = tmp_path / f'{msg_hash}.msg' - assert os.path.exists(str(msg_path.absolute())) + with monkeypatch.context() as m: + test_id = 1234 + hash_for_id = "03ac674216f3e15c761ee1a5e255f067953623c8b388b4459e13f978d7c846f4" + + # message doesn't need to conform to format completely + dt_iso = datetime.datetime.utcnow().isoformat() + msg = { + 'content': 'hello world', + 'author': 'Testy McTestface', + 'server': { + 'id': test_id + }, + 'channel': { + 'id': test_id + }, + 'created_at': dt_iso + } + + # set the ephemeral path + m.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # generate a signing key + key_file = tmp_path / 'priv_key.pem' + m.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + main.SIGNING_KEY = main.get_signing_key() + + # load a GPG key + pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- + + mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS + pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 + ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD + BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed + yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW + 9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z + ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX + o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO + UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l + Ag== + =IhBg + -----END PGP PUBLIC KEY BLOCK----- + ''' + + m.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) + gpg_pub_key = tmp_path / 'gpg_pub_key.asc' + gpg_pub_key.write_text(pub_key) + + main.key_fingerprints = main.load_gpg_keys() + + main.write_to_storage(msg) + + # a manifest should be in the tmp path now + result_manifest_path = tmp_path / f'{hash_for_id}-{hash_for_id}.manifest' + assert os.path.exists(result_manifest_path) + + # read manifest + with open(result_manifest_path, 'r') as f: + manifest_content = f.read() + manifest_fields = manifest_content.split(',') + + # we expect 3 fields in the manifest + assert len(manifest_fields) == 3 + + # we expect the first to be the iso date of the message + assert manifest_fields[0] == dt_iso + + # check if the msg file exists + msg_hash = manifest_fields[1] + + # Check if the filename is correct + msg_path = tmp_path / f'{msg_hash}.msg' + assert os.path.exists(str(msg_path.absolute())) def test_write_directory_file(self, tmp_path, monkeypatch, mocker): - # disable the test for nextcord.TextChannel - mocker.patch('__main__.isinstance', return_value=True) - - # set the ephemeral path - monkeypatch.setenv('EPHEMERAL_PATH', str(tmp_path)) - - # generate a signing key - key_file = tmp_path / 'priv_key.pem' - monkeypatch.setenv('SIGN_KEY_PEM', str(key_file.absolute())) - main.signing_key = main.get_signing_key() - - # load a GPG key - pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- - - mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS - pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 - ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD - BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed - yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW - 9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z - ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX - o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO - UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l - Ag== - =IhBg - -----END PGP PUBLIC KEY BLOCK----- - ''' - - monkeypatch.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) - gpg_pub_key = tmp_path / 'gpg_pub_key.asc' - gpg_pub_key.write_text(pub_key) - - main.key_fingerprints = main.load_gpg_keys() - - # the server list mock from discord - target_channels = [] - - channel1 = mocker.Mock(spec=nextcord.TextChannel) - channel1.guild.id = 1111 - channel1.guild.name = "Server 1" - channel1.id = 110011 - channel1.name = "Channel 1" - channel1.threads = [] - target_channels.append(channel1) - - channel2 = mocker.Mock(spec=nextcord.TextChannel) - channel2.guild.id = 1111 - channel2.guild.name = "Server 1" - channel2.id = 220011 - channel2.name = "Channel 2" - channel2.threads = [] - target_channels.append(channel2) - - channel3 = mocker.Mock(spec=nextcord.TextChannel) - channel3.guild.id = 2222 - channel3.guild.name = "Server 2" - channel3.id = 220022 - channel3.name = "Channel 1" - channel3.threads = [] - target_channels.append(channel3) - - # with threads - thread1 = mocker.Mock() - thread1.name = "Thread 1" - thread1.id = 121212 - - channel4 = mocker.Mock(spec=nextcord.TextChannel) - channel4.guild.id = 2222 - channel4.guild.name = "Server 2" - channel4.id = 330022 - channel4.name = "Channel 2" - channel4.threads = [thread1] - target_channels.append(channel4) - - channel5 = mocker.Mock(spec=nextcord.VoiceChannel) - target_channels.append(channel5) - - - dt = datetime.datetime.now() - iso8601_format = dt.isoformat().replace(':', '-').replace('.', '-') - - main.generate_directory_file(target_channels, dt) - - # check if a directory was written - result_manifest_path = tmp_path / f'{iso8601_format}.dir' - assert os.path.exists(result_manifest_path) - - # check if a directory seal was written - result_manifest_path = tmp_path / f'{iso8601_format}.dirseal' - assert os.path.exists(result_manifest_path) \ No newline at end of file + with monkeypatch.context() as m: + # disable the test for nextcord.TextChannel + mocker.patch('__main__.isinstance', return_value=True) + + # set the ephemeral path + m.setenv('EPHEMERAL_PATH', str(tmp_path)) + + # generate a signing key + key_file = tmp_path / 'priv_key.pem' + m.setenv('SIGN_KEY_PEM', str(key_file.absolute())) + main.SIGNING_KEY = main.get_signing_key() + + # load a GPG key + pub_key = '''-----BEGIN PGP PUBLIC KEY BLOCK----- + + mDMEZZ6XXxYJKwYBBAHaRw8BAQdAea3323zBNgy12RVKkCWWgfDe5vSLW3R9/6LS + pqE/hxG0MUdQRyB0ZXN0IGtleSAoT05MWSBGT1IgVEVTVElORykgPG1hcmt1c0B0 + ZXN0Lm9yZz6IkwQTFgoAOxYhBMuek9p7pwAmbyYg1qWXo028DaarBQJlnpdfAhsD + BQsJCAcCAiICBhUKCQgLAgQWAgMBAh4HAheAAAoJEKWXo028DaarUU8BAOyAmxed + yWBHajYaEoyn0wfSEGIFVCXatsvcbYpL6hc+AQCrn/t+oC/OqrO4HWPhQDAEgYtW + 9TWOC3A6CYyodYdPD7g4BGWel18SCisGAQQBl1UBBQEBB0DLccDTMTVh0a7Su94Z + ktDBAzTjYzQ5j2sxKe/OkK2VGQMBCAeIeAQYFgoAIBYhBMuek9p7pwAmbyYg1qWX + o028DaarBQJlnpdfAhsMAAoJEKWXo028DaarD+EA/0SIgap5bj9FqE+TwVNILLuO + UiwX/3AQaMi36RJ9oZYKAP9gIkwaL/m0Xu8WQiUNkATCHFsmauptqQw5V8GkSp0l + Ag== + =IhBg + -----END PGP PUBLIC KEY BLOCK----- + ''' + + m.setenv('GPG_KEY_DIR', str(tmp_path.absolute())) + gpg_pub_key = tmp_path / 'gpg_pub_key.asc' + gpg_pub_key.write_text(pub_key) + + main.key_fingerprints = main.load_gpg_keys() + + # the server list mock from discord + target_channels = [] + + channel1 = mocker.Mock(spec=nextcord.TextChannel) + channel1.guild.id = 1111 + channel1.guild.name = "Server 1" + channel1.id = 110011 + channel1.name = "Channel 1" + channel1.threads = [] + target_channels.append(channel1) + + channel2 = mocker.Mock(spec=nextcord.TextChannel) + channel2.guild.id = 1111 + channel2.guild.name = "Server 1" + channel2.id = 220011 + channel2.name = "Channel 2" + channel2.threads = [] + target_channels.append(channel2) + + channel3 = mocker.Mock(spec=nextcord.TextChannel) + channel3.guild.id = 2222 + channel3.guild.name = "Server 2" + channel3.id = 220022 + channel3.name = "Channel 1" + channel3.threads = [] + target_channels.append(channel3) + + # with threads + thread1 = mocker.Mock() + thread1.name = "Thread 1" + thread1.id = 121212 + + channel4 = mocker.Mock(spec=nextcord.TextChannel) + channel4.guild.id = 2222 + channel4.guild.name = "Server 2" + channel4.id = 330022 + channel4.name = "Channel 2" + channel4.threads = [thread1] + target_channels.append(channel4) + + channel5 = mocker.Mock(spec=nextcord.VoiceChannel) + target_channels.append(channel5) + + dt = datetime.datetime.now() + iso8601_format = dt.isoformat().replace(':', '-').replace('.', '-') + + main.generate_directory_file(target_channels, dt) + + # check if a directory was written + result_manifest_path = tmp_path / f'{iso8601_format}.dir' + assert os.path.exists(result_manifest_path) + + # check if a directory seal was written + result_manifest_path = tmp_path / f'{iso8601_format}.dirseal' + assert os.path.exists(result_manifest_path) +