llvm/mlir/python/mlir/extras/types.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from functools import partial
from typing import Optional, List

from ..ir import (
    Attribute,
    BF16Type,
    ComplexType,
    F16Type,
    F32Type,
    F64Type,
    Float8E3M4Type,
    Float8E4M3B11FNUZType,
    Float8E4M3FNType,
    Float8E4M3Type,
    Float8E5M2Type,
    FunctionType,
    IndexType,
    IntegerType,
    MemRefType,
    NoneType,
    OpaqueType,
    RankedTensorType,
    StridedLayoutAttr,
    StringAttr,
    TupleType,
    Type,
    UnrankedMemRefType,
    UnrankedTensorType,
    VectorType,
)

index = lambda: IndexType.get()


def i(width):
    return IntegerType.get_signless(width)


def si(width):
    return IntegerType.get_signed(width)


def ui(width):
    return IntegerType.get_unsigned(width)


bool = lambda: i(1)
i8 = lambda: i(8)
i16 = lambda: i(16)
i32 = lambda: i(32)
i64 = lambda: i(64)

si8 = lambda: si(8)
si16 = lambda: si(16)
si32 = lambda: si(32)
si64 = lambda: si(64)

ui8 = lambda: ui(8)
ui16 = lambda: ui(16)
ui32 = lambda: ui(32)
ui64 = lambda: ui(64)

f16 = lambda: F16Type.get()
f32 = lambda: F32Type.get()
f64 = lambda: F64Type.get()
bf16 = lambda: BF16Type.get()

f8E5M2 = lambda: Float8E5M2Type.get()
f8E4M3 = lambda: Float8E4M3Type.get()
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
f8E3M4 = lambda: Float8E3M4Type.get()

none = lambda: NoneType.get()


def complex(type):
    return ComplexType.get(type)


def opaque(dialect_namespace, type_data):
    return OpaqueType.get(dialect_namespace, type_data)


def _shaped(*shape, element_type: Type = None, type_constructor=None):
    if type_constructor is None:
        raise ValueError("shaped is an abstract base class - cannot be constructed.")
    if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
        shape and isinstance(shape[-1], Type) and element_type is not None
    ):
        raise ValueError(
            f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
        )
    if element_type is not None:
        type = element_type
        sizes = shape
    else:
        type = shape[-1]
        sizes = shape[:-1]
    if sizes:
        return type_constructor(sizes, type)
    else:
        return type_constructor(type)


def vector(
    *shape,
    element_type: Type = None,
    scalable: Optional[List[bool]] = None,
    scalable_dims: Optional[List[int]] = None,
):
    return _shaped(
        *shape,
        element_type=element_type,
        type_constructor=partial(
            VectorType.get, scalable=scalable, scalable_dims=scalable_dims
        ),
    )


def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
    if encoding is not None:
        encoding = StringAttr.get(encoding)
    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
        if encoding is not None:
            raise ValueError("UnrankedTensorType does not support encoding.")
        return _shaped(
            *shape, element_type=element_type, type_constructor=UnrankedTensorType.get
        )
    return _shaped(
        *shape,
        element_type=element_type,
        type_constructor=partial(RankedTensorType.get, encoding=encoding),
    )


def memref(
    *shape,
    element_type: Type = None,
    memory_space: Optional[int] = None,
    layout: Optional[StridedLayoutAttr] = None,
):
    if memory_space is not None:
        memory_space = Attribute.parse(str(memory_space))
    if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
        return _shaped(
            *shape,
            element_type=element_type,
            type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
        )
    return _shaped(
        *shape,
        element_type=element_type,
        type_constructor=partial(
            MemRefType.get, memory_space=memory_space, layout=layout
        ),
    )


def tuple(*elements):
    return TupleType.get_tuple(elements)


def function(*, inputs, results):
    return FunctionType.get(inputs, results)