// Copyright (c) 2017, Apple Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-3-clause license that can be
// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause
/*
* Each tree is a collection of nodes,
* each of which is identified by a unique identifier.
*
* Each node is either a branch or a leaf node.
* A branch node evaluates a value according to a behavior;
* if true, the node identified by ``true_child_node_id`` is evaluated next,
* if false, the node identified by ``false_child_node_id`` is evaluated next.
* A leaf node adds the evaluation value to the base prediction value
* to get the final prediction.
*
* A tree must have exactly one root node,
* which has no parent node.
* A tree must not terminate on a branch node.
* All leaf nodes must be accessible
* by evaluating one or more branch nodes in sequence,
* starting from the root node.
*/
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
import public "DataStructures.proto";
package CoreML.Specification;
/*
* A tree ensemble post-evaluation transform.
*/
enum TreeEnsemblePostEvaluationTransform {
NoTransform = 0;
Classification_SoftMax = 1;
Regression_Logistic = 2;
Classification_SoftMaxWithZeroClassReference = 3;
}
/*
* Tree ensemble parameters.
*/
message TreeEnsembleParameters {
message TreeNode {
uint64 treeId = 1;
uint64 nodeId = 2;
enum TreeNodeBehavior {
BranchOnValueLessThanEqual = 0;
BranchOnValueLessThan = 1;
BranchOnValueGreaterThanEqual = 2;
BranchOnValueGreaterThan = 3;
BranchOnValueEqual = 4;
BranchOnValueNotEqual = 5;
LeafNode = 6;
}
/*
* The branch mode parameters.
*
* If branch is false,
* then the parameters in this section must be filled in
* to determine how the branching functions.
*/
TreeNodeBehavior nodeBehavior = 3;
/*
* If the node behavior mode is a branch mode,
* then these values must be filled in.
*/
uint64 branchFeatureIndex = 10;
double branchFeatureValue = 11;
uint64 trueChildNodeId = 12;
uint64 falseChildNodeId = 13;
bool missingValueTracksTrueChild = 14;
/*
* The leaf mode.
*
* If ``nodeBahavior`` == ``LeafNode``,
* then the evaluationValue is added to the base prediction value
* in order to get the final prediction.
* To support multiclass classification
* as well as regression and binary classification,
* the evaluation value is encoded here as a sparse vector,
* with evaluationIndex being the index of the base vector
* that evaluation value is added to.
* In the single class case,
* it is expected that evaluationIndex is exactly 0.
*/
message EvaluationInfo {
uint64 evaluationIndex = 1;
double evaluationValue = 2;
}
repeated EvaluationInfo evaluationInfo = 20;
/*
* The relative hit rate of a node for optimization purposes.
*
* This value has no effect on the accuracy of the result;
* it allows the tree to optimize for frequent branches.
* The value is relative,
* compared to the hit rates of other branch nodes.
*
* You typically use a proportion of training samples
* that reached this node
* or some similar metric to derive this value.
*/
double relativeHitRate = 30;
}
repeated TreeNode nodes = 1;
/*
* The number of prediction dimensions or classes in the model.
*
* All instances of ``evaluationIndex`` in a leaf node
* must be less than this value,
* and the number of values in the ``basePredictionValue`` field
* must be equal to this value.
*
* For regression,
* this is the dimension of the prediction.
* For classification,
* this is the number of classes.
*/
uint64 numPredictionDimensions = 2;
/*
* The base prediction value.
*
* The number of values in this must match
* the default values of the tree model.
*/
repeated double basePredictionValue = 3;
}
/*
* A tree ensemble classifier.
*/
message TreeEnsembleClassifier {
TreeEnsembleParameters treeEnsemble = 1;
TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;
// Required class label mapping
oneof ClassLabels {
StringVector stringClassLabels = 100;
Int64Vector int64ClassLabels = 101;
}
}
/*
* A tree ensemble regressor.
*/
message TreeEnsembleRegressor {
TreeEnsembleParameters treeEnsemble = 1;
TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;
}