#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <utility>
namespace mlir {
namespace linalg {
static bool hasAllOneValues(DenseIntElementsAttr attr) { … }
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { … }
static Value createMul(Location loc, Value x, Value y, Type accType,
OpBuilder &builder) { … }
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
ArrayRef<int64_t> factors) { … }
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
Value fIndex, int64_t stride) { … }
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { … }
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter,
linalg::DepthwiseConv2DNhwcHwcOp convOp) { … }
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { … }
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { … }
namespace {
class ConvertConv2DNhwcHwcf final
: public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> { … };
class ConvertDepthwiseConv2DNhwcHwc final
: public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> { … };
class ConvertConv2DNchwFchw final
: public OpRewritePattern<linalg::Conv2DNchwFchwOp> { … };
class ConvertConv2DNhwcFhwc final
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { … };
}
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { … }
}
}