llvm/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp

//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Pass/Pass.h"

usingnamespacemlir;
usingnamespaceindex;

namespace {

//===----------------------------------------------------------------------===//
// Trivial Conversions
//===----------------------------------------------------------------------===//

ConvertIndexAdd;
ConvertIndexSub;
ConvertIndexMul;
ConvertIndexDivS;
ConvertIndexDivU;
ConvertIndexRemS;
ConvertIndexRemU;
ConvertIndexMaxS;
ConvertIndexMaxU;
ConvertIndexMinS;
ConvertIndexMinU;

ConvertIndexShl;
ConvertIndexShrS;
ConvertIndexShrU;

/// It is the case that when we convert bitwise operations to SPIR-V operations
/// we must take into account the special pattern in SPIR-V that if the
/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
/// index.add is never a boolean operation so we can directly convert it to the
/// Bitwise[And|Or]Op.
ConvertIndexAnd;
ConvertIndexOr;
ConvertIndexXor;

//===----------------------------------------------------------------------===//
// ConvertConstantBool
//===----------------------------------------------------------------------===//

// Converts index.bool.constant operation to spirv.Constant.
struct ConvertIndexConstantBoolOpPattern final
    : OpConversionPattern<BoolConstantOp> {};

//===----------------------------------------------------------------------===//
// ConvertConstant
//===----------------------------------------------------------------------===//

// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
// when required.
struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {};

//===----------------------------------------------------------------------===//
// ConvertIndexCeilDivS
//===----------------------------------------------------------------------===//

/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
/// conversion in IndexToLLVM.
struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {};

//===----------------------------------------------------------------------===//
// ConvertIndexCeilDivU
//===----------------------------------------------------------------------===//

/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
/// from the equivalent conversion in IndexToLLVM.
struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {};

//===----------------------------------------------------------------------===//
// ConvertIndexFloorDivS
//===----------------------------------------------------------------------===//

/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
/// in IndexToLLVM.
struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {};

//===----------------------------------------------------------------------===//
// ConvertIndexCast
//===----------------------------------------------------------------------===//

/// Convert a cast op. If the materialized index type is the same as the other
/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
/// zero extend when the result bitwidth is larger.
template <typename CastOp, typename ConvertOp>
struct ConvertIndexCast final : OpConversionPattern<CastOp> {};

ConvertIndexCastS;
ConvertIndexCastU;

//===----------------------------------------------------------------------===//
// ConvertIndexCmp
//===----------------------------------------------------------------------===//

// Helper template to replace the operation
template <typename ICmpOp>
static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) {}

struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {};

//===----------------------------------------------------------------------===//
// ConvertIndexSizeOf
//===----------------------------------------------------------------------===//

/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void index::populateIndexToSPIRVPatterns(
    const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

namespace mlir {
#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
struct ConvertIndexToSPIRVPass
    : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {};
} // namespace