kubernetes/vendor/github.com/checkpoint-restore/go-criu/v5/main.go

package criu

import (
	"errors"
	"fmt"
	"os"
	"os/exec"
	"strconv"
	"syscall"

	"github.com/checkpoint-restore/go-criu/v5/rpc"
	"google.golang.org/protobuf/proto"
)

// Criu struct
type Criu struct {
	swrkCmd  *exec.Cmd
	swrkSk   *os.File
	swrkPath string
}

// MakeCriu returns the Criu object required for most operations
func MakeCriu() *Criu {
	return &Criu{
		swrkPath: "criu",
	}
}

// SetCriuPath allows setting the path to the CRIU binary
// if it is in a non standard location
func (c *Criu) SetCriuPath(path string) {
	c.swrkPath = path
}

// Prepare sets up everything for the RPC communication to CRIU
func (c *Criu) Prepare() error {
	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_SEQPACKET, 0)
	if err != nil {
		return err
	}

	cln := os.NewFile(uintptr(fds[0]), "criu-xprt-cln")
	syscall.CloseOnExec(fds[0])
	srv := os.NewFile(uintptr(fds[1]), "criu-xprt-srv")
	defer srv.Close()

	args := []string{"swrk", strconv.Itoa(fds[1])}
	// #nosec G204
	cmd := exec.Command(c.swrkPath, args...)

	err = cmd.Start()
	if err != nil {
		cln.Close()
		return err
	}

	c.swrkCmd = cmd
	c.swrkSk = cln

	return nil
}

// Cleanup cleans up
func (c *Criu) Cleanup() {
	if c.swrkCmd != nil {
		c.swrkSk.Close()
		c.swrkSk = nil
		_ = c.swrkCmd.Wait()
		c.swrkCmd = nil
	}
}

func (c *Criu) sendAndRecv(reqB []byte) ([]byte, int, error) {
	cln := c.swrkSk
	_, err := cln.Write(reqB)
	if err != nil {
		return nil, 0, err
	}

	respB := make([]byte, 2*4096)
	n, err := cln.Read(respB)
	if err != nil {
		return nil, 0, err
	}

	return respB, n, nil
}

func (c *Criu) doSwrk(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify) error {
	resp, err := c.doSwrkWithResp(reqType, opts, nfy, nil)
	if err != nil {
		return err
	}
	respType := resp.GetType()
	if respType != reqType {
		return errors.New("unexpected CRIU RPC response")
	}

	return nil
}

func (c *Criu) doSwrkWithResp(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify, features *rpc.CriuFeatures) (*rpc.CriuResp, error) {
	var resp *rpc.CriuResp

	req := rpc.CriuReq{
		Type: &reqType,
		Opts: opts,
	}

	if nfy != nil {
		opts.NotifyScripts = proto.Bool(true)
	}

	if features != nil {
		req.Features = features
	}

	if c.swrkCmd == nil {
		err := c.Prepare()
		if err != nil {
			return nil, err
		}

		defer c.Cleanup()
	}

	for {
		reqB, err := proto.Marshal(&req)
		if err != nil {
			return nil, err
		}

		respB, respS, err := c.sendAndRecv(reqB)
		if err != nil {
			return nil, err
		}

		resp = &rpc.CriuResp{}
		err = proto.Unmarshal(respB[:respS], resp)
		if err != nil {
			return nil, err
		}

		if !resp.GetSuccess() {
			return resp, fmt.Errorf("operation failed (msg:%s err:%d)",
				resp.GetCrErrmsg(), resp.GetCrErrno())
		}

		respType := resp.GetType()
		if respType != rpc.CriuReqType_NOTIFY {
			break
		}
		if nfy == nil {
			return resp, errors.New("unexpected notify")
		}

		notify := resp.GetNotify()
		switch notify.GetScript() {
		case "pre-dump":
			err = nfy.PreDump()
		case "post-dump":
			err = nfy.PostDump()
		case "pre-restore":
			err = nfy.PreRestore()
		case "post-restore":
			err = nfy.PostRestore(notify.GetPid())
		case "network-lock":
			err = nfy.NetworkLock()
		case "network-unlock":
			err = nfy.NetworkUnlock()
		case "setup-namespaces":
			err = nfy.SetupNamespaces(notify.GetPid())
		case "post-setup-namespaces":
			err = nfy.PostSetupNamespaces()
		case "post-resume":
			err = nfy.PostResume()
		default:
			err = nil
		}

		if err != nil {
			return resp, err
		}

		req = rpc.CriuReq{
			Type:          &respType,
			NotifySuccess: proto.Bool(true),
		}
	}

	return resp, nil
}

// Dump dumps a process
func (c *Criu) Dump(opts *rpc.CriuOpts, nfy Notify) error {
	return c.doSwrk(rpc.CriuReqType_DUMP, opts, nfy)
}

// Restore restores a process
func (c *Criu) Restore(opts *rpc.CriuOpts, nfy Notify) error {
	return c.doSwrk(rpc.CriuReqType_RESTORE, opts, nfy)
}

// PreDump does a pre-dump
func (c *Criu) PreDump(opts *rpc.CriuOpts, nfy Notify) error {
	return c.doSwrk(rpc.CriuReqType_PRE_DUMP, opts, nfy)
}

// StartPageServer starts the page server
func (c *Criu) StartPageServer(opts *rpc.CriuOpts) error {
	return c.doSwrk(rpc.CriuReqType_PAGE_SERVER, opts, nil)
}

// StartPageServerChld starts the page server and returns PID and port
func (c *Criu) StartPageServerChld(opts *rpc.CriuOpts) (int, int, error) {
	resp, err := c.doSwrkWithResp(rpc.CriuReqType_PAGE_SERVER_CHLD, opts, nil, nil)
	if err != nil {
		return 0, 0, err
	}

	return int(resp.Ps.GetPid()), int(resp.Ps.GetPort()), nil
}

// GetCriuVersion executes the VERSION RPC call and returns the version
// as an integer. Major * 10000 + Minor * 100 + SubLevel
func (c *Criu) GetCriuVersion() (int, error) {
	resp, err := c.doSwrkWithResp(rpc.CriuReqType_VERSION, nil, nil, nil)
	if err != nil {
		return 0, err
	}

	if resp.GetType() != rpc.CriuReqType_VERSION {
		return 0, fmt.Errorf("Unexpected CRIU RPC response")
	}

	version := int(*resp.GetVersion().MajorNumber) * 10000
	version += int(*resp.GetVersion().MinorNumber) * 100
	if resp.GetVersion().Sublevel != nil {
		version += int(*resp.GetVersion().Sublevel)
	}

	if resp.GetVersion().Gitid != nil {
		// taken from runc: if it is a git release -> increase minor by 1
		version -= (version % 100)
		version += 100
	}

	return version, nil
}

// IsCriuAtLeast checks if the version is at least the same
// as the parameter version
func (c *Criu) IsCriuAtLeast(version int) (bool, error) {
	criuVersion, err := c.GetCriuVersion()
	if err != nil {
		return false, err
	}

	if criuVersion >= version {
		return true, nil
	}

	return false, nil
}