This changes things to rely on a plugin server that manages all connections made to the server. An optional handler can be passed into the server when the caller wants to do extra things with the connection. It is the caller's responsibility to close the server. When the server is closed, first all existing connections are closed (and new connections are prevented). Now the signal loop only needs to close the server and not deal with `net.Conn`'s directly (or double-indirects as the case was before this change). The socket, when present in the filesystem, is no longer unlinked eagerly, as reconnections require it to be present for the lifecycle of the plugin server. Co-authored-by: Bjorn Neergaard <bjorn.neergaard@docker.com> Signed-off-by: Brian Goff <cpuguy83@gmail.com> Signed-off-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
149 lines
2.9 KiB
Go
149 lines
2.9 KiB
Go
package socket
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"sync"
|
|
)
|
|
|
|
// EnvKey represents the well-known environment variable used to pass the plugin being
|
|
// executed the socket name it should listen on to coordinate with the host CLI.
|
|
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
|
|
|
|
// NewPluginServer creates a plugin server that listens on a new Unix domain socket.
|
|
// `h` is called for each new connection to the socket in a goroutine.
|
|
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
|
|
l, err := listen("docker_cli_" + randomID())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if h == nil {
|
|
h = func(net.Conn) {}
|
|
}
|
|
|
|
pl := &PluginServer{
|
|
l: l,
|
|
h: h,
|
|
}
|
|
|
|
go func() {
|
|
defer pl.Close()
|
|
for {
|
|
err := pl.accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return pl, nil
|
|
}
|
|
|
|
type PluginServer struct {
|
|
mu sync.Mutex
|
|
conns []net.Conn
|
|
l *net.UnixListener
|
|
h func(net.Conn)
|
|
closed bool
|
|
}
|
|
|
|
func (pl *PluginServer) accept() error {
|
|
conn, err := pl.l.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pl.mu.Lock()
|
|
defer pl.mu.Unlock()
|
|
|
|
if pl.closed {
|
|
// handle potential race condition between Close and Accept
|
|
conn.Close()
|
|
return errors.New("plugin server is closed")
|
|
}
|
|
|
|
pl.conns = append(pl.conns, conn)
|
|
|
|
go pl.h(conn)
|
|
return nil
|
|
}
|
|
|
|
func (pl *PluginServer) Addr() net.Addr {
|
|
return pl.l.Addr()
|
|
}
|
|
|
|
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
|
|
// Existing connections will receive [io.EOF].
|
|
func (pl *PluginServer) Close() error {
|
|
// Remove the listener socket, if it exists on the filesystem.
|
|
unlink(pl.l)
|
|
|
|
// Close connections first to ensure the connections get io.EOF instead of a connection reset.
|
|
pl.closeAllConns()
|
|
|
|
// Try to ensure that any active connections have a chance to receive io.EOF
|
|
runtime.Gosched()
|
|
|
|
return pl.l.Close()
|
|
}
|
|
|
|
func (pl *PluginServer) closeAllConns() {
|
|
pl.mu.Lock()
|
|
defer pl.mu.Unlock()
|
|
|
|
// Prevent new connections from being accepted
|
|
pl.closed = true
|
|
|
|
for _, conn := range pl.conns {
|
|
conn.Close()
|
|
}
|
|
|
|
pl.conns = nil
|
|
}
|
|
|
|
func randomID() string {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(err) // This shouldn't happen
|
|
}
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// ConnectAndWait connects to the socket passed via well-known env var,
|
|
// if present, and attempts to read from it until it receives an EOF, at which
|
|
// point cb is called.
|
|
func ConnectAndWait(cb func()) {
|
|
socketAddr, ok := os.LookupEnv(EnvKey)
|
|
if !ok {
|
|
// if a plugin compiled against a more recent version of docker/cli
|
|
// is executed by an older CLI binary, ignore missing environment
|
|
// variable and behave as usual
|
|
return
|
|
}
|
|
addr, err := net.ResolveUnixAddr("unix", socketAddr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn, err := net.DialUnix("unix", nil, addr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
b := make([]byte, 1)
|
|
for {
|
|
_, err := conn.Read(b)
|
|
if errors.Is(err, io.EOF) {
|
|
cb()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|