llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/Support/FormatVariadic.h"

usingnamespacemlir;
usingnamespacesparse_tensor;

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//

// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                      SmallVectorImpl<Type> *extraTypes, bool directOut) {}

// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
                     ValueRange fromVals, ValueRange extraVals,
                     SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
                     bool directOut) {}

//===----------------------------------------------------------------------===//
// Rewriting rules.
//===----------------------------------------------------------------------===//

namespace {

// A rewriting rules that converts public entry methods that use sparse tensors
// as input parameters and/or output return values into wrapper methods that
// [dis]assemble the individual tensors that constitute the actual storage used
// externally into MLIR sparse tensors before calling the original method.
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
// makes the original foo() internal and adds the following wrapper method
//
// void foo(..., t1..tn, ...) {
//   t = assemble t1..tn
//   _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
// makes the original bar() internal and adds the following wrapper method
//
// ... T1..TN ... bar(..., t1'..tn') {
//   ..., t, ... = _internal_bar(...)
//   t1..tn = disassemble t, t1'..tn'
//   return ..., t1..tn, ...
// }
//
// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//

void mlir::populateSparseAssembler(RewritePatternSet &patterns,
                                   bool directOut) {}