chromium/components/segmentation_platform/internal/tools/create_class.py

# Copyright 2022 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Script to generate header cc and unittest file for a class in chromium."""

_DOCUMENTATION = r"""Usage:

To generate default model template files:
  python3 components/segmentation_platform/internal/tools/create_class.py \
      --segment_id MY_FEATURE_USER

To generate generic header and cc files:
  python3 components/segmentation_platform/internal/tools/create_class.py \
      --header src/dir/class_name.h

If any of the file already exists then prints a log and does not touch the
file, but still creates the remaining files.
"""

import argparse
import datetime
import logging
import os
import sys

_HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef {macro}
#define {macro}

namespace {namespace} {{

class {clas} {{
 public:
  {clas}();
  ~{clas}();

  {clas}(const {clas}&) = delete;
  {clas}& operator=(const {clas}&) = delete;

 private:
}};

}}

#endif  // {macro}
"""

_CC_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "{file_path}"

namespace {namespace} {{

{clas}::{clas} () = default;
{clas}::~{clas}() = default;

}}
"""

_TEST_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "{file_path}"

#include "testing/gtest/include/gtest/gtest.h"

namespace {namespace} {{

class {test_class} : public testing::Test {{
 public:
  {test_class}() = default;
  ~{test_class}() override = default;

  void SetUp() override {{
    Test::SetUp();
  }}

  void TearDown() override {{
    Test::TearDown();
  }}

 protected:
}};

TEST_F({test_class}, Test) {{
}}

}}
"""

_MODEL_HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef {macro}
#define {macro}

#include <memory>

#include "base/feature_list.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/model_provider.h"

namespace {namespace} {{

// Feature flag for enabling {clas} segment.
BASE_DECLARE_FEATURE(kSegmentationPlatform{clas});

// Model to predict whether the user belongs to {clas} segment.
class {clas} : public DefaultModelProvider {{
 public:
  static constexpr char k{clas}Key[] = "{segmentation_key}";
  static constexpr char k{clas}UmaName[] = "{clas}";

  {clas}();
  ~{clas}() override = default;

  {clas}(const {clas}&) = delete;
  {clas}& operator=(const {clas}&) = delete;

  static std::unique_ptr<Config> GetConfig();

  // ModelProvider implementation.
  std::unique_ptr<ModelConfig> GetModelConfig() override;
  void ExecuteModelWithInput(const ModelProvider::Request& inputs,
                             ExecutionCallback callback) override;
}};

}}

#endif  // {macro}
"""

_MODEL_CC_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "{file_path}"

#include <memory>

#include "base/task/sequenced_task_runner.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/proto/aggregation.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"


namespace {namespace} {{

BASE_FEATURE(kSegmentationPlatform{clas},
             "SegmentationPlatform{clas}",
             base::FEATURE_DISABLED_BY_DEFAULT);

namespace {{
using proto::SegmentId;

// Default parameters for {clas} model.
constexpr SegmentId kSegmentId = SegmentId::{segment_id};
constexpr int64_t kModelVersion = 1;
// Store 28 buckets of input data (28 days).
constexpr int64_t kSignalStorageLength = 28;
// Wait until we have 7 days of data.
constexpr int64_t kMinSignalCollectionLength = 7;
// Refresh the result every 7 days.
constexpr int64_t kResultTTLDays = 7;

// InputFeatures.

// Enum values for the Example.EnumHistogram.
constexpr std::array<int32_t, 3> kEnumValues{{
    0, 3, 4
}};

// Set UMA metrics to use as input.
// TODO: Fill in the necessary signals for prediction.
constexpr std::array<MetadataWriter::UMAFeature, 3> kUMAFeatures = {{
    // Total amount of times user action was recorded in last 14 days.
    MetadataWriter::UMAFeature::FromUserAction("UserActionName", 14),

    // Total value of all records of the histogram in last 7 days.
    MetadataWriter::UMAFeature::FromValueHistogram(
        "Example.ValueHistogram", 7, proto::Aggregation::SUM),

    // Total count of number of records of enum histogram with given values.
    MetadataWriter::UMAFeature::FromEnumHistogram(
        "Example.EnumHistogram",
        14,
        kEnumValues.data(),
        kEnumValues.size()),
}};

}}  // namespace

// static
std::unique_ptr<Config> {clas}::GetConfig() {{
  if (!base::FeatureList::IsEnabled(
          kSegmentationPlatform{clas})) {{
    return nullptr;
  }}
  auto config = std::make_unique<Config>();
  config->segmentation_key = k{clas}Key;
  config->segmentation_uma_name = k{clas}UmaName;
  config->AddSegmentId(kSegmentId,
                       std::make_unique<{clas}>());
  config->auto_execute_and_cache = false;
  return config;
}}

{clas}::{clas}()
    : DefaultModelProvider(kSegmentId) {{}}

std::unique_ptr<DefaultModelProvider::ModelConfig> {clas}::GetModelConfig() {{
  proto::SegmentationModelMetadata metadata;
  MetadataWriter writer(&metadata);
  writer.SetDefaultSegmentationMetadataConfig(
      kMinSignalCollectionLength,
      kSignalStorageLength);

  // Set output config.
  const char kNot{clas}Label[] = "Not{clas}";
  writer.AddOutputConfigForBinaryClassifier(
      0.5,
      /*positive_label=*/k{clas}UmaName,
      kNot{clas}Label);
  writer.AddPredictedResultTTLInOutputConfig(
      /*top_label_to_ttl_list=*/{{}},
      /*default_ttl=*/kResultTTLDays, proto::TimeUnit::DAY);

  // Set features.
  writer.AddUmaFeatures(kUMAFeatures.data(),
                        kUMAFeatures.size());

  return std::make_unique<ModelConfig>(std::move(metadata), kModelVersion);
}}

void {clas}::ExecuteModelWithInput(
    const ModelProvider::Request& inputs,
    ExecutionCallback callback) {{
  // Invalid inputs.
  if (inputs.size() != kUMAFeatures.size()) {{
    base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
        FROM_HERE, base::BindOnce(std::move(callback), std::nullopt));
    return;
  }}

  // TODO: Update the heuristics here to return 1 when the user belongs to
  // {clas}.

  float result = 0;
  const int user_action_count = inputs[0];
  const int value_histogram_total = inputs[1];
  const int enum_hit_count = inputs[2];
  if (user_action_count && value_histogram_total && enum_hit_count)
    result = 1;

  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE,
      base::BindOnce(std::move(callback), ModelProvider::Response(1, result)));
}}

}}
"""

_MODEL_TEST_TEMPLATE = """// Copyright {year} The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "{file_path}"

