chromium/chrome/browser/resources/on_device_internals/tools.ts

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import '//resources/cr_elements/cr_button/cr_button.js';
import '//resources/cr_elements/cr_collapse/cr_collapse.js';
import '//resources/cr_elements/cr_hidden_style.css.js';
import '//resources/cr_elements/cr_input/cr_input.js';
import '//resources/cr_elements/cr_shared_vars.css.js';
import '//resources/cr_elements/cr_textarea/cr_textarea.js';
import '//resources/cr_elements/cr_expand_button/cr_expand_button.js';

import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js';
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';

import {BrowserProxy} from './browser_proxy.js';
import type {ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter} from './on_device_model.mojom-webui.js';
import {getTemplate} from './tools.html.js';

interface Response {
  text: string;
  response: string;
  responseClass: string;
  retracted: boolean;
  error: boolean;
}

interface OnDeviceInternalsToolsElement {
  $: {
    modelInput: CrInputElement,
    temperatureInput: CrInputElement,
    textInput: CrInputElement,
    topKInput: CrInputElement,
  };
}

function getPerformanceClassText(performanceClass: PerformanceClass): string {
  switch (performanceClass) {
    case PerformanceClass.kVeryLow:
      return 'Very Low';
    case PerformanceClass.kLow:
      return 'Low';
    case PerformanceClass.kMedium:
      return 'Medium';
    case PerformanceClass.kHigh:
      return 'High';
    case PerformanceClass.kVeryHigh:
      return 'Very High';
    case PerformanceClass.kGpuBlocked:
      return 'GPU blocked';
    case PerformanceClass.kFailedToLoadLibrary:
      return 'Failed to load native library';
    default:
      return 'Error';
  }
}

class OnDeviceInternalsToolsElement extends PolymerElement {
  static get is() {
    return 'on-device-internals-tools';
  }

  static get template() {
    return getTemplate();
  }

  static get properties() {
    return {
      modelPath_: {
        type: String,
        value: '',
      },
      error_: String,
      text_: String,
      loadModelStart_: {
        type: Number,
        value: 0,
      },
      currentResponse_: {
        type: Object,
        value: null,
      },
      responses_: {
        type: Array,
        value: () => [],
      },
      model_: {
        type: Object,
        value: null,
      },
      performanceClassText_: {
        type: String,
        value: 'Loading...',
      },
      contextExpanded_: Boolean,
      contextLength_: Number,
      contextText_: String,
      topK_: Number,
      temperature_: Number,
    };
  }

  static get observers() {
    return [
      'onModelOrErrorChanged_(model_, error_)',
    ];
  }

  private contextExpanded_: boolean = false;
  private contextLength_: number = 0;
  private contextText_: string;
  private currentResponse_: Response|null;
  private error_: string;
  private loadModelDuration_: number;
  private loadModelStart_: number;
  private modelPath_: string;
  private model_: OnDeviceModelRemote|null;
  private performanceClassText_: string;
  private responses_: Response[];
  private session_: SessionRemote|null = null;
  private temperature_: number = 0;
  private text_: string;
  private topK_: number = 1;

  private proxy_: BrowserProxy = BrowserProxy.getInstance();
  private responseRouter_: StreamingResponderCallbackRouter =
      new StreamingResponderCallbackRouter();

  override ready() {
    super.ready();
    this.getPerformanceClass_();
    this.$.temperatureInput.inputElement.step = '0.1';
  }

  private async getPerformanceClass_() {
    this.performanceClassText_ = getPerformanceClassText(
        (await this.proxy_.handler.getEstimatedPerformanceClass())
            .performanceClass);
  }

  private onModelOrErrorChanged_() {
    if (this.model_ !== null) {
      this.loadModelDuration_ = new Date().getTime() - this.loadModelStart_;
      this.$.textInput.focus();
    }
    this.loadModelStart_ = 0;
  }

  private onLoadClick_() {
    this.onModelSelected_();
  }

  private onServiceCrashed_() {
    if (this.currentResponse_) {
      this.currentResponse_.error = true;
      this.addResponse_();
    }
    this.error_ = 'Service crashed, please reload the model.';
    this.model_ = null;
    this.modelPath_ = '';
    this.loadModelStart_ = 0;
    this.$.modelInput.focus();
  }

