# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
import numpy as np
import ctypes
try:
import ml_dtypes
except ModuleNotFoundError:
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
ml_dtypes = None
class C128(ctypes.Structure):
"""A ctype representation for MLIR's Double Complex."""
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
class C64(ctypes.Structure):
"""A ctype representation for MLIR's Float Complex."""
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
class F16(ctypes.Structure):
"""A ctype representation for MLIR's Float16."""
_fields_ = [("f16", ctypes.c_int16)]
class BF16(ctypes.Structure):
"""A ctype representation for MLIR's BFloat16."""
_fields_ = [("bf16", ctypes.c_int16)]
class F8E5M2(ctypes.Structure):
"""A ctype representation for MLIR's Float8E5M2."""
_fields_ = [("f8E5M2", ctypes.c_int8)]
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
"""Converts dtype to ctype."""
if dtp == np.dtype(np.complex128):
return C128
if dtp == np.dtype(np.complex64):
return C64
if dtp == np.dtype(np.float16):
return F16
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
return BF16
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
return F8E5M2
return np.ctypeslib.as_ctypes_type(dtp)
def to_numpy(array):
"""Converts ctypes array back to numpy dtype array."""
if array.dtype == C128:
return array.view("complex128")
if array.dtype == C64:
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
assert not (
array.dtype == BF16 and ml_dtypes is None
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == BF16:
return array.view("bfloat16")
assert not (
array.dtype == F8E5M2 and ml_dtypes is None
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == F8E5M2:
return array.view("float8_e5m2")
return array
def make_nd_memref_descriptor(rank, dtype):
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given rank/dtype, where rank>0."""
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
("shape", ctypes.c_longlong * rank),
("strides", ctypes.c_longlong * rank),
]
return MemRefDescriptor
def make_zero_d_memref_descriptor(dtype):
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given dtype, where rank=0."""
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
]
return MemRefDescriptor
class UnrankedMemRefDescriptor(ctypes.Structure):
"""Creates a ctype struct for memref descriptor"""
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
def get_ranked_memref_descriptor(nparray):
"""Returns a ranked memref descriptor for the given numpy array."""
ctp = as_ctype(nparray.dtype)
if nparray.ndim == 0:
x = make_zero_d_memref_descriptor(ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
return x
x = make_nd_memref_descriptor(nparray.ndim, ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape
# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
return x
def get_unranked_memref_descriptor(nparray):
"""Returns a generic/unranked memref descriptor for the given numpy array."""
d = UnrankedMemRefDescriptor()
d.rank = nparray.ndim
x = get_ranked_memref_descriptor(nparray)
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
def move_aligned_ptr_by_offset(aligned_ptr, offset):
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
aligned_addr = ctypes.addressof(aligned_ptr.contents)
elem_size = ctypes.sizeof(aligned_ptr.contents)
shift = offset * elem_size
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
return content_ptr
def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset)
np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(val[0].shape),
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)
def ranked_memref_to_numpy(ranked_memref):
"""Converts ranked memrefs to numpy arrays."""
content_ptr = move_aligned_ptr_by_offset(
ranked_memref[0].aligned, ranked_memref[0].offset
)
np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)