chromium/components/exo/wayland/compatibility_test/template_client_event_receiver_version_tests.cc.tmpl

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <algorithm>
#include <memory>
#include <type_traits>

#include "base/check.h"
#include "base/logging.h"
#include "components/exo/wayland/compatibility_test/client_compatibility_test.h"
{% for protocol in protocols %}
#include "components/exo/wayland/compatibility_test/generated-{{ protocol.filename }}-client-helpers.h"
{% endfor %}
#include "components/exo/wayland/compatibility_test/wayland_client_registry.h"
#include "components/exo/wayland/compatibility_test/wayland_client_event_recorder.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/gmock/include/gmock/gmock.h"


{#
Figure out which Wayland interfaces need to be tested.

We only need to test interfaces which have events with a "since" that is
greater than the first possible version of the interface.
#}
{% set interfaces_to_test = [] %}
{% for protocol in protocols %}
  {% for interface in protocol.interfaces %}
    {% set min_version = get_minimum_version_to_construct(interface) %}
    {# if interface.events|selectattr("since")|max(attribute="since") > min_version #}
    {% if interface.events|selectattr("since")|selectattr("since", ">", min_version)|list|length %}
      {% set _ = interfaces_to_test.append(interface) %}
    {% endif %}
  {% endfor %}
{% endfor %}

namespace exo {
namespace wayland {
namespace compatibility {
namespace test {
namespace {

// Each of these generated fixture interfaces can be overriden by
// non-generated code in wayland_client_event_receiver_version_fixtures.h

{% for target_interface in interfaces_to_test %}
  {% set steps = get_construction_steps(target_interface) %}

class {{ target_interface.name | pascal  }}EventTestBase : public ClientCompatibilityTest {
 protected:
  // Returns true if this test should be skipped. The base implementation
  // checks that the server supports the required interface versions.
  virtual bool ShouldSkip(uint32_t target_interface_version) const noexcept;
  {% for step in steps %}
    {% if step.ctor %}
      {% if step.ctor.message.is_event %}
  // Overriden by non-generated code to ensure the server sends the interface
  // creation event.
  virtual void NudgeServerFor{{ step.instance_name | pascal }}Creation() noexcept;
      {% endif %}
  // Factory function to set {{ step.instance_name }}
  virtual std::unique_ptr<struct {{ step.interface.name }}> Make{{ step.instance_name | pascal }}() noexcept;
    {% endif %}
  {% endfor %}
  // Overriden by non-generated code to ensure the server sends the required
  // events.
  virtual void NudgeServerForTestEvents() noexcept;

  void ValidateEventsImpl(uint32_t target_interface_version) noexcept;
  void ValidateEvents(uint32_t target_interface_version) noexcept;

  // Interfaces and helpers needed for the test, in construction order.
  std::unique_ptr<wl_display> display_;
  std::unique_ptr<WaylandClientRegistry> registry_;
  {% for step in steps %}
  std::unique_ptr<struct {{ step.interface.name }}> {{ step.instance_name }};
  {% endfor %}
};

{% endfor %}

}  // namespace
}  // namespace test
}  // namespace compatibility
}  // namespace wayland
}  // namespace exo

// This is intentionally included after all the fixture class definitions
// above, so that the header file can provide overrides.
#include "components/exo/wayland/compatibility_test/wayland_client_event_receiver_version_fixtures.h"

namespace exo {
namespace wayland {
namespace compatibility {
namespace test {
namespace {

{% for target_interface in interfaces_to_test %}
  {% set steps = get_construction_steps(target_interface) %}
  {% set base_fixture_name = target_interface.name | pascal + "EventTestBase" %}
// Make sure this is defined even if not complete.
struct {{ target_interface.name | pascal  }}EventTest;

// Use {{ target_interface.name | pascal  }}EventTest if it is defined, otherwise
// fall back to {{ target_interface.name | pascal  }}EventTestBase.
constexpr bool k{{ target_interface.name | pascal }}EventTestFixtureUsesOverride =
    std::is_convertible<
        {{ target_interface.name | pascal  }}EventTest*,
        {{ target_interface.name | pascal  }}EventTestBase*
    >::value;
using {{ target_interface.name | pascal  }}EventTestFixture =
    std::conditional_t<
        k{{ target_interface.name | pascal }}EventTestFixtureUsesOverride,
        {{ target_interface.name | pascal  }}EventTest,
        {{ target_interface.name | pascal  }}EventTestBase>;

bool {{ base_fixture_name }}::ShouldSkip(uint32_t target_interface_version) const noexcept {
  return false
  {% for step in steps %}
    {% if not step.ctor %}
      {% if step.interface.protocol == target_interface.protocol %}
    || !registry_->Has<struct {{step.interface.name}}>(
        std::max(target_interface_version, {{ step.minimum_version }}u))
      {% else %}
    || !registry_->Has<struct {{step.interface.name}}>({{ step.minimum_version }}u)
      {% endif %}
    {% endif %}
  {% endfor %}
  ;
}

  {% for step in steps %}
    {% if step.ctor %}
      {% if step.ctor.message.is_event %}
void {{ base_fixture_name }}::NudgeServerFor{{ step.instance_name | pascal }}Creation() noexcept {}

      {% endif %}
std::unique_ptr<struct {{ step.interface.name }}>
{{ base_fixture_name }}::Make{{ step.instance_name | pascal }}() noexcept {
      {% if not step.ctor.message.is_event %}
  auto {{ step.instance_name }}obj = std::unique_ptr<struct {{ step.interface.name }}>(
  {{ step.ctor.interface_step.interface.name }}_{{ step.ctor.message.name }}(
      {{ step.ctor.interface_step.instance_name }}.get()
        {% for arg in step.ctor.message.args %}
            {% if arg.type == "new_id" %}
            {% elif arg.type == "object" %}
      , {{ step.ctor.object_args[loop.index0].instance_name }}.get()
            {% elif arg.type == "fd" %}
      , -1 /* generated default {{ arg.type }} value for {{ arg.name }} */
            {% elif arg.type == "string" %}
      , "" /* generated default {{ arg.type }} value for {{ arg.name }} */
            {% elif arg.type == "array" %}
      , nullptr_t /* generated default {{ arg.type }} value for {{ arg.name }} */
            {% elif arg.type in ["int", "uint", "fixed"] %}
      , 0 /* generated default {{ arg.type }} value for {{ arg.name }} */
            {% else %}
      , nullptr_t /* generated default {{ arg.type }} value for {{ arg.name }} */
            {% endif %}
        {% endfor %}
      ));
      {% else %}
  std::unique_ptr<struct {{ step.interface.name }}> {{ step.instance_name }}obj;

  static {{ step.ctor.interface_step.interface.name }}_listener {{ step.ctor.interface_step.interface.name }}_listener = {
        {% for event in step.ctor.interface_step.interface.events %}
      /* {{ event.name }} */
      [](void* user_data_, struct {{ step.ctor.interface_step.interface.name }}* interface_
          {% for arg in event.args %}
          {{ get_c_arg_for_client_event_arg(arg) }}
          {% endfor %}
      ) {
          {% if event.name == step.ctor.message.name %}
        if (user_data_) {
            {% for arg in event.args %}
              {% if arg.type == "ctor" and arg.interface == step.interface.name %}
          auto* container = static_cast<std::unique_ptr<struct {{ step.interface.name }}>*>(context);
          (*container) = {{ arg.name }};
              {% endif %}
            {% endfor %}
        } else {
          LOG(INFO) << "Received unexpected " << "{{ step.ctor.interface_step.interface.name }}::{{ event.name }}";
        }
          {% else %}
        LOG(INFO) << "Received unexpected " << "{{ step.ctor.interface_step.interface.name }}::{{ event.name }}";
          {% endif %}

          {% for arg in event.args %}
            {% if event.args.type == "fd" %}
        base::ScopedFD unused_scoped_{{ arg.name }}(fd);
            {% endif %}
          {% endfor %}
      },
        {% endfor %}
  };

  {
    int err = {{ step.ctor.interface_step.interface.name }}_add_listener(
      {{ step.ctor.interface_step.instance_name }}.get(),
      &{{ step.ctor.interface_step.interface.name }}_listener,
      &{{ step.instance_name }});
    if (err != 0) {
      ADD_FAILURE() << "Failed to add listener!";
      return {};
    }
  }

  // Sync with the server to ensure it receives all requests
  {
    int err = wl_display_roundtrip(display_.get());
    if (err < 0) {
      ADD_FAILURE() << "Failed to roundtrip!";
      return {};
    }
  }

  NudgeServerFor{{ step.instance_name | pascal }}Creation();

  // Sync with the server again to process all events
  {
    int err = wl_display_roundtrip(display_.get());
    if (err < 0) {
      ADD_FAILURE() << "Failed to roundtrip!";
      return {};
    }
  }

  // We want to keep the instance but want to ignore further events.
  {{ step.ctor.interface_step.interface.name }}_set_user_data({{ step.ctor.interface_step.instance_name }}.get(), nullptr);

  if (!{{ step.instance_name }}.get())
    ADD_FAILURE()
        << "Required event {{ step.ctor.interface_step.interface.name }}::{{ step.ctor.message.name }}\n"
        << "not received! This is needed to construct a {{ step.interface.name }} instance.\n"
        << "You may need to define "
        << "{{ step.interface.name | pascal  }}EventTest::NudgeServerFor{{ step.instance_name | pascal }}Creation()";

      {% endif %}
  return {{ step.instance_name }}obj;
}

    {% endif %}
  {% endfor %}

void {{ base_fixture_name }}::NudgeServerForTestEvents() noexcept {}

void {{ base_fixture_name }}::ValidateEventsImpl(uint32_t target_interface_version) noexcept {
  display_ = std::unique_ptr<wl_display>(wl_display_connect(nullptr));
  ASSERT_THAT(display_.get(), ::testing::IsTrue());

  registry_ = std::make_unique<WaylandClientRegistry>(display_.get());
  ASSERT_THAT(registry_.get(), ::testing::IsTrue());
  ASSERT_THAT(wl_display_roundtrip(display_.get()), ::testing::Ge(0));

  if (ShouldSkip(target_interface_version)) {
    GTEST_SKIP();
  }

  {% for step in steps %}
    {% if step.ctor %}
  {{ step.instance_name }} = Make{{ step.instance_name|pascal }}();
    {% else %}
      {% if step.interface.protocol == target_interface.protocol %}
  {{ step.instance_name }} = registry_->Bind<struct {{step.interface.name}}>(
      std::max(target_interface_version, {{ step.minimum_version }}u));
      {% else %}
  {{ step.instance_name }} = registry_->Bind<struct {{step.interface.name}}>(
      {{ step.minimum_version }}u);
      {% endif %}
    {% endif %}
  ASSERT_THAT({{ step.instance_name }}.get(), ::testing::IsTrue());
  {% endfor %}

  {% set min_version = get_minimum_version_to_construct(target_interface) %}
// Min version: {{ min_version }}
  {% for event in target_interface.events %}
// event {{ event.name }} since {{ event.since }} {{  event.since and event.since > min_version }}
  {% endfor %}

  static {{ target_interface.name }}_listener {{ target_interface.name }}_listener = {
  {% for event in target_interface.events %}
      /* {{ event.name }} */
      [](void* user_data_, struct {{ target_interface.name }}* interface_
    {% for arg in event.args %}
          {{ get_c_arg_for_client_event_arg(arg) }}
    {% endfor %}
      ) {
    {% if event.since and event.since > min_version %}
        auto* recorder = static_cast<EventRecorder*>(user_data_);
        uint32_t version = {{ target_interface.name }}_get_version(interface_);
        recorder->OnEvent(
            "{{ target_interface.name }}::{{ event.name }}",
            version);
    {% endif %}

    {% for arg in event.args %}
      {% if event.args.type == "fd" %}
        base::ScopedFD unused_scoped_{{ arg.name }}(fd);
      {% endif %}
    {% endfor %}
      },
  {% endfor %}
  };

  EventRecorder recorder;

  {
    int err = {{ target_interface.name }}_add_listener(
        {{ steps[-1].instance_name }}.get(),
        &{{ target_interface.name }}_listener,
        &recorder);
    ASSERT_THAT(err, ::testing::Eq(0)) << "Failed to add target listener!";
  }

  // Ensure the server side receives all the creation requests
  ASSERT_THAT(wl_display_roundtrip(display_.get()), ::testing::Ge(0));

  NudgeServerForTestEvents();

  // Ensure the server side events are received and processed
  ASSERT_THAT(wl_display_roundtrip(display_.get()), ::testing::Ge(0));

  // These expectations are really what this test is checking.
  {% for event in target_interface.events %}
    {% if event.since and event.since > min_version %}
  {
    const char* event_name =
        "{{ target_interface.name }}::{{ event.name }}";
    constexpr uint32_t event_since_version =
        {{ event.since or 1 }}u;
    std::optional<uint32_t> received_at_version =
        recorder.MaybeGetReceivedAtVersion(event_name);

    if (target_interface_version == event_since_version) {
      EXPECT_THAT(received_at_version, ::testing::Not(::testing::Eq(std::nullopt)))
          << "Failed to receive " << event_name << " at interface version "
          << target_interface_version
          << " as the event is defined as since version "
          << event_since_version << ".\n"
          << "\n"
          << "You may need to override:\n"
          << "    {{ target_interface.name | pascal  }}EventTest::NudgeServerForTestEvents()\n"
          << "to ensure the event is sent by the server.";
    }

    if (target_interface_version >= event_since_version) {
      ASSERT_THAT(received_at_version, ::testing::AnyOf(
                      ::testing::Eq(std::nullopt),
                      ::testing::Optional(::testing::Eq(target_interface_version))))
          << "Unexpected version mismatch. Test expected to create a\n"
          << "   {{ target_interface.name }}"
          << "at version "
          << target_interface_version
          << " but actually created the interface at version "
          << received_at_version.value_or(0u)
          << ".";
    }

    if (target_interface_version < event_since_version) {
      EXPECT_THAT(received_at_version, ::testing::Eq(std::nullopt))
          << "Unexpectedly received "
          << event_name
          << " at interface version "
          << received_at_version.value_or(0u)
          << ".\n"
          << "The event is defined as since version "
          << event_since_version
          << " and can cause the client to\n"
          << "abort or crash if sent to a client that uses an earlier version.";
    }
  }

      {% endif %}
  {% endfor %}
}

void {{ base_fixture_name }}::ValidateEvents(uint32_t target_interface_version) noexcept {
  ValidateEventsImpl(target_interface_version);

  if (HasFailure() && !k{{ target_interface.name | pascal }}EventTestFixtureUsesOverride) {
      LOG(INFO)
          << "\n"
          << "Note:\n"
          << "    {{ target_interface.name | pascal  }}EventTest\n"
          << "can be defined to override the default generated code for this test.\n"
          << "\n"
          << "For example, to disable/skip this test, add this to\n"
          << "components/exo/wayland/compatibility_test/wayland_client_event_receiver_version_fixtures.h:\n"
          << "\n"
          << "    struct {{ target_interface.name | pascal  }}EventTest\n"
          << "        : public {{ target_interface.name | pascal  }}EventTestBase {\n"
          << "         bool ShouldSkip(uint32_t) const noexcept override;\n"
          << "    };\n"
          << "    bool {{ target_interface.name | pascal  }}EventTest::ShouldSkip(uint32_t) const noexcept {\n"
          << "        return true;\n"
          << "    }\n"
          << "\n"
          << "Otherwise see that header for more general notes.";
    }
}

  {% for target_interface_version in get_versions_to_test_for_event_delivery(target_interface) %}
TEST_F({{ target_interface.name | pascal  }}EventTestFixture,
    EventsForV{{ target_interface_version }}) {
      ValidateEvents({{ target_interface_version }});
}

  {% endfor %}
{% endfor %}
}  // namespace
}  // namespace test
}  // namespace compatibility
}  // namespace wayland
}  // namespace exo