diff --git a/benches/main.rs b/benches/main.rs index ff1a78364..a17ba09ca 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -5,7 +5,7 @@ extern crate test; use test::{black_box, Bencher}; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyDict, PyString}; use _pydantic_core::SchemaValidator; @@ -14,17 +14,21 @@ fn build_schema_validator(py: Python, code: &str) -> SchemaValidator { SchemaValidator::py_new(py, schema).unwrap() } +fn json<'a>(py: Python<'a>, code: &'a str) -> &'a PyAny { + black_box(PyString::new(py, code)) +} + #[bench] fn ints_json(bench: &mut Bencher) { let gil = Python::acquire_gil(); let py = gil.python(); let validator = build_schema_validator(py, "{'type': 'int'}"); - let result = validator.validate_json(py, black_box("123".to_string())).unwrap(); + let result = validator.validate_json(py, json(py, "123")).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 123); - bench.iter(|| black_box(validator.validate_json(py, black_box("123".to_string())).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, json(py, "123")).unwrap())) } #[bench] @@ -53,10 +57,7 @@ fn list_int_json(bench: &mut Bencher) { (0..100).map(|x| x.to_string()).collect::>().join(",") ); - bench.iter(|| { - let input = black_box(code.clone()); - black_box(validator.validate_json(py, input).unwrap()) - }) + bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap())) } fn list_int_input(py: Python<'_>) -> (SchemaValidator, PyObject) { @@ -110,8 +111,7 @@ fn list_error_json(bench: &mut Bencher) { .join(", ") ); - let input = black_box(code.clone()); - match validator.validate_json(py, input) { + match validator.validate_json(py, json(py, &code)) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); @@ -122,12 +122,9 @@ fn list_error_json(bench: &mut Bencher) { } }; - bench.iter(|| { - let input = black_box(code.clone()); - match validator.validate_json(py, input) { - Ok(_) => panic!("unexpectedly valid"), - Err(e) => black_box(e), - } + bench.iter(|| match validator.validate_json(py, json(py, &code)) { + Ok(_) => panic!("unexpectedly valid"), + Err(e) => black_box(e), }) } @@ -197,10 +194,7 @@ fn list_any_json(bench: &mut Bencher) { (0..100).map(|x| x.to_string()).collect::>().join(",") ); - bench.iter(|| { - let input = black_box(code.clone()); - black_box(validator.validate_json(py, input).unwrap()) - }) + bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap())) } #[bench] @@ -242,10 +236,7 @@ fn dict_json(bench: &mut Bencher) { .join(", ") ); - bench.iter(|| { - let input = black_box(code.to_string()); - black_box(validator.validate_json(py, input).unwrap()) - }) + bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap())) } #[bench] @@ -343,10 +334,7 @@ fn typed_dict_json(bench: &mut Bencher) { let code = r#"{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}"#.to_string(); - bench.iter(|| { - let input = black_box(code.clone()); - black_box(validator.validate_json(py, input).unwrap()) - }) + bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap())) } #[bench] diff --git a/pydantic_core/_pydantic_core.pyi b/pydantic_core/_pydantic_core.pyi index 7eb75e44e..21bc10a39 100644 --- a/pydantic_core/_pydantic_core.pyi +++ b/pydantic_core/_pydantic_core.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from pydantic_core._types import Schema @@ -8,8 +8,8 @@ class SchemaValidator: def __init__(self, schema: Schema) -> None: ... def validate_python(self, input: Any) -> Any: ... def isinstance_python(self, input: Any) -> bool: ... - def validate_json(self, input: str) -> Any: ... - def isinstance_json(self, input: str) -> bool: ... + def validate_json(self, input: Union[str, bytes, bytearray]) -> Any: ... + def isinstance_json(self, input: Union[str, bytes, bytearray]) -> bool: ... def validate_assignment(self, field: str, input: Any, data: Dict[str, Any]) -> Dict[str, Any]: ... class SchemaError(ValueError): diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 9a48f2f22..d228f15d8 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -3,8 +3,8 @@ use std::str::from_utf8; use pyo3::exceptions::{PyAttributeError, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{ - PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString, - PyTime, PyTuple, PyType, + PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, + PyString, PyTime, PyTuple, PyType, }; use pyo3::{intern, AsPyPointer}; @@ -60,6 +60,12 @@ impl<'a> Input<'a> for PyAny { Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)), }; Ok(str.into()) + } else if let Ok(py_byte_array) = self.cast_as::() { + let str = match from_utf8(unsafe { py_byte_array.as_bytes() }) { + Ok(s) => s, + Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)), + }; + Ok(str.into()) } else if self.cast_as::().is_ok() { // do this before int and float parsing as `False` is cast to `0` and we don't want False to // be returned as a string @@ -270,6 +276,8 @@ impl<'a> Input<'a> for PyAny { } else if let Ok(py_str) = self.cast_as::() { let string = py_str.to_string_lossy().to_string(); Ok(string.into_bytes().into()) + } else if let Ok(py_byte_array) = self.cast_as::() { + Ok(py_byte_array.to_vec().into()) } else { Err(ValError::new(ErrorKind::BytesType, self)) } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 0aea74849..acf13fdd5 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -2,10 +2,9 @@ use std::fmt::Debug; use enum_dispatch::enum_dispatch; -use pyo3::exceptions::PyRecursionError; +use pyo3::exceptions::{PyRecursionError, PyTypeError}; use pyo3::prelude::*; -use pyo3::types::{PyAny, PyDict}; -use serde_json::from_str as parse_json; +use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString}; use crate::build_tools::{py_error, SchemaDict, SchemaError}; use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError}; @@ -99,8 +98,8 @@ impl SchemaValidator { } } - pub fn validate_json(&self, py: Python, input: String) -> PyResult { - match parse_json::(&input) { + pub fn validate_json(&self, py: Python, input: &PyAny) -> PyResult { + match parse_json(input)? { Ok(input) => { let r = self.validator.validate( py, @@ -112,15 +111,15 @@ impl SchemaValidator { r.map_err(|e| self.prepare_validation_err(py, e)) } Err(e) => { - let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, &input); + let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, input); let err = ValError::LineErrors(vec![line_err]); Err(self.prepare_validation_err(py, err)) } } } - pub fn isinstance_json(&self, py: Python, input: String) -> PyResult { - match parse_json::(&input) { + pub fn isinstance_json(&self, py: Python, input: &PyAny) -> PyResult { + match parse_json(input)? { Ok(input) => { match self.validator.validate( py, @@ -164,6 +163,20 @@ impl SchemaValidator { } } +fn parse_json(input: &PyAny) -> PyResult> { + if let Ok(py_bytes) = input.cast_as::() { + Ok(serde_json::from_slice(py_bytes.as_bytes())) + } else if let Ok(py_str) = input.cast_as::() { + let str = py_str.to_str()?; + Ok(serde_json::from_str(str)) + } else if let Ok(py_byte_array) = input.cast_as::() { + Ok(serde_json::from_slice(unsafe { py_byte_array.as_bytes() })) + } else { + let input_type = input.get_type().name().unwrap_or("unknown"); + py_error!(PyTypeError; "JSON input must be str, bytes or bytearray, not {}", input_type) + } +} + pub trait BuildValidator: Sized { const EXPECTED_TYPE: &'static str; diff --git a/tests/test_json.py b/tests/test_json.py index a136deddc..ff6a4e526 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -12,6 +12,18 @@ def test_bool(input_value, output_value): assert v.validate_json(input_value) == output_value +@pytest.mark.parametrize('input_value', ['[1, 2, 3]', b'[1, 2, 3]', bytearray(b'[1, 2, 3]')]) +def test_input_types(input_value): + v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}}) + assert v.validate_json(input_value) == [1, 2, 3] + + +def test_input_type_invalid(): + v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}}) + with pytest.raises(TypeError, match='^JSON input must be str, bytes or bytearray, not list$'): + v.validate_json([]) + + def test_null(): assert SchemaValidator({'type': 'none'}).validate_json('null') is None diff --git a/tests/validators/test_bytes.py b/tests/validators/test_bytes.py index cd216ea15..41cefa6ac 100644 --- a/tests/validators/test_bytes.py +++ b/tests/validators/test_bytes.py @@ -14,7 +14,9 @@ def test_strict_bytes_validator(): assert v.validate_json('"foo"') == b'foo' with pytest.raises(ValidationError, match='Value must be a valid bytes'): - assert v.validate_python('foo') == b'foo' + v.validate_python('foo') + with pytest.raises(ValidationError, match='Value must be a valid bytes'): + v.validate_python(bytearray(b'foo')) def test_lax_bytes_validator(): @@ -22,6 +24,7 @@ def test_lax_bytes_validator(): assert v.validate_python(b'foo') == b'foo' assert v.validate_python('foo') == b'foo' + assert v.validate_python(bytearray(b'foo')) == b'foo' assert v.validate_json('"foo"') == b'foo' diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index 8123af827..f802bde10 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -41,7 +41,12 @@ def test_str(py_or_json, input_value, expected): [ ('foobar', 'foobar'), (b'foobar', 'foobar'), + (bytearray(b'foobar'), 'foobar'), (b'\x81', Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode')), + ( + bytearray(b'\x81'), + Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode'), + ), # null bytes are very annoying, but we can't really block them here (b'\x00', '\x00'), (123, '123'),