chromium/chrome/browser/ash/power/ml/smart_dim/ml_agent_unittest.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/power/ml/smart_dim/ml_agent.h"

#include <memory>

#include "ash/constants/ash_features.h"
#include "base/containers/flat_map.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/strings/stringprintf.h"
#include "chromeos/dbus/machine_learning/machine_learning_client.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/test/chromeos_test_utils.h"
#include "components/assist_ranker/proto/example_preprocessor.pb.h"
#include "content/public/test/browser_task_environment.h"
#include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {
namespace power {
namespace ml {
namespace {

// Arbitrary inactivity score for the fake ml service connection to return, and
// its quantization via sigmoid transform:
constexpr double kTestInactivityScore = -3.7;
constexpr int kQuantizedTestInactivityScore = 2;

// Quantization of k20190521ModelDefaultDimThreshold (-0.5), the builtin
// threshold for SmartDimModelV3, via sigmoid.
// It's higher than kTestInactivityScore , which implies a no dim decision.
constexpr int kQuantizedBuiltinThreshold = 37;

// Arbitrary dim thresholds lower than kTestInactivityScore and its quantization
// via sigmoid transform, implying a yes dim decisions.
constexpr double kLowDimThreshold = -10.0;
constexpr int kQuantizedLowDimThreshold = 0;

// Test data lies in src/chromeos/test/data/smart_dim.
base::FilePath GetTestDataPath(const std::string& file_name) {
  base::FilePath path;
  CHECK(chromeos::test_utils::GetTestDataPath("smart_dim", file_name, &path));
  return path;
}

void LoadDownloadableSmartDimComponent(const double& threshold) {
  const char json_string_template[] =
      "{"
      "\"input_nodes\": [3],"
      "\"output_nodes\": [5],"
      "\"threshold\": %f,"
      "\"expected_feature_size\": 343,"
      "\"metrics_model_name\": \"SmartDimModel\""
      "}";
  const std::string json_string =
      base::StringPrintf(json_string_template, threshold);

  const std::string model_string = "This is a model string";

  std::string pb_string;
  const base::FilePath pb_path =
      GetTestDataPath("20181115_example_preprocessor_config.pb");
  CHECK(base::ReadFileToString(pb_path, &pb_string));

  SmartDimMlAgent::GetInstance()->OnComponentReady(
      std::make_tuple(json_string, pb_string, model_string));
}

UserActivityEvent::Features DefaultFeatures() {
  UserActivityEvent::Features features;
  // Bucketize to 95.
  features.set_battery_percent(96.0);
  features.set_device_management(UserActivityEvent::Features::UNMANAGED);
  features.set_device_mode(UserActivityEvent::Features::CLAMSHELL);
  features.set_device_type(UserActivityEvent::Features::CHROMEBOOK);
  // Bucketize to 200.
  features.set_key_events_in_last_hour(290);
  features.set_last_activity_day(UserActivityEvent::Features::THU);
  // Bucketize to 7.
  features.set_last_activity_time_sec(25920);
  // Bucketize to 7.
  features.set_last_user_activity_time_sec(25920);
  // Bucketize to 2000.
  features.set_mouse_events_in_last_hour(2600);
  features.set_on_battery(false);
  features.set_previous_negative_actions_count(3);
  features.set_previous_positive_actions_count(0);
  features.set_recent_time_active_sec(190);
  features.set_video_playing_time_sec(0);
  features.set_on_to_dim_sec(30);
  features.set_dim_to_screen_off_sec(10);
  features.set_time_since_last_key_sec(30);
  features.set_time_since_last_mouse_sec(688);
  // Bucketize to 900.
  features.set_time_since_video_ended_sec(1100);
  features.set_has_form_entry(false);
  features.set_source_id(123);  // not used.
  features.set_engagement_score(40);
  features.set_tab_domain("//mail.google.com");
  return features;
}

// Checks that |prediction| contains the specified expected decision threshold,
// score, and response. Sets |callback_done| to true so that this can be used to
// check RequestDimDecision runs its callback.
void CheckResult(bool* callback_done,
                 const int expected_threshold,
                 const int expected_score,
                 UserActivityEvent::ModelPrediction::Response expected_response,
                 UserActivityEvent::ModelPrediction prediction) {
  EXPECT_EQ(expected_response, prediction.response());
  EXPECT_EQ(expected_threshold, prediction.decision_threshold());
  EXPECT_EQ(expected_score, prediction.inactivity_score());

  *callback_done = true;
}

}  // namespace

class SmartDimMlAgentTest : public testing::Test {
 public:
  SmartDimMlAgentTest()
      : task_environment_(
            base::test::TaskEnvironment::MainThreadType::IO,
            base::test::TaskEnvironment::ThreadPoolExecutionMode::QUEUED) {}

