chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <pthread.h>

#include <cstdio>
#include <cstdlib>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h"
#include "mediapipe/tasks/cc/genai/inference/proto/llm_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/proto/transformer_params.pb.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/memory_mapped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/metadata_utils.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/model_data.h"
#include "mediapipe/tasks/cc/genai/inference/utils/llm_utils/scoped_file.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_builder_factory.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/llm_weights.h"
#include "sentencepiece/src/normalizer.h"  // from @com_google_sentencepiece
#include "sentencepiece/src/sentencepiece_processor.h"  // from @com_google_sentencepiece
#include "tensorflow/lite/model_builder.h"

namespace {

constexpr int kCheckLastKChars = 10;

struct LlmInferenceEngineCpu_Engine {
  sentencepiece::SentencePieceProcessor* tokenizer;
  sentencepiece::normalizer::Normalizer* normalizer;
  mediapipe::tasks::genai::xnn_utils::Llm* llm;
  int start_token_id;
  std::vector<std::string> stop_tokens;
  size_t max_num_tokens;
  ~LlmInferenceEngineCpu_Engine() {
    delete tokenizer;
    if (normalizer != nullptr) {
      delete normalizer;
    }
    delete llm;
  };
};

struct LlmInferenceEngineCpu_Session {
  const LlmInferenceEngineCpu_Engine* engine;
  std::string prompt;
  int max_num_output_tokens;
  int response_count;
  std::string last_10_char;
  std::string final_output;
  std::function<void(std::string)> cpu_callback;
  bool early_stop;
  pthread_t work_id;
  ~LlmInferenceEngineCpu_Session() { pthread_join(work_id, nullptr); };
};

void* next_token_function(void* args) {
  struct LlmInferenceEngineCpu_Session* cpu_session =
      (struct LlmInferenceEngineCpu_Session*)args;
  if (cpu_session->response_count++ < cpu_session->max_num_output_tokens) {
    if (cpu_session->early_stop) {
      return nullptr;
    }

    auto token_ids_per_step = std::vector<int>();
    auto status = cpu_session->engine->llm->GetNextToken(&token_ids_per_step);
    if (!status.ok()) {
      ABSL_LOG(FATAL) << "Failed to generate output: " << status;
    }

    // For future multithreading support.
    if (cpu_session->early_stop) {
      return nullptr;
    }

    if (cpu_session->response_count == cpu_session->max_num_output_tokens) {
      cpu_session->early_stop = true;
    }

    std::string token =
        cpu_session->engine->tokenizer->IdToPiece(token_ids_per_step[0]);
    if (cpu_session->engine->normalizer != nullptr) {
      token = cpu_session->engine->normalizer->Normalize(token);
    }
    cpu_session->last_10_char.append(token);

    int stop_index;
    for (const auto& stop_token : cpu_session->engine->stop_tokens) {
      stop_index = cpu_session->last_10_char.find(stop_token);
      if (stop_index != std::string::npos) {
        cpu_session->early_stop = true;
        cpu_session->last_10_char =
            cpu_session->last_10_char.substr(0, stop_index);
        break;
      }
    }

    std::string ready_char = "";
    if (cpu_session->early_stop) {
      ready_char = cpu_session->last_10_char;
    } else if (cpu_session->last_10_char.size() > kCheckLastKChars) {
      ready_char = cpu_session->last_10_char.substr(
          0, cpu_session->last_10_char.size() - kCheckLastKChars);
      cpu_session->last_10_char = cpu_session->last_10_char.substr(
          cpu_session->last_10_char.size() - kCheckLastKChars);
    }
    cpu_session->final_output.append(ready_char);

    cpu_session->cpu_callback(ready_char);

    next_token_function(args);
  }
  return nullptr;
};

void* start_llm_function(void* args) {
  struct LlmInferenceEngineCpu_Session* cpu_session =
      (struct LlmInferenceEngineCpu_Session*)args;

  std::vector<int> prompt_ids = {};

  auto status =
      cpu_session->engine->tokenizer->Encode(cpu_session->prompt, &prompt_ids);

  if (!status.ok()) {
    ABSL_LOG(FATAL) << "Failed to encode input: " << status;
  }
  prompt_ids.insert(prompt_ids.begin(), cpu_session->engine->start_token_id);

  ABSL_CHECK_OK(cpu_session->engine->llm->SeekTimeStep(0));
  ABSL_CHECK_OK(cpu_session->engine->llm->AddInputTokens({prompt_ids}));

  cpu_session->max_num_output_tokens =
      cpu_session->engine->max_num_tokens - prompt_ids.size();

  next_token_function(args);

  return nullptr;
}

absl::StatusOr<LlmInferenceEngine_Engine*>
LlmInferenceEngine_CreateEngine_Helper(const LlmModelSettings* model_settings) {
  MP_ASSIGN_OR_RETURN(auto model_file,
                      mediapipe::tasks::genai::llm_utils::ScopedFile::Open(
                          model_settings->model_path));
  MP_ASSIGN_OR_RETURN(auto model_data,
                      mediapipe::tasks::genai::llm_utils::ModelData::Create(
                          std::move(model_file)));

  if (model_settings->number_of_supported_lora_ranks != 0) {
    ABSL_LOG(FATAL) << "LoRA on CPU is not supported yet.";
  }

  auto llm_params_proto = model_data->GetLlmParameters();
  auto llm_params =
      mediapipe::tasks::genai::xnn_utils::LlmParams::FromLLMParametersProto(
          llm_params_proto);

  auto model_type = model_data->GetModelType();
  RET_CHECK(model_type) << "Failed to get model type.";

  MP_ASSIGN_OR_RETURN(auto backend,
                      model_data->ReadMetadata(
                          mediapipe::tasks::genai::llm_utils::kLlmBackendName));
  RET_CHECK_EQ(backend, "cpu");

  // Create directory for tokenizer and model cache file.
  if (model_settings->cache_dir != nullptr) {
    auto s = mediapipe::file::RecursivelyCreateDir(model_settings->cache_dir);
    if (!s.ok()) {
      ABSL_LOG(WARNING) << s;
    }
  }

  MP_ASSIGN_OR_RETURN(auto spm_model_content,
                      model_data->ReadMetadata("spm_vocab_model"));

  model_data.reset();

  llm_params.seq_size_T = model_settings->max_num_tokens;
  llm_params.cache_dir = model_settings->cache_dir;

  auto weight_loader = std::make_unique<
      mediapipe::tasks::genai::xnn_utils::DefaultLlmWeightsLoader>(
      model_settings->model_path, llm_params);

  auto runtime_configs =
      std::make_unique<mediapipe::tasks::genai::xnn_utils::RuntimeConfigs>();

  MP_ASSIGN_OR_RETURN(
      auto builder,
      mediapipe::tasks::genai::xnn_utils::CreateLlmBuilder(
          llm_params, std::move(runtime_configs), nullptr, *model_type));

  MP_ASSIGN_OR_RETURN(auto llm,
                      mediapipe::tasks::genai::xnn_utils::Llm::CreateLlm(
                          std::move(weight_loader), std::move(builder)));

  auto tokenizer = std::make_unique<sentencepiece::SentencePieceProcessor>();
  MP_RETURN_IF_ERROR(tokenizer->LoadFromSerializedProto(spm_model_content));

  std::vector<int> prompt_ids;
  auto status = tokenizer->Encode("hello", &prompt_ids);

  std::unique_ptr<sentencepiece::normalizer::Normalizer> normalizer;
  if (tokenizer->model_proto().has_denormalizer_spec() &&
      tokenizer->model_proto().denormalizer_spec().has_precompiled_charsmap() &&
      !tokenizer->model_proto()
           .denormalizer_spec()
           .precompiled_charsmap()
           .empty()) {
    normalizer = std::make_unique<sentencepiece::normalizer::Normalizer>(
        tokenizer->model_proto().denormalizer_spec());
  }

  std::unique_ptr<LlmInferenceEngineCpu_Engine> engine(
      new LlmInferenceEngineCpu_Engine{
          .tokenizer = tokenizer.release(),
          .normalizer = normalizer.release(),
          .llm = llm.release(),
          .start_token_id = llm_params_proto.start_token_id(),
          .stop_tokens =
              std::vector<std::string>(llm_params_proto.stop_tokens().begin(),
                                       llm_params_proto.stop_tokens().end()),
          .max_num_tokens = model_settings->max_num_tokens,
      });

  return engine.release();
}

absl::StatusOr<LlmInferenceEngine_Session*>
LlmInferenceEngine_CreateSession_Helper(
    const LlmInferenceEngineCpu_Engine* engine,
    const LlmSessionConfig* session_config) {
  std::unique_ptr<LlmInferenceEngineCpu_Session> session(
      new LlmInferenceEngineCpu_Session{.engine = engine});

  return session.release();
}

}  // namespace

