llvm/mlir/include/mlir/Interfaces/InferTypeOpInterface.h

//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
#define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {

class ShapedTypeComponents;
ReifiedRankedShapedTypeDims;

/// Reify the shape of the result of an operation (typically in terms of the
/// shape of its operands).
LogicalResult
reifyResultShapes(OpBuilder &b, Operation *op,
                  ReifiedRankedShapedTypeDims &reifiedReturnShapes);

/// Adaptor class to abstract the differences between whether value is from
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
class ShapeAdaptor {};

/// ShapedTypeComponents that represents the components of a ShapedType.
/// The components consist of
///  - A ranked or unranked shape with the dimension specification match those
///    of ShapeType's getShape() (e.g., dynamic dimension represented using
///    ShapedType::kDynamic)
///  - A element type, may be unset (nullptr)
///  - A attribute, may be unset (nullptr)
/// Used by ShapedType type inferences.
class ShapedTypeComponents {};

/// Range of values and shapes (corresponding effectively to Shapes dialect's
/// ValueShape type concept).
// Currently this exposes the Value (of operands) and Type of the Value. This is
// not ideal as then one can accidentally reference an out of date shape. This
// is done to both enable gradual switch and also as OpAdaptor doesn't currently
// allow returning anything other than Value.
class ValueShapeRange : public ValueRange::RangeBaseT {};

namespace detail {
// Helper function to infer return tensor returns types given element and
// shape inference function.
LogicalResult
inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
                       SmallVectorImpl<Type> &inferredReturnTypes);

/// Verifies that the inferred result types match the actual result types for
/// the op. Precondition: op implements InferTypeOpInterface.
LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail

namespace OpTrait {
template <typename ConcreteType>
class InferTensorType;
} // namespace OpTrait
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/InferTypeOpInterface.h.inc"

namespace mlir {
namespace OpTrait {

template <typename ConcreteType>
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {};

template <typename ConcreteType>
class InferShapedTypeOpAdaptor
    : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};

/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
///   Less strict is possible (e.g., implements inferReturnTypeComponents and
///   these always populates all element types and shapes or fails, but this
///   trait is currently only used where the interfaces are, so keep it
///   restricted for now).
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};

} // namespace OpTrait
} // namespace mlir

#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_