Skip to content

Commit e70a45f

Browse files
Implement loading/storing input/output vectors through groupshared memory and improved input vector/matrix test patterns
1 parent ad20ee6 commit e70a45f

File tree

2 files changed

+146
-76
lines changed

2 files changed

+146
-76
lines changed

tools/clang/unittests/HLSLExec/CoopVec.h

+63-27
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <DirectXPackedVector.h>
77

88
#include <cstdlib>
9+
#include <random>
910
#include <vector>
1011

1112
#include "dxc/Support/microcom.h"
@@ -358,6 +359,15 @@ GetMatrixSrcDataType(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation) {
358359
}
359360
}
360361

362+
bool IsIntegralDataType(D3D12_LINEAR_ALGEBRA_DATATYPE DataType) {
363+
return DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8 ||
364+
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8 ||
365+
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16 ||
366+
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16 ||
367+
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32 ||
368+
DataType == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32;
369+
}
370+
361371
struct TestVector {
362372
private:
363373
size_t NumVectors = 0;
@@ -534,39 +544,61 @@ struct TestVector {
534544
}
535545
}
536546

537-
template <typename T> void fillSimpleTestData() {
538-
// Create a vector of (1, 1, 0, ...)
547+
template <typename T>
548+
void fillSimpleTestData(D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
549+
std::mt19937 &Rnd) {
539550
for (size_t I = 0; I < NumVectors; ++I) {
540551
T *Vec = getVector<T>(I);
541552
for (size_t J = 0; J < VectorSize; ++J)
542-
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
543-
// Special case for HALF, which requires conversion from float
544-
Vec[J] = static_cast<T>(
545-
ConvertFloat32ToFloat16((J == 0 || J == 1) ? 1.0f : 0.0f));
553+
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF> ||
554+
std::is_same_v<T, float>) {
555+
float Elt = 0.0f;
556+
if (IsIntegralDataType(MatrixInterpretation)) {
557+
Elt = (float)(Rnd() & 0x7) - 3.0f;
558+
} else {
559+
Elt = ((float)(Rnd() & 0x3) - 1.0f) / 2.0f;
560+
}
561+
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
562+
Vec[J] = static_cast<T>(ConvertFloat32ToFloat16(Elt));
563+
} else {
564+
Vec[J] = static_cast<T>(Elt);
565+
}
546566
} else {
547-
Vec[J] = static_cast<T>((J == 0 || J == 1) ? 1 : 0);
567+
if constexpr (std::is_signed_v<T>) {
568+
Vec[J] = static_cast<T>((int32_t)(Rnd() & 0xf) - 8);
569+
} else {
570+
Vec[J] = static_cast<T>((uint32_t)(Rnd() & 0xf));
571+
}
548572
}
549573
}
550574
}
551575

