-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from AaltoRSE/db_interfaces
Db interfaces
- Loading branch information
Showing
7 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,3 +185,6 @@ bundles/ | |
vendor/pkg/ | ||
pyenv | ||
Vagrantfile | ||
|
||
# vscode stuff | ||
.vscode/**/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import redis | ||
import pymongo | ||
import secrets | ||
import string | ||
import os | ||
import urllib | ||
|
||
|
||
class KeyHandler: | ||
def __init__(self, testing: bool = False): | ||
if not testing: | ||
# Needs to be escaped if necessary | ||
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER")) | ||
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD")) | ||
|
||
# Set up required endpoints. The | ||
mongo_client = pymongo.MongoClient( | ||
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password) | ||
) | ||
redis_client = redis.StrictRedis(host="redis", port=6379, db=0) | ||
self.setup(mongo_client, redis_client) | ||
|
||
def setup(self, mongo_client: pymongo.MongoClient, redis_client: redis.Redis): | ||
self.redis_client = redis_client | ||
self.mongo_client = mongo_client | ||
self.db = mongo_client["gateway"] | ||
self.key_collection = self.db["apikeys"] | ||
keyindices = self.key_collection.index_information() | ||
# Make sure, that key is an index (avoids duplicates); | ||
if not "key" in keyindices: | ||
self.key_collection.create_index("key", unique=True) | ||
self.user_collection = self.db["users"] | ||
# Make sure, that username is an index (avoids duplicates when creating keys, which automatically adds a user if necessary); | ||
userindices = self.user_collection.index_information() | ||
if not "username" in userindices: | ||
self.user_collection.create_index("username", unique=True) | ||
|
||
def generate_api_key(self, length: int = 64): | ||
""" | ||
Function to generate an API key. | ||
Parameters: | ||
- length (int, optional): Length of the generated API key. Defaults to 64. | ||
Returns: | ||
- str: The generated API key. | ||
""" | ||
alphabet = string.ascii_letters + string.digits | ||
api_key = "".join(secrets.choice(alphabet) for _ in range(length)) | ||
return api_key | ||
|
||
def build_new_key_object(self, key: string, name: string): | ||
""" | ||
Function to create a new key object. | ||
Parameters: | ||
- key (str): The key value. | ||
- name (str): The name associated with the key. | ||
Returns: | ||
- dict: A dictionary representing the key object with "active" status, key, and name. | ||
""" | ||
return {"active": True, "key": key, "name": name} | ||
|
||
def check_key(self, key: string): | ||
""" | ||
Function to check if a key currently exists | ||
Parameters: | ||
- key (str): The key to check. | ||
Returns: | ||
- bool: True if the key exists | ||
""" | ||
return self.redis_client.sismember("keys", key) | ||
|
||
def delete_key(self, key: string, user: string): | ||
""" | ||
Function to delete an existing key | ||
Parameters: | ||
- key (str): The key to check. | ||
- user (str): The user that requests this deletion | ||
""" | ||
updated_user = self.user_collection.find_one_and_update( | ||
{"username": user, "keys": {"$elemMatch": {"$eq": key}}}, | ||
{"$pull": {"keys": key}}, | ||
) | ||
if not updated_user == None: | ||
# We found, and updated the user, so we can remove the key | ||
# removal should be instantaneous | ||
self.key_collection.delete_one({"key": key}) | ||
self.redis_client.srem("keys", key) | ||
|
||
def set_key_activity(self, key: string, user: string, active: bool): | ||
""" | ||
Function to set whether a key is active or not. | ||
The key has to be owned by the user indicated. | ||
Parameters: | ||
- key (str): The key to check. | ||
- user (str): The user that requests this deletion | ||
- active (bool): whether to activate or deactivate the key | ||
""" | ||
user_has_key = self.user_collection.find_one( | ||
{"username": user, "keys": {"$elemMatch": {"$eq": key}}} | ||
) | ||
if not user_has_key == None: | ||
# the requesting user has access to this key | ||
self.key_collection.update_one({"key": key}, {"$set": {"active": active}}) | ||
if active: | ||
self.redis_client.sadd("keys", key) | ||
else: | ||
self.redis_client.srem("keys", key) | ||
|
||
def create_key(self, user: string, name: string): | ||
""" | ||
Generates a unique API key and associates it with a specified user. | ||
Args: | ||
- user: Username of the user to whom the API key will be associated. | ||
- name: Name or label for the API key. | ||
Returns: | ||
- api_key: The generated unique API key associated with the user. | ||
""" | ||
key_created = False | ||
api_key = "" | ||
while not key_created: | ||
api_key = self.generate_api_key() | ||
found = self.key_collection.find_one({"key": api_key}) | ||
if found == None: | ||
self.key_collection.insert_one(self.build_new_key_object(api_key, name)) | ||
self.user_collection.update_one( | ||
{"username": user}, {"$addToSet": {"keys": api_key}}, upsert=True | ||
) | ||
self.redis_client.sadd("keys", api_key) | ||
key_created = True | ||
return api_key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pymongo | ||
from datetime import datetime | ||
import os | ||
import urllib | ||
|
||
|
||
# Needs to be escaped if necessary | ||
class LoggingHandler: | ||
def __init__(self, testing=False): | ||
if not testing: | ||
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER")) | ||
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD")) | ||
# Set up required endpoints. | ||
mongo_client = pymongo.MongoClient( | ||
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password) | ||
) | ||
self.setup(mongo_client) | ||
|
||
def setup(self, mongo_client): | ||
self.db = mongo_client["gateway"] | ||
self.log_collection = self.db["logs"] | ||
self.user_collection = self.db["users"] | ||
|
||
def create_log_entry(self, tokencount, model, source, sourcetype="apikey"): | ||
""" | ||
Function to create a log entry. | ||
Parameters: | ||
- tokencount (int): The count of tokens. | ||
- model (str): The model related to the log entry. | ||
- source (str): The source that authorized the request that is being logged. This could be a user name or an apikey. | ||
- sourcetype (str): Specification of what kind of source authorized the request that is being logged (either 'apikey' or 'user'). | ||
Returns: | ||
- dict: A dictionary representing the log entry with timestamp. | ||
""" | ||
return { | ||
"tokencount": tokencount, | ||
"model": model, | ||
"source": source, | ||
"sourcetype": sourcetype, | ||
"timestamp": datetime.utcnow(), # Current timestamp in UTC | ||
} | ||
|
||
def log_usage_for_key(self, tokencount, model, key): | ||
""" | ||
Function to log usage for a specific key. | ||
Parameters: | ||
- tokencount (int): The count of tokens used. | ||
- model (str): The model associated with the usage. | ||
- key (str): The key for which the usage is logged. | ||
""" | ||
log_entry = self.create_log_entry(tokencount, model, key) | ||
self.log_collection.insert_one(log_entry) | ||
|
||
def log_usage_for_user(self, tokencount, model, user): | ||
""" | ||
Function to log usage for a specific user. | ||
Parameters: | ||
- tokencount (int): The count of tokens used. | ||
- model (str): The model associated with the usage. | ||
- user (str): The user for which the usage is logged. | ||
""" | ||
log_entry = self.create_log_entry(tokencount, model, user, "user") | ||
self.log_collection.insert_one(log_entry) | ||
|
||
def get_usage_for_user(self, username): | ||
raise NotImplementedError | ||
|
||
def get_usage_for_key(self, key): | ||
raise NotImplementedError | ||
|
||
def get_usage_for_model(self, model): | ||
raise NotImplementedError | ||
|
||
def get_usage_for_timerange(self, start, end): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import redis | ||
import pymongo | ||
import schedule | ||
import time | ||
import urllib | ||
import os | ||
|
||
# Needs to be escaped if necessary | ||
mongo_user = urllib.parse.quote_plus(os.environ.get("MONGOUSER")) | ||
mongo_password = urllib.parse.quote_plus(os.environ.get("MONGOPASSWORD")) | ||
|
||
|
||
class RedisUpdater: | ||
def __init__(self): | ||
""" | ||
Initializes Redis and MongoDB connections. | ||
""" | ||
# Redis connection | ||
self.redis_client = redis.StrictRedis(host="redis", port=6379, db=0) | ||
self.entries = [] # Placeholder for entries fetched from MongoDB | ||
|
||
# MongoDB connection | ||
self.mongo_client = mongo_client = pymongo.MongoClient( | ||
"mongodb://%s:%s@mongo:27017/" % (mongo_user, mongo_password) | ||
) | ||
self.db = mongo_client["gateway"] | ||
self.keyCollection = self.db["apikeys"] | ||
|
||
def update_redis(self): | ||
""" | ||
Updates Redis with API keys fetched from MongoDB. | ||
""" | ||
# Fetch entries from MongoDB | ||
self.fetch_entries_from_mongodb() | ||
# Clean up the current keys, and then add the new ones. | ||
# optimally, this would be done in an atomic call, but we will have to | ||
# see how often someone | ||
self.redis_client.delete("keys") | ||
# Update Redis with fetched entries | ||
self.redis_client.sadd("keys", *self.entries) | ||
|
||
def fetch_entries_from_mongodb(self): | ||
""" | ||
Retrieves active API keys from MongoDB and stores them in self.entries. | ||
""" | ||
# Retrieve entries from MongoDB | ||
currentKeys = self.collection.find({"active": True}) | ||
self.entries = [ | ||
entry["APIKEY"] for entry in self.collection.find({}, {"key": 1}) | ||
] | ||
|
||
def start_scheduler(self): | ||
""" | ||
Initiates a scheduler to run update_redis every 15 minutes continuously. | ||
""" | ||
# Schedule the update every 15 minutes | ||
schedule.every(15).minutes.do(self.update_redis) | ||
|
||
# Run the scheduler continuously | ||
while True: | ||
schedule.run_pending() | ||
# Only do this every 60 seconds - thats often enough. | ||
time.sleep(60) | ||
|
||
|
||
# Create an instance of the RedisUpdater and start the scheduler | ||
updater = RedisUpdater() | ||
updater.start_scheduler() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from pytest_mock_resources import create_redis_fixture | ||
from pytest_mock_resources import create_mongo_fixture | ||
from key_handler import KeyHandler | ||
import redis | ||
import pymongo | ||
|
||
redis = create_redis_fixture() | ||
mongo = create_mongo_fixture() | ||
|
||
|
||
def test_check_key(redis, mongo): | ||
handler = KeyHandler(True) | ||
handler.setup(mongo, redis) | ||
keys = ["ABC", "DEF"] | ||
redis.sadd("keys", *keys) | ||
assert handler.check_key("ABC") == True | ||
assert handler.check_key("DEF") == True | ||
assert handler.check_key("BCE") == False | ||
|
||
|
||
def test_create_key(redis, mongo): | ||
handler = KeyHandler(True) | ||
handler.setup(mongo, redis) | ||
newKey = handler.create_key("NewUser", "NewKey") | ||
db = mongo["gateway"] | ||
user_collection = db["users"] | ||
key_collection = db["apikeys"] | ||
assert user_collection.count_documents({}) == 1 | ||
user = user_collection.find_one({}) | ||
assert user["username"] == "NewUser" | ||
assert len(user["keys"]) == 1 | ||
assert key_collection.count_documents({}) == 1 | ||
assert handler.check_key(newKey) == True | ||
|
||
|
||
def test_delete_key(redis, mongo): | ||
handler = KeyHandler(True) | ||
handler.setup(mongo, redis) | ||
newKey = handler.create_key("NewUser", "NewKey") | ||
db = mongo["gateway"] | ||
user_collection = db["users"] | ||
key_collection = db["apikeys"] | ||
assert user_collection.count_documents({}) == 1 | ||
assert key_collection.count_documents({}) == 1 | ||
assert handler.check_key(newKey) == True | ||
handler.delete_key(newKey, "NewUser") | ||
user = user_collection.find_one({}) | ||
assert user["username"] == "NewUser" | ||
assert len(user["keys"]) == 0 | ||
assert key_collection.count_documents({}) == 0 | ||
assert handler.check_key(newKey) == False | ||
|
||
|
||
def test_set_key_activity(redis, mongo): | ||
handler = KeyHandler(True) | ||
handler.setup(mongo, redis) | ||
# Create key | ||
newKey = handler.create_key("NewUser", "NewKey") | ||
db = mongo["gateway"] | ||
user_collection = db["users"] | ||
key_collection = db["apikeys"] | ||
# Make sure key is valid at start | ||
assert handler.check_key(newKey) == True | ||
# deactivate key | ||
handler.set_key_activity(newKey, "NewUser", False) | ||
key_data = key_collection.find_one({"key": newKey}) | ||
assert key_data["active"] == False | ||
# Ensure key still exists | ||
user = user_collection.find_one({"username": "NewUser"}) | ||
assert len(user["keys"]) == 1 | ||
# and key is inactive | ||
assert handler.check_key(newKey) == False | ||
# reactivate and test that the key is now active again | ||
handler.set_key_activity(newKey, "NewUser", True) | ||
key_data = key_collection.find_one({"key": newKey}) | ||
assert key_data["active"] == True | ||
assert handler.check_key(newKey) == True |
Oops, something went wrong.