6
6
#include < DirectXPackedVector.h>
7
7
8
8
#include < cstdlib>
9
+ #include < random>
9
10
#include < vector>
10
11
11
12
#include " dxc/Support/microcom.h"
@@ -358,6 +359,15 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
358
359
}
359
360
}
360
361
362
+ bool IsIntegralDataType (D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
363
+ return DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 ||
364
+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 ||
365
+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 ||
366
+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 ||
367
+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 ||
368
+ DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32;
369
+ }
370
+
361
371
struct TestVector {
362
372
private:
363
373
size_t NumVectors = 0 ;
@@ -534,39 +544,61 @@ struct TestVector {
534
544
}
535
545
}
536
546
537
- template <typename T> void fillSimpleTestData () {
538
- // Create a vector of (1, 1, 0, ...)
547
+ template <typename T>
548
+ void fillSimpleTestData (D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
549
+ std::mt19937 &Rnd) {
539
550
for (size_t I = 0 ; I < NumVectors; ++I) {
540
551
T *Vec = getVector<T>(I);
541
552
for (size_t J = 0 ; J < VectorSize; ++J)
542
- if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
543
- // Special case for HALF, which requires conversion from float
544
- Vec[J] = static_cast <T>(
545
- ConvertFloat32ToFloat16 ((J == 0 || J == 1 ) ? 1 .0f : 0 .0f ));
553
+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF> ||
554
+ std::is_same_v<T, float >) {
555
+ float Elt = 0 .0f ;
556
+ if (IsIntegralDataType (MatrixInterpretation)) {
557
+ Elt = (float )(Rnd () & 0x7 ) - 3 .0f ;
558
+ } else {
559
+ Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
560
+ }
561
+ if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
562
+ Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (Elt));
563
+ } else {
564
+ Vec[J] = static_cast <T>(Elt);
565
+ }
546
566
} else {
547
- Vec[J] = static_cast <T>((J == 0 || J == 1 ) ? 1 : 0 );
567
+ if constexpr (std::is_signed_v<T>) {
568
+ Vec[J] = static_cast <T>((int32_t )(Rnd () & 0xf ) - 8 );
569
+ } else {
570
+ Vec[J] = static_cast <T>((uint32_t )(Rnd () & 0xf ));
571
+ }
548
572
}
549
573
}
550
574
}
551
575
552
- template <typename T> void fillAllOnesTestData () {
553
- // Create a vector of (1, 1, 1, ...)
576
+ template <typename T> void FillSimpleMatrixTestData (std::mt19937 &Rnd) {
554
577
for (size_t I = 0 ; I < NumVectors; ++I) {
555
578
T *Vec = getVector<T>(I);
556
579
for (size_t J = 0 ; J < VectorSize; ++J)
557
580
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
558
- // Special case for HALF, which requires conversion from float
559
- Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (1 .0f ));
581
+ float Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
582
+ Vec[J] = static_cast <T>(ConvertFloat32ToFloat16 (Elt));
583
+ } else if constexpr (std::is_same_v<T, float >) {
584
+ float Elt = ((float )(Rnd () & 0x3 ) - 1 .0f ) / 2 .0f ;
585
+ Vec[J] = static_cast <T>(Elt);
560
586
} else {
561
- Vec[J] = static_cast <T>(1 );
587
+ if constexpr (std::is_signed_v<T>) {
588
+ Vec[J] = static_cast <T>((int32_t )(Rnd () & 0xf ) - 8 );
589
+ } else {
590
+ Vec[J] = static_cast <T>((uint32_t )(Rnd () & 0xf ));
591
+ }
562
592
}
563
593
}
564
594
}
565
595
566
596
static TestVector
567
597
createSimpleTestVector (size_t NumVectors, size_t VectorSize,
568
598
D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
569
- D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
599
+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
600
+ D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
601
+ std::mt19937 &Rnd) {
570
602
size_t ElementSize;
571
603
switch (DataType) {
572
604
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -600,35 +632,36 @@ struct TestVector {
600
632
TestVector Vec (NumVectors, VectorSize, ElementSize);
601
633
switch (DataType) {
602
634
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
603
- Vec.fillSimpleTestData <int8_t >();
635
+ Vec.fillSimpleTestData <int8_t >(MatrixInterpretation, Rnd );
604
636
break ;
605
637
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
606
- Vec.fillSimpleTestData <uint8_t >();
638
+ Vec.fillSimpleTestData <uint8_t >(MatrixInterpretation, Rnd );
607
639
break ;
608
640
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
609
- Vec.fillSimpleTestData <int16_t >();
641
+ Vec.fillSimpleTestData <int16_t >(MatrixInterpretation, Rnd );
610
642
break ;
611
643
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
612
- Vec.fillSimpleTestData <uint16_t >();
644
+ Vec.fillSimpleTestData <uint16_t >(MatrixInterpretation, Rnd );
613
645
break ;
614
646
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
615
- Vec.fillSimpleTestData <int32_t >();
647
+ Vec.fillSimpleTestData <int32_t >(MatrixInterpretation, Rnd );
616
648
break ;
617
649
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
618
650
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
619
651
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
620
- Vec.fillSimpleTestData <uint8_t >();
652
+ Vec.fillSimpleTestData <uint8_t >(MatrixInterpretation, Rnd );
621
653
} else {
622
- Vec.fillSimpleTestData <uint32_t >();
654
+ Vec.fillSimpleTestData <uint32_t >(MatrixInterpretation, Rnd );
623
655
}
624
656
break ;
625
657
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
626
658
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
627
659
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
628
- Vec.fillSimpleTestData <DirectX::PackedVector::HALF>();
660
+ Vec.fillSimpleTestData <DirectX::PackedVector::HALF>(MatrixInterpretation,
661
+ Rnd);
629
662
break ;
630
663
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
631
- Vec.fillSimpleTestData <float >();
664
+ Vec.fillSimpleTestData <float >(MatrixInterpretation, Rnd );
632
665
break ;
633
666
default :
634
667
throw std::invalid_argument (" Unsupported data type" );
@@ -638,7 +671,8 @@ struct TestVector {
638
671
639
672
static TestVector
640
673
createAllOnesTestMatrix (size_t NumVectors, size_t VectorSize,
641
- D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
674
+ D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
675
+ std::mt19937 &Rnd) {
642
676
size_t ElementSize;
643
677
switch (DataInterpretation) {
644
678
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -666,13 +700,13 @@ struct TestVector {
666
700
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
667
701
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
668
702
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
669
- Vec.fillAllOnesTestData <int8_t >();
703
+ Vec.FillSimpleMatrixTestData <int8_t >(Rnd );
670
704
break ;
671
705
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
672
706
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
673
707
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
674
708
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
675
- Vec.fillAllOnesTestData <float >();
709
+ Vec.FillSimpleMatrixTestData <float >(Rnd );
676
710
break ;
677
711
default :
678
712
throw std::invalid_argument (" Unsupported data type" );
@@ -724,10 +758,12 @@ struct TestVector {
724
758
ConvertInfo.DestInfo .NumColumns = (UINT)getVectorSize ();
725
759
726
760
if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) {
727
- ConvertInfo.DestInfo .DestStride = (UINT)getVectorSize () * DestEltSize;
761
+ ConvertInfo.DestInfo .DestStride =
762
+ ((UINT)getVectorSize () * DestEltSize + 15 ) & ~15 ;
728
763
} else if (MatrixLayout ==
729
764
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) {
730
- ConvertInfo.DestInfo .DestStride = (UINT)getNumVectors () * DestEltSize;
765
+ ConvertInfo.DestInfo .DestStride =
766
+ ((UINT)getNumVectors () * DestEltSize + 15 ) & ~15 ;
731
767
}
732
768
733
769
// Get destination size using preview interface
0 commit comments