Skip to content

Commit

Permalink
[mlir][spirv] Add OpenCL fma op and lowering
Browse files Browse the repository at this point in the history
Also, it seems Khronos has changed html spec format so small adjustment to script was needed.
Base op parsing is also probably broken.

Differential Revision: https://reviews.llvm.org/D119678
  • Loading branch information
Hardcode84 committed Feb 15, 2022
1 parent 290e482 commit 32389d0
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 14 deletions.
79 changes: 69 additions & 10 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,46 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
let assemblyFormat = "operands attr-dict `:` type($result)";
}

// Base class for OpenCL binary ops.
class SPV_OCLTernaryOp<string mnemonic, int opcode, Type resultType,
Type operandType, list<Trait> traits = []> :
SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {

let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$x,
SPV_ScalarOrVectorOf<operandType>:$y,
SPV_ScalarOrVectorOf<operandType>:$z
);

let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result
);

let hasVerifier = 0;
}

// Base class for OpenCL Ternary arithmetic ops where operand types and
// return type matches.
class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
list<Trait> traits = []> :
SPV_OCLTernaryOp<mnemonic, opcode, type, type,
traits # [SameOperandsAndResultType]> {
let assemblyFormat = "operands attr-dict `:` type($result)";
}


// -----

def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
let summary = [{
Error function of x encountered in integrating the normal distribution.
Compute the correctly rounded floating-point representation of the sum
of c with the infinitely precise product of a and b. Rounding of
intermediate products shall not occur. Edge case results are per the
IEEE 754-2008 standard.
}];

let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
floating-point values.

All of the operands, including the Result Type operand, must be of the
Expand All @@ -99,17 +130,13 @@ def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
<!-- End of AutoGen section -->

```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
float-scalar-vector-type
```mlir

#### Example:

```
%2 = spv.OCL.erf %0 : f32
%3 = spv.OCL.erf %1 : vector<3xf16>
%0 = spv.OCL.fma %a, %b, %c : f32
%1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
```
}];
}
Expand Down Expand Up @@ -179,6 +206,38 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {

// -----

def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
let summary = [{
Error function of x encountered in integrating the normal distribution.
}];

let description = [{
Result Type and x must be floating-point or vector(2,3,4,8,16) of
floating-point values.

All of the operands, including the Result Type operand, must be of the
same type.

<!-- End of AutoGen section -->

```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
float-scalar-vector-type
```mlir

#### Example:

```
%2 = spv.OCL.erf %0 : f32
%3 = spv.OCL.erf %1 : vector<3xf16>
```
}];
}

// -----

def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> {
let summary = "Exponentiation of Operand 1";

Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());

// OpenCL patterns
Expand All @@ -109,6 +109,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}

// CHECK-LABEL: @float32_ternary_scalar
// CHECK-LABEL: @float32_ternary_scalar
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.GLSL.Fma %{{.*}}: f32
%0 = math.fma %a, %b, %c : f32
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}

// CHECK-LABEL: @float32_ternary_scalar
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.OCL.fma %{{.*}}: f32
%0 = math.fma %a, %b, %c : f32
return
}

// CHECK-LABEL: @float32_ternary_vector
func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
%c: vector<4xf32>) {
// CHECK: spv.OCL.fma %{{.*}}: vector<4xf32>
%0 = math.fma %a, %b, %c : vector<4xf32>
return
}

} // end module
19 changes: 19 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,22 @@ func @sabs(%arg0 : i32) -> () {
return
}

// -----

//===----------------------------------------------------------------------===//
// spv.OCL.fma
//===----------------------------------------------------------------------===//

func @fma(%a : f32, %b : f32, %c : f32) -> () {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%2 = spv.OCL.fma %a, %b, %c : f32
return
}

// -----

func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
%2 = spv.OCL.fma %a, %b, %c : vector<3xf32>
return
}
6 changes: 6 additions & 0 deletions mlir/test/Target/SPIRV/ocl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
%0 = spv.OCL.fabs %arg0 : vector<16xf32>
spv.Return
}

spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
// CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%13 = spv.OCL.fma %arg0, %arg1, %arg2 : f32
spv.Return
}
}
2 changes: 1 addition & 1 deletion mlir/utils/spirv/gen_spirv_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_spirv_doc_from_html_spec(url, settings):
doc = {}

if settings.gen_ocl_ops:
section_anchor = spirv.find('h2', {'id': '_a_id_binary_a_binary_form'})
section_anchor = spirv.find('h2', {'id': '_binary_form'})
for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
for table in section.find_all('table'):
inst_html = table.tbody.tr.td
Expand Down

0 comments on commit 32389d0

Please sign in to comment.