chromium/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_

#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "absl/types/optional.h"  // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"

namespace tflite {
namespace task {
namespace vision {

// Sigmoid structure.
struct Sigmoid {};

std::ostream& operator<<(std::ostream& os, const Sigmoid& s);

// Transformation function to use for computing transformation scores.
enum class ScoreTransformation {};

// Sigmoid calibration parameters.
struct SigmoidCalibrationParameters {};

// This class is used to calibrate predicted scores so that scores are
// comparable across labels. Depending on the particular calibration parameters
// being used, the calibrated scores can also be approximately interpreted as a
// likelihood of being correct. For a given TF Lite model, such parameters are
// typically obtained from TF Lite Metadata (see ScoreCalibrationOptions).
class ScoreCalibration {};

// Builds SigmoidCalibrationParameters using data obtained from TF Lite Metadata
// (see ScoreCalibrationOptions in metadata schema).
//
// The provided `score_calibration_file` represents the contents of the score
// calibration associated file (TENSOR_AXIS_SCORE_CALIBRATION), i.e. one set of
// parameters (scale, slope, etc) per line. Each line must be in 1:1
// correspondence with `label_map_items`, so as to associate each sigmoid to its
// corresponding label name. Returns an error if no valid parameters could be
// built (e.g. malformed parameters).
tflite::support::StatusOr<SigmoidCalibrationParameters>
BuildSigmoidCalibrationParams(
    const tflite::ScoreCalibrationOptions& score_calibration_options,
    absl::string_view score_calibration_file,
    const std::vector<LabelMapItem>& label_map_items);

}  // namespace vision
}  // namespace task
}  // namespace tflite

#endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_