llvm/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"

#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <memory>

namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTOSPIRV
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE

usingnamespacemlir;

//===----------------------------------------------------------------------===//
// Conversion Helpers
//===----------------------------------------------------------------------===//

/// Converts the given `srcAttr` into a boolean attribute if it holds an
/// integral value. Returns null attribute if conversion fails.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {}

/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
/// Returns null attribute if conversion fails.
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
                                      Builder builder) {}

/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
                                  Builder builder) {}

/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {}

/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                       OpBuilder &builder, Location loc) {}

/// Returns true if scalar/vector type `a` and `b` have the same number of
/// bitwidth.
static bool hasSameBitwidth(Type a, Type b) {}

/// Returns a source type conversion failure for `srcType` and operation `op`.
static LogicalResult
getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
                         Type srcType) {}

/// Returns a source type conversion failure for the result type of `op`.
static LogicalResult
getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {}

// TODO: Move to some common place?
static std::string getDecorationString(spirv::Decoration decor) {}

namespace {

/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
/// operations. Op can potentially support overflow flags.
template <typename Op, typename SPIRVOp>
struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {};

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

/// Converts composite arith.constant operation to spirv.Constant.
struct ConstantCompositeOpPattern final
    : public OpConversionPattern<arith::ConstantOp> {};

/// Converts scalar arith.constant operation to spirv.Constant.
struct ConstantScalarOpPattern final
    : public OpConversionPattern<arith::ConstantOp> {};

//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//

/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
/// the sign of `signOperand`.
///
/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
/// the result is undefined."  So we cannot directly use spirv.SRem/spirv.SMod
/// if either operand can be negative. Emulate it via spirv.UMod.
template <typename SignedAbsOp>
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
                                    Value signOperand, OpBuilder &builder) {}

/// Converts arith.remsi to GLSL SPIR-V ops.
///
/// This cannot be merged into the template unary/binary pattern due to Vulkan
/// restrictions over spirv.SRem and spirv.SMod.
struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {};

/// Converts arith.remsi to OpenCL SPIR-V ops.
struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {};

//===----------------------------------------------------------------------===//
// BitwiseOp
//===----------------------------------------------------------------------===//

/// Converts bitwise operations to SPIR-V operations. This is a special pattern
/// other than the BinaryOpPatternPattern because if the operands are boolean
/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
struct BitwiseOpPattern final : public OpConversionPattern<Op> {};

//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//

/// Converts arith.xori to SPIR-V operations.
struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {};

/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
/// vector of i1.
struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {};

//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//

/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
/// of i1.
struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {};

//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//

/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {};

/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
/// vector of i1.
struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {};

//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//

/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {};

/// Converts arith.extui for cases where the type of source is neither i1 nor
/// vector of i1.
struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {};

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//

/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
/// of i1.
struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {};

/// Converts arith.trunci for cases where the type of result is neither i1
/// nor vector of i1.
struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {};

//===----------------------------------------------------------------------===//
// TypeCastingOp
//===----------------------------------------------------------------------===//

static std::optional<spirv::FPRoundingMode>
convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {}

/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {};

//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//

/// Converts integer compare operation on i1 type operands to SPIR-V ops.
class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {};

/// Converts integer compare operation to SPIR-V ops.
class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {};

//===----------------------------------------------------------------------===//
// CmpFOpPattern
//===----------------------------------------------------------------------===//

/// Converts floating-point comparison operations to SPIR-V ops.
class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {};

/// Converts floating point NaN check to SPIR-V ops. This pattern requires
/// Kernel capability.
class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {};

/// Converts floating point NaN check to SPIR-V ops. This pattern does not
/// require additional capability.
class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {};

//===----------------------------------------------------------------------===//
// AddUIExtendedOp
//===----------------------------------------------------------------------===//

/// Converts arith.addui_extended to spirv.IAddCarry.
class AddUIExtendedOpPattern final
    : public OpConversionPattern<arith::AddUIExtendedOp> {};

//===----------------------------------------------------------------------===//
// MulIExtendedOp
//===----------------------------------------------------------------------===//

/// Converts arith.mul*i_extended to spirv.*MulExtended.
template <typename ArithMulOp, typename SPIRVMulOp>
class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {};

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

/// Converts arith.select to spirv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {};

//===----------------------------------------------------------------------===//
// MinimumFOp, MaximumFOp
//===----------------------------------------------------------------------===//

/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {};

//===----------------------------------------------------------------------===//
// MinNumFOp, MaxNumFOp
//===----------------------------------------------------------------------===//

/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {};

} // namespace

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void mlir::arith::populateArithToSPIRVPatterns(
    const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
struct ConvertArithToSPIRVPass
    : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {};
} // namespace

std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {}