From f628befb18b261706144190dc204c3394f8ec64b Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 16 Jan 2025 21:13:12 -0800 Subject: [PATCH] update --- benchmarks/pvc/benchmarks.hpp | 9 ++++++--- include/cute/arch/xe_copy_1B.hpp | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index 875e641fd3..90c2ba07f8 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -94,7 +94,8 @@ using PvcGemmBF16BF16FP32_RCR_6 = cutlass::gemm::device::GemmConfiguration< float, cutlass::layout::RowMajor, float, Shape<_8, _128, _32>, TiledMMA>>, - XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T>; + XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T, + Scheduler::Gemm>; using PvcGemmBF16BF16FP32_CRR_7 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, @@ -103,7 +104,8 @@ using PvcGemmBF16BF16FP32_CRR_7 = cutlass::gemm::device::GemmConfiguration< float, cutlass::layout::RowMajor, float, Shape<_8, _128, _32>, TiledMMA>>, - XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V>; + XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V, + Scheduler::Gemm>; using PvcGemmBF16BF16FP32_CCR_8 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, @@ -112,7 +114,8 @@ using PvcGemmBF16BF16FP32_CCR_8 = cutlass::gemm::device::GemmConfiguration< float, cutlass::layout::RowMajor, float, Shape<_8, _128, _32>, TiledMMA>>, - XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T>; + XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T, + Scheduler::Gemm>; CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); diff --git a/include/cute/arch/xe_copy_1B.hpp b/include/cute/arch/xe_copy_1B.hpp index f647507d1d..ce8193d350 100644 --- a/include/cute/arch/xe_copy_1B.hpp +++ b/include/cute/arch/xe_copy_1B.hpp @@ -394,6 +394,8 @@ struct XE_2D_U8x1x64_LD_N { }; struct XE_2D_U8x2x64_LD_N { + using BlockShape = Shape<_2, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -521,6 +523,8 @@ struct XE_2D_U8x16x64_LD_N { }; struct XE_2D_U8x32x64_LD_N { + using BlockShape = Shape<_32, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -554,6 +558,8 @@ struct XE_2D_U8x32x64_LD_N { struct XE_2D_U8x32x16_LD_V { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -602,6 +608,8 @@ struct XE_2D_U8x32x32_LD_V { }; struct XE_2D_U8x32x64_LD_V { + using BlockShape = Shape<_32, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -618,6 +626,8 @@ struct XE_2D_U8x32x64_LD_V { }; struct XE_2D_U8x1x16_ST_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -634,6 +644,8 @@ struct XE_2D_U8x1x16_ST_N { }; struct XE_2D_U8x2x16_ST_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -650,6 +662,8 @@ struct XE_2D_U8x2x16_ST_N { }; struct XE_2D_U8x4x16_ST_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord,