llvm/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

//===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements scf.parallel to scf.for + async.execute conversion pass.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Async/Passes.h"

#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include <utility>

namespace mlir {
#define GEN_PASS_DEF_ASYNCPARALLELFOR
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::async;

#define DEBUG_TYPE

namespace {

// Rewrite scf.parallel operation into multiple concurrent async.execute
// operations over non overlapping subranges of the original loop.
//
// Example:
//
//   scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
//     "do_some_compute"(%i, %j): () -> ()
//   }
//
// Converted to:
//
//   // Parallel compute function that executes the parallel body region for
//   // a subset of the parallel iteration space defined by the one-dimensional
//   // compute block index.
//   func parallel_compute_function(%block_index : index, %block_size : index,
//                                  <parallel operation properties>, ...) {
//     // Compute multi-dimensional loop bounds for %block_index.
//     %block_lbi, %block_lbj = ...
//     %block_ubi, %block_ubj = ...
//
//     // Clone parallel operation body into the scf.for loop nest.
//     scf.for %i = %blockLbi to %blockUbi {
//       scf.for %j = block_lbj to %block_ubj {
//         "do_some_compute"(%i, %j): () -> ()
//       }
//     }
//   }
//
// And a dispatch function depending on the `asyncDispatch` option.
//
// When async dispatch is on: (pseudocode)
//
//   %block_size = ... compute parallel compute block size
//   %block_count = ... compute the number of compute blocks
//
//   func @async_dispatch(%block_start : index, %block_end : index, ...) {
//     // Keep splitting block range until we reached a range of size 1.
//     while (%block_end - %block_start > 1) {
//       %mid_index = block_start + (block_end - block_start) / 2;
//       async.execute { call @async_dispatch(%mid_index, %block_end); }
//       %block_end = %mid_index
//     }
//
//     // Call parallel compute function for a single block.
//     call @parallel_compute_fn(%block_start, %block_size, ...);
//   }
//
//   // Launch async dispatch for [0, block_count) range.
//   call @async_dispatch(%c0, %block_count);
//
// When async dispatch is off:
//
//   %block_size = ... compute parallel compute block size
//   %block_count = ... compute the number of compute blocks
//
//   scf.for %block_index = %c0 to %block_count {
//      call @parallel_compute_fn(%block_index, %block_size, ...)
//   }
//
struct AsyncParallelForPass
    : public impl::AsyncParallelForBase<AsyncParallelForPass> {};

struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {};

struct ParallelComputeFunctionType {};

// Helper struct to parse parallel compute function argument list.
struct ParallelComputeFunctionArgs {};

struct ParallelComputeFunctionBounds {};

struct ParallelComputeFunction {};

} // namespace

BlockArgument ParallelComputeFunctionArgs::blockIndex() {}
BlockArgument ParallelComputeFunctionArgs::blockSize() {}

ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {}

ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {}

ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() {}

ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {}

ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {}

template <typename ValueRange>
static SmallVector<IntegerAttr> integerConstants(ValueRange values) {}

// Converts one-dimensional iteration index in the [0, tripCount) interval
// into multidimensional iteration coordinate.
static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
                                      ArrayRef<Value> tripCounts) {}

// Returns a function type and implicit captures for a parallel compute
// function. We'll need a list of implicit captures to setup block and value
// mapping when we'll clone the body of the parallel operation.
static ParallelComputeFunctionType
getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {}

// Create a parallel compute fuction from the parallel operation.
static ParallelComputeFunction createParallelComputeFunction(
    scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds,
    unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {}

// Creates recursive async dispatch function for the given parallel compute
// function. Dispatch function keeps splitting block range into halves until it
// reaches a single block, and then excecutes it inline.
//
// Function pseudocode (mix of C++ and MLIR):
//
//   func @async_dispatch(%block_start : index, %block_end : index, ...) {
//
//     // Keep splitting block range until we reached a range of size 1.
//     while (%block_end - %block_start > 1) {
//       %mid_index = block_start + (block_end - block_start) / 2;
//       async.execute { call @async_dispatch(%mid_index, %block_end); }
//       %block_end = %mid_index
//     }
//
//     // Call parallel compute function for a single block.
//     call @parallel_compute_fn(%block_start, %block_size, ...);
//   }
//
static func::FuncOp
createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
                            PatternRewriter &rewriter) {}

// Launch async dispatch of the parallel compute function.
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
                            ParallelComputeFunction &parallelComputeFunction,
                            scf::ParallelOp op, Value blockSize,
                            Value blockCount,
                            const SmallVector<Value> &tripCounts) {}

// Dispatch parallel compute functions by submitting all async compute tasks
// from a simple for loop in the caller thread.
static void
doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
                     ParallelComputeFunction &parallelComputeFunction,
                     scf::ParallelOp op, Value blockSize, Value blockCount,
                     const SmallVector<Value> &tripCounts) {}

LogicalResult
AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
                                         PatternRewriter &rewriter) const {}

void AsyncParallelForPass::runOnOperation() {}

std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {}

std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
                                                       int32_t numWorkerThreads,
                                                       int32_t minTaskSize) {}

void mlir::async::populateAsyncParallelForPatterns(
    RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
    const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {}