diff --git a/edb/common/lru.py b/edb/common/lru.py index 89963fbd9d8..2bbb29c120c 100644 --- a/edb/common/lru.py +++ b/edb/common/lru.py @@ -20,6 +20,10 @@ from __future__ import annotations import collections.abc +import functools + + +from typing import TypeVar, Callable class LRUMapping(collections.abc.MutableMapping): @@ -75,3 +79,51 @@ def __len__(self): def __iter__(self): return iter(self._dict) + + +Tf = TypeVar('Tf', bound=Callable) + + +class _NoPickle: + def __init__(self, obj): + self.obj = obj + + def __bool__(self): + return bool(self.obj) + + def __getstate__(self): + return () + + def __setstate__(self, _d): + self.obj = None + + +def lru_method_cache(size: int | None=128) -> Callable[[Tf], Tf]: + """A version of lru_cache for methods that shouldn't leak memory. + + Basically the idea is that we generate a per-object lru-cached + partially applied method. + + Since pickling an lru_cache of a lambda or a functools.partial + doesn't work, we wrap it in a _NoPickle object that doesn't pickle + its contents. + """ + def transformer(f: Tf) -> Tf: + key = f'__{f.__name__}_cached' + + def func(self, *args, **kwargs): + _m = getattr(self, key, None) + if not _m: + _m = _NoPickle( + functools.lru_cache(size)(functools.partial(f, self)) + ) + setattr(self, key, _m) + return _m.obj(*args, **kwargs) + + return func # type: ignore + + return transformer + + +def method_cache(f: Tf) -> Tf: + return lru_method_cache(None)(f) diff --git a/edb/schema/schema.py b/edb/schema/schema.py index e0098fc607e..8967207ea3b 100644 --- a/edb/schema/schema.py +++ b/edb/schema/schema.py @@ -51,6 +51,7 @@ from edb import errors from edb.common import adapter from edb.common import english +from edb.common import lru from . import casts as s_casts from . import functions as s_func @@ -1218,7 +1219,7 @@ def get_operators( type=s_oper.Operator, ) - @functools.lru_cache() + @lru.lru_method_cache() def _get_casts( self, stype: s_types.Type, @@ -1274,7 +1275,7 @@ def get_referrers( return self._get_referrers( scls, scls_type=scls_type, field_name=field_name) - @functools.lru_cache() + @lru.lru_method_cache() def _get_referrers( self, scls: so.Object, @@ -1312,8 +1313,8 @@ def _get_referrers( return frozenset(referrers) # type: ignore - @functools.lru_cache() - def get_referrers_ex( # type: ignore + @lru.lru_method_cache() + def get_referrers_ex( self, scls: so.Object, *, diff --git a/edb/server/compiler/sertypes.py b/edb/server/compiler/sertypes.py index dcc83fdeb49..5a1b5a3a562 100644 --- a/edb/server/compiler/sertypes.py +++ b/edb/server/compiler/sertypes.py @@ -45,8 +45,9 @@ from edb import errors from edb.common import binwrapper -from edb.common import value_dispatch +from edb.common import lru from edb.common import uuidgen +from edb.common import value_dispatch from edb.protocol import enums as p_enums from edb.server import config @@ -1900,7 +1901,7 @@ def get_global_type_rep( class CompilationConfigSerializer(InputShapeSerializer): - @functools.lru_cache(64) + @lru.lru_method_cache(64) def encode_configs( self, *configs: immutables.Map[str, config.SettingValue] | None ) -> bytes: diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index 4d3d69ca91c..9023024e6b8 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -37,6 +37,7 @@ import immutables from edb.common import debug +from edb.common import lru from edb.pgsql import params as pgparams @@ -189,7 +190,7 @@ def _get_init_args(self): assert self._dbindex is not None return self._make_init_args(*self._dbindex.get_cached_compiler_args()) - @functools.lru_cache(1) + @lru.lru_method_cache(1) def _make_init_args(self, dbs, global_schema_pickle, system_config): init_args = ( dbs, @@ -1103,7 +1104,7 @@ async def stop(self): if worker.done(): (await worker).close() - @functools.lru_cache(1) + @lru.lru_method_cache(1) def _make_init_args(self, dbs, global_schema_pickle, system_config): init_args = ( dbs, @@ -1337,7 +1338,7 @@ def drop_tenant(self, client_id: int): for worker in self._workers.values(): worker.invalidate(client_id) - @functools.cache + @lru.method_cache def _get_init_args(self): init_args = ( self._backend_runtime_params, diff --git a/edb/server/compiler_pool/server.py b/edb/server/compiler_pool/server.py index d5a5f42eca9..4596e226c8f 100644 --- a/edb/server/compiler_pool/server.py +++ b/edb/server/compiler_pool/server.py @@ -36,6 +36,7 @@ import immutables from edb.common import debug +from edb.common import lru from edb.common import markup from .. import metrics @@ -204,7 +205,7 @@ def _init(self, kwargs: dict[str, typing.Any]) -> None: # this is deferred to _init_server() pass - @functools.cache + @lru.method_cache def _get_init_args(self): init_args = ( self._backend_runtime_params, diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index c863772b107..b745173ec9e 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -36,7 +36,6 @@ import asyncio import copy -import functools import hashlib import json import logging @@ -53,6 +52,7 @@ from edb import buildmeta from edb import errors +from edb.common import lru from edb.common import supervisor from edb.common import uuidgen @@ -925,7 +925,7 @@ def stop_watching(self) -> None: if self._ha_backend is not None: self._ha_backend.stop_watching() - @functools.cache + @lru.method_cache def get_client_id(self) -> int: tenant_id = self._instance_params.tenant_id if self._ha_backend is not None: diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 50130f2a17f..bbdb0939e69 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -35,7 +35,6 @@ import asyncio import contextlib import dataclasses -import functools import json import logging import os @@ -54,6 +53,7 @@ from edb import buildmeta from edb import errors from edb.common import asyncutil +from edb.common import lru from edb.common import retryloop from edb.common import verutils from edb.common.log import current_tenant @@ -355,7 +355,7 @@ def get_pg_dbname(self, dbname: str) -> str: def get_pgaddr(self) -> pgconnparams.ConnectionParams: return self._cluster.get_pgaddr() - @functools.lru_cache + @lru.method_cache def get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: return self._cluster.get_runtime_params() diff --git a/pyproject.toml b/pyproject.toml index a27450d7db9..6d601b285b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -259,10 +259,6 @@ ignore = [ # TODO: enable this "B905", # zip() without an explicit strict= parameter - - # TODO: enable this (this was tried before - it is non-trivial) - "B019", # Use of functools.lru_cache or functools.cache on methods can lead - # to memory leaks ] flake8-bugbear.extend-immutable-calls = ["immutables.Map"]