chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/sampling.h

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_SAMPLING_H_
#define MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_SAMPLING_H_

#include <sys/stat.h>

#include <memory>
#include <random>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"

namespace mediapipe::tasks::genai::xnn_utils {

// TODO: b/331677973 - de-duplicate from
// third_party/odml/infra/genai/inference/calculators/top_p_sampler_impl.h
class Sampler {
 public:
  enum class Type { kGreedy, kTopK, kTopP };

  // Creates a Sampler.
  // * If kGreedy sampler is used, Argmax will be returned ignoring all other
  //   arguments provided.
  // * If kTopK sampler is used, the top k logit values are selected. That is
  //   followed by temperature scaling and applying softmax. Finally, a sample
  //   is drawn from the resulting distribution.
  // * If kTopP sampler is selected, the top k logits are first selcted if k >
  //   0. Otherwise, k = vocab size. This is followed by temperature scaling and
  //   applying softmax. Finally, the top p are selected from the probabilities
  //   such that sum of p_i is greater than or equal to top_p. Lastly, a sample
  //   is drawn from the resulting distribution.
  static absl::StatusOr<std::unique_ptr<Sampler>> Create(Type type, int top_k,
                                                         float top_p,
                                                         float temperature,
                                                         int seed);
  // Given an input tensor of shape `(Batch, seq_len, vocab_size)`, runs
  // the configured sampling algorithm to find a winning class. The results are
  // reported as a 2D vector of integer indices where the first axis corresponds
  // to the batch size, and the second axis corresponds to the sequence length.
  absl::StatusOr<std::vector<std::vector<int>>> Sample(const Tensor& logits);

 private:
  Sampler(Type type, int top_k, float top_p, float temperature, int seed);
  absl::StatusOr<std::vector<std::vector<int>>> SampleGreedy(
      const Tensor& logits);
  absl::StatusOr<std::vector<std::vector<int>>> SampleTopK(
      const Tensor& logits);
  absl::StatusOr<std::vector<std::vector<int>>> SampleTopP(
      const Tensor& logits);
  absl::Status SelectTopK(std::vector<std::pair<float, int>>& logits_ids,
                          int k);
  // `logits_ids` must be sorted and normalized.
  absl::Status SelectTopP(std::vector<std::pair<float, int>>& logits_ids,
                          float p);
  // `logits_ids` must be sorted.
  absl::Status ScaledSoftmax(std::vector<std::pair<float, int>>& logits_ids,
                             bool normalize);
  absl::StatusOr<int> DoSampling(
      std::vector<std::pair<float, int>>& logits_ids);

  Type type_;
  int top_k_;
  float top_p_;
  float temperature_;
  std::unique_ptr<std::mt19937> generator_;
};

}  // namespace mediapipe::tasks::genai::xnn_utils
#endif  // MEDIAPIPE_TASKS_GENAI_INFERENCE_UTILS_XNN_UTILS_SAMPLING_H_