'use strict';
// https://www.w3.org/TR/webnn/#enumdef-mloperanddatatype
const allWebNNOperandDataTypes = [
'float32',
'float16',
'int32',
'uint32',
'int64',
'uint64',
'int8',
'uint8'
];
// https://webidl.spec.whatwg.org/#idl-unsigned-long
// The unsigned long type is an unsigned integer type that has values in the
// range [0, 4294967295].
// 4294967295 = 2 ** 32 - 1
const kMaxUnsignedLong = 2 ** 32 - 1;
const floatingPointTypes = ['float32', 'float16'];
const signedIntegerTypes = ['int32', 'int64', 'int8'];
const unsignedLongType = 'unsigned long';
const dimensions0D = [];
const dimensions1D = [2];
const dimensions2D = [2, 3];
const dimensions3D = [2, 3, 4];
const dimensions4D = [2, 3, 4, 5];
const dimensions5D = [2, 3, 4, 5, 6];
const adjustOffsetsArray = [
// Decrease 1
-1,
// Increase 1
1
];
// TODO
// Add more 5+ dimensions
const allWebNNDimensionsArray = [
dimensions0D,
dimensions1D,
dimensions2D,
dimensions3D,
dimensions4D,
dimensions5D
];
const notUnsignedLongAxisArray = [
// String
'abc',
// BigInt
BigInt(100),
// Object
{
value: 1
},
// Array Object
[0, 1],
// Date Object
new Date("2024-01-01"),
];
function getRank(inputDimensions) {
return inputDimensions.length;
}
function getAxisArray(inputDimensions) {
return Array.from({length: inputDimensions.length}, (_, i) => i);
}
function getAxesArrayContainSameValues(inputDimensions) {
// TODO
// Currently this function returns an array containing each element which all have the same value.
// For example axes: [0, 1, 2] for 3D input tensor
// this function returns
// [
// // two values are same
// [0, 0],
// [1, 1],
// [2, 2],
// // three values are same
// [0, 0, 0],
// [1, 1, 1]
// [2, 2, 2]
// ]
// while it should return
// [
// // two values are same
// [0, 0],
// [1, 1],
// [2, 2],
// [0, 0, 1],
// [0, 0, 2],
// [0, 1, 0],
// [0, 2, 0],
// [1, 0, 0],
// [2, 0, 0],
// [1, 1, 0],
// [1, 1, 2],
// [1, 0, 1],
// [1, 2, 1],
// [0, 1, 1],
// [2, 1, 1],
// [2, 2, 0],
// [2, 2, 1],
// [2, 0, 2],
// [2, 1, 2],
// [0, 2, 2],
// [1, 2, 2],
// // three (all) values are same
// [0, 0, 0],
// [1, 1, 1]
// [2, 2, 2]
// ]
const axesArrayContainSameValues = [];
const length = inputDimensions.length;
if (length >= 2) {
const validAxesArrayFull = getAxisArray(inputDimensions);
for (let index = 0; index < length; index++) {
axesArrayContainSameValues.push(new Array(2).fill(validAxesArrayFull[index]));
if (length > 2) {
axesArrayContainSameValues.push(new Array(3).fill(validAxesArrayFull[index]));
}
}
}
return axesArrayContainSameValues;
}
function generateUnbroadcastableDimensionsArray(dimensions) {
// Currently this function returns an array of some unbroadcastable dimensions.
// for example given dimensions [2, 3, 4]
// this function returns
// [
// [3, 3, 4],
// [2, 2, 4],
// [2, 4, 4],
// [2, 3, 3],
// [2, 3, 5],
// [3],
// [5],
// [1, 3],
// [1, 5],
// [1, 1, 3],
// [1, 1, 5],
// [1, 1, 1, 3],
// [1, 1, 1, 5],
// ]
if (dimensions.every(v => v === 1)) {
throw new Error(`[${dimensions}] always can be broadcasted`);
}
const resultDimensions = [];
const length = dimensions.length;
if (!dimensions.slice(0, length - 1).every(v => v === 1)) {
for (let i = 0; i < length; i++) {
if (dimensions[i] !== 1) {
for (let offset of [-1, 1]) {
const dimensionsB = dimensions.slice();
dimensionsB[i] += offset;
if (dimensionsB[i] !== 1) {
resultDimensions.push(dimensionsB);
}
}
}
}
}
const lastDimensionSize = dimensions[length - 1];
if (lastDimensionSize !== 1) {
for (let j = 0; j <= length; j++) {
if (lastDimensionSize > 2) {
resultDimensions.push(Array(j).fill(1).concat([lastDimensionSize - 1]));
}
resultDimensions.push(Array(j).fill(1).concat([lastDimensionSize + 1]));
}
}
return resultDimensions;
}
function generateOutOfRangeValuesArray(type) {
let range, outsideValueArray;
switch (type) {
case 'unsigned long':
range = [0, kMaxUnsignedLong];
break;
default:
throw new Error(`Unsupport ${type}`);
}
outsideValueArray = [range[0] - 1, range[1] + 1];
return outsideValueArray;
}
let inputIndex = 0;
let inputAIndex = 0;
let inputBIndex = 0;
let context;
test(() => assert_not_equals(navigator.ml, undefined, "ml property is defined on navigator"));
promise_setup(async () => {
if (navigator.ml === undefined) {
return;
}
const deviceType = location.search.substring(1);
context = await navigator.ml.createContext({deviceType: deviceType});
}, {explicit_timeout: true});
function assert_throws_with_label(func, regrexp) {
try {
func.call(this);
assert_true(false, 'Graph builder method unexpectedly succeeded');
} catch (e) {
assert_equals(e.name, 'TypeError');
const error_message = e.message;
assert_not_equals(error_message.match(regrexp), null);
}
}
function validateTwoInputsBroadcastable(operationName, label) {
if (navigator.ml === undefined) {
return;
}
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
if (dimensions.length > 0) {
const inputA = builder.input(`inputA${++inputAIndex}`, {dataType, dimensions});
const unbroadcastableDimensionsArray = generateUnbroadcastableDimensionsArray(dimensions);
for (let unbroadcastableDimensions of unbroadcastableDimensionsArray) {
const inputB = builder.input(`inputB${++inputBIndex}`, {dataType, dimensions: unbroadcastableDimensions});
assert_equals(typeof builder[operationName], 'function');
const options = {label};
const regrexp = new RegExp('\\[' + label + '\\]');
assert_throws_with_label(
() => builder[operationName](inputA, inputB, options), regrexp);
assert_throws_with_label(
() => builder[operationName](inputB, inputA, options), regrexp);
}
}
}
}
}, `[${operationName}] TypeError is expected if two inputs aren't broadcastable`);
}
function validateTwoInputsOfSameDataType(operationName, label) {
if (navigator.ml === undefined) {
return;
}
let operationNameArray;
if (typeof operationName === 'string') {
operationNameArray = [operationName];
} else if (Array.isArray(operationName)) {
operationNameArray = operationName;
} else {
throw new Error(`${operationName} should be an operation name string or an operation name string array`);
}
for (let subOperationName of operationNameArray) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const inputA = builder.input(`inputA${++inputAIndex}`, {dataType, dimensions});
for (let dataTypeB of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(
dataTypeB)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputB${++inputBIndex}`, {dataTypeB, dimensions1D}));
continue;
}
if (dataType !== dataTypeB) {
const inputB = builder.input(`inputB${++inputBIndex}`, {dataType: dataTypeB, dimensions});
const options = {label};
const regrexp = new RegExp('\\[' + label + '\\]');
assert_equals(typeof builder[subOperationName], 'function');
assert_throws_with_label(
() => builder[subOperationName](inputA, inputB, options),
regrexp);
}
}
}
}
}, `[${subOperationName}] TypeError is expected if two inputs aren't of same data type`);
}
}
/**
* Validate options.axes by given operation and input rank for
* argMin/Max / layerNormalization / Reduction operations operations
* @param {(String[]|String)} operationName - An operation name array or an
* operation name
*/
function validateOptionsAxes(operationName) {
if (navigator.ml === undefined) {
return;
}
let operationNameArray;
if (typeof operationName === 'string') {
operationNameArray = [operationName];
} else if (Array.isArray(operationName)) {
operationNameArray = operationName;
} else {
throw new Error(`${operationName} should be an operation name string or an operation name string array`);
}
const invalidAxisArray = generateOutOfRangeValuesArray(unsignedLongType);
for (let subOperationName of operationNameArray) {
// TypeError is expected if any of options.axes elements is not an unsigned long interger
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const rank = getRank(dimensions);
if (rank >= 1) {
const input =
builder.input(`input${++inputIndex}`, {dataType, dimensions});
for (let invalidAxis of invalidAxisArray) {
assert_equals(typeof builder[subOperationName], 'function');
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: invalidAxis}));
}
for (let axis of notUnsignedLongAxisArray) {
assert_false(
typeof axis === 'number' && Number.isInteger(axis),
`[${subOperationName}] any of options.axes elements should be of 'unsigned long'`);
assert_equals(typeof builder[subOperationName], 'function');
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: [axis]}));
}
}
}
}
}, `[${subOperationName}] TypeError is expected if any of options.axes elements is not an unsigned long interger`);
// TypeError is expected if any of options.axes elements is greater or equal
// to the size of input
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const rank = getRank(dimensions);
if (rank >= 1) {
const input =
builder.input(`input${++inputIndex}`, {dataType, dimensions});
assert_equals(typeof builder[subOperationName], 'function');
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: [rank]}));
assert_throws_js(
TypeError,
() => builder[subOperationName](input, {axes: [rank + 1]}));
}
}
}
}, `[${subOperationName}] TypeError is expected if any of options.axes elements is greater or equal to the size of input`);
// TypeError is expected if two or more values are same in the axes sequence
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of allWebNNOperandDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const rank = getRank(dimensions);
if (rank >= 2) {
const input =
builder.input(`input${++inputIndex}`, {dataType, dimensions});
const axesArrayContainSameValues =
getAxesArrayContainSameValues(dimensions);
for (let axes of axesArrayContainSameValues) {
assert_equals(typeof builder[subOperationName], 'function');
assert_throws_js(
TypeError, () => builder[subOperationName](input, {axes}));
}
}
}
}
}, `[${subOperationName}] TypeError is expected if two or more values are same in the axes sequence`);
}
}
// TODO: remove this method once all the data type limits of the unary
// operations are specified in context.OpSupportLimits().
/**
* Validate a unary operation
* @param {String} operationName - An operation name
* @param {Array} supportedDataTypes - Test building with these data types
* succeeds and test building with all other data types fails
*/
function validateUnaryOperation(operationName, supportedDataTypes, label) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of supportedDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const input = builder.input(`input`, {dataType, dimensions});
assert_equals(typeof builder[operationName], 'function');
const output = builder[operationName](input);
assert_equals(output.dataType(), dataType);
assert_array_equals(output.shape(), dimensions);
}
}
}, `[${operationName}] Test building an unary operator with supported type.`);
const unsupportedDataTypes =
new Set(allWebNNOperandDataTypes).difference(new Set(supportedDataTypes));
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of unsupportedDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const input = builder.input(`input`, {dataType, dimensions});
assert_equals(typeof builder[operationName], 'function');
const options = {label};
const regrexp = new RegExp('\\[' + label + '\\]');
assert_throws_with_label(
() => builder[operationName](input, options), regrexp);
}
}
}, `[${operationName}] Throw if the dataType is not supported for an unary operator.`);
}
/**
* Validate a single input operation
* @param {String} operationName - An operation name
*/
function validateSingleInputOperation(operationName, label) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
const supportedDataTypes =
context.opSupportLimits()[operationName].input.dataTypes;
for (let dataType of supportedDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const input = builder.input(`input`, {dataType, dimensions});
const output = builder[operationName](input);
assert_equals(output.dataType(), dataType);
assert_array_equals(output.shape(), dimensions);
}
}
}, `[${operationName}] Test building the operator with supported data type.`);
promise_test(async t => {
const builder = new MLGraphBuilder(context);
const unsupportedDataTypes =
new Set(allWebNNOperandDataTypes)
.difference(new Set(
context.opSupportLimits()[operationName].input.dataTypes));
for (let dataType of unsupportedDataTypes) {
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
assert_throws_js(
TypeError,
() => builder.input(
`inputA${++inputAIndex}`, {dataType, dimensions1D}));
continue;
}
for (let dimensions of allWebNNDimensionsArray) {
const input = builder.input(`input`, {dataType, dimensions});
assert_equals(typeof builder[operationName], 'function');
const options = {label};
const regrexp = new RegExp('\\[' + label + '\\]');
assert_throws_with_label(
() => builder[operationName](input, options), regrexp);
}
}
}, `[${operationName}] Throw if the data type is not supported for the operator.`);
}
/**
* Basic test that the builder method specified by `operationName` throws if
* given an input from another builder. Operands which do not accept a float32
* square 2D input should pass their own `operatorDescriptor`.
* @param {String} operationName
* @param {String} operatorDescriptor
*/
function validateInputFromAnotherBuilder(operatorName, operatorDescriptor = {
dataType: 'float32',
dimensions: [2, 2]
}) {
multi_builder_test(async (t, builder, otherBuilder) => {
const inputFromOtherBuilder =
otherBuilder.input('input', operatorDescriptor);
assert_equals(typeof builder[operatorName], 'function');
assert_throws_js(
TypeError, () => builder[operatorName](inputFromOtherBuilder));
}, `[${operatorName}] throw if input is from another builder`);
};
/**
* Basic test that the builder method specified by `operationName` throws if one
* of its inputs is from another builder. This helper may only be used by
* operands which accept float32 square 2D inputs.
* @param {String} operationName
*/
function validateTwoInputsFromMultipleBuilders(operatorName) {
const opDescriptor = {dataType: 'float32', dimensions: [2, 2]};
multi_builder_test(async (t, builder, otherBuilder) => {
const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor);
const input = builder.input('input', opDescriptor);
assert_equals(typeof builder[operatorName], 'function');
assert_throws_js(
TypeError, () => builder[operatorName](inputFromOtherBuilder, input));
}, `[${operatorName}] throw if first input is from another builder`);
multi_builder_test(async (t, builder, otherBuilder) => {
const inputFromOtherBuilder = otherBuilder.input('other', opDescriptor);
const input = builder.input('input', opDescriptor);
assert_equals(typeof builder[operatorName], 'function');
assert_throws_js(
TypeError, () => builder[operatorName](input, inputFromOtherBuilder));
}, `[${operatorName}] throw if second input is from another builder`);
};
function multi_builder_test(func, description) {
promise_test(async t => {
const context = await navigator.ml.createContext();
const builder = new MLGraphBuilder(context);
const otherBuilder = new MLGraphBuilder(context);
await func(t, builder, otherBuilder);
}, description);
}