llvm/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp

//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.scan' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"

#define DEBUG_TYPE

usingnamespacemlir;
usingnamespacemlir::vector;

/// This function checks to see if the vector combining kind
/// is consistent with the integer or float element type.
static bool isValidKind(bool isInt, vector::CombiningKind kind) {}

namespace {
/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
/// vector.extract_strided_slice.
///
/// Example:
///
/// ```
///   %0:2 = vector.scan <add>, %arg0, %arg1
///     {inclusive = true, reduction_dim = 1} :
///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
/// ```
///
/// is converted to:
///
/// ```
///   %cst = arith.constant dense<0> : vector<2x3xi32>
///   %0 = vector.extract_strided_slice %arg0
///     {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
///       : vector<2x3xi32> to vector<2x1xi32>
///   %1 = vector.insert_strided_slice %0, %cst
///     {offsets = [0, 0], strides = [1, 1]}
///       : vector<2x1xi32> into vector<2x3xi32>
///   %2 = vector.extract_strided_slice %arg0
///     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
///       : vector<2x3xi32> to vector<2x1xi32>
///   %3 = arith.muli %0, %2 : vector<2x1xi32>
///   %4 = vector.insert_strided_slice %3, %1
///     {offsets = [0, 1], strides = [1, 1]}
///       : vector<2x1xi32> into vector<2x3xi32>
///   %5 = vector.extract_strided_slice %arg0
///     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
///       : vector<2x3xi32> to vector<2x1xi32>
///   %6 = arith.muli %3, %5 : vector<2x1xi32>
///   %7 = vector.insert_strided_slice %6, %4
///     {offsets = [0, 2], strides = [1, 1]}
///       : vector<2x1xi32> into vector<2x3xi32>
///   %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
///   return %7, %8 : vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {};
} // namespace

void mlir::vector::populateVectorScanLoweringPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}