Skip to content

Commit

Permalink
Merge pull request #62 from ermakov-oleg/support-typeddict
Browse files Browse the repository at this point in the history
Added support TypedDict
  • Loading branch information
ermakov-oleg authored May 31, 2023
2 parents a04b0ba + e1e8a66 commit 1cb19c4
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 39 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ There is support for generic types from the standard typing module:
* Enum
* List
* Dict
* TypedDict
* Mapping
* Sequence
* Tuple (fixed size)
Expand Down
45 changes: 42 additions & 3 deletions python/serpyco_rs/_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from uuid import UUID

from attributes_doc import get_attributes_doc
from typing_extensions import assert_never, get_args
from typing_extensions import NotRequired, Required, assert_never, get_args, is_typeddict

from ._utils import to_camelcase
from .metadata import (
Expand Down Expand Up @@ -164,6 +164,15 @@ class EntityType(Type):
doc: Optional[str] = None


@dataclasses.dataclass
class TypedDictType(Type):
name: str
fields: Sequence[EntityField]
omit_none: bool = False
generics: Sequence[tuple[TypeVar, Any]] = tuple()
doc: Optional[str] = None


@dataclasses.dataclass
class OptionalType(Type):
inner: Type
Expand Down Expand Up @@ -242,6 +251,8 @@ def describe_type(t: Any, meta: Optional[_Meta] = None) -> Type:
metadata = _get_annotated_metadata(t)
if get_origin(t) == Annotated: # unwrap annotated
t = t.__origin__
if get_origin(t) in {Required, NotRequired}: # unwrap TypedDict special forms
t = t.__args__[0]
if hasattr(t, "__origin__"):
parameters = getattr(t.__origin__, "__parameters__", ())
args = t.__args__
Expand Down Expand Up @@ -358,7 +369,7 @@ def describe_type(t: Any, meta: Optional[_Meta] = None) -> Type:
if issubclass(t, (Enum, IntEnum)):
return EnumType(cls=t, custom_encoder=custom_encoder)

if dataclasses.is_dataclass(t) or _is_attrs(t):
if dataclasses.is_dataclass(t) or _is_attrs(t) or is_typeddict(t):
meta.add_to_state(meta_key, None)
entity_type = _describe_entity(
t=t,
Expand Down Expand Up @@ -422,7 +433,7 @@ def _describe_entity(
cls_none_format: NoneFormat,
custom_encoder: Optional[CustomEncoder[Any, Any]],
meta: _Meta,
) -> EntityType:
) -> Union[EntityType, TypedDictType]:
docs = get_attributes_doc(t)
try:
types = get_type_hints(t, include_extras=True)
Expand Down Expand Up @@ -454,6 +465,16 @@ def _describe_entity(
)
)

if is_typeddict(t):
return TypedDictType(
name=_generate_name(t, cls_filed_format, cls_none_format, generics),
fields=fields,
omit_none=cls_none_format is OmitNone,
generics=generics,
doc=t.__doc__,
custom_encoder=custom_encoder,
)

return EntityType(
cls=t,
name=_generate_name(t, cls_filed_format, cls_none_format, generics),
Expand All @@ -476,6 +497,16 @@ def _get_entity_fields(t: Any) -> Sequence[_Field[Any]]:
)
for f in dataclasses.fields(t)
]
if is_typeddict(t):
return [
_Field(
name=field_name,
type=field_type,
default=NOT_SET if _is_required_in_typeddict(t, field_name) else None,
default_factory=NOT_SET,
)
for field_name, field_type in t.__annotations__.items()
]
if _is_attrs(t):
assert attr
return [
Expand Down Expand Up @@ -602,3 +633,11 @@ def _is_literal_type(t: Any) -> bool:

def _is_attrs(t: Any) -> bool:
return attr is not None and attr.has(t)


def _is_required_in_typeddict(t: Any, key: str) -> bool:
if is_typeddict(t):
if t.__total__:
return key not in t.__optional_keys__
return key in t.__required_keys__
raise RuntimeError(f'Expected TypedDict, got "{t!r}"')
10 changes: 10 additions & 0 deletions python/serpyco_rs/_json_schema/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def _(arg: describe.EntityType, doc: Optional[str] = None) -> Schema:
)


@to_json_schema.register
def _(arg: describe.TypedDictType, doc: Optional[str] = None) -> Schema:
return ObjectType(
properties={prop.dict_key: to_json_schema(prop.type, prop.doc) for prop in arg.fields if not prop.is_property},
required=[prop.dict_key for prop in arg.fields if prop.required] or None,
name=arg.name,
description=arg.doc,
)


@to_json_schema.register
def _(arg: describe.ArrayType, doc: Optional[str] = None) -> Schema:
return ArrayType(
Expand Down
85 changes: 82 additions & 3 deletions src/serializer/encoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,73 @@ impl Encoder for EntityEncoder {
}
}

#[derive(Debug, Clone)]
pub struct TypedDictEncoder {
pub(crate) omit_none: bool,
pub(crate) fields: Vec<Field>,
}

