#include <functional>
#include <numeric>
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallVector.h"
usingnamespacemlir;
usingnamespacemlir::tosa;
namespace {
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &)> &toApply,
TargetType targetType) { … }
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &)> &toApply,
FloatType targetType);
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
PatternRewriter &rewriter) { … }
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) { … }
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) { … }
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) { … }
template <typename RangeType>
DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) { … }
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) { … }
struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> { … };
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> { … };
llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) { … }
int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
llvm::ArrayRef<int64_t> tensorShape) { … }
template <typename OperationType>
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
llvm::ArrayRef<int64_t> oldShape,
int64_t reductionAxis,
int64_t reductionIndex) { … }
template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> { … };
}
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
RewritePatternSet &patterns,
bool aggressiveReduceConstant) { … }
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
MLIRContext *ctx, RewritePatternSet &patterns) { … }
void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) { … }