chromium/third_party/wpt_tools/wpt/tools/wptrunner/wptrunner/update/tree.py

# mypy: allow-untyped-defs

import os
import re
import subprocess
import tempfile

from .. import vcs
from ..vcs import git, hg


def get_unique_name(existing, initial):
    """Get a name either equal to initial or of the form initial_N, for some
    integer N, that is not in the set existing.


    :param existing: Set of names that must not be chosen.
    :param initial: Name, or name prefix, to use"""
    if initial not in existing:
        return initial
    for i in range(len(existing) + 1):
        test = f"{initial}_{i + 1}"
        if test not in existing:
            return test
    assert False

class NoVCSTree:
    name = "non-vcs"

    def __init__(self, root=None):
        if root is None:
            root = os.path.abspath(os.curdir)
        self.root = root

    @classmethod
    def is_type(cls, path=None):
        return True

    @property
    def is_clean(self):
        return True

    def add_new(self, prefix=None):
        pass

    def add_ignored(self, sync_tree, prefix):
        pass

    def create_patch(self, patch_name, message):
        pass

    def update_patch(self, include=None):
        pass

    def commit_patch(self):
        pass


class HgTree:
    name = "mercurial"

    def __init__(self, root=None):
        if root is None:
            root = hg("root").strip()
        self.root = root
        self.hg = vcs.bind_to_repo(hg, self.root)

    def __getstate__(self):
        rv = self.__dict__.copy()
        del rv['hg']
        return rv

    def __setstate__(self, dict):
        self.__dict__.update(dict)
        self.hg = vcs.bind_to_repo(vcs.hg, self.root)

    @classmethod
    def is_type(cls, path=None):
        kwargs = {"log_error": False}
        if path is not None:
            kwargs["repo"] = path
        try:
            hg("root", **kwargs)
        except Exception:
            return False
        return True

    @property
    def is_clean(self):
        return self.hg("status").strip() == b""

    def add_new(self, prefix=None):
        if prefix is not None:
            args = ("-I", prefix)
        else:
            args = ()
        self.hg("add", *args)

    def add_ignored(self, sync_tree, prefix):
        pass

    def create_patch(self, patch_name, message):
        try:
            self.hg("qinit", log_error=False)
        except subprocess.CalledProcessError:
            pass

        patch_names = [item.strip() for item in self.hg("qseries").split(b"\n") if item.strip()]

        suffix = 0
        test_name = patch_name
        while test_name in patch_names:
            suffix += 1
            test_name = "%s-%i" % (patch_name, suffix)

        self.hg("qnew", test_name, "-X", self.root, "-m", message)

    def update_patch(self, include=None):
        if include is not None:
            args = []
            for item in include:
                args.extend(["-I", item])
        else:
            args = ()

        self.hg("qrefresh", *args)
        return True

    def commit_patch(self):
        self.hg("qfinish")

    def contains_commit(self, commit):
        try:
            self.hg("identify", "-r", commit.sha1)
            return True
        except subprocess.CalledProcessError:
            return False


