// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
import {
Model,
ModelLoader as ModelLoaderBase,
ModelResponse,
ModelResponseError,
ModelState,
} from '../../core/on_device_model/types.js';
import {signal} from '../../core/reactive/signal.js';
import {
assertExhaustive,
assertExists,
assertNotReached,
} from '../../core/utils/assert.js';
import {shorten} from '../../core/utils/utils.js';
import {
FormatFeature,
LoadModelResult,
ModelState as MojoModelState,
ModelStateMonitorReceiver,
ModelStateType,
OnDeviceModelRemote,
PageHandlerRemote,
ResponseChunk,
ResponseSummary,
SessionRemote,
StreamingResponderCallbackRouter,
} from './types.js';
// The input token limit is 2048 and 3 words roughly equals to 4
// tokens. Having a conservative limit here and leaving some room for the
// template.
// TODO: b/336477498 - Get the token limit from server and accurately count the
// token size with the same tokenizer.
// TODO(shik): Make this configurable.
const MAX_CONTENT_WORDS = Math.floor(((2048 - 100) / 4) * 3);
/**
* The keys are id of the safety classes.
*
* The safety class that each id corresponds to can be found at
* //google3/chrome/intelligence/ondevice/data/example_text_safety.txtpb.
*
* TODO: b/349723775 - Adjust the threshold to the final one.
*/
const REQUEST_SAFETY_SCORE_THRESHOLDS = new Map([
[4, 0.65],
[23, 0.65],
]);
const RESPONSE_SAFETY_SCORE_THRESHOLDS = new Map([
[0, 0.65],
[1, 0.65],
[2, 0.65],
[3, 0.65],
[4, 0.65],
[5, 0.65],
[6, 0.65],
[7, 0.65],
[9, 0.65],
[10, 0.65],
[11, 0.65],
[12, 0.8],
[18, 0.7],
[20, 0.65],
[21, 0.8],
[23, 0.65],
]);
function parseResponse(res: string): string {
// Note this is NOT an underscore: ▁(U+2581)
return res.replaceAll('▁', ' ').replaceAll(/\n+/g, '\n').trim();
}
abstract class OnDeviceModel<T> implements Model<T> {
constructor(
private readonly remote: OnDeviceModelRemote,
private readonly pageRemote: PageHandlerRemote,
private readonly modelId: string,
) {
// TODO(pihsun): Handle disconnection error
}
abstract execute(content: string): Promise<ModelResponse<T>>;
private executeRaw(text: string): Promise<string> {
const session = new SessionRemote();
this.remote.startSession(session.$.bindNewPipeAndPassReceiver());
const responseRouter = new StreamingResponderCallbackRouter();
// TODO(pihsun): Error handling.
const {promise, resolve} = Promise.withResolvers<string>();
const response: string[] = [];
const onResponseId = responseRouter.onResponse.addListener(
(chunk: ResponseChunk) => {
response.push(chunk.text);
},
);
const onCompleteId = responseRouter.onComplete.addListener(
(_: ResponseSummary) => {
responseRouter.removeListener(onResponseId);
responseRouter.removeListener(onCompleteId);
responseRouter.$.close();
session.$.close();
resolve(response.join('').trimStart());
},
);
session.execute(
{
text,
ignoreContext: false,
maxTokens: null,
tokenOffset: null,
maxOutputTokens: null,
unusedSafetyInterval: null,
topK: 1,
temperature: 0,
},
responseRouter.$.bindNewPipeAndPassRemote(),
);
return promise;
}
private async contentIsUnsafe(
content: string,
thresholds: Map<number, number>,
): Promise<boolean> {
const info = await this.remote.classifyTextSafety(content);
const scores = info.safetyInfo?.classScores ?? null;
if (scores === null) {
return false;
}
for (const [idx, threshold] of thresholds.entries()) {
const score = scores[idx];
if (score !== undefined && score >= threshold) {
return true;
}
}
return false;
}
close(): void {
this.remote.$.close();
}
private async formatInput(
feature: FormatFeature,
fields: Record<string, string>,
) {
const {result} = await this.pageRemote.formatModelInput(
{value: this.modelId},
feature,
fields,
);
return result;
}
/**
* Formats the prompt with the specified `formatFeature`, runs the prompt
* through the model, and returns the result.
*
* The key of the fields of each different model / formatFeature
* combination can be found in
* //google3/chromeos/odml_foundations/lib/inference/features/models/.
*/
protected async formatAndExecute(
formatFeature: FormatFeature,
fields: Record<string, string>,
): Promise<ModelResponse<string>> {
const prompt = await this.formatInput(formatFeature, fields);
if (prompt === null) {
console.error('formatInput returns null, wrong model?');
return {kind: 'error', error: ModelResponseError.GENERAL};
}
if (await this.contentIsUnsafe(prompt, REQUEST_SAFETY_SCORE_THRESHOLDS)) {
return {kind: 'error', error: ModelResponseError.UNSAFE};
}
const result = await this.executeRaw(prompt);
if (await this.contentIsUnsafe(result, RESPONSE_SAFETY_SCORE_THRESHOLDS)) {
return {kind: 'error', error: ModelResponseError.UNSAFE};
}
return {kind: 'success', result};
}
}
export class SummaryModel extends OnDeviceModel<string> {
override async execute(content: string): Promise<ModelResponse<string>> {
content = shorten(content, MAX_CONTENT_WORDS);
const resp = await this.formatAndExecute(FormatFeature.kAudioSummary, {
transcription: content,
});
// TODO(pihsun): `Result` monadic helper class?
if (resp.kind === 'error') {
return resp;
}
const summary = parseResponse(resp.result);
return {kind: 'success', result: summary};
}
}
export class TitleSuggestionModel extends OnDeviceModel<string[]> {
override async execute(content: string): Promise<ModelResponse<string[]>> {
content = shorten(content, MAX_CONTENT_WORDS);
const resp = await this.formatAndExecute(FormatFeature.kAudioTitle, {
transcription: content,
});
if (resp.kind === 'error') {
return resp;
}
const lines = parseResponse(resp.result).split('\n');
const titles: string[] = [];
for (const line of lines) {
// Each line should start with `- ` and the title.
const lineStart = '- ';
if (line.startsWith(lineStart)) {
titles.push(line.substring(lineStart.length));
}
}
return {kind: 'success', result: titles.slice(0, 3)};
}
}
/**
* Converts ModelState from mojo to the `ModelState` interface.
*/
export function mojoModelStateToModelState(state: MojoModelState): ModelState {
switch (state.type) {
case ModelStateType.kNotInstalled:
return {kind: 'notInstalled'};
case ModelStateType.kInstalling:
return {kind: 'installing', progress: assertExists(state.progress)};
case ModelStateType.kInstalled:
return {kind: 'installed'};
case ModelStateType.kError:
return {kind: 'error'};
case ModelStateType.kUnavailable:
return {kind: 'unavailable'};
case ModelStateType.MIN_VALUE:
case ModelStateType.MAX_VALUE:
return assertNotReached(
`Got MIN_VALUE or MAX_VALUE from mojo ModelStateType: ${state.type}`,
);
default:
assertExhaustive(state.type);
}
}
abstract class ModelLoader<T> extends ModelLoaderBase<T> {
override state = signal<ModelState>({kind: 'unavailable'});
protected abstract readonly modelId: string;
abstract createModel(remote: OnDeviceModelRemote): OnDeviceModel<T>;
constructor(protected readonly remote: PageHandlerRemote) {
super();
}
async init(): Promise<void> {
const update = (state: MojoModelState) => {
this.state.value = mojoModelStateToModelState(state);
};
const monitor = new ModelStateMonitorReceiver({update});
// This should be relatively quick since in recorder_app_ui.cc we just
// return the cached state here, but we await here to avoid UI showing
// temporary unavailable state.
const {state} = await this.remote.addModelMonitor(
{value: this.modelId},
monitor.$.bindNewPipeAndPassRemote(),
);
update(state);
}
override async load(): Promise<Model<T>> {
const newModel = new OnDeviceModelRemote();
const {result} = await this.remote.loadModel(
{value: this.modelId},
newModel.$.bindNewPipeAndPassReceiver(),
);
if (result !== LoadModelResult.kSuccess) {
// TODO(pihsun): Dedicated error type?
throw new Error(`Load model failed: ${result}`);
}
return this.createModel(newModel);
}
}
export class SummaryModelLoader extends ModelLoader<string> {
protected override modelId = '73caa678-45cb-4007-abb9-f04e431376da';
override createModel(remote: OnDeviceModelRemote): SummaryModel {
return new SummaryModel(remote, this.remote, this.modelId);
}
}
export class TitleSuggestionModelLoader extends ModelLoader<string[]> {
protected override modelId = '1bdd5282-2d14-413c-bf43-9ea6d55c38a6';
override createModel(remote: OnDeviceModelRemote): TitleSuggestionModel {
return new TitleSuggestionModel(remote, this.remote, this.modelId);
}
}