Skip to content

Commit dfb07f8

Browse files
authored
Json bytes and bytearray to JSON (#142)
* allow bytes and bytearray to json * more support for bytearray * match error message to python * fix rust benchmarks
1 parent 9b4c074 commit dfb07f8

File tree

7 files changed

+70
-41
lines changed

7 files changed

+70
-41
lines changed

benches/main.rs

+15-27
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ extern crate test;
55
use test::{black_box, Bencher};
66

77
use pyo3::prelude::*;
8-
use pyo3::types::PyDict;
8+
use pyo3::types::{PyDict, PyString};
99

1010
use _pydantic_core::SchemaValidator;
1111

@@ -14,17 +14,21 @@ fn build_schema_validator(py: Python, code: &str) -> SchemaValidator {
1414
SchemaValidator::py_new(py, schema).unwrap()
1515
}
1616

17+
fn json<'a>(py: Python<'a>, code: &'a str) -> &'a PyAny {
18+
black_box(PyString::new(py, code))
19+
}
20+
1721
#[bench]
1822
fn ints_json(bench: &mut Bencher) {
1923
let gil = Python::acquire_gil();
2024
let py = gil.python();
2125
let validator = build_schema_validator(py, "{'type': 'int'}");
2226

23-
let result = validator.validate_json(py, black_box("123".to_string())).unwrap();
27+
let result = validator.validate_json(py, json(py, "123")).unwrap();
2428
let result_int: i64 = result.extract(py).unwrap();
2529
assert_eq!(result_int, 123);
2630

27-
bench.iter(|| black_box(validator.validate_json(py, black_box("123".to_string())).unwrap()))
31+
bench.iter(|| black_box(validator.validate_json(py, json(py, "123")).unwrap()))
2832
}
2933

3034
#[bench]
@@ -53,10 +57,7 @@ fn list_int_json(bench: &mut Bencher) {
5357
(0..100).map(|x| x.to_string()).collect::<Vec<String>>().join(",")
5458
);
5559

56-
bench.iter(|| {
57-
let input = black_box(code.clone());
58-
black_box(validator.validate_json(py, input).unwrap())
59-
})
60+
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
6061
}
6162

