llvm/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains DXIL intrinsic expansions for those that don't have
//  opcodes in DirectX Intermediate Language (DXIL).
//===----------------------------------------------------------------------===//

#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"

#define DEBUG_TYPE "dxil-intrinsic-expansion"

using namespace llvm;

class DXILIntrinsicExpansionLegacy : public ModulePass {

public:
  bool runOnModule(Module &M) override;
  DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}

  void getAnalysisUsage(AnalysisUsage &AU) const override;
  static char ID; // Pass identification.
};

static bool isIntrinsicExpansion(Function &F) {
  switch (F.getIntrinsicID()) {
  case Intrinsic::abs:
  case Intrinsic::atan2:
  case Intrinsic::exp:
  case Intrinsic::log:
  case Intrinsic::log10:
  case Intrinsic::pow:
  case Intrinsic::dx_all:
  case Intrinsic::dx_any:
  case Intrinsic::dx_clamp:
  case Intrinsic::dx_uclamp:
  case Intrinsic::dx_lerp:
  case Intrinsic::dx_length:
  case Intrinsic::dx_normalize:
  case Intrinsic::dx_fdot:
  case Intrinsic::dx_sdot:
  case Intrinsic::dx_udot:
  case Intrinsic::dx_sign:
  case Intrinsic::dx_step:
    return true;
  }
  return false;
}

static Value *expandAbs(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  IRBuilder<> Builder(Orig);
  Type *Ty = X->getType();
  Type *EltTy = Ty->getScalarType();
  Constant *Zero = Ty->isVectorTy()
                       ? ConstantVector::getSplat(
                             ElementCount::getFixed(
                                 cast<FixedVectorType>(Ty)->getNumElements()),
                             ConstantInt::get(EltTy, 0))
                       : ConstantInt::get(EltTy, 0);
  auto *V = Builder.CreateSub(Zero, X);
  return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
                                 "dx.max");
}

// Create appropriate DXIL float dot intrinsic for the given A and B operands
// The appropriate opcode will be determined by the size of the operands
// The dot product is placed in the position indicated by Orig
static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
  Type *ATy = A->getType();
  [[maybe_unused]] Type *BTy = B->getType();
  assert(ATy->isVectorTy() && BTy->isVectorTy());

  IRBuilder<> Builder(Orig);

  auto *AVec = dyn_cast<FixedVectorType>(ATy);

  assert(ATy->getScalarType()->isFloatingPointTy());

  Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
  switch (AVec->getNumElements()) {
  case 2:
    DotIntrinsic = Intrinsic::dx_dot2;
    break;
  case 3:
    DotIntrinsic = Intrinsic::dx_dot3;
    break;
  case 4:
    DotIntrinsic = Intrinsic::dx_dot4;
    break;
  default:
    report_fatal_error(
        Twine("Invalid dot product input vector: length is outside 2-4"),
        /* gen_crash_diag=*/false);
    return nullptr;
  }
  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
                                 ArrayRef<Value *>{A, B}, nullptr, "dot");
}

// Create the appropriate DXIL float dot intrinsic for the operands of Orig
// The appropriate opcode will be determined by the size of the operands
// The dot product is placed in the position indicated by Orig
static Value *expandFloatDotIntrinsic(CallInst *Orig) {
  return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
                                 Orig->getOperand(1));
}

// Expand integer dot product to multiply and add ops
static Value *expandIntegerDotIntrinsic(CallInst *Orig,
                                        Intrinsic::ID DotIntrinsic) {
  assert(DotIntrinsic == Intrinsic::dx_sdot ||
         DotIntrinsic == Intrinsic::dx_udot);
  Value *A = Orig->getOperand(0);
  Value *B = Orig->getOperand(1);
  Type *ATy = A->getType();
  [[maybe_unused]] Type *BTy = B->getType();
  assert(ATy->isVectorTy() && BTy->isVectorTy());

  IRBuilder<> Builder(Orig);

  auto *AVec = dyn_cast<FixedVectorType>(ATy);

  assert(ATy->getScalarType()->isIntegerTy());

  Value *Result;
  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
                                   ? Intrinsic::dx_imad
                                   : Intrinsic::dx_umad;
  Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
  Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
  Result = Builder.CreateMul(Elt0, Elt1);
  for (unsigned I = 1; I < AVec->getNumElements(); I++) {
    Elt0 = Builder.CreateExtractElement(A, I);
    Elt1 = Builder.CreateExtractElement(B, I);
    Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
                                     ArrayRef<Value *>{Elt0, Elt1, Result},
                                     nullptr, "dx.mad");
  }
  return Result;
}

