chromium/ash/lobster/lobster_session_impl_unittest.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "ash/lobster/lobster_session_impl.h"

#include <string_view>

#include "ash/public/cpp/lobster/lobster_client.h"
#include "ash/public/cpp/lobster/lobster_session.h"
#include "ash/public/cpp/lobster/lobster_system_state.h"
#include "base/files/file_util.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {
namespace {

class MockLobsterClient : public LobsterClient {
 public:
  MockLobsterClient() {}

  MockLobsterClient(const MockLobsterClient&) = delete;
  MockLobsterClient& operator=(const MockLobsterClient&) = delete;

  ~MockLobsterClient() override = default;

  MOCK_METHOD(void, SetActiveSession, (LobsterSession * session), (override));
  MOCK_METHOD(LobsterSystemState, GetSystemState, (), (override));
  MOCK_METHOD(void,
              RequestCandidates,
              (const std::string& query,
               int num_candidates,
               RequestCandidatesCallback),
              (override));
  MOCK_METHOD(void,
              InflateCandidate,
              (uint32_t seed,
               const std::string& query,
               InflateCandidateCallback),
              (override));
  MOCK_METHOD(bool,
              SubmitFeedback,
              (const std::string& query,
               const std::string& model_input,
               const std::string& description,
               const std::string& image_bytes),
              (override));
};

class LobsterSessionImplTest : public testing::Test {
 public:
  LobsterSessionImplTest() {}
  LobsterSessionImplTest(const LobsterSessionImplTest&) = delete;
  LobsterSessionImplTest& operator=(const LobsterSessionImplTest&) = delete;

  ~LobsterSessionImplTest() override = default;

  void RunUntilIdle() { task_environment_.RunUntilIdle(); }

 private:
  base::test::TaskEnvironment task_environment_;
};

TEST_F(LobsterSessionImplTest, RequestCandidatesWithThreeResults) {
  auto lobster_client = std::make_unique<MockLobsterClient>();

  EXPECT_CALL(*lobster_client,
              RequestCandidates(/*query=*/"a nice strawberry",
                                /*num_candidates=*/3, testing::_))
      .WillOnce(testing::Invoke([](std::string_view query, int num_candidates,
                                   RequestCandidatesCallback done_callback) {
        std::vector<LobsterImageCandidate> image_candidates = {
            LobsterImageCandidate(/*id=*/0, /*image_bytes=*/"a1b2c3",
                                  /*seed=*/20,
                                  /*query=*/"a nice strawberry"),
            LobsterImageCandidate(/*id=*/1, /*image_bytes=*/"d4e5f6",
                                  /*seed=*/21,
                                  /*query=*/"a nice strawberry"),
            LobsterImageCandidate(/*id=*/2, /*image_bytes=*/"g7h8i9",
                                  /*seed=*/22,
                                  /*query=*/"a nice strawberry")};
        std::move(done_callback).Run(std::move(image_candidates));
      }));

  LobsterSessionImpl session(std::move(lobster_client));

  base::test::TestFuture<const LobsterResult&> future;

  session.RequestCandidates(/*query=*/"a nice strawberry", /*num_candidates=*/3,
                            future.GetCallback());

  EXPECT_THAT(
      future.Get().value(),
      testing::ElementsAre(
          LobsterImageCandidate(/*expected_id=*/0,
                                /*expected_image_bytes=*/"a1b2c3",
                                /*seed=*/20, /*query=*/"a nice strawberry"),
          LobsterImageCandidate(/*expected_id=*/1,
                                /*expected_image_bytes=*/"d4e5f6",
                                /*seed=*/21, /*query=*/"a nice strawberry"),
          LobsterImageCandidate(/*expected_id=*/2,
                                /*expected_image_bytes=*/"g7h8i9",
                                /*seed=*/22, /*query=*/"a nice strawberry")));
}

TEST_F(LobsterSessionImplTest, RequestCandidatesReturnsUnknownError) {
  auto lobster_client = std::make_unique<MockLobsterClient>();

  EXPECT_CALL(*lobster_client,
              RequestCandidates(/*query=*/"a nice blueberry",
                                /*num_candidates=*/1, testing::_))
      .WillOnce(testing::Invoke([](std::string_view query, int num_candidates,
                                   RequestCandidatesCallback done_callback) {
        std::move(done_callback)
            .Run(base::unexpected(
                LobsterError(LobsterErrorCode::kUnknown, "unknown error")));
      }));

  LobsterSessionImpl session(std::move(lobster_client));

  base::test::TestFuture<const LobsterResult&> future;

  session.RequestCandidates(/*query=*/"a nice blueberry", /*num_candidates=*/1,
                            future.GetCallback());

  EXPECT_EQ(future.Get().error(),
            LobsterError(LobsterErrorCode::kUnknown, "unknown error"));
}

TEST_F(LobsterSessionImplTest, CanNotDownloadACandidateIfItIsNotCached) {
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});

  LobsterSessionImpl session(std::make_unique<MockLobsterClient>(), store);

  base::test::TestFuture<bool> future;
  session.DownloadCandidate(/*id=*/2, base::FilePath("dummy_path"),
                            future.GetCallback());

  EXPECT_FALSE(future.Get());
}

