llvm/llvm/lib/Target/X86/X86LowerAMXType.cpp

//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file Pass to transform <256 x i32> load/store
/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
/// provides simple operation on x86_amx. The basic elementwise operation
/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
/// and only AMX intrinsics can operate on the type, we need transform
/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
/// not be combined with load/store, we transform the bitcast to amx load/store
/// and <256 x i32> store/load.
///
/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
/// because that is necessary for AMX fast register allocation. (In Fast
/// registera allocation, register will be allocated before spill/reload, so
/// there is no additional register for amx to identify the step in spill.)
/// The volatileTileData() will handle this case.
/// e.g.
/// ----------------------------------------------------------
/// | def %td = ...                                          |
/// | ...                                                    |
/// | "use %td"                                              |
/// ----------------------------------------------------------
/// will transfer to -->
/// ----------------------------------------------------------
/// | def %td = ...                                          |
/// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
/// | ...                                                    |
/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
/// | "use %td2"                                             |
/// ----------------------------------------------------------
//
//===----------------------------------------------------------------------===//
//
#include "X86.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"

#include <map>

usingnamespacellvm;
usingnamespacePatternMatch;

#define DEBUG_TYPE

static bool isAMXCast(Instruction *II) {}

static bool isAMXIntrinsic(Value *I) {}

static bool containsAMXCode(Function &F) {}

static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
                                           Type *Ty) {}

static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {}

static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {}

static std::pair<Value *, Value *> getShape(PHINode *Phi) {}

namespace {
class X86LowerAMXType {};

// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = bitcast <256 x i32> %src to x86_amx
// -->
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
// i8* %addr, i64 %stride64)
void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {}

// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
//                                                    %stride);
// %13 = bitcast x86_amx %src to <256 x i32>
// store <256 x i32> %13, <256 x i32>* %addr, align 64
// -->
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
//                                           %stride64, %13)
void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {}

// transform bitcast to <store, load> instructions.
bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {}

bool X86LowerAMXType::visit() {}
} // anonymous namespace

static Value *getAllocaPos(BasicBlock *BB) {}

static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {}

static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {}

static bool isIncomingOfPHI(Instruction *I) {}

// Let all AMX tile data become volatile data, shorten the life range
// of each tile register before fast register allocation.
namespace {
class X86VolatileTileData {};

Value *X86VolatileTileData::updatePhiIncomings(
    BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {}

void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
                                                Value *StorePtr) {}

// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
// and their related AMX intrinsics.
// 1) PHI Def should change to tileload.
// 2) PHI Incoming Values should tilestored in just after their def.
// 3) The mem of these tileload and tilestores should be same.
// e.g.
// ------------------------------------------------------
// bb_dom:
//   ...
//   br i1 %bool.cond, label %if.else, label %if.then
//
// if.then:
//   def %t0 = ...
//   ...
//   use %t0
//   ...
//   br label %if.end
//
// if.else:
//   def %t1 = ...
//   br label %if.end
//
// if.end:
//   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
//   ...
//   use %td
// ------------------------------------------------------
// -->
// ------------------------------------------------------
// bb_entry:
//   %mem = alloca <256 x i32>, align 1024                  *
//   ...
// bb_dom:
//   ...
//   br i1 %bool.cond, label %if.else, label %if.then
//
// if.then:
//   def %t0 = ...
//   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
//   ...
//   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
//   use %t0`                                               *
//   ...
//   br label %if.end
//
// if.else:
//   def %t1 = ...
//   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
//   br label %if.end
//
// if.end:
//   ...
//   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
//   use %td
// ------------------------------------------------------
void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {}

// Store the defined tile and load it before use.
// All its users are not PHI.
// e.g.
// ------------------------------------------------------
// def %td = ...
// ...
// "use %td"
// ------------------------------------------------------
// -->
// ------------------------------------------------------
// def %td = ...
// call void @llvm.x86.tilestored64.internal(mem, %td)
// ...
// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
// "use %td2"
// ------------------------------------------------------
void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {}

// Volatile Tile Model:
// 1) All the uses of tile data comes from tileload in time.
// 2) All the defs of tile data tilestore into mem immediately.
// For example:
// --------------------------------------------------------------------------
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
// call void @llvm.x86.tilestored64.internal(... td)                     area
// --------------------------------------------------------------------------
// 3) No terminator, call or other amx instructions in the key amx area.
bool X86VolatileTileData::volatileTileData() {}

} // anonymous namespace

namespace {

class X86LowerAMXCast {};

static bool DCEInstruction(Instruction *I,
                           SmallSetVector<Instruction *, 16> &WorkList,
                           const TargetLibraryInfo *TLI) {}

/// This function handles following case
///
///     A  ->  B    amxcast
///     PHI
///     B  ->  A    amxcast
///
/// All the related PHI nodes can be replaced by new PHI nodes with type A.
/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
bool X86LowerAMXCast::optimizeAMXCastFromPhi(
    IntrinsicInst *CI, PHINode *PN,
    SmallSetVector<Instruction *, 16> &DeadInst) {}

// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
// store <256 x i32> %43, <256 x i32>* %p, align 64
// -->
// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
//                                           i64 64, x86_amx %42)
bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {}

// %65 = load <256 x i32>, <256 x i32>* %p, align 64
// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
// -->
// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
//                                                   i8* %p, i64 64)
bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {}

bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {}

bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {}

// There might be remaining AMXcast after combineAMXcast and they should be
// handled elegantly.
bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {}

bool X86LowerAMXCast::transformAllAMXCast() {}

} // anonymous namespace

namespace {

class X86LowerAMXTypeLegacyPass : public FunctionPass {};

} // anonymous namespace

static const char PassName[] =;
char X86LowerAMXTypeLegacyPass::ID =;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                      false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                    false)

FunctionPass *llvm::createX86LowerAMXTypePass() {}