Skip to content

Fix failing unit tests on BMG #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions test/unit/cute/intel_xe/copy_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ using namespace cutlass;
using namespace syclcompat::experimental;

#define SUBGROUP_SIZE (16)
constexpr int row_alignment = 16; // Alignment requirement for Xe 2D Block Copy Instructions

template <class TensorS, class TensorD, class TiledLoad, class TiledStore,
class CopyOp = void>
Expand Down Expand Up @@ -265,31 +266,34 @@ struct copy_op<uint32_t, load, store, M, N, true> {
// Allocate and initialize
//
using dtype = uint32_t;
cutlass::host_vector<dtype> host_src(M * N);
cutlass::host_vector<dtype> host_output(M * N);

for (size_t i = 0; i < host_src.size(); ++i) {
host_src[i] = static_cast<dtype>(i);
constexpr int elem_alignment = row_alignment / sizeof(dtype);
constexpr int row_pitch_S = cute::ceil_div(N, elem_alignment) * elem_alignment;
constexpr int row_pitch_D = cute::ceil_div(M, elem_alignment) * elem_alignment;
using TensorLayoutS = decltype(make_layout(Shape<Int<M>, Int<N>>{}, make_stride(Int<row_pitch_S>{}, _1{})));
using TensorLayoutD = decltype(make_layout(Shape<Int<N>, Int<M>>{}, make_stride(Int<row_pitch_D>{}, _1{})));

cutlass::host_vector<dtype> host_src(M * row_pitch_S);
cutlass::host_vector<dtype> host_output(N * row_pitch_D);

for (size_t i = 0; i < cute::cosize(TensorLayoutS{}); ++i) {
host_src[TensorLayoutS{}(i)] = static_cast<dtype>(i);
}

cutlass::device_vector<dtype> device_src = host_src;
cutlass::device_vector<dtype> device_output = host_output;

Tensor S =
make_tensor(make_gmem_ptr(device_src.data()),
make_layout(Shape<Int<M>, Int<N>>{}, Stride<Int<N>, _1>{}));
Tensor D =
make_tensor(make_gmem_ptr(device_output.data()),
make_layout(Shape<Int<N>, Int<M>>{}, Stride<Int<M>, _1>{}));
Tensor S = make_tensor(make_gmem_ptr(device_src.data()), TensorLayoutS{});
Tensor D = make_tensor(make_gmem_ptr(device_output.data()), TensorLayoutD{});

auto tiled_load = make_tiled_copy(
Copy_Atom<Copy_Traits<load, decltype(S)>, dtype>{}.with(device_src.data(), M, N),
Copy_Atom<Copy_Traits<load, decltype(S)>, dtype>{}.with(S),
Layout<Shape<Int<SUBGROUP_SIZE>, _1>>{},
make_layout(shape_div(typename Copy_Traits<load, decltype(S)>::BlockShape{}, Shape<_16, _1>{})));
auto tiled_store = make_tiled_copy(
Copy_Atom<Copy_Traits<store, decltype(D)>, dtype>{}.with(device_output.data(), N, M),
Copy_Atom<Copy_Traits<store, decltype(D)>, dtype>{}.with(D),
Layout<Shape<_1, Int<SUBGROUP_SIZE>>>{},
make_layout(shape_div(typename Copy_Traits<store, decltype(S)>::BlockShape{}, Shape<_1, _16>{})));
make_layout(shape_div(typename Copy_Traits<store, decltype(D)>::BlockShape{}, Shape<_1, _16>{})));
auto blockDim = syclcompat::dim3(size(tiled_load));
//
// Launch the kernel
Expand All @@ -306,7 +310,7 @@ struct copy_op<uint32_t, load, store, M, N, true> {
host_output = device_output;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < M; ++j) {
EXPECT_EQ(host_output[i * M + j], host_src[j * N + i]);
EXPECT_EQ(host_output[i * row_pitch_D + j], host_src[j * row_pitch_S + i]);
}
}
}
Expand Down
25 changes: 13 additions & 12 deletions test/unit/cute/intel_xe/copy_subgroup_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using namespace syclcompat::experimental;

