llvm/mlir/include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h

//===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H

#include <optional>
#include <type_traits>

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
namespace transform {
class MatchOpInterface;

namespace detail {
/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
/// first operand.
template <typename OpTy>
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
                                                   TransformResults &results,
                                                   TransformState &state) {}
} // namespace detail

template <typename OpTy>
class AtMostOneOpMatcherOpTrait
    : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {};

template <typename OpTy>
class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {};

template <typename OpTy>
class SingleValueMatcherOpTrait
    : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {};

//===----------------------------------------------------------------------===//
// Printing/parsing for positional specification matchers
//===----------------------------------------------------------------------===//

/// Parses a positional index specification for transform match operations.
/// The following forms are accepted:
///
///  - `all`: sets `isAll` and returns;
///  - comma-separated-integer-list: populates `rawDimList` with the values;
///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
///  with the values and sets `isInverted`.
ParseResult parseTransformMatchDims(OpAsmParser &parser,
                                    DenseI64ArrayAttr &rawDimList,
                                    UnitAttr &isInverted, UnitAttr &isAll);

/// Prints a positional index specification for transform match operations.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
                             DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
                             UnitAttr isAll);

//===----------------------------------------------------------------------===//
// Utilities for positional specification matchers
//===----------------------------------------------------------------------===//

/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
                                         bool inverted, bool all);

/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are  are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
                          ArrayRef<int64_t> rawList, int64_t maxNumber,
                          SmallVectorImpl<int64_t> &result);

} // namespace transform
} // namespace mlir

#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H