Unverified Commit e2c18bb8 authored by BaiJiangJie's avatar BaiJiangJie Committed by GitHub

Merge pull request #218 from jumpserver/dev

Dev
parents 1b7446ca a38b9c1f
...@@ -39,6 +39,9 @@ class SSHConnection: ...@@ -39,6 +39,9 @@ class SSHConnection:
connection = cls.connections.get(key) connection = cls.connections.get(key)
if not connection: if not connection:
return None return None
if not connection.is_active:
cls.connections.pop(key, None)
return None
connection.ref += 1 connection.ref += 1
return connection return connection
...@@ -59,7 +62,9 @@ class SSHConnection: ...@@ -59,7 +62,9 @@ class SSHConnection:
) )
return connection return connection
connection = cls(user, asset, system_user) connection = cls(user, asset, system_user)
cls.set_connection_to_cache(connection) connection.connect()
if connection.is_active:
cls.set_connection_to_cache(connection)
return connection return connection
@classmethod @classmethod
...@@ -72,6 +77,7 @@ class SSHConnection: ...@@ -72,6 +77,7 @@ class SSHConnection:
self.asset = asset self.asset = asset
self.system_user = system_user self.system_user = system_user
self.client = None self.client = None
self.transport = None
self.sock = None self.sock = None
self.error = "" self.error = ""
self.ref = 1 self.ref = 1
...@@ -86,7 +92,7 @@ class SSHConnection: ...@@ -86,7 +92,7 @@ class SSHConnection:
self.system_user.password = password self.system_user.password = password
self.system_user.private_key = private_key self.system_user.private_key = private_key
def get_ssh_client(self): def connect(self):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
sock = None sock = None
...@@ -122,6 +128,7 @@ class SSHConnection: ...@@ -122,6 +128,7 @@ class SSHConnection:
) )
transport = ssh.get_transport() transport = ssh.get_transport()
transport.set_keepalive(20) transport.set_keepalive(20)
self.transport = transport
except Exception as e: except Exception as e:
password_short = "None" password_short = "None"
key_fingerprint = "None" key_fingerprint = "None"
...@@ -144,37 +151,44 @@ class SSHConnection: ...@@ -144,37 +151,44 @@ class SSHConnection:
self.sock = ssh self.sock = ssh
self.error = error self.error = error
def reconnect_if_need(self):
if not self.is_active:
self.connect()
if self.is_active:
return True
return False
def get_transport(self): def get_transport(self):
if not self.client: if self.reconnect_if_need():
self.get_ssh_client() return self.transport
if not self.client: return None
return self.client.get_transport()
else:
return None
def get_channel(self, term="xterm", width=80, height=24): def get_channel(self, term="xterm", width=80, height=24):
if not self.client: if self.reconnect_if_need():
self.get_ssh_client()
if self.client:
chan = self.client.invoke_shell(term, width=width, height=height) chan = self.client.invoke_shell(term, width=width, height=height)
return chan return chan
else: else:
return None return None
def get_sftp(self): def get_sftp(self):
if not self.client: if self.reconnect_if_need():
self.get_ssh_client()
if self.client:
return self.client.open_sftp() return self.client.open_sftp()
else: else:
return None return None
@property
def is_active(self):
return self.transport and self.transport.is_active()
def close(self): def close(self):
if self.ref > 1: if self.ref > 1:
self.ref -= 1 self.ref -= 1
logger.debug("Connection ref -1: {}->{}@{}".format( msg = "Connection ref -1: {}->{}@{}. {}".format(
self.user.username, self.asset.hostname, self.system_user.username) self.user.username, self.asset.hostname,
self.system_user.username, self.ref
) )
logger.debug(msg)
return return
self.__class__.remove_ssh_connection(self) self.__class__.remove_ssh_connection(self)
try: try:
...@@ -183,9 +197,12 @@ class SSHConnection: ...@@ -183,9 +197,12 @@ class SSHConnection:
self.sock.close() self.sock.close()
except Exception as e: except Exception as e:
logger.error("Close connection error: ", e) logger.error("Close connection error: ", e)
logger.debug("Close connection: {}->{}@{}".format(
self.user.username, self.asset.ip, self.system_user.username) msg = "Close connection: {}->{}@{}. Total connections live: {}".format(
self.user.username, self.asset.ip,
self.system_user.username, len(self.connections)
) )
logger.debug(msg)
@staticmethod @staticmethod
def get_proxy_sock_v2(asset): def get_proxy_sock_v2(asset):
......
...@@ -148,7 +148,7 @@ class ProxyServer: ...@@ -148,7 +148,7 @@ class ProxyServer:
conn = SSHConnection.new_connection_from_cache( conn = SSHConnection.new_connection_from_cache(
self.client.user, self.asset, self.system_user self.client.user, self.asset, self.system_user
) )
if not conn: if not conn or not conn.is_active:
return None return None
else: else:
conn = SSHConnection.new_connection( conn = SSHConnection.new_connection(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment