Skip to content

Json bytes and bytearray to JSON #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern crate test;
use test::{black_box, Bencher};

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyString};

use _pydantic_core::SchemaValidator;

Expand All @@ -14,17 +14,21 @@ fn build_schema_validator(py: Python, code: &str) -> SchemaValidator {
SchemaValidator::py_new(py, schema).unwrap()
}

fn json<'a>(py: Python<'a>, code: &'a str) -> &'a PyAny {
black_box(PyString::new(py, code))
}

#[bench]
fn ints_json(bench: &mut Bencher) {
let gil = Python::acquire_gil();
let py = gil.python();
let validator = build_schema_validator(py, "{'type': 'int'}");

let result = validator.validate_json(py, black_box("123".to_string())).unwrap();
let result = validator.validate_json(py, json(py, "123")).unwrap();
let result_int: i64 = result.extract(py).unwrap();
assert_eq!(result_int, 123);

bench.iter(|| black_box(validator.validate_json(py, black_box("123".to_string())).unwrap()))
bench.iter(|| black_box(validator.validate_json(py, json(py, "123")).unwrap()))
}

#[bench]
Expand Down Expand Up @@ -53,10 +57,7 @@ fn list_int_json(bench: &mut Bencher) {
(0..100).map(|x| x.to_string()).collect::<Vec<String>>().join(",")
);

bench.iter(|| {
let input = black_box(code.clone());
black_box(validator.validate_json(py, input).unwrap())
})
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
}

fn list_int_input(py: Python<'_>) -> (SchemaValidator, PyObject) {
Expand Down Expand Up @@ -110,8 +111,7 @@ fn list_error_json(bench: &mut Bencher) {
.join(", ")
);

let input = black_box(code.clone());
match validator.validate_json(py, input) {
match validator.validate_json(py, json(py, &code)) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
Expand All @@ -122,12 +122,9 @@ fn list_error_json(bench: &mut Bencher) {
}
};

bench.iter(|| {
let input = black_box(code.clone());
match validator.validate_json(py, input) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => black_box(e),
}
bench.iter(|| match validator.validate_json(py, json(py, &code)) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => black_box(e),
})
}

Expand Down Expand Up @@ -197,10 +194,7 @@ fn list_any_json(bench: &mut Bencher) {
(0..100).map(|x| x.to_string()).collect::<Vec<String>>().join(",")
);

bench.iter(|| {
let input = black_box(code.clone());
black_box(validator.validate_json(py, input).unwrap())
})
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
}

#[bench]
Expand Down Expand Up @@ -242,10 +236,7 @@ fn dict_json(bench: &mut Bencher) {
.join(", ")
);

bench.iter(|| {
let input = black_box(code.to_string());
black_box(validator.validate_json(py, input).unwrap())
})
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
}

#[bench]
Expand Down Expand Up @@ -343,10 +334,7 @@ fn typed_dict_json(bench: &mut Bencher) {

let code = r#"{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}"#.to_string();

bench.iter(|| {
let input = black_box(code.clone());
black_box(validator.validate_json(py, input).unwrap())
})
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
}

#[bench]
Expand Down
6 changes: 3 additions & 3 deletions pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from pydantic_core._types import Schema

Expand All @@ -8,8 +8,8 @@ class SchemaValidator:
def __init__(self, schema: Schema) -> None: ...
def validate_python(self, input: Any) -> Any: ...
def isinstance_python(self, input: Any) -> bool: ...
def validate_json(self, input: str) -> Any: ...
def isinstance_json(self, input: str) -> bool: ...
def validate_json(self, input: Union[str, bytes, bytearray]) -> Any: ...
def isinstance_json(self, input: Union[str, bytes, bytearray]) -> bool: ...
def validate_assignment(self, field: str, input: Any, data: Dict[str, Any]) -> Dict[str, Any]: ...

class SchemaError(ValueError):
Expand Down
12 changes: 10 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::str::from_utf8;
use pyo3::exceptions::{PyAttributeError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString,
PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet,
PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

Expand Down Expand Up @@ -60,6 +60,12 @@ impl<'a> Input<'a> for PyAny {
Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)),
};
Ok(str.into())
} else if let Ok(py_byte_array) = self.cast_as::<PyByteArray>() {
let str = match from_utf8(unsafe { py_byte_array.as_bytes() }) {
Ok(s) => s,
Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)),
};
Ok(str.into())
} else if self.cast_as::<PyBool>().is_ok() {
// do this before int and float parsing as `False` is cast to `0` and we don't want False to
// be returned as a string
Expand Down Expand Up @@ -270,6 +276,8 @@ impl<'a> Input<'a> for PyAny {
} else if let Ok(py_str) = self.cast_as::<PyString>() {
let string = py_str.to_string_lossy().to_string();
Ok(string.into_bytes().into())
} else if let Ok(py_byte_array) = self.cast_as::<PyByteArray>() {
Ok(py_byte_array.to_vec().into())
} else {
Err(ValError::new(ErrorKind::BytesType, self))
}
Expand Down
29 changes: 21 additions & 8 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ use std::fmt::Debug;

use enum_dispatch::enum_dispatch;

use pyo3::exceptions::PyRecursionError;
use pyo3::exceptions::{PyRecursionError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict};
use serde_json::from_str as parse_json;
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString};

