# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
import gc
from mlir.ir import *
from mlir.dialects import func
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
value = Operation.create("custom.op1", results=[i32]).result
value_capsule = value._CAPIPtr
assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
value2 = Value._CAPICreate(value_capsule)
assert value2 == value
# CHECK-LABEL: TEST: testOpResultOwner
@run
def testOpResultOwner():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
op = Operation.create("custom.op1", results=[i32])
assert op.result.owner == op
# CHECK-LABEL: TEST: testBlockArgOwner
@run
def testBlockArgOwner():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @foo(%arg0: f32) {
return
}""",
ctx,
)
func = module.body.operations[0]
block = func.regions[0].blocks[0]
assert block.arguments[0].owner == block
# CHECK-LABEL: TEST: testValueIsInstance
@run
def testValueIsInstance():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @foo(%arg0: f32) {
%0 = "some_dialect.some_op"() : () -> f64
return
}""",
ctx,
)
func = module.body.operations[0]
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
op = func.regions[0].blocks[0].operations[0]
assert not BlockArgument.isinstance(op.results[0])
assert OpResult.isinstance(op.results[0])
# CHECK-LABEL: TEST: testValueHash
@run
def testValueHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func.func @foo(%arg0: f32) -> f32 {
%0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
return %0 : f32
}""",
ctx,
)
[func] = module.body.operations
block = func.entry_block
op, ret = block.operations
assert hash(block.arguments[0]) == hash(op.operands[0])
assert hash(op.result) == hash(ret.operands[0])
# CHECK-LABEL: TEST: testValueUses
@run
def testValueUses():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
op1 = Operation.create("custom.op2", operands=[value])
op2 = Operation.create("custom.op2", operands=[value])
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
for use in value.uses:
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
@run
def testValueReplaceAllUsesWith():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
op1 = Operation.create("custom.op2", operands=[value])
op2 = Operation.create("custom.op2", operands=[value])
value2 = Operation.create("custom.op3", results=[i32]).results[0]
value.replace_all_uses_with(value2)
assert len(list(value.uses)) == 0
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
for use in value2.uses:
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")
# CHECK-LABEL: TEST: testValuePrintAsOperand
@run
def testValuePrintAsOperand():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
# CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
print(value)
value2 = Operation.create("custom.op2", results=[i32]).results[0]
# CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
print(value2)
topFn = func.FuncOp("test", ([i32, i32], []))
entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
with InsertionPoint(entry_block):
value3 = Operation.create("custom.op3", results=[i32]).results[0]
# CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
print(value3)
value4 = Operation.create("custom.op4", results=[i32]).results[0]
# CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
print(value4)
func.ReturnOp([])
# CHECK: %[[VAL1]]
print(value.get_name())
# CHECK: %[[VAL2]]
print(value2.get_name())
# CHECK: %[[VAL3]]
print(value3.get_name())
# CHECK: %[[VAL4]]
print(value4.get_name())
print("With AsmState")
# CHECK-LABEL: With AsmState
state = AsmState(topFn.operation, use_local_scope=True)
# CHECK: %0
print(value3.get_name(state=state))
# CHECK: %1
print(value4.get_name(state=state))
print("With use_local_scope")
# CHECK-LABEL: With use_local_scope
# CHECK: %0
print(value3.get_name(use_local_scope=True))
# CHECK: %1
print(value4.get_name(use_local_scope=True))
# CHECK: %[[ARG0:.*]]
print(entry_block.arguments[0].get_name())
# CHECK: %[[ARG1:.*]]
print(entry_block.arguments[1].get_name())
# CHECK: module {
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
# CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
# CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
# CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
# CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
# CHECK: return
# CHECK: }
# CHECK: }
print(module)
value2.owner.detach_from_parent()
# CHECK: %0
print(value2.get_name())
# CHECK-LABEL: TEST: testValueSetType
@run
def testValueSetType():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
# CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
print(value)
value.set_type(i64)
# CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
print(value)
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
print(value.owner)
# CHECK-LABEL: TEST: testValueCasters
@run
def testValueCasters():
class NOPResult(OpResult):
def __init__(self, v):
super().__init__(v)
def __str__(self):
return super().__str__().replace(Value.__name__, NOPResult.__name__)
class NOPValue(Value):
def __init__(self, v):
super().__init__(v)
def __str__(self):
return super().__str__().replace(Value.__name__, NOPValue.__name__)
class NOPBlockArg(BlockArgument):
def __init__(self, v):
super().__init__(v)
def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
@register_value_caster(IntegerType.static_typeid)
def cast_int(v) -> Value:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
return NOPResult(v)
if isinstance(v, BlockArgument):
return NOPBlockArg(v)
elif isinstance(v, Value):
return NOPValue(v)
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
values = Operation.create("custom.op1", results=[i32, i32]).results
# CHECK: in caster OpResult
# CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("result", values[0].result_number, values[0])
# CHECK: in caster OpResult
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("result", values[1].result_number, values[1])
# CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("results slice", values[:1][0].result_number, values[:1][0])
value0, value1 = values
# CHECK: in caster OpResult
# CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("result", value0.result_number, values[0])
# CHECK: in caster OpResult
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("result", value1.result_number, values[1])
op1 = Operation.create("custom.op2", operands=[value0, value1])
# CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
print(op1)
# CHECK: in caster Value
# CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 0", op1.operands[0])
# CHECK: in caster Value
# CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 1", op1.operands[1])
# CHECK: in caster BlockArgument
# CHECK: in caster BlockArgument
@func.FuncOp.from_py_func(i32, i32)
def reduction(arg0, arg1):
# CHECK: as func arg 0 NOPBlockArg
print("as func arg", arg0.arg_number, arg0.__class__.__name__)
# CHECK: as func arg 1 NOPBlockArg
print("as func arg", arg1.arg_number, arg1.__class__.__name__)
# CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
print(
"args slice",
reduction.func_op.arguments[:1][0].arg_number,
reduction.func_op.arguments[:1][0],
)
try:
@register_value_caster(IntegerType.static_typeid)
def dont_cast_int_shouldnt_register(v):
...
except RuntimeError as e:
# CHECK: Value caster is already registered: {{.*}}cast_int
print(e)
@register_value_caster(IntegerType.static_typeid, replace=True)
def dont_cast_int(v) -> OpResult:
assert isinstance(v, OpResult)
print("don't cast", v.result_number, v)
return v
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
# CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
new_value = Operation.create("custom.op1", results=[i32]).result
# CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
print("result", new_value.result_number, new_value)
# CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
new_value = Operation.create("custom.op2", results=[i32]).results[0]
# CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
print("result", new_value.result_number, new_value)