package ssh import ( "bufio" "bytes" "fmt" "io" "os/user" "sync" "time" "github.com/AlecAivazis/survey/v2" "github.com/kevinburke/ssh_config" "github.com/sfreiberg/simplessh" "github.com/sirupsen/logrus" ) // HostConfig is a SSH host config. type HostConfig struct { Host string IdentityFile string Port string User string } // GetHostConfig retrieves a ~/.ssh/config config for a host. func GetHostConfig(hostname, username, port string) (HostConfig, error) { var hostConfig HostConfig var host, idf string if host = ssh_config.Get(hostname, "Hostname"); host == "" { logrus.Debugf("no hostname found in SSH config, assuming %s", hostname) host = hostname } if username == "" { if username = ssh_config.Get(hostname, "User"); username == "" { systemUser, err := user.Current() if err != nil { return hostConfig, err } logrus.Debugf("no username found in SSH config or passed on command-line, assuming %s", username) username = systemUser.Username } } if port == "" { if port = ssh_config.Get(hostname, "Port"); port == "" { logrus.Debugf("no port found in SSH config or passed on command-line, assuming 22") port = "22" } } idf = ssh_config.Get(hostname, "IdentityFile") hostConfig.Host = host if idf != "" { hostConfig.IdentityFile = idf } hostConfig.Port = port hostConfig.User = username logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname) return hostConfig, nil } // New creates a new SSH client connection. func New(domainName, sshAuth, username, port string) (*simplessh.Client, error) { var client *simplessh.Client hostConfig, err := GetHostConfig(domainName, username, port) if err != nil { return client, err } if sshAuth == "identity-file" { var err error client, err = simplessh.ConnectWithAgentTimeout(hostConfig.Host, hostConfig.User, 5*time.Second) if err != nil { return client, err } } else { password := "" prompt := &survey.Password{ Message: "SSH password?", } if err := survey.AskOne(prompt, &password); err != nil { return client, err } var err error client, err = simplessh.ConnectWithPasswordTimeout(hostConfig.Host, hostConfig.User, password, 5*time.Second) if err != nil { return client, err } } return client, nil } // sudoWriter supports sudo command handling. // https://github.com/sfreiberg/simplessh/blob/master/simplessh.go type sudoWriter struct { b bytes.Buffer pw string stdin io.Writer m sync.Mutex } // Write satisfies the write interface for sudoWriter. // https://github.com/sfreiberg/simplessh/blob/master/simplessh.go func (w *sudoWriter) Write(p []byte) (int, error) { if string(p) == "sudo_password" { w.stdin.Write([]byte(w.pw + "\n")) w.pw = "" return len(p), nil } w.m.Lock() defer w.m.Unlock() return w.b.Write(p) } // RunSudoCmd runs SSH commands and streams output. // https://github.com/sfreiberg/simplessh/blob/master/simplessh.go func RunSudoCmd(cmd, passwd string, cl *simplessh.Client) error { session, err := cl.SSHClient.NewSession() if err != nil { return err } defer session.Close() cmd = "sudo -p " + "sudo_password" + " -S " + cmd w := &sudoWriter{ pw: passwd, } w.stdin, err = session.StdinPipe() if err != nil { return err } session.Stdout = w session.Stderr = w done := make(chan struct{}) scanner := bufio.NewScanner(session.Stdin) go func() { for scanner.Scan() { line := scanner.Text() fmt.Println(line) } done <- struct{}{} }() if err := session.Start(cmd); err != nil { return err } <-done if err := session.Wait(); err != nil { return err } return err } // Exec runs a command on a remote and streams output. // https://github.com/sfreiberg/simplessh/blob/master/simplessh.go func Exec(cmd string, cl *simplessh.Client) error { session, err := cl.SSHClient.NewSession() if err != nil { return err } defer session.Close() stdout, err := session.StdoutPipe() if err != nil { return err } stderr, err := session.StdoutPipe() if err != nil { return err } stdoutDone := make(chan struct{}) stdoutScanner := bufio.NewScanner(stdout) go func() { for stdoutScanner.Scan() { line := stdoutScanner.Text() fmt.Println(line) } stdoutDone <- struct{}{} }() stderrDone := make(chan struct{}) stderrScanner := bufio.NewScanner(stderr) go func() { for stderrScanner.Scan() { line := stderrScanner.Text() fmt.Println(line) } stderrDone <- struct{}{} }() if err := session.Start(cmd); err != nil { return err } <-stdoutDone <-stderrDone if err := session.Wait(); err != nil { return err } return nil }