llvm/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

//===- TosaFolders.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA operations
//
//===----------------------------------------------------------------------===//

#include <functional>
#include <numeric>

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallVector.h"

usingnamespacemlir;
usingnamespacemlir::tosa;

namespace {

/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
/// Elements of \p toTransform are extracted as \p SrcValueType.
///
/// \returns A tensor with the same size as \p toTransform, containing
/// \p TargetValueType values of type \p TargetType.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
    const DenseElementsAttr &toTransform,
    const std::function<TargetValType(const SrcValType &)> &toApply,
    TargetType targetType) {}

template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
    const DenseElementsAttr &toTransform,
    const std::function<APFloat(const APFloat &)> &toApply,
    FloatType targetType);

/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
                               PatternRewriter &rewriter) {}

/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
                                                TosaOp location,
                                                PatternRewriter &rewriter) {}

/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
                                                 TosaOp location,
                                                 PatternRewriter &rewriter) {}

/// Heuristic to decide when to replace a unary operation on a constant with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
/// this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p unaryOp and the constant input \p values.
/// \returns Whether folding should be applied.
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {}

template <typename RangeType>
DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
                                ShapedType outputType,
                                llvm::ArrayRef<int64_t> permValues) {}

// A type specialized transposition of an ElementsAttr.
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
// objects.
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
                            ShapedType outputType,
                            llvm::ArrayRef<int64_t> permValues) {}

struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {};

struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {};

/// Getting the axes position of the element which is located
/// in the tensor at the counter index

llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {}

/// Getting the index of the element which is located at the
/// axes position in the tensor

int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
                             llvm::ArrayRef<int64_t> tensorShape) {}

template <typename OperationType>
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
                                  llvm::ArrayRef<int64_t> oldShape,
                                  int64_t reductionAxis,
                                  int64_t reductionIndex) {}

template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {};

} // namespace

void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
                                               RewritePatternSet &patterns,
                                               bool aggressiveReduceConstant) {}

void mlir::tosa::populateTosaFoldConstantTransposePatterns(
    MLIRContext *ctx, RewritePatternSet &patterns) {}

void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
    MLIRContext *ctx, RewritePatternSet &patterns) {}