llvm/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s

//===----------------------------------------------------------------------===//
// CooperativeMatrix (KHR) extension ops.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cooperative_matrix_length
spirv.func @cooperative_matrix_length() -> i32 "None" {
  // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.ReturnValue %0 : i32
}

// -----

// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_memoperand
spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
    !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_function
spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
  // CHECK-SAME:   !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
    !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_stride_i16
spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
    !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_memoperand
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
                                                %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                                %stride : i32) "None" {
  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
  spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_stride_i16
spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
                                                %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                                %stride : i16) "None" {
  // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
  // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // expected-error @+1 {{expected ','}}
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                  %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
  // expected-error @+1 {{expected ','}}
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                  %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
  // expected-error @+1 {{expected '<'}}
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
                                                     %stride : i32) "None" {
  // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
  spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
    !spirv.ptr<i32, StorageBuffer>, i32, i32
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned> :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned | AccSat> :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
                                          %b : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB>,
                                          %c : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>) "None" {
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
        !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
  spirv.Return
}

spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                             %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
                                             %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

spirv.func @cooperative_matrix_muladd_i8_i16_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                 %b : !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB>,
                                                 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
                                                %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
                                                %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
        !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB> ->
          !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #0 must be of use 'MatrixA'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
  // expected-error @+1 {{expected ','}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
  // expected-error @+1 {{expected SSA operand}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, <ASigned> :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{expected '<'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #1 must be of use 'MatrixB'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #2 must be of use 'MatrixAcc'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'M'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'N'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'K'}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                      %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                      %c : !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix scope mismatch}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
        !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
                                                      %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
                                                      %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op Matrix Operands require all matrix element types to be Integer Types}}
  %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
        !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
        !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB> ->
          !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// Standard ops that can be used CooperativeMatrix types
//===----------------------------------------------------------------------===//

!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>

!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>

// These tests are kept in the same order as the list of compatible ops in the
// SPV_KHR_cooperative_matrix extension spec.

// CHECK-LABEL: @snegate
spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
  %p = spirv.SNegate %a : !matA_i32
  %q = spirv.SNegate %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @fnegate
spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
  %p = spirv.FNegate %a : !matA_f32
  %q = spirv.FNegate %b : !matB_f32
  spirv.Return
}

// CHECK-LABEL: @iadd
spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.IAdd %a, %a : !matA_i32
  %q = spirv.IAdd %b, %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @fadd
spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.FAdd %a, %a : !matA_f32
  %q = spirv.FAdd %b, %b : !matB_f32
  spirv.Return
}

// CHECK-LABEL: @isub
spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.ISub %a, %a : !matA_i32
  %q = spirv.ISub %b, %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @fsub
spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.FSub %a, %a : !matA_f32
  %q = spirv.FSub %b, %b : !matB_f32
  spirv.Return
}

// CHECK-LABEL: @fmul
spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.FMul %a, %a : !matA_f32
  %q = spirv.FMul %b, %b : !matB_f32
  spirv.Return
}

// CHECK-LABEL: @imul
spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.IMul %a, %a : !matA_i32
  %q = spirv.IMul %b, %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @fdiv
spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.FDiv %a, %a : !matA_f32
  %q = spirv.FDiv %b, %b : !matB_f32
  spirv.Return
}

// CHECK-LABEL: @sdiv
spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.SDiv %a, %a : !matA_i32
  %q = spirv.SDiv %b, %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @udiv
spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
  %p = spirv.UDiv %a, %a : !matA_i32
  %q = spirv.UDiv %b, %b : !matB_i32
  spirv.Return
}

// CHECK-LABEL: @matrix_times_scalar
spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
  // CHECK: spirv.MatrixTimesScalar {{%.*}} : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, f32
  %p = spirv.MatrixTimesScalar %a, %b : !matA_f32, f32
  spirv.Return
}

// -----

// For binary arithmetic instructions with coop matrix operands, the types must
// match.

spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
                 %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
  // expected-error @+1 {{op requires the same type for all operands and results}}
  %q = "spirv.IAdd"(%a, %b) :
    (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
    -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
  spirv.Return
}

// -----

spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
                 %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
  // expected-error @+1 {{op requires the same type for all operands and results}}
  %q = "spirv.FAdd"(%a, %b) :
    (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
    -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
  spirv.Return
}

// -----

spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
  %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
  spirv.Return
}