Skip to content

Commit

Permalink
CK-Tile Grouped GEMM refactor and post PR fixes (#1756)
Browse files Browse the repository at this point in the history
* Grouped gemm simple code refactor

* Offset invoker

* Invoke generic Run, and replace name of parrtitioner variable

* Tests fix type

* Removed namespaces

* Add template param to avoid implicit cast

* Remove generic function

* Constant value

* underline enum to int16_t

* Generalize partitioner function

* Remove whitespaces

* Rename function

* Using support

* Clang-format

* Clang-format

* Fn-partitioner description fn

* Typo

* Typo 2

* Better description

* Better description

* Refactor after review

* Use ctr instead of set fn

* Inovke ctr and typo

* Comments

* Remove unnecessary comment

* Review, remove modulo
  • Loading branch information
mozga-amd authored Jan 21, 2025
1 parent e7dce4d commit 3c93d3c
Show file tree
Hide file tree
Showing 17 changed files with 342 additions and 367 deletions.
8 changes: 4 additions & 4 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <hip/hip_runtime.h>

Expand Down Expand Up @@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;

using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;

using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
Expand All @@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;

Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;

using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/16_batched_gemm/batched_gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <hip/hip_runtime.h>

Expand Down Expand Up @@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;

using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;

using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
Expand All @@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;

Expand Down
3 changes: 1 addition & 2 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"

namespace {

Expand Down Expand Up @@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue<CLayout>>;
}; // namespace

std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}

std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);

float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
20 changes: 10 additions & 10 deletions example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup,
{

ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
gemm_workspace.Realloc(get_workspace_size(args));

float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args,
Expand Down Expand Up @@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];

stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{});

a_m_k_tensors.push_back(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));

std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
Expand Down Expand Up @@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
Expand Down
38 changes: 0 additions & 38 deletions example/ck_tile/17_grouped_gemm/utils.hpp

This file was deleted.

1 change: 0 additions & 1 deletion include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
Expand Down
57 changes: 51 additions & 6 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand All @@ -12,18 +12,37 @@

namespace ck_tile {

enum struct address_space_enum
template <typename, bool>
struct safe_underlying_type;

template <typename T>
struct safe_underlying_type<T, true>
{
using type = std::underlying_type_t<T>;
};

template <typename T>
struct safe_underlying_type<T, false>
{
using type = void;
};

template <typename T>
using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;

enum struct address_space_enum : std::uint16_t
{
generic,
generic = 0,
global,
lds,
sgpr,
vgpr,
constant,
vgpr
};

enum struct memory_operation_enum
enum struct memory_operation_enum : std::uint16_t
{
set,
set = 0,
atomic_add,
atomic_max,
add
Expand Down Expand Up @@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif
}

#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))

template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)(p); // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}

template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}

} // namespace ck_tile
37 changes: 0 additions & 37 deletions include/ck_tile/core/utility/amd_address_space.hpp

This file was deleted.

35 changes: 34 additions & 1 deletion include/ck_tile/host/host_tensor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -678,4 +678,37 @@ struct HostTensor
Descriptor mDesc;
Data mData;
};

template <typename TLayout>
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;

if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)

This comment has been minimized.

Copy link
@carlushuang

carlushuang Jan 22, 2025

Contributor

utilize the tensor_layout but not incude the corresponding header

{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}

} // namespace ck_tile
9 changes: 6 additions & 3 deletions include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep

CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);

const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);

const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);

Expand Down
Loading

0 comments on commit 3c93d3c

Please sign in to comment.