Commit 5ec6eae5 authored by 广宏伟's avatar 广宏伟

Merged in dev (pull request #58)

Dev
parents 2467b4e2 a2d2e325
...@@ -17,7 +17,7 @@ from .sshd import SSHServer ...@@ -17,7 +17,7 @@ from .sshd import SSHServer
from .httpd import HttpServer from .httpd import HttpServer
from .logger import create_logger from .logger import create_logger
from .tasks import TaskHandler from .tasks import TaskHandler
from .recorder import get_command_recorder_class, ServerReplayRecorder from .recorder import ReplayRecorder, CommandRecorder
from .utils import get_logger, register_app, register_service from .utils import get_logger, register_app, register_service
...@@ -56,7 +56,6 @@ class Coco: ...@@ -56,7 +56,6 @@ class Coco:
def __init__(self, root_path=None): def __init__(self, root_path=None):
self.root_path = root_path if root_path else BASE_DIR self.root_path = root_path if root_path else BASE_DIR
self.config = self.config_class(self.root_path, defaults=self.default_config)
self.sessions = [] self.sessions = []
self.clients = [] self.clients = []
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -67,8 +66,15 @@ class Coco: ...@@ -67,8 +66,15 @@ class Coco:
self.replay_recorder_class = None self.replay_recorder_class = None
self.command_recorder_class = None self.command_recorder_class = None
self._task_handler = None self._task_handler = None
self.config = None
self.init_config()
register_app(self) register_app(self)
def init_config(self):
self.config = self.config_class(
self.root_path, defaults=self.default_config
)
@property @property
def name(self): def name(self):
if self.config['NAME']: if self.config['NAME']:
...@@ -111,23 +117,21 @@ class Coco: ...@@ -111,23 +117,21 @@ class Coco:
)) ))
self.config.update(configs) self.config.update(configs)
def get_recorder_class(self): @staticmethod
self.replay_recorder_class = ServerReplayRecorder def new_command_recorder():
self.command_recorder_class = get_command_recorder_class(self.config) return CommandRecorder()
def new_command_recorder(self):
return self.command_recorder_class()
def new_replay_recorder(self): @staticmethod
return self.replay_recorder_class() def new_replay_recorder():
return ReplayRecorder()
def bootstrap(self): def bootstrap(self):
self.make_logger() self.make_logger()
self.service.initial() self.service.initial()
self.load_extra_conf_from_server() self.load_extra_conf_from_server()
self.get_recorder_class()
self.keep_heartbeat() self.keep_heartbeat()
self.monitor_sessions() self.monitor_sessions()
self.monitor_sessions_replay()
def heartbeat(self): def heartbeat(self):
_sessions = [s.to_json() for s in self.sessions] _sessions = [s.to_json() for s in self.sessions]
...@@ -156,6 +160,26 @@ class Coco: ...@@ -156,6 +160,26 @@ class Coco:
thread = threading.Thread(target=func) thread = threading.Thread(target=func)
thread.start() thread.start()
def monitor_sessions_replay(self):
interval = 10
recorder = self.new_replay_recorder()
log_dir = os.path.join(self.config['LOG_DIR'])
def func():
while not self.stop_evt.is_set():
active_sessions = [str(session.id) for session in self.sessions]
for filename in os.listdir(log_dir):
session_id = filename.split('.')[0]
if len(session_id) != 36:
continue
if session_id not in active_sessions:
recorder.file_path = os.path.join(log_dir, filename)
recorder.upload_replay(session_id, 1)
time.sleep(interval)
thread = threading.Thread(target=func)
thread.start()
def monitor_sessions(self): def monitor_sessions(self):
interval = self.config["HEARTBEAT_INTERVAL"] interval = self.config["HEARTBEAT_INTERVAL"]
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
import socket import socket
import uuid import uuid
import traceback from copy import deepcopy
from flask_socketio import SocketIO, Namespace, join_room from flask_socketio import SocketIO, Namespace, join_room
from flask import Flask, request, current_app, redirect from flask import Flask, request, current_app, redirect
......
...@@ -38,6 +38,9 @@ class InteractiveServer: ...@@ -38,6 +38,9 @@ class InteractiveServer:
@search_result.setter @search_result.setter
def search_result(self, value): def search_result(self, value):
if not value:
self._search_result = value
return
value = self.filter_system_users(value) value = self.filter_system_users(value)
self._search_result = value self._search_result = value
...@@ -88,7 +91,7 @@ class InteractiveServer: ...@@ -88,7 +91,7 @@ class InteractiveServer:
result = [] result = []
# 所有的 # 所有的
if q == '': if q in ('', None):
result = self.assets result = self.assets
# 用户输入的是数字,可能想使用id唯一键搜索 # 用户输入的是数字,可能想使用id唯一键搜索
elif q.isdigit() and self.search_result and \ elif q.isdigit() and self.search_result and \
...@@ -242,6 +245,7 @@ class InteractiveServer: ...@@ -242,6 +245,7 @@ class InteractiveServer:
self.search_assets(opt) self.search_assets(opt)
if self.search_result and len(self.search_result) == 1: if self.search_result and len(self.search_result) == 1:
asset = self.search_result[0] asset = self.search_result[0]
self.search_result = None
if asset.platform == "Windows": if asset.platform == "Windows":
self.client.send(warning( self.client.send(warning(
_("终端不支持登录windows, 请使用web terminal访问")) _("终端不支持登录windows, 请使用web terminal访问"))
......
...@@ -249,10 +249,12 @@ class WSProxy: ...@@ -249,10 +249,12 @@ class WSProxy:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
data = self.child.recv(BUF_SIZE) data = self.child.recv(BUF_SIZE)
except OSError: except (OSError, EOFError):
continue self.close()
break
if len(data) == 0: if len(data) == 0:
self.close() self.close()
break
data = data.decode(errors="ignore") data = data.decode(errors="ignore")
self.ws.emit("data", {'data': data, 'room': self.room_id}, self.ws.emit("data", {'data': data, 'room': self.room_id},
room=self.room_id) room=self.room_id)
...@@ -265,6 +267,7 @@ class WSProxy: ...@@ -265,6 +267,7 @@ class WSProxy:
thread.start() thread.start()
def close(self): def close(self):
self.ws.emit("logout", {"room": self.room_id}, room=self.room_id)
self.stop_event.set() self.stop_event.set()
try: try:
self.child.shutdown(1) self.child.shutdown(1)
......
...@@ -8,6 +8,7 @@ import time ...@@ -8,6 +8,7 @@ import time
import os import os
import gzip import gzip
import json import json
from copy import deepcopy
import jms_storage import jms_storage
...@@ -20,60 +21,6 @@ BUF_SIZE = 1024 ...@@ -20,60 +21,6 @@ BUF_SIZE = 1024
class ReplayRecorder(metaclass=abc.ABCMeta): class ReplayRecorder(metaclass=abc.ABCMeta):
def __init__(self, session=None):
self.session = session
@abc.abstractmethod
def record(self, data):
"""
记录replay数据
:param data: 数据 {
"session": "",
"data": "",
"timestamp": ""
}
:return:
"""
@abc.abstractmethod
def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass
@abc.abstractmethod
def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass
class CommandRecorder:
def __init__(self, session=None):
self.session = session
def record(self, data):
"""
:param data: 数据 {
"session":
"input":
"output":
"user":
"asset":
"system_user":
"timestamp":
}
:return:
"""
def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass
def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass
class ServerReplayRecorder(ReplayRecorder):
time_start = None time_start = None
storage = None storage = None
...@@ -81,6 +28,7 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -81,6 +28,7 @@ class ServerReplayRecorder(ReplayRecorder):
super().__init__() super().__init__()
self.file = None self.file = None
self.file_path = None self.file_path = None
self.get_storage()
def record(self, data): def record(self, data):
""" """
...@@ -107,70 +55,68 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -107,70 +55,68 @@ class ServerReplayRecorder(ReplayRecorder):
def session_end(self, session_id): def session_end(self, session_id):
self.file.write('"0":""}') self.file.write('"0":""}')
self.file.close() self.file.close()
if self.upload_replay(session_id): self.upload_replay(session_id)
logger.info("Succeed to push {}'s {}".format(session_id, "record"))
def get_storage(self):
config = deepcopy(current_app.config["REPLAY_STORAGE"])
config["SERVICE"] = app_service
self.storage = jms_storage.get_object_storage(config)
def upload_replay(self, session_id, times=3):
if times < 1:
if self.storage.type == 'jms':
return False
else: else:
logger.error("Failed to push {}'s {}".format(session_id, "record")) self.storage = jms_storage.JMSReplayStorage(app_service)
self.upload_replay(session_id, times=3)
def upload_replay(self, session_id): ok, msg = self.push_to_storage(session_id)
configs = app_service.load_config_from_server() if not ok:
logger.debug("upload_replay print config: {}".format(configs)) msg = 'Failed push replay file: {}, try again {}'.format(msg, times)
self.storage = jms_storage.init(configs["REPLAY_STORAGE"]) logger.warn(msg)
if not self.storage: self.upload_replay(session_id, times-1)
self.storage = jms_storage.jms(app_service) else:
if self.push_file(3, session_id): msg = 'Success push replay file: {}'.format(session_id)
logger.info(msg)
self.finish_replay(3, session_id)
os.unlink(self.file_path) os.unlink(self.file_path)
return True return True
else:
return False
def push_to_storage(self, session_id): def push_to_storage(self, session_id):
dt = time.strftime('%Y-%m-%d', time.localtime(self.time_start)) dt = time.strftime('%Y-%m-%d', time.localtime(self.time_start))
target = dt + '/' + session_id + '.replay.gz' target = dt + '/' + session_id + '.replay.gz'
return self.storage.upload_file(self.file_path, target) return self.storage.upload(self.file_path, target)
def push_file(self, times, session_id):
if times < 0:
if self.storage.type() == 'jms':
return False
else:
msg = "Failed push session {}'s replay log to storage".format(session_id)
logger.error(msg)
self.storage = jms_storage.jms(app_service)
return self.push_file(3, session_id)
if self.push_to_storage(session_id):
logger.info("Success push session: {}'s replay log to storage ".format(session_id))
return True
else:
msg = "Failed push session {}'s replay log to storage, try {} times".format(session_id, times)
logger.error(msg)
return self.push_file(times - 1, session_id)
def finish_replay(self, times, session_id): def finish_replay(self, times, session_id):
if times < 0: if times < 1:
logger.error("Failed finished session {}'s replay".format(session_id)) logger.error(
"Failed finished session {}'s replay".format(session_id)
)
return False return False
if app_service.finish_replay(session_id): if app_service.finish_replay(session_id):
logger.info("Success finish session {}'s replay ".format(session_id)) logger.info(
"Success finish session {}'s replay ".format(session_id)
)
return True return True
else: else:
logger.error("Failed finish session {}'s replay, try {} times".format(session_id, times)) msg = "Failed finish session {}'s replay, try {} times"
logger.error(msg.format(session_id, times))
return self.finish_replay(times - 1, session_id) return self.finish_replay(times - 1, session_id)
class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): class CommandRecorder(metaclass=Singleton):
batch_size = 10 batch_size = 10
timeout = 5 timeout = 5
no = 0 no = 0
storage = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.queue = MemoryQueue() self.queue = MemoryQueue()
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.push_to_server_async() self.push_to_server_async()
self.__class__.no += 1 self.get_storage()
def record(self, data): def record(self, data):
if data and data['input']: if data and data['input']:
...@@ -179,72 +125,22 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -179,72 +125,22 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
data['timestamp'] = int(data['timestamp']) data['timestamp'] = int(data['timestamp'])
self.queue.put(data) self.queue.put(data)
def get_storage(self):
config = deepcopy(current_app.config["COMMAND_STORAGE"])
config['SERVICE'] = app_service
self.storage = jms_storage.get_log_storage(config)
def push_to_server_async(self): def push_to_server_async(self):
def func(): def func():
while not self.stop_evt.is_set(): while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size, timeout=self.timeout) data_set = self.queue.mget(self.batch_size, timeout=self.timeout)
logger.debug("<Session command recorder {}> queue size: {}".format( logger.debug("Session command remain push: {}".format(
self.no, self.queue.qsize()) self.queue.qsize())
) )
if not data_set: if not data_set:
continue continue
logger.debug("Send {} commands to server".format(len(data_set))) logger.debug("Send {} commands to server".format(len(data_set)))
ok = app_service.push_session_command(data_set) ok = self.storage.bulk_save(data_set)
if not ok:
self.queue.mput(data_set)
thread = threading.Thread(target=func)
thread.daemon = True
thread.start()
def session_start(self, session_id):
pass
def session_end(self, session_id):
pass
# def __del__(self):
# print("GC: Session command storage has been gc")
class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
batch_size = 10
timeout = 5
no = 0
default_hosts = ["http://localhost"]
def __init__(self):
super().__init__()
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_es_async()
self.__class__.no += 1
self.store = jms_storage.ESStore(
current_app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts)
)
if not self.store.ping():
raise AssertionError("ESCommand storage init error")
def record(self, data):
if data and data['input']:
data['input'] = data['input'][:128]
data['output'] = data['output'][:1024]
data['timestamp'] = int(data['timestamp'])
self.queue.put(data)
def push_to_es_async(self):
def func():
while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size,
timeout=self.timeout)
logger.debug(
"<Session command recorder {}> queue size: {}".format(
self.no, self.queue.qsize())
)
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.store.bulk_save(data_set)
if not ok: if not ok:
self.queue.mput(data_set) self.queue.mput(data_set)
...@@ -253,25 +149,9 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -253,25 +149,9 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
thread.start() thread.start()
def session_start(self, session_id): def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass pass
def session_end(self, session_id): def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass pass
# def __del__(self):
# print("GC: ES command storage has been gc".format(self))
def get_command_recorder_class(config):
command_storage = config["COMMAND_STORAGE"]
storage_type = command_storage.get('TYPE')
if storage_type == "elasticsearch":
return ESCommandRecorder
else:
return ServerCommandRecorder
#
# def get_replay_recorder_class(config):
# ServerReplayRecorder.client = jms_storage.init(config["REPLAY_STORAGE"])
# return ServerReplayRecorder
...@@ -18,8 +18,7 @@ idna==2.6 ...@@ -18,8 +18,7 @@ idna==2.6
itsdangerous==0.24 itsdangerous==0.24
Jinja2==2.10 Jinja2==2.10
jmespath==0.9.3 jmespath==0.9.3
jms-es-sdk==0.5.2 jms-storage==0.0.15
jms-storage==0.0.12
jumpserver-python-sdk==0.0.42 jumpserver-python-sdk==0.0.42
MarkupSafe==1.0 MarkupSafe==1.0
oss2==2.4.0 oss2==2.4.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment