#!/usr/bin/python3

import os
import sys
import click
import json
import subprocess
import logging
import docker
import restic
import tarfile
import io
from pythonjsonlogger import jsonlogger
from datetime import datetime, timezone
from restic.errors import ResticFailedError
from pathlib import Path
from shutil import copyfile, rmtree

VOLUME_PATH = "/var/lib/docker/volumes/"
SECRET_PATH = '/secrets/'
SERVICE = None

logger = logging.getLogger("backupbot")
logging.addLevelName(55, 'SUMMARY')
setattr(logging, 'SUMMARY', 55)
setattr(logger, 'summary', lambda message, *args, **
        kwargs: logger.log(55, message, *args, **kwargs))


def handle_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return
    logger.critical("Uncaught exception", exc_info=(
        exc_type, exc_value, exc_traceback))


sys.excepthook = handle_exception


@click.group()
@click.option('-l', '--log', 'loglevel')
@click.option('-m', '--machine-logs', 'machine_logs', is_flag=True)
@click.option('service', '--host', '-h', envvar='SERVICE')
@click.option('repository', '--repo', '-r', envvar='RESTIC_REPOSITORY', required=True)
def cli(loglevel, service, repository, machine_logs):
    global SERVICE
    if service:
        SERVICE = service.replace('.', '_')
    if repository:
        os.environ['RESTIC_REPOSITORY'] = repository
    if loglevel:
        numeric_level = getattr(logging, loglevel.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % loglevel)
        logger.setLevel(numeric_level)
    if machine_logs:
        logHandler = logging.StreamHandler()
        formatter = jsonlogger.JsonFormatter(
            "%(levelname)s %(filename)s %(lineno)s %(process)d %(message)s", rename_fields={"levelname": "message_type"})
        logHandler.setFormatter(formatter)
        logger.addHandler(logHandler)

    export_secrets()
    init_repo()


def init_repo():
    repo = os.environ['RESTIC_REPOSITORY']
    logger.debug(f"set restic repository location: {repo}")
    restic.repository = repo
    restic.password_file = '/var/run/secrets/restic_password'
    try:
        restic.cat.config()
    except ResticFailedError as error:
        if 'unable to open config file' in str(error):
            result = restic.init()
            logger.info(f"Initialized restic repo: {result}")
        else:
            raise error


def export_secrets():
    for env in os.environ:
        if env.endswith('FILE') and not "COMPOSE_FILE" in env:
            logger.debug(f"exported secret: {env}")
            with open(os.environ[env]) as file:
                secret = file.read()
                os.environ[env.removesuffix('_FILE')] = secret
                # logger.debug(f"Read secret value: {secret}")


@cli.command()
@click.option('retries', '--retries', '-r', envvar='RETRIES', default=1)
def create(retries):
    pre_commands, post_commands, backup_paths, apps = get_backup_cmds()
    copy_secrets(apps)
    backup_paths.append(SECRET_PATH)
    run_commands(pre_commands)
    backup_volumes(backup_paths, apps, int(retries))
    run_commands(post_commands)


def get_backup_cmds():
    client = docker.from_env()
    container_by_service = {
        c.labels['com.docker.swarm.service.name']: c for c in client.containers.list()}
    backup_paths = set()
    backup_apps = set()
    pre_commands = {}
    post_commands = {}
    services = client.services.list()
    for s in services:
        labels = s.attrs['Spec']['Labels']
        if (backup := labels.get('backupbot.backup')) and bool(backup):
            # volumes: s.attrs['Spec']['TaskTemplate']['ContainerSpec']['Mounts'][0]['Source']
            stack_name = labels['com.docker.stack.namespace']
            # Remove this lines to backup only a specific service
            # This will unfortenately decrease restice performance
            # if SERVICE and SERVICE != stack_name:
            #     continue
            backup_apps.add(stack_name)
            backup_paths = backup_paths.union(
                Path(VOLUME_PATH).glob(f"{stack_name}_*"))
            if not (container := container_by_service.get(s.name)):
                logger.error(
                    f"Container {s.name} is not running, hooks can not be executed")
                continue
            if prehook := labels.get('backupbot.backup.pre-hook'):
                pre_commands[container] = prehook
            if posthook := labels.get('backupbot.backup.post-hook'):
                post_commands[container] = posthook
    return pre_commands, post_commands, list(backup_paths), list(backup_apps)


