Skip to content

Commit

Permalink
Merge pull request #1 from AaltoRSE/db_interfaces
Browse files Browse the repository at this point in the history
Db interfaces
  • Loading branch information
tpfau authored Nov 29, 2023
2 parents 7d5a634 + d3a2f3c commit af988b3
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,6 @@ bundles/
vendor/pkg/
pyenv
Vagrantfile

# vscode stuff
.vscode/**/*
140 changes: 140 additions & 0 deletions app/utils/key_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import redis
import pymongo
import secrets
import string
import os
import urllib


class KeyHandler:
def __init__(self, testing: bool = False):
if not testing:
# Needs to be escaped if necessary
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER"))
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD"))

# Set up required endpoints. The
mongo_client = pymongo.MongoClient(
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password)
)
redis_client = redis.StrictRedis(host="redis", port=6379, db=0)
self.setup(mongo_client, redis_client)

def setup(self, mongo_client: pymongo.MongoClient, redis_client: redis.Redis):
self.redis_client = redis_client
self.mongo_client = mongo_client
self.db = mongo_client["gateway"]
self.key_collection = self.db["apikeys"]
keyindices = self.key_collection.index_information()
# Make sure, that key is an index (avoids duplicates);
if not "key" in keyindices:
self.key_collection.create_index("key", unique=True)
self.user_collection = self.db["users"]
# Make sure, that username is an index (avoids duplicates when creating keys, which automatically adds a user if necessary);
userindices = self.user_collection.index_information()
if not "username" in userindices:
self.user_collection.create_index("username", unique=True)

def generate_api_key(self, length: int = 64):
"""
Function to generate an API key.
Parameters:
- length (int, optional): Length of the generated API key. Defaults to 64.
Returns:
- str: The generated API key.
"""
alphabet = string.ascii_letters + string.digits
api_key = "".join(secrets.choice(alphabet) for _ in range(length))
return api_key

def build_new_key_object(self, key: string, name: string):
"""
Function to create a new key object.
Parameters:
- key (str): The key value.
- name (str): The name associated with the key.
Returns:
- dict: A dictionary representing the key object with "active" status, key, and name.
"""
return {"active": True, "key": key, "name": name}

def check_key(self, key: string):
"""
Function to check if a key currently exists
Parameters:
- key (str): The key to check.
Returns:
- bool: True if the key exists
"""
return self.redis_client.sismember("keys", key)

def delete_key(self, key: string, user: string):
"""
Function to delete an existing key
Parameters:
- key (str): The key to check.
- user (str): The user that requests this deletion
"""
updated_user = self.user_collection.find_one_and_update(
{"username": user, "keys": {"$elemMatch": {"$eq": key}}},
{"$pull": {"keys": key}},
)
if not updated_user == None:
# We found, and updated the user, so we can remove the key
# removal should be instantaneous
self.key_collection.delete_one({"key": key})
self.redis_client.srem("keys", key)

def set_key_activity(self, key: string, user: string, active: bool):
"""
Function to set whether a key is active or not.
The key has to be owned by the user indicated.
Parameters:
- key (str): The key to check.
- user (str): The user that requests this deletion
- active (bool): whether to activate or deactivate the key
"""
user_has_key = self.user_collection.find_one(
{"username": user, "keys": {"$elemMatch": {"$eq": key}}}
)
if not user_has_key == None:
# the requesting user has access to this key
self.key_collection.update_one({"key": key}, {"$set": {"active": active}})
if active:
self.redis_client.sadd("keys", key)
else:
self.redis_client.srem("keys", key)

def create_key(self, user: string, name: string):
"""
Generates a unique API key and associates it with a specified user.
Args:
- user: Username of the user to whom the API key will be associated.
- name: Name or label for the API key.
Returns:
- api_key: The generated unique API key associated with the user.
"""
key_created = False
api_key = ""
while not key_created:
api_key = self.generate_api_key()
found = self.key_collection.find_one({"key": api_key})
if found == None:
self.key_collection.insert_one(self.build_new_key_object(api_key, name))
self.user_collection.update_one(
{"username": user}, {"$addToSet": {"keys": api_key}}, upsert=True
)
self.redis_client.sadd("keys", api_key)
key_created = True
return api_key
79 changes: 79 additions & 0 deletions app/utils/logging_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pymongo
from datetime import datetime
import os
import urllib


# Needs to be escaped if necessary
class LoggingHandler:
def __init__(self, testing=False):
if not testing:
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER"))
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD"))
# Set up required endpoints.
mongo_client = pymongo.MongoClient(
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password)
)
self.setup(mongo_client)

def setup(self, mongo_client):
self.db = mongo_client["gateway"]
self.log_collection = self.db["logs"]
self.user_collection = self.db["users"]

def create_log_entry(self, tokencount, model, source, sourcetype="apikey"):
"""
Function to create a log entry.
Parameters:
- tokencount (int): The count of tokens.
- model (str): The model related to the log entry.
- source (str): The source that authorized the request that is being logged. This could be a user name or an apikey.
- sourcetype (str): Specification of what kind of source authorized the request that is being logged (either 'apikey' or 'user').
Returns:
- dict: A dictionary representing the log entry with timestamp.
"""
return {
"tokencount": tokencount,
"model": model,
"source": source,
"sourcetype": sourcetype,
"timestamp": datetime.utcnow(), # Current timestamp in UTC
}

