diff --git a/business_objects/general.py b/business_objects/general.py index 76147ae5..f0ce8c05 100644 --- a/business_objects/general.py +++ b/business_objects/general.py @@ -11,10 +11,13 @@ from threading import Lock from sqlalchemy.dialects import postgresql from sqlalchemy.sql import Select +import os __THREAD_LOCK = Lock() +IS_DEV = os.getenv("IS_DEV", "false").lower() in {"true", "1", "yes", "y"} + session_lookup = {} @@ -23,7 +26,11 @@ def get_ctx_token() -> Any: session_uuid = str(uuid.uuid4()) session_id = request_id_ctx_var.set(session_uuid) - call_stack = "".join(traceback.format_stack()[-5:]) + if IS_DEV: + # traces are usually long running and only useful for debugging + call_stack = "".join(traceback.format_stack()[-5:]) + else: + call_stack = "Activate dev mode to see call stack" with __THREAD_LOCK: session_lookup[session_uuid] = { "session_id": session_uuid, diff --git a/business_objects/organization.py b/business_objects/organization.py index e20abe7e..84fa420b 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -4,16 +4,30 @@ from submodules.model import enums -from ..session import session +from ..session import session, request_id_ctx_var from ..models import Organization, Project, User from ..business_objects import project, user, general from ..util import prevent_sql_injection +from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session def get(id: str) -> Organization: return session.query(Organization).get(id) +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") +@with_session() +def get_org_cached(id: str) -> Organization: + print("get_org_cached with session:", request_id_ctx_var.get(), flush=True) + o = get(id) + if not o: + return None + general.expunge(o) + general.make_transient(o) + return o + + def get_by_name(name: str) -> Organization: return session.query(Organization).filter(Organization.name == name).first() @@ -118,7 +132,7 @@ def log_admin_requests(org_id: str) -> str: # enum AdminLogLevel if not org_id: # e.g. not assigned to an organization = not logged return None - if o := get(org_id): + if o := get_org_cached(org_id): return o.log_admin_requests return None diff --git a/business_objects/user.py b/business_objects/user.py index a20210b6..5d5cf5f7 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -1,10 +1,14 @@ from datetime import datetime from . import general, organization, team_member from .. import User, enums -from ..session import session + +# from typing import Dict, Any +from ..session import session, request_id_ctx_var from typing import List, Optional from sqlalchemy import sql +from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session from ..util import prevent_sql_injection @@ -13,6 +17,25 @@ def get(user_id: str) -> User: return session.query(User).get(user_id) +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") +@with_session() +def get_user_cached(user_id: str) -> User: + """ + Get user by id and return as dict + """ + print("get_user_cached with session:", request_id_ctx_var.get(), flush=True) + user = get(user_id) + if not user: + return None + + general.expunge(user) + general.make_transient(user) + return user + # if not user: + # return {} + # return sql_alchemy_to_dict(user) + + def get_by_id_list(user_ids: List[str]) -> List[User]: return session.query(User).filter(User.id.in_(user_ids)).all() diff --git a/cognition_objects/project.py b/cognition_objects/project.py index 9ebdd4d3..06f30ba6 100644 --- a/cognition_objects/project.py +++ b/cognition_objects/project.py @@ -1,13 +1,15 @@ from typing import List, Optional, Dict, Any, Iterable from ..business_objects import general, team_resource, user from ..cognition_objects import consumption_log, consumption_summary -from ..session import session +from ..session import session, request_id_ctx_var from ..models import CognitionProject, TeamMember, TeamResource from .. import enums from datetime import datetime from ..util import prevent_sql_injection from sqlalchemy.orm.attributes import flag_modified from copy import deepcopy +from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session def get(project_id: str) -> CognitionProject: @@ -18,6 +20,18 @@ def get(project_id: str) -> CognitionProject: ) +@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id") +@with_session() +def get_cached(project_id: str) -> CognitionProject: + print("get_project_cached with session:", request_id_ctx_var.get(), flush=True) + p = get(project_id) + if not p: + return None + general.expunge(p) + general.make_transient(p) + return p + + def get_org_id(project_id: str) -> str: if p := get(project_id): return str(p.organization_id) @@ -42,6 +56,18 @@ def get_by_user(project_id: str, user_id: str) -> CognitionProject: ) +@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id", "user_id") +@with_session() +def get_by_user_cached(project_id: str, user_id: str) -> CognitionProject: + print("get_by_user_cached with session:", request_id_ctx_var.get(), flush=True) + p = get_by_user(project_id, user_id) + if not p: + return None + general.expunge(p) + general.make_transient(p) + return p + + def get_all(org_id: str, order_by_name: bool = False) -> List[CognitionProject]: query = session.query(CognitionProject).filter( CognitionProject.organization_id == org_id diff --git a/db_cache.py b/db_cache.py new file mode 100644 index 00000000..8dbff58e --- /dev/null +++ b/db_cache.py @@ -0,0 +1,181 @@ +import time +import functools +import inspect +import threading +from enum import Enum +from .daemon import run_without_db_token + + +# Enum for logical cache separation\ +class CacheEnum(Enum): + DEFAULT = "default" + USER = "user" + ORGANIZATION = "organization" + PROJECT = "project" + TEAM = "team" + # extend with more categories as needed + + +# Global cache map: each cache_type -> its own dict of key -> (value, expires_at) +_GLOBAL_CACHE_MAP = {} +# Lock to protect cache operations +_CACHE_LOCK = threading.Lock() + + +def _cleanup_expired(): + while True: + time.sleep(60 * 60) # run every hour + now = time.time() + with _CACHE_LOCK: + for cache in _GLOBAL_CACHE_MAP.values(): + # collect expired keys first + expired_keys = [key for key, (_, exp) in cache.items() if now >= exp] + for key in expired_keys: + del cache[key] + + +# # Start cleanup thread as daemon +# _cleanup_thread = threading.Thread(target=_cleanup_expired, daemon=True) +# _cleanup_thread.start() + + +def start_cleanup_thread(): + run_without_db_token(_cleanup_expired) + + +class TTLCacheDecorator: + def __init__(self, cache_type=CacheEnum.DEFAULT, ttl_minutes=None, *key_fields): + """ + cache_type: namespace for the cache + ttl_minutes: time-to-live for cache entries, in minutes + key_fields: argument names (str) to build cache key; positions are not supported + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + if ttl_minutes is None: + raise ValueError("ttl_minutes must be specified") + self.cache_type = cache_type + # convert minutes to seconds + self.ttl = ttl_minutes * 60 + # only named fields + for f in key_fields: + if not isinstance(f, str): + raise TypeError("key_fields must be argument names (strings)") + self.key_fields = key_fields + + def __call__(self, fn): + sig = inspect.signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + # build cache key tuple from named fields + try: + key = tuple(bound.arguments[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for cache key: {e}") + + now = time.time() + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + entry = cache.get(key) + if entry: + value, expires_at = entry + if now < expires_at: + # print(f"Cache hit for {key} in {self.cache_type}") + return value + # expired + del cache[key] + + # miss or expired + # print(f"No cache hit for {key} in {self.cache_type}") + result = fn(*args, **kwargs) + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP[self.cache_type] + cache[key] = (result, now + self.ttl) + return result + + # management methods + def invalidate(**kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for invalidate key: {e}") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(self.cache_type, {}).pop(key, None) + + def update(value, **kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for update key: {e}") + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + cache[key] = (value, time.time() + self.ttl) + + def clear_all(): + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP[self.cache_type] = {} + + wrapper.invalidate = invalidate + wrapper.update = update + wrapper.clear_all = clear_all + wrapper.ttl = self.ttl + wrapper.key_fields = self.key_fields + wrapper.cache_type = self.cache_type + + return wrapper + + +# ─── GLOBAL INVALIDATE / UPDATE ────────────────────────────────────────────── + + +def invalidate_cache(cache_type: CacheEnum, key: tuple): + """ + Remove a single entry from the given cache. + key must be the exact tuple used when caching. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(cache_type, {}).pop(key, None) + + +def update_cache(cache_type: CacheEnum, key: tuple, value, ttl_minutes: float): + """ + Force-set a value in cache under `cache_type` and `key`, overriding any existing entry. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + expires_at = time.time() + ttl_minutes * 60 + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(cache_type, {}) + cache[key] = (value, expires_at) + + +# Example usage: !Note the tuples syntax +# invalidate_cache(CacheEnum.USER, (user_id,)) +# update_cache(CacheEnum.USER, (user_id,), some_value, ttl_minutes=5) + + +# Example usage: +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_user_by_id(user_id): +# print(f"Fetching user {user_id} from database") +# return {"id": user_id, "name": "John"} + +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_admin_user(user_id, dummy=None): +# print(f"Fetching admin user {user_id} from database") +# return {"id": user_id, "role": "admin"} + +# @TTLCacheDecorator(CacheEnum.RECORD, 120, 'project_id', 'record_id') +# def get_record(project_id, record_id): +# print(f"Fetching record {project_id}, {record_id} from database") +# return {"project_id": project_id, "id": record_id, "value": "Some data"} + +# Management examples: +# get_user_by_id.invalidate(user_id=1) +# get_user_by_id.update({"id":1, "name":"Jane"}, user_id=1) +# get_user_by_id.clear_all() +# get_record.clear_all() diff --git a/session.py b/session.py index 9e588ea1..3f65f36c 100644 --- a/session.py +++ b/session.py @@ -88,3 +88,22 @@ def __start_session_cleanup(): except Exception: traceback.print_exc() time.sleep(10) + + +def pool_report(): + """ + Returns a dict with pool metrics for the engine bound to the given + SQLAlchemy Session (or global `engine` if sess is None). + """ + # eng = sess.get_bind() if sess else engine + pool = engine.pool + + return { + "pool_size": pool.size(), + "checked_in": pool.checkedin(), + "overflow": pool.overflow(), + "checked_out": pool.checkedout(), + "max_overflow": pool._max_overflow, + "total_capacity": pool.size() + pool._max_overflow, + "available": (pool.size() + pool._max_overflow) - pool.checkedout(), + } diff --git a/session_wrapper.py b/session_wrapper.py new file mode 100644 index 00000000..5ff5196a --- /dev/null +++ b/session_wrapper.py @@ -0,0 +1,150 @@ +# db_utils.py + +import asyncio + +# import uuid +from .business_objects import general +from contextvars import copy_context +from .session import session, request_id_ctx_var +import functools + + +def _run_with_session( + fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs +): + """ + Sync helper: ensures a request-id is set (or reset), runs fn(*args, **kwargs), + then optionally removes the session. + + Args: + fn: the DB function to run + auto_remove: if True, calls Session.remove() after execution + new_session: if True, always assign a fresh UUID as the request ID + """ + # decide on request ID behavior + if new_session or request_id_ctx_var.get() is None: + # generate a unique request id for this session + # request_id_ctx_var.set(str(uuid.uuid4())) + general.get_ctx_token() + + try: + # Scoped Session uses request_id_ctx_var under the hood + return fn(*args, **kwargs) + except Exception: + session.rollback() + raise + finally: + if auto_remove: + session.remove() + + +# def with_session(auto_remove: bool = True, new_session: bool = True): +# """ +# Decorator for sync DB functions. + +# Args: +# auto_remove: session.remove() after fn returns (default True) +# new_session: force a fresh session UUID for each call (default False) + +# Usage: +# @with_session() +# def read_data(...): +# session = Session() +# return session.query(...) + +# @with_session(auto_remove=False, new_session=True) +# def batch_ops(...): +# session = Session() +# # do writes in an isolated session context +# """ + +# def decorator(fn): +# @functools.wraps(fn) +# def wrapper(*args, **kwargs): +# return _run_with_session( +# fn, +# *args, +# auto_remove=auto_remove, +# new_session=new_session, +# **kwargs, +# ) + +# return wrapper + +# return decorator + + +# async def run_db( +# fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs +# ): +# """ +# Async helper: runs a sync @with_session function in a threadpool. + +# Args: +# fn: the @with_session-decorated function to call +# auto_remove: pass-through to control session removal +# new_session: pass-through to force fresh session UUID +# """ +# ctx = copy_context() +# return await asyncio.to_thread( +# lambda: ctx.run( +# _run_with_session, +# fn, +# *args, +# auto_remove, +# new_session, +# **kwargs, +# ) +# ) + + +def with_session(auto_remove: bool = True, new_session: bool = False): + """ + Decorator for sync DB functions. + + Args: + auto_remove: session.remove() after fn returns (default True) + new_session: force a fresh session UUID for each call (default False) + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return _run_with_session( + fn, + *args, + auto_remove=auto_remove, + new_session=new_session, + **kwargs, + ) + + return wrapper + + return decorator + + +async def run_db( + fn, *args, auto_remove: bool = True, new_session: bool = False, **kwargs +): + """ + Async helper: runs a sync @with_session function in a threadpool. + + Args: + fn: the @with_session-decorated function to call + auto_remove: pass-through to control session removal + new_session: pass-through to force fresh session UUID + """ + ctx = copy_context() + + def call(): + # explicitly pass keyword-only args + return ctx.run( + _run_with_session, + fn, + *args, + auto_remove=auto_remove, + new_session=new_session, + **kwargs, + ) + + return await asyncio.to_thread(call)