llvm/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp

//===- SPIRVWebGPUTransforms.cpp - WebGPU-specific transforms -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements SPIR-V transforms used when targetting WebGPU.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"

#include <array>
#include <cstdint>

namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
} // namespace spirv
} // namespace mlir

namespace mlir {
namespace spirv {
namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {}

static Value lowerExtendedMultiplication(Operation *mulOp,
                                         PatternRewriter &rewriter, Value lhs,
                                         Value rhs, bool signExtendArguments) {}

//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//

template <typename MulExtendedOp, bool SignExtendArguments>
struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {};

ExpandSMulExtendedPattern;
ExpandUMulExtendedPattern;

struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {};

struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {};

struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {};

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
struct WebGPUPreparePass final
    : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {};
} // namespace

//===----------------------------------------------------------------------===//
// Public Interface
//===----------------------------------------------------------------------===//
void populateSPIRVExpandExtendedMultiplicationPatterns(
    RewritePatternSet &patterns) {}

void populateSPIRVExpandNonFiniteArithmeticPatterns(
    RewritePatternSet &patterns) {}

} // namespace spirv
} // namespace mlir