chromium/chromecast/cast_core/runtime/browser/message_port_handler.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromecast/cast_core/runtime/browser/message_port_handler.h"

#include <string_view>
#include <utility>

#include "base/logging.h"
#include "base/task/bind_post_task.h"
#include "base/task/sequenced_task_runner.h"
#include "base/time/time.h"
#include "components/cast/message_port/platform_message_port.h"
#include "components/cast_receiver/browser/public/message_port_service.h"

namespace chromecast {
namespace {

// This is used as a timeout for both sending cast::web::Message requests and
// awaiting responses.  Reaching this timeout without a response from the peer
// will close the connection and Blink message port.
constexpr base::TimeDelta kMessageTimeout = base::Seconds(10);

}  // namespace

#define DLOG_CHANNEL(level) DLOG(level) << "channel " << channel_id_ << ": "
#define DVLOG_CHANNEL(level) DVLOG(level) << "channel " << channel_id_ << ": "

MessagePortHandler::MessagePortHandler(
    std::unique_ptr<cast_api_bindings::MessagePort> message_port,
    uint32_t channel_id,
    cast_receiver::MessagePortService* message_port_service,
    cast::v2::CoreMessagePortApplicationServiceStub* core_app_stub,
    scoped_refptr<base::SequencedTaskRunner> task_runner)
    : task_runner_(std::move(task_runner)),
      message_port_service_(message_port_service),
      core_app_stub_(core_app_stub),
      message_port_(std::move(message_port)),
      channel_id_(channel_id) {
  DCHECK(message_port_service_);
  DCHECK(core_app_stub_);
  message_port_->SetReceiver(this);
}

MessagePortHandler::~MessagePortHandler() = default;

cast_receiver::Status MessagePortHandler::HandleMessage(
    const cast::web::Message& message) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!message_port_) {
    return cast_receiver::Status(cast_receiver::StatusCode::kFailedPrecondition,
                                 "Invalid MessagePort");
  }

  switch (message.message_type_case()) {
    case cast::web::Message::kStatus: {
      if (message.status().status() ==
          cast::web::MessagePortStatus_Status_ERROR) {
        DLOG_CHANNEL(WARNING) << "Received error message";
        CloseAndRemove();
      } else if (message.status().status() ==
                 cast::web::MessagePortStatus_Status_STARTED) {
        bool was_started = started_;
        started_ = true;
        if (!was_started && !pending_messages_.empty()) {
          ForwardNextMessage();
        }
      }
      return cast_receiver::OkStatus();
    }
    case cast::web::Message::kRequest: {
      DLOG_CHANNEL(INFO) << "Received request: " << message.request().data();
      std::vector<std::unique_ptr<cast_api_bindings::MessagePort>> ports;
      ports.reserve(message.request().ports_size());
      for (const auto& port : message.request().ports()) {
        std::unique_ptr<cast_api_bindings::MessagePort> client;
        std::unique_ptr<cast_api_bindings::MessagePort> server;
        cast_api_bindings::CreatePlatformMessagePortPair(&client, &server);
        message_port_service_->RegisterIncomingPort(port.channel().channel_id(),
                                                    std::move(client));
        ports.push_back(std::move(server));

        cast::web::Message notification;
        notification.mutable_channel()->set_channel_id(
            port.channel().channel_id());
        notification.mutable_status()->set_status(
            cast::web::MessagePortStatus_Status_STARTED);
        ForwardMessage(std::move(notification));
      }
      bool result = message_port_->PostMessageWithTransferables(
          message.request().data(), std::move(ports));
      SendResponse(result);
      return cast_receiver::OkStatus();
    }
    case cast::web::Message::kResponse: {
      if (!is_awaiting_response_) {
        LOG(FATAL) << "Received response while not expecting one.";
      }
      message_timeout_callback_.Cancel();
      is_awaiting_response_ = false;
      if (!pending_messages_.empty() && !has_outstanding_request_) {
        ForwardNextMessage();
      }
      return cast_receiver::OkStatus();
    }
    default:
      return cast_receiver::Status(cast_receiver::StatusCode::kInvalidArgument,
                                   "Invalid cast::web::Message");
  }
}

void MessagePortHandler::Close() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DLOG_CHANNEL(INFO) << "Closing channel";
  message_timeout_callback_.Cancel();
  pending_messages_.clear();
  message_port_->Close();
  message_port_.reset();
}

void MessagePortHandler::CloseAndRemove() {
  Close();
  message_port_service_->Remove(channel_id_);
}

