//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert SCF dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" usingnamespacemlir; //===----------------------------------------------------------------------===// // Context //===----------------------------------------------------------------------===// namespace mlir { struct ScfToSPIRVContextImpl { … }; } // namespace mlir /// We use ScfToSPIRVContext to store information about the lowering of the scf /// region that need to be used later on. When we lower scf.for/scf.if we create /// VariableOp to store the results. We need to keep track of the VariableOp /// created as we need to insert stores into them when lowering Yield. Those /// StoreOp cannot be created earlier as they may use a different type than /// yield operands. ScfToSPIRVContext::ScfToSPIRVContext() { … } ScfToSPIRVContext::~ScfToSPIRVContext() = default; namespace { //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// /// Replaces SCF op outputs with SPIR-V variable loads. /// We create VariableOp to handle the results value of the control flow region. /// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right /// after the loop we load the value from the allocation and use it as the SCF /// op result. template <typename ScfOp, typename OpTy> void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, ConversionPatternRewriter &rewriter, ScfToSPIRVContextImpl *scfToSPIRVContext, ArrayRef<Type> returnTypes) { … } Region::iterator getBlockIt(Region ®ion, unsigned index) { … } //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// /// Common class for all vector to GPU patterns. template <typename OpTy> class SCFToSPIRVPattern : public OpConversionPattern<OpTy> { … }; //===----------------------------------------------------------------------===// // scf::ForOp //===----------------------------------------------------------------------===// /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { … }; //===----------------------------------------------------------------------===// // scf::IfOp //===----------------------------------------------------------------------===// /// Pattern to convert a scf::IfOp within kernel functions into /// spirv::SelectionOp. struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { … }; //===----------------------------------------------------------------------===// // scf::YieldOp //===----------------------------------------------------------------------===// struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> { … }; //===----------------------------------------------------------------------===// // scf::WhileOp //===----------------------------------------------------------------------===// struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Public API //===----------------------------------------------------------------------===// void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns) { … }