llvm/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp

//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to emulate the
// 'vector.maskedload' and 'vector.maskedstore' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"

usingnamespacemlir;

namespace {

/// Convert vector.maskedload
///
/// Before:
///
///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
///
/// After:
///
///   %ivalue = %pass_thru
///   %m = vector.extract %mask[0]
///   %result0 = scf.if %m {
///     %v = memref.load %base[%idx_0, %idx_1]
///     %combined = vector.insert %v, %ivalue[0]
///     scf.yield %combined
///   } else {
///     scf.yield %ivalue
///   }
///   %m = vector.extract %mask[1]
///   %result1 = scf.if %m {
///     %v = memref.load %base[%idx_0, %idx_1 + 1]
///     %combined = vector.insert %v, %result0[1]
///     scf.yield %combined
///   } else {
///     scf.yield %result0
///   }
///   ...
///
struct VectorMaskedLoadOpConverter final
    : OpRewritePattern<vector::MaskedLoadOp> {};

/// Convert vector.maskedstore
///
/// Before:
///
///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
///
/// After:
///
///   %m = vector.extract %mask[0]
///   scf.if %m {
///     %extracted = vector.extract %value[0]
///     memref.store %extracted, %base[%idx_0, %idx_1]
///   }
///   %m = vector.extract %mask[1]
///   scf.if %m {
///     %extracted = vector.extract %value[1]
///     memref.store %extracted, %base[%idx_0, %idx_1 + 1]
///   }
///   ...
///
struct VectorMaskedStoreOpConverter final
    : OpRewritePattern<vector::MaskedStoreOp> {};

} // namespace

void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}