chromium/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/tensor_audio_test.py

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensor_audio."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf

import unittest
from tensorflow_lite_support.python.task.audio.core import audio_record
from tensorflow_lite_support.python.task.audio.core import tensor_audio
from tensorflow_lite_support.python.task.audio.core.pybinds import _pywrap_audio_buffer
from tensorflow_lite_support.python.test import test_util

_CppAudioFormat = _pywrap_audio_buffer.AudioFormat
_mock = unittest.mock

_CHANNELS = 1
_SAMPLE_RATE = 16000
_BUFFER_SIZE = 15600


class TensorAudioTest(parameterized.TestCase, tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.test_audio_path = test_util.get_test_data_path("speech.wav")
    self.test_tensor_audio = tensor_audio.TensorAudio(
        _CppAudioFormat(_CHANNELS, _SAMPLE_RATE), _BUFFER_SIZE)

  def test_create_from_wav_file_succeeds(self):
    # Loads TensorAudio object from WAV file.
    tensor = tensor_audio.TensorAudio.create_from_wav_file(
        self.test_audio_path, _BUFFER_SIZE)
    tensor_audio_format = tensor.format

    self.assertEqual(tensor_audio_format.channels, _CHANNELS)
    self.assertEqual(tensor_audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(tensor.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(tensor.buffer, np.ndarray)
    self.assertAlmostEqual(tensor.buffer[-1], -0.09640503)

  def test_create_from_wav_file_fails_with_empty_file_path(self):
    # Fails loading TensorAudio object from WAV file.
    with self.assertRaisesRegex(
        ValueError,
        "Data too short when trying to read string"):
      tensor_audio.TensorAudio.create_from_wav_file("", _BUFFER_SIZE)

  def test_load_from_array_succeeds_with_input_size_matches_buffer_size(self):
    # Loads audio data from a NumPy array.
    array = np.random.rand(_BUFFER_SIZE, _CHANNELS).astype(np.float32)
    self.test_tensor_audio.load_from_array(array)

    audio_buffer = self.test_tensor_audio.buffer
    audio_format = self.test_tensor_audio.format

    self.assertEqual(audio_format.channels, _CHANNELS)
    self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(audio_buffer, np.ndarray)
    self.assertAllClose(audio_buffer, array)

  def test_load_from_array_succeeds_with_larger_input_size_and_default_params(
      self):
    # Loads audio data from a NumPy array.
    array = np.random.rand(_BUFFER_SIZE * 2, _CHANNELS).astype(np.float32)
    self.test_tensor_audio.load_from_array(array)

    audio_buffer = self.test_tensor_audio.buffer
    audio_format = self.test_tensor_audio.format

    self.assertEqual(audio_format.channels, _CHANNELS)
    self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(audio_buffer, np.ndarray)
    self.assertAllClose(audio_buffer, array[_BUFFER_SIZE:])

  @parameterized.parameters((0, 15600), (7800, 15600))
  def test_load_from_array_succeeds_with_larger_input_size_and_params_specified(
      self, offset, size):
    # Loads audio data from a NumPy array.
    array = np.random.rand(_BUFFER_SIZE * 2, _CHANNELS).astype(np.float32)
    self.test_tensor_audio.load_from_array(array, offset, size)

    audio_buffer = self.test_tensor_audio.buffer
    audio_format = self.test_tensor_audio.format

    self.assertEqual(audio_format.channels, _CHANNELS)
    self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(audio_buffer, np.ndarray)
    self.assertAllClose(audio_buffer, array[offset:offset + size])

  def test_load_from_array_succeeds_with_smaller_input_size_and_default_params(
      self):
    # Loads audio data from a NumPy array.
    input_length = _BUFFER_SIZE // 3
    array = np.random.rand(input_length, _CHANNELS).astype(np.float32)
    self.test_tensor_audio.load_from_array(array)

    audio_buffer = self.test_tensor_audio.buffer
    audio_format = self.test_tensor_audio.format

    self.assertEqual(audio_format.channels, _CHANNELS)
    self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(audio_buffer, np.ndarray)
    self.assertAllClose(audio_buffer[-input_length:], array)

  @parameterized.parameters((0, 4000), (3900, 100))
  def test_load_from_array_succeeds_with_smaller_input_size_and_params_specified(
      self, offset, size):
    # Loads audio data from a NumPy array.
    array = np.random.rand(_BUFFER_SIZE // 3, _CHANNELS).astype(np.float32)
    self.test_tensor_audio.load_from_array(array, offset, size)

    audio_buffer = self.test_tensor_audio.buffer
    audio_format = self.test_tensor_audio.format

    self.assertEqual(audio_format.channels, _CHANNELS)
    self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE)
    self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE)
    self.assertIsInstance(audio_buffer, np.ndarray)
    self.assertAllClose(audio_buffer[-size:], array[offset:offset + size])

  @parameterized.parameters((7800, 15600), (0, 20000))
  def test_load_from_array_fails_with_invalid_offset_size(self, offset, size):
    # Fails loading audio data from a NumPy array with an invalid
    # config for offset and size.
    array = np.random.rand(_BUFFER_SIZE, _CHANNELS).astype(np.float32)
    with self.assertRaisesRegex(
        ValueError,
        rf"Index out of range. offset {offset if offset else 0} \+ size "
        f"{size if size else len(self.test_tensor_audio.buffer)} should be "
        f"<= src\'s length: {len(array)}"):
      self.test_tensor_audio.load_from_array(array, offset, size)

  def test_load_from_array_fails_with_invalid_number_of_channels(self):
    # Fails loading audio data from a NumPy array with an invalid
    # number of input channels.
    with self.assertRaisesRegex(
        ValueError, f"Input audio contains an invalid number of channels. "
        f"Expect {_CHANNELS}."):
      array = np.random.rand(_BUFFER_SIZE, 2).astype(np.float32)
      self.test_tensor_audio.load_from_array(array)

  @_mock.patch("sounddevice.InputStream", return_value=_mock.MagicMock())
  def test_load_from_audio_record(self, mock_input_stream):
    record = audio_record.AudioRecord(_CHANNELS, _SAMPLE_RATE, _BUFFER_SIZE)

    # Get AudioRecord's audio callback function.
    _, mock_input_stream_init_args = mock_input_stream.call_args
    callback_fn = mock_input_stream_init_args["callback"]

    # Create dummy data to feed to the AudioRecord instance.
    chunk_size = int(_BUFFER_SIZE * 0.5)
    input_data = []
    for _ in range(3):
      dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
      input_data.append(dummy_data)
      callback_fn(dummy_data)
    expected_data = np.concatenate(input_data[-2:])

    # Load audio data into TensorAudio from the AudioRecord instance.
    self.test_tensor_audio.load_from_audio_record(record)

    # Assert read all data in the float buffer.
    self.assertAllClose(self.test_tensor_audio.buffer, expected_data)

  @_mock.patch("sounddevice.InputStream", return_value=_mock.MagicMock())
  def test_load_from_audio_record_fails_with_invalid_buffer_size(self, _):
    # Fails loading audio data from an AudioRecord instance having
    # a buffer size less than that of TensorAudio.
    with self.assertRaisesRegex(
        ValueError,
        "The audio record's buffer size cannot be smaller than the tensor "
        "audio's sample count."):
      record = audio_record.AudioRecord(_CHANNELS, _SAMPLE_RATE, 10000)
      self.test_tensor_audio.load_from_audio_record(record)

  @_mock.patch("sounddevice.InputStream", return_value=_mock.MagicMock())
  def test_load_from_audio_record_fails_with_invalid_number_of_channels(
      self, _):
    # Fails loading audio data from an AudioRecord instance having
    # an invalid number of channels.
    with self.assertRaisesRegex(
        ValueError, f"The audio record's channel count doesn't match. "
        rf"Expects {_CHANNELS} channel\(s\)."):
      record = audio_record.AudioRecord(2, _SAMPLE_RATE, _BUFFER_SIZE)
      self.test_tensor_audio.load_from_audio_record(record)

  @_mock.patch("sounddevice.InputStream", return_value=_mock.MagicMock())
  def test_load_from_audio_record_fails_with_invalid_sample_rate(self, _):
    # Fails loading audio data from an AudioRecord instance having
    # an invalid sample rate.
    with self.assertRaisesRegex(
        ValueError, f"The audio record's sampling rate doesn't match. "
        f"Expects {_SAMPLE_RATE}Hz."):
      record = audio_record.AudioRecord(_CHANNELS, 20000, _BUFFER_SIZE)
      self.test_tensor_audio.load_from_audio_record(record)


if __name__ == "__main__":
  tf.test.main()