#!/usr/bin/env python3
# coding: utf-8

import os
import subprocess
import threading
import logging
import logging.handlers
import time
import argparse
import sys
import signal
from collections import defaultdict
import daemon
from daemon import pidfile

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, BASE_DIR)

try:
    from apps.jumpserver import const
    __version__ = const.VERSION
except ImportError as e:
    print("Not found __version__: {}".format(e))
    print("Sys path: {}".format(sys.path))
    print("Python is: ")
    print(subprocess.call('which python', shell=True))
    __version__ = 'Unknown'
    try:
        import apps
        print("List apps: {}".format(os.listdir('apps')))
        print('apps is: {}'.format(apps))
    except:
        pass

try:
    from apps.jumpserver.conf import load_user_config
    CONFIG = load_user_config()
except ImportError as e:
    print("Import error: {}".format(e))
    print("Could not find config file, `cp config_example.yml config.yml`")
    sys.exit(1)

os.environ["PYTHONIOENCODING"] = "UTF-8"
APPS_DIR = os.path.join(BASE_DIR, 'apps')
LOG_DIR = os.path.join(BASE_DIR, 'logs')
TMP_DIR = os.path.join(BASE_DIR, 'tmp')
HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1'
HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080
WS_PORT = CONFIG.WS_LISTEN_PORT or 8082
DEBUG = CONFIG.DEBUG or False
LOG_LEVEL = CONFIG.LOG_LEVEL or 'INFO'

START_TIMEOUT = 40
WORKERS = 4
DAEMON = False

EXIT_EVENT = threading.Event()
LOCK = threading.Lock()
daemon_pid_file = ''


try:
    os.makedirs(os.path.join(BASE_DIR, "data", "static"))
    os.makedirs(os.path.join(BASE_DIR, "data", "media"))
except:
    pass


class LogPipe(threading.Thread):

    def __init__(self, name, file_path, to_stdout=False):
        """Setup the object with a logger and a loglevel
        and start the thread
        """
        threading.Thread.__init__(self)
        self.daemon = False
        self.name = name
        self.file_path = file_path
        self.to_stdout = to_stdout
        self.fdRead, self.fdWrite = os.pipe()
        self.pipeReader = os.fdopen(self.fdRead)
        self.logger = self.init_logger()
        self.start()

    def init_logger(self):
        _logger = logging.getLogger(self.name)
        _logger.setLevel(logging.INFO)
        _formatter = logging.Formatter('%(message)s')
        _handler = logging.handlers.RotatingFileHandler(
            self.file_path, mode='a', maxBytes=5*1024*1024, backupCount=5
        )
        _handler.setFormatter(_formatter)
        _handler.setLevel(logging.INFO)
        _logger.addHandler(_handler)
        if self.to_stdout:
            _console = logging.StreamHandler()
            _console.setLevel(logging.INFO)
            _console.setFormatter(_formatter)
            _logger.addHandler(_console)
        return _logger

    def fileno(self):
        """Return the write file descriptor of the pipe
        """
        return self.fdWrite

    def run(self):
        """Run the thread, logging everything.
        """
        for line in iter(self.pipeReader.readline, ''):
            self.logger.info(line.strip('\n'))
        self.pipeReader.close()

    def close(self):
        """Close the write end of the pipe.
        """
        os.close(self.fdWrite)


def check_database_connection():
    os.chdir(os.path.join(BASE_DIR, 'apps'))
    for i in range(60):
        print("Check database connection ...")
        code = subprocess.call("python manage.py showmigrations users ", shell=True)
        if code == 0:
            print("Database connect success")
            return
        time.sleep(1)
    print("Connection database failed, exist")
    sys.exit(10)


def make_migrations():
    print("Check database structure change ...")
    os.chdir(os.path.join(BASE_DIR, 'apps'))
    print("Migrate model change to database ...")
    subprocess.call('python3 manage.py migrate', shell=True)


def collect_static():
    print("Collect static files")
    os.chdir(os.path.join(BASE_DIR, 'apps'))
    command = 'python3 manage.py collectstatic --no-input -c &> /dev/null ' \
              '&& echo "Collect static file done"'
    subprocess.call(command, shell=True)


def prepare():
    check_database_connection()
    make_migrations()
    collect_static()


def check_pid(pid):
    """ Check For the existence of a unix pid. """
    try:
        os.kill(pid, 0)
    except OSError:
        return False
    else:
        return True


def get_pid_file_path(s):
    return os.path.join('/tmp', '{}.pid'.format(s))


def get_log_file_path(s):
    return os.path.join(LOG_DIR, '{}.log'.format(s))


def get_pid_from_file(path):
    if os.path.isfile(path):
        with open(path) as f:
            try:
                return int(f.read().strip())
            except ValueError:
                return 0
    return 0


def get_pid(s):
    pid_file = get_pid_file_path(s)
    return get_pid_from_file(pid_file)


def is_running(s, unlink=True):
    pid_file = get_pid_file_path(s)

    if os.path.isfile(pid_file):
        pid = get_pid(s)
        if pid == 0:
            return False
        elif check_pid(pid):
            return True

        if unlink:
            os.unlink(pid_file)
    return False


def parse_service(s):
    all_services = [
        'gunicorn', 'celery_ansible', 'celery_default',
        'beat', 'flower', 'daphne',
    ]
    if s == 'all':
        return all_services
    elif s == "web":
        return ['gunicorn', 'flower', 'daphne']
    elif s == 'ws':
        return ['daphne']
    elif s == "task":
        return ["celery_ansible", "celery_default", "beat"]
    elif s == 'gunicorn':
        return ['gunicorn', 'flower']
    elif s == "celery":
        return ["celery_ansible", "celery_default"]
    elif "," in s:
        services = set()
        for i in s.split(','):
            services.update(parse_service(i))
        return services
    else:
        return [s]


def get_start_gunicorn_kwargs():
    print("\n- Start Gunicorn WSGI HTTP Server")
    prepare()
    bind = '{}:{}'.format(HTTP_HOST, HTTP_PORT)
    log_format = '%(h)s %(t)s "%(r)s" %(s)s %(b)s '

    cmd = [
        'gunicorn', 'jumpserver.wsgi',
        '-b', bind,
        '-k', 'gthread',
        '--threads', '10',
        '-w', str(WORKERS),
        '--max-requests', '4096',
        '--access-logformat', log_format,
        '--access-logfile', '-'
    ]

    if DEBUG:
        cmd.append('--reload')
    return {'cmd': cmd, 'cwd': APPS_DIR}


def get_start_daphne_kwargs():
    print("\n- Start Daphne ASGI WS Server")
    cmd = [
        'daphne', 'jumpserver.asgi:application',
        '-b', HTTP_HOST,
        '-p', str(WS_PORT),
    ]
    return {'cmd': cmd, 'cwd': APPS_DIR}


def get_start_celery_ansible_kwargs():
    print("\n- Start Celery as Distributed Task Queue: Ansible")
    return get_start_worker_kwargs('ansible', 4)


def get_start_celery_default_kwargs():
    print("\n- Start Celery as Distributed Task Queue: Celery")
    return get_start_worker_kwargs('celery', 2)


def get_start_worker_kwargs(queue, num):
    # Todo: Must set this environment, otherwise not no ansible result return
    os.environ.setdefault('PYTHONOPTIMIZE', '1')
    os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')

    if os.getuid() == 0:
        os.environ.setdefault('C_FORCE_ROOT', '1')

    cmd = [
        'celery', 'worker',
        '-A', 'ops',
        '-l', 'INFO',
        '-c', str(num),
        '-Q', queue,
        '-n', '{}@%h'.format(queue)
    ]
    return {"cmd": cmd, "cwd": APPS_DIR}


def get_start_flower_kwargs():
    print("\n- Start Flower as Task Monitor")
    if os.getuid() == 0:
        os.environ.setdefault('C_FORCE_ROOT', '1')

    cmd = [
        'celery', 'flower',
        '-A', 'ops',
        '-l', 'INFO',
        '--url_prefix=flower',
        '--auto_refresh=False',
        '--max_tasks=1000',
        '--tasks_columns=uuid,name,args,state,received,started,runtime,worker'
    ]
    return {"cmd": cmd, "cwd": APPS_DIR}


def get_start_beat_kwargs():
    print("\n- Start Beat as Periodic Task Scheduler")
    os.environ.setdefault('PYTHONOPTIMIZE', '1')
    if os.getuid() == 0:
        os.environ.setdefault('C_FORCE_ROOT', '1')

    scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
    cmd = [
        'celery', 'beat',
        '-A', 'ops',
        '-l', 'INFO',
        '--scheduler', scheduler,
        '--max-interval', '60'
    ]
    return {"cmd": cmd, 'cwd': APPS_DIR}


processes = {}


