llvm/mlir/test/python/dialects/transform_vector_ext.py

# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector


def run_apply_patterns(f):
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            sequence = transform.SequenceOp(
                transform.FailurePropagationMode.Propagate,
                [],
                transform.AnyOpType.get(),
            )
            with InsertionPoint(sequence.body):
                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
                with InsertionPoint(apply.patterns):
                    f()
                transform.YieldOp()
        print("\nTEST:", f.__name__)
        print(module)
    return f


@run_apply_patterns
def non_configurable_patterns():
    # CHECK-LABEL: TEST: non_configurable_patterns
    # CHECK: apply_patterns
    # CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
    vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
    # CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
    vector.ApplyRankReducingSubviewPatternsOp()
    # CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
    vector.ApplyTransferPermutationPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_broadcast
    vector.ApplyLowerBroadcastPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_masks
    vector.ApplyLowerMasksPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_masked_transfers
    vector.ApplyLowerMaskedTransfersPatternsOp()
    # CHECK: transform.apply_patterns.vector.materialize_masks
    vector.ApplyMaterializeMasksPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_outerproduct
    vector.ApplyLowerOuterProductPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_gather
    vector.ApplyLowerGatherPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_scan
    vector.ApplyLowerScanPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_shape_cast
    vector.ApplyLowerShapeCastPatternsOp()


@run_apply_patterns
def configurable_patterns():
    # CHECK-LABEL: TEST: configurable_patterns
    # CHECK: apply_patterns
    # CHECK: transform.apply_patterns.vector.lower_transfer
    # CHECK-SAME: max_transfer_rank = 4
    vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
    # CHECK: transform.apply_patterns.vector.transfer_to_scf
    # CHECK-SAME: max_transfer_rank = 3
    # CHECK-SAME: full_unroll = true
    vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)


@run_apply_patterns
def enum_configurable_patterns():
    # CHECK: transform.apply_patterns.vector.lower_contraction
    vector.ApplyLowerContractionPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_contraction
    # CHECK-SAME: lowering_strategy = matmulintrinsics
    vector.ApplyLowerContractionPatternsOp(
        lowering_strategy=vector.VectorContractLowering.Matmul
    )
    # CHECK: transform.apply_patterns.vector.lower_contraction
    # CHECK-SAME: lowering_strategy = parallelarith
    vector.ApplyLowerContractionPatternsOp(
        lowering_strategy=vector.VectorContractLowering.ParallelArith
    )

    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
    vector.ApplyLowerMultiReductionPatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
    # This is the default mode, not printed.
    vector.ApplyLowerMultiReductionPatternsOp(
        lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
    )
    # CHECK: transform.apply_patterns.vector.lower_multi_reduction
    # CHECK-SAME: lowering_strategy = innerreduction
    vector.ApplyLowerMultiReductionPatternsOp(
        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
    )

    # CHECK: transform.apply_patterns.vector.lower_transpose
    vector.ApplyLowerTransposePatternsOp()
    # CHECK: transform.apply_patterns.vector.lower_transpose
    # This is the default strategy, not printed.
    vector.ApplyLowerTransposePatternsOp(
        lowering_strategy=vector.VectorTransposeLowering.EltWise
    )
    # CHECK: transform.apply_patterns.vector.lower_transpose
    # CHECK-SAME: lowering_strategy = flat_transpose
    vector.ApplyLowerTransposePatternsOp(
        lowering_strategy=vector.VectorTransposeLowering.Flat
    )
    # CHECK: transform.apply_patterns.vector.lower_transpose
    # CHECK-SAME: lowering_strategy = shuffle_1d
    vector.ApplyLowerTransposePatternsOp(
        lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
    )
    # CHECK: transform.apply_patterns.vector.lower_transpose
    # CHECK-SAME: lowering_strategy = shuffle_16x16
    vector.ApplyLowerTransposePatternsOp(
        lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
    )
    # CHECK: transform.apply_patterns.vector.lower_transpose
    # CHECK-SAME: lowering_strategy = flat_transpose
    # CHECK-SAME: avx2_lowering_strategy = true
    vector.ApplyLowerTransposePatternsOp(
        lowering_strategy=vector.VectorTransposeLowering.Flat,
        avx2_lowering_strategy=True,
    )

    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
    vector.ApplySplitTransferFullPartialPatternsOp()
    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
    # CHECK-SAME: split_transfer_strategy = none
    vector.ApplySplitTransferFullPartialPatternsOp(
        split_transfer_strategy=vector.VectorTransferSplit.None_
    )
    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
    # CHECK-SAME: split_transfer_strategy = "vector-transfer"
    vector.ApplySplitTransferFullPartialPatternsOp(
        split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
    )
    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
    # This is the default mode, not printed.
    vector.ApplySplitTransferFullPartialPatternsOp(
        split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
    )
    # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
    # CHECK-SAME: split_transfer_strategy = "force-in-bounds"
    vector.ApplySplitTransferFullPartialPatternsOp(
        split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
    )