Commit 02104017 authored by ibuler's avatar ibuler

[Update] 添加连接复用逻辑

parent 27b1ff43
...@@ -292,27 +292,37 @@ class Config(dict): ...@@ -292,27 +292,37 @@ class Config(dict):
if default_value is None: if default_value is None:
return v return v
tp = type(default_value) tp = type(default_value)
try: # 对bool特殊处理
if tp in [list, dict]: if tp is bool and isinstance(v, str):
v = json.loads(v) if v in ("true", "True", "1"):
return True
else: else:
v = tp(v) return False
if tp in [list, dict] and isinstance(v, str):
try:
v = json.loads(v)
return v
except json.JSONDecodeError:
return v
try:
v = tp(v)
except Exception: except Exception:
pass pass
return v return v
def __getitem__(self, item): def __getitem__(self, item):
# 先从设置的来
try: try:
value = super(Config, self).__getitem__(item) value = super().__getitem__(item)
except KeyError: except KeyError:
value = None value = None
if value is not None: if value is not None:
return self.convert_type(item, value) return value
# 其次从环境变量来
value = os.environ.get(item, None) value = os.environ.get(item, None)
if value is not None: if value is not None:
if value.isdigit(): if value.lower() == 'false':
value = int(value)
elif value.lower() == 'false':
value = False value = False
elif value.lower() == 'true': elif value.lower() == 'true':
value = True value = True
...@@ -368,7 +378,8 @@ defaults = { ...@@ -368,7 +378,8 @@ defaults = {
'ASSET_LIST_PAGE_SIZE': 'auto', 'ASSET_LIST_PAGE_SIZE': 'auto',
'SFTP_ROOT': '/tmp', 'SFTP_ROOT': '/tmp',
'SFTP_SHOW_HIDDEN_FILE': False, 'SFTP_SHOW_HIDDEN_FILE': False,
'UPLOAD_FAILED_REPLAY_ON_START': True 'UPLOAD_FAILED_REPLAY_ON_START': True,
'REUSE_CONNECTION': False,
} }
......
...@@ -24,32 +24,85 @@ AUTO_LOGIN = 'auto' ...@@ -24,32 +24,85 @@ AUTO_LOGIN = 'auto'
class SSHConnection: class SSHConnection:
connections = {}
@staticmethod @staticmethod
def get_system_user_auth(system_user, asset): def make_key(user, asset, system_user):
key = "{}_{}_{}".format(user.id, asset.id, system_user.id)
return key
@classmethod
def new_connection_from_cache(cls, user, asset, system_user):
if not config.REUSE_CONNECTION:
return None
key = cls.make_key(user, asset, system_user)
connection = cls.connections.get(key)
if not connection:
return None
connection.ref += 1
return connection
@classmethod
def set_connection_to_cache(cls, conn):
if not config.REUSE_CONNECTION:
return None
key = cls.make_key(conn.user, conn.asset, conn.system_user)
cls.connections[key] = conn
@classmethod
def new_connection(cls, user, asset, system_user):
connection = cls.new_connection_from_cache(user, asset, system_user)
if connection:
logger.debug("Reuse connection: {}->{}@{}".format(
user.username, asset.ip, system_user.username)
)
return connection
connection = cls(user, asset, system_user)
cls.set_connection_to_cache(connection)
return connection
@classmethod
def remove_ssh_connection(cls, conn):
key = "{}_{}_{}".format(conn.user.id, conn.asset.id, conn.system_user.id)
cls.connections.pop(key, None)
def __init__(self, user, asset, system_user):
self.user = user
self.asset = asset
self.system_user = system_user
self.client = None
self.sock = None
self.error = ""
self.ref = 1
def get_system_user_auth(self):
""" """
获取系统用户的认证信息,密码或秘钥 获取系统用户的认证信息,密码或秘钥
:return: system user have full info :return: system user have full info
""" """
password, private_key = \ password, private_key = \
app_service.get_system_user_auth_info(system_user, asset) app_service.get_system_user_auth_info(self.system_user, self.asset)
system_user.password = password self.system_user.password = password
system_user.private_key = private_key self.system_user.private_key = private_key
def get_ssh_client(self, asset, system_user): def get_ssh_client(self):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sock = None sock = None
error = '' error = ''
if not system_user.password and not system_user.private_key: if not self.system_user.password and not self.system_user.private_key:
self.get_system_user_auth(system_user, asset) self.get_system_user_auth()
if asset.domain: if self.asset.domain:
sock = self.get_proxy_sock_v2(asset) sock = self.get_proxy_sock_v2(self.asset)
if not sock: if not sock:
error = 'Connect gateway failed.' error = 'Connect gateway failed.'
logger.error(error) logger.error(error)
asset = self.asset
system_user = self.system_user
try: try:
try: try:
ssh.connect( ssh.connect(
...@@ -86,30 +139,53 @@ class SSHConnection: ...@@ -86,30 +139,53 @@ class SSHConnection:
password_short, key_fingerprint, password_short, key_fingerprint,
)) ))
error += '\r\n' + str(e) if error else str(e) error += '\r\n' + str(e) if error else str(e)
return None, None, error ssh, sock, error = None, None, error
return ssh, sock, None self.client = ssh
self.sock = ssh
self.error = error
def get_transport(self, asset, system_user): def get_transport(self):
ssh, sock, msg = self.get_ssh_client(asset, system_user) if not self.client:
if ssh: self.get_ssh_client()
return ssh.get_transport(), sock, None if not self.client:
return self.client.get_transport()
else: else:
return None, None, msg return None
def get_channel(self, asset, system_user, term="xterm", width=80, height=24): def get_channel(self, term="xterm", width=80, height=24):
ssh, sock, msg = self.get_ssh_client(asset, system_user) if not self.client:
if ssh: self.get_ssh_client()
chan = ssh.invoke_shell(term, width=width, height=height) if self.client:
return chan, sock, None chan = self.client.invoke_shell(term, width=width, height=height)
return chan
else: else:
return None, sock, msg return None
def get_sftp(self, asset, system_user): def get_sftp(self):
ssh, sock, msg = self.get_ssh_client(asset, system_user) if not self.client:
if ssh: self.get_ssh_client()
return ssh.open_sftp(), sock, None if self.client:
return self.client.open_sftp()
else: else:
return None, sock, msg return None
def close(self):
if self.ref > 1:
self.ref -= 1
logger.debug("Connection ref -1: {}->{}@{}".format(
self.user.username, self.asset.hostname, self.system_user.username)
)
return
self.__class__.remove_ssh_connection(self)
try:
self.client.close()
if self.sock:
self.sock.close()
except Exception as e:
logger.error("Close connection error: ", e)
logger.debug("Close connection: {}->{}@{}".format(
self.user.username, self.asset.ip, self.system_user.username)
)
@staticmethod @staticmethod
def get_proxy_sock_v2(asset): def get_proxy_sock_v2(asset):
......
...@@ -398,20 +398,15 @@ class Server(BaseServer): ...@@ -398,20 +398,15 @@ class Server(BaseServer):
""" """
# Todo: Server name is not very suitable # Todo: Server name is not very suitable
def __init__(self, chan, sock, asset, system_user): def __init__(self, chan, connection, asset, system_user):
self.sock = sock self.connection = connection
self.asset = asset self.asset = asset
self.system_user = system_user self.system_user = system_user
super(Server, self).__init__(chan=chan) super(Server, self).__init__(chan=chan)
def close(self): def close(self):
super(Server, self).close() super(Server, self).close()
for i in range(5): self.connection.close()
if not self.chan.transport.is_alive():
break
self.chan.transport.close()
if self.sock:
self.sock.transport.close()
class WSProxy(object): class WSProxy(object):
......
...@@ -64,9 +64,9 @@ class ProxyServer: ...@@ -64,9 +64,9 @@ class ProxyServer:
def proxy(self): def proxy(self):
if not self.check_protocol(): if not self.check_protocol():
return return
self.get_system_user_username_if_need() self.server = self.get_server_conn_from_cache()
self.get_system_user_auth_or_manual_set() if not self.server:
self.server = self.get_server_conn() self.server = self.get_server_conn()
if self.server is None: if self.server is None:
return return
if self.client.closed: if self.client.closed:
...@@ -102,16 +102,25 @@ class ProxyServer: ...@@ -102,16 +102,25 @@ class ProxyServer:
} }
return app_service.validate_user_asset_permission(**kwargs) return app_service.validate_user_asset_permission(**kwargs)
def get_server_conn_from_cache(self):
server = None
if self.system_user.protocol == 'ssh':
server = self.get_ssh_server_conn(cache=True)
return server
def get_server_conn(self): def get_server_conn(self):
logger.info("Connect to {}:{} ...".format(self.asset.hostname, self.asset.port)) # 与获取连接
self.get_system_user_username_if_need()
self.get_system_user_auth_or_manual_set()
self.send_connecting_message() self.send_connecting_message()
logger.info("Connect to {}:{} ...".format(self.asset.hostname, self.asset.port))
if not self.validate_permission(): if not self.validate_permission():
msg = _('No permission') msg = _('No permission')
self.client.send_unicode(warning(wr(msg, before=2, after=0))) self.client.send_unicode(warning(wr(msg, before=2, after=0)))
server = None server = None
elif self.system_user.protocol == self.asset.protocol == 'telnet': elif self.system_user.protocol == 'telnet':
server = self.get_telnet_server_conn() server = self.get_telnet_server_conn()
elif self.system_user.protocol == self.asset.protocol == 'ssh': elif self.system_user.protocol == 'ssh':
server = self.get_ssh_server_conn() server = self.get_ssh_server_conn()
else: else:
server = None server = None
...@@ -129,21 +138,28 @@ class ProxyServer: ...@@ -129,21 +138,28 @@ class ProxyServer:
server = TelnetServer(sock, self.asset, self.system_user) server = TelnetServer(sock, self.asset, self.system_user)
return server return server
def get_ssh_server_conn(self): def get_ssh_server_conn(self, cache=False):
request = self.client.request request = self.client.request
term = request.meta.get('term', 'xterm') term = request.meta.get('term', 'xterm')
width = request.meta.get('width', 80) width = request.meta.get('width', 80)
height = request.meta.get('height', 24) height = request.meta.get('height', 24)
ssh = SSHConnection()
chan, sock, msg = ssh.get_channel( if cache:
self.asset, self.system_user, term=term, conn = SSHConnection.new_connection_from_cache(
width=width, height=height self.client.user, self.asset, self.system_user
) )
if not conn:
return None
else:
conn = SSHConnection.new_connection(
self.client.user, self.asset, self.system_user
)
chan = conn.get_channel(term=term, width=width, height=height)
if not chan: if not chan:
self.client.send_unicode(warning(wr(msg, before=1, after=0))) self.client.send_unicode(warning(wr(conn.error, before=1, after=0)))
server = None server = None
else: else:
server = Server(chan, sock, self.asset, self.system_user) server = Server(chan, conn, self.asset, self.system_user)
return server return server
def send_connecting_message(self): def send_connecting_message(self):
......
...@@ -87,9 +87,9 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -87,9 +87,9 @@ class SFTPServer(paramiko.SFTPServerInterface):
if asset.org_id: if asset.org_id:
key = "{}.{}".format(asset.hostname, asset.org_name) key = "{}.{}".format(asset.hostname, asset.org_name)
value['asset'] = asset value['asset'] = asset
value['system_users'] = {su.name: su value['system_users'] = {
su.name: su
for su in asset.system_users_granted for su in asset.system_users_granted
if su.protocol == "ssh" and su.login_mode == 'auto'
} }
hosts[key] = value hosts[key] = value
return hosts return hosts
...@@ -99,17 +99,9 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -99,17 +99,9 @@ class SFTPServer(paramiko.SFTPServerInterface):
super(SFTPServer, self).session_ended() super(SFTPServer, self).session_ended()
for _, v in self._sftp.items(): for _, v in self._sftp.items():
sftp = v['client'] sftp = v['client']
proxy = v.get('proxy') conn = v.get('connection')
chan = sftp.get_channel()
trans = chan.get_transport()
sftp.close() sftp.close()
conn.close()
active_channels = [c for c in trans._channels.values() if not c.closed]
if not active_channels:
trans.close()
if proxy:
proxy.close()
proxy.transport.close()
self._sftp = {} self._sftp = {}
def get_host_sftp(self, host, su): def get_host_sftp(self, host, su):
...@@ -121,17 +113,18 @@ class SFTPServer(paramiko.SFTPServerInterface): ...@@ -121,17 +113,18 @@ class SFTPServer(paramiko.SFTPServerInterface):
cache_key = '{}@{}'.format(su, host) cache_key = '{}@{}'.format(su, host)
if cache_key not in self._sftp: if cache_key not in self._sftp:
ssh = SSHConnection() conn = SSHConnection.new_connection(self.server.connection.user,
__sftp, proxy, msg = ssh.get_sftp(asset, system_user) asset, system_user)
__sftp = conn.get_sftp()
if __sftp: if __sftp:
sftp = { sftp = {
'client': __sftp, 'proxy': proxy, 'client': __sftp, 'connection': conn,
'home': __sftp.normalize('') 'home': __sftp.normalize('')
} }
self._sftp[cache_key] = sftp self._sftp[cache_key] = sftp
return sftp return sftp
else: else:
raise OSError("Can not connect asset sftp server: {}".format(msg)) raise OSError("Can not connect asset sftp server: {}".format(conn.error))
else: else:
return self._sftp[cache_key] return self._sftp[cache_key]
......
...@@ -57,3 +57,6 @@ BOOTSTRAP_TOKEN: <PleasgeChangeSameWithJumpserver> ...@@ -57,3 +57,6 @@ BOOTSTRAP_TOKEN: <PleasgeChangeSameWithJumpserver>
# SFTP是否显示隐藏文件 # SFTP是否显示隐藏文件
# SFTP_SHOW_HIDDEN_FILE: false # SFTP_SHOW_HIDDEN_FILE: false
# 是否复用和用户后端资产已建立的连接(用户不会复用其他用户的连接)
# REUSE_CONNECTION: false
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment