chromium/services/webnn/dml/tensor_desc_test.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/dml/tensor_desc.h"

#include <vector>

#include "base/numerics/safe_conversions.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/microsoft_dxheaders/include/directml.h"

namespace webnn::dml {

class WebNNTensorDescTest : public testing::Test {};

TEST_F(WebNNTensorDescTest, CreateAndCopyTensorDescA) {
  // Test creating and copying a TensorDesc with empty dimensions, and
  // whether its' members have been set valid values.
  TensorDesc tensor_a(DML_TENSOR_DATA_TYPE_FLOAT32, {});
  EXPECT_EQ(tensor_a.GetDataType(), DML_TENSOR_DATA_TYPE_FLOAT32);
  EXPECT_EQ(tensor_a.GetFlags(), DML_TENSOR_FLAG_NONE);
  EXPECT_EQ(tensor_a.GetDimensions(), std::vector<uint32_t>{1});
  EXPECT_EQ(tensor_a.GetStrides(), std::vector<uint32_t>{1});
  EXPECT_EQ(tensor_a.GetTotalTensorSizeInBytes(), 4u);

  TensorDesc tensor_a_copy1(tensor_a);
  EXPECT_EQ(tensor_a_copy1, tensor_a);

  TensorDesc tensor_a_copy2 = tensor_a;
  EXPECT_EQ(tensor_a_copy2, tensor_a);

  TensorDesc tensor_a_copy3(std::move(tensor_a_copy2));
  EXPECT_EQ(tensor_a_copy3, tensor_a);

  TensorDesc tensor_a_copy4 = std::move(tensor_a_copy3);
  EXPECT_EQ(tensor_a_copy4, tensor_a);

  DML_BUFFER_TENSOR_DESC& buffer_a_desc = tensor_a.buffer_desc_;
  EXPECT_EQ(tensor_a.tensor_desc_.Desc, &buffer_a_desc);
  EXPECT_EQ(buffer_a_desc.DimensionCount,
            base::checked_cast<uint32_t>(tensor_a.dimensions_.size()));
  EXPECT_EQ(buffer_a_desc.Sizes, tensor_a.dimensions_.data());
  EXPECT_EQ(buffer_a_desc.Strides, tensor_a.strides_.data());
}

TEST_F(WebNNTensorDescTest, CreateAndCopyTensorDescB) {
  // Test creating and copying a TensorDesc with DML_TENSOR_FLAG_OWNED_BY_DML
  // and dimensions, and whether its' members have been set valid values.
  std::vector<uint32_t> dimensions = {1, 2, 3}, strides = {6, 3, 1};
  TensorDesc tensor_b(DML_TENSOR_DATA_TYPE_FLOAT32,
                      DML_TENSOR_FLAG_OWNED_BY_DML, dimensions);
  EXPECT_EQ(tensor_b.GetDataType(), DML_TENSOR_DATA_TYPE_FLOAT32);
  EXPECT_EQ(tensor_b.GetFlags(), DML_TENSOR_FLAG_OWNED_BY_DML);
  EXPECT_EQ(tensor_b.GetDimensions(), dimensions);
  EXPECT_EQ(tensor_b.GetStrides(), strides);
  EXPECT_EQ(tensor_b.GetTotalTensorSizeInBytes(), 24u);

  TensorDesc tensor_b_copy1(tensor_b);
  EXPECT_EQ(tensor_b_copy1, tensor_b);

  TensorDesc tensor_b_copy2 = tensor_b;
  EXPECT_EQ(tensor_b_copy2, tensor_b);

  TensorDesc tensor_b_copy3(std::move(tensor_b_copy2));
  EXPECT_EQ(tensor_b_copy3, tensor_b);

  TensorDesc tensor_b_copy4 = std::move(tensor_b_copy3);
  EXPECT_EQ(tensor_b_copy4, tensor_b);

  DML_BUFFER_TENSOR_DESC& buffer_b_desc = tensor_b.buffer_desc_;
  EXPECT_EQ(tensor_b.tensor_desc_.Desc, &buffer_b_desc);
  EXPECT_EQ(buffer_b_desc.DimensionCount,
            base::checked_cast<uint32_t>(tensor_b.dimensions_.size()));
  EXPECT_EQ(buffer_b_desc.Sizes, tensor_b.dimensions_.data());
  EXPECT_EQ(buffer_b_desc.Strides, tensor_b.strides_.data());
}

TEST_F(WebNNTensorDescTest, CreateAndCopyTensorDescC) {
  // Test creating and copying a TensorDesc with strides and dimensions, and
  // whether its' members have been set valid values.
  std::vector<uint32_t> dimensions = {1, 2, 3}, strides = {6, 3, 1};
  TensorDesc tensor_c(DML_TENSOR_DATA_TYPE_FLOAT32,
                      DML_TENSOR_FLAG_OWNED_BY_DML, dimensions, strides);
  EXPECT_EQ(tensor_c.GetDataType(), DML_TENSOR_DATA_TYPE_FLOAT32);
  EXPECT_EQ(tensor_c.GetFlags(), DML_TENSOR_FLAG_OWNED_BY_DML);
  EXPECT_EQ(tensor_c.GetDimensions(), dimensions);
  EXPECT_EQ(tensor_c.GetStrides(), strides);
  EXPECT_EQ(tensor_c.GetTotalTensorSizeInBytes(), 24u);

  TensorDesc tensor_c_copy1(tensor_c);
  EXPECT_EQ(tensor_c_copy1, tensor_c);

  TensorDesc tensor_c_copy2 = tensor_c;
  EXPECT_EQ(tensor_c_copy2, tensor_c);

  TensorDesc tensor_c_copy3(std::move(tensor_c_copy2));
  EXPECT_EQ(tensor_c_copy3, tensor_c);

  TensorDesc tensor_c_copy4 = std::move(tensor_c_copy3);
  EXPECT_EQ(tensor_c_copy4, tensor_c);

  DML_BUFFER_TENSOR_DESC& buffer_c_desc = tensor_c.buffer_desc_;
  EXPECT_EQ(tensor_c.tensor_desc_.Desc, &buffer_c_desc);
  EXPECT_EQ(buffer_c_desc.DimensionCount,
            base::checked_cast<uint32_t>(tensor_c.dimensions_.size()));
  EXPECT_EQ(buffer_c_desc.Sizes, tensor_c.dimensions_.data());
  EXPECT_EQ(buffer_c_desc.Strides, tensor_c.strides_.data());
}

TEST_F(WebNNTensorDescTest, TransposeTensorDesc) {
  std::vector<uint32_t> dimensions = {1, 2, 3}, strides = {6, 3, 1};
  TensorDesc tensor(DML_TENSOR_DATA_TYPE_FLOAT32, DML_TENSOR_FLAG_OWNED_BY_DML,
                    dimensions, strides);
  EXPECT_EQ(tensor.GetDataType(), DML_TENSOR_DATA_TYPE_FLOAT32);
  EXPECT_EQ(tensor.GetFlags(), DML_TENSOR_FLAG_OWNED_BY_DML);
  EXPECT_EQ(tensor.GetDimensions(), dimensions);
  EXPECT_EQ(tensor.GetStrides(), strides);
  EXPECT_EQ(tensor.GetTotalTensorSizeInBytes(), 24u);

  std::vector<uint32_t> permutation = {2, 0, 1};
  tensor.Transpose(permutation);

  EXPECT_EQ(tensor.GetDimensions(), (std::vector<uint32_t>{3, 1, 2}));
  EXPECT_EQ(tensor.GetStrides(), (std::vector<uint32_t>{1, 6, 3}));
}

TEST_F(WebNNTensorDescTest, CreateAndBroadcastTensorDesc) {
  // Test creating a TensorDesc with dimensions and broadcasting it.
  std::vector<uint32_t> dimensions = {1, 1, 3};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, dimensions);
  std::vector<uint32_t> broadcasted_dimensions = {1, 2, 3};
  tensor_desc.BroadcastTo(broadcasted_dimensions);
  EXPECT_EQ(tensor_desc.GetDimensions(), broadcasted_dimensions);
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{3, 0, 1}));
}

TEST_F(WebNNTensorDescTest, CreateTensorDescAndBroadcastTo0D) {
  // Test creating a TensorDesc with dimensions = {1} and broadcasting it to {}.
  std::vector<uint32_t> dimensions = {1};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, dimensions);
  std::vector<uint32_t> broadcasted_dimensions = {};
  tensor_desc.BroadcastTo(broadcasted_dimensions);
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{1}));
}

TEST_F(WebNNTensorDescTest, Create0DTensorDescAndBroadcastTo0D) {
  // Test creating a TensorDesc with dimensions = {} and broadcasting it to {}.
  std::vector<uint32_t> dimensions = {};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, dimensions);
  std::vector<uint32_t> broadcasted_dimensions = {};
  tensor_desc.BroadcastTo(broadcasted_dimensions);
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{1}));
}

TEST_F(WebNNTensorDescTest, EnsureMinimumRank) {
  // Test expanding the tensor from rank 3 to 5 in trailing alignment.
  std::vector<uint32_t> dimensions_a = {3, 4, 5};
  TensorDesc tensor_desc_a(DML_TENSOR_DATA_TYPE_INT32,
                           DML_TENSOR_FLAG_OWNED_BY_DML, dimensions_a);

  tensor_desc_a.EnsureMinimumRank(/*minimum_rank*/ 5,
                                  TensorDesc::Alignment::kTrailing);
  EXPECT_EQ(tensor_desc_a.GetDimensions(),
            (std::vector<uint32_t>{1, 1, 3, 4, 5}));
  EXPECT_EQ(tensor_desc_a.GetStrides(),
            (std::vector<uint32_t>{0, 0, 20, 5, 1}));

  // Test expanding the tensor from rank 4 to 8 in leading alignment.
  std::vector<uint32_t> dimensions_b = {1, 2, 3, 4};
  TensorDesc tensor_desc_b(DML_TENSOR_DATA_TYPE_FLOAT16,
                           DML_TENSOR_FLAG_OWNED_BY_DML, dimensions_b);

  tensor_desc_b.EnsureMinimumRank(/*minimum_rank*/ 8,
                                  TensorDesc::Alignment::kLeading);
  EXPECT_EQ(tensor_desc_b.GetDimensions(),
            (std::vector<uint32_t>{1, 2, 3, 4, 1, 1, 1, 1}));
  EXPECT_EQ(tensor_desc_b.GetStrides(),
            (std::vector<uint32_t>{24, 12, 4, 1, 0, 0, 0, 0}));
}

TEST_F(WebNNTensorDescTest, Make1DBroadcastCompatibleTo4D) {
  // Test creating a TensorDesc with dimensions = {2}, axes = {1} and
  // minimum_rank = 4.
  std::vector<uint32_t> dimensions = {2};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, std::move(dimensions));
  uint32_t axes[1] = {1};
  tensor_desc.MakeBroadcastCompatible(4, axes);
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1, 2, 1, 1}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{0, 1, 0, 0}));
}

TEST_F(WebNNTensorDescTest, Make2DBroadcastCompatibleTo4D) {
  // Test creating a TensorDesc with dimensions = {2, 3}, axes = {1, 3} and
  // minimum_rank = 4.
  std::vector<uint32_t> dimensions = {2, 3};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, std::move(dimensions));
  uint32_t axes[2] = {1, 3};
  tensor_desc.MakeBroadcastCompatible(4, axes);
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1, 2, 1, 3}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{0, 3, 0, 1}));
}

TEST_F(WebNNTensorDescTest, Make2DBroadcastCompatibleTo4DWithNoDefaultStrides) {
  // Test creating a TensorDesc with dimensions = {3, 2}, strides = {1, 3} axes
  // = {1, 3} and minimum_rank = 4.
  std::vector<uint32_t> dimensions = {3, 2};
  std::vector<uint32_t> strides = {1, 3};
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, std::move(dimensions),
                         std::move(strides));
  uint32_t axes[2] = {1, 3};
  tensor_desc.MakeBroadcastCompatible(4, axes);
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1, 3, 1, 2}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{0, 1, 0, 3}));
}

TEST_F(WebNNTensorDescTest, Make0DBroadcastCompatibleTo4D) {
  // Test creating a scale TensorDesc, axes = {} and
  // minimum_rank = 4.
  TensorDesc tensor_desc(DML_TENSOR_DATA_TYPE_FLOAT32,
                         DML_TENSOR_FLAG_OWNED_BY_DML, {});
  tensor_desc.MakeBroadcastCompatible(4, {});
  EXPECT_EQ(tensor_desc.GetDimensions(), (std::vector<uint32_t>{1, 1, 1, 1}));
  EXPECT_EQ(tensor_desc.GetStrides(), (std::vector<uint32_t>{0, 0, 0, 0}));
}
}  // namespace webnn::dml