llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

//===- SparseTensorConversion.cpp - Sparse tensor primitives conversion ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A pass that converts sparse tensor primitives into calls into a runtime
// support library. Sparse tensor types are converted into opaque pointers
// to the underlying sparse storage schemes. The use of opaque pointers
// together with runtime support library keeps the conversion relatively
// simple, but at the expense of IR opacity, which obscures opportunities
// for subsequent optimization of the IR. An alternative is provided by
// the SparseTensorCodegen pass.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

namespace {

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//

/// Maps each sparse tensor type to an opaque pointer.
static std::optional<Type> convertSparseTensorTypes(Type type) {}

/// Generates call to lookup a level-size.  N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
static Value genLvlSizeCall(OpBuilder &builder, Location loc, Value tensor,
                            uint64_t lvl) {}

/// Generates call to lookup a dimension-size.  N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor,
                            uint64_t dim) {}

/// Looks up a level-size by returning a statically-computed constant
/// (when possible), or by calling `genLvlSizeCall` (when dynamic).
static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
                                 SparseTensorType stt, Value tensor,
                                 Level lvl) {}

/// Looks up a dimension-size by returning a constant from the shape
/// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes
/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes
/// of dense tensors).
static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
                                 SparseTensorType stt, Value tensor,
                                 Dimension dim) {}

/// Populates the array with the dimension-sizes of the given tensor.
static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt,
                         Value tensor, SmallVectorImpl<Value> &out) {}

/// Returns an array with the dimension-sizes of the given tensor.
/// If the *tensor* parameters is null, the tensor type is assumed to have a
/// static shape.
static SmallVector<Value> getDimSizes(OpBuilder &builder, Location loc,
                                      SparseTensorType stt,
                                      Value tensor = Value()) {}

/// Generates an uninitialized buffer of the given size and type,
/// but returns it as type `memref<? x $tp>` (rather than as type
/// `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
/// this buffer must be explicitly deallocated by client.
static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {}

/// Generates a temporary buffer for the level-types of the given encoding.
static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
                               SparseTensorType stt) {}

/// Extracts the bare (aligned) pointers that point to the tensor.
static Value extractBarePtrFromTensor(OpBuilder &builder, Location loc,
                                      Value tensor) {}

/// Generates a temporary buffer for the level-types of the given encoding.
static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
                               ValueRange lvlTensors, Value valTensor) {}

/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
/// the "swiss army knife" method of the sparse runtime support library
/// for materializing sparse tensors into the computation. This abstraction
/// reduces the need for modifications when the API changes.
class NewCallParams final {};

/// Generates a call to obtain the values array.
static Value genValuesCall(OpBuilder &builder, Location loc,
                           SparseTensorType stt, Value ptr) {}

/// Generates a call to obtain the positions array.
static Value genPositionsCall(OpBuilder &builder, Location loc,
                              SparseTensorType stt, Value ptr, Level l) {}

/// Generates a call to obtain the coordinates array.
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
                                SparseTensorType stt, Value ptr, Level l) {}

/// Generates a call to obtain the coordinates array (AoS view).
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
                                      SparseTensorType stt, Value ptr,
                                      Level l) {}

//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//

/// Sparse conversion rule for returns.
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {};

/// Sparse conversion rule for accessing level-sizes.
class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {};

/// Sparse conversion rule for trivial tensor casts.
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {};

class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {};

/// Sparse conversion rule for the new operator.
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {};

/// Sparse conversion rule for the alloc operator.
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
class SparseTensorAllocConverter
    : public OpConversionPattern<bufferization::AllocTensorOp> {};

/// Sparse conversion rule for the empty tensor.
class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {};

/// Sparse conversion rule for the convert operator.
class SparseTensorReorderCOOConverter
    : public OpConversionPattern<ReorderCOOOp> {};

/// Sparse conversion rule for the dealloc operator.
class SparseTensorDeallocConverter
    : public OpConversionPattern<bufferization::DeallocTensorOp> {};

/// Sparse conversion rule for position accesses.
class SparseTensorToPositionsConverter
    : public OpConversionPattern<ToPositionsOp> {};

/// Sparse conversion rule for coordinate accesses.
class SparseTensorToCoordinatesConverter
    : public OpConversionPattern<ToCoordinatesOp> {};

/// Sparse conversion rule for coordinate accesses (AoS style).
class SparseToCoordinatesBufferConverter
    : public OpConversionPattern<ToCoordinatesBufferOp> {};

/// Sparse conversion rule for value accesses.
class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {};

/// Sparse conversion rule for number of entries operator.
class SparseNumberOfEntriesConverter
    : public OpConversionPattern<NumberOfEntriesOp> {};

/// Sparse conversion rule for tensor rematerialization.
class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {};

/// Sparse conversion rule for the insertion operator.
class SparseTensorInsertConverter
    : public OpConversionPattern<tensor::InsertOp> {};

/// Sparse conversion rule for the expand operator.
class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {};

/// Sparse conversion rule for the compress operator.
class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {};

/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {};

/// Sparse conversion rule for the sparse_tensor.disassemble operator.
/// Note that the current implementation simply exposes the buffers to
/// the external client. This assumes the client only reads the buffers
/// (usually copying it to the external data structures, such as numpy
/// arrays). The semantics of the disassemble operation technically
/// require that the copying is done here already using the out-levels
/// and out-values clause.
class SparseTensorDisassembleConverter
    : public OpConversionPattern<DisassembleOp> {};

struct SparseHasRuntimeLibraryConverter
    : public OpConversionPattern<HasRuntimeLibraryOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// Sparse tensor type conversion into opaque pointer.
//===----------------------------------------------------------------------===//

mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {}

//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//

/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorConversionPatterns(
    const TypeConverter &typeConverter, RewritePatternSet &patterns) {}