chromium/chromecast/cast_core/grpc/grpc_server_streaming_test.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/files/scoped_temp_dir.h"
#include "base/rand_util.h"
#include "base/strings/stringprintf.h"
#include "base/synchronization/waitable_event.h"
#include "base/task/bind_post_task.h"
#include "base/task/thread_pool.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "base/threading/thread_restrictions.h"
#include "base/time/time.h"
#include "base/uuid.h"
#include "chromecast/cast_core/grpc/grpc_server.h"
#include "chromecast/cast_core/grpc/status_matchers.h"
#include "chromecast/cast_core/grpc/test_service.castcore.pb.h"
#include "chromecast/cast_core/grpc/test_utils.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace cast {
namespace utils {
namespace {

using ::cast::test::StatusIs;
using ::testing::IsNull;

const auto kEventTimeout = base::Seconds(1);
const auto kServerStopTimeout = base::Seconds(1);

class GrpcServerStreamingTest : public ::testing::Test {
 protected:
  GrpcServerStreamingTest() {
    CHECK(temp_dir_.CreateUniqueTempDir());
    endpoint_ =
        "unix:" +
        temp_dir_.GetPath()
            .AppendASCII(
                "cast-uds-" +
                base::Uuid::GenerateRandomV4().AsLowercaseString().substr(24))
            .value();
  }

  base::test::TaskEnvironment task_environment_{
      base::test::TaskEnvironment::TimeSource::MOCK_TIME};
  base::ScopedTempDir temp_dir_;
  std::string endpoint_;
};

TEST_F(GrpcServerStreamingTest, ServerStreamingCallSucceeds) {
  const int kMaxResponseCount = base::RandInt(10, 300);
  int server_response_count = 0;
  auto writes_available_callback = base::BindLambdaForTesting(
      [&](grpc::Status status,
          ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
        CU_CHECK_OK(status);
        if (server_response_count < kMaxResponseCount) {
          TestResponse response;
          response.set_bar(
              base::StringPrintf("test_bar%d", ++server_response_count));
          reactor->Write(std::move(response));
        } else {
          LOG(INFO) << "Writing finished";
          reactor->Write(grpc::Status::OK);
        }
      });
  auto call_handler = base::BindLambdaForTesting(
      [&](TestRequest request,
          ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
        EXPECT_EQ(request.foo(), "test_foo");

        reactor->SetWritesAvailableCallback(
            std::move(writes_available_callback));

        TestResponse response;
        response.set_bar(
            base::StringPrintf("test_bar%d", ++server_response_count));
        reactor->Write(std::move(response));
      });

  GrpcServer server;
  server.SetHandler<ServerStreamingServiceHandler::StreamingCall>(
      std::move(call_handler));
  ASSERT_THAT(server.Start(endpoint_), StatusIs(grpc::StatusCode::OK));

  ServerStreamingServiceStub stub(endpoint_);
  auto call = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  call.request().set_foo("test_foo");
  int call_count = 0;
  base::WaitableEvent response_received_event;
  std::move(call).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        CU_CHECK_OK(response);
        if (done) {
          response_received_event.Signal();
        } else {
          EXPECT_EQ(response->bar(),
                    base::StringPrintf("test_bar%d", ++call_count));
        }
      }));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));
  ASSERT_EQ(call_count, kMaxResponseCount);

  test::StopGrpcServer(server, kServerStopTimeout);
}

TEST_F(GrpcServerStreamingTest, ServerStreamingCallFailsRightAway) {
  GrpcServer server;
  server.SetHandler<ServerStreamingServiceHandler::StreamingCall>(
      base::BindLambdaForTesting(
          [&](TestRequest request,
              ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
            EXPECT_EQ(request.foo(), "test_foo");
            reactor->Write(
                grpc::Status(grpc::StatusCode::NOT_FOUND, "not found"));
          }));
  ASSERT_THAT(server.Start(endpoint_), StatusIs(grpc::StatusCode::OK));

  ServerStreamingServiceStub stub(endpoint_);
  auto call = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  call.request().set_foo("test_foo");
  base::WaitableEvent response_received_event;
  std::move(call).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        CHECK(done);
        ASSERT_THAT(response.status(),
                    StatusIs(grpc::StatusCode::NOT_FOUND, "not found"));
        response_received_event.Signal();
      }));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));

  test::StopGrpcServer(server, kServerStopTimeout);
}

TEST_F(GrpcServerStreamingTest, ServerStreamingCallCancelledIfServerIsStopped) {
  GrpcServer server;
  server.SetHandler<ServerStreamingServiceHandler::StreamingCall>(
      base::BindLambdaForTesting(
          [&](TestRequest request,
              ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
            reactor->Write(grpc::Status::OK);
          }));
  ASSERT_THAT(server.Start(endpoint_), StatusIs(grpc::StatusCode::OK));

  ServerStreamingServiceStub stub(endpoint_);
  auto call = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  call.request().set_foo("test_foo");
  base::WaitableEvent response_received_event{
      base::WaitableEvent::ResetPolicy::AUTOMATIC};
  std::move(call).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        CU_ASSERT_OK(response);
        response_received_event.Signal();
      }));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));

  test::StopGrpcServer(server, kServerStopTimeout);

  auto call1 = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  std::move(call1).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        ASSERT_THAT(response, StatusIs(grpc::StatusCode::UNAVAILABLE));
        response_received_event.Signal();
      }));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));
}

TEST_F(GrpcServerStreamingTest, ServerStreamingCallIsCancelledByClient) {
  base::WaitableEvent server_aborted_event;
  auto writes_available_callback = base::BindLambdaForTesting(
      [&](grpc::Status status,
          ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
        // The write callback can be called at any point in time with
        // ABORTED error, so ignore the success call.
        if (status.ok()) {
          return;
        }
        ASSERT_THAT(status, StatusIs(grpc::StatusCode::ABORTED));
        ASSERT_THAT(reactor, IsNull());
        server_aborted_event.Signal();
      });
  auto call_handler = base::BindLambdaForTesting(
      [&](TestRequest request,
          ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
        EXPECT_EQ(request.foo(), "test_foo");
        reactor->SetWritesAvailableCallback(
            std::move(writes_available_callback));
        TestResponse response;
        response.set_bar("test_bar");
        reactor->Write(std::move(response));
      });

  GrpcServer server;
  server.SetHandler<ServerStreamingServiceHandler::StreamingCall>(
      std::move(call_handler));
  ASSERT_THAT(server.Start(endpoint_), StatusIs(grpc::StatusCode::OK));

  size_t response_count = 0;
  base::WaitableEvent response_received_event{
      base::WaitableEvent::ResetPolicy::AUTOMATIC};
  ServerStreamingServiceStub stub(endpoint_);
  auto call = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  call.request().set_foo("test_foo");
  auto context = std::move(call).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        // Only one success response should be received.
        ++response_count;
        if (response_count == 1) {
          CU_CHECK_OK(response);
          EXPECT_EQ(response->bar(), "test_bar");
          response_received_event.Signal();
        } else {
          EXPECT_EQ(response_count, 2u);
          ASSERT_THAT(response, StatusIs(grpc::StatusCode::CANCELLED));
          response_received_event.Signal();
        }
      }));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));

  // Cancel the client call and wait for server and client to get the
  // notification.
  context.Cancel();
  ASSERT_TRUE(server_aborted_event.TimedWait(kEventTimeout));
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));
  task_environment_.RunUntilIdle();

  test::StopGrpcServer(server, kServerStopTimeout);
  task_environment_.RunUntilIdle();
}

TEST_F(GrpcServerStreamingTest,
       ServerStreamingCallIsCancelledByClientInActiveReactor) {
  GrpcServer server;

  base::WaitableEvent server_request_received;
  ServerStreamingServiceHandler::StreamingCall::Reactor* cancelled_reactor;
  server.SetHandler<ServerStreamingServiceHandler::StreamingCall>(
      base::BindLambdaForTesting(
          [&](TestRequest request,
              ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
            cancelled_reactor = reactor;
            server_request_received.Signal();
          }));
  ASSERT_THAT(server.Start(endpoint_), StatusIs(grpc::StatusCode::OK));

  base::WaitableEvent response_received_event;
  ServerStreamingServiceStub stub(endpoint_);
  auto call = stub.CreateCall<ServerStreamingServiceStub::StreamingCall>();
  auto context = std::move(call).InvokeAsync(base::BindLambdaForTesting(
      [&](GrpcStatusOr<TestResponse> response, bool done) {
        ASSERT_THAT(response, StatusIs(grpc::StatusCode::CANCELLED));
        response_received_event.Signal();
      }));

  // Wait for server to receive the request.
  ASSERT_TRUE(server_request_received.TimedWait(kEventTimeout));

  // Cancel the client call and wait for server and client to get the
  // notification.
  context.Cancel();
  ASSERT_TRUE(response_received_event.TimedWait(kEventTimeout));

  // Try to write a response with the reactor, and expect the ABORTED error.
  base::WaitableEvent reactor_aborted;
  cancelled_reactor->SetWritesAvailableCallback(base::BindLambdaForTesting(
      [&](grpc::Status status,
          ServerStreamingServiceHandler::StreamingCall::Reactor* reactor) {
        ASSERT_THAT(reactor, IsNull());
        ASSERT_THAT(status, StatusIs(grpc::StatusCode::ABORTED));
        reactor_aborted.Signal();
      }));
  ASSERT_TRUE(reactor_aborted.TimedWait(kEventTimeout));

  cancelled_reactor->Write(TestResponse());
  test::StopGrpcServer(server, kServerStopTimeout);
}

}  // namespace
}  // namespace utils
}  // namespace cast