18
18
19
19
//! `PyBuffer` implementation
20
20
use crate :: err:: { self , PyResult } ;
21
- use crate :: { exceptions, ffi, AsPyPointer , PyAny , Python } ;
21
+ use crate :: { exceptions, ffi, AsPyPointer , FromPyObject , PyAny , PyNativeType , Python } ;
22
22
use std:: ffi:: CStr ;
23
+ use std:: marker:: PhantomData ;
23
24
use std:: os:: raw;
24
25
use std:: pin:: Pin ;
25
26
use std:: { cell, mem, ptr, slice} ;
26
27
27
28
/// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`.
28
29
// use Pin<Box> because Python expects that the Py_buffer struct has a stable memory address
29
30
#[ repr( transparent) ]
30
- pub struct PyBuffer ( Pin < Box < ffi:: Py_buffer > > ) ;
31
+ pub struct PyBuffer < T > ( Pin < Box < ffi:: Py_buffer > > , PhantomData < T > ) ;
31
32
32
33
// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
33
34
// Accessing the buffer contents is protected using the GIL.
34
- unsafe impl Send for PyBuffer { }
35
- unsafe impl Sync for PyBuffer { }
35
+ unsafe impl < T > Send for PyBuffer < T > { }
36
+ unsafe impl < T > Sync for PyBuffer < T > { }
36
37
37
38
#[ derive( Copy , Clone , Eq , PartialEq ) ]
38
39
pub enum ElementType {
@@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool {
146
147
}
147
148
148
149
/// Trait implemented for possible element types of `PyBuffer`.
149
- pub unsafe trait Element {
150
+ pub unsafe trait Element : Copy {
150
151
/// Gets whether the element specified in the format string is potentially compatible.
151
152
/// Alignment and size are checked separately from this function.
152
153
fn is_compatible_format ( format : & CStr ) -> bool ;
153
154
}
154
155
155
- fn validate ( b : & ffi:: Py_buffer ) {
156
+ fn validate ( b : & ffi:: Py_buffer ) -> PyResult < ( ) > {
156
157
// shape and stride information must be provided when we use PyBUF_FULL_RO
157
- assert ! ( !b. shape. is_null( ) ) ;
158
- assert ! ( !b. strides. is_null( ) ) ;
158
+ if b. shape . is_null ( ) {
159
+ return Err ( exceptions:: BufferError :: py_err ( "Shape is Null" ) ) ;
160
+ }
161
+ if b. strides . is_null ( ) {
162
+ return Err ( exceptions:: BufferError :: py_err ( "PyBuffer: Strides is Null" ) ) ;
163
+ }
164
+ Ok ( ( ) )
165
+ }
166
+
167
+ impl < ' source , T : Element > FromPyObject < ' source > for PyBuffer < T > {
168
+ fn extract ( obj : & PyAny ) -> PyResult < PyBuffer < T > > {
169
+ Self :: get ( obj)
170
+ }
159
171
}
160
172
161
- impl PyBuffer {
173
+ impl < T : Element > PyBuffer < T > {
162
174
/// Get the underlying buffer from the specified python object.
163
- pub fn get ( py : Python , obj : & PyAny ) -> PyResult < PyBuffer > {
175
+ pub fn get ( obj : & PyAny ) -> PyResult < PyBuffer < T > > {
164
176
unsafe {
165
- let mut buf = Box :: pin ( mem :: zeroed :: < ffi:: Py_buffer > ( ) ) ;
177
+ let mut buf = Box :: pin ( ffi:: Py_buffer :: new ( ) ) ;
166
178
err:: error_on_minusone (
167
- py ,
179
+ obj . py ( ) ,
168
180
ffi:: PyObject_GetBuffer ( obj. as_ptr ( ) , & mut * buf, ffi:: PyBUF_FULL_RO ) ,
169
181
) ?;
170
- validate ( & buf) ;
171
- Ok ( PyBuffer ( buf) )
182
+ validate ( & buf) ?;
183
+ let buf = PyBuffer ( buf, PhantomData ) ;
184
+ // Type Check
185
+ if mem:: size_of :: < T > ( ) == buf. item_size ( )
186
+ && ( buf. 0 . buf as usize ) % mem:: align_of :: < T > ( ) == 0
187
+ && T :: is_compatible_format ( buf. format ( ) )
188
+ {
189
+ Ok ( buf)
190
+ } else {
191
+ Err ( exceptions:: BufferError :: py_err (
192
+ "Incompatible type as buffer" ,
193
+ ) )
194
+ }
172
195
}
173
196
}
174
197
@@ -307,12 +330,8 @@ impl PyBuffer {
307
330
///
308
331
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
309
332
/// to modify the values in the slice.
310
- pub fn as_slice < ' a , T : Element > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ ReadOnlyCell < T > ] > {
311
- if mem:: size_of :: < T > ( ) == self . item_size ( )
312
- && ( self . 0 . buf as usize ) % mem:: align_of :: < T > ( ) == 0
313
- && self . is_c_contiguous ( )
314
- && T :: is_compatible_format ( self . format ( ) )
315
- {
333
+ pub fn as_slice < ' a > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ ReadOnlyCell < T > ] > {
334
+ if self . is_c_contiguous ( ) {
316
335
unsafe {
317
336
Some ( slice:: from_raw_parts (
318
337
self . 0 . buf as * mut ReadOnlyCell < T > ,
@@ -334,13 +353,8 @@ impl PyBuffer {
334
353
///
335
354
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
336
355
/// to modify the values in the slice.
337
- pub fn as_mut_slice < ' a , T : Element > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ cell:: Cell < T > ] > {
338
- if !self . readonly ( )
339
- && mem:: size_of :: < T > ( ) == self . item_size ( )
340
- && ( self . 0 . buf as usize ) % mem:: align_of :: < T > ( ) == 0
341
- && self . is_c_contiguous ( )
342
- && T :: is_compatible_format ( self . format ( ) )
343
- {
356
+ pub fn as_mut_slice < ' a > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ cell:: Cell < T > ] > {
357
+ if !self . readonly ( ) && self . is_c_contiguous ( ) {
344
358
unsafe {
345
359
Some ( slice:: from_raw_parts (
346
360
self . 0 . buf as * mut cell:: Cell < T > ,
@@ -361,15 +375,8 @@ impl PyBuffer {
361
375
///
362
376
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
363
377
/// to modify the values in the slice.
364
- pub fn as_fortran_slice < ' a , T : Element > (
365
- & ' a self ,
366
- _py : Python < ' a > ,
367
- ) -> Option < & ' a [ ReadOnlyCell < T > ] > {
368
- if mem:: size_of :: < T > ( ) == self . item_size ( )
369
- && ( self . 0 . buf as usize ) % mem:: align_of :: < T > ( ) == 0
370
- && self . is_fortran_contiguous ( )
371
- && T :: is_compatible_format ( self . format ( ) )
372
- {
378
+ pub fn as_fortran_slice < ' a > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ ReadOnlyCell < T > ] > {
379
+ if mem:: size_of :: < T > ( ) == self . item_size ( ) && self . is_fortran_contiguous ( ) {
373
380
unsafe {
374
381
Some ( slice:: from_raw_parts (
375
382
self . 0 . buf as * mut ReadOnlyCell < T > ,
@@ -391,16 +398,8 @@ impl PyBuffer {
391
398
///
392
399
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
393
400
/// to modify the values in the slice.
394
- pub fn as_fortran_mut_slice < ' a , T : Element > (
395
- & ' a self ,
396
- _py : Python < ' a > ,
397
- ) -> Option < & ' a [ cell:: Cell < T > ] > {
398
- if !self . readonly ( )
399
- && mem:: size_of :: < T > ( ) == self . item_size ( )
400
- && ( self . 0 . buf as usize ) % mem:: align_of :: < T > ( ) == 0
401
- && self . is_fortran_contiguous ( )
402
- && T :: is_compatible_format ( self . format ( ) )
403
- {
401
+ pub fn as_fortran_mut_slice < ' a > ( & ' a self , _py : Python < ' a > ) -> Option < & ' a [ cell:: Cell < T > ] > {
402
+ if !self . readonly ( ) && self . is_fortran_contiguous ( ) {
404
403
unsafe {
405
404
Some ( slice:: from_raw_parts (
406
405
self . 0 . buf as * mut cell:: Cell < T > ,
@@ -421,7 +420,7 @@ impl PyBuffer {
421
420
/// To check whether the buffer format is compatible before calling this method,
422
421
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
423
422
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
424
- pub fn copy_to_slice < T : Element + Copy > ( & self , py : Python , target : & mut [ T ] ) -> PyResult < ( ) > {
423
+ pub fn copy_to_slice ( & self , py : Python , target : & mut [ T ] ) -> PyResult < ( ) > {
425
424
self . copy_to_slice_impl ( py, target, b'C' )
426
425
}
427
426
@@ -434,28 +433,16 @@ impl PyBuffer {
434
433
/// To check whether the buffer format is compatible before calling this method,
435
434
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
436
435
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
437
- pub fn copy_to_fortran_slice < T : Element + Copy > (
438
- & self ,
439
- py : Python ,
440
- target : & mut [ T ] ,
441
- ) -> PyResult < ( ) > {
436
+ pub fn copy_to_fortran_slice ( & self , py : Python , target : & mut [ T ] ) -> PyResult < ( ) > {
442
437
self . copy_to_slice_impl ( py, target, b'F' )
443
438
}
444
439
445
- fn copy_to_slice_impl < T : Element + Copy > (
446
- & self ,
447
- py : Python ,
448
- target : & mut [ T ] ,
449
- fort : u8 ,
450
- ) -> PyResult < ( ) > {
440
+ fn copy_to_slice_impl ( & self , py : Python , target : & mut [ T ] , fort : u8 ) -> PyResult < ( ) > {
451
441
if mem:: size_of_val ( target) != self . len_bytes ( ) {
452
442
return Err ( exceptions:: BufferError :: py_err (
453
443
"Slice length does not match buffer length." ,
454
444
) ) ;
455
445
}
456
- if !T :: is_compatible_format ( self . format ( ) ) || mem:: size_of :: < T > ( ) != self . item_size ( ) {
457
- return incompatible_format_error ( ) ;
458
- }
459
446
unsafe {
460
447
err:: error_on_minusone (
461
448
py,
@@ -473,23 +460,19 @@ impl PyBuffer {
473
460
/// If the buffer is multi-dimensional, the elements are written in C-style order.
474
461
///
475
462
/// Fails if the buffer format is not compatible with type `T`.
476
- pub fn to_vec < T : Element + Copy > ( & self , py : Python ) -> PyResult < Vec < T > > {
463
+ pub fn to_vec ( & self , py : Python ) -> PyResult < Vec < T > > {
477
464
self . to_vec_impl ( py, b'C' )
478
465
}
479
466
480
467
/// Copies the buffer elements to a newly allocated vector.
481
468
/// If the buffer is multi-dimensional, the elements are written in Fortran-style order.
482
469
///
483
470
/// Fails if the buffer format is not compatible with type `T`.
484
- pub fn to_fortran_vec < T : Element + Copy > ( & self , py : Python ) -> PyResult < Vec < T > > {
471
+ pub fn to_fortran_vec ( & self , py : Python ) -> PyResult < Vec < T > > {
485
472
self . to_vec_impl ( py, b'F' )
486
473
}
487
474
488
- fn to_vec_impl < T : Element + Copy > ( & self , py : Python , fort : u8 ) -> PyResult < Vec < T > > {
489
- if !T :: is_compatible_format ( self . format ( ) ) || mem:: size_of :: < T > ( ) != self . item_size ( ) {
490
- incompatible_format_error ( ) ?;
491
- unreachable ! ( ) ;
492
- }
475
+ fn to_vec_impl ( & self , py : Python , fort : u8 ) -> PyResult < Vec < T > > {
493
476
let item_count = self . item_count ( ) ;
494
477
let mut vec: Vec < T > = Vec :: with_capacity ( item_count) ;
495
478
unsafe {
@@ -520,7 +503,7 @@ impl PyBuffer {
520
503
/// To check whether the buffer format is compatible before calling this method,
521
504
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
522
505
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
523
- pub fn copy_from_slice < T : Element + Copy > ( & self , py : Python , source : & [ T ] ) -> PyResult < ( ) > {
506
+ pub fn copy_from_slice ( & self , py : Python , source : & [ T ] ) -> PyResult < ( ) > {
524
507
self . copy_from_slice_impl ( py, source, b'C' )
525
508
}
526
509
@@ -534,20 +517,11 @@ impl PyBuffer {
534
517
/// To check whether the buffer format is compatible before calling this method,
535
518
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
536
519
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
537
- pub fn copy_from_fortran_slice < T : Element + Copy > (
538
- & self ,
539
- py : Python ,
540
- source : & [ T ] ,
541
- ) -> PyResult < ( ) > {
520
+ pub fn copy_from_fortran_slice ( & self , py : Python , source : & [ T ] ) -> PyResult < ( ) > {
542
521
self . copy_from_slice_impl ( py, source, b'F' )
543
522
}
544
523
545
- fn copy_from_slice_impl < T : Element + Copy > (
546
- & self ,
547
- py : Python ,
548
- source : & [ T ] ,
549
- fort : u8 ,
550
- ) -> PyResult < ( ) > {
524
+ fn copy_from_slice_impl ( & self , py : Python , source : & [ T ] , fort : u8 ) -> PyResult < ( ) > {
551
525
if self . readonly ( ) {
552
526
return buffer_readonly_error ( ) ;
553
527
}
@@ -556,9 +530,6 @@ impl PyBuffer {
556
530
"Slice length does not match buffer length." ,
557
531
) ) ;
558
532
}
559
- if !T :: is_compatible_format ( self . format ( ) ) || mem:: size_of :: < T > ( ) != self . item_size ( ) {
560
- return incompatible_format_error ( ) ;
561
- }
562
533
unsafe {
563
534
err:: error_on_minusone (
564
535
py,
@@ -589,19 +560,14 @@ impl PyBuffer {
589
560
}
590
561
}
591
562
592
- fn incompatible_format_error ( ) -> PyResult < ( ) > {
593
- Err ( exceptions:: BufferError :: py_err (
594
- "Slice type is incompatible with buffer format." ,
595
- ) )
596
- }
597
-
563
+ #[ inline( always) ]
598
564
fn buffer_readonly_error ( ) -> PyResult < ( ) > {
599
565
Err ( exceptions:: BufferError :: py_err (
600
566
"Cannot write to read-only buffer." ,
601
567
) )
602
568
}
603
569
604
- impl Drop for PyBuffer {
570
+ impl < T > Drop for PyBuffer < T > {
605
571
fn drop ( & mut self ) {
606
572
let _gil_guard = Python :: acquire_gil ( ) ;
607
573
unsafe { ffi:: PyBuffer_Release ( & mut * self . 0 ) }
@@ -614,9 +580,9 @@ impl Drop for PyBuffer {
614
580
/// The data cannot be modified through the reference, but other references may
615
581
/// be modifying the data.
616
582
#[ repr( transparent) ]
617
- pub struct ReadOnlyCell < T > ( cell:: UnsafeCell < T > ) ;
583
+ pub struct ReadOnlyCell < T : Element > ( cell:: UnsafeCell < T > ) ;
618
584
619
- impl < T : Copy > ReadOnlyCell < T > {
585
+ impl < T : Element > ReadOnlyCell < T > {
620
586
#[ inline]
621
587
pub fn get ( & self ) -> T {
622
588
unsafe { * self . 0 . get ( ) }
@@ -675,7 +641,7 @@ mod test {
675
641
let gil = Python :: acquire_gil ( ) ;
676
642
let py = gil. python ( ) ;
677
643
let bytes = py. eval ( "b'abcde'" , None , None ) . unwrap ( ) ;
678
- let buffer = PyBuffer :: get ( py , & bytes) . unwrap ( ) ;
644
+ let buffer = PyBuffer :: get ( & bytes) . unwrap ( ) ;
679
645
assert_eq ! ( buffer. dimensions( ) , 1 ) ;
680
646
assert_eq ! ( buffer. item_count( ) , 5 ) ;
681
647
assert_eq ! ( buffer. format( ) . to_str( ) . unwrap( ) , "B" ) ;
@@ -684,26 +650,18 @@ mod test {
684
650
assert ! ( buffer. is_c_contiguous( ) ) ;
685
651
assert ! ( buffer. is_fortran_contiguous( ) ) ;
686
652
687
- assert ! ( buffer. as_slice:: <f64 >( py) . is_none( ) ) ;
688
- assert ! ( buffer. as_slice:: <i8 >( py) . is_none( ) ) ;
689
-
690
- let slice = buffer. as_slice :: < u8 > ( py) . unwrap ( ) ;
653
+ let slice = buffer. as_slice ( py) . unwrap ( ) ;
691
654
assert_eq ! ( slice. len( ) , 5 ) ;
692
655
assert_eq ! ( slice[ 0 ] . get( ) , b'a' ) ;
693
656
assert_eq ! ( slice[ 2 ] . get( ) , b'c' ) ;
694
657
695
- assert ! ( buffer. as_mut_slice:: <u8 >( py) . is_none( ) ) ;
696
-
697
658
assert ! ( buffer. copy_to_slice( py, & mut [ 0u8 ] ) . is_err( ) ) ;
698
659
let mut arr = [ 0 ; 5 ] ;
699
660
buffer. copy_to_slice ( py, & mut arr) . unwrap ( ) ;
700
661
assert_eq ! ( arr, b"abcde" as & [ u8 ] ) ;
701
662
702
663
assert ! ( buffer. copy_from_slice( py, & [ 0u8 ; 5 ] ) . is_err( ) ) ;
703
-
704
- assert ! ( buffer. to_vec:: <i8 >( py) . is_err( ) ) ;
705
- assert ! ( buffer. to_vec:: <u16 >( py) . is_err( ) ) ;
706
- assert_eq ! ( buffer. to_vec:: <u8 >( py) . unwrap( ) , b"abcde" ) ;
664
+ assert_eq ! ( buffer. to_vec( py) . unwrap( ) , b"abcde" ) ;
707
665
}
708
666
709
667
#[ allow( clippy:: float_cmp) ] // The test wants to ensure that no precision was lost on the Python round-trip
@@ -716,21 +674,18 @@ mod test {
716
674
. unwrap ( )
717
675
. call_method ( "array" , ( "f" , ( 1.0 , 1.5 , 2.0 , 2.5 ) ) , None )
718
676
. unwrap ( ) ;
719
- let buffer = PyBuffer :: get ( py , array) . unwrap ( ) ;
677
+ let buffer = PyBuffer :: get ( array) . unwrap ( ) ;
720
678
assert_eq ! ( buffer. dimensions( ) , 1 ) ;
721
679
assert_eq ! ( buffer. item_count( ) , 4 ) ;
722
680
assert_eq ! ( buffer. format( ) . to_str( ) . unwrap( ) , "f" ) ;
723
681
assert_eq ! ( buffer. shape( ) , [ 4 ] ) ;
724
682
725
- assert ! ( buffer. as_slice:: <f64 >( py) . is_none( ) ) ;
726
- assert ! ( buffer. as_slice:: <i32 >( py) . is_none( ) ) ;
727
-
728
- let slice = buffer. as_slice :: < f32 > ( py) . unwrap ( ) ;
683
+ let slice = buffer. as_slice ( py) . unwrap ( ) ;
729
684
assert_eq ! ( slice. len( ) , 4 ) ;
730
685
assert_eq ! ( slice[ 0 ] . get( ) , 1.0 ) ;
731
686
assert_eq ! ( slice[ 3 ] . get( ) , 2.5 ) ;
732
687
733
- let mut_slice = buffer. as_mut_slice :: < f32 > ( py) . unwrap ( ) ;
688
+ let mut_slice = buffer. as_mut_slice ( py) . unwrap ( ) ;
734
689
assert_eq ! ( mut_slice. len( ) , 4 ) ;
735
690
assert_eq ! ( mut_slice[ 0 ] . get( ) , 1.0 ) ;
736
691
mut_slice[ 3 ] . set ( 2.75 ) ;
@@ -741,6 +696,6 @@ mod test {
741
696
. unwrap ( ) ;
742
697
assert_eq ! ( slice[ 2 ] . get( ) , 12.0 ) ;
743
698
744
- assert_eq ! ( buffer. to_vec:: < f32 > ( py) . unwrap( ) , [ 10.0 , 11.0 , 12.0 , 13.0 ] ) ;
699
+ assert_eq ! ( buffer. to_vec( py) . unwrap( ) , [ 10.0 , 11.0 , 12.0 , 13.0 ] ) ;
745
700
}
746
701
}
0 commit comments