static Value *expandExpIntrinsic(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  IRBuilder<> Builder(Orig);
  Type *Ty = X->getType();
  Type *EltTy = Ty->getScalarType();
  Constant *Log2eConst =
      Ty->isVectorTy() ? ConstantVector::getSplat(
                             ElementCount::getFixed(
                                 cast<FixedVectorType>(Ty)->getNumElements()),
                             ConstantFP::get(EltTy, numbers::log2ef))
                       : ConstantFP::get(EltTy, numbers::log2ef);
  Value *NewX = Builder.CreateFMul(Log2eConst, X);
  auto *Exp2Call =
      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
  Exp2Call->setTailCall(Orig->isTailCall());
  Exp2Call->setAttributes(Orig->getAttributes());
  return Exp2Call;
}

static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
                                      Intrinsic::ID intrinsicId) {
  Value *X = Orig->getOperand(0);
  IRBuilder<> Builder(Orig);
  Type *Ty = X->getType();
  Type *EltTy = Ty->getScalarType();

  auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
                            Value *Elt) {
    if (IntrinsicId == Intrinsic::dx_any)
      return Builder.CreateOr(Result, Elt);
    assert(IntrinsicId == Intrinsic::dx_all);
    return Builder.CreateAnd(Result, Elt);
  };

  Value *Result = nullptr;
  if (!Ty->isVectorTy()) {
    Result = EltTy->isFloatingPointTy()
                 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
                 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
  } else {
    auto *XVec = dyn_cast<FixedVectorType>(Ty);
    Value *Cond =
        EltTy->isFloatingPointTy()
            ? Builder.CreateFCmpUNE(
                  X, ConstantVector::getSplat(
                         ElementCount::getFixed(XVec->getNumElements()),
                         ConstantFP::get(EltTy, 0)))
            : Builder.CreateICmpNE(
                  X, ConstantVector::getSplat(
                         ElementCount::getFixed(XVec->getNumElements()),
                         ConstantInt::get(EltTy, 0)));
    Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
    for (unsigned I = 1; I < XVec->getNumElements(); I++) {
      Value *Elt = Builder.CreateExtractElement(Cond, I);
      Result = ApplyOp(intrinsicId, Result, Elt);
    }
  }
  return Result;
}

static Value *expandLengthIntrinsic(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  IRBuilder<> Builder(Orig);
  Type *Ty = X->getType();
  Type *EltTy = Ty->getScalarType();

  // Though dx.length does work on scalar type, we can optimize it to just emit
  // fabs, in CGBuiltin.cpp. We shouldn't see a scalar type here because
  // CGBuiltin.cpp should have emitted a fabs call.
  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
  auto *XVec = dyn_cast<FixedVectorType>(Ty);
  unsigned XVecSize = XVec->getNumElements();
  if (!(Ty->isVectorTy() && XVecSize > 1))
    report_fatal_error(Twine("Invalid input type for length intrinsic"),
                       /* gen_crash_diag=*/false);

  Value *Sum = Builder.CreateFMul(Elt, Elt);
  for (unsigned I = 1; I < XVecSize; I++) {
    Elt = Builder.CreateExtractElement(X, I);
    Value *Mul = Builder.CreateFMul(Elt, Elt);
    Sum = Builder.CreateFAdd(Sum, Mul);
  }
  return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
                                 nullptr, "elt.sqrt");
}

static Value *expandLerpIntrinsic(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  Value *Y = Orig->getOperand(1);
  Value *S = Orig->getOperand(2);
  IRBuilder<> Builder(Orig);
  auto *V = Builder.CreateFSub(Y, X);
  V = Builder.CreateFMul(S, V);
  return Builder.CreateFAdd(X, V, "dx.lerp");
}

static Value *expandLogIntrinsic(CallInst *Orig,
                                 float LogConstVal = numbers::ln2f) {
  Value *X = Orig->getOperand(0);
  IRBuilder<> Builder(Orig);
  Type *Ty = X->getType();
  Type *EltTy = Ty->getScalarType();
  Constant *Ln2Const =
      Ty->isVectorTy() ? ConstantVector::getSplat(
                             ElementCount::getFixed(
                                 cast<FixedVectorType>(Ty)->getNumElements()),
                             ConstantFP::get(EltTy, LogConstVal))
                       : ConstantFP::get(EltTy, LogConstVal);
  auto *Log2Call =
      Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
  Log2Call->setTailCall(Orig->isTailCall());
  Log2Call->setAttributes(Orig->getAttributes());
  return Builder.CreateFMul(Ln2Const, Log2Call);
}
static Value *expandLog10Intrinsic(CallInst *Orig) {
  return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}

