# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
from mlir.dialects._ods_common import _cext
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
def expect_index_error(callback):
try:
_ = callback()
raise RuntimeError("Expected IndexError")
except IndexError:
pass
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
op = module.operation
assert op.context is ctx
# Get the block using iterators off of the named collections.
regions = list(op.regions)
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
# Should verify.
# CHECK: .verify = True
print(f".verify = {module.operation.verify()}")
# Get the blocks from the default collection.
default_blocks = list(regions[0])
# They should compare equal regardless of how obtained.
assert default_blocks == blocks
# Should be able to get the operations from either the named collection
# or the block.
operations = list(blocks[0].operations)
default_operations = list(blocks[0])
assert default_operations == operations
def walk_operations(indent, op):
for i, region in enumerate(op.regions):
print(f"{indent}REGION {i}:")
for j, block in enumerate(region):
print(f"{indent} BLOCK {j}:")
for k, child_op in enumerate(block):
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 1: func.return
walk_operations("", op)
# CHECK: Region iter: <mlir.{{.+}}.RegionIterator
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
print(" Region iter:", iter(op.regions))
print(" Block iter:", iter(op.regions[0]))
print("Operation iter:", iter(op.regions[0].blocks[0]))
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
@run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
def walk_operations(indent, op):
for i in range(len(op.regions)):
region = op.regions[i]
print(f"{indent}REGION {i}:")
for j in range(len(region.blocks)):
block = region.blocks[j]
print(f"{indent} BLOCK {j}:")
for k in range(len(block.operations)):
child_op = block.operations[k]
print(f"{indent} OP {k}: {child_op}")
print(
f"{indent} OP {k}: parent {child_op.operation.parent.name}"
)
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: OP 0: parent builtin.module
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 0: parent func.func
# CHECK: OP 1: func.return
# CHECK: OP 1: parent func.func
walk_operations("", module.operation)
# CHECK-LABEL: TEST: testBlockAndRegionOwners
@run
def testBlockAndRegionOwners():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
assert module.operation.regions[0].owner == module.operation
assert module.operation.regions[0].blocks[0].owner == module.operation
func = module.body.operations[0]
assert func.operation.regions[0].owner == func
assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
@run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
r"""
func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
return
}
""",
ctx,
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
# CHECK: Argument 2, type index
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
arg.set_type(new_type)
# CHECK: Argument 0, type i8
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments[1:]:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that we can concatenate slices of argument lists.
# CHECK: Length: 4
print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
# CHECK: Type: i8
# CHECK: Type: i16
# CHECK: Type: i24
for t in entry_block.arguments.types:
print("Type: ", t)
# Check that slicing and type access compose.
# CHECK: Sliced type: i16
# CHECK: Sliced type: i24
for t in entry_block.arguments[1:].types:
print("Sliced type: ", t)
# Check that slice addition works as expected.
# CHECK: Argument 2, type i24
# CHECK: Argument 0, type i8
restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
for arg in restructured:
print(f"Argument {arg.arg_number}, type {arg.type}")
# CHECK-LABEL: TEST: testOperationOperands
@run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) {
%0 = "test.producer"() : () -> i64
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
%3 = "test.producer3"() : () -> i64
%4 = "test.producer4"() : () -> i64
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[5]
assert len(consumer.operands) == 5
for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
assert left == right
# CHECK: test.producer0
# CHECK: test.producer1
# CHECK: test.producer2
# CHECK: test.producer3
# CHECK: test.producer4
full_slice = consumer.operands[:]
for operand in full_slice:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer1
first_two = consumer.operands[0:2]
for operand in first_two:
print(operand)
# CHECK: test.producer3
# CHECK: test.producer4
last_two = consumer.operands[3:]
for operand in last_two:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer2
# CHECK: test.producer4
even = consumer.operands[::2]
for operand in even:
print(operand)
# CHECK: test.producer2
fourth = consumer.operands[::2][1::2]
for operand in fourth:
print(operand)
# CHECK-LABEL: TEST: testOperationOperandsSet
@run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
"test.consumer"(%0) : (i64) -> ()
return
}"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer1 = entry_block.operations[1]
producer2 = entry_block.operations[2]
consumer = entry_block.operations[3]
assert len(consumer.operands) == 1
type = consumer.operands[0].type
# CHECK: test.producer1
consumer.operands[0] = producer1.result
print(consumer.operands[0])
# CHECK: test.producer2
consumer.operands[-1] = producer2.result
print(consumer.operands[0])
# CHECK-LABEL: TEST: testDetachedOperation
@run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create(
"custom.op1",
results=[i32, i32],
regions=1,
attributes={
"foo": StringAttr.get("foo_value"),
"bar": StringAttr.get("bar_value"),
},
)
# CHECK: %0:2 = "custom.op1"() ({
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
print(op1)
# TODO: Check successors once enough infra exists to do it properly.
# CHECK-LABEL: TEST: testOperationInsertionPoint
@run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""",
ctx,
)
# Create test op.
with Location.unknown(ctx):
op1 = Operation.create("custom.op1")
op2 = Operation.create("custom.op2")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
ip.insert(op2)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.op2"()
# CHECK: %0 = "custom.addi"
print(module)
# Trying to add a previously added op should raise.
try:
ip.insert(op1)
except ValueError:
pass
else:
assert False, "expected insert of attached op to raise"
# CHECK-LABEL: TEST: testOperationWithRegion
@run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create("custom.op1", regions=1)
block = op1.regions[0].blocks.append(i32, i32)
# CHECK: "custom.op1"() ({
# CHECK: ^bb0(%arg0: si32, %arg1: si32):
# CHECK: "custom.terminator"() : () -> ()
# CHECK: }) : () -> ()
terminator = Operation.create("custom.terminator")
ip = InsertionPoint(block)
ip.insert(terminator)
print(op1)
# Now add the whole operation to another op.
# TODO: Verify lifetime hazard by nulling out the new owning module and
# accessing op1.
# TODO: Also verify accessing the terminator once both parents are nulled
# out.
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.terminator"
# CHECK: %0 = "custom.addi"
print(module)
# CHECK-LABEL: TEST: testOperationResultList
@run
def testOperationResultList():
ctx = Context()
module = Module.parse(
r"""
func.func @f1() {
%0:3 = call @f2() : () -> (i32, f64, index)
call @f3() : () -> ()
return
}
func.func private @f2() -> (i32, f64, index)
func.func private @f3() -> ()
""",
ctx,
)
caller = module.body.operations[0]
call = caller.regions[0].blocks[0].operations[0]
assert len(call.results) == 3
# CHECK: Result 0, type i32
# CHECK: Result 1, type f64
# CHECK: Result 2, type index
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
for t in call.results.types:
print(f"Result type {t}")
# Out of range
expect_index_error(lambda: call.results[3])
expect_index_error(lambda: call.results[-4])
no_results_call = caller.regions[0].blocks[0].operations[1]
assert len(no_results_call.results) == 0
assert no_results_call.results.owner == no_results_call
# CHECK-LABEL: TEST: testOperationResultListSlice
@run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @f1() {
"some.op"() : () -> (i1, i2, i3, i4, i5)
return
}
"""
)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer = entry_block.operations[0]
assert len(producer.results) == 5
for left, right in zip(producer.results, producer.results[::-1][::-1]):
assert left == right
assert left.result_number == right.result_number
# CHECK: Result 0, type i1
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
# CHECK: Result 4, type i5
full_slice = producer.results[:]
for res in full_slice:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
middle = producer.results[1:4]
for res in middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 3, type i4
odd = producer.results[1::2]
for res in odd:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 3, type i4
# CHECK: Result 1, type i2
inverted_middle = producer.results[-2:0:-2]
for res in inverted_middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""",
ctx,
)
op = module.body.operations[0]
assert len(op.attributes) == 3
iattr = op.attributes["some.attribute"]
fattr = op.attributes["other.attribute"]
sattr = op.attributes["dependent"]
# CHECK: Attribute type i8, value 1
print(f"Attribute type {iattr.type}, value {iattr.value}")
# CHECK: Attribute type f64, value 3.0
print(f"Attribute type {fattr.type}, value {fattr.value}")
# CHECK: Attribute value text
print(f"Attribute value {sattr.value}")
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")
# We don't know in which order the attributes are stored.
# CHECK-DAG: NamedAttribute(dependent="text")
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
for attr in op.attributes:
print(str(attr))
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
except KeyError:
pass
else:
assert False, "expected KeyError on accessing a non-existent attribute"
try:
op.attributes[42]
except IndexError:
pass
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
# CHECK-LABEL: TEST: testOperationPrint
@run
def testOperationPrint():
ctx = Context()
module = Module.parse(
r"""
func.func @f1(%arg0: i32) -> i32 {
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
return %arg0 : i32
}
""",
ctx,
)
# Test print to stdout.
# CHECK: return %arg0 : i32
module.operation.print()
# Test print to text file.
f = io.StringIO()
# CHECK: <class 'str'>
# CHECK: return %arg0 : i32
module.operation.print(file=f)
str_value = f.getvalue()
print(str_value.__class__)
print(f.getvalue())
# Test roundtrip to bytecode.
bytecode_stream = io.BytesIO()
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
module_roundtrip = Module.parse(bytecode, ctx)
f = io.StringIO()
module_roundtrip.operation.print(file=f)
roundtrip_value = f.getvalue()
assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
# Test print to binary file.
f = io.BytesIO()
# CHECK: <class 'bytes'>
# CHECK: return %arg0 : i32
module.operation.print(file=f, binary=True)
bytes_value = f.getvalue()
print(bytes_value.__class__)
print(bytes_value)
# Test print local_scope.
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
module.operation.print(enable_debug_info=True, use_local_scope=True)
# Test printing using state.
state = AsmState(module.operation)
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
module.operation.print(state)
# Test print with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
module.operation.print(
large_elements_limit=2,
enable_debug_info=True,
pretty_debug_info=True,
print_generic_op_form=True,
use_local_scope=True,
)
# Test print with skip_regions option
# CHECK: func.func @f1(%arg0: i32) -> i32
# CHECK-NOT: func.return
module.body.operations[0].print(
skip_regions=True,
)
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
%4 = arith.constant 0 : i32
"""
)
print(module)
# addf should map to a known OpView class in the arithmetic dialect.
# We know the OpView for it defines an 'lhs' attribute.
addf = module.body.operations[2]
# CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
print(repr(addf))
# CHECK: "custom.f32"()
print(addf.lhs)
# One of the custom ops should resolve to the default OpView.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# Check again to make sure negative caching works.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# constant should map to an extension OpView class in the arithmetic dialect.
constant = module.body.operations[3]
# CHECK: <mlir.dialects.arith.ConstantOp object
print(repr(constant))
# Checks that the arith extension is being registered successfully
# (literal_value is a property on the extension class but not on the default OpView).
# CHECK: literal value 0
print("literal value", constant.literal_value)
# Checks that "late" registration/replacement (i.e., post all module loading/initialization)
# is working correctly.
@_cext.register_operation(arith._Dialect, replace=True)
class ConstantOp(arith.ConstantOp):
def __init__(self, result, value, *, loc=None, ip=None):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
constant = module.body.operations[3]
# CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
print(repr(constant))
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(
r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
"""
)
print(module)
try:
module.body.operations[0].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
print(e)
else:
assert False, "Expected exception"
try:
module.body.operations[1].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
print(e)
else:
assert False, "Expected exception"
# CHECK: %1 = "custom.one_result"() : () -> f32
print(module.body.operations[2])
def create_invalid_operation():
# This module has two region and is invalid verify that we fallback
# to the generic printer for safety.
op = Operation.create("builtin.module", regions=2)
op.regions[0].blocks.append()
return op
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
@run
def testInvalidOperationStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(invalid_op)
try:
invalid_op.verify()
except MLIRError as e:
# CHECK: Exception: <
# CHECK: Verification failed:
# CHECK: error: unknown: 'builtin.module' op requires one region
# CHECK: note: unknown: see current operation:
# CHECK: "builtin.module"() ({
# CHECK: ^bb0:
# CHECK: }, {
# CHECK: }) : () -> ()
# CHECK: >
print(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@run
def testInvalidModuleStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(module)
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
@run
def testInvalidOperationGetAsmBinarySoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
print(invalid_op.get_asm(binary=True))
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
@run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
try:
Operation.create(
"builtin.module", attributes={None: StringAttr.get("name")}
)
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": ctx})
except Exception as e:
# CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": None})
except Exception as e:
# CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
# CHECK-LABEL: TEST: testOperationName
@run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
%0 = "custom.op1"() : () -> f32
%1 = "custom.op2"() : () -> i32
%2 = "custom.op1"() : () -> f32
""",
ctx,
)
# CHECK: custom.op1
# CHECK: custom.op2
# CHECK: custom.op1
for op in module.body.operations:
print(op.operation.name)
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Operation.create("custom.op1").operation
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
op.operation.erase()
# CHECK-NOT: "custom.op1"
print(m)
# Ensure we can create another operation
Operation.create("custom.op2")
# CHECK-LABEL: TEST: testOperationClone
@run
def testOperationClone():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
clone = op.operation.clone()
op.operation.erase()
# CHECK: "custom.op1"
print(m)
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
loc = Location.name("loc")
op = Operation.create("custom.op", loc=loc)
assert op.location == loc
assert op.operation.location == loc
# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse(
"""
func.func private @bar()
func.func private @qux()
"""
)
foo = m1.body.operations[0]
bar = m2.body.operations[0]
qux = m2.body.operations[1]
bar.move_before(foo)
qux.move_after(foo)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
# CHECK: func private @qux
print(m1)
# CHECK: module {
# CHECK-NEXT: }
print(m2)
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
@run
def testAppendMoveFromAnotherBlock():
with Context():
m1 = Module.parse("func.func private @foo()")
m2 = Module.parse("func.func private @bar()")
func = m1.body.operations[0]
m2.body.append(func)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
print(m2)
# CHECK: module {
# CHECK-NEXT: }
print(m1)
# CHECK-LABEL: TEST: testDetachFromParent
@run
def testDetachFromParent():
with Context():
m1 = Module.parse("func.func private @foo()")
func = m1.body.operations[0].detach_from_parent()
try:
func.detach_from_parent()
except ValueError as e:
if "has no parent" not in str(e):
raise
else:
assert False, "expected ValueError when detaching a detached operation"
print(m1)
# CHECK-NOT: func private @foo
# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)
# CHECK-LABEL: TEST: testOperationParse
@run
def testOperationParse():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
# Generic operation parsing.
m = Operation.parse("module {}")
o = Operation.parse('"test.foo"() : () -> ()')
assert isinstance(m, ModuleOp)
assert type(o) is OpView
# Parsing specific operation.
m = ModuleOp.parse("module {}")
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
except MLIRError as e:
# CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
print(f"error: {e}")
else:
assert False, "expected error"
o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
# CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
# CHECK-LABEL: TEST: testOpWalk
@run
def testOpWalk():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
func.func @f() {
func.return
}
}
""",
ctx,
)
def callback(op):
print(op.name)
return WalkResult.ADVANCE
# Test post-order walk (default).
# CHECK-NEXT: Post-order
# CHECK-NEXT: func.return
# CHECK-NEXT: func.func
# CHECK-NEXT: builtin.module
print("Post-order")
module.operation.walk(callback)
# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NEXT: builtin.module
# CHECK-NEXT: func.fun
# CHECK-NEXT: func.return
print("Pre-order")
module.operation.walk(callback, WalkOrder.PRE_ORDER)
# Test interrput.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: func.return
print("Interrupt post-order")
def callback(op):
print(op.name)
return WalkResult.INTERRUPT
module.operation.walk(callback)
# Test skip.
# CHECK-NEXT: Skip pre-order
# CHECK-NEXT: builtin.module
print("Skip pre-order")
def callback(op):
print(op.name)
return WalkResult.SKIP
module.operation.walk(callback, WalkOrder.PRE_ORDER)
# Test exception.
# CHECK: Exception
# CHECK-NEXT: func.return
# CHECK-NEXT: Exception raised
print("Exception")
def callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE
try:
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")