llvm/clang/utils/ABITest/Enumeration.py

"""Utilities for enumeration of finite and countably infinite sets.
"""
from __future__ import absolute_import, division, print_function

###
# Countable iteration

# Simplifies some calculations
class Aleph0(int):
    _singleton = None

    def __new__(type):
        if type._singleton is None:
            type._singleton = int.__new__(type)
        return type._singleton

    def __repr__(self):
        return "<aleph0>"

    def __str__(self):
        return "inf"

    def __cmp__(self, b):
        return 1

    def __sub__(self, b):
        raise ValueError("Cannot subtract aleph0")

    __rsub__ = __sub__

    def __add__(self, b):
        return self

    __radd__ = __add__

    def __mul__(self, b):
        if b == 0:
            return b
        return self

    __rmul__ = __mul__

    def __floordiv__(self, b):
        if b == 0:
            raise ZeroDivisionError
        return self

    __rfloordiv__ = __floordiv__
    __truediv__ = __floordiv__
    __rtuediv__ = __floordiv__
    __div__ = __floordiv__
    __rdiv__ = __floordiv__

    def __pow__(self, b):
        if b == 0:
            return 1
        return self


aleph0 = Aleph0()


def base(line):
    return line * (line + 1) // 2


def pairToN(pair):
    x, y = pair
    line, index = x + y, y
    return base(line) + index


def getNthPairInfo(N):
    # Avoid various singularities
    if N == 0:
        return (0, 0)

    # Gallop to find bounds for line
    line = 1
    next = 2
    while base(next) <= N:
        line = next
        next = line << 1

    # Binary search for starting line
    lo = line
    hi = line << 1
    while lo + 1 != hi:
        # assert base(lo) <= N < base(hi)
        mid = (lo + hi) >> 1
        if base(mid) <= N:
            lo = mid
        else:
            hi = mid

    line = lo
    return line, N - base(line)


def getNthPair(N):
    line, index = getNthPairInfo(N)
    return (line - index, index)


def getNthPairBounded(N, W=aleph0, H=aleph0, useDivmod=False):
    """getNthPairBounded(N, W, H) -> (x, y)

    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""

    if W <= 0 or H <= 0:
        raise ValueError("Invalid bounds")
    elif N >= W * H:
        raise ValueError("Invalid input (out of bounds)")

    # Simple case...
    if W is aleph0 and H is aleph0:
        return getNthPair(N)

    # Otherwise simplify by assuming W < H
    if H < W:
        x, y = getNthPairBounded(N, H, W, useDivmod=useDivmod)
        return y, x

    if useDivmod:
        return N % W, N // W
    else:
        # Conceptually we want to slide a diagonal line across a
        # rectangle. This gives more interesting results for large
        # bounds than using divmod.

        # If in lower left, just return as usual
        cornerSize = base(W)
        if N < cornerSize:
            return getNthPair(N)

        # Otherwise if in upper right, subtract from corner
        if H is not aleph0:
            M = W * H - N - 1
            if M < cornerSize:
                x, y = getNthPair(M)
                return (W - 1 - x, H - 1 - y)

        # Otherwise, compile line and index from number of times we
        # wrap.
        N = N - cornerSize
        index, offset = N % W, N // W
        # p = (W-1, 1+offset) + (-1,1)*index
        return (W - 1 - index, 1 + offset + index)


def getNthPairBoundedChecked(
    N, W=aleph0, H=aleph0, useDivmod=False, GNP=getNthPairBounded
):
    x, y = GNP(N, W, H, useDivmod)
    assert 0 <= x < W and 0 <= y < H
    return x, y


def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)

    Return the N-th W-tuple, where for 0 <= x_i < H."""

    if useLeftToRight:
        elts = [None] * W
        for i in range(W):
            elts[i], N = getNthPairBounded(N, H)
        return tuple(elts)
    else:
        if W == 0:
            return ()
        elif W == 1:
            return (N,)
        elif W == 2:
            return getNthPairBounded(N, H, H)
        else:
            LW, RW = W // 2, W - (W // 2)
            L, R = getNthPairBounded(N, H**LW, H**RW)
            return getNthNTuple(
                L, LW, H=H, useLeftToRight=useLeftToRight
            ) + getNthNTuple(R, RW, H=H, useLeftToRight=useLeftToRight)


