chromium/components/named_mojo_ipc_server/named_mojo_server_endpoint_connector_win.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/named_mojo_ipc_server/named_mojo_server_endpoint_connector.h"

#include <windows.h>

#include <string.h>

#include <memory>
#include <utility>

#include "base/check.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/process/process_handle.h"
#include "base/sequence_checker.h"
#include "base/synchronization/waitable_event.h"
#include "base/synchronization/waitable_event_watcher.h"
#include "base/task/current_thread.h"
#include "base/task/sequenced_task_runner.h"
#include "base/thread_annotations.h"
#include "base/threading/sequence_bound.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "base/win/scoped_handle.h"
#include "base/win/windows_types.h"
#include "components/named_mojo_ipc_server/connection_info.h"
#include "components/named_mojo_ipc_server/endpoint_options.h"
#include "mojo/public/cpp/platform/named_platform_channel.h"
#include "mojo/public/cpp/platform/platform_channel_endpoint.h"
#include "mojo/public/cpp/platform/platform_handle.h"

namespace named_mojo_ipc_server {
namespace {

constexpr base::TimeDelta kRetryConnectionTimeout = base::Seconds(3);

class NamedMojoServerEndpointConnectorWin final
    : public NamedMojoServerEndpointConnector {
 public:
  explicit NamedMojoServerEndpointConnectorWin(
      const EndpointOptions& options,
      base::SequenceBound<Delegate> delegate);
  NamedMojoServerEndpointConnectorWin(
      const NamedMojoServerEndpointConnectorWin&) = delete;
  NamedMojoServerEndpointConnectorWin& operator=(
      const NamedMojoServerEndpointConnectorWin&) = delete;
  ~NamedMojoServerEndpointConnectorWin() override;

 private:
  void OnConnectedEventSignaled(base::WaitableEvent* event);

  void Connect();
  void OnReady();
  void OnError();

  void ResetConnectionObjects();

  // Overrides for NamedMojoServerEndpointConnector.
  bool TryStart() override;

  base::WaitableEventWatcher client_connection_watcher_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Non-null when there is a pending connection.
  base::win::ScopedHandle pending_named_pipe_handle_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Signaled by ConnectNamedPipe() once |pending_named_pipe_handle_| is
  // connected to a client.
  base::WaitableEvent client_connected_event_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Object to allow ConnectNamedPipe() to run asynchronously.
  OVERLAPPED connect_overlapped_ GUARDED_BY_CONTEXT(sequence_checker_);

  base::OneShotTimer retry_connect_timer_;
};

NamedMojoServerEndpointConnectorWin::NamedMojoServerEndpointConnectorWin(
    const EndpointOptions& options,
    base::SequenceBound<Delegate> delegate)
    : NamedMojoServerEndpointConnector(options, std::move(delegate)),
      client_connected_event_(base::WaitableEvent::ResetPolicy::MANUAL,
                              base::WaitableEvent::InitialState::NOT_SIGNALED) {
  DCHECK(delegate_);
}

NamedMojoServerEndpointConnectorWin::~NamedMojoServerEndpointConnectorWin() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void NamedMojoServerEndpointConnectorWin::Connect() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!pending_named_pipe_handle_.IsValid());

  mojo::NamedPlatformChannel::Options options;
  options.server_name = options_.server_name;
  options.security_descriptor = options_.security_descriptor;
  // Must be set to false to allow multiple clients to connect.
  options.enforce_uniqueness = false;
  mojo::PlatformChannelServerEndpoint server_endpoint =
      mojo::NamedPlatformChannel(options).TakeServerEndpoint();
  if (!server_endpoint.is_valid()) {
    OnError();
    return;
  }

  delegate_.AsyncCall(&Delegate::OnServerEndpointCreated);

  pending_named_pipe_handle_ =
      server_endpoint.TakePlatformHandle().TakeHandle();
  // The |lpOverlapped| argument of ConnectNamedPipe() has the annotation of
  // [in, out, optional], so we reset the content before passing it in, just to
  // be safe.
  memset(&connect_overlapped_, 0, sizeof(connect_overlapped_));
  connect_overlapped_.hEvent = client_connected_event_.handle();
  BOOL ok =
      ConnectNamedPipe(pending_named_pipe_handle_.Get(), &connect_overlapped_);
  if (ok) {
    PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
    OnError();
    return;
  }

  const DWORD err = GetLastError();
  switch (err) {
    case ERROR_PIPE_CONNECTED:
      // A client has connected before the server calls ConnectNamedPipe().
      OnReady();
      return;
    case ERROR_IO_PENDING:
      client_connection_watcher_.StartWatching(
          &client_connected_event_,
          base::BindOnce(
              &NamedMojoServerEndpointConnectorWin::OnConnectedEventSignaled,
              base::Unretained(this)),
          base::SequencedTaskRunner::GetCurrentDefault());
      return;
    default:
      PLOG(ERROR) << "Unexpected error: " << err;
      OnError();
      return;
  }
}

void NamedMojoServerEndpointConnectorWin::OnConnectedEventSignaled(
    base::WaitableEvent* event) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK_EQ(&client_connected_event_, event);

  OnReady();
}

void NamedMojoServerEndpointConnectorWin::OnReady() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  auto info = std::make_unique<ConnectionInfo>();
  if (!GetNamedPipeClientProcessId(pending_named_pipe_handle_.Get(),
                                   &info->pid)) {
    PLOG(ERROR) << "Failed to get peer PID";
    OnError();
    return;
  }
  mojo::PlatformChannelEndpoint endpoint(
      mojo::PlatformHandle(std::move(pending_named_pipe_handle_)));
  if (!endpoint.is_valid()) {
    LOG(ERROR) << "Endpoint is invalid.";
    OnError();
    return;
  }
  ResetConnectionObjects();
  delegate_.AsyncCall(&Delegate::OnClientConnected)
      .WithArgs(std::move(endpoint), std::move(info));
  Connect();
}

void NamedMojoServerEndpointConnectorWin::OnError() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  ResetConnectionObjects();
  retry_connect_timer_.Start(FROM_HERE, kRetryConnectionTimeout, this,
                             &NamedMojoServerEndpointConnectorWin::Connect);
}

void NamedMojoServerEndpointConnectorWin::ResetConnectionObjects() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  client_connection_watcher_.StopWatching();
  client_connected_event_.Reset();
  pending_named_pipe_handle_.Close();
}

bool NamedMojoServerEndpointConnectorWin::TryStart() {
  Connect();
  return true;
}

}  // namespace

// static
base::SequenceBound<NamedMojoServerEndpointConnector>
NamedMojoServerEndpointConnector::Create(
    scoped_refptr<base::SequencedTaskRunner> io_sequence,
    const EndpointOptions& options,
    base::SequenceBound<Delegate> delegate) {
  return base::SequenceBound<NamedMojoServerEndpointConnectorWin>(
      io_sequence, options, std::move(delegate));
}

}  // namespace named_mojo_ipc_server