chromium/ash/lobster/lobster_session_impl.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "ash/lobster/lobster_session_impl.h"

#include <map>
#include <memory>
#include <optional>
#include <string>

#include "ash/lobster/lobster_image_actuator.h"
#include "ash/public/cpp/lobster/lobster_client.h"
#include "ash/public/cpp/lobster/lobster_image_candidate.h"
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/types/expected.h"
#include "ui/base/ime/ash/ime_bridge.h"
#include "ui/base/ime/input_method.h"

namespace ash {

namespace {

ui::TextInputClient* GetFocusedTextInputClient() {
  const ui::InputMethod* input_method =
      IMEBridge::Get()->GetInputContextHandler()->GetInputMethod();
  if (!input_method || !input_method->GetTextInputClient()) {
    return nullptr;
  }
  return input_method->GetTextInputClient();
}

}  // namespace

LobsterSessionImpl::LobsterSessionImpl(
    std::unique_ptr<LobsterClient> client,
    const LobsterCandidateStore& candidate_store)
    : client_(std::move(client)), candidate_store_(candidate_store) {
  client_->SetActiveSession(this);
}

LobsterSessionImpl::LobsterSessionImpl(std::unique_ptr<LobsterClient> client)
    : LobsterSessionImpl(std::move(client), LobsterCandidateStore()) {}

LobsterSessionImpl::~LobsterSessionImpl() {
  client_->SetActiveSession(nullptr);
}

void LobsterSessionImpl::DownloadCandidate(int candidate_id,
                                           const base::FilePath& file_path,
                                           StatusCallback status_callback) {
  InflateCandidateAndPerformAction(
      candidate_id,
      base::BindOnce(
          [](const base::FilePath& file_path, const std::string& image_bytes) {
            WriteImageToPath(file_path, image_bytes);
          },
          file_path),
      std::move(status_callback));
}

void LobsterSessionImpl::RequestCandidates(const std::string& query,
                                           int num_candidates,
                                           RequestCandidatesCallback callback) {
  client_->RequestCandidates(
      query, num_candidates,
      base::BindOnce(&LobsterSessionImpl::OnRequestCandidates,
                     weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}

void LobsterSessionImpl::CommitAsInsert(int candidate_id,
                                        StatusCallback status_callback) {
  InflateCandidateAndPerformAction(
      candidate_id, base::BindOnce([](const std::string& image_bytes) {
        InsertImageOrCopyToClipboard(GetFocusedTextInputClient(), image_bytes);
      }),
      std::move(status_callback));
}

void LobsterSessionImpl::CommitAsDownload(int candidate_id,
                                          const base::FilePath& file_path,
                                          StatusCallback status_callback) {
  InflateCandidateAndPerformAction(
      candidate_id,
      base::BindOnce(
          [](const base::FilePath& file_path, const std::string& image_bytes) {
            WriteImageToPath(file_path, image_bytes);
          },
          file_path),
      std::move(status_callback));
}

void LobsterSessionImpl::PreviewFeedback(
    int candidate_id,
    LobsterPreviewFeedbackCallback callback) {
  std::optional<LobsterImageCandidate> candidate =
      candidate_store_.FindCandidateById(candidate_id);
  if (!candidate.has_value()) {
    std::move(callback).Run(base::unexpected("No candidate found."));
    return;
  }

  // TODO: b/362403784 - add the proper version.
  std::move(callback).Run(LobsterFeedbackPreview(
      {{"model_version", "dummy_version"}, {"model_input", candidate->query}},
      candidate->image_bytes));
}

bool LobsterSessionImpl::SubmitFeedback(int candidate_id,
                                        const std::string& description) {
  std::optional<LobsterImageCandidate> candidate =
      candidate_store_.FindCandidateById(candidate_id);
  if (!candidate.has_value()) {
    return false;
  }
  // Submit feedback along with the preview image.
  // TODO: b/362403784 - add the proper version.
  return client_->SubmitFeedback(/*query=*/candidate->query,
                                 /*model_version=*/"dummy_version",
                                 /*description=*/description,
                                 /*image_bytes=*/candidate->image_bytes);
}

void LobsterSessionImpl::OnRequestCandidates(RequestCandidatesCallback callback,
                                             const LobsterResult& result) {
  if (result.has_value()) {
    for (auto& image_candidate : *result) {
      candidate_store_.Cache(image_candidate);
    }
  }
  std::move(callback).Run(result);
}

void LobsterSessionImpl::InflateCandidateAndPerformAction(
    int candidate_id,
    ActionCallback action_callback,
    StatusCallback status_callback) {
  std::optional<LobsterImageCandidate> candidate =
      candidate_store_.FindCandidateById(candidate_id);
  if (!candidate.has_value()) {
    LOG(ERROR) << "No candidate found.";
    std::move(status_callback).Run(false);
    return;
  }

  client_->InflateCandidate(
      candidate->seed, candidate->query,
      base::BindOnce(
          [](ActionCallback action_callback, const LobsterResult& result) {
            if (!result.has_value()) {
              LOG(ERROR) << "No image candidate";
              return false;
            }

            // TODO: b/348283703 - Return the value of action callback.
            std::move(action_callback).Run((*result)[0].image_bytes);
            return true;
          },
          std::move(action_callback))
          .Then(base::BindOnce(
              [](StatusCallback status_callback, bool success) {
                std::move(status_callback).Run(success);
              },
              std::move(status_callback))));
}

}  // namespace ash