#include "components/safe_browsing/content/renderer/phishing_classifier/scorer.h"
#include <math.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "base/logging.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/shared_memory_mapping.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/not_fatal_until.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/trace_event/trace_event.h"
#include "components/safe_browsing/content/common/visual_utils.h"
#include "components/safe_browsing/content/renderer/phishing_classifier/features.h"
#include "components/safe_browsing/core/common/proto/client_model.pb.h"
#include "components/safe_browsing/core/common/proto/csd.pb.h"
#include "content/public/renderer/render_thread.h"
#include "crypto/sha2.h"
#include "skia/ext/image_operations.h"
#include "third_party/skia/include/core/SkBitmap.h"
#include "third_party/skia/include/core/SkColorSpace.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "third_party/tflite/src/tensorflow/lite/kernels/builtin_op_kernels.h"
#include "third_party/tflite/src/tensorflow/lite/op_resolver.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h"
#endif
namespace safe_browsing {
namespace {
std::string HashToString(const flat::Hash* hash) { … }
void RecordScorerCreationStatus(ScorerCreationStatus status) { … }
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
std::unique_ptr<tflite::MutableOpResolver> CreateOpResolver() { … }
std::unique_ptr<tflite::task::vision::ImageClassifier> CreateClassifier(
std::string model_data) { … }
std::unique_ptr<tflite::task::vision::ImageEmbedder> CreateImageEmbedder(
std::string model_data) { … }
std::string GetModelInput(const SkBitmap& bitmap,
int width,
int height,
bool image_embedding = false) { … }
auto CreateFrameBuffer(const std::string& model_input,
int input_width,
int input_height) { … }
void OnModelInputCreatedForClassifier(
const std::string& model_input,
int input_width,
int input_height,
std::unique_ptr<tflite::task::vision::ImageClassifier> classifier,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(std::vector<double>)> callback) { … }
void OnModelInputCreatedForImageEmbedding(
const std::string& model_input,
int input_width,
int input_height,
std::unique_ptr<tflite::task::vision::ImageEmbedder> image_embedder,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(ImageFeatureEmbedding)> callback) { … }
void OnClassifierCreated(
const SkBitmap& bitmap,
int input_width,
int input_height,
std::unique_ptr<tflite::task::vision::ImageClassifier> classifier,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(std::vector<double>)> callback) { … }
void OnImageEmbedderCreated(
const SkBitmap& bitmap,
int input_width,
int input_height,
std::unique_ptr<tflite::task::vision::ImageEmbedder> image_embedder,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(ImageFeatureEmbedding)> callback) { … }
#endif
}
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
void Scorer::ApplyVisualTfLiteModelHelper(
const SkBitmap& bitmap,
int input_width,
int input_height,
std::string model_data,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(std::vector<double>)> callback) { … }
void Scorer::ApplyImageEmbeddingTfLiteModelHelper(
const SkBitmap& bitmap,
int input_width,
int input_height,
const std::string& model_data,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
base::OnceCallback<void(ImageFeatureEmbedding)> callback) { … }
#endif
double Scorer::LogOdds2Prob(const double log_odds) const { … }
Scorer::Scorer() = default;
Scorer::~Scorer() = default;
ScorerStorage* ScorerStorage::GetInstance() { … }
ScorerStorage::ScorerStorage() = default;
ScorerStorage::~ScorerStorage() = default;
std::unique_ptr<Scorer> Scorer::Create(base::ReadOnlySharedMemoryRegion region,
base::File visual_tflite_model) { … }
std::unique_ptr<Scorer> Scorer::CreateScorerWithImageEmbeddingModel(
base::ReadOnlySharedMemoryRegion region,
base::File visual_tflite_model,
base::File image_embedding_model) { … }
void Scorer::AttachImageEmbeddingModel(base::File image_embedding_model) { … }
double Scorer::ComputeRuleScore(const flat::ClientSideModel_::Rule* rule,
const FeatureMap& features) const { … }
double Scorer::ComputeScore(const FeatureMap& features) const { … }
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
void Scorer::ApplyVisualTfLiteModel(
const SkBitmap& bitmap,
base::OnceCallback<void(std::vector<double>)> callback) const { … }
void Scorer::ApplyVisualTfLiteModelImageEmbedding(
const SkBitmap& bitmap,
base::OnceCallback<void(ImageFeatureEmbedding)> callback) const { … }
#endif
int Scorer::model_version() const { … }
int Scorer::dom_model_version() const { … }
bool Scorer::has_page_term(const std::string& str) const { … }
base::RepeatingCallback<bool(const std::string&)>
Scorer::find_page_term_callback() const { … }
bool Scorer::has_page_word(uint32_t page_word_hash) const { … }
base::RepeatingCallback<bool(uint32_t)> Scorer::find_page_word_callback()
const { … }
size_t Scorer::max_words_per_term() const { … }
uint32_t Scorer::murmurhash3_seed() const { … }
size_t Scorer::max_shingles_per_page() const { … }
size_t Scorer::shingle_size() const { … }
float Scorer::threshold_probability() const { … }
int Scorer::tflite_model_version() const { … }
const google::protobuf::RepeatedPtrField<TfLiteModelMetadata::Threshold>&
Scorer::tflite_thresholds() const { … }
int Scorer::image_embedding_tflite_model_version() const { … }
void ScorerStorage::SetScorer(std::unique_ptr<Scorer> scorer) { … }
void ScorerStorage::ClearScorer() { … }
Scorer* ScorerStorage::GetScorer() const { … }
void ScorerStorage::AddObserver(ScorerStorage::Observer* observer) { … }
void ScorerStorage::RemoveObserver(ScorerStorage::Observer* observer) { … }
}