//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "llvm/Support/FormatVariadic.h" usingnamespacemlir; usingnamespacesparse_tensor; //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// // Convert type range to new types range, with sparse tensors externalized. static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes, SmallVectorImpl<Type> *extraTypes, bool directOut) { … } // Convert input and output values to [dis]assemble ops for sparse tensors. static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn, bool directOut) { … } //===----------------------------------------------------------------------===// // Rewriting rules. //===----------------------------------------------------------------------===// namespace { // A rewriting rules that converts public entry methods that use sparse tensors // as input parameters and/or output return values into wrapper methods that // [dis]assemble the individual tensors that constitute the actual storage used // externally into MLIR sparse tensors before calling the original method. // // In particular, each sparse tensor input // // void foo(..., t, ...) { } // // makes the original foo() internal and adds the following wrapper method // // void foo(..., t1..tn, ...) { // t = assemble t1..tn // _internal_foo(..., t, ...) // } // // and likewise, each output tensor // // ... T ... bar(...) { return ..., t, ...; } // // makes the original bar() internal and adds the following wrapper method // // ... T1..TN ... bar(..., t1'..tn') { // ..., t, ... = _internal_bar(...) // t1..tn = disassemble t, t1'..tn' // return ..., t1..tn, ... // } // // (with a direct-out variant without the disassemble). // struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// void mlir::populateSparseAssembler(RewritePatternSet &patterns, bool directOut) { … }