-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
misc: add wgsl target and move module printing test case to filecheck (…
…#3615) A bit of cleanup
- Loading branch information
1 parent
f61725b
commit 1baf7d8
Showing
6 changed files
with
356 additions
and
368 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
from io import StringIO | ||
|
||
from xdsl.backend.wgsl.wgsl_printer import WGSLPrinter | ||
from xdsl.dialects import arith, gpu, memref, test | ||
from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType, i32 | ||
from xdsl.utils.test_value import TestSSAValue | ||
|
||
lhs_op = test.TestOp(result_types=[IndexType()]) | ||
rhs_op = test.TestOp(result_types=[IndexType()]) | ||
|
||
|
||
def test_gpu_global_id(): | ||
file = StringIO("") | ||
|
||
global_id_x = gpu.GlobalIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X)) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(global_id_x, file) | ||
|
||
assert "let v0: u32 = global_invocation_id.x;" in file.getvalue() | ||
|
||
|
||
def test_gpu_thread_id(): | ||
file = StringIO("") | ||
|
||
thread_id_x = gpu.ThreadIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X)) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(thread_id_x, file) | ||
|
||
assert "let v0: u32 = local_invocation_id.x;" in file.getvalue() | ||
|
||
|
||
def test_gpu_block_id(): | ||
file = StringIO("") | ||
|
||
block_id_x = gpu.BlockIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X)) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(block_id_x, file) | ||
|
||
assert "let v0: u32 = workgroup_id.x;" in file.getvalue() | ||
|
||
|
||
def test_gpu_grid_dim(): | ||
file = StringIO("") | ||
|
||
num_workgroups = gpu.GridDimOp(gpu.DimensionAttr(gpu.DimensionEnum.X)) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(num_workgroups, file) | ||
|
||
assert "let v0: u32 = num_workgroups.x;" in file.getvalue() | ||
|
||
|
||
def test_arith_constant_unsigned(): | ||
file = StringIO("") | ||
|
||
cst = arith.ConstantOp(IntegerAttr(42, IndexType())) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(cst, file) | ||
|
||
assert "let v0 : u32 = 42u;" in file.getvalue() | ||
|
||
|
||
def test_arith_constant_unsigned_neg(): | ||
file = StringIO("") | ||
|
||
cst = arith.ConstantOp(IntegerAttr(-1, IndexType())) | ||
cst.result.name_hint = "temp" | ||
|
||
printer = WGSLPrinter() | ||
printer.print(cst, file) | ||
|
||
assert "let vtemp : u32 = 4294967295u;" in file.getvalue() | ||
|
||
|
||
def test_arith_constant_signed(): | ||
file = StringIO("") | ||
|
||
cst = arith.ConstantOp(IntegerAttr(42, IntegerType(32))) | ||
cst.result.name_hint = "temp" | ||
|
||
printer = WGSLPrinter() | ||
printer.print(cst, file) | ||
|
||
assert "let vtemp : i32 = 42;" in file.getvalue() | ||
|
||
|
||
def test_arith_addi(): | ||
file = StringIO("") | ||
|
||
addi = arith.AddiOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(addi, file) | ||
|
||
assert "let v0 = v1 + v2;" in file.getvalue() | ||
|
||
|
||
def test_arith_subi(): | ||
file = StringIO("") | ||
|
||
subi = arith.SubiOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(subi, file) | ||
|
||
assert "let v0 = v1 - v2;" in file.getvalue() | ||
|
||
|
||
def test_arith_muli(): | ||
file = StringIO("") | ||
|
||
muli = arith.MuliOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(muli, file) | ||
|
||
assert "let v0 = v1 * v2;" in file.getvalue() | ||
|
||
|
||
def test_arith_addf(): | ||
file = StringIO("") | ||
|
||
addf = arith.AddfOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(addf, file) | ||
|
||
assert "let v0 = v1 + v2;" in file.getvalue() | ||
|
||
|
||
def test_arith_subf(): | ||
file = StringIO("") | ||
|
||
subf = arith.SubfOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(subf, file) | ||
|
||
assert "let v0 = v1 - v2;" in file.getvalue() | ||
|
||
|
||
def test_arith_mulf(): | ||
file = StringIO("") | ||
|
||
mulf = arith.MulfOp(lhs_op, rhs_op) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(mulf, file) | ||
|
||
assert "let v0 = v1 * v2;" in file.getvalue() | ||
|
||
|
||
def test_memref_load(): | ||
file = StringIO("") | ||
|
||
memref_type = memref.MemRefType(i32, [10, 10]) | ||
|
||
memref_val = TestSSAValue(memref_type) | ||
|
||
load = memref.LoadOp.get(memref_val, [lhs_op.res[0], rhs_op.res[0]]) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(load, file) | ||
|
||
assert "let v1 = v0[10u * v1 + 1u * v2];" in file.getvalue() | ||
|
||
|
||
def test_memref_store(): | ||
file = StringIO("") | ||
|
||
memref_type = memref.MemRefType(i32, [10, 10]) | ||
|
||
memref_val = TestSSAValue(memref_type) | ||
|
||
load = memref.LoadOp.get(memref_val, [lhs_op.res[0], rhs_op.res[0]]) | ||
|
||
store = memref.StoreOp.get(load.res, memref_val, [lhs_op.res[0], rhs_op.res[0]]) | ||
|
||
printer = WGSLPrinter() | ||
printer.print(store, file) | ||
|
||
assert "v1[10u * v1 + 1u * v2] = v0;" in file.getvalue() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
// RUN: xdsl-opt -t wgsl %s | filecheck %s | ||
|
||
builtin.module attributes {gpu.container_module} { | ||
"gpu.module"() ({ | ||
"gpu.func"() ({ | ||
^0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : memref<260x260xf32>, %arg4 : index, %arg5 : f32, %arg6 : f32, %arg7 : f32, %arg8 : f32, %arg9 : f32, %arg10 : f32, %arg11 : memref<260x260xf32>, %arg12: memref<260x260xindex>): | ||
%0 = "arith.constant"() {"value" = 2 : index} : () -> index | ||
%1 = "gpu.block_id"() {"dimension" = #gpu<dim x>} : () -> index | ||
%2 = "gpu.block_id"() {"dimension" = #gpu<dim y>} : () -> index | ||
%3 = "gpu.thread_id"() {"dimension" = #gpu<dim x>} : () -> index | ||
%4 = "gpu.thread_id"() {"dimension" = #gpu<dim y>} : () -> index | ||
%5 = arith.muli %1, %arg0 : index | ||
%6 = arith.addi %5, %arg1 : index | ||
%7 = arith.muli %2, %arg2 : index | ||
%8 = arith.addi %7, %arg1 : index | ||
%9 = arith.muli %3, %arg2 : index | ||
%10 = arith.addi %9, %arg1 : index | ||
%11 = arith.muli %4, %arg2 : index | ||
%12 = arith.addi %11, %arg1 : index | ||
%13 = arith.addi %10, %6 : index | ||
%14 = arith.addi %12, %8 : index | ||
%15 = arith.addi %14, %0 : index | ||
%16 = arith.addi %13, %0 : index | ||
%17 = "memref.load"(%arg3, %15, %16) {"nontemporal" = false} : (memref<260x260xf32>, index, index) -> f32 | ||
%18 = arith.addi %14, %arg4 : index | ||
%19 = arith.addi %18, %0 : index | ||
%20 = "memref.load"(%arg3, %19, %16) {"nontemporal" = false} : (memref<260x260xf32>, index, index) -> f32 | ||
%21 = arith.addi %14, %arg2 : index | ||
%22 = arith.addi %21, %0 : index | ||
%23 = "memref.load"(%arg3, %22, %16) {"nontemporal" = false} : (memref<260x260xf32>, index, index) -> f32 | ||
%24 = arith.addi %13, %arg4 : index | ||
%25 = arith.addi %24, %0 : index | ||
%26 = "memref.load"(%arg3, %15, %25) {"nontemporal" = false} : (memref<260x260xf32>, index, index) -> f32 | ||
%27 = arith.addi %13, %arg2 : index | ||
%28 = arith.addi %27, %0 : index | ||
%29 = "memref.load"(%arg3, %15, %28) {"nontemporal" = false} : (memref<260x260xf32>, index, index) -> f32 | ||
%30 = arith.mulf %17, %arg5 : f32 | ||
%31 = arith.mulf %20, %arg6 : f32 | ||
%32 = arith.mulf %23, %arg6 : f32 | ||
%33 = arith.mulf %17, %arg7 : f32 | ||
%34 = arith.addf %31, %32 : f32 | ||
%35 = arith.addf %34, %33 : f32 | ||
%36 = arith.mulf %26, %arg6 : f32 | ||
%37 = arith.mulf %29, %arg6 : f32 | ||
%38 = arith.addf %36, %37 : f32 | ||
%temp = arith.addf %38, %33 : f32 | ||
%40 = arith.addf %35, %temp : f32 | ||
%41 = arith.mulf %40, %arg8 : f32 | ||
%42 = arith.addf %30, %arg9 : f32 | ||
%43 = arith.addf %42, %41 : f32 | ||
%44 = arith.mulf %43, %arg10 : f32 | ||
"memref.store"(%44, %arg11, %15, %16) {"nontemporal" = false} : (f32, memref<260x260xf32>, index, index) -> () | ||
"gpu.return"() : () -> () | ||
}) {"function_type" = (index, index, index, memref<260x260xf32>, index, f32, f32, f32, f32, f32, f32, memref<260x260xf32>, memref<260x260xindex>) -> (), | ||
"gpu.kernel", "gpu.known_block_size" = array<i32: 128, 1, 1>, "gpu.known_grid_size" = array<i32: 2, 256, 1>, | ||
"sym_name" = "apply_kernel_kernel", | ||
"workgroup_attributions" = 0 : i64 | ||
} : () -> () | ||
"gpu.module_end"() : () -> () | ||
}) {"sym_name" = "apply_kernel_kernel"} : () -> () | ||
} | ||
|
||
// CHECK: @group(0) @binding(0) | ||
// CHECK-NEXT: var<storage,read> varg0: u32; | ||
|
||
// CHECK: @group(0) @binding(1) | ||
// CHECK-NEXT: var<storage,read> varg1: u32; | ||
|
||
// CHECK: @group(0) @binding(2) | ||
// CHECK-NEXT: var<storage,read> varg2: u32; | ||
|
||
// CHECK: @group(0) @binding(3) | ||
// CHECK-NEXT: var<storage,read> varg3: array<f32>; | ||
|
||
// CHECK: @group(0) @binding(4) | ||
// CHECK-NEXT: var<storage,read> varg4: u32; | ||
|
||
// CHECK: @group(0) @binding(5) | ||
// CHECK-NEXT: var<storage,read> varg5: f32; | ||
|
||
// CHECK: @group(0) @binding(6) | ||
// CHECK-NEXT: var<storage,read> varg6: f32; | ||
|
||
// CHECK: @group(0) @binding(7) | ||
// CHECK-NEXT: var<storage,read> varg7: f32; | ||
|
||
// CHECK: @group(0) @binding(8) | ||
// CHECK-NEXT: var<storage,read> varg8: f32; | ||
|
||
// CHECK: @group(0) @binding(9) | ||
// CHECK-NEXT: var<storage,read> varg9: f32; | ||
|
||
// CHECK: @group(0) @binding(10) | ||
// CHECK-NEXT: var<storage,read> varg10: f32; | ||
|
||
// CHECK: @group(0) @binding(11) | ||
// CHECK-NEXT: var<storage,read_write> varg11: array<f32>; | ||
|
||
// CHECK: @group(0) @binding(12) | ||
// CHECK-NEXT: var<storage,read> varg12: array<u32>; | ||
|
||
// CHECK: @compute | ||
// CHECK-NEXT: @workgroup_size(128,1,1) | ||
// CHECK-NEXT: fn apply_kernel_kernel(@builtin(global_invocation_id) global_invocation_id : vec3<u32>, | ||
// CHECK-NEXT: @builtin(workgroup_id) workgroup_id : vec3<u32>, | ||
// CHECK-NEXT: @builtin(local_invocation_id) local_invocation_id : vec3<u32>, | ||
// CHECK-NEXT: @builtin(num_workgroups) num_workgroups : vec3<u32>) { | ||
|
||
// CHECK: let v0 : u32 = 2u; | ||
// CHECK-NEXT: let v1: u32 = workgroup_id.x; | ||
// CHECK-NEXT: let v2: u32 = workgroup_id.y; | ||
// CHECK-NEXT: let v3: u32 = local_invocation_id.x; | ||
// CHECK-NEXT: let v4: u32 = local_invocation_id.y; | ||
// CHECK-NEXT: let v5 = v1 * varg0; | ||
// CHECK-NEXT: let v6 = v5 + varg1; | ||
// CHECK-NEXT: let v7 = v2 * varg2; | ||
// CHECK-NEXT: let v8 = v7 + varg1; | ||
// CHECK-NEXT: let v9 = v3 * varg2; | ||
// CHECK-NEXT: let v10 = v9 + varg1; | ||
// CHECK-NEXT: let v11 = v4 * varg2; | ||
// CHECK-NEXT: let v12 = v11 + varg1; | ||
// CHECK-NEXT: let v13 = v10 + v6; | ||
// CHECK-NEXT: let v14 = v12 + v8; | ||
// CHECK-NEXT: let v15 = v14 + v0; | ||
// CHECK-NEXT: let v16 = v13 + v0; | ||
// CHECK-NEXT: let v17 = varg3[260u * v15 + 1u * v16]; | ||
// CHECK-NEXT: let v18 = v14 + varg4; | ||
// CHECK-NEXT: let v19 = v18 + v0; | ||
// CHECK-NEXT: let v20 = varg3[260u * v19 + 1u * v16]; | ||
// CHECK-NEXT: let v21 = v14 + varg2; | ||
// CHECK-NEXT: let v22 = v21 + v0; | ||
// CHECK-NEXT: let v23 = varg3[260u * v22 + 1u * v16]; | ||
// CHECK-NEXT: let v24 = v13 + varg4; | ||
// CHECK-NEXT: let v25 = v24 + v0; | ||
// CHECK-NEXT: let v26 = varg3[260u * v15 + 1u * v25]; | ||
// CHECK-NEXT: let v27 = v13 + varg2; | ||
// CHECK-NEXT: let v28 = v27 + v0; | ||
// CHECK-NEXT: let v29 = varg3[260u * v15 + 1u * v28]; | ||
// CHECK-NEXT: let v30 = v17 * varg5; | ||
// CHECK-NEXT: let v31 = v20 * varg6; | ||
// CHECK-NEXT: let v32 = v23 * varg6; | ||
// CHECK-NEXT: let v33 = v17 * varg7; | ||
// CHECK-NEXT: let v34 = v31 + v32; | ||
// CHECK-NEXT: let v35 = v34 + v33; | ||
// CHECK-NEXT: let v36 = v26 * varg6; | ||
// CHECK-NEXT: let v37 = v29 * varg6; | ||
// CHECK-NEXT: let v38 = v36 + v37; | ||
// CHECK-NEXT: let vtemp = v38 + v33; | ||
// CHECK-NEXT: let v39 = v35 + vtemp; | ||
// CHECK-NEXT: let v40 = v39 * varg8; | ||
// CHECK-NEXT: let v41 = v30 + varg9; | ||
// CHECK-NEXT: let v42 = v41 + v40; | ||
// CHECK-NEXT: let v43 = v42 * varg10; | ||
// CHECK-NEXT: varg11[260u * v15 + 1u * v16] = v43; | ||
// CHECK-NEXT: } |
Oops, something went wrong.