diff --git a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index a92005eb3..589b3b033 100644 --- a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -162,8 +162,8 @@ void GPUXToSPIRVPass::runOnOperation() { }); typeConverter.addConversion( [&](xegpu::TensorDescType type) -> ::mlir::Type { - auto i64Type = ::mlir::IntegerType::get(context, 64); - return ::mlir::VectorType::get(2, i64Type); + auto i32Type = ::mlir::IntegerType::get(context, 32); + return ::mlir::VectorType::get(8, i32Type); }); typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type { unsigned rank = type.getRank(); diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index 67f29668b..bf8700548 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -222,9 +222,11 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op, } /// @brief -/// convert the tensor descriptor to [2xi64] which is of the format -/// -> [base pointer: i64, offsetX: i32, offsetY: i32] for 2D tensor desc -/// -> [base pointer: i64, unused] for 1D and scattered tensor desc +/// assemble the tensor descriptor payload[8xi32] which is of the format +/// -> [base pointer, surface width, surface height, surface pitch, +/// offsetX, offsetY, blockInfo] for 2D tensor desc +/// -> [base pointer, unused] for 1D and scattered tensor desc +/// only base pointer is i64, others are i32 class CreateNdDescToVCPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -235,9 +237,9 @@ class CreateNdDescToVCPattern : public OpConversionPattern { auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); // payload - auto v4i32 = VectorType::get(4, i32Type); - auto v2i64 = VectorType::get(2, i64Type); - Value payLoad = rewriter.create(loc, v2i64); + auto v8i32 = VectorType::get(8, i32Type); + auto v4i64 = VectorType::get(4, i64Type); + Value payLoad = rewriter.create(loc, v4i64); auto createIntConstant = [&](Type type, unsigned value) { auto attr = rewriter.getIntegerAttr(type, value); return rewriter.create(loc, type, attr); @@ -247,10 +249,34 @@ class CreateNdDescToVCPattern : public OpConversionPattern { auto idx0 = createIntConstant(i32Type, 0); payLoad = rewriter.create(loc, payLoad, base, idx0); + payLoad = rewriter.create(loc, v8i32, payLoad); auto tileType = op.getTensorDesc().getType(); auto rank = tileType.getRank(); if (rank == 2) { - payLoad = rewriter.create(loc, v4i32, payLoad); + auto idx2 = createIntConstant(i32Type, 2); + auto idx3 = createIntConstant(i32Type, 3); + auto idx4 = createIntConstant(i32Type, 4); + auto idx5 = createIntConstant(i32Type, 5); + auto idx6 = createIntConstant(i32Type, 6); + auto idx7 = createIntConstant(i32Type, 7); + auto blockWidth = tileType.getShape()[1]; + auto blockHeight = tileType.getShape()[0]; + // fixme: support memref for now + auto memType = cast(op.getSource().getType()); + unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); + auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; + auto surfaceHeight = memType.getShape()[0] - 1; + // fixme: pitch = width for now + auto surfacePitch = surfaceWidth; + auto surfaceW = createIntConstant(i32Type, surfaceWidth); + auto surfaceH = createIntConstant(i32Type, surfaceHeight); + auto surfaceP = createIntConstant(i32Type, surfacePitch); + payLoad = rewriter.create(loc, payLoad, + surfaceW, idx2); + payLoad = rewriter.create(loc, payLoad, + surfaceH, idx3); + payLoad = rewriter.create(loc, payLoad, + surfaceP, idx4); auto createOffset = [&](unsigned idx) -> Value { Value val; if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) { @@ -263,13 +289,14 @@ class CreateNdDescToVCPattern : public OpConversionPattern { }; auto offsetX = createOffset(1); auto offsetY = createOffset(0); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); payLoad = rewriter.create(loc, payLoad, - offsetX, idx2); + offsetX, idx5); payLoad = rewriter.create(loc, payLoad, - offsetY, idx3); - payLoad = rewriter.create(loc, v2i64, payLoad); + offsetY, idx6); + unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1); + auto blockInfo = createIntConstant(i32Type, blockVal); + payLoad = rewriter.create(loc, payLoad, + blockInfo, idx7); } rewriter.replaceOp(op, payLoad); return success(); @@ -283,32 +310,29 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern { matchAndRewrite(UpdateNDOffsetOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto desc = adaptor.getTensorDesc(); auto i32Type = rewriter.getI32Type(); - auto v4i32 = VectorType::get(4, i32Type); - auto v2i64 = VectorType::get(2, rewriter.getI64Type()); - Value cast = rewriter.create(loc, v4i32, desc); auto offsets = adaptor.getOffsets(); + auto desc = adaptor.getTensorDesc(); for (auto i = 0; i < offsets.size(); i++) { auto offset = offsets[i]; if (auto cst = dyn_cast(offset.getDefiningOp())) if (auto attr = dyn_cast(cst.getValue()); attr && attr.getInt() == 0) continue; - auto idx2 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 2)); - auto idx3 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 3)); - Value idx = i == 0 ? idx3 : idx2; + auto idx5 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 5)); + auto idx6 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 6)); + Value idx = i == 0 ? idx6 : idx5; auto oldOffset = - rewriter.create(loc, cast, idx); + rewriter.create(loc, desc, idx); offset = rewriter.create(loc, i32Type, offset); auto newOffset = rewriter.create(loc, i32Type, oldOffset, offset); - cast = rewriter.create(loc, v4i32, cast, - newOffset, idx); + desc = rewriter.create(loc, desc, newOffset, + idx); } - rewriter.replaceOpWithNewOp(op, v2i64, cast); + rewriter.replaceOp(op, desc); return success(); } }; @@ -322,14 +346,16 @@ class CreateDescToVCPattern : public OpConversionPattern { auto loc = op.getLoc(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); - auto v2i64 = VectorType::get(2, i64Type); - Value payLoad = rewriter.create(loc, v2i64); + auto v8i32 = VectorType::get(8, i32Type); + auto v4i64 = VectorType::get(4, i64Type); + Value payLoad = rewriter.create(loc, v4i64); auto base = rewriter.create(loc, i64Type, adaptor.getSource()); auto idx0 = rewriter.create( loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); payLoad = rewriter.create(loc, payLoad, base, idx0); + payLoad = rewriter.create(loc, v8i32, payLoad); rewriter.replaceOp(op, payLoad); return success(); } @@ -369,6 +395,8 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto i8Type = rewriter.getI8Type(); auto i16Type = rewriter.getI16Type(); auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); auto vnni = false; auto transpose = false; if constexpr (isLoad) { @@ -396,11 +424,9 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto nBlks = createIntConstant(i8Type, 1); auto tensorDesc = adaptor.getTensorDesc(); auto idx0 = createIntConstant(i32Type, 0); - auto base = - rewriter.create(loc, tensorDesc, idx0); - std::string typeStr; - VectorType newType; - std::tie(typeStr, newType) = encodeVectorType(rewriter, vecType, rank == 1); + auto cast = rewriter.create(loc, v4i64, tensorDesc); + auto base = rewriter.create(loc, cast, idx0); + auto [typeStr, newType] = encodeVectorType(rewriter, vecType, rank == 1); SmallVector args; if (rank == 2) { auto blockWidth = tileType.getShape()[1]; @@ -411,7 +437,7 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { // static memref for now auto createDescOp = op.getTensorDesc().template getDefiningOp(); - auto memType = cast(createDescOp.getSource().getType()); + auto memType = llvm::cast(createDescOp.getSource().getType()); unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; auto surfaceHeight = memType.getShape()[0] - 1; @@ -420,14 +446,12 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern { auto surfaceW = createIntConstant(i32Type, surfaceWidth); auto surfaceH = createIntConstant(i32Type, surfaceHeight); auto surfaceP = createIntConstant(i32Type, surfacePitch); - auto v4i32 = VectorType::get(4, i32Type); - tensorDesc = rewriter.create(loc, v4i32, tensorDesc); - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); + auto idx5 = createIntConstant(i32Type, 5); + auto idx6 = createIntConstant(i32Type, 6); auto offsetX = - rewriter.create(loc, tensorDesc, idx2); + rewriter.create(loc, tensorDesc, idx5); auto offsetY = - rewriter.create(loc, tensorDesc, idx3); + rewriter.create(loc, tensorDesc, idx6); args.assign({pred, l1CacheHint, l3CacheHint, dataum, trans, nBlks, blockW, blockH, transform, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY}); @@ -539,7 +563,6 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern { auto i1Type = rewriter.getI1Type(); auto i8Type = rewriter.getI8Type(); auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); auto vnni = false; auto transpose = false; if constexpr (isLoad) { @@ -607,66 +630,7 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern { rawSendMsg |= 1 << 25; } auto msg = createIntConstant(i32Type, rawSendMsg); - // payload - // payload is v8i32 = [base:i64, surfaceWidth:i32, surfaceHeight:i32, - // surefacePitch:i32, offsetX:i32, offsetY:i32, blockInfo:i32] - // the base/surfaceInfo/blockInfo are staticly from the tensor desc - // while the offsetX/Y are dynamicly udpated - auto insertPoint = rewriter.saveInsertionPoint(); - CreateNdDescOp createDescOp = *findDescOp(op.getTensorDesc()); - rewriter.setInsertionPointAfter(createDescOp); - auto v8i32 = VectorType::get(8, i32Type); - auto v4i64 = VectorType::get(4, i64Type); - Value payLoad = rewriter.create(loc, v4i64); - auto idx0 = createIntConstant(i32Type, 0); - auto desc = rewriter.getRemappedValue(createDescOp); - auto base = rewriter.create(loc, desc, idx0); - payLoad = - rewriter.create(loc, payLoad, base, idx0); - payLoad = rewriter.create(loc, v8i32, payLoad); - if (rank == 2) { - auto idx2 = createIntConstant(i32Type, 2); - auto idx3 = createIntConstant(i32Type, 3); - auto idx4 = createIntConstant(i32Type, 4); - auto idx5 = createIntConstant(i32Type, 5); - auto idx6 = createIntConstant(i32Type, 6); - auto idx7 = createIntConstant(i32Type, 7); - auto blockWidth = tileType.getShape()[1]; - auto blockHeight = tileType.getShape()[0]; - // fixme: support memref for now - auto memType = cast(createDescOp.getSource().getType()); - unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth(); - auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1; - auto surfaceHeight = memType.getShape()[0] - 1; - // fixme: pitch = width for now - auto surfacePitch = surfaceWidth; - auto surfaceW = createIntConstant(i32Type, surfaceWidth); - auto surfaceH = createIntConstant(i32Type, surfaceHeight); - auto surfaceP = createIntConstant(i32Type, surfacePitch); - payLoad = rewriter.create(loc, payLoad, - surfaceW, idx2); - payLoad = rewriter.create(loc, payLoad, - surfaceH, idx3); - payLoad = rewriter.create(loc, payLoad, - surfaceP, idx4); - unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1); - auto blockInfo = createIntConstant(i32Type, blockVal); - payLoad = rewriter.create(loc, payLoad, - blockInfo, idx7); - rewriter.restoreInsertionPoint(insertPoint); - auto v4i32 = VectorType::get(4, i32Type); - auto tensorDesc = adaptor.getTensorDesc(); - tensorDesc = rewriter.create(loc, v4i32, tensorDesc); - auto offsetX = - rewriter.create(loc, tensorDesc, idx2); - auto offsetY = - rewriter.create(loc, tensorDesc, idx3); - payLoad = rewriter.create(loc, payLoad, - offsetX, idx5); - payLoad = rewriter.create(loc, payLoad, - offsetY, idx6); - } - rewriter.restoreInsertionPoint(insertPoint); + auto payLoad = adaptor.getTensorDesc(); SmallVector args{modifier, execSize, pred, numSrc1, numDst, sfid, extMsg, msg, payLoad}; if constexpr (isLoad) { @@ -796,8 +760,10 @@ class GatherScatterToRawSend : public OpConversionPattern { auto i8Type = rewriter.getI8Type(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); auto tensorDesc = adaptor.getTensorDesc(); auto idx0 = createIntConstant(i32Type, 0); + tensorDesc = rewriter.create(loc, v4i64, tensorDesc); auto base = rewriter.create(loc, tensorDesc, idx0); VectorType newType = VectorType::get(1, i32Type); @@ -908,6 +874,7 @@ class AtomicToLsc : public OpConversionPattern { auto i16Type = rewriter.getI16Type(); auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); VectorType vecType = cast(op.getResult().getType()); std::string funcName = "llvm_genx_lsc_xatomic_stateless_"; auto [typeStr, newType] = encodeVectorType(rewriter, vecType, false, true); @@ -936,6 +903,7 @@ class AtomicToLsc : public OpConversionPattern { auto mask = createIntConstant(i8Type, 0); auto tensorDesc = adaptor.getTensorDesc(); + tensorDesc = rewriter.create(loc, v4i64, tensorDesc); auto idx0 = createIntConstant(i32Type, 0); auto base = rewriter.create(loc, tensorDesc, idx0); diff --git a/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir new file mode 100644 index 000000000..c58d2a6c1 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/gemm_1024x1024xf16.using.updateoffset.mlir @@ -0,0 +1,111 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> + %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> + %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + return %memref_1 : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %0 = gpu.block_id x + %1 = gpu.block_id y + %2 = arith.muli %0, %c8 : index + %3 = arith.muli %1, %c16 : index + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %5 = xegpu.load_nd %4 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block + %7 = xegpu.create_nd_tdesc %arg0[%2, %c0] {mode=vc}: memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %8 = xegpu.create_nd_tdesc %arg1[%c0, %3] {mode=vc}: memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5, %subA = %7, %subB = %8) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>) { + %9 = xegpu.load_nd %subA {mode=vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %10 = xegpu.load_nd %subB {mode=vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %11 = xegpu.dpas %9, %10, %arg4 {mode=vc}: vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %12 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode=vc}: !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %13 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode=vc}: !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + scf.yield %11, %12, %13: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16> + } + xegpu.store_nd %6#0, %4 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> + %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> + %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> + %init = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + // fill the top-left block 128x128 + // A matrix: row-major, start from 0.0, increase 0.01 per element + // B matrix: A matrix + 1.0 + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c128 step %c1 { + %int0 = arith.index_cast %arg0 : index to i16 + %int1 = arith.index_cast %arg1 : index to i16 + %c128_i16 = arith.constant 128 : i16 + %idx0 = arith.muli %int0, %c128_i16 : i16 + %idx1 = arith.addi %int1, %idx0 : i16 + %fp = arith.uitofp %idx1 : i16 to f16 + %cst100 = arith.constant 100.0 : f16 + %val0 = arith.divf %fp, %cst100 : f16 + %cst1 = arith.constant 1.0 : f16 + %val1 = arith.addf %val0, %cst1 : f16 + memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + } + } + // caculate the result C matrix + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> + %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { + %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> + %c = arith.mulf %a, %b : f16 + %cc = arith.extf %c : f16 to f32 + %ccc = arith.addf %cc, %arg3 : f32 + scf.yield %ccc : f32 + } + memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + } + } + + %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + //call @printMemrefF32(%cast) : (memref<*xf32>) -> () + %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> + //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeGPU/lit.local.cfg b/test/Integration/Dialect/XeGPU/lit.local.cfg index cf920ae42..e084b0d12 100644 --- a/test/Integration/Dialect/XeGPU/lit.local.cfg +++ b/test/Integration/Dialect/XeGPU/lit.local.cfg @@ -1,5 +1,6 @@ local_excludes = [ 'gemm_1024x1024xf16.mlir', + 'gemm_1024x1024xf16.using.updateoffset.mlir', 'gemm_1024x1016x1016_f16_f16_f32.mlir', 'load2d_dpas_store2d.mlir', 'load2d-padding-f32.mlir',