use crate::build_tools::{py_error, SchemaDict, SchemaError};
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError};
Expand Down Expand Up @@ -99,8 +98,8 @@ impl SchemaValidator {
}
}

pub fn validate_json(&self, py: Python, input: String) -> PyResult<PyObject> {
match parse_json::<JsonInput>(&input) {
pub fn validate_json(&self, py: Python, input: &PyAny) -> PyResult<PyObject> {
match parse_json(input)? {
Ok(input) => {
let r = self.validator.validate(
py,
Expand All @@ -112,15 +111,15 @@ impl SchemaValidator {
r.map_err(|e| self.prepare_validation_err(py, e))
}
Err(e) => {
let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, &input);
let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, input);
let err = ValError::LineErrors(vec![line_err]);
Err(self.prepare_validation_err(py, err))
}
}
}

pub fn isinstance_json(&self, py: Python, input: String) -> PyResult<bool> {
match parse_json::<JsonInput>(&input) {
pub fn isinstance_json(&self, py: Python, input: &PyAny) -> PyResult<bool> {
match parse_json(input)? {
Ok(input) => {
match self.validator.validate(
py,
Expand Down Expand Up @@ -164,6 +163,20 @@ impl SchemaValidator {
}
}

fn parse_json(input: &PyAny) -> PyResult<serde_json::Result<JsonInput>> {
if let Ok(py_bytes) = input.cast_as::<PyBytes>() {
Ok(serde_json::from_slice(py_bytes.as_bytes()))
} else if let Ok(py_str) = input.cast_as::<PyString>() {
let str = py_str.to_str()?;
Ok(serde_json::from_str(str))
} else if let Ok(py_byte_array) = input.cast_as::<PyByteArray>() {
Ok(serde_json::from_slice(unsafe { py_byte_array.as_bytes() }))
} else {
let input_type = input.get_type().name().unwrap_or("unknown");
py_error!(PyTypeError; "JSON input must be str, bytes or bytearray, not {}", input_type)
}
}

pub trait BuildValidator: Sized {
const EXPECTED_TYPE: &'static str;

Expand Down
12 changes: 12 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ def test_bool(input_value, output_value):
assert v.validate_json(input_value) == output_value


@pytest.mark.parametrize('input_value', ['[1, 2, 3]', b'[1, 2, 3]', bytearray(b'[1, 2, 3]')])
def test_input_types(input_value):
v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}})
assert v.validate_json(input_value) == [1, 2, 3]


def test_input_type_invalid():
v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}})
with pytest.raises(TypeError, match='^JSON input must be str, bytes or bytearray, not list$'):
v.validate_json([])


def test_null():
assert SchemaValidator({'type': 'none'}).validate_json('null') is None

Expand Down
5 changes: 4 additions & 1 deletion tests/validators/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def test_strict_bytes_validator():
assert v.validate_json('"foo"') == b'foo'

with pytest.raises(ValidationError, match='Value must be a valid bytes'):
assert v.validate_python('foo') == b'foo'
v.validate_python('foo')
with pytest.raises(ValidationError, match='Value must be a valid bytes'):
v.validate_python(bytearray(b'foo'))


def test_lax_bytes_validator():
v = SchemaValidator({'type': 'bytes'})

assert v.validate_python(b'foo') == b'foo'
assert v.validate_python('foo') == b'foo'
assert v.validate_python(bytearray(b'foo')) == b'foo'

assert v.validate_json('"foo"') == b'foo'

Expand Down
5 changes: 5 additions & 0 deletions tests/validators/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def test_str(py_or_json, input_value, expected):
[
('foobar', 'foobar'),
(b'foobar', 'foobar'),
(bytearray(b'foobar'), 'foobar'),
(b'\x81', Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode')),
(
bytearray(b'\x81'),
Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode'),
),
# null bytes are very annoying, but we can't really block them here
(b'\x00', '\x00'),
(123, '123'),
Expand Down