552-
template <typename T> void fillAllOnesTestData() {
553-
// Create a vector of (1, 1, 1, ...)
576+
template <typename T> void FillSimpleMatrixTestData(std::mt19937 &Rnd) {
554577
for (size_t I = 0; I < NumVectors; ++I) {
555578
T *Vec = getVector<T>(I);
556579
for (size_t J = 0; J < VectorSize; ++J)
557580
if constexpr (std::is_same_v<T, DirectX::PackedVector::HALF>) {
558-
// Special case for HALF, which requires conversion from float
559-
Vec[J] = static_cast<T>(ConvertFloat32ToFloat16(1.0f));
581+
float Elt = ((float)(Rnd() & 0x3) - 1.0f) / 2.0f;
582+
Vec[J] = static_cast<T>(ConvertFloat32ToFloat16(Elt));
583+
} else if constexpr (std::is_same_v<T, float>) {
584+
float Elt = ((float)(Rnd() & 0x3) - 1.0f) / 2.0f;
585+
Vec[J] = static_cast<T>(Elt);
560586
} else {
561-
Vec[J] = static_cast<T>(1);
587+
if constexpr (std::is_signed_v<T>) {
588+
Vec[J] = static_cast<T>((int32_t)(Rnd() & 0xf) - 8);
589+
} else {
590+
Vec[J] = static_cast<T>((uint32_t)(Rnd() & 0xf));
591+
}
562592
}
563593
}
564594
}
565595

566596
static TestVector
567597
createSimpleTestVector(size_t NumVectors, size_t VectorSize,
568598
D3D12_LINEAR_ALGEBRA_DATATYPE DataType,
569-
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
599+
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
600+
D3D12_LINEAR_ALGEBRA_DATATYPE MatrixInterpretation,
601+
std::mt19937 &Rnd) {
570602
size_t ElementSize;
571603
switch (DataType) {
572604
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -600,35 +632,36 @@ struct TestVector {
600632
TestVector Vec(NumVectors, VectorSize, ElementSize);
601633
switch (DataType) {
602634
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
603-
Vec.fillSimpleTestData<int8_t>();
635+
Vec.fillSimpleTestData<int8_t>(MatrixInterpretation, Rnd);
604636
break;
605637
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8:
606-
Vec.fillSimpleTestData<uint8_t>();
638+
Vec.fillSimpleTestData<uint8_t>(MatrixInterpretation, Rnd);
607639
break;
608640
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT16:
609-
Vec.fillSimpleTestData<int16_t>();
641+
Vec.fillSimpleTestData<int16_t>(MatrixInterpretation, Rnd);
610642
break;
611643
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
612-
Vec.fillSimpleTestData<uint16_t>();
644+
Vec.fillSimpleTestData<uint16_t>(MatrixInterpretation, Rnd);
613645
break;
614646
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
615-
Vec.fillSimpleTestData<int32_t>();
647+
Vec.fillSimpleTestData<int32_t>(MatrixInterpretation, Rnd);
616648
break;
617649
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
618650
if (DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8_T4_PACKED ||
619651
DataInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_UINT8_T4_PACKED) {
620-
Vec.fillSimpleTestData<uint8_t>();
652+
Vec.fillSimpleTestData<uint8_t>(MatrixInterpretation, Rnd);
621653
} else {
622-
Vec.fillSimpleTestData<uint32_t>();
654+
Vec.fillSimpleTestData<uint32_t>(MatrixInterpretation, Rnd);
623655
}
624656
break;
625657
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
626658
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
627659
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
628-
Vec.fillSimpleTestData<DirectX::PackedVector::HALF>();
660+
Vec.fillSimpleTestData<DirectX::PackedVector::HALF>(MatrixInterpretation,
661+
Rnd);
629662
break;
630663
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
631-
Vec.fillSimpleTestData<float>();
664+
Vec.fillSimpleTestData<float>(MatrixInterpretation, Rnd);
632665
break;
633666
default:
634667
throw std::invalid_argument("Unsupported data type");
@@ -638,7 +671,8 @@ struct TestVector {
638671

639672
static TestVector
640673
createAllOnesTestMatrix(size_t NumVectors, size_t VectorSize,
641-
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation) {
674+
D3D12_LINEAR_ALGEBRA_DATATYPE DataInterpretation,
675+
std::mt19937 &Rnd) {
642676
size_t ElementSize;
643677
switch (DataInterpretation) {
644678
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8:
@@ -666,13 +700,13 @@ struct TestVector {
666700
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT16:
667701
case D3D12_LINEAR_ALGEBRA_DATATYPE_SINT32:
668702
case D3D12_LINEAR_ALGEBRA_DATATYPE_UINT32:
669-
Vec.fillAllOnesTestData<int8_t>();
703+
Vec.FillSimpleMatrixTestData<int8_t>(Rnd);
670704
break;
671705
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3:
672706
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2:
673707
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT16:
674708
case D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32:
675-
Vec.fillAllOnesTestData<float>();
709+
Vec.FillSimpleMatrixTestData<float>(Rnd);
676710
break;
677711
default:
678712
throw std::invalid_argument("Unsupported data type");
@@ -724,10 +758,12 @@ struct TestVector {
724758
ConvertInfo.DestInfo.NumColumns = (UINT)getVectorSize();
725759

726760
if (MatrixLayout == D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR) {
727-
ConvertInfo.DestInfo.DestStride = (UINT)getVectorSize() * DestEltSize;
761+
ConvertInfo.DestInfo.DestStride =
762+
((UINT)getVectorSize() * DestEltSize + 15) & ~15;
728763
} else if (MatrixLayout ==
729764
D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR) {
730-
ConvertInfo.DestInfo.DestStride = (UINT)getNumVectors() * DestEltSize;
765+
ConvertInfo.DestInfo.DestStride =
766+
((UINT)getNumVectors() * DestEltSize + 15) & ~15;
731767
}
732768

733769
// Get destination size using preview interface

0 commit comments

Comments
 (0)