package ssh

import (
	"bufio"
	"bytes"
	"crypto/sha256"
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"os"
	"os/user"
	"path/filepath"
	"strings"
	"sync"
	"time"

	"coopcloud.tech/abra/pkg/context"
	"github.com/AlecAivazis/survey/v2"
	dockerSSHPkg "github.com/docker/cli/cli/connhelper/ssh"
	sshPkg "github.com/gliderlabs/ssh"
	"github.com/kevinburke/ssh_config"
	"github.com/sirupsen/logrus"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
	"golang.org/x/crypto/ssh/knownhosts"
)

var KnownHostsPath = filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")

type Client struct {
	SSHClient *ssh.Client
}

// HostConfig is a SSH host config.
type HostConfig struct {
	Host         string
	IdentityFile string
	Port         string
	User         string
}

// Exec cmd on the remote host and return stderr and stdout
func (c *Client) Exec(cmd string) ([]byte, error) {
	session, err := c.SSHClient.NewSession()
	if err != nil {
		return nil, err
	}
	defer session.Close()

	return session.CombinedOutput(cmd)
}

// Close the underlying SSH connection
func (c *Client) Close() error {
	return c.SSHClient.Close()
}

// New creates a new SSH client connection.
func New(domainName, sshAuth, username, port string) (*Client, error) {
	var client *Client

	ctxConnDetails, err := GetContextConnDetails(domainName)
	if err != nil {
		return client, nil
	}

	if sshAuth == "identity-file" {
		var err error
		client, err = connectWithAgentTimeout(
			ctxConnDetails.Host,
			ctxConnDetails.User,
			ctxConnDetails.Port,
			5*time.Second,
		)
		if err != nil {
			return client, err
		}
	} else {
		password := ""
		prompt := &survey.Password{
			Message: "SSH password?",
		}
		if err := survey.AskOne(prompt, &password); err != nil {
			return client, err
		}

		var err error
		client, err = connectWithPasswordTimeout(
			ctxConnDetails.Host,
			ctxConnDetails.User,
			ctxConnDetails.Port,
			password,
			5*time.Second,
		)
		if err != nil {
			return client, err
		}
	}

	return client, nil
}

// sudoWriter supports sudo command handling
type sudoWriter struct {
	b     bytes.Buffer
	pw    string
	stdin io.Writer
	m     sync.Mutex
}

// Write satisfies the write interface for sudoWriter
func (w *sudoWriter) Write(p []byte) (int, error) {
	if strings.Contains(string(p), "sudo_password") {
		w.stdin.Write([]byte(w.pw + "\n"))
		w.pw = ""
		return len(p), nil
	}

	w.m.Lock()
	defer w.m.Unlock()

	return w.b.Write(p)
}

// RunSudoCmd runs SSH commands and streams output
func RunSudoCmd(cmd, passwd string, cl *Client) error {
	session, err := cl.SSHClient.NewSession()
	if err != nil {
		return err
	}
	defer session.Close()

	sudoCmd := fmt.Sprintf("SSH_ASKPASS=/usr/bin/ssh-askpass; sudo -p sudo_password -S %s", cmd)

	w := &sudoWriter{pw: passwd}
	w.stdin, err = session.StdinPipe()
	if err != nil {
		return err
	}

	session.Stdout = w
	session.Stderr = w

	modes := ssh.TerminalModes{
		ssh.ECHO:          0,
		ssh.TTY_OP_ISPEED: 14400,
		ssh.TTY_OP_OSPEED: 14400,
	}

	err = session.RequestPty("xterm", 80, 40, modes)
	if err != nil {
		return err
	}

	if err := session.Run(sudoCmd); err != nil {
		return fmt.Errorf("%s", string(w.b.Bytes()))
	}

	return nil
}

// EnsureKnowHostsFiles ensures that ~/.ssh/known_hosts is created
func EnsureKnowHostsFiles() error {
	if _, err := os.Stat(KnownHostsPath); os.IsNotExist(err) {
		logrus.Debugf("missing %s, creating now", KnownHostsPath)
		file, err := os.OpenFile(KnownHostsPath, os.O_CREATE, 0600)
		if err != nil {
			return err
		}
		file.Close()
	}

	return nil
}

// GetHostKey checks if a host key is registered in the ~/.ssh/known_hosts file
func GetHostKey(hostname string) (bool, sshPkg.PublicKey, error) {
	var hostKey sshPkg.PublicKey

	ctxConnDetails, err := GetContextConnDetails(hostname)
	if err != nil {
		return false, hostKey, err
	}

	if err := EnsureKnowHostsFiles(); err != nil {
		return false, hostKey, err
	}

	file, err := os.Open(KnownHostsPath)
	if err != nil {
		return false, hostKey, err
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		fields := strings.Split(scanner.Text(), " ")
		if len(fields) != 3 {
			continue
		}

		hostnameAndPort := fmt.Sprintf("%s:%s", ctxConnDetails.Host, ctxConnDetails.Port)
		hashed := knownhosts.Normalize(hostnameAndPort)

		if strings.Contains(fields[0], hashed) {
			var err error
			hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
			if err != nil {
				return false, hostKey, fmt.Errorf("error parsing server SSH host key %q: %v", fields[2], err)
			}
			break
		}
	}

	if hostKey != nil {
		logrus.Debugf("server SSH host key present in ~/.ssh/known_hosts for %s", hostname)
		return true, hostKey, nil
	}

	return false, hostKey, nil
}

