chromium/components/ip_protection/android/android_auth_client_lib/client/java/src/org/chromium/components/ip_protection_auth/IpProtectionAuthClient.java

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package org.chromium.components.ip_protection_auth;

import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.ServiceConnection;
import android.content.pm.PackageManager;
import android.os.IBinder;
import android.os.RemoteException;

import androidx.annotation.GuardedBy;
import androidx.annotation.IntDef;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import org.jni_zero.CalledByNative;
import org.jni_zero.CalledByNativeForTesting;
import org.jni_zero.JNINamespace;

import org.chromium.base.ContextUtils;
import org.chromium.base.Log;
import org.chromium.base.ThreadUtils;
import org.chromium.base.TraceEvent;
import org.chromium.components.ip_protection_auth.common.IErrorCode;
import org.chromium.components.ip_protection_auth.common.IIpProtectionAuthAndSignCallback;
import org.chromium.components.ip_protection_auth.common.IIpProtectionAuthService;
import org.chromium.components.ip_protection_auth.common.IIpProtectionGetInitialDataCallback;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.HashSet;
import java.util.Set;

/** Client interface for the IP Protection Auth service. */
@JNINamespace("ip_protection::android")
public final class IpProtectionAuthClient implements AutoCloseable {
    private static final String TAG = "IppAuthClient";

    private static final String IP_PROTECTION_AUTH_ACTION =
            "android.net.http.IpProtectionAuthService";

    // mService being null signifies that the object has been closed by calling close().
    @Nullable private IIpProtectionAuthService mService;
    // We need to store this to unbind from the service.
    @Nullable private ConnectionSetup mConnection;
    @NonNull final CallbackTracker mCallbackTracker;

    // These values must be kept in sync with AuthRequestError in
    // ip_protection_auth_client_interface.h
    @Retention(RetentionPolicy.SOURCE)
    @IntDef({AuthRequestError.TRANSIENT, AuthRequestError.PERSISTENT, AuthRequestError.OTHER})
    public @interface AuthRequestError {
        int TRANSIENT = 0;
        int PERSISTENT = 1;
        int OTHER = 2;
    }

    /**
     * Convert IErrorCode value to AuthRequestError value.
     *
     * <p>Invalid values are mapped to AuthRequestError.OTHER.
     */
    private static int convertErrorCodeToAuthRequestError(int errorCode) {
        switch (errorCode) {
            case IErrorCode.IP_PROTECTION_AUTH_SERVICE_TRANSIENT_ERROR:
                return AuthRequestError.TRANSIENT;
            case IErrorCode.IP_PROTECTION_AUTH_SERVICE_PERSISTENT_ERROR:
                return AuthRequestError.PERSISTENT;
            default:
                return AuthRequestError.OTHER;
        }
    }

    /** This class must be used exclusively from the main thread. */
    private static final class ConnectionSetup implements ServiceConnection {
        @Nullable private IpProtectionAuthServiceCallback mCallback;
        private final Context mContext;
        private IpProtectionAuthClient mIpProtectionClient;
        private boolean mBound;

        ConnectionSetup(
                @NonNull Context context, @NonNull IpProtectionAuthServiceCallback callback) {
            mContext = context;
            mCallback = callback;
            mIpProtectionClient = null;
            mBound = true;
        }

        @Override
        public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
            try (TraceEvent event =
                    TraceEvent.scoped("IpProtectionAuthClient.Create.OnServiceConnected")) {
                // In some odd cases (b/357770633), this lifecycle method is triggered even
                // after the service is unbound. We ensure that the callback cannot be invoked more
                // than once from this object if it so happens.
                if (mCallback == null) {
                    return;
                }
                mIpProtectionClient =
                        new IpProtectionAuthClient(
                                this, IIpProtectionAuthService.Stub.asInterface(iBinder));
                mCallback.onResult(mIpProtectionClient);
                mCallback = null;
            }
        }

