//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Lower matrix intrinsics to vector operations. // // TODO: // * Improve fusion: // * Support more cases, e.g. multiply-add, multiply-sub, operands/results // transposed. // * Improve cost-modeling, e.g. choose different number of rows/columns // columns for tiles, consider cost of copies on alias. // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/MatrixUtils.h" #include <cmath> usingnamespacellvm; usingnamespacePatternMatch; #define DEBUG_TYPE … static cl::opt<bool> FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions.")); // TODO: Allow and use non-square tiles. static cl::opt<unsigned> TileSize( "fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc( "Tile size for matrix instruction fusion using square-shaped tiles.")); static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling.")); static cl::opt<bool> ForceFusion( "force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable.")); static cl::opt<bool> AllowContractEnabled( "matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error.")); static cl::opt<bool> VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false)); enum class MatrixLayoutTy { … }; static cl::opt<MatrixLayoutTy> MatrixLayout( "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false)); /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { … } /// Erase \p V from \p BB and move \II forward to avoid invalidating /// iterators. static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, BasicBlock &BB) { … } /// Return true if V is a splat of a value (which is used when multiplying a /// matrix with a scalar). static bool isSplat(Value *V) { … } /// Match any mul operation (fp or integer). template <typename LTy, typename RTy> auto m_AnyMul(const LTy &L, const RTy &R) { … } /// Match any add operation (fp or integer). template <typename LTy, typename RTy> auto m_AnyAdd(const LTy &L, const RTy &R) { … } namespace { // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) // assuming \p Stride elements between start two consecutive vectors. // \p Stride must be >= \p NumElements. // For column-major matrixes, the function computes the address of a column // vectors and \p NumElements must be set to the number of elements in a column // (= number of rows of the matrix). For row-major matrixes, the function // computes the address of a row vector and \p NumElements must be set to the // number of elements in a column (= number of columns of the matrix). // // Consider a 4x4 matrix in column-mjaor layout like below // // 0 1 2 3 // 0 v_0_0 v_0_1 v_0_2 v_0_3 // 1 v_1_0 v_1_1 v_1_2 v_1_3 // 2 v_2_0 v_2_1 v_2_2 v_2_3 // 3 v_3_0 v_3_1 v_3_2 v_3_3 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1, // we need a pointer to the first element of the submatrix as base pointer. // Then we can use computeVectorAddr to compute the addresses for the columns // of the sub-matrix. // // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) // -> just returns Base // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) // -> returns Base + (1 * 4) // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) // -> returns Base + (2 * 4) // // The graphic below illustrates the number of elements in a column (marked // with |) and the number of skipped elements (marked with }). // // v_0_0 v_0_1 {v_0_2 {v_0_3 // Base Col 1 Col 2 // | | | // v_1_0 |v_1_1 |v_1_2 |v_1_3 // v_2_0 |v_2_1 |v_2_2 |v_2_3 // v_3_0 {v_3_1 {v_3_2 v_3_3 // Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, unsigned NumElements, Type *EltType, IRBuilder<> &Builder) { … } namespace { struct ShapeInfo { … }; } // namespace static bool isUniformShape(Value *V) { … } /// Return the ShapeInfo for the result of \p I, it it can be determined. static std::optional<ShapeInfo> computeShapeInfoForInst(Instruction *I, const ValueMap<Value *, ShapeInfo> &ShapeMap) { … } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. /// /// Currently, the lowering for each matrix intrinsic is done as follows: /// 1. Propagate the shape information from intrinsics to connected /// instructions. /// 2. Lower instructions with shape information (assuming column-major layout). /// The lowering works similarly using row-major layout. /// 2.1. Get column vectors for each argument. If we already lowered the /// definition of an argument, use the produced column vectors directly. /// If not, split the operand vector containing an embedded matrix into /// a set of column vectors, /// 2.2. Lower the instruction in terms of column major operations, which /// yields a set of column vectors containing result matrix. Note that we /// lower all instructions that have shape information. Besides the /// intrinsics, this includes stores for example. /// 2.3. Update uses of the lowered instruction. If we have shape information /// for a user, there is nothing to do, as we will look up the result /// column matrix when lowering the user. For other uses, we embed the /// result matrix in a flat vector and update the use. /// 2.4. Cache the result column matrix for the instruction we lowered /// 3. After we lowered all instructions in a function, remove the now /// obsolete instructions. /// class LowerMatrixIntrinsics { … }; } // namespace PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { … } void LowerMatrixIntrinsicsPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { … }