Unverified Commit 40ac8f51 authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #58 from jumpserver/dev

添加eventlet支持websocket模式,优化storage等
parents 14ccd1b0 64b655aa
...@@ -9,6 +9,8 @@ import threading ...@@ -9,6 +9,8 @@ import threading
import socket import socket
import json import json
import signal import signal
import eventlet
from eventlet.debug import hub_prevent_multiple_readers
from jms.service import AppService from jms.service import AppService
...@@ -17,9 +19,11 @@ from .sshd import SSHServer ...@@ -17,9 +19,11 @@ from .sshd import SSHServer
from .httpd import HttpServer from .httpd import HttpServer
from .logger import create_logger from .logger import create_logger
from .tasks import TaskHandler from .tasks import TaskHandler
from .recorder import get_command_recorder_class, ServerReplayRecorder from .recorder import ReplayRecorder, CommandRecorder
from .utils import get_logger from .utils import get_logger, register_app, register_service
eventlet.monkey_patch()
hub_prevent_multiple_readers(False)
__version__ = '1.3.0' __version__ = '1.3.0'
...@@ -56,7 +60,6 @@ class Coco: ...@@ -56,7 +60,6 @@ class Coco:
def __init__(self, root_path=None): def __init__(self, root_path=None):
self.root_path = root_path if root_path else BASE_DIR self.root_path = root_path if root_path else BASE_DIR
self.config = self.config_class(self.root_path, defaults=self.default_config)
self.sessions = [] self.sessions = []
self.clients = [] self.clients = []
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -67,6 +70,14 @@ class Coco: ...@@ -67,6 +70,14 @@ class Coco:
self.replay_recorder_class = None self.replay_recorder_class = None
self.command_recorder_class = None self.command_recorder_class = None
self._task_handler = None self._task_handler = None
self.config = None
self.init_config()
register_app(self)
def init_config(self):
self.config = self.config_class(
self.root_path, defaults=self.default_config
)
@property @property
def name(self): def name(self):
...@@ -79,24 +90,25 @@ class Coco: ...@@ -79,24 +90,25 @@ class Coco:
def service(self): def service(self):
if self._service is None: if self._service is None:
self._service = AppService(self) self._service = AppService(self)
register_service(self._service)
return self._service return self._service
@property @property
def sshd(self): def sshd(self):
if self._sshd is None: if self._sshd is None:
self._sshd = SSHServer(self) self._sshd = SSHServer()
return self._sshd return self._sshd
@property @property
def httpd(self): def httpd(self):
if self._httpd is None: if self._httpd is None:
self._httpd = HttpServer(self) self._httpd = HttpServer()
return self._httpd return self._httpd
@property @property
def task_handler(self): def task_handler(self):
if self._task_handler is None: if self._task_handler is None:
self._task_handler = TaskHandler(self) self._task_handler = TaskHandler()
return self._task_handler return self._task_handler
def make_logger(self): def make_logger(self):
...@@ -109,24 +121,21 @@ class Coco: ...@@ -109,24 +121,21 @@ class Coco:
)) ))
self.config.update(configs) self.config.update(configs)
def get_recorder_class(self): @staticmethod
self.replay_recorder_class = ServerReplayRecorder def new_command_recorder():
self.command_recorder_class = get_command_recorder_class(self.config) return CommandRecorder()
def new_command_recorder(self):
recorder = self.command_recorder_class(self)
return recorder
def new_replay_recorder(self): @staticmethod
return self.replay_recorder_class(self) def new_replay_recorder():
return ReplayRecorder()
def bootstrap(self): def bootstrap(self):
self.make_logger() self.make_logger()
self.service.initial() self.service.initial()
self.load_extra_conf_from_server() self.load_extra_conf_from_server()
self.get_recorder_class()
self.keep_heartbeat() self.keep_heartbeat()
self.monitor_sessions() self.monitor_sessions()
self.monitor_sessions_replay()
def heartbeat(self): def heartbeat(self):
_sessions = [s.to_json() for s in self.sessions] _sessions = [s.to_json() for s in self.sessions]
...@@ -155,6 +164,31 @@ class Coco: ...@@ -155,6 +164,31 @@ class Coco:
thread = threading.Thread(target=func) thread = threading.Thread(target=func)
thread.start() thread.start()
def monitor_sessions_replay(self):
interval = 10
recorder = self.new_replay_recorder()
log_dir = os.path.join(self.config['LOG_DIR'])
def func():
while not self.stop_evt.is_set():
active_sessions = [str(session.id) for session in self.sessions]
for filename in os.listdir(log_dir):
session_id = filename.split('.')[0]
full_path = os.path.join(log_dir, filename)
if len(session_id) != 36:
continue
if session_id not in active_sessions:
recorder.file_path = full_path
ok = recorder.upload_replay(session_id, 1)
if not ok and os.path.getsize(full_path) == 0:
os.unlink(full_path)
time.sleep(interval)
thread = threading.Thread(target=func)
thread.start()
def monitor_sessions(self): def monitor_sessions(self):
interval = self.config["HEARTBEAT_INTERVAL"] interval = self.config["HEARTBEAT_INTERVAL"]
...@@ -188,9 +222,11 @@ class Coco: ...@@ -188,9 +222,11 @@ class Coco:
self.run_httpd() self.run_httpd()
signal.signal(signal.SIGTERM, lambda x, y: self.shutdown()) signal.signal(signal.SIGTERM, lambda x, y: self.shutdown())
while self.stop_evt.wait(5): while True:
print("Coco receive term signal, exit") if self.stop_evt.is_set():
break print("Coco receive term signal, exit")
break
time.sleep(3)
except KeyboardInterrupt: except KeyboardInterrupt:
self.stop_evt.set() self.stop_evt.set()
self.shutdown() self.shutdown()
...@@ -218,13 +254,19 @@ class Coco: ...@@ -218,13 +254,19 @@ class Coco:
def add_client(self, client): def add_client(self, client):
with self.lock: with self.lock:
self.clients.append(client) self.clients.append(client)
logger.info("New client {} join, total {} now".format(client, len(self.clients))) logger.info("New client {} join, total {} now".format(
client, len(self.clients)
)
)
def remove_client(self, client): def remove_client(self, client):
with self.lock: with self.lock:
try: try:
self.clients.remove(client) self.clients.remove(client)
logger.info("Client {} leave, total {} now".format(client, len(self.clients))) logger.info("Client {} leave, total {} now".format(
client, len(self.clients)
)
)
client.close() client.close()
except: except:
pass pass
...@@ -241,4 +283,5 @@ class Coco: ...@@ -241,4 +283,5 @@ class Coco:
self.sessions.remove(session) self.sessions.remove(session)
self.service.finish_session(session.to_json()) self.service.finish_session(session.to_json())
except ValueError: except ValueError:
logger.warning("Remove session: {} fail, maybe already removed".format(session)) msg = "Remove session: {} fail, maybe already removed"
\ No newline at end of file logger.warning(msg.format(session))
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import weakref
import os import os
import socket import socket
import paramiko import paramiko
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from .ctx import app_service
from .utils import get_logger, get_private_key_fingerprint from .utils import get_logger, get_private_key_fingerprint
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -15,21 +15,26 @@ TIMEOUT = 10 ...@@ -15,21 +15,26 @@ TIMEOUT = 10
class SSHConnection: class SSHConnection:
def __init__(self, app): def get_system_user_auth(self, system_user):
self._app = weakref.ref(app) """
获取系统用户的认证信息,密码或秘钥
@property :return: system user have full info
def app(self): """
return self._app() password, private_key = \
app_service.get_system_user_auth_info(system_user)
system_user.password = password
system_user.private_key = private_key
def get_ssh_client(self, asset, system_user): def get_ssh_client(self, asset, system_user):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sock = None sock = None
self.get_system_user_auth(system_user)
if not system_user.password and not system_user.private_key:
self.get_system_user_auth(system_user)
if asset.domain: if asset.domain:
sock = self.get_proxy_sock(asset) sock = self.get_proxy_sock_v2(asset)
try: try:
ssh.connect( ssh.connect(
...@@ -56,44 +61,62 @@ class SSHConnection: ...@@ -56,44 +61,62 @@ class SSHConnection:
system_user.username, asset.ip, asset.port, system_user.username, asset.ip, asset.port,
password_short, key_fingerprint, password_short, key_fingerprint,
)) ))
return None, str(e) return None, None, str(e)
except (socket.error, TimeoutError) as e: except (socket.error, TimeoutError) as e:
return None, str(e) return None, None, str(e)
return ssh, None return ssh, sock, None
def get_transport(self, asset, system_user): def get_transport(self, asset, system_user):
ssh, msg = self.get_ssh_client(asset, system_user) ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh: if ssh:
return ssh.get_transport(), None return ssh.get_transport(), sock, None
else: else:
return None, msg return None, None, msg
def get_channel(self, asset, system_user, term="xterm", width=80, height=24): def get_channel(self, asset, system_user, term="xterm", width=80, height=24):
ssh, msg = self.get_ssh_client(asset, system_user) ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh: if ssh:
chan = ssh.invoke_shell(term, width=width, height=height) chan = ssh.invoke_shell(term, width=width, height=height)
return chan, None return chan, sock, None
else: else:
return None, msg return None, sock, msg
def get_sftp(self, asset, system_user): def get_sftp(self, asset, system_user):
ssh, msg = self.get_ssh_client(asset, system_user) ssh, sock, msg = self.get_ssh_client(asset, system_user)
if ssh: if ssh:
return ssh.open_sftp(), None return ssh.open_sftp(), sock, None
else: else:
return None, msg return None, sock, msg
def get_system_user_auth(self, system_user): @staticmethod
""" def get_proxy_sock_v2(asset):
获取系统用户的认证信息,密码或秘钥 sock = None
:return: system user have full info domain = app_service.get_domain_detail_with_gateway(
""" asset.domain
system_user.password, system_user.private_key = \ )
self.app.service.get_system_user_auth_info(system_user) if not domain.has_ssh_gateway():
return None
for i in domain.gateways:
gateway = domain.random_ssh_gateway()
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ssh.connect(gateway.ip, username=gateway.username,
password=gateway.password,
pkey=gateway.private_key_obj)
except(paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
SSHException):
continue
sock = ssh.get_transport().open_channel(
'direct-tcpip', (asset.ip, asset.port), ('127.0.0.1', 0)
)
break
return sock
def get_proxy_sock(self, asset): def get_proxy_sock(self, asset):
sock = None sock = None
domain = self.app.service.get_domain_detail_with_gateway( domain = app_service.get_domain_detail_with_gateway(
asset.domain asset.domain
) )
if not domain.has_ssh_gateway(): if not domain.has_ssh_gateway():
......
# -*- coding: utf-8 -*-
#
from werkzeug.local import LocalProxy
from functools import partial
stack = {}
def _find(name):
if stack.get(name):
return stack[name]
else:
raise ValueError("Not found in stack: {}".format(name))
current_app = LocalProxy(partial(_find, 'app'))
app_service = LocalProxy(partial(_find, 'service'))
# current_app = []
# current_service = []
This diff is collapsed.
This diff is collapsed.
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import paramiko import paramiko
import threading import threading
import weakref
from .utils import get_logger from .utils import get_logger
from .ctx import current_app, app_service
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface):
https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py
""" """
def __init__(self, app, request): def __init__(self, request):
self._app = weakref.ref(app) self.request = request
self._request = weakref.ref(request)
self.event = threading.Event() self.event = threading.Event()
self.auth_valid = False self.auth_valid = False
self.otp_auth = False self.otp_auth = False
self.info = None self.info = None
@property
def app(self):
return self._app()
@property
def request(self):
return self._request()
def check_auth_interactive(self, username, submethods): def check_auth_interactive(self, username, submethods):
logger.info("Check auth interactive: %s %s" % (username, submethods)) logger.info("Check auth interactive: %s %s" % (username, submethods))
instructions = 'Please enter 6 digits.' instructions = 'Please enter 6 digits.'
...@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface):
if not seed: if not seed:
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
is_valid = self.app.service.authenticate_otp(seed, otp_code) is_valid = app_service.authenticate_otp(seed, otp_code)
if is_valid: if is_valid:
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
...@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface):
supported = [] supported = []
if self.otp_auth: if self.otp_auth:
return 'keyboard-interactive' return 'keyboard-interactive'
if self.app.config["PASSWORD_AUTH"]: if current_app.config["PASSWORD_AUTH"]:
supported.append("password") supported.append("password")
if self.app.config["PUBLIC_KEY_AUTH"]: if current_app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey") supported.append("publickey")
return ",".join(supported) return ",".join(supported)
...@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface):
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
def validate_auth(self, username, password="", public_key=""): def validate_auth(self, username, password="", public_key=""):
info = self.app.service.authenticate( info = app_service.authenticate(
username, password=password, public_key=public_key, username, password=password, public_key=public_key,
remote_addr=self.request.remote_ip remote_addr=self.request.remote_ip
) )
......
...@@ -49,6 +49,8 @@ def create_logger(app): ...@@ -49,6 +49,8 @@ def create_logger(app):
'coco': main_setting, 'coco': main_setting,
'paramiko': main_setting, 'paramiko': main_setting,
'jms': main_setting, 'jms': main_setting,
'socket.io': main_setting,
'engineio': main_setting,
} }
) )
......
...@@ -94,8 +94,9 @@ class Server: ...@@ -94,8 +94,9 @@ class Server:
""" """
# Todo: Server name is not very suitable # Todo: Server name is not very suitable
def __init__(self, chan, asset, system_user): def __init__(self, chan, sock, asset, system_user):
self.chan = chan self.chan = chan
self.sock = sock
self.asset = asset self.asset = asset
self.system_user = system_user self.system_user = system_user
self.send_bytes = 0 self.send_bytes = 0
...@@ -168,6 +169,8 @@ class Server: ...@@ -168,6 +169,8 @@ class Server:
self.stop_evt.set() self.stop_evt.set()
self.chan.close() self.chan.close()
self.chan.transport.close() self.chan.transport.close()
if self.sock:
self.sock.transport.close()
@staticmethod @staticmethod
def _have_enter_char(s): def _have_enter_char(s):
...@@ -218,7 +221,7 @@ class WSProxy: ...@@ -218,7 +221,7 @@ class WSProxy:
``` ```
""" """
def __init__(self, ws, child, room, connection): def __init__(self, ws, child, room_id):
""" """
:param ws: websocket instance or handler, have write_message method :param ws: websocket instance or handler, have write_message method
:param child: sock child pair :param child: sock child pair
...@@ -226,9 +229,8 @@ class WSProxy: ...@@ -226,9 +229,8 @@ class WSProxy:
self.ws = ws self.ws = ws
self.child = child self.child = child
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.room = room self.room_id = room_id
self.auto_forward() self.auto_forward()
self.connection = connection
def send(self, msg): def send(self, msg):
""" """
...@@ -247,12 +249,15 @@ class WSProxy: ...@@ -247,12 +249,15 @@ class WSProxy:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
data = self.child.recv(BUF_SIZE) data = self.child.recv(BUF_SIZE)
except OSError: except (OSError, EOFError):
continue
if len(data) == 0:
self.close() self.close()
break
if not data:
self.close()
break
data = data.decode(errors="ignore") data = data.decode(errors="ignore")
self.ws.emit("data", {'data': data, 'room': self.connection}, room=self.room) self.ws.emit("data", {'data': data, 'room': self.room_id},
room=self.room_id)
if len(data) == BUF_SIZE: if len(data) == BUF_SIZE:
time.sleep(0.1) time.sleep(0.1)
...@@ -262,11 +267,12 @@ class WSProxy: ...@@ -262,11 +267,12 @@ class WSProxy:
thread.start() thread.start()
def close(self): def close(self):
self.ws.emit("logout", {"room": self.room_id}, room=self.room_id)
self.stop_event.set() self.stop_event.set()
self.child.close() try:
self.ws.logout(self.connection) self.child.shutdown(1)
self.child.close()
except (OSError, EOFError):
pass
logger.debug("Proxy {} closed".format(self)) logger.debug("Proxy {} closed".format(self))
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
import threading import threading
import time import time
import weakref
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from .session import Session from .session import Session
from .models import Server from .models import Server
from .connection import SSHConnection from .connection import SSHConnection
from .ctx import current_app, app_service
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \ from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_logger get_logger, net_input
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -21,42 +21,51 @@ BUF_SIZE = 4096 ...@@ -21,42 +21,51 @@ BUF_SIZE = 4096
class ProxyServer: class ProxyServer:
def __init__(self, app, client): def __init__(self, client):
self._app = weakref.ref(app)
self.client = client self.client = client
self.server = None self.server = None
self.connecting = True self.connecting = True
self.stop_event = threading.Event() self.stop_event = threading.Event()
@property def get_system_user_auth(self, system_user):
def app(self): """
return self._app() 获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
password, private_key = \
app_service.get_system_user_auth_info(system_user)
if not password and not private_key:
prompt = "{}'s password: ".format(system_user.username)
password = net_input(self.client, prompt=prompt, sensitive=True)
system_user.password = password
system_user.private_key = private_key
def proxy(self, asset, system_user): def proxy(self, asset, system_user):
self.get_system_user_auth(system_user)
self.send_connecting_message(asset, system_user) self.send_connecting_message(asset, system_user)
self.server = self.get_server_conn(asset, system_user) self.server = self.get_server_conn(asset, system_user)
if self.server is None: if self.server is None:
return return
command_recorder = self.app.new_command_recorder() command_recorder = current_app.new_command_recorder()
replay_recorder = self.app.new_replay_recorder() replay_recorder = current_app.new_replay_recorder()
session = Session( session = Session(
self.client, self.server, self.client, self.server,
command_recorder=command_recorder, command_recorder=command_recorder,
replay_recorder=replay_recorder, replay_recorder=replay_recorder,
) )
self.app.add_session(session) current_app.add_session(session)
self.watch_win_size_change_async() self.watch_win_size_change_async()
session.bridge() session.bridge()
self.stop_event.set() self.stop_event.set()
self.end_watch_win_size_change() self.end_watch_win_size_change()
self.app.remove_session(session) current_app.remove_session(session)
def validate_permission(self, asset, system_user): def validate_permission(self, asset, system_user):
""" """
验证用户是否有连接改资产的权限 验证用户是否有连接改资产的权限
:return: True or False :return: True or False
""" """
return self.app.service.validate_user_asset_permission( return app_service.validate_user_asset_permission(
self.client.user.id, asset.id, system_user.id self.client.user.id, asset.id, system_user.id
) )
...@@ -76,18 +85,19 @@ class ProxyServer: ...@@ -76,18 +85,19 @@ class ProxyServer:
pass pass
def get_ssh_server_conn(self, asset, system_user): def get_ssh_server_conn(self, asset, system_user):
ssh = SSHConnection(self.app)
request = self.client.request request = self.client.request
term = request.meta.get('term', 'xterm') term = request.meta.get('term', 'xterm')
width = request.meta.get('width', 80) width = request.meta.get('width', 80)
height = request.meta.get('height', 24) height = request.meta.get('height', 24)
chan, msg = ssh.get_channel(asset, system_user, term=term, ssh = SSHConnection()
width=width, height=height) chan, sock, msg = ssh.get_channel(
asset, system_user, term=term, width=width, height=height
)
if not chan: if not chan:
self.client.send(warning(wr(msg, before=1, after=0))) self.client.send(warning(wr(msg, before=1, after=0)))
server = None server = None
else: else:
server = Server(chan, asset, system_user) server = Server(chan, sock, asset, system_user)
self.connecting = False self.connecting = False
self.client.send(b'\r\n') self.client.send(b'\r\n')
return server return server
...@@ -116,9 +126,11 @@ class ProxyServer: ...@@ -116,9 +126,11 @@ class ProxyServer:
def send_connecting_message(self, asset, system_user): def send_connecting_message(self, asset, system_user):
def func(): def func():
delay = 0.0 delay = 0.0
self.client.send('Connecting to {}@{} {:.1f}'.format(system_user, asset, delay)) self.client.send('Connecting to {}@{} {:.1f}'.format(
system_user, asset, delay)
)
while self.connecting and delay < TIMEOUT: while self.connecting and delay < TIMEOUT:
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode('utf-8')) self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode())
time.sleep(0.1) time.sleep(0.1)
delay += 0.1 delay += 0.1
thread = threading.Thread(target=func) thread = threading.Thread(target=func)
......
This diff is collapsed.
...@@ -40,7 +40,7 @@ class Session: ...@@ -40,7 +40,7 @@ class Session:
""" """
logger.info("Session add watcher: {} -> {} ".format(self.id, watcher)) logger.info("Session add watcher: {} -> {} ".format(self.id, watcher))
if not silent: if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8")) watcher.send("Welcome to watch session {}\r\n".format(self.id).encode())
self.sel.register(watcher, selectors.EVENT_READ) self.sel.register(watcher, selectors.EVENT_READ)
self._watchers.append(watcher) self._watchers.append(watcher)
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import tempfile import tempfile
import paramiko import paramiko
import time import time
from .ctx import app_service
from datetime import datetime from datetime import datetime
from .connection import SSHConnection from .connection import SSHConnection
...@@ -16,6 +17,17 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -16,6 +17,17 @@ class SFTPServer(paramiko.SFTPServerInterface):
self._sftp = {} self._sftp = {}
self.hosts = self.get_perm_hosts() self.hosts = self.get_perm_hosts()
def session_ended(self):
super().session_ended()
for _, v in self._sftp.items():
sftp = v['sftp']
sock = v.get('sock')
sftp.close()
if sock:
sock.close()
sock.transport.close()
self._sftp = {}
def get_host_sftp(self, host, su): def get_host_sftp(self, host, su):
asset = self.hosts.get(host) asset = self.hosts.get(host)
system_user = None system_user = None
...@@ -28,18 +40,18 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -28,18 +40,18 @@ class SFTPServer(paramiko.SFTPServerInterface):
raise OSError("No asset or system user explicit") raise OSError("No asset or system user explicit")
if host not in self._sftp: if host not in self._sftp:
ssh = SSHConnection(self.server.app) ssh = SSHConnection()
sftp, msg = ssh.get_sftp(asset, system_user) sftp, sock, msg = ssh.get_sftp(asset, system_user)
if sftp: if sftp:
self._sftp[host] = sftp self._sftp[host] = {'sftp': sftp, 'sock': sock}
return sftp return sftp
else: else:
raise OSError("Can not connect asset sftp server") raise OSError("Can not connect asset sftp server: {}".format(msg))
else: else:
return self._sftp[host] return self._sftp[host]['sftp']
def get_perm_hosts(self): def get_perm_hosts(self):
assets = self.server.app.service.get_user_assets( assets = app_service.get_user_assets(
self.server.request.user self.server.request.user
) )
return {asset.hostname: asset for asset in assets} return {asset.hostname: asset for asset in assets}
...@@ -89,7 +101,7 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -89,7 +101,7 @@ class SFTPServer(paramiko.SFTPServerInterface):
"is_success": is_success, "is_success": is_success,
} }
for i in range(1, 4): for i in range(1, 4):
ok = self.server.app.service.create_ftp_log(data) ok = app_service.create_ftp_log(data)
if ok: if ok:
break break
else: else:
......
...@@ -12,6 +12,7 @@ from .interface import SSHInterface ...@@ -12,6 +12,7 @@ from .interface import SSHInterface
from .interactive import InteractiveServer from .interactive import InteractiveServer
from .models import Client, Request from .models import Client, Request
from .sftp import SFTPServer from .sftp import SFTPServer
from .ctx import current_app
logger = get_logger(__file__) logger = get_logger(__file__)
BACKLOG = 5 BACKLOG = 5
...@@ -19,38 +20,41 @@ BACKLOG = 5 ...@@ -19,38 +20,41 @@ BACKLOG = 5
class SSHServer: class SSHServer:
def __init__(self, app): def __init__(self):
self.app = app
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.workers = []
self.host_key_path = os.path.join(self.app.root_path, 'keys', 'host_rsa_key') self.pipe = None
@property @property
def host_key(self): def host_key(self):
if not os.path.isfile(self.host_key_path): host_key_path = os.path.join(current_app.root_path, 'keys', 'host_rsa_key')
self.gen_host_key() if not os.path.isfile(host_key_path):
return paramiko.RSAKey(filename=self.host_key_path) self.gen_host_key(host_key_path)
return paramiko.RSAKey(filename=host_key_path)
def gen_host_key(self): @staticmethod
def gen_host_key(key_path):
ssh_key, _ = ssh_key_gen() ssh_key, _ = ssh_key_gen()
with open(self.host_key_path, 'w') as f: with open(key_path, 'w') as f:
f.write(ssh_key) f.write(ssh_key)
def run(self): def run(self):
host = self.app.config["BIND_HOST"] host = current_app.config["BIND_HOST"]
port = self.app.config["SSHD_PORT"] port = current_app.config["SSHD_PORT"]
print('Starting ssh server at {}:{}'.format(host, port)) print('Starting ssh server at {}:{}'.format(host, port))
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.bind((host, port)) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.listen(BACKLOG) sock.bind((host, port))
sock.listen(BACKLOG)
while not self.stop_evt.is_set(): while not self.stop_evt.is_set():
try: try:
sock, addr = self.sock.accept() client, addr = sock.accept()
logger.info("Get ssh request from {}: {}".format(addr[0], addr[1])) logger.info("Get ssh request from {}: {}".format(*addr))
thread = threading.Thread(target=self.handle_connection, args=(sock, addr)) thread = threading.Thread(target=self.handle_connection,
args=(client, addr))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
except Exception as e: except IndexError as e:
logger.error("Start SSH server error: {}".format(e)) logger.error("Start SSH server error: {}".format(e))
def handle_connection(self, sock, addr): def handle_connection(self, sock, addr):
...@@ -65,7 +69,7 @@ class SSHServer: ...@@ -65,7 +69,7 @@ class SSHServer:
'sftp', paramiko.SFTPServer, SFTPServer 'sftp', paramiko.SFTPServer, SFTPServer
) )
request = Request(addr) request = Request(addr)
server = SSHInterface(self.app, request) server = SSHInterface(request)
try: try:
transport.start_server(server=server) transport.start_server(server=server)
except paramiko.SSHException: except paramiko.SSHException:
...@@ -96,7 +100,7 @@ class SSHServer: ...@@ -96,7 +100,7 @@ class SSHServer:
def handle_chan(self, chan, request): def handle_chan(self, chan, request):
client = Client(chan, request) client = Client(chan, request)
self.app.add_client(client) current_app.add_client(client)
self.dispatch(client) self.dispatch(client)
def dispatch(self, client): def dispatch(self, client):
...@@ -104,7 +108,7 @@ class SSHServer: ...@@ -104,7 +108,7 @@ class SSHServer:
request_type = set(client.request.type) request_type = set(client.request.type)
if supported & request_type: if supported & request_type:
logger.info("Request type `pty`, dispatch to interactive mode") logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact() InteractiveServer(client).interact()
elif 'subsystem' in request_type: elif 'subsystem' in request_type:
pass pass
else: else:
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import weakref
from .ctx import current_app, app_service
from .utils import get_logger from .utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
class TaskHandler: class TaskHandler:
def __init__(self):
self.routes = {
'kill_session': self.handle_kill_session
}
def __init__(self, app): @staticmethod
self._app = weakref.ref(app) def handle_kill_session(task):
@property
def app(self):
return self._app()
def handle_kill_session(self, task):
logger.info("Handle kill session task: {}".format(task.args)) logger.info("Handle kill session task: {}".format(task.args))
session_id = task.args session_id = task.args
session = None session = None
for s in self.app.sessions: for s in current_app.sessions:
if s.id == session_id: if s.id == session_id:
session = s session = s
break break
if session: if session:
session.terminate() session.terminate()
self.app.service.finish_task(task.id) app_service.finish_task(task.id)
def handle(self, task): def handle(self, task):
if task.name == "kill_session": func = self.routes.get(task.name)
self.handle_kill_session(task) return func(task)
else:
logger.error("No handler for this task: {}".format(task.name))
...@@ -4,30 +4,35 @@ ...@@ -4,30 +4,35 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import hashlib
import logging import logging
import re import re
import os import os
import threading
import base64
import calendar
import time
import datetime
import gettext import gettext
from io import StringIO from io import StringIO
from binascii import hexlify from binascii import hexlify
import paramiko import paramiko
import pyte import pyte
import pytz
from email.utils import formatdate
from queue import Queue, Empty
from .exception import NoAppException from . import char
from .ctx import stack
BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
def ssh_key_string_to_obj(text, password=None): def ssh_key_string_to_obj(text, password=None):
key = None key = None
try: try:
...@@ -289,17 +294,130 @@ def get_logger(file_name): ...@@ -289,17 +294,130 @@ def get_logger(file_name):
return logging.getLogger('coco.'+file_name) return logging.getLogger('coco.'+file_name)
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
"""实现了一个ssh input, 提示用户输入, 获取并返回
:return user input string
def len_display(s): """
length = 0 input_data = []
for i in s: parser = TtyIOParser()
if zh_pattern.match(i): client.send(wrap_with_line_feed(prompt, before=before, after=after))
length += 2
while True:
data = client.recv(10)
if len(data) == 0:
break
# Client input backspace
if data in char.BACKSPACE_CHAR:
# If input words less than 0, should send 'BELL'
if len(input_data) > 0:
data = char.BACKSPACE_CHAR[data]
input_data.pop()
else:
data = char.BELL_CHAR
client.send(data)
continue
if data.startswith(b'\x03'):
# Ctrl-C
client.send('^C\r\n{} '.format(prompt).encode())
input_data = []
continue
elif data.startswith(b'\x04'):
# Ctrl-D
return 'q'
# Todo: Move x1b to char
if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR:
client.send(b'')
continue
# handle shell expect
multi_char_with_enter = False
if len(data) > 1 and data[-1] in char.ENTER_CHAR_ORDER:
if sensitive:
client.send(len(data) * '*')
else:
client.send(data)
input_data.append(data[:-1])
multi_char_with_enter = True
# If user type ENTER we should get user input
if data in char.ENTER_CHAR or multi_char_with_enter:
client.send(wrap_with_line_feed(b'', after=2))
option = parser.parse_input(input_data)
del input_data[:]
return option.strip()
else: else:
length += 1 if sensitive:
client.send(len(data) * '*')
else:
client.send(data)
input_data.append(data)
def register_app(app):
stack['app'] = app
def register_service(service):
stack['service'] = service
zh_pattern = re.compile(r'[\u4e00-\u9fa5]')
def find_chinese(s):
return zh_pattern.findall(s)
def align_with_zh(s, length, addin=' '):
if not isinstance(s, str):
s = str(s)
zh_len = len(find_chinese(s))
padding = length - (len(s) - zh_len) - zh_len*2
padding_content = ''
if padding > 0:
padding_content = addin*padding
return s + padding_content
def format_with_zh(size_list, *args):
data = []
for length, s in zip(size_list, args):
data.append(align_with_zh(s, length))
return ' '.join(data)
def size_of_str_with_zh(s):
if isinstance(s, int):
s = str(s)
try:
chinese = find_chinese(s)
except TypeError:
raise
return len(s) + len(chinese)
def item_max_length(_iter, maxi=None, mini=None, key=None):
if key:
_iter = [key(i) for i in _iter]
length = [size_of_str_with_zh(s) for s in _iter]
if not length:
return 1
if maxi:
length.append(maxi)
length = max(length)
if mini and length < mini:
length = mini
return length return length
def int_length(i):
return len(str(i))
ugettext = _gettext() ugettext = _gettext()
...@@ -12,15 +12,14 @@ cryptography==2.1.4 ...@@ -12,15 +12,14 @@ cryptography==2.1.4
docutils==0.14 docutils==0.14
dotmap==1.2.20 dotmap==1.2.20
elasticsearch==6.1.1 elasticsearch==6.1.1
Flask==0.12.2 Flask==1.0.2
Flask-SocketIO==2.9.2 Flask-SocketIO==2.9.2
idna==2.6 idna==2.6
itsdangerous==0.24 itsdangerous==0.24
Jinja2==2.10 Jinja2==2.10
jmespath==0.9.3 jmespath==0.9.3
jms-es-sdk==0.5.2 jms-storage==0.0.17
jms-storage==0.0.12 jumpserver-python-sdk==0.0.42
jumpserver-python-sdk==0.0.41
MarkupSafe==1.0 MarkupSafe==1.0
oss2==2.4.0 oss2==2.4.0
paramiko==2.4.0 paramiko==2.4.0
...@@ -28,9 +27,9 @@ psutil==5.4.1 ...@@ -28,9 +27,9 @@ psutil==5.4.1
pyasn1==0.4.2 pyasn1==0.4.2
pycparser==2.18 pycparser==2.18
PyNaCl==1.2.1 PyNaCl==1.2.1
pyte==0.7.0 pyte==0.8.0
python-dateutil==2.6.1 python-dateutil==2.6.1
python-engineio==2.0.1 python-engineio==2.1.0
python-gssapi==0.6.4 python-gssapi==0.6.4
python-socketio==1.8.3 python-socketio==1.8.3
pytz==2017.3 pytz==2017.3
...@@ -41,4 +40,5 @@ six==1.11.0 ...@@ -41,4 +40,5 @@ six==1.11.0
tornado==4.5.2 tornado==4.5.2
urllib3==1.22 urllib3==1.22
wcwidth==0.1.7 wcwidth==0.1.7
Werkzeug==0.12.2 Werkzeug==0.14.1
eventlet==0.22
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment