Commit 6b377ec5 authored by ibuler's avatar ibuler

[Update] 修改启动脚本

parent 44f8b978
...@@ -114,8 +114,6 @@ def hello(name, callback=None): ...@@ -114,8 +114,6 @@ def hello(name, callback=None):
# @after_app_shutdown_clean_periodic # @after_app_shutdown_clean_periodic
# @register_as_period_task(interval=30) # @register_as_period_task(interval=30)
def hello123(): def hello123():
p = subprocess.Popen('ls /tmp', shell=True)
print("{} Hello world".format(datetime.datetime.now().strftime("%H:%M:%S")))
return None return None
......
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
import os import os
import subprocess import subprocess
import threading import threading
import datetime
import logging import logging
import logging.handlers import logging.handlers
import time import time
import argparse import argparse
import sys import sys
import shutil
import signal import signal
from collections import defaultdict from collections import defaultdict
import daemon import daemon
...@@ -21,15 +23,15 @@ try: ...@@ -21,15 +23,15 @@ try:
from apps.jumpserver import const from apps.jumpserver import const
__version__ = const.VERSION __version__ = const.VERSION
except ImportError as e: except ImportError as e:
print("Not found __version__: {}".format(e)) logging.info("Not found __version__: {}".format(e))
print("Sys path: {}".format(sys.path)) logging.info("Sys path: {}".format(sys.path))
print("Python is: ") logging.info("Python is: ")
print(subprocess.call('which python', shell=True)) logging.info(subprocess.call('which python', shell=True))
__version__ = 'Unknown' __version__ = 'Unknown'
try: try:
import apps import apps
print("List apps: {}".format(os.listdir('apps'))) logging.info("List apps: {}".format(os.listdir('apps')))
print('apps is: {}'.format(apps)) logging.info('apps is: {}'.format(apps))
except: except:
pass pass
...@@ -37,8 +39,8 @@ try: ...@@ -37,8 +39,8 @@ try:
from apps.jumpserver.conf import load_user_config from apps.jumpserver.conf import load_user_config
CONFIG = load_user_config() CONFIG = load_user_config()
except ImportError as e: except ImportError as e:
print("Import error: {}".format(e)) logging.info("Import error: {}".format(e))
print("Could not find config file, `cp config_example.yml config.yml`") logging.info("Could not find config file, `cp config_example.yml config.yml`")
sys.exit(1) sys.exit(1)
os.environ["PYTHONIOENCODING"] = "UTF-8" os.environ["PYTHONIOENCODING"] = "UTF-8"
...@@ -54,11 +56,17 @@ LOG_LEVEL = CONFIG.LOG_LEVEL or 'INFO' ...@@ -54,11 +56,17 @@ LOG_LEVEL = CONFIG.LOG_LEVEL or 'INFO'
START_TIMEOUT = 40 START_TIMEOUT = 40
WORKERS = 4 WORKERS = 4
DAEMON = False DAEMON = False
LOG_KEEP_DAYS = 7
logging.basicConfig(
format='%(asctime)s %(message)s', level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
EXIT_EVENT = threading.Event() EXIT_EVENT = threading.Event()
LOCK = threading.Lock() LOCK = threading.Lock()
daemon_pid_file = '' files_preserve = []
logger = logging.getLogger()
try: try:
os.makedirs(os.path.join(BASE_DIR, "data", "static")) os.makedirs(os.path.join(BASE_DIR, "data", "static"))
...@@ -67,83 +75,32 @@ except: ...@@ -67,83 +75,32 @@ except:
pass pass
class LogPipe(threading.Thread):
def __init__(self, name, file_path, to_stdout=False):
"""Setup the object with a logger and a loglevel
and start the thread
"""
threading.Thread.__init__(self)
self.daemon = False
self.name = name
self.file_path = file_path
self.to_stdout = to_stdout
self.fdRead, self.fdWrite = os.pipe()
self.pipeReader = os.fdopen(self.fdRead)
self.logger = self.init_logger()
self.start()
def init_logger(self):
_logger = logging.getLogger(self.name)
_logger.setLevel(logging.INFO)
_formatter = logging.Formatter('%(message)s')
_handler = logging.handlers.RotatingFileHandler(
self.file_path, mode='a', maxBytes=5*1024*1024, backupCount=5
)
_handler.setFormatter(_formatter)
_handler.setLevel(logging.INFO)
_logger.addHandler(_handler)
if self.to_stdout:
_console = logging.StreamHandler()
_console.setLevel(logging.INFO)
_console.setFormatter(_formatter)
_logger.addHandler(_console)
return _logger
def fileno(self):
"""Return the write file descriptor of the pipe
"""
return self.fdWrite
def run(self):
"""Run the thread, logging everything.
"""
for line in iter(self.pipeReader.readline, ''):
self.logger.info(line.strip('\n'))
self.pipeReader.close()
def close(self):
"""Close the write end of the pipe.
"""
os.close(self.fdWrite)
def check_database_connection(): def check_database_connection():
os.chdir(os.path.join(BASE_DIR, 'apps')) os.chdir(os.path.join(BASE_DIR, 'apps'))
for i in range(60): for i in range(60):
print("Check database connection ...") logging.info("Check database connection ...")
code = subprocess.call("python manage.py showmigrations users ", shell=True) code = subprocess.call("python manage.py showmigrations users ", shell=True)
if code == 0: if code == 0:
print("Database connect success") logging.info("Database connect success")
return return
time.sleep(1) time.sleep(1)
print("Connection database failed, exist") logging.info("Connection database failed, exist")
sys.exit(10) sys.exit(10)
def make_migrations(): def make_migrations():
print("Check database structure change ...") logging.info("Check database structure change ...")
os.chdir(os.path.join(BASE_DIR, 'apps')) os.chdir(os.path.join(BASE_DIR, 'apps'))
print("Migrate model change to database ...") logging.info("Migrate model change to database ...")
subprocess.call('python3 manage.py migrate', shell=True) subprocess.call('python3 manage.py migrate', shell=True)
def collect_static(): def collect_static():
print("Collect static files") logging.info("Collect static files")
os.chdir(os.path.join(BASE_DIR, 'apps')) os.chdir(os.path.join(BASE_DIR, 'apps'))
command = 'python3 manage.py collectstatic --no-input -c &> /dev/null ' \ command = 'python3 manage.py collectstatic --no-input -c &> /dev/null '
'&& echo "Collect static file done"'
subprocess.call(command, shell=True) subprocess.call(command, shell=True)
logging.info("Collect static file done")
def prepare(): def prepare():
...@@ -213,8 +170,6 @@ def parse_service(s): ...@@ -213,8 +170,6 @@ def parse_service(s):
return ['daphne'] return ['daphne']
elif s == "task": elif s == "task":
return ["celery_ansible", "celery_default", "beat"] return ["celery_ansible", "celery_default", "beat"]
elif s == 'gunicorn':
return ['gunicorn', 'flower']
elif s == "celery": elif s == "celery":
return ["celery_ansible", "celery_default"] return ["celery_ansible", "celery_default"]
elif "," in s: elif "," in s:
...@@ -326,40 +281,72 @@ processes = {} ...@@ -326,40 +281,72 @@ processes = {}
def watch_services(): def watch_services():
max_retry = 3 max_retry = 3
signal.signal(signal.SIGTERM, lambda x, y: clean_up())
services_retry = defaultdict(int) services_retry = defaultdict(int)
stopped_services = {} stopped_services = {}
def check_services(): def check_services():
now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
for s, p in processes.items(): for s, p in processes.items():
print("{} Check service status: {} -> ".format(now, s), end='')
try: try:
p.wait(timeout=1) p.wait(timeout=1)
stopped_services[s] = ''
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
pass
if p.returncode is not None:
stopped_services[s] = ''
print("stopped")
else:
print("running")
stopped_services.pop(s, None) stopped_services.pop(s, None)
services_retry.pop(s, None) services_retry.pop(s, None)
continue
def retry_start_stopped_services(): def retry_start_stopped_services():
for s in stopped_services: for s in stopped_services:
if services_retry[s] > max_retry: if services_retry[s] > max_retry:
print("\nService start failed, exit: ", s) logging.info("Service start failed, exit: ", s)
EXIT_EVENT.set() EXIT_EVENT.set()
break break
print("\n> Find {} stopped, retry {}".format(
s, services_retry[s] + 1)
)
p = start_service(s) p = start_service(s)
logging.info("> Find {} stopped, retry {}, {}".format(
s, services_retry[s] + 1, p.pid)
)
processes[s] = p processes[s] = p
services_retry[s] += 1 services_retry[s] += 1
def rotate_log_if_need():
now = datetime.datetime.now()
tm = now.strftime('%H:%M')
if tm != '23:59':
return
suffix = now.strftime('%Y-%m-%d')
for s in processes:
log_path = get_log_file_path(s)
log_dir = os.path.dirname(log_path)
filename = os.path.basename(log_path)
pre_log_dir = os.path.join(log_dir, suffix)
if not os.path.exists(pre_log_dir):
os.mkdir(pre_log_dir)
pre_log_path = os.path.join(pre_log_dir, filename)
if os.path.isfile(log_path) and not os.path.isfile(pre_log_path):
logging.info("Rotate log file: {} => {}".format(log_path, pre_log_path))
shutil.copy(log_path, pre_log_path)
with open(log_path, 'w') as f:
pass
some_days_ago = now - datetime.timedelta(days=LOG_KEEP_DAYS)
days_ago_dir = os.path.join(LOG_DIR, some_days_ago.strftime('%Y-%m-%d'))
if os.path.exists(days_ago_dir):
logger.info("Remove old log: {}".format(days_ago_dir))
shutil.rmtree(days_ago_dir, ignore_errors=True)
while not EXIT_EVENT.is_set(): while not EXIT_EVENT.is_set():
try: try:
with LOCK: with LOCK:
check_services() check_services()
retry_start_stopped_services() retry_start_stopped_services()
time.sleep(10) rotate_log_if_need()
time.sleep(30)
except KeyboardInterrupt: except KeyboardInterrupt:
time.sleep(1) time.sleep(1)
break break
...@@ -383,13 +370,11 @@ def start_service(s): ...@@ -383,13 +370,11 @@ def start_service(s):
os.unlink(pid_file) os.unlink(pid_file)
cmd = kwargs.pop('cmd') cmd = kwargs.pop('cmd')
to_stdout = False log_file_path = get_log_file_path(s)
if not DAEMON: log_file_f = open(log_file_path, 'a')
to_stdout = True files_preserve.append(log_file_f)
log_file = get_log_file_path(s) kwargs['stderr'] = log_file_f
_logger = LogPipe(s, log_file, to_stdout=to_stdout) kwargs['stdout'] = log_file_f
stderr = stdout = _logger
kwargs.update({"stderr": stderr, "stdout": stdout})
p = subprocess.Popen(cmd, **kwargs) p = subprocess.Popen(cmd, **kwargs)
with open(pid_file, 'w') as f: with open(pid_file, 'w') as f:
f.write(str(p.pid)) f.write(str(p.pid))
...@@ -397,8 +382,8 @@ def start_service(s): ...@@ -397,8 +382,8 @@ def start_service(s):
def start_services_and_watch(s): def start_services_and_watch(s):
print(time.ctime()) logging.info(time.ctime())
print('Jumpserver version {}, more see https://www.jumpserver.org'.format( logging.info('Jumpserver version {}, more see https://www.jumpserver.org'.format(
__version__) __version__)
) )
...@@ -415,34 +400,46 @@ def start_services_and_watch(s): ...@@ -415,34 +400,46 @@ def start_services_and_watch(s):
watch_services() watch_services()
else: else:
show_service_status(s) show_service_status(s)
global daemon_pid_file context = get_daemon_context()
daemon_pid_file = get_pid_file_path('jms')
context = daemon.DaemonContext(
pidfile=pidfile.TimeoutPIDLockFile(daemon_pid_file),
signal_map={
signal.SIGTERM: clean_up,
signal.SIGHUP: 'terminate',
},
)
with context: with context:
watch_services() watch_services()
def stop_service(s, sig=15): def get_daemon_context():
services_set = parse_service(s) daemon_pid_file = get_pid_file_path('jms')
daemon_log_f = open(get_log_file_path('jms'), 'a')
files_preserve.append(daemon_log_f)
context = daemon.DaemonContext(
pidfile=pidfile.TimeoutPIDLockFile(daemon_pid_file),
stdout=daemon_log_f,
stderr=daemon_log_f,
files_preserve=files_preserve,
detach_process=True,
)
return context
def stop_service(srv, sig=15):
services_set = parse_service(srv)
for s in services_set: for s in services_set:
if not is_running(s): if not is_running(s):
show_service_status(s) show_service_status(s)
continue continue
print("Stop service: {}".format(s)) logging.info("Stop service: {}".format(s))
pid = get_pid(s) pid = get_pid(s)
os.kill(pid, sig) os.kill(pid, sig)
with LOCK: with LOCK:
processes.pop(s, None) processes.pop(s, None)
if s == "all": if srv == "all":
pid = get_pid('jms') stop_daemon_service()
os.kill(pid, sig)
def stop_daemon_service():
pid = get_pid('jms')
logging.info("Daemon pid is: {}".format(pid))
if pid:
os.kill(pid, 15)
def stop_multi_services(services): def stop_multi_services(services):
...@@ -468,9 +465,9 @@ def show_service_status(s): ...@@ -468,9 +465,9 @@ def show_service_status(s):
for ns in services_set: for ns in services_set:
if is_running(ns): if is_running(ns):
pid = get_pid(ns) pid = get_pid(ns)
print("{} is running: {}".format(ns, pid)) logging.info("{} is running: {}".format(ns, pid))
else: else:
print("{} is stopped".format(ns)) logging.info("{} is stopped".format(ns))
if __name__ == '__main__': if __name__ == '__main__':
...@@ -490,7 +487,7 @@ if __name__ == '__main__': ...@@ -490,7 +487,7 @@ if __name__ == '__main__':
) )
parser.add_argument( parser.add_argument(
"service", type=str, default="all", nargs="?", "service", type=str, default="all", nargs="?",
choices=("all", "web", "task", "gunicorn", "celery", "beat", "celery,beat", "flower"), choices=("all", "web", "task", "gunicorn", "celery", "beat", "celery,beat", "flower", "ws"),
help="The service to start", help="The service to start",
) )
parser.add_argument('-d', '--daemon', nargs="?", const=1) parser.add_argument('-d', '--daemon', nargs="?", const=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment