Commit 4d2015df authored by 广宏伟's avatar 广宏伟

Merged in dev (pull request #56)

Dev
parents 27dd5afb 17d01a8e
...@@ -18,7 +18,7 @@ from .httpd import HttpServer ...@@ -18,7 +18,7 @@ from .httpd import HttpServer
from .logger import create_logger from .logger import create_logger
from .tasks import TaskHandler from .tasks import TaskHandler
from .recorder import get_command_recorder_class, ServerReplayRecorder from .recorder import get_command_recorder_class, ServerReplayRecorder
from .utils import get_logger from .utils import get_logger, register_app, register_service
__version__ = '1.3.0' __version__ = '1.3.0'
...@@ -67,6 +67,7 @@ class Coco: ...@@ -67,6 +67,7 @@ class Coco:
self.replay_recorder_class = None self.replay_recorder_class = None
self.command_recorder_class = None self.command_recorder_class = None
self._task_handler = None self._task_handler = None
register_app(self)
@property @property
def name(self): def name(self):
...@@ -79,24 +80,25 @@ class Coco: ...@@ -79,24 +80,25 @@ class Coco:
def service(self): def service(self):
if self._service is None: if self._service is None:
self._service = AppService(self) self._service = AppService(self)
register_service(self._service)
return self._service return self._service
@property @property
def sshd(self): def sshd(self):
if self._sshd is None: if self._sshd is None:
self._sshd = SSHServer(self) self._sshd = SSHServer()
return self._sshd return self._sshd
@property @property
def httpd(self): def httpd(self):
if self._httpd is None: if self._httpd is None:
self._httpd = HttpServer(self) self._httpd = HttpServer()
return self._httpd return self._httpd
@property @property
def task_handler(self): def task_handler(self):
if self._task_handler is None: if self._task_handler is None:
self._task_handler = TaskHandler(self) self._task_handler = TaskHandler()
return self._task_handler return self._task_handler
def make_logger(self): def make_logger(self):
...@@ -114,11 +116,10 @@ class Coco: ...@@ -114,11 +116,10 @@ class Coco:
self.command_recorder_class = get_command_recorder_class(self.config) self.command_recorder_class = get_command_recorder_class(self.config)
def new_command_recorder(self): def new_command_recorder(self):
recorder = self.command_recorder_class(self) return self.command_recorder_class()
return recorder
def new_replay_recorder(self): def new_replay_recorder(self):
return self.replay_recorder_class(self) return self.replay_recorder_class()
def bootstrap(self): def bootstrap(self):
self.make_logger() self.make_logger()
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import weakref
import os import os
import socket import socket
import paramiko import paramiko
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from .ctx import app_service
from .utils import get_logger, get_private_key_fingerprint from .utils import get_logger, get_private_key_fingerprint
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -15,20 +15,13 @@ TIMEOUT = 10 ...@@ -15,20 +15,13 @@ TIMEOUT = 10
class SSHConnection: class SSHConnection:
def __init__(self, app):
self._app = weakref.ref(app)
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user): def get_system_user_auth(self, system_user):
""" """
获取系统用户的认证信息,密码或秘钥 获取系统用户的认证信息,密码或秘钥
:return: system user have full info :return: system user have full info
""" """
password, private_key = \ password, private_key = \
self.app.service.get_system_user_auth_info(system_user) app_service.get_system_user_auth_info(system_user)
system_user.password = password system_user.password = password
system_user.private_key = private_key system_user.private_key = private_key
...@@ -97,7 +90,7 @@ class SSHConnection: ...@@ -97,7 +90,7 @@ class SSHConnection:
def get_proxy_sock(self, asset): def get_proxy_sock(self, asset):
sock = None sock = None
domain = self.app.service.get_domain_detail_with_gateway( domain = app_service.get_domain_detail_with_gateway(
asset.domain asset.domain
) )
if not domain.has_ssh_gateway(): if not domain.has_ssh_gateway():
......
# -*- coding: utf-8 -*-
#
from werkzeug.local import LocalProxy
from functools import partial
stack = {}
def _find(name):
if stack.get(name):
return stack[name]
else:
raise ValueError("Not found in stack: {}".format(name))
current_app = LocalProxy(partial(_find, 'app'))
app_service = LocalProxy(partial(_find, 'service'))
# current_app = []
# current_service = []
...@@ -4,28 +4,23 @@ ...@@ -4,28 +4,23 @@
import os import os
import socket import socket
import uuid import uuid
from flask_socketio import SocketIO, Namespace, join_room, leave_room import traceback
from flask_socketio import SocketIO, Namespace, join_room
from flask import Flask, request, current_app, redirect from flask import Flask, request, current_app, redirect
from .models import Request, Client, WSProxy from .models import Request, Client, WSProxy
from .proxy import ProxyServer from .proxy import ProxyServer
from .utils import get_logger from .utils import get_logger
from .ctx import current_app, app_service
__version__ = '0.5.0'
BASE_DIR = os.path.dirname(os.path.dirname(__file__)) BASE_DIR = os.path.dirname(os.path.dirname(__file__))
logger = get_logger(__file__) logger = get_logger(__file__)
class BaseNamespace(Namespace): class BaseNamespace(Namespace):
clients = None
current_user = None current_user = None
@property
def app(self):
app = current_app.config['coco']
return app
def on_connect(self): def on_connect(self):
self.current_user = self.get_current_user() self.current_user = self.get_current_user()
if self.current_user is None: if self.current_user is None:
...@@ -38,230 +33,219 @@ class BaseNamespace(Namespace): ...@@ -38,230 +33,219 @@ class BaseNamespace(Namespace):
token = request.headers.get("Authorization") token = request.headers.get("Authorization")
user = None user = None
if session_id and csrf_token: if session_id and csrf_token:
user = self.app.service.check_user_cookie(session_id, csrf_token) user = app_service.check_user_cookie(session_id, csrf_token)
if token: if token:
user = self.app.service.check_user_with_token(token) user = app_service.check_user_with_token(token)
return user return user
def close(self):
try:
self.clients[request.sid]["client"].close()
except:
pass
class ProxyNamespace(BaseNamespace): class ProxyNamespace(BaseNamespace):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""
:param args:
:param kwargs:
self.connections = {
"request_sid": {
"room_id": {
"id": room_id,
"proxy": None,
"client": None,
"forwarder": None,
"request": None,
"cols": 80,
"rows": 24
},
...
},
...
}
"""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.clients = dict() self.connections = dict()
self.rooms = dict()
def new_connection(self):
def new_client(self): self.connections[request.sid] = dict()
room = str(uuid.uuid4())
client = { def new_room(self):
"cols": int(request.cookies.get('cols', 80)), room_id = str(uuid.uuid4())
"rows": int(request.cookies.get('rows', 24)), room = {
"room": room, "id": room_id,
"proxy": dict(), "proxy": None,
"client": dict(), "client": None,
"forwarder": dict(), "forwarder": None,
"request": self.make_coco_request() "request": self.make_coco_request(),
"cols": 80,
"rows": 24
} }
return client self.connections[request.sid][room_id] = room
return room
def make_coco_request(self):
x_forwarded_for = request.headers.get("X-Forwarded-For", '').split(',')
if x_forwarded_for and x_forwarded_for[0]:
remote_ip = x_forwarded_for[0]
else:
remote_ip = request.remote_addr
width_request = request.cookies.get('cols') @staticmethod
def get_win_size():
cols_request = request.cookies.get('cols')
rows_request = request.cookies.get('rows') rows_request = request.cookies.get('rows')
if width_request and width_request.isdigit(): if cols_request and cols_request.isdigit():
width = int(width_request) cols = int(cols_request)
else: else:
width = 80 cols = 80
if rows_request and rows_request.isdigit(): if rows_request and rows_request.isdigit():
rows = int(rows_request) rows = int(rows_request)
else: else:
rows = 24 rows = 24
return cols, rows
def make_coco_request(self):
x_forwarded_for = request.headers.get("X-Forwarded-For", '').split(',')
if x_forwarded_for and x_forwarded_for[0]:
remote_ip = x_forwarded_for[0]
else:
remote_ip = request.remote_addr
width, height = self.get_win_size()
req = Request((remote_ip, 0)) req = Request((remote_ip, 0))
req.user = self.current_user req.user = self.current_user
req.meta = { req.meta = {
"width": width, "width": width,
"height": rows, "height": height,
} }
return req return req
def on_connect(self): def on_connect(self):
logger.debug("On connect event trigger") logger.debug("On connect event trigger")
super().on_connect() super().on_connect()
client = self.new_client() self.new_connection()
self.clients[request.sid] = client
self.rooms[client['room']] = {
"admin": request.sid,
"member": [],
"rw": []
}
join_room(client['room'])
def on_data(self, message):
"""
收到浏览器请求
:param message: {"data": "xxx", "room": "xxx"}
:return:
"""
room = message.get('room')
if not room:
return
room_proxy = self.clients[request.sid]['proxy'].get(room)
if room_proxy:
room_proxy.send({"data": message['data']})
def on_host(self, message): def on_host(self, message):
# 此处获取主机的信息 # 此处获取主机的信息
logger.debug("On host event trigger") logger.debug("On host event trigger")
connection = str(uuid.uuid4())
asset_id = message.get('uuid', None) asset_id = message.get('uuid', None)
user_id = message.get('userid', None) user_id = message.get('userid', None)
secret = message.get('secret', None) secret = message.get('secret', None)
room = self.new_room()
self.emit('room', {'room': connection, 'secret': secret}) self.emit('room', {'room': room["id"], 'secret': secret})
join_room(room["id"])
if not asset_id or not user_id: if not asset_id or not user_id:
# self.on_connect() # self.on_connect()
return return
asset = self.app.service.get_asset(asset_id) asset = app_service.get_asset(asset_id)
system_user = self.app.service.get_system_user(user_id) system_user = app_service.get_system_user(user_id)
if not asset or not system_user: if not asset or not system_user:
self.on_connect() self.on_connect()
return return
child, parent = socket.socketpair() child, parent = socket.socketpair()
self.clients[request.sid]["client"][connection] = Client( client = Client(parent, room["request"])
parent, self.clients[request.sid]["request"] forwarder = ProxyServer(client)
) room["client"] = client
self.clients[request.sid]["proxy"][connection] = WSProxy( room["forwarder"] = forwarder
self, child, self.clients[request.sid]["room"], connection room["proxy"] = WSProxy(self, child, room["id"])
) room["cols"], room["rows"] = self.get_win_size()
self.clients[request.sid]["forwarder"][connection] = ProxyServer(
self.app, self.clients[request.sid]["client"][connection]
)
self.socketio.start_background_task( self.socketio.start_background_task(
self.clients[request.sid]["forwarder"][connection].proxy, forwarder.proxy, asset, system_user
asset, system_user
) )
def on_data(self, message):
"""
收到浏览器请求
:param message: {"data": "xxx", "room": "xxx"}
:return:
"""
room_id = message.get('room')
room = self.connections.get(request.sid, {}).get(room_id)
if not room:
return
room["proxy"].send({"data": message['data']})
def on_token(self, message): def on_token(self, message):
# 此处获取token含有的主机的信息 # 此处获取token含有的主机的信息
logger.debug("On token trigger") logger.debug("On token trigger")
logger.debug(message)
token = message.get('token', None) token = message.get('token', None)
secret = message.get('secret', None) secret = message.get('secret', None)
connection = str(uuid.uuid4()) room = self.new_room()
self.emit('room', {'room': connection, 'secret': secret}) self.emit('room', {'room': room["id"], 'secret': secret})
if not (token or secret): if not token or not secret:
logger.debug("token or secret is None") logger.debug("Token or secret is None")
self.emit('data', {'data': "\nOperation not permitted!", 'room': connection}) self.emit('data', {'data': "\nOperation not permitted!",
'room': room["id"]})
self.emit('disconnect') self.emit('disconnect')
return None return None
host = self.app.service.get_token_asset(token) info = app_service.get_token_asset(token)
logger.debug(host) logger.debug(info)
if not host: if not info:
logger.debug("host is None") logger.debug("Token info is None")
self.emit('data', {'data': "\nOperation not permitted!", 'room': connection}) self.emit('data', {'data': "\nOperation not permitted!",
'room': room["id"]})
self.emit('disconnect') self.emit('disconnect')
return None return None
user_id = host.get('user', None) user_id = info.get('user', None)
logger.debug("self.current_user") self.current_user = app_service.get_user_profile(user_id)
self.current_user = self.app.service.get_user_profile(user_id) room["request"].user = self.current_user
self.clients[request.sid]["request"].user = self.current_user
logger.debug(self.current_user) logger.debug(self.current_user)
self.on_host({'secret': secret, 'uuid': host['asset'], 'userid': host['system_user']}) self.on_host({
'secret': secret,
'uuid': info['asset'],
'userid': info['system_user'],
})
def on_resize(self, message): def on_resize(self, message):
cols = message.get('cols') cols, rows = message.get('cols', None), message.get('rows', None)
rows = message.get('rows')
logger.debug("On resize event trigger: {}*{}".format(cols, rows)) logger.debug("On resize event trigger: {}*{}".format(cols, rows))
if cols and rows and self.clients[request.sid]["request"]: rooms = self.connections.get(request.sid)
self.clients[request.sid]["request"].meta['width'] = cols if not rooms:
self.clients[request.sid]["request"].meta['height'] = rows return
self.clients[request.sid]["request"].change_size_event.set() room = list(rooms.values())[0]
if rooms and (room["cols"], room["rows"]) != (cols, rows):
def on_room(self, session_id): for room in rooms.values():
logger.debug("On room event trigger") room["request"].meta.update({
if session_id not in self.clients.keys(): 'width': cols, 'height': rows
self.emit( })
'error', "no such session", room["request"].change_size_event.set()
room=self.clients[request.sid]["room"] room.update({"cols": cols, "rows": rows})
)
else:
self.emit(
'room', self.clients[session_id]["room"],
room=self.clients[request.sid]["room"]
)
def on_join(self, room):
logger.debug("On join room event trigger")
self.on_leave(self.clients[request.id]["room"])
self.clients[request.sid]["room"] = room
self.rooms[room]["member"].append(request.sid)
join_room(room=room)
def on_leave(self, room):
logger.debug("On leave room event trigger")
if self.rooms[room]["admin"] == request.sid:
self.emit("data", "\nAdmin leave", room=room)
del self.rooms[room]
leave_room(room=room)
def on_disconnect(self): def on_disconnect(self):
logger.debug("On disconnect event trigger") logger.debug("On disconnect event trigger")
self.on_leave(self.clients[request.sid]["room"]) room_id_list = list(self.connections.get(request.sid, {}).keys())
for room_id in room_id_list:
try: try:
for connection in self.clients[request.sid]["client"]: self.on_logout(room_id)
self.on_logout(connection) except Exception as e:
del self.clients[request.sid] logger.warn(e)
except: del self.connections[request.sid]
pass
def on_logout(self, room_id):
def on_logout(self, connection): room = self.connections.get(request.sid, {}).get(room_id)
logger.debug("On logout event trigger") if room:
if connection: room["proxy"].close()
if connection in self.clients[request.sid]["proxy"].keys(): self.close_room(room_id)
self.clients[request.sid]["proxy"][connection].close() del self.connections[request.sid][room_id]
del self.clients[request.sid]['proxy'][connection] del room
def logout(self, connection):
if connection and (request.sid in self.clients.keys()):
if connection in self.clients[request.sid]["proxy"].keys():
del self.clients[request.sid]["proxy"][connection]
if connection in self.clients[request.sid]["forwarder"].keys():
del self.clients[request.sid]["forwarder"][connection]
if connection in self.clients[request.sid]["client"].keys():
del self.clients[request.sid]["client"][connection]
class HttpServer: class HttpServer:
# prepare may be rewrite it # prepare may be rewrite it
config = { config = {
'SECRET_KEY': '', 'SECRET_KEY': 'someWOrkSD20KMS9330)&#',
'coco': None, 'coco': None,
'LOGIN_URL': '/login' 'LOGIN_URL': '/login'
} }
async_mode = "threading" init_kwargs = dict(
# async_mode="gevent",
async_mode="threading",
ping_timeout=20,
ping_interval=10
)
def __init__(self, coco): def __init__(self):
config = coco.config config = {k: v for k, v in current_app.config.items()}
config.update(self.config) config.update(self.config)
config['coco'] = coco
self.flask_app = Flask(__name__, template_folder='dist') self.flask_app = Flask(__name__, template_folder='dist')
self.flask_app.config.update(config) self.flask_app.config.update(config)
self.socket_io = SocketIO() self.socket_io = SocketIO()
...@@ -270,11 +254,22 @@ class HttpServer: ...@@ -270,11 +254,22 @@ class HttpServer:
def register_routes(self): def register_routes(self):
self.socket_io.on_namespace(ProxyNamespace('/ssh')) self.socket_io.on_namespace(ProxyNamespace('/ssh'))
@staticmethod
def on_error_default(e):
traceback.print_exc()
logger.warn(e)
def register_error_handler(self):
self.socket_io.on_error_default(self.on_error_default)
def run(self): def run(self):
host = self.flask_app.config["BIND_HOST"] host = self.flask_app.config["BIND_HOST"]
port = self.flask_app.config["HTTPD_PORT"] port = self.flask_app.config["HTTPD_PORT"]
self.socket_io.init_app(self.flask_app, async_mode=self.async_mode) self.socket_io.init_app(
self.flask_app,
**self.init_kwargs
)
self.socket_io.run(self.flask_app, port=port, host=host, debug=False) self.socket_io.run(self.flask_app, port=port, host=host, debug=False)
def shutdown(self): def shutdown(self):
pass self.socket_io.server.close()
...@@ -4,16 +4,14 @@ ...@@ -4,16 +4,14 @@
import socket import socket
import threading import threading
import weakref
import os import os
from jms.models import Asset, AssetGroup
from . import char from . import char
from .utils import wrap_with_line_feed as wr, wrap_with_title as title, \ from .utils import wrap_with_line_feed as wr, wrap_with_title as title, \
wrap_with_primary as primary, wrap_with_warning as warning, \ wrap_with_warning as warning, is_obj_attr_has, is_obj_attr_eq, \
is_obj_attr_has, is_obj_attr_eq, sort_assets, TtyIOParser, \ sort_assets, ugettext as _, get_logger, net_input, format_with_zh, \
ugettext as _, get_logger item_max_length, size_of_str_with_zh
from .ctx import current_app, app_service
from .proxy import ProxyServer from .proxy import ProxyServer
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -22,19 +20,14 @@ logger = get_logger(__file__) ...@@ -22,19 +20,14 @@ logger = get_logger(__file__)
class InteractiveServer: class InteractiveServer:
_sentinel = object() _sentinel = object()
def __init__(self, app, client): def __init__(self, client):
self._app = weakref.ref(app)
self.client = client self.client = client
self.request = client.request self.request = client.request
self.assets = None self.assets = None
self._search_result = None self._search_result = None
self.asset_groups = None self.nodes = None
self.get_user_assets_async() self.get_user_assets_async()
self.get_user_asset_groups_async() self.get_user_nodes_async()
@property
def app(self):
return self._app()
@property @property
def search_result(self): def search_result(self):
...@@ -50,7 +43,7 @@ class InteractiveServer: ...@@ -50,7 +43,7 @@ class InteractiveServer:
def display_banner(self): def display_banner(self):
self.client.send(char.CLEAR_CHAR) self.client.send(char.CLEAR_CHAR)
logo_path = os.path.join(self.app.root_path, "logo.txt") logo_path = os.path.join(current_app.root_path, "logo.txt")
if os.path.isfile(logo_path): if os.path.isfile(logo_path):
with open(logo_path, 'rb') as f: with open(logo_path, 'rb') as f:
for i in f: for i in f:
...@@ -61,73 +54,16 @@ class InteractiveServer: ...@@ -61,73 +54,16 @@ class InteractiveServer:
banner = _("""\n {title} {user}, 欢迎使用Jumpserver开源跳板机系统 {end}\r\n\r banner = _("""\n {title} {user}, 欢迎使用Jumpserver开源跳板机系统 {end}\r\n\r
1) 输入 {green}ID{end} 直接登录 或 输入{green}部分 IP,主机名,备注{end} 进行搜索登录(如果唯一).\r 1) 输入 {green}ID{end} 直接登录 或 输入{green}部分 IP,主机名,备注{end} 进行搜索登录(如果唯一).\r
2) 输入 {green}/{end} + {green}IP, 主机名{end} or {green}备注 {end}搜索. 如: /ip\r 2) 输入 {green}/{end} + {green}IP, 主机名{end} or {green}备注 {end}搜索. 如: /ip\r
3) 输入 {green}P/p{end} 显示您有权限的主机.\r 3) 输入 {green}p{end} 显示您有权限的主机.\r
4) 输入 {green}G/g{end} 显示您有权限的主机组.\r 4) 输入 {green}g{end} 显示您有权限的节点\r
5) 输入 {green}G/g{end} + {green}组ID{end} 显示该组下主机. 如: g1\r 5) 输入 {green}g{end} + {green}组ID{end} 显示节点下主机. 如: g1\r
6) 输入 {green}H/h{end} 帮助.\r 6) 输入 {green}h{end} 帮助.\r
0) 输入 {green}Q/q{end} 退出.\r\n""").format( 0) 输入 {green}q{end} 退出.\r\n""").format(
title="\033[1;32m", green="\033[32m", title="\033[1;32m", green="\033[32m",
end="\033[0m", user=self.client.user end="\033[0m", user=self.client.user
) )
self.client.send(banner) self.client.send(banner)
def get_option(self, prompt='Opt> '):
"""实现了一个ssh input, 提示用户输入, 获取并返回
:return user input string
"""
# Todo: 实现自动hostname或IP补全
input_data = []
parser = TtyIOParser()
self.client.send(wr(prompt, before=1, after=0))
while True:
data = self.client.recv(10)
if len(data) == 0:
self.app.remove_client(self.client)
break
# Client input backspace
if data in char.BACKSPACE_CHAR:
# If input words less than 0, should send 'BELL'
if len(input_data) > 0:
data = char.BACKSPACE_CHAR[data]
input_data.pop()
else:
data = char.BELL_CHAR
self.client.send(data)
continue
if data.startswith(b'\x03'):
# Ctrl-C
self.client.send(b'^C\r\nOpt> ')
input_data = []
continue
elif data.startswith(b'\x04'):
# Ctrl-D
return 'q'
# Todo: Move x1b to char
if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR:
self.client.send(b'')
continue
# handle shell expect
multi_char_with_enter = False
if len(data) > 1 and data[-1] in char.ENTER_CHAR_ORDER:
self.client.send(data)
input_data.append(data[:-1])
multi_char_with_enter = True
# If user type ENTER we should get user input
if data in char.ENTER_CHAR or multi_char_with_enter:
self.client.send(wr(b'', after=2))
option = parser.parse_input(input_data)
del input_data[:]
return option.strip()
else:
self.client.send(data)
input_data.append(data)
def dispatch(self, opt): def dispatch(self, opt):
if opt is None: if opt is None:
return self._sentinel return self._sentinel
...@@ -136,9 +72,9 @@ class InteractiveServer: ...@@ -136,9 +72,9 @@ class InteractiveServer:
elif opt in ['p', 'P', '']: elif opt in ['p', 'P', '']:
self.display_assets() self.display_assets()
elif opt in ['g', 'G']: elif opt in ['g', 'G']:
self.display_asset_groups() self.display_nodes()
elif opt.startswith("g") and opt.lstrip("g").isdigit(): elif opt.startswith("g") and opt.lstrip("g").isdigit():
self.display_group_assets(int(opt.lstrip("g"))) self.display_node_assets(int(opt.lstrip("g")))
elif opt in ['q', 'Q', 'exit', 'quit']: elif opt in ['q', 'Q', 'exit', 'quit']:
return self._sentinel return self._sentinel
elif opt in ['h', 'H']: elif opt in ['h', 'H']:
...@@ -155,18 +91,21 @@ class InteractiveServer: ...@@ -155,18 +91,21 @@ class InteractiveServer:
if q == '': if q == '':
result = self.assets result = self.assets
# 用户输入的是数字,可能想使用id唯一键搜索 # 用户输入的是数字,可能想使用id唯一键搜索
elif q.isdigit() and self.search_result and len(self.search_result) >= int(q): elif q.isdigit() and self.search_result and \
len(self.search_result) >= int(q):
result = [self.search_result[int(q) - 1]] result = [self.search_result[int(q) - 1]]
# 全匹配到则直接返回全匹配的 # 全匹配到则直接返回全匹配的
if len(result) == 0: if len(result) == 0:
_result = [asset for asset in self.assets if is_obj_attr_eq(asset, q)] _result = [asset for asset in self.assets
if is_obj_attr_eq(asset, q)]
if len(_result) == 1: if len(_result) == 1:
result = _result result = _result
# 最后模糊匹配 # 最后模糊匹配
if len(result) == 0: if len(result) == 0:
result = [asset for asset in self.assets if is_obj_attr_has(asset, q)] result = [asset for asset in self.assets
if is_obj_attr_has(asset, q)]
self.search_result = result self.search_result = result
...@@ -177,52 +116,69 @@ class InteractiveServer: ...@@ -177,52 +116,69 @@ class InteractiveServer:
""" """
self.search_and_display('') self.search_and_display('')
def display_asset_groups(self): def display_nodes(self):
if self.asset_groups is None: if self.nodes is None:
self.get_user_asset_groups() self.get_user_nodes()
if len(self.asset_groups) == 0: if len(self.nodes) == 0:
self.client.send(warning(_("无"))) self.client.send(warning(_("无")))
return return
fake_group = AssetGroup(name=_("Name"), assets_amount=_("Assets"), comment=_("Comment")) id_length = max(len(str(len(self.nodes))), 5)
id_max_length = max(len(str(len(self.asset_groups))), 5) name_length = item_max_length(self.nodes, 15, key=lambda x: x.name)
name_max_length = max(max([len(group.name) for group in self.asset_groups]), 15) amount_length = item_max_length(self.nodes, 10,
amount_max_length = max(len(str(max([group.assets_amount for group in self.asset_groups]))), 10) key=lambda x: x.assets_amount)
header = '{1:>%d} {0.name:%d} {0.assets_amount:<%s} ' % (id_max_length, name_max_length, amount_max_length) size_list = [id_length, name_length, amount_length]
comment_length = max(self.request.meta["width"] - len(header.format(fake_group, id_max_length)), 2) fake_data = ['ID', _("Name"), _("Assets")]
line = header + '{0.comment:%s}' % (comment_length // 2) # comment中可能有中文 header_without_comment = format_with_zh(size_list, *fake_data)
header += "{0.comment:%s}" % comment_length comment_length = max(
self.client.send(title(header.format(fake_group, "ID"))) self.request.meta["width"] -
for index, group in enumerate(self.asset_groups, 1): size_of_str_with_zh(header_without_comment) - 1,
self.client.send(wr(line.format(group, index))) 2
self.client.send(wr(_("总共: {}").format(len(self.asset_groups)), before=1)) )
size_list.append(comment_length)
def display_group_assets(self, _id): fake_data.append(_("Comment"))
if _id > len(self.asset_groups) or _id <= 0:
self.client.send(title(format_with_zh(size_list, *fake_data)))
for index, group in enumerate(self.nodes, 1):
data = [index, group.name, group.assets_amount, group.comment]
self.client.send(wr(format_with_zh(size_list, *data)))
self.client.send(wr(_("总共: {}").format(len(self.nodes)), before=1))
def display_node_assets(self, _id):
if _id > len(self.nodes) or _id <= 0:
self.client.send(wr(warning("没有匹配分组,请重新输入"))) self.client.send(wr(warning("没有匹配分组,请重新输入")))
self.display_asset_groups() self.display_nodes()
return return
self.search_result = self.asset_groups[_id - 1].assets_granted self.search_result = self.nodes[_id - 1].assets_granted
self.display_search_result() self.display_search_result()
def display_search_result(self): def display_search_result(self):
self.search_result = sort_assets(self.search_result, self.app.config["ASSET_LIST_SORT_BY"]) sort_by = current_app.config["ASSET_LIST_SORT_BY"]
fake_asset = Asset(hostname=_("Hostname"), ip=_("IP"), _system_users_name_list=_("LoginAs"), self.search_result = sort_assets(self.search_result, sort_by)
comment=_("Comment")) fake_data = [_("ID"), _("Hostname"), _("IP"), _("LoginAs")]
id_max_length = max(len(str(len(self.search_result))), 3) id_length = max(len(str(len(self.search_result))), 4)
hostname_max_length = max(max([len(asset.hostname) for asset in self.search_result + [fake_asset]]), 15) hostname_length = item_max_length(self.search_result, 15,
sysuser_max_length = max([len(asset.system_users_name_list) for asset in self.search_result + [fake_asset]]) key=lambda x: x.hostname)
header = '{1:>%d} {0.hostname:%d} {0.ip:15} {0.system_users_name_list:%d} ' % \ sysuser_length = item_max_length(self.search_result,
(id_max_length, hostname_max_length, sysuser_max_length) key=lambda x: x.system_users_name_list)
comment_length = self.request.meta["width"] - len(header.format(fake_asset, id_max_length)) size_list = [id_length, hostname_length, 16, sysuser_length]
comment_length = max([comment_length, 2]) header_without_comment = format_with_zh(size_list, *fake_data)
line = header + '{0.comment:.%d}' % (comment_length // 2) # comment中可能有中文 comment_length = max(
header += '{0.comment:%s}' % comment_length self.request.meta["width"] -
self.client.send(wr(title(header.format(fake_asset, "ID")))) size_of_str_with_zh(header_without_comment) - 1,
2
)
size_list.append(comment_length)
fake_data.append(_("Comment"))
self.client.send(wr(title(format_with_zh(size_list, *fake_data))))
for index, asset in enumerate(self.search_result, 1): for index, asset in enumerate(self.search_result, 1):
self.client.send(wr(line.format(asset, index))) data = [
index, asset.hostname, asset.ip,
asset.system_users_name_list, asset.comment
]
self.client.send(wr(format_with_zh(size_list, *data)))
self.client.send(wr(_("总共: {} 匹配: {}").format( self.client.send(wr(_("总共: {} 匹配: {}").format(
len(self.assets), len(self.search_result)), before=1) len(self.assets), len(self.search_result)), before=1)
) )
...@@ -231,43 +187,44 @@ class InteractiveServer: ...@@ -231,43 +187,44 @@ class InteractiveServer:
self.search_assets(q) self.search_assets(q)
self.display_search_result() self.display_search_result()
def get_user_asset_groups(self): def get_user_nodes(self):
self.asset_groups = self.app.service.get_user_asset_groups(self.client.user) self.nodes = app_service.get_user_asset_groups(self.client.user)
def get_user_asset_groups_async(self): def get_user_nodes_async(self):
thread = threading.Thread(target=self.get_user_asset_groups) thread = threading.Thread(target=self.get_user_nodes)
thread.start() thread.start()
@staticmethod @staticmethod
def filter_system_users(assets): def filter_system_users(assets):
for asset in assets: for asset in assets:
system_users_granted = asset.system_users_granted system_users_granted = asset.system_users_granted
high_priority = max([s.priority for s in system_users_granted]) if system_users_granted else 1 high_priority = max([s.priority for s in system_users_granted]) \
system_users_cleaned = [s for s in system_users_granted if s.priority == high_priority] if system_users_granted else 1
system_users_cleaned = [s for s in system_users_granted
if s.priority == high_priority]
asset.system_users_granted = system_users_cleaned asset.system_users_granted = system_users_cleaned
return assets return assets
def get_user_assets(self): def get_user_assets(self):
self.assets = self.app.service.get_user_assets(self.client.user) self.assets = app_service.get_user_assets(self.client.user)
logger.debug("Get user {} assets total: {}".format(self.client.user, len(self.assets))) logger.debug("Get user {} assets total: {}".format(
self.client.user, len(self.assets))
)
def get_user_assets_async(self): def get_user_assets_async(self):
thread = threading.Thread(target=self.get_user_assets) thread = threading.Thread(target=self.get_user_assets)
thread.start() thread.start()
def choose_system_user(self, system_users): def choose_system_user(self, system_users):
# highest_priority = max([s.priority for s in system_users])
# system_users = [s for s in system_users if s == highest_priority]
if len(system_users) == 1: if len(system_users) == 1:
return system_users[0] return system_users[0]
elif len(system_users) == 0: elif len(system_users) == 0:
return None return None
while True: while True:
self.client.send(wr(_("选择一个登: "), after=1)) self.client.send(wr(_("选择一个登: "), after=1))
self.display_system_users(system_users) self.display_system_users(system_users)
opt = self.get_option("ID> ") opt = net_input(self.client, prompt="ID> ")
if opt.isdigit() and len(system_users) > int(opt): if opt.isdigit() and len(system_users) > int(opt):
return system_users[int(opt)] return system_users[int(opt)]
elif opt in ['q', 'Q']: elif opt in ['q', 'Q']:
...@@ -286,7 +243,9 @@ class InteractiveServer: ...@@ -286,7 +243,9 @@ class InteractiveServer:
if self.search_result and len(self.search_result) == 1: if self.search_result and len(self.search_result) == 1:
asset = self.search_result[0] asset = self.search_result[0]
if asset.platform == "Windows": if asset.platform == "Windows":
self.client.send(warning(_("终端不支持登录windows, 请使用web terminal访问"))) self.client.send(warning(
_("终端不支持登录windows, 请使用web terminal访问"))
)
return return
self.proxy(asset) self.proxy(asset)
else: else:
...@@ -297,14 +256,14 @@ class InteractiveServer: ...@@ -297,14 +256,14 @@ class InteractiveServer:
if system_user is None: if system_user is None:
self.client.send(_("没有系统用户")) self.client.send(_("没有系统用户"))
return return
forwarder = ProxyServer(self.app, self.client) forwarder = ProxyServer(self.client)
forwarder.proxy(asset, system_user) forwarder.proxy(asset, system_user)
def interact(self): def interact(self):
self.display_banner() self.display_banner()
while True: while True:
try: try:
opt = self.get_option() opt = net_input(self.client, prompt='Opt>', before=1)
rv = self.dispatch(opt) rv = self.dispatch(opt)
if rv is self._sentinel: if rv is self._sentinel:
break break
...@@ -318,7 +277,7 @@ class InteractiveServer: ...@@ -318,7 +277,7 @@ class InteractiveServer:
thread.start() thread.start()
def close(self): def close(self):
self.app.remove_client(self.client) current_app.remove_client(self.client)
# def __del__(self): # def __del__(self):
# print("GC: Interactive class been gc") # print("GC: Interactive class been gc")
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import paramiko import paramiko
import threading import threading
import weakref
from .utils import get_logger from .utils import get_logger
from .ctx import current_app, app_service
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface):
https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py
""" """
def __init__(self, app, request): def __init__(self, request):
self._app = weakref.ref(app) self.request = request
self._request = weakref.ref(request)
self.event = threading.Event() self.event = threading.Event()
self.auth_valid = False self.auth_valid = False
self.otp_auth = False self.otp_auth = False
self.info = None self.info = None
@property
def app(self):
return self._app()
@property
def request(self):
return self._request()
def check_auth_interactive(self, username, submethods): def check_auth_interactive(self, username, submethods):
logger.info("Check auth interactive: %s %s" % (username, submethods)) logger.info("Check auth interactive: %s %s" % (username, submethods))
instructions = 'Please enter 6 digits.' instructions = 'Please enter 6 digits.'
...@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface):
if not seed: if not seed:
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
is_valid = self.app.service.authenticate_otp(seed, otp_code) is_valid = app_service.authenticate_otp(seed, otp_code)
if is_valid: if is_valid:
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
...@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface):
supported = [] supported = []
if self.otp_auth: if self.otp_auth:
return 'keyboard-interactive' return 'keyboard-interactive'
if self.app.config["PASSWORD_AUTH"]: if current_app.config["PASSWORD_AUTH"]:
supported.append("password") supported.append("password")
if self.app.config["PUBLIC_KEY_AUTH"]: if current_app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey") supported.append("publickey")
return ",".join(supported) return ",".join(supported)
...@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface):
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
def validate_auth(self, username, password="", public_key=""): def validate_auth(self, username, password="", public_key=""):
info = self.app.service.authenticate( info = app_service.authenticate(
username, password=password, public_key=public_key, username, password=password, public_key=public_key,
remote_addr=self.request.remote_ip remote_addr=self.request.remote_ip
) )
......
...@@ -218,7 +218,7 @@ class WSProxy: ...@@ -218,7 +218,7 @@ class WSProxy:
``` ```
""" """
def __init__(self, ws, child, room, connection): def __init__(self, ws, child, room_id):
""" """
:param ws: websocket instance or handler, have write_message method :param ws: websocket instance or handler, have write_message method
:param child: sock child pair :param child: sock child pair
...@@ -226,9 +226,8 @@ class WSProxy: ...@@ -226,9 +226,8 @@ class WSProxy:
self.ws = ws self.ws = ws
self.child = child self.child = child
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.room = room self.room_id = room_id
self.auto_forward() self.auto_forward()
self.connection = connection
def send(self, msg): def send(self, msg):
""" """
...@@ -252,7 +251,9 @@ class WSProxy: ...@@ -252,7 +251,9 @@ class WSProxy:
if len(data) == 0: if len(data) == 0:
self.close() self.close()
data = data.decode(errors="ignore") data = data.decode(errors="ignore")
self.ws.emit("data", {'data': data, 'room': self.connection}, room=self.room) print("Send data: {}".format(data))
self.ws.emit("data", {'data': data, 'room': self.room_id},
room=self.room_id)
if len(data) == BUF_SIZE: if len(data) == BUF_SIZE:
time.sleep(0.1) time.sleep(0.1)
...@@ -265,7 +266,6 @@ class WSProxy: ...@@ -265,7 +266,6 @@ class WSProxy:
self.stop_event.set() self.stop_event.set()
self.child.shutdown(1) self.child.shutdown(1)
self.child.close() self.child.close()
self.ws.logout(self.connection)
logger.debug("Proxy {} closed".format(self)) logger.debug("Proxy {} closed".format(self))
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
import threading import threading
import time import time
import weakref
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from .session import Session from .session import Session
from .models import Server from .models import Server
from .connection import SSHConnection from .connection import SSHConnection
from .ctx import current_app, app_service
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \ from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_logger, net_input get_logger, net_input
...@@ -21,24 +21,19 @@ BUF_SIZE = 4096 ...@@ -21,24 +21,19 @@ BUF_SIZE = 4096
class ProxyServer: class ProxyServer:
def __init__(self, app, client): def __init__(self, client):
self._app = weakref.ref(app)
self.client = client self.client = client
self.server = None self.server = None
self.connecting = True self.connecting = True
self.stop_event = threading.Event() self.stop_event = threading.Event()
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user): def get_system_user_auth(self, system_user):
""" """
获取系统用户的认证信息,密码或秘钥 获取系统用户的认证信息,密码或秘钥
:return: system user have full info :return: system user have full info
""" """
password, private_key = \ password, private_key = \
self.app.service.get_system_user_auth_info(system_user) app_service.get_system_user_auth_info(system_user)
if not password and not private_key: if not password and not private_key:
prompt = "{}'s password: ".format(system_user.username) prompt = "{}'s password: ".format(system_user.username)
password = net_input(self.client, prompt=prompt, sensitive=True) password = net_input(self.client, prompt=prompt, sensitive=True)
...@@ -51,26 +46,26 @@ class ProxyServer: ...@@ -51,26 +46,26 @@ class ProxyServer:
self.server = self.get_server_conn(asset, system_user) self.server = self.get_server_conn(asset, system_user)
if self.server is None: if self.server is None:
return return
command_recorder = self.app.new_command_recorder() command_recorder = current_app.new_command_recorder()
replay_recorder = self.app.new_replay_recorder() replay_recorder = current_app.new_replay_recorder()
session = Session( session = Session(
self.client, self.server, self.client, self.server,
command_recorder=command_recorder, command_recorder=command_recorder,
replay_recorder=replay_recorder, replay_recorder=replay_recorder,
) )
self.app.add_session(session) current_app.add_session(session)
self.watch_win_size_change_async() self.watch_win_size_change_async()
session.bridge() session.bridge()
self.stop_event.set() self.stop_event.set()
self.end_watch_win_size_change() self.end_watch_win_size_change()
self.app.remove_session(session) current_app.remove_session(session)
def validate_permission(self, asset, system_user): def validate_permission(self, asset, system_user):
""" """
验证用户是否有连接改资产的权限 验证用户是否有连接改资产的权限
:return: True or False :return: True or False
""" """
return self.app.service.validate_user_asset_permission( return app_service.validate_user_asset_permission(
self.client.user.id, asset.id, system_user.id self.client.user.id, asset.id, system_user.id
) )
...@@ -90,13 +85,14 @@ class ProxyServer: ...@@ -90,13 +85,14 @@ class ProxyServer:
pass pass
def get_ssh_server_conn(self, asset, system_user): def get_ssh_server_conn(self, asset, system_user):
ssh = SSHConnection(self.app)
request = self.client.request request = self.client.request
term = request.meta.get('term', 'xterm') term = request.meta.get('term', 'xterm')
width = request.meta.get('width', 80) width = request.meta.get('width', 80)
height = request.meta.get('height', 24) height = request.meta.get('height', 24)
chan, msg = ssh.get_channel(asset, system_user, term=term, ssh = SSHConnection()
width=width, height=height) chan, msg = ssh.get_channel(
asset, system_user, term=term, width=width, height=height
)
if not chan: if not chan:
self.client.send(warning(wr(msg, before=1, after=0))) self.client.send(warning(wr(msg, before=1, after=0)))
server = None server = None
...@@ -130,9 +126,11 @@ class ProxyServer: ...@@ -130,9 +126,11 @@ class ProxyServer:
def send_connecting_message(self, asset, system_user): def send_connecting_message(self, asset, system_user):
def func(): def func():
delay = 0.0 delay = 0.0
self.client.send('Connecting to {}@{} {:.1f}'.format(system_user, asset, delay)) self.client.send('Connecting to {}@{} {:.1f}'.format(
system_user, asset, delay)
)
while self.connecting and delay < TIMEOUT: while self.connecting and delay < TIMEOUT:
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode('utf-8')) self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode())
time.sleep(0.1) time.sleep(0.1)
delay += 0.1 delay += 0.1
thread = threading.Thread(target=func) thread = threading.Thread(target=func)
......
...@@ -8,33 +8,19 @@ import time ...@@ -8,33 +8,19 @@ import time
import os import os
import gzip import gzip
import json import json
import shutil
import jms_storage import jms_storage
from .utils import get_logger from .utils import get_logger, Singleton
from .alignment import MemoryQueue from .alignment import MemoryQueue
from .ctx import current_app, app_service
logger = get_logger(__file__) logger = get_logger(__file__)
BUF_SIZE = 1024 BUF_SIZE = 1024
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
class ReplayRecorder(metaclass=abc.ABCMeta): class ReplayRecorder(metaclass=abc.ABCMeta):
def __init__(self, app, session=None): def __init__(self, session=None):
self.app = app
self.session = session self.session = session
@abc.abstractmethod @abc.abstractmethod
...@@ -61,8 +47,7 @@ class ReplayRecorder(metaclass=abc.ABCMeta): ...@@ -61,8 +47,7 @@ class ReplayRecorder(metaclass=abc.ABCMeta):
class CommandRecorder: class CommandRecorder:
def __init__(self, app, session=None): def __init__(self, session=None):
self.app = app
self.session = session self.session = session
def record(self, data): def record(self, data):
...@@ -92,8 +77,8 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -92,8 +77,8 @@ class ServerReplayRecorder(ReplayRecorder):
time_start = None time_start = None
storage = None storage = None
def __init__(self, app): def __init__(self):
super().__init__(app) super().__init__()
self.file = None self.file = None
self.file_path = None self.file_path = None
...@@ -114,8 +99,8 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -114,8 +99,8 @@ class ServerReplayRecorder(ReplayRecorder):
def session_start(self, session_id): def session_start(self, session_id):
self.time_start = time.time() self.time_start = time.time()
filename = session_id+'.replay.gz' filename = session_id + '.replay.gz'
self.file_path = os.path.join(self.app.config['LOG_DIR'], filename) self.file_path = os.path.join(current_app.config['LOG_DIR'], filename)
self.file = gzip.open(self.file_path, 'at') self.file = gzip.open(self.file_path, 'at')
self.file.write('{') self.file.write('{')
...@@ -128,11 +113,11 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -128,11 +113,11 @@ class ServerReplayRecorder(ReplayRecorder):
logger.error("Failed to push {}'s {}".format(session_id, "record")) logger.error("Failed to push {}'s {}".format(session_id, "record"))
def upload_replay(self, session_id): def upload_replay(self, session_id):
configs = self.app.service.load_config_from_server() configs = app_service.load_config_from_server()
logger.debug("upload_replay print config: {}".format(configs)) logger.debug("upload_replay print config: {}".format(configs))
self.storage = jms_storage.init(configs["REPLAY_STORAGE"]) self.storage = jms_storage.init(configs["REPLAY_STORAGE"])
if not self.storage: if not self.storage:
self.storage = jms_storage.jms(self.app.service) self.storage = jms_storage.jms(app_service)
if self.push_file(3, session_id): if self.push_file(3, session_id):
os.unlink(self.file_path) os.unlink(self.file_path)
return True return True
...@@ -151,7 +136,7 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -151,7 +136,7 @@ class ServerReplayRecorder(ReplayRecorder):
else: else:
msg = "Failed push session {}'s replay log to storage".format(session_id) msg = "Failed push session {}'s replay log to storage".format(session_id)
logger.error(msg) logger.error(msg)
self.storage = jms_storage.jms(self.app.service) self.storage = jms_storage.jms(app_service)
return self.push_file(3, session_id) return self.push_file(3, session_id)
if self.push_to_storage(session_id): if self.push_to_storage(session_id):
...@@ -167,7 +152,7 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -167,7 +152,7 @@ class ServerReplayRecorder(ReplayRecorder):
logger.error("Failed finished session {}'s replay".format(session_id)) logger.error("Failed finished session {}'s replay".format(session_id))
return False return False
if self.app.service.finish_replay(session_id): if app_service.finish_replay(session_id):
logger.info("Success finish session {}'s replay ".format(session_id)) logger.info("Success finish session {}'s replay ".format(session_id))
return True return True
else: else:
...@@ -180,8 +165,8 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -180,8 +165,8 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
timeout = 5 timeout = 5
no = 0 no = 0
def __init__(self, app): def __init__(self):
super().__init__(app) super().__init__()
self.queue = MemoryQueue() self.queue = MemoryQueue()
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.push_to_server_async() self.push_to_server_async()
...@@ -204,7 +189,7 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -204,7 +189,7 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
if not data_set: if not data_set:
continue continue
logger.debug("Send {} commands to server".format(len(data_set))) logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.app.service.push_session_command(data_set) ok = app_service.push_session_command(data_set)
if not ok: if not ok:
self.queue.mput(data_set) self.queue.mput(data_set)
...@@ -228,13 +213,15 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -228,13 +213,15 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
no = 0 no = 0
default_hosts = ["http://localhost"] default_hosts = ["http://localhost"]
def __init__(self, app): def __init__(self):
super().__init__(app) super().__init__()
self.queue = MemoryQueue() self.queue = MemoryQueue()
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.push_to_es_async() self.push_to_es_async()
self.__class__.no += 1 self.__class__.no += 1
self.store = jms_storage.ESStore(app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts)) self.store = jms_storage.ESStore(
current_app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts)
)
if not self.store.ping(): if not self.store.ping():
raise AssertionError("ESCommand storage init error") raise AssertionError("ESCommand storage init error")
......
...@@ -40,7 +40,7 @@ class Session: ...@@ -40,7 +40,7 @@ class Session:
""" """
logger.info("Session add watcher: {} -> {} ".format(self.id, watcher)) logger.info("Session add watcher: {} -> {} ".format(self.id, watcher))
if not silent: if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8")) watcher.send("Welcome to watch session {}\r\n".format(self.id).encode())
self.sel.register(watcher, selectors.EVENT_READ) self.sel.register(watcher, selectors.EVENT_READ)
self._watchers.append(watcher) self._watchers.append(watcher)
......
...@@ -34,7 +34,7 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -34,7 +34,7 @@ class SFTPServer(paramiko.SFTPServerInterface):
self._sftp[host] = sftp self._sftp[host] = sftp
return sftp return sftp
else: else:
raise OSError("Can not connect asset sftp server") raise OSError("Can not connect asset sftp server: {}".format(msg))
else: else:
return self._sftp[host] return self._sftp[host]
......
...@@ -12,6 +12,7 @@ from .interface import SSHInterface ...@@ -12,6 +12,7 @@ from .interface import SSHInterface
from .interactive import InteractiveServer from .interactive import InteractiveServer
from .models import Client, Request from .models import Client, Request
from .sftp import SFTPServer from .sftp import SFTPServer
from .ctx import current_app
logger = get_logger(__file__) logger = get_logger(__file__)
BACKLOG = 5 BACKLOG = 5
...@@ -19,38 +20,41 @@ BACKLOG = 5 ...@@ -19,38 +20,41 @@ BACKLOG = 5
class SSHServer: class SSHServer:
def __init__(self, app): def __init__(self):
self.app = app
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.workers = []
self.host_key_path = os.path.join(self.app.root_path, 'keys', 'host_rsa_key') self.pipe = None
@property @property
def host_key(self): def host_key(self):
if not os.path.isfile(self.host_key_path): host_key_path = os.path.join(current_app.root_path, 'keys', 'host_rsa_key')
self.gen_host_key() if not os.path.isfile(host_key_path):
return paramiko.RSAKey(filename=self.host_key_path) self.gen_host_key(host_key_path)
return paramiko.RSAKey(filename=host_key_path)
def gen_host_key(self): @staticmethod
def gen_host_key(key_path):
ssh_key, _ = ssh_key_gen() ssh_key, _ = ssh_key_gen()
with open(self.host_key_path, 'w') as f: with open(key_path, 'w') as f:
f.write(ssh_key) f.write(ssh_key)
def run(self): def run(self):
host = self.app.config["BIND_HOST"] host = current_app.config["BIND_HOST"]
port = self.app.config["SSHD_PORT"] port = current_app.config["SSHD_PORT"]
print('Starting ssh server at {}:{}'.format(host, port)) print('Starting ssh server at {}:{}'.format(host, port))
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.bind((host, port)) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.listen(BACKLOG) sock.bind((host, port))
sock.listen(BACKLOG)
while not self.stop_evt.is_set(): while not self.stop_evt.is_set():
try: try:
sock, addr = self.sock.accept() client, addr = sock.accept()
logger.info("Get ssh request from {}: {}".format(addr[0], addr[1])) logger.info("Get ssh request from {}: {}".format(*addr))
thread = threading.Thread(target=self.handle_connection, args=(sock, addr)) thread = threading.Thread(target=self.handle_connection,
args=(client, addr))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
except Exception as e: except IndexError as e:
logger.error("Start SSH server error: {}".format(e)) logger.error("Start SSH server error: {}".format(e))
def handle_connection(self, sock, addr): def handle_connection(self, sock, addr):
...@@ -65,7 +69,7 @@ class SSHServer: ...@@ -65,7 +69,7 @@ class SSHServer:
'sftp', paramiko.SFTPServer, SFTPServer 'sftp', paramiko.SFTPServer, SFTPServer
) )
request = Request(addr) request = Request(addr)
server = SSHInterface(self.app, request) server = SSHInterface(request)
try: try:
transport.start_server(server=server) transport.start_server(server=server)
except paramiko.SSHException: except paramiko.SSHException:
...@@ -96,7 +100,7 @@ class SSHServer: ...@@ -96,7 +100,7 @@ class SSHServer:
def handle_chan(self, chan, request): def handle_chan(self, chan, request):
client = Client(chan, request) client = Client(chan, request)
self.app.add_client(client) current_app.add_client(client)
self.dispatch(client) self.dispatch(client)
def dispatch(self, client): def dispatch(self, client):
...@@ -104,7 +108,7 @@ class SSHServer: ...@@ -104,7 +108,7 @@ class SSHServer:
request_type = set(client.request.type) request_type = set(client.request.type)
if supported & request_type: if supported & request_type:
logger.info("Request type `pty`, dispatch to interactive mode") logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact() InteractiveServer(client).interact()
elif 'subsystem' in request_type: elif 'subsystem' in request_type:
pass pass
else: else:
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import weakref
from .ctx import current_app, app_service
from .utils import get_logger from .utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
class TaskHandler: class TaskHandler:
routes = None
def __init__(self, app): def init(self):
self._app = weakref.ref(app) self.routes = {
'kill_session': self.handle_kill_session
}
@property @staticmethod
def app(self): def handle_kill_session(task):
return self._app()
def handle_kill_session(self, task):
logger.info("Handle kill session task: {}".format(task.args)) logger.info("Handle kill session task: {}".format(task.args))
session_id = task.args session_id = task.args
session = None session = None
for s in self.app.sessions: for s in current_app.sessions:
if s.id == session_id: if s.id == session_id:
session = s session = s
break break
if session: if session:
session.terminate() session.terminate()
self.app.service.finish_task(task.id) app_service.finish_task(task.id)
def handle(self, task): def handle(self, task):
if task.name == "kill_session": func = self.routes.get(task.name)
self.handle_kill_session(task) return func(task)
else:
logger.error("No handler for this task: {}".format(task.name))
...@@ -15,10 +15,24 @@ import paramiko ...@@ -15,10 +15,24 @@ import paramiko
import pyte import pyte
from . import char from . import char
from .ctx import stack
BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
def ssh_key_string_to_obj(text, password=None): def ssh_key_string_to_obj(text, password=None):
key = None key = None
try: try:
...@@ -280,27 +294,14 @@ def get_logger(file_name): ...@@ -280,27 +294,14 @@ def get_logger(file_name):
return logging.getLogger('coco.'+file_name) return logging.getLogger('coco.'+file_name)
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
def len_display(s):
length = 0
for i in s:
if zh_pattern.match(i):
length += 2
else:
length += 1
return length
def net_input(client, prompt='Opt> ', sensitive=False):
"""实现了一个ssh input, 提示用户输入, 获取并返回 """实现了一个ssh input, 提示用户输入, 获取并返回
:return user input string :return user input string
""" """
input_data = [] input_data = []
parser = TtyIOParser() parser = TtyIOParser()
client.send(wrap_with_line_feed(prompt, before=0, after=0)) client.send(wrap_with_line_feed(prompt, before=before, after=after))
while True: while True:
data = client.recv(10) data = client.recv(10)
...@@ -355,4 +356,67 @@ def net_input(client, prompt='Opt> ', sensitive=False): ...@@ -355,4 +356,67 @@ def net_input(client, prompt='Opt> ', sensitive=False):
input_data.append(data) input_data.append(data)
def register_app(app):
stack['app'] = app
def register_service(service):
stack['service'] = service
zh_pattern = re.compile(r'[\u4e00-\u9fa5]')
def find_chinese(s):
return zh_pattern.findall(s)
def align_with_zh(s, length, addin=' '):
if not isinstance(s, str):
s = str(s)
zh_len = len(find_chinese(s))
padding = length - (len(s) - zh_len) - zh_len*2
padding_content = ''
if padding > 0:
padding_content = addin*padding
return s + padding_content
def format_with_zh(size_list, *args):
data = []
for length, s in zip(size_list, args):
data.append(align_with_zh(s, length))
return ' '.join(data)
def size_of_str_with_zh(s):
if isinstance(s, int):
s = str(s)
try:
chinese = find_chinese(s)
except TypeError:
print(type(s))
raise
return len(s) + len(chinese)
def item_max_length(_iter, maxi=None, mini=None, key=None):
if key:
_iter = [key(i) for i in _iter]
length = [size_of_str_with_zh(s) for s in _iter]
if maxi:
length.append(maxi)
length = max(length)
if mini and length < mini:
length = mini
return length
def int_length(i):
return len(str(i))
ugettext = _gettext() ugettext = _gettext()
...@@ -28,7 +28,7 @@ psutil==5.4.1 ...@@ -28,7 +28,7 @@ psutil==5.4.1
pyasn1==0.4.2 pyasn1==0.4.2
pycparser==2.18 pycparser==2.18
PyNaCl==1.2.1 PyNaCl==1.2.1
pyte==0.7.0 pyte==0.8.0
python-dateutil==2.6.1 python-dateutil==2.6.1
python-engineio==2.1.0 python-engineio==2.1.0
python-gssapi==0.6.4 python-gssapi==0.6.4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment