chromium/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java

/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.lite.support.audio;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import android.media.AudioFormat;
import android.media.AudioRecord;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;

/** Test for {@link TensorAudio}. */
@RunWith(Suite.class)
@SuiteClasses({
  TensorAudioTest.General.class,
})
public class TensorAudioTest {

  /** General tests of TensorAudio. */
  @RunWith(RobolectricTestRunner.class)
  public static final class General extends TensorAudioTest {
    @Test
    public void createSucceedsWithTensorAudioFormat() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(
              TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
      assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
      assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
      assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
    }

    @Test
    public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(
              TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
      assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
      assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
      assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
    }

    @Test
    public void createSucceededsWithDefaultArguments() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
      // Number of channels defaults to 1.
      assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
      assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
      assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
    }

    @Test
    public void createSucceedsWithAudioFormat() throws Exception {
      AudioFormat format =
          new AudioFormat.Builder()
              .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
              .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
              .setSampleRate(16000)
              .build();
      TensorAudio tensor = TensorAudio.create(format, 100);
      // STEREO has 2 channels
      assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
      assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
      // flatSize = channelCount * sampleCount
      assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
    }

    @Test
    public void createFailedWithInvalidSampleRate() throws Exception {
      IllegalArgumentException exception =
          assertThrows(
              IllegalArgumentException.class,
              () -> TensorAudio.create(TensorAudioFormat.builder().setSampleRate(0).build(), 100));
      // Sample rate 0 is not allowed
      assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
    }

    @Test
    public void createFailedWithInvalidChannels() throws Exception {
      IllegalArgumentException exception =
          assertThrows(
              IllegalArgumentException.class,
              () ->
                  TensorAudio.create(
                      TensorAudioFormat.builder().setSampleRate(1).setChannels(-1).build(), 100));
      // Negative channels is not allowed
      assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
    }

    @Test
    public void loadSucceedsFromArray() throws Exception {
      TensorAudioFormat format =
          TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
      TensorAudio tensor = TensorAudio.create(format, 2);
      assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);

      tensor.load(new float[] {2.f, 0});
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .usingTolerance(0.001f)
          .containsExactly(new float[] {0, 0, 2.f, 0});

      tensor.load(new float[] {2.f, 3.f}, 0, 2);
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .usingTolerance(0.001f)
          .containsExactly(new float[] {2.f, 0, 2.f, 3.f});

      tensor.load(new float[] {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}, 1, 6);
      // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
      // numbers (index 3 to 6) of the load target sub-sequence (index 1 to 6).
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .usingTolerance(0.001f)
          .containsExactly(new float[] {5.f, 6.f, 7.f, 8.f});

      tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .usingTolerance(0.001f)
          .containsExactly(new float[] {7.f, 8.f, 1.f, -1.f});

      tensor.load(new short[] {1000, 2000, 3000, 0, 1000, Short.MIN_VALUE, 4000, 5000, 6000}, 3, 6);
      // The sequence is longer than the ring buffer size so it's expected to keep only the last 4
      // numbers.
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .usingTolerance(0.001f)
          .containsExactly(
              new float[] {
                -1.f, 4000.f / Short.MAX_VALUE, 5000.f / Short.MAX_VALUE, 6000.f / Short.MAX_VALUE
              });
    }

    @Test
    public void loadFailsWithIndexOutOfRange() throws Exception {
      TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
      TensorAudio tensor = TensorAudio.create(format, 5);

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
    }

    @Test
    public void loadFailsWithIncompatibleInputSize() throws Exception {
      TensorAudioFormat format =
          TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
      TensorAudio tensor = TensorAudio.create(format, 5);

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));

      assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
    }

    @Test
    public void loadAudioRecordSucceeds() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
      tensor.load(new float[] {1, 2, 3, 4, 5});
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});

      AudioRecord record = mock(AudioRecord.class);
      when(record.getBufferSizeInFrames()).thenReturn(5);
      when(record.getChannelCount()).thenReturn(1);
      when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
      when(record.getFormat())
          .thenReturn(
              new AudioFormat.Builder()
                  .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
                  .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
                  .setSampleRate(16000)
                  .build());
      // Unused
      when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
      // Used
      when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(1);
      assertThat(tensor.load(record)).isEqualTo(1);
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});

      record = mock(AudioRecord.class);
      when(record.getBufferSizeInFrames()).thenReturn(5);
      when(record.getChannelCount()).thenReturn(1);
      when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
      when(record.getFormat())
          .thenReturn(
              new AudioFormat.Builder()
                  .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
                  .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
                  .setSampleRate(16000)
                  .build());
      // Used
      when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(2);
      // Unused
      when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
      assertThat(tensor.load(record)).isEqualTo(2);
      assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[] {5.f, 0, 0, 0});
    }

    @Test
    public void loadAudioRecordFailsWithErrorState() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
      tensor.load(new float[] {1, 2, 3, 4, 5});
      assertThat(tensor.getTensorBuffer().getFloatArray())
          .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});

      AudioRecord record = mock(AudioRecord.class);
      when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
      when(record.getFormat())
          .thenReturn(
              new AudioFormat.Builder()
                  .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
                  .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
                  .setSampleRate(16000)
                  .build());
      // Unused
      when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
      // Used
      when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
          .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
      IllegalStateException exception =
          assertThrows(IllegalStateException.class, () -> tensor.load(record));
      assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
    }

    @Test
    public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
      AudioRecord record = mock(AudioRecord.class);
      when(record.getFormat())
          .thenReturn(
              new AudioFormat.Builder()
                  .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
                  .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
                  .setSampleRate(16000)
                  .build());
      when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);

      IllegalArgumentException exception =
          assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
      assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
    }

    @Test
    public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
      TensorAudio tensor =
          TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
      AudioRecord record = mock(AudioRecord.class);
      when(record.getFormat())
          .thenReturn(
              new AudioFormat.Builder()
                  .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
                  .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
                  .setSampleRate(44100) // Mismatch
                  .build());

      IllegalArgumentException exception =
          assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
      assertThat(exception).hasMessageThat().ignoringCase().contains("Incompatible audio format");
    }
  }
}