diff --git a/business_objects/organization.py b/business_objects/organization.py index e20abe7..feb6d7b 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -8,12 +8,23 @@ from ..models import Organization, Project, User from ..business_objects import project, user, general from ..util import prevent_sql_injection +from ..db_cache import TTLCacheDecorator, CacheEnum def get(id: str) -> Organization: return session.query(Organization).get(id) +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") +def get_org_cached(id: str) -> Organization: + o = get(id) + if not o: + return None + general.expunge(o) + general.make_transient(o) + return o + + def get_by_name(name: str) -> Organization: return session.query(Organization).filter(Organization.name == name).first() @@ -118,7 +129,7 @@ def log_admin_requests(org_id: str) -> str: # enum AdminLogLevel if not org_id: # e.g. not assigned to an organization = not logged return None - if o := get(org_id): + if o := get_org_cached(org_id): return o.log_admin_requests return None diff --git a/business_objects/user.py b/business_objects/user.py index a20210b..faf52c9 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -1,10 +1,12 @@ from datetime import datetime from . import general, organization, team_member from .. import User, enums +from typing import Dict, Any from ..session import session from typing import List, Optional from sqlalchemy import sql +from ..db_cache import TTLCacheDecorator, CacheEnum from ..util import prevent_sql_injection @@ -13,6 +15,23 @@ def get(user_id: str) -> User: return session.query(User).get(user_id) +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") +def get_user_cached(user_id: str) -> User: + """ + Get user by id and return as dict + """ + user = get(user_id) + if not user: + return None + + general.expunge(user) + general.make_transient(user) + return user + # if not user: + # return {} + # return sql_alchemy_to_dict(user) + + def get_by_id_list(user_ids: List[str]) -> List[User]: return session.query(User).filter(User.id.in_(user_ids)).all() diff --git a/db_cache.py b/db_cache.py new file mode 100644 index 0000000..d6acb2c --- /dev/null +++ b/db_cache.py @@ -0,0 +1,180 @@ +import time +import functools +import inspect +import threading +from enum import Enum +from .daemon import run_without_db_token + + +# Enum for logical cache separation\ +class CacheEnum(Enum): + DEFAULT = "default" + USER = "user" + ORGANIZATION = "organization" + TEAM = "team" + # extend with more categories as needed + + +# Global cache map: each cache_type -> its own dict of key -> (value, expires_at) +_GLOBAL_CACHE_MAP = {} +# Lock to protect cache operations +_CACHE_LOCK = threading.Lock() + + +def _cleanup_expired(): + while True: + time.sleep(60 * 60) # run every hour + now = time.time() + with _CACHE_LOCK: + for cache in _GLOBAL_CACHE_MAP.values(): + # collect expired keys first + expired_keys = [key for key, (_, exp) in cache.items() if now >= exp] + for key in expired_keys: + del cache[key] + + +# # Start cleanup thread as daemon +# _cleanup_thread = threading.Thread(target=_cleanup_expired, daemon=True) +# _cleanup_thread.start() + + +def start_cleanup_thread(): + run_without_db_token(_cleanup_expired) + + +class TTLCacheDecorator: + def __init__(self, cache_type=CacheEnum.DEFAULT, ttl_minutes=None, *key_fields): + """ + cache_type: namespace for the cache + ttl_minutes: time-to-live for cache entries, in minutes + key_fields: argument names (str) to build cache key; positions are not supported + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + if ttl_minutes is None: + raise ValueError("ttl_minutes must be specified") + self.cache_type = cache_type + # convert minutes to seconds + self.ttl = ttl_minutes * 60 + # only named fields + for f in key_fields: + if not isinstance(f, str): + raise TypeError("key_fields must be argument names (strings)") + self.key_fields = key_fields + + def __call__(self, fn): + sig = inspect.signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + # build cache key tuple from named fields + try: + key = tuple(bound.arguments[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for cache key: {e}") + + now = time.time() + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + entry = cache.get(key) + if entry: + value, expires_at = entry + if now < expires_at: + # print(f"Cache hit for {key} in {self.cache_type}") + return value + # expired + del cache[key] + + # miss or expired + # print(f"No cache hit for {key} in {self.cache_type}") + result = fn(*args, **kwargs) + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP[self.cache_type] + cache[key] = (result, now + self.ttl) + return result + + # management methods + def invalidate(**kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for invalidate key: {e}") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(self.cache_type, {}).pop(key, None) + + def update(value, **kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for update key: {e}") + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + cache[key] = (value, time.time() + self.ttl) + + def clear_all(): + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP[self.cache_type] = {} + + wrapper.invalidate = invalidate + wrapper.update = update + wrapper.clear_all = clear_all + wrapper.ttl = self.ttl + wrapper.key_fields = self.key_fields + wrapper.cache_type = self.cache_type + + return wrapper + + +# ─── GLOBAL INVALIDATE / UPDATE ────────────────────────────────────────────── + + +def invalidate_cache(cache_type: CacheEnum, key: tuple): + """ + Remove a single entry from the given cache. + key must be the exact tuple used when caching. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(cache_type, {}).pop(key, None) + + +def update_cache(cache_type: CacheEnum, key: tuple, value, ttl_minutes: float): + """ + Force-set a value in cache under `cache_type` and `key`, overriding any existing entry. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + expires_at = time.time() + ttl_minutes * 60 + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(cache_type, {}) + cache[key] = (value, expires_at) + + +# Example usage: !Note the tuples syntax +# invalidate_cache(CacheEnum.USER, (user_id,)) +# update_cache(CacheEnum.USER, (user_id,), some_value, ttl_minutes=5) + + +# Example usage: +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_user_by_id(user_id): +# print(f"Fetching user {user_id} from database") +# return {"id": user_id, "name": "John"} + +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_admin_user(user_id, dummy=None): +# print(f"Fetching admin user {user_id} from database") +# return {"id": user_id, "role": "admin"} + +# @TTLCacheDecorator(CacheEnum.RECORD, 120, 'project_id', 'record_id') +# def get_record(project_id, record_id): +# print(f"Fetching record {project_id}, {record_id} from database") +# return {"project_id": project_id, "id": record_id, "value": "Some data"} + +# Management examples: +# get_user_by_id.invalidate(user_id=1) +# get_user_by_id.update({"id":1, "name":"Jane"}, user_id=1) +# get_user_by_id.clear_all() +# get_record.clear_all()