template <class TensorS, class TensorD, uint32_t wg_tile_m, uint32_t wg_tile_n,
uint32_t sg_tile_m, uint32_t sg_tile_n>
void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) {
void copy_kernel_vectorized(TensorS S, TensorD D) {
using namespace cute;

using Element = typename TensorS::value_type;
Expand Down Expand Up @@ -158,15 +158,19 @@ bool copy(uint32_t M, uint32_t N) {
// Given a 2D shape, perform an efficient copy
//

constexpr int elem_alignment = 16 / sizeof(dtype);
int row_pitch = cute::ceil_div(N, elem_alignment) * elem_alignment;

auto tensor_shape = make_shape(M, N);
auto tensor_layout = make_layout(tensor_shape, make_stride(row_pitch, 1));
auto block_shape = make_shape(Int<wg_tile_m>{}, Int<wg_tile_n>{});
auto subgroup_shape = make_shape(Int<sg_tile_m>{}, Int<sg_tile_n>{});

//
// Allocate and initialize
//
cutlass::host_vector<dtype> host_src(size(tensor_shape));
cutlass::host_vector<dtype> host_output(size(tensor_shape));
cutlass::host_vector<dtype> host_src(cute::cosize(tensor_layout));
cutlass::host_vector<dtype> host_output(cute::cosize(tensor_layout));

for (size_t i = 0; i < host_src.size(); ++i) {
host_src[i] = static_cast<dtype>(i);
Expand All @@ -179,10 +183,8 @@ bool copy(uint32_t M, uint32_t N) {
// Make tensors
//

Tensor tensor_S = make_tensor(make_gmem_ptr(device_src.data()),
make_layout(tensor_shape, make_stride(N, 1)));
Tensor tensor_D = make_tensor(make_gmem_ptr(device_output.data()),
make_layout(tensor_shape, make_stride(N, 1)));
Tensor tensor_S = make_tensor(make_gmem_ptr(device_src.data()), tensor_layout);
Tensor tensor_D = make_tensor(make_gmem_ptr(device_output.data()), tensor_layout);

//
// Tile tensors
Expand Down Expand Up @@ -216,7 +218,7 @@ bool copy(uint32_t M, uint32_t N) {
wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n>>(
launch_policy{gridDim, blockDim,
kernel_properties{sycl_exp::sub_group_size<SUBGROUP_SIZE>}},
tensor_S, tensor_D, M, N);
tensor_S, tensor_D);

syclcompat::wait_and_throw();

Expand All @@ -226,22 +228,21 @@ bool copy(uint32_t M, uint32_t N) {

host_output = device_output;

auto surface_pitch = N;
for (int i = 0; i < sg_tile_m && i < M; i++) {
for (int j = 0; j < sg_tile_n && j < N; j++) {
EXPECT_EQ(host_output[surface_pitch * i + j], surface_pitch * i + j);
EXPECT_EQ(host_output[row_pitch * i + j], row_pitch * i + j);
}
}

for (int i = sg_tile_m; i < sg_tile_m + 1 && i < M; i++) {
for (int j = 0; j < sg_tile_n && j < N; j++) {
EXPECT_NE(host_output[surface_pitch * i + j], surface_pitch * i + j);
EXPECT_NE(host_output[row_pitch * i + j], row_pitch * i + j);
}
}

for (int i = 0; i < sg_tile_m && i < M; i++) {
for (int j = sg_tile_n; j < sg_tile_n + 1 && j < N; j++) {
EXPECT_NE(host_output[surface_pitch * i + j], surface_pitch * i + j);
EXPECT_NE(host_output[row_pitch * i + j], row_pitch * i + j);
}
}
return true;
Expand Down
4 changes: 2 additions & 2 deletions test/unit/gemm/device/gemm_testbed_3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4018,7 +4018,7 @@ template <typename Gemm, template <class T> class ActivationFunctor =
// TODO(Codeplay): remove the test_batch option once batching is enabled for all tests
bool TestXe(
double alpha = 1.0, double beta = 0.0,
bool test_batch = true, int max_alignment = 4,
bool test_batch = true, int max_alignment = 8,
CheckEquality check_relative_equality = CheckEquality::RELATIVE) {
using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
Expand All @@ -4040,7 +4040,7 @@ bool TestXe(
std::vector<int> problem_size_l = test_batch ? std::vector{1, 3, 4} : std::vector{1};

constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
std::vector<int> problem_size_k{TileShapeK};
std::vector<int> problem_size_k{TileShapeK, TileShapeK*32};

using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode;
std::vector decomposition_modes = {DecompositionMode::Heuristic};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(XE_Device_GemmUniversal_f16t_s4n_f32t_mixed_input_tensor_op_f32, 128x128x64
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// TODO(Codeplay): gemm batch doesn't work for mixed type
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 8);
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 16);
EXPECT_TRUE(passed);
}
////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(XE_Device_GemmUniversal_f16t_s4t_f32t_mixed_input_tensor_op_f32, 128x128x64
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// TODO(Codeplay): gemm batch doesn't work for mixed type
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 8);
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0, false, 32);
EXPECT_TRUE(passed);
}
////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion test/unit/gemm/device/xe_gemm_s8_s8_s32_tensor_op_s32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TEST(XE_Device_Gemm_s8t_s8t_s32t_tensor_op_s32, 256x256x32) {
using LayoutA = layout::RowMajor;
using LayoutB = layout::RowMajor;
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32<LayoutA, LayoutB>::Gemm;
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, true, 16));
}

/* TODO(Codeplay): Transposed copy are not implemented
Expand Down
Loading