Commit e1a3cf3b authored by ibuler's avatar ibuler

Merge branch 'dev' into test

parents 87d152d2 808da201
...@@ -42,13 +42,13 @@ class Coco: ...@@ -42,13 +42,13 @@ class Coco:
'LOG_DIR': os.path.join(BASE_DIR, 'logs'), 'LOG_DIR': os.path.join(BASE_DIR, 'logs'),
'SESSION_DIR': os.path.join(BASE_DIR, 'sessions'), 'SESSION_DIR': os.path.join(BASE_DIR, 'sessions'),
'ASSET_LIST_SORT_BY': 'hostname', # hostname, ip 'ASSET_LIST_SORT_BY': 'hostname', # hostname, ip
'SSH_PASSWORD_AUTH': True, 'PASSWORD_AUTH': True,
'SSH_PUBLIC_KEY_AUTH': True, 'PUBLIC_KEY_AUTH': True,
'HEARTBEAT_INTERVAL': 5, 'HEARTBEAT_INTERVAL': 5,
'MAX_CONNECTIONS': 500, 'MAX_CONNECTIONS': 500,
'ADMINS': '', 'ADMINS': '',
'REPLAY_RECORD_ENGINE': 'server', # local, server 'COMMAND_STORAGE': {'TYPE': 'server'}, # server
'COMMAND_RECORD_ENGINE': 'server', # local, server, elasticsearch(not yet) 'REPLAY_RECORD_ENGINE': 'server',
} }
def __init__(self, name=None, root_path=None): def __init__(self, name=None, root_path=None):
...@@ -93,16 +93,17 @@ class Coco: ...@@ -93,16 +93,17 @@ class Coco:
def make_logger(self): def make_logger(self):
create_logger(self) create_logger(self)
# Todo: load some config from server like replay and common upload
def load_extra_conf_from_server(self): def load_extra_conf_from_server(self):
pass configs = self.service.load_config_from_server()
self.config.update(configs)
def initial_recorder(self): def get_recorder_class(self):
self.replay_recorder_class = get_replay_recorder_class(self) self.replay_recorder_class = get_replay_recorder_class(self.config)
self.command_recorder_class = get_command_recorder_class(self) self.command_recorder_class = get_command_recorder_class(self.config)
def new_command_recorder(self): def new_command_recorder(self):
return self.command_recorder_class(self) recorder = self.command_recorder_class(self)
return recorder
def new_replay_recorder(self): def new_replay_recorder(self):
return self.replay_recorder_class(self) return self.replay_recorder_class(self)
...@@ -111,7 +112,7 @@ class Coco: ...@@ -111,7 +112,7 @@ class Coco:
self.make_logger() self.make_logger()
self.service.initial() self.service.initial()
self.load_extra_conf_from_server() self.load_extra_conf_from_server()
self.initial_recorder() self.get_recorder_class()
self.keep_heartbeat() self.keep_heartbeat()
self.monitor_sessions() self.monitor_sessions()
...@@ -150,10 +151,10 @@ class Coco: ...@@ -150,10 +151,10 @@ class Coco:
for s in self.sessions: for s in self.sessions:
if not s.stop_evt.is_set(): if not s.stop_evt.is_set():
continue continue
if s.date_finished is None: if s.date_end is None:
self.remove_session(s) self.remove_session(s)
continue continue
delta = datetime.datetime.now() - s.date_finished delta = datetime.datetime.now() - s.date_end
if delta > datetime.timedelta(seconds=interval*5): if delta > datetime.timedelta(seconds=interval*5):
self.remove_session(s) self.remove_session(s)
time.sleep(interval) time.sleep(interval)
......
...@@ -26,7 +26,6 @@ class BaseWebSocketHandler: ...@@ -26,7 +26,6 @@ class BaseWebSocketHandler:
def prepare(self, request): def prepare(self, request):
# self.app = self.settings["app"] # self.app = self.settings["app"]
child, parent = socket.socketpair()
if request.headers.getlist("X-Forwarded-For"): if request.headers.getlist("X-Forwarded-For"):
remote_ip = request.headers.getlist("X-Forwarded-For")[0] remote_ip = request.headers.getlist("X-Forwarded-For")[0]
else: else:
...@@ -36,10 +35,6 @@ class BaseWebSocketHandler: ...@@ -36,10 +35,6 @@ class BaseWebSocketHandler:
self.clients[request.sid]["request"].meta = {"width": self.clients[request.sid]["cols"], self.clients[request.sid]["request"].meta = {"width": self.clients[request.sid]["cols"],
"height": self.clients[request.sid]["rows"]} "height": self.clients[request.sid]["rows"]}
# self.request.__dict__.update(request.__dict__) # self.request.__dict__.update(request.__dict__)
self.clients[request.sid]["client"] = Client(parent, self.clients[request.sid]["request"])
self.clients[request.sid]["proxy"] = WSProxy(self, child, self.clients[request.sid]["room"])
self.app.clients.append(self.clients[request.sid]["client"])
self.clients[request.sid]["forwarder"] = ProxyServer(self.app, self.clients[request.sid]["client"])
def check_origin(self, origin): def check_origin(self, origin):
return True return True
...@@ -64,9 +59,10 @@ class SSHws(Namespace, BaseWebSocketHandler): ...@@ -64,9 +59,10 @@ class SSHws(Namespace, BaseWebSocketHandler):
"cols": int(request.cookies.get('cols', 80)), "cols": int(request.cookies.get('cols', 80)),
"rows": int(request.cookies.get('rows', 24)), "rows": int(request.cookies.get('rows', 24)),
"room": room, "room": room,
"chan": None, # "chan": dict(),
"proxy": None, "proxy": dict(),
"client": None, "client": dict(),
"forwarder": dict(),
"request": None, "request": None,
} }
self.rooms[room] = { self.rooms[room] = {
...@@ -80,18 +76,31 @@ class SSHws(Namespace, BaseWebSocketHandler): ...@@ -80,18 +76,31 @@ class SSHws(Namespace, BaseWebSocketHandler):
self.prepare(request) self.prepare(request)
def on_data(self, message): def on_data(self, message):
if self.clients[request.sid]["proxy"]: if message['room'] and self.clients[request.sid]["proxy"][message['room']]:
self.clients[request.sid]["proxy"].send({"data": message}) self.clients[request.sid]["proxy"][message['room']].send({"data": message['data']})
def on_host(self, message): def on_host(self, message):
# 此处获取主机的信息 # 此处获取主机的信息
uuid = message.get('uuid', None) connection = str(uuid.uuid4())
assetID = message.get('uuid', None)
userid = message.get('userid', None) userid = message.get('userid', None)
if uuid and userid: self.emit('room', {'room': connection, 'secret': message['secret']})
asset = self.app.service.get_asset(uuid)
if assetID and userid:
asset = self.app.service.get_asset(assetID)
system_user = self.app.service.get_system_user(userid) system_user = self.app.service.get_system_user(userid)
if system_user: if system_user:
self.socketio.start_background_task(self.clients[request.sid]["forwarder"].proxy, asset, system_user)
child, parent = socket.socketpair()
self.clients[request.sid]["client"][connection] = Client(parent, self.clients[request.sid]["request"])
self.clients[request.sid]["proxy"][connection] = WSProxy(self, child, self.clients[request.sid]["room"],
connection)
self.app.clients.append(self.clients[request.sid]["client"][connection])
self.clients[request.sid]["forwarder"][connection] = ProxyServer(self.app,
self.clients[request.sid]["client"][connection])
self.socketio.start_background_task(self.clients[request.sid]["forwarder"][connection].proxy, asset,
system_user)
# self.forwarder.proxy(self.asset, system_user) # self.forwarder.proxy(self.asset, system_user)
else: else:
self.on_disconnect() self.on_disconnect()
...@@ -125,13 +134,21 @@ class SSHws(Namespace, BaseWebSocketHandler): ...@@ -125,13 +134,21 @@ class SSHws(Namespace, BaseWebSocketHandler):
def on_disconnect(self): def on_disconnect(self):
self.on_leave(self.clients[request.sid]["room"]) self.on_leave(self.clients[request.sid]["room"])
try: try:
# todo: there maybe have bug del self.clients[request.sid]
self.clients[request.sid]["proxy"].close()
except: except:
pass pass
# self.ssh.close() # self.ssh.close()
pass pass
def on_logout(self, connection):
print("logout", connection)
if connection:
self.clients[request.sid]["proxy"][connection].close()
del self.clients[request.sid]["proxy"][connection]
del self.clients[request.sid]["forwarder"][connection]
self.clients[request.sid]["client"][connection].close()
del self.clients[request.sid]["client"][connection]
class HttpServer: class HttpServer:
# prepare may be rewrite it # prepare may be rewrite it
...@@ -155,7 +172,7 @@ class HttpServer: ...@@ -155,7 +172,7 @@ class HttpServer:
port = self.app.config["HTTPD_PORT"] port = self.app.config["HTTPD_PORT"]
print('Starting websocket server at {}:{}'.format(host, port)) print('Starting websocket server at {}:{}'.format(host, port))
self.socketio.on_namespace(SSHws('/ssh').app(self.app)) self.socketio.on_namespace(SSHws('/ssh').app(self.app))
self.socketio.init_app(self.flask) self.socketio.init_app(self.flask, async_mode="threading")
self.socketio.run(self.flask, port=port, host=host) self.socketio.run(self.flask, port=port, host=host)
def shutdown(self): def shutdown(self):
......
...@@ -28,7 +28,7 @@ class InteractiveServer: ...@@ -28,7 +28,7 @@ class InteractiveServer:
self.client = client self.client = client
self.request = client.request self.request = client.request
self.assets = None self.assets = None
self.search_result = None self._search_result = None
self.asset_groups = None self.asset_groups = None
self.get_user_assets_async() self.get_user_assets_async()
self.get_user_asset_groups_async() self.get_user_asset_groups_async()
...@@ -37,6 +37,18 @@ class InteractiveServer: ...@@ -37,6 +37,18 @@ class InteractiveServer:
def app(self): def app(self):
return self._app() return self._app()
@property
def search_result(self):
if self._search_result:
return self._search_result
else:
return None
@search_result.setter
def search_result(self, value):
value = self.filter_system_users(value)
self._search_result = value
def display_banner(self): def display_banner(self):
self.client.send(char.CLEAR_CHAR) self.client.send(char.CLEAR_CHAR)
logo_path = os.path.join(self.app.root_path, "logo.txt") logo_path = os.path.join(self.app.root_path, "logo.txt")
...@@ -219,14 +231,13 @@ class InteractiveServer: ...@@ -219,14 +231,13 @@ class InteractiveServer:
def filter_system_users(assets): def filter_system_users(assets):
for asset in assets: for asset in assets:
system_users_granted = asset.system_users_granted system_users_granted = asset.system_users_granted
high_priority = max([s.priority for s in system_users_granted]) high_priority = max([s.priority for s in system_users_granted]) if system_users_granted else 1
system_users_cleaned = [s for s in system_users_granted if s.priority == high_priority] system_users_cleaned = [s for s in system_users_granted if s.priority == high_priority]
asset.system_users_granted = system_users_cleaned asset.system_users_granted = system_users_cleaned
return assets return assets
def get_user_assets(self): def get_user_assets(self):
assets = self.app.service.get_user_assets(self.client.user) self.assets = self.app.service.get_user_assets(self.client.user)
self.assets = self.filter_system_users(assets)
logger.debug("Get user {} assets total: {}".format(self.client.user, len(self.assets))) logger.debug("Get user {} assets total: {}".format(self.client.user, len(self.assets)))
def get_user_assets_async(self): def get_user_assets_async(self):
...@@ -261,7 +272,7 @@ class InteractiveServer: ...@@ -261,7 +272,7 @@ class InteractiveServer:
def search_and_proxy(self, opt): def search_and_proxy(self, opt):
self.search_assets(opt) self.search_assets(opt)
if len(self.search_result) == 1: if self.search_result and len(self.search_result) == 1:
self.proxy(self.search_result[0]) self.proxy(self.search_result[0])
else: else:
self.display_search_result() self.display_search_result()
......
...@@ -43,9 +43,9 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -43,9 +43,9 @@ class SSHInterface(paramiko.ServerInterface):
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
supported = [] supported = []
if self.app.config["SSH_PASSWORD_AUTH"]: if self.app.config["PASSWORD_AUTH"]:
supported.append("password") supported.append("password")
if self.app.config["SSH_PUBLIC_KEY_AUTH"]: if self.app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey") supported.append("publickey")
return ",".join(supported) return ",".join(supported)
......
...@@ -186,7 +186,7 @@ class WSProxy: ...@@ -186,7 +186,7 @@ class WSProxy:
``` ```
""" """
def __init__(self, ws, child, room): def __init__(self, ws, child, room, connection):
""" """
:param ws: websocket instance or handler, have write_message method :param ws: websocket instance or handler, have write_message method
:param child: sock child pair :param child: sock child pair
...@@ -196,6 +196,7 @@ class WSProxy: ...@@ -196,6 +196,7 @@ class WSProxy:
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.room = room self.room = room
self.auto_forward() self.auto_forward()
self.connection = connection
def send(self, msg): def send(self, msg):
""" """
...@@ -215,7 +216,7 @@ class WSProxy: ...@@ -215,7 +216,7 @@ class WSProxy:
data = self.child.recv(BUF_SIZE) data = self.child.recv(BUF_SIZE)
if len(data) == 0: if len(data) == 0:
self.close() self.close()
self.ws.emit("data", data.decode("utf-8"), room=self.room) self.ws.emit("data", {'data': data.decode("utf-8"), 'room': self.connection}, room=self.room)
def auto_forward(self): def auto_forward(self):
thread = threading.Thread(target=self.forward, args=()) thread = threading.Thread(target=self.forward, args=())
...@@ -225,4 +226,3 @@ class WSProxy: ...@@ -225,4 +226,3 @@ class WSProxy:
def close(self): def close(self):
self.stop_event.set() self.stop_event.set()
self.child.close() self.child.close()
self.ws.on_disconnect()
...@@ -7,8 +7,8 @@ import threading ...@@ -7,8 +7,8 @@ import threading
import logging import logging
import time import time
import weakref import weakref
import paramiko import paramiko
from paramiko.ssh_exception import SSHException
from .session import Session from .session import Session
from .models import Server from .models import Server
...@@ -130,7 +130,10 @@ class ProxyServer: ...@@ -130,7 +130,10 @@ class ProxyServer:
width = self.request.meta.get('width', 80) width = self.request.meta.get('width', 80)
height = self.request.meta.get('height', 24) height = self.request.meta.get('height', 24)
logger.debug("Change win size: %s - %s" % (width, height)) logger.debug("Change win size: %s - %s" % (width, height))
try:
self.server.chan.resize_pty(width=width, height=height) self.server.chan.resize_pty(width=width, height=height)
except SSHException:
break
def watch_win_size_change_async(self): def watch_win_size_change_async(self):
thread = threading.Thread(target=self.watch_win_size_change) thread = threading.Thread(target=self.watch_win_size_change)
......
...@@ -11,6 +11,8 @@ import gzip ...@@ -11,6 +11,8 @@ import gzip
import json import json
import shutil import shutil
from jms_es_sdk import ESStore
from .alignment import MemoryQueue from .alignment import MemoryQueue
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -183,17 +185,69 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -183,17 +185,69 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
print("{} has been gc".format(self)) print("{} has been gc".format(self))
def get_command_recorder_class(app): class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
command_engine = app.config["COMMAND_RECORD_ENGINE"] batch_size = 10
timeout = 5
no = 0
if command_engine == "server": def __init__(self, app):
return ServerCommandRecorder super().__init__(app)
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_es_async()
self.__class__.no += 1
self.store = ESStore(**app.config["COMMAND_RECORD_OPTIONS"])
if not self.store.ping():
raise AssertionError("ESCommand storage init error")
def record(self, data):
if data and data['input']:
data['input'] = data['input'][:128]
data['output'] = data['output'][:1024]
data['timestamp'] = int(data['timestamp'])
self.queue.put(data)
def push_to_es_async(self):
def func():
while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size,
timeout=self.timeout)
logger.debug(
"<Session command recorder {}> queue size: {}".format(
self.no, self.queue.qsize())
)
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.store.bulk_save(data_set)
if not ok:
self.queue.mput(data_set)
thread = threading.Thread(target=func)
thread.daemon = True
thread.start()
def session_start(self, session_id):
pass
def session_end(self, session_id):
pass
def __del__(self):
print("{} has been gc".format(self))
def get_command_recorder_class(config):
command_storage = config["COMMAND_STORAGE"]
if command_storage['TYPE'] == "elasticsearch":
return ESCommandRecorder
else: else:
return ServerCommandRecorder return ServerCommandRecorder
def get_replay_recorder_class(app): def get_replay_recorder_class(config):
replay_engine = app.config["REPLAY_RECORD_ENGINE"] replay_engine = config["REPLAY_RECORD_ENGINE"]
if replay_engine == "server": if replay_engine == "server":
return ServerReplayRecorder return ServerReplayRecorder
else: else:
......
...@@ -9,6 +9,8 @@ import threading ...@@ -9,6 +9,8 @@ import threading
import paramiko import paramiko
import sys import sys
import time
from .utils import ssh_key_gen from .utils import ssh_key_gen
from .interface import SSHInterface from .interface import SSHInterface
from .interactive import InteractiveServer from .interactive import InteractiveServer
...@@ -48,13 +50,13 @@ class SSHServer: ...@@ -48,13 +50,13 @@ class SSHServer:
try: try:
sock, addr = self.sock.accept() sock, addr = self.sock.accept()
logger.info("Get ssh request from {}: {}".format(addr[0], addr[1])) logger.info("Get ssh request from {}: {}".format(addr[0], addr[1]))
thread = threading.Thread(target=self.handle, args=(sock, addr)) thread = threading.Thread(target=self.handle_connection, args=(sock, addr))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
except Exception as e: except Exception as e:
logger.error("Start SSH server error: {}".format(e)) logger.error("Start SSH server error: {}".format(e))
def handle(self, sock, addr): def handle_connection(self, sock, addr):
transport = paramiko.Transport(sock, gss_kex=False) transport = paramiko.Transport(sock, gss_kex=False)
try: try:
transport.load_server_moduli() transport.load_server_moduli()
...@@ -73,23 +75,29 @@ class SSHServer: ...@@ -73,23 +75,29 @@ class SSHServer:
logger.warning("Handle EOF Error") logger.warning("Handle EOF Error")
return return
chan = transport.accept(10) while True:
chan = transport.accept()
if chan is None: if chan is None:
logger.warning("No ssh channel get") continue
return
server.event.wait(5) server.event.wait(5)
if not server.event.is_set(): if not server.event.is_set():
logger.warning("Client not request a valid request, exiting") logger.warning("Client not request a valid request, exiting")
return return
t = threading.Thread(target=self.handle_chan, args=(chan, request))
t.daemon = True
t.start()
def handle_chan(self, chan, request):
client = Client(chan, request) client = Client(chan, request)
print(chan)
print(request)
self.app.add_client(client) self.app.add_client(client)
self.dispatch(client) self.dispatch(client)
def dispatch(self, client): def dispatch(self, client):
request_type = client.request.type request_type = client.request.type
if request_type == 'pty': if request_type == 'pty' or request_type == 'x11':
logger.info("Request type `pty`, dispatch to interactive mode") logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact() InteractiveServer(self.app, client).interact()
elif request_type == 'exec': elif request_type == 'exec':
......
...@@ -49,10 +49,10 @@ class Config: ...@@ -49,10 +49,10 @@ class Config:
# ASSET_LIST_SORT_BY = 'ip' # ASSET_LIST_SORT_BY = 'ip'
# 登录是否支持密码认证 # 登录是否支持密码认证
# SSH_PASSWORD_AUTH = True # PASSWORD_AUTH = True
# 登录是否支持秘钥认证 # 登录是否支持秘钥认证
# SSH_PUBLIC_KEY_AUTH = True # PUBLIC_KEY_AUTH = True
# 和Jumpserver 保持心跳时间间隔 # 和Jumpserver 保持心跳时间间隔
# HEARTBEAT_INTERVAL = 5 # HEARTBEAT_INTERVAL = 5
......
...@@ -29,3 +29,4 @@ urllib3==1.22 ...@@ -29,3 +29,4 @@ urllib3==1.22
wcwidth==0.1.7 wcwidth==0.1.7
werkzeug==0.12.2 werkzeug==0.12.2
jumpserver-python-sdk==0.0.23 jumpserver-python-sdk==0.0.23
jms-es-sdk==0.5.1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment