# RUN: %PYTHON %s | FileCheck %s
from mlir import ir
from mlir.dialects.transform import interpreter as interp
def test_in_context(f):
with ir.Context(), ir.Location.unknown():
f()
return f
print_root_module = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}"""
@test_in_context
def print_self():
m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
interp.apply_named_sequence(m, m.body.operations[0], m)
# CHECK-LABEL: print_self
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.print
# CHECK: transform.yield
@test_in_context
def print_other():
transform = ir.Module.parse(
print_root_module.replace("from interpreter", "print_other")
)
payload = ir.Module.parse("module attributes { this.is.payload } {}")
interp.apply_named_sequence(payload, transform.body.operations[0], transform)
# CHECK-LABEL: print_other
# CHECK-NOT: transform
# CHECK: this.is.payload
@test_in_context
def transform_options():
options = interp.TransformOptions()
options.expensive_checks = False
options.enforce_single_top_level_transform_op = True
m = ir.Module.parse(
print_root_module.replace("from interpreter", "transform_options")
)
payload = ir.Module.parse("module attributes { this.is.payload } {}")
interp.apply_named_sequence(payload, m.body.operations[0], m, options)
# CHECK-LABEL: transform_options
@test_in_context
def failed():
payload = ir.Module.parse("module attributes { this.is.payload } {}")
try:
interp.apply_named_sequence(payload, payload, payload)
except ValueError as e:
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)
print_root_via_include_module = """
module @print_root_via_include_module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.include @callee2 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}"""
callee2_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
transform.include @callee1 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}
"""
callee1_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}
"""
@test_in_context
def include():
main = ir.Module.parse(print_root_via_include_module)
callee1 = ir.Module.parse(callee1_definition)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee1)
interp.copy_symbols_and_merge_into(main, callee2)
# CHECK: @print_root_via_include_module
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.include @callee2
#
# CHECK: transform.named_sequence @callee1
# CHECK: transform.print
#
# CHECK: transform.named_sequence @callee2
# CHECK: transform.include @callee1
interp.apply_named_sequence(main, main.body.operations[0], main)
@test_in_context
def partial_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)
try:
interp.apply_named_sequence(main, main.body.operations[0], main)
except ValueError as e:
assert "Failed to apply" in str(e)
@test_in_context
def repeated_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)
try:
interp.copy_symbols_and_merge_into(main, callee2)
except ValueError as e:
assert "doubly defined symbol @callee2" in str(e)