void MessagePortHandler::CloseWithError(CloseError error) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  switch (error) {
    case CloseError::kPipeError:
      DLOG_CHANNEL(INFO) << "Closing with pipe error";
      break;
    case CloseError::kTimeout:
      DLOG_CHANNEL(INFO) << "Closing from timeout";
      break;
  }
  Close();

  auto call = core_app_stub_->CreateCall<
      cast::v2::CoreMessagePortApplicationServiceStub::PostMessage>();
  call.request().mutable_status()->set_status(
      cast::web::MessagePortStatus_Status_ERROR);
  call.request().mutable_channel()->set_channel_id(channel_id_);
  std::move(call).InvokeAsync(base::DoNothing());

  message_port_service_->Remove(channel_id_);
}

void MessagePortHandler::SendResponse(bool result) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  cast::web::Message message;
  message.mutable_response()->set_result(result);
  message.mutable_channel()->set_channel_id(channel_id_);
  ForwardMessage(std::move(message));
}

void MessagePortHandler::ForwardNextMessage() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!is_awaiting_response_);
  DCHECK(!pending_messages_.empty());
  DCHECK(!has_outstanding_request_);
  cast::web::Message next = std::move(pending_messages_.front());
  pending_messages_.pop_front();
  ForwardMessageNow(std::move(next));
}

bool MessagePortHandler::ForwardMessage(cast::web::Message message) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (message.has_request() &&
      (!started_ || is_awaiting_response_ || !pending_messages_.empty() ||
       has_outstanding_request_)) {
    pending_messages_.emplace_back(std::move(message));
    return true;
  }

  ForwardMessageNow(std::move(message));
  return true;
}

void MessagePortHandler::ForwardMessageNow(cast::web::Message message) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!message.has_request() || !has_outstanding_request_);
  bool was_request = message.has_request();

  auto call = core_app_stub_->CreateCall<
      cast::v2::CoreMessagePortApplicationServiceStub::PostMessage>(
      std::move(message));
  std::move(call).InvokeAsync(base::BindPostTask(
      task_runner_, base::BindOnce(&MessagePortHandler::OnPortMessagePosted,
                                   weak_factory_.GetWeakPtr(), was_request)));
  if (was_request) {
    has_outstanding_request_ = true;
    is_awaiting_response_ = true;
  }
  ResetTimeout();
}

void MessagePortHandler::ResetTimeout() {
  message_timeout_callback_.Reset(
      base::BindOnce(&MessagePortHandler::CloseWithError,
                     weak_factory_.GetWeakPtr(), CloseError::kTimeout));
  task_runner_->PostDelayedTask(FROM_HERE, message_timeout_callback_.callback(),
                                kMessageTimeout);
}

void MessagePortHandler::OnPortMessagePosted(
    bool was_request,
    cast::utils::GrpcStatusOr<cast::web::MessagePortStatus> response_or) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  has_outstanding_request_ = false;
  message_timeout_callback_.Cancel();
  if (!message_port_) {
    return;
  }

  if (!response_or.ok() ||
      response_or->status() != cast::web::MessagePortStatus_Status_OK) {
    DLOG_CHANNEL(WARNING) << "Send failed (" << response_or.ToString() << ", "
                          << cast::web::MessagePortStatus_Status_Name(
                                 response_or->status())
                          << ")";
    CloseAndRemove();
    return;
  }

  if (was_request && is_awaiting_response_) {
    ResetTimeout();
  } else if (!is_awaiting_response_ && !pending_messages_.empty()) {
    ForwardNextMessage();
  }
}

bool MessagePortHandler::OnMessage(
    std::string_view message,
    std::vector<std::unique_ptr<cast_api_bindings::MessagePort>> ports) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  cast::web::Message request;
  request.mutable_request()->set_data(std::string(message));
  request.mutable_channel()->set_channel_id(channel_id_);
  std::vector<cast::web::Message> started_notifications;
  started_notifications.reserve(ports.size());
  for (auto& port : ports) {
    auto* descriptor = request.mutable_request()->mutable_ports()->Add();
    uint32_t channel_id =
        message_port_service_->RegisterOutgoingPort(std::move(port));
    descriptor->mutable_channel()->set_channel_id(channel_id);
    descriptor->set_sequence_number(0);

    cast::web::Message notification;
    notification.mutable_channel()->set_channel_id(channel_id);
    notification.mutable_status()->set_status(
        cast::web::MessagePortStatus_Status_STARTED);
    started_notifications.push_back(std::move(notification));
  }
  ForwardMessage(std::move(request));
  for (auto& notification : started_notifications) {
    ForwardMessage(std::move(notification));
  }
  return true;
}

void MessagePortHandler::OnPipeError() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  CloseWithError(CloseError::kPipeError);
}

}  // namespace chromecast