llvm/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""YAML serialization is routed through here to centralize common logic."""

import sys

try:
    import yaml
except ModuleNotFoundError as e:
    raise ModuleNotFoundError(
        f"This tool requires PyYAML but it was not installed. "
        f"Recommend: {sys.executable} -m pip install PyYAML"
    ) from e

__all__ = [
    "yaml_dump",
    "yaml_dump_all",
    "YAMLObject",
]


class YAMLObject(yaml.YAMLObject):
    @classmethod
    def to_yaml(cls, dumper, self):
        """Default to a custom dictionary mapping."""
        return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())

    def to_yaml_custom_dict(self):
        raise NotImplementedError()

    def as_linalg_yaml(self):
        return yaml_dump(self)


def multiline_str_representer(dumper, data):
    if len(data.splitlines()) > 1:
        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
    else:
        return dumper.represent_scalar("tag:yaml.org,2002:str", data)


yaml.add_representer(str, multiline_str_representer)


def yaml_dump(data, sort_keys=False, **kwargs):
    return yaml.dump(data, sort_keys=sort_keys, **kwargs)


def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
    return yaml.dump_all(
        data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
    )