llvm/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp

//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering from high level async operations to async.coro
// and async.runtime operations.
//
//===----------------------------------------------------------------------===//

#include <utility>

#include "mlir/Dialect/Async/Passes.h"

#include "PassDetail.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include <optional>

namespace mlir {
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::async;

#define DEBUG_TYPE
// Prefix for functions outlined from `async.execute` op regions.
static constexpr const char kAsyncFnPrefix[] =;

namespace {

class AsyncToAsyncRuntimePass
    : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {};

} // namespace

namespace {

class AsyncFuncToAsyncRuntimePass
    : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {};

} // namespace

/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
/// async.await op lowering additionaly creates a resume block for each
/// operation to enable non-blocking waiting via coroutine suspension.
namespace {
struct CoroMachinery {};
} // namespace

FuncCoroMapPtr;

/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
/// that branches into preexisting entry block. Also inserts trailing blocks.
///
/// The result types of the passed `func` start with an optional `async.token`
/// and be continued with some number of `async.value`s.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
///  - `entry` block sets up the coroutine.
///  - `set_error` block sets completion token and async values state to error.
///  - `cleanup` block cleans up the coroutine state.
///  - `suspend block after the @llvm.coro.end() defines what value will be
///    returned to the initial caller of a coroutine. Everything before the
///    @llvm.coro.end() will be executed at every suspension point.
///
/// Coroutine structure (only the important bits):
///
///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
///   {
///     ^entry(<function-arguments>):
///       %token = <async token> : !async.token    // create async runtime token
///       %value = <async value> : !async.value<T> // create async value
///       %id = async.coro.getId                   // create a coroutine id
///       %hdl = async.coro.begin %id              // create a coroutine handle
///       cf.br ^preexisting_entry_block
///
///     /*  preexisting blocks modified to branch to the cleanup block */
///
///     ^set_error: // this block created lazily only if needed (see code below)
///       async.runtime.set_error %token : !async.token
///       async.runtime.set_error %value : !async.value<T>
///       cf.br ^cleanup
///
///     ^cleanup:
///       async.coro.free %hdl // delete the coroutine state
///       cf.br ^suspend
///
///     ^suspend:
///       async.coro.end %hdl // marks the end of a coroutine
///       return %token, %value : !async.token, !async.value<T>
///   }
///
static CoroMachinery setupCoroMachinery(func::FuncOp func) {}

// Lazily creates `set_error` block only if it is required for lowering to the
// runtime operations (see for example lowering of assert operation).
static Block *setupSetErrorBlock(CoroMachinery &coro) {}

//===----------------------------------------------------------------------===//
// async.execute op outlining to the coroutine functions.
//===----------------------------------------------------------------------===//

/// Outline the body region attached to the `async.execute` op into a standalone
/// function.
///
/// Note that this is not reversible transformation.
static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {}

//===----------------------------------------------------------------------===//
// Convert async.create_group operation to async.runtime.create_group
//===----------------------------------------------------------------------===//

namespace {
class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Convert async.add_to_group operation to async.runtime.add_to_group.
//===----------------------------------------------------------------------===//

namespace {
class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Convert async.func, async.return and async.call operations to non-blocking
// operations based on llvm coroutine
//===----------------------------------------------------------------------===//

namespace {

//===----------------------------------------------------------------------===//
// Convert async.func operation to func.func
//===----------------------------------------------------------------------===//

class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {};

//===----------------------------------------------------------------------===//
// Convert async.call operation to func.call
//===----------------------------------------------------------------------===//

class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {};

//===----------------------------------------------------------------------===//
// Convert async.return operation to async.runtime operations.
//===----------------------------------------------------------------------===//

class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Convert async.await and async.await_all operations to the async.runtime.await
// or async.runtime.await_and_resume operations.
//===----------------------------------------------------------------------===//

namespace {
template <typename AwaitType, typename AwaitableType>
class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {};

/// Lowering for `async.await` with a token operand.
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {};

/// Lowering for `async.await` with a value operand.
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {};

/// Lowering for `async.await_all` operation.
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {};

} // namespace

//===----------------------------------------------------------------------===//
// Convert async.yield operation to async.runtime operations.
//===----------------------------------------------------------------------===//

class YieldOpLowering : public OpConversionPattern<async::YieldOp> {};

//===----------------------------------------------------------------------===//
// Convert cf.assert operation to cf.cond_br into `set_error` block.
//===----------------------------------------------------------------------===//

class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {};

//===----------------------------------------------------------------------===//
void AsyncToAsyncRuntimePass::runOnOperation() {}

//===----------------------------------------------------------------------===//
void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
    RewritePatternSet &patterns, ConversionTarget &target) {}

void AsyncFuncToAsyncRuntimePass::runOnOperation() {}

std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createAsyncFuncToAsyncRuntimePass() {}