//===- Traits.cpp - Common op traits shared by dialects -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Traits.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/FormatVariadic.h" #include <optional> usingnamespacemlir; bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2) { … } bool OpTrait::util::staticallyKnownBroadcastable( ArrayRef<SmallVector<int64_t, 6>> shapes) { … } bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2, SmallVectorImpl<int64_t> &resultShape) { … } /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef<int64_t> getShape(Type type) { … } /// Returns the result broadcast composition type from the two given types by /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. /// /// elementType, if specified, will be used as the element type of the /// broadcasted result type. Otherwise it is required that the element type of /// type1 and type2 is the same and this element type will be used as the /// resultant element type. Type OpTrait::util::getBroadcastedType(Type type1, Type type2, Type elementType) { … } /// Returns a tuple corresponding to whether range has tensor or vector type. template <typename iterator_range> static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { … } static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, ArrayRef<int64_t> existing) { … } static std::string getShapeString(ArrayRef<int64_t> shape) { … } LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { … }