//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This header defines the `SparseTensorType` wrapper class. // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" namespace mlir { namespace sparse_tensor { //===----------------------------------------------------------------------===// /// A wrapper around `RankedTensorType`, which has three goals: /// /// (1) To provide a uniform API for querying aspects of sparse-tensor /// types; in particular, to make the "dimension" vs "level" distinction /// overt (i.e., explicit everywhere). Thus, throughout the sparsifier /// this class should be preferred over using `RankedTensorType` or /// `ShapedType` directly, since the methods of the latter do not make /// the "dimension" vs "level" distinction overt. /// /// (2) To provide a uniform abstraction over both sparse-tensor /// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`) /// and dense-tensor types (i.e., `RankedTensorType` without an encoding). /// That is, we want to manipulate dense-tensor types using the same API /// that we use for manipulating sparse-tensor types; both to keep the /// "dimension" vs "level" distinction overt, and to avoid needing to /// handle certain cases specially in the sparsifier. /// /// (3) To provide uniform handling of "defaults". In particular /// this means that dense-tensors should always return the same answers /// as sparse-tensors with a default encoding. But it additionally means /// that the answers should be normalized, so that there's no way to /// distinguish between non-provided data (which is filled in by default) /// vs explicitly-provided data which equals the defaults. /// class SparseTensorType { … }; /// Convenience methods to obtain a SparseTensorType from a Value. inline SparseTensorType getSparseTensorType(Value val) { … } inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) { … } } // namespace sparse_tensor } // namespace mlir #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_