#include <vector>
#include "testing/gtest/include/gtest/gtest.h"
#include "build/buildflag.h"
#include "third_party/mediapipe/buildflags.h"
#include "third_party/mediapipe/src/mediapipe/framework/calculator.pb.h"
#include "third_party/mediapipe/src/mediapipe/framework/calculator_graph.h"
#include "third_party/mediapipe/src/mediapipe/framework/formats/image_frame.h"
namespace {
[[maybe_unused]] uint8_t GetPixelValue(int x, int y, int iteration) { … }
TEST(MediaEffectsMediaPipe, OnCpu) { … }
#if BUILDFLAG(MEDIAPIPE_BUILD_WITH_GPU_SUPPORT)
TEST(MediaEffectsMediaPipe, OnGpu) {
mediapipe::CalculatorGraphConfig config;
config.add_input_stream("in");
config.add_output_stream("out");
auto* node1 = config.add_node();
*node1->mutable_calculator() = "ImageFrameToGpuBufferCalculator";
node1->add_input_stream("in");
node1->add_output_stream("out1");
auto* node2 = config.add_node();
*node2->mutable_calculator() = "GpuBufferToImageFrameCalculator";
node2->add_input_stream("out1");
node2->add_output_stream("out");
mediapipe::CalculatorGraph graph;
absl::Status status = graph.Initialize(config);
ASSERT_TRUE(status.ok()) << "Failed to initialize calculator graph, error: "
<< status;
absl::StatusOr<mediapipe::OutputStreamPoller> maybe_poller =
graph.AddOutputStreamPoller("out");
ASSERT_TRUE(maybe_poller.ok())
<< "Failed to obtain poller, error: " << maybe_poller.status();
mediapipe::OutputStreamPoller poller = std::move(maybe_poller.value());
status = graph.StartRun({});
ASSERT_TRUE(status.ok()) << "Failed to start run on the graph, error: "
<< status;
for (int i = 0; i < 10; ++i) {
mediapipe::ImageFrame input(mediapipe::ImageFormat::SRGBA, 10, 10);
uint8_t* pixel_data = input.MutablePixelData();
for (int col = 0; col < input.Width(); ++col) {
for (int row = 0; row < input.Height(); ++row) {
*(pixel_data + row + col * input.WidthStep()) =
GetPixelValue(col, row, i);
}
}
status = graph.AddPacketToInputStream(
"in", mediapipe::MakePacket<mediapipe::ImageFrame>(std::move(input))
.At(mediapipe::Timestamp(i)));
ASSERT_TRUE(status.ok())
<< "Failed to add packet to input stream, error: " << status;
}
status = graph.CloseInputStream("in");
ASSERT_TRUE(status.ok()) << "Failed to close the input stream, error: "
<< status;
int iteration = 0;
mediapipe::Packet packet;
while (poller.Next(&packet)) {
const mediapipe::ImageFrame& output = packet.Get<mediapipe::ImageFrame>();
const uint8_t* pixel_data = output.PixelData();
for (int col = 0; col < output.Width(); ++col) {
for (int row = 0; row < output.Height(); ++row) {
const uint8_t expected = GetPixelValue(col, row, iteration);
const uint8_t got = *(pixel_data + row + col * output.WidthStep());
EXPECT_EQ(got, expected)
<< "data mismatch at row " << row << ", column " << col
<< ", iteration " << iteration << ", expected " << expected
<< ", got " << got;
}
}
++iteration;
}
EXPECT_EQ(iteration, 10)
<< "insufficient number of packets read from the graph!";
status = graph.WaitUntilDone();
ASSERT_TRUE(status.ok()) << "WaitUntilDone() failed, error: " << status;
}
#endif
}