chromium/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/sources/TFLAudioClassifier.m

/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
#import "tensorflow_lite_support/ios/task/audio/sources/TFLAudioClassifier.h"
#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h"
#import "tensorflow_lite_support/ios/task/audio/core/audio_tensor/utils/sources/TFLAudioTensor+Utils.h"
#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h"

#include "tensorflow_lite_support/c/task/audio/audio_classifier.h"

@interface TFLAudioClassifier () {
  NSInteger _requiredCBufferSize;
  TfLiteAudioFormat *_requiredCAudioFormat;
}
/** Audio Classifier backed by C API */
@property(nonatomic) TfLiteAudioClassifier *audioClassifier;
@end

@implementation TFLAudioClassifierOptions

@synthesize baseOptions;
@synthesize classificationOptions;

- (instancetype)init {
  self = [super init];
  if (self) {
    self.baseOptions = [TFLBaseOptions new];
    self.classificationOptions = [TFLClassificationOptions new];
  }
  return self;
}

- (instancetype)initWithModelPath:(NSString *)modelPath {
  self = [self init];
  if (self) {
    self.baseOptions.modelFile.filePath = modelPath;
  }
  return self;
}

@end

@implementation TFLAudioClassifier
- (void)dealloc {
  TfLiteAudioFormatDelete(_requiredCAudioFormat);
  TfLiteAudioClassifierDelete(_audioClassifier);
}

- (BOOL)populateRequiredBufferSizeWithError:(NSError **)error {
  TfLiteSupportError *requiredBufferSizeError = NULL;
  _requiredCBufferSize =
      TfLiteAudioClassifierGetRequiredInputBufferSize(_audioClassifier, &requiredBufferSizeError);

  // Populate iOS error if C Error is not null and afterwards delete it.
  if (![TFLCommonUtils checkCError:requiredBufferSizeError toError:error]) {
    TfLiteSupportErrorDelete(requiredBufferSizeError);
  }

  return _requiredCBufferSize > 0;
}

- (BOOL)populateRequiredAudioFormatWithError:(NSError **)error {
  TfLiteSupportError *getAudioFormatError = nil;
  _requiredCAudioFormat =
      TfLiteAudioClassifierGetRequiredAudioFormat(_audioClassifier, &getAudioFormatError);

  if (![TFLCommonUtils checkCError:getAudioFormatError toError:error]) {
    TfLiteSupportErrorDelete(getAudioFormatError);
  }

  return _requiredCAudioFormat != NULL;
}

- (nullable instancetype)initWithAudioClassifier:(TfLiteAudioClassifier *)audioClassifier
                                           error:(NSError **)error {
  self = [super init];
  if (self) {
    _audioClassifier = audioClassifier;
    if (![self populateRequiredBufferSizeWithError:error]) {
      return nil;
    }

    if (![self populateRequiredAudioFormatWithError:error]) {
      return nil;
    }
  }
  return self;
}

+ (nullable instancetype)audioClassifierWithOptions:(TFLAudioClassifierOptions *)options
                                              error:(NSError **)error {
  if (!options) {
    [TFLCommonUtils createCustomError:error
                             withCode:TFLSupportErrorCodeInvalidArgumentError
                          description:@"TFLAudioClassifierOptions argument cannot be nil."];
    return nil;
  }

  TfLiteAudioClassifierOptions cOptions = TfLiteAudioClassifierOptionsCreate();

  if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options)
                                               error:error]) {
    [options.classificationOptions
        deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];
    return nil;
  }

  [options.baseOptions copyToCOptions:&(cOptions.base_options)];

  TfLiteSupportError *cCreateClassifierError = NULL;
  TfLiteAudioClassifier *cAudioClassifier =
      TfLiteAudioClassifierFromOptions(&cOptions, &cCreateClassifierError);

  [options.classificationOptions
      deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)];

  // Populate iOS error if TfliteSupportError is not null and afterwards delete it.
  if (![TFLCommonUtils checkCError:cCreateClassifierError toError:error]) {
    TfLiteSupportErrorDelete(cCreateClassifierError);
  }

  // Return nil if classifier evaluates to nil. If an error was generted by the C layer, it has
  // already been populated to an NSError and deleted before returning from the method.
  if (!cAudioClassifier) {
    return nil;
  }

  return [[TFLAudioClassifier alloc] initWithAudioClassifier:cAudioClassifier error:error];
}

- (nullable TFLClassificationResult *)classifyWithAudioTensor:(TFLAudioTensor *)audioTensor
                                                        error:(NSError **)error {
  if (!audioTensor) {
    [TFLCommonUtils createCustomError:error
                             withCode:TFLSupportErrorCodeInvalidArgumentError
                          description:@"audioTensor argument cannot be nil."];
    return nil;
  }

  TFLFloatBuffer *audioTensorBuffer = audioTensor.buffer;
  TfLiteAudioBuffer cAudioBuffer = [audioTensor cAudioBufferFromFloatBuffer:audioTensorBuffer];

  TfLiteSupportError *classifyError = NULL;
  TfLiteClassificationResult *cClassificationResult =
      TfLiteAudioClassifierClassify(_audioClassifier, &cAudioBuffer, &classifyError);

  // Populate iOS error if C Error is not null and afterwards delete it.
  if (![TFLCommonUtils checkCError:classifyError toError:error]) {
    TfLiteSupportErrorDelete(classifyError);
  }

  // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has
  // already been populated to an NSError and deleted before returning from the method.
  if (!cClassificationResult) {
    return nil;
  }

  TFLClassificationResult *classificationResult =
      [TFLClassificationResult classificationResultWithCResult:cClassificationResult];
  TfLiteClassificationResultDelete(cClassificationResult);

  return classificationResult;
}

- (TFLAudioTensor *)createInputAudioTensor {
  TFLAudioFormat *format =
      [[TFLAudioFormat alloc] initWithChannelCount:_requiredCAudioFormat->channels
                                        sampleRate:_requiredCAudioFormat->sample_rate];
  return [[TFLAudioTensor alloc] initWithAudioFormat:format
                                         sampleCount:_requiredCBufferSize / format.channelCount];
}

- (TFLAudioRecord *)createAudioRecordWithError:(NSError **)error {
  // The sample count of audio record should be strictly longer than audio tensor's so that
  // clients could run TFLAudioRecord's `startRecordingWithError:`
  // together with TFLAudioClassifier's `classifyWithAudioTensor:error:`
  TFLAudioFormat *format =
      [[TFLAudioFormat alloc] initWithChannelCount:_requiredCAudioFormat->channels
                                        sampleRate:_requiredCAudioFormat->sample_rate];

  return [[TFLAudioRecord alloc] initWithAudioFormat:format
                                          bufferSize:_requiredCBufferSize * 2
                                               error:error];
}

@end