chromium/ios/chrome/browser/sessions/model/legacy_session_restoration_service.mm

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#import "ios/chrome/browser/sessions/model/legacy_session_restoration_service.h"

#import "base/check.h"
#import "base/containers/contains.h"
#import "base/functional/callback.h"
#import "base/strings/sys_string_conversions.h"
#import "ios/chrome/browser/sessions/model/session_restoration_browser_agent.h"
#import "ios/chrome/browser/sessions/model/session_service_ios.h"
#import "ios/chrome/browser/sessions/model/web_session_state_cache.h"
#import "ios/chrome/browser/sessions/model/web_session_state_tab_helper.h"
#import "ios/chrome/browser/shared/model/profile/profile_ios.h"
#import "ios/chrome/browser/shared/model/web_state_list/web_state_list.h"
#import "ios/web/public/session/crw_session_storage.h"
#import "ios/web/public/session/proto/storage.pb.h"
#import "ios/web/public/web_state.h"

namespace {

// An output iterator that counts how many time it has been incremented.
// Allows to check if sets has non-empty intersection without allocating.
template <typename T1, typename T2>
struct CountingOutputIterator {
  CountingOutputIterator& operator++() {
    ++count;
    return *this;
  }
  CountingOutputIterator& operator++(int) {
    ++count;
    return *this;
  }

  CountingOutputIterator& operator*() { return *this; }
  CountingOutputIterator& operator=(const T1&) { return *this; }
  CountingOutputIterator& operator=(const T2&) { return *this; }

  uint32_t count = 0;
};

// Override of CountingOutputIterator<T1, T2> when types are identical.
template <typename T>
struct CountingOutputIterator<T, T> {
  CountingOutputIterator& operator++() {
    ++count;
    return *this;
  }
  CountingOutputIterator& operator++(int) {
    ++count;
    return *this;
  }

  CountingOutputIterator& operator*() { return *this; }
  CountingOutputIterator& operator=(const T&) { return *this; }

  uint32_t count = 0;
};

// Returns whether the two sets have non-empty intersection.
template <typename Range1, typename Range2>
constexpr bool HasIntersection(Range1&& range1, Range2&& range2) {
  auto result = base::ranges::set_intersection(
      std::forward<Range1>(range1), std::forward<Range2>(range2),
      CountingOutputIterator<decltype(*range1.begin()),
                             decltype(*range2.begin())>{});
  return result.count != 0;
}

// Returns the sessions identifiers for all `browsers`.
std::set<std::string> GetSessionIdentifiers(
    const std::set<Browser*>& browsers) {
  std::set<std::string> result;
  for (const Browser* browser : browsers) {
    result.insert(base::SysNSStringToUTF8(
        SessionRestorationBrowserAgent::FromBrowser(browser)->GetSessionID()));
  }
  return result;
}

}  // namespace

LegacySessionRestorationService::LegacySessionRestorationService(
    bool enable_pinned_tabs,
    bool enable_tab_groups,
    const base::FilePath& storage_path,
    SessionServiceIOS* session_service_ios,
    WebSessionStateCache* web_session_state_cache)
    : enable_pinned_tabs_(enable_pinned_tabs),
      enable_tab_groups_(enable_tab_groups),
      storage_path_(storage_path),
      session_service_ios_(session_service_ios),
      web_session_state_cache_(web_session_state_cache) {
  DCHECK(session_service_ios_);
  DCHECK(web_session_state_cache_);
}

LegacySessionRestorationService::~LegacySessionRestorationService() {}

void LegacySessionRestorationService::Shutdown() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(browsers_.empty()) << "Disconnect() must be called for all Browser";

  [session_service_ios_ shutdown];
  session_service_ios_ = nil;

  web_session_state_cache_ = nil;
}

void LegacySessionRestorationService::AddObserver(
    SessionRestorationObserver* observer) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  observers_.AddObserver(observer);
}

void LegacySessionRestorationService::RemoveObserver(
    SessionRestorationObserver* observer) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  observers_.RemoveObserver(observer);
}

void LegacySessionRestorationService::SaveSessions() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  for (Browser* browser : browsers_) {
    // SessionRestorationBrowserAgent is not created for backup Browsers.
    if (SessionRestorationBrowserAgent* browser_agent =
            SessionRestorationBrowserAgent::FromBrowser(browser)) {
      browser_agent->SaveSession(/*immediately=*/true);
    }
  }
}

void LegacySessionRestorationService::ScheduleSaveSessions() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  for (Browser* browser : browsers_) {
    // SessionRestorationBrowserAgent is not created for backup Browsers.
    if (SessionRestorationBrowserAgent* browser_agent =
            SessionRestorationBrowserAgent::FromBrowser(browser)) {
      browser_agent->SaveSession(/*immediately=*/false);
    }
  }
}

void LegacySessionRestorationService::SetSessionID(
    Browser* browser,
    const std::string& identifier) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!base::Contains(browsers_, browser));
  browsers_.insert(browser);

  browser->GetWebStateList()->AddObserver(this);

  // Create the SessionRestorationBrowserAgent for browser.
  SessionRestorationBrowserAgent::CreateForBrowser(
      browser, session_service_ios_, enable_pinned_tabs_, enable_tab_groups_);

  SessionRestorationBrowserAgent* browser_agent =
      SessionRestorationBrowserAgent::FromBrowser(browser);

  browser_agent->AddObserver(this);
  browser_agent->SetSessionID(base::SysUTF8ToNSString(identifier));
}

void LegacySessionRestorationService::LoadSession(Browser* browser) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(base::Contains(browsers_, browser));
  SessionRestorationBrowserAgent::FromBrowser(browser)->RestoreSession();
}

