diff --git a/edb/ir/staeval.py b/edb/ir/staeval.py index 81fcbd32ce0..7ce152d8af8 100644 --- a/edb/ir/staeval.py +++ b/edb/ir/staeval.py @@ -43,7 +43,6 @@ from edb.common import typeutils from edb.common import parsing -from edb.common import uuidgen from edb.common import value_dispatch from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler @@ -495,6 +494,9 @@ def bool_const_to_python( def cast_const_to_python(ir: irast.TypeCast, schema: s_schema.Schema) -> Any: schema, stype = irtyputils.ir_typeref_to_type(schema, ir.to_type) + if not isinstance(stype, s_scalars.ScalarType): + raise UnsupportedExpressionError( + "non-scalar casts are not supported in Python eval") pytype = scalar_type_to_python_type(stype, schema) sval = evaluate_to_python_val(ir.expr, schema=schema) return python_cast(sval, pytype) @@ -544,31 +546,23 @@ def schema_type_to_python_type( f'{stype.get_displayname(schema)} is not representable in Python') -typemap = { - 'std::str': str, - 'std::anyint': int, - 'std::anyfloat': float, - 'std::decimal': decimal.Decimal, - 'std::bigint': decimal.Decimal, - 'std::bool': bool, - 'std::json': str, - 'std::uuid': uuidgen.UUID, - 'std::duration': statypes.Duration, - 'cfg::memory': statypes.ConfigMemory, -} - - def scalar_type_to_python_type( - stype: s_types.Type, + stype: s_scalars.ScalarType, schema: s_schema.Schema, ) -> type: - for basetype_name, pytype in typemap.items(): - basetype = schema.get( - basetype_name, type=s_scalars.ScalarType, default=None) - if basetype and stype.issubclass(schema, basetype): - return pytype - - if stype.is_enum(schema): + typname = stype.get_name(schema) + pytype = statypes.maybe_get_python_type_for_scalar_type_name(str(typname)) + if pytype is None: + for ancestor in stype.get_ancestors(schema).objects(schema): + typname = ancestor.get_name(schema) + pytype = statypes.maybe_get_python_type_for_scalar_type_name( + str(typname)) + if pytype is not None: + break + + if pytype is not None: + return pytype + elif stype.is_enum(schema): return str raise UnsupportedExpressionError( @@ -618,8 +612,10 @@ def object_type_to_spec( ptype, schema, spec_class=spec_class, parent=parent, _memo=_memo) _memo[ptype] = pytype - else: + elif isinstance(ptype, s_scalars.ScalarType): pytype = scalar_type_to_python_type(ptype, schema) + else: + raise UnsupportedExpressionError(f"unsupported cast type: {ptype}") ptr_card: qltypes.SchemaCardinality = p.get_cardinality(schema) if ptr_card.is_known(): diff --git a/edb/ir/statypes.py b/edb/ir/statypes.py index 237d21b6cb9..8e86853f3b4 100644 --- a/edb/ir/statypes.py +++ b/edb/ir/statypes.py @@ -18,17 +18,34 @@ from __future__ import annotations -from typing import Any, Optional +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Mapping, + Optional, + Self, + TypeVar, +) import dataclasses +import datetime +import decimal +import enum import functools import re import struct -import datetime +import uuid import immutables from edb import errors +from edb.common import parametric +from edb.common import uuidgen + +from edb.schema import name as s_name +from edb.schema import objects as s_obj MISSING: Any = object() @@ -100,6 +117,14 @@ def __init__(self, val: str, /) -> None: def to_backend_str(self) -> str: raise NotImplementedError + @classmethod + def to_backend_expr(cls, expr: str) -> str: + raise NotImplementedError("{cls}.to_backend_expr()") + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + raise NotImplementedError("{cls}.to_frontend_expr()") + def to_json(self) -> str: raise NotImplementedError @@ -375,6 +400,14 @@ def to_timedelta(self) -> datetime.timedelta: def to_backend_str(self) -> str: return f'{self.to_microseconds()}us' + @classmethod + def to_backend_expr(cls, expr: str) -> str: + return f"edgedb_VER._interval_to_ms(({expr})::interval)::text || 'ms'" + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + return None + def to_json(self) -> str: return self.to_iso8601() @@ -494,6 +527,14 @@ def to_backend_str(self) -> str: return f'{self._value}B' + @classmethod + def to_backend_expr(cls, expr: str) -> str: + return f"edgedb_VER.cfg_memory_to_str({expr})" + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + return f"(edgedb_VER.str_to_cfg_memory({expr})::text || 'B')" + def to_json(self) -> str: return self.to_str() @@ -503,8 +544,195 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(self._value) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, ConfigMemory): return self._value == other._value else: return False + + +typemap = { + 'std::str': str, + 'std::anyint': int, + 'std::anyfloat': float, + 'std::decimal': decimal.Decimal, + 'std::bigint': decimal.Decimal, + 'std::bool': bool, + 'std::json': str, + 'std::uuid': uuidgen.UUID, + 'std::duration': Duration, + 'cfg::memory': ConfigMemory, +} + + +def maybe_get_python_type_for_scalar_type_name(name: str) -> Optional[type]: + return typemap.get(name) + + +E = TypeVar("E", bound=enum.StrEnum) + + +class EnumScalarType( + ScalarType, + parametric.SingleParametricType[E], + Generic[E], +): + """Configuration value represented by a custom string enum type that + supports arbitrary value mapping to backend (Postgres) configuration + values, e.g mapping "Enabled"/"Disabled" enum to a bool value, etc. + + We use SingleParametricType to obtain runtime access to the Generic + type arg to avoid having to copy-paste the constructors. + """ + + _val: E + _eql_type: ClassVar[Optional[s_name.QualName]] + + def __init_subclass__( + cls, + *, + edgeql_type: Optional[str] = None, + **kwargs: Any, + ) -> None: + global typemap + super().__init_subclass__(**kwargs) + if edgeql_type is not None: + if edgeql_type in typemap: + raise TypeError( + f"{edgeql_type} is already a registered EnumScalarType") + typemap[edgeql_type] = cls + cls._eql_type = s_name.QualName.from_string(edgeql_type) + + def __init__( + self, + val: E | str, + ) -> None: + if isinstance(val, self.type): + self._val = val + elif isinstance(val, str): + try: + self._val = self.type(val) + except ValueError: + raise errors.InvalidValueError( + f'unexpected backend value for ' + f'{self.__class__.__name__}: {val!r}' + ) from None + + def to_str(self) -> str: + return str(self._val) + + def to_json(self) -> str: + return self._val + + def encode(self) -> bytes: + return self._val.encode("utf8") + + @classmethod + def get_translation_map(cls) -> Mapping[E, str]: + raise NotImplementedError + + @classmethod + def decode(cls, data: bytes) -> Self: + return cls(val=cls.type(data.decode("utf8"))) + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self._val) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, type(self)): + return self._val == other._val + else: + return NotImplemented + + def __reduce__(self) -> tuple[ + Callable[..., EnumScalarType[Any]], + tuple[ + Optional[tuple[type, ...] | type], + E, + ], + ]: + assert type(self).is_fully_resolved(), \ + f'{type(self)} parameters are not resolved' + + cls: type[EnumScalarType[E]] = self.__class__ + types: Optional[tuple[type, ...]] = self.orig_args + if types is None or not cls.is_anon_parametrized(): + typeargs = None + else: + typeargs = types[0] if len(types) == 1 else types + return (cls.__restore__, (typeargs, self._val)) + + @classmethod + def __restore__( + cls, + typeargs: Optional[tuple[type, ...] | type], + val: E, + ) -> Self: + if typeargs is None or cls.is_anon_parametrized(): + obj = cls(val) + else: + obj = cls[typeargs](val) # type: ignore[index] + + return obj + + @classmethod + def get_edgeql_typeid(cls) -> uuid.UUID: + return s_obj.get_known_type_id('std::str') + + @classmethod + def get_edgeql_type(cls) -> s_name.QualName: + """Return fully-qualified name of the scalar type for this setting.""" + assert cls._eql_type is not None + return cls._eql_type + + def to_backend_str(self) -> str: + """Convert static frontend config value to backend config value.""" + return self.get_translation_map()[self._val] + + @classmethod + def to_backend_expr(cls, expr: str) -> str: + """Convert dynamic backend config value to frontend config value.""" + cases_list = [] + for fe_val, be_val in cls.get_translation_map().items(): + cases_list.append(f"WHEN lower('{fe_val}') THEN '{be_val}'") + cases = "\n".join(cases_list) + errmsg = f"unexpected frontend value for {cls.__name__}: %s" + err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" + return ( + f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " + f"FROM lower(({expr})) AS f(v))" + ) + + @classmethod + def to_frontend_expr(cls, expr: str) -> Optional[str]: + """Convert dynamic frontend config value to backend config value.""" + cases_list = [] + for fe_val, be_val in cls.get_translation_map().items(): + cases_list.append(f"WHEN lower('{be_val}') THEN '{fe_val}'") + cases = "\n".join(cases_list) + errmsg = f"unexpected backend value for {cls.__name__}: %s" + err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" + return ( + f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " + f"FROM lower(({expr})) AS f(v))" + ) + + +class EnabledDisabledEnum(enum.StrEnum): + Enabled = "Enabled" + Disabled = "Disabled" + + +class EnabledDisabledType( + EnumScalarType[EnabledDisabledEnum], + edgeql_type="cfg::TestEnabledDisabledEnum", +): + @classmethod + def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]: + return { + EnabledDisabledEnum.Enabled: "true", + EnabledDisabledEnum.Disabled: "false", + } diff --git a/edb/lib/_testmode.edgeql b/edb/lib/_testmode.edgeql index 6fb29b1f620..761a5dc53bb 100644 --- a/edb/lib/_testmode.edgeql +++ b/edb/lib/_testmode.edgeql @@ -56,7 +56,9 @@ CREATE TYPE cfg::TestInstanceConfigStatTypes EXTENDING cfg::TestInstanceConfig { }; -CREATE SCALAR TYPE cfg::TestEnum extending enum; +CREATE SCALAR TYPE cfg::TestEnum EXTENDING enum; +CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum + EXTENDING enum; ALTER TYPE cfg::AbstractConfig { @@ -141,6 +143,12 @@ ALTER TYPE cfg::AbstractConfig { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::backend_setting := '"max_connections"'; }; + + CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { + CREATE ANNOTATION cfg::internal := 'true'; + CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; + SET default := cfg::TestEnabledDisabledEnum.Enabled; + }; }; diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 8fcd01f5ed7..0a696b43072 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -1747,11 +1747,11 @@ class AssertJSONTypeFunction(trampoline.VersionedFunction): 'wrong_object_type', msg => coalesce( msg, - ( - 'expected JSON ' - || array_to_string(typenames, ' or ') - || '; got JSON ' - || coalesce(jsonb_typeof(val), 'UNKNOWN') + format( + 'expected JSON %s; got JSON %s: %s', + array_to_string(typenames, ' or '), + coalesce(jsonb_typeof(val), 'UNKNOWN'), + val::text ) ), detail => detail @@ -3392,7 +3392,7 @@ class InterpretConfigValueToJsonFunction(trampoline.VersionedFunction): - for memory size: we always convert to kilobytes; - already unitless numbers are left as is. - See https://www.postgresql.org/docs/12/config-setting.html + See https://www.postgresql.org/docs/current/config-setting.html for information about the units Postgres config system has. """ @@ -3444,6 +3444,53 @@ def __init__(self) -> None: ) +class PostgresJsonConfigValueToFrontendConfigValueFunction( + trampoline.VersionedFunction, +): + """Convert a Postgres config value to frontend config value. + + Most values are retained as-is, but some need translation, which + is implemented as a to_frontend_expr() on the corresponding + setting ScalarType. + """ + + def __init__(self, config_spec: edbconfig.Spec) -> None: + variants_list = [] + for setting in config_spec.values(): + if ( + setting.backend_setting + and isinstance(setting.type, type) + and issubclass(setting.type, statypes.ScalarType) + ): + conv_expr = setting.type.to_frontend_expr('"value"->>0') + if conv_expr is not None: + variants_list.append(f""" + WHEN {ql(setting.backend_setting)} + THEN to_jsonb({conv_expr}) + """) + + variants = "\n".join(variants_list) + text = f""" + SELECT ( + CASE "setting_name" + {variants} + ELSE "value" + END + ) + """ + + super().__init__( + name=('edgedb', '_postgres_json_config_value_to_fe_config_value'), + args=[ + ('setting_name', ('text',)), + ('value', ('jsonb',)) + ], + returns=('jsonb',), + volatility='immutable', + text=text, + ) + + class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): """Convert a Postgres setting to JSON value. @@ -3471,26 +3518,10 @@ class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): text = r""" SELECT - (CASE - - WHEN parsed_value.unit != '' - THEN - edgedb_VER._interpret_config_value_to_json( - parsed_value.val, - settings.vartype, - 1, - parsed_value.unit - ) - - ELSE - edgedb_VER._interpret_config_value_to_json( - "setting_value", - settings.vartype, - settings.multiplier, - settings.unit - ) - - END) + edgedb_VER._postgres_json_config_value_to_fe_config_value( + "setting_name", + backend_json_value.value + ) FROM LATERAL ( SELECT regexp_match( @@ -3521,8 +3552,29 @@ class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): as vartype, COALESCE(settings_in.multiplier, '1') as multiplier, COALESCE(settings_in.unit, '') as unit - ) as settings + ) AS settings + CROSS JOIN LATERAL + (SELECT + (CASE + WHEN parsed_value.unit != '' + THEN + edgedb_VER._interpret_config_value_to_json( + parsed_value.val, + settings.vartype, + 1, + parsed_value.unit + ) + + ELSE + edgedb_VER._interpret_config_value_to_json( + "setting_value", + settings.vartype, + settings.multiplier, + settings.unit + ) + END) AS value + ) AS backend_json_value """ def __init__(self) -> None: @@ -3748,11 +3800,14 @@ class SysConfigFullFunction(trampoline.VersionedFunction): pg_config AS ( SELECT spec.name, - edgedb_VER._interpret_config_value_to_json( - settings.setting, - settings.vartype, - settings.multiplier, - settings.unit + edgedb_VER._postgres_json_config_value_to_fe_config_value( + settings.name, + edgedb_VER._interpret_config_value_to_json( + settings.setting, + settings.vartype, + settings.multiplier, + settings.unit + ) ) AS value, source AS source, TRUE AS is_backend @@ -4123,12 +4178,9 @@ def __init__(self, config_spec: edbconfig.Spec) -> None: valql = '"value"->>0' if ( isinstance(setting.type, type) - and issubclass(setting.type, statypes.Duration) + and issubclass(setting.type, statypes.ScalarType) ): - valql = f""" - edgedb_VER._interval_to_ms(({valql})::interval)::text \ - || 'ms' - """ + valql = setting.type.to_backend_expr(valql) variants_list.append(f''' WHEN "name" = {ql(setting_name)} @@ -5244,6 +5296,8 @@ def get_bootstrap_commands( dbops.CreateFunction(TypeIDToConfigType()), dbops.CreateFunction(ConvertPostgresConfigUnitsFunction()), dbops.CreateFunction(InterpretConfigValueToJsonFunction()), + dbops.CreateFunction( + PostgresJsonConfigValueToFrontendConfigValueFunction(config_spec)), dbops.CreateFunction(PostgresConfigValueToJsonFunction()), dbops.CreateFunction(SysConfigFullFunction()), dbops.CreateFunction(SysConfigUncachedFunction()), diff --git a/edb/pgsql/patches.py b/edb/pgsql/patches.py index ca29cf381d2..d5bd67b1bf9 100644 --- a/edb/pgsql/patches.py +++ b/edb/pgsql/patches.py @@ -67,4 +67,18 @@ def get_version_key(num_patches: int): '''), ('sql-introspection', ''), ('metaschema-sql', 'SysConfigFullFunction'), + # 6.0b3 or 6.0rc1 + ('edgeql+schema+config+testmode', ''' +CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum + EXTENDING enum; +ALTER TYPE cfg::AbstractConfig { + CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { + CREATE ANNOTATION cfg::internal := 'true'; + CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; + SET default := cfg::TestEnabledDisabledEnum.Enabled; + }; +}; +'''), + ('metaschema-sql', 'PostgresConfigValueToJsonFunction'), + ('metaschema-sql', 'SysConfigFullFunction'), ] diff --git a/edb/schema/utils.py b/edb/schema/utils.py index 88ddac369dc..bc98b25d250 100644 --- a/edb/schema/utils.py +++ b/edb/schema/utils.py @@ -1379,6 +1379,15 @@ def const_ast_from_python(val: Any, with_secrets: bool=False) -> qlast.Expr: ), expr=qlast.Constant.string(value=val.to_iso8601()), ) + elif isinstance(val, statypes.EnumScalarType): + qltype = val.get_edgeql_type() + return qlast.TypeCast( + type=qlast.TypeName( + maintype=qlast.ObjectRef( + module=qltype.module, name=qltype.name), + ), + expr=qlast.Constant.string(value=val.to_str()), + ) elif isinstance(val, statypes.CompositeType): return qlast.InsertQuery( subject=name_to_ast_ref(sn.name_from_string(val._tspec.name)), diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 21735b58004..e7a526de6f4 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -50,9 +50,10 @@ from edb import errors from edb import edgeql -from edb.ir import statypes from edb.ir import typeutils as irtyputils from edb.edgeql import ast as qlast +from edb.edgeql import codegen as qlcodegen +from edb.edgeql import qltypes from edb.common import debug from edb.common import devmode @@ -71,6 +72,7 @@ from edb.schema import schema as s_schema from edb.schema import std as s_std from edb.schema import types as s_types +from edb.schema import utils as s_utils from edb.server import args as edbargs from edb.server import config @@ -614,7 +616,6 @@ def compile_bootstrap_script( expected_cardinality_one: bool = False, output_format: edbcompiler.OutputFormat = edbcompiler.OutputFormat.JSON, ) -> Tuple[s_schema.Schema, str]: - ctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema, @@ -1956,13 +1957,13 @@ async def _configure( backend_params.has_configfile_access ) ): - if isinstance(setting.default, statypes.Duration): - val = f'"{setting.default.to_iso8601()}"' - else: - val = repr(setting.default) - script = f''' - CONFIGURE INSTANCE SET {setting.name} := {val}; - ''' + script = qlcodegen.generate_source( + qlast.ConfigSet( + name=qlast.ObjectRef(name=setting.name), + scope=qltypes.ConfigScope.INSTANCE, + expr=s_utils.const_ast_from_python(setting.default), + ) + ) schema, sql = compile_bootstrap_script(compiler, schema, script) await _execute(ctx.conn, sql) diff --git a/edb/server/compiler/sertypes.py b/edb/server/compiler/sertypes.py index dcc83fdeb49..ffc91280c91 100644 --- a/edb/server/compiler/sertypes.py +++ b/edb/server/compiler/sertypes.py @@ -2098,11 +2098,29 @@ class EnumDesc(SchemaTypeDesc): names: list[str] ancestors: Optional[list[TypeDesc]] - def encode(self, data: str) -> bytes: - return _encode_str(data) + @functools.cached_property + def _decoder(self) -> Callable[[bytes], Any]: + assert self.name is not None + pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) + if pytype is not None and issubclass(pytype, statypes.ScalarType): + return pytype.decode + else: + return _decode_str + + @functools.cached_property + def _encoder(self) -> Callable[[Any], bytes]: + assert self.name is not None + pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) + if pytype is not None and issubclass(pytype, statypes.ScalarType): + return pytype.encode + else: + return _encode_str + + def encode(self, data: Any) -> bytes: + return self._encoder(data) - def decode(self, data: bytes) -> str: - return _decode_str(data) + def decode(self, data: bytes) -> Any: + return self._decoder(data) @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/edb/server/config/ops.py b/edb/server/config/ops.py index fb66e90ea31..b9f7dab744b 100644 --- a/edb/server/config/ops.py +++ b/edb/server/config/ops.py @@ -98,6 +98,9 @@ def coerce_single_value(setting: spec.Setting, value: Any) -> Any: elif (isinstance(value, (str, int)) and _issubclass(setting.type, statypes.ConfigMemory)): return statypes.ConfigMemory(value) + elif (isinstance(value, str) and + _issubclass(setting.type, statypes.EnumScalarType)): + return setting.type(value) else: raise errors.ConfigurationError( f'invalid value type for the {setting.name!r} setting') @@ -352,6 +355,8 @@ def spec_to_json(spec: spec.Spec): typeid = s_obj.get_known_type_id('std::duration') elif _issubclass(setting.type, statypes.ConfigMemory): typeid = s_obj.get_known_type_id('cfg::memory') + elif _issubclass(setting.type, statypes.EnumScalarType): + typeid = setting.type.get_edgeql_typeid() elif isinstance(setting.type, types.ConfigTypeSpec): typeid = types.CompositeConfigType.get_edgeql_typeid() else: @@ -420,6 +425,8 @@ def value_from_json_value(spec: spec.Spec, setting: spec.Setting, value: Any): return statypes.Duration.from_iso8601(value) elif _issubclass(setting.type, statypes.ConfigMemory): return statypes.ConfigMemory(value) + elif _issubclass(setting.type, statypes.EnumScalarType): + return setting.type(value) else: return value diff --git a/edb/server/config/spec.py b/edb/server/config/spec.py index 6288b08978b..e41b8ef8af8 100644 --- a/edb/server/config/spec.py +++ b/edb/server/config/spec.py @@ -40,8 +40,12 @@ from . import types -SETTING_TYPES = {str, int, bool, float, - statypes.Duration, statypes.ConfigMemory} +SETTING_TYPES = { + str, + int, + bool, + float, +} @dataclasses.dataclass(frozen=True, eq=True) @@ -64,13 +68,20 @@ class Setting: protected: bool = False def __post_init__(self) -> None: - if (self.type not in SETTING_TYPES and - not isinstance(self.type, types.ConfigTypeSpec)): + if ( + self.type not in SETTING_TYPES + and not isinstance(self.type, types.ConfigTypeSpec) + and not ( + isinstance(self.type, type) + and issubclass(self.type, statypes.ScalarType) + ) + ): raise ValueError( f'invalid config setting {self.name!r}: ' f'type is expected to be either one of ' f'{{str, int, bool, float}} ' - f'or an edb.server.config.types.ConfigType subclass') + f'or an edb.server.config.types.ConfigType ', + f'or edb.ir.statypes.ScalarType subclass') if self.set_of: if not isinstance(self.default, frozenset): @@ -269,8 +280,10 @@ def _load_spec_from_type( ptype, schema, spec_class=types.ConfigTypeSpec, ) - else: + elif isinstance(ptype, s_scalars.ScalarType): pytype = staeval.scalar_type_to_python_type(ptype, schema) + else: + raise RuntimeError(f"unsupported config value type: {ptype}") attributes = {} for a, v in p.get_annotations(schema).items(schema): diff --git a/edb/server/config/types.py b/edb/server/config/types.py index 975cbbd8dbb..168bcfdd61f 100644 --- a/edb/server/config/types.py +++ b/edb/server/config/types.py @@ -203,6 +203,11 @@ def from_pyvalue( and isinstance(value, str | int) ): value = statypes.ConfigMemory(value) + elif ( + _issubclass(f_type, statypes.EnumScalarType) + and isinstance(value, str) + ): + value = f_type(value) elif not isinstance(f_type, type) or not isinstance(value, f_type): raise cls._err( diff --git a/tests/test_server_config.py b/tests/test_server_config.py index bb897039171..f59660f8fb9 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -40,6 +40,7 @@ from edb.protocol import messages from edb.testbase import server as tb + from edb.schema import objects as s_obj from edb.ir import statypes @@ -2140,6 +2141,96 @@ async def test_server_config_query_timeout(self): new_aborted += float(line.split(' ')[1]) self.assertEqual(orig_aborted, new_aborted) + @unittest.skipIf( + "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ, + "cannot use CONFIGURE INSTANCE in multi-tenant mode", + ) + async def test_server_config_custom_enum(self): + async def assert_conf(con, name, expected_val): + val = await con.query_single(f''' + select assert_single(cfg::Config.{name}) + ''') + + self.assertEqual( + str(val), + expected_val + ) + + async with tb.start_edgedb_server( + security=args.ServerSecurityMode.InsecureDevMode, + ) as sd: + c1 = await sd.connect() + c2 = await sd.connect() + + await c2.query('create database test') + t1 = await sd.connect(database='test') + + # check that the default was set correctly + await assert_conf( + c1, '__check_function_bodies', 'Enabled') + + #### + + await c1.query(''' + configure instance set + __check_function_bodies + := cfg::TestEnabledDisabledEnum.Disabled; + ''') + + for c in {c1, c2, t1}: + await assert_conf( + c, '__check_function_bodies', 'Disabled') + + #### + + await t1.query(''' + configure current database set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Enabled; + ''') + + for c in {c1, c2}: + await assert_conf( + c, '__check_function_bodies', 'Disabled') + + await assert_conf( + t1, '__check_function_bodies', 'Enabled') + + #### + + await c2.query(''' + configure session set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Disabled; + ''') + await assert_conf( + c1, '__check_function_bodies', 'Disabled') + await assert_conf( + t1, '__check_function_bodies', 'Enabled') + await assert_conf( + c2, '__check_function_bodies', 'Disabled') + + #### + await c1.query(''' + configure instance reset + __check_function_bodies; + ''') + await t1.query(''' + configure session set + __check_function_bodies := + cfg::TestEnabledDisabledEnum.Disabled; + ''') + await assert_conf( + c1, '__check_function_bodies', 'Enabled') + await assert_conf( + t1, '__check_function_bodies', 'Disabled') + await assert_conf( + c2, '__check_function_bodies', 'Disabled') + + await c1.aclose() + await c2.aclose() + await t1.aclose() + class TestStaticServerConfig(tb.TestCase): @test.xerror("static config args not supported")