//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the folders and canonicalization patterns for SPIR-V ops. // //===----------------------------------------------------------------------===// #include <optional> #include <utility> #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" usingnamespacemlir; //===----------------------------------------------------------------------===// // Common utility functions //===----------------------------------------------------------------------===// /// Returns the boolean value under the hood if the given `boolAttr` is a scalar /// or splat vector bool constant. static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) { … } // Extracts an element from the given `composite` by following the given // `indices`. Returns a null Attribute if error happens. static Attribute extractCompositeElement(Attribute composite, ArrayRef<unsigned> indices) { … } static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) { … } //===----------------------------------------------------------------------===// // TableGen'erated canonicalizers //===----------------------------------------------------------------------===// namespace { #include "SPIRVCanonicalization.inc" } // namespace //===----------------------------------------------------------------------===// // spirv.AccessChainOp //===----------------------------------------------------------------------===// namespace { /// Combines chained `spirv::AccessChainOp` operations into one /// `spirv::AccessChainOp` operation. struct CombineChainedAccessChain final : OpRewritePattern<spirv::AccessChainOp> { … }; } // namespace void spirv::AccessChainOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // spirv.IAddCarry //===----------------------------------------------------------------------===// // We are required to use CompositeConstructOp to create a constant struct as // they are not yet implemented as constant, hence we can not do so in a fold. struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> { … }; void spirv::IAddCarryOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // spirv.[S|U]MulExtended //===----------------------------------------------------------------------===// // We are required to use CompositeConstructOp to create a constant struct as // they are not yet implemented as constant, hence we can not do so in a fold. template <typename MulOp, bool IsSigned> struct MulExtendedFold final : OpRewritePattern<MulOp> { … }; SMulExtendedOpFold; void spirv::SMulExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> { … }; UMulExtendedOpFold; void spirv::UMulExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// // Input: // %0 = spirv.UMod %arg0, %const32 : i32 // %1 = spirv.UMod %0, %const4 : i32 // Output: // %0 = spirv.UMod %arg0, %const32 : i32 // %1 = spirv.UMod %arg0, %const4 : i32 // The transformation is only applied if one divisor is a multiple of the other. // TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants struct UModSimplification final : OpRewritePattern<spirv::UModOp> { … }; void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // spirv.BitcastOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) { … } //===----------------------------------------------------------------------===// // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.Constant //===----------------------------------------------------------------------===// OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) { … } //===----------------------------------------------------------------------===// // spirv.IAdd //===----------------------------------------------------------------------===// OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.IMul //===----------------------------------------------------------------------===// OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ISub //===----------------------------------------------------------------------===// OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SDiv //===----------------------------------------------------------------------===// OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SMod //===----------------------------------------------------------------------===// OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SRem //===----------------------------------------------------------------------===// OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.UDiv //===----------------------------------------------------------------------===// OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SNegate //===----------------------------------------------------------------------===// OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.NotOp //===----------------------------------------------------------------------===// OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.LogicalEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.LogicalNotEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.LogicalNot //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) { … } void spirv::LogicalNotOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // spirv.LogicalOr //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SelectOp //===----------------------------------------------------------------------===// OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.IEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.INotEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SGreaterThan //===----------------------------------------------------------------------===// OpFoldResult spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SGreaterThanEqual //===----------------------------------------------------------------------===// OpFoldResult spirv::SGreaterThanEqualOp::fold( spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.UGreaterThan //===----------------------------------------------------------------------===// OpFoldResult spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.UGreaterThanEqual //===----------------------------------------------------------------------===// OpFoldResult spirv::UGreaterThanEqualOp::fold( spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SLessThan //===----------------------------------------------------------------------===// OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.SLessThanEqual //===----------------------------------------------------------------------===// OpFoldResult spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ULessThan //===----------------------------------------------------------------------===// OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ULessThanEqual //===----------------------------------------------------------------------===// OpFoldResult spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ShiftLeftLogical //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftLeftLogicalOp::fold( spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ShiftRightArithmetic //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftRightArithmeticOp::fold( spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.ShiftRightLogical //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftRightLogicalOp::fold( spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.BitwiseAndOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.BitwiseOrOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.BitwiseXorOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // spirv.mlir.selection //===----------------------------------------------------------------------===// namespace { // Blocks from the given `spirv.mlir.selection` operation must satisfy the // following layout: // // +-----------------------------------------------+ // | header block | // | spirv.BranchConditionalOp %cond, ^case0, ^case1 | // +-----------------------------------------------+ // / \ // ... // // // +------------------------+ +------------------------+ // | case #0 | | case #1 | // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 | // | spirv.Branch ^merge | | spirv.Branch ^merge | // +------------------------+ +------------------------+ // // // ... // \ / // v // +-------------+ // | merge block | // +-------------+ // struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> { … }; LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { … } } // namespace void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … }