Files
docker-cli/cli-plugins/socket/socket_test.go
Brian Goff d68cc0e8d0 plugin: closer-based plugin notification socket
This changes things to rely on a plugin server that manages all
connections made to the server.

An optional handler can be passed into the server when the caller wants
to do extra things with the connection.

It is the caller's responsibility to close the server.
When the server is closed, first all existing connections are closed
(and new connections are prevented).

Now the signal loop only needs to close the server and not deal with
`net.Conn`'s directly (or double-indirects as the case was before this
change).

The socket, when present in the filesystem, is no longer unlinked
eagerly, as reconnections require it to be present for the lifecycle of
the plugin server.

Co-authored-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
2024-03-21 15:08:19 -06:00

190 lines
4.8 KiB
Go

package socket
import (
"errors"
"io"
"io/fs"
"net"
"os"
"runtime"
"strings"
"sync/atomic"
"testing"
"time"
"gotest.tools/v3/assert"
"gotest.tools/v3/poll"
)
func TestPluginServer(t *testing.T) {
t.Run("connection closes with EOF when server closes", func(t *testing.T) {
called := make(chan struct{})
srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
assert.NilError(t, err)
assert.Assert(t, srv != nil, "returned nil server but no error")
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
defer conn.Close()
done := make(chan error, 1)
go func() {
_, err := conn.Read(make([]byte, 1))
done <- err
}()
select {
case <-called:
case <-time.After(10 * time.Millisecond):
t.Fatal("handler not called")
}
srv.Close()
select {
case err := <-done:
if !errors.Is(err, io.EOF) {
t.Fatalf("exepcted EOF error, got: %v", err)
}
case <-time.After(10 * time.Millisecond):
}
})
t.Run("allows reconnects", func(t *testing.T) {
var calls int32
h := func(_ net.Conn) {
atomic.AddInt32(&calls, 1)
}
srv, err := NewPluginServer(h)
assert.NilError(t, err)
defer srv.Close()
assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
waitForCalls := func(n int) {
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if atomic.LoadInt32(&calls) == int32(n) {
return poll.Success()
}
return poll.Continue("waiting for handler to be called")
})
}
otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
otherConn.Close()
waitForCalls(1)
conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn.Close()
waitForCalls(2)
// and again but don't close the existing connection
conn2, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn2.Close()
waitForCalls(3)
srv.Close()
// now make sure we get EOF on the existing connections
buf := make([]byte, 1)
_, err = conn.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
_, err = conn2.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
})
t.Run("does not leak sockets to local directory", func(t *testing.T) {
srv, err := NewPluginServer(nil)
assert.NilError(t, err)
assert.Check(t, srv != nil, "returned nil server but no error")
checkDirNoNewPluginServer(t)
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
checkDirNoNewPluginServer(t)
})
}
func checkDirNoNewPluginServer(t *testing.T) {
t.Helper()
files, err := os.ReadDir(".")
assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
for _, f := range files {
info, err := f.Info()
assert.NilError(t, err, "failed to check file info")
// check for a socket with `docker_cli_` in the name (from `SetupConn()`)
if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
t.Fatal("found socket in a local directory")
}
}
}
func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) {
srv, err := NewPluginServer(nil)
assert.NilError(t, err, "failed to setup server")
defer srv.Close()
done := make(chan struct{})
t.Setenv(EnvKey, srv.Addr().String())
cancelFunc := func() {
done <- struct{}{}
}
ConnectAndWait(cancelFunc)
select {
case <-done:
t.Fatal("unexpectedly done")
default:
}
srv.Close()
select {
case <-done:
case <-time.After(10 * time.Millisecond):
t.Fatal("cancel function not closed after 10ms")
}
})
// TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
srv, err := NewPluginServer(nil)
assert.NilError(t, err, "failed to setup server")
defer srv.Close()
t.Setenv(EnvKey, srv.Addr().String())
numGoroutines := runtime.NumGoroutine()
ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
srv.Close()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
})
}