//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::async; //===----------------------------------------------------------------------===// // Async Runtime C API declaration. //===----------------------------------------------------------------------===// static constexpr const char *kAddRef = …; static constexpr const char *kDropRef = …; static constexpr const char *kCreateToken = …; static constexpr const char *kCreateValue = …; static constexpr const char *kCreateGroup = …; static constexpr const char *kEmplaceToken = …; static constexpr const char *kEmplaceValue = …; static constexpr const char *kSetTokenError = …; static constexpr const char *kSetValueError = …; static constexpr const char *kIsTokenError = …; static constexpr const char *kIsValueError = …; static constexpr const char *kIsGroupError = …; static constexpr const char *kAwaitToken = …; static constexpr const char *kAwaitValue = …; static constexpr const char *kAwaitGroup = …; static constexpr const char *kExecute = …; static constexpr const char *kGetValueStorage = …; static constexpr const char *kAddTokenToGroup = …; static constexpr const char *kAwaitTokenAndExecute = …; static constexpr const char *kAwaitValueAndExecute = …; static constexpr const char *kAwaitAllAndExecute = …; static constexpr const char *kGetNumWorkerThreads = …; namespace { /// Async Runtime API function types. /// /// Because we can't create API function signature for type parametrized /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After /// lowering all async data types become opaque pointers at runtime. struct AsyncAPI { … }; } // namespace /// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { … } //===----------------------------------------------------------------------===// // Coroutine resume function wrapper. //===----------------------------------------------------------------------===// static constexpr const char *kResume = …; /// A function that takes a coroutine handle and calls a `llvm.coro.resume` /// intrinsics. We need this function to be able to pass it to the async /// runtime execute API. static void addResumeFunction(ModuleOp module) { … } //===----------------------------------------------------------------------===// // Convert Async dialect types to LLVM types. //===----------------------------------------------------------------------===// namespace { /// AsyncRuntimeTypeConverter only converts types from the Async dialect to /// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { … }; /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter /// as type converter. Allows access to it via the 'getTypeConverter' /// convenience method. template <typename SourceOp> class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.id to @llvm.coro.id intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.begin to @llvm.coro.begin intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.free to @llvm.coro.free intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.end to @llvm.coro.end intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.save to @llvm.coro.save intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. //===----------------------------------------------------------------------===// namespace { /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and /// branch to the appropriate block based on the return code. /// /// Before: /// /// ^suspended: /// "opBefore"(...) /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// /// After: /// /// ^suspended: /// "opBefore"(...) /// %suspend = llmv.intr.coro.suspend ... /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.create to the corresponding runtime API call. // // To allocate storage for the async values we use getelementptr trick: // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt //===----------------------------------------------------------------------===// namespace { class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.create_group to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeCreateGroupOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeSetAvailableOpLowering : public OpConversionPattern<RuntimeSetAvailableOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.set_error to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeSetErrorOpLowering : public OpConversionPattern<RuntimeSetErrorOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.is_error to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await_and_resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitAndResumeOpLowering : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeResumeOpLowering : public AsyncOpConversionPattern<RuntimeResumeOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.store to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.load to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.add_to_group to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAddToGroupOpLowering : public OpConversionPattern<RuntimeAddToGroupOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.num_worker_threads to the corresponding runtime API // call. //===----------------------------------------------------------------------===// namespace { class RuntimeNumWorkerThreadsOpLowering : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.runtime.add_ref` and // `async.runtime.drop_ref` to the corresponding API calls). //===----------------------------------------------------------------------===// namespace { template <typename RefCountingOp> class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { … }; class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { … }; class RuntimeDropRefOpLowering : public RefCountingOpLowering<RuntimeDropRefOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Convert return operations that return async values from async regions. //===----------------------------------------------------------------------===// namespace { class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { … }; } // namespace //===----------------------------------------------------------------------===// namespace { struct ConvertAsyncToLLVMPass : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> { … }; } // namespace void ConvertAsyncToLLVMPass::runOnOperation() { … } //===----------------------------------------------------------------------===// // Patterns for structural type conversions for the Async dialect operations. //===----------------------------------------------------------------------===// namespace { class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { … }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { … }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { … }; } // namespace void mlir::populateAsyncStructuralTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { … }