package ssh

import (
	"bufio"
	"bytes"
	"fmt"
	"io"
	"os/user"
	"sync"
	"time"

	"github.com/AlecAivazis/survey/v2"
	"github.com/kevinburke/ssh_config"
	"github.com/sfreiberg/simplessh"
	"github.com/sirupsen/logrus"
)

// HostConfig is a SSH host config.
type HostConfig struct {
	Host         string
	IdentityFile string
	Port         string
	User         string
}

// GetHostConfig retrieves a ~/.ssh/config config for a host.
func GetHostConfig(hostname, username, port string) (HostConfig, error) {
	var hostConfig HostConfig

	var host, idf string

	if host = ssh_config.Get(hostname, "Hostname"); host == "" {
		logrus.Debugf("no hostname found in SSH config, assuming %s", hostname)
		host = hostname
	}

	if username == "" {
		if username = ssh_config.Get(hostname, "User"); username == "" {
			systemUser, err := user.Current()
			if err != nil {
				return hostConfig, err
			}
			logrus.Debugf("no username found in SSH config or passed on command-line, assuming %s", username)
			username = systemUser.Username
		}
	}

	if port == "" {
		if port = ssh_config.Get(hostname, "Port"); port == "" {
			logrus.Debugf("no port found in SSH config or passed on command-line, assuming 22")
			port = "22"
		}
	}

	idf = ssh_config.Get(hostname, "IdentityFile")

	hostConfig.Host = host
	if idf != "" {
		hostConfig.IdentityFile = idf
	}
	hostConfig.Port = port
	hostConfig.User = username

	logrus.Debugf("constructed SSH config %s for %s", hostConfig, hostname)

	return hostConfig, nil
}

// New creates a new SSH client connection.
func New(domainName, sshAuth, username, port string) (*simplessh.Client, error) {
	var client *simplessh.Client

	hostConfig, err := GetHostConfig(domainName, username, port)
	if err != nil {
		return client, err
	}

	if sshAuth == "identity-file" {
		var err error
		client, err = simplessh.ConnectWithAgentTimeout(hostConfig.Host, hostConfig.User, 5*time.Second)
		if err != nil {
			return client, err
		}
	} else {
		password := ""
		prompt := &survey.Password{
			Message: "SSH password?",
		}
		if err := survey.AskOne(prompt, &password); err != nil {
			return client, err
		}

		var err error
		client, err = simplessh.ConnectWithPasswordTimeout(hostConfig.Host, hostConfig.User, password, 5*time.Second)
		if err != nil {
			return client, err
		}
	}

	return client, nil
}

// sudoWriter supports sudo command handling.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
type sudoWriter struct {
	b     bytes.Buffer
	pw    string
	stdin io.Writer
	m     sync.Mutex
}

// Write satisfies the write interface for sudoWriter.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func (w *sudoWriter) Write(p []byte) (int, error) {
	if string(p) == "sudo_password" {
		w.stdin.Write([]byte(w.pw + "\n"))
		w.pw = ""
		return len(p), nil
	}

	w.m.Lock()
	defer w.m.Unlock()

	return w.b.Write(p)
}

// RunSudoCmd runs SSH commands and streams output.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func RunSudoCmd(cmd, passwd string, cl *simplessh.Client) error {
	session, err := cl.SSHClient.NewSession()
	if err != nil {
		return err
	}
	defer session.Close()

	cmd = "sudo -p " + "sudo_password" + " -S " + cmd

	w := &sudoWriter{
		pw: passwd,
	}
	w.stdin, err = session.StdinPipe()
	if err != nil {
		return err
	}

	session.Stdout = w
	session.Stderr = w

	done := make(chan struct{})
	scanner := bufio.NewScanner(session.Stdin)

	go func() {
		for scanner.Scan() {
			line := scanner.Text()
			fmt.Println(line)
		}
		done <- struct{}{}
	}()

	if err := session.Start(cmd); err != nil {
		return err
	}

	<-done

	if err := session.Wait(); err != nil {
		return err
	}

	return err
}

// Exec runs a command on a remote and streams output.
// https://github.com/sfreiberg/simplessh/blob/master/simplessh.go
func Exec(cmd string, cl *simplessh.Client) error {
	session, err := cl.SSHClient.NewSession()
	if err != nil {
		return err
	}
	defer session.Close()

	stdout, err := session.StdoutPipe()
	if err != nil {
		return err
	}

	stderr, err := session.StdoutPipe()
	if err != nil {
		return err
	}

	stdoutDone := make(chan struct{})
	stdoutScanner := bufio.NewScanner(stdout)

	go func() {
		for stdoutScanner.Scan() {
			line := stdoutScanner.Text()
			fmt.Println(line)
		}
		stdoutDone <- struct{}{}
	}()

	stderrDone := make(chan struct{})
	stderrScanner := bufio.NewScanner(stderr)

	go func() {
		for stderrScanner.Scan() {
			line := stderrScanner.Text()
			fmt.Println(line)
		}
		stderrDone <- struct{}{}
	}()

	if err := session.Start(cmd); err != nil {
		return err
	}

	<-stdoutDone
	<-stderrDone

	if err := session.Wait(); err != nil {
		return err
	}

	return nil
}