#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
#include <algorithm>
#include <cassert>
#include "ruy/profiler/instrumentation.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
template <typename T>
inline void ExtractPatchIntoBufferColumn(
const RuntimeShape& input_shape, int w, int h, int b, int kheight,
int kwidth, int stride_width, int stride_height, int pad_width,
int pad_height, int in_width, int in_height, int in_depth,
int single_buffer_length, int buffer_id, const T* in_data,
T* conv_buffer_data, uint8_t zero_byte) { … }
template <typename T>
void DilatedIm2col(const ConvParams& params, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& filter_shape,
const RuntimeShape& output_shape, T* im2col_data,
const int32_t* zero_bytes, const int zero_bytes_len) { … }
template <typename T>
void DilatedIm2col(const ConvParams& params, uint8_t zero_byte,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& filter_shape,
const RuntimeShape& output_shape, T* im2col_data) { … }
template <typename T>
void Im2col(const ConvParams& params, int kheight, int kwidth,
uint8_t zero_byte, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) { … }
template <typename T>
void Im2col(const ConvParams& params, int kheight, int kwidth,
const int32_t* input_offsets, const int input_offsets_size,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) { … }
template <typename T>
inline void ExtractPatchIntoBufferColumn3D(
int b, int d, int h, int w,
int kdepth, int kheight, int kwidth,
int stride_depth, int stride_height, int stride_width,
int pad_depth, int pad_height, int pad_width,
int in_depth, int in_height, int in_width, int in_channel,
int output_row_offset, const T* in_data, T* conv_buffer_data,
uint8_t zero_byte) { … }
template <typename T>
void Im2col3D(const Conv3DParams& params, int kdepth, int kheight, int kwidth,
uint8_t zero_byte, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& im2col_shape,
T* im2col_data) { … }
template <typename T>
inline void DilatedIm2col3D(const Conv3DParams& params, int filter_depth,
int filter_height, int filter_width,
uint8_t zero_byte, const RuntimeShape& input_shape,
const T* input_data,
const RuntimeShape& im2col_shape, T* im2col_data) { … }
}
}
#endif