llvm/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILTranslateMetadata.h"
#include "DXILResource.h"
#include "DXILResourceAnalysis.h"
#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/VersionTuple.h"
#include "llvm/TargetParser/Triple.h"
#include <cstdint>

using namespace llvm;
using namespace llvm::dxil;

namespace {
/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
/// for TranslateMetadata pass
class DiagnosticInfoTranslateMD : public DiagnosticInfo {
private:
  const Twine &Msg;
  const Module &Mod;

public:
  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
  /// the message to show. Note that this class does not copy this message, so
  /// this reference must be valid for the whole life time of the diagnostic.
  DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg,
                            DiagnosticSeverity Severity = DS_Error)
      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}

  void print(DiagnosticPrinter &DP) const override {
    DP << Mod.getName() << ": " << Msg << '\n';
  }
};

enum class EntryPropsTag {
  ShaderFlags = 0,
  GSState,
  DSState,
  HSState,
  NumThreads,
  AutoBindingSpace,
  RayPayloadSize,
  RayAttribSize,
  ShaderKind,
  MSState,
  ASStateTag,
  WaveSize,
  EntryRootSig,
};

} // namespace

static NamedMDNode *emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
                                         const dxil::Resources &MDResources) {
  LLVMContext &Context = M.getContext();

  SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
  for (const ResourceInfo &RI : DRM.srvs())
    SRVs.push_back(RI.getAsMetadata(Context));
  for (const ResourceInfo &RI : DRM.uavs())
    UAVs.push_back(RI.getAsMetadata(Context));
  for (const ResourceInfo &RI : DRM.cbuffers())
    CBufs.push_back(RI.getAsMetadata(Context));
  for (const ResourceInfo &RI : DRM.samplers())
    Smps.push_back(RI.getAsMetadata(Context));

  Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
  Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
  Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs);
  Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps);
  bool HasResources = !DRM.empty();

  if (MDResources.hasUAVs()) {
    assert(!UAVMD && "Old and new UAV representations can't coexist");
    UAVMD = MDResources.writeUAVs(M);
    HasResources = true;
  }

  if (MDResources.hasCBuffers()) {
    assert(!CBufMD && "Old and new cbuffer representations can't coexist");
    CBufMD = MDResources.writeCBuffers(M);
    HasResources = true;
  }

  if (!HasResources)
    return nullptr;

  NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
  ResourceMD->addOperand(
      MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));

  return ResourceMD;
}

static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
  switch (Env) {
  case Triple::Pixel:
    return "ps";
  case Triple::Vertex:
    return "vs";
  case Triple::Geometry:
    return "gs";
  case Triple::Hull:
    return "hs";
  case Triple::Domain:
    return "ds";
  case Triple::Compute:
    return "cs";
  case Triple::Library:
    return "lib";
  case Triple::Mesh:
    return "ms";
  case Triple::Amplification:
    return "as";
  default:
    break;
  }
  llvm_unreachable("Unsupported environment for DXIL generation.");
}

static uint32_t getShaderStage(Triple::EnvironmentType Env) {
  return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
}

static SmallVector<Metadata *>
getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
  SmallVector<Metadata *> MDVals;
  MDVals.emplace_back(ConstantAsMetadata::get(
      ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
  switch (Tag) {
  case EntryPropsTag::ShaderFlags:
    MDVals.emplace_back(ConstantAsMetadata::get(
        ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
    break;
  case EntryPropsTag::ShaderKind:
    MDVals.emplace_back(ConstantAsMetadata::get(
        ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
    break;
  case EntryPropsTag::GSState:
  case EntryPropsTag::DSState:
  case EntryPropsTag::HSState:
  case EntryPropsTag::NumThreads:
  case EntryPropsTag::AutoBindingSpace:
  case EntryPropsTag::RayPayloadSize:
  case EntryPropsTag::RayAttribSize:
  case EntryPropsTag::MSState:
  case EntryPropsTag::ASStateTag:
  case EntryPropsTag::WaveSize:
  case EntryPropsTag::EntryRootSig:
    llvm_unreachable("NYI: Unhandled entry property tag");
  }
  return MDVals;
}

static MDTuple *
getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
                       const Triple::EnvironmentType ShaderProfile) {
  SmallVector<Metadata *> MDVals;
  LLVMContext &Ctx = EP.Entry->getContext();
  if (EntryShaderFlags != 0)
    MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
                                        EntryShaderFlags, Ctx));

  if (EP.Entry != nullptr) {
    // FIXME: support more props.
    // See https://github.com/llvm/llvm-project/issues/57948.
    // Add shader kind for lib entries.
    if (ShaderProfile == Triple::EnvironmentType::Library &&
        EP.ShaderStage != Triple::EnvironmentType::Library)
      MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
                                          getShaderStage(EP.ShaderStage), Ctx));

    if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
      MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
          Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
      Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
                                       Type::getInt32Ty(Ctx), EP.NumThreadsX)),
                                   ConstantAsMetadata::get(ConstantInt::get(
                                       Type::getInt32Ty(Ctx), EP.NumThreadsY)),
                                   ConstantAsMetadata::get(ConstantInt::get(
                                       Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
      MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
    }
  }
  if (MDVals.empty())
    return nullptr;
  return MDNode::get(Ctx, MDVals);
}

MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
                                MDNode *Resources, MDTuple *Properties,
                                LLVMContext &Ctx) {
  // Each entry point metadata record specifies:
  //  * reference to the entry point function global symbol
  //  * unmangled name
  //  * list of signatures
  //  * list of resources
  //  * list of tag-value pairs of shader capabilities and other properties
  Metadata *MDVals[5];
  MDVals[0] =
      EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
  MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
  MDVals[2] = Signatures;
  MDVals[3] = Resources;
  MDVals[4] = Properties;
  return MDNode::get(Ctx, MDVals);
}

static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
                            MDNode *MDResources,
                            const uint64_t EntryShaderFlags,
                            const Triple::EnvironmentType ShaderProfile) {
  MDTuple *Properties =
      getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
  return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
                                EP.Entry->getContext());
}

static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
  if (MMDI.ValidatorVersion.empty())
    return;

  LLVMContext &Ctx = M.getContext();
  IRBuilder<> IRB(Ctx);
  Metadata *MDVals[2];
  MDVals[0] =
      ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
  MDVals[1] = ConstantAsMetadata::get(
      IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
  NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
  // Set validator version obtained from DXIL Metadata Analysis pass
  ValVerNode->clearOperands();
  ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
}

static void emitShaderModelVersionMD(Module &M,
                                     const ModuleMetadataInfo &MMDI) {
  LLVMContext &Ctx = M.getContext();
  IRBuilder<> IRB(Ctx);
  Metadata *SMVals[3];
  VersionTuple SM = MMDI.ShaderModelVersion;
  SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
  SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
  SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
  NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
  SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
}

static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
  LLVMContext &Ctx = M.getContext();
  IRBuilder<> IRB(Ctx);
  VersionTuple DXILVer = MMDI.DXILVersion;
  Metadata *DXILVals[2];
  DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
  DXILVals[1] =
      ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
  NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
  DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
}

static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
                                        uint64_t ShaderFlags) {
  LLVMContext &Ctx = M.getContext();
  MDTuple *Properties = nullptr;
  if (ShaderFlags != 0) {
    SmallVector<Metadata *> MDVals;
    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
    // ShaderFlags for each entry function. Currently, ShaderFlags value
    // provided by ShaderFlagsAnalysis pass is created by walking *all* the
    // function instructions of the module. Is it is correct to use this value
    // for metadata of the empty library entry?
    MDVals.append(
        getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
    Properties = MDNode::get(Ctx, MDVals);
  }
  // Library has an entry metadata with resource table metadata and all other
  // MDNodes as null.
  return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
}

static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                              const Resources &MDResources,
                              const ComputedShaderFlags &ShaderFlags,
                              const ModuleMetadataInfo &MMDI) {
  LLVMContext &Ctx = M.getContext();
  IRBuilder<> IRB(Ctx);
  SmallVector<MDNode *> EntryFnMDNodes;

  emitValidatorVersionMD(M, MMDI);
  emitShaderModelVersionMD(M, MMDI);
  emitDXILVersionTupleMD(M, MMDI);
  NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, MDResources);
  auto *ResourceMD =
      (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
  // FIXME: Add support to construct Signatures
  // See https://github.com/llvm/llvm-project/issues/57928
  MDTuple *Signatures = nullptr;

  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
    EntryFnMDNodes.emplace_back(
        emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
  else if (MMDI.EntryPropertyVec.size() > 1) {
    M.getContext().diagnose(DiagnosticInfoTranslateMD(
        M, "Non-library shader: One and only one entry expected"));
  }

  for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
    // ShaderFlags for each entry function. For now, assume shader flags value
    // of entry functions being compiled for lib_* shader profile viz.,
    // EntryPro.Entry is 0.
    uint64_t EntryShaderFlags =
        (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
                                                                 : ShaderFlags;
    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
      if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
        M.getContext().diagnose(DiagnosticInfoTranslateMD(
            M,
            "Shader stage '" +
                Twine(getShortShaderStage(EntryProp.ShaderStage) +
                      "' for entry '" + Twine(EntryProp.Entry->getName()) +
                      "' different from specified target profile '" +
                      Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
                            "'"))));
      }
    }
    EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
                                            EntryShaderFlags,
                                            MMDI.ShaderProfile));
  }

  NamedMDNode *EntryPointsNamedMD =
      M.getOrInsertNamedMetadata("dx.entryPoints");
  for (auto *Entry : EntryFnMDNodes)
    EntryPointsNamedMD->addOperand(Entry);
}

PreservedAnalyses DXILTranslateMetadata::run(Module &M,
                                             ModuleAnalysisManager &MAM) {
  const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
  const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
  const ComputedShaderFlags &ShaderFlags =
      MAM.getResult<ShaderFlagsAnalysis>(M);
  const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);

  translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);

  return PreservedAnalyses::all();
}

namespace {
class DXILTranslateMetadataLegacy : public ModulePass {
public:
  static char ID; // Pass identification, replacement for typeid
  explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}

  StringRef getPassName() const override { return "DXIL Translate Metadata"; }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<DXILResourceWrapperPass>();
    AU.addRequired<DXILResourceMDWrapper>();
    AU.addRequired<ShaderFlagsAnalysisWrapper>();
    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
    AU.addPreserved<DXILResourceWrapperPass>();
    AU.addPreserved<DXILResourceMDWrapper>();
    AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
  }

  bool runOnModule(Module &M) override {
    const DXILResourceMap &DRM =
        getAnalysis<DXILResourceWrapperPass>().getResourceMap();
    const dxil::Resources &MDResources =
        getAnalysis<DXILResourceMDWrapper>().getDXILResource();
    const ComputedShaderFlags &ShaderFlags =
        getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
    dxil::ModuleMetadataInfo MMDI =
        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

    translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
    return true;
  }
};

} // namespace

char DXILTranslateMetadataLegacy::ID = 0;

ModulePass *llvm::createDXILTranslateMetadataLegacyPass() {
  return new DXILTranslateMetadataLegacy();
}

INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
                      "DXIL Translate Metadata", false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper)
INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
                    "DXIL Translate Metadata", false, false)