def copy_secrets(apps):
    # TODO: check if it is deployed
    rmtree(SECRET_PATH, ignore_errors=True)
    os.mkdir(SECRET_PATH)
    client = docker.from_env()
    container_by_service = {
        c.labels['com.docker.swarm.service.name']: c for c in client.containers.list()}
    services = client.services.list()
    for s in services:
        app_name = s.attrs['Spec']['Labels']['com.docker.stack.namespace']
        if (app_name in apps and
                (app_secs := s.attrs['Spec']['TaskTemplate']['ContainerSpec'].get('Secrets'))):
            if not container_by_service.get(s.name):
                logger.error(
                    f"Container {s.name} is not running, secrets can not be copied.")
                continue
            container_id = container_by_service[s.name].id
            for sec in app_secs:
                src = f'/var/lib/docker/containers/{container_id}/mounts/secrets/{sec["SecretID"]}'
                if not Path(src).exists():
                    logger.error(
                        f"For the secret {sec['SecretName']} the file {src} does not exist for {s.name}")
                    continue
                dst = SECRET_PATH + sec['SecretName']
                copyfile(src, dst)


def run_commands(commands):
    for container, command in commands.items():
        if not command:
            continue
        # Remove bash/sh wrapping
        command = command.removeprefix('bash -c').removeprefix('sh -c').removeprefix(' ')
        # Remove quotes surrounding the command
        if (len(command) >= 2 and command[0] == command[-1] and (command[0] == "'" or command[0] == '"')):
            command = command[1:-1]
        # Use bash's pipefail to return exit codes inside a pipe to prevent silent failure
        command = f"bash -c 'set -o pipefail;{command}'"
        logger.info(f"run command in {container.name}:")
        logger.info(command)
        result = container.exec_run(command)
        if result.exit_code:
            logger.error(
                f"Failed to run command {command} in {container.name}: {result.output.decode()}")
        else:
            logger.info(result.output.decode())


def backup_volumes(backup_paths, apps, retries, dry_run=False):
    while True:
        try:
            result = restic.backup(backup_paths, dry_run=dry_run, tags=apps)
            logger.summary("backup finished", extra=result)
            return
        except ResticFailedError as error:
            logger.error(
                f"Backup failed for {apps}. Could not Backup these paths: {backup_paths}")
            logger.error(error, exc_info=True)
            if retries > 0:
                retries -= 1
            else:
                exit(1)


@cli.command()
@click.option('snapshot', '--snapshot', '-s', envvar='SNAPSHOT', default='latest')
@click.option('target', '--target', '-t', envvar='TARGET', default='/')
@click.option('noninteractive', '--noninteractive', envvar='NONINTERACTIVE', is_flag=True)
def restore(snapshot, target, noninteractive):
    # Todo: recommend to shutdown the container
    service_paths = VOLUME_PATH
    if SERVICE:
        service_paths = service_paths + f'{SERVICE}_*'
    snapshots = restic.snapshots(snapshot_id=snapshot)
    if not snapshot:
        logger.error("No Snapshots with ID {snapshots}")
        exit(1)
    if not noninteractive:
        snapshot_date = datetime.fromisoformat(snapshots[0]['time'])
        delta = datetime.now(tz=timezone.utc) - snapshot_date
        print(
            f"You are going to restore Snapshot {snapshot} of {service_paths} at {target}")
        print(f"This snapshot is {delta} old")
        print(
            f"THIS COMMAND WILL IRREVERSIBLY OVERWRITES {target}{service_paths.removeprefix('/')}")
        prompt = input("Type YES (uppercase) to continue: ")
        if prompt != 'YES':
            logger.error("Restore aborted")
            exit(1)
    print(f"Restoring Snapshot {snapshot} of {service_paths} at {target}")
    # TODO: use tags if no snapshot is selected, to use a snapshot including SERVICE
    result = restic.restore(snapshot_id=snapshot,
                            include=service_paths, target_dir=target)
    logger.debug(result)


@cli.command()
def snapshots():
    snapshots = restic.snapshots()
    no_snapshots = True
    for snap in snapshots:
        if not SERVICE or (tags := snap.get('tags')) and SERVICE in tags:
            print(snap['time'], snap['id'])
            no_snapshots = False
    if no_snapshots:
        err_msg = "No Snapshots found"
        if SERVICE:
            service_name = SERVICE.replace('_', '.')
            err_msg += f' for app {service_name}'
        logger.warning(err_msg)


