forked from toolshed/abra
This is part of trying to debug: coop-cloud/organising#250 And also part of: coop-cloud/docs.coopcloud.tech#27 Where I now try to specify the same logic as `ssh -i <my-key-path>` in the underlying connection logic. This should help with being more explicit about what key is being used via the SSH config file.
605 lines
14 KiB
Go
605 lines
14 KiB
Go
package ssh
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"coopcloud.tech/abra/pkg/context"
|
|
"github.com/AlecAivazis/survey/v2"
|
|
dockerSSHPkg "github.com/docker/cli/cli/connhelper/ssh"
|
|
sshPkg "github.com/gliderlabs/ssh"
|
|
"github.com/kevinburke/ssh_config"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
var KnownHostsPath = filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
|
|
|
|
type Client struct {
|
|
SSHClient *ssh.Client
|
|
}
|
|
|
|
// HostConfig is a SSH host config.
|
|
type HostConfig struct {
|
|
Host string
|
|
IdentityFile string
|
|
Port string
|
|
User string
|
|
}
|
|
|
|
// Exec cmd on the remote host and return stderr and stdout
|
|
func (c *Client) Exec(cmd string) ([]byte, error) {
|
|
session, err := c.SSHClient.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
|
|
return session.CombinedOutput(cmd)
|
|
}
|
|
|
|
// Close the underlying SSH connection
|
|
func (c *Client) Close() error {
|
|
return c.SSHClient.Close()
|
|
}
|
|
|
|
// New creates a new SSH client connection.
|
|
func New(domainName, sshAuth, username, port string) (*Client, error) {
|
|
var client *Client
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(domainName)
|
|
if err != nil {
|
|
return client, nil
|
|
}
|
|
|
|
if sshAuth == "identity-file" {
|
|
var err error
|
|
client, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
} else {
|
|
password := ""
|
|
prompt := &survey.Password{
|
|
Message: "SSH password?",
|
|
}
|
|
if err := survey.AskOne(prompt, &password); err != nil {
|
|
return client, err
|
|
}
|
|
|
|
var err error
|
|
client, err = connectWithPasswordTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
password,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// sudoWriter supports sudo command handling
|
|
type sudoWriter struct {
|
|
b bytes.Buffer
|
|
pw string
|
|
stdin io.Writer
|
|
m sync.Mutex
|
|
}
|
|
|
|
// Write satisfies the write interface for sudoWriter
|
|
func (w *sudoWriter) Write(p []byte) (int, error) {
|
|
if string(p) == "sudo_password" {
|
|
w.stdin.Write([]byte(w.pw + "\n"))
|
|
w.pw = ""
|
|
return len(p), nil
|
|
}
|
|
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
return w.b.Write(p)
|
|
}
|
|
|
|
// RunSudoCmd runs SSH commands and streams output
|
|
func RunSudoCmd(cmd, passwd string, cl *Client) error {
|
|
session, err := cl.SSHClient.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
cmd = "sudo -p " + "sudo_password" + " -S " + cmd
|
|
|
|
w := &sudoWriter{
|
|
pw: passwd,
|
|
}
|
|
w.stdin, err = session.StdinPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session.Stdout = w
|
|
session.Stderr = w
|
|
|
|
done := make(chan struct{})
|
|
scanner := bufio.NewScanner(session.Stdin)
|
|
|
|
go func() {
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
fmt.Println(line)
|
|
}
|
|
done <- struct{}{}
|
|
}()
|
|
|
|
if err := session.Start(cmd); err != nil {
|
|
return err
|
|
}
|
|
|
|
<-done
|
|
|
|
if err := session.Wait(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Exec runs a command on a remote and streams output
|
|
func Exec(cmd string, cl *Client) error {
|
|
session, err := cl.SSHClient.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
stdout, err := session.StdoutPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
stderr, err := session.StdoutPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
stdoutDone := make(chan struct{})
|
|
stdoutScanner := bufio.NewScanner(stdout)
|
|
|
|
go func() {
|
|
for stdoutScanner.Scan() {
|
|
line := stdoutScanner.Text()
|
|
fmt.Println(line)
|
|
}
|
|
stdoutDone <- struct{}{}
|
|
}()
|
|
|
|
stderrDone := make(chan struct{})
|
|
stderrScanner := bufio.NewScanner(stderr)
|
|
|
|
go func() {
|
|
for stderrScanner.Scan() {
|
|
line := stderrScanner.Text()
|
|
fmt.Println(line)
|
|
}
|
|
stderrDone <- struct{}{}
|
|
}()
|
|
|
|
if err := session.Start(cmd); err != nil {
|
|
return err
|
|
}
|
|
|
|
<-stdoutDone
|
|
<-stderrDone
|
|
|
|
if err := session.Wait(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnsureKnowHostsFiles ensures that ~/.ssh/known_hosts is created
|
|
func EnsureKnowHostsFiles() error {
|
|
if _, err := os.Stat(KnownHostsPath); os.IsNotExist(err) {
|
|
logrus.Debugf("missing %s, creating now", KnownHostsPath)
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_CREATE, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetHostKey checks if a host key is registered in the ~/.ssh/known_hosts file
|
|
func GetHostKey(hostname string) (bool, sshPkg.PublicKey, error) {
|
|
var hostKey sshPkg.PublicKey
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
if err := EnsureKnowHostsFiles(); err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
file, err := os.Open(KnownHostsPath)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
defer file.Close()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
fields := strings.Split(scanner.Text(), " ")
|
|
if len(fields) != 3 {
|
|
continue
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", ctxConnDetails.Host, ctxConnDetails.Port)
|
|
hashed := knownhosts.Normalize(hostnameAndPort)
|
|
|
|
if strings.Contains(fields[0], hashed) {
|
|
var err error
|
|
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
|
|
if err != nil {
|
|
return false, hostKey, fmt.Errorf("error parsing server SSH host key %q: %v", fields[2], err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if hostKey != nil {
|
|
logrus.Debugf("server SSH host key present in ~/.ssh/known_hosts for %s", hostname)
|
|
return true, hostKey, nil
|
|
}
|
|
|
|
return false, hostKey, nil
|
|
}
|
|
|
|
// InsertHostKey adds a new host key to the ~/.ssh/known_hosts file
|
|
func InsertHostKey(hostname string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
hashedHostname := knownhosts.Normalize(hostname)
|
|
lineHostname := knownhosts.Line([]string{hashedHostname}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineHostname))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hashedRemote := knownhosts.Normalize(remote.String())
|
|
lineRemote := knownhosts.Line([]string{hashedRemote}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineRemote))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Debugf("SSH host key generated: %s", lineHostname)
|
|
logrus.Debugf("SSH host key generated: %s", lineRemote)
|
|
|
|
return nil
|
|
}
|
|
|
|
// HostKeyAddCallback ensures server ssh host keys are handled
|
|
func HostKeyAddCallback(hostnameAndPort string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
exists, _, err := GetHostKey(hostnameAndPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
logrus.Debugf("server SSH host key found for %s, moving on", hostname)
|
|
return nil
|
|
}
|
|
|
|
if !exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
parsedPubKey := FingerprintSHA256(pubKey)
|
|
|
|
fmt.Printf(fmt.Sprintf(`
|
|
You are attempting to make an SSH connection to a server but there is no entry
|
|
in your ~/.ssh/known_hosts file which confirms that this is indeed the server
|
|
you want to connect to. Please take a moment to validate the following SSH host
|
|
key, it is important.
|
|
|
|
Host: %s
|
|
Fingerprint: %s
|
|
|
|
If this is confusing to you, you can read the article below and learn how to
|
|
validate this fingerprint safely. Thanks to the comrades at cyberia.club for
|
|
writing this extensive guide <3
|
|
|
|
https://sequentialread.com/understanding-the-secure-shell-protocol-ssh/
|
|
|
|
`, hostname, parsedPubKey))
|
|
|
|
response := false
|
|
prompt := &survey.Confirm{
|
|
Message: "are you sure you trust this host key?",
|
|
}
|
|
|
|
if err := survey.AskOne(prompt, &response); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !response {
|
|
logrus.Fatal("exiting as requested")
|
|
}
|
|
|
|
logrus.Debugf("attempting to insert server SSH host key for %s, %s", hostnameAndPort, remote)
|
|
|
|
if err := InsertHostKey(hostnameAndPort, remote, pubKey); err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Infof("successfully added server SSH host key for %s", hostname)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// connect makes the SSH connection
|
|
func connect(username, host, port string, authMethod ssh.AuthMethod, timeout time.Duration) (*Client, error) {
|
|
config := &ssh.ClientConfig{
|
|
User: username,
|
|
Auth: []ssh.AuthMethod{authMethod},
|
|
HostKeyCallback: HostKeyAddCallback, // the main reason why we fork
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", host, port)
|
|
|
|
logrus.Debugf("tcp dialing %s", hostnameAndPort)
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
conn, err = net.DialTimeout("tcp", hostnameAndPort, timeout)
|
|
if err != nil {
|
|
logrus.Debugf("tcp dialing %s failed, trying via ~/.ssh/config", hostnameAndPort)
|
|
hostConfig, err := GetHostConfig(host, username, port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn, err = net.DialTimeout("tcp", fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port), timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, hostnameAndPort, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := ssh.NewClient(sshConn, chans, reqs)
|
|
c := &Client{SSHClient: client}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func connectWithAgentTimeout(host, username, port string, timeout time.Duration) (*Client, error) {
|
|
logrus.Debugf("using ssh-agent to make an SSH connection for %s", host)
|
|
|
|
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
agentCl := agent.NewClient(sshAgent)
|
|
authMethod := ssh.PublicKeysCallback(agentCl.Signers)
|
|
|
|
loadedKeys, err := agentCl.List()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var convertedKeys []string
|
|
for _, key := range loadedKeys {
|
|
convertedKeys = append(convertedKeys, key.String())
|
|
}
|
|
|
|
if len(convertedKeys) > 0 {
|
|
logrus.Debugf("ssh-agent has these keys loaded: %s", strings.Join(convertedKeys, ","))
|
|
} else {
|
|
logrus.Debug("ssh-agent has no keys loaded")
|
|
}
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
func connectWithPasswordTimeout(host, username, port, pass string, timeout time.Duration) (*Client, error) {
|
|
authMethod := ssh.Password(pass)
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
// EnsureHostKey ensures that a host key trusted and added to the ~/.ssh/known_hosts file
|
|
func EnsureHostKey(hostname string) error {
|
|
if hostname == "default" || hostname == "local" {
|
|
logrus.Debugf("not checking server SSH host key against local/default target")
|
|
return nil
|
|
}
|
|
|
|
exists, _, err := GetHostKey(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FingerprintSHA256 generates the SHA256 fingerprint for a server SSH host key
|
|
func FingerprintSHA256(key ssh.PublicKey) string {
|
|
hash := sha256.Sum256(key.Marshal())
|
|
b64hash := base64.StdEncoding.EncodeToString(hash[:])
|
|
trimmed := strings.TrimRight(b64hash, "=")
|
|
return fmt.Sprintf("SHA256:%s", trimmed)
|
|
}
|
|
|
|
// GetContextConnDetails retrieves SSH connection details from a docker context endpoint
|
|
func GetContextConnDetails(serverName string) (*dockerSSHPkg.Spec, error) {
|
|
dockerContextStore := context.NewDefaultDockerContextStore()
|
|
contexts, err := dockerContextStore.Store.List()
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
if strings.Contains(serverName, ":") {
|
|
serverName = strings.Split(serverName, ":")[0]
|
|
}
|
|
|
|
for _, ctx := range contexts {
|
|
endpoint, err := context.GetContextEndpoint(ctx)
|
|
if err != nil && strings.Contains(err.Error(), "does not exist") {
|
|
// No local context found, we can continue safely
|
|
continue
|
|
}
|
|
if ctx.Name == serverName {
|
|
ctxConnDetails, err := dockerSSHPkg.ParseURL(endpoint)
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
logrus.Debugf("found context connection details %v for %s", ctxConnDetails, serverName)
|
|
return ctxConnDetails, nil
|
|
}
|
|
}
|
|
|
|
hostConfig, err := GetHostConfig(serverName, "", "")
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
logrus.Debugf("couldn't find a docker context matching %s", serverName)
|
|
logrus.Debugf("searching ~/.ssh/config for a Host entry for %s", serverName)
|
|
|
|
connDetails := &dockerSSHPkg.Spec{
|
|
Host: hostConfig.Host,
|
|
User: hostConfig.User,
|
|
Port: hostConfig.Port,
|
|
}
|
|
|
|
logrus.Debugf("using %v from ~/.ssh/config for connection details", connDetails)
|
|
|
|
return connDetails, nil
|
|
}
|
|
|
|
// GetHostConfig retrieves a ~/.ssh/config config for a host.
|
|
func GetHostConfig(hostname, username, port string) (HostConfig, error) {
|
|
var hostConfig HostConfig
|
|
|
|
var host, idf string
|
|
|
|
if host = ssh_config.Get(hostname, "Hostname"); host == "" {
|
|
logrus.Debugf("no hostname found in SSH config, assuming %s", hostname)
|
|
host = hostname
|
|
}
|
|
|
|
if username == "" {
|
|
if username = ssh_config.Get(hostname, "User"); username == "" {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
logrus.Debugf("no username found in SSH config or passed on command-line, assuming %s", username)
|
|
username = systemUser.Username
|
|
}
|
|
}
|
|
|
|
if port == "" {
|
|
if port = ssh_config.Get(hostname, "Port"); port == "" {
|
|
logrus.Debugf("no port found in SSH config or passed on command-line, assuming 22")
|
|
port = "22"
|
|
}
|
|
}
|
|
|
|
idf = ssh_config.Get(hostname, "IdentityFile")
|
|
if idf != "" {
|
|
var err error
|
|
idf, err = identityFileAbsPath(idf)
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
hostConfig.IdentityFile = idf
|
|
}
|
|
|
|
hostConfig.Host = host
|
|
hostConfig.Port = port
|
|
hostConfig.User = username
|
|
|
|
logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)
|
|
|
|
return hostConfig, nil
|
|
}
|
|
|
|
func identityFileAbsPath(relPath string) (string, error) {
|
|
var err error
|
|
var absPath string
|
|
|
|
if strings.HasPrefix(relPath, "~/") {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
absPath = filepath.Join(systemUser.HomeDir, relPath[2:])
|
|
} else {
|
|
absPath, err = filepath.Abs(relPath)
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
}
|
|
|
|
logrus.Debugf("resolved %s to %s to read the ssh identity file", relPath, absPath)
|
|
|
|
return absPath, nil
|
|
}
|