//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a partial lowering of Toy operations to a combination of // affine loops, memref operations and standard operations. This lowering // expects that all calls have been inlined, and all shapes have been resolved. // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" #include "toy/Dialect.h" #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Casting.h" #include <algorithm> #include <cstdint> #include <functional> #include <memory> #include <utility> usingnamespacemlir; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(RankedTensorType type) { … } /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { … } /// This defines the function type used to process an iteration of a lowered /// loop. It takes as input an OpBuilder, an range of memRefOperands /// corresponding to the operands of the input operation, and the range of loop /// induction variables for the iteration. It returns a value to store at the /// current index of the iteration. LoopIterationFn; static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { … } namespace { //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> struct BinaryOpLowering : public ConversionPattern { … }; AddOpLowering; MulOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations //===----------------------------------------------------------------------===// struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { … }; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { … }; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { … }; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Return operations //===----------------------------------------------------------------------===// struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { … }; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Transpose operations //===----------------------------------------------------------------------===// struct TransposeOpLowering : public ConversionPattern { … }; } // namespace //===----------------------------------------------------------------------===// // ToyToAffineLoweringPass //===----------------------------------------------------------------------===// /// This is a partial lowering to affine loops of the toy operations that are /// computationally intensive (like matmul for example...) while keeping the /// rest of the code in the Toy dialect. namespace { struct ToyToAffineLoweringPass : public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> { … }; } // namespace void ToyToAffineLoweringPass::runOnOperation() { … } /// Create a pass for lowering operations in the `Affine` and `Std` dialects, /// for a subset of the Toy IR (e.g. matmul). std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() { … }