#include "components/segmentation_platform/embedder/model_provider_factory_impl.h"
#include "base/task/sequenced_task_runner.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "components/segmentation_platform/internal/execution/optimization_guide/optimization_guide_segmentation_model_provider.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
namespace segmentation_platform {
namespace {
class DummyModelProvider : public ModelProvider { … };
}
ModelProviderFactoryImpl::ModelProviderFactoryImpl(
optimization_guide::OptimizationGuideModelProvider*
optimization_guide_provider,
std::vector<std::unique_ptr<Config>>& configs,
scoped_refptr<base::SequencedTaskRunner> background_task_runner)
: … { … }
ModelProviderFactoryImpl::~ModelProviderFactoryImpl() = default;
std::unique_ptr<ModelProvider> ModelProviderFactoryImpl::CreateProvider(
proto::SegmentId segment_id) { … }
std::unique_ptr<DefaultModelProvider>
ModelProviderFactoryImpl::CreateDefaultProvider(proto::SegmentId segment_id) { … }
TestDefaultModelOverride::TestDefaultModelOverride() = default;
TestDefaultModelOverride::~TestDefaultModelOverride() = default;
TestDefaultModelOverride& TestDefaultModelOverride::GetInstance() { … }
std::unique_ptr<DefaultModelProvider>
TestDefaultModelOverride::TakeOwnershipOfModelProvider(
proto::SegmentId target) { … }
void TestDefaultModelOverride::SetModelForTesting(
proto::SegmentId target,
std::unique_ptr<DefaultModelProvider> default_provider) { … }
}