chromium/chrome/browser/ash/smb_client/discovery/mdns_host_locator.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/smb_client/discovery/mdns_host_locator.h"

#include <utility>
#include <vector>

#include "base/functional/bind.h"
#include "base/strings/string_util.h"
#include "base/task/single_thread_task_runner.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "net/base/net_errors.h"
#include "net/dns/mdns_client.h"
#include "net/dns/public/dns_protocol.h"
#include "net/dns/record_rdata.h"

namespace ash::smb_client {

namespace {

using net::MDnsTransaction;

constexpr char kSmbMDnsServiceName[] = "_smb._tcp.local";
constexpr char kMdnsLocalString[] = ".local";

constexpr int32_t kPtrTransactionFlags = MDnsTransaction::QUERY_NETWORK;
constexpr int32_t kSrvTransactionFlags = MDnsTransaction::SINGLE_RESULT |
                                         MDnsTransaction::QUERY_CACHE |
                                         MDnsTransaction::QUERY_NETWORK;
constexpr int32_t kATransactionFlags =
    MDnsTransaction::SINGLE_RESULT | MDnsTransaction::QUERY_CACHE;

}  // namespace

Hostname RemoveLocal(const std::string& raw_hostname) {
  if (!base::EndsWith(raw_hostname, kMdnsLocalString,
                      base::CompareCase::INSENSITIVE_ASCII)) {
    return raw_hostname;
  }

  DCHECK_GE(raw_hostname.size(), strlen(kMdnsLocalString));
  size_t ending_pos = raw_hostname.size() - strlen(kMdnsLocalString);
  return raw_hostname.substr(0, ending_pos);
}

class MDnsHostLocator::Impl {
 public:
  explicit Impl(scoped_refptr<base::SingleThreadTaskRunner> task_runner)
      : task_runner_(task_runner) {
    // This object is created on the UI thread, so detach the sequence checker
    // and let it re-attach the first time we are run on the IO thread.
    DETACH_FROM_SEQUENCE(sequence_checker_);
  }

  Impl(const Impl&) = delete;
  Impl& operator=(const Impl&) = delete;

  ~Impl() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); }

  void FindHosts(FindHostsCallback callback);

 private:
  // Start running the mDNS query on the IO thread.
  void FindHostsOnIOThread();

  // Makes the MDnsClient start listening on port 5353 on each network
  // interface.
  bool StartListening();

  // Creates a PTR transaction and finds all SMB services in the network.
  bool CreatePtrTransaction();

  // Creates an SRV transaction, which returns the hostname of |service|.
  void CreateSrvTransaction(const std::string& service);

  // Creates an A transaction, which returns the address of |raw_hostname|.
  void CreateATransaction(const std::string& raw_hostname);

  // Handler for the PTR transaction request. Returns true if the transaction
  // successfully starts.
  void OnPtrTransactionResponse(net::MDnsTransaction::Result result,
                                const net::RecordParsed* record);

  // Handler for the SRV transaction request.
  void OnSrvTransactionResponse(net::MDnsTransaction::Result result,
                                const net::RecordParsed* record);

  // Handler for the A transaction request.
  void OnATransactionResponse(const std::string& raw_hostname,
                              net::MDnsTransaction::Result result,
                              const net::RecordParsed* record);

  // Resolves services that were found through a PTR transaction request. If
  // there are no more services to be processed, this will call the
  // FindHostsCallback with the hosts found.
  void ResolveServicesFound();

  // Fires the callback if there are no more transactions left.
  void FireCallbackIfFinished();

  // Fires the callback immediately. If |success| is true, return with the hosts
  // that were found.
  void FireCallback(bool success);

 private:
  // IO thread task runner.
  scoped_refptr<base::SingleThreadTaskRunner> task_runner_;

  uint32_t remaining_transactions_ = 0;
  std::vector<std::string> services_;
  HostMap results_;
  FindHostsCallback callback_;

  // Network stack mDNS client.
  std::unique_ptr<net::MDnsSocketFactory> socket_factory_;
  std::unique_ptr<net::MDnsClient> mdns_client_;
  std::vector<std::unique_ptr<net::MDnsTransaction>> transactions_;

  SEQUENCE_CHECKER(sequence_checker_);
};

MDnsHostLocator::MDnsHostLocator()
    : io_task_runner_(content::GetIOThreadTaskRunner({})),
      impl_(nullptr, base::OnTaskRunnerDeleter(io_task_runner_)) {}

MDnsHostLocator::~MDnsHostLocator() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void MDnsHostLocator::FindHosts(FindHostsCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  // Reset any existing query.
  weak_factory_.InvalidateWeakPtrs();
  impl_.reset(new Impl(io_task_runner_));

  callback_ = std::move(callback);
  impl_->FindHosts(
      base::BindOnce(&MDnsHostLocator::PostFindHostsDone,
                     base::SingleThreadTaskRunner::GetCurrentDefault(),
                     base::BindOnce(&MDnsHostLocator::OnFindHostsDone,
                                    weak_factory_.GetWeakPtr())));
}

// static
void MDnsHostLocator::PostFindHostsDone(
    scoped_refptr<base::TaskRunner> task_runner,
    FindHostsCallback callback,
    bool success,
    const HostMap& hosts) {
  task_runner->PostTask(FROM_HERE,
                        base::BindOnce(std::move(callback), success, hosts));
}

void MDnsHostLocator::OnFindHostsDone(bool success, const HostMap& hosts) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  weak_factory_.InvalidateWeakPtrs();
  impl_.reset();

  std::move(callback_).Run(success, hosts);
}

