chromium/chrome/browser/nearby_sharing/certificates/nearby_share_certificate_storage_impl.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/nearby_sharing/certificates/nearby_share_certificate_storage_impl.h"

#include <algorithm>
#include <optional>
#include <utility>
#include <vector>

#include "base/base64url.h"
#include "base/json/values_util.h"
#include "base/memory/ptr_util.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
#include "base/values.h"
#include "chrome/browser/nearby_sharing/certificates/common.h"
#include "chrome/browser/nearby_sharing/certificates/constants.h"
#include "chrome/browser/nearby_sharing/certificates/nearby_share_private_certificate.h"
#include "chrome/browser/nearby_sharing/common/nearby_share_prefs.h"
#include "components/cross_device/logging/logging.h"
#include "components/leveldb_proto/public/proto_database_provider.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/pref_service.h"
#include "third_party/nearby/sharing/proto/rpc_resources.pb.h"

namespace {

// Compare to leveldb_proto::Enums::InitStatus. Using a separate enum so that
// the values don't change.
// These values are persisted to logs. Entries should not be renumbered and
// numeric values should never be reused.
enum InitStatusMetric {
  kOK = 0,
  kNotInitialized = 1,
  kError = 2,
  kCorrupt = 3,
  kInvalidOperation = 4,
  kMaxValue = kInvalidOperation
};

void RecordInitializationSuccessRateMetric(bool success, size_t num_attempts) {
  base::UmaHistogramBoolean(
      "Nearby.Share.Certificates.Storage.InitializeSuccessRate", success);
  if (success) {
    base::UmaHistogramExactLinear(
        "Nearby.Share.Certificates.Storage.InitializeAttemptCount",
        num_attempts,
        kNearbyShareCertificateStorageMaxNumInitializeAttempts + 1);
  }
}

void RecordInitializationAttemptResultMetric(
    leveldb_proto::Enums::InitStatus init_status) {
  InitStatusMetric metric;
  switch (init_status) {
    case leveldb_proto::Enums::InitStatus::kOK:
      metric = InitStatusMetric::kOK;
      break;
    case leveldb_proto::Enums::InitStatus::kNotInitialized:
      metric = InitStatusMetric::kNotInitialized;
      break;
    case leveldb_proto::Enums::InitStatus::kError:
      metric = InitStatusMetric::kError;
      break;
    case leveldb_proto::Enums::InitStatus::kCorrupt:
      metric = InitStatusMetric::kCorrupt;
      break;
    case leveldb_proto::Enums::InitStatus::kInvalidOperation:
      metric = InitStatusMetric::kInvalidOperation;
      break;
  }
  base::UmaHistogramEnumeration(
      "Nearby.Share.Certificates.Storage.InitializeAttemptResult", metric);
}

void RecordAddPublicCertificatesSuccessRateMetric(bool success) {
  base::UmaHistogramBoolean(
      "Nearby.Share.Certificates.Storage.AddPublicCertificatesSuccessRate",
      success);
}

void RecordRemoveExpiredPublicCertificatesSuccessMetric(bool success) {
  base::UmaHistogramBoolean(
      "Nearby.Share.Certificates.Storage."
      "RemoveExpiredPublicCertificatesSuccessRate",
      success);
}

const base::FilePath::CharType kPublicCertificateDatabaseName[] =
    FILE_PATH_LITERAL("NearbySharePublicCertificateDatabase");

std::string EncodeString(const std::string& unencoded_string) {
  std::string encoded_string;
  base::Base64UrlEncode(unencoded_string,
                        base::Base64UrlEncodePolicy::INCLUDE_PADDING,
                        &encoded_string);
  return encoded_string;
}

std::optional<std::string> DecodeString(const std::string& encoded_string) {
  std::string decoded_string;
  if (!base::Base64UrlDecode(encoded_string,
                             base::Base64UrlDecodePolicy::REQUIRE_PADDING,
                             &decoded_string))
    return std::nullopt;

  return decoded_string;
}

bool SortBySecond(const std::pair<std::string, base::Time>& pair1,
                  const std::pair<std::string, base::Time>& pair2) {
  return pair1.second < pair2.second;
}

NearbyShareCertificateStorageImpl::ExpirationList MergeExpirations(
    const NearbyShareCertificateStorageImpl::ExpirationList& old_exp,
    const NearbyShareCertificateStorageImpl::ExpirationList& new_exp) {
  // Remove duplicates with a preference for new entries.
  std::map<std::string, base::Time> merged_map(new_exp.begin(), new_exp.end());
  merged_map.insert(old_exp.begin(), old_exp.end());
  // Convert map to vector and sort by expiration time.
  NearbyShareCertificateStorageImpl::ExpirationList merged(merged_map.begin(),
                                                           merged_map.end());
  std::sort(merged.begin(), merged.end(), SortBySecond);
  return merged;
}

base::Time TimestampToTime(nearby::sharing::proto::Timestamp timestamp) {
  return base::Time::UnixEpoch() + base::Seconds(timestamp.seconds()) +
         base::Nanoseconds(timestamp.nanos());
}

}  // namespace

