chromium/third_party/mediapipe/src/mediapipe/tasks/cc/genai/inference/utils/xnn_utils/sampling.cc

// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/sampling.h"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <memory>
#include <random>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/genai/inference/utils/xnn_utils/xnn_tensor.h"

namespace mediapipe::tasks::genai::xnn_utils {

absl::StatusOr<std::unique_ptr<Sampler>> Sampler::Create(Type type, int top_k,
                                                         float top_p,
                                                         float temperature,
                                                         int seed) {
  if (type == Type::kTopK || type == Type::kTopP) {
    RET_CHECK_GT(top_k, 1).SetCode(absl::StatusCode::kInvalidArgument)
        << "top_k must be > 1";
    RET_CHECK_GE(temperature, 0.0f).SetCode(absl::StatusCode::kInvalidArgument)
        << "temperature must be >= 0";
    RET_CHECK_LE(temperature, 1.0f).SetCode(absl::StatusCode::kInvalidArgument)
        << "temperature must be <= 1";
  }
  if (type == Type::kTopP) {
    RET_CHECK_GT(top_p, 0).SetCode(absl::StatusCode::kInvalidArgument)
        << "top_p must be between 0 and 1";
    RET_CHECK_LE(top_p, 1.0).SetCode(absl::StatusCode::kInvalidArgument)
        << "top_p must be between 0 and 1";
  }
  return absl::WrapUnique(new Sampler(type, top_k, top_p, temperature, seed));
}

absl::StatusOr<std::vector<std::vector<int>>> Sampler::Sample(
    const Tensor& logits) {
  if (logits.dims.size() != 3) {
    return absl::InvalidArgumentError(
        "Tensor must be (Batch, seq_len, vocab_size)");
  }

  switch (type_) {
    case Type::kGreedy:
      return SampleGreedy(logits);
    case Type::kTopK:
      return SampleTopK(logits);
    case Type::kTopP:
      return SampleTopP(logits);
    default:
      return absl::InvalidArgumentError("Unsupported sampler type");
  }
};

Sampler::Sampler(Type type, int top_k, float top_p, float temperature, int seed)
    : type_(type),
      top_k_(top_k),
      top_p_(top_p),
      temperature_(temperature),
      generator_(std::make_unique<std::mt19937>(seed)) {}

absl::StatusOr<std::vector<std::vector<int>>> Sampler::SampleGreedy(
    const Tensor& logits) {
  size_t batch_size = logits.dims[0];
  size_t draft_size = logits.dims[1];
  size_t vocab_size = logits.dims[2];

  const float* float_logits = logits.DataAs<float>();
  std::vector<std::vector<int>> outputs;
  outputs.reserve(batch_size);
  // select the token with the highest logit directly.
  for (int batch = 0; batch < batch_size; ++batch) {
    outputs.push_back(std::vector<int>());
    outputs[batch].reserve(draft_size);
    for (int draft = 0; draft < draft_size; ++draft) {
      // the index of the first logit for a single token
      int token_index =
          (batch * draft_size * vocab_size) + (draft * vocab_size);
      float max_logit = float_logits[token_index];
      int max_id = 0;
      for (int v = 0; v < vocab_size; ++v) {
        float prob = float_logits[token_index + v];
        if (prob > max_logit) {
          max_logit = prob;
          max_id = v;
        }
      }
      outputs[batch].push_back(max_id);
    }
  }
  return outputs;
};

absl::StatusOr<std::vector<std::vector<int>>> Sampler::SampleTopK(
    const Tensor& logits) {
  const size_t batch_size = logits.dims[0];
  const size_t draft_size = logits.dims[1];
  const size_t vocab_size = logits.dims[2];
  const float* flat_data = logits.DataAs<float>();

  std::vector<std::vector<int>> outputs;
  outputs.reserve(batch_size);
  for (int batch = 0; batch < batch_size; ++batch) {
    outputs.push_back(std::vector<int>());
    outputs[batch].reserve(draft_size);
    for (int draft = 0; draft < draft_size; ++draft) {
      // the index of the first logit for a single token
      int token_index =
          (batch * draft_size * vocab_size) + (draft * vocab_size);
      std::vector<std::pair<float, int>> logits_ids;
      logits_ids.reserve(vocab_size);
      for (int v = 0; v < vocab_size; ++v) {
        float logit = flat_data[token_index + v];
        logits_ids.push_back(std::make_pair(logit, v));
      }
      MP_RETURN_IF_ERROR(SelectTopK(logits_ids, top_k_));
      // No need to normalize logits here, sampler takes care of that.
      MP_RETURN_IF_ERROR(ScaledSoftmax(logits_ids, /*normalize=*/false));
      MP_ASSIGN_OR_RETURN(int sample_idx, DoSampling(logits_ids));
      outputs[batch].push_back(sample_idx);
    }
  }
  return outputs;
}

absl::StatusOr<std::vector<std::vector<int>>> Sampler::SampleTopP(
    const Tensor& logits) {
  const size_t batch_size = logits.dims[0];
  const size_t draft_size = logits.dims[1];
  const size_t vocab_size = logits.dims[2];
  const int k = top_k_ > 0 ? top_k_ : vocab_size;
  const float* flat_data = logits.DataAs<float>();

  std::vector<std::vector<int>> outputs;
  outputs.reserve(batch_size);
  for (int batch = 0; batch < batch_size; ++batch) {
    outputs.push_back(std::vector<int>());
    outputs[batch].reserve(draft_size);
    for (int draft = 0; draft < draft_size; ++draft) {
      // the index of the first logit for a single token
      int token_index =
          (batch * draft_size * vocab_size) + (draft * vocab_size);
      std::vector<std::pair<float, int>> logits_ids;
      logits_ids.reserve(vocab_size);
      for (int v = 0; v < vocab_size; ++v) {
        float logit = flat_data[token_index + v];
        logits_ids.push_back(std::make_pair(logit, v));
      }
      MP_RETURN_IF_ERROR(SelectTopK(logits_ids, k));
      MP_RETURN_IF_ERROR(ScaledSoftmax(logits_ids, /*normalize=*/true));
      MP_RETURN_IF_ERROR(SelectTopP(logits_ids, top_p_));
      MP_ASSIGN_OR_RETURN(int sample_idx, DoSampling(logits_ids));
      outputs[batch].push_back(sample_idx);
    }
  }
  return outputs;
}

absl::Status Sampler::SelectTopK(std::vector<std::pair<float, int>>& logits_ids,
                                 int k) {
  if (k > logits_ids.size()) {
    return absl::InvalidArgumentError(
        "Top k value must be smaller than the number of logits.");
  }
  std::partial_sort(
      logits_ids.begin(), logits_ids.begin() + k, logits_ids.end(),
      [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
        // reverse order.
        return a.first > b.first;
      });
  logits_ids.resize(k);
  return absl::OkStatus();
}

absl::Status Sampler::SelectTopP(std::vector<std::pair<float, int>>& logits_ids,
                                 float p) {
  int included = 0;
  float prob_sum = 0.0;
  for (const auto& [logit, _] : logits_ids) {
    ++included;
    prob_sum += logit;
    if (prob_sum >= p) {
      break;
    }
  }
  if (included == 0) {
    return absl::InternalError("Bad top_p value.");
  }
  logits_ids.resize(included);
  return absl::OkStatus();
}

absl::Status Sampler::ScaledSoftmax(
    std::vector<std::pair<float, int>>& logits_ids, bool normalize) {
  float scale = 1 / (temperature_ ? temperature_ : 1.0);
  double sum = 0.0;
  float max_logit = logits_ids[0].first;
  for (int i = 0; i < logits_ids.size(); ++i) {
    const float logit = logits_ids[i].first;
    const float p = expf(scale * (logit - max_logit));
    sum += p;
    logits_ids[i].first = p;
  }
  if (normalize) {
    for (int i = 0; i < logits_ids.size(); ++i) {
      logits_ids[i].first /= sum;
    }
  }
  return absl::OkStatus();
}

absl::StatusOr<int> Sampler::DoSampling(
    std::vector<std::pair<float, int>>& logits_ids) {
  std::vector<float> probs;
  probs.reserve(logits_ids.size());
  for (const auto& [logit, _] : logits_ids) {
    probs.push_back(logit);
  }
  // Probabilities are normalized by `discrete_distribution`.
  std::discrete_distribution<> dist(probs.begin(), probs.end());
  int sample_idx = dist(*generator_);
  return logits_ids[sample_idx].second;
}

}  // namespace mediapipe::tasks::genai::xnn_utils