// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/base/network_cost_change_notifier_win.h"
#include <wrl.h>
#include <wrl/client.h>
#include "base/check.h"
#include "base/no_destructor.h"
#include "base/task/bind_post_task.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/threading/scoped_thread_priority.h"
#include "base/win/com_init_util.h"
using Microsoft::WRL::ComPtr;
namespace net {
namespace {
NetworkChangeNotifier::ConnectionCost ConnectionCostFromNlmConnectionCost(
DWORD connection_cost_flags) {
if (connection_cost_flags == NLM_CONNECTION_COST_UNKNOWN) {
return NetworkChangeNotifier::CONNECTION_COST_UNKNOWN;
} else if ((connection_cost_flags & NLM_CONNECTION_COST_UNRESTRICTED) != 0) {
return NetworkChangeNotifier::CONNECTION_COST_UNMETERED;
} else {
return NetworkChangeNotifier::CONNECTION_COST_METERED;
}
}
NetworkCostChangeNotifierWin::CoCreateInstanceCallback&
GetCoCreateInstanceCallback() {
static base::NoDestructor<
NetworkCostChangeNotifierWin::CoCreateInstanceCallback>
co_create_instance_callback{base::BindRepeating(&CoCreateInstance)};
return *co_create_instance_callback;
}
} // namespace
// This class is used as an event sink to register for notifications from the
// `INetworkCostManagerEvents` interface. In particular, we are focused on
// getting notified when the connection cost changes.
class NetworkCostManagerEventSinkWin final
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
INetworkCostManagerEvents> {
public:
static HRESULT CreateInstance(
INetworkCostManager* network_cost_manager,
base::RepeatingClosure cost_changed_callback,
ComPtr<NetworkCostManagerEventSinkWin>* result) {
ComPtr<NetworkCostManagerEventSinkWin> instance =
Microsoft::WRL::Make<net::NetworkCostManagerEventSinkWin>(
cost_changed_callback);
HRESULT hr = instance->RegisterForNotifications(network_cost_manager);
if (hr != S_OK) {
return hr;
}
*result = instance;
return S_OK;
}
void UnRegisterForNotifications() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (event_sink_connection_point_) {
event_sink_connection_point_->Unadvise(event_sink_connection_cookie_);
event_sink_connection_point_.Reset();
}
}
// Implement the INetworkCostManagerEvents interface.
HRESULT __stdcall CostChanged(DWORD /*cost*/,
NLM_SOCKADDR* /*socket_address*/) final {
// It is possible to get multiple notifications in a short period of time.
// Rather than worrying about whether this notification represents the
// latest, just notify the owner who can get the current value from the
// INetworkCostManager so we know that we're actually getting the correct
// value.
cost_changed_callback_.Run();
return S_OK;
}
HRESULT __stdcall DataPlanStatusChanged(
NLM_SOCKADDR* /*socket_address*/) final {
return S_OK;
}
NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback)
: cost_changed_callback_(cost_changed_callback) {}
NetworkCostManagerEventSinkWin(const NetworkCostManagerEventSinkWin&) =
delete;
NetworkCostManagerEventSinkWin& operator=(
const NetworkCostManagerEventSinkWin&) = delete;
private:
~NetworkCostManagerEventSinkWin() final = default;
HRESULT RegisterForNotifications(INetworkCostManager* cost_manager) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
base::win::AssertComInitialized();
base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
ComPtr<IUnknown> this_event_sink_unknown;
HRESULT hr = QueryInterface(IID_PPV_ARGS(&this_event_sink_unknown));
// `NetworkCostManagerEventSinkWin::QueryInterface` for `IUnknown` must
// succeed since it is implemented by this class.
CHECK_EQ(hr, S_OK);
ComPtr<IConnectionPointContainer> connection_point_container;
hr =
cost_manager->QueryInterface(IID_PPV_ARGS(&connection_point_container));
if (hr != S_OK) {
return hr;
}
Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point;
hr = connection_point_container->FindConnectionPoint(
IID_INetworkCostManagerEvents, &event_sink_connection_point);
if (hr != S_OK) {
return hr;
}
hr = event_sink_connection_point->Advise(this_event_sink_unknown.Get(),
&event_sink_connection_cookie_);
if (hr != S_OK) {
return hr;
}
CHECK_EQ(event_sink_connection_point_, nullptr);
event_sink_connection_point_ = event_sink_connection_point;
return S_OK;
}
base::RepeatingClosure cost_changed_callback_;
// The following members must be accessed on the sequence from
// `sequence_checker_`
SEQUENCE_CHECKER(sequence_checker_);
DWORD event_sink_connection_cookie_ = 0;
Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point_;
};
// static
base::SequenceBound<NetworkCostChangeNotifierWin>
NetworkCostChangeNotifierWin::CreateInstance(
CostChangedCallback cost_changed_callback) {
scoped_refptr<base::SequencedTaskRunner> com_best_effort_task_runner =
base::ThreadPool::CreateCOMSTATaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN});
return base::SequenceBound<NetworkCostChangeNotifierWin>(
com_best_effort_task_runner,
// Ensure `cost_changed_callback` runs on the sequence of the creator and
// owner of `NetworkCostChangeNotifierWin`.
base::BindPostTask(base::SequencedTaskRunner::GetCurrentDefault(),
cost_changed_callback));
}
NetworkCostChangeNotifierWin::NetworkCostChangeNotifierWin(
CostChangedCallback cost_changed_callback)
: cost_changed_callback_(cost_changed_callback) {
StartWatching();
}
NetworkCostChangeNotifierWin::~NetworkCostChangeNotifierWin() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
StopWatching();
}
void NetworkCostChangeNotifierWin::StartWatching() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (base::win::GetVersion() < kSupportedOsVersion) {
return;
}
base::win::AssertComInitialized();
base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY();
// Create `INetworkListManager` using `CoCreateInstance()`. Tests may provide
// a fake implementation of `INetworkListManager` through an
// `OverrideCoCreateInstanceForTesting()`.
ComPtr<INetworkCostManager> cost_manager;
HRESULT hr = GetCoCreateInstanceCallback().Run(
CLSID_NetworkListManager, /*unknown_outer=*/nullptr, CLSCTX_ALL,
IID_INetworkCostManager, &cost_manager);
if (hr != S_OK) {
return;
}
// Subscribe to cost changed events.
hr = NetworkCostManagerEventSinkWin::CreateInstance(
cost_manager.Get(),
// Cost changed callbacks must run on this sequence to get the new cost
// from `INetworkCostManager`.
base::BindPostTask(
base::SequencedTaskRunner::GetCurrentDefault(),
base::BindRepeating(&NetworkCostChangeNotifierWin::HandleCostChanged,
weak_ptr_factory_.GetWeakPtr())),
&cost_manager_event_sink_);
if (hr != S_OK) {
return;
}
// Set the initial cost and inform observers of the initial value.
cost_manager_ = cost_manager;
HandleCostChanged();
}
void NetworkCostChangeNotifierWin::StopWatching() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (cost_manager_event_sink_) {
cost_manager_event_sink_->UnRegisterForNotifications();
cost_manager_event_sink_.Reset();
}
cost_manager_.Reset();
}
void NetworkCostChangeNotifierWin::HandleCostChanged() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DWORD connection_cost_flags;
HRESULT hr = cost_manager_->GetCost(&connection_cost_flags,
/*destination_ip_address=*/nullptr);
if (hr != S_OK) {
connection_cost_flags = NLM_CONNECTION_COST_UNKNOWN;
}
NetworkChangeNotifier::ConnectionCost changed_cost =
ConnectionCostFromNlmConnectionCost(connection_cost_flags);
cost_changed_callback_.Run(changed_cost);
}
// static
void NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
CoCreateInstanceCallback callback_for_testing) {
GetCoCreateInstanceCallback() = callback_for_testing;
}
} // namespace net