forked from toolshed/abra
.gitea
cli
cmd
pkg
app
autocomplete
client
compose
config
container
context
dns
formatter
git
integration
limit
lint
recipe
secret
server
service
ssh
ssh.go
upstream
web
scripts
tests
.drone.yml
.e2e.env.sample
.envrc.sample
.gitignore
.goreleaser.yml
Makefile
README.md
go.mod
go.sum
renovate.json
We could default to ~/.ssh/id_rsa but if that doesn't exist, then we'll just be confusing people in the logs. Best is to just rely on the ssh-agent which overrides this anyway. We will document this. See coop-cloud/organising#277
544 lines
13 KiB
Go
544 lines
13 KiB
Go
package ssh
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"coopcloud.tech/abra/pkg/context"
|
|
"github.com/AlecAivazis/survey/v2"
|
|
dockerSSHPkg "github.com/docker/cli/cli/connhelper/ssh"
|
|
sshPkg "github.com/gliderlabs/ssh"
|
|
"github.com/kevinburke/ssh_config"
|
|
"github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
var KnownHostsPath = filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
|
|
|
|
type Client struct {
|
|
SSHClient *ssh.Client
|
|
}
|
|
|
|
// HostConfig is a SSH host config.
|
|
type HostConfig struct {
|
|
Host string
|
|
IdentityFile string
|
|
Port string
|
|
User string
|
|
}
|
|
|
|
// Exec cmd on the remote host and return stderr and stdout
|
|
func (c *Client) Exec(cmd string) ([]byte, error) {
|
|
session, err := c.SSHClient.NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
|
|
return session.CombinedOutput(cmd)
|
|
}
|
|
|
|
// Close the underlying SSH connection
|
|
func (c *Client) Close() error {
|
|
return c.SSHClient.Close()
|
|
}
|
|
|
|
// New creates a new SSH client connection.
|
|
func New(domainName, sshAuth, username, port string) (*Client, error) {
|
|
var client *Client
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(domainName)
|
|
if err != nil {
|
|
return client, nil
|
|
}
|
|
|
|
if sshAuth == "identity-file" {
|
|
var err error
|
|
client, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
} else {
|
|
password := ""
|
|
prompt := &survey.Password{
|
|
Message: "SSH password?",
|
|
}
|
|
if err := survey.AskOne(prompt, &password); err != nil {
|
|
return client, err
|
|
}
|
|
|
|
var err error
|
|
client, err = connectWithPasswordTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
password,
|
|
5*time.Second,
|
|
)
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
// sudoWriter supports sudo command handling
|
|
type sudoWriter struct {
|
|
b bytes.Buffer
|
|
pw string
|
|
stdin io.Writer
|
|
m sync.Mutex
|
|
}
|
|
|
|
// Write satisfies the write interface for sudoWriter
|
|
func (w *sudoWriter) Write(p []byte) (int, error) {
|
|
if strings.Contains(string(p), "sudo_password") {
|
|
w.stdin.Write([]byte(w.pw + "\n"))
|
|
w.pw = ""
|
|
return len(p), nil
|
|
}
|
|
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
return w.b.Write(p)
|
|
}
|
|
|
|
// RunSudoCmd runs SSH commands and streams output
|
|
func RunSudoCmd(cmd, passwd string, cl *Client) error {
|
|
session, err := cl.SSHClient.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
sudoCmd := fmt.Sprintf("SSH_ASKPASS=/usr/bin/ssh-askpass; sudo -p sudo_password -S %s", cmd)
|
|
|
|
w := &sudoWriter{pw: passwd}
|
|
w.stdin, err = session.StdinPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session.Stdout = w
|
|
session.Stderr = w
|
|
|
|
modes := ssh.TerminalModes{
|
|
ssh.ECHO: 0,
|
|
ssh.TTY_OP_ISPEED: 14400,
|
|
ssh.TTY_OP_OSPEED: 14400,
|
|
}
|
|
|
|
err = session.RequestPty("xterm", 80, 40, modes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := session.Run(sudoCmd); err != nil {
|
|
return fmt.Errorf("%s", string(w.b.Bytes()))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnsureKnowHostsFiles ensures that ~/.ssh/known_hosts is created
|
|
func EnsureKnowHostsFiles() error {
|
|
if _, err := os.Stat(KnownHostsPath); os.IsNotExist(err) {
|
|
logrus.Debugf("missing %s, creating now", KnownHostsPath)
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_CREATE, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetHostKey checks if a host key is registered in the ~/.ssh/known_hosts file
|
|
func GetHostKey(hostname string) (bool, sshPkg.PublicKey, error) {
|
|
var hostKey sshPkg.PublicKey
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
if err := EnsureKnowHostsFiles(); err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
|
|
file, err := os.Open(KnownHostsPath)
|
|
if err != nil {
|
|
return false, hostKey, err
|
|
}
|
|
defer file.Close()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
fields := strings.Split(scanner.Text(), " ")
|
|
if len(fields) != 3 {
|
|
continue
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", ctxConnDetails.Host, ctxConnDetails.Port)
|
|
hashed := knownhosts.Normalize(hostnameAndPort)
|
|
|
|
if strings.Contains(fields[0], hashed) {
|
|
var err error
|
|
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
|
|
if err != nil {
|
|
return false, hostKey, fmt.Errorf("error parsing server SSH host key %q: %v", fields[2], err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if hostKey != nil {
|
|
logrus.Debugf("server SSH host key present in ~/.ssh/known_hosts for %s", hostname)
|
|
return true, hostKey, nil
|
|
}
|
|
|
|
return false, hostKey, nil
|
|
}
|
|
|
|
// InsertHostKey adds a new host key to the ~/.ssh/known_hosts file
|
|
func InsertHostKey(hostname string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
file, err := os.OpenFile(KnownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
hashedHostname := knownhosts.Normalize(hostname)
|
|
lineHostname := knownhosts.Line([]string{hashedHostname}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineHostname))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hashedRemote := knownhosts.Normalize(remote.String())
|
|
lineRemote := knownhosts.Line([]string{hashedRemote}, pubKey)
|
|
_, err = file.WriteString(fmt.Sprintf("%s\n", lineRemote))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Debugf("SSH host key generated: %s", lineHostname)
|
|
logrus.Debugf("SSH host key generated: %s", lineRemote)
|
|
|
|
return nil
|
|
}
|
|
|
|
// HostKeyAddCallback ensures server ssh host keys are handled
|
|
func HostKeyAddCallback(hostnameAndPort string, remote net.Addr, pubKey ssh.PublicKey) error {
|
|
exists, _, err := GetHostKey(hostnameAndPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
logrus.Debugf("server SSH host key found for %s", hostname)
|
|
return nil
|
|
}
|
|
|
|
if !exists {
|
|
hostname := strings.Split(hostnameAndPort, ":")[0]
|
|
parsedPubKey := FingerprintSHA256(pubKey)
|
|
|
|
fmt.Printf(fmt.Sprintf(`
|
|
You are attempting to make an SSH connection to a server but there is no entry
|
|
in your ~/.ssh/known_hosts file which confirms that you have already validated
|
|
that this is indeed the server you want to connect to. Please take a moment to
|
|
validate the following SSH host key, it is important.
|
|
|
|
Host: %s
|
|
Fingerprint: %s
|
|
|
|
If this is confusing to you, you can read the article below and learn how to
|
|
validate this fingerprint safely. Thanks to the comrades at cyberia.club for
|
|
writing this extensive guide <3
|
|
|
|
https://sequentialread.com/understanding-the-secure-shell-protocol-ssh/
|
|
|
|
`, hostname, parsedPubKey))
|
|
|
|
response := false
|
|
prompt := &survey.Confirm{
|
|
Message: "are you sure you trust this host key?",
|
|
}
|
|
|
|
if err := survey.AskOne(prompt, &response); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !response {
|
|
logrus.Fatal("exiting as requested")
|
|
}
|
|
|
|
logrus.Debugf("attempting to insert server SSH host key for %s, %s", hostnameAndPort, remote)
|
|
|
|
if err := InsertHostKey(hostnameAndPort, remote, pubKey); err != nil {
|
|
return err
|
|
}
|
|
|
|
logrus.Infof("successfully added server SSH host key for %s", hostname)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// connect makes the SSH connection
|
|
func connect(username, host, port string, authMethod ssh.AuthMethod, timeout time.Duration) (*Client, error) {
|
|
config := &ssh.ClientConfig{
|
|
User: username,
|
|
Auth: []ssh.AuthMethod{authMethod},
|
|
HostKeyCallback: HostKeyAddCallback, // the main reason why we fork
|
|
}
|
|
|
|
hostnameAndPort := fmt.Sprintf("%s:%s", host, port)
|
|
|
|
logrus.Debugf("tcp dialing %s", hostnameAndPort)
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
conn, err = net.DialTimeout("tcp", hostnameAndPort, timeout)
|
|
if err != nil {
|
|
logrus.Debugf("tcp dialing %s failed, trying via ~/.ssh/config", hostnameAndPort)
|
|
hostConfig, err := GetHostConfig(host, username, port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn, err = net.DialTimeout("tcp", fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port), timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, hostnameAndPort, config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := ssh.NewClient(sshConn, chans, reqs)
|
|
c := &Client{SSHClient: client}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func connectWithAgentTimeout(host, username, port string, timeout time.Duration) (*Client, error) {
|
|
logrus.Debugf("using ssh-agent to make an SSH connection for %s", host)
|
|
|
|
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
agentCl := agent.NewClient(sshAgent)
|
|
authMethod := ssh.PublicKeysCallback(agentCl.Signers)
|
|
|
|
loadedKeys, err := agentCl.List()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var convertedKeys []string
|
|
for _, key := range loadedKeys {
|
|
convertedKeys = append(convertedKeys, key.String())
|
|
}
|
|
|
|
if len(convertedKeys) > 0 {
|
|
logrus.Debugf("ssh-agent has these keys loaded: %s", strings.Join(convertedKeys, ","))
|
|
} else {
|
|
logrus.Debug("ssh-agent has no keys loaded")
|
|
}
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
func connectWithPasswordTimeout(host, username, port, pass string, timeout time.Duration) (*Client, error) {
|
|
authMethod := ssh.Password(pass)
|
|
|
|
return connect(username, host, port, authMethod, timeout)
|
|
}
|
|
|
|
// EnsureHostKey ensures that a host key trusted and added to the ~/.ssh/known_hosts file
|
|
func EnsureHostKey(hostname string) error {
|
|
if hostname == "default" || hostname == "local" {
|
|
logrus.Debugf("not checking server SSH host key against local/default target")
|
|
return nil
|
|
}
|
|
|
|
exists, _, err := GetHostKey(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
|
|
ctxConnDetails, err := GetContextConnDetails(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = connectWithAgentTimeout(
|
|
ctxConnDetails.Host,
|
|
ctxConnDetails.User,
|
|
ctxConnDetails.Port,
|
|
5*time.Second,
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FingerprintSHA256 generates the SHA256 fingerprint for a server SSH host key
|
|
func FingerprintSHA256(key ssh.PublicKey) string {
|
|
hash := sha256.Sum256(key.Marshal())
|
|
b64hash := base64.StdEncoding.EncodeToString(hash[:])
|
|
trimmed := strings.TrimRight(b64hash, "=")
|
|
return fmt.Sprintf("SHA256:%s", trimmed)
|
|
}
|
|
|
|
// GetContextConnDetails retrieves SSH connection details from a docker context endpoint
|
|
func GetContextConnDetails(serverName string) (*dockerSSHPkg.Spec, error) {
|
|
dockerContextStore := context.NewDefaultDockerContextStore()
|
|
contexts, err := dockerContextStore.Store.List()
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
if strings.Contains(serverName, ":") {
|
|
serverName = strings.Split(serverName, ":")[0]
|
|
}
|
|
|
|
for _, ctx := range contexts {
|
|
endpoint, err := context.GetContextEndpoint(ctx)
|
|
if err != nil && strings.Contains(err.Error(), "does not exist") {
|
|
// No local context found, we can continue safely
|
|
continue
|
|
}
|
|
if ctx.Name == serverName {
|
|
ctxConnDetails, err := dockerSSHPkg.ParseURL(endpoint)
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
logrus.Debugf("found context connection details %v for %s", ctxConnDetails, serverName)
|
|
return ctxConnDetails, nil
|
|
}
|
|
}
|
|
|
|
hostConfig, err := GetHostConfig(serverName, "", "")
|
|
if err != nil {
|
|
return &dockerSSHPkg.Spec{}, err
|
|
}
|
|
|
|
logrus.Debugf("couldn't find a docker context matching %s", serverName)
|
|
logrus.Debugf("searching ~/.ssh/config for a Host entry for %s", serverName)
|
|
|
|
connDetails := &dockerSSHPkg.Spec{
|
|
Host: hostConfig.Host,
|
|
User: hostConfig.User,
|
|
Port: hostConfig.Port,
|
|
}
|
|
|
|
logrus.Debugf("using %v from ~/.ssh/config for connection details", connDetails)
|
|
|
|
return connDetails, nil
|
|
}
|
|
|
|
// GetHostConfig retrieves a ~/.ssh/config config for a host.
|
|
func GetHostConfig(hostname, username, port string) (HostConfig, error) {
|
|
var hostConfig HostConfig
|
|
|
|
if hostname == "" {
|
|
if hostname = ssh_config.Get(hostname, "Hostname"); hostname == "" {
|
|
logrus.Debugf("no hostname found in SSH config, assuming %s", hostname)
|
|
}
|
|
}
|
|
|
|
if username == "" {
|
|
if username = ssh_config.Get(hostname, "User"); username == "" {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
logrus.Debugf("no username found in SSH config or passed on command-line, assuming %s", username)
|
|
username = systemUser.Username
|
|
}
|
|
}
|
|
|
|
if port == "" {
|
|
if port = ssh_config.Get(hostname, "Port"); port == "" {
|
|
logrus.Debugf("no port found in SSH config or passed on command-line, assuming 22")
|
|
port = "22"
|
|
}
|
|
}
|
|
|
|
if idf := ssh_config.Get(hostname, "IdentityFile"); idf != "" && idf != "~/.ssh/identity" {
|
|
var err error
|
|
idf, err = identityFileAbsPath(idf)
|
|
if err != nil {
|
|
return hostConfig, err
|
|
}
|
|
hostConfig.IdentityFile = idf
|
|
} else {
|
|
logrus.Debugf("no identity file found in SSH config for %s", hostname)
|
|
hostConfig.IdentityFile = ""
|
|
}
|
|
|
|
hostConfig.Host = hostname
|
|
hostConfig.Port = port
|
|
hostConfig.User = username
|
|
|
|
logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)
|
|
|
|
return hostConfig, nil
|
|
}
|
|
|
|
func identityFileAbsPath(relPath string) (string, error) {
|
|
var err error
|
|
var absPath string
|
|
|
|
if strings.HasPrefix(relPath, "~/") {
|
|
systemUser, err := user.Current()
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
absPath = filepath.Join(systemUser.HomeDir, relPath[2:])
|
|
} else {
|
|
absPath, err = filepath.Abs(relPath)
|
|
if err != nil {
|
|
return absPath, err
|
|
}
|
|
}
|
|
|
|
logrus.Debugf("resolved %s to %s to read the ssh identity file", relPath, absPath)
|
|
|
|
return absPath, nil
|
|
}
|