void LlmInferenceEngine_CloseResponseContext(
    LlmResponseContext* response_context) {
  for (size_t i = 0; i < response_context->response_count; i++) {
    free(const_cast<char*>(response_context->response_array[i]));
  }
  free(response_context->response_array);
  response_context->response_array = nullptr;
  response_context->response_count = 0;
}

int LlmInferenceEngine_CreateEngine(const LlmModelSettings* model_settings,
                                    LlmInferenceEngine_Session** engine_out,
                                    char** error_msg) {
  auto engine = LlmInferenceEngine_CreateEngine_Helper(model_settings);
  if (!engine.ok()) {
    if (error_msg) {
      *error_msg = strdup(
          absl::StrCat("Failed to create engine: ", engine.status().ToString())
              .c_str());
    }
    return static_cast<int>(engine.status().code());
  }
  *engine_out = engine.value();
  return 0;
}

void LlmInferenceEngine_Engine_Delete(LlmInferenceEngine_Engine* engine) {
  delete reinterpret_cast<LlmInferenceEngineCpu_Engine*>(engine);
}

int LlmInferenceEngine_CreateSession(LlmInferenceEngine_Engine* engine,
                                     const LlmSessionConfig* session_config,
                                     LlmInferenceEngine_Session** session_out,
                                     char** error_msg) {
  auto cpu_engine = reinterpret_cast<LlmInferenceEngineCpu_Engine*>(engine);
  auto session =
      LlmInferenceEngine_CreateSession_Helper(cpu_engine, session_config);
  if (!session.ok()) {
    if (error_msg) {
      *error_msg = strdup(absl::StrCat("Failed to create session: ",
                                       session.status().ToString())
                              .c_str());
    }
    return static_cast<int>(session.status().code());
  }
  *session_out = session.value();
  return 0;
}

void LlmInferenceEngine_Session_Delete(LlmInferenceEngine_Session* session) {
  delete reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);
}

int LlmInferenceEngine_Session_AddQueryChunk(
    LlmInferenceEngine_Session* session, const char* input, char** error_msg) {
  auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);
  cpu_session->prompt = input;
  return 0;
}

LlmResponseContext LlmInferenceEngine_Session_PredictSync(
    LlmInferenceEngine_Session* session) {
  LlmInferenceEngine_Session_PredictAsync(
      session, nullptr,
      [](void* callback_context, LlmResponseContext* response_context) {});

  auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);
  pthread_join(cpu_session->work_id, nullptr);
  cpu_session->work_id = 0;
  auto final_output = cpu_session->final_output;

  char** result = (char**)malloc(sizeof(char*) * 1);
  if (result == nullptr) {
    ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
  }

  result[0] = (char*)malloc(sizeof(char*) * (final_output.size() + 1));
  if (result[0] == nullptr) {
    ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
  }

  snprintf(result[0], final_output.size() + 1, "%s", final_output.c_str());

  LlmResponseContext response_context = {
      .response_array = result,
      .response_count = 1,
      .done = true,
  };

  return response_context;
}

void LlmInferenceEngine_Session_PredictAsync(
    LlmInferenceEngine_Session* session, void* callback_context,
    void (*callback)(void* callback_context,
                     LlmResponseContext* response_context)) {
  auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);

  cpu_session->cpu_callback = [=](std::string responses) -> void {
    char** result = (char**)malloc(sizeof(char*) * 1);
    if (result == nullptr) {
      ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
    }

    result[0] = (char*)malloc(sizeof(char*) * (responses.size() + 1));
    if (result[0] == nullptr) {
      ABSL_LOG(FATAL) << "Failed to allocate result for cpu session.";
    }

    snprintf(result[0], responses.size() + 1, "%s", responses.c_str());
    auto response_context = std::make_unique<LlmResponseContext>();
    response_context->response_array = result,
    response_context->response_count = 1,
    response_context->done = cpu_session->early_stop;
    callback(callback_context, response_context.release());
  };

  cpu_session->final_output = "";
  cpu_session->last_10_char = "";
  cpu_session->early_stop = false;

  pthread_t work_id = 0;
  cpu_session->work_id = work_id;
  pthread_create(&cpu_session->work_id, nullptr, start_llm_function,
                 cpu_session);
}

int LlmInferenceEngine_Session_Clone(
    LlmInferenceEngine_Session* session,
    LlmInferenceEngine_Session** cloned_session, char** error_msg) {
  *error_msg = strdup("Not implemented");
  return 12;
}

int LlmInferenceEngine_Session_SizeInTokens(LlmInferenceEngine_Session* session,
                                            const char* input,
                                            char** error_msg) {
  auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);
  std::vector<int> output_ids;
  auto status = cpu_session->engine->tokenizer->Encode(input, &output_ids);
  if (!status.ok()) {
    *error_msg = strdup(status.ToString().c_str());
    return -1;
  }
  return output_ids.size();
}