llvm/mlir/test/python/dialects/llvm.py

# RUN: %PYTHON %s | FileCheck %s
# This is just a smoke test that the dialect is functional.

from mlir.ir import *
from mlir.dialects import llvm


def constructAndPrintInModule(f):
    print("\nTEST:", f.__name__)
    with Context(), Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            f()
        print(module)
    return f


# CHECK-LABEL: testStructType
@constructAndPrintInModule
def testStructType():
    print(llvm.StructType.get_literal([]))
    # CHECK: !llvm.struct<()>

    i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
    print(llvm.StructType.get_literal([i8, i32, i64]))
    print(llvm.StructType.get_literal([i32]))
    print(llvm.StructType.get_literal([i32, i32], packed=True))
    literal = llvm.StructType.get_literal([i8, i32, i64])
    assert len(literal.body) == 3
    print(*tuple(literal.body))
    assert literal.name is None
    # CHECK: !llvm.struct<(i8, i32, i64)>
    # CHECK: !llvm.struct<(i32)>
    # CHECK: !llvm.struct<packed (i32, i32)>
    # CHECK: i8 i32 i64

    assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
    assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])

    print(llvm.StructType.get_identified("foo"))
    print(llvm.StructType.get_identified("bar"))
    # CHECK: !llvm.struct<"foo", opaque>
    # CHECK: !llvm.struct<"bar", opaque>

    assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
        "foo"
    )
    assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
        "bar"
    )

    foo_struct = llvm.StructType.get_identified("foo")
    print(foo_struct.name)
    print(foo_struct.body)
    assert foo_struct.opaque
    foo_struct.set_body([i32, i64])
    print(*tuple(foo_struct.body))
    print(foo_struct)
    assert not foo_struct.packed
    assert not foo_struct.opaque
    assert llvm.StructType.get_identified("foo") == foo_struct
    # CHECK: foo
    # CHECK: None
    # CHECK: i32 i64
    # CHECK: !llvm.struct<"foo", (i32, i64)>

    bar_struct = llvm.StructType.get_identified("bar")
    bar_struct.set_body([i32], packed=True)
    print(bar_struct)
    assert bar_struct.packed
    # CHECK: !llvm.struct<"bar", packed (i32)>

    # Same body, should not raise.
    foo_struct.set_body([i32, i64])

    try:
        foo_struct.set_body([])
    except ValueError as e:
        pass
    else:
        assert False, "expected exception not raised"

    try:
        bar_struct.set_body([i32])
    except ValueError as e:
        pass
    else:
        assert False, "expected exception not raised"

    print(llvm.StructType.new_identified("foo", []))
    assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
        "foo", []
    )
    # CHECK: !llvm.struct<"foo{{[^"]+}}

    opaque = llvm.StructType.get_opaque("opaque")
    print(opaque)
    assert opaque.opaque
    # CHECK: !llvm.struct<"opaque", opaque>


# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
    mat64f32_t = Type.parse(
        "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
    )
    result = llvm.UndefOp(mat64f32_t)
    # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>


# CHECK-LABEL: testPointerType
@constructAndPrintInModule
def testPointerType():
    ptr = llvm.PointerType.get()
    # CHECK: !llvm.ptr
    print(ptr)

    ptr_with_addr = llvm.PointerType.get(1)
    # CHECK: !llvm.ptr<1>
    print(ptr_with_addr)


# CHECK-LABEL: testConstant
@constructAndPrintInModule
def testConstant():
    i32 = IntegerType.get_signless(32)
    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
    # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
    print(c_128.owner)


# CHECK-LABEL: testIntrinsics
@constructAndPrintInModule
def testIntrinsics():
    i32 = IntegerType.get_signless(32)
    ptr = llvm.PointerType.get()
    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
    # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
    print(c_128.owner)

    alloca = llvm.alloca(ptr, c_128, i32)
    # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
    print(alloca.owner)

    c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
    # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
    print(c_0.owner)

    result = llvm.intr_memset(alloca, c_0, c_128, False)
    # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
    print(result)