impl Encoder for TypedDictEncoder {
#[inline]
fn dump(&self, value: *mut PyObject) -> PyResult<*mut PyObject> {
let dict_ptr = ffi!(PyDict_New());

for field in &self.fields {
let field_val = match py_object_get_item(value, field.name.as_ptr()) {
Ok(val) => val,
Err(e) => {
if field.required {
return Err(ValidationError::new_err(format!(
"data dictionary is missing required parameter {} (err: {})",
&field.name, e
)));
} else {
continue;
}
}
}; // val RC +1

let dump_result = field.encoder.dump(field_val)?; // new obj or RC +1

if field.required || !self.omit_none || dump_result != unsafe { NONE_PY_TYPE } {
ffi!(PyDict_SetItem(
dict_ptr,
field.dict_key.as_ptr(),
dump_result
)); // key and val RC +1
}

ffi!(Py_DECREF(field_val));
ffi!(Py_DECREF(dump_result));
}

Ok(dict_ptr)
}

#[inline]
fn load(&self, value: *mut PyObject) -> PyResult<*mut PyObject> {
let dict_ptr = ffi!(PyDict_New());
for field in &self.fields {
let val = match py_object_get_item(value, field.dict_key.as_ptr()) {
Ok(val) => field.encoder.load(val)?, // new obj or RC +1
Err(e) => {
if field.required {
return Err(ValidationError::new_err(format!(
"data dictionary is missing required parameter {} (err: {})",
&field.dict_key, e
)));
} else {
continue;
}
}
};
ffi!(PyDict_SetItem(dict_ptr, field.name.as_ptr(), val)); // key and val RC +1
ffi!(Py_DECREF(val));
}
Ok(dict_ptr)
}
}

#[derive(Debug, Clone)]
pub struct UUIDEncoder;

Expand Down Expand Up @@ -396,16 +463,25 @@ impl Encoder for DateEncoder {
}
}

#[derive(Debug)]
pub enum Encoders {
Entity(EntityEncoder),
TypedDict(TypedDictEncoder),
}

#[derive(Debug, Clone)]
pub struct LazyEncoder {
pub(crate) inner: Arc<AtomicRefCell<Option<EntityEncoder>>>,
pub(crate) inner: Arc<AtomicRefCell<Option<Encoders>>>,
}

