llvm/mlir/lib/Bindings/Python/TransformInterpreter.cpp

//===- TransformInterpreter.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pybind classes for the transform dialect interpreter.
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>

namespace py = pybind11;

namespace {
struct PyMlirTransformOptions {
  PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
  PyMlirTransformOptions(PyMlirTransformOptions &&other) {
    options = other.options;
    other.options.ptr = nullptr;
  }
  PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;

  ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }

  MlirTransformOptions options;
};
} // namespace

static void populateTransformInterpreterSubmodule(py::module &m) {
  py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
      .def(py::init())
      .def_property(
          "expensive_checks",
          [](const PyMlirTransformOptions &self) {
            return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
          },
          [](PyMlirTransformOptions &self, bool value) {
            mlirTransformOptionsEnableExpensiveChecks(self.options, value);
          })
      .def_property(
          "enforce_single_top_level_transform_op",
          [](const PyMlirTransformOptions &self) {
            return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
                self.options);
          },
          [](PyMlirTransformOptions &self, bool value) {
            mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
                                                                 value);
          });

  m.def(
      "apply_named_sequence",
      [](MlirOperation payloadRoot, MlirOperation transformRoot,
         MlirOperation transformModule, const PyMlirTransformOptions &options) {
        mlir::python::CollectDiagnosticsToStringScope scope(
            mlirOperationGetContext(transformRoot));

        // Calling back into Python to invalidate everything under the payload
        // root. This is awkward, but we don't have access to PyMlirContext
        // object here otherwise.
        py::object obj = py::cast(payloadRoot);
        obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);

        MlirLogicalResult result = mlirTransformApplyNamedSequence(
            payloadRoot, transformRoot, transformModule, options.options);
        if (mlirLogicalResultIsSuccess(result))
          return;

        throw py::value_error(
            "Failed to apply named transform sequence.\nDiagnostic message " +
            scope.takeMessage());
      },
      py::arg("payload_root"), py::arg("transform_root"),
      py::arg("transform_module"),
      py::arg("transform_options") = PyMlirTransformOptions());

  m.def(
      "copy_symbols_and_merge_into",
      [](MlirOperation target, MlirOperation other) {
        mlir::python::CollectDiagnosticsToStringScope scope(
            mlirOperationGetContext(target));

        MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
        if (mlirLogicalResultIsFailure(result)) {
          throw py::value_error(
              "Failed to merge symbols.\nDiagnostic message " +
              scope.takeMessage());
        }
      },
      py::arg("target"), py::arg("other"));
}

PYBIND11_MODULE(_mlirTransformInterpreter, m) {
  m.doc() = "MLIR Transform dialect interpreter functionality.";
  populateTransformInterpreterSubmodule(m);
}