// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = tensorflow;
namespace {
// Convert the ImageFrame into Tensor with floating point value type.
// The value will be normalized based on mean and stddev.
std::unique_ptr<tf::Tensor> ImageFrameToNormalizedTensor(
// const ImageFrame& image_frame, float mean, float stddev) {
const ImageFrame& image_frame,
const mediapipe::proto_ns::RepeatedField<float>& mean,
const mediapipe::proto_ns::RepeatedField<float>& stddev) {
const int cols = image_frame.Width();
const int rows = image_frame.Height();
const int channels = image_frame.NumberOfChannels();
const uint8_t* pixel = image_frame.PixelData();
const int width_padding = image_frame.WidthStep() - cols * channels;
auto tensor = ::absl::make_unique<tf::Tensor>(
tf::DT_FLOAT, tf::TensorShape({rows, cols, channels}));
auto tensor_data = tensor->tensor<float, 3>();
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < cols; ++col) {
for (int channel = 0; channel < channels; ++channel) {
float mean_value = 0;
if (mean.size() > 1) {
mean_value = mean[channel];
} else if (!mean.empty()) {
mean_value = mean[0];
}
float stddev_value = 1;
if (stddev.size() > 1) {
stddev_value = stddev[channel];
} else if (!stddev.empty()) {
stddev_value = stddev[0];
}
tensor_data(row, col, channel) =
(pixel[channel] - mean_value) / stddev_value;
}
pixel += channels;
}
pixel += width_padding;
}
return tensor;
}
} // namespace
// Converts ImageFrames to TensorFlow Tensors.
//
// The calculator expects one input (a packet containing an ImageFrame) and
// generates one output (a packet containing a tf::Tensor holding the same
// pixel data). The output tensor will be 3D with dimensions corresponding to
// height, width, and the number of channels (e.g. 3 for RGB or 1 for GRAY8).
//
// This calculator supports ImageFrame objects with any valid format (SRGB
// SRGBA, GRAY8, GRAY16, and VEC32F1). It will generate a Tensor using DT_UINT8
// for the first three types, DT_UINT16 for GRAY16, and DT_FLOAT for VEC32F1.
//
// The ImageFrame data can be packed or padded. The pixel data will be copied
// to the Tensor in row-major order.
//
// Example config:
// node {
// calculator: "ImageFrameToTensorCalculator"
// input_stream: "scaled_frames"
// output_stream: "video_tensors"
// }
class ImageFrameToTensorCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
ImageFrameToTensorCalculatorOptions options_;
};
REGISTER_CALCULATOR(ImageFrameToTensorCalculator);
absl::Status ImageFrameToTensorCalculator::GetContract(CalculatorContract* cc) {
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<ImageFrame>(
// ImageFrame frame.
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Index(0).Set<tf::Tensor>(
// Output TensorFlow Tensor.
);
return absl::OkStatus();
}
absl::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<ImageFrameToTensorCalculatorOptions>();
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) {
const Packet& input_item = cc->Inputs().Index(0).Value();
RET_CHECK(!input_item.IsEmpty()) << "Input cannot be empty.";
// Extract the ImageFrame and metadata from the input packet.
const ImageFrame& video_frame = input_item.Get<ImageFrame>();
const int bytes_per_pixel = video_frame.ByteDepth();
std::unique_ptr<tf::Tensor> tensor;
if (options_.has_data_type()) {
RET_CHECK_EQ(bytes_per_pixel, 1) << "Unsupported image format ("
<< bytes_per_pixel << " bytes per pixel)";
const tf::DataType data_type = options_.data_type();
RET_CHECK_EQ(data_type, tf::DT_FLOAT)
<< "Unsupported data type " << data_type;
RET_CHECK_GT(options_.stddev().size(), 0) << "You must set a stddev.";
RET_CHECK_GT(options_.stddev()[0], 0.0f) << "The stddev cannot be zero.";
if (options_.stddev().size() > 1) {
RET_CHECK_EQ(options_.stddev().size(), video_frame.NumberOfChannels())
<< "If specifying multiple stddev normalization values, "
<< "the number must match the number of image channels.";
}
if (options_.mean().size() > 1) {
RET_CHECK_EQ(options_.mean().size(), video_frame.NumberOfChannels())
<< "If specifying multiple mean normalization values, "
<< "the number must match the number of image channels.";
}
tensor = ImageFrameToNormalizedTensor(video_frame, options_.mean(),
options_.stddev());
} else {
const int height = video_frame.Height();
const int width = video_frame.Width();
const int num_channels = video_frame.NumberOfChannels();
const int num_components = width * height * num_channels;
tf::TensorShape tensor_shape({height, width, num_channels});
// Use uint8 uint16, or float as the TF type depending on bpp of ImageFrame.
tf::DataType data_type;
if (bytes_per_pixel == 1) {
data_type = tf::DT_UINT8;
} else if (bytes_per_pixel == 2) {
data_type = tf::DT_UINT16;
} else if (bytes_per_pixel == 4) {
data_type = tf::DT_FLOAT;
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported image format (", bytes_per_pixel, " bytes per pixel)"));
}
// This failure should never trigger, but it protects the code against
// internal TF changes.
RET_CHECK(tf::DataTypeCanUseMemcpy(data_type))
<< "Tensor data type does not support memcpy (type=" << data_type
<< ")";
// Create the output tensor.
tensor = ::absl::make_unique<tf::Tensor>(data_type, tensor_shape);
// Copy pixel data from the ImageFrame to the tensor.
if (data_type == tf::DT_UINT8) {
uint8_t* dst = tensor->flat<uint8_t>().data();
video_frame.CopyToBuffer(dst, num_components);
} else if (data_type == tf::DT_UINT16) {
uint16_t* dst = tensor->flat<uint16_t>().data();
video_frame.CopyToBuffer(dst, num_components);
} else {
float* dst = tensor->flat<float>().data();
video_frame.CopyToBuffer(dst, num_components);
}
}
cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp());
return absl::OkStatus();
}
} // namespace mediapipe