llvm/llvm/lib/Target/DirectX/DXILOpBuilder.cpp

//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains class to help build DXIL op functions.
//===----------------------------------------------------------------------===//

#include "DXILOpBuilder.h"
#include "DXILConstants.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

using namespace llvm;
using namespace llvm::dxil;

constexpr StringLiteral DXILOpNamePrefix = "dx.op.";

namespace {
enum OverloadKind : uint16_t {
  UNDEFINED = 0,
  VOID = 1,
  HALF = 1 << 1,
  FLOAT = 1 << 2,
  DOUBLE = 1 << 3,
  I1 = 1 << 4,
  I8 = 1 << 5,
  I16 = 1 << 6,
  I32 = 1 << 7,
  I64 = 1 << 8,
  UserDefineType = 1 << 9,
  ObjectType = 1 << 10,
};
struct Version {
  unsigned Major = 0;
  unsigned Minor = 0;
};

struct OpOverload {
  Version DXILVersion;
  uint16_t ValidTys;
};
} // namespace

struct OpStage {
  Version DXILVersion;
  uint32_t ValidStages;
};

struct OpAttribute {
  Version DXILVersion;
  uint32_t ValidAttrs;
};

static const char *getOverloadTypeName(OverloadKind Kind) {
  switch (Kind) {
  case OverloadKind::HALF:
    return "f16";
  case OverloadKind::FLOAT:
    return "f32";
  case OverloadKind::DOUBLE:
    return "f64";
  case OverloadKind::I1:
    return "i1";
  case OverloadKind::I8:
    return "i8";
  case OverloadKind::I16:
    return "i16";
  case OverloadKind::I32:
    return "i32";
  case OverloadKind::I64:
    return "i64";
  case OverloadKind::VOID:
  case OverloadKind::UNDEFINED:
    return "void";
  case OverloadKind::ObjectType:
  case OverloadKind::UserDefineType:
    break;
  }
  llvm_unreachable("invalid overload type for name");
}

static OverloadKind getOverloadKind(Type *Ty) {
  if (!Ty)
    return OverloadKind::VOID;

  Type::TypeID T = Ty->getTypeID();
  switch (T) {
  case Type::VoidTyID:
    return OverloadKind::VOID;
  case Type::HalfTyID:
    return OverloadKind::HALF;
  case Type::FloatTyID:
    return OverloadKind::FLOAT;
  case Type::DoubleTyID:
    return OverloadKind::DOUBLE;
  case Type::IntegerTyID: {
    IntegerType *ITy = cast<IntegerType>(Ty);
    unsigned Bits = ITy->getBitWidth();
    switch (Bits) {
    case 1:
      return OverloadKind::I1;
    case 8:
      return OverloadKind::I8;
    case 16:
      return OverloadKind::I16;
    case 32:
      return OverloadKind::I32;
    case 64:
      return OverloadKind::I64;
    default:
      llvm_unreachable("invalid overload type");
      return OverloadKind::VOID;
    }
  }
  case Type::PointerTyID:
    return OverloadKind::UserDefineType;
  case Type::StructTyID: {
    // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
    // how we're handling overloads and remove the `OverloadKind` proxy enum.
    StructType *ST = cast<StructType>(Ty);
    return getOverloadKind(ST->getElementType(0));
  }
  default:
    return OverloadKind::UNDEFINED;
  }
}

static std::string getTypeName(OverloadKind Kind, Type *Ty) {
  if (Kind < OverloadKind::UserDefineType) {
    return getOverloadTypeName(Kind);
  } else if (Kind == OverloadKind::UserDefineType) {
    StructType *ST = cast<StructType>(Ty);
    return ST->getStructName().str();
  } else if (Kind == OverloadKind::ObjectType) {
    StructType *ST = cast<StructType>(Ty);
    return ST->getStructName().str();
  } else {
    std::string Str;
    raw_string_ostream OS(Str);
    Ty->print(OS);
    return OS.str();
  }
}

