Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xegpu-to-vc-func] Is transpose_bit_width=16 supported? #895

Open
dchigarev opened this issue Sep 25, 2024 · 1 comment
Open

[xegpu-to-vc-func] Is transpose_bit_width=16 supported? #895

dchigarev opened this issue Sep 25, 2024 · 1 comment
Labels
question Further information is requested

Comments

@dchigarev
Copy link
Contributor

dchigarev commented Sep 25, 2024

I'm trying to transpose a 16x16xf16 matrix using xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} but the values are being transposed in the '32bit manner' (although transpose_bit_width=16). Is this an expected behavior or a bug?

Reproducer
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp  \
// RUN:                                       --runner imex-cpu-runner -e main \
// RUN:                                       --entry-point-result=void \
// RUN:                                       --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime

module attributes {gpu.container_module} {
  gpu.module @transpose_16bit_loadnd attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Bfloat16ConversionINTEL, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, StorageBuffer16BitAccess, VectorComputeINTEL, VectorAnyINTEL], [SPV_INTEL_bfloat16_conversion, SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_KHR_16bit_storage, SPV_NV_cooperative_matrix, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
    gpu.func @transpose_16bit_loadnd(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 2, 2, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
  
      %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
      %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>

      %2 = xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
      xegpu.store_nd %2, %1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
      gpu.return
    }
  }

  func.func @main() {
    %c_gen_int = arith.constant 0 : i1
    %cf_lower = arith.constant -0.5 : f32
    %cf_upper = arith.constant 0.5 : f32

    %result = memref.alloc() : memref<16x16xf16>
    %resultc = memref.alloc() : memref<16x16xf16>
    %r_r = memref.cast %result : memref<16x16xf16> to memref<*xf16>
    call @fillResource1DRandomF16(%r_r, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()

    %cast2 = memref.cast %result : memref<16x16xf16> to memref<*xf16>
    call @printMemrefF16(%cast2) : (memref<*xf16>) -> ()

    %gpu_result_index = gpu.alloc host_shared () : memref<16x16xf16>
    %gpu_result = gpu.alloc host_shared () : memref<16x16xf16>

    memref.copy %result, %gpu_result_index : memref<16x16xf16> to memref<16x16xf16>

    gpu.launch_func @transpose_16bit_loadnd::@transpose_16bit_loadnd blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)  args(%gpu_result_index : memref<16x16xf16>, %gpu_result : memref<16x16xf16>)
    memref.copy %gpu_result, %resultc : memref<16x16xf16> to memref<16x16xf16>
    %cast1 = memref.cast %resultc : memref<16x16xf16> to memref<*xf16>
    call @printMemrefF16(%cast1) : (memref<*xf16>) -> ()

    return
  }
  func.func private @printMemrefF16(memref<*xf16>)
  func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
}
Output
Original matrix:
[[-0.0335999,   -0.108459,   0.454346,   0.173096,   0.291992,   -0.437744,   0.150879,   0.243774,   -0.118896,   0.390625,   -0.337402,   0.184204,   0.148682,   0.109863,   0.131592,   0.167603], 
 [-0.383789,   0.182007,   0.157837,   0.016922,   0.403564,   -0.355957,   -0.465576,   -0.0371094,   -0.167603,   -0.0213776,   -0.107849,   -0.364502,   -0.49292,   -0.40625,   -0.474121,   0.259277], 
 [0.212158,   -0.0585938,   0.307861,   0.357178,   -0.0243073,   -0.301514,   0.157715,   0.0397949,   -0.115845,   -0.0805054,   0.354248,   0.288818,   -0.387695,   0.265137,   -0.191528,   0.23584], 
 [0.186035,   0.13623,   0.164795,   0.321777,   -0.131348,   0.189575,   0.437744,   -0.437256,   -0.488281,   0.104675,   0.223145,   0.468994,   0.471436,   0.289551,   -0.388184,   0.24231], 
 [0.152832,   -0.233521,   0.0818481,   -0.445312,   -0.0191803,   0.349854,   0.472168,   -0.358398,   -0.220459,   0.244751,   -0.0543518,   0.000132799,   0.288086,   0.0359192,   -0.0933838,   0.165527], 
 [-0.0643311,   -0.368896,   0.398438,   0.125854,   0.174438,   0.010788,   0.0161896,   -0.0637817,   -0.450928,   -0.256104,   -0.0791016,   0.197266,   -0.274658,   -0.172607,   -0.0960693,   0.376221], 
 [0.326416,   0.428223,   0.0844116,   -0.111023,   0.288574,   -0.287109,   0.147461,   0.489258,   -0.109314,   0.0188751,   -0.375732,   0.175903,   -0.309082,   -0.172852,   -0.499756,   -0.102051], 
 [-0.395508,   -0.160034,   -0.210571,   0.429688,   -0.302246,   -0.0577393,   -0.0242767,   -0.174194,   0.21228,   0.110107,   0.34082,   0.348877,   -0.255371,   0.156738,   0.143066,   -0.0538025], 
 [0.43042,   -0.496338,   0.0446472,   0.376465,   -0.153564,   -0.231934,   0.322266,   -0.2771,   -0.272949,   0.0265045,   0.293457,   -0.207886,   -0.248657,   -0.244141,   0.118164,   -0.167969], 
 [-0.318359,   0.33252,   0.192261,   -0.403564,   0.23877,   0.078064,   0.400391,   -0.290771,   -0.12323,   0.0836182,   0.265381,   -0.337891,   -0.431396,   0.262207,   0.0490723,   0.0157623], 
 [0.310059,   -0.481201,   -0.0360413,   0.371582,   0.39624,   0.413086,   0.307861,   0.499756,   0.0454102,   -0.2771,   0.352783,   -0.0714111,   0.184082,   0.4729,   -0.0998535,   -0.420166], 
 [0.445312,   -0.265381,   -0.182983,   -0.249146,   -0.437256,   -0.298828,   -0.418701,   -0.0429688,   0.0679932,   -0.256836,   -0.38208,   0.378174,   0.0784302,   -0.149658,   0.232544,   -0.249634], 
 [-0.318115,   -0.179443,   -0.33667,   -0.43335,   -0.129028,   0.0360718,   -0.48999,   0.333984,   0.356934,   0.238159,   -0.198608,   0.0809326,   0.0897827,   0.209839,   -0.0469055,   -0.409668], 
 [0.345947,   -0.0787354,   0.0486755,   0.098938,   0.0684204,   0.227295,   0.0414429,   0.465576,   0.204834,   0.419189,   0.297119,   -0.347412,   -0.0586853,   0.239746,   0.174805,   0.0572205], 
 [0.261963,   -0.0251923,   0.481201,   -0.470703,   -0.0614014,   0.305176,   -0.439209,   -0.0430603,   0.346924,   -0.411377,   0.00965118,   0.00774765,   -0.378906,   0.466309,   -0.13623,   -0.0748901], 
 [-0.241821,   0.000869751,   -0.336914,   -0.0773926,   0.469238,   -0.218994,   0.362305,   0.00957489,   -0.297852,   0.0365906,   -0.382568,   0.308594,   0.134277,   -0.322998,   -0.445557,   0.158325]]


