chromium/chrome/browser/ash/printing/oauth2/http_exchange.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/printing/oauth2/http_exchange.h"

#include <string>
#include <utility>

#include "base/functional/bind.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/notreached.h"
#include "base/strings/escape.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/values.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
#include "services/network/public/mojom/url_response_head.mojom.h"

namespace ash {
namespace printing {
namespace oauth2 {

namespace {

// Converts ContentFormat to MIME string.
std::string ToString(ContentFormat format) {
  switch (format) {
    case ContentFormat::kJson:
      return "application/json";
    case ContentFormat::kXWwwFormUrlencoded:
      return "application/x-www-form-urlencoded";
    default:
      NOTREACHED_IN_MIGRATION();
  }
  return "";
}

}  // namespace

HttpExchange::HttpExchange(
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
    : url_loader_factory_(url_loader_factory) {}

HttpExchange::~HttpExchange() {}

void HttpExchange::Clear() {
  content_.clear();
  error_msg_.clear();
  url_loader_.reset();
}

void HttpExchange::AddParamString(const std::string& name,
                                  const std::string& value) {
  DCHECK(!name.empty());
  content_.Set(name, value);
}

void HttpExchange::AddParamArrayString(const std::string& name,
                                       const std::vector<std::string>& value) {
  DCHECK(!name.empty());
  base::Value::List list_node;
  for (const auto& value_element : value) {
    list_node.Append(value_element);
  }
  content_.Set(name, std::move(list_node));
}

void HttpExchange::Exchange(
    const std::string& http_method,
    const GURL& url,
    ContentFormat request_format,
    int success_http_status,
    int error_http_status,
    const net::PartialNetworkTrafficAnnotationTag& partial_traffic_annotation,
    OnExchangeCompletedCallback callback) {
  std::string data;
  // Converts `content_` to `data`.
  if (request_format == ContentFormat::kJson) {
    if (!base::JSONWriter::Write(content_, &data)) {
      error_msg_ = "Cannot create JSON payload.";
      std::move(callback).Run(StatusCode::kUnexpectedError);
      return;
    }
  } else if (request_format == ContentFormat::kXWwwFormUrlencoded) {
    for (const auto kv : content_) {
      if (!data.empty()) {
        data += "&";
      }
      data += base::EscapeUrlEncodedData(kv.first, true);
      data += "=";
      switch (kv.second.type()) {
        case base::Value::Type::BOOLEAN:
          data += (kv.second.GetBool()) ? ("true") : ("false");
          break;
        case base::Value::Type::INTEGER:
          data += base::NumberToString(kv.second.GetInt());
          break;
        case base::Value::Type::STRING:
          data += base::EscapeUrlEncodedData(kv.second.GetString(), true);
          break;
        default:
          error_msg_ = "Cannot save a vector value as x-www-form-urlencoded.";
          std::move(callback).Run(StatusCode::kUnexpectedError);
          return;
      }
    }
  }

  // Prepares `url_loader_`.
  auto resource_request = std::make_unique<network::ResourceRequest>();
  resource_request->method = http_method;
  resource_request->url = url;
  resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit;
  resource_request->headers.SetHeader("Accept", "application/json");
  net::NetworkTrafficAnnotationTag traffic_annotation =
      net::CompleteNetworkTrafficAnnotation("printing_oauth2_http_exchange",
                                            partial_traffic_annotation, R"(
        semantics {
          sender: "Printers Manager"
          trigger: "User tries to use a printer that requires OAuth2 "
            "authorization."
          destination: OTHER
          destination_other: "Trusted Authorization Server"
        }
        policy {
          cookies_allowed: NO
          setting:
            "You can enable or disable this experimental feature via 'Enable "
            "OAuth when printing via the IPP protocol' on the chrome://flags "
            "tab."
          policy_exception_justification: "This is an experimental feature. "
            "The policy controlling this feature does not yet exist."
        })");
  url_loader_ = network::SimpleURLLoader::Create(std::move(resource_request),
                                                 traffic_annotation);

  // Sets the payload.
  if (request_format != ContentFormat::kEmpty) {
    url_loader_->AttachStringForUpload(data, ToString(request_format));
  }
  // We want to receive the body even on error, as it may contain the reason
  // for failure.
  url_loader_->SetAllowHttpErrorResults(true);
  // It's safe to use Unretained below as the |url_loader_| is owned by
  // |this|.
  url_loader_->DownloadToString(
      url_loader_factory_.get(),
      base::BindOnce(&HttpExchange::OnURLLoaderCompleted,
                     base::Unretained(this), success_http_status,
                     error_http_status, std::move(callback)),
      1024 * 1024);
}

void HttpExchange::OnURLLoaderCompleted(
    int success_http_status,
    int error_http_status,
    OnExchangeCompletedCallback callback,
    std::unique_ptr<std::string> response_body) {
  // Checks for connection errors.
  const int net_error = url_loader_->NetError();
  if (!(net_error == net::OK && url_loader_->ResponseInfo() &&
        url_loader_->ResponseInfo()->headers)) {
    error_msg_ = base::StringPrintf("NetError=%d", net_error);
    std::move(callback).Run(StatusCode::kConnectionError);
    return;
  }

  // Checks for special HTTP status codes.
  const auto http_status = GetHttpStatus();
  if (http_status == 500) {
    error_msg_ = "Internal Server Error (HTTP status code=500)";
    std::move(callback).Run(StatusCode::kServerError);
    return;
  }
  if (http_status == 503) {
    error_msg_ = "Service Unavailable (HTTP status code=503)";
    std::move(callback).Run(StatusCode::kServerTemporarilyUnavailable);
    return;
  }
  if (http_status != success_http_status && http_status != error_http_status) {
    error_msg_ = "Unexpected HTTP status: " + base::NumberToString(http_status);
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }

  // Checks if response body (payload) exists. It cannot be empty since we
  // expect a JSON object.
  if (!response_body || response_body->empty()) {
    error_msg_ = "Missing payload.";
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }

  // Checks if the payload of the response contains a JSON object.
  std::string mime_type;
  std::string charset;
  url_loader_->ResponseInfo()->headers->GetMimeTypeAndCharset(&mime_type,
                                                              &charset);
  if (mime_type != "application/json") {
    error_msg_ = "Unexpected type of payload: " + mime_type;
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }
  if (!charset.empty() && charset != "utf-8" && charset != "ascii") {
    error_msg_ = "Unsupported charset: " + charset;
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }
  auto parsed = base::JSONReader::Read(*response_body);
  if (!parsed) {
    error_msg_ = "Cannot parse JSON payload.";
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }
  if (!parsed->is_dict()) {
    error_msg_ = "JSON payload is not a dictionary.";
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }
  content_ = std::move(parsed).value().TakeDict();

  // Exits if success.
  if (http_status == success_http_status) {
    std::move(callback).Run(StatusCode::kOK);
    return;
  }

  // Parses error response.
  std::string error;
  std::string error_description;
  std::string error_uri;
  if (!ParamStringGet("error", true, &error) ||
      !ParamStringGet("error_description", false, &error_description) ||
      !ParamStringGet("error_uri", false, &error_uri)) {
    std::move(callback).Run(StatusCode::kInvalidResponse);
    return;
  }

  // Handles the error response.
  error_msg_ = "error=" + error;
  if (!error_description.empty()) {
    error_msg_ += base::StrCat({"; description=", error_description});
  }
  if (!error_uri.empty()) {
    error_msg_ += base::StrCat({"; uri=", error_uri});
  }
  if (error == "invalid_grant") {
    std::move(callback).Run(StatusCode::kInvalidAccessToken);
  } else {
    std::move(callback).Run(StatusCode::kAccessDenied);
  }
}

int HttpExchange::GetHttpStatus() const {
  if (url_loader_ && url_loader_->ResponseInfo() &&
      url_loader_->ResponseInfo()->headers) {
    return url_loader_->ResponseInfo()->headers->response_code();
  }
  return 0;
}

bool HttpExchange::ParamArrayStringContains(const std::string& name,
                                            bool required,
                                            const std::string& value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_list()) {
    error_msg_ = base::StrCat({"Field ", name, " must be an array."});
    return false;
  }
  const auto& node_as_list = node->GetList();
  for (auto& element : node_as_list) {
    if (element.is_string() && element.GetString() == value) {
      // Success!
      return true;
    }
  }
  error_msg_ =
      base::StrCat({"Field ", name, " must contain the value '", value, "'"});
  return false;
}

bool HttpExchange::ParamArrayStringEquals(
    const std::string& name,
    bool required,
    const std::vector<std::string>& value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_list()) {
    error_msg_ = base::StrCat({"Field ", name, " must be an array"});
    return false;
  }
  const base::Value::List& node_as_list = node->GetList();
  if (node_as_list.size() == value.size()) {
    // Compares the vectors, element by element.
    bool are_equal = true;
    for (size_t i = 0; i < value.size() && are_equal; ++i) {
      const auto& element = node_as_list[i];
      if (!element.is_string() || element.GetString() != value[i]) {
        are_equal = false;
      }
    }
    if (are_equal) {
      // Success!
      return true;
    }
  }
  // Vectors are different, builds an error message.
  error_msg_ = base::StrCat({"Field ", name, " must contain the value ["});
  for (auto& element : value) {
    error_msg_ += base::StrCat({"'", element, "',"});
  }
  // Removes the last comma and add a closing bracket.
  if (!value.empty()) {
    error_msg_.pop_back();
  }
  error_msg_ += "]";
  return false;
}

bool HttpExchange::ParamStringGet(const std::string& name,
                                  bool required,
                                  std::string* value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_string()) {
    error_msg_ = base::StrCat({"Field ", name, " must be a string"});
    return false;
  }
  if (required && node->GetString().empty()) {
    error_msg_ = base::StrCat({"Field ", name, " cannot be empty"});
    return false;
  }
  if (value) {
    *value = node->GetString();
  }
  return true;
}

bool HttpExchange::ParamStringEquals(const std::string& name,
                                     bool required,
                                     const std::string& value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_string()) {
    error_msg_ = base::StrCat({"Field ", name, " must be a string"});
    return false;
  }
  if (value != node->GetString()) {
    error_msg_ = base::StrCat({"Field ", name, " must be equal '", value, "'"});
    return false;
  }
  return true;
}

bool HttpExchange::ParamURLGet(const std::string& name,
                               bool required,
                               GURL* value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_string()) {
    error_msg_ = base::StrCat({"Field ", name, " must be an URL"});
    return false;
  }
  GURL gurl(node->GetString());
  if (gurl.is_valid() && gurl.IsStandard() && gurl.scheme() == "https") {
    // Success!
    if (value) {
      *value = gurl;
    }
    return true;
  }
  error_msg_ =
      base::StrCat({"Field ", name, " must be a valid URL of type 'https://'"});
  return false;
}

bool HttpExchange::ParamURLEquals(const std::string& name,
                                  bool required,
                                  const GURL& value) {
  base::Value* node = FindNode(name, required);
  if (!node) {
    return !required;
  }
  if (!node->is_string()) {
    error_msg_ = base::StrCat({"Field ", name, " must be an URL"});
    return false;
  }
  if (value != GURL(node->GetString())) {
    error_msg_ =
        base::StrCat({"Field ", name, " must be equal '", value.spec(), "'"});
    return false;
  }
  return true;
}

const std::string& HttpExchange::GetErrorMessage() const {
  return error_msg_;
}

base::Value* HttpExchange::FindNode(const std::string& name, bool required) {
  base::Value* value = content_.Find(name);
  if (required && !value) {
    error_msg_ = base::StrCat({"Field ", name, " is missing"});
  }
  return value;
}

}  // namespace oauth2
}  // namespace printing
}  // namespace ash