chromium/services/webnn/dml/tensor_desc.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/dml/tensor_desc.h"

#include "base/check.h"
#include "base/check_op.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/webnn_utils.h"

namespace webnn::dml {

TensorDesc::TensorDesc(DML_TENSOR_DATA_TYPE data_type,
                       std::vector<uint32_t> dimensions)
    : TensorDesc(data_type, DML_TENSOR_FLAG_NONE, std::move(dimensions)) {}

TensorDesc::TensorDesc(DML_TENSOR_DATA_TYPE data_type,
                       DML_TENSOR_FLAGS flags,
                       std::vector<uint32_t> dimensions,
                       std::vector<uint32_t> strides) {
  // DML (as of at least 1.11) requires dimension count to be at least 1 because
  // otherwise validation during operator creation will complain with
  // E_INVALIDARG. So scalars must be conveyed with dimensions = [1].
  dimensions_ =
      dimensions.empty() ? std::vector<uint32_t>{1} : std::move(dimensions);

  if (!strides.empty()) {
    CHECK_EQ(dimensions_.size(), strides.size());
  }

  // If no strides are given, set strides as default value calculated from
  // dimensions, e.g., a tensor with dimensions [1, 2, 3, 4] should have default
  // strides [24, 12, 4, 1], referring to
  // https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml-helper-functions#calculatestrides.
  strides_ =
      strides.empty() ? CalculateStrides(dimensions_) : std::move(strides);

  // Round up to the nearest 4 bytes. The buffer allocation already aligned
  // chunks up to DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT.
  uint64_t minimum_implied_size_in_bytes =
      CalculateDMLBufferTensorSize(data_type, dimensions_, strides_);

  buffer_desc_.DimensionCount = dimensions_.size();
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
  buffer_desc_.TotalTensorSizeInBytes = minimum_implied_size_in_bytes;
  buffer_desc_.GuaranteedBaseOffsetAlignment = 0;
  buffer_desc_.DataType = data_type;
  buffer_desc_.Flags = flags;

  tensor_desc_ = DML_TENSOR_DESC{DML_TENSOR_TYPE_BUFFER, &buffer_desc_};
}

TensorDesc::TensorDesc(TensorDesc const& other)
    : dimensions_(other.dimensions_),
      strides_(other.strides_),
      buffer_desc_(other.buffer_desc_) {
  // Update the internal pointers to dimensions and strides.
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
  tensor_desc_ = DML_TENSOR_DESC{DML_TENSOR_TYPE_BUFFER, &buffer_desc_};
}

TensorDesc::TensorDesc(TensorDesc&& other)
    : dimensions_(std::move(other.dimensions_)),
      strides_(std::move(other.strides_)),
      buffer_desc_(std::move(other.buffer_desc_)) {
  // Update the internal pointers to dimensions and strides.
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
  tensor_desc_ = DML_TENSOR_DESC{DML_TENSOR_TYPE_BUFFER, &buffer_desc_};
}

TensorDesc& TensorDesc::operator=(const TensorDesc& other) {
  dimensions_ = other.dimensions_;
  strides_ = other.strides_;
  buffer_desc_ = other.buffer_desc_;

  // Update the internal pointers to dimensions and strides.
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
  tensor_desc_ = DML_TENSOR_DESC{DML_TENSOR_TYPE_BUFFER, &buffer_desc_};

  return *this;
}

TensorDesc& TensorDesc::operator=(TensorDesc&& other) {
  dimensions_ = std::move(other.dimensions_);
  strides_ = std::move(other.strides_);
  buffer_desc_ = std::move(other.buffer_desc_);

  // Update the internal pointers to dimensions and strides.
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
  tensor_desc_ = DML_TENSOR_DESC{DML_TENSOR_TYPE_BUFFER, &buffer_desc_};

  return *this;
}

TensorDesc::~TensorDesc() = default;

bool TensorDesc::operator==(const TensorDesc& other) const {
  return dimensions_ == other.dimensions_ && strides_ == other.strides_ &&
         buffer_desc_.DataType == other.buffer_desc_.DataType &&
         buffer_desc_.TotalTensorSizeInBytes ==
             other.buffer_desc_.TotalTensorSizeInBytes &&
         buffer_desc_.GuaranteedBaseOffsetAlignment ==
             other.buffer_desc_.GuaranteedBaseOffsetAlignment;
}

void TensorDesc::Transpose(base::span<const uint32_t> permutation) {
  dimensions_ = PermuteArray(dimensions_, permutation);
  strides_ = PermuteArray(strides_, permutation);

  // Round up to the nearest 4 bytes. The buffer allocation already aligned
  // chunks up to DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT.
  uint64_t minimum_implied_size_in_bytes = CalculateDMLBufferTensorSize(
      buffer_desc_.DataType, dimensions_, strides_);
  CHECK_EQ(buffer_desc_.TotalTensorSizeInBytes, minimum_implied_size_in_bytes);

  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
}

// Change the dimensions and strides of the TensorDesc by setting the stride of
// broadcasting dimension to 0 and reuse the stride value for other dimensions,
// its behavior follows the helper function:
// https://learn.microsoft.com/en-us/windows/ai/directml/dml-helper-functions#calculatestrides
void TensorDesc::BroadcastTo(base::span<const uint32_t> broadcasted_dims,
                             size_t ignorable_tails_count) {
  // When broadcasting to a 0D scalar (broadcasted_dims = {}), only the
  // dimension of original tensor equals to {1} is legal, and the dimensions_
  // and strides_ of TensorDesc do not need to be changed at this time.
  if (broadcasted_dims.empty()) {
    CHECK_EQ(ignorable_tails_count, 0u);
    CHECK_EQ(dimensions_.size(), 1u);
    CHECK_EQ(dimensions_[0], 1u);
    return;
  }

  auto original_rank = dimensions_.size(),
       broadcasted_rank = broadcasted_dims.size();
  CHECK_LE(original_rank, broadcasted_rank);

  EnsureMinimumRank(broadcasted_rank, Alignment::kTrailing);

  CHECK_LE(ignorable_tails_count, original_rank);
  for (size_t i = 0; i < broadcasted_rank - ignorable_tails_count; ++i) {
    if (dimensions_[i] != broadcasted_dims[i]) {
      CHECK_EQ(dimensions_[i], 1u);
      dimensions_[i] = broadcasted_dims[i];
      strides_[i] = 0;
    }
  }
}

void TensorDesc::EnsureMinimumRank(size_t minimum_rank, Alignment alignment) {
  if (dimensions_.size() >= minimum_rank) {
    return;
  }

  size_t insertion_count = minimum_rank - dimensions_.size();
  switch (alignment) {
    case Alignment::kLeading: {
      dimensions_.insert(dimensions_.end(), insertion_count, 1u);
      strides_.insert(strides_.end(), insertion_count, 0u);
      break;
    }
    case Alignment::kTrailing: {
      dimensions_.insert(dimensions_.begin(), insertion_count, 1u);
      strides_.insert(strides_.begin(), insertion_count, 0u);
      break;
    }
  }

  // Note the TotalTensorSizeInBytes is not changed by inserting ones in
  // dimensions.
  buffer_desc_.DimensionCount = dimensions_.size();
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
}

void TensorDesc::MakeBroadcastCompatible(size_t minimum_rank,
                                         base::span<const uint32_t> axes) {
  if (dimensions_.size() >= minimum_rank) {
    return;
  }

  CHECK_LE(axes.size(), dimensions_.size());
  std::vector<uint32_t> new_dimensions(minimum_rank, 1);
  std::vector<uint32_t> new_strides(minimum_rank, 0);
  for (size_t i = 0; i < axes.size(); i++) {
    CHECK_LT(axes[i], minimum_rank);
    new_dimensions[axes[i]] = dimensions_[i];
    new_strides[axes[i]] = strides_[i];
  }

  dimensions_ = std::move(new_dimensions);
  strides_ = std::move(new_strides);
  buffer_desc_.DimensionCount = dimensions_.size();
  buffer_desc_.Sizes = dimensions_.data();
  buffer_desc_.Strides = strides_.data();
}

void TensorDesc::SetTotalTensorSizeInBytes(
    uint64_t new_total_tensor_size_bytes) {
  CHECK_GE(new_total_tensor_size_bytes,
           CalculateDMLBufferTensorSize(buffer_desc_.DataType, dimensions_,
                                        strides_));
  CHECK_GE(new_total_tensor_size_bytes, buffer_desc_.TotalTensorSizeInBytes);
  buffer_desc_.TotalTensorSizeInBytes = new_total_tensor_size_bytes;
}
}  // namespace webnn::dml