# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# CHECK-LABEL: TEST: testAffineMapCapsule
@run
def testAffineMapCapsule():
with Context() as ctx:
am1 = AffineMap.get_empty(ctx)
# CHECK: mlir.ir.AffineMap._CAPIPtr
affine_map_capsule = am1._CAPIPtr
print(affine_map_capsule)
am2 = AffineMap._CAPICreate(affine_map_capsule)
assert am2 == am1
assert am2.context is ctx
# CHECK-LABEL: TEST: testAffineMapGet
@run
def testAffineMapGet():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
# CHECK: (d0, d1)[s0, s1, s2] -> ()
map0 = AffineMap.get(2, 3, [])
print(map0)
# CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
map1 = AffineMap.get(2, 3, [d1, c2])
print(map1)
# CHECK: () -> (2)
map2 = AffineMap.get(0, 0, [c2])
print(map2)
# CHECK: (d0, d1) -> (d0, d1)
map3 = AffineMap.get(2, 0, [d0, d1])
print(map3)
# CHECK: (d0, d1) -> (d1)
map4 = AffineMap.get(2, 0, [d1])
print(map4)
# CHECK: (d0, d1, d2) -> (d2, d0, d1)
map5 = AffineMap.get_permutation([2, 0, 1])
print(map5)
assert map1 == AffineMap.get(2, 3, [d1, c2])
assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
assert map2 == AffineMap.get_constant(2)
assert map3 == AffineMap.get_identity(2)
assert map4 == AffineMap.get_minor_identity(2, 1)
try:
AffineMap.get(1, 1, [1])
except RuntimeError as e:
# CHECK: Invalid expression when attempting to create an AffineMap
print(e)
try:
AffineMap.get(1, 1, [None])
except RuntimeError as e:
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
print(e)
try:
AffineMap.get_permutation([1, 0, 1])
except RuntimeError as e:
# CHECK: Invalid permutation when attempting to create an AffineMap
print(e)
try:
map3.get_submap([42])
except ValueError as e:
# CHECK: result position out of bounds
print(e)
try:
map3.get_minor_submap(42)
except ValueError as e:
# CHECK: number of results out of bounds
print(e)
try:
map3.get_major_submap(42)
except ValueError as e:
# CHECK: number of results out of bounds
print(e)
# CHECK-LABEL: TEST: testAffineMapDerive
@run
def testAffineMapDerive():
with Context() as ctx:
map5 = AffineMap.get_identity(5)
# CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
map123 = map5.get_submap([1, 2, 3])
print(map123)
# CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
map01 = map5.get_major_submap(2)
print(map01)
# CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
map34 = map5.get_minor_submap(2)
print(map34)
# CHECK-LABEL: TEST: testAffineMapProperties
@run
def testAffineMapProperties():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map1 = AffineMap.get(3, 0, [d2, d0])
map2 = AffineMap.get(3, 0, [d2, d0, d1])
map3 = AffineMap.get(3, 1, [d2, d0, d1])
# CHECK: False
print(map1.is_permutation)
# CHECK: True
print(map1.is_projected_permutation)
# CHECK: True
print(map2.is_permutation)
# CHECK: True
print(map2.is_projected_permutation)
# CHECK: False
print(map3.is_permutation)
# CHECK: False
print(map3.is_projected_permutation)
# CHECK-LABEL: TEST: testAffineMapExprs
@run
def testAffineMapExprs():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map3 = AffineMap.get(3, 1, [d2, d0, d1])
# CHECK: 3
print(map3.n_dims)
# CHECK: 4
print(map3.n_inputs)
# CHECK: 1
print(map3.n_symbols)
assert map3.n_inputs == map3.n_dims + map3.n_symbols
# CHECK: 3
print(len(map3.results))
for expr in map3.results:
# CHECK: d2
# CHECK: d0
# CHECK: d1
print(expr)
for expr in map3.results[-1:-4:-1]:
# CHECK: d1
# CHECK: d0
# CHECK: d2
print(expr)
assert list(map3.results) == [d2, d0, d1]
# CHECK-LABEL: TEST: testCompressUnusedSymbols
@run
def testCompressUnusedSymbols():
with Context() as ctx:
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
s0, s1, s2 = (
AffineSymbolExpr.get(0),
AffineSymbolExpr.get(1),
AffineSymbolExpr.get(2),
)
maps = [
AffineMap.get(3, 3, [d2, d0, d1]),
AffineMap.get(3, 3, [d2, d0 + s2, d1]),
AffineMap.get(3, 3, [d1, d2, d0]),
]
compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
# CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
# CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
# CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
print(maps)
# CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
# CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
# CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
print(compressed_maps)
# CHECK-LABEL: TEST: testReplace
@run
def testReplace():
with Context() as ctx:
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
s0, s1, s2 = (
AffineSymbolExpr.get(0),
AffineSymbolExpr.get(1),
AffineSymbolExpr.get(2),
)
map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
# CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
print(replace0)
# CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
print(replace1)
# CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
print(replace3)
# CHECK-LABEL: TEST: testHash
@run
def testHash():
with Context():
d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
m1 = AffineMap.get(2, 0, [d0, d1])
m2 = AffineMap.get(2, 0, [d1, d0])
assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
dictionary = dict()
dictionary[m1] = 1
dictionary[m2] = 2
assert m1 in dictionary