        @Override
        public void onServiceDisconnected(ComponentName componentName) {
            try (TraceEvent event =
                    TraceEvent.scoped("IpProtectionAuthClient.Create.OnServiceDisconnected")) {
                unbindIfBound();
                mIpProtectionClient.mCallbackTracker.rejectUnresolvedCallbacks(
                        AuthRequestError.OTHER);
            }
        }

        @Override
        public void onBindingDied(ComponentName name) {
            try (TraceEvent event =
                    TraceEvent.scoped("IpProtectionAuthClient.Create.OnBindingDied")) {
                unbindIfBound();
                mIpProtectionClient.mCallbackTracker.rejectUnresolvedCallbacks(
                        AuthRequestError.OTHER);
            }
        }

        @Override
        public void onNullBinding(ComponentName name) {
            try (TraceEvent event =
                    TraceEvent.scoped("IpProtectionAuthClient.Create.OnNullBinding")) {
                unbindIfBound();
                // We ensure that the callback cannot be invoked more than once from this object.
                if (mCallback == null) {
                    return;
                }
                mCallback.onError("Service returned null from onBind()");
                mCallback = null;
            }
        }

        public void unbindIfBound() {
            ThreadUtils.assertOnUiThread();
            if (mBound) {
                mBound = false;
                mContext.unbindService(this);
            }
        }
    }

    /**
     * Enforces single-run semantics for callbacks across potentially competing threads and provides
     * mechanisms for running callbacks for abandoned requests.
     */
    static final class CallbackTracker {
        // Callbacks must never be run from code holding this lock.
        @NonNull private final Object mUnresolvedCallbacksLock;

        @GuardedBy("mUnresolvedCallbacksLock")
        @NonNull
        final Set<IpProtectionByteArrayCallback> mUnresolvedCallbacks;

        /**
         * A thread-safe wrapper around a callback that enforces that it's used (at most) once.
         *
         * <p>Note that this only enforces thread-safety for deciding whether a callback should be
         * run (in the case where there are threads competing to call the same callback). It does
         * not enforce any thread-safety for the underlying callback's logic.
         */
        public final class TrackedCallback implements IpProtectionByteArrayCallback {
            @NonNull private final IpProtectionByteArrayCallback mInner;

            // See CallbackTracker.wrapCallback.
            public TrackedCallback(@NonNull IpProtectionByteArrayCallback inner) {
                mInner = inner;
                synchronized (mUnresolvedCallbacksLock) {
                    mUnresolvedCallbacks.add(inner);
                }
            }

            @Override
            public void onResult(@NonNull byte[] result) {
                IpProtectionByteArrayCallback callback = unwrap();
                if (callback != null) {
                    callback.onResult(result);
                }
            }

            @Override
            public void onError(int authRequestError) {
                IpProtectionByteArrayCallback callback = unwrap();
                if (callback != null) {
                    callback.onError(authRequestError);
                }
            }

            /**
             * Extract the wrapped callback.
             *
             * <p>The previously wrapped callback will no longer be tracked.
             *
             * @return The unwrapped callback if the callback has not bean consumed yet (unwrapped
             *     or triggered), or null if the callback has already been consumed. Null does not
             *     necessarily indicate a programming error - merely that the callback was
             *     resolved/rejected due to some other mechanism first, perhaps on a competing
             *     thread.
             */
            @Nullable
            public IpProtectionByteArrayCallback unwrap() {
                final boolean isUnresolved;
                synchronized (mUnresolvedCallbacksLock) {
                    isUnresolved = mUnresolvedCallbacks.remove(mInner);
                }
                if (!isUnresolved) {
                    Log.w(TAG, "callback already used");
                    return null;
                }
                return mInner;
            }
        }

        public CallbackTracker() {
            mUnresolvedCallbacksLock = new Object();
            mUnresolvedCallbacks = new HashSet<>();
        }

        /**
         * Wraps a callback in a thread-safe wrapper that ensures the callback is used only once.
         *
         * <p>The wrapped callback should be considered owned by the callback tracker and thus only
         * be used via the wrapper or the callback tracker.
         */
        @NonNull
        public TrackedCallback wrapCallback(@NonNull IpProtectionByteArrayCallback callback) {
            return new TrackedCallback(callback);
        }

        /**
         * Unwraps the callbacks from all unresolved tracked callbacks and returns the unwrapped
         * callbacks.
         */
        private IpProtectionByteArrayCallback[] unwrapAllCallbacks() {
            final IpProtectionByteArrayCallback[] callbacks;
            synchronized (mUnresolvedCallbacksLock) {
                callbacks = mUnresolvedCallbacks.toArray(new IpProtectionByteArrayCallback[0]);
                mUnresolvedCallbacks.clear();
            }
            return callbacks;
        }

        /** Call onError for all unresolved callbacks. */
        public void rejectUnresolvedCallbacks(int authRequestError) {
            for (IpProtectionByteArrayCallback callback : unwrapAllCallbacks()) {
                callback.onError(authRequestError);
            }
        }
    }

    private final class IIpProtectionGetInitialDataCallbackStub
            extends IIpProtectionGetInitialDataCallback.Stub {
        private final CallbackTracker.TrackedCallback mCallback;

        IIpProtectionGetInitialDataCallbackStub(CallbackTracker.TrackedCallback callback) {
            mCallback = callback;
        }

        @Override
        public void reportResult(byte[] bytes) {
            if (bytes == null) {
                Log.e(TAG, "null getInitialData response");
                mCallback.onError(AuthRequestError.OTHER);
                return;
            }
            mCallback.onResult(bytes);
        }

        @Override
        public void reportError(int errorCode) {
            mCallback.onError(convertErrorCodeToAuthRequestError(errorCode));
        }
    }

    private final class IIpProtectionAuthAndSignCallbackStub
            extends IIpProtectionAuthAndSignCallback.Stub {
        private final CallbackTracker.TrackedCallback mCallback;

        IIpProtectionAuthAndSignCallbackStub(CallbackTracker.TrackedCallback callback) {
            mCallback = callback;
        }

        @Override
        public void reportResult(byte[] bytes) {
            if (bytes == null) {
                Log.e(TAG, "null authAndSign response");
                mCallback.onError(AuthRequestError.OTHER);
                return;
            }
            mCallback.onResult(bytes);
        }

        @Override
        public void reportError(int errorCode) {
            mCallback.onError(convertErrorCodeToAuthRequestError(errorCode));
        }
    }

    IpProtectionAuthClient(
            @NonNull ConnectionSetup connectionSetup,
            @NonNull IIpProtectionAuthService ipProtectionAuthService) {
        mConnection = connectionSetup;
        mService = ipProtectionAuthService;
        mCallbackTracker = new CallbackTracker();
    }

    @CalledByNativeForTesting
    public static void createConnectedInstanceForTesting(
            @NonNull String mockServicePackageName,
            @NonNull String mockServiceClassName,
            @NonNull IpProtectionAuthServiceCallback callback) {
        Intent intent = new Intent(IP_PROTECTION_AUTH_ACTION);
        intent.setClassName(mockServicePackageName, mockServiceClassName);
        createConnectedInstanceCommon(intent, PackageManager.MATCH_DISABLED_COMPONENTS, callback);
    }

    @CalledByNative
    public static void createConnectedInstance(@NonNull IpProtectionAuthServiceCallback callback) {
        // Use IP_PROTECTION_AUTH_ACTION to resolve system service that satisfies
        // the intent, going from implicit to explicit intent in createConnectedInstanceCommon.
        var intent = new Intent(IP_PROTECTION_AUTH_ACTION);
        createConnectedInstanceCommon(intent, PackageManager.MATCH_SYSTEM_ONLY, callback);
    }

    /**
     * Converts the given intent to an explicit intent and binds to the service.
     *
     * <p>The intent is converted to an explicit intent by resolving via the PackageManager. The
     * callback is called with either an IpProtectionAuthClient bound to the service or an error
     * string.
     */
    private static void createConnectedInstanceCommon(
            @NonNull Intent intent,
            int resolveFlags,
            @NonNull IpProtectionAuthServiceCallback callback) {
        var context = ContextUtils.getApplicationContext();
        var packageManager = context.getPackageManager();
        // When Chromium moves to API level 33 as minimum-supported version,
        // we can switch resolveService to use PackageManager.ResolveInfoFlags.
        var resolveInfo = packageManager.resolveService(intent, resolveFlags);
        if (resolveInfo == null || resolveInfo.serviceInfo == null) {
            final String error =
                    "Unable to locate the IP Protection authentication provider package ("
                            + intent.getAction()
                            + " action). This is expected if the host system is not set up to"
                            + " provide IP Protection services.";
            ThreadUtils.postOnUiThread(
                    () -> {
                        try (TraceEvent event =
                                TraceEvent.scoped(
                                        "IpProtectionAuthClient.Create.FailResolveInfo")) {
                            callback.onError(error);
                        }
                    });
            return;
        }
        var serviceInfo = resolveInfo.serviceInfo;
        var componentName = new ComponentName(serviceInfo.packageName, serviceInfo.name);
        intent.setComponent(componentName);
        ThreadUtils.postOnUiThread(
                () -> {
                    try (TraceEvent event =
                            TraceEvent.scoped("IpProtectionAuthClient.Create.TryBind")) {
                        var connectionSetup = new ConnectionSetup(context, callback);
                        try {
                            boolean binding =
                                    context.bindService(
                                            intent, connectionSetup, Context.BIND_AUTO_CREATE);
                            if (!binding) {
                                connectionSetup.unbindIfBound();
                                callback.onError("bindService() failed: returned false");
                                return;
                            }
                        } catch (SecurityException e) {
                            connectionSetup.unbindIfBound();
                            callback.onError("Failed to bind service: " + e);
                            return;
                        }
                    }
                });
    }

    // See documentation in native IpProtectionAuthClient::GetInitialData
    @CalledByNative
    public void getInitialData(byte[] request, IpProtectionByteArrayCallback callback) {
        try (TraceEvent event =
                TraceEvent.scoped("IpProtectionAuthClient.Request.GetInitialData")) {
            if (mService == null) {
                // This denotes a coding error by the caller so it makes sense to throw an
                // unchecked exception.
                throw new IllegalStateException("Already closed");
            }
            CallbackTracker.TrackedCallback trackedCallback =
                    mCallbackTracker.wrapCallback(callback);
            IIpProtectionGetInitialDataCallbackStub callbackStub =
                    new IIpProtectionGetInitialDataCallbackStub(trackedCallback);
            try {
                mService.getInitialData(request, callbackStub);
            } catch (RuntimeException | RemoteException ex) {
                Log.e(TAG, "error calling getInitialData", ex);
                trackedCallback.onError(AuthRequestError.OTHER);
            }
        }
    }

    // See documentation in native IpProtectionAuthClient::AuthAndSign
    @CalledByNative
    public void authAndSign(byte[] request, IpProtectionByteArrayCallback callback) {
        try (TraceEvent event = TraceEvent.scoped("IpProtectionAuthClient.Request.AuthAndSign")) {
            if (mService == null) {
                // This denotes a coding error by the caller so it makes sense to throw an
                // unchecked exception.
                throw new IllegalStateException("Already closed");
            }
            CallbackTracker.TrackedCallback trackedCallback =
                    mCallbackTracker.wrapCallback(callback);
            IIpProtectionAuthAndSignCallbackStub callbackStub =
                    new IIpProtectionAuthAndSignCallbackStub(trackedCallback);
            try {
                mService.authAndSign(request, callbackStub);
            } catch (RuntimeException | RemoteException ex) {
                Log.e(TAG, "error calling authAndSign", ex);
                trackedCallback.onError(AuthRequestError.OTHER);
            }
        }
    }

    @Override
    @CalledByNative
    public void close() {
        if (mService == null) {
            return;
        }
        ThreadUtils.runOnUiThread(mConnection::unbindIfBound);
        mConnection = null;
        mService = null;
        mCallbackTracker.rejectUnresolvedCallbacks(AuthRequestError.OTHER);
    }
}