llvm/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td

//===- SparseTensorTypes.td - Sparse tensor dialect types ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SPARSETENSOR_TYPES
#define SPARSETENSOR_TYPES

include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"

//===----------------------------------------------------------------------===//
// Base class.
//===----------------------------------------------------------------------===//

// Base class for Builtin dialect types.
class SparseTensor_Type<string name, list<Trait> traits = [],
                   string baseCppClass = "::mlir::Type">
    : TypeDef<SparseTensor_Dialect, name, traits, baseCppClass> {}

//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Types.
//===----------------------------------------------------------------------===//

def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
  let mnemonic = "storage_specifier";
  let summary = "Structured metadata for sparse tensor low-level storage scheme";

  let description = [{
    Values with storage_specifier types represent aggregated storage scheme
    metadata for the given sparse tensor encoding.  It currently holds
    a set of values for level-sizes, coordinate arrays, position arrays,
    and value array.  Note that the type is not yet stable and subject to
    change in the near future.

    Examples:

    ```mlir
    // A storage specifier that can be used to store storage scheme metadata from CSR matrix.
    !storage_specifier<#CSR>
    ```
  }];

  let parameters = (ins SparseTensorEncodingAttr : $encoding);
  let builders = [
    TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
    TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
      return get(encoding.getContext(), encoding);
    }]>,
    TypeBuilderWithInferredContext<(ins "Type":$type), [{
      return get(getSparseTensorEncoding(type));
    }]>,
    TypeBuilderWithInferredContext<(ins "Value":$tensor), [{
      return get(tensor.getType());
    }]>
  ];

  // We skipped the default builder that simply takes the input sparse tensor encoding
  // attribute since we need to normalize the dimension level type and remove unrelated
  // fields that are irrelavant to sparse tensor storage scheme.
  let skipDefaultBuilders = 1;
  let assemblyFormat="`<` qualified($encoding) `>`";
}

def IsSparseTensorStorageSpecifierTypePred
    : CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">;

def SparseTensorStorageSpecifier
    : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
          "::mlir::sparse_tensor::StorageSpecifierType">;

//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Types.
//===----------------------------------------------------------------------===//

def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
  let mnemonic = "iter_space";

  let description = [{
    A sparse iteration space that represents an abstract N-D (sparse) iteration space
    extracted from a sparse tensor, i.e., a set of (crd_0, crd_1, ..., crd_N) for
    every stored element (usually nonzeros) in a sparse tensor between the specified
    [$loLvl, $hiLvl) levels.

    Examples:

    ```mlir
    // An iteration space extracted from a CSR tensor between levels [0, 2).
    !iter_space<#CSR, lvls = 0 to 2>
    ```
  }];

  let parameters = (ins
     SparseTensorEncodingAttr : $encoding,
     "Level" : $loLvl,
     "Level" : $hiLvl
  );

  let extraClassDeclaration = [{
     /// The the dimension of the iteration space.
     unsigned getSpaceDim() const {
       return getHiLvl() - getLoLvl();
     }

     /// Get the level types for the iteration space.
     ArrayRef<LevelType> getLvlTypes() const {
       return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
     }

     /// Whether the iteration space is unique (i.e., no duplicated coordinate).
     bool isUnique() {
       return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
     }

     /// Get the corresponding iterator type.
     ::mlir::sparse_tensor::IteratorType getIteratorType() const;
  }];

  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
}

def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
  let mnemonic = "iterator";

  let description = [{
    An iterator that points to the current element in the corresponding iteration space.

    Examples:

    ```mlir
    // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
    !iterator<#CSR, lvls = 0 to 2>
    ```
  }];

  let parameters = (ins
     SparseTensorEncodingAttr : $encoding,
     "Level" : $loLvl,
     "Level" : $hiLvl
  );

  let extraClassDeclaration = [{
     /// Get the corresponding iteration space type.
     ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;

     unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
     ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
     bool isUnique() { return getIterSpaceType().isUnique(); }
  }];

  let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
}

def IsSparseSparseIterSpaceTypePred
    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;

def IsSparseSparseIteratorTypePred
    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;

def AnySparseIterSpace
    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
          "::mlir::sparse_tensor::IterSpaceType">;

def AnySparseIterator
    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
          "::mlir::sparse_tensor::IteratorType">;


#endif // SPARSETENSOR_TYPES