  private async onModelSelected_() {
    this.error_ = '';
    if (this.model_) {
      this.model_.$.close();
    }
    this.model_ = null;
    this.loadModelStart_ = new Date().getTime();
    const modelPath = this.$.modelInput.value;
    // <if expr="is_win">
    // Windows file paths are std::wstring, so use Array<Number>.
    const processedPath = Array.from(modelPath, (c) => c.charCodeAt(0));
    // </if>
    // <if expr="not is_win">
    const processedPath = modelPath;
    // </if>
    const newModel = new OnDeviceModelRemote();
    const {result} = await this.proxy_.handler.loadModel(
        {path: processedPath}, newModel.$.bindNewPipeAndPassReceiver());
    if (result !== LoadModelResult.kSuccess) {
      this.error_ =
          'Unable to load model. Specify a correct and absolute path.';
    } else {
      this.model_ = newModel;
      this.model_.onConnectionError.addListener(() => {
        this.onServiceCrashed_();
      });
      this.startNewSession_();
      this.modelPath_ = modelPath;
    }
  }

  private onAddContextClick_() {
    if (this.session_ === null) {
      return;
    }
    this.session_.addContext(
        {
          text: this.contextText_,
          ignoreContext: false,
          maxTokens: null,
          tokenOffset: null,
          maxOutputTokens: null,
          unusedSafetyInterval: null,
          topK: null,
          temperature: null,
        },
        null);
    this.contextLength_ += this.contextText_.split(/(\s+)/).length;
    this.contextText_ = '';
  }

  private startNewSession_() {
    if (this.model_ === null) {
      return;
    }
    this.contextLength_ = 0;
    this.session_ = new SessionRemote();
    this.model_.startSession(this.session_.$.bindNewPipeAndPassReceiver());
  }

  private onCancelClick_() {
    this.responseRouter_.$.close();
    this.responseRouter_ = new StreamingResponderCallbackRouter();
    this.addResponse_();
  }

  private onExecuteClick_() {
    this.onExecute_();
  }

  private addResponse_() {
    this.unshift('responses_', this.currentResponse_);
    this.currentResponse_ = null;
    this.$.textInput.focus();
  }

  private onExecute_() {
    if (this.session_ === null) {
      return;
    }
    if (!this.$.topKInput.validate()) {
      return;
    }
    if (!this.$.temperatureInput.validate()) {
      return;
    }
    this.session_.execute(
        {
          text: this.text_,
          ignoreContext: false,
          maxTokens: null,
          tokenOffset: null,
          maxOutputTokens: null,
          unusedSafetyInterval: null,
          topK: this.topK_,
          temperature: this.temperature_,
        },
        this.responseRouter_.$.bindNewPipeAndPassRemote());
    const onResponseId =
        this.responseRouter_.onResponse.addListener((chunk: ResponseChunk) => {
          this.set(
              'currentResponse_.response',
              (this.currentResponse_?.response + chunk.text).trimStart());
        });
    const onCompleteId =
        this.responseRouter_.onComplete.addListener((_: ResponseSummary) => {
          this.addResponse_();
          this.responseRouter_.removeListener(onResponseId);
          this.responseRouter_.removeListener(onCompleteId);
        });
    this.currentResponse_ = {
      text: this.text_,
      response: '',
      responseClass: 'response',
      retracted: false,
      error: false,
    };
    this.text_ = '';
  }

  private canExecute_(): boolean {
    return !this.currentResponse_ && this.model_ !== null;
  }

  private isLoading_(): boolean {
    return this.loadModelStart_ !== 0;
  }

  private getModelText_(): string {
    if (this.modelPath_.length === 0) {
      return '';
    }
    return 'Model loaded from ' + this.modelPath_ + ' in ' +
        this.loadModelDuration_ + 'ms';
  }
}

declare global {
  interface HTMLElementTagNameMap {
    'on-device-internals-tools': OnDeviceInternalsToolsElement;
  }
}

customElements.define(
    OnDeviceInternalsToolsElement.is, OnDeviceInternalsToolsElement);