chromium/remoting/host/security_key/security_key_message_handler.cc

// Copyright 2016 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "remoting/host/security_key/security_key_message_handler.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "remoting/base/logging.h"
#include "remoting/host/security_key/security_key_ipc_client.h"
#include "remoting/host/security_key/security_key_ipc_constants.h"
#include "remoting/host/security_key/security_key_message_reader_impl.h"
#include "remoting/host/security_key/security_key_message_writer_impl.h"

namespace remoting {

SecurityKeyMessageHandler::SecurityKeyMessageHandler() = default;

SecurityKeyMessageHandler::~SecurityKeyMessageHandler() {
  DCHECK(thread_checker_.CalledOnValidThread());
}

void SecurityKeyMessageHandler::Start(
    base::File message_read_stream,
    base::File message_write_stream,
    std::unique_ptr<SecurityKeyIpcClient> ipc_client,
    base::OnceClosure error_callback) {
  DCHECK(thread_checker_.CalledOnValidThread());
  DCHECK(message_read_stream.IsValid());
  DCHECK(message_write_stream.IsValid());
  DCHECK(ipc_client);
  DCHECK(!error_callback.is_null());
  DCHECK(error_callback_.is_null());

  if (!reader_) {
    reader_ = std::make_unique<SecurityKeyMessageReaderImpl>(
        std::move(message_read_stream));
  }

  if (!writer_) {
    writer_ = std::make_unique<SecurityKeyMessageWriterImpl>(
        std::move(message_write_stream));
  }

  ipc_client_ = std::move(ipc_client);
  error_callback_ = std::move(error_callback);

  reader_->Start(
      base::BindRepeating(&SecurityKeyMessageHandler::ProcessSecurityKeyMessage,
                          base::Unretained(this)),
      base::BindOnce(&SecurityKeyMessageHandler::OnError,
                     base::Unretained(this)));
}

void SecurityKeyMessageHandler::SetSecurityKeyMessageReaderForTest(
    std::unique_ptr<SecurityKeyMessageReader> reader) {
  DCHECK(!reader_);
  reader_ = std::move(reader);
}

void SecurityKeyMessageHandler::SetSecurityKeyMessageWriterForTest(
    std::unique_ptr<SecurityKeyMessageWriter> writer) {
  DCHECK(!writer_);
  writer_ = std::move(writer);
}

void SecurityKeyMessageHandler::ProcessSecurityKeyMessage(
    std::unique_ptr<SecurityKeyMessage> message) {
  DCHECK(thread_checker_.CalledOnValidThread());

  SecurityKeyMessageType message_type = message->type();
  HOST_LOG << "Received message from pipe. type="
           << static_cast<int>(message_type);
  if (message_type == SecurityKeyMessageType::CONNECT) {
    HandleConnectRequest(message->payload());
  } else if (message_type == SecurityKeyMessageType::REQUEST) {
    HandleSecurityKeyRequest(message->payload());
  } else {
    // uint8_t is handled as char, so we use uint16_t to show number here.
    LOG(ERROR) << "Unknown message type: "
               << static_cast<uint16_t>(message_type);
    SendMessage(SecurityKeyMessageType::UNKNOWN_COMMAND);
  }
}

void SecurityKeyMessageHandler::HandleIpcConnectionChange() {
  DCHECK(thread_checker_.CalledOnValidThread());
  SendMessageWithPayload(SecurityKeyMessageType::CONNECT_RESPONSE,
                         std::string(1, kConnectResponseActiveSession));
}

void SecurityKeyMessageHandler::HandleIpcConnectionError() {
  DCHECK(thread_checker_.CalledOnValidThread());
  SendMessageWithPayload(SecurityKeyMessageType::CONNECT_ERROR,
                         "Unknown error occurred during connection.");
}

void SecurityKeyMessageHandler::HandleSecurityKeyResponse(
    const std::string& response_data) {
  if (response_data.compare(kSecurityKeyConnectionError) == 0) {
    SendMessageWithPayload(SecurityKeyMessageType::REQUEST_ERROR,
                           "An error occurred during the request.");
    return;
  }

  if (response_data.empty()) {
    SendMessageWithPayload(SecurityKeyMessageType::REQUEST_ERROR,
                           "Invalid client response received.");
    return;
  }

  SendMessageWithPayload(SecurityKeyMessageType::REQUEST_RESPONSE,
                         response_data);
}

void SecurityKeyMessageHandler::HandleConnectRequest(
    const std::string& message_payload) {
  DCHECK(thread_checker_.CalledOnValidThread());
  if (!message_payload.empty()) {
    SendMessageWithPayload(SecurityKeyMessageType::CONNECT_ERROR,
                           "Unexpected payload data received.");
    return;
  }

  if (ipc_client_->CheckForSecurityKeyIpcServerChannel()) {
    // If we find an IPC server, then attempt to establish a connection.
    ipc_client_->EstablishIpcConnection(
        base::BindOnce(&SecurityKeyMessageHandler::HandleIpcConnectionChange,
                       base::Unretained(this)),
        base::BindOnce(&SecurityKeyMessageHandler::HandleIpcConnectionError,
                       base::Unretained(this)));
  } else {
    SendMessageWithPayload(SecurityKeyMessageType::CONNECT_RESPONSE,
                           std::string(1, kConnectResponseNoSession));
  }
}

void SecurityKeyMessageHandler::HandleSecurityKeyRequest(
    const std::string& message_payload) {
  DCHECK(thread_checker_.CalledOnValidThread());
  if (message_payload.empty()) {
    SendMessageWithPayload(SecurityKeyMessageType::REQUEST_ERROR,
                           "Request sent without request data.");
    return;
  }

  if (!ipc_client_->SendSecurityKeyRequest(
          message_payload,
          base::BindRepeating(
              &SecurityKeyMessageHandler::HandleSecurityKeyResponse,
              base::Unretained(this)))) {
    SendMessageWithPayload(SecurityKeyMessageType::REQUEST_ERROR,
                           "Failed to send request data.");
  }
}

void SecurityKeyMessageHandler::SendMessage(
    SecurityKeyMessageType message_type) {
  if (!writer_->WriteMessage(message_type)) {
    OnError();
  }
}

void SecurityKeyMessageHandler::SendMessageWithPayload(
    SecurityKeyMessageType message_type,
    const std::string& message_payload) {
  if (!writer_->WriteMessageWithPayload(message_type, message_payload)) {
    HOST_LOG << "Failed to send message to pipe. type="
             << static_cast<int>(message_type);
    OnError();
  } else {
    HOST_LOG << "Successfully sent message to pipe. type="
             << static_cast<int>(message_type);
  }
}

void SecurityKeyMessageHandler::OnError() {
  DCHECK(thread_checker_.CalledOnValidThread());
  ipc_client_.reset();
  writer_.reset();
  reader_.reset();

  if (!error_callback_.is_null()) {
    std::move(error_callback_).Run();
  }
}

}  // namespace remoting