chromium/components/policy/core/common/scoped_critical_policy_section.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/policy/core/common/scoped_critical_policy_section.h"

#include <windows.h>

#include <userenv.h>

#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/task/bind_post_task.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "components/policy/core/common/policy_types.h"

namespace policy {

namespace {

void EnterSection(
    ScopedCriticalPolicySection::OnSectionEnteredCallback callback) {
  ScopedCriticalPolicySection::Handles handles;
  // We need both user and machine handles. Based on MSFT doc, user handle must
  // be acquired first to prevent dead lock.
  // https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-entercriticalpolicysection
  //
  // If we failed to aquire lock or the API is timeout, we will read the policy
  // regardless, as we used to have.
  handles.user_handle = ::EnterCriticalPolicySection(false);
  if (!handles.user_handle) {
    PLOG(WARNING) << "Failed to enter user critical policy section.";
  }
  handles.machine_handle = ::EnterCriticalPolicySection(true);
  if (!handles.machine_handle) {
    PLOG(WARNING) << "Failed to enter machine critical policy section.";
  }
  std::move(callback).Run(handles);
}

}  // namespace

// static
void ScopedCriticalPolicySection::Enter(
    base::OnceClosure callback,
    const scoped_refptr<base::SequencedTaskRunner>& task_runner) {
  DCHECK(task_runner->RunsTasksInCurrentSequence());
  EnterWithEnterSectionCallback(std::move(callback), EnterSectionCallback(),
                                task_runner);
}

// static
void ScopedCriticalPolicySection::EnterWithEnterSectionCallback(
    base::OnceClosure callback,
    EnterSectionCallback enter_section_callback,
    const scoped_refptr<base::SequencedTaskRunner>& task_runner) {
  auto scoped_section =
      std::make_unique<ScopedCriticalPolicySection>(task_runner);

  scoped_section->enter_section_callback_ =
      enter_section_callback ? std::move(enter_section_callback)
                             : base::BindOnce(&EnterSection);
  scoped_section->Init(base::BindOnce(
      [](std::unique_ptr<ScopedCriticalPolicySection> scoped_section,
         base::OnceClosure callback) { std::move(callback).Run(); },
      std::move(scoped_section), std::move(callback)));
}

ScopedCriticalPolicySection::ScopedCriticalPolicySection(
    const scoped_refptr<base::SequencedTaskRunner>& task_runner)
    : task_runner_(task_runner) {}

ScopedCriticalPolicySection::~ScopedCriticalPolicySection() {
  if (machine_handle_) {
    ::LeaveCriticalPolicySection(machine_handle_);
  }

  if (user_handle_) {
    ::LeaveCriticalPolicySection(user_handle_);
  }
}

void ScopedCriticalPolicySection::Init(base::OnceClosure callback) {
  DCHECK(!callback_);
  callback_ = std::move(callback);

  if (enter_section_callback_) {
    // Call ::EnterCriticalPolicySection in a different thread as the API could
    // take minutes to return.
    // Using `PostTask` instead of `PostTaskAndReplyWithResult` allows unit test
    // mimic blocking function easily.
    auto on_section_entered = base::BindPostTask(
        task_runner_,
        base::BindOnce(&ScopedCriticalPolicySection::OnSectionEntered,
                       weak_factory_.GetWeakPtr()));
    base::ThreadPool::PostTask(
        FROM_HERE,
        {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
        base::BindOnce(std::move(enter_section_callback_),
                       std::move(on_section_entered)));
  }

  // Based on UMA data, 15 seconds timeout is enough for 99.9% cases.
  task_runner_->PostDelayedTask(
      FROM_HERE,
      base::BindOnce(&ScopedCriticalPolicySection::OnSectionEntered,
                     weak_factory_.GetWeakPtr(), Handles()),
      base::Seconds(15));
}

void ScopedCriticalPolicySection::OnSectionEntered(Handles handles) {
  DCHECK(task_runner_->RunsTasksInCurrentSequence());
  machine_handle_ = handles.machine_handle;
  user_handle_ = handles.user_handle;
  std::move(callback_).Run();
}

}  // namespace policy