// Use dot product of vector operand with itself to calculate the length.
// Divide the vector by that length to normalize it.
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  Type *Ty = Orig->getType();
  Type *EltTy = Ty->getScalarType();
  IRBuilder<> Builder(Orig);

  auto *XVec = dyn_cast<FixedVectorType>(Ty);
  if (!XVec) {
    if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
      const APFloat &fpVal = constantFP->getValueAPF();
      if (fpVal.isZero())
        report_fatal_error(Twine("Invalid input scalar: length is zero"),
                           /* gen_crash_diag=*/false);
    }
    return Builder.CreateFDiv(X, X);
  }

  Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);

  // verify that the length is non-zero
  // (if the dot product is non-zero, then the length is non-zero)
  if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
    const APFloat &fpVal = constantFP->getValueAPF();
    if (fpVal.isZero())
      report_fatal_error(Twine("Invalid input vector: length is zero"),
                         /* gen_crash_diag=*/false);
  }

  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
                                                ArrayRef<Value *>{DotProduct},
                                                nullptr, "dx.rsqrt");

  Value *MultiplicandVec =
      Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
  return Builder.CreateFMul(X, MultiplicandVec);
}

static Value *expandAtan2Intrinsic(CallInst *Orig) {
  Value *Y = Orig->getOperand(0);
  Value *X = Orig->getOperand(1);
  Type *Ty = X->getType();
  IRBuilder<> Builder(Orig);
  Builder.setFastMathFlags(Orig->getFastMathFlags());

  Value *Tan = Builder.CreateFDiv(Y, X);

  CallInst *Atan =
      Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
  Atan->setTailCall(Orig->isTailCall());
  Atan->setAttributes(Orig->getAttributes());

  // Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
  Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
  Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
  Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
  Constant *Zero = ConstantFP::get(Ty, 0);
  Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
  Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);

  // x > 0 -> atan.
  Value *Result = Atan;
  Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
  Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
  Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
  Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);

  // x < 0, y >= 0 -> atan + pi.
  Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
  Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);

  // x < 0, y < 0 -> atan - pi.
  Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
  Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);

  // x == 0, y < 0 -> -pi/2
  Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
  Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);

  // x == 0, y > 0 -> pi/2
  Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
  Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);

  return Result;
}

static Value *expandPowIntrinsic(CallInst *Orig) {

  Value *X = Orig->getOperand(0);
  Value *Y = Orig->getOperand(1);
  Type *Ty = X->getType();
  IRBuilder<> Builder(Orig);

  auto *Log2Call =
      Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
  auto *Mul = Builder.CreateFMul(Log2Call, Y);
  auto *Exp2Call =
      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
  Exp2Call->setTailCall(Orig->isTailCall());
  Exp2Call->setAttributes(Orig->getAttributes());
  return Exp2Call;
}

static Value *expandStepIntrinsic(CallInst *Orig) {

  Value *X = Orig->getOperand(0);
  Value *Y = Orig->getOperand(1);
  Type *Ty = X->getType();
  IRBuilder<> Builder(Orig);

  Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
  Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
  Value *Cond = Builder.CreateFCmpOLT(Y, X);

  if (Ty != Ty->getScalarType()) {
    auto *XVec = dyn_cast<FixedVectorType>(Ty);
    One = ConstantVector::getSplat(
        ElementCount::getFixed(XVec->getNumElements()), One);
    Zero = ConstantVector::getSplat(
        ElementCount::getFixed(XVec->getNumElements()), Zero);
  }

  return Builder.CreateSelect(Cond, Zero, One);
}

static Intrinsic::ID getMaxForClamp(Type *ElemTy,
                                    Intrinsic::ID ClampIntrinsic) {
  if (ClampIntrinsic == Intrinsic::dx_uclamp)
    return Intrinsic::umax;
  assert(ClampIntrinsic == Intrinsic::dx_clamp);
  if (ElemTy->isVectorTy())
    ElemTy = ElemTy->getScalarType();
  if (ElemTy->isIntegerTy())
    return Intrinsic::smax;
  assert(ElemTy->isFloatingPointTy());
  return Intrinsic::maxnum;
}

static Intrinsic::ID getMinForClamp(Type *ElemTy,
                                    Intrinsic::ID ClampIntrinsic) {
  if (ClampIntrinsic == Intrinsic::dx_uclamp)
    return Intrinsic::umin;
  assert(ClampIntrinsic == Intrinsic::dx_clamp);
  if (ElemTy->isVectorTy())
    ElemTy = ElemTy->getScalarType();
  if (ElemTy->isIntegerTy())
    return Intrinsic::smin;
  assert(ElemTy->isFloatingPointTy());
  return Intrinsic::minnum;
}

static Value *expandClampIntrinsic(CallInst *Orig,
                                   Intrinsic::ID ClampIntrinsic) {
  Value *X = Orig->getOperand(0);
  Value *Min = Orig->getOperand(1);
  Value *Max = Orig->getOperand(2);
  Type *Ty = X->getType();
  IRBuilder<> Builder(Orig);
  auto *MaxCall = Builder.CreateIntrinsic(
      Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
  return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
                                 {MaxCall, Max}, nullptr, "dx.min");
}

static Value *expandSignIntrinsic(CallInst *Orig) {
  Value *X = Orig->getOperand(0);
  Type *Ty = X->getType();
  Type *ScalarTy = Ty->getScalarType();
  Type *RetTy = Orig->getType();
  Constant *Zero = Constant::getNullValue(Ty);

  IRBuilder<> Builder(Orig);

  Value *GT;
  Value *LT;
  if (ScalarTy->isFloatingPointTy()) {
    GT = Builder.CreateFCmpOLT(Zero, X);
    LT = Builder.CreateFCmpOLT(X, Zero);
  } else {
    assert(ScalarTy->isIntegerTy());
    GT = Builder.CreateICmpSLT(Zero, X);
    LT = Builder.CreateICmpSLT(X, Zero);
  }

  Value *ZextGT = Builder.CreateZExt(GT, RetTy);
  Value *ZextLT = Builder.CreateZExt(LT, RetTy);

  return Builder.CreateSub(ZextGT, ZextLT);
}

static bool expandIntrinsic(Function &F, CallInst *Orig) {
  Value *Result = nullptr;
  Intrinsic::ID IntrinsicId = F.getIntrinsicID();
  switch (IntrinsicId) {
  case Intrinsic::abs:
    Result = expandAbs(Orig);
    break;
  case Intrinsic::atan2:
    Result = expandAtan2Intrinsic(Orig);
    break;
  case Intrinsic::exp:
    Result = expandExpIntrinsic(Orig);
    break;
  case Intrinsic::log:
    Result = expandLogIntrinsic(Orig);
    break;
  case Intrinsic::log10:
    Result = expandLog10Intrinsic(Orig);
    break;
  case Intrinsic::pow:
    Result = expandPowIntrinsic(Orig);
    break;
  case Intrinsic::dx_all:
  case Intrinsic::dx_any:
    Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
    break;
  case Intrinsic::dx_uclamp:
  case Intrinsic::dx_clamp:
    Result = expandClampIntrinsic(Orig, IntrinsicId);
    break;
  case Intrinsic::dx_lerp:
    Result = expandLerpIntrinsic(Orig);
    break;
  case Intrinsic::dx_length:
    Result = expandLengthIntrinsic(Orig);
    break;
  case Intrinsic::dx_normalize:
    Result = expandNormalizeIntrinsic(Orig);
    break;
  case Intrinsic::dx_fdot:
    Result = expandFloatDotIntrinsic(Orig);
    break;
  case Intrinsic::dx_sdot:
  case Intrinsic::dx_udot:
    Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
    break;
  case Intrinsic::dx_sign:
    Result = expandSignIntrinsic(Orig);
    break;
  case Intrinsic::dx_step:
    Result = expandStepIntrinsic(Orig);
  }
  if (Result) {
    Orig->replaceAllUsesWith(Result);
    Orig->eraseFromParent();
    return true;
  }
  return false;
}

static bool expansionIntrinsics(Module &M) {
  for (auto &F : make_early_inc_range(M.functions())) {
    if (!isIntrinsicExpansion(F))
      continue;
    bool IntrinsicExpanded = false;
    for (User *U : make_early_inc_range(F.users())) {
      auto *IntrinsicCall = dyn_cast<CallInst>(U);
      if (!IntrinsicCall)
        continue;
      IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
    }
    if (F.user_empty() && IntrinsicExpanded)
      F.eraseFromParent();
  }
  return true;
}

PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
                                              ModuleAnalysisManager &) {
  if (expansionIntrinsics(M))
    return PreservedAnalyses::none();
  return PreservedAnalyses::all();
}

bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
  return expansionIntrinsics(M);
}

void DXILIntrinsicExpansionLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.addPreserved<DXILResourceWrapperPass>();
}

char DXILIntrinsicExpansionLegacy::ID = 0;

INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
                      "DXIL Intrinsic Expansion", false, false)
INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
                    "DXIL Intrinsic Expansion", false, false)

ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
  return new DXILIntrinsicExpansionLegacy();
}