chromium/third_party/libaom/source/libaom/av1/encoder/tx_search.c

/*
 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */

#include "av1/common/cfl.h"
#include "av1/common/reconintra.h"
#include "av1/encoder/block.h"
#include "av1/encoder/hybrid_fwd_txfm.h"
#include "av1/common/idct.h"
#include "av1/encoder/model_rd.h"
#include "av1/encoder/random.h"
#include "av1/encoder/rdopt_utils.h"
#include "av1/encoder/sorting_network.h"
#include "av1/encoder/tx_prune_model_weights.h"
#include "av1/encoder/tx_search.h"
#include "av1/encoder/txb_rdopt.h"

#define PROB_THRESH_OFFSET_TX_TYPE

struct rdcost_block_args {};

TxCandidateInfo;

// origin_threshold * 128 / 100
static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] =;

// lookup table for predict_skip_txfm
// int max_tx_size = max_txsize_rect_lookup[bsize];
// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
//   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] =;

// look-up table for sqrt of number of pixels in a transform block
// rounded up to the nearest integer.
static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] =;

static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {}

static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
                                      const int64_t ref_best_rd,
                                      const uint32_t hash) {}

static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
                                    RD_STATS *const rd_stats,
                                    MACROBLOCK *const x) {}

int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
                            int blk_col, const BLOCK_SIZE plane_bsize,
                            const BLOCK_SIZE tx_bsize,
                            unsigned int *block_mse_q8) {}

// Computes the residual block's SSE and mean on all visible 4x4s in the
// transform block
static inline int64_t pixel_diff_stats(
    MACROBLOCK *x, int plane, int blk_row, int blk_col,
    const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
    unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {}

// Uses simple features on top of DCT coefficients to quickly predict
// whether optimal RD decision is to skip encoding the residual.
// The sse value is stored in dist.
static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
                             int reduced_tx_set) {}

// Used to set proper context for early termination with skip = 1.
static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
                                 BLOCK_SIZE bsize, int64_t dist) {}

static inline void save_mb_rd_info(int n4, uint32_t hash,
                                   const MACROBLOCK *const x,
                                   const RD_STATS *const rd_stats,
                                   MB_RD_RECORD *mb_rd_record) {}

static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
                                 const SPEED_FEATURES *sf,
                                 int tx_size_search_method) {}

static inline void select_tx_block(
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);

// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
// 0: Do not collect any RD stats
// 1: Collect RD stats for transform units
// 2: Collect RD stats for partition units
#if CONFIG_COLLECT_RD_STATS

static inline void get_energy_distribution_fine(
    const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
    const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
    double *verdist) {
  const int bw = block_size_wide[bsize];
  const int bh = block_size_high[bsize];
  unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

  if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
    // Special cases: calculate 'esq' values manually, as we don't have 'vf'
    // functions for the 16 (very small) sub-blocks of this block.
    const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
    const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
    assert(bw <= 32);
    assert(bh <= 32);
    assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
    if (cpi->common.seq_params->use_highbitdepth) {
      const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
      const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
      for (int i = 0; i < bh; ++i)
        for (int j = 0; j < bw; ++j) {
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
          esq[index] +=
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
        }
    } else {
      for (int i = 0; i < bh; ++i)
        for (int j = 0; j < bw; ++j) {
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
          esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
                        (src[j + i * src_stride] - dst[j + i * dst_stride]);
        }
    }
  } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
    const int f_index =
        (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
    assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
    const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
    assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
    assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
                                 dst_stride, &esq[1]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
                                 dst_stride, &esq[2]);
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
                                 dst_stride, &esq[3]);
    src += bh / 4 * src_stride;
    dst += bh / 4 * dst_stride;

    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
                                 dst_stride, &esq[5]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
                                 dst_stride, &esq[6]);
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
                                 dst_stride, &esq[7]);
    src += bh / 4 * src_stride;
    dst += bh / 4 * dst_stride;

    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
                                 dst_stride, &esq[9]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
                                 dst_stride, &esq[10]);
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
                                 dst_stride, &esq[11]);
    src += bh / 4 * src_stride;
    dst += bh / 4 * dst_stride;

    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
                                 dst_stride, &esq[13]);
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
                                 dst_stride, &esq[14]);
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
                                 dst_stride, &esq[15]);
  }

  double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
                 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
                 esq[12] + esq[13] + esq[14] + esq[15];
  if (total > 0) {
    const double e_recip = 1.0 / total;
    hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
    hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
    hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
    if (need_4th) {
      hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
    }
    verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
    verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
    verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
    if (need_4th) {
      verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
    }
  } else {
    hordist[0] = verdist[0] = 0.25;
    hordist[1] = verdist[1] = 0.25;
    hordist[2] = verdist[2] = 0.25;
    if (need_4th) {
      hordist[3] = verdist[3] = 0.25;
    }
  }
}

