chromium/third_party/fuzztest/src/fuzztest/internal/domains/protobuf_domain_impl.h

// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_
#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_

#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "./fuzztest/domain_core.h"
#include "./fuzztest/internal/any.h"
#include "./fuzztest/internal/domains/arbitrary_impl.h"
#include "./fuzztest/internal/domains/container_of_impl.h"
#include "./fuzztest/internal/domains/domain.h"
#include "./fuzztest/internal/domains/domain_base.h"
#include "./fuzztest/internal/domains/element_of_impl.h"
#include "./fuzztest/internal/domains/optional_of_impl.h"
#include "./fuzztest/internal/logging.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/status.h"
#include "./fuzztest/internal/type_support.h"

// GetMessage is a Windows macro. Undefine it here to avoid code clutter.
#ifdef _WIN32
#pragma push_macro("GetMessage")
#undef GetMessage
#endif

namespace google::protobuf {
class EnumDescriptor;

template <typename E>
const EnumDescriptor* GetEnumDescriptor();
}  // namespace google::protobuf

namespace fuzztest::internal {

// Sniff the API to get the types we need without naming them directly.
// This allows for a soft dependency on proto without having to #include its
// headers.
ProtobufReflection;
ProtobufDescriptor;
ProtobufFieldDescriptor;
ProtobufOneofDescriptor;

template <typename Message, typename V>
class ProtocolBufferAccess;

#define FUZZTEST_INTERNAL_PROTO_ACCESS_
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
FUZZTEST_INTERNAL_PROTO_ACCESS_;
#undef FUZZTEST_INTERNAL_PROTO_ACCESS_

struct ProtoEnumTag;
struct ProtoMessageTag;

// Dynamic to static dispatch visitation.
// It will invoke:
//   visitor.VisitSingular<type>(field)  // for singular fields
//   visitor.VisitRepeated<type>(field)  // for repeated fields
// where `type` is:
//  - For bool, integrals, floating point and string: their C++ type.
//  - For enum: the tag type ProtoEnumTag.
//  - For message: the tag type ProtoMessageTag.
template <typename FieldDescriptor, typename Visitor>
auto VisitProtobufField(const FieldDescriptor* field, Visitor visitor) {}

Predicate;

template <typename T>
Predicate<T> IncludeAll() {}

template <typename T>
Predicate<T> IsOptional() {}

template <typename T>
Predicate<T> IsRepeated() {}

template <typename T>
Predicate<T> And(Predicate<T> lhs, Predicate<T> rhs) {}

template <typename T>
std::function<Domain<T>(Domain<T>)> Identity() {}

template <typename Message>
class ProtoPolicy {};

template <typename Prototype>
class PrototypePtr {};

// Domain for std::unique_ptr<Message>, where the prototype is accepted as a
// constructor argument.
template <typename Message>
class ProtobufDomainUntypedImpl
    : public domain_implementor::DomainBase<
          ProtobufDomainUntypedImpl<Message>, std::unique_ptr<Message>,
          absl::flat_hash_map<int, GenericDomainCorpusType>> {};

// Domain for `T` where `T` is a Protobuf message type.
// It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more
// convenient.
template <typename T,
          typename UntypedImpl = ProtobufDomainUntypedImpl<typename T::Message>>
class ProtobufDomainImpl
    : public domain_implementor::DomainBase<ProtobufDomainImpl<T>, T,
                                            corpus_type_t<UntypedImpl>> {
 public:
  using typename ProtobufDomainImpl::DomainBase::corpus_type;
  using typename ProtobufDomainImpl::DomainBase::value_type;
  using FieldDescriptor = ProtobufFieldDescriptor<typename T::Message>;

  corpus_type Init(absl::BitGenRef prng) {}

  uint64_t CountNumberOfFields(const corpus_type& val) {}

  uint64_t MutateNumberOfProtoFields(corpus_type& val) {}

  void Mutate(corpus_type& val, absl::BitGenRef prng, bool only_shrink) {}

  value_type GetValue(const corpus_type& v) const {}

  std::optional<corpus_type> FromValue(const value_type& value) const {}

  auto GetPrinter() const {}

  std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {}

  IRObject SerializeCorpus(const corpus_type& v) const {}

  absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {}

  // Provide a conversion to the type that WithMessageField wants.
  // Makes it easier on the user.
  operator Domain<std::unique_ptr<typename T::Message>>() const {
    return inner_;
  }

  ProtobufDomainImpl&& Self() && {}

  ProtobufDomainImpl&& WithFieldsAlwaysSet(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithFieldsUnset(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithOptionalFieldsAlwaysSet(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithOptionalFieldsUnset(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsAlwaysSet(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsUnset(
      std::function<bool(const FieldDescriptor*)> filter =
          IncludeAll<FieldDescriptor>()) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsSize(int64_t size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsSize(
      std::function<bool(const FieldDescriptor*)> filter, int64_t size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsMinSize(int64_t min_size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsMinSize(
      std::function<bool(const FieldDescriptor*)> filter, int64_t min_size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(int64_t max_size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldsMaxSize(
      std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && {}

  ProtobufDomainImpl&& WithFieldUnset(absl::string_view field) && {}

  ProtobufDomainImpl&& WithFieldAlwaysSet(absl::string_view field) && {}

  ProtobufDomainImpl&& WithRepeatedFieldSize(
      absl::string_view field_name, std::optional<int64_t> min_size,
      std::optional<int64_t> max_size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldMinSize(absl::string_view field_name,
                                                int64_t min_size) && {}

  ProtobufDomainImpl&& WithRepeatedFieldMaxSize(absl::string_view field_name,
                                                int64_t max_size) && {}

  ProtobufDomainImpl&& WithOneofAlwaysSet(absl::string_view oneof_name) && {}

#define FUZZTEST_INTERNAL_WITH_FIELD(Camel, cpp, TAG)                          \
  using Camel##type = MakeDependentType<cpp, T>;                               \
  ProtobufDomainImpl&& With##Camel##Field(absl::string_view field,             \
                                          Domain<Camel##type> domain)&& {      \
    const FieldDescriptor* descriptor = inner_.GetField(field);                \
    if (descriptor->is_repeated()) {                                           \
      inner_.WithField(                                                        \
          field, inner_.template GetOuterDomainForField</*is_repeated=*/true>( \
                     descriptor, std::move(domain)));                          \
    } else {                                                                   \
      inner_.WithOneofFieldWithoutNullnessConfiguration(field);                \
      inner_.WithField(                                                        \
          field,                                                               \
          inner_.template GetOuterDomainForField</*is_repeated=*/false>(       \
              descriptor, std::move(domain)));                                 \
    }                                                                          \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##FieldAlwaysSet(                            \
      absl::string_view field, Domain<Camel##type> domain)&& {                 \
    const FieldDescriptor* descriptor = inner_.GetField(field);                \
    if (descriptor->is_repeated()) {                                           \
      inner_.WithField(                                                        \
          field,                                                               \
          SequenceContainerOfImpl<std::vector<Camel##type>, decltype(domain)>( \
              std::move(domain))                                               \
              .WithMinSize(1));                                                \
    } else {                                                                   \
      inner_.WithOneofField(field, OptionalPolicy::kWithoutNull);              \
      inner_.WithField(                                                        \
          field, OptionalOfImpl<std::optional<Camel##type>>(std::move(domain)) \
                     .SetWithoutNull());                                       \
    }                                                                          \
    return std::move(*this);                                                   \
  }                                                                            \
  /* TODO(b/271123298): Remove the following two methods and replace them with \
  WithField(Unset/AlwaysSet) */                                                \
  ProtobufDomainImpl&& With##Camel##FieldUnset(absl::string_view field)&& {    \
    auto default_domain =                                                      \
        inner_.template GetFieldTypeDefaultDomain<TAG>(field);                 \
    inner_.WithOneofField(field, OptionalPolicy::kAlwaysNull);                 \
    inner_.WithField(field, OptionalOfImpl<std::optional<Camel##type>>(        \
                                std::move(default_domain))                     \
                                .SetAlwaysNull());                             \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##FieldAlwaysSet(                            \
      absl::string_view field)&& {                                             \
    return std::move(*this).With##Camel##FieldAlwaysSet(                       \
        field, inner_.template GetFieldTypeDefaultDomain<TAG>(field));         \
  }                                                                            \
  ProtobufDomainImpl&& WithOptional##Camel##Field(                             \
      absl::string_view field,                                                 \
      Domain<MakeDependentType<std::optional<cpp>, T>> domain)&& {             \
    FailIfIsOneof(field);                                                      \
    inner_.WithField(field, std::move(domain));                                \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& WithRepeated##Camel##Field(                             \
      absl::string_view field,                                                 \
      Domain<MakeDependentType<std::vector<cpp>, T>> domain)&& {               \
    inner_.WithField(field, std::move(domain));                                \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##Fields(Domain<Camel##type> domain)&& {     \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(                          \
        IncludeAll<FieldDescriptor>(), std::move(domain));                     \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##Fields(                                    \
      std::function<bool(const FieldDescriptor*)>&& filter,                    \
      Domain<Camel##type> domain)&& {                                          \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(std::move(filter),        \
                                                     std::move(domain));       \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& WithOptional##Camel##Fields(                            \
      Domain<Camel##type> domain)&& {                                          \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(                          \
        IsOptional<FieldDescriptor>(), std::move(domain));                     \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& WithOptional##Camel##Fields(                            \
      std::function<bool(const FieldDescriptor*)>&& filter,                    \
      Domain<Camel##type> domain)&& {                                          \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(                          \
        And(IsOptional<FieldDescriptor>(), std::move(filter)),                 \
        std::move(domain));                                                    \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& WithRepeated##Camel##Fields(                            \
      Domain<Camel##type> domain)&& {                                          \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(                          \
        IsRepeated<FieldDescriptor>(), std::move(domain));                     \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& WithRepeated##Camel##Fields(                            \
      std::function<bool(const FieldDescriptor*)>&& filter,                    \
      Domain<Camel##type> domain)&& {                                          \
    inner_.GetPolicy().SetDefaultDomainFor##Camel##s(                          \
        And(IsRepeated<FieldDescriptor>(), std::move(filter)),                 \
        std::move(domain));                                                    \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##FieldsTransformed(                         \
      std::function<Domain<Camel##type>(Domain<Camel##type>)>&&                \
          transformer)&& {                                                     \
    inner_.GetPolicy().SetDomainTransformerFor##Camel##s(                      \
        IncludeAll<FieldDescriptor>(), std::move(transformer));                \
    return std::move(*this);                                                   \
  }                                                                            \
  ProtobufDomainImpl&& With##Camel##FieldsTransformed(                         \
      std::function<bool(const FieldDescriptor*)>&& filter,                    \
      std::function<Domain<Camel##type>(Domain<Camel##type>)>&&                \
          transformer)&& {                                                     \
    inner_.GetPolicy().SetDomainTransformerFor##Camel##s(                      \
        std::move(filter), std::move(transformer));                            \
    return std::move(*this);                                                   \
  }

  FUZZTEST_INTERNAL_WITH_FIELD
  FUZZTEST_INTERNAL_WITH_FIELD(Int32, int32_t, int32_t)
  FUZZTEST_INTERNAL_WITH_FIELD(UInt32, uint32_t, uint32_t)
  FUZZTEST_INTERNAL_WITH_FIELD(Int64, int64_t, int64_t)
  FUZZTEST_INTERNAL_WITH_FIELD(UInt64, uint64_t, uint64_t)
  FUZZTEST_INTERNAL_WITH_FIELD
  FUZZTEST_INTERNAL_WITH_FIELD
  FUZZTEST_INTERNAL_WITH_FIELD
  FUZZTEST_INTERNAL_WITH_FIELD(Enum, int, ProtoEnumTag)
  FUZZTEST_INTERNAL_WITH_FIELD(Protobuf, std::unique_ptr<typename T::Message>,
                               ProtoMessageTag)

#undef FUZZTEST_INTERNAL_WITH_FIELD

  // The following methods automatically cast Domain<Proto> to
  // Domain<unique_ptr<Message>>

  template <typename Protobuf>
  ProtobufDomainImpl&& WithProtobufField(absl::string_view field,
                                         Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithProtobufFieldAlwaysSet(absl::string_view field,
                                                  Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithProtobufFields(Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithProtobufFields(
      std::function<bool(const FieldDescriptor*)>&& filter,
      Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithOptionalProtobufFields(Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithOptionalProtobufFields(
      std::function<bool(const FieldDescriptor*)>&& filter,
      Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithRepeatedProtobufFields(Domain<Protobuf> domain) && {}

  template <typename Protobuf>
  ProtobufDomainImpl&& WithRepeatedProtobufFields(
      std::function<bool(const FieldDescriptor*)>&& filter,
      Domain<Protobuf> domain) && {}

  template <typename OptionalProtobufDomain>
  ProtobufDomainImpl&& WithOptionalProtobufField(
      absl::string_view field, OptionalProtobufDomain domain) && {}

  template <typename RepeatedProtobufDomain>
  ProtobufDomainImpl&& WithRepeatedProtobufField(
      absl::string_view field, RepeatedProtobufDomain domain) && {}

 private:
  void FailIfIsOneof(absl::string_view field) {}

  template <typename Inner>
  Domain<std::unique_ptr<typename T::Message>> ToUntypedProtoDomain(
      Inner inner_domain) {}

  template <typename Inner>
  Domain<std::optional<std::unique_ptr<typename T::Message>>>
  ToOptionalUntypedProtoDomain(Inner inner_domain) {}

  template <typename Inner>
  Domain<std::vector<std::unique_ptr<typename T::Message>>>
  ToRepeatedUntypedProtoDomain(Inner inner_domain) {}

  UntypedImpl inner_{&T::default_instance(), /*use_lazy_initialization=*/true};
};

ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_v<T>>>;

ArbitraryImpl<T, std::enable_if_t<is_protocol_buffer_enum_v<T>>>;

}  // namespace fuzztest::internal

#ifdef _WIN32
#pragma pop_macro("GetMessage")
#endif

#endif  // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_PROTOBUF_DOMAIN_IMPL_H_