llvm/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for the structured interface sfor Linalg ops.
//
//===----------------------------------------------------------------------===//

#ifndef LINALG_IR_LINALGINTERFACES
#define LINALG_IR_LINALGINTERFACES

include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/IR/OpBase.td"

// The 'LinalgContractionOpInterface' provides access to the
// 'ContractionOpInterface'.
def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
  let description = [{
   A Linalg contraction is defined in general terms:
     1. Has 2 input and 1 output shapes.
     2. Has at least one reduction dimension.
     3. Has only projected permutation indexing maps.
     4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
     (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
     operations that may change the type (e.g. for mixed-precision).
   As a consequence, when vectorization of such an op occurs, the only special
   behavior is that the (unique) MulOpType is vectorized into a
   `vector.contract`. All other ops are handled in a generic fashion.
   In the future, we may wish to allow more input arguments and elementwise and
   constant operations that do not involve the reduction dimension(s).
  }];
  let cppNamespace = "::mlir::linalg";
  let verify = [{ return detail::verifyContractionInterface($_op); }];
  let verifyWithRegions = 1;
  let methods = [
    InterfaceMethod<
    /*desc=*/"Returns the left-hand side operand.",
    /*retTy=*/"Value",
    /*methodName=*/"lhs",
    /*args=*/(ins),
    /*methodBody=*/[{
      return $_op.getOperation()->getOperand(0);
    }]>,
    InterfaceMethod<
    /*desc=*/"Returns the right-hand side operand.",
    /*retTy=*/"Value",
    /*methodName=*/"rhs",
    /*args=*/(ins),
    /*methodBody=*/[{
      return $_op.getOperation()->getOperand(1);
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      row-major matmul operation.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isRowMajorMatmul",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isRowMajorMatmul($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      column-major matmul operation.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isColumnMajorMatmul",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isColumnMajorMatmul($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      row-major batch matmul operation.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isRowMajorBatchMatmul",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      vector-matrix multiplication.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isVecmat",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isVecmat($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      batched vector-matrix multiplication.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isBatchVecmat",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isBatchVecmat($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      matrix-vector multiplication.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isMatvec",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isMatvec($_op.getIndexingMaps());
    }]>,
    InterfaceMethod<
    /*desc=*/[{
      Returns whether the given op has indexing maps that correspond to a
      batched matrix-vector multiplication.
    }],
    /*retTy=*/"bool",
    /*methodName=*/"isBatchMatvec",
    /*args=*/(ins),
    /*methodBody=*/[{
        return mlir::isBatchMatvec($_op.getIndexingMaps());
    }]>,
  ];
}

def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
  let description = [{
    A convolution is defined in general terms:
    1. Has an `image` and a `filter` operand.
    2. Has one `output` operand.
    3. The indexing maps of the input have expressions that satisfy
    ```
       AffineExpr ::== AffineDimExpr | ConvolvedExpr
       ConvolvedExpr ::== MulExpr (`+` MulExpr)+
       MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
    ```
    4. The filter and the output have projected permutation maps.
    5. Each of the loops can be qualified as one of,
       - Loop over batch dimension,
       - Loop over output image dimensions,
       - Loop over output channel dimensions,
       - Loop over convolved filter dimensions,
       - Loop over input channel dimension.
  }];
  let cppNamespace = "::mlir::linalg";
  let verify = [{ return detail::verifyConvolutionInterface($_op); }];
  let methods = [
    InterfaceMethod<
      /*desc=*/"Return the image operand.",
      /*retTy=*/"Value",
      /*methodName=*/"image",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getOperation()->getOperand(0);
      }]
    >,
    InterfaceMethod<
      /*desc=*/"Return the filter operand.",
      /*retTy=*/"Value",
      /*methodName=*/"filter",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getOperation()->getOperand(1);
      }]
    >,
  ];
}

def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
  let description = [{
    A fill operation is defined in general terms:
    1. Has a scalar `value` operand.
    2. Has one `output` operand.
  }];
  let cppNamespace = "::mlir::linalg";
  let verify = [{ return detail::verifyFillInterface($_op); }];
  let methods = [
    InterfaceMethod<
      /*desc=*/"Return the fill value.",
      /*retTy=*/"Value",
      /*methodName=*/"value",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getOperation()->getOperand(0);
      }]
    >,
    InterfaceMethod<
      /*desc=*/"Return the output operand.",
      /*retTy=*/"Value",
      /*methodName=*/"output",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getOperation()->getOperand(1);
      }]
    >,
    InterfaceMethod<
      /*desc=*/"Return the result.",
      /*retTy=*/"Value",
      /*methodName=*/"result",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        if ($_op.getOperation()->getResults().empty())
          return nullptr;
        return $_op.getOperation()->getResults().front();
      }]
    >,
  ];
}

// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
    : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
  let cppNamespace = "::mlir::linalg";
  let methods = [
    //===------------------------------------------------------------------===//
    // Loop types handling.
    //===------------------------------------------------------------------===//
    InterfaceMethod<
      /*desc=*/[{
        Return the number of parallel loops.
      }],
      /*retTy=*/"unsigned",
      /*methodName=*/"getNumParallelLoops",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return llvm::count($_op.getIteratorTypesArray(),
                           utils::IteratorType::parallel);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the dims that are parallel loops.
      }],
      /*retTy=*/"void",
      /*methodName=*/"getParallelDims",
      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return findPositionsOfType($_op.getIteratorTypesArray(),
                                   utils::IteratorType::parallel, res);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the number of reduction loops.
      }],
      /*retTy=*/"unsigned",
      /*methodName=*/"getNumReductionLoops",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return llvm::count($_op.getIteratorTypesArray(),
                           utils::IteratorType::reduction);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the dims that are reduction loops.
      }],
      /*retTy=*/"void",
      /*methodName=*/"getReductionDims",
      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return findPositionsOfType($_op.getIteratorTypesArray(),
                                   utils::IteratorType::reduction, res);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the total number of loops within the current operation.
      }],
      /*retTy=*/"unsigned",
      /*methodName=*/"getNumLoops",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getIteratorTypesArray().size();
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Returns true if the current operation has only one loop and it's a
        reduction loop.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"hasSingleReductionLoop",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        auto iters = $_op.getIteratorTypesArray();
        return iters.size() == 1 &&
               llvm::count(iters, utils::IteratorType::reduction) == 1;
      }]>,
    //===------------------------------------------------------------------===//
    // Input and Init arguments handling.
    //===------------------------------------------------------------------===//
    InterfaceMethod<
      /*desc=*/[{
        Return true if the payload uses the value loaded from `opOperand`. This
        is useful to avoid loading from "write-only" memory that may be
        uninitialized, as well as properly cloning "read-write" operands.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"payloadUsesValueFromOperand",
      /*args=*/(ins "OpOperand *":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        unsigned bbArgNumber = opOperand->getOperandNumber();
        // Init tensors have uses.
        return !getBlock()->getArgument(bbArgNumber).use_empty();
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return true if `opOperand` is an init tensor. This is true when it is
        an output tensor operand whose value is used in the payload region.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"isInitTensor",
      /*args=*/(ins "OpOperand *":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        if (!$_op.isDpsInit(opOperand))
          return false;
        return payloadUsesValueFromOperand(opOperand);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
      }],
      /*retTy=*/"int64_t",
      /*methodName=*/"getRank",
      /*args=*/(ins "OpOperand*":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(opOperand->getOwner() == this->getOperation());
        Type t = opOperand->get().getType();
        // A VectorType is an elemental type, do not consider its rank for the operand.
        if (isa<VectorType>(t))
          return 0;
        // Tensor and Memref container types have a rank.
        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
          // Failsafe.
          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
                 "expected a ranked tensor or memref in LinalgInterface::getRank");
          return shapedType.getRank();
        }
        return 0;
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the input block arguments of the region.
      }],
      /*retTy=*/"Block::BlockArgListType",
      /*methodName=*/"getRegionInputArgs",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the output block arguments of the region.
      }],
      /*retTy=*/"Block::BlockArgListType",
      /*methodName=*/"getRegionOutputArgs",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return getBlock()->getArguments().take_back($_op.getNumDpsInits());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the `opOperand` shape or an empty vector for scalars or vectors
        not wrapped within a tensor or a memref.
      }],
      /*retTy=*/"ArrayRef<int64_t>",
      /*methodName=*/"getShape",
      /*args=*/(ins "OpOperand*":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(opOperand->getOwner() == this->getOperation());
        Type t = opOperand->get().getType();
        // A VectorType is an elemental type, do not consider its rank for the operand.
        if (isa<VectorType>(t))
          return {};
        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
          // Failsafe.
          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
                 "expected a ranked tensor or memref in LinalgInterface::getRank");
          return shapedType.getShape();
        }
        return {};
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the block argument for an `opOperand`.
      }],
      /*retTy=*/"BlockArgument",
      /*methodName=*/"getMatchingBlockArgument",
      /*args=*/(ins "OpOperand *":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(opOperand->getOwner() == this->getOperation());
        return getBlock()->getArgument(opOperand->getOperandNumber());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the operand for a `blockArgument`.
      }],
      /*retTy=*/"OpOperand *",
      /*methodName=*/"getMatchingOpOperand",
      /*args=*/(ins "BlockArgument":$blockArgument),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(blockArgument.getOwner() == getBlock());
        return &this->getOperation()->getOpOperand(
            blockArgument.getArgNumber());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the input or output indexing map for `opOperand`.
      }],
      /*retTy=*/"AffineMap",
      /*methodName=*/"getMatchingIndexingMap",
      /*args=*/(ins "OpOperand*":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(opOperand->getOwner() == this->getOperation());
        auto indexingMaps =
          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
        return *(indexingMaps.begin() + opOperand->getOperandNumber());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the indexing map for a `result`.
      }],
      /*retTy=*/"AffineMap",
      /*methodName=*/"getIndexingMapMatchingResult",
      /*args=*/(ins "OpResult":$result),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(result.getOwner() == this->getOperation());
        auto indexingMaps =
          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
        return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
                 result.getResultNumber());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the value yielded by the region corresponding to an output
        `opOperand`.
      }],
      /*retTy=*/"OpOperand *",
      /*methodName=*/"getMatchingYieldValue",
      /*args=*/(ins "OpOperand*":$opOperand),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        assert(opOperand->getOwner() == this->getOperation());
        int64_t resultIndex =
            opOperand->getOperandNumber() - $_op.getNumDpsInputs();
        assert(resultIndex >= 0 &&
               resultIndex < this->getOperation()->getNumResults());
        Operation *yieldOp = getBlock()->getTerminator();
        return &yieldOp->getOpOperand(resultIndex);
      }]
    >,
    //===------------------------------------------------------------------===//
    // Other interface methods.
    //===------------------------------------------------------------------===//
    InterfaceMethod<
      /*desc=*/[{
        Return the single block constituting the body of the operation by
        calling the getBody method on the concrete operation.
      }],
      /*retTy=*/"Block*",
      /*methodName=*/"getBlock",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        // Assume the concrete operation implements the
        // SingleBlockImplicitTerminator trait.
        return $_op.getBody();
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return iterator types in the current operation.

        Default implementation assumes that the operation has an attribute
        `iterator_types`, but it's not always the case. Sometimes iterator types
        can be infered from other parameters and in such cases default
        getIteratorTypesArray should be overriden.
      }],
      /*retTy=*/"SmallVector<utils::IteratorType>",
      /*methodName=*/"getIteratorTypesArray",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        auto range = $_op.getIteratorTypes()
                         .template getAsValueRange<IteratorTypeAttr,
                                                   utils::IteratorType>();
        return {range.begin(), range.end()};
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return true if the indexing map is depending on the current op instance.
        This means that the indexing map is dynamically synthesized by using the
        op instance's concrete attributes, instead of being static for all
        instances of the same op kind.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"hasDynamicIndexingMaps",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{ return false; }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Verify all attributes used by indexing maps are valid.
      }],
      /*retTy=*/"LogicalResult",
      /*methodName=*/"verifyIndexingMapRequiredAttributes",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{ return success(); }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the indexing maps attribute within the current operation.
      }],
      /*retTy=*/"ArrayAttr",
      /*methodName=*/"getIndexingMaps"
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the indexing maps within the current operation.
      }],
      /*retTy=*/"SmallVector<AffineMap>",
      /*methodName=*/"getIndexingMapsArray",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        auto range = $_op.getIndexingMaps()
          .template getAsValueRange<AffineMapAttr>();
        return {range.begin(), range.end()};
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return true if any of the operands has a dynamic shape.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"hasDynamicShape",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return the name registered for this op when lowering to an external
        library call.
      }],
      /*retTy=*/"std::string",
      /*methodName=*/"getLibraryCallName",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return $_op.getLibraryCallName();
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
         Return whether the op accesses the iteration indices.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"hasIndexSemantics",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/""
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return op operands that have a corresponding argument in the basic block.
        By default, the block should have an argument for each operand, but there
        are expection. For example, in `map` output operand isn't used in
        the block.
      }],
      /*retTy=*/"::llvm::SmallVector<OpOperand *>",
      /*methodName=*/"getOpOperandsMatchingBBargs",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        ::llvm::SmallVector<OpOperand *> result;
        result.reserve($_op->getNumOperands());
        llvm::transform(
          this->getOperation()->getOpOperands(),
          std::back_inserter(result),
          [](OpOperand &opOperand) { return &opOperand; });
        return result;
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Given a dimension of the iteration space of a Linalg operation, finds an
        operand in the operation that is defined on such dimension. Returns
        whether such operand was found or not. If found, also returns the
        operand value and the dimension position within the operand.
      }],
      /*retTy=*/"LogicalResult",
      /*methodName=*/"mapIterationSpaceDimToOperandDim",
      /*args=*/(ins "unsigned":$dimPos,
                    "::mlir::Value &":$operand,
                    "unsigned &":$operandDimPos),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        // Retrieve the operand and its dimension position from the first
        // operand with a permutation map that is defined on such dimension.
        for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
          if (idxMap.isProjectedPermutation()) {
            if (auto mayOperandDim = idxMap.getResultPosition(
                getAffineDimExpr(dimPos, idxMap.getContext()))) {
              operand = $_op->getOperand(i);
              operandDimPos = *mayOperandDim;
              return success();
            }
          }
        }

        return failure();
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Given a dimension of the iteration space of a Linalg operation, finds
        all the operands in the operation that are defined on such dimension.
        Returns all the operand values found and their dimension positions in
        `operandDimPairs`.
      }],
      /*retTy=*/"void",
      /*methodName=*/"mapIterationSpaceDimToAllOperandDims",
      /*args=*/(ins "unsigned":$dimPos,
                    "mlir::SmallVectorImpl<std::pair<Value, unsigned>>&":$operandDimPairs),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
          if (idxMap.isProjectedPermutation()) {
            if (auto mayOperandDim = idxMap.getResultPosition(
                getAffineDimExpr(dimPos, idxMap.getContext()))) {
              operandDimPairs.push_back({$_op->getOperand(i), *mayOperandDim});
            }
          }
        }

        return;
      }]
    >,
    //===------------------------------------------------------------------===//
    // Linalg generalization hooks.
    //===------------------------------------------------------------------===//
    InterfaceMethod<
      /*desc=*/[{
        Hook to provide a custom AffineMap used to compute all the operand
        subshapes given loop bounds. This is used to answer the question: "given
        an iteration space over the codomain, what are the subshapes of the
        operands involved in the computation".
        The default behavior is to just concatenate all the indexing maps.
        A custom AffineMap allows providing a map that can be used to
        compute subshapes even in cases where the concatenation of indexing maps
        (i.e. the data traversal order) is not a simple permutation of the loop
        traversal order. It is then possible to define ops with skewed data
        traversal order for which we can still easily compute hyperrectangular
        loop bounds and subviews.
      }],
      /*retTy=*/"AffineMap",
      /*methodName=*/"getLoopsToShapesMap",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        auto maps =  $_op.getIndexingMapsArray();
        return concatAffineMaps(maps);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Hook to provide a custom AffineMap used to construct the
        hyperrectangular loop iteration space given all the operand subshapes.
        This is used to answer the question:
        "Given a list of operand ranges, what is the subportion of the iteration
        space involved in the computation".
        This is the inverse problem of `getLoopsToShapesMap`.
        Return the empty AffineMap when such an AffineMap cannot be constructed.
        The default behavior is based on a very simple inference procedure that
        only works with permutation affine maps.
        A more advanced Tensor-Comprehension like inference is possible but has
        proven to be ambiguous in unfavorable case.
        A safer and more robust alternative is to allow each op to define
        its own AffineMap.
      }],
      /*retTy=*/"AffineMap",
      /*methodName=*/"getShapesToLoopsMap",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return inversePermutation(getLoopsToShapesMap());
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Checks if the given operands can be dropped, and the remaining
        operands can still compute the bounds of the op.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"canOpOperandsBeDropped",
      /*args=*/(ins "ArrayRef<OpOperand *>":$droppedOperands),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Like `getShape`, but only returns statically-known information, without
        generating any new IR. For each shape dimension, returns >=0 if that
        dimension is statically known, or ShapedType::kDynamic otherwise.
      }],
      /*retTy=*/"SmallVector<int64_t>",
      /*methodName=*/"getStaticShape",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        SmallVector<int64_t> res;
        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
          llvm::append_range(res, getShape(&opOperand));
        return res;
      }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Returns the statically-known loop ranges. Composes
        `getShapesToLoopsMap()` with the result of `getStaticShape`.
        Returns ShapedType::kDynamic for non-statically-known loop ranges.
        This is expected to be called by a valid Linalg op
      }],
      /*retTy=*/"SmallVector<int64_t, 4>",
      /*methodName=*/"getStaticLoopRanges",
      /*args=*/(ins),
      /*methodBody=*/"",
      /*defaultImplementation=*/[{
        SmallVector<int64_t> viewSizes = getStaticShape();
        AffineMap invertedMap = getShapesToLoopsMap();
        assert(invertedMap && "expected a valid Linalg op to call the method");
        return invertedMap.compose(viewSizes);
      }]
    >,
    //===------------------------------------------------------------------===//
    // Other static interface methods.
    //===------------------------------------------------------------------===//
    StaticInterfaceMethod<
      /*desc=*/[{
        Returns the region builder for constructing the body for linalg.generic.
        Returns a null function if this named op does not define a region
        builder.
      }],
      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
      /*methodName=*/"getRegionBuilder",
      (ins),
      [{ return ConcreteOp::getRegionBuilder(); }]
    >,
    InterfaceMethod<
      /*desc=*/[{
        Return true if all the indexing maps are projected permutations.
        Otherwise return false.
      }],
      /*retTy=*/"bool",
      /*methodName=*/"hasOnlyProjectedPermutations",
      (ins),
      [{
        return llvm::all_of($_op.getIndexingMapsArray(),
                            [](AffineMap map) { return map.isProjectedPermutation(); });
      }]
    >
  ];

  let extraClassDeclaration = [{
    /// Return the flat list of all operand dimension sizes in the order they
    /// appear in the operands.
    SmallVector<OpFoldResult> createFlatListOfOperandDims(OpBuilder &, Location);

    /// Return the flat list of all operands' static dimension sizes in the
    /// order they appear in the operands. All operand dimension sizes have to
    /// be statically known.
    SmallVector<int64_t, 4> createFlatListOfOperandStaticDims();

    /// Create the loop ranges to materialize the computation over the current
    /// operands. This is done by applying `getShapesToLoopsMap` to
    /// `createFlatListOfOperandDims`.
    SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);

    /// Compute the static loop sizes necessary to vectorize the computation.
    /// This is done by applying `getShapesToLoopsMap` to
    /// `createFlatListOfOperandStaticDims`.
    SmallVector<int64_t, 4> computeStaticLoopSizes();

    /// Returns the value that expresses the shape of the output in terms of
    /// shape of the input operands where possible
    LogicalResult reifyResultShapes(OpBuilder &b,
        ReifiedRankedShapedTypeDims &reifiedReturnShapes);

    /// Return the index in the indexingMaps vector that corresponds to this `opOperand`
    int64_t getIndexingMapIndex(OpOperand *opOperand);
  }];

  let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
  let verifyWithRegions = 1;
}

def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
  let description = [{
    Interface for decomposing aggregated operations into a sequence of simpler
    ops.
  }];
  let cppNamespace = "::mlir::linalg";
  let methods = [
      InterfaceMethod<
        /*desc=*/[{
          Method to decompose the operation into simpler operations.

          On success, this method returns one `Value` per result in the
          original operation.
          The order of the returned values must match the order of the
          original values.
          In other words, the returned vector can be used directly with
          `RewriterBase::replaceOp(this, returnedValues)`.
        }],
        /*retType=*/"FailureOr<SmallVector<Value>>",
        /*methodName=*/"decomposeOperation",
        /*args=*/(ins
            "OpBuilder &":$b),
        /*methodBody=*/"",
        /*defaultImplementation=*/[{
          return {};
        }]
      >
  ];
}

#endif // LINALG_IR_LINALGINTERFACES