chromium/third_party/mediapipe/src/mediapipe/calculators/tensor/text_to_tensor_calculator.cc

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstring>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/memory_manager.h"
#include "mediapipe/framework/memory_manager_service.h"

namespace mediapipe {
namespace api2 {

// Trivially converts an input string into a Tensor that stores a copy of
// the string.
//
// Inputs:
//   TEXT - std::string
//
// Outputs:
//   TENSORS - std::vector<Tensor>
//     Vector containing a single Tensor storing a copy of the input string.
//     Note that the underlying buffer of the Tensor is not necessarily
//     null-terminated. It is the graph writer's responsibility to copy the
//     correct number of characters when copying from this Tensor's buffer.
//
// Example:
//   node {
//     calculator: "TextToTensorCalculator"
//     input_stream: "TEXT:text"
//     output_stream: "TENSORS:tensors"
//   }
class TextToTensorCalculator : public Node {
 public:
  static constexpr Input<std::string> kTextIn{"TEXT"};
  static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};

  MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);

  absl::Status Open(CalculatorContext* cc) final;
  absl::Status Process(CalculatorContext* cc) override;

  static absl::Status UpdateContract(CalculatorContract* cc);

 private:
  // Enable pooling of AHWBs in Tensor instances.
  MemoryManager* memory_manager_ = nullptr;
};

absl::Status TextToTensorCalculator::Open(CalculatorContext* cc) {
  if (cc->Service(kMemoryManagerService).IsAvailable()) {
    memory_manager_ = &cc->Service(kMemoryManagerService).GetObject();
  }
  return absl::OkStatus();
}

absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
  absl::string_view text = kTextIn(cc).Get();
  int input_len = static_cast<int>(text.length());

  std::vector<Tensor> result;
  result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len}),
                    memory_manager_});
  std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
              input_len * sizeof(char));
  kTensorsOut(cc).Send(std::move(result));
  return absl::OkStatus();
}

// static
absl::Status TextToTensorCalculator::UpdateContract(CalculatorContract* cc) {
  cc->UseService(kMemoryManagerService).Optional();
  return absl::OkStatus();
}

MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);

}  // namespace api2
}  // namespace mediapipe