// InsertHostKey adds a new host key to the ~/.ssh/known_hosts file
func InsertHostKey(hostname string, remote net.Addr, pubKey ssh.PublicKey) error {
	file, err := os.OpenFile(KnownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
	if err != nil {
		return err
	}
	defer file.Close()

	hashedHostname := knownhosts.Normalize(hostname)
	lineHostname := knownhosts.Line([]string{hashedHostname}, pubKey)
	_, err = file.WriteString(fmt.Sprintf("%s\n", lineHostname))
	if err != nil {
		return err
	}

	hashedRemote := knownhosts.Normalize(remote.String())
	lineRemote := knownhosts.Line([]string{hashedRemote}, pubKey)
	_, err = file.WriteString(fmt.Sprintf("%s\n", lineRemote))
	if err != nil {
		return err
	}

	logrus.Debugf("SSH host key generated: %s", lineHostname)
	logrus.Debugf("SSH host key generated: %s", lineRemote)

	return nil
}

// HostKeyAddCallback ensures server ssh host keys are handled
func HostKeyAddCallback(hostnameAndPort string, remote net.Addr, pubKey ssh.PublicKey) error {
	exists, _, err := GetHostKey(hostnameAndPort)
	if err != nil {
		return err
	}

	if exists {
		hostname := strings.Split(hostnameAndPort, ":")[0]
		logrus.Debugf("server SSH host key found for %s", hostname)
		return nil
	}

	if !exists {
		hostname := strings.Split(hostnameAndPort, ":")[0]
		parsedPubKey := FingerprintSHA256(pubKey)

		fmt.Printf(fmt.Sprintf(`
You are attempting to make an SSH connection to a server but there is no entry
in your ~/.ssh/known_hosts file which confirms that you have already validated
that this is indeed the server you want to connect to. Please take a moment to
validate the following SSH host key, it is important.

    Host:        %s
    Fingerprint: %s

If this is confusing to you, you can read the article below and learn how to
validate this fingerprint safely. Thanks to the comrades at cyberia.club for
writing this extensive guide <3

    https://sequentialread.com/understanding-the-secure-shell-protocol-ssh/

`, hostname, parsedPubKey))

		response := false
		prompt := &survey.Confirm{
			Message: "are you sure you trust this host key?",
		}

		if err := survey.AskOne(prompt, &response); err != nil {
			return err
		}

		if !response {
			logrus.Fatal("exiting as requested")
		}

		logrus.Debugf("attempting to insert server SSH host key for %s, %s", hostnameAndPort, remote)

		if err := InsertHostKey(hostnameAndPort, remote, pubKey); err != nil {
			return err
		}

		logrus.Infof("successfully added server SSH host key for %s", hostname)
	}

	return nil
}

// connect makes the SSH connection
func connect(username, host, port string, authMethod ssh.AuthMethod, timeout time.Duration) (*Client, error) {
	config := &ssh.ClientConfig{
		User:            username,
		Auth:            []ssh.AuthMethod{authMethod},
		HostKeyCallback: HostKeyAddCallback, // the main reason why we fork
	}

	hostnameAndPort := fmt.Sprintf("%s:%s", host, port)

	logrus.Debugf("tcp dialing %s", hostnameAndPort)

	var conn net.Conn
	var err error
	conn, err = net.DialTimeout("tcp", hostnameAndPort, timeout)
	if err != nil {
		logrus.Debugf("tcp dialing %s failed, trying via ~/.ssh/config", hostnameAndPort)
		hostConfig, err := GetHostConfig(host, username, port, true)
		if err != nil {
			return nil, err
		}
		conn, err = net.DialTimeout("tcp", fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port), timeout)
		if err != nil {
			return nil, err
		}
	}

	sshConn, chans, reqs, err := ssh.NewClientConn(conn, hostnameAndPort, config)
	if err != nil {
		return nil, err
	}

	client := ssh.NewClient(sshConn, chans, reqs)
	c := &Client{SSHClient: client}

	return c, nil
}

func connectWithAgentTimeout(host, username, port string, timeout time.Duration) (*Client, error) {
	logrus.Debugf("using ssh-agent to make an SSH connection for %s", host)

	sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
	if err != nil {
		return nil, err
	}

	agentCl := agent.NewClient(sshAgent)
	authMethod := ssh.PublicKeysCallback(agentCl.Signers)

	loadedKeys, err := agentCl.List()
	if err != nil {
		return nil, err
	}

	var convertedKeys []string
	for _, key := range loadedKeys {
		convertedKeys = append(convertedKeys, key.String())
	}

	if len(convertedKeys) > 0 {
		logrus.Debugf("ssh-agent has these keys loaded: %s", strings.Join(convertedKeys, ","))
	} else {
		logrus.Debug("ssh-agent has no keys loaded")
	}

	return connect(username, host, port, authMethod, timeout)
}

func connectWithPasswordTimeout(host, username, port, pass string, timeout time.Duration) (*Client, error) {
	authMethod := ssh.Password(pass)

	return connect(username, host, port, authMethod, timeout)
}

// EnsureHostKey ensures that a host key trusted and added to the ~/.ssh/known_hosts file
func EnsureHostKey(hostname string) error {
	if hostname == "default" || hostname == "local" {
		logrus.Debugf("not checking server SSH host key against local/default target")
		return nil
	}

	exists, _, err := GetHostKey(hostname)
	if err != nil {
		return err
	}
	if exists {
		return nil
	}

	ctxConnDetails, err := GetContextConnDetails(hostname)
	if err != nil {
		return err
	}

	_, err = connectWithAgentTimeout(
		ctxConnDetails.Host,
		ctxConnDetails.User,
		ctxConnDetails.Port,
		5*time.Second,
	)

	if err != nil {
		return err
	}

	return nil
}

// FingerprintSHA256 generates the SHA256 fingerprint for a server SSH host key
func FingerprintSHA256(key ssh.PublicKey) string {
	hash := sha256.Sum256(key.Marshal())
	b64hash := base64.StdEncoding.EncodeToString(hash[:])
	trimmed := strings.TrimRight(b64hash, "=")
	return fmt.Sprintf("SHA256:%s", trimmed)
}

// GetContextConnDetails retrieves SSH connection details from a docker context endpoint
func GetContextConnDetails(serverName string) (*dockerSSHPkg.Spec, error) {
	dockerContextStore := context.NewDefaultDockerContextStore()
	contexts, err := dockerContextStore.Store.List()
	if err != nil {
		return &dockerSSHPkg.Spec{}, err
	}

	if strings.Contains(serverName, ":") {
		serverName = strings.Split(serverName, ":")[0]
	}

	for _, ctx := range contexts {
		endpoint, err := context.GetContextEndpoint(ctx)
		if err != nil && strings.Contains(err.Error(), "does not exist") {
			// No local context found, we can continue safely
			continue
		}
		if ctx.Name == serverName {
			ctxConnDetails, err := dockerSSHPkg.ParseURL(endpoint)
			if err != nil {
				return &dockerSSHPkg.Spec{}, err
			}
			logrus.Debugf("found context connection details %v for %s", ctxConnDetails, serverName)
			return ctxConnDetails, nil
		}
	}

	hostConfig, err := GetHostConfig(serverName, "", "", false)
	if err != nil {
		return &dockerSSHPkg.Spec{}, err
	}

	logrus.Debugf("couldn't find a docker context matching %s", serverName)
	logrus.Debugf("searching ~/.ssh/config for a Host entry for %s", serverName)

	connDetails := &dockerSSHPkg.Spec{
		Host: hostConfig.Host,
		User: hostConfig.User,
		Port: hostConfig.Port,
	}

	logrus.Debugf("using %v from ~/.ssh/config for connection details", connDetails)

	return connDetails, nil
}

// GetHostConfig retrieves a ~/.ssh/config config for a host.
func GetHostConfig(hostname, username, port string, override bool) (HostConfig, error) {
	var hostConfig HostConfig

	if hostname == "" || override {
		if sshHost := ssh_config.Get(hostname, "Hostname"); sshHost != "" {
			hostname = sshHost
		}
	}

	if username == "" || override {
		if sshUser := ssh_config.Get(hostname, "User"); sshUser != "" {
			username = sshUser
		} else {
			systemUser, err := user.Current()
			if err != nil {
				return hostConfig, err
			}
			username = systemUser.Username
		}
	}

	if port == "" || override {
		if sshPort := ssh_config.Get(hostname, "Port"); sshPort != "" {
			// skip override probably correct port with dummy default value from
			// ssh_config which is 22. only when the original port number is empty
			// should we try this default. this might not cover all cases
			// unfortunately.
			if port != "" && sshPort != "22" {
				port = sshPort
			}
		}
	}

	if idf := ssh_config.Get(hostname, "IdentityFile"); idf != "" && idf != "~/.ssh/identity" {
		var err error
		idf, err = identityFileAbsPath(idf)
		if err != nil {
			return hostConfig, err
		}
		hostConfig.IdentityFile = idf
	} else {
		hostConfig.IdentityFile = ""
	}

	hostConfig.Host = hostname
	hostConfig.Port = port
	hostConfig.User = username

	logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)

	return hostConfig, nil
}

func identityFileAbsPath(relPath string) (string, error) {
	var err error
	var absPath string

	if strings.HasPrefix(relPath, "~/") {
		systemUser, err := user.Current()
		if err != nil {
			return absPath, err
		}
		absPath = filepath.Join(systemUser.HomeDir, relPath[2:])
	} else {
		absPath, err = filepath.Abs(relPath)
		if err != nil {
			return absPath, err
		}
	}

	logrus.Debugf("resolved %s to %s to read the ssh identity file", relPath, absPath)

	return absPath, nil
}