From cf709a2215006849ee974b6f1d25ad3fb6831daa Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Tue, 6 May 2025 17:50:15 +0200 Subject: [PATCH 1/3] Adds cache decorator & cache or middleware calls --- business_objects/organization.py | 2 + business_objects/user.py | 3 + db_cache.py | 180 +++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 db_cache.py diff --git a/business_objects/organization.py b/business_objects/organization.py index e20abe7..843a317 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -8,8 +8,10 @@ from ..models import Organization, Project, User from ..business_objects import project, user, general from ..util import prevent_sql_injection +from ..db_cache import TTLCacheDecorator, CacheEnum +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") def get(id: str) -> Organization: return session.query(Organization).get(id) diff --git a/business_objects/user.py b/business_objects/user.py index a20210b..dada9e8 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -5,10 +5,13 @@ from typing import List, Optional from sqlalchemy import sql +from ..db_cache import TTLCacheDecorator, CacheEnum + from ..util import prevent_sql_injection +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") def get(user_id: str) -> User: return session.query(User).get(user_id) diff --git a/db_cache.py b/db_cache.py new file mode 100644 index 0000000..00f7949 --- /dev/null +++ b/db_cache.py @@ -0,0 +1,180 @@ +import time +import functools +import inspect +import threading +from enum import Enum +from .daemon import run_without_db_token + + +# Enum for logical cache separation\ +class CacheEnum(Enum): + DEFAULT = "default" + USER = "user" + ORGANIZATION = "organization" + TEAM = "team" + # extend with more categories as needed + + +# Global cache map: each cache_type -> its own dict of key -> (value, expires_at) +_GLOBAL_CACHE_MAP = {} +# Lock to protect cache operations +_CACHE_LOCK = threading.Lock() + + +def _cleanup_expired(): + while True: + time.sleep(60 * 60) # run every hour + now = time.time() + with _CACHE_LOCK: + for cache in _GLOBAL_CACHE_MAP.values(): + # collect expired keys first + expired_keys = [key for key, (_, exp) in cache.items() if now >= exp] + for key in expired_keys: + del cache[key] + + +# # Start cleanup thread as daemon +# _cleanup_thread = threading.Thread(target=_cleanup_expired, daemon=True) +# _cleanup_thread.start() + + +def start_cleanup_thread(): + run_without_db_token(_cleanup_expired) + + +class TTLCacheDecorator: + def __init__(self, cache_type=CacheEnum.DEFAULT, ttl_minutes=None, *key_fields): + """ + cache_type: namespace for the cache + ttl_minutes: time-to-live for cache entries, in minutes + key_fields: argument names (str) to build cache key; positions are not supported + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + if ttl_minutes is None: + raise ValueError("ttl_minutes must be specified") + self.cache_type = cache_type + # convert minutes to seconds + self.ttl = ttl_minutes * 60 + # only named fields + for f in key_fields: + if not isinstance(f, str): + raise TypeError("key_fields must be argument names (strings)") + self.key_fields = key_fields + + def __call__(self, fn): + sig = inspect.signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + # build cache key tuple from named fields + try: + key = tuple(bound.arguments[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for cache key: {e}") + + now = time.time() + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + entry = cache.get(key) + if entry: + value, expires_at = entry + if now < expires_at: + print(f"Cache hit for {key} in {self.cache_type}") + return value + # expired + del cache[key] + + # miss or expired + print(f"No cache hit for {key} in {self.cache_type}") + result = fn(*args, **kwargs) + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP[self.cache_type] + cache[key] = (result, now + self.ttl) + return result + + # management methods + def invalidate(**kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for invalidate key: {e}") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(self.cache_type, {}).pop(key, None) + + def update(value, **kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for update key: {e}") + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + cache[key] = (value, time.time() + self.ttl) + + def clear_all(): + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP[self.cache_type] = {} + + wrapper.invalidate = invalidate + wrapper.update = update + wrapper.clear_all = clear_all + wrapper.ttl = self.ttl + wrapper.key_fields = self.key_fields + wrapper.cache_type = self.cache_type + + return wrapper + + +# ─── GLOBAL INVALIDATE / UPDATE ────────────────────────────────────────────── + + +def invalidate_cache(cache_type: CacheEnum, key: tuple): + """ + Remove a single entry from the given cache. + key must be the exact tuple used when caching. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(cache_type, {}).pop(key, None) + + +def update_cache(cache_type: CacheEnum, key: tuple, value, ttl_minutes: float): + """ + Force-set a value in cache under `cache_type` and `key`, overriding any existing entry. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + expires_at = time.time() + ttl_minutes * 60 + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(cache_type, {}) + cache[key] = (value, expires_at) + + +# Example usage: !Note the tuples syntax +# invalidate_cache(CacheEnum.USER, (user_id,)) +# update_cache(CacheEnum.USER, (user_id,), some_value, ttl_minutes=5) + + +# Example usage: +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_user_by_id(user_id): +# print(f"Fetching user {user_id} from database") +# return {"id": user_id, "name": "John"} + +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_admin_user(user_id, dummy=None): +# print(f"Fetching admin user {user_id} from database") +# return {"id": user_id, "role": "admin"} + +# @TTLCacheDecorator(CacheEnum.RECORD, 120, 'project_id', 'record_id') +# def get_record(project_id, record_id): +# print(f"Fetching record {project_id}, {record_id} from database") +# return {"project_id": project_id, "id": record_id, "value": "Some data"} + +# Management examples: +# get_user_by_id.invalidate(user_id=1) +# get_user_by_id.update({"id":1, "name":"Jane"}, user_id=1) +# get_user_by_id.clear_all() +# get_record.clear_all() From 064cb7db0b9650e41ac7e60d5630fc249b6f1b35 Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Wed, 7 May 2025 09:17:06 +0200 Subject: [PATCH 2/3] Remove prints --- db_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/db_cache.py b/db_cache.py index 00f7949..d6acb2c 100644 --- a/db_cache.py +++ b/db_cache.py @@ -81,13 +81,13 @@ def wrapper(*args, **kwargs): if entry: value, expires_at = entry if now < expires_at: - print(f"Cache hit for {key} in {self.cache_type}") + # print(f"Cache hit for {key} in {self.cache_type}") return value # expired del cache[key] # miss or expired - print(f"No cache hit for {key} in {self.cache_type}") + # print(f"No cache hit for {key} in {self.cache_type}") result = fn(*args, **kwargs) with _CACHE_LOCK: cache = _GLOBAL_CACHE_MAP[self.cache_type] From 8b8321d63f735818cb08b7224c8c7cc2bc1de935 Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Wed, 7 May 2025 11:55:48 +0200 Subject: [PATCH 3/3] Expunge test --- business_objects/organization.py | 13 +++++++++++-- business_objects/user.py | 20 ++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/business_objects/organization.py b/business_objects/organization.py index 843a317..feb6d7b 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -11,11 +11,20 @@ from ..db_cache import TTLCacheDecorator, CacheEnum -@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") def get(id: str) -> Organization: return session.query(Organization).get(id) +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") +def get_org_cached(id: str) -> Organization: + o = get(id) + if not o: + return None + general.expunge(o) + general.make_transient(o) + return o + + def get_by_name(name: str) -> Organization: return session.query(Organization).filter(Organization.name == name).first() @@ -120,7 +129,7 @@ def log_admin_requests(org_id: str) -> str: # enum AdminLogLevel if not org_id: # e.g. not assigned to an organization = not logged return None - if o := get(org_id): + if o := get_org_cached(org_id): return o.log_admin_requests return None diff --git a/business_objects/user.py b/business_objects/user.py index dada9e8..faf52c9 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -1,21 +1,37 @@ from datetime import datetime from . import general, organization, team_member from .. import User, enums +from typing import Dict, Any from ..session import session from typing import List, Optional from sqlalchemy import sql from ..db_cache import TTLCacheDecorator, CacheEnum - from ..util import prevent_sql_injection -@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") def get(user_id: str) -> User: return session.query(User).get(user_id) +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") +def get_user_cached(user_id: str) -> User: + """ + Get user by id and return as dict + """ + user = get(user_id) + if not user: + return None + + general.expunge(user) + general.make_transient(user) + return user + # if not user: + # return {} + # return sql_alchemy_to_dict(user) + + def get_by_id_list(user_ids: List[str]) -> List[User]: return session.query(User).filter(User.id.in_(user_ids)).all()