// static
NearbyShareCertificateStorageImpl::Factory*
    NearbyShareCertificateStorageImpl::Factory::test_factory_ = nullptr;

// static
std::unique_ptr<NearbyShareCertificateStorage>
NearbyShareCertificateStorageImpl::Factory::Create(
    PrefService* pref_service,
    leveldb_proto::ProtoDatabaseProvider* proto_database_provider,
    const base::FilePath& profile_path) {
  if (test_factory_) {
    return test_factory_->CreateInstance(pref_service, proto_database_provider,
                                         profile_path);
  }

  base::FilePath database_path =
      profile_path.Append(kPublicCertificateDatabaseName);
  scoped_refptr<base::SequencedTaskRunner> database_task_runner =
      base::ThreadPool::CreateSequencedTaskRunner(
          {base::MayBlock(), base::TaskPriority::BEST_EFFORT});

  return std::make_unique<NearbyShareCertificateStorageImpl>(
      pref_service,
      proto_database_provider->GetDB<nearby::sharing::proto::PublicCertificate>(
          leveldb_proto::ProtoDbType::NEARBY_SHARE_PUBLIC_CERTIFICATE_DATABASE,
          database_path, database_task_runner));
}

// static
void NearbyShareCertificateStorageImpl::Factory::SetFactoryForTesting(
    Factory* test_factory) {
  test_factory_ = test_factory;
}

NearbyShareCertificateStorageImpl::Factory::~Factory() = default;

NearbyShareCertificateStorageImpl::NearbyShareCertificateStorageImpl(
    PrefService* pref_service,
    std::unique_ptr<
        leveldb_proto::ProtoDatabase<nearby::sharing::proto::PublicCertificate>>
        proto_database)
    : pref_service_(pref_service), db_(std::move(proto_database)) {
  FetchPublicCertificateExpirations();
  Initialize();
}

NearbyShareCertificateStorageImpl::~NearbyShareCertificateStorageImpl() =
    default;

void NearbyShareCertificateStorageImpl::Initialize() {
  switch (init_status_) {
    case InitStatus::kUninitialized:
    case InitStatus::kFailed:
      num_initialize_attempts_++;
      if (num_initialize_attempts_ >
          kNearbyShareCertificateStorageMaxNumInitializeAttempts) {
        FinishInitialization(false);
        break;
      }

      CD_LOG(VERBOSE, Feature::NS)
          << __func__
          << ": Attempting to initialize public certificate "
             "database. Number of attempts: "
          << num_initialize_attempts_;
      db_->Init(base::BindOnce(
          &NearbyShareCertificateStorageImpl::OnDatabaseInitialized,
          weak_ptr_factory_.GetWeakPtr(), base::TimeTicks::Now()));
      break;
    case InitStatus::kInitialized:
      NOTREACHED_IN_MIGRATION();
      break;
  }
}

void NearbyShareCertificateStorageImpl::DestroyAndReinitialize() {
  CD_LOG(ERROR, Feature::NS)
      << __func__
      << ": Public certificate database corrupt. Erasing and "
         "initializing new database.";
  init_status_ = InitStatus::kUninitialized;
  db_->Destroy(base::BindOnce(
      &NearbyShareCertificateStorageImpl::OnDatabaseDestroyedReinitialize,
      weak_ptr_factory_.GetWeakPtr()));
}

void NearbyShareCertificateStorageImpl::OnDatabaseInitialized(
    base::TimeTicks initialize_start_time,
    leveldb_proto::Enums::InitStatus status) {
  switch (status) {
    case leveldb_proto::Enums::InitStatus::kOK:
      base::UmaHistogramLongTimes(
          "Nearby.Share.Certificates.Storage.InitializeSuccessDuration",
          base::TimeTicks::Now() - initialize_start_time);
      FinishInitialization(true);
      break;
    case leveldb_proto::Enums::InitStatus::kError:
      Initialize();
      break;
    case leveldb_proto::Enums::InitStatus::kCorrupt:
      DestroyAndReinitialize();
      break;
    case leveldb_proto::Enums::InitStatus::kInvalidOperation:
    case leveldb_proto::Enums::InitStatus::kNotInitialized:
      FinishInitialization(false);
      break;
  }
  RecordInitializationAttemptResultMetric(status);
}

