Unverified Commit 9591d717 authored by 老广's avatar 老广 Committed by GitHub

[Update] 优化coco内存占用问题 (#153)

* [Update] 优化coco内存占用问题

* [Update] Send unicode

* [Update] Send unicode

* [Bugfix] 修复连接信息bug
parent 1eb34b69
......@@ -159,7 +159,7 @@ class Coco:
continue
# Session已正常关闭
if s.closed:
Session.remove_session(s)
Session.remove_session(s.id)
else:
check_session_idle_too_long(s)
except Exception as e:
......@@ -182,11 +182,8 @@ class Coco:
self.run_httpd()
signal.signal(signal.SIGTERM, lambda x, y: self.shutdown())
while True:
if self.stop_evt.is_set():
print("Coco receive term signal, exit")
break
time.sleep(3)
self.lock.acquire()
self.lock.acquire()
except KeyboardInterrupt:
self.shutdown()
......@@ -204,8 +201,8 @@ class Coco:
logger.info("Grace shutdown the server")
for connection in Connection.connections.values():
connection.close()
time.sleep(1)
self.heartbeat()
self.lock.release()
self.stop_evt.set()
self.sshd.shutdown()
self.httpd.shutdown()
......@@ -71,9 +71,7 @@ class SSHConnection:
)
transport = ssh.get_transport()
transport.set_keepalive(300)
except (paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
SSHException) as e:
except Exception as e:
password_short = "None"
key_fingerprint = "None"
if system_user.password:
......@@ -85,13 +83,11 @@ class SSHConnection:
)
logger.error("Connect {}@{}:{} auth failed, password: \
{}, key: {}".format(
{}, key: {}".format(
system_user.username, asset.ip, asset.port,
password_short, key_fingerprint,
))
return None, None, error + '\n' + str(e)
except (socket.error, socket.timeout) as e:
return None, None, error + '\n' + str(e)
return ssh, sock, None
def get_transport(self, asset, system_user):
......@@ -134,48 +130,15 @@ class SSHConnection:
password=gateway.password,
pkey=gateway.private_key_obj,
timeout=config['SSH_TIMEOUT'])
except (paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
SSHException, socket.error):
except:
continue
sock = ssh.get_transport().open_channel(
'direct-tcpip', (asset.ip, asset.port), ('127.0.0.1', 0)
)
break
return sock
def get_proxy_sock(self, asset):
sock = None
domain = app_service.get_domain_detail_with_gateway(
asset.domain
)
if not domain.has_ssh_gateway():
return None
for i in domain.gateways:
gateway = domain.random_ssh_gateway()
proxy_command = [
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(gateway.port),
"{}@{}".format(gateway.username, gateway.ip),
"-W", "{}:{}".format(asset.ip, asset.port), "-q",
]
if gateway.password:
proxy_command.insert(0, "sshpass -p {}".format(gateway.password))
if gateway.private_key:
gateway.set_key_dir(os.path.join(config['ROOT_PATH'], 'keys'))
proxy_command.append("-i {}".format(gateway.private_key_file))
proxy_command = ' '.join(proxy_command)
try:
sock = paramiko.ProxyCommand(proxy_command)
sock = ssh.get_transport().open_channel(
'direct-tcpip', (asset.ip, asset.port), ('127.0.0.1', 0)
)
break
except (paramiko.AuthenticationException,
paramiko.BadAuthenticationType, SSHException,
TimeoutError) as e:
logger.error(e)
continue
except:
return None
return sock
......@@ -235,7 +198,6 @@ class TelnetConnection:
self.asset.hostname
)
logger.info(msg)
self.client.send(b'\r\n' + data)
return self.sock, None
elif result is False:
self.sock.close()
......
......@@ -66,7 +66,7 @@ class SFTPVolume(BaseVolume):
data["dirs"] = 1
if self._is_root(path):
del data['phash']
data.pop('phash', None)
data['name'] = self.root_name
data['locked'] = 1
data['volume_id'] = self.get_volume_id()
......
......@@ -94,12 +94,12 @@ class InteractiveServer:
_("{T}8) Enter {green}r{end} to refresh your assets and nodes.{R}"),
_("{T}0) Enter {green}q{end} exit.{R}")
]
self.client.send(header.format(
self.client.send_unicode(header.format(
title="\033[1;32m", user=self.client.user, end="\033[0m",
T='\t', R='\r\n\r'
))
for item in menu:
self.client.send(item.format(
self.client.send_unicode(item.format(
green="\033[32m", end="\033[0m",
T='\t', R='\r\n\r'
))
......@@ -112,7 +112,7 @@ class InteractiveServer:
for i in f:
if i.decode('utf-8').startswith('#'):
continue
self.client.send(i.decode('utf-8').replace('\n', '\r\n'))
self.client.send_unicode(i.decode('utf-8').replace('\n', '\r\n'))
def dispatch(self, opt):
if opt is None:
......@@ -152,7 +152,7 @@ class InteractiveServer:
asset = assets[0]
if asset.protocol == "rdp" \
or asset.platform.lower().startswith("windows"):
self.client.send(warning(
self.client.send_unicode(warning(
_("Terminal does not support login rdp, "
"please use web terminal to access"))
)
......@@ -201,7 +201,7 @@ class InteractiveServer:
def display_assets_paging(self, assets):
if len(assets) == 0:
self.client.send(wr(_("No Assets"), before=0))
self.client.send_unicode(wr(_("No Assets"), before=0))
return
self.total_count = len(assets)
......@@ -270,15 +270,15 @@ class InteractiveServer:
)
size_list.append(comment_length)
fake_data.append(_("Comment"))
self.client.send(wr(title(format_with_zh(size_list, *fake_data))))
self.client.send_unicode(wr(title(format_with_zh(size_list, *fake_data))))
for index, asset in enumerate(self.results, 1):
data = [
index, asset.hostname, asset.ip,
asset.system_users_name_list, asset.comment
]
self.client.send(wr(format_with_zh(size_list, *data)))
self.client.send_unicode(wr(format_with_zh(size_list, *data)))
self.client.send(wr(title(
self.client.send_unicode(wr(title(
_("Page: {}, Count: {}, Total Page: {}, Total Count: {}").format(
self.page, len(self.results), self.total_pages,
self.total_count)), before=1)
......@@ -286,13 +286,13 @@ class InteractiveServer:
def display_page_bottom_prompt(self):
msg = wr(_('Tips: Enter the asset ID and log directly into the asset.'), before=1)
self.client.send(msg)
self.client.send_unicode(msg)
prompt_page_up = _("Page up: P/p")
prompt_page_down = _("Page down: Enter|N/n")
prompt_back = _("BACK: b/q")
prompts = [prompt_page_up, prompt_page_down, prompt_back]
prompt = '\t'.join(prompts)
self.client.send(wr(prompt, before=1))
self.client.send_unicode(wr(prompt, before=1))
def get_user_action(self):
opt = net_input(self.client, prompt=':')
......@@ -365,14 +365,14 @@ class InteractiveServer:
self.get_user_nodes()
if not self.nodes:
self.client.send(wr(_('No Nodes'), before=0))
self.client.send_unicode(wr(_('No Nodes'), before=0))
return
self.node_tree.show(key=lambda node: node.identifier)
self.client.send(wr(title(_("Node: [ ID.Name(Asset amount) ]")), before=0))
self.client.send(wr(self.node_tree._reader.replace('\n', '\r\n'), before=0))
self.client.send_unicode(wr(title(_("Node: [ ID.Name(Asset amount) ]")), before=0))
self.client.send_unicode(wr(self.node_tree._reader.replace('\n', '\r\n'), before=0))
prompt = _("Tips: Enter g+NodeID to display the host under the node, such as g1")
self.client.send(wr(title(prompt), before=1))
self.client.send_unicode(wr(title(prompt), before=1))
def display_node_assets(self, _id):
if self.nodes is None:
......@@ -380,7 +380,7 @@ class InteractiveServer:
if _id > len(self.nodes) or _id <= 0:
msg = wr(warning(_("There is no matched node, please re-enter")))
self.client.send(msg)
self.client.send_unicode(msg)
self.display_nodes_as_tree()
return
......@@ -409,7 +409,7 @@ class InteractiveServer:
return None
while True:
self.client.send(wr(_("Select a login:: "), after=1))
self.client.send_unicode(wr(_("Select a login:: "), after=1))
self.display_system_users(system_users)
opt = net_input(self.client, prompt="ID> ")
if opt.isdigit() and len(system_users) > int(opt):
......@@ -423,7 +423,7 @@ class InteractiveServer:
def display_system_users(self, system_users):
for index, system_user in enumerate(system_users):
self.client.send(wr("{} {}".format(index, system_user.name)))
self.client.send_unicode(wr("{} {}".format(index, system_user.name)))
#
# Proxy
......@@ -432,7 +432,7 @@ class InteractiveServer:
def proxy(self, asset):
system_user = self.choose_system_user(asset.system_users_granted)
if system_user is None:
self.client.send(_("No system user"))
self.client.send_unicode(_("No system user"))
return
forwarder = ProxyServer(self.client, asset, system_user)
forwarder.proxy()
......
......@@ -58,7 +58,7 @@ class Connection(object):
return
client.close()
self.__class__.clients_num -= 1
del self.clients[tid]
self.clients.pop(tid, None)
logger.info("Client {} leave, total {} now".format(
client, self.__class__.clients_num
))
......@@ -83,7 +83,7 @@ class Connection(object):
if not connection:
return
connection.close()
del cls.connections[cid]
cls.connections.pop(cid, None)
@classmethod
def get_connection(cls, cid):
......@@ -123,13 +123,11 @@ class Client(object):
return self.chan.fileno()
def send(self, b):
if isinstance(b, str):
b = b.encode("utf-8")
try:
return self.chan.send(b)
except OSError:
self.close()
return
return self.chan.send(b)
def send_unicode(self, s):
b = s.encode()
self.send(b)
@property
def closed(self):
......@@ -256,7 +254,7 @@ class BaseServer(object):
break
elif action == rule.DENY:
msg = _("Command `{}` is forbidden ........").format(cmd)
self.command_forbidden(msg)
data = self.command_forbidden(msg)
break
return data
......@@ -356,7 +354,7 @@ class BaseServer(object):
return self.chan.fileno()
def close(self):
logger.info("Closed server {}".format(self))
logger.info("Close server to {}".format(self))
self.r_input_output_data_filter(b'')
self.chan.close()
......@@ -399,8 +397,10 @@ class Server(BaseServer):
def close(self):
super(Server, self).close()
self.chan.transport.close()
logger.debug("Backend server closed")
for i in range(5):
if not self.chan.transport.is_alive():
break
self.chan.transport.close()
if self.sock:
self.sock.transport.close()
......
......@@ -11,7 +11,7 @@ from .connection import SSHConnection, TelnetConnection
from .service import app_service
from .config import config
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_logger, net_input, ugettext as _
get_logger, net_input, ugettext as _, ignore_error
logger = get_logger(__file__)
......@@ -48,7 +48,7 @@ class ProxyServer:
msg = 'System user <{}> and asset <{}> protocol are inconsistent.'.format(
self.system_user.name, self.asset.hostname
)
self.client.send(warning(wr(msg, before=1, after=0)))
self.client.send_unicode(warning(wr(msg, before=1, after=0)))
return False
return True
......@@ -68,12 +68,19 @@ class ProxyServer:
self.server = self.get_server_conn()
if self.server is None:
return
if self.client.closed:
self.server.close()
return
session = Session.new_session(self.client, self.server)
try:
session.bridge()
finally:
Session.remove_session(session.id)
self.server.close()
msg = 'Session end, total {} now'.format(
len(Session.sessions),
)
logger.info(msg)
def validate_permission(self):
"""
......@@ -88,7 +95,7 @@ class ProxyServer:
logger.info("Connect to {}:{} ...".format(self.asset.hostname, self.asset.port))
self.send_connecting_message()
if not self.validate_permission():
self.client.send(warning(_('No permission')))
self.client.send_unicode(warning(_('No permission')))
server = None
elif self.system_user.protocol == self.asset.protocol == 'telnet':
server = self.get_telnet_server_conn()
......@@ -97,14 +104,13 @@ class ProxyServer:
else:
server = None
self.connecting = False
self.client.send(b'\r\n')
return server
def get_telnet_server_conn(self):
telnet = TelnetConnection(self.asset, self.system_user, self.client)
sock, msg = telnet.get_socket()
if not sock:
self.client.send(warning(wr(msg, before=1, after=0)))
self.client.send_unicode(warning(wr(msg, before=1, after=0)))
server = None
else:
server = TelnetServer(sock, self.asset, self.system_user)
......@@ -121,24 +127,27 @@ class ProxyServer:
width=width, height=height
)
if not chan:
self.client.send(warning(wr(msg, before=1, after=0)))
self.client.send_unicode(warning(wr(msg, before=1, after=0)))
server = None
else:
server = Server(chan, sock, self.asset, self.system_user)
return server
def send_connecting_message(self):
@ignore_error
def func():
delay = 0.0
self.client.send(_('Connecting to {}@{} {:.1f}').format(
self.system_user, self.asset, delay)
msg = _('Connecting to {}@{} {:.1f}').format(
self.system_user, self.asset, delay
)
self.client.send_unicode(msg)
while self.connecting and delay < config['SSH_TIMEOUT']:
if 0 <= delay < 10:
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode())
self.client.send_unicode('\x08\x08\x08{:.1f}'.format(delay))
else:
self.client.send('\x08\x08\x08\x08{:.1f}'.format(delay).encode())
self.client.send_unicode('\x08\x08\x08\x08{:.1f}'.format(delay))
time.sleep(0.1)
delay += 0.1
self.client.send(b'\r\n')
thread = threading.Thread(target=func)
thread.start()
......@@ -100,7 +100,7 @@ class ReplayRecorder(object):
return False
if app_service.finish_replay(session_id):
logger.info(
logger.debug(
"Success finished session {}'s replay ".format(session_id)
)
return True
......@@ -146,7 +146,9 @@ class CommandRecorder(object):
def push_to_server_async(self):
def func():
while not self.stop_evt.is_set():
while True:
if self.stop_evt.is_set() and self.queue.empty():
break
data_set = self.queue.mget(self.batch_size, timeout=self.timeout)
size = self.queue.qsize()
if size > 0:
......@@ -166,7 +168,7 @@ class CommandRecorder(object):
pass
def session_end(self, session_id):
pass
self.stop_evt.set()
def get_command_recorder():
......
......@@ -62,7 +62,7 @@ class Session:
session.close()
app_service.finish_session(session.to_json())
app_service.finish_replay(sid)
del cls.sessions[sid]
cls.sessions.pop(sid, None)
def add_watcher(self, watcher, silent=False):
"""
......@@ -74,7 +74,7 @@ class Session:
"""
logger.debug("Session add watcher: {} -> {} ".format(self.id, watcher))
if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode())
watcher.send_unicode("Welcome to watch session {}\r\n".format(self.id))
self.sel.register(watcher, selectors.EVENT_READ)
self._watchers.append(watcher)
......@@ -146,7 +146,7 @@ class Session:
if not msg:
msg = _("Terminated by administrator")
try:
self.client.send(wr(warn(msg), before=1))
self.client.send_unicode(wr(warn(msg), before=1))
except OSError:
pass
self.stop_evt.set()
......@@ -166,6 +166,8 @@ class Session:
self.sel.register(self.server, selectors.EVENT_READ)
self.sel.register(self.stop_evt, selectors.EVENT_READ)
self.sel.register(self.client.change_size_evt, selectors.EVENT_READ)
if self.client.closed:
return
while not self.is_finished:
events = self.sel.select(timeout=60)
for sock in [key.fileobj for key, _ in events]:
......@@ -202,7 +204,6 @@ class Session:
logger.debug("Resize server chan size {}*{}".format(width, height))
self.server.resize_pty(width=width, height=height)
@ignore_error
def close(self):
if self.closed:
logger.debug("Session has been closed: {} ".format(self.id))
......
......@@ -119,7 +119,7 @@ class SSHServer:
else:
msg = "Request type `{}:{}` not support now".format(kind, chan_type)
logger.error(msg)
client.send(msg)
client.send_unicode(msg)
finally:
connection = Connection.get_connection(client.connection_id)
connection.remove_client(client.id)
......
......@@ -300,7 +300,8 @@ def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
"""
input_data = []
parser = TtyIOParser()
client.send(wrap_with_line_feed(prompt, before=before, after=after))
msg = wrap_with_line_feed(prompt, before=before, after=after)
client.send_unicode(msg)
while True:
data = client.recv(1)
......@@ -319,7 +320,7 @@ def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
if data.startswith(b'\x03'):
# Ctrl-C
client.send('^C\r\n{} '.format(prompt).encode())
client.send_unicode('^C\r\n{} '.format(prompt))
input_data = []
continue
elif data.startswith(b'\x04'):
......@@ -339,7 +340,7 @@ def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
return option.strip()
else:
if sensitive:
client.send(len(data) * '*')
client.send_unicode((len(data) * '*'))
else:
client.send(data)
input_data.append(data)
......@@ -460,7 +461,6 @@ def ignore_error(func):
return resp
except Exception as e:
logger.error("Error occur: {} {}".format(func.__name__, e))
raise e
return wrapper
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment