llvm/mlir/docs/Dialects/ShapeDialect.md

# 'shape' Dialect

Description of operations & types within the Shape dialect as well as their
[usage](#different-stages-of-lowering-shape-dialect).

[include "Dialects/ShapeDialectOps.md"]

## Different stages of lowering Shape dialect

In this section we shall give a brief overview of the different uses of the 
shape dialect and the lowering between these uses. Currently we have 3 worlds /
stages of lowering of shape functions:

1.  _Error monadic/error carrying/user specification_:
    This "input" form carries both the shape and whether in error state as
    value. Hence at this level all operations are pure operations producing and
    consuming values where the values could represent an error.

2.  _Constrained_:
    This form uses a variant of explicit evidence passing to allow leveraging
    existing compiler infrastructure to preserve safety information during
    optimization.

3.  _Side-effecting/asserting_:
    This final lowered form is imperative form with side-effecting ops (e.g.,
    assert) for final codegen.

We are going to do a quick step through of the lowering using the example of
a matmul.

Starting from the shape function of matmul in the error monadic form
below[^wip_form1]:

```mlir
shape.function_library @shplib {

func.func @matmul(%lhs: !shape.value_shape, %rhs: !shape.value_shape) -> !shape.shape {
  %c1 = shape.const_size 1
  %c2 = shape.const_size 2
  // We could also allow rank etc operations directly on value_shape too, that
  // would make it nicer as "input" language, but keeping it explicit inside the
  // IR instead and then we could have helper methods in front-end language.
  %lhs_shape = shape.shape_of %lhs : !shape.value_shape -> !shape.shape
  %rhs_shape = shape.shape_of %rhs : !shape.value_shape -> !shape.shape
  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
  // This is not minimal as one could ensure the ranks are the same below, also a
  // variadic meet would make it more concise too.
  %r = "shape.meet"(%lhs_rank, %rhs_rank) : (!shape.size, !shape.size) -> !shape.size
  %rank = shape.meet %c2, %r, error="requires rank 2 operands" :
    !shape.size, !shape.size -> !shape.size
  %l0, %l1 = "shape.split_at"(%lhs_shape, %c1) :
    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
  %r0, %r1 = "shape.split_at"(%rhs_shape, %c1) :
    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
  %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
    !shape.shape, !shape.shape -> !shape.shape
  %res = shape.concat %l0, %r1
  // Should have `shape.return %res requires %c, %rank` to enable
  return %res : !shape.shape
}

} mapping {
  foo.matmul = @matmul
}
```

*   We are using the default builtin func and return here. Preferably we'd use
    ‘shape\_func’ as a special function op that allows passing multiple results
    back that affect correct execution (e.g., serves as an error join)
    *   This would also means one can't reify it inside a regular function
        without handling the shape.return - that is a feature here as these are
        more of a template.
    *   Currently we also have not marked `meet` as having no side-effects to
        avoid DCE until we have `shape.return`, at which point computing the
        meet could be treated as purely computational returning error.
*   Meet represents a constraint that should hold, so should not be used to see
    *if* something is equal. E.g., this means `meet` can't be used to represent

    ```
       either(meet(x, y), meet(y,z))
    ```

*   This could have been written more concisely as something like

    ```
      concat(lhs[0], rhs[1]) if rank(lhs) == 2 &&
        rank(rhs) == 2 && lhs[1] == rhs[0]
    ```

    but not focusing on front-end proper here.

We are going to lower to "most" nested form directly (see
[test](https://github.com/tensorflow/tensorflow/blob/64062b5c51e04e370df26551d247496787d3f5c2/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L3088)
for an example reification along with legalization). In the above this was in a
separate shape function library, while here we would normally reify it as part
of lowering, but for simplicity will show as a standalone shape function.

```mlir
func.func @matmul_shape1(%lhs: tensor<*xf32>, %rhs: tensor<*xindex>) -> tensor<?xindex> {
  %c1 = shape.const_size 1
  %c2 = shape.const_size 2
  // We allow `shape.shape_of` to return either a `!shape.shape` or
  // `tensor<?xindex>` type, in the case where the input is a tensor the most
  // refined type is a tensor of `index` but not required.
  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> !shape.shape
  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> !shape.shape
  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
  %w1 = shape.cstr_eq %lhs_rank, %rhs_rank : !shape.witness
  %res = shape.assuming %w1 -> tensor<?xindex> {
    %r1 = shape.any %lhs_rank, %rhs_rank : (!shape.size, !shape.size) -> !shape.size
    // Error message needs an addition, currently only on cstr_require.
    %w2 = shape.cstr_eq %c2, %r1, error="requires rank 2 operands"
    %res_1 = shape.assuming %w2 -> tensor<?xindex> {
      // Here the lowered
      //   %rank = shape.any %c2, %r1 (!shape.size, !shape.size) -> !shape.size
      // is dead and so elided further. But if `%rank` was actually consumed,
      // then it could have been folded in `shape.any`.
      %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
        (!shape.shape, !shape.size) -> !shape.shape
      %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
        (!shape.shape, !shape.size) -> !shape.shape
      %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
        !shape.size, !shape.size -> !shape.size
      %res = concat(%l0, %r1)
      shape.assuming_yield %res
    }
    shape.assuming_yield %res_1
  }
  return %res : tensor<?xindex>
}
```

We can now hoist computations of constraint were possible (which in the case
below is not too many as we need to verify the rank before we can split)

```mlir
func.func @matmul_shape2(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
  %c1 = shape.const_size 1
  %c2 = shape.const_size 2
  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
  %w1 = shape.cstr_eq %c2, %lhs_rank, error="requires rank 2 operands"
  %w2 = shape.cstr_eq %c2, %rhs_rank, error="requires rank 2 operands"
  %w = shape.assuming_all %w1, %w2
  %res = shape.assuming %w -> tensor<?xindex> {
    %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
    %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
    %w3 = shape.cstr_eq %l1, %r0, error="inner dimensions required to match"
    %res_2 = shape.assuming %w3 {
      %res = concat(%l0, %r1)
      shape.assuming_yield %res
    }
    shape.assuming_yield %res_1
  }
  return %res
}
```

The above form can now be lowered to the fully imperative form (see
[test](https://github.com/tensorflow/mlir-hlo/blob/af14e1ded33c3164d4418c5d234b5b346b6d017c/tests/rank-specialization.mlir#L22)
for example).

```mlir
func.func @matmul_shape3(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
  %w1 = shape.shape_eq %lhs_rank, %rhs_rank
  %w2 = shape.shape_eq %c2, %lhs_rank
  %w3 = and %w1, %w2
  assert %w3, "requires rank 2 operands"
  %l0, %l1 = shape.split_at(%lhs_shape, %c1) : tensor<?xindex>
  %r0, %r1 = shape.split_at(%rhs_shape, %c1) : tensor<?xindex>
  %w4 = shape.eq %l1, %r0
  assert %w4, "inner dimensions required to match"
  %res = concat(%l0, %r1)
  return %res
}
```

*   In this case form 3 is as easy and closer to form 1 (but only as no
    reordering was required). So it is a good question if the frontend authoring
    language could be more similar to the imperative form (under discussion).
*   The above form presented here is an intermittent form during a lowering
    pass. If used as input we would need to restrict the optimizations on it as
    the `shape` dialect operations are no longer connected by producer-consumer
    to enforce guard checking.

The above could be further lowered by using `tensor.dim`, `tensor.from_elements`
etc (or one could even lower these by way of, say, MHLO or TOSA dialect).

[^wip_form1]: This form is least use inside the current workflows and needs more work. In particular in the example we use `shape_func` where in the code we instead use standard func as first form 1 isn't used explicitly.