//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arm_sve { #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc" } // namespace mlir::arm_sve usingnamespacemlir; usingnamespacemlir::arm_sve; // A tag to mark unrealized_conversions produced by this pass. This is used to // detect IR this pass failed to completely legalize, and report an error. // If everything was successfully legalized, no tagged ops will remain after // this pass. constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__"); /// Definitions: /// /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE /// predicate registers. A full predicate is the smallest quantity that can be /// loaded/stored. /// /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing /// dimension matches the size of a legal SVE vector size (such as /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than /// a svbool). namespace { /// Checks if a vector type is a SVE mask [2]. bool isSVEMaskType(VectorType type) { … } VectorType widenScalableMaskTypeToSvbool(VectorType type) { … } /// A helper for cloning an op and replacing it will a new version, updated by a /// callback. template <typename TOp, typename TLegalizerCallback> void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op, TLegalizerCallback callback) { … } /// A helper for cloning an op and replacing it with a new version, updated by a /// callback, and an unrealized conversion back to the type of the replaced op. template <typename TOp, typename TLegalizerCallback> void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, TLegalizerCallback callback) { … } /// Extracts the widened SVE memref value (that's legal to store/load) from the /// `unrealized_conversion_cast`s added by this pass. static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) { … } /// The default alignment of an alloca in LLVM may request overaligned sizes for /// SVE types, which will fail during stack frame allocation. This rewrite /// explicitly adds a reasonable alignment to allocas of scalable types. struct RelaxScalableVectorAllocaAlignment : public OpRewritePattern<memref::AllocaOp> { … }; /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_ /// to load/store) with a wider allocation of svbool (_legal_ to load/store) /// followed by a tagged unrealized conversion to the original type. /// /// Example /// ``` /// %alloca = memref.alloca() : memref<vector<[4]xi1>> /// ``` /// is rewritten into: /// ``` /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> /// %alloca = builtin.unrealized_conversion_cast %widened /// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>> /// {__arm_sve_legalize_vector_storage__} /// ``` template <typename AllocLikeOp> struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> { … }; /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts /// of memref types that are _legal_ to load/store, followed by unrealized /// conversions. /// /// Example: /// ``` /// %alloca = builtin.unrealized_conversion_cast %widened /// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>> /// {__arm_sve_legalize_vector_storage__} /// %cast = vector.type_cast %alloca /// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>> /// ``` /// is rewritten into: /// ``` /// %widened_cast = vector.type_cast %widened /// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>> /// %cast = builtin.unrealized_conversion_cast %widened_cast /// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>> /// {__arm_sve_legalize_vector_storage__} /// ``` struct LegalizeSVEMaskTypeCastConversion : public OpRewritePattern<vector::TypeCastOp> { … }; /// Replaces stores to unrealized conversions to SVE predicate memref types that /// are _illegal_ to load/store from (!= svbool [1]), with /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores. /// /// Example: /// ``` /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>> /// ``` /// is rewritten into: /// ``` /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1> /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>> /// ``` struct LegalizeSVEMaskStoreConversion : public OpRewritePattern<memref::StoreOp> { … }; /// Replaces loads from unrealized conversions to SVE predicate memref types /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal) /// wider loads, followed by `arm_sve.convert_from_svbool`s. /// /// Example: /// ``` /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>> /// ``` /// is rewritten into: /// ``` /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>> /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1> /// ``` struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> { … }; } // namespace void mlir::arm_sve::populateLegalizeVectorStoragePatterns( RewritePatternSet &patterns) { … } namespace { struct LegalizeVectorStorage : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> { … }; } // namespace std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() { … }