From eaf9a289490a26b881cf2c53e8c73511ebc782b3 Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Thu, 19 Sep 2024 14:58:22 -0400 Subject: [PATCH] Add custom serialization for scalars. Some scalar types need to be cast into some other SQL type during serialization. This can be defined as `custom_sql_serialization` on the scalar types in schema. --- edb/ir/ast.py | 2 ++ edb/ir/typeutils.py | 21 ++++++--------------- edb/pgsql/compiler/expr.py | 7 +++---- edb/pgsql/compiler/output.py | 4 ++-- edb/schema/scalars.py | 3 +++ 5 files changed, 16 insertions(+), 21 deletions(-) diff --git a/edb/ir/ast.py b/edb/ir/ast.py index dbaa1327828..e3a98eb03a9 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -185,6 +185,8 @@ class TypeRef(ImmutableBase): needs_custom_json_cast: bool = False # If this has a schema-configured backend type, what is it sql_type: typing.Optional[str] = None + # If this has a schema-configured custom sql serialization, what is it + custom_sql_serialization: typing.Optional[str] = None def __repr__(self) -> str: return f'' diff --git a/edb/ir/typeutils.py b/edb/ir/typeutils.py index d4dcfaba2c3..861f2a3cfbd 100644 --- a/edb/ir/typeutils.py +++ b/edb/ir/typeutils.py @@ -196,25 +196,12 @@ def is_persistent_tuple(typeref: irast.TypeRef) -> bool: return False -def get_custom_serialization( - typeref: irast.TypeRef, -) -> Optional[str]: - # FIXME: instead of hardcode we need to extract this from the - # schema/extension - if str(typeref.real_base_type.name_hint) in { - 'ext::postgis::box2d', - 'ext::postgis::box3d', - }: - return 'geometry' - - return None - - def needs_custom_serialization(typeref: irast.TypeRef) -> bool: # True if any component needs custom serialization return contains_predicate( typeref, - lambda typeref: get_custom_serialization(typeref) is not None + lambda typeref: + typeref.real_base_type.custom_sql_serialization is not None ) @@ -427,6 +414,7 @@ def _typeref( sql_type = None needs_custom_json_cast = False + custom_sql_serialization = None if isinstance(t, s_scalars.ScalarType): sql_type = t.resolve_sql_type(schema) if material_typeref is None: @@ -437,6 +425,8 @@ def _typeref( if jcast: needs_custom_json_cast = bool(jcast.get_code(schema)) + custom_sql_serialization = t.get_custom_sql_serialization(schema) + result = irast.TypeRef( id=t.id, name_hint=name_hint, @@ -457,6 +447,7 @@ def _typeref( is_opaque_union=t.get_is_opaque_union(schema), needs_custom_json_cast=needs_custom_json_cast, sql_type=sql_type, + custom_sql_serialization=custom_sql_serialization, ) elif isinstance(t, s_types.Tuple) and t.is_named(schema): schema, material_type = t.material_type(schema) diff --git a/edb/pgsql/compiler/expr.py b/edb/pgsql/compiler/expr.py index 90df793ab45..39d12e474e5 100644 --- a/edb/pgsql/compiler/expr.py +++ b/edb/pgsql/compiler/expr.py @@ -140,16 +140,15 @@ def compile_Parameter( if irtyputils.needs_custom_serialization(expr.typeref): if irtyputils.is_array(expr.typeref): - el_sql_type = irtyputils.get_custom_serialization( - expr.typeref.subtypes[0]) + subt = expr.typeref.subtypes[0] + el_sql_type = subt.real_base_type.custom_sql_serialization # Arrays of text encoded types need to come in as the custom type result = pgast.TypeCast( arg=result, type_name=pgast.TypeName(name=(f'{el_sql_type}[]',)), ) else: - el_sql_type = irtyputils.get_custom_serialization( - expr.typeref) + el_sql_type = expr.typeref.real_base_type.custom_sql_serialization assert el_sql_type is not None result = pgast.TypeCast( arg=result, diff --git a/edb/pgsql/compiler/output.py b/edb/pgsql/compiler/output.py index 802051bf9a9..7a99d5eefed 100644 --- a/edb/pgsql/compiler/output.py +++ b/edb/pgsql/compiler/output.py @@ -640,7 +640,7 @@ def serialize_custom_array( ] ) else: - el_sql_type = irtyputils.get_custom_serialization(el_type) + el_sql_type = el_type.real_base_type.custom_sql_serialization return pgast.TypeCast( arg=expr, type_name=pgast.TypeName(name=(f'{el_sql_type}[]',)), @@ -715,7 +715,7 @@ def output_as_value( elif irtyputils.is_tuple(ser_typeref): return serialize_custom_tuple(expr, styperef=ser_typeref, env=env) else: - el_sql_type = irtyputils.get_custom_serialization(ser_typeref) + el_sql_type = ser_typeref.real_base_type.custom_sql_serialization assert el_sql_type is not None val = pgast.TypeCast( arg=val, diff --git a/edb/schema/scalars.py b/edb/schema/scalars.py index 1791ec32ea2..748209ba4e2 100644 --- a/edb/schema/scalars.py +++ b/edb/schema/scalars.py @@ -88,6 +88,9 @@ class ScalarType( compcoef=0.0, ) + custom_sql_serialization = so.SchemaField( + str, default=None, inheritable=False, compcoef=0.0) + @classmethod def get_schema_class_displayname(cls) -> str: return 'scalar type'