void MDnsHostLocator::Impl::FindHosts(FindHostsCallback callback) {
  DCHECK(callback_.is_null());
  callback_ = std::move(callback);
  task_runner_->PostTask(
      FROM_HERE, base::BindOnce(&MDnsHostLocator::Impl::FindHostsOnIOThread,
                                base::Unretained(this)));
}

void MDnsHostLocator::Impl::FindHostsOnIOThread() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (!(StartListening() && CreatePtrTransaction())) {
    LOG(ERROR) << "Failed to start MDnsHostLocator";

    FireCallback(false /* success */);
    return;
  }
}

bool MDnsHostLocator::Impl::StartListening() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  socket_factory_ = net::MDnsSocketFactory::CreateDefault();
  mdns_client_ = net::MDnsClient::CreateDefault();
  int result = mdns_client_->StartListening(socket_factory_.get());
  if (result != net::OK) {
    LOG(ERROR) << "Error starting mDNS client: " << net::ErrorToString(result);
  }
  return result == net::OK;
}

bool MDnsHostLocator::Impl::CreatePtrTransaction() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  std::unique_ptr<MDnsTransaction> transaction =
      mdns_client_->CreateTransaction(
          net::dns_protocol::kTypePTR, kSmbMDnsServiceName,
          kPtrTransactionFlags,
          base::BindRepeating(&MDnsHostLocator::Impl::OnPtrTransactionResponse,
                              base::Unretained(this)));

  if (!transaction->Start()) {
    LOG(ERROR) << "Failed to start PTR transaction";
    return false;
  }

  transactions_.push_back(std::move(transaction));
  return true;
}

void MDnsHostLocator::Impl::CreateSrvTransaction(const std::string& service) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  std::unique_ptr<MDnsTransaction> transaction =
      mdns_client_->CreateTransaction(
          net::dns_protocol::kTypeSRV, service, kSrvTransactionFlags,
          base::BindRepeating(&MDnsHostLocator::Impl::OnSrvTransactionResponse,
                              base::Unretained(this)));

  if (!transaction->Start()) {
    // If the transaction fails to start, fire the callback if there are no more
    // transactions left to be processed.
    LOG(ERROR) << "Failed to start SRV transaction";

    FireCallbackIfFinished();
    return;
  }

  transactions_.push_back(std::move(transaction));
}

void MDnsHostLocator::Impl::CreateATransaction(
    const std::string& raw_hostname) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  std::unique_ptr<MDnsTransaction> transaction =
      mdns_client_->CreateTransaction(
          net::dns_protocol::kTypeA, raw_hostname, kATransactionFlags,
          base::BindRepeating(&MDnsHostLocator::Impl::OnATransactionResponse,
                              base::Unretained(this), raw_hostname));

  if (!transaction->Start()) {
    // If the transaction fails to start, fire the callback if there are no more
    // transactions left to be processed.
    LOG(ERROR) << "Failed to start A transaction";

    FireCallbackIfFinished();
    return;
  }

  transactions_.push_back(std::move(transaction));
}

void MDnsHostLocator::Impl::OnPtrTransactionResponse(
    MDnsTransaction::Result result,
    const net::RecordParsed* record) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (result == MDnsTransaction::Result::RESULT_RECORD) {
    DCHECK(record);

    const net::PtrRecordRdata* data = record->rdata<net::PtrRecordRdata>();
    DCHECK(data);

    services_.push_back(data->ptrdomain());
  } else if (result == MDnsTransaction::Result::RESULT_DONE) {
    remaining_transactions_ = services_.size();

    ResolveServicesFound();
  } else {
    LOG(ERROR) << "Error getting a PTR transaction response";
    FireCallback(false /* success */);
  }
}

void MDnsHostLocator::Impl::OnSrvTransactionResponse(
    MDnsTransaction::Result result,
    const net::RecordParsed* record) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (result != MDnsTransaction::Result::RESULT_RECORD) {
    // SRV transaction wasn't able to get a hostname. Fire the callback if there
    // are no more transactions.
    FireCallbackIfFinished();
    return;
  }

  DCHECK(record);
  const net::SrvRecordRdata* srv = record->rdata<net::SrvRecordRdata>();
  DCHECK(srv);

  CreateATransaction(srv->target());
}

void MDnsHostLocator::Impl::OnATransactionResponse(
    const std::string& raw_hostname,
    MDnsTransaction::Result result,
    const net::RecordParsed* record) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (result == MDnsTransaction::Result::RESULT_RECORD) {
    DCHECK(record);

    const net::ARecordRdata* ip = record->rdata<net::ARecordRdata>();
    DCHECK(ip);

    results_[RemoveLocal(raw_hostname)] = ip->address();
  }

  // Regardless of what the result is, check to see if the callback can be fired
  // after an A transaction returns.
  FireCallbackIfFinished();
}

void MDnsHostLocator::Impl::ResolveServicesFound() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  if (services_.empty()) {
    // Call the callback immediately.
    FireCallback(true /* success */);
  } else {
    for (const std::string& services : services_) {
      CreateSrvTransaction(services);
    }
  }
}

void MDnsHostLocator::Impl::FireCallbackIfFinished() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  DCHECK_GT(remaining_transactions_, 0u);
  if (--remaining_transactions_ == 0) {
    FireCallback(true /* success */);
  }
}

void MDnsHostLocator::Impl::FireCallback(bool success) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

  // DCHECK to ensure that remaining_transactions_ is at 0 if success is true.
  DCHECK(!success || (remaining_transactions_ == 0));

  std::move(callback_).Run(success, std::move(results_));
}

}  // namespace ash::smb_client