chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file_win.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <windows.h>

#include "absl/cleanup/cleanup.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/scoped_file.h"

namespace mediapipe::tasks::genai::llm_utils {
namespace {

class MemoryMappedFileWin : public MemoryMappedFile {
 public:
  MemoryMappedFileWin(HANDLE hmap, uint64_t length, void* data)
      : hmap_(hmap), length_(length), data_(data) {}
  ~MemoryMappedFileWin() override {
    ::UnmapViewOfFile(data_);
    ::CloseHandle(hmap_);
  }

  uint64_t length() override { return length_; }

  void* data() override { return data_; }

 private:
  HANDLE hmap_;
  uint64_t length_;
  void* data_;
};

absl::StatusOr<std::unique_ptr<MemoryMappedFile>> CreateImpl(HANDLE hfile,
                                                             uint64_t offset,
                                                             uint64_t length,
                                                             const char* key,
                                                             bool writable) {
  RET_CHECK_EQ(offset % MemoryMappedFile::GetOffsetAlignment(), 0)
      << "Offset must be a multiple of allocation granularity: " << offset
      << ", " << MemoryMappedFile::GetOffsetAlignment();

  LARGE_INTEGER size;
  RET_CHECK(::GetFileSizeEx(hfile, &size)) << "Failed to get size.";
  int64_t file_size = static_cast<int64_t>(size.QuadPart);
  RET_CHECK_GE(file_size, length + offset) << "Length and offset too large.";
  if (length == 0) {
    length = file_size - offset;
  }

  DWORD access = FILE_MAP_COPY;
  DWORD protect = PAGE_WRITECOPY;
  if (writable) {
    access = FILE_MAP_ALL_ACCESS;
    protect = PAGE_READWRITE;
  }

  HANDLE hmap = ::OpenFileMappingA(access, false, key);
  if (hmap == NULL) {
    hmap = ::CreateFileMappingA(hfile, nullptr, protect, 0, 0, key);
  }

  RET_CHECK(hmap) << "Failed to create mapping.";
  auto close_hmap = absl::MakeCleanup([hmap] { ::CloseHandle(hmap); });

  ULARGE_INTEGER map_start = {};
  map_start.QuadPart = offset;
  void* mapped_region = ::MapViewOfFile(hmap, access, map_start.HighPart,
                                        map_start.LowPart, length);
  RET_CHECK(mapped_region) << "Failed to map.";

  std::move(close_hmap).Cancel();

  return std::make_unique<MemoryMappedFileWin>(hmap, length, mapped_region);
}

}  // namespace

// static
size_t MemoryMappedFile::GetOffsetAlignment() {
  SYSTEM_INFO sys_info;
  ::GetSystemInfo(&sys_info);
  return sys_info.dwAllocationGranularity;
}

// static
absl::StatusOr<std::unique_ptr<MemoryMappedFile>> MemoryMappedFile::Create(
    absl::string_view path) {
  MP_ASSIGN_OR_RETURN(auto scoped_file, ScopedFile::Open(path));
  return CreateImpl(scoped_file.file(), 0, 0, nullptr, false);
}

// static
absl::StatusOr<std::unique_ptr<MemoryMappedFile>> MemoryMappedFile::Create(
    HANDLE file, uint64_t offset, uint64_t length, absl::string_view key) {
  return CreateImpl(file, offset, length, key.empty() ? nullptr : key.data(),
                    false);
}

absl::StatusOr<std::unique_ptr<MemoryMappedFile>>
MemoryMappedFile::CreateMutable(absl::string_view path) {
  MP_ASSIGN_OR_RETURN(auto scoped_file, ScopedFile::OpenWritable(path));
  return CreateImpl(scoped_file.file(), 0, 0, nullptr, true);
}

}  // namespace mediapipe::tasks::genai::llm_utils