impl Encoder for LazyEncoder {
#[inline]
fn dump(&self, value: *mut PyObject) -> PyResult<*mut PyObject> {
match self.inner.borrow().as_ref() {
Some(encoder) => encoder.dump(value),
Some(encoder) => match encoder {
Encoders::Entity(encoder) => encoder.dump(value),
Encoders::TypedDict(encoder) => encoder.dump(value),
},
None => Err(PyRuntimeError::new_err(
"[RUST] Invalid recursive encoder".to_string(),
)),
Expand All @@ -415,7 +491,10 @@ impl Encoder for LazyEncoder {
#[inline]
fn load(&self, value: *mut PyObject) -> PyResult<*mut PyObject> {
match self.inner.borrow().as_ref() {
Some(encoder) => encoder.load(value),
Some(encoder) => match encoder {
Encoders::Entity(encoder) => encoder.load(value),
Encoders::TypedDict(encoder) => encoder.load(value),
},
None => Err(PyRuntimeError::new_err(
"[RUST] Invalid recursive encoder".to_string(),
)),
Expand Down
82 changes: 52 additions & 30 deletions src/serializer/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::serializer::encoders::{
CustomEncoder, DateEncoder, DateTimeEncoder, LazyEncoder, TEncoder, TimeEncoder, UnionEncoder,
CustomEncoder, DateEncoder, DateTimeEncoder, Encoders, LazyEncoder, TEncoder, TimeEncoder,
TypedDictEncoder, UnionEncoder,
};
use atomic_refcell::AtomicRefCell;
use pyo3::prelude::*;
Expand All @@ -16,7 +17,7 @@ use super::encoders::{
NoopEncoder, OptionalEncoder, TupleEncoder, UUIDEncoder,
};

type EncoderStateValue = Arc<AtomicRefCell<Option<EntityEncoder>>>;
type EncoderStateValue = Arc<AtomicRefCell<Option<Encoders>>>;

#[pyclass]
#[derive(Debug)]
Expand Down Expand Up @@ -145,33 +146,7 @@ pub fn get_encoder(
let py_type = type_info.getattr(py, "cls")?;
let class_fields = type_info.getattr(py, "fields")?;
let omit_none = type_info.getattr(py, "omit_none")?.is_true(py)?;
let mut fields = vec![];

for field in class_fields.as_ref(py).iter()? {
let field = field?;
let f_name: &PyString = field.getattr("name")?.downcast()?;
let dict_key: &PyString = field.getattr("dict_key")?.downcast()?;
let required = field.getattr("required")?.is_true()?;
let f_type = get_object_type(field.getattr("type")?)?;
let f_default = field.getattr("default")?;
let f_default_factory = field.getattr("default_factory")?;

let fld = Field {
name: f_name.into(),
dict_key: dict_key.into(),
encoder: get_encoder(py, f_type, encoder_state)?,
required,
default: match is_not_set(f_default)? {
true => None,
false => Some(f_default.into()),
},
default_factory: match is_not_set(f_default_factory)? {
true => None,
false => Some(f_default_factory.into()),
},
};
fields.push(fld);
}
let fields = iterate_on_fields(py, class_fields, encoder_state)?;

let encoder = EntityEncoder {
fields,
Expand All @@ -180,7 +155,20 @@ pub fn get_encoder(
};
let python_object_id = type_info.as_ptr() as *const _ as usize;
let val = encoder_state.entry(python_object_id).or_default();
AtomicRefCell::<Option<EntityEncoder>>::borrow_mut(val).replace(encoder.clone());
AtomicRefCell::<Option<Encoders>>::borrow_mut(val)
.replace(Encoders::Entity(encoder.clone()));
wrap_with_custom_encoder(py, type_info, Box::new(encoder))?
}
Type::TypedDict(type_info) => {
let class_fields = type_info.getattr(py, "fields")?;
let omit_none = type_info.getattr(py, "omit_none")?.is_true(py)?;
let fields = iterate_on_fields(py, class_fields, encoder_state)?;

let encoder = TypedDictEncoder { fields, omit_none };
let python_object_id = type_info.as_ptr() as *const _ as usize;
let val = encoder_state.entry(python_object_id).or_default();
AtomicRefCell::<Option<Encoders>>::borrow_mut(val)
.replace(Encoders::TypedDict(encoder.clone()));
wrap_with_custom_encoder(py, type_info, Box::new(encoder))?
}
Type::RecursionHolder(type_info) => {
Expand Down Expand Up @@ -239,3 +227,37 @@ fn to_optional(py: Python<'_>, value: PyObject) -> Option<PyObject> {
false => Some(value),
}
}

fn iterate_on_fields(
py: Python<'_>,
fields_attr: PyObject,
encoder_state: &mut HashMap<usize, EncoderStateValue>,
) -> PyResult<Vec<Field>> {
let mut fields = vec![];
for field in fields_attr.as_ref(py).iter()? {
let field = field?;
let f_name: &PyString = field.getattr("name")?.downcast()?;
let dict_key: &PyString = field.getattr("dict_key")?.downcast()?;
let required = field.getattr("required")?.is_true()?;
let f_type = get_object_type(field.getattr("type")?)?;
let f_default = field.getattr("default")?;
let f_default_factory = field.getattr("default_factory")?;

let fld = Field {
name: f_name.into(),
dict_key: dict_key.into(),
encoder: get_encoder(py, f_type, encoder_state)?,
required,
default: match is_not_set(f_default)? {
true => None,
false => Some(f_default.into()),
},
default_factory: match is_not_set(f_default_factory)? {
true => None,
false => Some(f_default_factory.into()),
},
};
fields.push(fld);
}
Ok(fields)
}
5 changes: 5 additions & 0 deletions src/serializer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub static mut DATETIME_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut DATE_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut ENUM_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut ENTITY_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut TYPED_DICT_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut OPTIONAL_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut ARRAY_TYPE: *mut PyObject = 0 as *mut PyObject;
pub static mut DICTIONARY_TYPE: *mut PyObject = 0 as *mut PyObject;
Expand Down Expand Up @@ -52,6 +53,7 @@ pub enum Type {
Date(Py<PyAny>),
Enum(Py<PyAny>),
Entity(Py<PyAny>),
TypedDict(Py<PyAny>),
Optional(Py<PyAny>),
Array(Py<PyAny>),
Dictionary(Py<PyAny>),
Expand Down Expand Up @@ -87,6 +89,8 @@ pub fn get_object_type(type_info: &PyAny) -> PyResult<Type> {
Ok(Type::Enum(type_info.into()))
} else if check_type!(type_info, ENTITY_TYPE) {
Ok(Type::Entity(type_info.into()))
} else if check_type!(type_info, TYPED_DICT_TYPE) {
Ok(Type::TypedDict(type_info.into()))
} else if check_type!(type_info, OPTIONAL_TYPE) {
Ok(Type::Optional(type_info.into()))
} else if check_type!(type_info, ARRAY_TYPE) {
Expand Down Expand Up @@ -123,6 +127,7 @@ pub fn init(py: Python<'_>) {
DATE_TYPE = get_attr_ptr!(describe, "DateType");
ENUM_TYPE = get_attr_ptr!(describe, "EnumType");
ENTITY_TYPE = get_attr_ptr!(describe, "EntityType");
TYPED_DICT_TYPE = get_attr_ptr!(describe, "TypedDictType");
OPTIONAL_TYPE = get_attr_ptr!(describe, "OptionalType");
ARRAY_TYPE = get_attr_ptr!(describe, "ArrayType");
DICTIONARY_TYPE = get_attr_ptr!(describe, "DictionaryType");
Expand Down
Loading

0 comments on commit 1cb19c4

Please sign in to comment.