chromium/third_party/openscreen/src/osp/public/message_demuxer_unittest.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "osp/public/message_demuxer.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "osp/msgs/osp_messages.h"
#include "osp/public/testing/message_demuxer_test_support.h"
#include "platform/test/fake_clock.h"
#include "third_party/tinycbor/src/src/cbor.h"

namespace openscreen::osp {
namespace {

using ::testing::_;
using ::testing::Invoke;

ErrorOr<size_t> ConvertDecodeResult(ssize_t result) {}

class MessageDemuxerTest : public ::testing::Test {};

}  // namespace

TEST_F() {
  MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);
  ASSERT_TRUE();

  EXPECT_CALL().Times(0);
  demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());

  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);

  watch = MessageDemuxer::MessageWatch();
  EXPECT_CALL().Times(0);
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
}

TEST_F() {
  MockMessageCallback mock_callback_;
  constexpr uint64_t endpoint_id_ = 13;

  MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);
  ASSERT_TRUE();

  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .Times(2)
      .WillRepeatedly(Invoke([&decode_result, &received_request](
                                 uint64_t endpoint_id, uint64_t connection_id,
                                 msgs::Type message_type, const uint8_t* buffer,
                                 size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size() - 3);
  demuxer_.OnStreamData(endpoint_id_, connection_id_,
                        buffer_.data() + buffer_.size() - 3, 3);
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  MockMessageCallback mock_callback_;
  constexpr uint64_t endpoint_id_ = 13;

  MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
      msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
  ASSERT_TRUE();

  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  MockMessageCallback mock_callback_global;
  MockMessageCallback mock_callback_;
  constexpr uint64_t endpoint_id_ = 13;

  MessageDemuxer::MessageWatch default_watch =
      demuxer_.SetDefaultMessageTypeWatch(
          msgs::Type::kPresentationConnectionOpenRequest,
          &mock_callback_global);
  ASSERT_TRUE();
  MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);
  ASSERT_TRUE();

  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL().Times(0);
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);

  decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);
  ASSERT_TRUE();

  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  MockMessageCallback mock_init_callback;
  msgs::PresentationConnectionOpenRequest received_request;
  msgs::PresentationStartRequest received_init_request;
  ssize_t decode_result1 = 0;
  ssize_t decode_result2 = 0;
  MessageDemuxer::MessageWatch init_watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationStartRequest, &mock_init_callback);
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result1, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result1 = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result1);
      }));
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result2, &received_init_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result2 = msgs::DecodePresentationStartRequest(
            buffer, buffer_size, received_init_request);
        return ConvertDecodeResult(decode_result2);
      }));
  MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);
  ASSERT_TRUE();

  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());

  msgs::CborEncodeBuffer buffer;
  msgs::PresentationStartRequest request = {.request_id = 2,
                                            .url = "https://example.com/recv"};
  ASSERT_TRUE();
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer.data(),
                        buffer.size());

  ExpectDecodedRequest(decode_result1, received_request);
  ASSERT_GT();
  EXPECT_EQ();
  EXPECT_EQ();
  EXPECT_EQ();
}

TEST_F() {
  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
      msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
  ASSERT_TRUE();
  demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                        buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  MessageDemuxer demuxer(FakeClock::now, 10);

  demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                       buffer_.size());
  EXPECT_CALL().Times(0);
  MessageDemuxer::MessageWatch watch = demuxer.WatchMessageType(
      endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
      &mock_callback_);

  msgs::PresentationConnectionOpenRequest received_request;
  ssize_t decode_result = 0;
  EXPECT_CALL()
      .WillOnce(Invoke([&decode_result, &received_request](
                           uint64_t endpoint_id, uint64_t connection_id,
                           msgs::Type message_type, const uint8_t* buffer,
                           size_t buffer_size, Clock::time_point now) {
        decode_result = msgs::DecodePresentationConnectionOpenRequest(
            buffer, buffer_size, received_request);
        return ConvertDecodeResult(decode_result);
      }));
  demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
                       buffer_.size());
  ExpectDecodedRequest(decode_result, received_request);
}

TEST_F() {
  std::vector<uint8_t> kAgentInfoResponseSerialized{0x0B, 0xFF};
  std::vector<uint8_t> kPresentationConnectionCloseEventSerialized{0x40, 0x71,
                                                                   0x00};
  std::vector<uint8_t> kAuthCapabilitiesSerialized{0x43, 0xE9, 0xFF, 0x00};

  size_t used_bytes;
  auto kAgentInfoResponseInfo =
      MessageTypeDecoder::DecodeType(kAgentInfoResponseSerialized, &used_bytes);
  EXPECT_FALSE();
  EXPECT_EQ();
  EXPECT_EQ();

  auto kPresentationConnectionCloseEventInfo = MessageTypeDecoder::DecodeType(
      kPresentationConnectionCloseEventSerialized, &used_bytes);
  EXPECT_FALSE();
  EXPECT_EQ();
  EXPECT_EQ();

  auto kAuthCapabilitiesInfo =
      MessageTypeDecoder::DecodeType(kAuthCapabilitiesSerialized, &used_bytes);
  EXPECT_FALSE();
  EXPECT_EQ();
  EXPECT_EQ();

  auto kUnknownInfo = MessageTypeDecoder::DecodeType({0xFF}, &used_bytes);
  EXPECT_TRUE();
}

}  // namespace openscreen::osp