void NearbyShareCertificateStorageImpl::FinishInitialization(bool success) {
  init_status_ = success ? InitStatus::kInitialized : InitStatus::kFailed;
  if (success) {
    CD_LOG(VERBOSE, Feature::NS)
        << __func__ << "Public certificate database initialization succeeded.";
  } else {
    CD_LOG(ERROR, Feature::NS)
        << __func__ << "Public certificate database initialization failed.";
  }
  RecordInitializationSuccessRateMetric(success, num_initialize_attempts_);

  // We run deferred callbacks even if initialization failed not to cause
  // possible client-side blocks of next calls to the database.
  while (!deferred_callbacks_.empty()) {
    base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
        FROM_HERE, std::move(deferred_callbacks_.front()));
    deferred_callbacks_.pop();
  }
}

void NearbyShareCertificateStorageImpl::OnDatabaseDestroyedReinitialize(
    bool success) {
  if (!success) {
    CD_LOG(ERROR, Feature::NS)
        << __func__ << ": Failed to destroy public certificate database.";
    FinishInitialization(false);
    return;
  }

  public_certificate_expirations_.clear();
  SavePublicCertificateExpirations();

  Initialize();
}

void NearbyShareCertificateStorageImpl::AddPublicCertificatesCallback(
    std::unique_ptr<ExpirationList> new_expirations,
    ResultCallback callback,
    bool proceed) {
  RecordAddPublicCertificatesSuccessRateMetric(proceed);
  if (!proceed) {
    CD_LOG(ERROR, Feature::NS)
        << __func__ << ": Failed to add public certificates.";
    std::move(callback).Run(false);
    return;
  }
  CD_LOG(VERBOSE, Feature::NS)
      << __func__ << ": Successfully added public certificates.";

  public_certificate_expirations_ =
      MergeExpirations(public_certificate_expirations_, *new_expirations);
  SavePublicCertificateExpirations();
  std::move(callback).Run(true);
}

void NearbyShareCertificateStorageImpl::RemoveExpiredPublicCertificatesCallback(
    std::unique_ptr<base::flat_set<std::string>> ids_to_remove,
    ResultCallback callback,
    bool proceed) {
  RecordRemoveExpiredPublicCertificatesSuccessMetric(proceed);
  if (!proceed) {
    CD_LOG(ERROR, Feature::NS)
        << __func__ << ": Failed to remove expired public certificates.";
    std::move(callback).Run(false);
    return;
  }
  CD_LOG(VERBOSE, Feature::NS)
      << __func__ << ": Expired public certificates successfully removed.";

  std::erase_if(public_certificate_expirations_,
                [&](const ExpirationList::value_type& expiration) {
                  return ids_to_remove->contains(expiration.first);
                });
  SavePublicCertificateExpirations();
  std::move(callback).Run(true);
}

void NearbyShareCertificateStorageImpl::GetPublicCertificates(
    PublicCertificateCallback callback) {
  if (init_status_ == InitStatus::kFailed) {
    std::move(callback).Run(false, nullptr);
    return;
  }

  if (init_status_ == InitStatus::kUninitialized) {
    deferred_callbacks_.push(base::BindOnce(
        &NearbyShareCertificateStorageImpl::GetPublicCertificates,
        base::Unretained(this), std::move(callback)));
    return;
  }

  CD_LOG(VERBOSE, Feature::NS)
      << __func__ << ": Calling LoadEntries on database.";
  db_->LoadEntries(std::move(callback));
}

std::optional<std::vector<NearbySharePrivateCertificate>>
NearbyShareCertificateStorageImpl::GetPrivateCertificates() const {
  const base::Value& list = pref_service_->GetValue(
      prefs::kNearbySharingPrivateCertificateListPrefName);
  std::vector<NearbySharePrivateCertificate> certs;
  for (const base::Value& cert_dict : list.GetList()) {
    std::optional<NearbySharePrivateCertificate> cert(
        NearbySharePrivateCertificate::FromDictionary(cert_dict.GetDict()));
    if (!cert)
      return std::nullopt;

    certs.push_back(*std::move(cert));
  }
  return certs;
}

std::optional<base::Time>
NearbyShareCertificateStorageImpl::NextPublicCertificateExpirationTime() const {
  if (public_certificate_expirations_.empty())
    return std::nullopt;

  // |public_certificate_expirations_| is sorted by expiration date.
  return public_certificate_expirations_.front().second;
}

void NearbyShareCertificateStorageImpl::ReplacePrivateCertificates(
    const std::vector<NearbySharePrivateCertificate>& private_certificates) {
  base::Value::List list;
  for (const NearbySharePrivateCertificate& cert : private_certificates) {
    list.Append(cert.ToDictionary());
  }
  pref_service_->SetList(prefs::kNearbySharingPrivateCertificateListPrefName,
                         std::move(list));
}

