# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testSymbolTableInsert
@run
def testSymbolTableInsert():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
m1 = Module.parse(
"""
func.func private @foo()
func.func private @bar()"""
)
m2 = Module.parse(
"""
func.func private @qux()
func.func private @foo()
"foo.bar"() : () -> ()"""
)
symbol_table = SymbolTable(m1.operation)
# CHECK: func private @foo
# CHECK: func private @bar
assert "foo" in symbol_table
print(symbol_table["foo"])
assert "bar" in symbol_table
bar = symbol_table["bar"]
print(symbol_table["bar"])
assert "qux" not in symbol_table
del symbol_table["bar"]
try:
symbol_table.erase(symbol_table["bar"])
except KeyError:
pass
else:
assert False, "expected KeyError"
# CHECK: module
# CHECK: func private @foo()
print(m1)
assert "bar" not in symbol_table
try:
print(bar)
except RuntimeError as e:
if "the operation has been invalidated" not in str(e):
raise
else:
assert False, "expected RuntimeError due to invalidated operation"
qux = m2.body.operations[0]
m1.body.append(qux)
symbol_table.insert(qux)
assert "qux" in symbol_table
# Check that insertion actually renames this symbol in the symbol table.
foo2 = m2.body.operations[0]
m1.body.append(foo2)
updated_name = symbol_table.insert(foo2)
assert foo2.name.value != "foo"
assert foo2.name == updated_name
assert isinstance(updated_name, StringAttr)
# CHECK: module
# CHECK: func private @foo()
# CHECK: func private @qux()
# CHECK: func private @foo{{.*}}
print(m1)
try:
symbol_table.insert(m2.body.operations[0])
except ValueError as e:
if "Expected operation to have a symbol name" not in str(e):
raise
else:
assert False, "exepcted ValueError when adding a non-symbol"
# CHECK-LABEL: testSymbolTableRAUW
@run
def testSymbolTableRAUW():
with Context() as ctx:
m = Module.parse(
"""
func.func private @foo() {
call @bar() : () -> ()
return
}
func.func private @bar()
"""
)
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
# Do renaming just within `foo`.
SymbolTable.set_symbol_name(bar, "bam")
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
# CHECK: call @bam()
# CHECK: func private @bam
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("bam")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
# Do renaming within the module.
SymbolTable.set_symbol_name(bar, "baz")
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
# CHECK: call @baz()
# CHECK: func private @baz
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("baz")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
# CHECK-LABEL: testSymbolTableVisibility
@run
def testSymbolTableVisibility():
with Context() as ctx:
m = Module.parse(
"""
func.func private @foo() {
return
}
"""
)
foo = m.operation.regions[0].blocks[0].operations[0]
# CHECK: Existing visibility: StringAttr("private")
print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
SymbolTable.set_visibility(foo, "public")
# CHECK: func public @foo
print(m)
# CHECK: testWalkSymbolTables
@run
def testWalkSymbolTables():
with Context() as ctx:
m = Module.parse(
"""
module @outer {
module @inner{
}
}
"""
)
def callback(symbol_table_op, uses_visible):
print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
# CHECK: SYMBOL TABLE: True: module @inner
# CHECK: SYMBOL TABLE: True: module @outer
SymbolTable.walk_symbol_tables(m.operation, True, callback)
# Make sure exceptions in the callback are handled.
def error_callback(symbol_table_op, uses_visible):
assert False, "Raised from python"
try:
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
except RuntimeError as e:
# CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
print(f"GOT EXCEPTION: {e}")