kubernetes/cluster/images/etcd/migrate/integration_test.go

//go:build integration
// +build integration

/*
Copyright 2018 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
	"bytes"
	cryptorand "crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"flag"
	"fmt"
	"math/big"
	"net"
	"os"
	"path/filepath"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/blang/semver/v4"
	"k8s.io/klog/v2"
	netutils "k8s.io/utils/net"
)

var (
	testSupportedVersions = mustParseSupportedVersions([]string{"3.0.17", "3.1.12"})
	testVersionPrevious   = &EtcdVersion{semver.MustParse("3.0.17")}
	testVersionLatest     = &EtcdVersion{semver.MustParse("3.1.12")}
)

func init() {
	// Enable klog which is used in dependencies
	klog.InitFlags(nil)
	flag.Set("logtostderr", "true")
	flag.Set("v", "9")
}

func TestMigrate(t *testing.T) {
	migrations := []struct {
		title            string
		memberCount      int
		startVersion     string
		endVersion       string
		protocol         string
		clientListenUrls string
	}{
		// upgrades
		{"v3-v3-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},
		{"oldest-newest-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},
		{"v3-v3-up-with-additional-client-url", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", "http://127.0.0.1:2379,http://10.128.0.1:2379"},

		// warning: v2->v3 ha upgrades not currently supported.
		{"ha-v3-v3-up", 3, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},

		// downgrades
		{"v3-v3-down", 1, "3.1.12/etcd3", "3.0.17/etcd3", "https", ""},

		// warning: ha downgrades not yet supported.
	}

	for _, m := range migrations {
		t.Run(m.title, func(t *testing.T) {
			start := mustParseEtcdVersionPair(m.startVersion)
			end := mustParseEtcdVersionPair(m.endVersion)

			testCfgs := clusterConfig(t, m.title, m.memberCount, m.protocol, m.clientListenUrls)

			servers := []*EtcdMigrateServer{}
			for _, cfg := range testCfgs {
				client, err := NewEtcdMigrateClient(cfg)
				if err != nil {
					t.Fatalf("Failed to create client: %v", err)
				}
				server := NewEtcdMigrateServer(cfg, client)
				servers = append(servers, server)
			}

			// Start the servers.
			parallel(servers, func(server *EtcdMigrateServer) {
				dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
				if err != nil {
					t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
				}
				migrator := &Migrator{server.cfg, dataDir, server.client}
				err = migrator.MigrateIfNeeded(start)
				if err != nil {
					t.Fatalf("Migration failed: %v", err)
				}
				err = server.Start(start.version)
				if err != nil {
					t.Fatalf("Failed to start server: %v", err)
				}
			})

			// Write a value to each server, read it back.
			parallel(servers, func(server *EtcdMigrateServer) {
				key := fmt.Sprintf("/registry/%s", server.cfg.name)
				value := fmt.Sprintf("value-%s", server.cfg.name)
				err := server.client.Put(start.version, key, value)
				if err != nil {
					t.Fatalf("failed to write text value: %v", err)
				}

				checkVal, err := server.client.Get(start.version, key)
				if err != nil {
					t.Errorf("Error getting %s for validation: %v", key, err)
				}
				if checkVal != value {
					t.Errorf("Expected %s from %s but got %s", value, key, checkVal)
				}
			})

			// Migrate the servers in series.
			serial(servers, func(server *EtcdMigrateServer) {
				err := server.Stop()
				if err != nil {
					t.Fatalf("Stop server failed: %v", err)
				}
				dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
				if err != nil {
					t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
				}
				migrator := &Migrator{server.cfg, dataDir, server.client}
				err = migrator.MigrateIfNeeded(end)
				if err != nil {
					t.Fatalf("Migration failed: %v", err)
				}
				err = server.Start(end.version)
				if err != nil {
					t.Fatalf("Start server failed: %v", err)
				}
			})

			// Check that all test values can be read back from all the servers.
			parallel(servers, func(server *EtcdMigrateServer) {
				for _, s := range servers {
					key := fmt.Sprintf("/registry/%s", s.cfg.name)
					value := fmt.Sprintf("value-%s", s.cfg.name)
					checkVal, err := server.client.Get(end.version, key)
					if err != nil {
						t.Errorf("Error getting %s from etcd 2.x after rollback from 3.x: %v", key, err)
					}
					if checkVal != value {
						t.Errorf("Expected %s from %s but got %s when reading after rollback from %s to %s", value, key, checkVal, start, end)
					}
				}
			})

			// Stop the servers.
			parallel(servers, func(server *EtcdMigrateServer) {
				err := server.Stop()
				if err != nil {
					t.Fatalf("Failed to stop server: %v", err)
				}
			})

			// Check that version.txt contains the correct end version.
			parallel(servers, func(server *EtcdMigrateServer) {
				dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
				v, err := dataDir.versionFile.Read()
				if err != nil {
					t.Fatalf("Failed to read version.txt file: %v", err)
				}
				if !v.Equals(end) {
					t.Errorf("Expected version.txt to contain %s but got %s", end, v)
				}
				// Integration tests are run in a docker container with umask of 0022.
				checkPermissions(t, server.cfg.dataDirectory, 0755|os.ModeDir)
				checkPermissions(t, dataDir.versionFile.path, 0644)
			})
		})
	}
}

func parallel(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
	var wg sync.WaitGroup
	wg.Add(len(servers))
	for _, server := range servers {
		go func(s *EtcdMigrateServer) {
			defer wg.Done()
			fn(s)
		}(server)
	}
	wg.Wait()
}

func serial(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
	for _, server := range servers {
		fn(server)
	}
}

func checkPermissions(t *testing.T, path string, expected os.FileMode) {
	info, err := os.Stat(path)
	if err != nil {
		t.Fatalf("Failed to stat file %s: %v", path, err)
	}
	if info.Mode() != expected {
		t.Errorf("Expected permissions for file %s of %s, but got %s", path, expected, info.Mode())
	}
}

func clusterConfig(t *testing.T, name string, memberCount int, protocol string, clientListenUrls string) []*EtcdMigrateCfg {
	peers := []string{}
	for i := 0; i < memberCount; i++ {
		memberName := fmt.Sprintf("%s-%d", name, i)
		peerPort := uint64(2380 + i*10000)
		peer := fmt.Sprintf("%s=%s://127.0.0.1:%d", memberName, protocol, peerPort)
		peers = append(peers, peer)
	}
	initialCluster := strings.Join(peers, ",")

	extraArgs := ""
	if protocol == "https" {
		extraArgs = getOrCreateTLSPeerCertArgs(t)
	}

	cfgs := []*EtcdMigrateCfg{}
	for i := 0; i < memberCount; i++ {
		memberName := fmt.Sprintf("%s-%d", name, i)
		peerURL := fmt.Sprintf("%s://127.0.0.1:%d", protocol, uint64(2380+i*10000))
		cfg := &EtcdMigrateCfg{
			binPath:           "/usr/local/bin",
			name:              memberName,
			initialCluster:    initialCluster,
			port:              uint64(2379 + i*10000),
			peerListenUrls:    peerURL,
			peerAdvertiseUrls: peerURL,
			clientListenUrls:  clientListenUrls,
			etcdDataPrefix:    "/registry",
			ttlKeysDirectory:  "/registry/events",
			supportedVersions: testSupportedVersions,
			dataDirectory:     fmt.Sprintf("/tmp/etcd-data-dir-%s", memberName),
			etcdServerArgs:    extraArgs,
		}
		cfgs = append(cfgs, cfg)
	}
	return cfgs
}

func getOrCreateTLSPeerCertArgs(t *testing.T) string {
	spec := TestCertSpec{
		host: "localhost",
		ips:  []string{"127.0.0.1"},
	}
	certDir := "/tmp/certs"
	certFile := filepath.Join(certDir, "test.crt")
	keyFile := filepath.Join(certDir, "test.key")
	err := getOrCreateTestCertFiles(certFile, keyFile, spec)
	if err != nil {
		t.Fatalf("failed to create server cert: %v", err)
	}
	return fmt.Sprintf("--peer-client-cert-auth --peer-trusted-ca-file=%s --peer-cert-file=%s --peer-key-file=%s", certFile, certFile, keyFile)
}

type TestCertSpec struct {
	host       string
	names, ips []string // in certificate
}

func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) {
	if _, err := os.Stat(certFileName); err == nil {
		if _, err := os.Stat(keyFileName); err == nil {
			return nil
		}
	}

	certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
	if err != nil {
		return err
	}

	os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0777))
	err = os.WriteFile(certFileName, certPem, os.FileMode(0777))
	if err != nil {
		return err
	}

	os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0777))
	err = os.WriteFile(keyFileName, keyPem, os.FileMode(0777))
	if err != nil {
		return err
	}

	return nil
}

func parseIPList(ips []string) []net.IP {
	var netIPs []net.IP
	for _, ip := range ips {
		netIPs = append(netIPs, netutils.ParseIPSloppy(ip))
	}
	return netIPs
}

// generateSelfSignedCertKey creates a self-signed certificate and key for the given host.
// Host may be an IP or a DNS name
// You may also specify additional subject alt names (either ip or dns names) for the certificate
func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
	priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
	if err != nil {
		return nil, nil, err
	}

	template := x509.Certificate{
		SerialNumber: big.NewInt(1),
		Subject: pkix.Name{
			CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
		},
		NotBefore: time.Unix(0, 0),
		NotAfter:  time.Now().Add(time.Hour * 24 * 365 * 100),

		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
	}

	if ip := netutils.ParseIPSloppy(host); ip != nil {
		template.IPAddresses = append(template.IPAddresses, ip)
	} else {
		template.DNSNames = append(template.DNSNames, host)
	}

	template.IPAddresses = append(template.IPAddresses, alternateIPs...)
	template.DNSNames = append(template.DNSNames, alternateDNS...)

	derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
	if err != nil {
		return nil, nil, err
	}

	// Generate cert
	certBuffer := bytes.Buffer{}
	if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
		return nil, nil, err
	}

	// Generate key
	keyBuffer := bytes.Buffer{}
	if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
		return nil, nil, err
	}

	return certBuffer.Bytes(), keyBuffer.Bytes(), nil
}

// mustParseEtcdVersionPair parses a "<version>/<storage-version>" string to an EtcdVersionPair
// or panics if the parse fails.
func mustParseEtcdVersionPair(s string) *EtcdVersionPair {
	pair, err := ParseEtcdVersionPair(s)
	if err != nil {
		panic(err)
	}
	return pair
}

// mustParseSupportedVersions parses a comma separated list of etcd versions or panics if the parse fails.
func mustParseSupportedVersions(list []string) SupportedVersions {
	versions, err := ParseSupportedVersions(list)
	if err != nil {
		panic(err)
	}
	return versions
}