1
1
use std:: collections:: { hash_map:: DefaultHasher , HashMap } ;
2
2
use std:: hash:: { Hash , Hasher } ;
3
3
4
- use numpy:: PyArray1 ;
4
+ use numpy:: npyffi ;
5
5
use pyo3:: class:: basic:: CompareOp ;
6
6
use pyo3:: exceptions;
7
7
use pyo3:: prelude:: * ;
8
8
use pyo3:: types:: * ;
9
+ use pyo3:: AsPyPointer ;
9
10
use tk:: models:: bpe:: BPE ;
10
11
use tk:: tokenizer:: {
11
12
Model , PaddingDirection , PaddingParams , PaddingStrategy , PostProcessor , TokenizerImpl ,
@@ -258,8 +259,24 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
258
259
struct PyArrayUnicode ( Vec < String > ) ;
259
260
impl FromPyObject < ' _ > for PyArrayUnicode {
260
261
fn extract ( ob : & PyAny ) -> PyResult < Self > {
261
- let array = ob. downcast :: < PyArray1 < u8 > > ( ) ?;
262
- let arr = array. as_array_ptr ( ) ;
262
+ if unsafe { npyffi:: PyArray_Check ( ob. py ( ) , ob. as_ptr ( ) ) } == 0 {
263
+ return Err ( exceptions:: PyTypeError :: new_err ( "Expected an np.array" ) ) ;
264
+ }
265
+ let arr = ob. as_ptr ( ) as * mut npyffi:: PyArrayObject ;
266
+ if unsafe { ( * arr) . nd } != 1 {
267
+ return Err ( exceptions:: PyTypeError :: new_err (
268
+ "Expected a 1 dimensional np.array" ,
269
+ ) ) ;
270
+ }
271
+ if unsafe { ( * arr) . flags }
272
+ & ( npyffi:: NPY_ARRAY_C_CONTIGUOUS | npyffi:: NPY_ARRAY_F_CONTIGUOUS )
273
+ == 0
274
+ {
275
+ return Err ( exceptions:: PyTypeError :: new_err (
276
+ "Expected a continuous np.array" ,
277
+ ) ) ;
278
+ }
279
+ let n_elem = unsafe { * ( * arr) . dimensions } as usize ;
263
280
let ( type_num, elsize, alignment, data) = unsafe {
264
281
let desc = ( * arr) . descr ;
265
282
(
@@ -269,7 +286,6 @@ impl FromPyObject<'_> for PyArrayUnicode {
269
286
( * arr) . data ,
270
287
)
271
288
} ;
272
- let n_elem = array. shape ( ) [ 0 ] ;
273
289
274
290
// type_num == 19 => Unicode
275
291
if type_num != 19 {
@@ -310,10 +326,27 @@ impl From<PyArrayUnicode> for tk::InputSequence<'_> {
310
326
struct PyArrayStr ( Vec < String > ) ;
311
327
impl FromPyObject < ' _ > for PyArrayStr {
312
328
fn extract ( ob : & PyAny ) -> PyResult < Self > {
313
- let array = ob. downcast :: < PyArray1 < u8 > > ( ) ?;
314
- let arr = array. as_array_ptr ( ) ;
329
+ if unsafe { npyffi:: PyArray_Check ( ob. py ( ) , ob. as_ptr ( ) ) } == 0 {
330
+ return Err ( exceptions:: PyTypeError :: new_err ( "Expected an np.array" ) ) ;
331
+ }
332
+ let arr = ob. as_ptr ( ) as * mut npyffi:: PyArrayObject ;
333
+
334
+ if unsafe { ( * arr) . nd } != 1 {
335
+ return Err ( exceptions:: PyTypeError :: new_err (
336
+ "Expected a 1 dimensional np.array" ,
337
+ ) ) ;
338
+ }
339
+ if unsafe { ( * arr) . flags }
340
+ & ( npyffi:: NPY_ARRAY_C_CONTIGUOUS | npyffi:: NPY_ARRAY_F_CONTIGUOUS )
341
+ == 0
342
+ {
343
+ return Err ( exceptions:: PyTypeError :: new_err (
344
+ "Expected a continuous np.array" ,
345
+ ) ) ;
346
+ }
347
+ let n_elem = unsafe { * ( * arr) . dimensions } as usize ;
348
+
315
349
let ( type_num, data) = unsafe { ( ( * ( * arr) . descr ) . type_num , ( * arr) . data ) } ;
316
- let n_elem = array. shape ( ) [ 0 ] ;
317
350
318
351
if type_num != 17 {
319
352
return Err ( exceptions:: PyTypeError :: new_err (
0 commit comments