Unverified Commit 40ac8f51 authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #58 from jumpserver/dev

添加eventlet支持websocket模式,优化storage等
parents 14ccd1b0 64b655aa
......@@ -9,6 +9,8 @@ import threading
import socket
import json
import signal
import eventlet
from eventlet.debug import hub_prevent_multiple_readers
from jms.service import AppService
......@@ -17,9 +19,11 @@ from .sshd import SSHServer
from .httpd import HttpServer
from .logger import create_logger
from .tasks import TaskHandler
from .recorder import get_command_recorder_class, ServerReplayRecorder
from .utils import get_logger
from .recorder import ReplayRecorder, CommandRecorder
from .utils import get_logger, register_app, register_service
eventlet.monkey_patch()
hub_prevent_multiple_readers(False)
__version__ = '1.3.0'
......@@ -56,7 +60,6 @@ class Coco:
def __init__(self, root_path=None):
self.root_path = root_path if root_path else BASE_DIR
self.config = self.config_class(self.root_path, defaults=self.default_config)
self.sessions = []
self.clients = []
self.lock = threading.Lock()
......@@ -67,6 +70,14 @@ class Coco:
self.replay_recorder_class = None
self.command_recorder_class = None
self._task_handler = None
self.config = None
self.init_config()
register_app(self)
def init_config(self):
self.config = self.config_class(
self.root_path, defaults=self.default_config
)
@property
def name(self):
......@@ -79,24 +90,25 @@ class Coco:
def service(self):
if self._service is None:
self._service = AppService(self)
register_service(self._service)
return self._service
@property
def sshd(self):
if self._sshd is None:
self._sshd = SSHServer(self)
self._sshd = SSHServer()
return self._sshd
@property
def httpd(self):
if self._httpd is None:
self._httpd = HttpServer(self)
self._httpd = HttpServer()
return self._httpd
@property
def task_handler(self):
if self._task_handler is None:
self._task_handler = TaskHandler(self)
self._task_handler = TaskHandler()
return self._task_handler
def make_logger(self):
......@@ -109,24 +121,21 @@ class Coco:
))
self.config.update(configs)
def get_recorder_class(self):
self.replay_recorder_class = ServerReplayRecorder
self.command_recorder_class = get_command_recorder_class(self.config)
def new_command_recorder(self):
recorder = self.command_recorder_class(self)
return recorder
@staticmethod
def new_command_recorder():
return CommandRecorder()
def new_replay_recorder(self):
return self.replay_recorder_class(self)
@staticmethod
def new_replay_recorder():
return ReplayRecorder()
def bootstrap(self):
self.make_logger()
self.service.initial()
self.load_extra_conf_from_server()
self.get_recorder_class()
self.keep_heartbeat()
self.monitor_sessions()
self.monitor_sessions_replay()
def heartbeat(self):
_sessions = [s.to_json() for s in self.sessions]
......@@ -155,6 +164,31 @@ class Coco:
thread = threading.Thread(target=func)
thread.start()
def monitor_sessions_replay(self):
interval = 10
recorder = self.new_replay_recorder()
log_dir = os.path.join(self.config['LOG_DIR'])
def func():
while not self.stop_evt.is_set():
active_sessions = [str(session.id) for session in self.sessions]
for filename in os.listdir(log_dir):
session_id = filename.split('.')[0]
full_path = os.path.join(log_dir, filename)
if len(session_id) != 36:
continue
if session_id not in active_sessions:
recorder.file_path = full_path
ok = recorder.upload_replay(session_id, 1)
if not ok and os.path.getsize(full_path) == 0:
os.unlink(full_path)
time.sleep(interval)
thread = threading.Thread(target=func)
thread.start()
def monitor_sessions(self):
interval = self.config["HEARTBEAT_INTERVAL"]
......@@ -188,9 +222,11 @@ class Coco:
self.run_httpd()
signal.signal(signal.SIGTERM, lambda x, y: self.shutdown())
while self.stop_evt.wait(5):
print("Coco receive term signal, exit")
break
while True:
if self.stop_evt.is_set():
print("Coco receive term signal, exit")
break
time.sleep(3)
except KeyboardInterrupt:
self.stop_evt.set()
self.shutdown()
......@@ -218,13 +254,19 @@ class Coco:
def add_client(self, client):
with self.lock:
self.clients.append(client)
logger.info("New client {} join, total {} now".format(client, len(self.clients)))
logger.info("New client {} join, total {} now".format(
client, len(self.clients)
)
)
def remove_client(self, client):
with self.lock:
try:
self.clients.remove(client)
logger.info("Client {} leave, total {} now".format(client, len(self.clients)))
logger.info("Client {} leave, total {} now".format(
client, len(self.clients)
)
)
client.close()
except:
pass
......@@ -241,4 +283,5 @@ class Coco:
self.sessions.remove(session)
self.service.finish_session(session.to_json())
except ValueError:
logger.warning("Remove session: {} fail, maybe already removed".format(session))
\ No newline at end of file
msg = "Remove session: {} fail, maybe already removed"
logger.warning(msg.format(session))
# -*- coding: utf-8 -*-
#
import weakref
import os
import socket
import paramiko
from paramiko.ssh_exception import SSHException
from .ctx import app_service
from .utils import get_logger, get_private_key_fingerprint
logger = get_logger(__file__)
......@@ -15,21 +15,26 @@ TIMEOUT = 10
class SSHConnection:
def __init__(self, app):
self._app = weakref.ref(app)
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user):
"""
获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
password, private_key = \
app_service.get_system_user_auth_info(system_user)
system_user.password = password
system_user.private_key = private_key
def get_ssh_client(self, asset, system_user):
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sock = None
self.get_system_user_auth(system_user)
if not system_user.password and not system_user.private_key:
self.get_system_user_auth(system_user)
if asset.domain:
sock = self.get_proxy_sock(asset)
sock = self.get_proxy_sock_v2(asset)
try:
ssh.connect(
......@@ -56,44 +61,62 @@ class SSHConnection:
system_user.username, asset.ip, asset.port,
password_short, key_fingerprint,
))
return None, str(e)
return None, None, str(e)
except (socket.error, TimeoutError) as e:
return None, str(e)
return ssh, None
return None, None, str(e)
return ssh, sock, None
def get_transport(self, asset, system_user):
ssh, msg = self.get_ssh_client(asset, system_user)
ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh:
return ssh.get_transport(), None
return ssh.get_transport(), sock, None
else:
return None, msg
return None, None, msg
def get_channel(self, asset, system_user, term="xterm", width=80, height=24):
ssh, msg = self.get_ssh_client(asset, system_user)
ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh:
chan = ssh.invoke_shell(term, width=width, height=height)
return chan, None
return chan, sock, None
else:
return None, msg
return None, sock, msg
def get_sftp(self, asset, system_user):
ssh, msg = self.get_ssh_client(asset, system_user)
ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh:
return ssh.open_sftp(), None
return ssh.open_sftp(), sock, None
else:
return None, msg
return None, sock, msg
def get_system_user_auth(self, system_user):
"""
获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
system_user.password, system_user.private_key = \
self.app.service.get_system_user_auth_info(system_user)
@staticmethod
def get_proxy_sock_v2(asset):
sock = None
domain = app_service.get_domain_detail_with_gateway(
asset.domain
)
if not domain.has_ssh_gateway():
return None
for i in domain.gateways:
gateway = domain.random_ssh_gateway()
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ssh.connect(gateway.ip, username=gateway.username,
password=gateway.password,
pkey=gateway.private_key_obj)
except(paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
SSHException):
continue
sock = ssh.get_transport().open_channel(
'direct-tcpip', (asset.ip, asset.port), ('127.0.0.1', 0)
)
break
return sock
def get_proxy_sock(self, asset):
sock = None
domain = self.app.service.get_domain_detail_with_gateway(
domain = app_service.get_domain_detail_with_gateway(
asset.domain
)
if not domain.has_ssh_gateway():
......
# -*- coding: utf-8 -*-
#
from werkzeug.local import LocalProxy
from functools import partial
stack = {}
def _find(name):
if stack.get(name):
return stack[name]
else:
raise ValueError("Not found in stack: {}".format(name))
current_app = LocalProxy(partial(_find, 'app'))
app_service = LocalProxy(partial(_find, 'service'))
# current_app = []
# current_service = []
......@@ -4,32 +4,25 @@
import os
import socket
import uuid
from flask_socketio import SocketIO, Namespace, join_room, leave_room
from flask_socketio import SocketIO, Namespace, join_room
from flask import Flask, request, current_app, redirect
from .models import Request, Client, WSProxy
from .proxy import ProxyServer
from .utils import get_logger
from .ctx import current_app, app_service
__version__ = '0.5.0'
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
logger = get_logger(__file__)
class BaseNamespace(Namespace):
clients = None
current_user = None
@property
def app(self):
app = current_app.config['coco']
return app
def on_connect(self):
self.current_user = self.get_current_user()
if self.current_user is None:
return redirect(current_app.config['LOGIN_URL'])
return redirect(self.socketio.config['LOGIN_URL'])
logger.debug("{} connect websocket".format(self.current_user))
def get_current_user(self):
......@@ -38,249 +31,253 @@ class BaseNamespace(Namespace):
token = request.headers.get("Authorization")
user = None
if session_id and csrf_token:
user = self.app.service.check_user_cookie(session_id, csrf_token)
user = app_service.check_user_cookie(session_id, csrf_token)
if token:
user = self.app.service.check_user_with_token(token)
user = app_service.check_user_with_token(token)
return user
def close(self):
try:
self.clients[request.sid]["client"].close()
except:
pass
class ProxyNamespace(BaseNamespace):
def __init__(self, *args, **kwargs):
"""
:param args:
:param kwargs:
self.connections = {
"request_sid": {
"room_id": {
"id": room_id,
"proxy": None,
"client": None,
"forwarder": None,
"request": None,
"cols": 80,
"rows": 24
},
...
},
...
}
"""
super().__init__(*args, **kwargs)
self.clients = dict()
self.rooms = dict()
def new_client(self):
room = str(uuid.uuid4())
client = {
"cols": int(request.cookies.get('cols', 80)),
"rows": int(request.cookies.get('rows', 24)),
"room": room,
"proxy": dict(),
"client": dict(),
"forwarder": dict(),
"request": self.make_coco_request()
self.connections = dict()
def new_connection(self):
self.connections[request.sid] = dict()
def new_room(self):
room_id = str(uuid.uuid4())
room = {
"id": room_id,
"proxy": None,
"client": None,
"forwarder": None,
"request": self.make_coco_request(),
"cols": 80,
"rows": 24
}
return client
def make_coco_request(self):
x_forwarded_for = request.headers.get("X-Forwarded-For", '').split(',')
if x_forwarded_for and x_forwarded_for[0]:
remote_ip = x_forwarded_for[0]
else:
remote_ip = request.remote_addr
self.connections[request.sid][room_id] = room
return room
width_request = request.cookies.get('cols')
@staticmethod
def get_win_size():
cols_request = request.cookies.get('cols')
rows_request = request.cookies.get('rows')
if width_request and width_request.isdigit():
width = int(width_request)
if cols_request and cols_request.isdigit():
cols = int(cols_request)
else:
width = 80
cols = 80
if rows_request and rows_request.isdigit():
rows = int(rows_request)
else:
rows = 24
return cols, rows
def make_coco_request(self):
x_forwarded_for = request.headers.get("X-Forwarded-For", '').split(',')
if x_forwarded_for and x_forwarded_for[0]:
remote_ip = x_forwarded_for[0]
else:
remote_ip = request.remote_addr
width, height = self.get_win_size()
req = Request((remote_ip, 0))
req.user = self.current_user
req.meta = {
"width": width,
"height": rows,
"height": height,
}
return req
def on_connect(self):
logger.debug("On connect event trigger")
super().on_connect()
client = self.new_client()
self.clients[request.sid] = client
self.rooms[client['room']] = {
"admin": request.sid,
"member": [],
"rw": []
}
join_room(client['room'])
def on_data(self, message):
"""
收到浏览器请求
:param message: {"data": "xxx", "room": "xxx"}
:return:
"""
room = message.get('room')
if not room:
return
room_proxy = self.clients[request.sid]['proxy'].get(room)
if room_proxy:
room_proxy.send({"data": message['data']})
self.new_connection()
def on_host(self, message):
# 此处获取主机的信息
logger.debug("On host event trigger")
connection = str(uuid.uuid4())
asset_id = message.get('uuid', None)
user_id = message.get('userid', None)
secret = message.get('secret', None)
room = self.new_room()
self.emit('room', {'room': connection, 'secret': secret})
self.emit('room', {'room': room["id"], 'secret': secret})
join_room(room["id"])
if not asset_id or not user_id:
# self.on_connect()
return
asset = self.app.service.get_asset(asset_id)
system_user = self.app.service.get_system_user(user_id)
asset = app_service.get_asset(asset_id)
system_user = app_service.get_system_user(user_id)
if not asset or not system_user:
self.on_connect()
return
child, parent = socket.socketpair()
self.clients[request.sid]["client"][connection] = Client(
parent, self.clients[request.sid]["request"]
)
self.clients[request.sid]["proxy"][connection] = WSProxy(
self, child, self.clients[request.sid]["room"], connection
)
self.clients[request.sid]["forwarder"][connection] = ProxyServer(
self.app, self.clients[request.sid]["client"][connection]
)
client = Client(parent, room["request"])
forwarder = ProxyServer(client)
room["client"] = client
room["forwarder"] = forwarder
room["proxy"] = WSProxy(self, child, room["id"])
room["cols"], room["rows"] = self.get_win_size()
self.socketio.start_background_task(
self.clients[request.sid]["forwarder"][connection].proxy,
asset, system_user
forwarder.proxy, asset, system_user
)
def on_data(self, message):
"""
收到浏览器请求
:param message: {"data": "xxx", "room": "xxx"}
:return:
"""
room_id = message.get('room')
room = self.connections.get(request.sid, {}).get(room_id)
if not room:
return
room["proxy"].send({"data": message['data']})
def on_token(self, message):
# 此处获取token含有的主机的信息
logger.debug("On token trigger")
logger.debug(message)
token = message.get('token', None)
secret = message.get('secret', None)
connection = str(uuid.uuid4())
self.emit('room', {'room': connection, 'secret': secret})
if not (token or secret):
logger.debug("token or secret is None")
self.emit('data', {'data': "\nOperation not permitted!", 'room': connection})
room = self.new_room()
self.emit('room', {'room': room["id"], 'secret': secret})
self.socketio.sleep(0)
if not token or not secret:
logger.debug("Token or secret is None")
self.emit('data', {'data': "\nOperation not permitted!",
'room': room["id"]})
self.emit('disconnect')
self.socketio.sleep(0)
return None
host = self.app.service.get_token_asset(token)
logger.debug(host)
if not host:
logger.debug("host is None")
self.emit('data', {'data': "\nOperation not permitted!", 'room': connection})
info = app_service.get_token_asset(token)
logger.debug(info)
if not info:
logger.debug("Token info is None")
self.emit('data', {'data': "\nOperation not permitted!",
'room': room["id"]})
self.emit('disconnect')
self.socketio.sleep(0)
return None
user_id = host.get('user', None)
logger.debug("self.current_user")
self.current_user = self.app.service.get_user_profile(user_id)
self.clients[request.sid]["request"].user = self.current_user
user_id = info.get('user', None)
self.current_user = app_service.get_user_profile(user_id)
room["request"].user = self.current_user
logger.debug(self.current_user)
# {
# "user": {UUID},
# "asset": {UUID},
# "system_user": {UUID}
# }
self.on_host({'secret': secret, 'uuid': host['asset'], 'userid': host['system_user']})
self.on_host({
'secret': secret,
'uuid': info['asset'],
'userid': info['system_user'],
})
def on_resize(self, message):
cols = message.get('cols')
rows = message.get('rows')
cols, rows = message.get('cols', None), message.get('rows', None)
logger.debug("On resize event trigger: {}*{}".format(cols, rows))
if cols and rows and self.clients[request.sid]["request"]:
self.clients[request.sid]["request"].meta['width'] = cols
self.clients[request.sid]["request"].meta['height'] = rows
self.clients[request.sid]["request"].change_size_event.set()
def on_room(self, session_id):
logger.debug("On room event trigger")
if session_id not in self.clients.keys():
self.emit(
'error', "no such session",
room=self.clients[request.sid]["room"]
)
else:
self.emit(
'room', self.clients[session_id]["room"],
room=self.clients[request.sid]["room"]
)
def on_join(self, room):
logger.debug("On join room event trigger")
self.on_leave(self.clients[request.id]["room"])
self.clients[request.sid]["room"] = room
self.rooms[room]["member"].append(request.sid)
join_room(room=room)
def on_leave(self, room):
logger.debug("On leave room event trigger")
if self.rooms[room]["admin"] == request.sid:
self.emit("data", "\nAdmin leave", room=room)
del self.rooms[room]
leave_room(room=room)
rooms = self.connections.get(request.sid)
if not rooms:
return
room = list(rooms.values())[0]
if rooms and (room["cols"], room["rows"]) != (cols, rows):
for room in rooms.values():
room["request"].meta.update({
'width': cols, 'height': rows
})
room["request"].change_size_event.set()
room.update({"cols": cols, "rows": rows})
def on_disconnect(self):
logger.debug("On disconnect event trigger")
self.on_leave(self.clients[request.sid]["room"])
try:
for connection in self.clients[request.sid]["client"]:
self.on_logout(connection)
del self.clients[request.sid]
except:
pass
def on_logout(self, connection):
logger.debug("On logout event trigger")
if connection:
if connection in self.clients[request.sid]["proxy"].keys():
self.clients[request.sid]["proxy"][connection].close()
del self.clients[request.sid]['proxy'][connection]
def logout(self, connection):
if connection and (request.sid in self.clients.keys()):
if connection in self.clients[request.sid]["proxy"].keys():
del self.clients[request.sid]["proxy"][connection]
if connection in self.clients[request.sid]["forwarder"].keys():
del self.clients[request.sid]["forwarder"][connection]
if connection in self.clients[request.sid]["client"].keys():
del self.clients[request.sid]["client"][connection]
rooms = {k: v for k, v in self.connections.get(request.sid, {}).items()}
for room_id in rooms:
try:
self.on_logout(room_id)
except Exception as e:
logger.warn(e)
del self.connections[request.sid]
def on_logout(self, room_id):
room = self.connections.get(request.sid, {}).get(room_id)
if room:
room["proxy"].close()
self.close_room(room_id)
del self.connections[request.sid][room_id]
del room
def on_ping(self):
self.emit('pong')
class HttpServer:
# prepare may be rewrite it
config = {
'SECRET_KEY': '',
'SECRET_KEY': 'someWOrkSD20KMS9330)&#',
'coco': None,
'LOGIN_URL': '/login'
}
async_mode = "threading"
def __init__(self, coco):
config = coco.config
init_kwargs = dict(
async_mode="eventlet",
# async_mode="threading",
# ping_timeout=20,
# ping_interval=10,
# engineio_logger=True,
# logger=True
)
def __init__(self):
config = {k: v for k, v in current_app.config.items()}
config.update(self.config)
config['coco'] = coco
self.flask_app = Flask(__name__, template_folder='dist')
self.flask_app.config.update(config)
self.socket_io = SocketIO()
self.register_routes()
self.register_error_handler()
def register_routes(self):
self.socket_io.on_namespace(ProxyNamespace('/ssh'))
@staticmethod
def on_error_default(e):
logger.exception(e)
def register_error_handler(self):
self.socket_io.on_error_default(self.on_error_default)
def run(self):
host = self.flask_app.config["BIND_HOST"]
port = self.flask_app.config["HTTPD_PORT"]
self.socket_io.init_app(self.flask_app, async_mode=self.async_mode)
print('Starting websocket server at {}:{}'.format(host, port))
self.socket_io.init_app(
self.flask_app,
**self.init_kwargs
)
self.socket_io.run(self.flask_app, port=port, host=host, debug=False)
def shutdown(self):
self.socket_io.stop()
pass
......@@ -4,16 +4,14 @@
import socket
import threading
import weakref
import os
from jms.models import Asset, AssetGroup
from . import char
from .utils import wrap_with_line_feed as wr, wrap_with_title as title, \
wrap_with_primary as primary, wrap_with_warning as warning, \
is_obj_attr_has, is_obj_attr_eq, sort_assets, TtyIOParser, \
ugettext as _, get_logger
wrap_with_warning as warning, is_obj_attr_has, is_obj_attr_eq, \
sort_assets, ugettext as _, get_logger, net_input, format_with_zh, \
item_max_length, size_of_str_with_zh
from .ctx import current_app, app_service
from .proxy import ProxyServer
logger = get_logger(__file__)
......@@ -22,19 +20,14 @@ logger = get_logger(__file__)
class InteractiveServer:
_sentinel = object()
def __init__(self, app, client):
self._app = weakref.ref(app)
def __init__(self, client):
self.client = client
self.request = client.request
self.assets = None
self._search_result = None
self.asset_groups = None
self.nodes = None
self.get_user_assets_async()
self.get_user_asset_groups_async()
@property
def app(self):
return self._app()
self.get_user_nodes_async()
@property
def search_result(self):
......@@ -45,12 +38,15 @@ class InteractiveServer:
@search_result.setter
def search_result(self, value):
if not value:
self._search_result = value
return
value = self.filter_system_users(value)
self._search_result = value
def display_banner(self):
self.client.send(char.CLEAR_CHAR)
logo_path = os.path.join(self.app.root_path, "logo.txt")
logo_path = os.path.join(current_app.root_path, "logo.txt")
if os.path.isfile(logo_path):
with open(logo_path, 'rb') as f:
for i in f:
......@@ -61,73 +57,16 @@ class InteractiveServer:
banner = _("""\n {title} {user}, 欢迎使用Jumpserver开源跳板机系统 {end}\r\n\r
1) 输入 {green}ID{end} 直接登录 或 输入{green}部分 IP,主机名,备注{end} 进行搜索登录(如果唯一).\r
2) 输入 {green}/{end} + {green}IP, 主机名{end} or {green}备注 {end}搜索. 如: /ip\r
3) 输入 {green}P/p{end} 显示您有权限的主机.\r
4) 输入 {green}G/g{end} 显示您有权限的主机组.\r
5) 输入 {green}G/g{end} + {green}组ID{end} 显示该组下主机. 如: g1\r
6) 输入 {green}H/h{end} 帮助.\r
0) 输入 {green}Q/q{end} 退出.\r\n""").format(
3) 输入 {green}p{end} 显示您有权限的主机.\r
4) 输入 {green}g{end} 显示您有权限的节点\r
5) 输入 {green}g{end} + {green}组ID{end} 显示节点下主机. 如: g1\r
6) 输入 {green}h{end} 帮助.\r
0) 输入 {green}q{end} 退出.\r\n""").format(
title="\033[1;32m", green="\033[32m",
end="\033[0m", user=self.client.user
)
self.client.send(banner)
def get_option(self, prompt='Opt> '):
"""实现了一个ssh input, 提示用户输入, 获取并返回
:return user input string
"""
# Todo: 实现自动hostname或IP补全
input_data = []
parser = TtyIOParser()
self.client.send(wr(prompt, before=1, after=0))
while True:
data = self.client.recv(10)
if len(data) == 0:
self.app.remove_client(self.client)
break
# Client input backspace
if data in char.BACKSPACE_CHAR:
# If input words less than 0, should send 'BELL'
if len(input_data) > 0:
data = char.BACKSPACE_CHAR[data]
input_data.pop()
else:
data = char.BELL_CHAR
self.client.send(data)
continue
if data.startswith(b'\x03'):
# Ctrl-C
self.client.send(b'^C\r\nOpt> ')
input_data = []
continue
elif data.startswith(b'\x04'):
# Ctrl-D
return 'q'
# Todo: Move x1b to char
if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR:
self.client.send(b'')
continue
# handle shell expect
multi_char_with_enter = False
if len(data) > 1 and data[-1] in char.ENTER_CHAR_ORDER:
self.client.send(data)
input_data.append(data[:-1])
multi_char_with_enter = True
# If user type ENTER we should get user input
if data in char.ENTER_CHAR or multi_char_with_enter:
self.client.send(wr(b'', after=2))
option = parser.parse_input(input_data)
del input_data[:]
return option.strip()
else:
self.client.send(data)
input_data.append(data)
def dispatch(self, opt):
if opt is None:
return self._sentinel
......@@ -136,9 +75,9 @@ class InteractiveServer:
elif opt in ['p', 'P', '']:
self.display_assets()
elif opt in ['g', 'G']:
self.display_asset_groups()
self.display_nodes()
elif opt.startswith("g") and opt.lstrip("g").isdigit():
self.display_group_assets(int(opt.lstrip("g")))
self.display_node_assets(int(opt.lstrip("g")))
elif opt in ['q', 'Q', 'exit', 'quit']:
return self._sentinel
elif opt in ['h', 'H']:
......@@ -152,21 +91,24 @@ class InteractiveServer:
result = []
# 所有的
if q == '':
if q in ('', None):
result = self.assets
# 用户输入的是数字,可能想使用id唯一键搜索
elif q.isdigit() and self.search_result and len(self.search_result) >= int(q):
elif q.isdigit() and self.search_result and \
len(self.search_result) >= int(q):
result = [self.search_result[int(q) - 1]]
# 全匹配到则直接返回全匹配的
if len(result) == 0:
_result = [asset for asset in self.assets if is_obj_attr_eq(asset, q)]
_result = [asset for asset in self.assets
if is_obj_attr_eq(asset, q)]
if len(_result) == 1:
result = _result
# 最后模糊匹配
if len(result) == 0:
result = [asset for asset in self.assets if is_obj_attr_has(asset, q)]
result = [asset for asset in self.assets
if is_obj_attr_has(asset, q)]
self.search_result = result
......@@ -177,52 +119,69 @@ class InteractiveServer:
"""
self.search_and_display('')
def display_asset_groups(self):
if self.asset_groups is None:
self.get_user_asset_groups()
def display_nodes(self):
if self.nodes is None:
self.get_user_nodes()
if len(self.asset_groups) == 0:
if len(self.nodes) == 0:
self.client.send(warning(_("无")))
return
fake_group = AssetGroup(name=_("Name"), assets_amount=_("Assets"), comment=_("Comment"))
id_max_length = max(len(str(len(self.asset_groups))), 5)
name_max_length = max(max([len(group.name) for group in self.asset_groups]), 15)
amount_max_length = max(len(str(max([group.assets_amount for group in self.asset_groups]))), 10)
header = '{1:>%d} {0.name:%d} {0.assets_amount:<%s} ' % (id_max_length, name_max_length, amount_max_length)
comment_length = max(self.request.meta["width"] - len(header.format(fake_group, id_max_length)), 2)
line = header + '{0.comment:%s}' % (comment_length // 2) # comment中可能有中文
header += "{0.comment:%s}" % comment_length
self.client.send(title(header.format(fake_group, "ID")))
for index, group in enumerate(self.asset_groups, 1):
self.client.send(wr(line.format(group, index)))
self.client.send(wr(_("总共: {}").format(len(self.asset_groups)), before=1))
def display_group_assets(self, _id):
if _id > len(self.asset_groups) or _id <= 0:
id_length = max(len(str(len(self.nodes))), 5)
name_length = item_max_length(self.nodes, 15, key=lambda x: x.name)
amount_length = item_max_length(self.nodes, 10,
key=lambda x: x.assets_amount)
size_list = [id_length, name_length, amount_length]
fake_data = ['ID', _("Name"), _("Assets")]
header_without_comment = format_with_zh(size_list, *fake_data)
comment_length = max(
self.request.meta["width"] -
size_of_str_with_zh(header_without_comment) - 1,
2
)
size_list.append(comment_length)
fake_data.append(_("Comment"))
self.client.send(title(format_with_zh(size_list, *fake_data)))
for index, group in enumerate(self.nodes, 1):
data = [index, group.name, group.assets_amount, group.comment]
self.client.send(wr(format_with_zh(size_list, *data)))
self.client.send(wr(_("总共: {}").format(len(self.nodes)), before=1))
def display_node_assets(self, _id):
if _id > len(self.nodes) or _id <= 0:
self.client.send(wr(warning("没有匹配分组,请重新输入")))
self.display_asset_groups()
self.display_nodes()
return
self.search_result = self.asset_groups[_id - 1].assets_granted
self.search_result = self.nodes[_id - 1].assets_granted
self.display_search_result()
def display_search_result(self):
self.search_result = sort_assets(self.search_result, self.app.config["ASSET_LIST_SORT_BY"])
fake_asset = Asset(hostname=_("Hostname"), ip=_("IP"), _system_users_name_list=_("LoginAs"),
comment=_("Comment"))
id_max_length = max(len(str(len(self.search_result))), 3)
hostname_max_length = max(max([len(asset.hostname) for asset in self.search_result + [fake_asset]]), 15)
sysuser_max_length = max([len(asset.system_users_name_list) for asset in self.search_result + [fake_asset]])
header = '{1:>%d} {0.hostname:%d} {0.ip:15} {0.system_users_name_list:%d} ' % \
(id_max_length, hostname_max_length, sysuser_max_length)
comment_length = self.request.meta["width"] - len(header.format(fake_asset, id_max_length))
comment_length = max([comment_length, 2])
line = header + '{0.comment:.%d}' % (comment_length // 2) # comment中可能有中文
header += '{0.comment:%s}' % comment_length
self.client.send(wr(title(header.format(fake_asset, "ID"))))
sort_by = current_app.config["ASSET_LIST_SORT_BY"]
self.search_result = sort_assets(self.search_result, sort_by)
fake_data = [_("ID"), _("Hostname"), _("IP"), _("LoginAs")]
id_length = max(len(str(len(self.search_result))), 4)
hostname_length = item_max_length(self.search_result, 15,
key=lambda x: x.hostname)
sysuser_length = item_max_length(self.search_result,
key=lambda x: x.system_users_name_list)
size_list = [id_length, hostname_length, 16, sysuser_length]
header_without_comment = format_with_zh(size_list, *fake_data)
comment_length = max(
self.request.meta["width"] -
size_of_str_with_zh(header_without_comment) - 1,
2
)
size_list.append(comment_length)
fake_data.append(_("Comment"))
self.client.send(wr(title(format_with_zh(size_list, *fake_data))))
for index, asset in enumerate(self.search_result, 1):
self.client.send(wr(line.format(asset, index)))
data = [
index, asset.hostname, asset.ip,
asset.system_users_name_list, asset.comment
]
self.client.send(wr(format_with_zh(size_list, *data)))
self.client.send(wr(_("总共: {} 匹配: {}").format(
len(self.assets), len(self.search_result)), before=1)
)
......@@ -231,43 +190,44 @@ class InteractiveServer:
self.search_assets(q)
self.display_search_result()
def get_user_asset_groups(self):
self.asset_groups = self.app.service.get_user_asset_groups(self.client.user)
def get_user_nodes(self):
self.nodes = app_service.get_user_asset_groups(self.client.user)
def get_user_asset_groups_async(self):
thread = threading.Thread(target=self.get_user_asset_groups)
def get_user_nodes_async(self):
thread = threading.Thread(target=self.get_user_nodes)
thread.start()
@staticmethod
def filter_system_users(assets):
for asset in assets:
system_users_granted = asset.system_users_granted
high_priority = max([s.priority for s in system_users_granted]) if system_users_granted else 1
system_users_cleaned = [s for s in system_users_granted if s.priority == high_priority]
high_priority = max([s.priority for s in system_users_granted]) \
if system_users_granted else 1
system_users_cleaned = [s for s in system_users_granted
if s.priority == high_priority]
asset.system_users_granted = system_users_cleaned
return assets
def get_user_assets(self):
self.assets = self.app.service.get_user_assets(self.client.user)
logger.debug("Get user {} assets total: {}".format(self.client.user, len(self.assets)))
self.assets = app_service.get_user_assets(self.client.user)
logger.debug("Get user {} assets total: {}".format(
self.client.user, len(self.assets))
)
def get_user_assets_async(self):
thread = threading.Thread(target=self.get_user_assets)
thread.start()
def choose_system_user(self, system_users):
# highest_priority = max([s.priority for s in system_users])
# system_users = [s for s in system_users if s == highest_priority]
if len(system_users) == 1:
return system_users[0]
elif len(system_users) == 0:
return None
while True:
self.client.send(wr(_("选择一个登: "), after=1))
self.client.send(wr(_("选择一个登: "), after=1))
self.display_system_users(system_users)
opt = self.get_option("ID> ")
opt = net_input(self.client, prompt="ID> ")
if opt.isdigit() and len(system_users) > int(opt):
return system_users[int(opt)]
elif opt in ['q', 'Q']:
......@@ -285,8 +245,11 @@ class InteractiveServer:
self.search_assets(opt)
if self.search_result and len(self.search_result) == 1:
asset = self.search_result[0]
self.search_result = None
if asset.platform == "Windows":
self.client.send(warning(_("终端不支持登录windows, 请使用web terminal访问")))
self.client.send(warning(
_("终端不支持登录windows, 请使用web terminal访问"))
)
return
self.proxy(asset)
else:
......@@ -297,14 +260,14 @@ class InteractiveServer:
if system_user is None:
self.client.send(_("没有系统用户"))
return
forwarder = ProxyServer(self.app, self.client)
forwarder = ProxyServer(self.client)
forwarder.proxy(asset, system_user)
def interact(self):
self.display_banner()
while True:
try:
opt = self.get_option()
opt = net_input(self.client, prompt='Opt> ', before=1)
rv = self.dispatch(opt)
if rv is self._sentinel:
break
......@@ -318,7 +281,7 @@ class InteractiveServer:
thread.start()
def close(self):
self.app.remove_client(self.client)
current_app.remove_client(self.client)
# def __del__(self):
# print("GC: Interactive class been gc")
......@@ -4,9 +4,9 @@
import paramiko
import threading
import weakref
from .utils import get_logger
from .ctx import current_app, app_service
logger = get_logger(__file__)
......@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface):
https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py
"""
def __init__(self, app, request):
self._app = weakref.ref(app)
self._request = weakref.ref(request)
def __init__(self, request):
self.request = request
self.event = threading.Event()
self.auth_valid = False
self.otp_auth = False
self.info = None
@property
def app(self):
return self._app()
@property
def request(self):
return self._request()
def check_auth_interactive(self, username, submethods):
logger.info("Check auth interactive: %s %s" % (username, submethods))
instructions = 'Please enter 6 digits.'
......@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface):
if not seed:
return paramiko.AUTH_FAILED
is_valid = self.app.service.authenticate_otp(seed, otp_code)
is_valid = app_service.authenticate_otp(seed, otp_code)
if is_valid:
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
......@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface):
supported = []
if self.otp_auth:
return 'keyboard-interactive'
if self.app.config["PASSWORD_AUTH"]:
if current_app.config["PASSWORD_AUTH"]:
supported.append("password")
if self.app.config["PUBLIC_KEY_AUTH"]:
if current_app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey")
return ",".join(supported)
......@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface):
return paramiko.AUTH_SUCCESSFUL
def validate_auth(self, username, password="", public_key=""):
info = self.app.service.authenticate(
info = app_service.authenticate(
username, password=password, public_key=public_key,
remote_addr=self.request.remote_ip
)
......
......@@ -49,6 +49,8 @@ def create_logger(app):
'coco': main_setting,
'paramiko': main_setting,
'jms': main_setting,
'socket.io': main_setting,
'engineio': main_setting,
}
)
......
......@@ -94,8 +94,9 @@ class Server:
"""
# Todo: Server name is not very suitable
def __init__(self, chan, asset, system_user):
def __init__(self, chan, sock, asset, system_user):
self.chan = chan
self.sock = sock
self.asset = asset
self.system_user = system_user
self.send_bytes = 0
......@@ -168,6 +169,8 @@ class Server:
self.stop_evt.set()
self.chan.close()
self.chan.transport.close()
if self.sock:
self.sock.transport.close()
@staticmethod
def _have_enter_char(s):
......@@ -218,7 +221,7 @@ class WSProxy:
```
"""
def __init__(self, ws, child, room, connection):
def __init__(self, ws, child, room_id):
"""
:param ws: websocket instance or handler, have write_message method
:param child: sock child pair
......@@ -226,9 +229,8 @@ class WSProxy:
self.ws = ws
self.child = child
self.stop_event = threading.Event()
self.room = room
self.room_id = room_id
self.auto_forward()
self.connection = connection
def send(self, msg):
"""
......@@ -247,12 +249,15 @@ class WSProxy:
while not self.stop_event.is_set():
try:
data = self.child.recv(BUF_SIZE)
except OSError:
continue
if len(data) == 0:
except (OSError, EOFError):
self.close()
break
if not data:
self.close()
break
data = data.decode(errors="ignore")
self.ws.emit("data", {'data': data, 'room': self.connection}, room=self.room)
self.ws.emit("data", {'data': data, 'room': self.room_id},
room=self.room_id)
if len(data) == BUF_SIZE:
time.sleep(0.1)
......@@ -262,11 +267,12 @@ class WSProxy:
thread.start()
def close(self):
self.ws.emit("logout", {"room": self.room_id}, room=self.room_id)
self.stop_event.set()
self.child.close()
self.ws.logout(self.connection)
try:
self.child.shutdown(1)
self.child.close()
except (OSError, EOFError):
pass
logger.debug("Proxy {} closed".format(self))
......@@ -4,15 +4,15 @@
import threading
import time
import weakref
from paramiko.ssh_exception import SSHException
from .session import Session
from .models import Server
from .connection import SSHConnection
from .ctx import current_app, app_service
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_logger
get_logger, net_input
logger = get_logger(__file__)
......@@ -21,42 +21,51 @@ BUF_SIZE = 4096
class ProxyServer:
def __init__(self, app, client):
self._app = weakref.ref(app)
def __init__(self, client):
self.client = client
self.server = None
self.connecting = True
self.stop_event = threading.Event()
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user):
"""
获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
password, private_key = \
app_service.get_system_user_auth_info(system_user)
if not password and not private_key:
prompt = "{}'s password: ".format(system_user.username)
password = net_input(self.client, prompt=prompt, sensitive=True)
system_user.password = password
system_user.private_key = private_key
def proxy(self, asset, system_user):
self.get_system_user_auth(system_user)
self.send_connecting_message(asset, system_user)
self.server = self.get_server_conn(asset, system_user)
if self.server is None:
return
command_recorder = self.app.new_command_recorder()
replay_recorder = self.app.new_replay_recorder()
command_recorder = current_app.new_command_recorder()
replay_recorder = current_app.new_replay_recorder()
session = Session(
self.client, self.server,
command_recorder=command_recorder,
replay_recorder=replay_recorder,
)
self.app.add_session(session)
current_app.add_session(session)
self.watch_win_size_change_async()
session.bridge()
self.stop_event.set()
self.end_watch_win_size_change()
self.app.remove_session(session)
current_app.remove_session(session)
def validate_permission(self, asset, system_user):
"""
验证用户是否有连接改资产的权限
:return: True or False
"""
return self.app.service.validate_user_asset_permission(
return app_service.validate_user_asset_permission(
self.client.user.id, asset.id, system_user.id
)
......@@ -76,18 +85,19 @@ class ProxyServer:
pass
def get_ssh_server_conn(self, asset, system_user):
ssh = SSHConnection(self.app)
request = self.client.request
term = request.meta.get('term', 'xterm')
width = request.meta.get('width', 80)
height = request.meta.get('height', 24)
chan, msg = ssh.get_channel(asset, system_user, term=term,
width=width, height=height)
ssh = SSHConnection()
chan, sock, msg = ssh.get_channel(
asset, system_user, term=term, width=width, height=height
)
if not chan:
self.client.send(warning(wr(msg, before=1, after=0)))
server = None
else:
server = Server(chan, asset, system_user)
server = Server(chan, sock, asset, system_user)
self.connecting = False
self.client.send(b'\r\n')
return server
......@@ -116,9 +126,11 @@ class ProxyServer:
def send_connecting_message(self, asset, system_user):
def func():
delay = 0.0
self.client.send('Connecting to {}@{} {:.1f}'.format(system_user, asset, delay))
self.client.send('Connecting to {}@{} {:.1f}'.format(
system_user, asset, delay)
)
while self.connecting and delay < TIMEOUT:
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode('utf-8'))
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode())
time.sleep(0.1)
delay += 0.1
thread = threading.Thread(target=func)
......
......@@ -8,94 +8,27 @@ import time
import os
import gzip
import json
import shutil
from copy import deepcopy
import jms_storage
from .utils import get_logger
from .utils import get_logger, Singleton
from .alignment import MemoryQueue
from .ctx import current_app, app_service
logger = get_logger(__file__)
BUF_SIZE = 1024
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
class ReplayRecorder(metaclass=abc.ABCMeta):
def __init__(self, app, session=None):
self.app = app
self.session = session
@abc.abstractmethod
def record(self, data):
"""
记录replay数据
:param data: 数据 {
"session": "",
"data": "",
"timestamp": ""
}
:return:
"""
@abc.abstractmethod
def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass
@abc.abstractmethod
def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass
class CommandRecorder:
def __init__(self, app, session=None):
self.app = app
self.session = session
def record(self, data):
"""
:param data: 数据 {
"session":
"input":
"output":
"user":
"asset":
"system_user":
"timestamp":
}
:return:
"""
def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass
def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass
class ServerReplayRecorder(ReplayRecorder):
time_start = None
storage = None
def __init__(self, app):
super().__init__(app)
def __init__(self):
super().__init__()
self.file = None
self.file_path = None
self.get_storage()
def record(self, data):
"""
......@@ -114,78 +47,76 @@ class ServerReplayRecorder(ReplayRecorder):
def session_start(self, session_id):
self.time_start = time.time()
filename = session_id+'.replay.gz'
self.file_path = os.path.join(self.app.config['LOG_DIR'], filename)
filename = session_id + '.replay.gz'
self.file_path = os.path.join(current_app.config['LOG_DIR'], filename)
self.file = gzip.open(self.file_path, 'at')
self.file.write('{')
def session_end(self, session_id):
self.file.write('"0":""}')
self.file.close()
if self.upload_replay(session_id):
logger.info("Succeed to push {}'s {}".format(session_id, "record"))
else:
logger.error("Failed to push {}'s {}".format(session_id, "record"))
self.upload_replay(session_id)
def upload_replay(self, session_id):
configs = self.app.service.load_config_from_server()
logger.debug("upload_replay print config: {}".format(configs))
self.storage = jms_storage.init(configs["REPLAY_STORAGE"])
if not self.storage:
self.storage = jms_storage.jms(self.app.service)
if self.push_file(3, session_id):
def get_storage(self):
config = deepcopy(current_app.config["REPLAY_STORAGE"])
config["SERVICE"] = app_service
self.storage = jms_storage.get_object_storage(config)
def upload_replay(self, session_id, times=3):
if times < 1:
if self.storage.type == 'jms':
return False
else:
self.storage = jms_storage.JMSReplayStorage(app_service)
self.upload_replay(session_id, times=3)
ok, msg = self.push_to_storage(session_id)
if not ok:
msg = 'Failed push replay file: {}, try again {}'.format(msg, times)
logger.warn(msg)
self.upload_replay(session_id, times-1)
else:
msg = 'Success push replay file: {}'.format(session_id)
logger.info(msg)
self.finish_replay(3, session_id)
os.unlink(self.file_path)
return True
else:
return False
def push_to_storage(self, session_id):
dt = time.strftime('%Y-%m-%d', time.localtime(self.time_start))
target = dt + '/' + session_id + '.replay.gz'
return self.storage.upload_file(self.file_path, target)
def push_file(self, times, session_id):
if times < 0:
if self.storage.type() == 'jms':
return False
else:
msg = "Failed push session {}'s replay log to storage".format(session_id)
logger.error(msg)
self.storage = jms_storage.jms(self.app.service)
return self.push_file(3, session_id)
if self.push_to_storage(session_id):
logger.info("Success push session: {}'s replay log to storage ".format(session_id))
return True
else:
msg = "Failed push session {}'s replay log to storage, try {} times".format(session_id, times)
logger.error(msg)
return self.push_file(times - 1, session_id)
return self.storage.upload(self.file_path, target)
def finish_replay(self, times, session_id):
if times < 0:
logger.error("Failed finished session {}'s replay".format(session_id))
if times < 1:
logger.error(
"Failed finished session {}'s replay".format(session_id)
)
return False
if self.app.service.finish_replay(session_id):
logger.info("Success finish session {}'s replay ".format(session_id))
if app_service.finish_replay(session_id):
logger.info(
"Success finish session {}'s replay ".format(session_id)
)
return True
else:
logger.error("Failed finish session {}'s replay, try {} times".format(session_id, times))
msg = "Failed finish session {}'s replay, try {} times"
logger.error(msg.format(session_id, times))
return self.finish_replay(times - 1, session_id)
class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
class CommandRecorder(metaclass=Singleton):
batch_size = 10
timeout = 5
no = 0
storage = None
def __init__(self, app):
super().__init__(app)
def __init__(self):
super().__init__()
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_server_async()
self.__class__.no += 1
self.get_storage()
def record(self, data):
if data and data['input']:
......@@ -194,70 +125,22 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
data['timestamp'] = int(data['timestamp'])
self.queue.put(data)
def get_storage(self):
config = deepcopy(current_app.config["COMMAND_STORAGE"])
config['SERVICE'] = app_service
self.storage = jms_storage.get_log_storage(config)
def push_to_server_async(self):
def func():
while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size, timeout=self.timeout)
logger.debug("<Session command recorder {}> queue size: {}".format(
self.no, self.queue.qsize())
)
size = self.queue.qsize()
if size > 0:
logger.debug("Session command remain push: {}".format(size))
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.app.service.push_session_command(data_set)
if not ok:
self.queue.mput(data_set)
thread = threading.Thread(target=func)
thread.daemon = True
thread.start()
def session_start(self, session_id):
pass
def session_end(self, session_id):
pass
# def __del__(self):
# print("GC: Session command storage has been gc")
class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
batch_size = 10
timeout = 5
no = 0
default_hosts = ["http://localhost"]
def __init__(self, app):
super().__init__(app)
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_es_async()
self.__class__.no += 1
self.store = jms_storage.ESStore(app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts))
if not self.store.ping():
raise AssertionError("ESCommand storage init error")
def record(self, data):
if data and data['input']:
data['input'] = data['input'][:128]
data['output'] = data['output'][:1024]
data['timestamp'] = int(data['timestamp'])
self.queue.put(data)
def push_to_es_async(self):
def func():
while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size,
timeout=self.timeout)
logger.debug(
"<Session command recorder {}> queue size: {}".format(
self.no, self.queue.qsize())
)
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.store.bulk_save(data_set)
ok = self.storage.bulk_save(data_set)
if not ok:
self.queue.mput(data_set)
......@@ -266,25 +149,9 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
thread.start()
def session_start(self, session_id):
print("Session start: {}".format(session_id))
pass
def session_end(self, session_id):
print("Session end: {}".format(session_id))
pass
# def __del__(self):
# print("GC: ES command storage has been gc".format(self))
def get_command_recorder_class(config):
command_storage = config["COMMAND_STORAGE"]
storage_type = command_storage.get('TYPE')
if storage_type == "elasticsearch":
return ESCommandRecorder
else:
return ServerCommandRecorder
#
# def get_replay_recorder_class(config):
# ServerReplayRecorder.client = jms_storage.init(config["REPLAY_STORAGE"])
# return ServerReplayRecorder
......@@ -40,7 +40,7 @@ class Session:
"""
logger.info("Session add watcher: {} -> {} ".format(self.id, watcher))
if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8"))
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode())
self.sel.register(watcher, selectors.EVENT_READ)
self._watchers.append(watcher)
......
......@@ -2,6 +2,7 @@ import os
import tempfile
import paramiko
import time
from .ctx import app_service
from datetime import datetime
from .connection import SSHConnection
......@@ -16,6 +17,17 @@ class SFTPServer(paramiko.SFTPServerInterface):
self._sftp = {}
self.hosts = self.get_perm_hosts()
def session_ended(self):
super().session_ended()
for _, v in self._sftp.items():
sftp = v['sftp']
sock = v.get('sock')
sftp.close()
if sock:
sock.close()
sock.transport.close()
self._sftp = {}
def get_host_sftp(self, host, su):
asset = self.hosts.get(host)
system_user = None
......@@ -28,18 +40,18 @@ class SFTPServer(paramiko.SFTPServerInterface):
raise OSError("No asset or system user explicit")
if host not in self._sftp:
ssh = SSHConnection(self.server.app)
sftp, msg = ssh.get_sftp(asset, system_user)
ssh = SSHConnection()
sftp, sock, msg = ssh.get_sftp(asset, system_user)
if sftp:
self._sftp[host] = sftp
self._sftp[host] = {'sftp': sftp, 'sock': sock}
return sftp
else:
raise OSError("Can not connect asset sftp server")
raise OSError("Can not connect asset sftp server: {}".format(msg))
else:
return self._sftp[host]
return self._sftp[host]['sftp']
def get_perm_hosts(self):
assets = self.server.app.service.get_user_assets(
assets = app_service.get_user_assets(
self.server.request.user
)
return {asset.hostname: asset for asset in assets}
......@@ -89,7 +101,7 @@ class SFTPServer(paramiko.SFTPServerInterface):
"is_success": is_success,
}
for i in range(1, 4):
ok = self.server.app.service.create_ftp_log(data)
ok = app_service.create_ftp_log(data)
if ok:
break
else:
......
......@@ -12,6 +12,7 @@ from .interface import SSHInterface
from .interactive import InteractiveServer
from .models import Client, Request
from .sftp import SFTPServer
from .ctx import current_app
logger = get_logger(__file__)
BACKLOG = 5
......@@ -19,38 +20,41 @@ BACKLOG = 5
class SSHServer:
def __init__(self, app):
self.app = app
def __init__(self):
self.stop_evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.host_key_path = os.path.join(self.app.root_path, 'keys', 'host_rsa_key')
self.workers = []
self.pipe = None
@property
def host_key(self):
if not os.path.isfile(self.host_key_path):
self.gen_host_key()
return paramiko.RSAKey(filename=self.host_key_path)
host_key_path = os.path.join(current_app.root_path, 'keys', 'host_rsa_key')
if not os.path.isfile(host_key_path):
self.gen_host_key(host_key_path)
return paramiko.RSAKey(filename=host_key_path)
def gen_host_key(self):
@staticmethod
def gen_host_key(key_path):
ssh_key, _ = ssh_key_gen()
with open(self.host_key_path, 'w') as f:
with open(key_path, 'w') as f:
f.write(ssh_key)
def run(self):
host = self.app.config["BIND_HOST"]
port = self.app.config["SSHD_PORT"]
host = current_app.config["BIND_HOST"]
port = current_app.config["SSHD_PORT"]
print('Starting ssh server at {}:{}'.format(host, port))
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((host, port))
self.sock.listen(BACKLOG)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(BACKLOG)
while not self.stop_evt.is_set():
try:
sock, addr = self.sock.accept()
logger.info("Get ssh request from {}: {}".format(addr[0], addr[1]))
thread = threading.Thread(target=self.handle_connection, args=(sock, addr))
client, addr = sock.accept()
logger.info("Get ssh request from {}: {}".format(*addr))
thread = threading.Thread(target=self.handle_connection,
args=(client, addr))
thread.daemon = True
thread.start()
except Exception as e:
except IndexError as e:
logger.error("Start SSH server error: {}".format(e))
def handle_connection(self, sock, addr):
......@@ -65,7 +69,7 @@ class SSHServer:
'sftp', paramiko.SFTPServer, SFTPServer
)
request = Request(addr)
server = SSHInterface(self.app, request)
server = SSHInterface(request)
try:
transport.start_server(server=server)
except paramiko.SSHException:
......@@ -96,7 +100,7 @@ class SSHServer:
def handle_chan(self, chan, request):
client = Client(chan, request)
self.app.add_client(client)
current_app.add_client(client)
self.dispatch(client)
def dispatch(self, client):
......@@ -104,7 +108,7 @@ class SSHServer:
request_type = set(client.request.type)
if supported & request_type:
logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact()
InteractiveServer(client).interact()
elif 'subsystem' in request_type:
pass
else:
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
import weakref
from .ctx import current_app, app_service
from .utils import get_logger
logger = get_logger(__file__)
class TaskHandler:
def __init__(self):
self.routes = {
'kill_session': self.handle_kill_session
}
def __init__(self, app):
self._app = weakref.ref(app)
@property
def app(self):
return self._app()
def handle_kill_session(self, task):
@staticmethod
def handle_kill_session(task):
logger.info("Handle kill session task: {}".format(task.args))
session_id = task.args
session = None
for s in self.app.sessions:
for s in current_app.sessions:
if s.id == session_id:
session = s
break
if session:
session.terminate()
self.app.service.finish_task(task.id)
app_service.finish_task(task.id)
def handle(self, task):
if task.name == "kill_session":
self.handle_kill_session(task)
else:
logger.error("No handler for this task: {}".format(task.name))
func = self.routes.get(task.name)
return func(task)
......@@ -4,30 +4,35 @@
from __future__ import unicode_literals
import hashlib
import logging
import re
import os
import threading
import base64
import calendar
import time
import datetime
import gettext
from io import StringIO
from binascii import hexlify
import paramiko
import pyte
import pytz
from email.utils import formatdate
from queue import Queue, Empty
from .exception import NoAppException
from . import char
from .ctx import stack
BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
def ssh_key_string_to_obj(text, password=None):
key = None
try:
......@@ -289,17 +294,130 @@ def get_logger(file_name):
return logging.getLogger('coco.'+file_name)
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
"""实现了一个ssh input, 提示用户输入, 获取并返回
def len_display(s):
length = 0
for i in s:
if zh_pattern.match(i):
length += 2
:return user input string
"""
input_data = []
parser = TtyIOParser()
client.send(wrap_with_line_feed(prompt, before=before, after=after))
while True:
data = client.recv(10)
if len(data) == 0:
break
# Client input backspace
if data in char.BACKSPACE_CHAR:
# If input words less than 0, should send 'BELL'
if len(input_data) > 0:
data = char.BACKSPACE_CHAR[data]
input_data.pop()
else:
data = char.BELL_CHAR
client.send(data)
continue
if data.startswith(b'\x03'):
# Ctrl-C
client.send('^C\r\n{} '.format(prompt).encode())
input_data = []
continue
elif data.startswith(b'\x04'):
# Ctrl-D
return 'q'
# Todo: Move x1b to char
if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR:
client.send(b'')
continue
# handle shell expect
multi_char_with_enter = False
if len(data) > 1 and data[-1] in char.ENTER_CHAR_ORDER:
if sensitive:
client.send(len(data) * '*')
else:
client.send(data)
input_data.append(data[:-1])
multi_char_with_enter = True
# If user type ENTER we should get user input
if data in char.ENTER_CHAR or multi_char_with_enter:
client.send(wrap_with_line_feed(b'', after=2))
option = parser.parse_input(input_data)
del input_data[:]
return option.strip()
else:
length += 1
if sensitive:
client.send(len(data) * '*')
else:
client.send(data)
input_data.append(data)
def register_app(app):
stack['app'] = app
def register_service(service):
stack['service'] = service
zh_pattern = re.compile(r'[\u4e00-\u9fa5]')
def find_chinese(s):
return zh_pattern.findall(s)
def align_with_zh(s, length, addin=' '):
if not isinstance(s, str):
s = str(s)
zh_len = len(find_chinese(s))
padding = length - (len(s) - zh_len) - zh_len*2
padding_content = ''
if padding > 0:
padding_content = addin*padding
return s + padding_content
def format_with_zh(size_list, *args):
data = []
for length, s in zip(size_list, args):
data.append(align_with_zh(s, length))
return ' '.join(data)
def size_of_str_with_zh(s):
if isinstance(s, int):
s = str(s)
try:
chinese = find_chinese(s)
except TypeError:
raise
return len(s) + len(chinese)
def item_max_length(_iter, maxi=None, mini=None, key=None):
if key:
_iter = [key(i) for i in _iter]
length = [size_of_str_with_zh(s) for s in _iter]
if not length:
return 1
if maxi:
length.append(maxi)
length = max(length)
if mini and length < mini:
length = mini
return length
def int_length(i):
return len(str(i))
ugettext = _gettext()
......@@ -12,15 +12,14 @@ cryptography==2.1.4
docutils==0.14
dotmap==1.2.20
elasticsearch==6.1.1
Flask==0.12.2
Flask==1.0.2
Flask-SocketIO==2.9.2
idna==2.6
itsdangerous==0.24
Jinja2==2.10
jmespath==0.9.3
jms-es-sdk==0.5.2
jms-storage==0.0.12
jumpserver-python-sdk==0.0.41
jms-storage==0.0.17
jumpserver-python-sdk==0.0.42
MarkupSafe==1.0
oss2==2.4.0
paramiko==2.4.0
......@@ -28,9 +27,9 @@ psutil==5.4.1
pyasn1==0.4.2
pycparser==2.18
PyNaCl==1.2.1
pyte==0.7.0
pyte==0.8.0
python-dateutil==2.6.1
python-engineio==2.0.1
python-engineio==2.1.0
python-gssapi==0.6.4
python-socketio==1.8.3
pytz==2017.3
......@@ -41,4 +40,5 @@ six==1.11.0
tornado==4.5.2
urllib3==1.22
wcwidth==0.1.7
Werkzeug==0.12.2
Werkzeug==0.14.1
eventlet==0.22
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment