#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
usingnamespacemlir;
usingnamespacemlir::linalg;
static std::optional<int64_t> getConstantRange(const Range &range) { … }
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> tiles,
ArrayRef<int64_t> dims) { … }
static FailureOr<PackTransposeResult>
transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
tensor::PackOp packOp, AffineMap operandMap,
ArrayRef<unsigned> blocksStartDimPos,
bool transposeOuterBlocks, bool transposeInnerBlocks) { … }
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
const ControlBlockPackMatmulFn &controlPackMatmul) { … }
namespace {
template <typename OpTy>
struct BlockPackMatmul : public OpRewritePattern<OpTy> { … };
template <>
struct BlockPackMatmul<linalg::GenericOp>
: public OpRewritePattern<linalg::GenericOp> { … };
struct LinalgBlockPackMatmul
: public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> { … };
}
void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { … }