llvm/mlir/python/mlir/dialects/memref.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import operator
from itertools import accumulate
from typing import Optional

from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation


def _is_constant_int_like(i):
    return (
        isinstance(i, Value)
        and isinstance(i.owner, Operation)
        and isinstance(i.owner.opview, ConstantOp)
        and _is_integer_like_type(i.type)
    )


def _is_static_int_like(i):
    return (
        isinstance(i, int) and not ShapedType.is_dynamic_size(i)
    ) or _is_constant_int_like(i)


def _infer_memref_subview_result_type(
    source_memref_type, offsets, static_sizes, static_strides
):
    source_strides, source_offset = source_memref_type.get_strides_and_offset()
    # "canonicalize" from tuple|list -> list
    offsets, static_sizes, static_strides, source_strides = map(
        list, (offsets, static_sizes, static_strides, source_strides)
    )

    if not all(
        all(_is_static_int_like(i) for i in s)
        for s in [
            static_sizes,
            static_strides,
            source_strides,
        ]
    ):
        raise ValueError(
            "Only inferring from python or mlir integer constant is supported."
        )

    for s in [offsets, static_sizes, static_strides]:
        for idx, i in enumerate(s):
            if _is_constant_int_like(i):
                s[idx] = i.owner.opview.literal_value

    if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
        target_offset = ShapedType.get_dynamic_size()
    else:
        target_offset = source_offset
        for offset, target_stride in zip(offsets, source_strides):
            target_offset += offset * target_stride

    target_strides = []
    for source_stride, static_stride in zip(source_strides, static_strides):
        target_strides.append(source_stride * static_stride)

    # If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
    default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
    if target_strides == default_strides and target_offset == 0:
        layout = None
    else:
        layout = StridedLayoutAttr.get(target_offset, target_strides)
    return (
        offsets,
        static_sizes,
        static_strides,
        MemRefType.get(
            static_sizes,
            source_memref_type.element_type,
            layout,
            source_memref_type.memory_space,
        ),
    )


_generated_subview = subview


def subview(
    source: Value,
    offsets: MixedValues,
    sizes: MixedValues,
    strides: MixedValues,
    *,
    result_type: Optional[MemRefType] = None,
    loc=None,
    ip=None,
):
    if offsets is None:
        offsets = []
    if sizes is None:
        sizes = []
    if strides is None:
        strides = []
    source_strides, source_offset = source.type.get_strides_and_offset()
    if result_type is None and all(
        all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
    ):
        # If any are arith.constant results then this will canonicalize to python int
        # (which can then be used to fully specify the subview).
        (
            offsets,
            sizes,
            strides,
            result_type,
        ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
    elif result_type is None:
        raise ValueError(
            "mixed static/dynamic offset/sizes/strides requires explicit result type."
        )

    offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
    sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
    strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)

    return _generated_subview(
        result_type,
        source,
        offsets,
        sizes,
        strides,
        static_offsets,
        static_sizes,
        static_strides,
        loc=loc,
        ip=ip,
    )