llvm/mlir/include/mlir/Dialect/Arith/Utils/Utils.h

//===- Utils.h - General Arith transformation utilities ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various transformation utilities for
// the Arith dialect. These are not passes by themselves but are used
// either by passes, optimization sequences, or in turn by other transformation
// utilities.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARITH_UTILS_UTILS_H
#define MLIR_DIALECT_ARITH_UTILS_UTILS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {

ReassociationIndices;

/// Infer the output shape for a {memref|tensor}.expand_shape when it is
/// possible to do so.
///
/// Note: This should *only* be used to implement
/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
/// If you need to infer the output shape you should use the static method of
/// `ExpandShapeOp` instead of calling this.
///
/// `inputShape` is the shape of the tensor or memref being expanded as a
/// sequence of SSA values or constants. `expandedType` is the output shape of
/// the expand_shape operation. `reassociation` is the reassociation denoting
/// the output dims each input dim is mapped to.
///
/// Returns the output shape in `outputShape` and `staticOutputShape`, following
/// the conventions for the output_shape and static_output_shape inputs to the
/// expand_shape ops.
std::optional<SmallVector<OpFoldResult>>
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
                            ArrayRef<ReassociationIndices> reassociation,
                            ArrayRef<OpFoldResult> inputShape);

/// Matches a ConstantIndexOp.
detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();

llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
                                            ArrayRef<int64_t> shape);

/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
/// a Value or creates a ConstantOp if it casts to an Integer Attribute.
/// Other attribute types are not supported.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
                                    OpFoldResult ofr);

/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
/// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
/// Other attribute types are not supported.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                      OpFoldResult ofr);

/// Similar to the other overload, but converts multiple OpFoldResults into
/// Values.
SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                ArrayRef<OpFoldResult> valueOrAttrVec);

/// Create a cast from an index-like value (index or integer) to another
/// index-like value. If the value type and the target type are the same, it
/// returns the original value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
                                      Type targetType, Value value);

/// Converts a scalar value `operand` to type `toType`. If the value doesn't
/// convert, a warning will be issued and the operand is returned as is (which
/// will presumably yield a verification issue downstream).
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
                           Type toType, bool isUnsignedCast);

/// Create a constant of type `type` at location `loc` whose value is `value`
/// (an APInt or APFloat whose type must match the element type of `type`).
/// If `type` is a shaped type, create a splat constant of the given value.
/// Constants are folded if possible.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                  const APInt &value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                  int64_t value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                  const APFloat &value);

/// Returns the int type of the integer in ofr.
/// Other attribute types are not supported.
Type getType(OpFoldResult ofr);

/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {};

namespace arith {

// Build the product of a sequence.
// If values = (v0, v1, ..., vn) than the returned
// value is v0 * v1 * ... * vn.
// All values must have the same type.
//
// The version without `resultType` must contain at least one element in values.
// Then the result will have the same type as the elements in `values`.
// If `values` is empty in the version with `resultType` returns 1 with type
// `resultType`.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
                    Type resultType);

// Map strings to float types.
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);

} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H