//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// \file Pass to transform <256 x i32> load/store /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only /// provides simple operation on x86_amx. The basic elementwise operation /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> /// and only AMX intrinsics can operate on the type, we need transform /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can /// not be combined with load/store, we transform the bitcast to amx load/store /// and <256 x i32> store/load. /// /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, /// because that is necessary for AMX fast register allocation. (In Fast /// registera allocation, register will be allocated before spill/reload, so /// there is no additional register for amx to identify the step in spill.) /// The volatileTileData() will handle this case. /// e.g. /// ---------------------------------------------------------- /// | def %td = ... | /// | ... | /// | "use %td" | /// ---------------------------------------------------------- /// will transfer to --> /// ---------------------------------------------------------- /// | def %td = ... | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | /// | ... | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| /// | "use %td2" | /// ---------------------------------------------------------- // //===----------------------------------------------------------------------===// // #include "X86.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" #include <map> usingnamespacellvm; usingnamespacePatternMatch; #define DEBUG_TYPE … static bool isAMXCast(Instruction *II) { … } static bool isAMXIntrinsic(Value *I) { … } static bool containsAMXCode(Function &F) { … } static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty) { … } static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { … } static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { … } static std::pair<Value *, Value *> getShape(PHINode *Phi) { … } namespace { class X86LowerAMXType { … }; // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %addr, i64 %stride64) void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { … } // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, // %stride); // %13 = bitcast x86_amx %src to <256 x i32> // store <256 x i32> %13, <256 x i32>* %addr, align 64 // --> // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { … } // transform bitcast to <store, load> instructions. bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { … } bool X86LowerAMXType::visit() { … } } // anonymous namespace static Value *getAllocaPos(BasicBlock *BB) { … } static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { … } static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { … } static bool isIncomingOfPHI(Instruction *I) { … } // Let all AMX tile data become volatile data, shorten the life range // of each tile register before fast register allocation. namespace { class X86VolatileTileData { … }; Value *X86VolatileTileData::updatePhiIncomings( BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { … } void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr) { … } // Smilar with volatileTileNonPHI, this function only handle PHI Nodes // and their related AMX intrinsics. // 1) PHI Def should change to tileload. // 2) PHI Incoming Values should tilestored in just after their def. // 3) The mem of these tileload and tilestores should be same. // e.g. // ------------------------------------------------------ // bb_dom: // ... // br i1 %bool.cond, label %if.else, label %if.then // // if.then: // def %t0 = ... // ... // use %t0 // ... // br label %if.end // // if.else: // def %t1 = ... // br label %if.end // // if.end: // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] // ... // use %td // ------------------------------------------------------ // --> // ------------------------------------------------------ // bb_entry: // %mem = alloca <256 x i32>, align 1024 * // ... // bb_dom: // ... // br i1 %bool.cond, label %if.else, label %if.then // // if.then: // def %t0 = ... // call void @llvm.x86.tilestored64.internal(mem, %t0) * // ... // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* // use %t0` * // ... // br label %if.end // // if.else: // def %t1 = ... // call void @llvm.x86.tilestored64.internal(mem, %t1) * // br label %if.end // // if.end: // ... // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * // use %td // ------------------------------------------------------ void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { … } // Store the defined tile and load it before use. // All its users are not PHI. // e.g. // ------------------------------------------------------ // def %td = ... // ... // "use %td" // ------------------------------------------------------ // --> // ------------------------------------------------------ // def %td = ... // call void @llvm.x86.tilestored64.internal(mem, %td) // ... // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) // "use %td2" // ------------------------------------------------------ void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { … } // Volatile Tile Model: // 1) All the uses of tile data comes from tileload in time. // 2) All the defs of tile data tilestore into mem immediately. // For example: // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) // call void @llvm.x86.tilestored64.internal(... td) area // -------------------------------------------------------------------------- // 3) No terminator, call or other amx instructions in the key amx area. bool X86VolatileTileData::volatileTileData() { … } } // anonymous namespace namespace { class X86LowerAMXCast { … }; static bool DCEInstruction(Instruction *I, SmallSetVector<Instruction *, 16> &WorkList, const TargetLibraryInfo *TLI) { … } /// This function handles following case /// /// A -> B amxcast /// PHI /// B -> A amxcast /// /// All the related PHI nodes can be replaced by new PHI nodes with type A. /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. bool X86LowerAMXCast::optimizeAMXCastFromPhi( IntrinsicInst *CI, PHINode *PN, SmallSetVector<Instruction *, 16> &DeadInst) { … } // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) // store <256 x i32> %43, <256 x i32>* %p, align 64 // --> // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, // i64 64, x86_amx %42) bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { … } // %65 = load <256 x i32>, <256 x i32>* %p, align 64 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) // --> // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %p, i64 64) bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { … } bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { … } bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { … } // There might be remaining AMXcast after combineAMXcast and they should be // handled elegantly. bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { … } bool X86LowerAMXCast::transformAllAMXCast() { … } } // anonymous namespace namespace { class X86LowerAMXTypeLegacyPass : public FunctionPass { … }; } // anonymous namespace static const char PassName[] = …; char X86LowerAMXTypeLegacyPass::ID = …; INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) FunctionPass *llvm::createX86LowerAMXTypePass() { … }