forked from toolshed/abra
		
	
		
			
				
	
	
		
			178 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			178 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| //go:build windows
 | |
| 
 | |
| package socket
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"syscall"
 | |
| 	"unsafe"
 | |
| 
 | |
| 	"github.com/Microsoft/go-winio/pkg/guid"
 | |
| 	"golang.org/x/sys/windows"
 | |
| )
 | |
| 
 | |
| //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
 | |
| 
 | |
| //sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
 | |
| //sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
 | |
| //sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
 | |
| 
 | |
| const socketError = uintptr(^uint32(0))
 | |
| 
 | |
| var (
 | |
| 	// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
 | |
| 
 | |
| 	ErrBufferSize     = errors.New("buffer size")
 | |
| 	ErrAddrFamily     = errors.New("address family")
 | |
| 	ErrInvalidPointer = errors.New("invalid pointer")
 | |
| 	ErrSocketClosed   = fmt.Errorf("socket closed: %w", net.ErrClosed)
 | |
| )
 | |
| 
 | |
| // todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
 | |
| 
 | |
| // GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
 | |
| // If rsa is not large enough, the [windows.WSAEFAULT] is returned.
 | |
| func GetSockName(s windows.Handle, rsa RawSockaddr) error {
 | |
| 	ptr, l, err := rsa.Sockaddr()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
 | |
| 	// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
 | |
| 	return getsockname(s, ptr, &l)
 | |
| }
 | |
| 
 | |
| // GetPeerName returns the remote address the socket is connected to.
 | |
| //
 | |
| // See [GetSockName] for more information.
 | |
| func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
 | |
| 	ptr, l, err := rsa.Sockaddr()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return getpeername(s, ptr, &l)
 | |
| }
 | |
| 
 | |
| func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
 | |
| 	ptr, l, err := rsa.Sockaddr()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	return bind(s, ptr, l)
 | |
| }
 | |
| 
 | |
| // "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
 | |
| // their sockaddr interface, so they cannot be used with HvsockAddr
 | |
| // Replicate functionality here from
 | |
| // https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
 | |
| 
 | |
| // The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
 | |
| // runtime via a WSAIoctl call:
 | |
| // https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
 | |
| 
 | |
| type runtimeFunc struct {
 | |
| 	id   guid.GUID
 | |
| 	once sync.Once
 | |
| 	addr uintptr
 | |
| 	err  error
 | |
| }
 | |
| 
 | |
| func (f *runtimeFunc) Load() error {
 | |
| 	f.once.Do(func() {
 | |
| 		var s windows.Handle
 | |
| 		s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
 | |
| 		if f.err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		defer windows.CloseHandle(s) //nolint:errcheck
 | |
| 
 | |
| 		var n uint32
 | |
| 		f.err = windows.WSAIoctl(s,
 | |
| 			windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
 | |
| 			(*byte)(unsafe.Pointer(&f.id)),
 | |
| 			uint32(unsafe.Sizeof(f.id)),
 | |
| 			(*byte)(unsafe.Pointer(&f.addr)),
 | |
| 			uint32(unsafe.Sizeof(f.addr)),
 | |
| 			&n,
 | |
| 			nil, // overlapped
 | |
| 			0,   // completionRoutine
 | |
| 		)
 | |
| 	})
 | |
| 	return f.err
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
 | |
| 	WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
 | |
| 		Data1: 0x25a207b9,
 | |
| 		Data2: 0xddf3,
 | |
| 		Data3: 0x4660,
 | |
| 		Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
 | |
| 	}
 | |
| 
 | |
| 	connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
 | |
| )
 | |
| 
 | |
| func ConnectEx(
 | |
| 	fd windows.Handle,
 | |
| 	rsa RawSockaddr,
 | |
| 	sendBuf *byte,
 | |
| 	sendDataLen uint32,
 | |
| 	bytesSent *uint32,
 | |
| 	overlapped *windows.Overlapped,
 | |
| ) error {
 | |
| 	if err := connectExFunc.Load(); err != nil {
 | |
| 		return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
 | |
| 	}
 | |
| 	ptr, n, err := rsa.Sockaddr()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
 | |
| }
 | |
| 
 | |
| // BOOL LpfnConnectex(
 | |
| //   [in]           SOCKET s,
 | |
| //   [in]           const sockaddr *name,
 | |
| //   [in]           int namelen,
 | |
| //   [in, optional] PVOID lpSendBuffer,
 | |
| //   [in]           DWORD dwSendDataLength,
 | |
| //   [out]          LPDWORD lpdwBytesSent,
 | |
| //   [in]           LPOVERLAPPED lpOverlapped
 | |
| // )
 | |
| 
 | |
| func connectEx(
 | |
| 	s windows.Handle,
 | |
| 	name unsafe.Pointer,
 | |
| 	namelen int32,
 | |
| 	sendBuf *byte,
 | |
| 	sendDataLen uint32,
 | |
| 	bytesSent *uint32,
 | |
| 	overlapped *windows.Overlapped,
 | |
| ) (err error) {
 | |
| 	r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
 | |
| 		uintptr(s),
 | |
| 		uintptr(name),
 | |
| 		uintptr(namelen),
 | |
| 		uintptr(unsafe.Pointer(sendBuf)),
 | |
| 		uintptr(sendDataLen),
 | |
| 		uintptr(unsafe.Pointer(bytesSent)),
 | |
| 		uintptr(unsafe.Pointer(overlapped)),
 | |
| 	)
 | |
| 
 | |
| 	if r1 == 0 {
 | |
| 		if e1 != 0 {
 | |
| 			err = error(e1)
 | |
| 		} else {
 | |
| 			err = syscall.EINVAL
 | |
| 		}
 | |
| 	}
 | |
| 	return err
 | |
| }
 |