void LegacySessionRestorationService::LoadWebStateStorage(
    Browser* browser,
    web::WebState* web_state,
    WebStateStorageCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!base::Contains(browsers_, browser)) {
    return;
  }

  web::proto::WebStateStorage storage;
  [web_state->BuildSessionStorage() serializeToProto:storage];
  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE, base::BindOnce(std::move(callback), std::move(storage)));
}

void LegacySessionRestorationService::AttachBackup(Browser* browser,
                                                   Browser* backup) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(base::Contains(browsers_, browser));
  DCHECK(!base::Contains(browsers_, backup));
  DCHECK(!base::Contains(browsers_to_backup_, browser));

  browsers_.insert(backup);
  browsers_to_backup_.insert(std::make_pair(browser, backup));
  backups_to_browser_.insert(std::make_pair(backup, browser));
}

void LegacySessionRestorationService::Disconnect(Browser* browser) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(base::Contains(browsers_, browser));
  browsers_.erase(browser);

  // Deal with backup Browser.
  if (base::Contains(backups_to_browser_, browser)) {
    Browser* original = backups_to_browser_[browser];
    backups_to_browser_.erase(browser);
    browsers_to_backup_.erase(original);
    return;
  }

  // Must disconnect the backup Browser before the original Browser.
  DCHECK(!base::Contains(browsers_to_backup_, browser));

  SessionRestorationBrowserAgent* browser_agent =
      SessionRestorationBrowserAgent::FromBrowser(browser);

  browser_agent->SaveSession(/* immediately */ true);
  browser_agent->RemoveObserver(this);

  // Destroy the SessionRestorationBrowserAgent for browser.
  SessionRestorationBrowserAgent::RemoveFromBrowser(browser);

  browser->GetWebStateList()->RemoveObserver(this);
}

std::unique_ptr<web::WebState>
LegacySessionRestorationService::CreateUnrealizedWebState(
    Browser* browser,
    web::proto::WebStateStorage storage) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return web::WebState::CreateWithStorageSession(
      web::WebState::CreateParams(browser->GetBrowserState()),
      [[CRWSessionStorage alloc] initWithProto:storage
                              uniqueIdentifier:web::WebStateID::NewUnique()
                              stableIdentifier:[[NSUUID UUID] UUIDString]],
      base::ReturnValueOnce<NSData*>(nil));
}

void LegacySessionRestorationService::DeleteDataForDiscardedSessions(
    const std::set<std::string>& identifiers,
    base::OnceClosure closure) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!HasIntersection(identifiers, GetSessionIdentifiers(browsers_)));
  NSMutableArray<NSString*>* sessions = [[NSMutableArray alloc] init];
  for (const std::string& identifier : identifiers) {
    [sessions addObject:base::SysUTF8ToNSString(identifier)];
  }
  [session_service_ios_ deleteSessions:sessions
                             directory:storage_path_
                            completion:std::move(closure)];
}

void LegacySessionRestorationService::InvokeClosureWhenBackgroundProcessingDone(
    base::OnceClosure closure) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  [session_service_ios_ shutdownWithClosure:std::move(closure)];
}

void LegacySessionRestorationService::PurgeUnassociatedData(
    base::OnceClosure closure) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  [web_session_state_cache_
      purgeUnassociatedDataWithCompletion:std::move(closure)];
}

void LegacySessionRestorationService::ParseDataForBrowserAsync(
    Browser* browser,
    WebStateStorageIterationCallback iter_callback,
    WebStateStorageIterationCompleteCallback done) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(base::Contains(browsers_, browser));
  using SessionMap = std::map<web::WebStateID, CRWSessionStorage*>;
  SessionMap sessions;

  // Collect the data synchronously (since it is available when using the
  // legacy implementation).
  WebStateList* const web_state_list = browser->GetWebStateList();
  const int web_state_list_count = web_state_list->count();
  for (int index = 0; index < web_state_list_count; ++index) {
    web::WebState* const web_state = web_state_list->GetWebStateAt(index);
    web::WebStateID const web_state_id = web_state->GetUniqueIdentifier();
    sessions.insert(
        std::make_pair(web_state_id, web_state->BuildSessionStorage()));
  }

  // Post a task that will invoke `iter_callback` for all WebState and
  // then `done`.
  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE,
      base::BindOnce(
          [](SessionMap sessions, WebStateStorageIterationCallback iterator) {
            for (const auto& [web_state_id, session] : sessions) {
              web::proto::WebStateStorage storage;
              [session serializeToProto:storage];
              iterator.Run(web_state_id, std::move(storage));
            }
          },
          std::move(sessions), std::move(iter_callback))
          .Then(std::move(done)));
}

void LegacySessionRestorationService::WillStartSessionRestoration(
    Browser* browser) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  for (SessionRestorationObserver& observer : observers_) {
    observer.WillStartSessionRestoration(browser);
  }
}

void LegacySessionRestorationService::SessionRestorationFinished(
    Browser* browser,
    const std::vector<web::WebState*>& restored_web_states) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  for (SessionRestorationObserver& observer : observers_) {
    observer.SessionRestorationFinished(browser, restored_web_states);
  }
}

void LegacySessionRestorationService::WebStateListDidChange(
    WebStateList* web_state_list,
    const WebStateListChange& change,
    const WebStateListStatus& status) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  switch (change.type()) {
    case WebStateListChange::Type::kInsert: {
      const auto& typed_change = change.As<WebStateListChangeInsert>();
      WebSessionStateTabHelper::CreateForWebState(
          typed_change.inserted_web_state());
      break;
    }

    case WebStateListChange::Type::kReplace: {
      const auto& typed_change = change.As<WebStateListChangeReplace>();
      WebSessionStateTabHelper::CreateForWebState(
          typed_change.inserted_web_state());
      break;
    }

    default:
      break;
  }
}