//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering from high level async operations to async.coro // and async.runtime operations. // //===----------------------------------------------------------------------===// #include <utility> #include "mlir/Dialect/Async/Passes.h" #include "PassDetail.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include <optional> namespace mlir { #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME #include "mlir/Dialect/Async/Passes.h.inc" } // namespace mlir usingnamespacemlir; usingnamespacemlir::async; #define DEBUG_TYPE … // Prefix for functions outlined from `async.execute` op regions. static constexpr const char kAsyncFnPrefix[] = …; namespace { class AsyncToAsyncRuntimePass : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { … }; } // namespace namespace { class AsyncFuncToAsyncRuntimePass : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> { … }; } // namespace /// Function targeted for coroutine transformation has two additional blocks at /// the end: coroutine cleanup and coroutine suspension. /// /// async.await op lowering additionaly creates a resume block for each /// operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { … }; } // namespace FuncCoroMapPtr; /// Utility to partially update the regular function CFG to the coroutine CFG /// compatible with LLVM coroutines switched-resume lowering using /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block /// that branches into preexisting entry block. Also inserts trailing blocks. /// /// The result types of the passed `func` start with an optional `async.token` /// and be continued with some number of `async.value`s. /// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. /// - `set_error` block sets completion token and async values state to error. /// - `cleanup` block cleans up the coroutine state. /// - `suspend block after the @llvm.coro.end() defines what value will be /// returned to the initial caller of a coroutine. Everything before the /// @llvm.coro.end() will be executed at every suspension point. /// /// Coroutine structure (only the important bits): /// /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) /// { /// ^entry(<function-arguments>): /// %token = <async token> : !async.token // create async runtime token /// %value = <async value> : !async.value<T> // create async value /// %id = async.coro.getId // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle /// cf.br ^preexisting_entry_block /// /// /* preexisting blocks modified to branch to the cleanup block */ /// /// ^set_error: // this block created lazily only if needed (see code below) /// async.runtime.set_error %token : !async.token /// async.runtime.set_error %value : !async.value<T> /// cf.br ^cleanup /// /// ^cleanup: /// async.coro.free %hdl // delete the coroutine state /// cf.br ^suspend /// /// ^suspend: /// async.coro.end %hdl // marks the end of a coroutine /// return %token, %value : !async.token, !async.value<T> /// } /// static CoroMachinery setupCoroMachinery(func::FuncOp func) { … } // Lazily creates `set_error` block only if it is required for lowering to the // runtime operations (see for example lowering of assert operation). static Block *setupSetErrorBlock(CoroMachinery &coro) { … } //===----------------------------------------------------------------------===// // async.execute op outlining to the coroutine functions. //===----------------------------------------------------------------------===// /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// /// Note that this is not reversible transformation. static std::pair<func::FuncOp, CoroMachinery> outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { … } //===----------------------------------------------------------------------===// // Convert async.create_group operation to async.runtime.create_group //===----------------------------------------------------------------------===// namespace { class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.add_to_group operation to async.runtime.add_to_group. //===----------------------------------------------------------------------===// namespace { class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.func, async.return and async.call operations to non-blocking // operations based on llvm coroutine //===----------------------------------------------------------------------===// namespace { //===----------------------------------------------------------------------===// // Convert async.func operation to func.func //===----------------------------------------------------------------------===// class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> { … }; //===----------------------------------------------------------------------===// // Convert async.call operation to func.call //===----------------------------------------------------------------------===// class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> { … }; //===----------------------------------------------------------------------===// // Convert async.return operation to async.runtime operations. //===----------------------------------------------------------------------===// class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.await and async.await_all operations to the async.runtime.await // or async.runtime.await_and_resume operations. //===----------------------------------------------------------------------===// namespace { template <typename AwaitType, typename AwaitableType> class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { … }; /// Lowering for `async.await` with a token operand. class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { … }; /// Lowering for `async.await` with a value operand. class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { … }; /// Lowering for `async.await_all` operation. class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.yield operation to async.runtime operations. //===----------------------------------------------------------------------===// class YieldOpLowering : public OpConversionPattern<async::YieldOp> { … }; //===----------------------------------------------------------------------===// // Convert cf.assert operation to cf.cond_br into `set_error` block. //===----------------------------------------------------------------------===// class AssertOpLowering : public OpConversionPattern<cf::AssertOp> { … }; //===----------------------------------------------------------------------===// void AsyncToAsyncRuntimePass::runOnOperation() { … } //===----------------------------------------------------------------------===// void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns( RewritePatternSet &patterns, ConversionTarget &target) { … } void AsyncFuncToAsyncRuntimePass::runOnOperation() { … } std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { … } std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncFuncToAsyncRuntimePass() { … }