From ace9b709f48944f4cbabbf2d9aca64f85a622fa8 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 30 Jan 2025 10:44:27 -0800 Subject: [PATCH] config: Add support for remapping Postgres configs into Gel enums (#8275) We have a long-standing rule that configuration parameters should avoid using a boolean type and use enums instead (though we've not generally been super diligent about this always). Postgres, on the other hand, has lots of boolean settings. To solve this, teach the config framework how to remap backend config values onto arbitrary frontend enums. In general the below is all that is required: class EnabledDisabledEnum(enum.StrEnum): Enabled = "Enabled" Disabled = "Disabled" class EnabledDisabledType( EnumScalarType[EnabledDisabledEnum], edgeql_type="cfg::TestEnabledDisabledEnum", ): @classmethod def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]: return { EnabledDisabledEnum.Enabled: "true", EnabledDisabledEnum.Disabled: "false", } BACKPORT NOTES: The patches won't work until the next two commits are applied. Fixing patches wound up being downstream of this PR in annoying way. --- edb/ir/staeval.py | 44 +++--- edb/ir/statypes.py | 234 +++++++++++++++++++++++++++++++- edb/lib/_testmode.edgeql | 10 +- edb/pgsql/metaschema.py | 128 ++++++++++++----- edb/pgsql/patches.py | 14 ++ edb/schema/utils.py | 9 ++ edb/server/bootstrap.py | 19 +-- edb/server/compiler/sertypes.py | 26 +++- edb/server/config/ops.py | 7 + edb/server/config/spec.py | 25 +++- edb/server/config/types.py | 5 + tests/test_server_config.py | 91 +++++++++++++ 12 files changed, 528 insertions(+), 84 deletions(-) diff --git a/edb/ir/staeval.py b/edb/ir/staeval.py index 81fcbd32ce0..7ce152d8af8 100644 --- a/edb/ir/staeval.py +++ b/edb/ir/staeval.py @@ -43,7 +43,6 @@ from edb.common import typeutils from edb.common import parsing -from edb.common import uuidgen from edb.common import value_dispatch from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler @@ -495,6 +494,9 @@ def bool_const_to_python( def cast_const_to_python(ir: irast.TypeCast, schema: s_schema.Schema) -> Any: schema, stype = irtyputils.ir_typeref_to_type(schema, ir.to_type) + if not isinstance(stype, s_scalars.ScalarType): + raise UnsupportedExpressionError( + "non-scalar casts are not supported in Python eval") pytype = scalar_type_to_python_type(stype, schema) sval = evaluate_to_python_val(ir.expr, schema=schema) return python_cast(sval, pytype) @@ -544,31 +546,23 @@ def schema_type_to_python_type( f'{stype.get_displayname(schema)} is not representable in Python') -typemap = { - 'std::str': str, - 'std::anyint': int, - 'std::anyfloat': float, - 'std::decimal': decimal.Decimal, - 'std::bigint': decimal.Decimal, - 'std::bool': bool, - 'std::json': str, - 'std::uuid': uuidgen.UUID, - 'std::duration': statypes.Duration, - 'cfg::memory': statypes.ConfigMemory, -} - - def scalar_type_to_python_type( - stype: s_types.Type, + stype: s_scalars.ScalarType, schema: s_schema.Schema, ) -> type: - for basetype_name, pytype in typemap.items(): - basetype = schema.get( - basetype_name, type=s_scalars.ScalarType, default=None) - if basetype and stype.issubclass(schema, basetype): - return pytype - - if stype.is_enum(schema): + typname = stype.get_name(schema) + pytype = statypes.maybe_get_python_type_for_scalar_type_name(str(typname)) + if pytype is None: + for ancestor in stype.get_ancestors(schema).objects(schema): + typname = ancestor.get_name(schema) + pytype = statypes.maybe_get_python_type_for_scalar_type_name( + str(typname)) + if pytype is not None: + break + + if pytype is not None: + return pytype + elif stype.is_enum(schema): return str raise UnsupportedExpressionError( @@ -618,8 +612,10 @@ def object_type_to_spec( ptype, schema, spec_class=spec_class, parent=parent, _memo=_memo) _memo[ptype] = pytype - else: + elif isinstance(ptype, s_scalars.ScalarType): pytype = scalar_type_to_python_type(ptype, schema) + else: + raise UnsupportedExpressionError(f"unsupported cast type: {ptype}") ptr_card: qltypes.SchemaCardinality = p.get_cardinality(schema) if ptr_card.is_known(): diff --git a/edb/ir/statypes.py b/edb/ir/statypes.py index 237d21b6cb9..8e86853f3b4 100644 --- a/edb/ir/statypes.py +++ b/edb/ir/statypes.py @@ -18,17 +18,34 @@ from __future__ import annotations -from typing import Any, Optional +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Mapping, + Optional, + Self, + TypeVar, +) import dataclasses +import datetime +import decimal +import enum import functools import re import struct -import datetime +import uuid import immutables from edb import errors +from edb.common import parametric +from edb.common import uuidgen + +from edb.schema import name as s_name +from edb.schema import objects as s_obj MISSING: Any = object() @@ -100,6 +117,14 @@ def __init__(self, val: str, /) -> None: def to_backend_str(self) -> str: raise NotImplementedError + @classmethod + def to_backend_expr(cls, expr: str) -> str: + raise NotImplementedError("{cls}.to_backend_expr()") + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + raise NotImplementedError("{cls}.to_frontend_expr()") + def to_json(self) -> str: raise NotImplementedError @@ -375,6 +400,14 @@ def to_timedelta(self) -> datetime.timedelta: def to_backend_str(self) -> str: return f'{self.to_microseconds()}us' + @classmethod + def to_backend_expr(cls, expr: str) -> str: + return f"edgedb_VER._interval_to_ms(({expr})::interval)::text || 'ms'" + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + return None + def to_json(self) -> str: return self.to_iso8601() @@ -494,6 +527,14 @@ def to_backend_str(self) -> str: return f'{self._value}B' + @classmethod + def to_backend_expr(cls, expr: str) -> str: + return f"edgedb_VER.cfg_memory_to_str({expr})" + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + return f"(edgedb_VER.str_to_cfg_memory({expr})::text || 'B')" + def to_json(self) -> str: return self.to_str() @@ -503,8 +544,195 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(self._value) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, ConfigMemory): return self._value == other._value else: return False + + +typemap = { + 'std::str': str, + 'std::anyint': int, + 'std::anyfloat': float, + 'std::decimal': decimal.Decimal, + 'std::bigint': decimal.Decimal, + 'std::bool': bool, + 'std::json': str, + 'std::uuid': uuidgen.UUID, + 'std::duration': Duration, + 'cfg::memory': ConfigMemory, +} + + +def maybe_get_python_type_for_scalar_type_name(name: str) -> Optional[type]: + return typemap.get(name) + + +E = TypeVar("E", bound=enum.StrEnum) + + +class EnumScalarType( + ScalarType, + parametric.SingleParametricType[E], + Generic[E], +): + """Configuration value represented by a custom string enum type that + supports arbitrary value mapping to backend (Postgres) configuration + values, e.g mapping "Enabled"/"Disabled" enum to a bool value, etc. + + We use SingleParametricType to obtain runtime access to the Generic + type arg to avoid having to copy-paste the constructors. + """ + + _val: E + _eql_type: ClassVar[Optional[s_name.QualName]] + + def __init_subclass__( + cls, + *, + edgeql_type: Optional[str] = None, + **kwargs: Any, + ) -> None: + global typemap + super().__init_subclass__(**kwargs) + if edgeql_type is not None: + if edgeql_type in typemap: + raise TypeError( + f"{edgeql_type} is already a registered EnumScalarType") + typemap[edgeql_type] = cls + cls._eql_type = s_name.QualName.from_string(edgeql_type) + + def __init__( + self, + val: E | str, + ) -> None: + if isinstance(val, self.type): + self._val = val + elif isinstance(val, str): + try: + self._val = self.type(val) + except ValueError: + raise errors.InvalidValueError( + f'unexpected backend value for ' + f'{self.__class__.__name__}: {val!r}' + ) from None + + def to_str(self) -> str: + return str(self._val) + + def to_json(self) -> str: + return self._val + + def encode(self) -> bytes: + return self._val.encode("utf8") + + @classmethod + def get_translation_map(cls) -> Mapping[E, str]: + raise NotImplementedError + + @classmethod + def decode(cls, data: bytes) -> Self: + return cls(val=cls.type(data.decode("utf8"))) + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self._val) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, type(self)): + return self._val == other._val + else: + return NotImplemented + + def __reduce__(self) -> tuple[ + Callable[..., EnumScalarType[Any]], + tuple[ + Optional[tuple[type, ...] | type], + E, + ], + ]: + assert type(self).is_fully_resolved(), \ + f'{type(self)} parameters are not resolved' + + cls: type[EnumScalarType[E]] = self.__class__ + types: Optional[tuple[type, ...]] = self.orig_args + if types is None or not cls.is_anon_parametrized(): + typeargs = None + else: + typeargs = types[0] if len(types) == 1 else types + return (cls.__restore__, (typeargs, self._val)) + + @classmethod + def __restore__( + cls, + typeargs: Optional[tuple[type, ...] | type], + val: E, + ) -> Self: + if typeargs is None or cls.is_anon_parametrized(): + obj = cls(val) + else: + obj = cls[typeargs](val) # type: ignore[index] + + return obj + + @classmethod + def get_edgeql_typeid(cls) -> uuid.UUID: + return s_obj.get_known_type_id('std::str') + + @classmethod + def get_edgeql_type(cls) -> s_name.QualName: + """Return fully-qualified name of the scalar type for this setting.""" + assert cls._eql_type is not None + return cls._eql_type + + def to_backend_str(self) -> str: + """Convert static frontend config value to backend config value.""" + return self.get_translation_map()[self._val] + + @classmethod + def to_backend_expr(cls, expr: str) -> str: + """Convert dynamic backend config value to frontend config value.""" + cases_list = [] + for fe_val, be_val in cls.get_translation_map().items(): + cases_list.append(f"WHEN lower('{fe_val}') THEN '{be_val}'") + cases = "\n".join(cases_list) + errmsg = f"unexpected frontend value for {cls.__name__}: %s" + err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" + return ( + f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " + f"FROM lower(({expr})) AS f(v))" + ) + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + """Convert dynamic frontend config value to backend config value.""" + cases_list = [] + for fe_val, be_val in cls.get_translation_map().items(): + cases_list.append(f"WHEN lower('{be_val}') THEN '{fe_val}'") + cases = "\n".join(cases_list) + errmsg = f"unexpected backend value for {cls.__name__}: %s" + err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" + return ( + f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " + f"FROM lower(({expr})) AS f(v))" + ) + + +class EnabledDisabledEnum(enum.StrEnum): + Enabled = "Enabled" + Disabled = "Disabled" + + +class EnabledDisabledType( + EnumScalarType[EnabledDisabledEnum], + edgeql_type="cfg::TestEnabledDisabledEnum", +): + @classmethod + def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]: + return { + EnabledDisabledEnum.Enabled: "true", + EnabledDisabledEnum.Disabled: "false", + } diff --git a/edb/lib/_testmode.edgeql b/edb/lib/_testmode.edgeql index 6fb29b1f620..761a5dc53bb 100644 --- a/edb/lib/_testmode.edgeql +++ b/edb/lib/_testmode.edgeql @@ -56,7 +56,9 @@ CREATE TYPE cfg::TestInstanceConfigStatTypes EXTENDING cfg::TestInstanceConfig { }; -CREATE SCALAR TYPE cfg::TestEnum extending enum; +CREATE SCALAR TYPE cfg::TestEnum EXTENDING enum; +CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum + EXTENDING enum; ALTER TYPE cfg::AbstractConfig { @@ -141,6 +143,12 @@ ALTER TYPE cfg::AbstractConfig { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::backend_setting := '"max_connections"'; }; + + CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { + CREATE ANNOTATION cfg::internal := 'true'; + CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; + SET default := cfg::TestEnabledDisabledEnum.Enabled; + }; }; diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 8fcd01f5ed7..0a696b43072 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -1747,11 +1747,11 @@ class AssertJSONTypeFunction(trampoline.VersionedFunction): 'wrong_object_type', msg => coalesce( msg, - ( - 'expected JSON ' - || array_to_string(typenames, ' or ') - || '; got JSON ' - || coalesce(jsonb_typeof(val), 'UNKNOWN') + format( + 'expected JSON %s; got JSON %s: %s', + array_to_string(typenames, ' or '), + coalesce(jsonb_typeof(val), 'UNKNOWN'), + val::text ) ), detail => detail @@ -3392,7 +3392,7 @@ class InterpretConfigValueToJsonFunction(trampoline.VersionedFunction): - for memory size: we always convert to kilobytes; - already unitless numbers are left as is. - See https://www.postgresql.org/docs/12/config-setting.html + See https://www.postgresql.org/docs/current/config-setting.html for information about the units Postgres config system has. """ @@ -3444,6 +3444,53 @@ def __init__(self) -> None: ) +class PostgresJsonConfigValueToFrontendConfigValueFunction( + trampoline.VersionedFunction, +): + """Convert a Postgres config value to frontend config value. + + Most values are retained as-is, but some need translation, which + is implemented as a to_frontend_expr() on the corresponding + setting ScalarType. + """ + + def __init__(self, config_spec: edbconfig.Spec) -> None: + variants_list = [] + for setting in config_spec.values(): + if ( + setting.backend_setting + and isinstance(setting.type, type) + and issubclass(setting.type, statypes.ScalarType) + ): + conv_expr = setting.type.to_frontend_expr('"value"->>0') + if conv_expr is not None: + variants_list.append(f""" + WHEN {ql(setting.backend_setting)} + THEN to_jsonb({conv_expr}) + """) + + variants = "\n".join(variants_list) + text = f""" + SELECT ( + CASE "setting_name" + {variants} + ELSE "value" + END + ) + """ + + super().__init__( + name=('edgedb', '_postgres_json_config_value_to_fe_config_value'), + args=[ + ('setting_name', ('text',)), + ('value', ('jsonb',)) + ], + returns=('jsonb',), + volatility='immutable', + text=text, + ) + + class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): """Convert a Postgres setting to JSON value. @@ -3471,26 +3518,10 @@ class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): text = r""" SELECT - (CASE - - WHEN parsed_value.unit != '' - THEN - edgedb_VER._interpret_config_value_to_json( - parsed_value.val, - settings.vartype, - 1, - parsed_value.unit - ) - - ELSE - edgedb_VER._interpret_config_value_to_json( - "setting_value", - settings.vartype, - settings.multiplier, - settings.unit - ) - - END) + edgedb_VER._postgres_json_config_value_to_fe_config_value( + "setting_name", + backend_json_value.value + ) FROM LATERAL ( SELECT regexp_match( @@ -3521,8 +3552,29 @@ class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): as vartype, COALESCE(settings_in.multiplier, '1') as multiplier, COALESCE(settings_in.unit, '') as unit - ) as settings + ) AS settings + CROSS JOIN LATERAL + (SELECT + (CASE + WHEN parsed_value.unit != '' + THEN + edgedb_VER._interpret_config_value_to_json( + parsed_value.val, + settings.vartype, + 1, + parsed_value.unit + ) + + ELSE + edgedb_VER._interpret_config_value_to_json( + "setting_value", + settings.vartype, + settings.multiplier, + settings.unit + ) + END) AS value + ) AS backend_json_value """ def __init__(self) -> None: @@ -3748,11 +3800,14 @@ class SysConfigFullFunction(trampoline.VersionedFunction): pg_config AS ( SELECT spec.name, - edgedb_VER._interpret_config_value_to_json( - settings.setting, - settings.vartype, - settings.multiplier, - settings.unit + edgedb_VER._postgres_json_config_value_to_fe_config_value( + settings.name, + edgedb_VER._interpret_config_value_to_json( + settings.setting, + settings.vartype, + settings.multiplier, + settings.unit + ) ) AS value, source AS source, TRUE AS is_backend @@ -4123,12 +4178,9 @@ def __init__(self, config_spec: edbconfig.Spec) -> None: valql = '"value"->>0' if ( isinstance(setting.type, type) - and issubclass(setting.type, statypes.Duration) + and issubclass(setting.type, statypes.ScalarType) ): - valql = f""" - edgedb_VER._interval_to_ms(({valql})::interval)::text \ - || 'ms' - """ + valql = setting.type.to_backend_expr(valql) variants_list.append(f''' WHEN "name" = {ql(setting_name)} @@ -5244,6 +5296,8 @@ def get_bootstrap_commands( dbops.CreateFunction(TypeIDToConfigType()), dbops.CreateFunction(ConvertPostgresConfigUnitsFunction()), dbops.CreateFunction(InterpretConfigValueToJsonFunction()), + dbops.CreateFunction( + PostgresJsonConfigValueToFrontendConfigValueFunction(config_spec)), dbops.CreateFunction(PostgresConfigValueToJsonFunction()), dbops.CreateFunction(SysConfigFullFunction()), dbops.CreateFunction(SysConfigUncachedFunction()), diff --git a/edb/pgsql/patches.py b/edb/pgsql/patches.py index ca29cf381d2..d5bd67b1bf9 100644 --- a/edb/pgsql/patches.py +++ b/edb/pgsql/patches.py @@ -67,4 +67,18 @@ def get_version_key(num_patches: int): '''), ('sql-introspection', ''), ('metaschema-sql', 'SysConfigFullFunction'), + # 6.0b3 or 6.0rc1 + ('edgeql+schema+config+testmode', ''' +CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum + EXTENDING enum; +ALTER TYPE cfg::AbstractConfig { + CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { + CREATE ANNOTATION cfg::internal := 'true'; + CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; + SET default := cfg::TestEnabledDisabledEnum.Enabled; + }; +}; +'''), + ('metaschema-sql', 'PostgresConfigValueToJsonFunction'), + ('metaschema-sql', 'SysConfigFullFunction'), ] diff --git a/edb/schema/utils.py b/edb/schema/utils.py index 88ddac369dc..bc98b25d250 100644 --- a/edb/schema/utils.py +++ b/edb/schema/utils.py @@ -1379,6 +1379,15 @@ def const_ast_from_python(val: Any, with_secrets: bool=False) -> qlast.Expr: ), expr=qlast.Constant.string(value=val.to_iso8601()), ) + elif isinstance(val, statypes.EnumScalarType): + qltype = val.get_edgeql_type() + return qlast.TypeCast( + type=qlast.TypeName( + maintype=qlast.ObjectRef( + module=qltype.module, name=qltype.name), + ), + expr=qlast.Constant.string(value=val.to_str()), + ) elif isinstance(val, statypes.CompositeType): return qlast.InsertQuery( subject=name_to_ast_ref(sn.name_from_string(val._tspec.name)), diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 21735b58004..e7a526de6f4 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -50,9 +50,10 @@ from edb import errors from edb import edgeql -from edb.ir import statypes from edb.ir import typeutils as irtyputils from edb.edgeql import ast as qlast +from edb.edgeql import codegen as qlcodegen +from edb.edgeql import qltypes from edb.common import debug from edb.common import devmode @@ -71,6 +72,7 @@ from edb.schema import schema as s_schema from edb.schema import std as s_std from edb.schema import types as s_types +from edb.schema import utils as s_utils from edb.server import args as edbargs from edb.server import config @@ -614,7 +616,6 @@ def compile_bootstrap_script( expected_cardinality_one: bool = False, output_format: edbcompiler.OutputFormat = edbcompiler.OutputFormat.JSON, ) -> Tuple[s_schema.Schema, str]: - ctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema, @@ -1956,13 +1957,13 @@ async def _configure( backend_params.has_configfile_access ) ): - if isinstance(setting.default, statypes.Duration): - val = f'"{setting.default.to_iso8601()}"' - else: - val = repr(setting.default) - script = f''' - CONFIGURE INSTANCE SET {setting.name} := {val}; - ''' + script = qlcodegen.generate_source( + qlast.ConfigSet( + name=qlast.ObjectRef(name=setting.name), + scope=qltypes.ConfigScope.INSTANCE, + expr=s_utils.const_ast_from_python(setting.default), + ) + ) schema, sql = compile_bootstrap_script(compiler, schema, script) await _execute(ctx.conn, sql) diff --git a/edb/server/compiler/sertypes.py b/edb/server/compiler/sertypes.py index dcc83fdeb49..ffc91280c91 100644 --- a/edb/server/compiler/sertypes.py +++ b/edb/server/compiler/sertypes.py @@ -2098,11 +2098,29 @@ class EnumDesc(SchemaTypeDesc): names: list[str] ancestors: Optional[list[TypeDesc]] - def encode(self, data: str) -> bytes: - return _encode_str(data) + @functools.cached_property + def _decoder(self) -> Callable[[bytes], Any]: + assert self.name is not None + pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) + if pytype is not None and issubclass(pytype, statypes.ScalarType): + return pytype.decode + else: + return _decode_str + + @functools.cached_property + def _encoder(self) -> Callable[[Any], bytes]: + assert self.name is not None + pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) + if pytype is not None and issubclass(pytype, statypes.ScalarType): + return pytype.encode + else: + return _encode_str + + def encode(self, data: Any) -> bytes: + return self._encoder(data) - def decode(self, data: bytes) -> str: - return _decode_str(data) + def decode(self, data: bytes) -> Any: + return self._decoder(data) @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/edb/server/config/ops.py b/edb/server/config/ops.py index fb66e90ea31..b9f7dab744b 100644 --- a/edb/server/config/ops.py +++ b/edb/server/config/ops.py @@ -98,6 +98,9 @@ def coerce_single_value(setting: spec.Setting, value: Any) -> Any: elif (isinstance(value, (str, int)) and _issubclass(setting.type, statypes.ConfigMemory)): return statypes.ConfigMemory(value) + elif (isinstance(value, str) and + _issubclass(setting.type, statypes.EnumScalarType)): + return setting.type(value) else: raise errors.ConfigurationError( f'invalid value type for the {setting.name!r} setting') @@ -352,6 +355,8 @@ def spec_to_json(spec: spec.Spec): typeid = s_obj.get_known_type_id('std::duration') elif _issubclass(setting.type, statypes.ConfigMemory): typeid = s_obj.get_known_type_id('cfg::memory') + elif _issubclass(setting.type, statypes.EnumScalarType): + typeid = setting.type.get_edgeql_typeid() elif isinstance(setting.type, types.ConfigTypeSpec): typeid = types.CompositeConfigType.get_edgeql_typeid() else: @@ -420,6 +425,8 @@ def value_from_json_value(spec: spec.Spec, setting: spec.Setting, value: Any): return statypes.Duration.from_iso8601(value) elif _issubclass(setting.type, statypes.ConfigMemory): return statypes.ConfigMemory(value) + elif _issubclass(setting.type, statypes.EnumScalarType): + return setting.type(value) else: return value diff --git a/edb/server/config/spec.py b/edb/server/config/spec.py index 6288b08978b..e41b8ef8af8 100644 --- a/edb/server/config/spec.py +++ b/edb/server/config/spec.py @@ -40,8 +40,12 @@ from . import types -SETTING_TYPES = {str, int, bool, float, - statypes.Duration, statypes.ConfigMemory} +SETTING_TYPES = { + str, + int, + bool, + float, +} @dataclasses.dataclass(frozen=True, eq=True) @@ -64,13 +68,20 @@ class Setting: protected: bool = False def __post_init__(self) -> None: - if (self.type not in SETTING_TYPES and - not isinstance(self.type, types.ConfigTypeSpec)): + if ( + self.type not in SETTING_TYPES + and not isinstance(self.type, types.ConfigTypeSpec) + and not ( + isinstance(self.type, type) + and issubclass(self.type, statypes.ScalarType) + ) + ): raise ValueError( f'invalid config setting {self.name!r}: ' f'type is expected to be either one of ' f'{{str, int, bool, float}} ' - f'or an edb.server.config.types.ConfigType subclass') + f'or an edb.server.config.types.ConfigType ', + f'or edb.ir.statypes.ScalarType subclass') if self.set_of: if not isinstance(self.default, frozenset): @@ -269,8 +280,10 @@ def _load_spec_from_type( ptype, schema, spec_class=types.ConfigTypeSpec, ) - else: + elif isinstance(ptype, s_scalars.ScalarType): pytype = staeval.scalar_type_to_python_type(ptype, schema) + else: + raise RuntimeError(f"unsupported config value type: {ptype}") attributes = {} for a, v in p.get_annotations(schema).items(schema): diff --git a/edb/server/config/types.py b/edb/server/config/types.py index 975cbbd8dbb..168bcfdd61f 100644 --- a/edb/server/config/types.py +++ b/edb/server/config/types.py @@ -203,6 +203,11 @@ def from_pyvalue( and isinstance(value, str | int) ): value = statypes.ConfigMemory(value) + elif ( + _issubclass(f_type, statypes.EnumScalarType) + and isinstance(value, str) + ): + value = f_type(value) elif not isinstance(f_type, type) or not isinstance(value, f_type): raise cls._err( diff --git a/tests/test_server_config.py b/tests/test_server_config.py index bb897039171..f59660f8fb9 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -40,6 +40,7 @@ from edb.protocol import messages from edb.testbase import server as tb + from edb.schema import objects as s_obj from edb.ir import statypes @@ -2140,6 +2141,96 @@ async def test_server_config_query_timeout(self): new_aborted += float(line.split(' ')[1]) self.assertEqual(orig_aborted, new_aborted) + @unittest.skipIf( + "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, + "cannot use CONFIGURE INSTANCE in multi-tenant mode", + ) + async def test_server_config_custom_enum(self): + async def assert_conf(con, name, expected_val): + val = await con.query_single(f''' + select assert_single(cfg::Config.{name}) + ''') + + self.assertEqual( + str(val), + expected_val + ) + + async with tb.start_edgedb_server( + security=args.ServerSecurityMode.InsecureDevMode, + ) as sd: + c1 = await sd.connect() + c2 = await sd.connect() + + await c2.query('create database test') + t1 = await sd.connect(database='test') + + # check that the default was set correctly + await assert_conf( + c1, '__check_function_bodies', 'Enabled') + + #### + + await c1.query(''' + configure instance set + __check_function_bodies + := cfg::TestEnabledDisabledEnum.Disabled; + ''') + + for c in {c1, c2, t1}: + await assert_conf( + c, '__check_function_bodies', 'Disabled') + + #### + + await t1.query(''' + configure current database set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Enabled; + ''') + + for c in {c1, c2}: + await assert_conf( + c, '__check_function_bodies', 'Disabled') + + await assert_conf( + t1, '__check_function_bodies', 'Enabled') + + #### + + await c2.query(''' + configure session set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Disabled; + ''') + await assert_conf( + c1, '__check_function_bodies', 'Disabled') + await assert_conf( + t1, '__check_function_bodies', 'Enabled') + await assert_conf( + c2, '__check_function_bodies', 'Disabled') + + #### + await c1.query(''' + configure instance reset + __check_function_bodies; + ''') + await t1.query(''' + configure session set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Disabled; + ''') + await assert_conf( + c1, '__check_function_bodies', 'Enabled') + await assert_conf( + t1, '__check_function_bodies', 'Disabled') + await assert_conf( + c2, '__check_function_bodies', 'Disabled') + + await c1.aclose() + await c2.aclose() + await t1.aclose() + class TestStaticServerConfig(tb.TestCase): @test.xerror("static config args not supported")