llvm/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg transformation to break a reduction dimension
// between a parallel and a reduction dimension.
//
//===----------------------------------------------------------------------===//

#include <optional>
#include <utility>

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"

usingnamespacemlir;
usingnamespacemlir::linalg;

FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
    RewriterBase &b, LinalgOp op,
    const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {}

/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
/// done as a transform to enable better vectorization.
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
                                   unsigned reductionDimPos,
                                   int64_t reductionRatio) {}

static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
                                   unsigned reductionDimPos, int64_t size) {}

/// Core rewrite implementation.
FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
    RewriterBase &b, LinalgOp op,
    const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {}

namespace {

struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {};

} // namespace

void linalg::populateSplitReductionPattern(
    RewritePatternSet &patterns,
    const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {}