Skip to content

Commit 02210b8

Browse files
committed
Use PyArray_Check instead of downcasting to PyArray1<u8>
1 parent 2b72e00 commit 02210b8

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

bindings/python/src/tokenizer.rs

+40-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::collections::{hash_map::DefaultHasher, HashMap};
22
use std::hash::{Hash, Hasher};
33

4-
use numpy::PyArray1;
4+
use numpy::npyffi;
55
use pyo3::class::basic::CompareOp;
66
use pyo3::exceptions;
77
use pyo3::prelude::*;
88
use pyo3::types::*;
9+
use pyo3::AsPyPointer;
910
use tk::models::bpe::BPE;
1011
use tk::tokenizer::{
1112
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
@@ -258,8 +259,24 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
258259
struct PyArrayUnicode(Vec<String>);
259260
impl FromPyObject<'_> for PyArrayUnicode {
260261
fn extract(ob: &PyAny) -> PyResult<Self> {
261-
let array = ob.downcast::<PyArray1<u8>>()?;
262-
let arr = array.as_array_ptr();
262+
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
263+
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
264+
}
265+
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
266+
if unsafe { (*arr).nd } != 1 {
267+
return Err(exceptions::PyTypeError::new_err(
268+
"Expected a 1 dimensional np.array",
269+
));
270+
}
271+
if unsafe { (*arr).flags }
272+
& (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
273+
== 0
274+
{
275+
return Err(exceptions::PyTypeError::new_err(
276+
"Expected a continuous np.array",
277+
));
278+
}
279+
let n_elem = unsafe { *(*arr).dimensions } as usize;
263280
let (type_num, elsize, alignment, data) = unsafe {
264281
let desc = (*arr).descr;
265282
(
@@ -269,7 +286,6 @@ impl FromPyObject<'_> for PyArrayUnicode {
269286
(*arr).data,
270287
)
271288
};
272-
let n_elem = array.shape()[0];
273289

274290
// type_num == 19 => Unicode
275291
if type_num != 19 {
@@ -310,10 +326,27 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {
310326
struct PyArrayStr(Vec<String>);
311327
impl FromPyObject<'_> for PyArrayStr {
312328
fn extract(ob: &PyAny) -> PyResult<Self> {
313-
let array = ob.downcast::<PyArray1<u8>>()?;
314-
let arr = array.as_array_ptr();
329+
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
330+
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
331+
}
332+
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
333+
334+
if unsafe { (*arr).nd } != 1 {
335+
return Err(exceptions::PyTypeError::new_err(
336+
"Expected a 1 dimensional np.array",
337+
));
338+
}
339+
if unsafe { (*arr).flags }
340+
& (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
341+
== 0
342+
{
343+
return Err(exceptions::PyTypeError::new_err(
344+
"Expected a continuous np.array",
345+
));
346+
}
347+
let n_elem = unsafe { *(*arr).dimensions } as usize;
348+
315349
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) };
316-
let n_elem = array.shape()[0];
317350

318351
if type_num != 17 {
319352
return Err(exceptions::PyTypeError::new_err(

0 commit comments

Comments
 (0)