TEST_F(LobsterSessionImplTest, CanDownloadACandiateIfItIsInCache) {
  auto lobster_client = std::make_unique<MockLobsterClient>();
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice strawberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice strawberry"});

  ON_CALL(*lobster_client,
          InflateCandidate(/*seed=*/21, testing::_, testing::_))
      .WillByDefault([](uint32_t seed, std::string_view query,
                        InflateCandidateCallback done_callback) {
        std::vector<LobsterImageCandidate> inflated_candidates = {
            LobsterImageCandidate(/*id=*/0, /*image_bytes=*/"a1b2c3",
                                  /*seed=*/30,
                                  /*query=*/"a nice strawberry")};
        std::move(done_callback).Run(std::move(inflated_candidates));
      });

  LobsterSessionImpl session(std::move(lobster_client), store);
  session.RequestCandidates("a nice strawberry", 2,
                            base::BindOnce([](const LobsterResult&) {}));
  RunUntilIdle();

  base::test::TestFuture<bool> future;

  session.DownloadCandidate(/*id=*/1, base::FilePath("dummy_path"),
                            future.GetCallback());

  EXPECT_TRUE(future.Get());
}

TEST_F(LobsterSessionImplTest,
       CanNotPreviewFeedbackForACandidateIfItIsNotCached) {
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});

  LobsterSessionImpl session(std::make_unique<MockLobsterClient>(), store);
  base::test::TestFuture<const LobsterFeedbackPreviewResponse&> future;

  session.PreviewFeedback(/*id=*/2, future.GetCallback());

  EXPECT_EQ(future.Get().error(), "No candidate found.");
}

TEST_F(LobsterSessionImplTest, CanPreviewFeedbackForACandidateIfItIsInCache) {
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});
  LobsterSessionImpl session(std::make_unique<MockLobsterClient>(), store);
  base::test::TestFuture<const LobsterFeedbackPreviewResponse&> future;

  session.PreviewFeedback(/*id=*/1, future.GetCallback());

  ASSERT_TRUE(future.Get().has_value());
  EXPECT_EQ(future.Get()->preview_image_bytes, "d4e5f6");
  std::map<std::string, std::string> expected_feedback_preview_fields = {
      {"model_version", "dummy_version"}, {"model_input", "a nice raspberry"}};
  EXPECT_EQ(future.Get()->fields, expected_feedback_preview_fields);
}

TEST_F(LobsterSessionImplTest,
       CanNotSubmitFeedbackForACandiateIfItIsNotCached) {
  auto lobster_client = std::make_unique<MockLobsterClient>();
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});

  ON_CALL(*lobster_client, SubmitFeedback(/*query=*/"a nice raspberry",
                                          /*model_input=*/"dummy_version",
                                          /*description=*/"Awesome raspberry",
                                          /*image_bytes=*/"a1b2c3"))
      .WillByDefault(testing::Return(true));

  ON_CALL(*lobster_client, SubmitFeedback(/*query=*/"a nice raspberry",
                                          /*model_input=*/"dummy_version",
                                          /*description=*/"Awesome raspberry",
                                          /*image_bytes=*/"d4e5f6"))
      .WillByDefault(testing::Return(true));

  LobsterSessionImpl session(std::move(lobster_client), store);
  EXPECT_FALSE(session.SubmitFeedback(/*candidate_id*/ 2,
                                      /*description=*/"Awesome raspberry"));
}

TEST_F(LobsterSessionImplTest,
       CanNotSubmitFeedbackForACandiateIfSubmissionFails) {
  auto lobster_client = std::make_unique<MockLobsterClient>();
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});

  ON_CALL(*lobster_client, SubmitFeedback(/*query=*/"a nice raspberry",
                                          /*model_input=*/"dummy_version",
                                          /*description=*/"Awesome raspberry",
                                          /*image_bytes=*/"a1b2c3"))
      .WillByDefault(testing::Return(false));

  LobsterSessionImpl session(std::move(lobster_client), store);
  EXPECT_FALSE(session.SubmitFeedback(/*candidate_id*/ 0,
                                      /*description=*/"Awesome raspberry"));
}

TEST_F(LobsterSessionImplTest, CanSubmitFeedbackForACandiateIfItIsInCache) {
  auto lobster_client = std::make_unique<MockLobsterClient>();
  LobsterCandidateStore store;
  store.Cache({.id = 0,
               .image_bytes = "a1b2c3",
               .seed = 20,
               .query = "a nice raspberry"});
  store.Cache({.id = 1,
               .image_bytes = "d4e5f6",
               .seed = 21,
               .query = "a nice raspberry"});

  EXPECT_CALL(*lobster_client,
              SubmitFeedback(/*query=*/"a nice raspberry",
                             /*model_input=*/"dummy_version",
                             /*description=*/"Awesome raspberry",
                             /*image_bytes=*/"a1b2c3"))
      .WillOnce(testing::Return(true));

  EXPECT_CALL(*lobster_client,
              SubmitFeedback(/*query=*/"a nice raspberry",
                             /*model_input=*/"dummy_version",
                             /*description=*/"Awesome raspberry",
                             /*image_bytes=*/"d4e5f6"))
      .WillOnce(testing::Return(true));

  LobsterSessionImpl session(std::move(lobster_client), store);
  EXPECT_TRUE(session.SubmitFeedback(/*candidate_id*/ 0,
                                     /*description=*/"Awesome raspberry"));
  EXPECT_TRUE(session.SubmitFeedback(/*candidate_id*/ 1,
                                     /*description=*/"Awesome raspberry"));
}

}  // namespace

}  // namespace ash