diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index c84677d26a8b69d..ff738fc2555734a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4234,8 +4234,13 @@ class SPIRV_CoopMatrixOfType allowedTypes> : "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()", "Cooperative Matrix">; +class SPIRV_MatrixOfType allowedTypes> : + ContainerType, SPIRV_IsMatrixType, + "::llvm::cast<::mlir::spirv::MatrixType>($_self).getElementType()", + "Matrix">; + class SPIRV_VectorOf : - VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>; + VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; @@ -4248,6 +4253,9 @@ class SPIRV_MatrixOrCoopMatrixOf : AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>; +class SPIRV_MatrixOf : + SPIRV_MatrixOfType<[type]>; + def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; @@ -4387,7 +4395,8 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>; def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; -def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>; +def SPIRV_OC_OpVectorTimesMatrix : I32EnumAttrCase<"OpVectorTimesMatrix", 144>; +def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>; def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>; def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>; @@ -4559,7 +4568,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv, SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem, SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod, - SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector, + SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, + SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index 5bd99386e008582..f2796861cdf5617 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> { // ----- -def SPIRV_MatrixTimesScalarOp : SPIRV_Op< - "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> { +def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> { let summary = "Scale a floating-point matrix."; let description = [{ @@ -114,8 +113,11 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op< // ----- -def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> { - let summary = "Linear-algebraic multiply of matrix X vector."; +def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [ + Pure, + AllElementTypesMatch<["vector", "result"]> + ]> { + let summary = "Linear-algebraic Matrix X Vector."; let description = [{ Result Type must be a vector of floating-point type. @@ -140,12 +142,12 @@ def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> { ]; let arguments = (ins - SPIRV_AnyMatrix:$matrix, - SPIRV_AnyVector:$vector + SPIRV_MatrixOf:$matrix, + SPIRV_VectorOf:$vector ); let results = (outs - SPIRV_AnyVector:$result + SPIRV_VectorOf:$result ); let assemblyFormat = [{ @@ -198,4 +200,53 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> { // ----- +def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [ + Pure, + AllElementTypesMatch<["vector", "result"]> + ]> { + let summary = "Linear-algebraic Vector X Matrix."; + + let description = [{ + Result Type must be a vector of floating-point type. + + Vector must be a vector with the same Component Type as the Component + Type in Result Type. Its number of components must equal the number of + components in each column in Matrix. + + Matrix must be a matrix with the same Component Type as the Component + Type in Result Type. Its number of columns must equal the number of + components in Result Type. + + + + #### Example: + + ```mlir + %result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPIRV_C_Matrix]> + ]; + + let arguments = (ins + SPIRV_VectorOf:$vector, + SPIRV_MatrixOf:$matrix + ); + + let results = (outs + SPIRV_VectorOf:$result + ); + + let assemblyFormat = [{ + operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result) + }]; +} + +// ----- + #endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 040bf6a34cea781..f0f03e989cb4750 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1717,10 +1717,32 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() { << resultType.getNumElements() << ") must match the matrix rows (" << matrixType.getNumRows() << ")"; - auto matrixElementType = matrixType.getElementType(); - if (matrixElementType != vectorType.getElementType() || - matrixElementType != resultType.getElementType()) - return emitOpError("matrix, vector, and result element types must match"); + if (matrixType.getElementType() != resultType.getElementType()) + return emitOpError("matrix and result element types must match"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.VectorTimesMatrix +//===----------------------------------------------------------------------===// + +LogicalResult spirv::VectorTimesMatrixOp::verify() { + auto vectorType = llvm::cast(getVector().getType()); + auto matrixType = llvm::cast(getMatrix().getType()); + auto resultType = llvm::cast(getType()); + + if (matrixType.getNumRows() != vectorType.getNumElements()) + return emitOpError("number of components in vector must equal the number " + "of components in each column in matrix"); + + if (resultType.getNumElements() != matrixType.getNumColumns()) + return emitOpError("number of columns in matrix must equal the number of " + "components in result"); + + if (matrixType.getElementType() != resultType.getElementType()) + return emitOpError("matrix must be a matrix with the same component type " + "as the component type in result"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir index 37e7514d664ef0e..ba95322dbf38bbc 100644 --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.ReturnValue %result : vector<4xf32> } + // CHECK-LABEL: @vector_times_matrix_1 + spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" { + // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32> + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32> + spirv.ReturnValue %result : vector<4xf32> + } + // CHECK-LABEL: @matrix_times_matrix_1 spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{ // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>> @@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 return } - // ----- func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){ @@ -135,7 +141,7 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 // ----- func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) { - // expected-error @+1 {{matrix, vector, and result element types must match}} + // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}} %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32> return } @@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3 %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32> return } + +// ----- + +func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) { + // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}} + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32> + return +} + +// ----- + +func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) { + // expected-error @+1 {{number of columns in matrix must equal the number of components in result}} + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32> + return +} + +// ----- + +func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1: !spirv.matrix<4 x vector<3xf32>>) { + // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}} + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf16>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32> + return +} + +// ----- + +func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) { + // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}} + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32> + return +} diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir index 0ec1dc27e4e9323..452f8fc16f2588a 100644 --- a/mlir/test/Target/SPIRV/matrix.mlir +++ b/mlir/test/Target/SPIRV/matrix.mlir @@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce { %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32> spirv.ReturnValue %result : vector<4xf32> } + + // CHECK-LABEL: @vector_times_matrix_1 + spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" { + // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32> + %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32> + spirv.ReturnValue %result : vector<4xf32> + } // CHECK-LABEL: @matrix_times_matrix_1 spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{