//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the dialect for the Toy IR: custom type parsing and // operation verification. // //===----------------------------------------------------------------------===// #include "toy/Dialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include <algorithm> #include <cassert> #include <cstddef> #include <cstdint> #include <string> usingnamespacemlir; usingnamespacemlir::toy; #include "toy/Dialect.cpp.inc" //===----------------------------------------------------------------------===// // ToyInlinerInterface //===----------------------------------------------------------------------===// /// This class defines the interface for handling inlining with Toy /// operations. struct ToyInlinerInterface : public DialectInlinerInterface { … }; //===----------------------------------------------------------------------===// // Toy Operations //===----------------------------------------------------------------------===// /// A generalized parser for binary operations. This parses the different forms /// of 'printBinaryOp' below. static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { … } /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { … } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// /// Build a constant operation. /// The builder is passed as an argument, so is the state that this method is /// expected to fill in order to build the operation. void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, double value) { … } /// The 'OpAsmParser' class provides a collection of methods for parsing /// various punctuation, as well as attributes, operands, types, etc. Each of /// these methods returns a `ParseResult`. This class is a wrapper around /// `LogicalResult` that can be converted to a boolean `true` value on failure, /// or `false` on success. This allows for easily chaining together a set of /// parser rules. These rules are used to populate an `mlir::OperationState` /// similarly to the `build` methods described above. mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { … } /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. void ConstantOp::print(mlir::OpAsmPrinter &printer) { … } /// Verify that the given attribute value is valid for the given type. static llvm::LogicalResult verifyConstantForType(mlir::Type type, mlir::Attribute opaqueValue, mlir::Operation *op) { … } /// Verifier for the constant operation. This corresponds to the `::verify(...)` /// in the op definition. llvm::LogicalResult ConstantOp::verify() { … } llvm::LogicalResult StructConstantOp::verify() { … } /// Infer the output shape of the ConstantOp, this is required by the shape /// inference interface. void ConstantOp::inferShapes() { … } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value lhs, mlir::Value rhs) { … } mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { … } void AddOp::print(mlir::OpAsmPrinter &p) { … } /// Infer the output shape of the AddOp, this is required by the shape inference /// interface. void AddOp::inferShapes() { … } //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// /// Infer the output shape of the CastOp, this is required by the shape /// inference interface. void CastOp::inferShapes() { … } /// Returns true if the given set of input and result types are compatible with /// this cast operation. This is required by the `CastOpInterface` to verify /// this operation and provide other additional utilities. bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, llvm::StringRef name, mlir::FunctionType type, llvm::ArrayRef<mlir::NamedAttribute> attrs) { … } mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { … } void FuncOp::print(mlir::OpAsmPrinter &p) { … } //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, StringRef callee, ArrayRef<mlir::Value> arguments) { … } /// Return the callee of the generic call operation, this is required by the /// call interface. CallInterfaceCallable GenericCallOp::getCallableForCallee() { … } /// Set the callee for the generic call operation, this is required by the call /// interface. void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { … } /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { … } /// Get the argument operands to the called function as a mutable range, this is /// required by the call interface. MutableOperandRange GenericCallOp::getArgOperandsMutable() { … } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value lhs, mlir::Value rhs) { … } mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { … } void MulOp::print(mlir::OpAsmPrinter &p) { … } /// Infer the output shape of the MulOp, this is required by the shape inference /// interface. void MulOp::inferShapes() { … } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// llvm::LogicalResult ReturnOp::verify() { … } //===----------------------------------------------------------------------===// // StructAccessOp //===----------------------------------------------------------------------===// void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, mlir::Value input, size_t index) { … } llvm::LogicalResult StructAccessOp::verify() { … } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value value) { … } void TransposeOp::inferShapes() { … } llvm::LogicalResult TransposeOp::verify() { … } //===----------------------------------------------------------------------===// // Toy Types //===----------------------------------------------------------------------===// namespace mlir { namespace toy { namespace detail { /// This class represents the internal storage of the Toy `StructType`. struct StructTypeStorage : public mlir::TypeStorage { … }; } // namespace detail } // namespace toy } // namespace mlir /// Create an instance of a `StructType` with the given element types. There /// *must* be at least one element type. StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) { … } /// Returns the element types of this struct type. llvm::ArrayRef<mlir::Type> StructType::getElementTypes() { … } /// Parse an instance of a type registered to the toy dialect. mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { … } /// Print an instance of a type registered to the toy dialect. void ToyDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &printer) const { … } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "toy/Ops.cpp.inc" //===----------------------------------------------------------------------===// // ToyDialect //===----------------------------------------------------------------------===// /// Dialect initialization, the instance will be owned by the context. This is /// the point of registration of types and operations for the dialect. void ToyDialect::initialize() { … } mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Attribute value, mlir::Type type, mlir::Location loc) { … }