// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/test/win/fake_network_cost_manager.h"
#include <netlistmgr.h>
#include <wrl/implements.h>
#include <map>
#include "base/task/sequenced_task_runner.h"
#include "net/base/network_cost_change_notifier_win.h"
using Microsoft::WRL::ClassicCom;
using Microsoft::WRL::ComPtr;
using Microsoft::WRL::RuntimeClass;
using Microsoft::WRL::RuntimeClassFlags;
namespace net {
namespace {
DWORD NlmConnectionCostFlagsFromConnectionCost(
NetworkChangeNotifier::ConnectionCost source_cost) {
switch (source_cost) {
case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNMETERED:
return (NLM_CONNECTION_COST_UNRESTRICTED | NLM_CONNECTION_COST_CONGESTED);
case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_METERED:
return (NLM_CONNECTION_COST_VARIABLE | NLM_CONNECTION_COST_ROAMING |
NLM_CONNECTION_COST_APPROACHINGDATALIMIT);
case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNKNOWN:
default:
return NLM_CONNECTION_COST_UNKNOWN;
}
}
void DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target,
DWORD cost) {
std::ignore =
event_target->CostChanged(cost, /*destination_address=*/nullptr);
}
} // namespace
// A fake implementation of `INetworkCostManager` that can simulate costs,
// changed costs and errors.
class FakeNetworkCostManager final
: public RuntimeClass<RuntimeClassFlags<ClassicCom>,
INetworkCostManager,
IConnectionPointContainer,
IConnectionPoint> {
public:
FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost,
NetworkCostManagerStatus error_status)
: error_status_(error_status), connection_cost_(connection_cost) {}
// For each event sink in `event_sinks_`, call
// `INetworkCostManagerEvents::CostChanged()` with `changed_cost` on the event
// sink's task runner.
void PostCostChangedEvents(
NetworkChangeNotifier::ConnectionCost changed_cost) {
DWORD cost_for_changed_event;
std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration>
event_sinks_for_changed_event;
{
base::AutoLock auto_lock(member_lock_);
connection_cost_ = changed_cost;
cost_for_changed_event =
NlmConnectionCostFlagsFromConnectionCost(changed_cost);
// Get the snapshot of event sinks to notify. The snapshot collection
// creates a new `ComPtr` for each event sink, which increments each the
// event sink's reference count, ensuring that each event sink
// remains alive to receive the cost changed event notification.
event_sinks_for_changed_event = event_sinks_;
}
for (const auto& pair : event_sinks_for_changed_event) {
const auto& registration = pair.second;
registration.event_sink_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&DispatchCostChangedEvent, registration.event_sink_,
cost_for_changed_event));
}
}
// Implement the `INetworkCostManager` interface.
HRESULT
__stdcall GetCost(DWORD* cost,
NLM_SOCKADDR* destination_ip_address) override {
if (error_status_ == NetworkCostManagerStatus::kErrorGetCostFailed) {
return E_FAIL;
}
if (destination_ip_address != nullptr) {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
{
base::AutoLock auto_lock(member_lock_);
*cost = NlmConnectionCostFlagsFromConnectionCost(connection_cost_);
}
return S_OK;
}
HRESULT __stdcall GetDataPlanStatus(
NLM_DATAPLAN_STATUS* data_plan_status,
NLM_SOCKADDR* destination_ip_address) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
HRESULT __stdcall SetDestinationAddresses(
UINT32 length,
NLM_SOCKADDR* destination_ip_address_list,
VARIANT_BOOL append) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
// Implement the `IConnectionPointContainer` interface.
HRESULT __stdcall FindConnectionPoint(REFIID connection_point_id,
IConnectionPoint** result) override {
if (error_status_ ==
NetworkCostManagerStatus::kErrorFindConnectionPointFailed) {
return E_ABORT;
}
if (connection_point_id != IID_INetworkCostManagerEvents) {
return E_NOINTERFACE;
}
*result = static_cast<IConnectionPoint*>(this);
AddRef();
return S_OK;
}
HRESULT __stdcall EnumConnectionPoints(
IEnumConnectionPoints** results) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
// Implement the `IConnectionPoint` interface.
HRESULT __stdcall Advise(IUnknown* event_sink,
DWORD* event_sink_cookie) override {
if (error_status_ == NetworkCostManagerStatus::kErrorAdviseFailed) {
return E_NOT_VALID_STATE;
}
ComPtr<INetworkCostManagerEvents> cost_manager_event_sink;
HRESULT hr =
event_sink->QueryInterface(IID_PPV_ARGS(&cost_manager_event_sink));
if (hr != S_OK) {
return hr;
}
base::AutoLock auto_lock(member_lock_);
event_sinks_[next_event_sink_cookie_] = {
cost_manager_event_sink,
base::SequencedTaskRunner::GetCurrentDefault()};
*event_sink_cookie = next_event_sink_cookie_;
++next_event_sink_cookie_;
return S_OK;
}
HRESULT __stdcall Unadvise(DWORD event_sink_cookie) override {
base::AutoLock auto_lock(member_lock_);
auto it = event_sinks_.find(event_sink_cookie);
if (it == event_sinks_.end()) {
return ERROR_NOT_FOUND;
}
event_sinks_.erase(it);
return S_OK;
}
HRESULT __stdcall GetConnectionInterface(IID* result) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
HRESULT __stdcall GetConnectionPointContainer(
IConnectionPointContainer** result) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
HRESULT __stdcall EnumConnections(IEnumConnections** result) override {
NOTIMPLEMENTED();
return E_NOTIMPL;
}
// Implement the `IUnknown` interface.
HRESULT __stdcall QueryInterface(REFIID interface_id,
void** result) override {
if (error_status_ == NetworkCostManagerStatus::kErrorQueryInterfaceFailed) {
return E_NOINTERFACE;
}
return RuntimeClass<RuntimeClassFlags<ClassicCom>, INetworkCostManager,
IConnectionPointContainer,
IConnectionPoint>::QueryInterface(interface_id, result);
}
FakeNetworkCostManager(const FakeNetworkCostManager&) = delete;
FakeNetworkCostManager& operator=(const FakeNetworkCostManager&) = delete;
private:
// The error state for this `FakeNetworkCostManager` to simulate. Cannot be
// changed.
const NetworkCostManagerStatus error_status_;
// Synchronizes access to all members below.
base::Lock member_lock_;
NetworkChangeNotifier::ConnectionCost connection_cost_
GUARDED_BY(member_lock_);
DWORD next_event_sink_cookie_ GUARDED_BY(member_lock_) = 0;
struct EventSinkRegistration {
ComPtr<INetworkCostManagerEvents> event_sink_;
scoped_refptr<base::SequencedTaskRunner> event_sink_task_runner_;
};
std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration> event_sinks_
GUARDED_BY(member_lock_);
};
FakeNetworkCostManagerEnvironment::FakeNetworkCostManagerEnvironment() {
// Set up `NetworkCostChangeNotifierWin` to use the fake OS APIs.
NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
base::BindRepeating(
&FakeNetworkCostManagerEnvironment::FakeCoCreateInstance,
base::Unretained(this)));
}
FakeNetworkCostManagerEnvironment::~FakeNetworkCostManagerEnvironment() {
// Restore `NetworkCostChangeNotifierWin` to use the real OS APIs.
NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
base::BindRepeating(&CoCreateInstance));
}
HRESULT FakeNetworkCostManagerEnvironment::FakeCoCreateInstance(
REFCLSID class_id,
LPUNKNOWN outer_aggregate,
DWORD context_flags,
REFIID interface_id,
LPVOID* result) {
NetworkChangeNotifier::ConnectionCost connection_cost_for_new_instance;
NetworkCostManagerStatus error_status_for_new_instance;
{
base::AutoLock auto_lock(member_lock_);
connection_cost_for_new_instance = connection_cost_;
error_status_for_new_instance = error_status_;
}
if (error_status_for_new_instance ==
NetworkCostManagerStatus::kErrorCoCreateInstanceFailed) {
return E_ACCESSDENIED;
}
if (class_id != CLSID_NetworkListManager) {
return E_NOINTERFACE;
}
if (interface_id != IID_INetworkCostManager) {
return E_NOINTERFACE;
}
ComPtr<FakeNetworkCostManager> instance =
Microsoft::WRL::Make<FakeNetworkCostManager>(
connection_cost_for_new_instance, error_status_for_new_instance);
{
base::AutoLock auto_lock(member_lock_);
fake_network_cost_managers_.push_back(instance);
}
*result = instance.Detach();
return S_OK;
}
void FakeNetworkCostManagerEnvironment::SetCost(
NetworkChangeNotifier::ConnectionCost value) {
// Update the cost for each `INetworkCostManager` instance in
// `fake_network_cost_managers_`.
std::vector<Microsoft::WRL::ComPtr<FakeNetworkCostManager>>
fake_network_cost_managers_for_change_event;
{
base::AutoLock auto_lock(member_lock_);
connection_cost_ = value;
fake_network_cost_managers_for_change_event = fake_network_cost_managers_;
}
for (const auto& network_cost_manager :
fake_network_cost_managers_for_change_event) {
network_cost_manager->PostCostChangedEvents(/*connection_cost=*/value);
}
}
void FakeNetworkCostManagerEnvironment::SimulateError(
NetworkCostManagerStatus error_status) {
base::AutoLock auto_lock(member_lock_);
error_status_ = error_status;
}
} // namespace net