chromium/ios/chrome/browser/webui/ui_bundled/on_device_llm_internals_ui.mm

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#import "ios/chrome/browser/webui/ui_bundled/on_device_llm_internals_ui.h"

#import <Foundation/Foundation.h>

#import "base/strings/sys_string_conversions.h"
#import "base/values.h"
#import "components/optimization_guide/optimization_guide_buildflags.h"
#import "ios/chrome/browser/optimization_guide/resources/on_device_llm_buildflags.h"
#import "ios/chrome/browser/shared/model/profile/profile_ios.h"
#import "ios/chrome/browser/shared/model/url/chrome_url_constants.h"
#import "ios/chrome/grit/ios_resources.h"
#import "ios/web/public/webui/url_data_source_ios.h"
#import "ios/web/public/webui/web_ui_ios.h"
#import "ios/web/public/webui/web_ui_ios_data_source.h"
#import "ios/web/public/webui/web_ui_ios_message_handler.h"

#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
#import "components/optimization_guide/internal/third_party/odml/src/odml/infra/genai/inference/c/llm_inference_engine.h"  // nogncheck
#endif

namespace {

web::WebUIIOSDataSource* CreateOnDeviceLlmInternalsUIHTMLSource() {
  web::WebUIIOSDataSource* source =
      web::WebUIIOSDataSource::Create(kChromeUIOnDeviceLlmInternalsHost);

  source->SetDefaultResource(IDR_IOS_ON_DEVICE_LLM_INTERNALS_HTML);
  return source;
}

class OnDeviceLlmInternalsHandler : public web::WebUIIOSMessageHandler {
 public:
  OnDeviceLlmInternalsHandler();

  OnDeviceLlmInternalsHandler(const OnDeviceLlmInternalsHandler&) = delete;
  OnDeviceLlmInternalsHandler& operator=(const OnDeviceLlmInternalsHandler&) =
      delete;

  ~OnDeviceLlmInternalsHandler() override;

  // WebUIIOSMessageHandler:
  void RegisterMessages() override;

 private:
  void HandleRequestModelInformation(const base::Value::List& args);
  void InitAndGenerateResponse(const base::Value::List& args);
};

}  // namespace

OnDeviceLlmInternalsHandler::OnDeviceLlmInternalsHandler() {}

OnDeviceLlmInternalsHandler::~OnDeviceLlmInternalsHandler() {}

void OnDeviceLlmInternalsHandler::RegisterMessages() {
  web_ui()->RegisterMessageCallback(
      "requestModelInformation",
      base::BindRepeating(
          &OnDeviceLlmInternalsHandler::HandleRequestModelInformation,
          base::Unretained(this)));
  web_ui()->RegisterMessageCallback(
      "initAndGenerateResponse",
      base::BindRepeating(&OnDeviceLlmInternalsHandler::InitAndGenerateResponse,
                          base::Unretained(this)));
}

void OnDeviceLlmInternalsHandler::HandleRequestModelInformation(
    const base::Value::List& args) {
  std::string model_name = BUILDFLAG(IOS_ON_DEVICE_LLM_NAME);
  if (model_name.empty()) {
    model_name = "(No model loaded)";
  }

  base::ValueView js_args[] = {model_name};
  web_ui()->CallJavascriptFunction("updateModelInformation", js_args);
}

void OnDeviceLlmInternalsHandler::InitAndGenerateResponse(
    const base::Value::List& args) {
  CHECK(args.size() == 1);

// iOS is bring-your-own-model. To enable the on-device code:
// Run `gn args out/target` or add the following to `~/.setup-gn`
// ios_on_device_llm_path = /path/to/model.bin
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
  VLOG(1) << "Init LLM and generate response...";
  VLOG(1) << "query: " << args[0];

  std::string cache_dir = base::SysNSStringToUTF8(
      [[[NSFileManager defaultManager] temporaryDirectory] path]);
  VLOG(1) << "cache_dir: " << cache_dir;

  NSString* model_file_name =
      base::SysUTF8ToNSString(BUILDFLAG(IOS_ON_DEVICE_LLM_NAME));
  NSString* model_file_extension =
      base::SysUTF8ToNSString(BUILDFLAG(IOS_ON_DEVICE_LLM_EXTENSION));
  std::string model_file_path = base::SysNSStringToUTF8([[NSBundle mainBundle]
      pathForResource:model_file_name
               ofType:model_file_extension]);
  VLOG(1) << "model_file_path: " << model_file_path;

  const LlmModelSettings model_settings = {
      .model_path = model_file_path.c_str(),
      .cache_dir = cache_dir.c_str(),
      .max_num_tokens = 512,
      .num_decode_steps_per_sync = 3,
      .sequence_batch_size = 0,
  };

  const LlmSessionConfig session_config = {
      .topk = 40,
      .topp = 1.0f,
      .temperature = 0.8f,
      .random_seed = 0,
  };

  // Create the engine.
  char* error = nullptr;
  LlmInferenceEngine_Engine* llm_engine = nullptr;
  int error_code =
      LlmInferenceEngine_CreateEngine(&model_settings, &llm_engine, &error);
  VLOG(1) << "LlmInferenceEngine_CreateEngine error code: " << error_code;
  VLOG_IF(1, error_code != 0)
      << "LlmInferenceEngine_CreateEngine error message: " << error;

  // Create the session.
  LlmInferenceEngine_Session* llm_session = nullptr;
  error_code = LlmInferenceEngine_CreateSession(llm_engine, &session_config,
                                                &llm_session, &error);
  VLOG(1) << "LlmInferenceEngine_CreateSession error code: " << error_code;
  VLOG_IF(1, error_code != 0)
      << "LlmInferenceEngine_CreateSession error message: " << error;

  // Run the inference.
  // TODO(crbug.com/356608952): Not on the main thread.
  error_code = LlmInferenceEngine_Session_AddQueryChunk(
      llm_session, args[0].GetString().c_str(), &error);
  VLOG(1) << "LlmInferenceEngine_Session_AddQueryChunk error code: "
          << error_code;
  VLOG_IF(1, error_code != 0)
      << "LlmInferenceEngine_Session_AddQueryChunk error message: " << error;

  LlmResponseContext llm_response_context =
      LlmInferenceEngine_Session_PredictSync(llm_session);

  std::string response = std::string(llm_response_context.response_array[0]);
  VLOG(1) << "LLM internals: response: " << response;

  // Delete the inference objects.
  // TODO(crbug.com/356608952): Reuse these across runs.
  LlmInferenceEngine_CloseResponseContext(&llm_response_context);
  LlmInferenceEngine_Session_Delete(llm_session);
  llm_session = nullptr;
  LlmInferenceEngine_Engine_Delete(llm_engine);
  llm_engine = nullptr;
#else
  std::string response = "No model loaded.";
#endif  // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)

  base::ValueView js_args[] = {response};
  web_ui()->CallJavascriptFunction("showResult", js_args);
}

OnDeviceLlmInternalsUI::OnDeviceLlmInternalsUI(web::WebUIIOS* web_ui,
                                               const std::string& host)
    : web::WebUIIOSController(web_ui, host) {
  web_ui->AddMessageHandler(std::make_unique<OnDeviceLlmInternalsHandler>());

  web::WebUIIOSDataSource::Add(ChromeBrowserState::FromWebUIIOS(web_ui),
                               CreateOnDeviceLlmInternalsUIHTMLSource());
}

OnDeviceLlmInternalsUI::~OnDeviceLlmInternalsUI() {}