//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// // intrinsics // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass replaces masked memory intrinsics - when unsupported by the target // - with a chain of basic blocks, that deal with the elements one-by-one if the // appropriate mask bit is set. // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> #include <optional> usingnamespacellvm; #define DEBUG_TYPE … namespace { class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { … }; } // end anonymous namespace static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU); static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU); char ScalarizeMaskedMemIntrinLegacyPass::ID = …; INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { … } static bool isConstantIntVector(Value *Mask) { … } static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx) { … } // Translate a masked load intrinsic like // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, // <16 x i1> %mask, <16 x i32> %passthru) // to a chain of basic blocks, with loading element one-by-one if // the appropriate mask bit is set // // %1 = bitcast i8* %addr to i32* // %2 = extractelement <16 x i1> %mask, i32 0 // br i1 %2, label %cond.load, label %else // // cond.load: ; preds = %0 // %3 = getelementptr i32* %1, i32 0 // %4 = load i32* %3 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 // br label %else // // else: ; preds = %0, %cond.load // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] // %6 = extractelement <16 x i1> %mask, i32 1 // br i1 %6, label %cond.load1, label %else2 // // cond.load1: ; preds = %else // %7 = getelementptr i32* %1, i32 1 // %8 = load i32* %7 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 // br label %else2 // // else2: ; preds = %else, %cond.load1 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] // %10 = extractelement <16 x i1> %mask, i32 2 // br i1 %10, label %cond.load4, label %else5 // static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } // Translate a masked store intrinsic, like // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, // <16 x i1> %mask) // to a chain of basic blocks, that stores element one-by-one if // the appropriate mask bit is set // // %1 = bitcast i8* %addr to i32* // %2 = extractelement <16 x i1> %mask, i32 0 // br i1 %2, label %cond.store, label %else // // cond.store: ; preds = %0 // %3 = extractelement <16 x i32> %val, i32 0 // %4 = getelementptr i32* %1, i32 0 // store i32 %3, i32* %4 // br label %else // // else: ; preds = %0, %cond.store // %5 = extractelement <16 x i1> %mask, i32 1 // br i1 %5, label %cond.store1, label %else2 // // cond.store1: ; preds = %else // %6 = extractelement <16 x i32> %val, i32 1 // %7 = getelementptr i32* %1, i32 1 // store i32 %6, i32* %7 // br label %else2 // . . . static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } // Translate a masked gather intrinsic like // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, // <16 x i1> %Mask, <16 x i32> %Src) // to a chain of basic blocks, with loading element one-by-one if // the appropriate mask bit is set // // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind // %Mask0 = extractelement <16 x i1> %Mask, i32 0 // br i1 %Mask0, label %cond.load, label %else // // cond.load: // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 // %Load0 = load i32, i32* %Ptr0, align 4 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 // br label %else // // else: // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] // %Mask1 = extractelement <16 x i1> %Mask, i32 1 // br i1 %Mask1, label %cond.load1, label %else2 // // cond.load1: // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 // %Load1 = load i32, i32* %Ptr1, align 4 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 // br label %else2 // . . . // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src // ret <16 x i32> %Result static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } // Translate a masked scatter intrinsic, like // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, // <16 x i1> %Mask) // to a chain of basic blocks, that stores element one-by-one if // the appropriate mask bit is set. // // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind // %Mask0 = extractelement <16 x i1> %Mask, i32 0 // br i1 %Mask0, label %cond.store, label %else // // cond.store: // %Elt0 = extractelement <16 x i32> %Src, i32 0 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 // store i32 %Elt0, i32* %Ptr0, align 4 // br label %else // // else: // %Mask1 = extractelement <16 x i1> %Mask, i32 1 // br i1 %Mask1, label %cond.store1, label %else2 // // cond.store1: // %Elt1 = extractelement <16 x i32> %Src, i32 1 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 // store i32 %Elt1, i32* %Ptr1, align 4 // br label %else2 // . . . static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { … } static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT) { … } bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { … } PreservedAnalyses ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { … } static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU) { … } static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU) { … }