llvm/mlir/test/python/dialects/affine.py

# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import func
from mlir.dialects import arith
from mlir.dialects import memref
from mlir.dialects import affine
import mlir.extras.types as T


def constructAndPrintInModule(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f()
        print(module)
    return f


# CHECK-LABEL: TEST: testAffineStoreOp
@constructAndPrintInModule
def testAffineStoreOp():
    f32 = F32Type.get()
    index_type = IndexType.get()
    memref_type_out = MemRefType.get([12, 12], f32)

    # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
    @func.FuncOp.from_py_func(index_type)
    def affine_store_test(arg0):
        # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
        mem = memref.AllocOp(memref_type_out, [], []).result

        d0 = AffineDimExpr.get(0)
        s0 = AffineSymbolExpr.get(0)
        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])

        # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
        a1 = arith.ConstantOp(f32, 2.1)

        # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
        affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)

        return mem


# CHECK-LABEL: TEST: testAffineDelinearizeInfer
@constructAndPrintInModule
def testAffineDelinearizeInfer():
    # CHECK: %[[C0:.*]] = arith.constant 0 : index
    c0 = arith.ConstantOp(T.index(), 0)
    # CHECK: %[[C1:.*]] = arith.constant 1 : index
    c1 = arith.ConstantOp(T.index(), 1)
    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (%[[C1:.*]], %[[C0:.*]]) : index, index
    two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0])


# CHECK-LABEL: TEST: testAffineLoadOp
@constructAndPrintInModule
def testAffineLoadOp():
    f32 = F32Type.get()
    index_type = IndexType.get()
    memref_type_in = MemRefType.get([10, 10], f32)

    # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
    @func.FuncOp.from_py_func(memref_type_in, index_type)
    def affine_load_test(I, arg0):
        d0 = AffineDimExpr.get(0)
        s0 = AffineSymbolExpr.get(0)
        map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])

        # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
        a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)

        return a1


# CHECK-LABEL: TEST: testAffineForOp
@constructAndPrintInModule
def testAffineForOp():
    f32 = F32Type.get()
    index_type = IndexType.get()
    memref_type = MemRefType.get([1024], f32)

    # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
    # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
    # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
    @func.FuncOp.from_py_func(memref_type)
    def affine_for_op_test(buffer):
        # CHECK: %[[C1:.*]] = arith.constant 1 : index
        c1 = arith.ConstantOp(index_type, 1)
        # CHECK: %[[C2:.*]] = arith.constant 2 : index
        c2 = arith.ConstantOp(index_type, 2)
        # CHECK: %[[C3:.*]] = arith.constant 3 : index
        c3 = arith.ConstantOp(index_type, 3)
        # CHECK: %[[C9:.*]] = arith.constant 9 : index
        c9 = arith.ConstantOp(index_type, 9)
        # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
        ac0 = AffineConstantExpr.get(0)

        d0 = AffineDimExpr.get(0)
        d1 = AffineDimExpr.get(1)
        s0 = AffineSymbolExpr.get(0)
        lb = AffineMap.get(1, 1, [ac0, d0 + s0])
        ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
        sum_0 = arith.ConstantOp(f32, 0.0)

        # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
        sum = affine.AffineForOp(
            lb,
            ub,
            2,
            iter_args=[sum_0],
            lower_bound_operands=[c2, c3],
            upper_bound_operands=[c9, c1],
        )

        with InsertionPoint(sum.body):
            # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
            tmp = memref.LoadOp(buffer, [sum.induction_variable])
            sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
            affine.AffineYieldOp([sum_next])


# CHECK-LABEL: TEST: testAffineForOpErrors
@constructAndPrintInModule
def testAffineForOpErrors():
    c1 = arith.ConstantOp(T.index(), 1)
    c2 = arith.ConstantOp(T.index(), 2)
    c3 = arith.ConstantOp(T.index(), 3)
    d0 = AffineDimExpr.get(0)

    try:
        affine.AffineForOp(
            c1,
            c2,
            1,
            lower_bound_operands=[c3],
            upper_bound_operands=[],
        )
    except ValueError as e:
        assert (
            e.args[0]
            == "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
        )

    try:
        affine.AffineForOp(
            AffineMap.get_constant(1),
            c2,
            1,
            lower_bound_operands=[c3, c3],
            upper_bound_operands=[],
        )
    except ValueError as e:
        assert (
            e.args[0]
            == "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
        )

    try:
        two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c1])
        affine.AffineForOp(
            two_indices,
            c2,
            1,
            lower_bound_operands=[],
            upper_bound_operands=[],
        )
    except ValueError as e:
        assert e.args[0] == "Only a single concrete value is supported for lower bound."

    try:
        affine.AffineForOp(
            1.0,
            c2,
            1,
            lower_bound_operands=[],
            upper_bound_operands=[],
        )
    except ValueError as e:
        assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."


@constructAndPrintInModule
def testForSugar():
    memref_t = T.memref(10, T.index())
    range = affine.for_

    # CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>

    # CHECK-LABEL:   func.func @range_loop_1(
    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) {
    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
    # CHECK:           }
    # CHECK:           return
    # CHECK:         }
    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
    def range_loop_1(lb, ub, memref_v):
        for i in range(lb, ub, step=1):
            add = arith.addi(i, i)
            memref.store(add, memref_v, [i])

            affine.yield_([])

    # CHECK-LABEL:   func.func @range_loop_2(
    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 {
    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
    # CHECK:           }
    # CHECK:           return
    # CHECK:         }
    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
    def range_loop_2(lb, ub, memref_v):
        for i in range(lb, 10, step=1):
            add = arith.addi(i, i)
            memref.store(add, memref_v, [i])
            affine.yield_([])

    # CHECK-LABEL:   func.func @range_loop_3(
    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) {
    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
    # CHECK:           }
    # CHECK:           return
    # CHECK:         }
    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
    def range_loop_3(lb, ub, memref_v):
        for i in range(0, ub, step=1):
            add = arith.addi(i, i)
            memref.store(add, memref_v, [i])
            affine.yield_([])

    # CHECK-LABEL:   func.func @range_loop_4(
    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
    # CHECK:           }
    # CHECK:           return
    # CHECK:         }
    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
    def range_loop_4(lb, ub, memref_v):
        for i in range(0, 10, step=1):
            add = arith.addi(i, i)
            memref.store(add, memref_v, [i])
            affine.yield_([])

    # CHECK-LABEL:   func.func @range_loop_8(
    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
    # CHECK:           %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
    # CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
    # CHECK:             memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
    # CHECK:             affine.yield %[[VAL_5]] : memref<10xindex>
    # CHECK:           }
    # CHECK:           return
    # CHECK:         }
    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
    def range_loop_8(lb, ub, memref_v):
        for i, it in range(0, 10, iter_args=[memref_v]):
            add = arith.addi(i, i)
            memref.store(add, it, [i])
            affine.yield_([it])