llvm/mlir/test/python/dialects/transform_extras.py

# RUN: %PYTHON %s | FileCheck %s

from typing import Callable
from mlir import ir
from mlir.dialects import scf, pdl
from mlir.dialects.transform import (
    structured,
    get_parent_op,
    apply_patterns_canonicalization,
    apply_cse,
    any_op_t,
)
from mlir.dialects.transform import FailurePropagationMode
from mlir.dialects.transform.structured import structured_match
from mlir.dialects.transform.loop import loop_unroll
from mlir.dialects.transform.extras import (
    constant_param,
    OpHandle,
    insert_transform_script,
    sequence,
    apply_patterns,
)
from mlir.extras import types as T


def construct_and_print_in_module(f):
    print("\nTEST:", f.__name__)
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        with ir.InsertionPoint(module.body):
            f()
        print(module)
    return f


def build_transform_script(script: Callable[[OpHandle], None]):
    print("\nTEST:", script.__name__)
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
        insert_transform_script(module.body, script=script, dump_script=True)
        module.operation.verify()


def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
    print("\nTEST:", script.__name__)
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
        insert_transform_script(
            ir.InsertionPoint.at_block_begin(module.body),
            script=script,
            dump_script=True,
        )
        module.operation.verify()


# CHECK-LABEL: TEST: test_build_script_at_insertion_point
@build_transform_script_at_insertion_point
def test_build_script_at_insertion_point(op: OpHandle):
    pass
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: transform.yield
    # CHECK-NEXT: }


# CHECK-LABEL: TEST: test_constant_param_int
@build_transform_script
def test_constant_param_int(_: OpHandle):
    constant_param(ir.IntegerAttr.get(T.i32(), 42))
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
    # CHECK-SAME:   !transform.param<i32>


# CHECK-LABEL: TEST: test_constant_param_py_int
@build_transform_script
def test_constant_param_py_int(_: OpHandle):
    constant_param(42)
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
    # CHECK-SAME:   !transform.param<i64>


# CHECK-LABEL: TEST: test_constant_param_symbol_attr
@build_transform_script
def test_constant_param_symbol_attr(_: OpHandle):
    constant_param(ir.SymbolRefAttr.get(["symbol"]))
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
    # CHECK-SAME:   !transform.any_param


# CHECK-LABEL: TEST: test_constant_param_type
@build_transform_script
def test_constant_param_type(_: OpHandle):
    constant_param(ir.TypeAttr.get(T.i32()))
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
    # CHECK-SAME:   !transform.any_param


# CHECK-LABEL: TEST: test_get_defining_op
@build_transform_script
def test_get_defining_op(op: OpHandle):
    op.get_result().get_defining_op()
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
    # CHECK-SAME:   !transform.any_value
    # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]


# CHECK-LABEL: TEST: test_get_result
@build_transform_script
def test_get_result(op: OpHandle):
    op.get_result()
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]


# CHECK-LABEL: TEST: test_match_ops_single
@build_transform_script
def test_match_ops_single(op: OpHandle):
    op.match_ops(scf.ForOp)
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
    # CHECK-SAME:    in %[[VAL_0]]
    # CHECK-SAME:      -> !transform.op<"scf.for">


# CHECK-LABEL: TEST: test_match_ops_string_name
@build_transform_script
def test_match_ops_string_name(op: OpHandle):
    op.match_ops("linalg.matmul")
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]


# CHECK-LABEL: TEST: test_match_ops_string_iface
@build_transform_script
def test_match_ops_string_iface(op: OpHandle):
    op.match_ops("LinalgOp")
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]


# CHECK-LABEL: TEST: test_match_ops_iface
@build_transform_script
def test_match_ops_iface(op: OpHandle):
    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]


# CHECK-LABEL: TEST: test_match_ops_multiple
@build_transform_script
def test_match_ops_multiple(op: OpHandle):
    op.match_ops([scf.ForOp, scf.ForallOp])
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
    # CHECK-SAME:     -> !transform.any_op


# CHECK-LABEL: TEST: test_match_ops_mixed
@build_transform_script
def test_match_ops_mixed(op: OpHandle):
    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
    # CHECK-SAME:     -> !transform.any_op


# CHECK-LABEL: TEST: test_print_message
@build_transform_script
def test_print_message(op: OpHandle):
    op.print("message")
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}


# CHECK-LABEL: TEST: test_print_plain
@build_transform_script
def test_print_plain(op: OpHandle):
    op.print()
    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
    # CHECK-NEXT: transform.print %[[VAL_0]]


# CHECK-LABEL: TEST: test_sequence_region
@construct_and_print_in_module
def test_sequence_region():
    # CHECK:   transform.sequence  failures(propagate) {
    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
    # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
    # CHECK:   }
    @sequence([], FailurePropagationMode.Propagate, [])
    def basic(target: any_op_t()):
        m = structured_match(any_op_t(), target, ops=["arith.addi"])
        loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
        loop_unroll(loop, 4)


# CHECK-LABEL: TEST: test_apply_patterns
@construct_and_print_in_module
def test_apply_patterns():
    # CHECK:   transform.sequence  failures(propagate) {
    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
    # CHECK:     apply_patterns to %[[VAL_2]] {
    # CHECK:       transform.apply_patterns.canonicalization
    # CHECK:     } : !pdl.operation
    # CHECK:     %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
    # CHECK:     apply_cse to %[[VAL_3]] : !transform.any_op
    # CHECK:   }
    @sequence([], FailurePropagationMode.Propagate, [])
    def basic(variant_op: any_op_t()):
        matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
        top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")

        @apply_patterns(top_func)
        def pats():
            apply_patterns_canonicalization()

        top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
        apply_cse(top_func)