chromium/chrome/browser/segmentation_platform/android/javatests/src/org/chromium/chrome/browser/segmentation_platform/SegmentationPlatformServiceFactoryTest.java

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package org.chromium.chrome.browser.segmentation_platform;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;

import static org.chromium.base.test.util.Batch.PER_CLASS;

import androidx.test.filters.MediumTest;

import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import org.chromium.base.Callback;
import org.chromium.base.ThreadUtils;
import org.chromium.base.library_loader.LibraryLoader;
import org.chromium.base.test.BaseJUnit4ClassRunner;
import org.chromium.base.test.util.Batch;
import org.chromium.base.test.util.CallbackHelper;
import org.chromium.base.test.util.CommandLineFlags;
import org.chromium.chrome.browser.flags.ChromeSwitches;
import org.chromium.chrome.browser.profiles.ProfileManager;
import org.chromium.chrome.test.ChromeTabbedActivityTestRule;
import org.chromium.components.segmentation_platform.ClassificationResult;
import org.chromium.components.segmentation_platform.Constants;
import org.chromium.components.segmentation_platform.InputContext;
import org.chromium.components.segmentation_platform.PredictionOptions;
import org.chromium.components.segmentation_platform.ProcessedValue;
import org.chromium.components.segmentation_platform.SegmentationPlatformService;
import org.chromium.components.segmentation_platform.prediction_status.PredictionStatus;
import org.chromium.url.GURL;

import java.util.concurrent.TimeoutException;

@RunWith(BaseJUnit4ClassRunner.class)
@CommandLineFlags.Add({ChromeSwitches.DISABLE_FIRST_RUN_EXPERIENCE})
@Batch(value = PER_CLASS)
public class SegmentationPlatformServiceFactoryTest {

    @Rule
    public ChromeTabbedActivityTestRule mActivityTestRule = new ChromeTabbedActivityTestRule();

    private CallbackHelper mCallbackHelper = new CallbackHelper();

    @Test
    @MediumTest
    public void testGetClassificationResult_withNullInputContext() throws TimeoutException {
        LibraryLoader.getInstance().ensureInitialized();
        mActivityTestRule.startMainActivityOnBlankPage();

        ThreadUtils.runOnUiThreadBlocking(
                new Runnable() {
                    @Override
                    public void run() {
                        SegmentationPlatformService segmentationPlatformService =
                                SegmentationPlatformServiceFactory.getForProfile(
                                        ProfileManager.getLastUsedRegularProfile());

                        PredictionOptions options = new PredictionOptions(true);
                        segmentationPlatformService.getClassificationResult(
                                "intentional_user",
                                options,
                                null,
                                new Callback<ClassificationResult>() {
                                    @Override
                                    public void onResult(ClassificationResult result) {
                                        Assert.assertEquals(
                                                PredictionStatus.NOT_READY, result.status);
                                        assertThat(result.orderedLabels, Matchers.empty());
                                        mCallbackHelper.notifyCalled();
                                    }
                                });
                    }
                });

        mCallbackHelper.waitForNext();
    }

    @Test
    @MediumTest
    public void testGetClassificationResult_withOnDemandModel() throws TimeoutException {
        LibraryLoader.getInstance().ensureInitialized();
        mActivityTestRule.startMainActivityOnBlankPage();

        ThreadUtils.runOnUiThreadBlocking(
                new Runnable() {
                    @Override
                    public void run() {
                        SegmentationPlatformService segmentationPlatformService =
                                SegmentationPlatformServiceFactory.getForProfile(
                                        ProfileManager.getLastUsedRegularProfile());

                        PredictionOptions options = PredictionOptions.forOndemand(false);
                        InputContext inputContext = new InputContext();
                        inputContext.addEntry(
                                Constants.CONTEXTUAL_PAGE_ACTIONS_PRICE_TRACKING_INPUT,
                                ProcessedValue.fromFloat(1.0f));
                        inputContext.addEntry(
                                Constants.CONTEXTUAL_PAGE_ACTIONS_READER_MODE_INPUT,
                                ProcessedValue.fromFloat(0.0f));
                        inputContext.addEntry(
                                Constants.CONTEXTUAL_PAGE_ACTIONS_PRICE_INSIGHTS_INPUT,
                                ProcessedValue.fromFloat(0.0f));
                        inputContext.addEntry("url", ProcessedValue.fromGURL(GURL.emptyGURL()));

                        segmentationPlatformService.getClassificationResult(
                                "contextual_page_actions",
                                options,
                                inputContext,
                                new Callback<ClassificationResult>() {
                                    @Override
                                    public void onResult(ClassificationResult result) {
                                        Assert.assertEquals(
                                                PredictionStatus.SUCCEEDED, result.status);
                                        assertThat(result.orderedLabels, hasSize(1));
                                        assertThat(
                                                result.orderedLabels, contains("price_tracking"));
                                        mCallbackHelper.notifyCalled();
                                    }
                                });
                    }
                });

        mCallbackHelper.waitForNext();
    }
}