chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_C_LLM_INFERENCE_ENGINE_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_C_LLM_INFERENCE_ENGINE_H_

#ifdef __cplusplus
#include <cstddef>
#else
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#endif

#ifndef ODML_EXPORT
#define ODML_EXPORT __attribute__((visibility("default")))
#endif  // ODML_EXPORT

#ifdef __cplusplus
extern "C" {
#endif

typedef void LlmInferenceEngine_Engine;

typedef void LlmInferenceEngine_Session;

// LlmActivationDataType defines the activation data type for the model.
typedef enum {
  // Use Default activation data type mentioned in the model metadata file.
  kLlmActivationDataTypeDefault = 0,

  // Use Float32 activation data type.
  kLlmActivationDataTypeFloat32 = 1,

  // Use Float16 activation data type.
  kLlmActivationDataTypeFloat16 = 2,

  // Use Int16 activation data type.
  kLlmActivationDataTypeInt16 = 3,

  // Use Int8 activation data type.
  kLlmActivationDataTypeInt8 = 4,
} LlmActivationDataType;

// LlmSessionConfig configures how to execute the model.
typedef struct {
  // Path to the model artifact.
  const char* model_path;

  // Directory path for storing model related tokenizer and cache weights. the
  // user is responsible for providing the directory that can be writable by the
  // program.
  const char* cache_dir;

  // Maximum number of tokens for input and output.
  size_t max_num_tokens;

  // Number of decode steps per sync. Used by GPU only. The default value is 3.
  size_t num_decode_steps_per_sync;

  // Sequence batch size for encoding. Used by GPU only. Number of input tokens
  // to process at a time for batch processing. Setting this value to 1 means
  // both the encoding and decoding share the same graph of sequence length
  // of 1. Setting this value to 0 means the batch size will be optimized
  // programmatically.
  size_t sequence_batch_size;

  // Number of supported lora ranks for the base model. Used by GPU only.
  size_t number_of_supported_lora_ranks;

  // The supported lora ranks for the base model. Used by GPU only.
  size_t* supported_lora_ranks;

  // Maximum top k, which is the max Top-K value supported for all
  // sessions created with the engine, used by GPU only. If a session with Top-K
  // value larger than this is being asked to be created, it will be
  // rejected(throw error). If not provided, the max top k will be 1, which
  // means only greedy decoding is supported for any sessions created with this
  // engine.
  size_t max_top_k;

  // Optional setting for specific activation data type.
  LlmActivationDataType llm_activation_data_type;

  // Optional setting for the number of draft tokens to generate when using
  // speculative decoding. Setting to 0 will disable speculative decoding.
  size_t num_draft_tokens;
} LlmModelSettings;

// LlmSessionConfig configures how to execute the model.
typedef struct {
  // Top K number of tokens to be sampled from for each decoding step.
  size_t topk;

  // Maximum cumulative probability over the tokens to sample from in each
  // decoding step for top-p / nucleus sampling.
  float topp;

  // Randomness when decoding the next token, 0.0f means greedy decoding.
  float temperature;

  // random seed, for reproducible sampling.
  size_t random_seed;

  // Path to the LoRA tflite flatbuffer file. Optional.
  // This is only compatible with GPU models.
  const char* lora_path;
} LlmSessionConfig;

// LlmResponseContext is the return type for
// LlmInferenceEngine_Session_PredictSync.
typedef struct {
  // An array of string. The size of the array depends on the number of
  // responses.
  char** response_array;

  // Number of responses.
  int response_count;

  // Done all outputs for this session.
  bool done;
} LlmResponseContext;

// Frees all context within the LlmResponseContext.
ODML_EXPORT void LlmInferenceEngine_CloseResponseContext(
    LlmResponseContext* response_context);

// Create a LlmInferenceEngine session for executing a query.
ODML_EXPORT int LlmInferenceEngine_CreateEngine(
    const LlmModelSettings* model_settings,
    LlmInferenceEngine_Engine** engine_out, char** error_msg);

// Free the engine, will wait until graph is done executing.
ODML_EXPORT void LlmInferenceEngine_Engine_Delete(
    LlmInferenceEngine_Engine* engine);

// Create a LlmInferenceEngine session for executing a query.
ODML_EXPORT int LlmInferenceEngine_CreateSession(
    LlmInferenceEngine_Engine* engine, const LlmSessionConfig* session_config,
    LlmInferenceEngine_Session** session_out, char** error_msg);

// Free the session, will wait until graph is done executing.
ODML_EXPORT void LlmInferenceEngine_Session_Delete(
    LlmInferenceEngine_Session* session);

// Add query chunk to the session. This can be called multiple times to add
// multiple query chunks before calling `PredictSync` or `PredictAsync`. The
// query chunks will be processed in the order they are added, similar to a
// concatenated prompt, but able to be processed in chunks.
ODML_EXPORT int LlmInferenceEngine_Session_AddQueryChunk(
    LlmInferenceEngine_Session* session, const char* input, char** error_msg);

// Return the generated output based on the previously added query chunks in
// sync mode.
ODML_EXPORT LlmResponseContext
LlmInferenceEngine_Session_PredictSync(LlmInferenceEngine_Session* session);

// Run callback function in async mode.
// The callback will be invoked multiple times until `response_context.done`
// is `true`. You need to invoke `LlmInferenceEngine_CloseResponseContext` after
// each invocation to free memory.
// The callback context can be a pointer to any user defined data structure as
// it is passed to the callback unmodified.
ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync(
    LlmInferenceEngine_Session* session, void* callback_context,
    void (*callback)(void* callback_context,
                     LlmResponseContext* response_context));

// Clone the provided session.
ODML_EXPORT int LlmInferenceEngine_Session_Clone(
    LlmInferenceEngine_Session* session,
    LlmInferenceEngine_Session** cloned_session, char** error_msg);

// Tokenizes an input prompt using a pre-existing processor and returns its
// length in tokens. Returns -1 if tokenization fails.
ODML_EXPORT int LlmInferenceEngine_Session_SizeInTokens(
    LlmInferenceEngine_Session* session, const char* input, char** error_msg);

#ifdef __cplusplus
}  // extern C
#endif

#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_C_LLM_INFERENCE_ENGINE_H_