chromium/third_party/xnnpack/src/src/subgraph/even-split.c

// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <inttypes.h>
#include <stdint.h>  // For size_t.
#include <string.h>

#include "xnnpack.h"
#include "xnnpack/allocation-type.h"
#include "xnnpack/common.h"
#include "xnnpack/log.h"
#include "xnnpack/node-type.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/subgraph-validation.h"
#include "xnnpack/subgraph.h"
#include "pthreadpool.h"

static enum xnn_status create_even_split_operator_helper(
    const uint32_t output_id,
    const struct xnn_node* node,
    struct xnn_operator_data* opdata,
    size_t index)
{}

static enum xnn_status create_even_split_n_operator(
  const struct xnn_node* node,
  const struct xnn_value* values,
  size_t num_values,
  struct xnn_operator_data* opdata,
  struct xnn_code_cache* code_cache,
  size_t num_splits,
  xnn_weights_cache_t weights_cache)
{}

static enum xnn_status create_even_split2_operator(
  const struct xnn_node* node,
  const struct xnn_value* values,
  size_t num_values,
  struct xnn_operator_data* opdata,
  struct xnn_code_cache* code_cache,
  xnn_weights_cache_t weights_cache)
{}

static enum xnn_status create_even_split3_operator(
  const struct xnn_node* node,
  const struct xnn_value* values,
  size_t num_values,
  struct xnn_operator_data* opdata,
  struct xnn_code_cache* code_cache,
  xnn_weights_cache_t weights_cache)
{}

static enum xnn_status create_even_split4_operator(
  const struct xnn_node* node,
  const struct xnn_value* values,
  size_t num_values,
  struct xnn_operator_data* opdata,
  struct xnn_code_cache* code_cache,
  xnn_weights_cache_t weights_cache)
{}

static enum xnn_status reshape_even_split_operator_helper(
  const struct xnn_value* values,
  const uint32_t num_values,
  struct xnn_operator_data* opdata,
  size_t index,
  size_t num_splits,
  int32_t axis,
  pthreadpool_t threadpool)
{}

static enum xnn_status reshape_even_split_n_operator(
  struct xnn_operator_data* opdata,
  struct xnn_value* values,
  size_t num_values,
  size_t num_splits,
  pthreadpool_t threadpool)
{}

static enum xnn_status reshape_even_split2_operator(
  struct xnn_operator_data* opdata,
  struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

static enum xnn_status reshape_even_split3_operator(
  struct xnn_operator_data* opdata,
  struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

static enum xnn_status reshape_even_split4_operator(
  struct xnn_operator_data* opdata,
  struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_even_split_operator_helper(
  const struct xnn_value* values,
  const uint32_t num_values,
  const struct xnn_operator_data* opdata,
  size_t index,
  const void* input_data,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_even_split_n_operator(
  const struct xnn_operator_data* opdata,
  const struct xnn_value* values,
  size_t num_values,
  size_t num_splits,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_even_split2_operator(
  const struct xnn_operator_data* opdata,
  const struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_even_split3_operator(
  const struct xnn_operator_data* opdata,
  const struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

static enum xnn_status setup_even_split4_operator(
  const struct xnn_operator_data* opdata,
  const struct xnn_value* values,
  size_t num_values,
  pthreadpool_t threadpool)
{}

enum xnn_status check_output_value(
  xnn_subgraph_t subgraph,
  int32_t split_dim,
  uint32_t input_id,
  uint32_t output_id,
  const char* nth,
  enum xnn_node_type node_type)
{}

enum xnn_status check_output_compute_type(
  xnn_subgraph_t subgraph,
  uint32_t input_id,
  uint32_t output_id,
  const char* nth,
  enum xnn_node_type node_type)
{}

enum xnn_status xnn_define_even_split_n(
  enum xnn_node_type node_type,
  xnn_subgraph_t subgraph,
  int32_t split_dim,
  uint32_t input_id,
  size_t num_outputs,
  const uint32_t* output_ids,
  uint32_t flags)
{
  assert(num_outputs > 1);
  assert(num_outputs < 5);

  enum xnn_status status;
  if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
    return status;
  }

  if ((status = xnn_subgraph_check_input_node_id(node_type, input_id, subgraph->num_values)) != xnn_status_success) {
    return status;
  }

  const struct xnn_value* input_value = &subgraph->values[input_id];
  status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
  if (status != xnn_status_success) {
    return status;
  }

  status = check_output_value(subgraph, split_dim, input_id, output_ids[0], "first", node_type);
  if (status != xnn_status_success) {
    return status;
  }
  status = check_output_value(subgraph, split_dim, input_id, output_ids[1], "second", node_type);
  if (status != xnn_status_success) {
    return status;
  }

  if (num_outputs > 2) {
    status = check_output_value(subgraph, split_dim, input_id, output_ids[2], "third", node_type);
    if (status != xnn_status_success) {
      return status;
    }
  }
  if (num_outputs > 3) {
    status = check_output_value(subgraph, split_dim, input_id, output_ids[3], "fourth", node_type);
    if (status != xnn_status_success) {
      return status;
    }
  }

  enum xnn_compute_type compute_type = xnn_compute_type_invalid;
  switch (input_value->datatype) {
    case xnn_datatype_fp16:
      compute_type = xnn_compute_type_fp16;
      break;
    case xnn_datatype_fp32:
      compute_type = xnn_compute_type_fp32;
      break;
    case xnn_datatype_qint8:
      compute_type = xnn_compute_type_qs8;
      break;
    case xnn_datatype_quint8:
      compute_type = xnn_compute_type_qu8;
      break;
    default:
      xnn_log_error(
        "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
        xnn_node_type_to_string(node_type), input_id, xnn_datatype_to_string(input_value->datatype),
        input_value->datatype);
      return xnn_status_invalid_parameter;
  }

  if (compute_type == xnn_compute_type_qs8 || compute_type == xnn_compute_type_qu8) {
    check_output_compute_type(subgraph, input_id, output_ids[0], "first", node_type);
    check_output_compute_type(subgraph, input_id, output_ids[1], "second", node_type);
    if (num_outputs > 2) {
      check_output_compute_type(subgraph, input_id, output_ids[2], "third", node_type);
    }
    if (num_outputs > 3) {
      check_output_compute_type(subgraph, input_id, output_ids[3], "fourth", node_type);
    }
  }

  struct xnn_node* node = xnn_subgraph_new_node(subgraph);
  if (node == NULL) {
    return xnn_status_out_of_memory;
  }

  node->params.even_split.axis = split_dim;
  node->type = node_type;
  node->compute_type = compute_type;
  node->num_inputs = 1;
  node->inputs[0] = input_id;
  node->num_outputs = num_outputs;
  node->outputs[0] = output_ids[0];
  node->outputs[1] = output_ids[1];
  switch (num_outputs) {
    case 2:
      node->create = create_even_split2_operator;
      node->reshape = reshape_even_split2_operator;
      node->setup = setup_even_split2_operator;
      break;
    case 3:
      node->outputs[2] = output_ids[2];
      node->create = create_even_split3_operator;
      node->reshape = reshape_even_split3_operator;
      node->setup = setup_even_split3_operator;
      break;
    case 4:
      node->outputs[2] = output_ids[2];
      node->outputs[3] = output_ids[3];
      node->create = create_even_split4_operator;
      node->reshape = reshape_even_split4_operator;
      node->setup = setup_even_split4_operator;
      break;
    default:
      XNN_UNREACHABLE;
  }
  node->flags = flags;

  return xnn_status_success;
};

enum xnn_status xnn_define_even_split2(
  xnn_subgraph_t subgraph,
  int32_t split_dim,
  uint32_t input_id,
  uint32_t output1_id,
  uint32_t output2_id,
  uint32_t flags)
{}

enum xnn_status xnn_define_even_split3(
  xnn_subgraph_t subgraph,
  int32_t split_dim,
  uint32_t input_id,
  uint32_t output1_id,
  uint32_t output2_id,
  uint32_t output3_id,
  uint32_t flags)
{}

enum xnn_status xnn_define_even_split4(
  xnn_subgraph_t subgraph,
  int32_t split_dim,
  uint32_t input_id,
  uint32_t output1_id,
  uint32_t output2_id,
  uint32_t output3_id,
  uint32_t output4_id,
  uint32_t flags)
{}