kubernetes/pkg/util/iptables/monitor_test.go

//go:build linux
// +build linux

/*
Copyright 2019 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package iptables

import (
	"context"
	"fmt"
	"io"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"k8s.io/apimachinery/pkg/util/sets"
	utilwait "k8s.io/apimachinery/pkg/util/wait"
	"k8s.io/utils/exec"
)

// We can't use the normal FakeExec because we don't know precisely how many times the
// Monitor thread will do its checks, and we don't know precisely how its iptables calls
// will interleave with the main thread's. So we use our own fake Exec implementation that
// implements a minimal iptables interface. This will need updates as iptables.runner
// changes its use of Exec.
type monitorFakeExec struct {
	sync.Mutex

	tables map[string]sets.Set[string]

	block      bool
	wasBlocked bool
}

func newMonitorFakeExec() *monitorFakeExec {
	tables := make(map[string]sets.Set[string])
	tables["mangle"] = sets.New[string]()
	tables["filter"] = sets.New[string]()
	tables["nat"] = sets.New[string]()
	return &monitorFakeExec{tables: tables}
}

func (mfe *monitorFakeExec) blockIPTables(block bool) {
	mfe.Lock()
	defer mfe.Unlock()

	mfe.block = block
}

func (mfe *monitorFakeExec) getWasBlocked() bool {
	mfe.Lock()
	defer mfe.Unlock()

	wasBlocked := mfe.wasBlocked
	mfe.wasBlocked = false
	return wasBlocked
}

func (mfe *monitorFakeExec) Command(cmd string, args ...string) exec.Cmd {
	return &monitorFakeCmd{mfe: mfe, cmd: cmd, args: args}
}

func (mfe *monitorFakeExec) CommandContext(ctx context.Context, cmd string, args ...string) exec.Cmd {
	return mfe.Command(cmd, args...)
}

func (mfe *monitorFakeExec) LookPath(file string) (string, error) {
	return file, nil
}

type monitorFakeCmd struct {
	mfe  *monitorFakeExec
	cmd  string
	args []string
}

func (mfc *monitorFakeCmd) CombinedOutput() ([]byte, error) {
	if mfc.cmd == cmdIPTablesRestore {
		// Only used for "iptables-restore --version", and the result doesn't matter
		return []byte{}, nil
	} else if mfc.cmd != cmdIPTables {
		panic("bad command " + mfc.cmd)
	}

	if len(mfc.args) == 1 && mfc.args[0] == "--version" {
		return []byte("iptables v1.6.2"), nil
	}

	if len(mfc.args) != 8 || mfc.args[0] != WaitString || mfc.args[1] != WaitSecondsValue || mfc.args[2] != WaitIntervalString || mfc.args[3] != WaitIntervalUsecondsValue || mfc.args[6] != "-t" {
		panic(fmt.Sprintf("bad args %#v", mfc.args))
	}
	op := operation(mfc.args[4])
	chainName := mfc.args[5]
	tableName := mfc.args[7]

	mfc.mfe.Lock()
	defer mfc.mfe.Unlock()

	table := mfc.mfe.tables[tableName]
	if table == nil {
		return []byte{}, fmt.Errorf("no such table %q", tableName)
	}

	// For ease-of-testing reasons, blockIPTables blocks create and list, but not delete
	if mfc.mfe.block && op != opDeleteChain {
		mfc.mfe.wasBlocked = true
		return []byte{}, exec.CodeExitError{Code: 4, Err: fmt.Errorf("could not get xtables.lock, etc")}
	}

	switch op {
	case opCreateChain:
		if !table.Has(chainName) {
			table.Insert(chainName)
		}
		return []byte{}, nil
	case opListChain:
		if table.Has(chainName) {
			return []byte{}, nil
		}
		return []byte{}, fmt.Errorf("no such chain %q", chainName)
	case opDeleteChain:
		table.Delete(chainName)
		return []byte{}, nil
	default:
		panic("should not be reached")
	}
}

func (mfc *monitorFakeCmd) SetStdin(in io.Reader) {
	// Used by getIPTablesRestoreVersionString(), can be ignored
}

func (mfc *monitorFakeCmd) Run() error {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) Output() ([]byte, error) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) SetDir(dir string) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) SetStdout(out io.Writer) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) SetStderr(out io.Writer) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) SetEnv(env []string) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) StdoutPipe() (io.ReadCloser, error) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) StderrPipe() (io.ReadCloser, error) {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) Start() error {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) Wait() error {
	panic("should not be reached")
}

func (mfc *monitorFakeCmd) Stop() {
	panic("should not be reached")
}

func TestIPTablesMonitor(t *testing.T) {
	mfe := newMonitorFakeExec()
	ipt := New(mfe, ProtocolIPv4)

	var reloads uint32
	stopCh := make(chan struct{})

	canary := Chain("MONITOR-TEST-CANARY")
	tables := []Table{TableMangle, TableFilter, TableNAT}
	go ipt.Monitor(canary, tables, func() {
		if !ensureNoChains(mfe) {
			t.Errorf("reload called while canaries still exist")
		}
		atomic.AddUint32(&reloads, 1)
	}, 100*time.Millisecond, stopCh)

	// Monitor should create canary chains quickly
	if err := waitForChains(mfe, canary, tables); err != nil {
		t.Errorf("failed to create iptables canaries: %v", err)
	}

	if err := waitForReloads(&reloads, 0); err != nil {
		t.Errorf("got unexpected reloads: %v", err)
	}

	// If we delete all of the chains, it should reload
	ipt.DeleteChain(TableMangle, canary)
	ipt.DeleteChain(TableFilter, canary)
	ipt.DeleteChain(TableNAT, canary)

	if err := waitForReloads(&reloads, 1); err != nil {
		t.Errorf("got unexpected number of reloads after flush: %v", err)
	}
	if err := waitForChains(mfe, canary, tables); err != nil {
		t.Errorf("failed to create iptables canaries: %v", err)
	}

	// If we delete two chains, it should not reload yet
	ipt.DeleteChain(TableMangle, canary)
	ipt.DeleteChain(TableFilter, canary)

	if err := waitForNoReload(&reloads, 1); err != nil {
		t.Errorf("got unexpected number of reloads after partial flush: %v", err)
	}

	// Now ensure that "iptables -L" will get an error about the xtables.lock, and
	// delete the last chain. The monitor should not reload, because it can't actually
	// tell if the chain was deleted or not.
	mfe.blockIPTables(true)
	ipt.DeleteChain(TableNAT, canary)
	if err := waitForBlocked(mfe); err != nil {
		t.Errorf("failed waiting for monitor to be blocked from monitoring: %v", err)
	}

	// After unblocking the monitor, it should now reload
	mfe.blockIPTables(false)

	if err := waitForReloads(&reloads, 2); err != nil {
		t.Errorf("got unexpected number of reloads after slow flush: %v", err)
	}
	if err := waitForChains(mfe, canary, tables); err != nil {
		t.Errorf("failed to create iptables canaries: %v", err)
	}

	// If we close the stop channel, it should stop running
	close(stopCh)

	if err := waitForNoReload(&reloads, 2); err != nil {
		t.Errorf("got unexpected number of reloads after stop: %v", err)
	}
	if !ensureNoChains(mfe) {
		t.Errorf("canaries still exist after stopping monitor")
	}

	// If we create a new monitor while the iptables lock is held, it will
	// retry creating canaries until it succeeds

	stopCh = make(chan struct{})
	_ = mfe.getWasBlocked()
	mfe.blockIPTables(true)
	go ipt.Monitor(canary, tables, func() {
		if !ensureNoChains(mfe) {
			t.Errorf("reload called while canaries still exist")
		}
		atomic.AddUint32(&reloads, 1)
	}, 100*time.Millisecond, stopCh)

	// Monitor should not have created canaries yet
	if !ensureNoChains(mfe) {
		t.Errorf("canary created while iptables blocked")
	}

	if err := waitForBlocked(mfe); err != nil {
		t.Errorf("failed waiting for monitor to fail creating canaries: %v", err)
	}

	mfe.blockIPTables(false)
	if err := waitForChains(mfe, canary, tables); err != nil {
		t.Errorf("failed to create iptables canaries: %v", err)
	}

	close(stopCh)
}

func waitForChains(mfe *monitorFakeExec, canary Chain, tables []Table) error {
	return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
		mfe.Lock()
		defer mfe.Unlock()

		for _, table := range tables {
			if !mfe.tables[string(table)].Has(string(canary)) {
				return false, nil
			}
		}
		return true, nil
	})
}

func ensureNoChains(mfe *monitorFakeExec) bool {
	mfe.Lock()
	defer mfe.Unlock()
	return mfe.tables["mangle"].Len() == 0 &&
		mfe.tables["filter"].Len() == 0 &&
		mfe.tables["nat"].Len() == 0
}

func waitForReloads(reloads *uint32, expected uint32) error {
	if atomic.LoadUint32(reloads) < expected {
		utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
			return atomic.LoadUint32(reloads) >= expected, nil
		})
	}
	got := atomic.LoadUint32(reloads)
	if got != expected {
		return fmt.Errorf("expected %d, got %d", expected, got)
	}
	return nil
}

func waitForNoReload(reloads *uint32, expected uint32) error {
	utilwait.PollImmediate(50*time.Millisecond, 250*time.Millisecond, func() (bool, error) {
		return atomic.LoadUint32(reloads) > expected, nil
	})

	got := atomic.LoadUint32(reloads)
	if got != expected {
		return fmt.Errorf("expected %d, got %d", expected, got)
	}
	return nil
}

func waitForBlocked(mfe *monitorFakeExec) error {
	return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
		blocked := mfe.getWasBlocked()
		return blocked, nil
	})
}