549 lines
13 KiB
Go
549 lines
13 KiB
Go
package ssh
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"coopcloud.tech/abra/pkg/context"
|
|
"github.com/AlecAivazis/survey/v2"
|
|
dockerSSHPkg "github.com/docker/cli/cli/connhelper/ssh"
|
|
sshPkg "github.com/gliderlabs/ssh"
|
|
"github.com/kevinburke/ssh_config"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
var KnownHostsPath = filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
|
|
|
|
type Client struct {
|
|
SSHClient *ssh.Client
|
|
}
|
|
|
|
// HostConfig is a SSH host config.
|
|
type HostConfig struct {
|
|
Host string
|
|
IdentityFile string
|
|
Port string
|
|
User string
|
|
}
|
|
|
|
// Exec cmd on the remote host and return stderr and stdout
|
|
func (c *Client) Exec(cmd string) ([]byte, error) {
|
|
session, err := c.SSHClient.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
|
|
return session.CombinedOutput(cmd)
|
|
}
|
|
|
|
// Close the underlying SSH connection
|
|
func (c *Client) Close() error {
|
|
return c.SSHClient.Close()
|
|
}
|
|
|
|
// New creates a new SSH client connection.
|
|
func New(domainName, sshAuth, username, port string) (*Client, error) {
|
|
var client *Client
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(domainName)
|
|
if err != nil {
|
|
return client, nil
|
|
}
|
|
|
|
if sshAuth == "identity-file" {
|
|
var err error
|
|
client, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
} else {
|
|
password := ""
|
|
prompt := &survey.Password{
|
|
Message: "SSH password?",
|
|
}
|
|
if err := survey.AskOne(prompt, &password); err != nil {
|
|
return client, err
|
|
}
|
|
|
|
var err error
|
|
client, err = connectWithPasswordTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
password,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// sudoWriter supports sudo command handling
|
|
type sudoWriter struct {
|
|
b bytes.Buffer
|
|
pw string
|
|
stdin io.Writer
|
|
m sync.Mutex
|
|
}
|
|
|
|
// Write satisfies the write interface for sudoWriter
|
|
func (w *sudoWriter) Write(p []byte) (int, error) {
|
|
if strings.Contains(string(p), "sudo_password") {
|
|
w.stdin.Write([]byte(w.pw + "\n"))
|
|
w.pw = ""
|
|
return len(p), nil
|
|
}
|
|
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
return w.b.Write(p)
|
|
}
|
|
|
|
// RunSudoCmd runs SSH commands and streams output
|
|
func RunSudoCmd(cmd, passwd string, cl *Client) error {
|
|
session, err := cl.SSHClient.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
sudoCmd := fmt.Sprintf("SSH_ASKPASS=/usr/bin/ssh-askpass; sudo -p sudo_password -S %s", cmd)
|
|
|
|
w := &sudoWriter{pw: passwd}
|
|
w.stdin, err = session.StdinPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session.Stdout = w
|
|
session.Stderr = w
|
|
|
|
modes := ssh.TerminalModes{
|
|
ssh.ECHO: 0,
|
|
ssh.TTY_OP_ISPEED: 14400,
|
|
ssh.TTY_OP_OSPEED: 14400,
|
|
}
|
|
|
|
err = session.RequestPty("xterm", 80, 40, modes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := session.Run(sudoCmd); err != nil {
|
|
return fmt.Errorf("%s", string(w.b.Bytes()))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnsureKnowHostsFiles ensures that ~/.ssh/known_hosts is created
|
|
func EnsureKnowHostsFiles() error {
|
|
if _, err := os.Stat(KnownHostsPath); os.IsNotExist(err) {
|
|
logrus.Debugf("missing %s, creating now", KnownHostsPath)
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_CREATE, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetHostKey checks if a host key is registered in the ~/.ssh/known_hosts file
|
|
func GetHostKey(hostname string) (bool, sshPkg.PublicKey, error) {
|
|
var hostKey sshPkg.PublicKey
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
if err := EnsureKnowHostsFiles(); err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
file, err := os.Open(KnownHostsPath)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
defer file.Close()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
fields := strings.Split(scanner.Text(), " ")
|
|
if len(fields) != 3 {
|
|
continue
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", ctxConnDetails.Host, ctxConnDetails.Port)
|
|
hashed := knownhosts.Normalize(hostnameAndPort)
|
|
|
|
if strings.Contains(fields[0], hashed) {
|
|
var err error
|
|
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
|
|
if err != nil {
|
|
return false, hostKey, fmt.Errorf("error parsing server SSH host key %q: %v", fields[2], err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if hostKey != nil {
|
|
logrus.Debugf("server SSH host key present in ~/.ssh/known_hosts for %s", hostname)
|
|
return true, hostKey, nil
|
|
}
|
|
|
|
return false, hostKey, nil
|
|
}
|
|
|
|
// InsertHostKey adds a new host key to the ~/.ssh/known_hosts file
|
|
func InsertHostKey(hostname string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
hashedHostname := knownhosts.Normalize(hostname)
|
|
lineHostname := knownhosts.Line([]string{hashedHostname}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineHostname))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hashedRemote := knownhosts.Normalize(remote.String())
|
|
lineRemote := knownhosts.Line([]string{hashedRemote}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineRemote))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Debugf("SSH host key generated: %s", lineHostname)
|
|
logrus.Debugf("SSH host key generated: %s", lineRemote)
|
|
|
|
return nil
|
|
}
|
|
|
|
// HostKeyAddCallback ensures server ssh host keys are handled
|
|
func HostKeyAddCallback(hostnameAndPort string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
exists, _, err := GetHostKey(hostnameAndPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
logrus.Debugf("server SSH host key found for %s", hostname)
|
|
return nil
|
|
}
|
|
|
|
if !exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
parsedPubKey := FingerprintSHA256(pubKey)
|
|
|
|
fmt.Printf(fmt.Sprintf(`
|
|
You are attempting to make an SSH connection to a server but there is no entry
|
|
in your ~/.ssh/known_hosts file which confirms that you have already validated
|
|
that this is indeed the server you want to connect to. Please take a moment to
|
|
validate the following SSH host key, it is important.
|
|
|
|
Host: %s
|
|
Fingerprint: %s
|
|
|
|
If this is confusing to you, you can read the article below and learn how to
|
|
validate this fingerprint safely. Thanks to the comrades at cyberia.club for
|
|
writing this extensive guide <3
|
|
|
|
https://sequentialread.com/understanding-the-secure-shell-protocol-ssh/
|
|
|
|
`, hostname, parsedPubKey))
|
|
|
|
response := false
|
|
prompt := &survey.Confirm{
|
|
Message: "are you sure you trust this host key?",
|
|
}
|
|
|
|
if err := survey.AskOne(prompt, &response); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !response {
|
|
logrus.Fatal("exiting as requested")
|
|
}
|
|
|
|
logrus.Debugf("attempting to insert server SSH host key for %s, %s", hostnameAndPort, remote)
|
|
|
|
if err := InsertHostKey(hostnameAndPort, remote, pubKey); err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Infof("successfully added server SSH host key for %s", hostname)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// connect makes the SSH connection
|
|
func connect(username, host, port string, authMethod ssh.AuthMethod, timeout time.Duration) (*Client, error) {
|
|
config := &ssh.ClientConfig{
|
|
User: username,
|
|
Auth: []ssh.AuthMethod{authMethod},
|
|
HostKeyCallback: HostKeyAddCallback, // the main reason why we fork
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", host, port)
|
|
|
|
logrus.Debugf("tcp dialing %s", hostnameAndPort)
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
conn, err = net.DialTimeout("tcp", hostnameAndPort, timeout)
|
|
if err != nil {
|
|
logrus.Debugf("tcp dialing %s failed, trying via ~/.ssh/config", hostnameAndPort)
|
|
hostConfig, err := GetHostConfig(host, username, port, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn, err = net.DialTimeout("tcp", fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port), timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, hostnameAndPort, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := ssh.NewClient(sshConn, chans, reqs)
|
|
c := &Client{SSHClient: client}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func connectWithAgentTimeout(host, username, port string, timeout time.Duration) (*Client, error) {
|
|
logrus.Debugf("using ssh-agent to make an SSH connection for %s", host)
|
|
|
|
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
agentCl := agent.NewClient(sshAgent)
|
|
authMethod := ssh.PublicKeysCallback(agentCl.Signers)
|
|
|
|
loadedKeys, err := agentCl.List()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var convertedKeys []string
|
|
for _, key := range loadedKeys {
|
|
convertedKeys = append(convertedKeys, key.String())
|
|
}
|
|
|
|
if len(convertedKeys) > 0 {
|
|
logrus.Debugf("ssh-agent has these keys loaded: %s", strings.Join(convertedKeys, ","))
|
|
} else {
|
|
logrus.Debug("ssh-agent has no keys loaded")
|
|
}
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
func connectWithPasswordTimeout(host, username, port, pass string, timeout time.Duration) (*Client, error) {
|
|
authMethod := ssh.Password(pass)
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
// EnsureHostKey ensures that a host key trusted and added to the ~/.ssh/known_hosts file
|
|
func EnsureHostKey(hostname string) error {
|
|
if hostname == "default" || hostname == "local" {
|
|
logrus.Debugf("not checking server SSH host key against local/default target")
|
|
return nil
|
|
}
|
|
|
|
exists, _, err := GetHostKey(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FingerprintSHA256 generates the SHA256 fingerprint for a server SSH host key
|
|
func FingerprintSHA256(key ssh.PublicKey) string {
|
|
hash := sha256.Sum256(key.Marshal())
|
|
b64hash := base64.StdEncoding.EncodeToString(hash[:])
|
|
trimmed := strings.TrimRight(b64hash, "=")
|
|
return fmt.Sprintf("SHA256:%s", trimmed)
|
|
}
|
|
|
|
// GetContextConnDetails retrieves SSH connection details from a docker context endpoint
|
|
func GetContextConnDetails(serverName string) (*dockerSSHPkg.Spec, error) {
|
|
dockerContextStore := context.NewDefaultDockerContextStore()
|
|
contexts, err := dockerContextStore.Store.List()
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
if strings.Contains(serverName, ":") {
|
|
serverName = strings.Split(serverName, ":")[0]
|
|
}
|
|
|
|
for _, ctx := range contexts {
|
|
endpoint, err := context.GetContextEndpoint(ctx)
|
|
if err != nil && strings.Contains(err.Error(), "does not exist") {
|
|
// No local context found, we can continue safely
|
|
continue
|
|
}
|
|
if ctx.Name == serverName {
|
|
ctxConnDetails, err := dockerSSHPkg.ParseURL(endpoint)
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
logrus.Debugf("found context connection details %v for %s", ctxConnDetails, serverName)
|
|
return ctxConnDetails, nil
|
|
}
|
|
}
|
|
|
|
hostConfig, err := GetHostConfig(serverName, "", "", false)
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
logrus.Debugf("couldn't find a docker context matching %s", serverName)
|
|
logrus.Debugf("searching ~/.ssh/config for a Host entry for %s", serverName)
|
|
|
|
connDetails := &dockerSSHPkg.Spec{
|
|
Host: hostConfig.Host,
|
|
User: hostConfig.User,
|
|
Port: hostConfig.Port,
|
|
}
|
|
|
|
logrus.Debugf("using %v from ~/.ssh/config for connection details", connDetails)
|
|
|
|
return connDetails, nil
|
|
}
|
|
|
|
// GetHostConfig retrieves a ~/.ssh/config config for a host.
|
|
func GetHostConfig(hostname, username, port string, override bool) (HostConfig, error) {
|
|
var hostConfig HostConfig
|
|
|
|
if hostname == "" || override {
|
|
if sshHost := ssh_config.Get(hostname, "Hostname"); sshHost != "" {
|
|
hostname = sshHost
|
|
}
|
|
}
|
|
|
|
if username == "" || override {
|
|
if sshUser := ssh_config.Get(hostname, "User"); sshUser != "" {
|
|
username = sshUser
|
|
} else {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
username = systemUser.Username
|
|
}
|
|
}
|
|
|
|
if port == "" || override {
|
|
if sshPort := ssh_config.Get(hostname, "Port"); sshPort != "" {
|
|
// skip override probably correct port with dummy default value from
|
|
// ssh_config which is 22. only when the original port number is empty
|
|
// should we try this default. this might not cover all cases
|
|
// unfortunately.
|
|
if port != "" && sshPort != "22" {
|
|
port = sshPort
|
|
}
|
|
}
|
|
}
|
|
|
|
if idf := ssh_config.Get(hostname, "IdentityFile"); idf != "" && idf != "~/.ssh/identity" {
|
|
var err error
|
|
idf, err = identityFileAbsPath(idf)
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
hostConfig.IdentityFile = idf
|
|
} else {
|
|
hostConfig.IdentityFile = ""
|
|
}
|
|
|
|
hostConfig.Host = hostname
|
|
hostConfig.Port = port
|
|
hostConfig.User = username
|
|
|
|
logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)
|
|
|
|
return hostConfig, nil
|
|
}
|
|
|
|
func identityFileAbsPath(relPath string) (string, error) {
|
|
var err error
|
|
var absPath string
|
|
|
|
if strings.HasPrefix(relPath, "~/") {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
absPath = filepath.Join(systemUser.HomeDir, relPath[2:])
|
|
} else {
|
|
absPath, err = filepath.Abs(relPath)
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
}
|
|
|
|
logrus.Debugf("resolved %s to %s to read the ssh identity file", relPath, absPath)
|
|
|
|
return absPath, nil
|
|
}
|