0
0
forked from toolshed/abra
abra/pkg/ssh/ssh.go

227 lines
4.6 KiB
Go

package ssh
import (
"bufio"
"bytes"
"fmt"
"io"
"os/user"
"sync"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/kevinburke/ssh_config"
"github.com/sfreiberg/simplessh"
"github.com/sirupsen/logrus"
)
// HostConfig is a SSH host config.
type HostConfig struct {
Host string
IdentityFile string
Port string
User string
}
// GetHostConfig retrieves a ~/.ssh/config config for a host.
func GetHostConfig(hostname, username, port string) (HostConfig, error) {
var hostConfig HostConfig
var host, idf string
if host = ssh_config.Get(hostname, "Hostname"); host == "" {
logrus.Debugf("no hostname found in SSH config, assuming %s", hostname)
host = hostname
}
if username == "" {
if username = ssh_config.Get(hostname, "User"); username == "" {
systemUser, err := user.Current()
if err != nil {
return hostConfig, err
}
logrus.Debugf("no username found in SSH config or passed on command-line, assuming %s", username)
username = systemUser.Username
}
}
if port == "" {
if port = ssh_config.Get(hostname, "Port"); port == "" {
logrus.Debugf("no port found in SSH config or passed on command-line, assuming 22")
port = "22"
}
}
idf = ssh_config.Get(hostname, "IdentityFile")
hostConfig.Host = host
if idf != "" {
hostConfig.IdentityFile = idf
}
hostConfig.Port = port
hostConfig.User = username
logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)
return hostConfig, nil
}
// New creates a new SSH client connection.
func New(domainName, sshAuth, username, port string) (*simplessh.Client, error) {
var client *simplessh.Client
hostConfig, err := GetHostConfig(domainName, username, port)
if err != nil {
return client, err
}
if sshAuth == "identity-file" {
var err error
client, err = simplessh.ConnectWithAgentTimeout(hostConfig.Host, hostConfig.User, 5*time.Second)
if err != nil {
return client, err
}
} else {
password := ""
prompt := &survey.Password{
Message: "SSH password?",
}
if err := survey.AskOne(prompt, &password); err != nil {
return client, err
}
var err error
client, err = simplessh.ConnectWithPasswordTimeout(hostConfig.Host, hostConfig.User, password, 5*time.Second)
if err != nil {
return client, err
}
}
return client, nil
}
// sudoWriter supports sudo command handling.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
type sudoWriter struct {
b bytes.Buffer
pw string
stdin io.Writer
m sync.Mutex
}
// Write satisfies the write interface for sudoWriter.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func (w *sudoWriter) Write(p []byte) (int, error) {
if string(p) == "sudo_password" {
w.stdin.Write([]byte(w.pw + "\n"))
w.pw = ""
return len(p), nil
}
w.m.Lock()
defer w.m.Unlock()
return w.b.Write(p)
}
// RunSudoCmd runs SSH commands and streams output.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func RunSudoCmd(cmd, passwd string, cl *simplessh.Client) error {
session, err := cl.SSHClient.NewSession()
if err != nil {
return err
}
defer session.Close()
cmd = "sudo -p " + "sudo_password" + " -S " + cmd
w := &sudoWriter{
pw: passwd,
}
w.stdin, err = session.StdinPipe()
if err != nil {
return err
}
session.Stdout = w
session.Stderr = w
done := make(chan struct{})
scanner := bufio.NewScanner(session.Stdin)
go func() {
for scanner.Scan() {
line := scanner.Text()
fmt.Println(line)
}
done <- struct{}{}
}()
if err := session.Start(cmd); err != nil {
return err
}
<-done
if err := session.Wait(); err != nil {
return err
}
return err
}
// Exec runs a command on a remote and streams output.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func Exec(cmd string, cl *simplessh.Client) error {
session, err := cl.SSHClient.NewSession()
if err != nil {
return err
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
return err
}
stderr, err := session.StdoutPipe()
if err != nil {
return err
}
stdoutDone := make(chan struct{})
stdoutScanner := bufio.NewScanner(stdout)
go func() {
for stdoutScanner.Scan() {
line := stdoutScanner.Text()
fmt.Println(line)
}
stdoutDone <- struct{}{}
}()
stderrDone := make(chan struct{})
stderrScanner := bufio.NewScanner(stderr)
go func() {
for stderrScanner.Scan() {
line := stderrScanner.Text()
fmt.Println(line)
}
stderrDone <- struct{}{}
}()
if err := session.Start(cmd); err != nil {
return err
}
<-stdoutDone
<-stderrDone
if err := session.Wait(); err != nil {
return err
}
return nil
}