Commit ae977b66 authored by ibuler's avatar ibuler

[Feature] 修改上传日志方式

parent 49d0c399
...@@ -3,12 +3,17 @@ import os ...@@ -3,12 +3,17 @@ import os
import time import time
import threading import threading
import logging import logging
import multiprocessing
from jms.service import AppService from jms.service import AppService
from .config import Config from .config import Config
from .sshd import SSHServer from .sshd import SSHServer
from .httpd import HttpServer from .httpd import HttpServer
from .logging import create_logger from .logging import create_logger
from .queue import MemoryQueue
from .record import ServerCommandRecorder, ServerReplayRecorder, \
START_SENTINEL, DONE_SENTINEL
__version__ = '0.4.0' __version__ = '0.4.0'
...@@ -41,6 +46,9 @@ class Coco: ...@@ -41,6 +46,9 @@ class Coco:
'HEARTBEAT_INTERVAL': 5, 'HEARTBEAT_INTERVAL': 5,
'MAX_CONNECTIONS': 500, 'MAX_CONNECTIONS': 500,
'ADMINS': '', 'ADMINS': '',
'QUEUE_ENGINE': 'memory',
'QUEUE_MAX_SIZE': 0,
'MAX_PUSH_THREADS': 5,
# 'MAX_RECORD_OUTPUT_LENGTH': 4096, # 'MAX_RECORD_OUTPUT_LENGTH': 4096,
} }
...@@ -55,6 +63,10 @@ class Coco: ...@@ -55,6 +63,10 @@ class Coco:
self._service = None self._service = None
self._sshd = None self._sshd = None
self._httpd = None self._httpd = None
self._command_queue = None
self._replay_queue = None
self._command_recorder = None
self._replay_recorder = None
@property @property
def service(self): def service(self):
...@@ -81,10 +93,62 @@ class Coco: ...@@ -81,10 +93,62 @@ class Coco:
def load_extra_conf_from_server(self): def load_extra_conf_from_server(self):
pass pass
def initial_queue(self):
logger.debug("Initial app queue")
queue_size = int(self.config['QUEUE_MAX_SIZE'])
# Todo: For other queue
if self.config['QUEUE_ENGINE'] == 'memory':
self._command_queue = MemoryQueue(queue_size)
self._replay_queue = MemoryQueue(queue_size)
else:
self._command_queue = MemoryQueue(queue_size)
self._replay_queue = MemoryQueue(queue_size)
def initial_recorder(self):
if self.config['REPLAY_STORE_ENGINE'] == 'server':
self._replay_recorder = ServerReplayRecorder(self)
else:
self._replay_recorder = ServerReplayRecorder(self)
if self.config['COMMAND_STORE_ENGINE'] == 'server':
self._command_recorder = ServerCommandRecorder(self)
else:
self._command_recorder = ServerCommandRecorder(self)
def keep_push_record(self):
threads = []
def push_command(q, callback, size=10):
while not self.stop_evt.is_set():
data_set = q.mget(size)
callback(data_set)
def push_replay(q, callback, size=10):
while not self.stop_evt.is_set():
data_set = q.mget(size)
callback(data_set)
for i in range(self.config['MAX_PUSH_THREADS']):
t = threading.Thread(target=push_command, args=(
self._command_queue, self._command_recorder.record_command,
))
threads.append(t)
t = threading.Thread(target=push_replay, args=(
self._replay_queue, self._replay_recorder.record_replay,
))
threads.append(t)
for t in threads:
t.start()
logger.info("Start push record process: {}".format(t))
def bootstrap(self): def bootstrap(self):
self.make_logger() self.make_logger()
self.initial_queue()
self.initial_recorder()
self.service.initial() self.service.initial()
self.load_extra_conf_from_server() self.load_extra_conf_from_server()
self.keep_push_record()
self.keep_heartbeat() self.keep_heartbeat()
self.monitor_sessions() self.monitor_sessions()
...@@ -93,6 +157,8 @@ class Coco: ...@@ -93,6 +157,8 @@ class Coco:
tasks = self.service.terminal_heartbeat(_sessions) tasks = self.service.terminal_heartbeat(_sessions)
if tasks: if tasks:
self.handle_task(tasks) self.handle_task(tasks)
logger.info("Command queue size: {}".format(self._command_queue.qsize()))
logger.info("Replay queue size: {}".format(self._replay_queue.qsize()))
def keep_heartbeat(self): def keep_heartbeat(self):
def func(): def func():
...@@ -157,6 +223,7 @@ class Coco: ...@@ -157,6 +223,7 @@ class Coco:
for client in self.clients: for client in self.clients:
self.remove_client(client) self.remove_client(client)
time.sleep(1) time.sleep(1)
self.stop_evt.set()
self.sshd.shutdown() self.sshd.shutdown()
self.httpd.shutdown() self.httpd.shutdown()
logger.info("Grace shutdown the server") logger.info("Grace shutdown the server")
...@@ -179,15 +246,57 @@ class Coco: ...@@ -179,15 +246,57 @@ class Coco:
def add_session(self, session): def add_session(self, session):
with self.lock: with self.lock:
self.sessions.append(session) self.sessions.append(session)
self.put_command_start_queue(session)
self.put_replay_done_queue(session)
self.heartbeat() self.heartbeat()
def remove_session(self, session): def remove_session(self, session):
with self.lock: with self.lock:
logger.info("Remove session: {}".format(session)) logger.info("Remove session: {}".format(session))
self.sessions.remove(session) self.sessions.remove(session)
del session.server self.put_command_done_queue(session)
del session.client self.put_replay_done_queue(session)
del session
self.heartbeat() self.heartbeat()
def put_replay_queue(self, session, data):
logger.debug("Put replay data: {} {}".format(session, data))
self._replay_queue.put((
session.id, data, time.time()
))
def put_replay_start_queue(self, session):
self._replay_queue.put((
session.id, START_SENTINEL, time.time()
))
def put_replay_done_queue(self, session):
self._replay_queue.put((
session.id, DONE_SENTINEL, time.time()
))
def put_command_queue(self, session, _input, _output):
logger.debug("Put command data: {} {} {}".format(session, _input, _output))
self._command_queue.put((
session.id, _input[:128], _output[:1024], session.client.user.username,
session.server.asset.hostname, session.server.system_user.username,
time.time()
))
def put_command_start_queue(self, session):
self._command_queue.put((
session.id, START_SENTINEL, START_SENTINEL,
session.client.user.username,
session.server.asset.hostname,
session.server.system_user.username,
time.time()
))
def put_command_done_queue(self, session):
self._command_queue.put((
session.id, DONE_SENTINEL, DONE_SENTINEL,
session.client.user.username,
session.server.asset.hostname,
session.server.system_user.username,
time.time()
))
...@@ -5,4 +5,5 @@ class PermissionFailed(Exception): ...@@ -5,4 +5,5 @@ class PermissionFailed(Exception):
pass pass
class NoAppException(Exception):
pass
...@@ -3,13 +3,12 @@ import socket ...@@ -3,13 +3,12 @@ import socket
import threading import threading
import logging import logging
import time import time
import weakref
import paramiko import paramiko
from .session import Session from .session import Session
from .models import Server from .models import Server
from .record import LocalFileReplayRecorder, LocalFileCommandRecorder, \
ServerReplayRecorder, ServerCommandRecorder
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning
...@@ -20,44 +19,27 @@ BUF_SIZE = 4096 ...@@ -20,44 +19,27 @@ BUF_SIZE = 4096
class ProxyServer: class ProxyServer:
def __init__(self, app, client): def __init__(self, app, client):
self.app = app self._app = weakref.ref(app)
self.client = client self.client = client
self.request = client.request self.request = client.request
self.server = None self.server = None
self.connecting = True self.connecting = True
self.session = None self.session = None
@property
def app(self):
return self._app()
def proxy(self, asset, system_user): def proxy(self, asset, system_user):
self.send_connecting_message(asset, system_user) self.send_connecting_message(asset, system_user)
self.server = self.get_server_conn(asset, system_user) self.server = self.get_server_conn(asset, system_user)
if self.server is None: if self.server is None:
return return
self.session = Session(self.client, self.server) self.session = Session(self.app, self.client, self.server)
self.app.add_session(self.session) self.app.add_session(self.session)
self.watch_win_size_change_async() self.watch_win_size_change_async()
self.add_recorder()
self.session.bridge() self.session.bridge()
def add_recorder(self):
"""
上传记录,如果配置的是server,就上传到服务器端,实例化对应的recorder,
将来有计划直接上传到 es和oss
:return:
"""
if self.app.config["REPLAY_STORE_ENGINE"].lower() == "server":
replay_recorder = ServerReplayRecorder(self.app, self.session)
else:
replay_recorder = LocalFileReplayRecorder(self.app, self.session)
if self.app.config["COMMAND_STORE_ENGINE"].lower() == "server":
command_recorder = ServerCommandRecorder(self.app, self.session)
else:
command_recorder = LocalFileCommandRecorder(self.app, self.session)
self.session.add_recorder(replay_recorder)
self.session.record_replay_async()
self.server.add_recorder(command_recorder)
self.server.record_command_async()
def validate_permission(self, asset, system_user): def validate_permission(self, asset, system_user):
""" """
验证用户是否有连接改资产的权限 验证用户是否有连接改资产的权限
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import logging import logging
import socket import socket
import threading import threading
import weakref
import os import os
from jms.models import Asset, AssetGroup from jms.models import Asset, AssetGroup
...@@ -20,7 +21,7 @@ class InteractiveServer: ...@@ -20,7 +21,7 @@ class InteractiveServer:
_sentinel = object() _sentinel = object()
def __init__(self, app, client): def __init__(self, app, client):
self.app = app self._app = weakref.ref(app)
self.client = client self.client = client
self.request = client.request self.request = client.request
self.assets = None self.assets = None
...@@ -29,6 +30,10 @@ class InteractiveServer: ...@@ -29,6 +30,10 @@ class InteractiveServer:
self.get_user_assets_async() self.get_user_assets_async()
self.get_user_asset_groups_async() self.get_user_asset_groups_async()
@property
def app(self):
return self._app()
def display_banner(self): def display_banner(self):
self.client.send(char.CLEAR_CHAR) self.client.send(char.CLEAR_CHAR)
logo_path = os.path.join(self.app.root_path, "logo.txt") logo_path = os.path.join(self.app.root_path, "logo.txt")
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
import paramiko import paramiko
import threading import threading
import weakref
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -17,11 +18,15 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -17,11 +18,15 @@ class SSHInterface(paramiko.ServerInterface):
""" """
def __init__(self, app, request): def __init__(self, app, request):
self.app = app self._app = weakref.ref(app)
self.request = request self.request = request
self.event = threading.Event() self.event = threading.Event()
self.auth_valid = False self.auth_valid = False
@property
def app(self):
return self._app()
def check_auth_interactive(self, username, submethods): def check_auth_interactive(self, username, submethods):
logger.info("Check auth interactive: %s %s" % (username, submethods)) logger.info("Check auth interactive: %s %s" % (username, submethods))
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
......
import json import json
import queue
import threading import threading
import datetime import datetime
import logging import logging
import weakref
from . import char from . import char
from . import utils from . import utils
from .record import START_SENTINEL, DONE_SENTINEL
BUF_SIZE = 4096 BUF_SIZE = 4096
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -83,44 +84,24 @@ class Server: ...@@ -83,44 +84,24 @@ class Server:
self._input_initial = False self._input_initial = False
self._in_vim_state = False self._in_vim_state = False
self.recorders = []
self.filters = [] self.filters = []
self._input = "" self._input = ""
self._output = "" self._output = ""
self.command_queue = queue.Queue() self._session_ref = None
def add_recorder(self, recorder): @property
self.recorders.append(recorder) def session(self):
return self._session_ref() if self._session_ref is not None else None
def remove_recorder(self, recorder):
self.recorders.remove(recorder)
def add_filter(self, _filter): def add_filter(self, _filter):
self.filters.append(_filter) self.filters.append(_filter)
def remove_filter(self, _filter):
self.filters.remove(_filter)
def record_command_async(self):
def func():
logger.info("Start server command record thread: {}".format(self))
for recorder in self.recorders:
recorder.start()
while not self.stop_evt.is_set():
_input, _output = self.command_queue.get()
if _input is None:
break
for recorder in self.recorders:
recorder.record_command(datetime.datetime.now(), _input, _output)
logger.info("Exit server command record thread: {}".format(self))
for recorder in self.recorders:
recorder.done()
thread = threading.Thread(target=func)
thread.start()
def fileno(self): def fileno(self):
return self.chan.fileno() return self.chan.fileno()
def set_session(self, session):
self._session_ref = weakref.ref(session)
def send(self, b): def send(self, b):
if isinstance(b, str): if isinstance(b, str):
b = b.encode("utf-8") b = b.encode("utf-8")
...@@ -138,7 +119,7 @@ class Server: ...@@ -138,7 +119,7 @@ class Server:
self._input, self._output, self._input, self._output,
"#" * 30 + " End " + "#" * 30, "#" * 30 + " End " + "#" * 30,
)) ))
self.command_queue.put((self._input, self._output)) self.session.put_command(self._input, self._output)
del self.input_data[:] del self.input_data[:]
del self.output_data[:] del self.output_data[:]
# self._input = "" # self._input = ""
...@@ -148,6 +129,7 @@ class Server: ...@@ -148,6 +129,7 @@ class Server:
def recv(self, size): def recv(self, size):
data = self.chan.recv(size) data = self.chan.recv(size)
self.session.put_replay(data)
if self._input_initial: if self._input_initial:
if self._in_input_state: if self._in_input_state:
self.input_data.append(data) self.input_data.append(data)
...@@ -160,7 +142,6 @@ class Server: ...@@ -160,7 +142,6 @@ class Server:
self.chan.close() self.chan.close()
self.stop_evt.set() self.stop_evt.set()
self.chan.transport.close() self.chan.transport.close()
self.command_queue.put((None, None))
@staticmethod @staticmethod
def _have_enter_char(s): def _have_enter_char(s):
......
# -*- coding: utf-8 -*-
#
import queue
class MultiQueueMixin:
def mget(self, size=1, block=True, timeout=5):
items = []
for i in range(size):
try:
items.append(self.get(block=block, timeout=timeout))
except queue.Empty:
break
return items
class MemoryQueue(MultiQueueMixin, queue.Queue):
pass
...@@ -2,229 +2,99 @@ ...@@ -2,229 +2,99 @@
# #
import abc import abc
import tarfile
import threading
import time
import os
import logging import logging
import base64
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
BUF_SIZE = 1024 BUF_SIZE = 1024
START_SENTINEL = object()
DONE_SENTINEL = object()
class ReplayRecorder(metaclass=abc.ABCMeta): class ReplayRecorder(metaclass=abc.ABCMeta):
def __init__(self, app, session): def __init__(self, app):
self.app = app self.app = app
self.session = session
@abc.abstractmethod @abc.abstractmethod
def record_replay(self, now, timedelta, size, data): def record_replay(self, data_set):
pass """
记录replay数据
:param data_set: 数据集 [("session", "data", "timestamp"),]
:return:
"""
for data in data_set:
if data[1] is START_SENTINEL:
data_set.remove(data)
self.session_start(data[0])
if data[1] is DONE_SENTINEL:
data_set.remove(data)
self.session_done(data[0])
@abc.abstractmethod @abc.abstractmethod
def start(self): def session_done(self, session_id):
pass pass
@abc.abstractmethod @abc.abstractmethod
def done(self): def session_start(self, session_id):
pass pass
class CommandRecorder(metaclass=abc.ABCMeta): class CommandRecorder(metaclass=abc.ABCMeta):
def __init__(self, app, session): def __init__(self, app):
self.app = app self.app = app
self.session = session
@abc.abstractmethod @abc.abstractmethod
def record_command(self, now, _input, _output): def record_command(self, data_set):
pass """
:param data_set: 数据集
[("session", "input", "output", "user",
"asset", "system_user", "timestamp"),]
:return:
"""
for data in data_set:
if data[1] is START_SENTINEL:
data_set.remove(data)
self.session_start(data[0])
if data[1] is DONE_SENTINEL:
data_set.remove(data)
self.session_done(data[0])
@abc.abstractmethod @abc.abstractmethod
def start(self): def session_start(self, session_id):
pass pass
@abc.abstractmethod @abc.abstractmethod
def done(self): def session_done(self, session_id):
pass pass
class LocalFileReplayRecorder(ReplayRecorder): class ServerReplayRecorder(ReplayRecorder):
def __init__(self, app, session):
super().__init__(app, session)
self.session_dir = ""
self.data_filename = ""
self.time_filename = ""
self.data_f = None
self.time_f = None
self.prepare_file()
def prepare_file(self):
self.session_dir = os.path.join(
self.app.config["SESSION_DIR"],
self.session.date_created.strftime("%Y-%m-%d"),
str(self.session.id)
)
if not os.path.isdir(self.session_dir):
os.makedirs(self.session_dir)
self.data_filename = os.path.join(self.session_dir, "data.txt")
self.time_filename = os.path.join(self.session_dir, "time.txt")
try:
self.data_f = open(self.data_filename, "wb")
self.time_f = open(self.time_filename, "w")
except IOError as e:
logger.debug(e)
self.done()
def record_replay(self, now, timedelta, size, data):
logger.debug("File recorder replay: ({},{},{})".format(timedelta, size, data))
self.time_f.write("{} {}\n".format(timedelta, size))
self.data_f.write(data)
def start(self):
logger.info("Session record start: {}".format(self.session))
self.data_f.write("Session {} started on {}\n".format(self.session, time.asctime()).encode("utf-8"))
def done(self):
logger.debug("Session record done: {}".format(self.session))
self.data_f.write("Session {} done on {}\n".format(self.session, time.asctime()).encode("utf-8"))
for f in (self.data_f, self.time_f):
try:
f.close()
except IOError:
pass
class LocalFileCommandRecorder(CommandRecorder):
def __init__(self, app, session):
super().__init__(app, session)
self.cmd_filename = ""
self.cmd_f = None
self.session_dir = ""
self.prepare_file()
def prepare_file(self):
self.session_dir = os.path.join(
self.app.config["SESSION_DIR"],
self.session.date_created.strftime("%Y-%m-%d"),
str(self.session.id)
)
if not os.path.isdir(self.session_dir):
os.makedirs(self.session_dir)
self.cmd_filename = os.path.join(self.session_dir, "command.txt")
try:
self.cmd_f = open(self.cmd_filename, "wb")
except IOError as e:
logger.debug(e)
self.done()
def record_command(self, now, _input, _output):
logger.debug("File recorder command: ({},{})".format(_input, _output))
self.cmd_f.write("{}\n".format(now.strftime("%Y-%m-%d %H:%M:%S")))
self.cmd_f.write("$ {}\n".format(_input))
self.cmd_f.write("{}\n\n".format(_output))
self.cmd_f.flush()
def start(self):
pass
def done(self): def record_replay(self, data_set):
try: super().record_replay(data_set)
self.cmd_f.close() print(data_set)
except:
pass def session_start(self, session_id):
print("Session {} start".format(session_id))
class ServerReplayRecorder(LocalFileReplayRecorder): def session_done(self, session_id):
print("Session {} done".format(session_id))
def done(self):
super().done()
self.push_record() class ServerCommandRecorder(CommandRecorder):
def archive_record(self): def record_command(self, data_set):
filename = os.path.join(self.session_dir, "replay.tar.bz2") super().record_command(data_set)
logger.debug("Start archive log: {}".format(filename)) print(data_set)
tar = tarfile.open(filename, "w:bz2")
os.chdir(self.session_dir) def session_start(self, session_id):
time_filename = os.path.basename(self.time_filename) print("Session {} start".format(session_id))
data_filename = os.path.basename(self.data_filename)
for i in (time_filename, data_filename): def session_done(self, session_id):
tar.add(i) print("Session {} done".format(session_id))
tar.close()
return filename
def push_archive_record(self, archive):
logger.debug("Start push replay record to server")
return self.app.service.push_session_replay(archive, str(self.session.id))
def push_record(self):
logger.info("Start push replay record to server")
def func():
archive = self.archive_record()
for i in range(1, 5):
result = self.push_archive_record(archive)
if not result:
logger.error("Push replay error, try again")
time.sleep(5)
continue
else:
break
thread = threading.Thread(target=func)
thread.start()
class ServerCommandRecorder(LocalFileCommandRecorder):
def record_command(self, now, _input, _output):
logger.debug("File recorder command: ({},{})".format(_input, _output))
self.cmd_f.write("{}|{}|{}\n".format(
int(now.timestamp()),
base64.b64encode(_input.encode("utf-8")).decode('utf-8'),
base64.b64encode(_output.encode("utf-8")).decode('utf-8'),
).encode('utf-8'))
def start(self):
pass
def done(self):
super().done()
self.push_record()
def archive_record(self):
filename = os.path.join(self.session_dir, "command.tar.bz2")
logger.debug("Start archive command record: {}".format(filename))
tar = tarfile.open(filename, "w:bz2")
os.chdir(self.session_dir)
cmd_filename = os.path.basename(self.cmd_filename)
tar.add(cmd_filename)
tar.close()
return filename
def push_archive_record(self, archive):
logger.debug("Start push command record to server")
return self.app.service.push_session_command(archive, str(self.session.id))
def push_record(self):
logger.info("Start push command record to server")
def func():
archive = self.archive_record()
for i in range(1, 5):
result = self.push_archive_record(archive)
if not result:
logger.error("Push command record error, try again")
time.sleep(5)
continue
else:
break
thread = threading.Thread(target=func)
thread.start()
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import datetime import datetime
import time import time
import selectors import selectors
import weakref
BUF_SIZE = 1024 BUF_SIZE = 1024
...@@ -16,18 +17,23 @@ logger = logging.getLogger(__file__) ...@@ -16,18 +17,23 @@ logger = logging.getLogger(__file__)
class Session: class Session:
def __init__(self, client, server): def __init__(self, app, client, server):
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
self._app = weakref.ref(app)
self.client = client # Master of the session, it's a client sock self.client = client # Master of the session, it's a client sock
self.server = server # Server channel self.server = server # Server channel
self.watchers = [] # Only watch session self._watchers = [] # Only watch session
self.sharers = [] # Join to the session, read and write self._sharers = [] # Join to the session, read and write
self.replaying = True self.replaying = True
self.date_created = datetime.datetime.now() self.date_created = datetime.datetime.now()
self.date_finished = None self.date_finished = None
self.recorders = []
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.sel = selectors.DefaultSelector() self.sel = selectors.DefaultSelector()
self.server.set_session(self)
@property
def app(self):
return self._app()
def add_watcher(self, watcher, silent=False): def add_watcher(self, watcher, silent=False):
""" """
...@@ -41,12 +47,12 @@ class Session: ...@@ -41,12 +47,12 @@ class Session:
if not silent: if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8")) watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8"))
self.sel.register(watcher, selectors.EVENT_READ) self.sel.register(watcher, selectors.EVENT_READ)
self.watchers.append(watcher) self._watchers.append(watcher)
def remove_watcher(self, watcher): def remove_watcher(self, watcher):
logger.info("Session %s remove watcher %s" % (self.id, watcher)) logger.info("Session %s remove watcher %s" % (self.id, watcher))
self.sel.unregister(watcher) self.sel.unregister(watcher)
self.watchers.remove(watcher) self._watchers.remove(watcher)
def add_sharer(self, sharer, silent=False): def add_sharer(self, sharer, silent=False):
""" """
...@@ -60,7 +66,7 @@ class Session: ...@@ -60,7 +66,7 @@ class Session:
sharer.send("Welcome to join session: {}\r\n" sharer.send("Welcome to join session: {}\r\n"
.format(self.id).encode("utf-8")) .format(self.id).encode("utf-8"))
self.sel.register(sharer, selectors.EVENT_READ) self.sel.register(sharer, selectors.EVENT_READ)
self.sharers.append(sharer) self._sharers.append(sharer)
def remove_sharer(self, sharer): def remove_sharer(self, sharer):
logger.info("Session %s remove sharer %s" % (self.id, sharer)) logger.info("Session %s remove sharer %s" % (self.id, sharer))
...@@ -68,13 +74,7 @@ class Session: ...@@ -68,13 +74,7 @@ class Session:
.format(self.id, datetime.datetime.now()) .format(self.id, datetime.datetime.now())
.encode("utf-8")) .encode("utf-8"))
self.sel.unregister(sharer) self.sel.unregister(sharer)
self.sharers.remove(sharer) self._sharers.remove(sharer)
def add_recorder(self, recorder):
self.recorders.append(recorder)
def remove_recorder(self, recorder):
self.recorders.remove(recorder)
def bridge(self): def bridge(self):
""" """
...@@ -90,31 +90,31 @@ class Session: ...@@ -90,31 +90,31 @@ class Session:
data = sock.recv(BUF_SIZE) data = sock.recv(BUF_SIZE)
if sock == self.server: if sock == self.server:
if len(data) == 0: if len(data) == 0:
msg = "Server close the connection: {}".format(self.server) msg = "Server close the connection"
logger.info(msg) logger.info(msg)
for watcher in [self.client] + self.watchers + self.sharers: for watcher in [self.client] + self._watchers + self._sharers:
watcher.send(msg.encode('utf-8')) watcher.send(msg.encode('utf-8'))
self.close() self.close()
break break
for watcher in [self.client] + self.watchers + self.sharers: for watcher in [self.client] + self._watchers + self._sharers:
watcher.send(data) watcher.send(data)
elif sock == self.client: elif sock == self.client:
if len(data) == 0: if len(data) == 0:
msg = "Client close the connection: {}".format(self.client) msg = "Client close the connection: {}".format(self.client)
logger.info(msg) logger.info(msg)
for watcher in self.watchers + self.sharers: for watcher in self._watchers + self._sharers:
watcher.send(msg.encode("utf-8")) watcher.send(msg.encode("utf-8"))
self.close() self.close()
break break
self.server.send(data) self.server.send(data)
elif sock in self.sharers: elif sock in self._sharers:
if len(data) == 0: if len(data) == 0:
logger.info("Sharer {} leave the session {}".format(sock, self.id)) logger.info("Sharer {} leave the session {}".format(sock, self.id))
self.remove_sharer(sock) self.remove_sharer(sock)
self.server.send(data) self.server.send(data)
elif sock in self.watchers: elif sock in self._watchers:
if len(data) == 0: if len(data) == 0:
self.watchers.remove(sock) self._watchers.remove(sock)
logger.info("Watcher {} leave the session {}".format(sock, self.id)) logger.info("Watcher {} leave the session {}".format(sock, self.id))
logger.info("Session stop event set: {}".format(self.id)) logger.info("Session stop event set: {}".format(self.id))
...@@ -122,36 +122,18 @@ class Session: ...@@ -122,36 +122,18 @@ class Session:
logger.debug("Resize server chan size {}*{}".format(width, height)) logger.debug("Resize server chan size {}*{}".format(width, height))
self.server.resize_pty(width=width, height=height) self.server.resize_pty(width=width, height=height)
def record_replay_async(self): def put_command(self, _input, _output):
def func(): self.app.put_command_queue(self, _input, _output)
parent, child = socket.socketpair()
self.add_watcher(parent) def put_replay(self, data):
logger.info("Start record replay thread: {}".format(self.id)) self.app.put_replay_queue(self, data)
for recorder in self.recorders:
recorder.start()
while not self.stop_evt.is_set():
start_t = time.time()
data = child.recv(BUF_SIZE)
end_t = time.time()
size = len(data)
now = datetime.datetime.now()
timedelta = '{:.4f}'.format(end_t - start_t)
if size == 0:
break
for recorder in self.recorders:
recorder.record_replay(now, timedelta, size, data)
logger.info("Exit record replay thread: {}".format(self.id))
for recorder in self.recorders:
recorder.done()
thread = threading.Thread(target=func)
thread.start()
def close(self): def close(self):
logger.info("Close the session: {} ".format(self.id)) logger.info("Close the session: {} ".format(self.id))
self.stop_evt.set() self.stop_evt.set()
self.date_finished = datetime.datetime.now() self.date_finished = datetime.datetime.now()
self.server.close() self.server.close()
for c in self.watchers + self.sharers: for c in self._watchers + self._sharers:
c.close() c.close()
def to_json(self): def to_json(self):
......
...@@ -19,6 +19,8 @@ import pytz ...@@ -19,6 +19,8 @@ import pytz
from email.utils import formatdate from email.utils import formatdate
from queue import Queue, Empty from queue import Queue, Empty
from .exception import NoAppException
BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
...@@ -361,4 +363,5 @@ def compile_message(): ...@@ -361,4 +363,5 @@ def compile_message():
pass pass
ugettext = _gettext() ugettext = _gettext()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment