chromium/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/audio_classifier/TFLAudioClassifierTests.m

/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/task/audio/sources/TFLAudioClassifier.h"
#import "tensorflow_lite_support/ios/test/task/audio/core/audio_record/utils/sources/AVAudioPCMBuffer+Utils.h"

#define VerifyError(error, expectedDomain, expectedCode, expectedLocalizedDescription)  \
  XCTAssertNotNil(error);                                                               \
  XCTAssertEqualObjects(error.domain, expectedDomain);                                  \
  XCTAssertEqual(error.code, expectedCode);                                             \
  XCTAssertNotEqual(                                                                    \
      [error.localizedDescription rangeOfString:expectedLocalizedDescription].location, \
      NSNotFound)

#define VerifyCategory(category, expectedIndex, expectedScore, expectedLabel, expectedDisplayName) \
  XCTAssertEqual(category.index, expectedIndex);                                                   \
  XCTAssertEqualWithAccuracy(category.score, expectedScore, 1e-6);                                 \
  XCTAssertEqualObjects(category.label, expectedLabel);                                            \
  XCTAssertEqualObjects(category.displayName, expectedDisplayName);

#define VerifyClassifications(classifications, expectedHeadIndex, expectedCategoryCount) \
  XCTAssertEqual(classifications.categories.count, expectedCategoryCount);               \
  XCTAssertEqual(classifications.headIndex, expectedHeadIndex)

#define VerifyClassificationResult(classificationResult, expectedClassificationsCount) \
  XCTAssertNotNil(classificationResult);                                               \
  XCTAssertEqual(classificationResult.classifications.count, expectedClassificationsCount)

static NSString *const expectedTaskErrorDomain = @"org.tensorflow.lite.tasks";

NS_ASSUME_NONNULL_BEGIN

@interface TFLAudioClassifierTests : XCTestCase
@property(nonatomic, nullable) NSString *modelPath;
@property(nonatomic) AVAudioFormat *audioEngineFormat;
@end

// This category of TFLAudioRecord is private to the test files. This is needed in order to
// expose the method to load the audio record buffer without calling: -[TFLAudioRecord
// startRecordingWithError:]. This is needed to avoid exposing this method which isn't useful to the
// consumers of the framework.
@interface TFLAudioRecord (Tests)
- (void)convertAndLoadBuffer:(AVAudioPCMBuffer *)buffer
         usingAudioConverter:(AVAudioConverter *)audioConverter;
@end

@implementation TFLAudioClassifierTests

- (void)setUp {
  // Put setup code here. This method is called before the invocation of each test method in the
  // class.
  [super setUp];
  self.modelPath =
      [[NSBundle bundleForClass:self.class] pathForResource:@"yamnet_audio_classifier_with_metadata"
                                                     ofType:@"tflite"];
  XCTAssertNotNil(self.modelPath);

  self.audioEngineFormat = [[AVAudioFormat alloc] initWithCommonFormat:AVAudioPCMFormatFloat32
                                                            sampleRate:48000
                                                              channels:1
                                                           interleaved:NO];
}

- (nullable AVAudioPCMBuffer *)bufferFromFileWithName:(NSString *)name
                                            extension:(NSString *)extension
                                          audioFormat:(TFLAudioFormat *)audioFormat {
  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:name ofType:extension];
  return [AVAudioPCMBuffer loadPCMBufferFromFileWithPath:filePath audioFormat:audioFormat];
}

- (nullable AVAudioPCMBuffer *)bufferFromFileWithName:(NSString *)name
                                            extension:(NSString *)extension
                                     processingFormat:(AVAudioFormat *)processingFormat {
  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:name ofType:extension];
  return [AVAudioPCMBuffer loadPCMBufferFromFileWithPath:filePath
                                        processingFormat:processingFormat];
}

- (void)mockLoadBufferOfAudioRecord:(TFLAudioRecord *)audioRecord {
  // Loading AVAudioPCMBuffer with an array is not currently supported for iOS versions < 15.0.
  // Instead audio samples from a wav file are loaded and converted into the same format
  // of AVAudioEngine's input node to mock the input from the AVAudio Engine.
  AVAudioPCMBuffer *audioEngineBuffer = [self bufferFromFileWithName:@"speech"
                                                           extension:@"wav"
                                                    processingFormat:self.audioEngineFormat];
  XCTAssertNotNil(audioEngineBuffer);

  // Convert the buffer in the audio engine input format to the format with which audio record is
  // intended to output the audio samples. This mocks the internal conversion of audio record when
  // -[TFLAudioRecord startRecording:withError:] is called.
  AVAudioFormat *recordingFormat = [[AVAudioFormat alloc]
      initWithCommonFormat:AVAudioPCMFormatFloat32
                sampleRate:audioRecord.audioFormat.sampleRate
                  channels:(AVAudioChannelCount)audioRecord.audioFormat.channelCount
               interleaved:YES];

  AVAudioConverter *audioConverter = [[AVAudioConverter alloc] initFromFormat:self.audioEngineFormat
                                                                     toFormat:recordingFormat];
  // Convert and load the buffer of `TFLAudioRecord`.
  [audioRecord convertAndLoadBuffer:audioEngineBuffer usingAudioConverter:audioConverter];
}

- (TFLAudioClassifier *)createAudioClassifierWithModelPath:(NSString *)modelPath {
  TFLAudioClassifierOptions *options =
      [[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
                                                                                 error:nil];
  XCTAssertNotNil(audioClassifier);

  return audioClassifier;
}

- (TFLAudioClassifier *)createAudioClassifierWithOptions:(TFLAudioClassifierOptions *)options {
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
                                                                                 error:nil];
  XCTAssertNotNil(audioClassifier);

  return audioClassifier;
}

- (TFLAudioTensor *)createAudioTensorWithAudioClassifier:(TFLAudioClassifier *)audioClassifier {
  // Create the audio tensor using audio classifier.
  TFLAudioTensor *audioTensor = [audioClassifier createInputAudioTensor];
  XCTAssertNotNil(audioTensor);

  return audioTensor;
}

- (void)loadAudioTensor:(TFLAudioTensor *)audioTensor fromWavFileWithName:(NSString *)fileName {
  // Load pcm buffer from file.
  AVAudioPCMBuffer *buffer = [self bufferFromFileWithName:fileName
                                                extension:@"wav"
                                              audioFormat:audioTensor.audioFormat];

  // Get float buffer from pcm buffer.
  TFLFloatBuffer *floatBuffer = buffer.floatBuffer;
  XCTAssertNotNil(floatBuffer);

  // Load float buffer into the audio tensor.
  [audioTensor loadBuffer:floatBuffer offset:0 size:floatBuffer.size error:nil];
}

- (TFLClassificationResult *)classifyWithAudioClassifier:(TFLAudioClassifier *)audioClassifier
                                             audioTensor:(TFLAudioTensor *)audioTensor
                                   expectedCategoryCount:(const NSInteger)expectedCategoryCount {
  TFLClassificationResult *classificationResult =
      [audioClassifier classifyWithAudioTensor:audioTensor error:nil];

  const NSInteger expectedClassificationsCount = 1;
  VerifyClassificationResult(classificationResult, expectedClassificationsCount);

  const NSInteger expectedHeadIndex = 0;
  VerifyClassifications(classificationResult.classifications[0], expectedHeadIndex,
                        expectedCategoryCount);

  return classificationResult;
}

- (void)validateClassificationResultForInferenceWithFloatBuffer:
    (NSArray<TFLCategory *> *)categories {
  VerifyCategory(categories[0],
                 0,          // expectedIndex
                 0.957031,   // expectedScore
                 @"Speech",  // expectedLabel
                 nil         // expectedDisplaName
  );
  VerifyCategory(categories[1],
                 500,  // expectedIndex
                                         0.019531,               // expectedScore
                 @"Inside, small room",  // expectedLabel
                 nil                     // expectedDisplaName
  );
}

- (void)validateCategoriesForInferenceWithAudioRecord:(NSArray<TFLCategory *> *)categories {
  // The third category is different from the third category specified in -[TFLAudioClassifierTests
  // validateClassificationResultForInferenceWithFloatBuffer]. This is because in case of inference
  // with audio record involves more native internal conversions to mock the conversions done by the
  // audio record as opposed to inference with float buffer where the number of native conversions
  // are fewer. Since each native conversion by `AVAudioConverter` employ strategies to pick samples
  // based on the format specified, the samples passed in for inference in case of float buffer and
  // audio record will be slightly different.
  VerifyCategory(categories[0],
                 0,          // expectedIndex
                 0.957031,   // expectedScore
                 @"Speech",  // expectedLabel
                 nil         // expectedDisplaName
  );
  VerifyCategory(categories[1],
                 500,  // expectedIndex
                 0.019531,               // expectedScore
                 @"Inside, small room",  // expectedLabel
                 nil                     // expectedDisplaName
  );
}

- (void)validateCategoriesForInferenceWithLabelDenyList:(NSArray<TFLCategory *> *)categories {
  VerifyCategory(categories[0],
                 500,  // expectedIndex
                                         0.019531,               // expectedScore
                 @"Inside, small room",  // expectedLabel
                 nil                     // expectedDisplaName
  );
}

- (TFLAudioRecord *)createAudioRecordWithAudioClassifier:(TFLAudioClassifier *)audioClassifier {
  // Create audio record using audio classifier
  TFLAudioRecord *audioRecord = [audioClassifier createAudioRecordWithError:nil];
  XCTAssertNotNil(audioRecord);

  return audioRecord;
}

- (void)testInferenceWithFloatBufferSucceeds {
  TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithModelPath:self.modelPath];

  TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
  [self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

  const NSInteger expectedCategoryCount = 521;
  TFLClassificationResult *classificationResult =
      [self classifyWithAudioClassifier:audioClassifier
                            audioTensor:audioTensor
                  expectedCategoryCount:expectedCategoryCount];

  [self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
                                                                    .classifications[0]
                                                                    .categories];
}

- (void)testInferenceWithAudioRecordSucceeds {
  TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithModelPath:self.modelPath];
  TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];

  TFLAudioRecord *audioRecord = [self createAudioRecordWithAudioClassifier:audioClassifier];
  // Mocks the loading of the internal buffer of audio record when new samples are input from the
  // mic.
  [self mockLoadBufferOfAudioRecord:audioRecord];
  // Load the audioRecord buffer into the audio tensor.
  [audioTensor loadAudioRecord:audioRecord withError:nil];

  const NSInteger expectedCategoryCount = 521;
  TFLClassificationResult *classificationResult =
      [self classifyWithAudioClassifier:audioClassifier
                            audioTensor:audioTensor
                  expectedCategoryCount:expectedCategoryCount];

  [self validateCategoriesForInferenceWithAudioRecord:classificationResult.classifications[0]
                                                          .categories];
}

- (void)testInferenceWithMaxResultsSucceeds {
  const NSInteger maxResults = 3;
  TFLAudioClassifierOptions *options =
      [[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
  options.classificationOptions.maxResults = maxResults;

  TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithOptions:options];
  TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
  [self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

  TFLClassificationResult *classificationResult = [self classifyWithAudioClassifier:audioClassifier
                                                                        audioTensor:audioTensor
                                                              expectedCategoryCount:maxResults];

  [self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
                                                                    .classifications[0]
                                                                    .categories];
}

- (void)testInferenceWithClassNameAllowListAndDenyListFails {
  TFLAudioClassifierOptions *options =
      [[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
  options.classificationOptions.labelAllowList = @[ @"Speech" ];
  options.classificationOptions.labelDenyList = @[ @"Inside, small room" ];

  NSError *error = nil;
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
                                                                                 error:&error];
  XCTAssertNil(audioClassifier);
  VerifyError(error,
              expectedTaskErrorDomain,                  // expectedErrorDomain
              TFLSupportErrorCodeInvalidArgumentError,  // expectedErrorCode
              @"INVALID_ARGUMENT: `class_name_allowlist` and `class_name_denylist` are mutually "
              @"exclusive options."  // expectedErrorMessage
  );
}

- (void)testInferenceWithLabelAllowListSucceeds {
  TFLAudioClassifierOptions *options =
      [[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
  options.classificationOptions.labelAllowList = @[ @"Speech", @"Inside, small room" ];

  NSError *error = nil;
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
                                                                                 error:&error];
  TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
  [self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

  TFLClassificationResult *classificationResult =
      [self classifyWithAudioClassifier:audioClassifier
                            audioTensor:audioTensor
                  expectedCategoryCount:options.classificationOptions.labelAllowList.count];

  [self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
                                                                    .classifications[0]
                                                                    .categories];
}

- (void)testInferenceWithLabelDenyListSucceeds {
  TFLAudioClassifierOptions *options =
      [[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
  options.classificationOptions.labelDenyList = @[ @"Speech" ];

  NSError *error = nil;
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
                                                                                 error:&error];
  TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
  [self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

  const NSInteger expectedCategoryCount = 520;
  TFLClassificationResult *classificationResult =
      [self classifyWithAudioClassifier:audioClassifier
                            audioTensor:audioTensor
                  expectedCategoryCount:expectedCategoryCount];

  [self validateCategoriesForInferenceWithLabelDenyList:classificationResult.classifications[0]
                                                            .categories];
}

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnonnull"
- (void)testCreateAudioClassifierWithNilOptionsFails {
  NSError *error = nil;
  TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:nil
                                                                                 error:&error];

  XCTAssertNil(audioClassifier);
  VerifyError(error,
              expectedTaskErrorDomain,                              // expectedErrorDomain
              TFLSupportErrorCodeInvalidArgumentError,              // expectedErrorCode
              @"TFLAudioClassifierOptions argument cannot be nil."  // expectedErrorMessage
  );
}

- (void)testInferenceWithNilAudioTensorFails {
  TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithModelPath:self.modelPath];

  NSError *error = nil;
  TFLClassificationResult *classificationResult = [audioClassifier classifyWithAudioTensor:nil
                                                                                     error:&error];

  XCTAssertNil(classificationResult);
  VerifyError(error,
              expectedTaskErrorDomain,                  // expectedErrorDomain
              TFLSupportErrorCodeInvalidArgumentError,  // expectedErrorCode
              @"audioTensor argument cannot be nil."    // expectedErrorMessage
  );
}
#pragma clang diagnostic pop

@end

NS_ASSUME_NONNULL_END