//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the op definition spec of cooperative matrix multiply extension ops.
// We support both cooperative matrix extensions:
// - SPV_NV_cooperative_matrix
// - SPV_KHR_cooperative_matrix
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
// -----
def SPIRV_KHRCooperativeMatrixLengthOp :
SPIRV_KhrVendorOp<"CooperativeMatrixLength", [Pure]> {
let summary = "Queries the number of cooperative matrix components";
let description = [{
Number of components of a cooperative matrix type accessible to each
invocation when treated as a composite.
The type attribute must be a cooperative matrix type.
#### Example:
```
%0 = spirv.KHR.CooperativeMatrixLength :
!spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
```
}];
let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
let arguments = (ins
TypeAttrOf<SPIRV_AnyCooperativeMatrix>:$cooperative_matrix_type
);
let results = (outs
SPIRV_Int32:$result
);
let hasVerifier = false;
}
// -----
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
let summary = "Loads a cooperative matrix through a pointer";
let description = [{
Load a cooperative matrix through a pointer.
Result Type is the type of the loaded object. It must be a cooperative
matrix type.
Pointer is a pointer. Its type must be an OpTypePointer whose Type operand is
a scalar or vector type. If the Shader capability was declared, Pointer must
point into an array and any ArrayStride decoration on Pointer is ignored.
MemoryLayout specifies how matrix elements are laid out in memory. It must
come from a 32-bit integer constant instruction whose value corresponds to a
Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
description of the layouts and detailed layout-specific rules.
Stride further qualifies how matrix elements are laid out in memory. It must
be a scalar integer type and its exact semantics depend on MemoryLayout.
Memory Operand must be a Memory Operand literal. If not present, it is the
same as specifying None.
NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
as 'Memory Access'.
For a given dynamic instance of this instruction, all operands of this
instruction must be the same for all invocations in a given scope instance
(where the scope is the scope the cooperative matrix type was created with).
All invocations in a given scope instance must be active or all must be
inactive.
TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
support this optionality in the SPIR-V dialect.
#### Example:
```
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>
: !spirv.ptr<i32, StorageBuffer>, i32
-> !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
%1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile>
: !spirv.ptr<f32, StorageBuffer>, i64
-> !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>
```
}];
let assemblyFormat = [{
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
type(operands) `->` type($result)
}];
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
);
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);
let builders = [
OpBuilder<(ins "Type":$result, "Value":$pointer,
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, result, pointer, layout, stride,
spirv::MemoryAccessAttr{});
}]>
];
}
// -----
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
let summary = "Stores a cooperative matrix through a pointer";
let description = [{
Store a cooperative matrix through a pointer.
Pointer is a pointer. Its type must be an OpTypePointer whose Type operand
is a scalar or vector type. If the Shader capability was declared, Pointer
must point into an array and any ArrayStride decoration on Pointer is
ignored.
Object is the object to store. Its type must be an
OpTypeCooperativeMatrixKHR.
MemoryLayout specifies how matrix elements are laid out in memory. It must
come from a 32-bit integer constant instruction whose value corresponds to a
Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
description of the layouts and detailed layout-specific rules.
Stride further qualifies how matrix elements are laid out in memory. It must
be a scalar integer type and its exact semantics depend on MemoryLayout.
Memory Operand must be a Memory Operand literal. If not present, it is the
same as specifying None.
NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
as 'Memory Access'.
For a given dynamic instance of this instruction, all operands of this
instruction must be the same for all invocations in a given scope instance
(where the scope is the scope the cooperative matrix type was created with).
All invocations in a given scope instance must be active or all must be
inactive.
TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
support this optionality in the SPIR-V dialect.
#### Example:
```
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <ColumnMajor>, <Volatile> :
!spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
```
}];
let assemblyFormat = [{
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
type(operands)
}];
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_AnyCooperativeMatrix:$object,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
);
let results = (outs);
let builders = [
OpBuilder<(ins "Value":$pointer, "Value":$object,
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, pointer, object, layout, stride,
spirv::MemoryAccessAttr{});
}]>
];
}
// -----
def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMulAdd",
[Pure, AllTypesMatch<["c", "result"]>]> {
let summary = "Returns the result of `(A x B) + C` of matrices A, B, and C";
let description = [{
Linear-algebraic matrix multiply of A by B and then component-wise add C.
The order of the operations is implementation-dependent. The internal
precision of floating-point operations is defined by the client API. Integer
operations used in the multiplication of A by B are performed at the
precision of the Result Type and the resulting value will equal the
low-order N bits of the correct result R, where N is the result width and R
is computed with enough precision to avoid overflow and underflow if the
SaturatingAccumulation Cooperative Matrix Operand is not present. If the
SaturatingAccumulation Cooperative Matrix Operand is present and overflow or
underflow occurs as part of calculating that intermediate result, the result
of the instruction is undefined. Integer additions of the elements of that
intermediate result with those of C are performed at the precision of Result
Type, are exact, and are saturating if the SaturatingAccumulation
Cooperative Matrix Operand is present, with the signedness of the saturation
being that of the components of Result Type. If the SaturatingAccumulation
Cooperative Matrix Operand is not present then the resulting value will
equal the low-order N bits of the correct result R, where N is the result
width and R is computed with enough precision to avoid overflow and
underflow.
Result Type must be a cooperative matrix type with M rows and N columns
whose Use must be MatrixAccumulatorKHR.
A is a cooperative matrix with M rows and K columns whose Use must be
MatrixAKHR.
B is a cooperative matrix with K rows and N columns whose Use must be
MatrixBKHR.
C is a cooperative matrix with M rows and N columns whose Use must be
MatrixAccumulatorKHR.
The values of M, N, and K must be consistent across the result and operands.
This is referred to as an MxNxK matrix multiply.
A, B, C, and Result Type must have the same scope, and this defines the
scope of the operation. A, B, C, and Result Type need not necessarily have
the same component type, this is defined by the client API.
If the Component Type of any matrix operand is an integer type, then its
components are treated as signed if the Matrix{A,B,C,Result}SignedComponents
Cooperative Matrix Operand is present and are treated as unsigned otherwise.
Cooperative Matrix Operands is an optional Cooperative Matrix Operand
literal. If not present, it is the same as specifying the Cooperative Matrix
Operand None.
For a given dynamic instance of this instruction, all invocations in a given
scope instance must be active or all must be inactive (where the scope is
the scope of the operation).
``` {.ebnf}
cooperative-matrixmuladd-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixMulAdd`
ssa-use `,` ssa-use `,` ssa-use
(`<` matrix-operands `>`)? `:`
a-cooperative-matrix-type `,`
b-cooperative-matrix-type `->`
result-cooperative-matrix-type
```
#### Example:
```
%0 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC :
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
%1 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC, <ASigned | AccSat> :
!spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
!spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
!spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
```
}];
let assemblyFormat = [{
$a `,` $b `,` $c ( `,` $matrix_operands^ )? attr-dict `:`
type($a) `,` type($b) `->` type($c)
}];
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
let arguments = (ins
SPIRV_AnyCooperativeMatrix:$a,
SPIRV_AnyCooperativeMatrix:$b,
SPIRV_AnyCooperativeMatrix:$c,
OptionalAttr<SPIRV_KHR_CooperativeMatrixOperandsAttr>:$matrix_operands
);
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
build($_builder, $_state, a, b, c,
spirv::CooperativeMatrixOperandsKHRAttr{});
}]>
];
}
// -----
#endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS