chromium/third_party/blink/web_tests/external/wpt/webnn/validation_tests/argMinMax.https.any.js

// META: title=validation tests for WebNN API argMin/Max operations
// META: global=window,dedicatedworker
// META: variant=?cpu
// META: variant=?gpu
// META: variant=?npu
// META: script=../resources/utils_validation.js

'use strict';

const kArgMinMaxOperators = [
  'argMin',
  'argMax',
];

const label = 'arg_min_max_1_!';

const tests = [
  {
    name: '[argMin/Max] Test with default options.',
    input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
    axis: 0,
    output: {dimensions: [2, 3, 4]}
  },
  {
    name: '[argMin/Max] Test with axes=1.',
    input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
    axis: 1,
    output: {dimensions: [1, 3, 4]}
  },
  {
    name: '[argMin/Max] Test with outputDataType=int32',
    input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
    axis: 1,
    options: {
      outputDataType: 'int32',
    },
    output: {dimensions: [1, 3, 4]}
  },
  {
    name: '[argMin/Max] Test with outputDataType=int64',
    input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
    axis: 1,
    options: {
      outputDataType: 'int64',
    },
    output: {dimensions: [1, 3, 4]}
  },
  {
    name:
        '[argMin/Max] Throw if the value in axis is greater than or equal to input rank.',
    input: {dataType: 'float32', dimensions: [1, 2, 3, 4]},
    axis: 4,
    options: {
      label: label,
    },
  },
  {
    name: '[argMin/Max] Throw if input is a scalar and axis=0.',
    input: {dataType: 'float32', dimensions: []},
    axis: 0,
    options: {
      label: label,
    },
  },
];

function runTests(operatorName, tests) {
  tests.forEach(test => {
    promise_test(async t => {
      const builder = new MLGraphBuilder(context);
      const input = builder.input(
        'input',
        { dataType: test.input.dataType, dimensions: test.input.dimensions });
      const axis = test.axis;
      if (test.options && test.options.outputDataType !== undefined) {
        if (context.opSupportLimits()[operatorName].output.dataTypes.includes(
          test.options.outputDataType)) {
          const output = builder[operatorName](input, axis, test.options);
          assert_equals(output.dataType(), test.options.outputDataType);
          assert_array_equals(output.shape(), test.output.dimensions);
        } else {
          assert_throws_js(
            TypeError, () => builder[operatorName](input, axis, test.options));
        }
        return;
      }
      if (test.output) {
        const output = builder[operatorName](input, axis, test.options);
        assert_equals(output.dataType(), 'int32');
        assert_array_equals(output.shape(), test.output.dimensions);
      } else {
        const regrexp = /\[arg_min_max_1_\!\]/;
        assert_throws_with_label(
            () => builder[operatorName](input, axis, test.options), regrexp);
      }
    }, test.name.replace('[argMin/Max]', `[${operatorName}]`));
  });
}

kArgMinMaxOperators.forEach((operatorName) => {
  validateInputFromAnotherBuilder(operatorName);
  runTests(operatorName, tests);
});