//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements an ODS (and C++) generator from a YAML form // derived from the mathematical expression of linalg named ops. Typically a // math oriented DSL will be used to export the essential representation to // this form, and maintaining the SOT at the math level (versus recreating it // in MLIR) is deemed to have systemic value. // //===----------------------------------------------------------------------===// #include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/YAMLTraits.h" #include <optional> usingnamespacemlir; Input; MappingTraits; ScalarEnumerationTraits; ScalarTraits; #define DEBUG_TYPE … //===----------------------------------------------------------------------===// // Mapping structs (correspond to data types in the YAML description). // TODO: Since this is a schema/part of the contract, it should be moved to // a real header. //===----------------------------------------------------------------------===// namespace { struct LinalgYAMLContext { … }; struct LinalgOpMetadata { … }; struct SerializedAffineMap { … }; enum class LinalgOperandDefKind { … }; struct LinalgOperandDef { … }; enum class LinalgIteratorTypeDef { … }; struct LinalgIndexingMapsConfig { … }; struct ScalarExpression; enum class ScalarFnKind { … }; struct ScalarFn { … }; struct ScalarExpression { … }; struct ScalarAssign { … }; struct LinalgStructuredOpConfig { … }; struct LinalgOpConfig { … }; } // namespace //===----------------------------------------------------------------------===// // Mapping traits. //===----------------------------------------------------------------------===// LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(…) namespace llvm { namespace yaml { /// Top-level type containing op metadata and one of a concrete op type. /// Currently, the only defined op type is `structured_op` (maps to /// `LinalgStructuredOpConfig`). template <> struct MappingTraits<LinalgOpConfig> { … }; /// A structured op models (at most) a single contraction by modeling /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, /// outputs, or index attributes. /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). template <> struct MappingTraits<LinalgStructuredOpConfig> { … }; /// Maps a named tensor, scalar or attribute argument to an operation, /// consisting of: /// - `name`: Must be unique within the operation. /// - `usage`: How the argument is used (input, output, attribute, etc). /// - `type_var`: The symbolic type variable that binds to the element or self /// type of the tensor or scalar argument, respectively. /// - `shape_map`: An optional AffineMap from all op symbols to the shape of /// the argument. Only tensor arguments have a `shape_map`. Each shape must /// be normalized over the same list of symbols and have no dimension /// inputs. /// - `index_attr_map`: An optional AffineMap from all op symbols to the /// index attribute symbols. During op creation these symbols are replaced /// by the corresponding `name` index attribue values. Only index attribute /// arguments have an `index_attr_map`. /// - `default_indices`: An optional default initialization for index /// attribute arguments. /// - `default_fn`: An optional default initialization for function attribute /// arguments. template <> struct MappingTraits<LinalgOperandDef> { … }; /// Usage enum for a named argument. template <> struct ScalarEnumerationTraits<LinalgOperandDefKind> { … }; /// Iterator type enum. template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> { … }; /// Metadata about the op (name, C++ name, and documentation). template <> struct MappingTraits<LinalgOpMetadata> { … }; /// How the ops indexing maps are produced. Must be one of: /// - static_indexing_maps: A static list of AffineMaps, possibly with /// some symbols that bind to attributes of the op. Each indexing map must /// be normalized over the same list of dimensions, and its symbols must /// match the symbols for argument shapes. template <> struct MappingTraits<LinalgIndexingMapsConfig> { … }; /// Models an assignment to a named output. /// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). template <> struct MappingTraits<ScalarAssign> { … }; /// A scalar expression (RHS of an assignment). Must be one of: /// - `scalar_arg`: An operation argument. /// - `scalar_const`: A constant definition. /// - `scalar_index`: An iteration index. /// - `scalar_fn`: A named function (see `ScalarFn`). template <> struct MappingTraits<ScalarExpression> { … }; /// Scalar function kind enum. template <> struct ScalarEnumerationTraits<ScalarFnKind> { … }; /// A scalar expression that evaluates a named function. /// Functions are generally "math" level and type polymorphic. Builtin /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` template <> struct MappingTraits<ScalarFn> { … }; /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. template <> struct ScalarTraits<SerializedAffineMap> { … }; } // namespace yaml } // namespace llvm namespace { //===----------------------------------------------------------------------===// // Generation utilities //===----------------------------------------------------------------------===// class GenerationContext { … }; } // namespace static std::string generateCppExpression(SerializedAffineMap self, StringRef contextName) { … } template <typename Container> static std::string interleaveToString(Container &container, StringRef separator) { … } static std::optional<int> findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) { … } // Try to map the TypeVar to a predefined or an argument type. static std::optional<std::string> findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) { … } static ScalarAssign *findAssignment(StringRef name, std::vector<ScalarAssign> &assignments) { … } // Return true if the operand is a function attribute. static bool isFunctionAttribute(LinalgOperandDefKind kind) { … } // Return true if the operand is an attribute. static bool isAttribute(LinalgOperandDefKind kind) { … } // Get the enum name for the given operand kind. std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { … } // Get the enum name for the given function kind. std::string convertFunctionKindToEnumName(ScalarFnKind kind) { … } //===----------------------------------------------------------------------===// // Templates //===----------------------------------------------------------------------===// // A single line banner format. Parameters: // {0}: Single line comment static const char bannerFormat[] = …; //===----------------------------------------------------------------------===// // Named generic op generation. // These ops map at most a single contraction that complies with the limitations // of a linalg.generic. //===----------------------------------------------------------------------===// // Template for Linalg named ops' ODS definitions. Parameters: // {0}: ODS/C++ op name // {1}: assembly op mnemonic // {2}: op interface list // {3}: documentation (summary + description) // {4}: op attribute list // {5}: builder methods taking standalone attribute parameters // {6}: additional method defintions // {7}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = …; // Builder method taking attribute parameters. Parameters: // {0}: Class name // {1}: Comma interleaved attribute parameters // {2}: Attribute initialization static const char structuredOpBuilderFormat[] = …; // The getIteratorTypesArray() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = …; // The getIteratorTypesArray() method for rank polymorphic structured ops. // Parameters: // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = …; // The indexing_maps() method for structured ops. Parameters: // {0}: Class name // {1}: Comma-separated list of dimension variable names. // {2}: Statements static const char structuredOpIndexingMapsFormat[] = …; // The indexing_maps() method for rank polymorphic structured ops. Parameters: // {0}: Class name static const char rankPolyStructuredOpIndexingMapsFormat[] = …; // Implementations of fold, getEffects and getSpeculatability. // Parameters: // {0}: Class name const char structuredOpFoldersFormat[] = …; // Implementation of parse/print. // Parameters: // {0}: Class name static const char structuredOpParserFormat[] = …; static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, GenerationContext &genContext) { … } static LogicalResult generateNamedGenericOpDefns(LinalgOpConfig &opConfig, GenerationContext &genContext) { … } static LogicalResult generateOp(LinalgOpConfig &opConfig, GenerationContext &genContext) { … } //===----------------------------------------------------------------------===// // Command line options and main //===----------------------------------------------------------------------===// static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); static llvm::cl::opt<std::string> outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("")); static llvm::cl::opt<std::string> outputCppImplFilename("o-impl", llvm::cl::desc("C++ implementation file name"), llvm::cl::value_desc("filename"), llvm::cl::init("")); int main(int argc, char **argv) { … }