#include "llvm/CodeGen/ReplaceWithVeclib.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
usingnamespacellvm;
#ifndef NDEBUG
namespace {
static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
if (!Mod)
Err.print("ReplaceWithVecLibTest", errs());
return Mod;
}
class ReplaceWithVecLibTest : public ::testing::Test {
std::string getLastLine(std::string Out) {
if (!Out.empty() && *(Out.cend() - 1) == '\n')
Out.pop_back();
size_t LastNL = Out.find_last_of('\n');
return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
}
protected:
LLVMContext Ctx;
std::string run(const VecDesc &VD, const char *IR) {
TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
TLII.addVectorizableFunctions({VD});
FunctionAnalysisManager FAM;
FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
FunctionPassManager FPM;
FPM.addPass(ReplaceWithVeclib());
std::unique_ptr<Module> M = parseIR(Ctx, IR);
PassBuilder PB;
PB.registerFunctionAnalyses(FAM);
bool DebugFlagPrev = llvm::DebugFlag;
llvm::DebugFlag = true;
testing::internal::CaptureStderr();
FPM.run(*M->getFunction("foo"), FAM);
llvm::DebugFlag = DebugFlagPrev;
return getLastLine(testing::internal::GetCapturedStderr());
}
};
}
static const char *IR = R"IR(
define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
%call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
ret <vscale x 4 x float> %call
}
declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
)IR";
TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
ElementCount::getScalable(4), true,
"_ZGVsMxvu"};
EXPECT_EQ(run(CorrectVD, IR),
"Intrinsic calls replaced with vector libraries: 1");
}
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
ElementCount::getScalable(4), true,
"_ZGVsMxvv"};
EXPECT_EQ(run(IncorrectVD, IR),
"replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
"type at index 1: i32");
}
#endif