chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"

namespace mediapipe {

const char kTFRecordPath[] = "TFRECORD_PATH";
const char kRecordIndex[] = "RECORD_INDEX";
const char kExampleTag[] = "EXAMPLE";
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";

// Reads a tensorflow example/sequence example from a tfrecord file.
// If the "RECORD_INDEX" input side packet is provided, the calculator is going
// to fetch the example/sequence example of the tfrecord file at the target
// record index. Otherwise, the reader always reads the first example/sequence
// example of the tfrecord file.
//
// Example config:
// node {
//   calculator: "TFRecordReaderCalculator"
//   input_side_packet: "TFRECORD_PATH:tfrecord_path"
//   input_side_packet: "RECORD_INDEX:record_index"
//   output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
// }
class TFRecordReaderCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc);

  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override;
};

absl::Status TFRecordReaderCalculator::GetContract(CalculatorContract* cc) {
  cc->InputSidePackets().Tag(kTFRecordPath).Set<std::string>();
  if (cc->InputSidePackets().HasTag(kRecordIndex)) {
    cc->InputSidePackets().Tag(kRecordIndex).Set<int>();
  }

  RET_CHECK(cc->OutputSidePackets().HasTag(kExampleTag) ||
            cc->OutputSidePackets().HasTag(kSequenceExampleTag))
      << "TFRecordReaderCalculator must output either Tensorflow example or "
         "sequence example.";
  if (cc->OutputSidePackets().HasTag(kExampleTag)) {
    cc->OutputSidePackets().Tag(kExampleTag).Set<tensorflow::Example>();
  } else {
    cc->OutputSidePackets()
        .Tag(kSequenceExampleTag)
        .Set<tensorflow::SequenceExample>();
  }
  return absl::OkStatus();
}

absl::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) {
  std::unique_ptr<tensorflow::RandomAccessFile> file;
  auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile(
      cc->InputSidePackets().Tag(kTFRecordPath).Get<std::string>(), &file);
  RET_CHECK(tf_status.ok())
      << "Failed to open tfrecord file: " << tf_status.ToString();
  tensorflow::io::RecordReader reader(file.get(),
                                      tensorflow::io::RecordReaderOptions());
  uint64_t offset = 0;
  tensorflow::tstring example_str;
  const int target_idx =
      cc->InputSidePackets().HasTag(kRecordIndex)
          ? cc->InputSidePackets().Tag(kRecordIndex).Get<int>()
          : 0;
  int current_idx = 0;
  while (current_idx <= target_idx) {
    tf_status = reader.ReadRecord(&offset, &example_str);
    RET_CHECK(tf_status.ok())
        << "Failed to read tfrecord: " << tf_status.ToString();
    if (current_idx == target_idx) {
      if (cc->OutputSidePackets().HasTag(kExampleTag)) {
        tensorflow::Example tf_example;
        tf_example.ParseFromArray(example_str.data(), example_str.size());
        cc->OutputSidePackets()
            .Tag(kExampleTag)
            .Set(MakePacket<tensorflow::Example>(std::move(tf_example)));
      } else {
        tensorflow::SequenceExample tf_sequence_example;
        tf_sequence_example.ParseFromString(example_str);
        cc->OutputSidePackets()
            .Tag(kSequenceExampleTag)
            .Set(MakePacket<tensorflow::SequenceExample>(
                std::move(tf_sequence_example)));
      }
    }
    ++current_idx;
  }

  return absl::OkStatus();
}

absl::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) {
  return absl::OkStatus();
}

REGISTER_CALCULATOR(TFRecordReaderCalculator);

}  // namespace mediapipe