package ssh import ( "bufio" "bytes" "crypto/sha256" "encoding/base64" "fmt" "io" "net" "os" "os/user" "path/filepath" "strings" "sync" "time" "coopcloud.tech/abra/pkg/context" "github.com/AlecAivazis/survey/v2" dockerSSHPkg "github.com/docker/cli/cli/connhelper/ssh" sshPkg "github.com/gliderlabs/ssh" "github.com/kevinburke/ssh_config" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/knownhosts" ) var KnownHostsPath = filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts") type Client struct { SSHClient *ssh.Client } // HostConfig is a SSH host config. type HostConfig struct { Host string IdentityFile string Port string User string } // Exec cmd on the remote host and return stderr and stdout func (c *Client) Exec(cmd string) ([]byte, error) { session, err := c.SSHClient.NewSession() if err != nil { return nil, err } defer session.Close() return session.CombinedOutput(cmd) } // Close the underlying SSH connection func (c *Client) Close() error { return c.SSHClient.Close() } // New creates a new SSH client connection. func New(domainName, sshAuth, username, port string) (*Client, error) { var client *Client ctxConnDetails, err := GetContextConnDetails(domainName) if err != nil { return client, nil } if sshAuth == "identity-file" { var err error client, err = connectWithAgentTimeout( ctxConnDetails.Host, ctxConnDetails.User, ctxConnDetails.Port, 5*time.Second, ) if err != nil { return client, err } } else { password := "" prompt := &survey.Password{ Message: "SSH password?", } if err := survey.AskOne(prompt, &password); err != nil { return client, err } var err error client, err = connectWithPasswordTimeout( ctxConnDetails.Host, ctxConnDetails.User, ctxConnDetails.Port, password, 5*time.Second, ) if err != nil { return client, err } } return client, nil } // sudoWriter supports sudo command handling type sudoWriter struct { b bytes.Buffer pw string stdin io.Writer m sync.Mutex } // Write satisfies the write interface for sudoWriter func (w *sudoWriter) Write(p []byte) (int, error) { if strings.Contains(string(p), "sudo_password") { w.stdin.Write([]byte(w.pw + "\n")) w.pw = "" return len(p), nil } w.m.Lock() defer w.m.Unlock() return w.b.Write(p) } // RunSudoCmd runs SSH commands and streams output func RunSudoCmd(cmd, passwd string, cl *Client) error { session, err := cl.SSHClient.NewSession() if err != nil { return err } defer session.Close() sudoCmd := fmt.Sprintf("SSH_ASKPASS=/usr/bin/ssh-askpass; sudo -p sudo_password -S %s", cmd) w := &sudoWriter{pw: passwd} w.stdin, err = session.StdinPipe() if err != nil { return err } session.Stdout = w session.Stderr = w modes := ssh.TerminalModes{ ssh.ECHO: 0, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, } err = session.RequestPty("xterm", 80, 40, modes) if err != nil { return err } if err := session.Run(sudoCmd); err != nil { return fmt.Errorf("%s", string(w.b.Bytes())) } return nil } // EnsureKnowHostsFiles ensures that ~/.ssh/known_hosts is created func EnsureKnowHostsFiles() error { if _, err := os.Stat(KnownHostsPath); os.IsNotExist(err) { logrus.Debugf("missing %s, creating now", KnownHostsPath) file, err := os.OpenFile(KnownHostsPath, os.O_CREATE, 0600) if err != nil { return err } file.Close() } return nil } // GetHostKey checks if a host key is registered in the ~/.ssh/known_hosts file func GetHostKey(hostname string) (bool, sshPkg.PublicKey, error) { var hostKey sshPkg.PublicKey ctxConnDetails, err := GetContextConnDetails(hostname) if err != nil { return false, hostKey, err } if err := EnsureKnowHostsFiles(); err != nil { return false, hostKey, err } file, err := os.Open(KnownHostsPath) if err != nil { return false, hostKey, err } defer file.Close() scanner := bufio.NewScanner(file) for scanner.Scan() { fields := strings.Split(scanner.Text(), " ") if len(fields) != 3 { continue } hostnameAndPort := fmt.Sprintf("%s:%s", ctxConnDetails.Host, ctxConnDetails.Port) hashed := knownhosts.Normalize(hostnameAndPort) if strings.Contains(fields[0], hashed) { var err error hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) if err != nil { return false, hostKey, fmt.Errorf("error parsing server SSH host key %q: %v", fields[2], err) } break } } if hostKey != nil { logrus.Debugf("server SSH host key present in ~/.ssh/known_hosts for %s", hostname) return true, hostKey, nil } return false, hostKey, nil } // InsertHostKey adds a new host key to the ~/.ssh/known_hosts file func InsertHostKey(hostname string, remote net.Addr, pubKey ssh.PublicKey) error { file, err := os.OpenFile(KnownHostsPath, os.O_APPEND|os.O_WRONLY, 0600) if err != nil { return err } defer file.Close() hashedHostname := knownhosts.Normalize(hostname) lineHostname := knownhosts.Line([]string{hashedHostname}, pubKey) _, err = file.WriteString(fmt.Sprintf("%s\n", lineHostname)) if err != nil { return err } hashedRemote := knownhosts.Normalize(remote.String()) lineRemote := knownhosts.Line([]string{hashedRemote}, pubKey) _, err = file.WriteString(fmt.Sprintf("%s\n", lineRemote)) if err != nil { return err } logrus.Debugf("SSH host key generated: %s", lineHostname) logrus.Debugf("SSH host key generated: %s", lineRemote) return nil } // HostKeyAddCallback ensures server ssh host keys are handled func HostKeyAddCallback(hostnameAndPort string, remote net.Addr, pubKey ssh.PublicKey) error { exists, _, err := GetHostKey(hostnameAndPort) if err != nil { return err } if exists { hostname := strings.Split(hostnameAndPort, ":")[0] logrus.Debugf("server SSH host key found for %s", hostname) return nil } if !exists { hostname := strings.Split(hostnameAndPort, ":")[0] parsedPubKey := FingerprintSHA256(pubKey) fmt.Printf(fmt.Sprintf(` You are attempting to make an SSH connection to a server but there is no entry in your ~/.ssh/known_hosts file which confirms that you have already validated that this is indeed the server you want to connect to. Please take a moment to validate the following SSH host key, it is important. Host: %s Fingerprint: %s If this is confusing to you, you can read the article below and learn how to validate this fingerprint safely. Thanks to the comrades at cyberia.club for writing this extensive guide <3 https://sequentialread.com/understanding-the-secure-shell-protocol-ssh/ `, hostname, parsedPubKey)) response := false prompt := &survey.Confirm{ Message: "are you sure you trust this host key?", } if err := survey.AskOne(prompt, &response); err != nil { return err } if !response { logrus.Fatal("exiting as requested") } logrus.Debugf("attempting to insert server SSH host key for %s, %s", hostnameAndPort, remote) if err := InsertHostKey(hostnameAndPort, remote, pubKey); err != nil { return err } logrus.Infof("successfully added server SSH host key for %s", hostname) } return nil } // connect makes the SSH connection func connect(username, host, port string, authMethod ssh.AuthMethod, timeout time.Duration) (*Client, error) { config := &ssh.ClientConfig{ User: username, Auth: []ssh.AuthMethod{authMethod}, HostKeyCallback: HostKeyAddCallback, // the main reason why we fork } hostnameAndPort := fmt.Sprintf("%s:%s", host, port) logrus.Debugf("tcp dialing %s", hostnameAndPort) var conn net.Conn var err error conn, err = net.DialTimeout("tcp", hostnameAndPort, timeout) if err != nil { logrus.Debugf("tcp dialing %s failed, trying via ~/.ssh/config", hostnameAndPort) hostConfig, err := GetHostConfig(host, username, port, true) if err != nil { return nil, err } conn, err = net.DialTimeout("tcp", fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port), timeout) if err != nil { return nil, err } } sshConn, chans, reqs, err := ssh.NewClientConn(conn, hostnameAndPort, config) if err != nil { return nil, err } client := ssh.NewClient(sshConn, chans, reqs) c := &Client{SSHClient: client} return c, nil } func connectWithAgentTimeout(host, username, port string, timeout time.Duration) (*Client, error) { logrus.Debugf("using ssh-agent to make an SSH connection for %s", host) sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { return nil, err } agentCl := agent.NewClient(sshAgent) authMethod := ssh.PublicKeysCallback(agentCl.Signers) loadedKeys, err := agentCl.List() if err != nil { return nil, err } var convertedKeys []string for _, key := range loadedKeys { convertedKeys = append(convertedKeys, key.String()) } if len(convertedKeys) > 0 { logrus.Debugf("ssh-agent has these keys loaded: %s", strings.Join(convertedKeys, ",")) } else { logrus.Debug("ssh-agent has no keys loaded") } return connect(username, host, port, authMethod, timeout) } func connectWithPasswordTimeout(host, username, port, pass string, timeout time.Duration) (*Client, error) { authMethod := ssh.Password(pass) return connect(username, host, port, authMethod, timeout) } // EnsureHostKey ensures that a host key trusted and added to the ~/.ssh/known_hosts file func EnsureHostKey(hostname string) error { if hostname == "default" || hostname == "local" { logrus.Debugf("not checking server SSH host key against local/default target") return nil } exists, _, err := GetHostKey(hostname) if err != nil { return err } if exists { return nil } ctxConnDetails, err := GetContextConnDetails(hostname) if err != nil { return err } _, err = connectWithAgentTimeout( ctxConnDetails.Host, ctxConnDetails.User, ctxConnDetails.Port, 5*time.Second, ) if err != nil { return err } return nil } // FingerprintSHA256 generates the SHA256 fingerprint for a server SSH host key func FingerprintSHA256(key ssh.PublicKey) string { hash := sha256.Sum256(key.Marshal()) b64hash := base64.StdEncoding.EncodeToString(hash[:]) trimmed := strings.TrimRight(b64hash, "=") return fmt.Sprintf("SHA256:%s", trimmed) } // GetContextConnDetails retrieves SSH connection details from a docker context endpoint func GetContextConnDetails(serverName string) (*dockerSSHPkg.Spec, error) { dockerContextStore := context.NewDefaultDockerContextStore() contexts, err := dockerContextStore.Store.List() if err != nil { return &dockerSSHPkg.Spec{}, err } if strings.Contains(serverName, ":") { serverName = strings.Split(serverName, ":")[0] } for _, ctx := range contexts { endpoint, err := context.GetContextEndpoint(ctx) if err != nil && strings.Contains(err.Error(), "does not exist") { // No local context found, we can continue safely continue } if ctx.Name == serverName { ctxConnDetails, err := dockerSSHPkg.ParseURL(endpoint) if err != nil { return &dockerSSHPkg.Spec{}, err } logrus.Debugf("found context connection details %v for %s", ctxConnDetails, serverName) return ctxConnDetails, nil } } hostConfig, err := GetHostConfig(serverName, "", "", false) if err != nil { return &dockerSSHPkg.Spec{}, err } logrus.Debugf("couldn't find a docker context matching %s", serverName) logrus.Debugf("searching ~/.ssh/config for a Host entry for %s", serverName) connDetails := &dockerSSHPkg.Spec{ Host: hostConfig.Host, User: hostConfig.User, Port: hostConfig.Port, } logrus.Debugf("using %v from ~/.ssh/config for connection details", connDetails) return connDetails, nil } // GetHostConfig retrieves a ~/.ssh/config config for a host. func GetHostConfig(hostname, username, port string, override bool) (HostConfig, error) { var hostConfig HostConfig if hostname == "" || override { if sshHost := ssh_config.Get(hostname, "Hostname"); sshHost != "" { hostname = sshHost } } if username == "" || override { if sshUser := ssh_config.Get(hostname, "User"); sshUser != "" { username = sshUser } else { systemUser, err := user.Current() if err != nil { return hostConfig, err } username = systemUser.Username } } if port == "" || override { if sshPort := ssh_config.Get(hostname, "Port"); sshPort != "" { // skip override probably correct port with dummy default value from // ssh_config which is 22. only when the original port number is empty // should we try this default. this might not cover all cases // unfortunately. if port != "" && sshPort != "22" { port = sshPort } } } if idf := ssh_config.Get(hostname, "IdentityFile"); idf != "" && idf != "~/.ssh/identity" { var err error idf, err = identityFileAbsPath(idf) if err != nil { return hostConfig, err } hostConfig.IdentityFile = idf } else { hostConfig.IdentityFile = "" } hostConfig.Host = hostname hostConfig.Port = port hostConfig.User = username logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname) return hostConfig, nil } func identityFileAbsPath(relPath string) (string, error) { var err error var absPath string if strings.HasPrefix(relPath, "~/") { systemUser, err := user.Current() if err != nil { return absPath, err } absPath = filepath.Join(systemUser.HomeDir, relPath[2:]) } else { absPath, err = filepath.Abs(relPath) if err != nil { return absPath, err } } logrus.Debugf("resolved %s to %s to read the ssh identity file", relPath, absPath) return absPath, nil }