Skip to content

Commit d674b5f

Browse files
authored
Merge pull request #952 from kngwyu/typed-pybuffer
Typed PyBuffer
2 parents be1b704 + 5939362 commit d674b5f

File tree

6 files changed

+129
-120
lines changed

6 files changed

+129
-120
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1111
### Changed
1212
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
1313
- Call `Py_Finalize` at exit to flush buffers, etc. [#943](https://github.com/PyO3/pyo3/pull/943)
14+
- Add type parameter to PyBuffer. #[951](https://github.com/PyO3/pyo3/pull/951)
1415

1516
### Removed
1617
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)

src/buffer.rs

Lines changed: 65 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@
1818

1919
//! `PyBuffer` implementation
2020
use crate::err::{self, PyResult};
21-
use crate::{exceptions, ffi, AsPyPointer, PyAny, Python};
21+
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
2222
use std::ffi::CStr;
23+
use std::marker::PhantomData;
2324
use std::os::raw;
2425
use std::pin::Pin;
2526
use std::{cell, mem, ptr, slice};
2627

2728
/// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`.
2829
// use Pin<Box> because Python expects that the Py_buffer struct has a stable memory address
2930
#[repr(transparent)]
30-
pub struct PyBuffer(Pin<Box<ffi::Py_buffer>>);
31+
pub struct PyBuffer<T>(Pin<Box<ffi::Py_buffer>>, PhantomData<T>);
3132

3233
// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
3334
// Accessing the buffer contents is protected using the GIL.
34-
unsafe impl Send for PyBuffer {}
35-
unsafe impl Sync for PyBuffer {}
35+
unsafe impl<T> Send for PyBuffer<T> {}
36+
unsafe impl<T> Sync for PyBuffer<T> {}
3637

