// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/printing/oauth2/authorization_server_session.h"
#include <algorithm>
#include <string>
#include <vector>
#include "base/containers/flat_set.h"
#include "base/functional/bind.h"
#include "base/strings/string_split.h"
#include "chrome/browser/ash/printing/oauth2/constants.h"
#include "chrome/browser/ash/printing/oauth2/http_exchange.h"
#include "chromeos/printing/uri.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "url/gurl.h"
namespace ash {
namespace printing {
namespace oauth2 {
base::flat_set<std::string> ParseScope(const std::string& scope) {
std::vector<std::string> tokens = base::SplitString(
scope, " ", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
base::flat_set<std::string> output(std::move(tokens));
return output;
}
AuthorizationServerSession::AuthorizationServerSession(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
const GURL& token_endpoint_uri,
base::flat_set<std::string>&& scope)
: token_endpoint_uri_(token_endpoint_uri),
scope_(scope),
http_exchange_(url_loader_factory) {}
AuthorizationServerSession::~AuthorizationServerSession() = default;
bool AuthorizationServerSession::ContainsAll(
const base::flat_set<std::string>& scope) const {
return std::includes(scope_.begin(), scope_.end(), scope.begin(),
scope.end());
}
void AuthorizationServerSession::AddToWaitingList(StatusCallback callback) {
callbacks_.push_back(std::move(callback));
}
std::vector<StatusCallback> AuthorizationServerSession::TakeWaitingList() {
std::vector<StatusCallback> waitlist;
waitlist.swap(callbacks_);
return waitlist;
}
void AuthorizationServerSession::SendFirstTokenRequest(
const std::string& client_id,
const std::string& authorization_code,
const std::string& code_verifier,
StatusCallback callback) {
net::PartialNetworkTrafficAnnotationTag partial_traffic_annotation =
net::DefinePartialNetworkTrafficAnnotation(
"printing_oauth2_first_token_request",
"printing_oauth2_http_exchange", R"(semantics {
description:
"This request opens OAuth 2 session with the Authorization Server by "
"asking it for an access token."
data:
"Identifier of the client obtained from the Authorization server during "
"registration and temporary security codes used during authorization "
"process."
})");
http_exchange_.Clear();
// Moves query parameters from URL to the content.
chromeos::Uri uri(token_endpoint_uri_.spec());
auto query = uri.GetQuery();
for (const auto& kv : query) {
http_exchange_.AddParamString(kv.first, kv.second);
}
uri.SetQuery({});
// Prepare the request.
http_exchange_.AddParamString("grant_type", "authorization_code");
http_exchange_.AddParamString("code", authorization_code);
http_exchange_.AddParamString("redirect_uri", kRedirectURI);
http_exchange_.AddParamString("client_id", client_id);
http_exchange_.AddParamString("code_verifier", code_verifier);
http_exchange_.Exchange(
"POST", GURL(uri.GetNormalized()), ContentFormat::kXWwwFormUrlencoded,
200, 400, partial_traffic_annotation,
base::BindOnce(&AuthorizationServerSession::OnFirstTokenResponse,
base::Unretained(this), std::move(callback)));
}
void AuthorizationServerSession::SendNextTokenRequest(StatusCallback callback) {
access_token_.clear();
if (refresh_token_.empty()) {
std::move(callback).Run(StatusCode::kAuthorizationNeeded,
"No refresh token");
return;
}
net::PartialNetworkTrafficAnnotationTag partial_traffic_annotation =
net::DefinePartialNetworkTrafficAnnotation(
"printing_oauth2_next_token_request", "printing_oauth2_http_exchange",
R"(semantics {
description:
"This request refreshes OAuth 2 session with the Authorization Server by "
"asking it for a new access token."
data:
"A refresh token previously issued by the Authorization Server."
})");
http_exchange_.Clear();
// Move query parameters from URL to the content.
chromeos::Uri uri(token_endpoint_uri_.spec());
auto query = uri.GetQuery();
for (const auto& kv : query) {
http_exchange_.AddParamString(kv.first, kv.second);
}
uri.SetQuery({});
// Prepare the request.
http_exchange_.AddParamString("grant_type", "refresh_token");
http_exchange_.AddParamString("refresh_token", refresh_token_);
http_exchange_.Exchange(
"POST", GURL(uri.GetNormalized()), ContentFormat::kXWwwFormUrlencoded,
200, 400, partial_traffic_annotation,
base::BindOnce(&AuthorizationServerSession::OnNextTokenResponse,
base::Unretained(this), std::move(callback)));
}
void AuthorizationServerSession::OnFirstTokenResponse(StatusCallback callback,
StatusCode status) {
if (status != StatusCode::kOK) {
std::move(callback).Run(status, http_exchange_.GetErrorMessage());
return;
}
// Parses response.
std::string scope;
const bool ok =
http_exchange_.ParamStringGet("access_token", true, &access_token_) &&
http_exchange_.ParamStringEquals("token_type", true, "bearer") &&
http_exchange_.ParamStringGet("refresh_token", false, &refresh_token_) &&
http_exchange_.ParamStringGet("scope", false, &scope);
if (!ok) {
// Error occurred.
access_token_.clear();
refresh_token_.clear();
std::move(callback).Run(StatusCode::kInvalidResponse,
http_exchange_.GetErrorMessage());
return;
}
// Success!
auto new_scope = ParseScope(scope);
scope_.insert(new_scope.begin(), new_scope.end());
std::move(callback).Run(StatusCode::kOK, access_token_);
}
void AuthorizationServerSession::OnNextTokenResponse(StatusCallback callback,
StatusCode status) {
if (status == StatusCode::kInvalidAccessToken) {
std::move(callback).Run(StatusCode::kAuthorizationNeeded,
"Refresh token expired");
return;
}
if (status != StatusCode::kOK) {
std::move(callback).Run(status, http_exchange_.GetErrorMessage());
return;
}
// Parses response.
const bool ok =
http_exchange_.ParamStringGet("access_token", true, &access_token_) &&
http_exchange_.ParamStringEquals("token_type", true, "bearer") &&
http_exchange_.ParamStringGet("refresh_token", false, &refresh_token_);
if (!ok) {
// Error occurred.
access_token_.clear();
refresh_token_.clear();
std::move(callback).Run(StatusCode::kInvalidResponse,
http_exchange_.GetErrorMessage());
return;
}
// Success!
std::move(callback).Run(StatusCode::kOK, access_token_);
}
} // namespace oauth2
} // namespace printing
} // namespace ash