void NearbyShareCertificateStorageImpl::AddPublicCertificates(
    const std::vector<nearby::sharing::proto::PublicCertificate>&
        public_certificates,
    ResultCallback callback) {
  if (init_status_ == InitStatus::kFailed) {
    std::move(callback).Run(false);
    return;
  }

  if (init_status_ == InitStatus::kUninitialized) {
    deferred_callbacks_.push(base::BindOnce(
        &NearbyShareCertificateStorageImpl::AddPublicCertificates,
        base::Unretained(this), public_certificates, std::move(callback)));
    return;
  }

  auto new_entries = std::make_unique<std::vector<
      std::pair<std::string, nearby::sharing::proto::PublicCertificate>>>();
  auto new_expirations = std::make_unique<ExpirationList>();
  for (const nearby::sharing::proto::PublicCertificate& cert :
       public_certificates) {
    new_entries->emplace_back(cert.secret_id(), cert);
    new_expirations->emplace_back(cert.secret_id(),
                                  TimestampToTime(cert.end_time()));
  }
  std::sort(new_expirations->begin(), new_expirations->end(), SortBySecond);

  CD_LOG(VERBOSE, Feature::NS)
      << __func__
      << ": Calling UpdateEntries on public certificate database with "
      << public_certificates.size() << " certificates.";
  db_->UpdateEntries(
      std::move(new_entries), std::make_unique<std::vector<std::string>>(),
      base::BindOnce(
          &NearbyShareCertificateStorageImpl::AddPublicCertificatesCallback,
          weak_ptr_factory_.GetWeakPtr(), std::move(new_expirations),
          std::move(callback)));
}

void NearbyShareCertificateStorageImpl::RemoveExpiredPublicCertificates(
    base::Time now,
    ResultCallback callback) {
  if (init_status_ == InitStatus::kFailed) {
    std::move(callback).Run(false);
    return;
  }

  if (init_status_ == InitStatus::kUninitialized) {
    deferred_callbacks_.push(base::BindOnce(
        &NearbyShareCertificateStorageImpl::RemoveExpiredPublicCertificates,
        base::Unretained(this), now, std::move(callback)));
    return;
  }

  auto ids_to_remove = std::make_unique<std::vector<std::string>>();
  for (const auto& pair : public_certificate_expirations_) {
    // Because the list is sorted by expiration time, break as soon as we
    // encounter an unexpired certificate. Apply a tolerance when evaluating
    // whether the certificate is expired to account for clock skew between
    // devices. This conforms this the GmsCore implementation.
    if (!IsNearbyShareCertificateExpired(
            now,
            /*not_after=*/pair.second,
            /*use_public_certificate_tolerance=*/true)) {
      break;
    }

    ids_to_remove->emplace_back(pair.first);
  }
  if (ids_to_remove->empty()) {
    std::move(callback).Run(true);
    return;
  }

  auto ids_to_add = std::make_unique<leveldb_proto::ProtoDatabase<
      nearby::sharing::proto::PublicCertificate>::KeyEntryVector>();

  auto ids_to_remove_set = std::make_unique<base::flat_set<std::string>>(
      ids_to_remove->begin(), ids_to_remove->end());

  CD_LOG(VERBOSE, Feature::NS)
      << __func__
      << ": Calling UpdateEntries on public certificate database to remove "
      << ids_to_remove->size() << " expired certificates.";
  db_->UpdateEntries(
      std::move(ids_to_add), std::move(ids_to_remove),
      base::BindOnce(&NearbyShareCertificateStorageImpl::
                         RemoveExpiredPublicCertificatesCallback,
                     weak_ptr_factory_.GetWeakPtr(),
                     std::move(ids_to_remove_set), std::move(callback)));
}

bool NearbyShareCertificateStorageImpl::FetchPublicCertificateExpirations() {
  const base::Value::Dict& dict = pref_service_->GetDict(
      prefs::kNearbySharingPublicCertificateExpirationDictPrefName);
  public_certificate_expirations_.clear();

  public_certificate_expirations_.reserve(dict.size());
  for (const auto pair : dict) {
    std::optional<std::string> id = DecodeString(pair.first);
    std::optional<base::Time> expiration = base::ValueToTime(pair.second);
    if (!id || !expiration)
      return false;

    public_certificate_expirations_.emplace_back(*id, *expiration);
  }
  std::sort(public_certificate_expirations_.begin(),
            public_certificate_expirations_.end(), SortBySecond);

  return true;
}

void NearbyShareCertificateStorageImpl::SavePublicCertificateExpirations() {
  base::Value::Dict dict;

  for (const std::pair<std::string, base::Time>& pair :
       public_certificate_expirations_) {
    dict.Set(EncodeString(pair.first), base::TimeToValue(pair.second));
  }

  pref_service_->SetDict(
      prefs::kNearbySharingPublicCertificateExpirationDictPrefName,
      std::move(dict));
}