// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
module on_device_model.mojom;
import "mojo/public/mojom/base/file.mojom";
import "mojo/public/mojom/base/file_path.mojom";
import "mojo/public/mojom/base/read_only_file.mojom";
// Opened file resources needed to define an adaptation.
[Stable]
struct AdaptationAssets {
// Model weights could be passed as an opened file or a file path.
// The backend type will decide which one should be used, or
// which one is preferred if both are passed. If both are unset,
// usually the operation should fail.
// APU backend: weights_path should be used.
// GPU backend: weights should be used.
// TODO(b/313919363): This should also be a ReadOnlyFile.
mojo_base.mojom.File? weights;
[MinVersion=1] mojo_base.mojom.FilePath? weights_path;
};
// Conveys the result of a language detection attempt on output text.
[Stable]
struct LanguageDetectionResult {
// Language code of the detected language. If detection was indeterminate,
// this is "und" per ISO 639-2.
string code;
// Reliability of this result, in the range [0, 1].
float reliability;
};
// Aggregated text safety evaluation results.
[Stable]
struct SafetyInfo {
// Independent safety class probabilities in the range [0, 1].
array<float> class_scores;
// Language detection information. Present if and only if the safety config is
// restricted by language.
LanguageDetectionResult? language;
};
// Partial response received via StreamingResponder.OnResponse().
[Stable]
struct ResponseChunk {
// Text for this chunk of the response.
string text;
// Optional safety information computed against the full response so far, up
// to and including `text`.
SafetyInfo? safety_info;
};
// Information pertaining to a complete response that was streamed by a
// StreamingResponder.
[Stable]
struct ResponseSummary {
// Optional safety information computed against the full response.
SafetyInfo? safety_info;
};
// Streams a response from a call to execute a model. Close this pipe to cancel
// the call to |Execute()|.
[Stable]
interface StreamingResponder {
// This is called each time a new chunk of text is available.
OnResponse@0(ResponseChunk chunk);
// This is called once when all text for the query has been returned. No other
// methods on this interface will be called after OnComplete(). `summary`
// conveys metadata about the response that was streamed.
OnComplete@1(ResponseSummary summary);
};
// Notifies the caller when the model is done processing context. Close this
// pipe to cancel the call to |AddContext()|.
[Stable]
interface ContextClient {
// Called when the context has finished processing with the number of tokens
// processed.
OnComplete@0(uint32 tokens_processed);
};
// Params to describe the adaptation to load.
[Stable]
struct LoadAdaptationParams {
// Assets for an adaptation.
AdaptationAssets assets;
};
// Options to control how the model handles input.
[Stable]
struct InputOptions {
// The text for this input.
string text;
// The maximum number of tokens that should be processed. If not set, will
// process all tokens from this input.
uint32? max_tokens;
// After text is tokenized, the offset into that vector to start processing.
// If not set, will start at the first token.
uint32? token_offset;
// If this is true, indicates this is a one-off call that wants to ignore the
// context for this input. Note that this is less efficient than running on
// top of the current context, so only use when necessary.
bool ignore_context;
// The maximum number of tokens that should be output from a call to
// Execute(). If not set, will output tokens until an end token or the maximum
// sequence length.
uint32? max_output_tokens;
// DEPRECATED: No longer does anything.
uint32? unused_safety_interval;
// These params control the output sampling. Higher `top_k` means more tokens
// are considered, higher `temperature` means less likely tokens are more
// probable.
// `top_k` should be a value from 1 to the max top K value the model was
// initialized with.
uint32? top_k;
// `temperature` should be a value greater than 0.0. Values above 1.0 may give
// poor results.
float? temperature;
};
// A session for a model that allows adding context and then executing an input
// with that context.
[Stable]
interface Session {
// Adds context to this session. Any context added here will build off of
// previous calls to |AddContext()|.
AddContext@0(InputOptions input, pending_remote<ContextClient>? client);
// Executes model on the given input. The input will be added on top of the
// context provided by |AddContext()|. The response will be streamed to
// |response|. To cancel the request, close the |response| pipe.
Execute@1(InputOptions input, pending_remote<StreamingResponder> response);
// Gets the size of the given text in tokens. Will return 0 if the text is
// empty or error occurred.
GetSizeInTokens@2(string text) => (uint32 size);
// Gets the probability score of the first token in `text` on top of the
// current context.
Score@3(string text) => (float probability);
// Clones the current session. The cloned session will have the same context
// as the current session.
[MinVersion=1]
Clone@4(pending_receiver<Session> session);
};
// A loaded model which can be queried. This interface must be controlled by the
// browser and consumers must take care to sanitize inputs.
[Stable]
interface OnDeviceModel {
// Starts a session with this model. If a session starts before the previous
// one has completed and the multiple sessions support is not enabled for
// this model, the previous session will be canceled, the mojo connection
// for the previous session will also be closed, and the ongoing connections
// of StreamingResponder & ContextClient will also be closed.
// When the multiple sessions support is enabled, nothing will be
// disconnected, and the further requests will be queued.
StartSession@0(pending_receiver<Session> session);
// Infers multiclass safety scores for the given `text` using this model's
// underlying safety classifier, if any. Returns null if classification fails
// or there is no classifier available.
ClassifyTextSafety@1(string text) => (SafetyInfo? safety_info);
// Detects the language of the text using the language classifier. Returns
// null if there is no classifier available.
DetectLanguage@2(string text) => (LanguageDetectionResult? result);
// Loads an adaptation with the specified params. This will always load the
// adaptation on top of the base model.
LoadAdaptation@3(
LoadAdaptationParams params,
pending_receiver<OnDeviceModel> model) => (LoadModelResult result);
};
// Classifies the device based on how fast it is estimated to be able to run a
// model.
[Stable, Extensible]
enum PerformanceClass {
// There was an error running the benchmark. The device is likely not able to
// run any models.
[Default] kError,
// The GPU was blocked so the benchmark could not run.
kGpuBlocked,
// The library failed to load so the benchmark could not run.
kFailedToLoadLibrary,
// The values below classify devices into a range of performance buckets.
kVeryLow,
kLow,
kMedium,
kHigh,
kVeryHigh,
};
[Stable, Extensible]
enum LoadModelResult {
kSuccess,
kGpuBlocked,
[Default] kFailedToLoadLibrary,
};