chromium/third_party/blink/renderer/platform/peerconnection/h265_parameter_sets_tracker.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/351564777): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "third_party/blink/renderer/platform/peerconnection/h265_parameter_sets_tracker.h"

#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "base/check.h"
#include "base/logging.h"
#include "third_party/webrtc/common_video/h265/h265_common.h"
#include "third_party/webrtc/common_video/h265/h265_pps_parser.h"
#include "third_party/webrtc/common_video/h265/h265_sps_parser.h"
#include "third_party/webrtc/common_video/h265/h265_vps_parser.h"

namespace blink {

namespace {
constexpr size_t kMaxParameterSetSizeBytes = 1024;
}

H265ParameterSetsTracker::H265ParameterSetsTracker() = default;
H265ParameterSetsTracker::~H265ParameterSetsTracker() = default;

H265ParameterSetsTracker::PpsData::PpsData() = default;
H265ParameterSetsTracker::PpsData::PpsData(PpsData&& rhs) = default;
H265ParameterSetsTracker::PpsData& H265ParameterSetsTracker::PpsData::operator=(
    PpsData&& rhs) = default;
H265ParameterSetsTracker::PpsData::~PpsData() = default;

H265ParameterSetsTracker::SpsData::SpsData() = default;
H265ParameterSetsTracker::SpsData::SpsData(SpsData&& rhs) = default;
H265ParameterSetsTracker::SpsData& H265ParameterSetsTracker::SpsData::operator=(
    SpsData&& rhs) = default;
H265ParameterSetsTracker::SpsData::~SpsData() = default;

H265ParameterSetsTracker::VpsData::VpsData() = default;
H265ParameterSetsTracker::VpsData::VpsData(VpsData&& rhs) = default;
H265ParameterSetsTracker::VpsData& H265ParameterSetsTracker::VpsData::operator=(
    VpsData&& rhs) = default;
H265ParameterSetsTracker::VpsData::~VpsData() = default;

H265ParameterSetsTracker::FixedBitstream
H265ParameterSetsTracker::MaybeFixBitstream(
    rtc::ArrayView<const uint8_t> bitstream) {
  if (!bitstream.size()) {
    return {PacketAction::kRequestKeyframe};
  }

  bool has_irap_nalu = false;
  bool prepend_vps = true, prepend_sps = true, prepend_pps = true;

  // Required size of fixed bitstream.
  size_t required_size = 0;
  H265ParameterSetsTracker::FixedBitstream fixed;
  fixed.action = PacketAction::kPassThrough;

  auto vps_data = vps_data_.end();
  auto sps_data = sps_data_.end();
  auto pps_data = pps_data_.end();
  std::optional<uint32_t> pps_id;
  uint32_t sps_id = 0, vps_id = 0;
  uint32_t slice_sps_id = 0, slice_pps_id = 0;

  parser_.ParseBitstream(
      rtc::ArrayView<const uint8_t>(bitstream.data(), bitstream.size()));

  std::vector<webrtc::H265::NaluIndex> nalu_indices =
      webrtc::H265::FindNaluIndices(bitstream.data(), bitstream.size());
  for (const auto& nalu_index : nalu_indices) {
    if (nalu_index.payload_size < 2) {
      // H.265 NALU header is at least 2 bytes.
      return {PacketAction::kRequestKeyframe};
    }
    const uint8_t* payload_start =
        bitstream.data() + nalu_index.payload_start_offset;
    const uint8_t* nalu_start = bitstream.data() + nalu_index.start_offset;
    size_t nalu_size = nalu_index.payload_size +
                       nalu_index.payload_start_offset -
                       nalu_index.start_offset;
    uint8_t nalu_type = webrtc::H265::ParseNaluType(payload_start[0]);

    std::optional<webrtc::H265VpsParser::VpsState> vps;
    std::optional<webrtc::H265SpsParser::SpsState> sps;

    switch (nalu_type) {
      case webrtc::H265::NaluType::kVps:
        // H.265 parameter set parsers expect NALU header already stripped.
        vps = webrtc::H265VpsParser::ParseVps(payload_start + 2,
                                              nalu_index.payload_size - 2);
        // Always replace VPS with the same ID. Same for other parameter sets.
        if (vps) {
          std::unique_ptr<VpsData> current_vps_data =
              std::make_unique<VpsData>();
          // Copy with start code included. Same for other parameter sets.
          if (!current_vps_data.get() || !nalu_size ||
              nalu_size > kMaxParameterSetSizeBytes) {
            return {PacketAction::kRequestKeyframe};
          }
          current_vps_data->size = nalu_size;
          uint8_t* vps_payload = new uint8_t[current_vps_data->size];
          memcpy(vps_payload, nalu_start, current_vps_data->size);
          current_vps_data->payload.reset(vps_payload);
          vps_data_.Set(vps->id, std::move(current_vps_data));
        }
        prepend_vps = false;
        break;
      case webrtc::H265::NaluType::kSps:
        sps = webrtc::H265SpsParser::ParseSps(payload_start + 2,
                                              nalu_index.payload_size - 2);
        if (sps) {
          std::unique_ptr<SpsData> current_sps_data =
              std::make_unique<SpsData>();
          if (!current_sps_data.get() || !nalu_size ||
              nalu_size > kMaxParameterSetSizeBytes) {
            return {PacketAction::kRequestKeyframe};
          }
          current_sps_data->size = nalu_size;
          current_sps_data->vps_id = sps->vps_id;
          uint8_t* sps_payload = new uint8_t[current_sps_data->size];
          memcpy(sps_payload, nalu_start, current_sps_data->size);
          current_sps_data->payload.reset(sps_payload);
          sps_data_.Set(sps->sps_id, std::move(current_sps_data));
        }
        prepend_sps = false;
        break;
      case webrtc::H265::NaluType::kPps:
        if (webrtc::H265PpsParser::ParsePpsIds(payload_start + 2,
                                               nalu_index.payload_size - 2,
                                               &slice_pps_id, &slice_sps_id)) {
          auto current_sps_data = sps_data_.find(slice_sps_id);
          if (current_sps_data == sps_data_.end()) {
            DLOG(WARNING) << "No SPS associated with current parsed PPS found.";
            fixed.action = PacketAction::kRequestKeyframe;
          } else {
            std::unique_ptr<PpsData> current_pps_data =
                std::make_unique<PpsData>();
            if (!current_pps_data.get() || !nalu_size ||
                nalu_size > kMaxParameterSetSizeBytes) {
              return {PacketAction::kRequestKeyframe};
            }
            current_pps_data->size = nalu_size;
            current_pps_data->sps_id = slice_sps_id;
            uint8_t* pps_payload = new uint8_t[current_pps_data->size];
            memcpy(pps_payload, nalu_start, current_pps_data->size);
            current_pps_data->payload.reset(pps_payload);
            pps_data_.Set(slice_pps_id, std::move(current_pps_data));
          }
          prepend_pps = false;
        }
        break;
      case webrtc::H265::NaluType::kBlaWLp:
      case webrtc::H265::NaluType::kBlaWRadl:
      case webrtc::H265::NaluType::kBlaNLp:
      case webrtc::H265::NaluType::kIdrWRadl:
      case webrtc::H265::NaluType::kIdrNLp:
      case webrtc::H265::NaluType::kCra:
        has_irap_nalu = true;
        pps_id = parser_.GetLastSlicePpsId();
        if (!pps_id) {
          DLOG(WARNING) << "Failed to parse PPS id from current slice.";
          fixed.action = PacketAction::kRequestKeyframe;
          break;
        }
        pps_data = pps_data_.find(pps_id.value());
        if (pps_data == pps_data_.end()) {
          DLOG(WARNING) << "PPS associated with current slice is not found.";
          fixed.action = PacketAction::kRequestKeyframe;
          break;
        }

        sps_id = (pps_data->value)->sps_id;
        sps_data = sps_data_.find(sps_id);
        if (sps_data == sps_data_.end()) {
          DLOG(WARNING) << "SPS associated with current slice is not found.";
          fixed.action = PacketAction::kRequestKeyframe;
          break;
        }

        vps_id = (sps_data->value)->vps_id;
        vps_data = vps_data_.find(vps_id);
        if (vps_data == vps_data_.end()) {
          DLOG(WARNING) << "VPS associated with current slice is not found.";
          fixed.action = PacketAction::kRequestKeyframe;
          break;
        }

        if (!prepend_vps && !prepend_sps && !prepend_pps) {
          fixed.action = PacketAction::kPassThrough;
        } else {
          required_size += vps_data->value->size + sps_data->value->size +
                           pps_data->value->size;

          required_size += bitstream.size();
          size_t offset = 0;

          fixed.bitstream = webrtc::EncodedImageBuffer::Create(required_size);
          memcpy(fixed.bitstream->data(), vps_data->value->payload.get(),
                 vps_data->value->size);
          offset += vps_data->value->size;
          memcpy(fixed.bitstream->data() + offset,
                 sps_data->value->payload.get(), sps_data->value->size);
          offset += sps_data->value->size;
          memcpy(fixed.bitstream->data() + offset,
                 pps_data->value->payload.get(), pps_data->value->size);
          offset += pps_data->value->size;
          memcpy(fixed.bitstream->data() + offset, bitstream.data(),
                 bitstream.size());

          fixed.action = PacketAction::kInsert;
        }
        break;
      default:
        break;
    }

    if (fixed.action == PacketAction::kRequestKeyframe) {
      return {PacketAction::kRequestKeyframe};
    } else if (fixed.action == PacketAction::kInsert) {
      return fixed;
    }

    if (has_irap_nalu) {
      break;
    }
  }

  fixed.action = PacketAction::kPassThrough;

  return fixed;
}

}  // namespace blink