llvm/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header defines the `SparseTensorType` wrapper class.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"

namespace mlir {
namespace sparse_tensor {

//===----------------------------------------------------------------------===//
/// A wrapper around `RankedTensorType`, which has three goals:
///
/// (1) To provide a uniform API for querying aspects of sparse-tensor
/// types; in particular, to make the "dimension" vs "level" distinction
/// overt (i.e., explicit everywhere).  Thus, throughout the sparsifier
/// this class should be preferred over using `RankedTensorType` or
/// `ShapedType` directly, since the methods of the latter do not make
/// the "dimension" vs "level" distinction overt.
///
/// (2) To provide a uniform abstraction over both sparse-tensor
/// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`)
/// and dense-tensor types (i.e., `RankedTensorType` without an encoding).
/// That is, we want to manipulate dense-tensor types using the same API
/// that we use for manipulating sparse-tensor types; both to keep the
/// "dimension" vs "level" distinction overt, and to avoid needing to
/// handle certain cases specially in the sparsifier.
///
/// (3) To provide uniform handling of "defaults".  In particular
/// this means that dense-tensors should always return the same answers
/// as sparse-tensors with a default encoding.  But it additionally means
/// that the answers should be normalized, so that there's no way to
/// distinguish between non-provided data (which is filled in by default)
/// vs explicitly-provided data which equals the defaults.
///
class SparseTensorType {};

/// Convenience methods to obtain a SparseTensorType from a Value.
inline SparseTensorType getSparseTensorType(Value val) {}
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {}

} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_