1
- // Copyright 2014-2016 bluss and ndarray developers.
1
+ // Copyright 2014-2020 bluss and ndarray developers.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4
4
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -325,9 +325,9 @@ where
325
325
326
326
// Avoid initializing the memory in vec -- set it during iteration
327
327
unsafe {
328
- let mut c = Array :: uninitialized ( m) ;
329
- general_mat_vec_mul ( A :: one ( ) , self , rhs, A :: zero ( ) , & mut c ) ;
330
- c
328
+ let mut c = Array1 :: maybe_uninit ( m) ;
329
+ general_mat_vec_mul_impl ( A :: one ( ) , self , rhs, A :: zero ( ) , c . raw_view_mut ( ) . cast :: < A > ( ) ) ;
330
+ c. assume_init ( )
331
331
}
332
332
}
333
333
}
@@ -598,6 +598,30 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
598
598
S2 : Data < Elem = A > ,
599
599
S3 : DataMut < Elem = A > ,
600
600
A : LinalgScalar ,
601
+ {
602
+ unsafe {
603
+ general_mat_vec_mul_impl ( alpha, a, x, beta, y. raw_view_mut ( ) )
604
+ }
605
+ }
606
+
607
+ /// General matrix-vector multiplication
608
+ ///
609
+ /// Use a raw view for the destination vector, so that it can be uninitalized.
610
+ ///
611
+ /// ## Safety
612
+ ///
613
+ /// The caller must ensure that the raw view is valid for writing.
614
+ /// the destination may be uninitialized iff beta is zero.
615
+ unsafe fn general_mat_vec_mul_impl < A , S1 , S2 > (
616
+ alpha : A ,
617
+ a : & ArrayBase < S1 , Ix2 > ,
618
+ x : & ArrayBase < S2 , Ix1 > ,
619
+ beta : A ,
620
+ y : RawArrayViewMut < A , Ix1 > ,
621
+ ) where
622
+ S1 : Data < Elem = A > ,
623
+ S2 : Data < Elem = A > ,
624
+ A : LinalgScalar ,
601
625
{
602
626
let ( ( m, k) , k2) = ( a. dim ( ) , x. dim ( ) ) ;
603
627
let m2 = y. dim ( ) ;
@@ -626,22 +650,20 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
626
650
let x_stride = x. strides( ) [ 0 ] as blas_index;
627
651
let y_stride = y. strides( ) [ 0 ] as blas_index;
628
652
629
- unsafe {
630
- blas_sys:: $gemv(
631
- layout,
632
- a_trans,
633
- m as blas_index, // m, rows of Op(a)
634
- k as blas_index, // n, cols of Op(a)
635
- cast_as( & alpha) , // alpha
636
- a. ptr. as_ptr( ) as * const _, // a
637
- a_stride, // lda
638
- x. ptr. as_ptr( ) as * const _, // x
639
- x_stride,
640
- cast_as( & beta) , // beta
641
- y. ptr. as_ptr( ) as * mut _, // x
642
- y_stride,
643
- ) ;
644
- }
653
+ blas_sys:: $gemv(
654
+ layout,
655
+ a_trans,
656
+ m as blas_index, // m, rows of Op(a)
657
+ k as blas_index, // n, cols of Op(a)
658
+ cast_as( & alpha) , // alpha
659
+ a. ptr. as_ptr( ) as * const _, // a
660
+ a_stride, // lda
661
+ x. ptr. as_ptr( ) as * const _, // x
662
+ x_stride,
663
+ cast_as( & beta) , // beta
664
+ y. ptr. as_ptr( ) as * mut _, // x
665
+ y_stride,
666
+ ) ;
645
667
return ;
646
668
}
647
669
}
@@ -655,8 +677,9 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
655
677
/* general */
656
678
657
679
if beta. is_zero ( ) {
680
+ // when beta is zero, c may be uninitialized
658
681
Zip :: from ( a. outer_iter ( ) ) . and ( y) . apply ( |row, elt| {
659
- * elt = row. dot ( x) * alpha;
682
+ elt. write ( row. dot ( x) * alpha) ;
660
683
} ) ;
661
684
} else {
662
685
Zip :: from ( a. outer_iter ( ) ) . and ( y) . apply ( |row, elt| {
@@ -683,7 +706,7 @@ fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
683
706
#[ cfg( feature = "blas" ) ]
684
707
fn blas_compat_1d < A , S > ( a : & ArrayBase < S , Ix1 > ) -> bool
685
708
where
686
- S : Data ,
709
+ S : RawData ,
687
710
A : ' static ,
688
711
S :: Elem : ' static ,
689
712
{
0 commit comments