def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
    t = GNT(N, W, H, useLeftToRight)
    assert len(t) == W
    for i in t:
        assert i < H
    return t


def getNthTuple(
    N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False
):
    """getNthTuple(N, maxSize, maxElement) -> x

    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
    y < maxElement."""

    # All zero sized tuples are isomorphic, don't ya know.
    if N == 0:
        return ()
    N -= 1
    if maxElement is not aleph0:
        if maxSize is aleph0:
            raise NotImplementedError("Max element size without max size unhandled")
        bounds = [maxElement**i for i in range(1, maxSize + 1)]
        S, M = getNthPairVariableBounds(N, bounds)
    else:
        S, M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
    return getNthNTuple(M, S + 1, maxElement, useLeftToRight=useLeftToRight)


def getNthTupleChecked(
    N,
    maxSize=aleph0,
    maxElement=aleph0,
    useDivmod=False,
    useLeftToRight=False,
    GNT=getNthTuple,
):
    # FIXME: maxsize is inclusive
    t = GNT(N, maxSize, maxElement, useDivmod, useLeftToRight)
    assert len(t) <= maxSize
    for i in t:
        assert i < maxElement
    return t


def getNthPairVariableBounds(N, bounds):
    """getNthPairVariableBounds(N, bounds) -> (x, y)

    Given a finite list of bounds (which may be finite or aleph0),
    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
    bounds[x]."""

    if not bounds:
        raise ValueError("Invalid bounds")
    if not (0 <= N < sum(bounds)):
        raise ValueError("Invalid input (out of bounds)")

    level = 0
    active = list(range(len(bounds)))
    active.sort(key=lambda i: bounds[i])
    prevLevel = 0
    for i, index in enumerate(active):
        level = bounds[index]
        W = len(active) - i
        if level is aleph0:
            H = aleph0
        else:
            H = level - prevLevel
        levelSize = W * H
        if N < levelSize:  # Found the level
            idelta, delta = getNthPairBounded(N, W, H)
            return active[i + idelta], prevLevel + delta
        else:
            N -= levelSize
            prevLevel = level
    else:
        raise RuntimError("Unexpected loop completion")


def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
    x, y = GNVP(N, bounds)
    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
    return (x, y)


###


def testPairs():
    W = 3
    H = 6
    a = [["  " for x in range(10)] for y in range(10)]
    b = [["  " for x in range(10)] for y in range(10)]
    for i in range(min(W * H, 40)):
        x, y = getNthPairBounded(i, W, H)
        x2, y2 = getNthPairBounded(i, W, H, useDivmod=True)
        print(i, (x, y), (x2, y2))
        a[y][x] = "%2d" % i
        b[y2][x2] = "%2d" % i

    print("-- a --")
    for ln in a[::-1]:
        if "".join(ln).strip():
            print("  ".join(ln))
    print("-- b --")
    for ln in b[::-1]:
        if "".join(ln).strip():
            print("  ".join(ln))


def testPairsVB():
    bounds = [2, 2, 4, aleph0, 5, aleph0]
    a = [["  " for x in range(15)] for y in range(15)]
    b = [["  " for x in range(15)] for y in range(15)]
    for i in range(min(sum(bounds), 40)):
        x, y = getNthPairVariableBounds(i, bounds)
        print(i, (x, y))
        a[y][x] = "%2d" % i

    print("-- a --")
    for ln in a[::-1]:
        if "".join(ln).strip():
            print("  ".join(ln))


###

# Toggle to use checked versions of enumeration routines.
if False:
    getNthPairVariableBounds = getNthPairVariableBoundsChecked
    getNthPairBounded = getNthPairBoundedChecked
    getNthNTuple = getNthNTupleChecked
    getNthTuple = getNthTupleChecked

if __name__ == "__main__":
    testPairs()

    testPairsVB()