//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include <algorithm> usingnamespacemlir; usingnamespacemlir::linalg; /// Include the definitions of the copy operation interface. #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" //===----------------------------------------------------------------------===// // Interface utility functions //===----------------------------------------------------------------------===// bool linalg::detail::canOpOperandsBeDroppedImpl( linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { … } //===----------------------------------------------------------------------===// // CopyOpInterface implementation //===----------------------------------------------------------------------===// bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { … } //===----------------------------------------------------------------------===// // FillOpInterface implementation //===----------------------------------------------------------------------===// std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) { … } //===----------------------------------------------------------------------===// // Elementwise Single Unary/Binary-OpInterface implementation //===----------------------------------------------------------------------===// static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, unsigned arity) { … } bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) { … } bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) { … } //===----------------------------------------------------------------------===// // ContractionOpInterface implementation //===----------------------------------------------------------------------===// /// If the value is defined by a chain of unary side effect-free, go up the /// use-def chain until the first value that isn't defined by such an op. // TODO: relax to multi-operands with constants, which are technically unary ops // as needed (e.g. add5). static Value getSourceSkipUnary(Value value) { … } bool mlir::linalg::detail::isContractionBody( Block &block, function_ref<bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs) { … } /// Returns true if the two operations are of the kinds specified by a pair of /// consecutive template arguments. template <typename AddOpTy, typename MulOpTy, typename... Args> static bool isPairTemplateImpl(Operation *add, Operation *mul) { … } /// Returns true if the block is a body of a contraction with the kinds of /// operations given pairwise by template arguments. template <typename... Args> static bool isContractionBody(Block &block) { … } /// Given an `indexingMap` and its corresponding `iterators`, returns /// the positions of the iterators of type `iter` that are indexed by /// the `indexingMap` as a permutation. This is useful to infer various /// subcomputations on a `LinalgOp`. This is performed by looking up /// each result in the `indexingMap` and determining whether: /// - It is a single AffineDimExpr. /// - It is the only result involving this AffineDimExpr. static llvm::SmallDenseSet<int64_t> findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef<utils::IteratorType> iterators, utils::IteratorType iter) { … } namespace { auto par = …; auto red = …; } // namespace /// Infer the iterator types from the init affine map. This looks at which dims /// are present in the map results, and returns an iterator types array with /// parallel types for dims that are present, and reduction types for dims that /// are not present. static FailureOr<SmallVector<utils::IteratorType>> inferIteratorsFromOutMap(AffineMap map) { … } /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form /// a matmul subcomputation within `linalgOp`. These dimensions are such that: /// 1. The m dimension is involved in an outer-product along LHS /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). /// 2. The n dimension is involved in an outer-product along RHS /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). /// 3. The k dimension appears as a permutation on LHS and RHS. /// 4. m, n and k appear only once in any given indexing. /// 5. Optional batch dimensions that appear in all operands are captured. /// This allows e.g. detecting that some contraction is embedded within /// `linalgOp` with some orthogonal heuristic. static FailureOr<ContractionDimensions> inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, ArrayRef<utils::IteratorType> iterators) { … } FailureOr<ContractionDimensions> mlir::linalg::inferContractionDims(LinalgOp linalgOp) { … } FailureOr<ContractionDimensions> mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { … } namespace mlir::linalg::detail { enum class MatchContractionResult { … }; } // namespace mlir::linalg::detail mlir::linalg::detail::MatchContractionResult mlir::linalg::detail::isContractionInterfaceImpl( Operation *op, mlir::linalg::ContractionDimensions *dimensions) { … } StringRef mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { … } bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { … } /// Verify that a LinalgOp `op` is a contraction. /// A Linalg contraction is defined in general terms: /// 1. Has 2 input and 1 output shapes. /// 2. Has at least one reduction dimension. /// 3. Has only projected permutation indexing maps. /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary /// operations that may change the type (e.g. for mixed-precision). /// As a consequence, when vectorization of such an op occurs, the only special /// behavior is that the (unique) MulOpType is vectorized into a /// `vector.contract`. All other ops are handled in a generic fashion. /// In the future, we may wish to allow more input arguments and elementwise and /// constant operations that do not involve the reduction dimension(s). LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { … } //===----------------------------------------------------------------------===// // ConvolutionOpInterface implementation //===----------------------------------------------------------------------===// /// Of the given two expressions returns one that is of type T (`lhs` gets /// preference over `rhs`) template <typename T> static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { … } namespace { /// Walk the indexing expressions for input of a convolution operation to verify /// its of the right form, either /// - AffineDimExpr /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* /// /// classifies the AffineDimExpr as convolved dimensions or unconvolved /// dimensions and verifies each dimension occurs only once. struct ConvAccessExprWalker : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { … }; } // namespace static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { … } static SmallVector<int64_t, 2> getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { … } /// Classifies dimensions in the `linalgOp` used by a convolution /// subcomputation, as captured by `inputExprWalker`. If /// `allowEmptyConvolvedDims` is not set this this will fail if there is not /// at least convolved dimension pair (output image + filter loop). Convolution /// dimensions are specified in sorted order, and strides match the order of /// the filter loop dimensions, while the dilations match the order of the /// output image dimensions. static FailureOr<ConvolutionDimensions> inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims) { … } /// Find at least 1 parallel (output_image) and reduction (filter_loop) /// dimension candidates that form a convolution subcomputation within /// `linalgOp`. The LHS is assumed to be the convolution input while the /// RHS is assumed as the filter. /// These dimensions are such that: /// 1. Optional batch dimensions that appear in the input and filter. /// 2. The output_image dimension is involved in a cross-correlation along LHS /// (i.e. it is a permutation on RES and LHS and has an associated /// filter_loop in RHS). /// 3. Optional output_channel dimension is involved in an outer-product along /// RHS (i.e. it is a permutation on RES and RHS and does not appear in /// LHS). /// 4. Optional input_channel dimension appears as a permutation on LHS and /// RHS. /// 5. The filter_loop dimension appears as a permutation on the RHS and /// represents the shape of the kernel cross-correlated along a /// corresponding output_image dim. /// 6. The input_channel dimension appears as a permutation on LHS and RHS. /// 7. All dimensions appear only once in any given indexing map. /// This allows e.g. detecting that some convolution is embedded within /// `linalgOp` with some orthogonal heuristic. /// When multiple dimension occurrences exist that match any classification /// indices are returned in sorted order. /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. FailureOr<ConvolutionDimensions> mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { … } namespace mlir::linalg::detail { enum class MatchConvolutionResult { … }; } // namespace mlir::linalg::detail mlir::linalg::detail::MatchConvolutionResult mlir::linalg::detail::isConvolutionInterfaceImpl( Operation *op, ConvolutionDimensions *dimensions, bool allowEmptyConvolvedDims) { … } StringRef mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { … } bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims) { … } LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { … } //===----------------------------------------------------------------------===// // FillOpInterface implementation //===----------------------------------------------------------------------===// enum class MatchFillResult { … }; static MatchFillResult isFillInterfaceImpl(Operation *op) { … } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { … } //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { … } SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { … } SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { … } SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() { … } /// Visitor to check if any of the given set of positions from AffineDimExprs /// are used within an AffineExpr. struct HasAffineDimExprVisitor : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { … }; static std::pair<int64_t, int64_t> getResultsPositionInLoopsToShapeMap(LinalgOp &op) { … } LogicalResult LinalgOp::reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { … } /// Return the index in the indexingMaps vector that corresponds to this /// `opOperand`. int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { … } LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { … }