chromium/services/webnn/dml/utils.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_DML_UTILS_H_
#define SERVICES_WEBNN_DML_UTILS_H_

#include <string>
#include <vector>

#include "base/component_export.h"
#include "base/containers/span.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "third_party/microsoft_dxheaders/include/directml.h"
#include "third_party/microsoft_dxheaders/src/include/directx/d3d12.h"

// Windows SDK headers should be included after DirectX headers.
#include <wrl.h>

namespace webnn::dml {

uint64_t CalculateDMLBufferTensorSize(DML_TENSOR_DATA_TYPE data_type,
                                      const std::vector<uint32_t>& dimensions,
                                      const std::vector<uint32_t>& strides);

std::vector<uint32_t> CalculateStrides(base::span<const uint32_t> dimensions);

// Gets the ID3D12Device used to create the IDMLDevice1.
Microsoft::WRL::ComPtr<ID3D12Device> GetD3D12Device(IDMLDevice1* dml_device);

// Returns the maximum feature level supported by the DML device.
DML_FEATURE_LEVEL GetMaxSupportedDMLFeatureLevel(IDMLDevice1* dml_device);

// Creates a transition barrier which is used to specify the resource is
// transitioning from `before` to `after` states.
D3D12_RESOURCE_BARRIER COMPONENT_EXPORT(WEBNN_SERVICE)
    CreateTransitionBarrier(ID3D12Resource* resource,
                            D3D12_RESOURCE_STATES before,
                            D3D12_RESOURCE_STATES after);

// Helper function to upload data from CPU to GPU, the resource can be created
// for a single buffer or a big buffer combined from multiple buffers.
void COMPONENT_EXPORT(WEBNN_SERVICE)
    UploadBufferWithBarrier(CommandRecorder* command_recorder,
                            Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
                            Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
                            size_t buffer_size);

// Helper function to readback data from GPU to CPU, the resource can be created
// for a single buffer or a big buffer combined from multiple buffers.
void COMPONENT_EXPORT(WEBNN_SERVICE) ReadbackBufferWithBarrier(
    CommandRecorder* command_recorder,
    Microsoft::WRL::ComPtr<ID3D12Resource> readback_buffer,
    Microsoft::WRL::ComPtr<ID3D12Resource> default_buffer,
    size_t buffer_size);

// TODO(crbug.com/40278771): move buffer helpers into command recorder.
void COMPONENT_EXPORT(WEBNN_SERVICE)
    UploadBufferWithBarrier(CommandRecorder* command_recorder,
                            BufferImplDml* dst_buffer,
                            Microsoft::WRL::ComPtr<ID3D12Resource> src_buffer,
                            size_t buffer_size);

void COMPONENT_EXPORT(WEBNN_SERVICE)
    ReadbackBufferWithBarrier(CommandRecorder* command_recorder,
                              Microsoft::WRL::ComPtr<ID3D12Resource> dst_buffer,
                              BufferImplDml* src_buffer,
                              size_t buffer_size);

mojom::ErrorPtr CreateError(mojom::Error::Code error_code,
                            const std::string& error_message,
                            std::string_view label = "");

// Create a resource with `size` bytes in
// D3D12_RESOURCE_STATE_UNORDERED_ACCESS state from the default heap of the
// owned D3D12 device. For this method and the other two, if there are no
// errors, S_OK is returned and the created resource is returned via
// `resource`. Otherwise, the corresponding HRESULT error code is returned.
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
    CreateDefaultBuffer(ID3D12Device* device,
                        uint64_t size,
                        const wchar_t* name_for_debugging,
                        Microsoft::WRL::ComPtr<ID3D12Resource>& resource);

// Create a resource with `size` bytes in D3D12_RESOURCE_STATE_GENERIC_READ
// state from the uploading heap of the owned D3D12 device.
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
    CreateUploadBuffer(ID3D12Device* device,
                       uint64_t size,
                       const wchar_t* name_for_debugging,
                       Microsoft::WRL::ComPtr<ID3D12Resource>& resource);

// Create a resource with `size` bytes in D3D12_RESOURCE_STATE_COPY_DEST state
// from the reading-back heap of the owned D3D12 device.
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
    CreateReadbackBuffer(ID3D12Device* device,
                         uint64_t size,
                         const wchar_t* name_for_debugging,
                         Microsoft::WRL::ComPtr<ID3D12Resource>& resource);

// Create a resource with `size` bytes in
// D3D12_RESOURCE_STATE_UNORDERED_ACCESS state and from a custom heap with CPU
// memory pool (D3D12_MEMORY_POOL_L0) optimized for CPU uploading data to GPU.
// This type of buffer should only be created for GPU with UMA (Unified Memory
// Architecture).
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE)
    CreateCustomUploadBuffer(ID3D12Device* device,
                             uint64_t size,
                             const wchar_t* name_for_debugging,
                             Microsoft::WRL::ComPtr<ID3D12Resource>& resource);

// Create a resource with `size` bytes in
// D3D12_RESOURCE_STATE_UNORDERED_ACCESS state and from a custom heap with CPU
// memory pool (D3D12_MEMORY_POOL_L0) optimized for CPU reading data back from
// GPU. This type of buffer should only be created for GPU with UMA (Unified
// Memory Architecture).
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE) CreateCustomReadbackBuffer(
    ID3D12Device* device,
    uint64_t size,
    const wchar_t* name_for_debugging,
    Microsoft::WRL::ComPtr<ID3D12Resource>& resource);

// Create a descriptor heap with D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV type,
// D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE flag and large enough for the
// number of descriptors.
HRESULT COMPONENT_EXPORT(WEBNN_SERVICE) CreateDescriptorHeap(
    ID3D12Device* device,
    uint32_t num_descriptors,
    const wchar_t* name_for_debugging,
    Microsoft::WRL::ComPtr<ID3D12DescriptorHeap>& descriptor_heap);

}  // namespace webnn::dml

#endif  // SERVICES_WEBNN_DML_UTILS_H_