llvm/mlir/test/python/dialects/transform_interpreter.py

# RUN: %PYTHON %s | FileCheck %s

from mlir import ir
from mlir.dialects.transform import interpreter as interp


def test_in_context(f):
    with ir.Context(), ir.Location.unknown():
        f()
    return f


print_root_module = """
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root: !transform.any_op) {
    transform.print %root { name = \"from interpreter\" }: !transform.any_op
    transform.yield
  }
}"""


@test_in_context
def print_self():
    m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
    interp.apply_named_sequence(m, m.body.operations[0], m)


# CHECK-LABEL: print_self
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.print
# CHECK: transform.yield


@test_in_context
def print_other():
    transform = ir.Module.parse(
        print_root_module.replace("from interpreter", "print_other")
    )
    payload = ir.Module.parse("module attributes { this.is.payload } {}")
    interp.apply_named_sequence(payload, transform.body.operations[0], transform)


# CHECK-LABEL: print_other
# CHECK-NOT: transform
# CHECK: this.is.payload


@test_in_context
def transform_options():
    options = interp.TransformOptions()
    options.expensive_checks = False
    options.enforce_single_top_level_transform_op = True
    m = ir.Module.parse(
        print_root_module.replace("from interpreter", "transform_options")
    )
    payload = ir.Module.parse("module attributes { this.is.payload } {}")
    interp.apply_named_sequence(payload, m.body.operations[0], m, options)


# CHECK-LABEL: transform_options


@test_in_context
def failed():
    payload = ir.Module.parse("module attributes { this.is.payload } {}")
    try:
        interp.apply_named_sequence(payload, payload, payload)
    except ValueError as e:
        assert (
            "must implement TransformOpInterface to be used as transform root" in str(e)
        )


print_root_via_include_module = """
module @print_root_via_include_module attributes {transform.with_named_sequence} {
  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
  transform.named_sequence @__transform_main(%root: !transform.any_op) {
    transform.include @callee2 failures(propagate)
        (%root) : (!transform.any_op) -> ()
    transform.yield
  }
}"""

callee2_definition = """
module attributes {transform.with_named_sequence} {
  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
    transform.include @callee1 failures(propagate)
        (%root) : (!transform.any_op) -> ()
    transform.yield
  }
}
"""

callee1_definition = """
module attributes {transform.with_named_sequence} {
  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
    transform.print %root { name = \"from interpreter\" }: !transform.any_op
    transform.yield
  }
}
"""


@test_in_context
def include():
    main = ir.Module.parse(print_root_via_include_module)
    callee1 = ir.Module.parse(callee1_definition)
    callee2 = ir.Module.parse(callee2_definition)
    interp.copy_symbols_and_merge_into(main, callee1)
    interp.copy_symbols_and_merge_into(main, callee2)

    # CHECK: @print_root_via_include_module
    # CHECK: transform.named_sequence @__transform_main
    # CHECK: transform.include @callee2
    #
    # CHECK: transform.named_sequence @callee1
    # CHECK: transform.print
    #
    # CHECK: transform.named_sequence @callee2
    # CHECK: transform.include @callee1
    interp.apply_named_sequence(main, main.body.operations[0], main)


@test_in_context
def partial_include():
    main = ir.Module.parse(print_root_via_include_module)
    callee2 = ir.Module.parse(callee2_definition)
    interp.copy_symbols_and_merge_into(main, callee2)

    try:
        interp.apply_named_sequence(main, main.body.operations[0], main)
    except ValueError as e:
        assert "Failed to apply" in str(e)


@test_in_context
def repeated_include():
    main = ir.Module.parse(print_root_via_include_module)
    callee2 = ir.Module.parse(callee2_definition)
    interp.copy_symbols_and_merge_into(main, callee2)

    try:
        interp.copy_symbols_and_merge_into(main, callee2)
    except ValueError as e:
        assert "doubly defined symbol @callee2" in str(e)