// Static properties.
struct OpCodeProperty {
  dxil::OpCode OpCode;
  // Offset in DXILOpCodeNameTable.
  unsigned OpCodeNameOffset;
  dxil::OpCodeClass OpCodeClass;
  // Offset in DXILOpCodeClassNameTable.
  unsigned OpCodeClassNameOffset;
  llvm::SmallVector<OpOverload> Overloads;
  llvm::SmallVector<OpStage> Stages;
  llvm::SmallVector<OpAttribute> Attributes;
  int OverloadParamIndex; // parameter index which control the overload.
                          // When < 0, should be only 1 overload type.
};

// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
// getOpCodeParameterKind which generated by tableGen.
#define DXIL_OP_OPERATION_TABLE
#include "DXILOperation.inc"
#undef DXIL_OP_OPERATION_TABLE

static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
                                         const OpCodeProperty &Prop) {
  if (Kind == OverloadKind::VOID) {
    return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
  }
  return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
          getTypeName(Kind, Ty))
      .str();
}

static std::string constructOverloadTypeName(OverloadKind Kind,
                                             StringRef TypeName) {
  if (Kind == OverloadKind::VOID)
    return TypeName.str();

  assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
  return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
}

static StructType *getOrCreateStructType(StringRef Name,
                                         ArrayRef<Type *> EltTys,
                                         LLVMContext &Ctx) {
  StructType *ST = StructType::getTypeByName(Ctx, Name);
  if (ST)
    return ST;

  return StructType::create(Ctx, EltTys, Name);
}

static StructType *getResRetType(Type *ElementTy) {
  LLVMContext &Ctx = ElementTy->getContext();
  OverloadKind Kind = getOverloadKind(ElementTy);
  std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
  Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
                         Type::getInt32Ty(Ctx)};
  return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}

static StructType *getHandleType(LLVMContext &Ctx) {
  return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
                               Ctx);
}

static StructType *getResBindType(LLVMContext &Context) {
  if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind"))
    return ST;
  Type *Int32Ty = Type::getInt32Ty(Context);
  Type *Int8Ty = Type::getInt8Ty(Context);
  return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty},
                            "dx.types.ResBind");
}

static StructType *getResPropsType(LLVMContext &Context) {
  if (auto *ST =
          StructType::getTypeByName(Context, "dx.types.ResourceProperties"))
    return ST;
  Type *Int32Ty = Type::getInt32Ty(Context);
  return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}

static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
                                    Type *OverloadTy) {
  switch (Kind) {
  case OpParamType::VoidTy:
    return Type::getVoidTy(Ctx);
  case OpParamType::HalfTy:
    return Type::getHalfTy(Ctx);
  case OpParamType::FloatTy:
    return Type::getFloatTy(Ctx);
  case OpParamType::DoubleTy:
    return Type::getDoubleTy(Ctx);
  case OpParamType::Int1Ty:
    return Type::getInt1Ty(Ctx);
  case OpParamType::Int8Ty:
    return Type::getInt8Ty(Ctx);
  case OpParamType::Int16Ty:
    return Type::getInt16Ty(Ctx);
  case OpParamType::Int32Ty:
    return Type::getInt32Ty(Ctx);
  case OpParamType::Int64Ty:
    return Type::getInt64Ty(Ctx);
  case OpParamType::OverloadTy:
    return OverloadTy;
  case OpParamType::ResRetHalfTy:
    return getResRetType(Type::getHalfTy(Ctx));
  case OpParamType::ResRetFloatTy:
    return getResRetType(Type::getFloatTy(Ctx));
  case OpParamType::ResRetInt16Ty:
    return getResRetType(Type::getInt16Ty(Ctx));
  case OpParamType::ResRetInt32Ty:
    return getResRetType(Type::getInt32Ty(Ctx));
  case OpParamType::HandleTy:
    return getHandleType(Ctx);
  case OpParamType::ResBindTy:
    return getResBindType(Ctx);
  case OpParamType::ResPropsTy:
    return getResPropsType(Ctx);
  }
  llvm_unreachable("Invalid parameter kind");
  return nullptr;
}

static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) {
  switch (EnvType) {
  case Triple::Pixel:
    return ShaderKind::pixel;
  case Triple::Vertex:
    return ShaderKind::vertex;
  case Triple::Geometry:
    return ShaderKind::geometry;
  case Triple::Hull:
    return ShaderKind::hull;
  case Triple::Domain:
    return ShaderKind::domain;
  case Triple::Compute:
    return ShaderKind::compute;
  case Triple::Library:
    return ShaderKind::library;
  case Triple::RayGeneration:
    return ShaderKind::raygeneration;
  case Triple::Intersection:
    return ShaderKind::intersection;
  case Triple::AnyHit:
    return ShaderKind::anyhit;
  case Triple::ClosestHit:
    return ShaderKind::closesthit;
  case Triple::Miss:
    return ShaderKind::miss;
  case Triple::Callable:
    return ShaderKind::callable;
  case Triple::Mesh:
    return ShaderKind::mesh;
  case Triple::Amplification:
    return ShaderKind::amplification;
  default:
    break;
  }
  llvm_unreachable(
      "Shader Kind Not Found - Invalid DXIL Environment Specified");
}

static SmallVector<Type *>
getArgTypesFromOpParamTypes(ArrayRef<dxil::OpParamType> Types,
                            LLVMContext &Context, Type *OverloadTy) {
  SmallVector<Type *> ArgTys;
  ArgTys.emplace_back(Type::getInt32Ty(Context));
  for (dxil::OpParamType Ty : Types)
    ArgTys.emplace_back(getTypeFromOpParamType(Ty, Context, OverloadTy));
  return ArgTys;
}

/// Construct DXIL function type. This is the type of a function with
/// the following prototype
///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
/// <param-types> are constructed from types in Prop.
static FunctionType *getDXILOpFunctionType(dxil::OpCode OpCode,
                                           LLVMContext &Context,
                                           Type *OverloadTy) {

  switch (OpCode) {
#define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...)                            \
  case OpCode:                                                                 \
    return FunctionType::get(                                                  \
        getTypeFromOpParamType(RetType, Context, OverloadTy),                  \
        getArgTypesFromOpParamTypes({__VA_ARGS__}, Context, OverloadTy),       \
        /*isVarArg=*/false);
#include "DXILOperation.inc"
  }
  llvm_unreachable("Invalid OpCode?");
}

/// Get index of the property from PropList valid for the most recent
/// DXIL version not greater than DXILVer.
/// PropList is expected to be sorted in ascending order of DXIL version.
template <typename T>
static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
                                          const VersionTuple DXILVer) {
  size_t Index = PropList.size() - 1;
  for (auto Iter = PropList.rbegin(); Iter != PropList.rend();
       Iter++, Index--) {
    const T &Prop = *Iter;
    if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <=
        DXILVer) {
      return Index;
    }
  }
  return std::nullopt;
}

namespace llvm {
namespace dxil {

// No extra checks on TargetTriple need be performed to verify that the
// Triple is well-formed or that the target is supported since these checks
// would have been done at the time the module M is constructed in the earlier
// stages of compilation.
DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
  Triple TT(Triple(M.getTargetTriple()));
  DXILVersion = TT.getDXILVersion();
  ShaderStage = TT.getEnvironment();
  // Ensure Environment type is known
  if (ShaderStage == Triple::UnknownEnvironment) {
    report_fatal_error(
        Twine(DXILVersion.getAsString()) +
            ": Unknown Compilation Target Shader Stage specified ",
        /*gen_crash_diag*/ false);
  }
}

static Error makeOpError(dxil::OpCode OpCode, Twine Msg) {
  return make_error<StringError>(
      Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg,
      inconvertibleErrorCode());
}

Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
                                                ArrayRef<Value *> Args,
                                                const Twine &Name,
                                                Type *RetTy) {
  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);

  Type *OverloadTy = nullptr;
  if (Prop->OverloadParamIndex == 0) {
    if (!RetTy)
      return makeOpError(OpCode, "Op overloaded on unknown return type");
    OverloadTy = RetTy;
  } else if (Prop->OverloadParamIndex > 0) {
    // The index counts including the return type
    unsigned ArgIndex = Prop->OverloadParamIndex - 1;
    if (static_cast<unsigned>(ArgIndex) >= Args.size())
      return makeOpError(OpCode, "Wrong number of arguments");
    OverloadTy = Args[ArgIndex]->getType();
  }

  FunctionType *DXILOpFT =
      getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);

  std::optional<size_t> OlIndexOrErr =
      getPropIndex(ArrayRef(Prop->Overloads), DXILVersion);
  if (!OlIndexOrErr.has_value())
    return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") +
                                   DXILVersion.getAsString());

  uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;

  OverloadKind Kind = getOverloadKind(OverloadTy);

  // Check if the operation supports overload types and OverloadTy is valid
  // per the specified types for the operation
  if ((ValidTyMask != OverloadKind::UNDEFINED) &&
      (ValidTyMask & (uint16_t)Kind) == 0)
    return makeOpError(OpCode, "Invalid overload type");

  // Perform necessary checks to ensure Opcode is valid in the targeted shader
  // kind
  std::optional<size_t> StIndexOrErr =
      getPropIndex(ArrayRef(Prop->Stages), DXILVersion);
  if (!StIndexOrErr.has_value())
    return makeOpError(OpCode, Twine("No valid stage for DXIL version ") +
                                   DXILVersion.getAsString());

  uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages;

  // Ensure valid shader stage properties are specified
  if (ValidShaderKindMask == ShaderKind::removed)
    return makeOpError(OpCode, "Operation has been removed");

  // Shader stage need not be validated since getShaderKindEnum() fails
  // for unknown shader stage.

  // Verify the target shader stage is valid for the DXIL operation
  ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage);
  if (!(ValidShaderKindMask & ModuleStagekind))
    return makeOpError(OpCode, "Invalid stage");

  std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
  FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);

  // We need to inject the opcode as the first argument.
  SmallVector<Value *> OpArgs;
  OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
  OpArgs.append(Args.begin(), Args.end());

  return IRB.CreateCall(DXILFn, OpArgs, Name);
}

CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
                                  const Twine &Name, Type *RetTy) {
  Expected<CallInst *> Result = tryCreateOp(OpCode, Args, Name, RetTy);
  if (Error E = Result.takeError())
    llvm_unreachable("Invalid arguments for operation");
  return *Result;
}

StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
  return ::getResRetType(ElementTy);
}

StructType *DXILOpBuilder::getHandleType() {
  return ::getHandleType(IRB.getContext());
}

Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound,
                                    uint32_t SpaceID, dxil::ResourceClass RC) {
  Type *Int32Ty = IRB.getInt32Ty();
  Type *Int8Ty = IRB.getInt8Ty();
  return ConstantStruct::get(
      getResBindType(IRB.getContext()),
      {ConstantInt::get(Int32Ty, LowerBound),
       ConstantInt::get(Int32Ty, UpperBound),
       ConstantInt::get(Int32Ty, SpaceID),
       ConstantInt::get(Int8Ty, llvm::to_underlying(RC))});
}

Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) {
  Type *Int32Ty = IRB.getInt32Ty();
  return ConstantStruct::get(
      getResPropsType(IRB.getContext()),
      {ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)});
}

const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
  return ::getOpCodeName(DXILOp);
}
} // namespace dxil
} // namespace llvm