llvm/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

//===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Lower matrix intrinsics to vector operations.
//
// TODO:
//  * Improve fusion:
//   * Support more cases, e.g. multiply-add, multiply-sub, operands/results
//     transposed.
//   * Improve cost-modeling, e.g. choose different number of rows/columns
//     columns for tiles, consider cost of copies on alias.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"

#include <cmath>

usingnamespacellvm;
usingnamespacePatternMatch;

#define DEBUG_TYPE

static cl::opt<bool>
    FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
               cl::desc("Enable/disable fusing matrix instructions."));
// TODO: Allow and use non-square tiles.
static cl::opt<unsigned> TileSize(
    "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
    cl::desc(
        "Tile size for matrix instruction fusion using square-shaped tiles."));
static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
                                  cl::Hidden,
                                  cl::desc("Generate loop nest for tiling."));
static cl::opt<bool> ForceFusion(
    "force-fuse-matrix", cl::init(false), cl::Hidden,
    cl::desc("Force matrix instruction fusion even if not profitable."));
static cl::opt<bool> AllowContractEnabled(
    "matrix-allow-contract", cl::init(false), cl::Hidden,
    cl::desc("Allow the use of FMAs if available and profitable. This may "
             "result in different results, due to less rounding error."));

static cl::opt<bool>
    VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
                    cl::desc("Enable/disable matrix shape verification."),
                    cl::init(false));

enum class MatrixLayoutTy {};

static cl::opt<MatrixLayoutTy> MatrixLayout(
    "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
    cl::desc("Sets the default matrix layout"),
    cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
                          "Use column-major layout"),
               clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
                          "Use row-major layout")));

static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
                                            cl::init(false));

/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {}

/// Erase \p V from \p BB and move \II forward to avoid invalidating
/// iterators.
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
                                   BasicBlock &BB) {}

/// Return true if V is a splat of a value (which is used when multiplying a
/// matrix with a scalar).
static bool isSplat(Value *V) {}

/// Match any mul operation (fp or integer).
template <typename LTy, typename RTy>
auto m_AnyMul(const LTy &L, const RTy &R) {}

/// Match any add operation (fp or integer).
template <typename LTy, typename RTy>
auto m_AnyAdd(const LTy &L, const RTy &R) {}

namespace {

// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
// assuming \p Stride elements between start two consecutive vectors.
// \p Stride must be >= \p NumElements.
// For column-major matrixes, the function computes the address of a column
// vectors and \p NumElements must be set to the number of elements in a column
// (= number of rows of the matrix). For row-major matrixes, the function
// computes the address of a row vector and \p NumElements must be set to the
// number of elements in a column (= number of columns of the matrix).
//
// Consider a 4x4 matrix in column-mjaor layout like below
//
//      0       1      2      3
// 0   v_0_0  v_0_1  v_0_2  v_0_3
// 1   v_1_0  v_1_1  v_1_2  v_1_3
// 2   v_2_0  v_2_1  v_2_2  v_2_3
// 3   v_3_0  v_3_1  v_3_2  v_3_3

// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
// we need a pointer to the first element of the submatrix as base pointer.
// Then we can use computeVectorAddr to compute the addresses for the columns
// of the sub-matrix.
//
// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
//           -> just returns Base
// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
//           -> returns Base + (1 * 4)
// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
//           -> returns Base + (2 * 4)
//
// The graphic below illustrates the number of elements in a column (marked
// with |) and the number of skipped elements (marked with }).
//
//         v_0_0  v_0_1 {v_0_2 {v_0_3
//                Base   Col 1  Col 2
//                  |     |      |
//         v_1_0 |v_1_1 |v_1_2 |v_1_3
//         v_2_0 |v_2_1 |v_2_2 |v_2_3
//         v_3_0 {v_3_1 {v_3_2  v_3_3
//
Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
                         unsigned NumElements, Type *EltType,
                         IRBuilder<> &Builder) {}

namespace {
struct ShapeInfo {};
} // namespace

static bool isUniformShape(Value *V) {}

/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
                        const ValueMap<Value *, ShapeInfo> &ShapeMap) {}

/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
///
/// Currently, the lowering for each matrix intrinsic is done as follows:
/// 1. Propagate the shape information from intrinsics to connected
/// instructions.
/// 2. Lower instructions with shape information (assuming column-major layout).
///  The lowering works similarly using row-major layout.
///  2.1. Get column vectors for each argument. If we already lowered the
///       definition of an argument, use the produced column vectors directly.
///       If not, split the operand vector containing an embedded matrix into
///       a set of column vectors,
///  2.2. Lower the instruction in terms of column major operations, which
///       yields a set of column vectors containing result matrix. Note that we
///       lower all instructions that have shape information. Besides the
///       intrinsics, this includes stores for example.
///  2.3. Update uses of the lowered instruction. If we have shape information
///       for a user, there is nothing to do, as we will look up the result
///       column matrix when lowering the user. For other uses, we embed the
///       result matrix in a flat vector and update the use.
///  2.4. Cache the result column matrix for the instruction we lowered
/// 3. After we lowered all instructions in a function, remove the now
///    obsolete instructions.
///
class LowerMatrixIntrinsics {};
} // namespace

PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
                                                 FunctionAnalysisManager &AM) {}

void LowerMatrixIntrinsicsPass::printPipeline(
    raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {}