// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/elevation_service/elevator.h"
#include <dpapi.h>
#include <oleauto.h>
#include <stdint.h>
#include <string>
#include <vector>
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/numerics/safe_conversions.h"
#include "base/process/process.h"
#include "base/strings/sys_string_conversions.h"
#include "base/version_info/channel.h"
#include "base/version_info/version_info.h"
#include "base/win/scoped_localalloc.h"
#include "base/win/win_util.h"
#include "build/branding_buildflags.h"
#include "chrome/elevation_service/caller_validation.h"
#include "chrome/elevation_service/elevated_recovery_impl.h"
#include "chrome/install_static/install_util.h"
#include "third_party/abseil-cpp/absl/cleanup/cleanup.h"
#if BUILDFLAG(GOOGLE_CHROME_BRANDING)
#include "chrome/elevation_service/internal/elevation_service_internal.h"
#endif // BUILDFLAG(GOOGLE_CHROME_BRANDING)
namespace elevation_service {
namespace {
// Returns a base::Process of the process making the RPC call to us, or invalid
// base::Process if could not be determined.
base::Process GetCallingProcess() {
// Validation should always be done impersonating the caller.
HANDLE calling_process_handle;
RPC_STATUS status = I_RpcOpenClientProcess(
nullptr, PROCESS_QUERY_LIMITED_INFORMATION, &calling_process_handle);
// RPC_S_NO_CALL_ACTIVE indicates that the caller is local process.
if (status == RPC_S_NO_CALL_ACTIVE)
return base::Process::Current();
if (status != RPC_S_OK)
return base::Process();
return base::Process(calling_process_handle);
}
} // namespace
HRESULT Elevator::RunRecoveryCRXElevated(const wchar_t* crx_path,
const wchar_t* browser_appid,
const wchar_t* browser_version,
const wchar_t* session_id,
DWORD caller_proc_id,
ULONG_PTR* proc_handle) {
base::win::ScopedHandle scoped_proc_handle;
HRESULT hr = RunChromeRecoveryCRX(base::FilePath(crx_path), browser_appid,
browser_version, session_id, caller_proc_id,
&scoped_proc_handle);
*proc_handle = base::win::HandleToUint32(scoped_proc_handle.Take());
return hr;
}
HRESULT Elevator::EncryptData(ProtectionLevel protection_level,
const BSTR plaintext,
BSTR* ciphertext,
DWORD* last_error) {
if (protection_level >= ProtectionLevel::PROTECTION_MAX) {
return kErrorUnsupportedProtectionLevel;
}
UINT length = ::SysStringByteLen(plaintext);
if (!length)
return E_INVALIDARG;
std::string plaintext_str(reinterpret_cast<char*>(plaintext), length);
#if BUILDFLAG(GOOGLE_CHROME_BRANDING)
auto pre_process_result = PreProcessData(plaintext_str);
if (!pre_process_result.has_value()) {
return pre_process_result.error();
}
plaintext_str.swap(*pre_process_result);
#endif // BUILDFLAG(GOOGLE_CHROME_BRANDING)
HRESULT hr = ::CoImpersonateClient();
if (FAILED(hr))
return hr;
DATA_BLOB intermediate = {};
{
absl::Cleanup revert_to_self = [] { ::CoRevertToSelf(); };
const auto calling_process = GetCallingProcess();
if (!calling_process.IsValid())
return kErrorCouldNotObtainCallingProcess;
const auto validation_data =
GenerateValidationData(protection_level, calling_process);
if (!validation_data.has_value()) {
return validation_data.error();
}
const auto data =
std::string(validation_data->cbegin(), validation_data->cend());
std::string data_to_encrypt;
AppendStringWithLength(data, data_to_encrypt);
AppendStringWithLength(plaintext_str, data_to_encrypt);
DATA_BLOB input = {};
input.cbData = base::checked_cast<DWORD>(data_to_encrypt.length());
input.pbData = const_cast<BYTE*>(
reinterpret_cast<const BYTE*>(data_to_encrypt.data()));
if (!::CryptProtectData(
&input, /*szDataDescr=*/
base::SysUTF8ToWide(version_info::GetProductName()).c_str(),
nullptr, nullptr, nullptr, /*dwFlags=*/CRYPTPROTECT_AUDIT,
&intermediate)) {
*last_error = ::GetLastError();
return kErrorCouldNotEncryptWithUserContext;
}
}
DATA_BLOB output = {};
{
base::win::ScopedLocalAlloc intermediate_freer(intermediate.pbData);
if (!::CryptProtectData(
&intermediate,
/*szDataDescr=*/
base::SysUTF8ToWide(version_info::GetProductName()).c_str(),
nullptr, nullptr, nullptr, /*dwFlags=*/CRYPTPROTECT_AUDIT,
&output)) {
*last_error = ::GetLastError();
return kErrorCouldNotEncryptWithSystemContext;
}
}
base::win::ScopedLocalAlloc output_freer(output.pbData);
*ciphertext = ::SysAllocStringByteLen(reinterpret_cast<LPCSTR>(output.pbData),
output.cbData);
if (!*ciphertext)
return E_OUTOFMEMORY;
return S_OK;
}
HRESULT Elevator::DecryptData(const BSTR ciphertext,
BSTR* plaintext,
DWORD* last_error) {
UINT length = ::SysStringByteLen(ciphertext);
if (!length)
return E_INVALIDARG;
DATA_BLOB input = {};
input.cbData = length;
input.pbData = reinterpret_cast<BYTE*>(ciphertext);
DATA_BLOB intermediate = {};
// Decrypt using the SYSTEM dpapi store.
if (!::CryptUnprotectData(&input, nullptr, nullptr, nullptr, nullptr, 0,
&intermediate)) {
*last_error = ::GetLastError();
return kErrorCouldNotDecryptWithSystemContext;
}
base::win::ScopedLocalAlloc intermediate_freer(intermediate.pbData);
HRESULT hr = ::CoImpersonateClient();
if (FAILED(hr))
return hr;
std::string plaintext_str;
{
DATA_BLOB output = {};
absl::Cleanup revert_to_self = [] { ::CoRevertToSelf(); };
// Decrypt using the user store.
if (!::CryptUnprotectData(&intermediate, nullptr, nullptr, nullptr, nullptr,
0, &output)) {
*last_error = ::GetLastError();
return kErrorCouldNotDecryptWithUserContext;
}
base::win::ScopedLocalAlloc output_freer(output.pbData);
std::string mutable_plaintext(reinterpret_cast<char*>(output.pbData),
output.cbData);
const std::string validation_data = PopFromStringFront(mutable_plaintext);
if (validation_data.empty()) {
return E_INVALIDARG;
}
const auto data =
std::vector<uint8_t>(validation_data.cbegin(), validation_data.cend());
const auto process = GetCallingProcess();
if (!process.IsValid()) {
*last_error = ::GetLastError();
return kErrorCouldNotObtainCallingProcess;
}
// Note: Validation should always be done using caller impersonation token.
std::string log_message;
HRESULT validation_result = ValidateData(process, data, &log_message);
if (FAILED(validation_result)) {
*last_error = ::GetLastError();
// Only enable extended logging on Dev channel.
if (install_static::GetChromeChannel() == version_info::Channel::DEV &&
!log_message.empty()) {
*plaintext =
::SysAllocStringByteLen(log_message.c_str(), log_message.length());
}
return validation_result;
}
plaintext_str = PopFromStringFront(mutable_plaintext);
}
#if BUILDFLAG(GOOGLE_CHROME_BRANDING)
auto post_process_result = PostProcessData(plaintext_str);
if (!post_process_result.has_value()) {
return post_process_result.error();
}
plaintext_str.swap(*post_process_result);
#endif // BUILDFLAG(GOOGLE_CHROME_BRANDING)
*plaintext =
::SysAllocStringByteLen(plaintext_str.c_str(), plaintext_str.length());
if (!*plaintext)
return E_OUTOFMEMORY;
return S_OK;
}
// static
void Elevator::AppendStringWithLength(const std::string& to_append,
std::string& base) {
uint32_t size = base::checked_cast<uint32_t>(to_append.length());
base.append(reinterpret_cast<char*>(&size), sizeof(size));
base.append(to_append);
}
// static
std::string Elevator::PopFromStringFront(std::string& str) {
uint32_t size;
if (str.length() < sizeof(size))
return std::string();
auto it = str.begin();
// Obtain the size.
memcpy(&size, str.c_str(), sizeof(size));
// Skip over the size field.
std::string value;
if (size) {
it += sizeof(size);
// Pull the string out.
value.assign(it, it + size);
DCHECK_EQ(value.length(), base::checked_cast<std::string::size_type>(size));
}
// Trim the string to the remainder.
str = str.substr(sizeof(size) + size);
return value;
}
} // namespace elevation_service