chromium/third_party/coremltools/mlmodel/format/TreeEnsemble.proto

// Copyright (c) 2017, Apple Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-3-clause license that can be
// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause

/*
 * Each tree is a collection of nodes,
 * each of which is identified by a unique identifier.
 *
 * Each node is either a branch or a leaf node.
 * A branch node evaluates a value according to a behavior;
 * if true, the node identified by ``true_child_node_id`` is evaluated next,
 * if false, the node identified by ``false_child_node_id`` is evaluated next.
 * A leaf node adds the evaluation value to the base prediction value
 * to get the final prediction.
 *
 * A tree must have exactly one root node,
 * which has no parent node.
 * A tree must not terminate on a branch node.
 * All leaf nodes must be accessible
 * by evaluating one or more branch nodes in sequence,
 * starting from the root node.
 */

syntax = "proto3";
option optimize_for = LITE_RUNTIME;

import public "DataStructures.proto";

package CoreML.Specification;

/*
 * A tree ensemble post-evaluation transform.
 */
enum TreeEnsemblePostEvaluationTransform {
  NoTransform = 0;
  Classification_SoftMax = 1;
  Regression_Logistic = 2;
  Classification_SoftMaxWithZeroClassReference = 3;
}

/*
 * Tree ensemble parameters.
 */
message TreeEnsembleParameters {
  message TreeNode {
    uint64 treeId = 1;
    uint64 nodeId = 2;

    enum TreeNodeBehavior {
      BranchOnValueLessThanEqual = 0;
      BranchOnValueLessThan = 1;
      BranchOnValueGreaterThanEqual = 2;
      BranchOnValueGreaterThan = 3;
      BranchOnValueEqual = 4;
      BranchOnValueNotEqual = 5;
      LeafNode = 6;
    }

    /*
     * The branch mode parameters.
     *
     * If branch is false,
     * then the parameters in this section must be filled in
     * to determine how the branching functions.
     */
    TreeNodeBehavior nodeBehavior = 3;

    /*
     * If the node behavior mode is a branch mode,
     * then these values must be filled in.
     */
    uint64 branchFeatureIndex = 10;
    double branchFeatureValue = 11;
    uint64 trueChildNodeId = 12;
    uint64 falseChildNodeId = 13;
    bool missingValueTracksTrueChild = 14;

    /*
     * The leaf mode.
     *
     * If ``nodeBahavior`` == ``LeafNode``,
     * then the evaluationValue is added to the base prediction value
     * in order to get the final prediction.
     * To support multiclass classification
     * as well as regression and binary classification,
     * the evaluation value is encoded here as a sparse vector,
     * with evaluationIndex being the index of the base vector
     * that evaluation value is added to.
     * In the single class case,
     * it is expected that evaluationIndex is exactly 0.
     */
    message EvaluationInfo {
      uint64 evaluationIndex = 1;
      double evaluationValue = 2;
    }

    repeated EvaluationInfo evaluationInfo = 20;

    /*
     * The relative hit rate of a node for optimization purposes.
     *
     * This value has no effect on the accuracy of the result;
     * it allows the tree to optimize for frequent branches.
     * The value is relative,
     * compared to the hit rates of other branch nodes.
     *
     * You typically use a proportion of training samples
     * that reached this node
     * or some similar metric to derive this value.
     */
    double relativeHitRate = 30;
  }

  repeated TreeNode nodes = 1;

  /*
   * The number of prediction dimensions or classes in the model.
   *
   * All instances of ``evaluationIndex`` in a leaf node
   * must be less than this value,
   * and the number of values in the ``basePredictionValue`` field
   * must be equal to this value.
   *
   * For regression,
   * this is the dimension of the prediction.
   * For classification,
   * this is the number of classes.
   */
  uint64 numPredictionDimensions = 2;

  /*
   * The base prediction value.
   *
   * The number of values in this must match
   * the default values of the tree model.
   */
  repeated double basePredictionValue = 3;
}

/*
 * A tree ensemble classifier.
 */
message TreeEnsembleClassifier {
  TreeEnsembleParameters treeEnsemble = 1;
  TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;

  // Required class label mapping
  oneof ClassLabels {
    StringVector stringClassLabels = 100;
    Int64Vector int64ClassLabels = 101;
  }
}

/*
 * A tree ensemble regressor.
 */
message TreeEnsembleRegressor {
  TreeEnsembleParameters treeEnsemble = 1;
  TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;
}