Skip to content

Commit

Permalink
Merge pull request #87 from JakobGM/thomasaarholt/serialize_camelcase
Browse files Browse the repository at this point in the history
Fix AliasGenerator
  • Loading branch information
thomasaarholt authored Oct 22, 2024
2 parents 61be2d2 + 02f6eba commit 10d6923
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
51 changes: 29 additions & 22 deletions src/patito/_pydantic/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def validate_annotation(
class DtypeResolver:
def __init__(self, annotation: Any | None):
self.annotation = annotation
self.schema = TypeAdapter(annotation).json_schema()
# mode='serialization' allows nested models with structs, see #86
self.schema = TypeAdapter(annotation).json_schema(mode="serialization")
self.defs = self.schema.get("$defs", {})

def valid_polars_dtypes(self) -> DataTypeGroup:
Expand Down Expand Up @@ -159,6 +160,7 @@ def _pydantic_subschema_to_valid_polars_types(
self.defs[props["$ref"].split("/")[-1]]
)
return DataTypeGroup([])

pyd_type = props.get("type")
if pyd_type == "array":
if "items" not in props:
Expand All @@ -169,28 +171,27 @@ def _pydantic_subschema_to_valid_polars_types(
return DataTypeGroup(
[pl.List(dtype) for dtype in item_dtypes], match_base_type=False
)

elif pyd_type == "object":
if "properties" not in props:
return DataTypeGroup([])
object_props = props["properties"]
struct_fields: list[pl.Field] = []
for name, sub_props in object_props.items():
dtype = self._default_polars_dtype_for_schema(sub_props)
assert dtype is not None
struct_fields.append(pl.Field(name, dtype))
return DataTypeGroup(
[
pl.Struct(
[
pl.Field(
name, self._default_polars_dtype_for_schema(sub_props)
)
for name, sub_props in object_props.items()
]
)
],
[pl.Struct(struct_fields)],
match_base_type=False,
) # for structs, return only the default dtype set to avoid combinatoric issues
return _pyd_type_to_valid_dtypes(
PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
)

def _default_polars_dtype_for_schema(self, schema: dict) -> DataType | None:
def _default_polars_dtype_for_schema(
self, schema: dict[str, Any]
) -> DataType | None:
if "anyOf" in schema:
if len(schema["anyOf"]) == 2: # look for optionals first
schema = _without_optional(schema)
Expand All @@ -206,13 +207,14 @@ def _default_polars_dtype_for_schema(self, schema: dict) -> DataType | None:

def _pydantic_subschema_to_default_dtype(
self,
props: dict,
props: dict[str, Any],
) -> DataType | None:
if "column_info" in props: # user has specified in patito model
ci = ColumnInfo.model_validate_json(props["column_info"])
if ci.dtype is not None:
dtype = ci.dtype() if isinstance(ci.dtype, DataTypeClass) else ci.dtype
return dtype

if "type" not in props:
if "enum" in props:
raise TypeError("Mixed type enums not supported by patito.")
Expand All @@ -223,10 +225,12 @@ def _pydantic_subschema_to_default_dtype(
self.defs[props["$ref"].split("/")[-1]]
)
return None

pyd_type = props.get("type")
if pyd_type == "numeric":
pyd_type = "number"
if pyd_type == "array":

elif pyd_type == "array":
if "items" not in props:
raise NotImplementedError(
"Unexpected error processing pydantic schema. Please file an issue."
Expand All @@ -236,18 +240,21 @@ def _pydantic_subschema_to_default_dtype(
if inner_default_type is None:
return None
return pl.List(inner_default_type)
elif pyd_type == "object":

elif pyd_type == "object": # these are structs
if "properties" not in props:
raise NotImplementedError(
"dictionaries not currently supported by patito"
)
object_props = props["properties"]
return pl.Struct(
[
pl.Field(name, self._default_polars_dtype_for_schema(sub_props))
for name, sub_props in object_props.items()
]
)
object_props: dict[str, dict[str, str]] = props["properties"]
struct_fields: list[pl.Field] = []

for name, sub_props in object_props.items():
dtype = self._default_polars_dtype_for_schema(sub_props)
assert dtype is not None
struct_fields.append(pl.Field(name, dtype))
return pl.Struct(struct_fields)

return _pyd_type_to_default_dtype(
PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
)
2 changes: 1 addition & 1 deletion src/patito/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def to_expr(va: str | AliasPath | AliasChoices) -> pl.Expr | None:
f"TODO figure out how this AliasPath behaves ({va})"
)
return (
pl.col(va.path[0]).list.get(va.path[1], null_on_oob=True)
pl.col(str(va.path[0])).list.get(va.path[1], null_on_oob=True)
if va.path[0] in self.collect_schema()
else None
)
Expand Down

0 comments on commit 10d6923

Please sign in to comment.