diff --git a/components/engine/client/hijack.go b/components/engine/client/hijack.go index 2b14831fd2..35f5dd86dc 100644 --- a/components/engine/client/hijack.go +++ b/components/engine/client/hijack.go @@ -9,7 +9,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "strings" "time" "github.com/docker/docker/api/types" @@ -17,21 +16,6 @@ import ( "github.com/pkg/errors" ) -// tlsClientCon holds tls information and a dialed connection. -type tlsClientCon struct { - *tls.Conn - rawConn net.Conn -} - -func (c *tlsClientCon) CloseWrite() error { - // Go standard tls.Conn doesn't provide the CloseWrite() method so we do it - // on its underlying connection. - if conn, ok := c.rawConn.(types.CloseWriter); ok { - return conn.CloseWrite() - } - return nil -} - // postHijacked sends a POST request and hijacks the connection. func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) { bodyEncoded, err := encodeData(body) @@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err } -func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) { - return tlsDialWithDialer(new(net.Dialer), network, addr, config) -} - -// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in -// order to return our custom tlsClientCon struct which holds both the tls.Conn -// object _and_ its underlying raw connection. The rationale for this is that -// we need to be able to close the write end of the connection when attaching, -// which tls.Conn does not provide. -func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := dialer.Timeout - - if !dialer.Deadline.IsZero() { - deadlineTimeout := time.Until(dialer.Deadline) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - var errChannel chan error - - if timeout != 0 { - errChannel = make(chan error, 2) - time.AfterFunc(timeout, func() { - errChannel <- errors.New("") - }) - } - - proxyDialer, err := sockets.DialerFromEnvironment(dialer) - if err != nil { - return nil, err - } - - rawConn, err := proxyDialer.Dial(network, addr) - if err != nil { - return nil, err - } - // When we set up a TCP connection for hijack, there could be long periods - // of inactivity (a long running command with no output) that in certain - // network setups may cause ECONNTIMEOUT, leaving the client in an unknown - // state. Setting TCP KeepAlive on the socket connection will prohibit - // ECONNTIMEOUT unless the socket connection truly is broken - if tcpConn, ok := rawConn.(*net.TCPConn); ok { - tcpConn.SetKeepAlive(true) - tcpConn.SetKeepAlivePeriod(30 * time.Second) - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - config = tlsConfigClone(config) - config.ServerName = hostname - } - - conn := tls.Client(rawConn, config) - - if timeout == 0 { - err = conn.Handshake() - } else { - go func() { - errChannel <- conn.Handshake() - }() - - err = <-errChannel - } - - if err != nil { - rawConn.Close() - return nil, err - } - - // This is Docker difference with standard's crypto/tls package: returned a - // wrapper which holds both the TLS and raw connections. - return &tlsClientCon{conn, rawConn}, nil -} - func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) { if tlsConfig != nil && proto != "unix" && proto != "npipe" { - // Notice this isn't Go standard's tls.Dial function - return tlsDial(proto, addr, tlsConfig) + return tls.Dial(proto, addr, tlsConfig) } if proto == "npipe" { return sockets.DialPipe(addr, 32*time.Second) diff --git a/components/engine/client/hijack_test.go b/components/engine/client/hijack_test.go new file mode 100644 index 0000000000..dea5addb8c --- /dev/null +++ b/components/engine/client/hijack_test.go @@ -0,0 +1,103 @@ +package client + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/docker/docker/api/server/httputils" + "github.com/docker/docker/api/types" + "github.com/gotestyourself/gotestyourself/assert" + "github.com/pkg/errors" + "golang.org/x/net/context" +) + +func TestTLSCloseWriter(t *testing.T) { + t.Parallel() + + var chErr chan error + ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + chErr = make(chan error, 1) + defer close(chErr) + if err := httputils.ParseForm(req); err != nil { + chErr <- errors.Wrap(err, "error parsing form") + http.Error(w, err.Error(), 500) + return + } + r, rw, err := httputils.HijackConnection(w) + if err != nil { + chErr <- errors.Wrap(err, "error hijacking connection") + http.Error(w, err.Error(), 500) + return + } + defer r.Close() + + fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n") + + buf := make([]byte, 5) + _, err = r.Read(buf) + if err != nil { + chErr <- errors.Wrap(err, "error reading from client") + return + } + _, err = rw.Write(buf) + if err != nil { + chErr <- errors.Wrap(err, "error writing to client") + return + } + })}} + + var ( + l net.Listener + err error + ) + for i := 1024; i < 10000; i++ { + l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i)) + if err == nil { + break + } + } + assert.Assert(t, err) + + ts.Listener = l + defer l.Close() + + defer func() { + if chErr != nil { + assert.Assert(t, <-chErr) + } + }() + + ts.StartTLS() + defer ts.Close() + + serverURL, err := url.Parse(ts.URL) + assert.Assert(t, err) + + client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil) + assert.Assert(t, err) + + resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}}) + assert.Assert(t, err) + defer resp.Close() + + if _, ok := resp.Conn.(types.CloseWriter); !ok { + t.Fatal("tls conn did not implement the CloseWrite interface") + } + + _, err = resp.Conn.Write([]byte("hello")) + assert.Assert(t, err) + + b, err := ioutil.ReadAll(resp.Reader) + assert.Assert(t, err) + assert.Assert(t, string(b) == "hello") + assert.Assert(t, resp.CloseWrite()) + + // This should error since writes are closed + _, err = resp.Conn.Write([]byte("no")) + assert.Assert(t, err != nil) +} diff --git a/components/engine/client/tlsconfig_clone.go b/components/engine/client/tlsconfig_clone.go deleted file mode 100644 index 88200e92c3..0000000000 --- a/components/engine/client/tlsconfig_clone.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build go1.8 - -package client // import "github.com/docker/docker/client" - -import "crypto/tls" - -// tlsConfigClone returns a clone of tls.Config. This function is provided for -// compatibility for go1.7 that doesn't include this method in stdlib. -func tlsConfigClone(c *tls.Config) *tls.Config { - return c.Clone() -} diff --git a/components/engine/client/tlsconfig_clone_go17.go b/components/engine/client/tlsconfig_clone_go17.go deleted file mode 100644 index e298542367..0000000000 --- a/components/engine/client/tlsconfig_clone_go17.go +++ /dev/null @@ -1,33 +0,0 @@ -// +build go1.7,!go1.8 - -package client // import "github.com/docker/docker/client" - -import "crypto/tls" - -// tlsConfigClone returns a clone of tls.Config. This function is provided for -// compatibility for go1.7 that doesn't include this method in stdlib. -func tlsConfigClone(c *tls.Config) *tls.Config { - return &tls.Config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - } -}