Skip to content

Commit

Permalink
add sessionmaker in settings (#112)
Browse files Browse the repository at this point in the history
* add sessionmaker in settings

* cleanup

* future not needed with sqlalchemy 2

* fix 3.8

* show values

* don't check kwargs
  • Loading branch information
malmans2 authored Feb 13, 2024
1 parent b11bd30 commit 075b768
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 72 deletions.
4 changes: 2 additions & 2 deletions cacholote/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
# When expiration is provided, only get entries with matching expiration
filters.append(database.CacheEntry.expiration == settings.expiration)

with settings.sessionmaker() as session:
with settings.instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry)
.filter(*filters)
Expand All @@ -97,7 +97,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
warnings.warn(f"can NOT encode output: {ex!r}", UserWarning)
return result

with settings.sessionmaker() as session:
with settings.instantiated_sessionmaker() as session:
session.add(cache_entry)
return _decode_and_update(session, cache_entry, settings)

Expand Down
12 changes: 6 additions & 6 deletions cacholote/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def delete(
Keyword arguments of functions to delete from cache
"""
hexdigest = encode._hexdigestify_python_call(func_to_del, *args, **kwargs)
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(database.CacheEntry.key == hexdigest)
):
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_unknown_files(self, lock_validity_period: Optional[float]) -> Set[str]:

unknown_sizes = {k: v for k, v in self.sizes.items() if k not in files_to_skip}
if unknown_sizes:
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
json.loads(
cache_entry._result_as_string,
Expand Down Expand Up @@ -231,14 +231,14 @@ def delete_cache_files(
start_timestamp = utils.utcnow()

# Get entries to clean
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
cache_entry_ids = session.scalars(
sa.select(database.CacheEntry.id).filter(*filters).order_by(*sorters)
)

# Loop over entries
for cache_entry_id in cache_entry_ids:
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
filters = [
database.CacheEntry.id == cache_entry_id,
# skip entries updated while cleaning
Expand Down Expand Up @@ -322,14 +322,14 @@ def clean_invalid_cache_entries(
if check_expiration:
filters.append(database.CacheEntry.expiration <= utils.utcnow())
if filters:
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters)
):
_delete_cache_entry(session, cache_entry)

if try_decode:
with config.get().sessionmaker() as session:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
try:
decode.loads(cache_entry._result_as_string)
Expand Down
76 changes: 41 additions & 35 deletions cacholote/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,19 @@
_SETTINGS: Optional[Settings] = None
_DEFAULT_CACHE_DIR = pathlib.Path(tempfile.gettempdir()) / "cacholote"
_DEFAULT_CACHE_DIR.mkdir(exist_ok=True)
_DEFAULT_CACHE_DB_URLPATH = f"sqlite:///{_DEFAULT_CACHE_DIR / 'cacholote.db'}"
_DEFAULT_CACHE_FILES_URLPATH = f"{_DEFAULT_CACHE_DIR / 'cache_files'}"
_DEFAULT_LOGGER = structlog.get_logger(
wrapper_class=structlog.make_filtering_bound_logger(logging.WARNING)
)

_CONFIG_NOT_SET_MSG = (
"Configuration settings have not been set. Run `cacholote.config.reset()`."
)


class Settings(pydantic_settings.BaseSettings):
use_cache: bool = True
cache_db_urlpath: str = f"sqlite:///{_DEFAULT_CACHE_DIR / 'cacholote.db'}"
cache_db_urlpath: Optional[str] = _DEFAULT_CACHE_DB_URLPATH
create_engine_kwargs: Dict[str, Any] = {}
cache_files_urlpath: str = f"{_DEFAULT_CACHE_DIR / 'cache_files'}"
sessionmaker: Optional[sa.orm.sessionmaker] = None # type: ignore[type-arg]
cache_files_urlpath: str = _DEFAULT_CACHE_FILES_URLPATH
cache_files_urlpath_readonly: Optional[str] = None
cache_files_storage_options: Dict[str, Any] = {}
xarray_cache_type: Literal[
Expand Down Expand Up @@ -80,35 +79,42 @@ def validate_expiration(
raise ValueError(f"Expiration is missing the timezone info. {expiration=}")
return expiration

def make_cache_dir(self) -> None:
@pydantic.model_validator(mode="after")
def make_cache_dir(self) -> "Settings":
fs, _, (urlpath, *_) = fsspec.get_fs_token_paths(
self.cache_files_urlpath,
storage_options=self.cache_files_storage_options,
)
fs.mkdirs(urlpath, exist_ok=True)
return self

def set_engine_and_session(self, force_reset: bool = False) -> None:
if (
force_reset
or database.ENGINE is None
or database.SESSIONMAKER is None
or str(database.ENGINE.url) != self.cache_db_urlpath
):
database._set_engine_and_session(
self.cache_db_urlpath, self.create_engine_kwargs
@pydantic.model_validator(mode="after")
def check_mutually_exclusive(self) -> "Settings":
if self.sessionmaker and self.cache_db_urlpath:
raise ValueError(
f"`sessionmaker` is mutually exclusive with `{self.cache_db_urlpath=}`."
)
if not (self.sessionmaker or self.cache_db_urlpath):
raise ValueError(
"Please provide either `sessionmaker` or `cache_db_urlpath`."
)
return self

@property
def engine(self) -> sa.engine.Engine:
if database.ENGINE is None:
raise ValueError(_CONFIG_NOT_SET_MSG)
return database.ENGINE
def instantiated_sessionmaker(self) -> sa.orm.sessionmaker: # type: ignore[type-arg]
if self.sessionmaker is None:
self.sessionmaker = database.cached_sessionmaker(
self.cache_db_urlpath, **self.create_engine_kwargs
)
self.cache_db_urlpath = None
self.create_engine_kwargs = {}
return self.sessionmaker

@property
def sessionmaker(self) -> sa.orm.sessionmaker: # type: ignore[type-arg]
if database.SESSIONMAKER is None:
raise ValueError(_CONFIG_NOT_SET_MSG)
return database.SESSIONMAKER
def engine(self) -> sa.engine.Engine:
engine = self.instantiated_sessionmaker.kw["bind"]
assert isinstance(engine, sa.engine.Engine)
return engine

model_config = pydantic_settings.SettingsConfigDict(
case_sensitive=False, env_prefix="cacholote_"
Expand All @@ -124,10 +130,12 @@ class set:
----------
use_cache: bool, default: True
Enable/disable cache.
cache_db_urlpath: str, default:"sqlite:////system_tmp_dir/cacholote/cacholote.db"
cache_db_urlpath: str, None, default:"sqlite:////system_tmp_dir/cacholote/cacholote.db"
URL for cache database (driver://user:pass@host/database).
create_engine_kwargs: dict, default: {}
Keyword arguments for ``sqlalchemy.create_engine``
sessionmaker: sessionmaker, optional
sqlalchemy.sessionamaker, mutually exclusive with cache_db_urlpath and create_engine_kwargs
cache_files_urlpath: str, default:"/system_tmp_dir/cacholote/cache_files"
URL for cache files (protocol://location).
cache_files_storage_options: dict, default: {}
Expand Down Expand Up @@ -155,15 +163,16 @@ class set:

def __init__(self, **kwargs: Any):
self._old_settings = get()
self._old_engine = database.ENGINE
self._old_session = database.SESSIONMAKER

model_dump = self._old_settings.model_dump()
if kwargs.get("cache_db_urlpath") or kwargs.get("create_engine_kwargs"):
model_dump["sessionmaker"] = None
if kwargs.get("sessionmaker") is not None:
model_dump["cache_db_urlpath"] = None
model_dump["create_engine_kwargs"] = {}

global _SETTINGS
_SETTINGS = Settings(**{**self._old_settings.model_dump(), **kwargs})
_SETTINGS.make_cache_dir()
_SETTINGS.set_engine_and_session(
self._old_settings.create_engine_kwargs != _SETTINGS.create_engine_kwargs
)
_SETTINGS = Settings(**{**model_dump, **kwargs})

def __enter__(self) -> Settings:
return get()
Expand All @@ -174,9 +183,6 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
database.ENGINE = self._old_engine
database.SESSIONMAKER = self._old_session

global _SETTINGS
_SETTINGS = self._old_settings

Expand Down
18 changes: 7 additions & 11 deletions cacholote/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.

import datetime
import functools
import json
import warnings
from typing import Any, Dict, Optional
from typing import Any

import sqlalchemy as sa
import sqlalchemy.orm
Expand All @@ -28,9 +29,6 @@
datetime.MAXYEAR, 12, 31, tzinfo=datetime.timezone.utc
)

ENGINE: Optional[sa.engine.Engine] = None
SESSIONMAKER: Optional[sa.orm.sessionmaker] = None # type: ignore[type-arg]

Base = sa.orm.declarative_base()


Expand Down Expand Up @@ -79,10 +77,8 @@ def _commit_or_rollback(session: sa.orm.Session) -> None:
session.rollback()


def _set_engine_and_session(
cache_db_urlpath: str, create_engine_kwargs: Dict[str, Any]
) -> None:
global ENGINE, SESSIONMAKER
ENGINE = sa.create_engine(cache_db_urlpath, future=True, **create_engine_kwargs)
Base.metadata.create_all(ENGINE)
SESSIONMAKER = sa.orm.sessionmaker(ENGINE)
@functools.lru_cache()
def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker: # type: ignore[type-arg]
engine = sa.create_engine(url, **kwargs)
Base.metadata.create_all(engine)
return sa.orm.sessionmaker(engine)
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from moto.moto_server.threaded_moto_server import ThreadedMotoServer

from cacholote import config
from cacholote import config, database


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -44,6 +44,7 @@ def set_cache(
) -> Iterator[str]:
param = getattr(request, "param", "file")
if param.lower() == "cads":
database.cached_sessionmaker.cache_clear()
test_bucket_name = "test-bucket"
client_kwargs = create_test_bucket(s3_server, test_bucket_name)
with config.set(
Expand Down
42 changes: 27 additions & 15 deletions tests/test_01_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,31 @@
import pytest
import sqlalchemy as sa

from cacholote import config, database
from cacholote import config

does_not_raise = contextlib.nullcontext


def test_change_sessionmaker(tmp_path: pathlib.Path) -> None:
old_sessionmaker = config.get().instantiated_sessionmaker
new_db = "sqlite:///" + str(tmp_path / "dummy.db")

with config.set(cache_db_urlpath=new_db):
new_sessionmaker = config.get().instantiated_sessionmaker
assert new_sessionmaker is not old_sessionmaker
assert config.get().instantiated_sessionmaker is old_sessionmaker

with config.set(sessionmaker=new_sessionmaker):
assert config.get().instantiated_sessionmaker is new_sessionmaker
assert config.get().instantiated_sessionmaker is old_sessionmaker

config.set(cache_db_urlpath=new_db)
assert config.get().instantiated_sessionmaker is new_sessionmaker

config.set(sessionmaker=old_sessionmaker)
assert config.get().instantiated_sessionmaker is old_sessionmaker


def test_change_cache_db_urlpath(tmp_path: pathlib.Path) -> None:
old_db = config.get().cache_db_urlpath
new_db = "sqlite:///" + str(tmp_path / "dummy.db")
Expand All @@ -36,7 +56,7 @@ def test_set_engine_and_sessionmaker(
tmp_path: pathlib.Path, key: str, reset: bool
) -> None:
old_engine = config.get().engine
old_sessionmaker = config.get().sessionmaker
old_sessionmaker = config.get().instantiated_sessionmaker

kwargs: Dict[str, Any] = {}
if key == "cache_db_urlpath":
Expand All @@ -51,20 +71,20 @@ def test_set_engine_and_sessionmaker(
with config.set(**kwargs):
if reset:
assert config.get().engine is not old_engine
assert config.get().sessionmaker is not old_sessionmaker
assert config.get().instantiated_sessionmaker is not old_sessionmaker
else:
assert config.get().engine is old_engine
assert config.get().sessionmaker is old_sessionmaker
assert config.get().instantiated_sessionmaker is old_sessionmaker
assert config.get().engine is old_engine
assert config.get().sessionmaker is old_sessionmaker
assert config.get().instantiated_sessionmaker is old_sessionmaker

config.set(**kwargs)
if reset:
assert config.get().engine is not old_engine
assert config.get().sessionmaker is not old_sessionmaker
assert config.get().instantiated_sessionmaker is not old_sessionmaker
else:
assert config.get().engine is old_engine
assert config.get().sessionmaker is old_sessionmaker
assert config.get().instantiated_sessionmaker is old_sessionmaker


def test_env_variables(tmp_path: pathlib.Path) -> None:
Expand Down Expand Up @@ -117,11 +137,3 @@ def test_set_expiration(
) -> None:
with raises:
config.set(expiration=expiration)


@pytest.mark.parametrize("set_cache", ["off"], indirect=True)
def test_engine_and_session_are_initialized() -> None:
with config.set():
pass
assert database.SESSIONMAKER is not None
assert database.ENGINE is not None
9 changes: 7 additions & 2 deletions tests/test_40_xarray_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from cacholote import cache, config, decode, encode, extra_encoders, utils

dask = pytest.importorskip("dask")

try:
import xarray as xr
except ImportError:
Expand Down Expand Up @@ -35,11 +37,14 @@ def test_dictify_xr_dataset(tmp_path: pathlib.Path) -> None:

# Create sample dataset
ds = xr.Dataset({"foo": [0]}, attrs={})
with dask.config.set({"tokenize.ensure-deterministic": True}):
token = dask.base.tokenize(ds)

# Check dict
actual = extra_encoders.dictify_xr_object(ds)
href = f"{readonly_dir}/247fd17e087ae491996519c097e70e48.nc"
local_path = f"{tmp_path}/cache_files/247fd17e087ae491996519c097e70e48.nc"
print(fsspec.filesystem("file").ls(f"{tmp_path}/cache_files"))
href = f"{readonly_dir}/{token}.nc"
local_path = f"{tmp_path}/cache_files/{token}.nc"
expected = {
"type": "python_call",
"callable": "cacholote.extra_encoders:decode_xr_dataset",
Expand Down

0 comments on commit 075b768

Please sign in to comment.