chromium/remoting/host/security_key/security_key_auth_handler_win.cc

// Copyright 2016 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/host/security_key/security_key_auth_handler.h"

#include <cstdint>
#include <map>
#include <memory>
#include <string>

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/single_thread_task_runner.h"
#include "base/threading/thread_checker.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "base/win/win_util.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "remoting/base/logging.h"
#include "remoting/host/client_session_details.h"
#include "remoting/host/mojom/remote_security_key.mojom.h"
#include "remoting/host/security_key/security_key_ipc_constants.h"

namespace remoting {

namespace {

// The timeout used to disconnect a client from the IPC Server if it forgets to
// send a request after it is connected.  This ensures the server channel is not
// blocked forever.
constexpr base::TimeDelta kInitialRequestTimeout = base::Seconds(5);

// This value represents the amount of time to wait for a security key request
// from the client before terminating the connection.
constexpr base::TimeDelta kSecurityKeyRequestTimeout = base::Seconds(60);

struct ActiveConnection {
  mojo::ReceiverId receiver_id;
  base::OneShotTimer disconnect_timer;
  mojom::SecurityKeyForwarder::OnSecurityKeyRequestCallback
      on_security_key_request_callback;
};

}  // namespace

// Implements the mojom::SecurityKeyForwarder interface and handles incoming SK
// requests from the IPC client. The caller is responsible for running the IPC
// server and passing in new connections through BindSecurityKeyForwarder().
// TODO(joedow): Update SecurityKeyAuthHandler impls to run on a separate IO
// thread instead of the thread it was created on: crbug.com/591739
class SecurityKeyAuthHandlerWin : public SecurityKeyAuthHandler,
                                  public mojom::SecurityKeyForwarder {
 public:
  explicit SecurityKeyAuthHandlerWin(
      ClientSessionDetails* client_session_details);

  SecurityKeyAuthHandlerWin(const SecurityKeyAuthHandlerWin&) = delete;
  SecurityKeyAuthHandlerWin& operator=(const SecurityKeyAuthHandlerWin&) =
      delete;

  ~SecurityKeyAuthHandlerWin() override;

 private:
  // On Windows, sizeof(int) != sizeof(ReceiverId), so we can't just use the
  // receiver ID as the connection ID.
  using ActiveConnections = std::map</* connection_id */ int, ActiveConnection>;

  // SecurityKeyAuthHandler interface.
  void BindSecurityKeyForwarder(
      mojo::PendingReceiver<mojom::SecurityKeyForwarder> receiver) override;
  void CreateSecurityKeyConnection() override;
  bool IsValidConnectionId(int security_key_connection_id) const override;
  void SendClientResponse(int security_key_connection_id,
                          const std::string& response) override;
  void SendErrorAndCloseConnection(int security_key_connection_id) override;
  void SetSendMessageCallback(const SendMessageCallback& callback) override;
  size_t GetActiveConnectionCountForTest() const override;
  void SetRequestTimeoutForTest(base::TimeDelta timeout) override;

  // mojom::SecurityKeyForwarder interface.
  void OnSecurityKeyRequest(const std::string& request_data,
                            OnSecurityKeyRequestCallback callback) override;

  void OnIpcPeerDisconnected();

  // Closes the connection created for a security key forwarding session.
  void CloseSecurityKeyRequestConnection(int connection_id);

  base::OnceClosure GetCloseConnectionClosure(int connection_id);

  // Represents the last id assigned to a new security key request connection.
  int last_connection_id_ = 0;

  // Sends a security key extension message to the client when called.
  SendMessageCallback send_message_callback_;

  // Interface which provides details about the client session.
  raw_ptr<ClientSessionDetails> client_session_details_ = nullptr;

  // Tracks the connection created for each security key forwarding session.
  ActiveConnections active_connections_;

  mojo::ReceiverSet<mojom::SecurityKeyForwarder, /* connection_id */ int>
      receiver_set_;

  // Ensures SecurityKeyAuthHandlerWin methods are called on the same thread.
  base::ThreadChecker thread_checker_;

  base::WeakPtrFactory<SecurityKeyAuthHandlerWin> weak_factory_{this};
};

std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create(
    ClientSessionDetails* client_session_details,
    const SendMessageCallback& send_message_callback,
    scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) {
  std::unique_ptr<SecurityKeyAuthHandler> auth_handler(
      new SecurityKeyAuthHandlerWin(client_session_details));
  auth_handler->SetSendMessageCallback(send_message_callback);
  return auth_handler;
}

SecurityKeyAuthHandlerWin::SecurityKeyAuthHandlerWin(
    ClientSessionDetails* client_session_details)
    : client_session_details_(client_session_details) {
  DCHECK(client_session_details_);
  receiver_set_.set_disconnect_handler(
      base::BindRepeating(&SecurityKeyAuthHandlerWin::OnIpcPeerDisconnected,
                          weak_factory_.GetWeakPtr()));
}

SecurityKeyAuthHandlerWin::~SecurityKeyAuthHandlerWin() {
  DCHECK(thread_checker_.CalledOnValidThread());
}

void SecurityKeyAuthHandlerWin::BindSecurityKeyForwarder(
    mojo::PendingReceiver<mojom::SecurityKeyForwarder> receiver) {
  DCHECK(thread_checker_.CalledOnValidThread());
  int new_connection_id = ++last_connection_id_;
  // Note that this default-constructs the object.
  ActiveConnection& connection = active_connections_[new_connection_id];
  connection.receiver_id =
      receiver_set_.Add(this, std::move(receiver), new_connection_id);
  // Close the connection if the client doesn't send any requests within the
  // deadline.
  connection.disconnect_timer.Start(
      FROM_HERE, kInitialRequestTimeout,
      GetCloseConnectionClosure(new_connection_id));
}

void SecurityKeyAuthHandlerWin::CreateSecurityKeyConnection() {
  // No-op, since the caller maintains the IPC connection and passes pending
  // receivers via BindSecurityKeyForwarder().
}

bool SecurityKeyAuthHandlerWin::IsValidConnectionId(int connection_id) const {
  DCHECK(thread_checker_.CalledOnValidThread());
  return active_connections_.find(connection_id) != active_connections_.end();
}

void SecurityKeyAuthHandlerWin::SendClientResponse(
    int connection_id,
    const std::string& response_data) {
  DCHECK(thread_checker_.CalledOnValidThread());

  auto iter = active_connections_.find(connection_id);
  if (iter == active_connections_.end()) {
    HOST_LOG << "Invalid security key connection ID received: "
             << connection_id;
    return;
  }
  ActiveConnection& connection = iter->second;
  std::move(connection.on_security_key_request_callback).Run(response_data);
  // Reset the timer to give the client a chance to send another request.
  connection.disconnect_timer.Start(FROM_HERE, kSecurityKeyRequestTimeout,
                                    GetCloseConnectionClosure(connection_id));
}

void SecurityKeyAuthHandlerWin::SendErrorAndCloseConnection(int connection_id) {
  DCHECK(thread_checker_.CalledOnValidThread());

  SendClientResponse(connection_id, kSecurityKeyConnectionError);
  CloseSecurityKeyRequestConnection(connection_id);
}

void SecurityKeyAuthHandlerWin::SetSendMessageCallback(
    const SendMessageCallback& callback) {
  DCHECK(thread_checker_.CalledOnValidThread());
  send_message_callback_ = callback;
}

size_t SecurityKeyAuthHandlerWin::GetActiveConnectionCountForTest() const {
  return active_connections_.size();
}

void SecurityKeyAuthHandlerWin::SetRequestTimeoutForTest(
    base::TimeDelta timeout) {
  // SecurityKeyAuthHandlerWin tests don't override request timeout.
  NOTREACHED_IN_MIGRATION();
}

void SecurityKeyAuthHandlerWin::OnSecurityKeyRequest(
    const std::string& request_data,
    OnSecurityKeyRequestCallback callback) {
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(send_message_callback_);

  int connection_id = receiver_set_.current_context();
  auto iter = active_connections_.find(connection_id);
  DCHECK(iter != active_connections_.end());
  ActiveConnection& connection = iter->second;
  if (connection.on_security_key_request_callback) {
    LOG(ERROR) << "Received security key request while waiting for a response";
    CloseSecurityKeyRequestConnection(connection_id);
    return;
  }
  // Reset the timer to give the client a chance to send the response.
  connection.disconnect_timer.Start(FROM_HERE, kSecurityKeyRequestTimeout,
                                    GetCloseConnectionClosure(connection_id));
  connection.on_security_key_request_callback = std::move(callback);
  send_message_callback_.Run(connection_id, request_data);
}

void SecurityKeyAuthHandlerWin::OnIpcPeerDisconnected() {
  DCHECK(thread_checker_.CalledOnValidThread());
  active_connections_.erase(receiver_set_.current_context());
}

void SecurityKeyAuthHandlerWin::CloseSecurityKeyRequestConnection(
    int connection_id) {
  DCHECK(thread_checker_.CalledOnValidThread());
  auto iter = active_connections_.find(connection_id);
  if (iter == active_connections_.end()) {
    LOG(ERROR) << "Connection ID " << connection_id << " doesn't exist.";
    return;
  }
  receiver_set_.Remove(iter->second.receiver_id);
  active_connections_.erase(iter);
}

base::OnceClosure SecurityKeyAuthHandlerWin::GetCloseConnectionClosure(
    int connection_id) {
  return base::BindOnce(
      &SecurityKeyAuthHandlerWin::CloseSecurityKeyRequestConnection,
      weak_factory_.GetWeakPtr(), connection_id);
}

}  // namespace remoting