chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc

// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"

namespace mediapipe {

namespace tf = ::tensorflow;

// Given an input Tensor (example dimensions [1, 1024, 1, 5]), it squeezes all
// dimensions with size 1, or dimensions at specific indices, producing a tensor
// containing identical data (example output dimensions [1024, 5]).
class TensorSqueezeDimensionsCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Need one input";
    cc->Inputs().Index(0).Set<tf::Tensor>(
        // Input Tensor
    );
    RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Need one output";
    cc->Outputs().Index(0).Set<tf::Tensor>(
        // Output Tensor Reduced Dimensions
    );
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    options_ = cc->Options<TensorSqueezeDimensionsCalculatorOptions>();
    RET_CHECK(options_.squeeze_all_single_dims() ^ (options_.dim_size() > 0))
        << "Must specify dimensions to remove, or set squeeze_all_single_dims, "
           "but not both. Received options: "
        << options_.DebugString();
    if (options_.dim_size() > 0) {
      remove_dims_ =
          std::vector<int32_t>(options_.dim().begin(), options_.dim().end());
      std::sort(remove_dims_.rbegin(), remove_dims_.rend());
      remove_dims_initialized_ = true;
    }
    cc->SetOffset(0);
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    const tf::Tensor& input_tensor = cc->Inputs().Index(0).Get<tf::Tensor>();
    tf::TensorShape tensor_shape = input_tensor.shape();
    if (!remove_dims_initialized_) {
      // Happens iff options.squeeze_all_single_dims is set.
      // Initialize remove_dims_ to all dimensions with size 1.
      InitializeToRemoveAllSingletonDimensions(tensor_shape);
      remove_dims_initialized_ = true;
    }
    for (const int dim : remove_dims_) {
      RET_CHECK_GT(tensor_shape.dims(), dim)
          << "Dimension " << dim
          << " does not exist in input tensor with num dimensions "
          << input_tensor.dims();
      RET_CHECK_EQ(tensor_shape.dim_size(dim), 1)
          << "Cannot remove dimension " << dim << " with size "
          << tensor_shape.dim_size(dim);
      tensor_shape.RemoveDim(dim);
    }

    std::unique_ptr<tf::Tensor> output_tensor(new tf::Tensor);
    RET_CHECK(output_tensor->CopyFrom(input_tensor, tensor_shape));
    cc->Outputs().Index(0).Add(output_tensor.release(), cc->InputTimestamp());
    return absl::OkStatus();
  }

  absl::Status Close(CalculatorContext* cc) override {
    return absl::OkStatus();
  }

 private:
  TensorSqueezeDimensionsCalculatorOptions options_;
  std::vector<int32_t> remove_dims_;
  bool remove_dims_initialized_;

  void InitializeToRemoveAllSingletonDimensions(
      const tf::TensorShape& tensor_shape) {
    const int dims = tensor_shape.dims();
    for (int i = dims - 1; i >= 0; --i) {
      if (tensor_shape.dim_size(i) == 1) {
        remove_dims_.push_back(i);
      }
    }
    if (remove_dims_.empty()) {
      ABSL_LOG(ERROR)
          << "TensorSqueezeDimensionsCalculator is squeezing input with "
             "no single-dimensions. Calculator will be a no-op.";
      ABSL_LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape "
                      << tensor_shape.DebugString();
    }
  }
};
REGISTER_CALCULATOR(TensorSqueezeDimensionsCalculator);

}  // namespace mediapipe