@cli.command()
@click.option('snapshot', '--snapshot', '-s', envvar='SNAPSHOT', default='latest')
@click.option('path', '--path', '-p', envvar='INCLUDE_PATH')
def ls(snapshot, path):
    results = list_files(snapshot, path)
    for r in results:
        if r.get('path'):
            print(f"{r['ctime']}\t{r['path']}")


def list_files(snapshot, path):
    cmd = restic.cat.base_command() + ['ls']
    if SERVICE:
        cmd = cmd + ['--tag', SERVICE]
    cmd.append(snapshot)
    if path:
        cmd.append(path)
    try:
        output = restic.internal.command_executor.execute(cmd)
    except ResticFailedError as error:
        if 'no snapshot found' in str(error):
            err_msg = f'There is no snapshot "{snapshot}"'
            if SERVICE:
                err_msg += f' for the app "{SERVICE}"'
            logger.error(err_msg)
            exit(1)
        else:
            raise error
    output = output.replace('}\n{', '}|{')
    results = list(map(json.loads, output.split('|')))
    return results


@cli.command()
@click.option('snapshot', '--snapshot', '-s', envvar='SNAPSHOT', default='latest')
@click.option('path', '--path', '-p', envvar='INCLUDE_PATH')
@click.option('volumes', '--volumes', '-v', envvar='VOLUMES')
@click.option('secrets', '--secrets', '-c', is_flag=True, envvar='SECRETS')
def download(snapshot, path, volumes, secrets):
    file_dumps = []
    if not any([path, volumes, secrets]):
        volumes = secrets = True
    if path:
        path = path.removesuffix('/')
        binary_output = dump(snapshot, path)
        files = list_files(snapshot, path)
        filetype = [f.get('type') for f in files if f.get('path') == path][0]
        filename = Path(path).name
        if filetype == 'dir':
            filename = filename + ".tar"
        tarinfo = tarfile.TarInfo(name=filename)
        tarinfo.size = len(binary_output)
        file_dumps.append((binary_output, tarinfo))
    if volumes:
        if not SERVICE:
            logger.error("Please specify '--host' when using '--volumes'")
            exit(1)
        files = list_files(snapshot, VOLUME_PATH)
        for f in files[1:]:
            path = f['path']
            if Path(path).name.startswith(SERVICE) and f['type'] == 'dir':
                binary_output = dump(snapshot, path)
                filename = f"{Path(path).name}.tar"
                tarinfo = tarfile.TarInfo(name=filename)
                tarinfo.size = len(binary_output)
                file_dumps.append((binary_output, tarinfo))
    if secrets:
        if not SERVICE:
            logger.error("Please specify '--host' when using '--secrets'")
            exit(1)
        filename = f"{SERVICE}.json"
        files = list_files(snapshot, SECRET_PATH)
        secrets = {}
        for f in files[1:]:
            path = f['path']
            if Path(path).name.startswith(SERVICE) and f['type'] == 'file':
                secret = dump(snapshot, path).decode()
                secret_name = path.removeprefix(f'{SECRET_PATH}{SERVICE}_')
                secrets[secret_name] = secret
        binary_output = json.dumps(secrets).encode()
        tarinfo = tarfile.TarInfo(name=filename)
        tarinfo.size = len(binary_output)
        file_dumps.append((binary_output, tarinfo))
    with tarfile.open('/tmp/backup.tar.gz', "w:gz") as tar:
        print(f"Writing files to /tmp/backup.tar.gz...")
        for binary_output, tarinfo in file_dumps:
            tar.addfile(tarinfo, fileobj=io.BytesIO(binary_output))
    size = get_formatted_size('/tmp/backup.tar.gz')
    print(
        f"Backup has been written to /tmp/backup.tar.gz with a size of {size}")


def get_formatted_size(file_path):
    file_size = os.path.getsize(file_path)
    units = ['Bytes', 'KB', 'MB', 'GB', 'TB']
    for unit in units:
        if file_size < 1024:
            return f"{round(file_size, 3)} {unit}"
        file_size /= 1024
    return f"{round(file_size, 3)} {units[-1]}"


def dump(snapshot, path):
    cmd = restic.cat.base_command() + ['dump']
    if SERVICE:
        cmd = cmd + ['--tag', SERVICE]
    cmd = cmd + [snapshot, path]
    print(f"Dumping {path} from snapshot '{snapshot}'")
    output = subprocess.run(cmd, capture_output=True)
    if output.returncode:
        logger.error(
            f"error while dumping {path} from snapshot '{snapshot}': {output.stderr}")
        exit(1)
    return output.stdout


if __name__ == '__main__':
    cli()