llvm/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Math dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"

#define DEBUG_TYPE

usingnamespacemlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
/// given type is not a 32-bit scalar/vector type.
static Value getScalarOrVectorI32Constant(Type type, int value,
                                          OpBuilder &builder, Location loc) {}

/// Check if the type is supported by math-to-spirv conversion. We expect to
/// only see scalars and vectors at this point, with higher-level types already
/// lowered.
static bool isSupportedSourceType(Type originalType) {}

/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
/// Notify of a match failure othwerise and return a `failure` result.
/// This is intended to simplify type checks in `OpConversionPattern`s.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
                                        Operation *sourceOp) {}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.

namespace {
/// Converts elementwise unary, binary, and ternary standard operations to
/// SPIR-V operations. Checks that source `Op` types are supported.
template <typename Op, typename SPIRVOp>
struct CheckedElementwiseOpPattern final
    : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {};

/// Converts math.copysign to SPIR-V ops.
struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {};

/// Converts math.ctlz to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for counting leading zeros. If
/// Shader capability is supported, we can leverage GL FindUMsb to calculate
/// it.
struct CountLeadingZerosPattern final
    : public OpConversionPattern<math::CountLeadingZerosOp> {};

/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
/// these operations.
template <typename ExpOp>
struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {};

/// Converts math.log1p to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
template <typename LogOp>
struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {};

/// Converts math.log2 and math.log10 to SPIR-V ops.
///
/// SPIR-V does not have direct operations for log2 and log10. Explicitly
/// lower to these operations using:
///   log2(x) = log(x) * 1/log(2)
///   log10(x) = log(x) * 1/log(10)

template <typename MathLogOp, typename SpirvLogOp>
struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {};

/// Converts math.powf to SPIRV-Ops.
struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {};

/// Converts math.round to GLSL SPIRV extended ops.
struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

namespace mlir {
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns) {}

} // namespace mlir