folly/folly/io/test/ShutdownSocketSetTest.cpp

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <folly/io/ShutdownSocketSet.h>

#include <atomic>
#include <chrono>
#include <thread>

#include <glog/logging.h>

#include <folly/net/NetOps.h>
#include <folly/net/NetworkSocket.h>
#include <folly/portability/GTest.h>
#include <folly/synchronization/Baton.h>

namespace folly {
namespace test {

class Server {
 public:
  Server();

  void stop(bool abortive);
  void join();
  int port() const { return port_; }
  int closeClients(bool abortive);

  void shutdownAll(bool abortive);

 private:
  NetworkSocket acceptSocket_;
  int port_;
  enum StopMode { NO_STOP, ORDERLY, ABORTIVE };
  std::atomic<StopMode> stop_;
  std::thread serverThread_;
  std::vector<NetworkSocket> fds_;
  folly::ShutdownSocketSet shutdownSocketSet_;
  folly::Baton<> baton_;
};

Server::Server() : acceptSocket_(), port_(0), stop_(NO_STOP) {
  acceptSocket_ = netops::socket(PF_INET, SOCK_STREAM, 0);
  CHECK_NE(acceptSocket_, NetworkSocket());
  shutdownSocketSet_.add(acceptSocket_);

  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = 0;
  addr.sin_addr.s_addr = INADDR_ANY;
  CHECK_ERR(netops::bind(
      acceptSocket_, reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)));

  CHECK_ERR(netops::listen(acceptSocket_, 10));

  socklen_t addrLen = sizeof(addr);
  CHECK_ERR(netops::getsockname(
      acceptSocket_, reinterpret_cast<sockaddr*>(&addr), &addrLen));

  port_ = ntohs(addr.sin_port);

  serverThread_ = std::thread([this] {
    bool first = true;
    while (stop_ == NO_STOP) {
      sockaddr_in peer;
      socklen_t peerLen = sizeof(peer);
      auto fd = netops::accept(
          acceptSocket_, reinterpret_cast<sockaddr*>(&peer), &peerLen);
      if (fd == NetworkSocket()) {
        if (errno == EINTR) {
          continue;
        }
        if (errno == EINVAL || errno == ENOTSOCK) { // socket broken
          break;
        }
      }
      CHECK_NE(fd, NetworkSocket());
      shutdownSocketSet_.add(fd);
      fds_.push_back(fd);
      CHECK(first);
      first = false;
      baton_.post();
    }

    if (stop_ != NO_STOP) {
      closeClients(stop_ == ABORTIVE);
    }

    shutdownSocketSet_.close(acceptSocket_);
  });
}

int Server::closeClients(bool abortive) {
  for (auto fd : fds_) {
    if (abortive) {
      struct linger l = {1, 0};
      CHECK_ERR(netops::setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
    }
    shutdownSocketSet_.close(fd);
  }
  int n = fds_.size();
  fds_.clear();
  return n;
}

void Server::shutdownAll(bool abortive) {
  baton_.wait();
  shutdownSocketSet_.shutdownAll(abortive);
}

void Server::stop(bool abortive) {
  stop_ = abortive ? ABORTIVE : ORDERLY;
  netops::shutdown(acceptSocket_, SHUT_RDWR);
}

void Server::join() {
  serverThread_.join();
}

NetworkSocket createConnectedSocket(int port) {
  auto sock = netops::socket(PF_INET, SOCK_STREAM, 0);
  CHECK_NE(sock, NetworkSocket());
  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);
  addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX
  CHECK_ERR(netops::connect(
      sock, reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)));
  return sock;
}

void runCloseTest(bool abortive) {
  Server server;

  auto sock = createConnectedSocket(server.port());

  std::thread stopper([&server, abortive] {
    std::this_thread::sleep_for(std::chrono::milliseconds(200));
    server.stop(abortive);
    server.join();
  });

  char c;
  int r = netops::recv(sock, &c, 1, 0);
  if (abortive) {
    int e = errno;
    EXPECT_EQ(-1, r);
    EXPECT_EQ(ECONNRESET, e);
  } else {
    EXPECT_EQ(0, r);
  }

  netops::close(sock);

  stopper.join();

  EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited
}

TEST(ShutdownSocketSetTest, OrderlyClose) {
  runCloseTest(false);
}

TEST(ShutdownSocketSetTest, AbortiveClose) {
  runCloseTest(true);
}

void runKillTest(bool abortive) {
  Server server;

  auto sock = createConnectedSocket(server.port());

  std::thread killer([&server, abortive] {
    server.shutdownAll(abortive);
    server.join();
  });

  char c;
  int r = netops::recv(sock, &c, 1, 0);

  // "abortive" is just a hint for ShutdownSocketSet, so accept both
  // behaviors
  if (abortive) {
    if (r == -1) {
      EXPECT_EQ(ECONNRESET, errno);
    } else {
      EXPECT_EQ(r, 0);
    }
  } else {
    EXPECT_EQ(0, r);
  }

  netops::close(sock);

  killer.join();

  // NOT closed by server when it exited
  EXPECT_EQ(1, server.closeClients(false));
}

TEST(ShutdownSocketSetTest, OrderlyKill) {
  runKillTest(false);
}

TEST(ShutdownSocketSetTest, AbortiveKill) {
  runKillTest(true);
}
} // namespace test
} // namespace folly