chromium/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m

/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
#import <CoreGraphics/CoreGraphics.h>
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h"
#import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"

#define VerifyColoredLabel(coloredLabel, expectedR, expectedG, expectedB, expectedLabel) \
  XCTAssertEqual(coloredLabel.r, expectedR);                                             \
  XCTAssertEqual(coloredLabel.g, expectedG);                                             \
  XCTAssertEqual(coloredLabel.b, expectedB);                                             \
  XCTAssertEqualObjects(coloredLabel.label, expectedLabel)

// The maximum fraction of pixels in the candidate mask that can have a
// different class than the golden mask for the test to pass.
float const kGoldenMaskTolerance = 1e-2;

// Magnification factor used when creating the golden category masks to make
// them more human-friendly. Each pixel in the golden masks has its value
// multiplied by this factor, i.e. a value of 10 means class index 1, a value of
// 20 means class index 2, etc.
NSInteger const kGoldenMaskMagnificationFactor = 10;

NSInteger const deepLabV3SegmentationWidth = 257;

NSInteger const deepLabV3SegmentationHeight = 257;

@interface TFLImageSegmenterTests : XCTestCase

@property(nonatomic, nullable) NSString *modelPath;

@end

@implementation TFLImageSegmenterTests

- (void)setUp {
  // Put setup code here. This method is called before the invocation of each test method in the
  // class.
  [super setUp];
  self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3"
                                                                  ofType:@"tflite"];
  XCTAssertNotNil(self.modelPath);
}

- (void)compareWithDeepLabV3PartialColoredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels {
  VerifyColoredLabel(coloredLabels[0],
                     0,               // expectedR
                     0,               // expectedG
                     0,               // expectedB
                     @"background");  // expectedLabel

  VerifyColoredLabel(coloredLabels[1],
                     128,          // expectedR
                     0,            // expectedG
                     0,            // expectedB
                     @"aeroplane"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[2],
                     0,          // expectedR
                     128,        // expectedG
                     0,          // expectedB
                     @"bicycle"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[3],
                     128,     // expectedR
                     128,     // expectedG
                     0,       // expectedB
                     @"bird"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[4],
                     0,       // expectedR
                     0,       // expectedG
                     128,     // expectedB
                     @"boat"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[5],
                     128,       // expectedR
                     0,         // expectedG
                     128,       // expectedB
                     @"bottle"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[6],
                     0,      // expectedR
                     128,    // expectedG
                     128,    // expectedB
                     @"bus"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[7],
                     128,    // expectedR
                     128,    // expectedG
                     128,    // expectedB
                     @"car"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[8],
                     64,     // expectedR
                     0,      // expectedG
                     0,      // expectedB
                     @"cat"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[9],
                     192,      // expectedR
                     0,        // expectedG
                     0,        // expectedB
                     @"chair"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[10],
                     64,     // expectedR
                     128,    // expectedG
                     0,      // expectedB
                     @"cow"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[11],
                     192,             // expectedR
                     128,             // expectedG
                     0,               // expectedB
                     @"dining table"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[12],
                     64,     // expectedR
                     0,      // expectedG
                     128,    // expectedB
                     @"dog"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[13],
                     192,      // expectedR
                     0,        // expectedG
                     128,      // expectedB
                     @"horse"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[14],
                     64,           // expectedR
                     128,          // expectedG
                     128,          // expectedB
                     @"motorbike"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[15],
                     192,       // expectedR
                     128,       // expectedG
                     128,       // expectedB
                     @"person"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[16],
                     0,               // expectedR
                     64,              // expectedG
                     0,               // expectedB
                     @"potted plant"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[17],
                     128,      // expectedR
                     64,       // expectedG
                     0,        // expectedB
                     @"sheep"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[18],
                     0,       // expectedR
                     192,     // expectedG
                     0,       // expectedB
                     @"sofa"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[19],
                     128,      // expectedR
                     192,      // expectedG
                     0,        // expectedB
                     @"train"  // expectedLabel
  );

  VerifyColoredLabel(coloredLabels[20],
                     0,     // expectedR
                     64,    // expectedG
                     128,   // expectedB
                     @"tv"  // expectedLabel
  );
}

- (void)testSuccessfulImageSegmentationWithCategoryMask {
  TFLImageSegmenterOptions *imageSegmenterOptions =
      [[TFLImageSegmenterOptions alloc] initWithModelPath:self.modelPath];

  TFLImageSegmenter *imageSegmenter =
      [TFLImageSegmenter imageSegmenterWithOptions:imageSegmenterOptions error:nil];
  XCTAssertNotNil(imageSegmenter);

  GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:self.class
                                                 fileName:@"segmentation_input_rotation0"
                                                   ofType:@"jpg"];
  XCTAssertNotNil(gmlImage);

  TFLSegmentationResult *segmentationResult = [imageSegmenter segmentWithGMLImage:gmlImage
                                                                            error:nil];

  XCTAssertNotNil(segmentationResult);
  XCTAssertEqual(segmentationResult.segmentations.count, 1);

  XCTAssertNotNil(segmentationResult.segmentations[0].coloredLabels);
  [self compareWithDeepLabV3PartialColoredLabels:segmentationResult.segmentations[0].coloredLabels];

  XCTAssertNotNil(segmentationResult.segmentations[0].categoryMask);
  XCTAssertTrue(segmentationResult.segmentations[0].categoryMask.mask != nil);

  GMLImage *goldenImage = [GMLImage imageFromBundleWithClass:self.class
                                                    fileName:@"segmentation_golden_rotation0"
                                                      ofType:@"png"];

  XCTAssertNotNil(goldenImage);
  CVPixelBufferRef pixelBuffer = [goldenImage grayScalePixelBuffer];

  CVPixelBufferLockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);

  UInt8 *pixelBufferBaseAddress = (UInt8 *)CVPixelBufferGetBaseAddress(pixelBuffer);

  XCTAssertEqual(deepLabV3SegmentationWidth,
                 segmentationResult.segmentations[0].categoryMask.width);
  XCTAssertEqual(deepLabV3SegmentationHeight,
                 segmentationResult.segmentations[0].categoryMask.height);

  NSInteger numPixels = deepLabV3SegmentationWidth * deepLabV3SegmentationHeight;

  float inconsistentPixels = 0;

  for (int i = 0; i < numPixels; i++)
    if (segmentationResult.segmentations[0].categoryMask.mask[i] * kGoldenMaskMagnificationFactor !=
        pixelBufferBaseAddress[i])
      inconsistentPixels += 1;

  CVPixelBufferUnlockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly);

  XCTAssertLessThan(inconsistentPixels / (float)numPixels, kGoldenMaskTolerance);
}

@end