def log_usage_for_key(self, tokencount, model, key):
"""
Function to log usage for a specific key.
Parameters:
- tokencount (int): The count of tokens used.
- model (str): The model associated with the usage.
- key (str): The key for which the usage is logged.
"""
log_entry = self.create_log_entry(tokencount, model, key)
self.log_collection.insert_one(log_entry)

def log_usage_for_user(self, tokencount, model, user):
"""
Function to log usage for a specific user.
Parameters:
- tokencount (int): The count of tokens used.
- model (str): The model associated with the usage.
- user (str): The user for which the usage is logged.
"""
log_entry = self.create_log_entry(tokencount, model, user, "user")
self.log_collection.insert_one(log_entry)

def get_usage_for_user(self, username):
raise NotImplementedError

def get_usage_for_key(self, key):
raise NotImplementedError

def get_usage_for_model(self, model):
raise NotImplementedError

def get_usage_for_timerange(self, start, end):
raise NotImplementedError
68 changes: 68 additions & 0 deletions app/utils/redis_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import redis
import pymongo
import schedule
import time
import urllib
import os

# Needs to be escaped if necessary
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER"))
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD"))


class RedisUpdater:
def __init__(self):
"""
Initializes Redis and MongoDB connections.
"""
# Redis connection
self.redis_client = redis.StrictRedis(host="redis", port=6379, db=0)
self.entries = [] # Placeholder for entries fetched from MongoDB

# MongoDB connection
self.mongo_client = mongo_client = pymongo.MongoClient(
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password)
)
self.db = mongo_client["gateway"]
self.keyCollection = self.db["apikeys"]

def update_redis(self):
"""
Updates Redis with API keys fetched from MongoDB.
"""
# Fetch entries from MongoDB
self.fetch_entries_from_mongodb()
# Clean up the current keys, and then add the new ones.
# optimally, this would be done in an atomic call, but we will have to
# see how often someone
self.redis_client.delete("keys")
# Update Redis with fetched entries
self.redis_client.sadd("keys", *self.entries)

def fetch_entries_from_mongodb(self):
"""
Retrieves active API keys from MongoDB and stores them in self.entries.
"""
# Retrieve entries from MongoDB
currentKeys = self.collection.find({"active": True})
self.entries = [
entry["APIKEY"] for entry in self.collection.find({}, {"key": 1})
]

def start_scheduler(self):
"""
Initiates a scheduler to run update_redis every 15 minutes continuously.
"""
# Schedule the update every 15 minutes
schedule.every(15).minutes.do(self.update_redis)

# Run the scheduler continuously
while True:
schedule.run_pending()
# Only do this every 60 seconds - thats often enough.
time.sleep(60)


# Create an instance of the RedisUpdater and start the scheduler
updater = RedisUpdater()
updater.start_scheduler()
77 changes: 77 additions & 0 deletions app/utils/test_key_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from pytest_mock_resources import create_redis_fixture
from pytest_mock_resources import create_mongo_fixture
from key_handler import KeyHandler
import redis
import pymongo

redis = create_redis_fixture()
mongo = create_mongo_fixture()


def test_check_key(redis, mongo):
handler = KeyHandler(True)
handler.setup(mongo, redis)
keys = ["ABC", "DEF"]
redis.sadd("keys", *keys)
assert handler.check_key("ABC") == True
assert handler.check_key("DEF") == True
assert handler.check_key("BCE") == False


def test_create_key(redis, mongo):
handler = KeyHandler(True)
handler.setup(mongo, redis)
newKey = handler.create_key("NewUser", "NewKey")
db = mongo["gateway"]
user_collection = db["users"]
key_collection = db["apikeys"]
assert user_collection.count_documents({}) == 1
user = user_collection.find_one({})
assert user["username"] == "NewUser"
assert len(user["keys"]) == 1
assert key_collection.count_documents({}) == 1
assert handler.check_key(newKey) == True


def test_delete_key(redis, mongo):
handler = KeyHandler(True)
handler.setup(mongo, redis)
newKey = handler.create_key("NewUser", "NewKey")
db = mongo["gateway"]
user_collection = db["users"]
key_collection = db["apikeys"]
assert user_collection.count_documents({}) == 1
assert key_collection.count_documents({}) == 1
assert handler.check_key(newKey) == True
handler.delete_key(newKey, "NewUser")
user = user_collection.find_one({})
assert user["username"] == "NewUser"
assert len(user["keys"]) == 0
assert key_collection.count_documents({}) == 0
assert handler.check_key(newKey) == False


def test_set_key_activity(redis, mongo):
handler = KeyHandler(True)
handler.setup(mongo, redis)
# Create key
newKey = handler.create_key("NewUser", "NewKey")
db = mongo["gateway"]
user_collection = db["users"]
key_collection = db["apikeys"]
# Make sure key is valid at start
assert handler.check_key(newKey) == True
# deactivate key
handler.set_key_activity(newKey, "NewUser", False)
key_data = key_collection.find_one({"key": newKey})
assert key_data["active"] == False
# Ensure key still exists
user = user_collection.find_one({"username": "NewUser"})
assert len(user["keys"]) == 1
# and key is inactive
assert handler.check_key(newKey) == False
# reactivate and test that the key is now active again
handler.set_key_activity(newKey, "NewUser", True)
key_data = key_collection.find_one({"key": newKey})
assert key_data["active"] == True
assert handler.check_key(newKey) == True
Loading

0 comments on commit af988b3

Please sign in to comment.