#include "components/segmentation_platform/embedder/default_model/default_model_test_base.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace {namespace} {{

class {test_class} : public DefaultModelTestBase {{
 public:
  {test_class}() : DefaultModelTestBase(std::make_unique<{clas}>()) {{}}
  ~{test_class}() override = default;
}};

TEST_F({test_class}, InitAndFetchModel) {{
  ExpectInitAndFetchModel();
}}

TEST_F({test_class}, ExecuteModelWithInput) {{
  // TODO: Add test cases to verify if the heuristic returns the right segment.
  ExpectExecutionWithInput(/*inputs=*/{{1, 2, 3}}, /*expected_error=*/false,
                           /*expected_result=*/{{1}});
}}

}}
"""


def _GetLogger():
    """Logger for the tool."""
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger('create_class')
    logger.setLevel(level=logging.INFO)
    return logger


def _WriteFile(path, type_str, contents):
    """Writes a file with contents to the path, if not exists."""
    if os.path.exists(path):
        _GetLogger().error('%s already exists', type_str)
        return

    _GetLogger().info('Writing %s file %s', type_str, path)
    with open(path, 'w') as f:
        f.write(contents)


def _GetClassNameFromFile(header):
    """Gets a class name from the header file name."""
    file_base = os.path.basename(header).replace('.h', '')
    class_name = ''
    for i in range(len(file_base)):
        if i == 0 or file_base[i - 1] == '_':
            class_name += file_base[i].upper()
        elif file_base[i] == '_':
            continue
        else:
            class_name += file_base[i]
    return class_name


def _GetSegmentationKeyFromFile(header):
    """Gets the segmentation key based on the header file."""
    return os.path.basename(header).replace('.h', '')


def _GetHeader(args):
    """Parses the args and returns path to the header file."""
    if args.header:
        if '.h' not in args.header:
            raise ValueError('The first argument should be a path to header')

        _GetLogger().info('Creating class for header %s', args.header)
        return args.header

    if args.segment_id:
        _PREFIXES_TO_REMOVE = [
            'OPTIMIZATION_TARGET_SEGMENTATION_', 'OPTIMIZATION_TARGET_'
        ]
        _GetLogger().info('Creating default model for %s', args.segment_id)
        model_name = args.segment_id
        for prefix in _PREFIXES_TO_REMOVE:
            print(prefix, model_name, model_name.startswith(prefix))
            if model_name.startswith(prefix):
                model_name = model_name[len(prefix):]
                break
        print(model_name)
        return (
            'components/segmentation_platform/embedder/default_model/%s.h' %
            model_name.lower())

    raise ValueError('Required either --header or --segment_id argument.')


def _CreateFilesForClass(args):
    """Creates header cc and test files for the class."""
    header_template = _HEADER_TEMPLATE
    cc_template = _CC_TEMPLATE
    test_template = _TEST_TEMPLATE
    if args.segment_id:
        header_template = _MODEL_HEADER_TEMPLATE
        cc_template = _MODEL_CC_TEMPLATE
        test_template = _MODEL_TEST_TEMPLATE

    header = _GetHeader(args)

    file_cc = header.replace('.h', '.cc')
    file_test = header.replace('.h', '_unittest.cc')

    format_args = {}
    format_args['year'] = datetime.date.today().year
    format_args['file_path'] = header
    format_args['macro'] = (
        header.replace('/', '_').replace('.', '_').upper() + '_')
    format_args['clas'] = _GetClassNameFromFile(header)
    format_args['segment_id'] = args.segment_id
    format_args['segmentation_key'] = _GetSegmentationKeyFromFile(header)
    format_args['namespace'] = args.namespace
    format_args['test_class'] = format_args['clas'] + 'Test'

    contents = header_template.format_map(format_args)
    _WriteFile(header, 'Header', contents)

    contents = cc_template.format_map(format_args)
    _WriteFile(file_cc, 'CC', contents)

    contents = test_template.format_map(format_args)
    _WriteFile(file_test, 'Test', contents)


def _CreateOptionParser():
    """Options parser for the tool."""
    parser = argparse.ArgumentParser(
        description=_DOCUMENTATION,
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--header',
                        help='Path to the header file from src/',
                        default='')
    parser.add_argument('--segment_id',
                        help='The segment ID enum value',
                        default='')
    parser.add_argument('--namespace',
                        dest='namespace',
                        default='segmentation_platform')
    return parser


def main():
    parser = _CreateOptionParser()
    args = parser.parse_args()

    _CreateFilesForClass(args)


if __name__ == '__main__':
    sys.exit(main())