chromium/third_party/mediapipe/src/mediapipe/util/frame_buffer/halide/rgb_resize_generator.cc

// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "Halide.h"
#include "mediapipe/util/frame_buffer/halide/common.h"

namespace {

using ::Halide::BoundaryConditions::repeat_edge;
using ::mediapipe::frame_buffer::halide::common::resize_bilinear_int;

class RgbResize : public Halide::Generator<RgbResize> {
 public:
  Var x{"x"}, y{"y"};

  Input<Buffer<uint8_t, 3>> src_rgb{"src_rgb"};
  Input<float> scale_x{"scale_x", 1.0f, 0.0f, 1024.0f};
  Input<float> scale_y{"scale_y", 1.0f, 0.0f, 1024.0f};

  Output<Func> dst_rgb{"dst_rgb", UInt(8), 3};

  void generate();
  void schedule();
};

void RgbResize::generate() {
  // Resize each of the RGB planes independently.
  resize_bilinear_int(repeat_edge(src_rgb), dst_rgb, scale_x, scale_y);
}

void RgbResize::schedule() {
  Halide::Func dst_rgb_func = dst_rgb;
  Halide::Var c = dst_rgb_func.args()[2];
  Halide::OutputImageParam rgb_output = dst_rgb_func.output_buffer();
  Halide::Expr input_rgb_channels = src_rgb.dim(2).extent();
  Halide::Expr output_rgb_channels = rgb_output.dim(2).extent();
  Halide::Expr min_width =
      Halide::min(src_rgb.dim(0).extent(), rgb_output.dim(0).extent());

  // Specialize the generated code for RGB and RGBA (input and output channels
  // must match); further, specialize the vectorized implementation so it only
  // runs on images wide enough to support it.
  const int vector_size = natural_vector_size<uint8_t>();
  const Expr channel_specializations[] = {
      input_rgb_channels == 3 && output_rgb_channels == 3,
      input_rgb_channels == 4 && output_rgb_channels == 4,
  };
  dst_rgb_func.reorder(c, x, y);
  for (const Expr& channel_specialization : channel_specializations) {
    dst_rgb_func.specialize(channel_specialization && min_width >= vector_size)
        .unroll(c)
        .vectorize(x, vector_size);
  }

  // Require that the input/output buffer be interleaved and tightly-
  // packed; that is, either RGBRGBRGB[...] or RGBARGBARGBA[...],
  // without gaps between pixels.
  src_rgb.dim(0).set_stride(input_rgb_channels);
  src_rgb.dim(2).set_stride(1);
  rgb_output.dim(0).set_stride(output_rgb_channels);
  rgb_output.dim(2).set_stride(1);

  // RGB planes starts at index zero in every dimension.
  src_rgb.dim(0).set_min(0);
  src_rgb.dim(1).set_min(0);
  src_rgb.dim(2).set_min(0);
  rgb_output.dim(0).set_min(0);
  rgb_output.dim(1).set_min(0);
  rgb_output.dim(2).set_min(0);
}

}  // namespace

HALIDE_REGISTER_GENERATOR(RgbResize, rgb_resize_generator)