llvm/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "llvm/Support/Compiler.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include <memory>

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

namespace // namespace

/// Vector multiplication (conjunction) of 3 vectors, i.e.;
///   a(i) = b(i) * c(i) * d(i)
/// which should form the single lattice point
/// {
///   lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
/// }
/// after optimization, the dense dimesion should be kept, despite it appears
/// in the middle
/// {
///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP()
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF

/// Vector multiplication (conjunction) of 2 vectors, i.e.;
///   o(i) = b(i) * c(i) * o(i)
/// which should form the single lattice point (note how a synthetic tensor
/// i_03_U is created for the sparse output)
/// {
///   lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
/// }
/// after optimization, the synthetic tensor should be preserved.
/// {
///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP()
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT

/// Vector addition (disjunction) of 2 vectors. i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
///   lat( i_00 / tensor_0 )
///   lat( i_01 / tensor_1 )
/// }
/// and after optimization, the lattice points do not change (as there is no
/// duplicated point and all input vectors are sparse vector).
/// {
///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
///   lat( i_00 / tensor_0 )
///   lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ
FOREVERY_COMMON_DISJ_BINOP()
#undef IMPL_MERGER_TEST_DISJ

/// Vector multiplication (conjunction) of 2 vectors, i.e.;
///   a(i) = b(i) * c(i)
/// which should form the single lattice point
/// {
///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ
FOREVERY_COMMON_CONJ_BINOP()
#undef IMPL_MERGER_TEST_CONJ

/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
///   a(i) = b(i) * c(i) + d(i);
/// which should form
/// {
///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
///    lat( i_00 i_01 / tensor_0 * tensor_1
///    lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP()
#undef IMPL_MERGER_TEST_CONJ_DISJ

/// Vector addition (disjunction) then addition (disjunction), i.e.;
///   a(i) = b(i) + c(i) + d(i)
/// which should form
/// {
///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
///   lat( i_02 i_01 / tensor_2 + tensor_1 )
///   lat( i_02 i_00 / tensor_2 + tensor_0 )
///   lat( i_01 i_00 / tensor_1 + tensor_0 )
///   lat( i_02 / tensor_2 )
///   lat( i_01 / tensor_1 )
///   lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP()
#undef IMPL_MERGER_TEST_DISJ_DISJ

/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
///   a(i) = b(i) * c(i) * d(i);
/// which should form
/// {
///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP()
#undef IMPL_MERGER_TEST_CONJ_CONJ

/// Vector addition (disjunction) of 2 vectors, i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
///   lat( i_00 / sparse_tensor_0 )
///   lat( i_01 / dense_tensor_1 )
/// }
/// which should be optimized to
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
/// }
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED)
FOREVERY_COMMON_DISJ_BINOP()
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ

/// Vector multiplication (conjunction) of 2 vectors, i.e.:
///   a(i) = b(i) * c(i)
/// which should form the single lattice point
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
/// }
/// it should be optimized to
/// {
///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ
FOREVERY_COMMON_CONJ_BINOP()
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ

/// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
///   lat( i_00 / tensor_0 cmp 0 )
///   lat( i_01 / 0 cmp tensor_1 )
/// }
/// and after optimization, the lattice points do not change (as there is no
/// duplicated point and all input vectors are sparse vector).
/// {
///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
///   lat( i_00 / tensor_0 cmp 0 )
///   lat( i_01 / 0 cmp tensor_1 )
/// }
TEST_P(MergerTest3T1L, vector_cmp) {}

/// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.;
///   a(i) = b(i) cmp c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) )
///   lat( i_00 / sparse_tensor_0 cmp 0)
///   lat( i_01 / 0 cmp dense_tensor_1 )
/// }
/// which should be optimized to
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton)
///   lat( i_01 / 0 cmp dense_tensor_0 ) ()
/// }
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_P(MergerTest3T1LD, vector_cmp) {}