kubernetes/vendor/github.com/Microsoft/go-winio/internal/socket/socket.go

//go:build windows

package socket

import (
	"errors"
	"fmt"
	"net"
	"sync"
	"syscall"
	"unsafe"

	"github.com/Microsoft/go-winio/pkg/guid"
	"golang.org/x/sys/windows"
)

//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go

//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind

const socketError = uintptr(^uint32(0))

var (
	// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?

	ErrBufferSize     = errors.New("buffer size")
	ErrAddrFamily     = errors.New("address family")
	ErrInvalidPointer = errors.New("invalid pointer")
	ErrSocketClosed   = fmt.Errorf("socket closed: %w", net.ErrClosed)
)

// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)

// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
	ptr, l, err := rsa.Sockaddr()
	if err != nil {
		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
	}

	// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
	// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
	return getsockname(s, ptr, &l)
}

// GetPeerName returns the remote address the socket is connected to.
//
// See [GetSockName] for more information.
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
	ptr, l, err := rsa.Sockaddr()
	if err != nil {
		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
	}

	return getpeername(s, ptr, &l)
}

func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
	ptr, l, err := rsa.Sockaddr()
	if err != nil {
		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
	}

	return bind(s, ptr, l)
}

// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
// their sockaddr interface, so they cannot be used with HvsockAddr
// Replicate functionality here from
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go

// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
// runtime via a WSAIoctl call:
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks

type runtimeFunc struct {
	id   guid.GUID
	once sync.Once
	addr uintptr
	err  error
}

func (f *runtimeFunc) Load() error {
	f.once.Do(func() {
		var s windows.Handle
		s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
		if f.err != nil {
			return
		}
		defer windows.CloseHandle(s) //nolint:errcheck

		var n uint32
		f.err = windows.WSAIoctl(s,
			windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
			(*byte)(unsafe.Pointer(&f.id)),
			uint32(unsafe.Sizeof(f.id)),
			(*byte)(unsafe.Pointer(&f.addr)),
			uint32(unsafe.Sizeof(f.addr)),
			&n,
			nil, // overlapped
			0,   // completionRoutine
		)
	})
	return f.err
}

var (
	// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
	WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
		Data1: 0x25a207b9,
		Data2: 0xddf3,
		Data3: 0x4660,
		Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
	}

	connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
)

func ConnectEx(
	fd windows.Handle,
	rsa RawSockaddr,
	sendBuf *byte,
	sendDataLen uint32,
	bytesSent *uint32,
	overlapped *windows.Overlapped,
) error {
	if err := connectExFunc.Load(); err != nil {
		return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
	}
	ptr, n, err := rsa.Sockaddr()
	if err != nil {
		return err
	}
	return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
}

// BOOL LpfnConnectex(
//   [in]           SOCKET s,
//   [in]           const sockaddr *name,
//   [in]           int namelen,
//   [in, optional] PVOID lpSendBuffer,
//   [in]           DWORD dwSendDataLength,
//   [out]          LPDWORD lpdwBytesSent,
//   [in]           LPOVERLAPPED lpOverlapped
// )

func connectEx(
	s windows.Handle,
	name unsafe.Pointer,
	namelen int32,
	sendBuf *byte,
	sendDataLen uint32,
	bytesSent *uint32,
	overlapped *windows.Overlapped,
) (err error) {
	r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
		uintptr(s),
		uintptr(name),
		uintptr(namelen),
		uintptr(unsafe.Pointer(sendBuf)),
		uintptr(sendDataLen),
		uintptr(unsafe.Pointer(bytesSent)),
		uintptr(unsafe.Pointer(overlapped)),
	)

	if r1 == 0 {
		if e1 != 0 {
			err = error(e1)
		} else {
			err = syscall.EINVAL
		}
	}
	return err
}