Skip to content

Commit

Permalink
config: Add support for remapping Postgres configs into Gel enums (#8275
Browse files Browse the repository at this point in the history
)

We have a long-standing rule that configuration parameters should avoid
using a boolean type and use enums instead (though we've not generally
been super diligent about this always).  Postgres, on the other hand,
has lots of boolean settings.  To solve this, teach the config framework
how to remap backend config values onto arbitrary frontend enums.  In
general the below is all that is required:

    class EnabledDisabledEnum(enum.StrEnum):
        Enabled = "Enabled"
        Disabled = "Disabled"

    class EnabledDisabledType(
        EnumScalarType[EnabledDisabledEnum],
        edgeql_type="cfg::TestEnabledDisabledEnum",
    ):
        @classmethod
def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]:
            return {
                EnabledDisabledEnum.Enabled: "true",
                EnabledDisabledEnum.Disabled: "false",
            }

BACKPORT NOTES: The patches won't work until the next two commits are
applied. Fixing patches wound up being downstream of this PR in
annoying way.
  • Loading branch information
elprans authored and msullivan committed Feb 4, 2025
1 parent 34923bd commit ace9b70
Show file tree
Hide file tree
Showing 12 changed files with 528 additions and 84 deletions.
44 changes: 20 additions & 24 deletions edb/ir/staeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

from edb.common import typeutils
from edb.common import parsing
from edb.common import uuidgen
from edb.common import value_dispatch
from edb.edgeql import ast as qlast
from edb.edgeql import compiler as qlcompiler
Expand Down Expand Up @@ -495,6 +494,9 @@ def bool_const_to_python(
def cast_const_to_python(ir: irast.TypeCast, schema: s_schema.Schema) -> Any:

schema, stype = irtyputils.ir_typeref_to_type(schema, ir.to_type)
if not isinstance(stype, s_scalars.ScalarType):
raise UnsupportedExpressionError(
"non-scalar casts are not supported in Python eval")
pytype = scalar_type_to_python_type(stype, schema)
sval = evaluate_to_python_val(ir.expr, schema=schema)
return python_cast(sval, pytype)
Expand Down Expand Up @@ -544,31 +546,23 @@ def schema_type_to_python_type(
f'{stype.get_displayname(schema)} is not representable in Python')


typemap = {
'std::str': str,
'std::anyint': int,
'std::anyfloat': float,
'std::decimal': decimal.Decimal,
'std::bigint': decimal.Decimal,
'std::bool': bool,
'std::json': str,
'std::uuid': uuidgen.UUID,
'std::duration': statypes.Duration,
'cfg::memory': statypes.ConfigMemory,
}


def scalar_type_to_python_type(
stype: s_types.Type,
stype: s_scalars.ScalarType,
schema: s_schema.Schema,
) -> type:
for basetype_name, pytype in typemap.items():
basetype = schema.get(
basetype_name, type=s_scalars.ScalarType, default=None)
if basetype and stype.issubclass(schema, basetype):
return pytype

if stype.is_enum(schema):
typname = stype.get_name(schema)
pytype = statypes.maybe_get_python_type_for_scalar_type_name(str(typname))
if pytype is None:
for ancestor in stype.get_ancestors(schema).objects(schema):
typname = ancestor.get_name(schema)
pytype = statypes.maybe_get_python_type_for_scalar_type_name(
str(typname))
if pytype is not None:
break

if pytype is not None:
return pytype
elif stype.is_enum(schema):
return str

raise UnsupportedExpressionError(
Expand Down Expand Up @@ -618,8 +612,10 @@ def object_type_to_spec(
ptype, schema, spec_class=spec_class,
parent=parent, _memo=_memo)
_memo[ptype] = pytype
else:
elif isinstance(ptype, s_scalars.ScalarType):
pytype = scalar_type_to_python_type(ptype, schema)
else:
raise UnsupportedExpressionError(f"unsupported cast type: {ptype}")

ptr_card: qltypes.SchemaCardinality = p.get_cardinality(schema)
if ptr_card.is_known():
Expand Down
234 changes: 231 additions & 3 deletions edb/ir/statypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,34 @@


from __future__ import annotations
from typing import Any, Optional
from typing import (
Any,
Callable,
ClassVar,
Generic,
Mapping,
Optional,
Self,
TypeVar,
)

import dataclasses
import datetime
import decimal
import enum
import functools
import re
import struct
import datetime
import uuid

import immutables

from edb import errors
from edb.common import parametric
from edb.common import uuidgen

from edb.schema import name as s_name
from edb.schema import objects as s_obj

MISSING: Any = object()

Expand Down Expand Up @@ -100,6 +117,14 @@ def __init__(self, val: str, /) -> None:
def to_backend_str(self) -> str:
raise NotImplementedError

@classmethod
def to_backend_expr(cls, expr: str) -> str:
raise NotImplementedError("{cls}.to_backend_expr()")

@classmethod
def to_frontend_expr(cls, expr: str) -> Optional[str]:
raise NotImplementedError("{cls}.to_frontend_expr()")

def to_json(self) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -375,6 +400,14 @@ def to_timedelta(self) -> datetime.timedelta:
def to_backend_str(self) -> str:
return f'{self.to_microseconds()}us'

@classmethod
def to_backend_expr(cls, expr: str) -> str:
return f"edgedb_VER._interval_to_ms(({expr})::interval)::text || 'ms'"

@classmethod
def to_frontend_expr(cls, expr: str) -> Optional[str]:
return None

def to_json(self) -> str:
return self.to_iso8601()

Expand Down Expand Up @@ -494,6 +527,14 @@ def to_backend_str(self) -> str:

return f'{self._value}B'

@classmethod
def to_backend_expr(cls, expr: str) -> str:
return f"edgedb_VER.cfg_memory_to_str({expr})"

@classmethod
def to_frontend_expr(cls, expr: str) -> Optional[str]:
return f"(edgedb_VER.str_to_cfg_memory({expr})::text || 'B')"

def to_json(self) -> str:
return self.to_str()

Expand All @@ -503,8 +544,195 @@ def __repr__(self) -> str:
def __hash__(self) -> int:
return hash(self._value)

def __eq__(self, other: object) -> bool:
def __eq__(self, other: Any) -> bool:
if isinstance(other, ConfigMemory):
return self._value == other._value
else:
return False


typemap = {
'std::str': str,
'std::anyint': int,
'std::anyfloat': float,
'std::decimal': decimal.Decimal,
'std::bigint': decimal.Decimal,
'std::bool': bool,
'std::json': str,
'std::uuid': uuidgen.UUID,
'std::duration': Duration,
'cfg::memory': ConfigMemory,
}


def maybe_get_python_type_for_scalar_type_name(name: str) -> Optional[type]:
return typemap.get(name)


E = TypeVar("E", bound=enum.StrEnum)


class EnumScalarType(
ScalarType,
parametric.SingleParametricType[E],
Generic[E],
):
"""Configuration value represented by a custom string enum type that
supports arbitrary value mapping to backend (Postgres) configuration
values, e.g mapping "Enabled"/"Disabled" enum to a bool value, etc.
We use SingleParametricType to obtain runtime access to the Generic
type arg to avoid having to copy-paste the constructors.
"""

_val: E
_eql_type: ClassVar[Optional[s_name.QualName]]

def __init_subclass__(
cls,
*,
edgeql_type: Optional[str] = None,
**kwargs: Any,
) -> None:
global typemap
super().__init_subclass__(**kwargs)
if edgeql_type is not None:
if edgeql_type in typemap:
raise TypeError(
f"{edgeql_type} is already a registered EnumScalarType")
typemap[edgeql_type] = cls
cls._eql_type = s_name.QualName.from_string(edgeql_type)

def __init__(
self,
val: E | str,
) -> None:
if isinstance(val, self.type):
self._val = val
elif isinstance(val, str):
try:
self._val = self.type(val)
except ValueError:
raise errors.InvalidValueError(
f'unexpected backend value for '
f'{self.__class__.__name__}: {val!r}'
) from None

def to_str(self) -> str:
return str(self._val)

def to_json(self) -> str:
return self._val

def encode(self) -> bytes:
return self._val.encode("utf8")

@classmethod
def get_translation_map(cls) -> Mapping[E, str]:
raise NotImplementedError

@classmethod
def decode(cls, data: bytes) -> Self:
return cls(val=cls.type(data.decode("utf8")))

def __repr__(self) -> str:
return f"<statypes.{self.__class__.__name__} '{self._val}'>"

def __hash__(self) -> int:
return hash(self._val)

def __eq__(self, other: Any) -> bool:
if isinstance(other, type(self)):
return self._val == other._val
else:
return NotImplemented

def __reduce__(self) -> tuple[
Callable[..., EnumScalarType[Any]],
tuple[
Optional[tuple[type, ...] | type],
E,
],
]:
assert type(self).is_fully_resolved(), \
f'{type(self)} parameters are not resolved'

cls: type[EnumScalarType[E]] = self.__class__
types: Optional[tuple[type, ...]] = self.orig_args
if types is None or not cls.is_anon_parametrized():
typeargs = None
else:
typeargs = types[0] if len(types) == 1 else types
return (cls.__restore__, (typeargs, self._val))

@classmethod
def __restore__(
cls,
typeargs: Optional[tuple[type, ...] | type],
val: E,
) -> Self:
if typeargs is None or cls.is_anon_parametrized():
obj = cls(val)
else:
obj = cls[typeargs](val) # type: ignore[index]

return obj

@classmethod
def get_edgeql_typeid(cls) -> uuid.UUID:
return s_obj.get_known_type_id('std::str')

@classmethod
def get_edgeql_type(cls) -> s_name.QualName:
"""Return fully-qualified name of the scalar type for this setting."""
assert cls._eql_type is not None
return cls._eql_type

def to_backend_str(self) -> str:
"""Convert static frontend config value to backend config value."""
return self.get_translation_map()[self._val]

@classmethod
def to_backend_expr(cls, expr: str) -> str:
"""Convert dynamic backend config value to frontend config value."""
cases_list = []
for fe_val, be_val in cls.get_translation_map().items():
cases_list.append(f"WHEN lower('{fe_val}') THEN '{be_val}'")
cases = "\n".join(cases_list)
errmsg = f"unexpected frontend value for {cls.__name__}: %s"
err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))"
return (
f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND "
f"FROM lower(({expr})) AS f(v))"
)

@classmethod
def to_frontend_expr(cls, expr: str) -> Optional[str]:
"""Convert dynamic frontend config value to backend config value."""
cases_list = []
for fe_val, be_val in cls.get_translation_map().items():
cases_list.append(f"WHEN lower('{be_val}') THEN '{fe_val}'")
cases = "\n".join(cases_list)
errmsg = f"unexpected backend value for {cls.__name__}: %s"
err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))"
return (
f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND "
f"FROM lower(({expr})) AS f(v))"
)


class EnabledDisabledEnum(enum.StrEnum):
Enabled = "Enabled"
Disabled = "Disabled"


class EnabledDisabledType(
EnumScalarType[EnabledDisabledEnum],
edgeql_type="cfg::TestEnabledDisabledEnum",
):
@classmethod
def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]:
return {
EnabledDisabledEnum.Enabled: "true",
EnabledDisabledEnum.Disabled: "false",
}
10 changes: 9 additions & 1 deletion edb/lib/_testmode.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ CREATE TYPE cfg::TestInstanceConfigStatTypes EXTENDING cfg::TestInstanceConfig {
};


CREATE SCALAR TYPE cfg::TestEnum extending enum<One, Two, Three>;
CREATE SCALAR TYPE cfg::TestEnum EXTENDING enum<One, Two, Three>;
CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum
EXTENDING enum<Enabled, Disabled>;


ALTER TYPE cfg::AbstractConfig {
Expand Down Expand Up @@ -141,6 +143,12 @@ ALTER TYPE cfg::AbstractConfig {
CREATE ANNOTATION cfg::internal := 'true';
CREATE ANNOTATION cfg::backend_setting := '"max_connections"';
};

CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum {
CREATE ANNOTATION cfg::internal := 'true';
CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"';
SET default := cfg::TestEnabledDisabledEnum.Enabled;
};
};


Expand Down
Loading

0 comments on commit ace9b70

Please sign in to comment.