Skip to content

Commit d5b61ba

Browse files
authored
Merge pull request #876 from rust-ndarray/uninit
Improve .assume_init() and use Array::maybe_uninit instead of Array::uninitialized
2 parents 79392bb + ba07345 commit d5b61ba

File tree

9 files changed

+157
-68
lines changed

9 files changed

+157
-68
lines changed

benches/bench1.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
extern crate test;
1111

12+
use std::mem::MaybeUninit;
13+
1214
use ndarray::ShapeBuilder;
1315
use ndarray::{arr0, arr1, arr2, azip, s};
1416
use ndarray::{Array, Array1, Array2, Axis, Ix, Zip};
@@ -269,9 +271,9 @@ fn add_2d_alloc_zip_uninit(bench: &mut test::Bencher) {
269271
let a = Array::<i32, _>::zeros((ADD2DSZ, ADD2DSZ));
270272
let b = Array::<i32, _>::zeros((ADD2DSZ, ADD2DSZ));
271273
bench.iter(|| unsafe {
272-
let mut c = Array::uninitialized(a.dim());
273-
azip!((&a in &a, &b in &b, c in c.raw_view_mut())
274-
std::ptr::write(c, a + b)
274+
let mut c = Array::<MaybeUninit<i32>, _>::maybe_uninit(a.dim());
275+
azip!((&a in &a, &b in &b, c in c.raw_view_mut().cast::<i32>())
276+
c.write(a + b)
275277
);
276278
c
277279
});

src/data_repr.rs

+21
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ use crate::extension::nonnull;
1010
/// *Don’t use this type directly—use the type alias
1111
/// [`Array`](type.Array.html) for the array type!*
1212
// Like a Vec, but with non-unique ownership semantics
13+
//
14+
// repr(C) to make it transmutable OwnedRepr<A> -> OwnedRepr<B> if
15+
// transmutable A -> B.
1316
#[derive(Debug)]
17+
#[repr(C)]
1418
pub struct OwnedRepr<A> {
1519
ptr: NonNull<A>,
1620
len: usize,
@@ -50,6 +54,23 @@ impl<A> OwnedRepr<A> {
5054
self.ptr
5155
}
5256

57+
/// Cast self into equivalent repr of other element type
58+
///
59+
/// ## Safety
60+
///
61+
/// Caller must ensure the two types have the same representation.
62+
/// **Panics** if sizes don't match (which is not a sufficient check).
63+
pub(crate) unsafe fn data_subst<B>(self) -> OwnedRepr<B> {
64+
// necessary but not sufficient check
65+
assert_eq!(mem::size_of::<A>(), mem::size_of::<B>());
66+
let self_ = ManuallyDrop::new(self);
67+
OwnedRepr {
68+
ptr: self_.ptr.cast::<B>(),
69+
len: self_.len,
70+
capacity: self_.capacity,
71+
}
72+
}
73+
5374
fn take_as_vec(&mut self) -> Vec<A> {
5475
let capacity = self.capacity;
5576
let len = self.len;

src/data_traits.rs

+32
Original file line numberDiff line numberDiff line change
@@ -526,28 +526,60 @@ unsafe impl<'a, A> DataMut for CowRepr<'a, A> where A: Clone {}
526526
pub trait RawDataSubst<A>: RawData {
527527
/// The resulting array storage of the same kind but substituted element type
528528
type Output: RawData<Elem = A>;
529+
530+
/// Unsafely translate the data representation from one element
531+
/// representation to another.
532+
///
533+
/// ## Safety
534+
///
535+
/// Caller must ensure the two types have the same representation.
536+
unsafe fn data_subst(self) -> Self::Output;
529537
}
530538

531539
impl<A, B> RawDataSubst<B> for OwnedRepr<A> {
532540
type Output = OwnedRepr<B>;
541+
542+
unsafe fn data_subst(self) -> Self::Output {
543+
self.data_subst()
544+
}
533545
}
534546

535547
impl<A, B> RawDataSubst<B> for OwnedArcRepr<A> {
536548
type Output = OwnedArcRepr<B>;
549+
550+
unsafe fn data_subst(self) -> Self::Output {
551+
OwnedArcRepr(Arc::from_raw(Arc::into_raw(self.0) as *const OwnedRepr<B>))
552+
}
537553
}
538554

539555
impl<A, B> RawDataSubst<B> for RawViewRepr<*const A> {
540556
type Output = RawViewRepr<*const B>;
557+
558+
unsafe fn data_subst(self) -> Self::Output {
559+
RawViewRepr::new()
560+
}
541561
}
542562

543563
impl<A, B> RawDataSubst<B> for RawViewRepr<*mut A> {
544564
type Output = RawViewRepr<*mut B>;
565+
566+
unsafe fn data_subst(self) -> Self::Output {
567+
RawViewRepr::new()
568+
}
545569
}
546570

547571
impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a A> {
548572
type Output = ViewRepr<&'a B>;
573+
574+
unsafe fn data_subst(self) -> Self::Output {
575+
ViewRepr::new()
576+
}
549577
}
550578

551579
impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
552580
type Output = ViewRepr<&'a mut B>;
581+
582+
unsafe fn data_subst(self) -> Self::Output {
583+
ViewRepr::new()
584+
}
553585
}

src/doc/ndarray_for_numpy_users/mod.rs

-1
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@
647647
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
648648
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
649649
//! [.t()]: ../../struct.ArrayBase.html#method.t
650-
//! [::uninitialized()]: ../../struct.ArrayBase.html#method.uninitialized
651650
//! [vec-* dot]: ../../struct.ArrayBase.html#method.dot
652651
//! [.visit()]: ../../struct.ArrayBase.html#method.visit
653652
//! [::zeros()]: ../../struct.ArrayBase.html#method.zeros

src/impl_special_element_types.rs

+2-18
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::mem::size_of;
10-
use std::mem::ManuallyDrop;
119
use std::mem::MaybeUninit;
1210

1311
use crate::imp_prelude::*;
@@ -37,12 +35,10 @@ where
3735
/// array's storage; it is for example possible to slice these in place, but that must
3836
/// only be done after all elements have been initialized.
3937
pub unsafe fn assume_init(self) -> ArrayBase<<S as RawDataSubst<A>>::Output, D> {
40-
// NOTE: Fully initialized includes elements not reachable in current slicing/view.
41-
4238
let ArrayBase { data, ptr, dim, strides } = self;
4339

44-
// transmute from storage of MaybeUninit<A> to storage of A
45-
let data = unlimited_transmute::<S, S::Output>(data);
40+
// "transmute" from storage of MaybeUninit<A> to storage of A
41+
let data = S::data_subst(data);
4642
let ptr = ptr.cast::<A>();
4743

4844
ArrayBase {
@@ -53,15 +49,3 @@ where
5349
}
5450
}
5551
}
56-
57-
/// Transmute from A to B.
58-
///
59-
/// Like transmute, but does not have the compile-time size check which blocks
60-
/// using regular transmute for "S to S::Output".
61-
///
62-
/// **Panics** if the size of A and B are different.
63-
unsafe fn unlimited_transmute<A, B>(data: A) -> B {
64-
assert_eq!(size_of::<A>(), size_of::<B>());
65-
let old_data = ManuallyDrop::new(data);
66-
(&*old_data as *const A as *const B).read()
67-
}

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2014-2016 bluss and ndarray developers.
1+
// Copyright 2014-2020 bluss and ndarray developers.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -187,6 +187,7 @@ mod shape_builder;
187187
mod slice;
188188
mod split_at;
189189
mod stacking;
190+
mod traversal_utils;
190191
#[macro_use]
191192
mod zip;
192193

src/linalg/impl_linalg.rs

+45-22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2014-2016 bluss and ndarray developers.
1+
// Copyright 2014-2020 bluss and ndarray developers.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -325,9 +325,9 @@ where
325325

326326
// Avoid initializing the memory in vec -- set it during iteration
327327
unsafe {
328-
let mut c = Array::uninitialized(m);
329-
general_mat_vec_mul(A::one(), self, rhs, A::zero(), &mut c);
330-
c
328+
let mut c = Array1::maybe_uninit(m);
329+
general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
330+
c.assume_init()
331331
}
332332
}
333333
}
@@ -598,6 +598,30 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
598598
S2: Data<Elem = A>,
599599
S3: DataMut<Elem = A>,
600600
A: LinalgScalar,
601+
{
602+
unsafe {
603+
general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
604+
}
605+
}
606+
607+
/// General matrix-vector multiplication
608+
///
609+
/// Use a raw view for the destination vector, so that it can be uninitalized.
610+
///
611+
/// ## Safety
612+
///
613+
/// The caller must ensure that the raw view is valid for writing.
614+
/// the destination may be uninitialized iff beta is zero.
615+
unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
616+
alpha: A,
617+
a: &ArrayBase<S1, Ix2>,
618+
x: &ArrayBase<S2, Ix1>,
619+
beta: A,
620+
y: RawArrayViewMut<A, Ix1>,
621+
) where
622+
S1: Data<Elem = A>,
623+
S2: Data<Elem = A>,
624+
A: LinalgScalar,
601625
{
602626
let ((m, k), k2) = (a.dim(), x.dim());
603627
let m2 = y.dim();
@@ -626,22 +650,20 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
626650
let x_stride = x.strides()[0] as blas_index;
627651
let y_stride = y.strides()[0] as blas_index;
628652

629-
unsafe {
630-
blas_sys::$gemv(
631-
layout,
632-
a_trans,
633-
m as blas_index, // m, rows of Op(a)
634-
k as blas_index, // n, cols of Op(a)
635-
cast_as(&alpha), // alpha
636-
a.ptr.as_ptr() as *const _, // a
637-
a_stride, // lda
638-
x.ptr.as_ptr() as *const _, // x
639-
x_stride,
640-
cast_as(&beta), // beta
641-
y.ptr.as_ptr() as *mut _, // x
642-
y_stride,
643-
);
644-
}
653+
blas_sys::$gemv(
654+
layout,
655+
a_trans,
656+
m as blas_index, // m, rows of Op(a)
657+
k as blas_index, // n, cols of Op(a)
658+
cast_as(&alpha), // alpha
659+
a.ptr.as_ptr() as *const _, // a
660+
a_stride, // lda
661+
x.ptr.as_ptr() as *const _, // x
662+
x_stride,
663+
cast_as(&beta), // beta
664+
y.ptr.as_ptr() as *mut _, // x
665+
y_stride,
666+
);
645667
return;
646668
}
647669
}
@@ -655,8 +677,9 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
655677
/* general */
656678

657679
if beta.is_zero() {
680+
// when beta is zero, c may be uninitialized
658681
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {
659-
*elt = row.dot(x) * alpha;
682+
elt.write(row.dot(x) * alpha);
660683
});
661684
} else {
662685
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {
@@ -683,7 +706,7 @@ fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
683706
#[cfg(feature = "blas")]
684707
fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
685708
where
686-
S: Data,
709+
S: RawData,
687710
A: 'static,
688711
S::Elem: 'static,
689712
{

src/stacking.rs

+24-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2014-2016 bluss and ndarray developers.
1+
// Copyright 2014-2020 bluss and ndarray developers.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -8,6 +8,7 @@
88

99
use crate::error::{from_kind, ErrorKind, ShapeError};
1010
use crate::imp_prelude::*;
11+
use crate::traversal_utils::assign_to;
1112

1213
/// Stack arrays along the new axis.
1314
///
@@ -88,25 +89,23 @@ where
8889
let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
8990
res_dim.set_axis(axis, stacked_dim);
9091

91-
// we can safely use uninitialized values here because they are Copy
92-
// and we will only ever write to them
93-
let size = res_dim.size();
94-
let mut v = Vec::with_capacity(size);
95-
unsafe {
96-
v.set_len(size);
97-
}
98-
let mut res = Array::from_shape_vec(res_dim, v)?;
92+
// we can safely use uninitialized values here because we will
93+
// overwrite every one of them.
94+
let mut res = Array::maybe_uninit(res_dim);
9995

10096
{
10197
let mut assign_view = res.view_mut();
10298
for array in arrays {
10399
let len = array.len_of(axis);
104-
let (mut front, rest) = assign_view.split_at(axis, len);
105-
front.assign(array);
100+
let (front, rest) = assign_view.split_at(axis, len);
101+
assign_to(array, front);
106102
assign_view = rest;
107103
}
104+
debug_assert_eq!(assign_view.len(), 0);
105+
}
106+
unsafe {
107+
Ok(res.assume_init())
108108
}
109-
Ok(res)
110109
}
111110

112111
/// Stack arrays along the new axis.
@@ -158,22 +157,24 @@ where
158157

159158
res_dim.set_axis(axis, arrays.len());
160159

161-
// we can safely use uninitialized values here because they are Copy
162-
// and we will only ever write to them
163-
let size = res_dim.size();
164-
let mut v = Vec::with_capacity(size);
165-
unsafe {
166-
v.set_len(size);
167-
}
168-
let mut res = Array::from_shape_vec(res_dim, v)?;
160+
// we can safely use uninitialized values here because we will
161+
// overwrite every one of them.
162+
let mut res = Array::maybe_uninit(res_dim);
169163

170164
res.axis_iter_mut(axis)
171165
.zip(arrays.iter())
172-
.for_each(|(mut assign_view, array)| {
173-
assign_view.assign(&array);
166+
.for_each(|(assign_view, array)| {
167+
// assign_view is D::Larger::Smaller which is usually == D
168+
// (but if D is Ix6, we have IxD != Ix6 here; differing types
169+
// but same number of axes).
170+
let assign_view = assign_view.into_dimensionality::<D>()
171+
.expect("same-dimensionality cast");
172+
assign_to(array, assign_view);
174173
});
175174

176-
Ok(res)
175+
unsafe {
176+
Ok(res.assume_init())
177+
}
177178
}
178179

179180
/// Stack arrays along the new axis.

src/traversal_utils.rs

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2020 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::{
10+
IntoNdProducer,
11+
AssignElem,
12+
Zip,
13+
};
14+
15+
/// Assign values from producer P1 to producer P2
16+
/// P1 and P2 must be of the same shape and dimension
17+
pub(crate) fn assign_to<'a, P1, P2, A>(from: P1, to: P2)
18+
where P1: IntoNdProducer<Item = &'a A>,
19+
P2: IntoNdProducer<Dim = P1::Dim>,
20+
P2::Item: AssignElem<A>,
21+
A: Clone + 'a
22+
{
23+
Zip::from(from)
24+
.apply_assign_into(to, A::clone);
25+
}
26+

0 commit comments

Comments
 (0)