static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
  double sum = 0.0;
  for (int j = 0; j < h; ++j) {
    for (int i = 0; i < w; ++i) {
      const int err = diff[j * stride + i];
      sum += err * err;
    }
  }
  assert(w > 0 && h > 0);
  return sum / (w * h);
}

static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
  double sum = 0.0;
  for (int j = 0; j < h; ++j) {
    for (int i = 0; i < w; ++i) {
      sum += abs(diff[j * stride + i]);
    }
  }
  assert(w > 0 && h > 0);
  return sum / (w * h);
}

static inline void get_2x2_normalized_sses_and_sads(
    const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
    int src_stride, const uint8_t *const dst, int dst_stride,
    const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
    double *const sad_norm_arr) {
  const BLOCK_SIZE tx_bsize_half =
      get_partition_subsize(tx_bsize, PARTITION_SPLIT);
  if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
    const int half_width = block_size_wide[tx_bsize] / 2;
    const int half_height = block_size_high[tx_bsize] / 2;
    for (int row = 0; row < 2; ++row) {
      for (int col = 0; col < 2; ++col) {
        const int16_t *const this_src_diff =
            src_diff + row * half_height * diff_stride + col * half_width;
        if (sse_norm_arr) {
          sse_norm_arr[row * 2 + col] =
              get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
        }
        if (sad_norm_arr) {
          sad_norm_arr[row * 2 + col] =
              get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
        }
      }
    }
  } else {  // use function pointers to calculate stats
    const int half_width = block_size_wide[tx_bsize_half];
    const int half_height = block_size_high[tx_bsize_half];
    const int num_samples_half = half_width * half_height;
    for (int row = 0; row < 2; ++row) {
      for (int col = 0; col < 2; ++col) {
        const uint8_t *const this_src =
            src + row * half_height * src_stride + col * half_width;
        const uint8_t *const this_dst =
            dst + row * half_height * dst_stride + col * half_width;

        if (sse_norm_arr) {
          unsigned int this_sse;
          cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
                                             dst_stride, &this_sse);
          sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
        }

        if (sad_norm_arr) {
          const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
              this_src, src_stride, this_dst, dst_stride);
          sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
        }
      }
    }
  }
}

#if CONFIG_COLLECT_RD_STATS == 1
static double get_mean(const int16_t *diff, int stride, int w, int h) {
  double sum = 0.0;
  for (int j = 0; j < h; ++j) {
    for (int i = 0; i < w; ++i) {
      sum += diff[j * stride + i];
    }
  }
  assert(w > 0 && h > 0);
  return sum / (w * h);
}
static inline void PrintTransformUnitStats(
    const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
    TX_TYPE tx_type, int64_t rd) {
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;

  // Generate small sample to restrict output size.
  static unsigned int seed = 21743;
  if (lcg_rand16(&seed) % 256 > 0) return;

  const char output_file[] = "tu_stats.txt";
  FILE *fout = fopen(output_file, "a");
  if (!fout) return;

  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
  const MACROBLOCKD *const xd = &x->e_mbd;
  const int plane = 0;
  struct macroblock_plane *const p = &x->plane[plane];
  const struct macroblockd_plane *const pd = &xd->plane[plane];
  const int txw = tx_size_wide[tx_size];
  const int txh = tx_size_high[tx_size];
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
  const int num_samples = txw * txh;

  const double rate_norm = (double)rd_stats->rate / num_samples;
  const double dist_norm = (double)rd_stats->dist / num_samples;

  fprintf(fout, "%g %g", rate_norm, dist_norm);

  const int src_stride = p->src.stride;
  const uint8_t *const src =
      &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
  const int dst_stride = pd->dst.stride;
  const uint8_t *const dst =
      &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
  unsigned int sse;
  cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
  const double sse_norm = (double)sse / num_samples;

  const unsigned int sad =
      cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
  const double sad_norm = (double)sad / num_samples;

  fprintf(fout, " %g %g", sse_norm, sad_norm);

  const int diff_stride = block_size_wide[plane_bsize];
  const int16_t *const src_diff =
      &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];

  double sse_norm_arr[4], sad_norm_arr[4];
  get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
                                   dst_stride, src_diff, diff_stride,
                                   sse_norm_arr, sad_norm_arr);
  for (int i = 0; i < 4; ++i) {
    fprintf(fout, " %g", sse_norm_arr[i]);
  }
  for (int i = 0; i < 4; ++i) {
    fprintf(fout, " %g", sad_norm_arr[i]);
  }

  const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
  const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];

  fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
          tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);

  int model_rate;
  int64_t model_dist;
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
                                   &model_rate, &model_dist);
  const double model_rate_norm = (double)model_rate / num_samples;
  const double model_dist_norm = (double)model_dist / num_samples;
  fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);

  const double mean = get_mean(src_diff, diff_stride, txw, txh);
  float hor_corr, vert_corr;
  av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
                                  &vert_corr);
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);

  double hdist[4] = { 0 }, vdist[4] = { 0 };
  get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
                               1, hdist, vdist);
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);

  fprintf(fout, " %d %" PRId64, x->rdmult, rd);

  fprintf(fout, "\n");
  fclose(fout);
}
#endif  // CONFIG_COLLECT_RD_STATS == 1

#if CONFIG_COLLECT_RD_STATS >= 2
static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
  const AV1_COMMON *cm = &cpi->common;
  const int num_planes = av1_num_planes(cm);
  const MACROBLOCKD *xd = &x->e_mbd;
  const MB_MODE_INFO *mbmi = xd->mi[0];
  int64_t total_sse = 0;
  for (int plane = 0; plane < num_planes; ++plane) {
    const struct macroblock_plane *const p = &x->plane[plane];
    const struct macroblockd_plane *const pd = &xd->plane[plane];
    const BLOCK_SIZE bs =
        get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
    unsigned int sse;

    if (plane) continue;

    cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
                            pd->dst.stride, &sse);
    total_sse += sse;
  }
  total_sse <<= 4;
  return total_sse;
}

static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
                             int64_t sse, int *est_residue_cost,
                             int64_t *est_dist) {
  const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
  if (md->ready) {
    if (sse < md->dist_mean) {
      *est_residue_cost = 0;
      *est_dist = sse;
    } else {
      *est_dist = (int64_t)round(md->dist_mean);
      const double est_ld = md->a * sse + md->b;
      // Clamp estimated rate cost by INT_MAX / 2.
      // TODO([email protected]): find better solution than clamping.
      if (fabs(est_ld) < 1e-2) {
        *est_residue_cost = INT_MAX / 2;
      } else {
        double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
        if (est_residue_cost_dbl < 0) {
          *est_residue_cost = 0;
        } else {
          *est_residue_cost =
              (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
        }
      }
      if (*est_residue_cost <= 0) {
        *est_residue_cost = 0;
        *est_dist = sse;
      }
    }
    return 1;
  }
  return 0;
}

static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
                                   const uint8_t *dst8, int dst_stride, int w,
                                   int h) {
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
  double sum = 0.0;
  for (int j = 0; j < h; ++j) {
    for (int i = 0; i < w; ++i) {
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
      sum += diff;
    }
  }
  assert(w > 0 && h > 0);
  return sum / (w * h);
}

static double get_diff_mean(const uint8_t *src, int src_stride,
                            const uint8_t *dst, int dst_stride, int w, int h) {
  double sum = 0.0;
  for (int j = 0; j < h; ++j) {
    for (int i = 0; i < w; ++i) {
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
      sum += diff;
    }
  }
  assert(w > 0 && h > 0);
  return sum / (w * h);
}

static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
                                            const TileDataEnc *tile_data,
                                            MACROBLOCK *x,
                                            const RD_STATS *const rd_stats,
                                            BLOCK_SIZE plane_bsize) {
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;

  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
      (tile_data == NULL ||
       !tile_data->inter_mode_rd_models[plane_bsize].ready))
    return;
  (void)tile_data;
  // Generate small sample to restrict output size.
  static unsigned int seed = 95014;

  if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
      1)
    return;

  const char output_file[] = "pu_stats.txt";
  FILE *fout = fopen(output_file, "a");
  if (!fout) return;

  MACROBLOCKD *const xd = &x->e_mbd;
  const int plane = 0;
  struct macroblock_plane *const p = &x->plane[plane];
  struct macroblockd_plane *pd = &xd->plane[plane];
  const int diff_stride = block_size_wide[plane_bsize];
  int bw, bh;
  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
                     &bh);
  const int num_samples = bw * bh;
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
  const int shift = (xd->bd - 8);

  const double rate_norm = (double)rd_stats->rate / num_samples;
  const double dist_norm = (double)rd_stats->dist / num_samples;
  const double rdcost_norm =
      (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;

  fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);

  const int src_stride = p->src.stride;
  const uint8_t *const src = p->src.buf;
  const int dst_stride = pd->dst.stride;
  const uint8_t *const dst = pd->dst.buf;
  const int16_t *const src_diff = p->src_diff;

  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
  const double sse_norm = (double)sse / num_samples;

  const unsigned int sad =
      cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
  const double sad_norm =
      (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);

  fprintf(fout, " %g %g", sse_norm, sad_norm);

  double sse_norm_arr[4], sad_norm_arr[4];
  get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
                                   dst_stride, src_diff, diff_stride,
                                   sse_norm_arr, sad_norm_arr);
  if (shift) {
    for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
    for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
  }
  for (int i = 0; i < 4; ++i) {
    fprintf(fout, " %g", sse_norm_arr[i]);
  }
  for (int i = 0; i < 4; ++i) {
    fprintf(fout, " %g", sad_norm_arr[i]);
  }

  fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);

  int model_rate;
  int64_t model_dist;
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
                                   &model_rate, &model_dist);
  const double model_rdcost_norm =
      (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
  const double model_rate_norm = (double)model_rate / num_samples;
  const double model_dist_norm = (double)model_dist / num_samples;
  fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
          model_rdcost_norm);

  double mean;
  if (is_cur_buf_hbd(xd)) {
    mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
                                pd->dst.stride, bw, bh);
  } else {
    mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
                         bw, bh);
  }
  mean /= (1 << shift);
  float hor_corr, vert_corr;
  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
                                  &vert_corr);
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);

  double hdist[4] = { 0 }, vdist[4] = { 0 };
  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
                               dst_stride, 1, hdist, vdist);
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);

  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
    assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
    const int64_t overall_sse = get_sse(cpi, x);
    int est_residue_cost = 0;
    int64_t est_dist = 0;
    get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
                      &est_dist);
    const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
    const double est_dist_norm = (double)est_dist / num_samples;
    const double est_rdcost_norm =
        (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
    fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
            est_rdcost_norm);
  }

  fprintf(fout, "\n");
  fclose(fout);
}
#endif  // CONFIG_COLLECT_RD_STATS >= 2
#endif  // CONFIG_COLLECT_RD_STATS

static inline void inverse_transform_block_facade(MACROBLOCK *const x,
                                                  int plane, int block,
                                                  int blk_row, int blk_col,
                                                  int eob, int reduced_tx_set) {}

static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
                               int block, int blk_row, int blk_col,
                               BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
                               const TXB_CTX *const txb_ctx, int skip_trellis,
                               TX_TYPE best_tx_type, int do_quant,
                               int *rate_cost, uint16_t best_eob) {}

static unsigned pixel_dist_visible_only(
    const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
    const int src_stride, const uint8_t *dst, const int dst_stride,
    const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
    int visible_cols) {}

// Compute the pixel domain distortion from src and dst on all visible 4x4s in
// the
// transform block.
static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
                           int plane, const uint8_t *src, const int src_stride,
                           const uint8_t *dst, const int dst_stride,
                           int blk_row, int blk_col,
                           const BLOCK_SIZE plane_bsize,
                           const BLOCK_SIZE tx_bsize) {}

static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
                                           int plane, BLOCK_SIZE plane_bsize,
                                           int block, int blk_row, int blk_col,
                                           TX_SIZE tx_size) {}

// pruning thresholds for prune_txk_type and prune_txk_type_separ
static const int prune_factors[5] =;  // scale 1000
static const int mul_factors[5] =;       // scale 100

// R-D costs are sorted in ascending order.
static inline void sort_rd(int64_t rds[], int txk[], int len) {}

static inline int64_t av1_block_error_qm(const tran_low_t *coeff,
                                         const tran_low_t *dqcoeff,
                                         intptr_t block_size,
                                         const qm_val_t *qmatrix,
                                         const int16_t *scan, int64_t *ssz) {}

static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
                                        TX_SIZE tx_size,
                                        const qm_val_t *qmatrix,
                                        const int16_t *scan, int64_t *out_dist,
                                        int64_t *out_sse) {}

static uint16_t prune_txk_type_separ(
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
    int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
    int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {}

static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
                               int block, TX_SIZE tx_size, int blk_row,
                               int blk_col, BLOCK_SIZE plane_bsize,
                               int *txk_map, uint16_t allowed_tx_mask,
                               int prune_factor, const TXB_CTX *const txb_ctx,
                               int reduced_tx_set_used) {}

// These thresholds were calibrated to provide a certain number of TX types
// pruned by the model on average, i.e. selecting a threshold with index i
// will lead to pruning i+1 TX types on average
static const float *prune_2D_adaptive_thresholds[] =;

static inline float get_adaptive_thresholds(
    TX_SIZE tx_size, TxSetType tx_set_type,
    TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {}

static inline void get_energy_distribution_finer(const int16_t *diff,
                                                 int stride, int bw, int bh,
                                                 float *hordist,
                                                 float *verdist) {}

static inline bool check_bit_mask(uint16_t mask, int val) {}

static inline void set_bit_mask(uint16_t *mask, int val) {}

static inline void unset_bit_mask(uint16_t *mask, int val) {}

static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
                        int blk_row, int blk_col, TxSetType tx_set_type,
                        TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
                        uint16_t *allowed_tx_mask) {}

static float get_dev(float mean, double x2_sum, int num) {}

// Writes the features required by the ML model to predict tx split based on
// mean and standard deviation values of the block and sub-blocks.
// Returns the number of elements written to the output array which is at most
// 12 currently. Hence 'features' buffer should be able to accommodate at least
// 12 elements.
static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
                                        int bh, float *features) {}

static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
                               int blk_col, TX_SIZE tx_size) {}

static inline uint16_t get_tx_mask(
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
    int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
    const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
    int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {}

#if CONFIG_RD_DEBUG
static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
                                         int txb_coeff_cost) {
  rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
}
#endif

static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
                              TX_SIZE tx_size, const TX_TYPE tx_type,
                              const TXB_CTX *const txb_ctx,
                              int reduced_tx_set_used) {}

static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
                                          QUANT_PARAM *quant_param, int plane,
                                          int block, TX_SIZE tx_size,
                                          int quant_b_adapt, int qstep,
                                          unsigned int coeff_opt_satd_threshold,
                                          int skip_trellis, int dc_only_blk) {}

// Predict DC only blocks if the residual variance is below a qstep based
// threshold.For such blocks, transform type search is bypassed.
static inline void predict_dc_only_block(
    MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
    int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
    int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
    int *dc_only_blk) {}

// Search for the best transform type for a given transform block.
// This function can be used for both inter and intra, both luma and chroma.
static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
                           int block, int blk_row, int blk_col,
                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
                           const TXB_CTX *const txb_ctx,
                           FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
                           int64_t ref_best_rd, RD_STATS *best_rd_stats) {}

// Pick transform type for a luma transform block of tx_size. Note this function
// is used only for inter-predicted blocks.
static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
                              TX_SIZE tx_size, int blk_row, int blk_col,
                              int block, int plane_bsize, TXB_CTX *txb_ctx,
                              RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
                              int64_t ref_rdcost) {}

static inline void try_tx_block_no_split(
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
    const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
    int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
    FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {}

static inline void try_tx_block_split(
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
    int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
    FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {}

static float get_var(float mean, double x2_sum, int num) {}

static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
                                   int bh, float *dev_of_mean,
                                   float *var_of_vars) {}

static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
                                    int blk_row, int blk_col, TX_SIZE tx_size,
                                    int *try_no_split, int *try_split,
                                    int pruning_level) {}

// Search for the best transform partition(recursive)/type for a given
// inter-predicted luma block. The obtained transform selection will be saved
// in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
static inline void select_tx_block(
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {}

static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
                                          MACROBLOCK *x, RD_STATS *rd_stats,
                                          int64_t ref_best_rd, BLOCK_SIZE bs) {}

static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
                                           MACROBLOCK *x, RD_STATS *rd_stats,
                                           int64_t ref_best_rd, BLOCK_SIZE bs) {}

#if !CONFIG_REALTIME_ONLY
static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
                                            int blk_col, BLOCK_SIZE bsize,
                                            TX_SIZE tx_size) {
  const MACROBLOCKD *const xd = &x->e_mbd;
  const MB_MODE_INFO *const mbmi = xd->mi[0];

  // Disable the pruning logic using NN model for the following cases:
  // 1) Lossless coding as only 4x4 transform is evaluated in this case
  // 2) When transform and current block sizes do not match as the features are
  // obtained over the current block
  // 3) When operating bit-depth is not 8-bit as the input features are not
  // scaled according to bit-depth.
  if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
      xd->bd != 8)
    return;

  // Currently NN model based pruning is supported only when largest transform
  // size is 8x8
  if (tx_size != TX_8X8) return;

  // Neural network model is a sequential neural net and was trained using SGD
  // optimizer. The model can be further improved in terms of speed/quality by
  // considering the following experiments:
  // 1) Generate ML model by training with balanced data for different learning
  // rates and optimizers.
  // 2) Experiment with ML model by adding features related to the statistics of
  // top and left pixels to capture the accuracy of reconstructed neighbouring
  // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
  // sub-blocks, etc.
  // 3) Generate ML models for transform blocks other than 8x8.
  const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
  const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;

  float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
  const int diff_stride = block_size_wide[bsize];

  const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
                        MI_SIZE * blk_col;
  const int bw = tx_size_wide[tx_size];
  const int bh = tx_size_high[tx_size];

  int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);

  features[feature_idx++] = log1pf((float)x->source_variance);

  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
  const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
  features[feature_idx++] = log_dc_q_square;
  assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
  for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
    features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
                  av1_intra_tx_split_8x8_std[i];
  }

  float score;
  av1_nn_predict(features, nn_config, 1, &score);

  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
  if (score <= intra_tx_prune_thresh[0])
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
  else if (score > intra_tx_prune_thresh[1])
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
}
#endif  // !CONFIG_REALTIME_ONLY

/*!\brief Transform type search for luma macroblock with fixed transform size.
 *
 * \ingroup transform_search
 * Search for the best transform type and return the transform coefficients RD
 * cost of current luma macroblock with the given uniform transform size.
 *
 * \param[in]    x              Pointer to structure holding the data for the
                                current encoding macroblock
 * \param[in]    cpi            Top-level encoder structure
 * \param[in]    rd_stats       Pointer to struct to keep track of the RD stats
 * \param[in]    ref_best_rd    Best RD cost seen for this block so far
 * \param[in]    bs             Size of the current macroblock
 * \param[in]    tx_size        The given transform size
 * \param[in]    ftxs_mode      Transform search mode specifying desired speed
                                and quality tradeoff
 * \param[in]    skip_trellis   Binary flag indicating if trellis optimization
                                should be skipped
 * \return       An int64_t value that is the best RD cost found.
 */
static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                                RD_STATS *rd_stats, int64_t ref_best_rd,
                                BLOCK_SIZE bs, TX_SIZE tx_size,
                                FAST_TX_SEARCH_MODE ftxs_mode,
                                int skip_trellis) {}

// Search for the best uniform transform size and type for current coding block.
static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
                                               MACROBLOCK *x,
                                               RD_STATS *rd_stats,
                                               int64_t ref_best_rd,
                                               BLOCK_SIZE bs) {}

// Search for the best transform type for the given transform block in the
// given plane/channel, and calculate the corresponding RD cost.
static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
                                 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
                                 void *arg) {}

int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                              RD_STATS *rd_stats, int64_t ref_best_rd,
                              BLOCK_SIZE bs, TX_SIZE tx_size) {}

// Search for the best transform type for a luma inter-predicted block, given
// the transform block partitions.
// This function is used only when some speed features are enabled.
static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
                                int blk_col, int block, TX_SIZE tx_size,
                                BLOCK_SIZE plane_bsize, int depth,
                                ENTROPY_CONTEXT *above_ctx,
                                ENTROPY_CONTEXT *left_ctx,
                                TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
                                int64_t ref_best_rd, RD_STATS *rd_stats,
                                FAST_TX_SEARCH_MODE ftxs_mode) {}

// search for tx type with tx sizes already decided for a inter-predicted luma
// partition block. It's used only when some speed features are enabled.
// Return value 0: early termination triggered, no valid rd cost available;
//              1: rd cost values are valid.
static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                           RD_STATS *rd_stats, BLOCK_SIZE bsize,
                           int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {}

// Search for the best transform size and type for current inter-predicted
// luma block with recursive transform block partitioning. The obtained
// transform selection will be saved in xd->mi[0], the corresponding RD stats
// will be saved in rd_stats. The returned value is the corresponding RD cost.
static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
                                       RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                       int64_t ref_best_rd) {}

// Return 1 to terminate transform search early. The decision is made based on
// the comparison with the reference RD cost and the model-estimated RD cost.
static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
                                              MACROBLOCK *x, BLOCK_SIZE bsize,
                                              int64_t ref_best_rd) {}

void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                                         RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                         int64_t ref_best_rd) {}

void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                                       RD_STATS *rd_stats, BLOCK_SIZE bs,
                                       int64_t ref_best_rd) {}

int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
                  BLOCK_SIZE bsize, int64_t ref_best_rd) {}

void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
                          RD_STATS *rd_stats, int64_t ref_best_rd,
                          int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
                          TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
                          int skip_trellis) {}

int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
                    RD_STATS *rd_stats, RD_STATS *rd_stats_y,
                    RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {}