Skip to content

Commit

Permalink
Introduce a lru_method_cache decorator for caching methods (#8267)
Browse files Browse the repository at this point in the history
The decorator arranges for a cache to be stored inside the object on
first call.

Inspired by #8261 but without per-object toil.
Fixes #5377.
  • Loading branch information
msullivan authored Jan 29, 2025
1 parent 6d79b49 commit 0fe5693
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 18 deletions.
52 changes: 52 additions & 0 deletions edb/common/lru.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from __future__ import annotations

import collections.abc
import functools


from typing import TypeVar, Callable


class LRUMapping(collections.abc.MutableMapping):
Expand Down Expand Up @@ -75,3 +79,51 @@ def __len__(self):

def __iter__(self):
return iter(self._dict)


Tf = TypeVar('Tf', bound=Callable)


class _NoPickle:
def __init__(self, obj):
self.obj = obj

def __bool__(self):
return bool(self.obj)

def __getstate__(self):
return ()

def __setstate__(self, _d):
self.obj = None


def lru_method_cache(size: int | None=128) -> Callable[[Tf], Tf]:
"""A version of lru_cache for methods that shouldn't leak memory.
Basically the idea is that we generate a per-object lru-cached
partially applied method.
Since pickling an lru_cache of a lambda or a functools.partial
doesn't work, we wrap it in a _NoPickle object that doesn't pickle
its contents.
"""
def transformer(f: Tf) -> Tf:
key = f'__{f.__name__}_cached'

def func(self, *args, **kwargs):
_m = getattr(self, key, None)
if not _m:
_m = _NoPickle(
functools.lru_cache(size)(functools.partial(f, self))
)
setattr(self, key, _m)
return _m.obj(*args, **kwargs)

return func # type: ignore

return transformer


def method_cache(f: Tf) -> Tf:
return lru_method_cache(None)(f)
9 changes: 5 additions & 4 deletions edb/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from edb import errors
from edb.common import adapter
from edb.common import english
from edb.common import lru

from . import casts as s_casts
from . import functions as s_func
Expand Down Expand Up @@ -1218,7 +1219,7 @@ def get_operators(
type=s_oper.Operator,
)

@functools.lru_cache()
@lru.lru_method_cache()
def _get_casts(
self,
stype: s_types.Type,
Expand Down Expand Up @@ -1274,7 +1275,7 @@ def get_referrers(
return self._get_referrers(
scls, scls_type=scls_type, field_name=field_name)

@functools.lru_cache()
@lru.lru_method_cache()
def _get_referrers(
self,
scls: so.Object,
Expand Down Expand Up @@ -1312,8 +1313,8 @@ def _get_referrers(

return frozenset(referrers) # type: ignore

@functools.lru_cache()
def get_referrers_ex( # type: ignore
@lru.lru_method_cache()
def get_referrers_ex(
self,
scls: so.Object,
*,
Expand Down
5 changes: 3 additions & 2 deletions edb/server/compiler/sertypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@

from edb import errors
from edb.common import binwrapper
from edb.common import value_dispatch
from edb.common import lru
from edb.common import uuidgen
from edb.common import value_dispatch

from edb.protocol import enums as p_enums
from edb.server import config
Expand Down Expand Up @@ -1900,7 +1901,7 @@ def get_global_type_rep(


class CompilationConfigSerializer(InputShapeSerializer):
@functools.lru_cache(64)
@lru.lru_method_cache(64)
def encode_configs(
self, *configs: immutables.Map[str, config.SettingValue] | None
) -> bytes:
Expand Down
7 changes: 4 additions & 3 deletions edb/server/compiler_pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import immutables

from edb.common import debug
from edb.common import lru

from edb.pgsql import params as pgparams

Expand Down Expand Up @@ -189,7 +190,7 @@ def _get_init_args(self):
assert self._dbindex is not None
return self._make_init_args(*self._dbindex.get_cached_compiler_args())

@functools.lru_cache(1)
@lru.lru_method_cache(1)
def _make_init_args(self, dbs, global_schema_pickle, system_config):
init_args = (
dbs,
Expand Down Expand Up @@ -1103,7 +1104,7 @@ async def stop(self):
if worker.done():
(await worker).close()

@functools.lru_cache(1)
@lru.lru_method_cache(1)
def _make_init_args(self, dbs, global_schema_pickle, system_config):
init_args = (
dbs,
Expand Down Expand Up @@ -1337,7 +1338,7 @@ def drop_tenant(self, client_id: int):
for worker in self._workers.values():
worker.invalidate(client_id)

@functools.cache
@lru.method_cache
def _get_init_args(self):
init_args = (
self._backend_runtime_params,
Expand Down
3 changes: 2 additions & 1 deletion edb/server/compiler_pool/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import immutables

from edb.common import debug
from edb.common import lru
from edb.common import markup

from .. import metrics
Expand Down Expand Up @@ -204,7 +205,7 @@ def _init(self, kwargs: dict[str, typing.Any]) -> None:
# this is deferred to _init_server()
pass

@functools.cache
@lru.method_cache
def _get_init_args(self):
init_args = (
self._backend_runtime_params,
Expand Down
4 changes: 2 additions & 2 deletions edb/server/pgcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

import asyncio
import copy
import functools
import hashlib
import json
import logging
Expand All @@ -53,6 +52,7 @@

from edb import buildmeta
from edb import errors
from edb.common import lru
from edb.common import supervisor
from edb.common import uuidgen

Expand Down Expand Up @@ -925,7 +925,7 @@ def stop_watching(self) -> None:
if self._ha_backend is not None:
self._ha_backend.stop_watching()

@functools.cache
@lru.method_cache
def get_client_id(self) -> int:
tenant_id = self._instance_params.tenant_id
if self._ha_backend is not None:
Expand Down
4 changes: 2 additions & 2 deletions edb/server/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import asyncio
import contextlib
import dataclasses
import functools
import json
import logging
import os
Expand All @@ -54,6 +53,7 @@
from edb import buildmeta
from edb import errors
from edb.common import asyncutil
from edb.common import lru
from edb.common import retryloop
from edb.common import verutils
from edb.common.log import current_tenant
Expand Down Expand Up @@ -355,7 +355,7 @@ def get_pg_dbname(self, dbname: str) -> str:
def get_pgaddr(self) -> pgconnparams.ConnectionParams:
return self._cluster.get_pgaddr()

@functools.lru_cache
@lru.method_cache
def get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams:
return self._cluster.get_runtime_params()

Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,6 @@ ignore = [

# TODO: enable this
"B905", # zip() without an explicit strict= parameter

# TODO: enable this (this was tried before - it is non-trivial)
"B019", # Use of functools.lru_cache or functools.cache on methods can lead
# to memory leaks
]
flake8-bugbear.extend-immutable-calls = ["immutables.Map"]

Expand Down

0 comments on commit 0fe5693

Please sign in to comment.