6263
fn list_int_input(py: Python<'_>) -> (SchemaValidator, PyObject) {
@@ -110,8 +111,7 @@ fn list_error_json(bench: &mut Bencher) {
110111
.join(", ")
111112
);
112113

113-
let input = black_box(code.clone());
114-
match validator.validate_json(py, input) {
114+
match validator.validate_json(py, json(py, &code)) {
115115
Ok(_) => panic!("unexpectedly valid"),
116116
Err(e) => {
117117
let v = e.value(py);
@@ -122,12 +122,9 @@ fn list_error_json(bench: &mut Bencher) {
122122
}
123123
};
124124

125-
bench.iter(|| {
126-
let input = black_box(code.clone());
127-
match validator.validate_json(py, input) {
128-
Ok(_) => panic!("unexpectedly valid"),
129-
Err(e) => black_box(e),
130-
}
125+
bench.iter(|| match validator.validate_json(py, json(py, &code)) {
126+
Ok(_) => panic!("unexpectedly valid"),
127+
Err(e) => black_box(e),
131128
})
132129
}
133130

@@ -197,10 +194,7 @@ fn list_any_json(bench: &mut Bencher) {
197194
(0..100).map(|x| x.to_string()).collect::<Vec<String>>().join(",")
198195
);
199196

200-
bench.iter(|| {
201-
let input = black_box(code.clone());
202-
black_box(validator.validate_json(py, input).unwrap())
203-
})
197+
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
204198
}
205199

206200
#[bench]
@@ -242,10 +236,7 @@ fn dict_json(bench: &mut Bencher) {
242236
.join(", ")
243237
);
244238

245-
bench.iter(|| {
246-
let input = black_box(code.to_string());
247-
black_box(validator.validate_json(py, input).unwrap())
248-
})
239+
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
249240
}
250241

251242
#[bench]
@@ -343,10 +334,7 @@ fn typed_dict_json(bench: &mut Bencher) {
343334

344335
let code = r#"{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}"#.to_string();
345336

346-
bench.iter(|| {
347-
let input = black_box(code.clone());
348-
black_box(validator.validate_json(py, input).unwrap())
349-
})
337+
bench.iter(|| black_box(validator.validate_json(py, json(py, &code)).unwrap()))
350338
}
351339

352340
#[bench]

pydantic_core/_pydantic_core.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Union
22

33
from pydantic_core._types import Schema
44

@@ -8,8 +8,8 @@ class SchemaValidator:
88
def __init__(self, schema: Schema) -> None: ...
99
def validate_python(self, input: Any) -> Any: ...
1010
def isinstance_python(self, input: Any) -> bool: ...
11-
def validate_json(self, input: str) -> Any: ...
12-
def isinstance_json(self, input: str) -> bool: ...
11+
def validate_json(self, input: Union[str, bytes, bytearray]) -> Any: ...
12+
def isinstance_json(self, input: Union[str, bytes, bytearray]) -> bool: ...
1313
def validate_assignment(self, field: str, input: Any, data: Dict[str, Any]) -> Dict[str, Any]: ...
1414

1515
class SchemaError(ValueError):

src/input/input_python.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::str::from_utf8;
33
use pyo3::exceptions::{PyAttributeError, PyTypeError};
44
use pyo3::prelude::*;
55
use pyo3::types::{
6-
PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString,
7-
PyTime, PyTuple, PyType,
6+
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet,
7+
PyString, PyTime, PyTuple, PyType,
88
};
99
use pyo3::{intern, AsPyPointer};
1010

@@ -60,6 +60,12 @@ impl<'a> Input<'a> for PyAny {
6060
Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)),
6161
};
6262
Ok(str.into())
63+
} else if let Ok(py_byte_array) = self.cast_as::<PyByteArray>() {
64+
let str = match from_utf8(unsafe { py_byte_array.as_bytes() }) {
65+
Ok(s) => s,
66+
Err(_) => return Err(ValError::new(ErrorKind::StrUnicode, self)),
67+
};
68+
Ok(str.into())
6369
} else if self.cast_as::<PyBool>().is_ok() {
6470
// do this before int and float parsing as `False` is cast to `0` and we don't want False to
6571
// be returned as a string
@@ -270,6 +276,8 @@ impl<'a> Input<'a> for PyAny {
270276
} else if let Ok(py_str) = self.cast_as::<PyString>() {
271277
let string = py_str.to_string_lossy().to_string();
272278
Ok(string.into_bytes().into())
279+
} else if let Ok(py_byte_array) = self.cast_as::<PyByteArray>() {
280+
Ok(py_byte_array.to_vec().into())
273281
} else {
274282
Err(ValError::new(ErrorKind::BytesType, self))
275283
}

src/validators/mod.rs

+21-8
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ use std::fmt::Debug;
22

33
use enum_dispatch::enum_dispatch;
44

5-
use pyo3::exceptions::PyRecursionError;
5+
use pyo3::exceptions::{PyRecursionError, PyTypeError};
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyAny, PyDict};
8-
use serde_json::from_str as parse_json;
7+
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString};
98

109
use crate::build_tools::{py_error, SchemaDict, SchemaError};
1110
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError};
@@ -99,8 +98,8 @@ impl SchemaValidator {
9998
}
10099
}
101100

102-
pub fn validate_json(&self, py: Python, input: String) -> PyResult<PyObject> {
103-
match parse_json::<JsonInput>(&input) {
101+
pub fn validate_json(&self, py: Python, input: &PyAny) -> PyResult<PyObject> {
102+
match parse_json(input)? {
104103
Ok(input) => {
105104
let r = self.validator.validate(
106105
py,
@@ -112,15 +111,15 @@ impl SchemaValidator {
112111
r.map_err(|e| self.prepare_validation_err(py, e))
113112
}
114113
Err(e) => {
115-
let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, &input);
114+
let line_err = ValLineError::new(ErrorKind::InvalidJson { error: e.to_string() }, input);
116115
let err = ValError::LineErrors(vec![line_err]);
117116
Err(self.prepare_validation_err(py, err))
118117
}
119118
}
120119
}
121120

122-
pub fn isinstance_json(&self, py: Python, input: String) -> PyResult<bool> {
123-
match parse_json::<JsonInput>(&input) {
121+
pub fn isinstance_json(&self, py: Python, input: &PyAny) -> PyResult<bool> {
122+
match parse_json(input)? {
124123
Ok(input) => {
125124
match self.validator.validate(
126125
py,
@@ -164,6 +163,20 @@ impl SchemaValidator {
164163
}
165164
}
166165

166+
fn parse_json(input: &PyAny) -> PyResult<serde_json::Result<JsonInput>> {
167+
if let Ok(py_bytes) = input.cast_as::<PyBytes>() {
168+
Ok(serde_json::from_slice(py_bytes.as_bytes()))
169+
} else if let Ok(py_str) = input.cast_as::<PyString>() {
170+
let str = py_str.to_str()?;
171+
Ok(serde_json::from_str(str))
172+
} else if let Ok(py_byte_array) = input.cast_as::<PyByteArray>() {
173+
Ok(serde_json::from_slice(unsafe { py_byte_array.as_bytes() }))
174+
} else {
175+
let input_type = input.get_type().name().unwrap_or("unknown");
176+
py_error!(PyTypeError; "JSON input must be str, bytes or bytearray, not {}", input_type)
177+
}
178+
}
179+
167180
pub trait BuildValidator: Sized {
168181
const EXPECTED_TYPE: &'static str;
169182

tests/test_json.py

+12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ def test_bool(input_value, output_value):
1212
assert v.validate_json(input_value) == output_value
1313

1414

15+
@pytest.mark.parametrize('input_value', ['[1, 2, 3]', b'[1, 2, 3]', bytearray(b'[1, 2, 3]')])
16+
def test_input_types(input_value):
17+
v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}})
18+
assert v.validate_json(input_value) == [1, 2, 3]
19+
20+
21+
def test_input_type_invalid():
22+
v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}})
23+
with pytest.raises(TypeError, match='^JSON input must be str, bytes or bytearray, not list$'):
24+
v.validate_json([])
25+
26+
1527
def test_null():
1628
assert SchemaValidator({'type': 'none'}).validate_json('null') is None
1729

tests/validators/test_bytes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ def test_strict_bytes_validator():
1414
assert v.validate_json('"foo"') == b'foo'
1515

1616
with pytest.raises(ValidationError, match='Value must be a valid bytes'):
17-
assert v.validate_python('foo') == b'foo'
17+
v.validate_python('foo')
18+
with pytest.raises(ValidationError, match='Value must be a valid bytes'):
19+
v.validate_python(bytearray(b'foo'))
1820

1921

2022
def test_lax_bytes_validator():
2123
v = SchemaValidator({'type': 'bytes'})
2224

2325
assert v.validate_python(b'foo') == b'foo'
2426
assert v.validate_python('foo') == b'foo'
27+
assert v.validate_python(bytearray(b'foo')) == b'foo'
2528

2629
assert v.validate_json('"foo"') == b'foo'
2730

tests/validators/test_string.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def test_str(py_or_json, input_value, expected):
4141
[
4242
('foobar', 'foobar'),
4343
(b'foobar', 'foobar'),
44+
(bytearray(b'foobar'), 'foobar'),
4445
(b'\x81', Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode')),
46+
(
47+
bytearray(b'\x81'),
48+
Err('Value must be a valid string, unable to parse raw data as a unicode string [kind=str_unicode'),
49+
),
4550
# null bytes are very annoying, but we can't really block them here
4651
(b'\x00', '\x00'),
4752
(123, '123'),

0 commit comments

Comments
 (0)