  SmartDimMlAgentTest(const SmartDimMlAgentTest&) = delete;
  SmartDimMlAgentTest& operator=(const SmartDimMlAgentTest&) = delete;

  void SetUp() override {
    chromeos::MachineLearningClient::InitializeFake();
    chromeos::machine_learning::ServiceConnection::
        UseFakeServiceConnectionForTesting(&fake_service_connection_);
    chromeos::machine_learning::ServiceConnection::GetInstance()->Initialize();
    fake_service_connection_.SetOutputValue(
        std::vector<int64_t>{1L}, std::vector<double>{kTestInactivityScore});
  }

  void TearDown() override { chromeos::MachineLearningClient::Shutdown(); }

 protected:
  chromeos::machine_learning::FakeServiceConnectionImpl
      fake_service_connection_;
  // DownloadWorker::InitializeFromComponent posts task to BrowserThread::UI,
  // while content::BrowserTaskEnvironment provides BrowserThread support in
  // unittest.
  content::BrowserTaskEnvironment task_environment_;

 private:
  data_decoder::test::InProcessDataDecoder in_process_data_decoder_;
};

// This test covers two things:
// 1. ml_agent can swap between download worker and builtin worker as per
// IsDownloadWorkerReady.
// 2. ml_agent can combine results from worker with threshold to get right
// DIM/NO_DIM decisions.
TEST_F(SmartDimMlAgentTest, SwitchBetweenWorkers) {
  auto* agent = SmartDimMlAgent::GetInstance();
  agent->ResetForTesting();

  // Without LoadDownloadableSmartDimComponent, download_worker_ is not ready.
  EXPECT_FALSE(agent->IsDownloadWorkerReady());

  bool callback_done = false;
  // By checking prediction.decision_threshold == kQuantizedBuiltinThreshold we
  // know that builtin worker is at work. This threshold is high, so the
  // decision is NO_DIM.
  agent->RequestDimDecision(
      DefaultFeatures(),
      base::BindOnce(&CheckResult, &callback_done, kQuantizedBuiltinThreshold,
                     kQuantizedTestInactivityScore,
                     UserActivityEvent::ModelPrediction::NO_DIM));

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(callback_done);

  // After load from download components, it should use download worker.
  LoadDownloadableSmartDimComponent(kLowDimThreshold);
  task_environment_.RunUntilIdle();
  ASSERT_TRUE(agent->IsDownloadWorkerReady());

  callback_done = false;
  // By checking prediction.decision_threshold == kQuantizedLowDimThreshold we
  // know that download worker is at work. This threshold is low, so the
  // decision is DIM.
  agent->RequestDimDecision(
      DefaultFeatures(),
      base::BindOnce(&CheckResult, &callback_done, kQuantizedLowDimThreshold,
                     kQuantizedTestInactivityScore,
                     UserActivityEvent::ModelPrediction::DIM));
  task_environment_.RunUntilIdle();
  EXPECT_TRUE(callback_done);
}

// Check that CancelableOnceCallback ensures a callback doesn't execute twice,
// in case two RequestDimDecision() calls were made before any callback ran.
TEST_F(SmartDimMlAgentTest, CheckCancelableCallback) {
  SmartDimMlAgent::GetInstance()->ResetForTesting();

  bool callback_done = false;
  int num_callbacks_run = 0;
  for (int i = 0; i < 2; i++) {
    SmartDimMlAgent::GetInstance()->RequestDimDecision(
        DefaultFeatures(),
        base::BindOnce(
            [](bool* callback_done, int* num_callbacks_run,
               UserActivityEvent::ModelPrediction prediction) {
              *callback_done = true;
              (*num_callbacks_run)++;
            },
            &callback_done, &num_callbacks_run));
  }
  task_environment_.RunUntilIdle();
  EXPECT_TRUE(callback_done);
  EXPECT_EQ(1, num_callbacks_run);
}

// Check that CancelPreviousRequest() can successfully prevent a previous
// requested dim decision request from running.
TEST_F(SmartDimMlAgentTest, CheckCanceledRequest) {
  SmartDimMlAgent::GetInstance()->ResetForTesting();

  bool callback_done = false;
  SmartDimMlAgent::GetInstance()->RequestDimDecision(
      DefaultFeatures(), base::BindOnce(
                             [](bool* callback_done,
                                UserActivityEvent::ModelPrediction prediction) {
                               *callback_done = true;
                             },
                             &callback_done));
  SmartDimMlAgent::GetInstance()->CancelPreviousRequest();
  task_environment_.RunUntilIdle();
  EXPECT_FALSE(callback_done);
}

// Check that when ML service fails to load model or create graph executor,
// download_worker is initially ready, then eventually marked not ready.
TEST_F(SmartDimMlAgentTest, LoadModelFailure) {
  SmartDimMlAgent::GetInstance()->ResetForTesting();

  // Make fake_service_connection_ fail to load models and turn it to async_mode
  // to fake the real ml-service loading a bad flatbuffer model.
  fake_service_connection_.SetLoadModelFailure();
  fake_service_connection_.SetAsyncMode(true);

  // Before ml-service responds loading failure, OnConnectionError isn't
  // invoked, download_worker_ is set to ready (fake-ready).
  LoadDownloadableSmartDimComponent(kLowDimThreshold);
  task_environment_.RunUntilIdle();
  EXPECT_TRUE(SmartDimMlAgent::GetInstance()->IsDownloadWorkerReady());

  // Requests during the fake-ready status doesn't crash.
  bool callback_done = false;
  SmartDimMlAgent::GetInstance()->RequestDimDecision(
      DefaultFeatures(), base::BindOnce(
                             [](bool* callback_done,
                                UserActivityEvent::ModelPrediction prediction) {
                               *callback_done = true;
                             },
                             &callback_done));
  task_environment_.RunUntilIdle();
  EXPECT_FALSE(callback_done);

  // Ml-service responds loading failure, OnConnectionError is invoked,
  // download_worker_ is set to not ready.
  fake_service_connection_.RunPendingCalls();
  task_environment_.RunUntilIdle();
  EXPECT_FALSE(SmartDimMlAgent::GetInstance()->IsDownloadWorkerReady());

  // Reset fake_service_connection_ so that builtin_worker can process requests.
  fake_service_connection_.SetAsyncMode(false);
  fake_service_connection_.SetExecuteSuccess();
  // Requests after the fake-ready status can be processed successfully.
  SmartDimMlAgent::GetInstance()->RequestDimDecision(
      DefaultFeatures(), base::BindOnce(
                             [](bool* callback_done,
                                UserActivityEvent::ModelPrediction prediction) {
                               *callback_done = true;
                             },
                             &callback_done));
  task_environment_.RunUntilIdle();
  EXPECT_TRUE(callback_done);
}

}  // namespace ml
}  // namespace power
}  // namespace ash