3738
#[derive(Copy, Clone, Eq, PartialEq)]
3839
pub enum ElementType {
@@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool {
146147
}
147148

148149
/// Trait implemented for possible element types of `PyBuffer`.
149-
pub unsafe trait Element {
150+
pub unsafe trait Element: Copy {
150151
/// Gets whether the element specified in the format string is potentially compatible.
151152
/// Alignment and size are checked separately from this function.
152153
fn is_compatible_format(format: &CStr) -> bool;
153154
}
154155

155-
fn validate(b: &ffi::Py_buffer) {
156+
fn validate(b: &ffi::Py_buffer) -> PyResult<()> {
156157
// shape and stride information must be provided when we use PyBUF_FULL_RO
157-
assert!(!b.shape.is_null());
158-
assert!(!b.strides.is_null());
158+
if b.shape.is_null() {
159+
return Err(exceptions::BufferError::py_err("Shape is Null"));
160+
}
161+
if b.strides.is_null() {
162+
return Err(exceptions::BufferError::py_err("PyBuffer: Strides is Null"));
163+
}
164+
Ok(())
165+
}
166+
167+
impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
168+
fn extract(obj: &PyAny) -> PyResult<PyBuffer<T>> {
169+
Self::get(obj)
170+
}
159171
}
160172

161-
impl PyBuffer {
173+
impl<T: Element> PyBuffer<T> {
162174
/// Get the underlying buffer from the specified python object.
163-
pub fn get(py: Python, obj: &PyAny) -> PyResult<PyBuffer> {
175+
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
164176
unsafe {
165-
let mut buf = Box::pin(mem::zeroed::<ffi::Py_buffer>());
177+
let mut buf = Box::pin(ffi::Py_buffer::new());
166178
err::error_on_minusone(
167-
py,
179+
obj.py(),
168180
ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO),
169181
)?;
170-
validate(&buf);
171-
Ok(PyBuffer(buf))
182+
validate(&buf)?;
183+
let buf = PyBuffer(buf, PhantomData);
184+
// Type Check
185+
if mem::size_of::<T>() == buf.item_size()
186+
&& (buf.0.buf as usize) % mem::align_of::<T>() == 0
187+
&& T::is_compatible_format(buf.format())
188+
{
189+
Ok(buf)
190+
} else {
191+
Err(exceptions::BufferError::py_err(
192+
"Incompatible type as buffer",
193+
))
194+
}
172195
}
173196
}
174197

@@ -307,12 +330,8 @@ impl PyBuffer {
307330
///
308331
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
309332
/// to modify the values in the slice.
310-
pub fn as_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
311-
if mem::size_of::<T>() == self.item_size()
312-
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
313-
&& self.is_c_contiguous()
314-
&& T::is_compatible_format(self.format())
315-
{
333+
pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
334+
if self.is_c_contiguous() {
316335
unsafe {
317336
Some(slice::from_raw_parts(
318337
self.0.buf as *mut ReadOnlyCell<T>,
@@ -334,13 +353,8 @@ impl PyBuffer {
334353
///
335354
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
336355
/// to modify the values in the slice.
337-
pub fn as_mut_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
338-
if !self.readonly()
339-
&& mem::size_of::<T>() == self.item_size()
340-
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
341-
&& self.is_c_contiguous()
342-
&& T::is_compatible_format(self.format())
343-
{
356+
pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
357+
if !self.readonly() && self.is_c_contiguous() {
344358
unsafe {
345359
Some(slice::from_raw_parts(
346360
self.0.buf as *mut cell::Cell<T>,
@@ -361,15 +375,8 @@ impl PyBuffer {
361375
///
362376
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
363377
/// to modify the values in the slice.
364-
pub fn as_fortran_slice<'a, T: Element>(
365-
&'a self,
366-
_py: Python<'a>,
367-
) -> Option<&'a [ReadOnlyCell<T>]> {
368-
if mem::size_of::<T>() == self.item_size()
369-
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
370-
&& self.is_fortran_contiguous()
371-
&& T::is_compatible_format(self.format())
372-
{
378+
pub fn as_fortran_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
379+
if mem::size_of::<T>() == self.item_size() && self.is_fortran_contiguous() {
373380
unsafe {
374381
Some(slice::from_raw_parts(
375382
self.0.buf as *mut ReadOnlyCell<T>,
@@ -391,16 +398,8 @@ impl PyBuffer {
391398
///
392399
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
393400
/// to modify the values in the slice.
394-
pub fn as_fortran_mut_slice<'a, T: Element>(
395-
&'a self,
396-
_py: Python<'a>,
397-
) -> Option<&'a [cell::Cell<T>]> {
398-
if !self.readonly()
399-
&& mem::size_of::<T>() == self.item_size()
400-
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
401-
&& self.is_fortran_contiguous()
402-
&& T::is_compatible_format(self.format())
403-
{
401+
pub fn as_fortran_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
402+
if !self.readonly() && self.is_fortran_contiguous() {
404403
unsafe {
405404
Some(slice::from_raw_parts(
406405
self.0.buf as *mut cell::Cell<T>,
@@ -421,7 +420,7 @@ impl PyBuffer {
421420
/// To check whether the buffer format is compatible before calling this method,
422421
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
423422
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
424-
pub fn copy_to_slice<T: Element + Copy>(&self, py: Python, target: &mut [T]) -> PyResult<()> {
423+
pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
425424
self.copy_to_slice_impl(py, target, b'C')
426425
}
427426

@@ -434,28 +433,16 @@ impl PyBuffer {
434433
/// To check whether the buffer format is compatible before calling this method,
435434
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
436435
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
437-
pub fn copy_to_fortran_slice<T: Element + Copy>(
438-
&self,
439-
py: Python,
440-
target: &mut [T],
441-
) -> PyResult<()> {
436+
pub fn copy_to_fortran_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
442437
self.copy_to_slice_impl(py, target, b'F')
443438
}
444439

445-
fn copy_to_slice_impl<T: Element + Copy>(
446-
&self,
447-
py: Python,
448-
target: &mut [T],
449-
fort: u8,
450-
) -> PyResult<()> {
440+
fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
451441
if mem::size_of_val(target) != self.len_bytes() {
452442
return Err(exceptions::BufferError::py_err(
453443
"Slice length does not match buffer length.",
454444
));
455445
}
456-
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
457-
return incompatible_format_error();
458-
}
459446
unsafe {
460447
err::error_on_minusone(
461448
py,
@@ -473,23 +460,19 @@ impl PyBuffer {
473460
/// If the buffer is multi-dimensional, the elements are written in C-style order.
474461
///
475462
/// Fails if the buffer format is not compatible with type `T`.
476-
pub fn to_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
463+
pub fn to_vec(&self, py: Python) -> PyResult<Vec<T>> {
477464
self.to_vec_impl(py, b'C')
478465
}
479466

480467
/// Copies the buffer elements to a newly allocated vector.
481468
/// If the buffer is multi-dimensional, the elements are written in Fortran-style order.
482469
///
483470
/// Fails if the buffer format is not compatible with type `T`.
484-
pub fn to_fortran_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
471+
pub fn to_fortran_vec(&self, py: Python) -> PyResult<Vec<T>> {
485472
self.to_vec_impl(py, b'F')
486473
}
487474

488-
fn to_vec_impl<T: Element + Copy>(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
489-
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
490-
incompatible_format_error()?;
491-
unreachable!();
492-
}
475+
fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
493476
let item_count = self.item_count();
494477
let mut vec: Vec<T> = Vec::with_capacity(item_count);
495478
unsafe {
@@ -520,7 +503,7 @@ impl PyBuffer {
520503
/// To check whether the buffer format is compatible before calling this method,
521504
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
522505
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
523-
pub fn copy_from_slice<T: Element + Copy>(&self, py: Python, source: &[T]) -> PyResult<()> {
506+
pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
524507
self.copy_from_slice_impl(py, source, b'C')
525508
}
526509

@@ -534,20 +517,11 @@ impl PyBuffer {
534517
/// To check whether the buffer format is compatible before calling this method,
535518
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
536519
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
537-
pub fn copy_from_fortran_slice<T: Element + Copy>(
538-
&self,
539-
py: Python,
540-
source: &[T],
541-
) -> PyResult<()> {
520+
pub fn copy_from_fortran_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
542521
self.copy_from_slice_impl(py, source, b'F')
543522
}
544523

545-
fn copy_from_slice_impl<T: Element + Copy>(
546-
&self,
547-
py: Python,
548-
source: &[T],
549-
fort: u8,
550-
) -> PyResult<()> {
524+
fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> {
551525
if self.readonly() {
552526
return buffer_readonly_error();
553527
}
@@ -556,9 +530,6 @@ impl PyBuffer {
556530
"Slice length does not match buffer length.",
557531
));
558532
}
559-
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
560-
return incompatible_format_error();
561-
}
562533
unsafe {
563534
err::error_on_minusone(
564535
py,
@@ -589,19 +560,14 @@ impl PyBuffer {
589560
}
590561
}
591562

592-
fn incompatible_format_error() -> PyResult<()> {
593-
Err(exceptions::BufferError::py_err(
594-
"Slice type is incompatible with buffer format.",
595-
))
596-
}
597-
563+
#[inline(always)]
598564
fn buffer_readonly_error() -> PyResult<()> {
599565
Err(exceptions::BufferError::py_err(
600566
"Cannot write to read-only buffer.",
601567
))
602568
}
603569

604-
impl Drop for PyBuffer {
570+
impl<T> Drop for PyBuffer<T> {
605571
fn drop(&mut self) {
606572
let _gil_guard = Python::acquire_gil();
607573
unsafe { ffi::PyBuffer_Release(&mut *self.0) }
@@ -614,9 +580,9 @@ impl Drop for PyBuffer {
614580
/// The data cannot be modified through the reference, but other references may
615581
/// be modifying the data.
616582
#[repr(transparent)]
617-
pub struct ReadOnlyCell<T>(cell::UnsafeCell<T>);
583+
pub struct ReadOnlyCell<T: Element>(cell::UnsafeCell<T>);
618584

619-
impl<T: Copy> ReadOnlyCell<T> {
585+
impl<T: Element> ReadOnlyCell<T> {
620586
#[inline]
621587
pub fn get(&self) -> T {
622588
unsafe { *self.0.get() }
@@ -675,7 +641,7 @@ mod test {
675641
let gil = Python::acquire_gil();
676642
let py = gil.python();
677643
let bytes = py.eval("b'abcde'", None, None).unwrap();
678-
let buffer = PyBuffer::get(py, &bytes).unwrap();
644+
let buffer = PyBuffer::get(&bytes).unwrap();
679645
assert_eq!(buffer.dimensions(), 1);
680646
assert_eq!(buffer.item_count(), 5);
681647
assert_eq!(buffer.format().to_str().unwrap(), "B");
@@ -684,26 +650,18 @@ mod test {
684650
assert!(buffer.is_c_contiguous());
685651
assert!(buffer.is_fortran_contiguous());
686652

687-
assert!(buffer.as_slice::<f64>(py).is_none());
688-
assert!(buffer.as_slice::<i8>(py).is_none());
689-
690-
let slice = buffer.as_slice::<u8>(py).unwrap();
653+
let slice = buffer.as_slice(py).unwrap();
691654
assert_eq!(slice.len(), 5);
692655
assert_eq!(slice[0].get(), b'a');
693656
assert_eq!(slice[2].get(), b'c');
694657

695-
assert!(buffer.as_mut_slice::<u8>(py).is_none());
696-
697658
assert!(buffer.copy_to_slice(py, &mut [0u8]).is_err());
698659
let mut arr = [0; 5];
699660
buffer.copy_to_slice(py, &mut arr).unwrap();
700661
assert_eq!(arr, b"abcde" as &[u8]);
701662

702663
assert!(buffer.copy_from_slice(py, &[0u8; 5]).is_err());
703-
704-
assert!(buffer.to_vec::<i8>(py).is_err());
705-
assert!(buffer.to_vec::<u16>(py).is_err());
706-
assert_eq!(buffer.to_vec::<u8>(py).unwrap(), b"abcde");
664+
assert_eq!(buffer.to_vec(py).unwrap(), b"abcde");
707665
}
708666

709667
#[allow(clippy::float_cmp)] // The test wants to ensure that no precision was lost on the Python round-trip
@@ -716,21 +674,18 @@ mod test {
716674
.unwrap()
717675
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
718676
.unwrap();
719-
let buffer = PyBuffer::get(py, array).unwrap();
677+
let buffer = PyBuffer::get(array).unwrap();
720678
assert_eq!(buffer.dimensions(), 1);
721679
assert_eq!(buffer.item_count(), 4);
722680
assert_eq!(buffer.format().to_str().unwrap(), "f");
723681
assert_eq!(buffer.shape(), [4]);
724682

725-
assert!(buffer.as_slice::<f64>(py).is_none());
726-
assert!(buffer.as_slice::<i32>(py).is_none());
727-
728-
let slice = buffer.as_slice::<f32>(py).unwrap();
683+
let slice = buffer.as_slice(py).unwrap();
729684
assert_eq!(slice.len(), 4);
730685
assert_eq!(slice[0].get(), 1.0);
731686
assert_eq!(slice[3].get(), 2.5);
732687

733-
let mut_slice = buffer.as_mut_slice::<f32>(py).unwrap();
688+
let mut_slice = buffer.as_mut_slice(py).unwrap();
734689
assert_eq!(mut_slice.len(), 4);
735690
assert_eq!(mut_slice[0].get(), 1.0);
736691
mut_slice[3].set(2.75);
@@ -741,6 +696,6 @@ mod test {
741696
.unwrap();
742697
assert_eq!(slice[2].get(), 12.0);
743698

744-
assert_eq!(buffer.to_vec::<f32>(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
699+
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
745700
}
746701
}

0 commit comments

Comments
 (0)