folly/folly/io/async/test/MockAsyncSocketLegacyObserver.h

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// TODO(lume): remove this file in future diff once we replace LegacyObserver
// with AsyncSocket::ManagedObserver

#pragma once

#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSocketException.h>
#include <folly/portability/GMock.h>

namespace folly {
namespace test {

class MockAsyncSocketLegacyLifecycleObserver
    : public AsyncSocket::LegacyLifecycleObserver {
 public:
  using AsyncSocket::LegacyLifecycleObserver::LegacyLifecycleObserver;
  MOCK_METHOD(void, observerAttachMock, (AsyncSocket*));
  MOCK_METHOD(void, observerDetachMock, (AsyncSocket*));
  MOCK_METHOD(void, destroyMock, (AsyncSocket*));
  MOCK_METHOD(void, closeMock, (AsyncSocket*));
  MOCK_METHOD(void, connectAttemptMock, (AsyncSocket*));
  MOCK_METHOD(void, connectSuccessMock, (AsyncSocket*));
  MOCK_METHOD(
      void, connectErrorMock, (AsyncSocket*, const AsyncSocketException&));
  MOCK_METHOD(void, evbAttachMock, (AsyncSocket*, EventBase*));
  MOCK_METHOD(void, evbDetachMock, (AsyncSocket*, EventBase*));
  MOCK_METHOD(
      void, byteEventMock, (AsyncSocket*, const AsyncSocket::ByteEvent&));
  MOCK_METHOD(void, byteEventsEnabledMock, (AsyncSocket*));
  MOCK_METHOD(
      void,
      byteEventsUnavailableMock,
      (AsyncSocket*, const AsyncSocketException&));
  MOCK_METHOD(
      void,
      prewriteMock,
      (AsyncSocket*, const PrewriteState&, PrewriteRequestContainer&));

 private:
  void observerAttach(AsyncSocket* socket) noexcept override {
    observerAttachMock(socket);
  }
  void observerDetach(AsyncSocket* socket) noexcept override {
    observerDetachMock(socket);
  }
  void destroy(AsyncSocket* socket) noexcept override { destroyMock(socket); }
  void close(AsyncSocket* socket) noexcept override { closeMock(socket); }
  void connectAttempt(AsyncSocket* socket) noexcept override {
    connectAttemptMock(socket);
  }
  void connectSuccess(AsyncSocket* socket) noexcept override {
    connectSuccessMock(socket);
  }
  void connectError(
      AsyncSocket* socket, const AsyncSocketException& ex) noexcept override {
    connectErrorMock(socket, ex);
  }
  void evbAttach(AsyncSocket* socket, EventBase* eb) noexcept override {
    evbAttachMock(socket, eb);
  }
  void evbDetach(AsyncSocket* socket, EventBase* eb) noexcept override {
    evbDetachMock(socket, eb);
  }
  void byteEvent(
      AsyncSocket* socket, const AsyncSocket::ByteEvent& ev) noexcept override {
    byteEventMock(socket, ev);
  }
  void byteEventsEnabled(AsyncSocket* socket) noexcept override {
    byteEventsEnabledMock(socket);
  }
  void byteEventsUnavailable(
      AsyncSocket* socket, const AsyncSocketException& ex) noexcept override {
    byteEventsUnavailableMock(socket, ex);
  }
  void prewrite(
      AsyncSocket* socket,
      const PrewriteState& state,
      PrewriteRequestContainer& container) noexcept override {
    prewriteMock(socket, state, container);
  }
};

/**
 * Extends mock class to simplify ByteEvents tests.
 */
class MockAsyncSocketLegacyLifecycleObserverForByteEvents
    : public MockAsyncSocketLegacyLifecycleObserver {
 public:
  MockAsyncSocketLegacyLifecycleObserverForByteEvents(
      AsyncSocket* socket,
      const MockAsyncSocketLegacyLifecycleObserverForByteEvents::Config&
          observerConfig)
      : MockAsyncSocketLegacyLifecycleObserver(observerConfig),
        socket_(socket) {
    ON_CALL(*this, byteEventMock(testing::_, testing::_))
        .WillByDefault(testing::Invoke(
            [this](
                AsyncSocket* socketport, const AsyncSocket::ByteEvent& event) {
              CHECK_EQ(this->socket_, socketport);
              byteEvents_.emplace_back(event);
            }));
    ON_CALL(*this, byteEventsEnabledMock(testing::_))
        .WillByDefault(testing::Invoke([this](AsyncSocket* socketport) {
          CHECK_EQ(this->socket_, socketport);
          byteEventsEnabledCalled_++;
        }));

    ON_CALL(*this, byteEventsUnavailableMock(testing::_, testing::_))
        .WillByDefault(testing::Invoke(
            [this](AsyncSocket* socketport, const AsyncSocketException& ex) {
              CHECK_EQ(this->socket_, socketport);
              byteEventsUnavailableCalled_++;
              byteEventsUnavailableCalledEx_.emplace(ex);
            }));
    socket_->addLifecycleObserver(this);
  }

  const std::vector<AsyncSocket::ByteEvent>& getByteEvents() {
    return byteEvents_;
  }

  folly::Optional<AsyncSocket::ByteEvent> getByteEventReceivedWithOffset(
      const uint64_t offset, const AsyncSocket::ByteEvent::Type type) {
    for (const auto& byteEvent : byteEvents_) {
      if (type == byteEvent.type && offset == byteEvent.offset) {
        return byteEvent;
      }
    }
    return folly::none;
  }

  folly::Optional<uint64_t> maxOffsetForByteEventReceived(
      const AsyncSocket::ByteEvent::Type type) {
    folly::Optional<uint64_t> maybeMaxOffset;
    for (const auto& byteEvent : byteEvents_) {
      if (type == byteEvent.type &&
          (!maybeMaxOffset.has_value() ||
           maybeMaxOffset.value() <= byteEvent.offset)) {
        maybeMaxOffset = byteEvent.offset;
      }
    }
    return maybeMaxOffset;
  }

  bool checkIfByteEventReceived(
      const AsyncSocket::ByteEvent::Type type, const uint64_t offset) {
    for (const auto& byteEvent : byteEvents_) {
      if (type == byteEvent.type && offset == byteEvent.offset) {
        return true;
      }
    }
    return false;
  }

  void waitForByteEvent(
      const AsyncSocket::ByteEvent::Type type, const uint64_t offset) {
    while (!checkIfByteEventReceived(type, offset)) {
      socket_->getEventBase()->loopOnce();
    }
  }

  // Exposed ByteEvent helper fields with const
  const uint32_t& byteEventsEnabledCalled{byteEventsEnabledCalled_};
  const uint32_t& byteEventsUnavailableCalled{byteEventsUnavailableCalled_};
  const folly::Optional<AsyncSocketException>& byteEventsUnavailableCalledEx{
      byteEventsUnavailableCalledEx_};
  const std::vector<AsyncSocket::ByteEvent>& byteEvents{byteEvents_};

 private:
  AsyncSocket* socket_;

  // ByteEvents helpers
  uint32_t byteEventsEnabledCalled_{0};
  uint32_t byteEventsUnavailableCalled_{0};
  folly::Optional<AsyncSocketException> byteEventsUnavailableCalledEx_;
  std::vector<AsyncSocket::ByteEvent> byteEvents_;
};

} // namespace test
} // namespace folly