Transposed matrix (with transpose_bit_width=16, but seems like it's still 32):
[[-0.0335999,   -0.108459,   -0.383789,   0.182007,   0.212158,   -0.0585938,   0.186035,   0.13623,   0.152832,   -0.233521,   -0.0643311,   -0.368896,   0.326416,   0.428223,   -0.395508,   -0.160034], 
 [0.43042,   -0.496338,   -0.318359,   0.33252,   0.310059,   -0.481201,   0.445312,   -0.265381,   -0.318115,   -0.179443,   0.345947,   -0.0787354,   0.261963,   -0.0251923,   -0.241821,   0.000869751], 
 [0.454346,   0.173096,   0.157837,   0.016922,   0.307861,   0.357178,   0.164795,   0.321777,   0.0818481,   -0.445312,   0.398438,   0.125854,   0.0844116,   -0.111023,   -0.210571,   0.429688], 
 [0.0446472,   0.376465,   0.192261,   -0.403564,   -0.0360413,   0.371582,   -0.182983,   -0.249146,   -0.33667,   -0.43335,   0.0486755,   0.098938,   0.481201,   -0.470703,   -0.336914,   -0.0773926], 
 [0.291992,   -0.437744,   0.403564,   -0.355957,   -0.0243073,   -0.301514,   -0.131348,   0.189575,   -0.0191803,   0.349854,   0.174438,   0.010788,   0.288574,   -0.287109,   -0.302246,   -0.0577393], 
 [-0.153564,   -0.231934,   0.23877,   0.078064,   0.39624,   0.413086,   -0.437256,   -0.298828,   -0.129028,   0.0360718,   0.0684204,   0.227295,   -0.0614014,   0.305176,   0.469238,   -0.218994], 
 [0.150879,   0.243774,   -0.465576,   -0.0371094,   0.157715,   0.0397949,   0.437744,   -0.437256,   0.472168,   -0.358398,   0.0161896,   -0.0637817,   0.147461,   0.489258,   -0.0242767,   -0.174194], 
 [0.322266,   -0.2771,   0.400391,   -0.290771,   0.307861,   0.499756,   -0.418701,   -0.0429688,   -0.48999,   0.333984,   0.0414429,   0.465576,   -0.439209,   -0.0430603,   0.362305,   0.00957489], 
 [-0.118896,   0.390625,   -0.167603,   -0.0213776,   -0.115845,   -0.0805054,   -0.488281,   0.104675,   -0.220459,   0.244751,   -0.450928,   -0.256104,   -0.109314,   0.0188751,   0.21228,   0.110107], 
 [-0.272949,   0.0265045,   -0.12323,   0.0836182,   0.0454102,   -0.2771,   0.0679932,   -0.256836,   0.356934,   0.238159,   0.204834,   0.419189,   0.346924,   -0.411377,   -0.297852,   0.0365906], 
 [-0.337402,   0.184204,   -0.107849,   -0.364502,   0.354248,   0.288818,   0.223145,   0.468994,   -0.0543518,   0.000132799,   -0.0791016,   0.197266,   -0.375732,   0.175903,   0.34082,   0.348877], 
 [0.293457,   -0.207886,   0.265381,   -0.337891,   0.352783,   -0.0714111,   -0.38208,   0.378174,   -0.198608,   0.0809326,   0.297119,   -0.347412,   0.00965118,   0.00774765,   -0.382568,   0.308594], 
 [0.148682,   0.109863,   -0.49292,   -0.40625,   -0.387695,   0.265137,   0.471436,   0.289551,   0.288086,   0.0359192,   -0.274658,   -0.172607,   -0.309082,   -0.172852,   -0.255371,   0.156738], 
 [-0.248657,   -0.244141,   -0.431396,   0.262207,   0.184082,   0.4729,   0.0784302,   -0.149658,   0.0897827,   0.209839,   -0.0586853,   0.239746,   -0.378906,   0.466309,   0.134277,   -0.322998], 
 [0.131592,   0.167603,   -0.474121,   0.259277,   -0.191528,   0.23584,   -0.388184,   0.24231,   -0.0933838,   0.165527,   -0.0960693,   0.376221,   -0.499756,   -0.102051,   0.143066,   -0.0538025], 
 [0.118164,   -0.167969,   0.0490723,   0.0157623,   -0.0998535,   -0.420166,   0.232544,   -0.249634,   -0.0469055,   -0.409668,   0.174805,   0.0572205,   -0.13623,   -0.0748901,   -0.445557,   0.158325]]
@chencha3
Copy link
Contributor

chencha3 commented Sep 25, 2024

On current public platform, only 32-bit is supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants