llvm/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp

//===- llvm/unittest/CodeGen/SPIRVAPITest.cpp -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// Test that SPIR-V Backend provides an API call that translates LLVM IR Module
/// into SPIR-V.
//
//===----------------------------------------------------------------------===//

#include "llvm/AsmParser/Parser.h"
#include "llvm/BinaryFormat/Magic.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
#include <gmock/gmock.h>
#include <string>
#include <utility>

using ::testing::StartsWith;

namespace llvm {

extern "C" bool
SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
                     const std::vector<std::string> &AllowExtNames,
                     const std::vector<std::string> &Opts);

class SPIRVAPITest : public testing::Test {
protected:
  bool toSpirv(StringRef Assembly, std::string &Result, std::string &ErrMsg,
               const std::vector<std::string> &AllowExtNames,
               const std::vector<std::string> &Opts) {
    SMDiagnostic ParseError;
    M = parseAssemblyString(Assembly, ParseError, Context);
    if (!M) {
      ParseError.print("IR parsing failed: ", errs());
      report_fatal_error("Can't parse input assembly.");
    }
    bool Status =
        SPIRVTranslateModule(M.get(), Result, ErrMsg, AllowExtNames, Opts);
    if (!Status)
      errs() << ErrMsg;
    return Status;
  }

  LLVMContext Context;
  std::unique_ptr<Module> M;

  static constexpr StringRef ExtensionAssembly = R"(
    define dso_local spir_func void @test1() {
    entry:
      %res1 = tail call spir_func i32 @_Z26__spirv_GroupBitwiseAndKHR(i32 2, i32 0, i32 0)
      ret void
    }

    declare dso_local spir_func i32  @_Z26__spirv_GroupBitwiseAndKHR(i32, i32, i32)
  )";
  static constexpr StringRef OkAssembly = R"(
    %struct = type { [1 x i64] }

    define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
    entry:
      call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
      ret void
    }

    define spir_func void @bar(<2 x i32> noundef) {
    entry:
      ret void
    }
  )";
};

TEST_F(SPIRVAPITest, checkTranslateOk) {
  StringRef Assemblies[] = {"", OkAssembly};
  // Those command line arguments that overlap with registered by llc/codegen
  // are to be started with the ' ' symbol.
  std::vector<std::string> SetOfOpts[] = {
      {}, {"- mtriple=spirv32-unknown-unknown"}};
  for (const auto &Opts : SetOfOpts) {
    for (StringRef &Assembly : Assemblies) {
      std::string Result, Error;
      bool Status = toSpirv(Assembly, Result, Error, {}, Opts);
      EXPECT_TRUE(Status && Error.empty() && !Result.empty());
      EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
    }
  }
}

TEST_F(SPIRVAPITest, checkTranslateError) {
  std::string Result, Error;
  bool Status = toSpirv(OkAssembly, Result, Error, {},
                        {"-mtriple=spirv32-unknown-unknown"});
  EXPECT_FALSE(Status);
  EXPECT_TRUE(Result.empty());
  EXPECT_THAT(Error,
              StartsWith("SPIRVTranslateModule: Unknown command line argument "
                         "'-mtriple=spirv32-unknown-unknown'"));
  Status = toSpirv(OkAssembly, Result, Error, {}, {"- O 5"});
  EXPECT_FALSE(Status);
  EXPECT_TRUE(Result.empty());
  EXPECT_EQ(Error, "Invalid optimization level!");
}

TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByOpts) {
  std::string Result, Error;
  std::vector<std::string> Opts{
      "--spirv-ext=+SPV_KHR_uniform_group_instructions"};
  bool Status = toSpirv(ExtensionAssembly, Result, Error, {}, Opts);
  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
}

TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByArg) {
  std::string Result, Error;
  std::vector<std::string> ExtNames{"SPV_KHR_uniform_group_instructions"};
  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
}

TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByArgList) {
  std::string Result, Error;
  std::vector<std::string> ExtNames{"SPV_KHR_subgroup_rotate",
                                    "SPV_KHR_uniform_group_instructions",
                                    "SPV_KHR_subgroup_rotate"};
  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
}

TEST_F(SPIRVAPITest, checkTranslateAllExtensions) {
  std::string Result, Error;
  std::vector<std::string> Opts{"--spirv-ext=all"};
  bool Status = toSpirv(ExtensionAssembly, Result, Error, {}, Opts);
  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
}

TEST_F(SPIRVAPITest, checkTranslateUnknownExtensionByArg) {
  std::string Result, Error;
  std::vector<std::string> ExtNames{"SPV_XYZ_my_unknown_extension"};
  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
  EXPECT_FALSE(Status);
  EXPECT_TRUE(Result.empty());
  EXPECT_EQ(Error, "Unknown SPIR-V extension: SPV_XYZ_my_unknown_extension");
}

#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
TEST_F(SPIRVAPITest, checkTranslateExtensionError) {
  std::string Result, Error;
  std::vector<std::string> Opts;
  EXPECT_DEATH_IF_SUPPORTED(
      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
}

TEST_F(SPIRVAPITest, checkTranslateUnknownExtensionByOpts) {
  std::string Result, Error;
  std::vector<std::string> Opts{"--spirv-ext=+SPV_XYZ_my_unknown_extension"};
  EXPECT_DEATH_IF_SUPPORTED(
      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
      "SPIRVTranslateModule: for the --spirv-ext option: Unknown SPIR-V");
}

TEST_F(SPIRVAPITest, checkTranslateWrongExtensionByOpts) {
  std::string Result, Error;
  std::vector<std::string> Opts{"--spirv-ext=+SPV_KHR_subgroup_rotate"};
  EXPECT_DEATH_IF_SUPPORTED(
      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
}

TEST_F(SPIRVAPITest, checkTranslateWrongExtensionByArg) {
  std::string Result, Error;
  std::vector<std::string> ExtNames{"SPV_KHR_subgroup_rotate"};
  EXPECT_DEATH_IF_SUPPORTED(
      { toSpirv(ExtensionAssembly, Result, Error, ExtNames, {}); },
      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
}
#endif

} // end namespace llvm