chromium/base/win/com_init_balancer_unittest.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/win/com_init_balancer.h"

#include <shlobj.h>
#include <wrl/client.h>

#include "base/test/gtest_util.h"
#include "base/win/com_init_util.h"
#include "base/win/scoped_com_initializer.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace base {
namespace win {

using Microsoft::WRL::ComPtr;

TEST(TestComInitBalancer, BalancedPairsWithComBalancerEnabled) {
  {
    // Assert COM has initialized correctly.
    ScopedCOMInitializer com_initializer(
        ScopedCOMInitializer::Uninitialization::kBlockPremature);
    ASSERT_TRUE(com_initializer.Succeeded());

    // Create COM object successfully.
    ComPtr<IUnknown> shell_link;
    HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                    IID_PPV_ARGS(&shell_link));
    EXPECT_TRUE(SUCCEEDED(hr));
  }

  // ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
  if (DCHECK_IS_ON()) {
    EXPECT_NOTREACHED_DEATH(AssertComInitialized());
  }
}

TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerEnabled) {
  {
    // Assert COM has initialized correctly.
    ScopedCOMInitializer com_initializer(
        ScopedCOMInitializer::Uninitialization::kBlockPremature);
    ASSERT_TRUE(com_initializer.Succeeded());

    // Attempt to prematurely uninitialize the COM library.
    ::CoUninitialize();
    ::CoUninitialize();

    // Assert COM is still initialized.
    AssertComInitialized();

    // Create COM object successfully.
    ComPtr<IUnknown> shell_link;
    HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                    IID_PPV_ARGS(&shell_link));
    EXPECT_TRUE(SUCCEEDED(hr));
  }

  // ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
  if (DCHECK_IS_ON()) {
    EXPECT_NOTREACHED_DEATH(AssertComInitialized());
  }
}

TEST(TestComInitBalancer, BalancedPairsWithComBalancerDisabled) {
  {
    // Assert COM has initialized correctly.
    ScopedCOMInitializer com_initializer(
        ScopedCOMInitializer::Uninitialization::kAllow);
    ASSERT_TRUE(com_initializer.Succeeded());

    // Create COM object successfully.
    ComPtr<IUnknown> shell_link;
    HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                    IID_PPV_ARGS(&shell_link));
    EXPECT_TRUE(SUCCEEDED(hr));
  }

  // ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
  if (DCHECK_IS_ON()) {
    EXPECT_NOTREACHED_DEATH(AssertComInitialized());
  }
}

TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerDisabled) {
  // Assert COM has initialized correctly.
  ScopedCOMInitializer com_initializer(
      ScopedCOMInitializer::Uninitialization::kAllow);
  ASSERT_TRUE(com_initializer.Succeeded());

  // Attempt to prematurely uninitialize the COM library.
  ::CoUninitialize();
  ::CoUninitialize();

  // Assert COM is not initialized.
  if (DCHECK_IS_ON()) {
    EXPECT_NOTREACHED_DEATH(AssertComInitialized());
  }

  // Create COM object unsuccessfully.
  ComPtr<IUnknown> shell_link;
  HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                  IID_PPV_ARGS(&shell_link));
  EXPECT_TRUE(FAILED(hr));
  EXPECT_EQ(CO_E_NOTINITIALIZED, hr);
}

TEST(TestComInitBalancer, OneRegisteredSpyRefCount) {
  ScopedCOMInitializer com_initializer(
      ScopedCOMInitializer::Uninitialization::kBlockPremature);
  ASSERT_TRUE(com_initializer.Succeeded());

  // Reference count should be 1 after initialization.
  EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());

  // Attempt to prematurely uninitialize the COM library.
  ::CoUninitialize();

  // Expect reference count to remain at 1.
  EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());
}

TEST(TestComInitBalancer, ThreeRegisteredSpiesRefCount) {
  ScopedCOMInitializer com_initializer_1(
      ScopedCOMInitializer::Uninitialization::kBlockPremature);
  ScopedCOMInitializer com_initializer_2(
      ScopedCOMInitializer::Uninitialization::kBlockPremature);
  ScopedCOMInitializer com_initializer_3(
      ScopedCOMInitializer::Uninitialization::kBlockPremature);
  ASSERT_TRUE(com_initializer_1.Succeeded());
  ASSERT_TRUE(com_initializer_2.Succeeded());
  ASSERT_TRUE(com_initializer_3.Succeeded());

  // Reference count should be 3 after initialization.
  EXPECT_EQ(DWORD(3),
            com_initializer_1.GetCOMBalancerReferenceCountForTesting());
  EXPECT_EQ(DWORD(3),
            com_initializer_2.GetCOMBalancerReferenceCountForTesting());
  EXPECT_EQ(DWORD(3),
            com_initializer_3.GetCOMBalancerReferenceCountForTesting());

  // Attempt to prematurely uninitialize the COM library.
  ::CoUninitialize();  // Reference count -> 2.
  ::CoUninitialize();  // Reference count -> 1.
  ::CoUninitialize();

  // Expect reference count to remain at 1.
  EXPECT_EQ(DWORD(1),
            com_initializer_1.GetCOMBalancerReferenceCountForTesting());
  EXPECT_EQ(DWORD(1),
            com_initializer_2.GetCOMBalancerReferenceCountForTesting());
  EXPECT_EQ(DWORD(1),
            com_initializer_3.GetCOMBalancerReferenceCountForTesting());
}

}  // namespace win
}  // namespace base