class GitTree:
    name = "git"

    def __init__(self, root=None, log_error=True):
        if root is None:
            root = git("rev-parse", "--show-toplevel", log_error=log_error).strip().decode('utf-8')
        self.root = root
        self.git = vcs.bind_to_repo(git, self.root, log_error=log_error)
        self.message = None
        self.commit_cls = Commit

    def __getstate__(self):
        rv = self.__dict__.copy()
        del rv['git']
        return rv

    def __setstate__(self, dict):
        self.__dict__.update(dict)
        self.git = vcs.bind_to_repo(vcs.git, self.root)

    @classmethod
    def is_type(cls, path=None):
        kwargs = {"log_error": False}
        if path is not None:
            kwargs["repo"] = path
        try:
            git("rev-parse", "--show-toplevel", **kwargs)
        except Exception:
            return False
        return True

    @property
    def rev(self):
        """Current HEAD revision"""
        if vcs.is_git_root(self.root):
            return self.git("rev-parse", "HEAD").strip()
        else:
            return None

    @property
    def is_clean(self):
        return self.git("status").strip() == b""

    def add_new(self, prefix=None):
        """Add files to the staging area.

        :param prefix: None to include all files or a path prefix to
                       add all files under that path.
        """
        if prefix is None:
            args = ["-a"]
        else:
            args = ["--no-ignore-removal", prefix]
        self.git("add", *args)

    def add_ignored(self, sync_tree, prefix):
        """Add files to the staging area that are explicitly ignored by git.

        :param prefix: None to include all files or a path prefix to
                       add all files under that path.
        """
        with tempfile.TemporaryFile() as f:
            sync_tree.git("ls-tree", "-z", "-r", "--name-only", "HEAD", stdout=f)
            f.seek(0)
            ignored_files = sync_tree.git("check-ignore", "--no-index", "--stdin", "-z", stdin=f)
        args = []
        for entry in ignored_files.decode('utf-8').split('\0'):
            args.append(os.path.join(prefix, entry))
        if args:
            self.git("add", "--force", *args)

    def list_refs(self, ref_filter=None):
        """Get a list of sha1, name tuples for references in a repository.

        :param ref_filter: Pattern that reference name must match (from the end,
                           matching whole /-delimited segments only
        """
        args = []
        if ref_filter is not None:
            args.append(ref_filter)
        data = self.git("show-ref", *args)
        rv = []
        for line in data.split(b"\n"):
            if not line.strip():
                continue
            sha1, ref = line.split()
            rv.append((sha1, ref))
        return rv

    def list_remote(self, remote, ref_filter=None):
        """Return a list of (sha1, name) tupes for references in a remote.

        :param remote: URL of the remote to list.
        :param ref_filter: Pattern that the reference name must match.
        """
        args = []
        if ref_filter is not None:
            args.append(ref_filter)
        data = self.git("ls-remote", remote, *args)
        rv = []
        for line in data.split(b"\n"):
            if not line.strip():
                continue
            sha1, ref = line.split()
            rv.append((sha1, ref))
        return rv

    def get_remote_sha1(self, remote, branch):
        """Return the SHA1 of a particular branch in a remote.

        :param remote: the remote URL
        :param branch: the branch name"""
        for sha1, ref in self.list_remote(remote, branch):
            if ref.decode('utf-8') == "refs/heads/%s" % branch:
                return self.commit_cls(self, sha1.decode('utf-8'))
        assert False

    def create_patch(self, patch_name, message):
        # In git a patch is actually a commit
        self.message = message

    def update_patch(self, include=None):
        """Commit the staged changes, or changes to listed files.

        :param include: Either None, to commit staged changes, or a list
                        of filenames (which must already be in the repo)
                        to commit
        """
        if include is not None:
            args = tuple(include)
        else:
            args = ()

        if self.git("status", "-uno", "-z", *args).strip():
            self.git("add", *args)
            return True
        return False

    def commit_patch(self):
        assert self.message is not None

        if self.git("diff", "--name-only", "--staged", "-z").strip():
            self.git("commit", "-m", self.message)
            return True

        return False

    def init(self):
        self.git("init")
        assert vcs.is_git_root(self.root)

    def checkout(self, rev, branch=None, force=False):
        """Checkout a particular revision, optionally into a named branch.

        :param rev: Revision identifier (e.g. SHA1) to checkout
        :param branch: Branch name to use
        :param force: Force-checkout
        """
        assert rev is not None

        args = []
        if branch:
            branches = [ref[len("refs/heads/"):].decode('utf-8') for sha1, ref in self.list_refs()
                        if ref.startswith(b"refs/heads/")]
            branch = get_unique_name(branches, branch)

            args += ["-b", branch]

        if force:
            args.append("-f")
        args.append(rev)
        self.git("checkout", *args)

    def update(self, remote, remote_branch, local_branch):
        """Fetch from the remote and checkout into a local branch.

        :param remote: URL to the remote repository
        :param remote_branch: Branch on the remote repository to check out
        :param local_branch: Local branch name to check out into
        """
        if not vcs.is_git_root(self.root):
            self.init()
        self.git("clean", "-xdf")
        self.git("fetch", remote, f"{remote_branch}:{local_branch}")
        self.checkout(local_branch)
        self.git("submodule", "update", "--init", "--recursive")

    def clean(self):
        self.git("checkout", self.rev)
        self.git("branch", "-D", self.local_branch)

    def paths(self):
        """List paths in the tree"""
        repo_paths = [self.root] + [os.path.join(self.root, path)
                                    for path in self.submodules()]

        rv = []

        for repo_path in repo_paths:
            paths = vcs.git("ls-tree", "-r", "--name-only", "HEAD", repo=repo_path).split(b"\n")
            rv.extend(os.path.relpath(os.path.join(repo_path, item.decode('utf-8')), self.root) for item in paths
                      if item.strip())
        return rv

    def submodules(self):
        """List submodule directories"""
        output = self.git("submodule", "status", "--recursive")
        rv = []
        for line in output.split(b"\n"):
            line = line.strip()
            if not line:
                continue
            parts = line.split(b" ")
            rv.append(parts[1])
        return rv

    def contains_commit(self, commit):
        try:
            self.git("rev-parse", "--verify", commit.sha1)
            return True
        except subprocess.CalledProcessError:
            return False


class CommitMessage:
    def __init__(self, text):
        self.text = text
        self._parse_message()

    def __str__(self):
        return self.text

    def _parse_message(self):
        lines = self.text.splitlines()
        self.full_summary = lines[0]
        self.body = "\n".join(lines[1:])


class Commit:
    msg_cls = CommitMessage

    _sha1_re = re.compile("^[0-9a-f]{40}$")

    def __init__(self, tree, sha1):
        """Object representing a commit in a specific GitTree.

        :param tree: GitTree to which this commit belongs.
        :param sha1: Full sha1 string for the commit
        """
        assert self._sha1_re.match(sha1)

        self.tree = tree
        self.git = tree.git
        self.sha1 = sha1
        self.author, self.email, self.message = self._get_meta()

    def __getstate__(self):
        rv = self.__dict__.copy()
        del rv['git']
        return rv

    def __setstate__(self, dict):
        self.__dict__.update(dict)
        self.git = self.tree.git

    def _get_meta(self):
        author, email, message = self.git("show", "-s", "--format=format:%an\n%ae\n%B", self.sha1).decode('utf-8').split("\n", 2)
        return author, email, self.msg_cls(message)