def watch_services():
    max_retry = 3
    signal.signal(signal.SIGTERM, lambda x, y: clean_up())
    services_retry = defaultdict(int)
    stopped_services = {}

    def check_services():
        for s, p in processes.items():
            try:
                p.wait(timeout=1)
                stopped_services[s] = ''
            except subprocess.TimeoutExpired:
                stopped_services.pop(s, None)
                services_retry.pop(s, None)
                continue

    def retry_start_stopped_services():
        for s in stopped_services:
            if services_retry[s] > max_retry:
                print("\nService start failed, exit: ", s)
                EXIT_EVENT.set()
                break

            print("\n> Find {} stopped, retry {}".format(
                s, services_retry[s] + 1)
            )
            p = start_service(s)
            processes[s] = p
            services_retry[s] += 1

    while not EXIT_EVENT.is_set():
        try:
            with LOCK:
                check_services()
            retry_start_stopped_services()
            time.sleep(10)
        except KeyboardInterrupt:
            time.sleep(1)
            break
    clean_up()


def start_service(s):
    services_kwargs = {
        "gunicorn": get_start_gunicorn_kwargs,
        "celery_ansible": get_start_celery_ansible_kwargs,
        "celery_default": get_start_celery_default_kwargs,
        "beat": get_start_beat_kwargs,
        "flower": get_start_flower_kwargs,
        "daphne": get_start_daphne_kwargs,
    }

    kwargs = services_kwargs.get(s)()
    pid_file = get_pid_file_path(s)

    if os.path.isfile(pid_file):
        os.unlink(pid_file)
    cmd = kwargs.pop('cmd')

    to_stdout = False
    if not DAEMON:
        to_stdout = True
    log_file = get_log_file_path(s)
    _logger = LogPipe(s, log_file, to_stdout=to_stdout)
    stderr = stdout = _logger
    kwargs.update({"stderr": stderr, "stdout": stdout})
    p = subprocess.Popen(cmd, **kwargs)
    with open(pid_file, 'w') as f:
        f.write(str(p.pid))
    return p


def start_services_and_watch(s):
    print(time.ctime())
    print('Jumpserver version {}, more see https://www.jumpserver.org'.format(
        __version__)
    )

    services_set = parse_service(s)
    for i in services_set:
        if is_running(i):
            show_service_status(i)
            continue
        p = start_service(i)
        time.sleep(2)
        processes[i] = p

    if not DAEMON:
        watch_services()
    else:
        show_service_status(s)
        global daemon_pid_file
        daemon_pid_file = get_pid_file_path('jms')
        context = daemon.DaemonContext(
            pidfile=pidfile.TimeoutPIDLockFile(daemon_pid_file),
            signal_map={
                signal.SIGTERM: clean_up,
                signal.SIGHUP: 'terminate',
            },
        )
        with context:
            watch_services()


def stop_service(s, sig=15):
    services_set = parse_service(s)
    for s in services_set:
        if not is_running(s):
            show_service_status(s)
            continue
        print("Stop service: {}".format(s))
        pid = get_pid(s)
        os.kill(pid, sig)
        with LOCK:
            processes.pop(s, None)

    if s == "all":
        pid = get_pid('jms')
        os.kill(pid, sig)


def stop_multi_services(services):
    for s in services:
        stop_service(s, sig=9)


def stop_service_force(s):
    stop_service(s, sig=9)


def clean_up():
    if not EXIT_EVENT.is_set():
        EXIT_EVENT.set()
    processes_dump = {k: v for k, v in processes.items()}
    for s1, p1 in processes_dump.items():
        stop_service(s1)
        p1.wait()


def show_service_status(s):
    services_set = parse_service(s)
    for ns in services_set:
        if is_running(ns):
            pid = get_pid(ns)
            print("{} is running: {}".format(ns, pid))
        else:
            print("{} is stopped".format(ns))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="""
        Jumpserver service control tools;

        Example: \r\n

        %(prog)s start all -d;
        """
    )
    parser.add_argument(
        'action', type=str,
        choices=("start", "stop", "restart", "status"),
        help="Action to run"
    )
    parser.add_argument(
        "service", type=str, default="all", nargs="?",
        choices=("all", "web", "task", "gunicorn", "celery", "beat", "celery,beat", "flower"),
        help="The service to start",
    )
    parser.add_argument('-d', '--daemon', nargs="?", const=1)
    parser.add_argument('-w', '--worker', type=int, nargs="?", const=4)
    args = parser.parse_args()
    if args.daemon:
        DAEMON = True

    if args.worker:
        WORKERS = args.worker

    action = args.action
    srv = args.service

    if action == "start":
        start_services_and_watch(srv)
        os._exit(0)
    elif action == "stop":
        stop_service(srv)
    elif action == "restart":
        DAEMON = True
        stop_service(srv)
        time.sleep(5)
        start_services_and_watch(srv)
    else:
        show_service_status(srv)