chromium/media/cdm/win/test/media_foundation_clear_key_cdm.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "media/cdm/win/test/media_foundation_clear_key_cdm.h"

#include <mfapi.h>
#include <mferror.h>
#include <windows.media.protection.playready.h>
#include <wrl.h>
#include <wrl/client.h>
#include <wrl/implements.h>

#include "base/notreached.h"
#include "media/base/win/mf_feature_checks.h"
#include "media/base/win/mf_helpers.h"
#include "media/cdm/clear_key_cdm_common.h"
#include "media/cdm/win/test/media_foundation_clear_key_guids.h"
#include "media/cdm/win/test/media_foundation_clear_key_session.h"
#include "media/cdm/win/test/media_foundation_clear_key_trusted_input.h"
#include "media/cdm/win/test/mock_media_protection_pmp_server.h"

namespace media {

using Microsoft::WRL::ComPtr;
using Microsoft::WRL::MakeAndInitialize;

namespace {

static HRESULT AddPropertyToSet(
    _Inout_ ABI::Windows::Foundation::Collections::IPropertySet* property_set,
    _In_ LPCWSTR name,
    _In_ IInspectable* inspectable) {
  boolean replaced = false;
  ComPtr<ABI::Windows::Foundation::Collections::IMap<HSTRING, IInspectable*>>
      map;

  RETURN_IF_FAILED(property_set->QueryInterface(IID_PPV_ARGS(&map)));
  RETURN_IF_FAILED(
      map->Insert(Microsoft::WRL::Wrappers::HStringReference(name).Get(),
                  inspectable, &replaced));

  return S_OK;
}

static HRESULT AddStringToPropertySet(
    _Inout_ ABI::Windows::Foundation::Collections::IPropertySet* property_set,
    _In_ LPCWSTR name,
    _In_ LPCWSTR string) {
  ComPtr<ABI::Windows::Foundation::IPropertyValue> property_value;
  ComPtr<ABI::Windows::Foundation::IPropertyValueStatics>
      property_value_statics;

  RETURN_IF_FAILED(ABI::Windows::Foundation::GetActivationFactory(
      Microsoft::WRL::Wrappers::HStringReference(
          RuntimeClass_Windows_Foundation_PropertyValue)
          .Get(),
      &property_value_statics));

  RETURN_IF_FAILED(property_value_statics->CreateString(
      Microsoft::WRL::Wrappers::HStringReference(string).Get(),
      &property_value));
  RETURN_IF_FAILED(AddPropertyToSet(property_set, name, property_value.Get()));

  return S_OK;
}

static HRESULT AddBoolToPropertySet(
    _Inout_ ABI::Windows::Foundation::Collections::IPropertySet* property_set,
    _In_ LPCWSTR name,
    _In_ BOOL value) {
  ComPtr<ABI::Windows::Foundation::IPropertyValue> property_value;
  ComPtr<ABI::Windows::Foundation::IPropertyValueStatics>
      property_value_statics;

  RETURN_IF_FAILED(ABI::Windows::Foundation::GetActivationFactory(
      Microsoft::WRL::Wrappers::HStringReference(
          RuntimeClass_Windows_Foundation_PropertyValue)
          .Get(),
      &property_value_statics));

  RETURN_IF_FAILED(
      property_value_statics->CreateBoolean(!!value, &property_value));
  RETURN_IF_FAILED(AddPropertyToSet(property_set, name, property_value.Get()));

  return S_OK;
}

}  // namespace

MediaFoundationClearKeyCdm::MediaFoundationClearKeyCdm() {
  DVLOG_FUNC(1);
}

MediaFoundationClearKeyCdm::~MediaFoundationClearKeyCdm() {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  Shutdown();
}

HRESULT MediaFoundationClearKeyCdm::RuntimeClassInitialize(
    _In_ IPropertyStore* properties) {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  ComPtr<ABI::Windows::Foundation::Collections::IPropertySet> property_pmp;
  RETURN_IF_FAILED(Windows::Foundation::ActivateInstance(
      Microsoft::WRL::Wrappers::HStringReference(
          RuntimeClass_Windows_Foundation_Collections_PropertySet)
          .Get(),
      &property_pmp));

  // As a workaround to create an in-process PMP server, use the PlayReady media
  // protection system ID here as the MediaEngine will call
  // MFIsContentProtectionDeviceSupported() to determine whether the specified
  // protection system ID is supported.
  RETURN_IF_FAILED(AddStringToPropertySet(
      property_pmp.Get(), L"Windows.Media.Protection.MediaProtectionSystemId",
      PLAYREADY_GUID_MEDIA_PROTECTION_SYSTEM_ID_STRING));

  // Setting this to TRUE allows the system to create an in-process PMP server,
  // pretending to use hardware protection layer.
  RETURN_IF_FAILED(AddBoolToPropertySet(
      property_pmp.Get(),
      L"Windows.Media.Protection.UseHardwareProtectionLayer", TRUE));

  // Note that we don't need to add this property
  // "Windows.Media.Protection.MediaProtectionSystemIdMapping".

  // Use a custom PMP server so that MediaEngine can create an in-process PMP
  // server regardless of the system's hardware decryption capability.
  RETURN_IF_FAILED((MakeAndInitialize<
                    MockMediaProtectionPMPServer,
                    ABI::Windows::Media::Protection::IMediaProtectionPMPServer>(
      &media_protection_pmp_server_, property_pmp.Get())));

  return S_OK;
}

// IMFContentDecryptionModule
STDMETHODIMP MediaFoundationClearKeyCdm::SetContentEnabler(
    _In_ IMFContentEnabler* content_enabler,
    _In_ IMFAsyncResult* result) {
  DVLOG_FUNC(1);

  // This method can be called from a different MF thread, so the
  // DCHECK_CALLED_ON_VALID_THREAD(thread_checker_) is not checked here.

  RETURN_IF_FAILED(GetShutdownStatus());

  if (!content_enabler || !result) {
    return E_INVALIDARG;
  }

  // Invoke the callback immediately but will determine whether the keyid exists
  // or not in the decryptor's ProcessOutput().
  RETURN_IF_FAILED(MFInvokeCallback(result));

  return S_OK;
}

STDMETHODIMP MediaFoundationClearKeyCdm::GetSuspendNotify(
    _COM_Outptr_ IMFCdmSuspendNotify** notify) {
  DVLOG_FUNC(3);

  // API not used.
  NOTIMPLEMENTED();
  return E_NOTIMPL;
}

STDMETHODIMP MediaFoundationClearKeyCdm::SetPMPHostApp(IMFPMPHostApp* host) {
  DVLOG_FUNC(3);

  // API not used.
  NOTIMPLEMENTED();
  return E_NOTIMPL;
}

STDMETHODIMP MediaFoundationClearKeyCdm::CreateSession(
    _In_ MF_MEDIAKEYSESSION_TYPE session_type,
    _In_ IMFContentDecryptionModuleSessionCallbacks* callbacks,
    _COM_Outptr_ IMFContentDecryptionModuleSession** session) {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  RETURN_IF_FAILED(GetShutdownStatus());

  RETURN_IF_FAILED((MakeAndInitialize<MediaFoundationClearKeySession,
                                      IMFContentDecryptionModuleSession>(
      session, session_type, callbacks, GetAesDecryptor(),
      base::BindOnce(&MediaFoundationClearKeyCdm::OnSessionIdCreated,
                     weak_factory_.GetWeakPtr()),
      base::BindOnce(&MediaFoundationClearKeyCdm::OnSessionIdRemoved,
                     weak_factory_.GetWeakPtr()))));

  return S_OK;
}

STDMETHODIMP MediaFoundationClearKeyCdm::SetServerCertificate(
    _In_reads_bytes_opt_(server_certificate_size)
        const BYTE* server_certificate,
    _In_ DWORD server_certificate_size) {
  DVLOG_FUNC(3);

  // API not used.
  NOTIMPLEMENTED();
  return E_NOTIMPL;
}

STDMETHODIMP MediaFoundationClearKeyCdm::CreateTrustedInput(
    _In_reads_bytes_(content_init_data_size) const BYTE* content_init_data,
    _In_ DWORD content_init_data_size,
    _COM_Outptr_ IMFTrustedInput** trusted_input) {
  DVLOG_FUNC(1);

  // This method can be called from a different MF thread, so the
  // DCHECK_CALLED_ON_VALID_THREAD(thread_checker_) is not checked here.

  RETURN_IF_FAILED(GetShutdownStatus());

  ComPtr<IMFTrustedInput> trusted_input_new;
  RETURN_IF_FAILED(
      (MakeAndInitialize<MediaFoundationClearKeyTrustedInput, IMFTrustedInput>(
          &trusted_input_new, GetAesDecryptor())));

  *trusted_input = trusted_input_new.Detach();

  return S_OK;
}

STDMETHODIMP MediaFoundationClearKeyCdm::GetProtectionSystemIds(
    _Outptr_result_buffer_(*count) GUID** system_ids,
    _Out_ DWORD* count) {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  RETURN_IF_FAILED(GetShutdownStatus());

  *system_ids = nullptr;
  *count = 0;

  GUID* system_id = static_cast<GUID*>(CoTaskMemAlloc(sizeof(GUID)));
  if (!system_id) {
    return E_OUTOFMEMORY;
  }

  *system_id = MEDIA_FOUNDATION_CLEARKEY_GUID_CLEARKEY_PROTECTION_SYSTEM_ID;
  *system_ids = system_id;
  *count = 1;

  return S_OK;
}

// IMFGetService
STDMETHODIMP MediaFoundationClearKeyCdm::GetService(
    __RPC__in REFGUID guid_service,
    __RPC__in REFIID riid,
    __RPC__deref_out_opt LPVOID* object) {
  DVLOG_FUNC(1);

  // This method can be called from a different MF thread, so the
  // DCHECK_CALLED_ON_VALID_THREAD(thread_checker_) is not checked here.

  RETURN_IF_FAILED(GetShutdownStatus());

  if (MF_CONTENTDECRYPTIONMODULE_SERVICE != guid_service) {
    return MF_E_UNSUPPORTED_SERVICE;
  }

  if (media_protection_pmp_server_ == nullptr) {
    return MF_INVALID_STATE_ERR;
  }

  if (riid == ABI::Windows::Media::Protection::IID_IMediaProtectionPMPServer) {
    RETURN_IF_FAILED(media_protection_pmp_server_.CopyTo(riid, object));
  } else {
    ComPtr<IMFGetService> get_service;
    RETURN_IF_FAILED(media_protection_pmp_server_.As(&get_service));
    RETURN_IF_FAILED(get_service->GetService(MF_PMP_SERVICE, riid, object));
  }

  return S_OK;
}

// IMFShutdown
STDMETHODIMP MediaFoundationClearKeyCdm::Shutdown() {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  base::AutoLock lock(lock_);
  if (is_shutdown_) {
    return MF_E_SHUTDOWN;
  }

  is_shutdown_ = true;
  return S_OK;
}

STDMETHODIMP MediaFoundationClearKeyCdm::GetShutdownStatus(
    MFSHUTDOWN_STATUS* status) {
  DVLOG_FUNC(1);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  // Per IMFShutdown::GetShutdownStatus spec, MF_E_INVALIDREQUEST is returned if
  // Shutdown has not been called beforehand.
  base::AutoLock lock(lock_);
  if (!is_shutdown_) {
    return MF_E_INVALIDREQUEST;
  }

  return S_OK;
}

scoped_refptr<AesDecryptor> MediaFoundationClearKeyCdm::GetAesDecryptor() {
  DVLOG_FUNC(1);

  if (!aes_decryptor_) {
    aes_decryptor_ = base::MakeRefCounted<AesDecryptor>(
        base::BindRepeating(&MediaFoundationClearKeyCdm::OnSessionMessage,
                            weak_factory_.GetWeakPtr()),
        base::BindRepeating(&MediaFoundationClearKeyCdm::OnSessionClosed,
                            weak_factory_.GetWeakPtr()),
        base::BindRepeating(&MediaFoundationClearKeyCdm::OnSessionKeysChange,
                            weak_factory_.GetWeakPtr()),
        base::DoNothing());  // AesDecryptor never calls this.
  }

  return aes_decryptor_;
}

void MediaFoundationClearKeyCdm::OnSessionMessage(
    const std::string& session_id,
    CdmMessageType message_type,
    const std::vector<uint8_t>& message) {
  DVLOG_FUNC(1) << "session_id=" << session_id
                << ", message_type=" << static_cast<int>(message_type);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  auto* session = FindSession(session_id);
  CHECK(session);
  session->OnSessionMessage(session_id, message_type, message);
}

void MediaFoundationClearKeyCdm::OnSessionClosed(
    const std::string& session_id,
    CdmSessionClosedReason reason) {
  DVLOG_FUNC(1) << "session_id=" << session_id
                << ", reason=" << static_cast<int>(reason);
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  auto* session = FindSession(session_id);
  CHECK(session);
  session->OnSessionClosed(session_id, reason);
}

void MediaFoundationClearKeyCdm::OnSessionKeysChange(
    const std::string& session_id,
    bool has_additional_usable_key,
    CdmKeysInfo keys_info) {
  DVLOG_FUNC(1) << "session_id=" << session_id
                << ", has_additional_usable_key=" << has_additional_usable_key;
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  auto* session = FindSession(session_id);
  CHECK(session);
  session->OnSessionKeysChange(session_id, has_additional_usable_key,
                               std::move(keys_info));
}

void MediaFoundationClearKeyCdm::OnSessionIdCreated(
    const std::string& session_id,
    Microsoft::WRL::ComPtr<IMFContentDecryptionModuleSession> session) {
  DVLOG_FUNC(1) << "session_id=" << session_id;
  CHECK(FindSession(session_id) == nullptr);
  CHECK(session);

  sessions_.emplace(session_id, session);
}

void MediaFoundationClearKeyCdm::OnSessionIdRemoved(
    const std::string& session_id) {
  DVLOG_FUNC(1) << "session_id=" << session_id;
  auto it = sessions_.find(session_id);
  CHECK(it != sessions_.end());
  sessions_.erase(it);
}

MediaFoundationClearKeySession* MediaFoundationClearKeyCdm::FindSession(
    const std::string& session_id) {
  DVLOG_FUNC(3) << "session_id=" << session_id;
  auto it = sessions_.find(session_id);
  return it == sessions_.end()
             ? nullptr
             : static_cast<MediaFoundationClearKeySession*>(it->second.Get());
}

}  // namespace media