Unverified Commit 0c215e5d authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #34 from jumpserver/dev

Merge to master
parents 08b15983 1c8dfae8
...@@ -8,5 +8,3 @@ logs/* ...@@ -8,5 +8,3 @@ logs/*
conf.py conf.py
host_rsa_key host_rsa_key
sessions/* sessions/*
Dockerfile
conf_docker.py
...@@ -10,7 +10,7 @@ pre version. ...@@ -10,7 +10,7 @@ pre version.
## Install ## Install
$ git clone http://xxxx $ git clone https://github.com/jumpserver/coco.git
## Setting ## Setting
...@@ -31,7 +31,7 @@ Also some config you need kown: ...@@ -31,7 +31,7 @@ Also some config you need kown:
## Start ## Start
# python ssh_server.py # python run_server.py
When your start ssh server, It will register with jumpserver api, When your start ssh server, It will register with jumpserver api,
......
...@@ -6,8 +6,8 @@ import datetime ...@@ -6,8 +6,8 @@ import datetime
import os import os
import time import time
import threading import threading
import logging
import socket import socket
import json
from jms.service import AppService from jms.service import AppService
...@@ -17,18 +17,20 @@ from .httpd import HttpServer ...@@ -17,18 +17,20 @@ from .httpd import HttpServer
from .logger import create_logger from .logger import create_logger
from .tasks import TaskHandler from .tasks import TaskHandler
from .recorder import get_command_recorder_class, get_replay_recorder_class from .recorder import get_command_recorder_class, get_replay_recorder_class
from .utils import get_logger
__version__ = '0.4.0' __version__ = '0.5.0'
BASE_DIR = os.path.dirname(os.path.dirname(__file__)) BASE_DIR = os.path.dirname(os.path.dirname(__file__))
logger = logging.getLogger(__file__) logger = get_logger(__file__)
class Coco: class Coco:
config_class = Config config_class = Config
default_config = { default_config = {
'NAME': socket.gethostname(), 'DEFAULT_NAME': socket.gethostname(),
'NAME': None,
'CORE_HOST': 'http://127.0.0.1:8080', 'CORE_HOST': 'http://127.0.0.1:8080',
'DEBUG': True, 'DEBUG': True,
'BIND_HOST': '0.0.0.0', 'BIND_HOST': '0.0.0.0',
...@@ -42,19 +44,18 @@ class Coco: ...@@ -42,19 +44,18 @@ class Coco:
'LOG_DIR': os.path.join(BASE_DIR, 'logs'), 'LOG_DIR': os.path.join(BASE_DIR, 'logs'),
'SESSION_DIR': os.path.join(BASE_DIR, 'sessions'), 'SESSION_DIR': os.path.join(BASE_DIR, 'sessions'),
'ASSET_LIST_SORT_BY': 'hostname', # hostname, ip 'ASSET_LIST_SORT_BY': 'hostname', # hostname, ip
'SSH_PASSWORD_AUTH': True, 'PASSWORD_AUTH': True,
'SSH_PUBLIC_KEY_AUTH': True, 'PUBLIC_KEY_AUTH': True,
'HEARTBEAT_INTERVAL': 5, 'HEARTBEAT_INTERVAL': 5,
'MAX_CONNECTIONS': 500, 'MAX_CONNECTIONS': 500,
'ADMINS': '', 'ADMINS': '',
'REPLAY_RECORD_ENGINE': 'server', # local, server 'COMMAND_STORAGE': {'TYPE': 'server'}, # server
'COMMAND_RECORD_ENGINE': 'server', # local, server, elasticsearch(not yet) 'REPLAY_STORAGE': {'TYPE': 'server'},
} }
def __init__(self, name=None, root_path=None): def __init__(self, root_path=None):
self.root_path = root_path if root_path else BASE_DIR self.root_path = root_path if root_path else BASE_DIR
self.config = self.config_class(self.root_path, defaults=self.default_config) self.config = self.config_class(self.root_path, defaults=self.default_config)
self.name = name if name else self.config["NAME"]
self.sessions = [] self.sessions = []
self.clients = [] self.clients = []
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -66,6 +67,13 @@ class Coco: ...@@ -66,6 +67,13 @@ class Coco:
self.command_recorder_class = None self.command_recorder_class = None
self._task_handler = None self._task_handler = None
@property
def name(self):
if self.config['NAME']:
return self.config['NAME']
else:
return self.config['DEFAULT_NAME']
@property @property
def service(self): def service(self):
if self._service is None: if self._service is None:
...@@ -93,16 +101,20 @@ class Coco: ...@@ -93,16 +101,20 @@ class Coco:
def make_logger(self): def make_logger(self):
create_logger(self) create_logger(self)
# Todo: load some config from server like replay and common upload
def load_extra_conf_from_server(self): def load_extra_conf_from_server(self):
pass configs = self.service.load_config_from_server()
logger.debug("Loading config from server: {}".format(
json.dumps(configs)
))
self.config.update(configs)
def initial_recorder(self): def get_recorder_class(self):
self.replay_recorder_class = get_replay_recorder_class(self) self.replay_recorder_class = get_replay_recorder_class(self.config)
self.command_recorder_class = get_command_recorder_class(self) self.command_recorder_class = get_command_recorder_class(self.config)
def new_command_recorder(self): def new_command_recorder(self):
return self.command_recorder_class(self) recorder = self.command_recorder_class(self)
return recorder
def new_replay_recorder(self): def new_replay_recorder(self):
return self.replay_recorder_class(self) return self.replay_recorder_class(self)
...@@ -111,7 +123,7 @@ class Coco: ...@@ -111,7 +123,7 @@ class Coco:
self.make_logger() self.make_logger()
self.service.initial() self.service.initial()
self.load_extra_conf_from_server() self.load_extra_conf_from_server()
self.initial_recorder() self.get_recorder_class()
self.keep_heartbeat() self.keep_heartbeat()
self.monitor_sessions() self.monitor_sessions()
...@@ -193,6 +205,7 @@ class Coco: ...@@ -193,6 +205,7 @@ class Coco:
for client in self.clients: for client in self.clients:
self.remove_client(client) self.remove_client(client)
time.sleep(1) time.sleep(1)
self.heartbeat()
self.stop_evt.set() self.stop_evt.set()
self.sshd.shutdown() self.sshd.shutdown()
self.httpd.shutdown() self.httpd.shutdown()
...@@ -216,10 +229,10 @@ class Coco: ...@@ -216,10 +229,10 @@ class Coco:
def add_session(self, session): def add_session(self, session):
with self.lock: with self.lock:
self.sessions.append(session) self.sessions.append(session)
self.heartbeat_async() self.service.create_session(session.to_json())
def remove_session(self, session): def remove_session(self, session):
with self.lock: with self.lock:
logger.info("Remove session: {}".format(session)) logger.info("Remove session: {}".format(session))
self.sessions.remove(session) self.sessions.remove(session)
self.heartbeat_async() self.service.finish_session(session.to_json())
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
BACKSPACE_CHAR = {b'\x08': b'\x08\x1b[K', b'\x7f': b'\x08\x1b[K'} BACKSPACE_CHAR = {b'\x08': b'\x08\x1b[K', b'\x7f': b'\x08\x1b[K'}
ENTER_CHAR = [b'\r', b'\n', b'\r\n'] ENTER_CHAR = [b'\r', b'\n', b'\r\n']
ENTER_CHAR_ORDER = [ord(b'\r'), ord(b'\n')]
UNSUPPORTED_CHAR = {b'\x15': 'Ctrl-U', b'\x0c': 'Ctrl-L', b'\x05': 'Ctrl-E'} UNSUPPORTED_CHAR = {b'\x15': 'Ctrl-U', b'\x0c': 'Ctrl-L', b'\x05': 'Ctrl-E'}
CLEAR_CHAR = b'\x1b[H\x1b[2J' CLEAR_CHAR = b'\x1b[H\x1b[2J'
BELL_CHAR = b'\x07' BELL_CHAR = b'\x07'
NEW_LINE = b'\r\n' NEW_LINE = b'\r\n'
RZ_PROTOCOL_CHAR = b'**\x18B0900000000a87c\r\x8a\x11'
This diff is collapsed.
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import logging
import socket import socket
import threading import threading
import weakref import weakref
import os import os
from jms.models import Asset, AssetGroup from jms.models import Asset, AssetGroup
from . import char from . import char
from .utils import wrap_with_line_feed as wr, wrap_with_title as title, \ from .utils import wrap_with_line_feed as wr, wrap_with_title as title, \
wrap_with_primary as primary, wrap_with_warning as warning, \ wrap_with_primary as primary, wrap_with_warning as warning, \
is_obj_attr_has, is_obj_attr_eq, sort_assets, TtyIOParser, \ is_obj_attr_has, is_obj_attr_eq, sort_assets, TtyIOParser, \
ugettext as _ ugettext as _, get_logger
from .proxy import ProxyServer from .proxy import ProxyServer
logger = logging.getLogger(__file__) logger = get_logger(__file__)
class InteractiveServer: class InteractiveServer:
...@@ -42,7 +41,7 @@ class InteractiveServer: ...@@ -42,7 +41,7 @@ class InteractiveServer:
if self._search_result: if self._search_result:
return self._search_result return self._search_result
else: else:
return None return []
@search_result.setter @search_result.setter
def search_result(self, value): def search_result(self, value):
...@@ -81,8 +80,10 @@ class InteractiveServer: ...@@ -81,8 +80,10 @@ class InteractiveServer:
input_data = [] input_data = []
parser = TtyIOParser() parser = TtyIOParser()
self.client.send(wr(prompt, before=1, after=0)) self.client.send(wr(prompt, before=1, after=0))
while True: while True:
data = self.client.recv(10) data = self.client.recv(10)
logger.debug(data)
if len(data) == 0: if len(data) == 0:
self.app.remove_client(self.client) self.app.remove_client(self.client)
break break
...@@ -97,6 +98,15 @@ class InteractiveServer: ...@@ -97,6 +98,15 @@ class InteractiveServer:
self.client.send(data) self.client.send(data)
continue continue
if data.startswith(b'\x03'):
# Ctrl-C
self.client.send(b'^C\r\nOpt> ')
input_data = []
continue
elif data.startswith(b'\x04'):
# Ctrl-D
return 'q'
# Todo: Move x1b to char # Todo: Move x1b to char
if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR: if data.startswith(b'\x1b') or data in char.UNSUPPORTED_CHAR:
self.client.send(b'') self.client.send(b'')
...@@ -104,7 +114,7 @@ class InteractiveServer: ...@@ -104,7 +114,7 @@ class InteractiveServer:
# handle shell expect # handle shell expect
multi_char_with_enter = False multi_char_with_enter = False
if len(data) > 1 and data[-1] in char.ENTER_CHAR: if len(data) > 1 and data[-1] in char.ENTER_CHAR_ORDER:
self.client.send(data) self.client.send(data)
input_data.append(data[:-1]) input_data.append(data[:-1])
multi_char_with_enter = True multi_char_with_enter = True
...@@ -124,13 +134,13 @@ class InteractiveServer: ...@@ -124,13 +134,13 @@ class InteractiveServer:
return self._sentinel return self._sentinel
elif opt.startswith("/"): elif opt.startswith("/"):
self.search_and_display(opt.lstrip("/")) self.search_and_display(opt.lstrip("/"))
elif opt in ['p', 'P']: elif opt in ['p', 'P', '']:
self.display_assets() self.display_assets()
elif opt in ['g', 'G']: elif opt in ['g', 'G']:
self.display_asset_groups() self.display_asset_groups()
elif opt.startswith("g") and opt.lstrip("g").isdigit(): elif opt.startswith("g") and opt.lstrip("g").isdigit():
self.display_group_assets(int(opt.lstrip("g"))) self.display_group_assets(int(opt.lstrip("g")))
elif opt in ['q', 'Q']: elif opt in ['q', 'Q', 'exit', 'quit']:
return self._sentinel return self._sentinel
elif opt in ['h', 'H']: elif opt in ['h', 'H']:
self.display_banner() self.display_banner()
...@@ -173,7 +183,7 @@ class InteractiveServer: ...@@ -173,7 +183,7 @@ class InteractiveServer:
self.get_user_asset_groups() self.get_user_asset_groups()
if len(self.asset_groups) == 0: if len(self.asset_groups) == 0:
self.client.send(warning(_("Nothing"))) self.client.send(warning(_("")))
return return
fake_group = AssetGroup(name=_("Name"), assets_amount=_("Assets"), comment=_("Comment")) fake_group = AssetGroup(name=_("Name"), assets_amount=_("Assets"), comment=_("Comment"))
...@@ -182,25 +192,26 @@ class InteractiveServer: ...@@ -182,25 +192,26 @@ class InteractiveServer:
amount_max_length = max(len(str(max([group.assets_amount for group in self.asset_groups]))), 10) amount_max_length = max(len(str(max([group.assets_amount for group in self.asset_groups]))), 10)
header = '{1:>%d} {0.name:%d} {0.assets_amount:<%s} ' % (id_max_length, name_max_length, amount_max_length) header = '{1:>%d} {0.name:%d} {0.assets_amount:<%s} ' % (id_max_length, name_max_length, amount_max_length)
comment_length = self.request.meta["width"] - len(header.format(fake_group, id_max_length)) comment_length = self.request.meta["width"] - len(header.format(fake_group, id_max_length))
line = header + '{0.comment:%s}' % (comment_length//2) # comment中可能有中文 line = header + '{0.comment:%s}' % (comment_length // 2) # comment中可能有中文
header += "{0.comment:%s}" % comment_length header += "{0.comment:%s}" % comment_length
self.client.send(title(header.format(fake_group, "ID"))) self.client.send(title(header.format(fake_group, "ID")))
for index, group in enumerate(self.asset_groups, 1): for index, group in enumerate(self.asset_groups, 1):
self.client.send(wr(line.format(group, index))) self.client.send(wr(line.format(group, index)))
self.client.send(wr(_("Total: {}").format(len(self.asset_groups)), before=1)) self.client.send(wr(_("总共: {}").format(len(self.asset_groups)), before=1))
def display_group_assets(self, _id): def display_group_assets(self, _id):
if _id > len(self.asset_groups) or _id <= 0: if _id > len(self.asset_groups) or _id <= 0:
self.client.send(wr(warning("Not match group, select again"))) self.client.send(wr(warning("没有匹配分组,请重新输入")))
self.display_asset_groups() self.display_asset_groups()
return return
self.search_result = self.asset_groups[_id-1].assets_granted self.search_result = self.asset_groups[_id - 1].assets_granted
self.display_search_result() self.display_search_result()
def display_search_result(self): def display_search_result(self):
self.search_result = sort_assets(self.search_result, self.app.config["ASSET_LIST_SORT_BY"]) self.search_result = sort_assets(self.search_result, self.app.config["ASSET_LIST_SORT_BY"])
fake_asset = Asset(hostname=_("Hostname"), ip=_("IP"), _system_users_name_list=_("LoginAs"), comment=_("Comment")) fake_asset = Asset(hostname=_("Hostname"), ip=_("IP"), _system_users_name_list=_("LoginAs"),
comment=_("Comment"))
id_max_length = max(len(str(len(self.search_result))), 3) id_max_length = max(len(str(len(self.search_result))), 3)
hostname_max_length = max(max([len(asset.hostname) for asset in self.search_result + [fake_asset]]), 15) hostname_max_length = max(max([len(asset.hostname) for asset in self.search_result + [fake_asset]]), 15)
sysuser_max_length = max([len(asset.system_users_name_list) for asset in self.search_result + [fake_asset]]) sysuser_max_length = max([len(asset.system_users_name_list) for asset in self.search_result + [fake_asset]])
...@@ -212,7 +223,7 @@ class InteractiveServer: ...@@ -212,7 +223,7 @@ class InteractiveServer:
self.client.send(wr(title(header.format(fake_asset, "ID")))) self.client.send(wr(title(header.format(fake_asset, "ID"))))
for index, asset in enumerate(self.search_result, 1): for index, asset in enumerate(self.search_result, 1):
self.client.send(wr(line.format(asset, index))) self.client.send(wr(line.format(asset, index)))
self.client.send(wr(_("Total: {} Matched: {}").format( self.client.send(wr(_("总共: {} 匹配: {}").format(
len(self.assets), len(self.search_result)), before=1) len(self.assets), len(self.search_result)), before=1)
) )
...@@ -254,7 +265,7 @@ class InteractiveServer: ...@@ -254,7 +265,7 @@ class InteractiveServer:
return None return None
while True: while True:
self.client.send(wr(_("Choose one to login: "), after=1)) self.client.send(wr(_("选择一个登陆: "), after=1))
self.display_system_users(system_users) self.display_system_users(system_users)
opt = self.get_option("ID> ") opt = self.get_option("ID> ")
if opt.isdigit() and len(system_users) > int(opt): if opt.isdigit() and len(system_users) > int(opt):
...@@ -272,15 +283,19 @@ class InteractiveServer: ...@@ -272,15 +283,19 @@ class InteractiveServer:
def search_and_proxy(self, opt): def search_and_proxy(self, opt):
self.search_assets(opt) self.search_assets(opt)
if len(self.search_result) == 1: if self.search_result and len(self.search_result) == 1:
self.proxy(self.search_result[0]) asset = self.search_result[0]
if asset.platform == "Windows":
self.client.send(warning(_("终端不支持登录windows, 请使用web terminal访问")))
return
self.proxy(asset)
else: else:
self.display_search_result() self.display_search_result()
def proxy(self, asset): def proxy(self, asset):
system_user = self.choose_system_user(asset.system_users_granted) system_user = self.choose_system_user(asset.system_users_granted)
if system_user is None: if system_user is None:
self.client.send(_("No user")) self.client.send(_("没有系统用户"))
return return
forwarder = ProxyServer(self.app, self.client) forwarder = ProxyServer(self.app, self.client)
forwarder.proxy(asset, system_user) forwarder.proxy(asset, system_user)
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import logging
import paramiko import paramiko
import threading import threading
import weakref import weakref
from .utils import get_logger
logger = logging.getLogger(__file__) logger = get_logger(__file__)
class SSHInterface(paramiko.ServerInterface): class SSHInterface(paramiko.ServerInterface):
...@@ -43,9 +43,9 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -43,9 +43,9 @@ class SSHInterface(paramiko.ServerInterface):
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
supported = [] supported = []
if self.app.config["SSH_PASSWORD_AUTH"]: if self.app.config["PASSWORD_AUTH"]:
supported.append("password") supported.append("password")
if self.app.config["SSH_PUBLIC_KEY_AUTH"]: if self.app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey") supported.append("publickey")
return ",".join(supported) return ",".join(supported)
......
...@@ -4,43 +4,56 @@ ...@@ -4,43 +4,56 @@
import os import os
import logging import logging
from logging import StreamHandler from logging.config import dictConfig
from logging.handlers import TimedRotatingFileHandler
LOG_LEVELS = {
'DEBUG': logging.DEBUG,
'INFO': logging.INFO,
'WARN': logging.WARNING,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'FATAL': logging.FATAL,
'CRITICAL': logging.CRITICAL,
}
def create_logger(app): def create_logger(app):
level = app.config['LOG_LEVEL'] level = app.config['LOG_LEVEL']
level = LOG_LEVELS.get(level, logging.INFO)
log_dir = app.config.get('LOG_DIR') log_dir = app.config.get('LOG_DIR')
log_path = os.path.join(log_dir, 'coco.log') log_path = os.path.join(log_dir, 'coco.log')
main_setting = {
'handlers': ['console', 'file'],
'level': level,
'propagate': False,
}
config = dict(
version=1,
formatters={
"main": {
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S',
},
'simple': {
'format': '%(asctime)s [%(levelname)-8s] %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S',
}
},
handlers={
'null': {
'level': 'DEBUG',
'class': 'logging.NullHandler',
},
'console': {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'main'
},
'file': {
'level': 'DEBUG',
'class': 'logging.FileHandler',
'formatter': 'main',
'filename': log_path,
},
},
loggers={
'coco': main_setting,
'paramiko': main_setting,
'jms': main_setting,
}
)
dictConfig(config)
logger = logging.getLogger() logger = logging.getLogger()
return logger
main_formatter = logging.Formatter(
fmt='%(asctime)s [%(module)s %(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# main_formatter = logging.Formatter(
# fmt='%(asctime)s [%(levelname)s] %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S'
# )
console_handler = StreamHandler()
file_handler = TimedRotatingFileHandler(
filename=log_path, when='D', backupCount=10
)
for handler in [console_handler, file_handler]:
handler.setFormatter(main_formatter)
logger.addHandler(handler)
logger.setLevel(level)
logging.getLogger("requests").setLevel(logging.WARNING)
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import threading import threading
import datetime import datetime
import logging
import weakref import weakref
from . import char from . import char
from . import utils from . import utils
BUF_SIZE = 4096 BUF_SIZE = 4096
logger = logging.getLogger(__file__) logger = utils.get_logger(__file__)
class Request: class Request:
...@@ -23,6 +22,18 @@ class Request: ...@@ -23,6 +22,18 @@ class Request:
self.date_start = datetime.datetime.now() self.date_start = datetime.datetime.now()
class SizedList(list):
def __init__(self, maxsize=0):
self.maxsize = maxsize
self.size = 0
super().__init__()
def append(self, b):
if self.maxsize == 0 or self.size < self.maxsize:
super().append(b)
self.size += len(b)
class Client: class Client:
""" """
Client is the request client. Nothing more to say Client is the request client. Nothing more to say
...@@ -79,8 +90,8 @@ class Server: ...@@ -79,8 +90,8 @@ class Server:
self.recv_bytes = 0 self.recv_bytes = 0
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.input_data = [] self.input_data = SizedList(maxsize=1024)
self.output_data = [] self.output_data = SizedList(maxsize=1024)
self._in_input_state = True self._in_input_state = True
self._input_initial = False self._input_initial = False
self._in_vim_state = False self._in_vim_state = False
...@@ -102,7 +113,7 @@ class Server: ...@@ -102,7 +113,7 @@ class Server:
else: else:
return None return None
def send(self, b): def parse(self, b):
if isinstance(b, str): if isinstance(b, str):
b = b.encode("utf-8") b = b.encode("utf-8")
if not self._input_initial: if not self._input_initial:
...@@ -119,10 +130,14 @@ class Server: ...@@ -119,10 +130,14 @@ class Server:
self._input, self._output, self._input, self._output,
"#" * 30 + " End " + "#" * 30, "#" * 30 + " End " + "#" * 30,
)) ))
if self._input:
self.session.put_command(self._input, self._output) self.session.put_command(self._input, self._output)
del self.input_data[:] del self.input_data[:]
del self.output_data[:] del self.output_data[:]
self._in_input_state = True self._in_input_state = True
def send(self, b):
self.parse(b)
return self.chan.send(b) return self.chan.send(b)
def recv(self, size): def recv(self, size):
...@@ -137,9 +152,10 @@ class Server: ...@@ -137,9 +152,10 @@ class Server:
def close(self): def close(self):
logger.info("Closed server {}".format(self)) logger.info("Closed server {}".format(self))
self.parse(b'')
self.chan.close() self.chan.close()
self.stop_evt.set() self.stop_evt.set()
self.chan.transport.close() self.chan.close()
@staticmethod @staticmethod
def _have_enter_char(s): def _have_enter_char(s):
...@@ -149,10 +165,14 @@ class Server: ...@@ -149,10 +165,14 @@ class Server:
return False return False
def _parse_output(self): def _parse_output(self):
if not self.output_data:
return ''
parser = utils.TtyIOParser() parser = utils.TtyIOParser()
return parser.parse_output(self.output_data) return parser.parse_output(self.output_data)
def _parse_input(self): def _parse_input(self):
if not self.input_data or self.input_data[0] == char.RZ_PROTOCOL_CHAR:
return
parser = utils.TtyIOParser() parser = utils.TtyIOParser()
return parser.parse_input(self.input_data) return parser.parse_input(self.input_data)
...@@ -213,7 +233,10 @@ class WSProxy: ...@@ -213,7 +233,10 @@ class WSProxy:
def forward(self): def forward(self):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try:
data = self.child.recv(BUF_SIZE) data = self.child.recv(BUF_SIZE)
except OSError:
continue
if len(data) == 0: if len(data) == 0:
self.close() self.close()
self.ws.emit("data", {'data': data.decode("utf-8"), 'room': self.connection}, room=self.room) self.ws.emit("data", {'data': data.decode("utf-8"), 'room': self.connection}, room=self.room)
...@@ -226,3 +249,9 @@ class WSProxy: ...@@ -226,3 +249,9 @@ class WSProxy:
def close(self): def close(self):
self.stop_event.set() self.stop_event.set()
self.child.close() self.child.close()
self.ws.logout(self.connection)
logger.debug("Proxy {} closed".format(self))
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import socket import socket
import threading import threading
import logging
import time import time
import weakref import weakref
import paramiko import paramiko
...@@ -13,10 +12,10 @@ from paramiko.ssh_exception import SSHException ...@@ -13,10 +12,10 @@ from paramiko.ssh_exception import SSHException
from .session import Session from .session import Session
from .models import Server from .models import Server
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \ from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_private_key_fingerprint get_private_key_fingerprint, get_logger
logger = logging.getLogger(__file__) logger = get_logger(__file__)
TIMEOUT = 8 TIMEOUT = 8
BUF_SIZE = 4096 BUF_SIZE = 4096
...@@ -93,7 +92,7 @@ class ProxyServer: ...@@ -93,7 +92,7 @@ class ProxyServer:
timeout=TIMEOUT, compress=True, auth_timeout=10, timeout=TIMEOUT, compress=True, auth_timeout=10,
look_for_keys=False look_for_keys=False
) )
except (paramiko.AuthenticationException, paramiko.BadAuthenticationType): except (paramiko.AuthenticationException, paramiko.BadAuthenticationType, SSHException):
admins = self.app.config['ADMINS'] or 'administrator' admins = self.app.config['ADMINS'] or 'administrator'
self.client.send(warning(wr( self.client.send(warning(wr(
"Authenticate with server failed, contact {}".format(admins), "Authenticate with server failed, contact {}".format(admins),
......
...@@ -3,17 +3,20 @@ ...@@ -3,17 +3,20 @@
# #
import abc import abc
import logging
import threading import threading
import time import time
import os import os
import gzip import gzip
import json import json
import shutil import shutil
import boto3 # AWS S3 sdk
from jms_es_sdk import ESStore
from .utils import get_logger
from .alignment import MemoryQueue from .alignment import MemoryQueue
logger = logging.getLogger(__file__) logger = get_logger(__file__)
BUF_SIZE = 1024 BUF_SIZE = 1024
...@@ -126,11 +129,45 @@ class ServerReplayRecorder(ReplayRecorder): ...@@ -126,11 +129,45 @@ class ServerReplayRecorder(ReplayRecorder):
logger.info("Succeed to push {}'s {}".format(session_id, "record")) logger.info("Succeed to push {}'s {}".format(session_id, "record"))
else: else:
logger.error("Failed to push {}'s {}".format(session_id, "record")) logger.error("Failed to push {}'s {}".format(session_id, "record"))
self.push_to_server(session_id)
def push_to_server(self, session_id): def push_to_server(self, session_id):
if self.upload_replay(3, session_id):
if self.finish_replay(3, session_id):
return True
else:
return False
else:
return False
def push_local(self, session_id):
return self.app.service.push_session_replay(os.path.join(self.app.config['LOG_DIR'], session_id + '.replay.gz'), return self.app.service.push_session_replay(os.path.join(self.app.config['LOG_DIR'], session_id + '.replay.gz'),
session_id) session_id)
def upload_replay(self, times, session_id):
if times > 0:
if self.push_local(session_id):
logger.info("success push session: {}'s replay log ".format(session_id))
return True
else:
logger.error("failed report session {}'s replay log, try {} times".format(session_id, times))
return self.upload_replay(times - 1, session_id)
else:
logger.error("failed report session {}'s replay log".format(session_id))
return False
def finish_replay(self, times, session_id):
if times > 0:
if self.app.service.finish_replay(session_id):
logger.info("success report session {}'s replay log ".format(session_id))
return True
else:
logger.error("failed report session {}'s replay log, try {} times".format(session_id, times))
return self.finish_replay(times - 1, session_id)
else:
logger.error("failed report session {}'s replay log".format(session_id))
return False
def __del__(self): def __del__(self):
print("{} has been gc".format(self)) print("{} has been gc".format(self))
del self.file del self.file
...@@ -183,18 +220,126 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton): ...@@ -183,18 +220,126 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
print("{} has been gc".format(self)) print("{} has been gc".format(self))
def get_command_recorder_class(app): class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
command_engine = app.config["COMMAND_RECORD_ENGINE"] batch_size = 10
timeout = 5
no = 0
default_hosts = ["http://localhost"]
if command_engine == "server": def __init__(self, app):
return ServerCommandRecorder super().__init__(app)
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_es_async()
self.__class__.no += 1
self.store = ESStore(app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts))
if not self.store.ping():
raise AssertionError("ESCommand storage init error")
def record(self, data):
if data and data['input']:
data['input'] = data['input'][:128]
data['output'] = data['output'][:1024]
data['timestamp'] = int(data['timestamp'])
self.queue.put(data)
def push_to_es_async(self):
def func():
while not self.stop_evt.is_set():
data_set = self.queue.mget(self.batch_size,
timeout=self.timeout)
logger.debug(
"<Session command recorder {}> queue size: {}".format(
self.no, self.queue.qsize())
)
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.store.bulk_save(data_set)
if not ok:
self.queue.mput(data_set)
thread = threading.Thread(target=func)
thread.daemon = True
thread.start()
def session_start(self, session_id):
pass
def session_end(self, session_id):
pass
def __del__(self):
print("{} has been gc".format(self))
class S3ReplayRecorder(ServerReplayRecorder):
def __init__(self, app):
super().__init__(app)
self.bucket = app.config["REPLAY_STORAGE"].get("BUCKET", "jumpserver")
self.REGION = app.config["REPLAY_STORAGE"].get("REGION", None)
self.ACCESS_KEY = app.config["REPLAY_STORAGE"].get("ACCESS_KEY", None)
self.SECRET_KEY = app.config["REPLAY_STORAGE"].get("SECRET_KEY", None)
if self.ACCESS_KEY and self.REGION and self.SECRET_KEY:
self.s3 = boto3.client('s3',
region_name=self.REGION,
aws_access_key_id=self.ACCESS_KEY,
aws_secret_access_key=self.SECRET_KEY)
else:
self.s3 = boto3.client('s3')
def push_to_s3(self, session_id):
logger.debug("push to server")
try:
self.s3.upload_file(
os.path.join(self.app.config['LOG_DIR'], session_id + '.replay.gz'),
self.bucket,
time.strftime('%Y-%m-%d', time.localtime(
self.starttime)) + '/' + session_id + '.replay.gz')
return True
except:
return False
def upload_replay(self, times, session_id):
if times > 0:
if self.push_to_s3(session_id):
logger.info("success push session: {}'s replay log to S3 ".format(session_id))
return True
else:
logger.error("failed report session {}'s replay log to S3, try {} times".format(session_id, times))
return self.upload_replay(times - 1, session_id)
else:
logger.error("failed report session {}'s replay log S3, try to push to local".format(session_id))
return self.upload_replay_to_local(3, session_id)
def upload_replay_to_local(self, times, session_id):
if times > 0:
if self.push_local(session_id):
logger.info("success push session: {}'s replay log ".format(session_id))
return True
else:
logger.error("failed report session {}'s replay log, try {} times".format(session_id, times))
return self.upload_replay_to_local(times - 1, session_id)
else:
logger.error("failed report session {}'s replay log".format(session_id))
return False
def get_command_recorder_class(config):
command_storage = config["COMMAND_STORAGE"]
storage_type = command_storage.get('TYPE')
if storage_type == "elasticsearch":
return ESCommandRecorder
else: else:
return ServerCommandRecorder return ServerCommandRecorder
def get_replay_recorder_class(app): def get_replay_recorder_class(config):
replay_engine = app.config["REPLAY_RECORD_ENGINE"] replay_storage = config["REPLAY_STORAGE"]
if replay_engine == "server": logger.debug(replay_storage)
return ServerReplayRecorder storage_type = replay_storage.get('TYPE')
if storage_type == "s3":
return S3ReplayRecorder
else: else:
return ServerReplayRecorder return ServerReplayRecorder
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
# #
import threading import threading
import uuid import uuid
import logging
import datetime import datetime
import selectors import selectors
import time import time
from .utils import get_logger
BUF_SIZE = 1024 BUF_SIZE = 1024
logger = logging.getLogger(__file__) logger = get_logger(__file__)
class Session: class Session:
...@@ -27,6 +28,7 @@ class Session: ...@@ -27,6 +28,7 @@ class Session:
self._command_recorder = command_recorder self._command_recorder = command_recorder
self._replay_recorder = replay_recorder self._replay_recorder = replay_recorder
self.server.set_session(self) self.server.set_session(self)
self.date_last_active = datetime.datetime.utcnow()
def add_watcher(self, watcher, silent=False): def add_watcher(self, watcher, silent=False):
""" """
...@@ -128,6 +130,8 @@ class Session: ...@@ -128,6 +130,8 @@ class Session:
logger.info(msg) logger.info(msg)
self.close() self.close()
break break
self.date_last_active = datetime.datetime.utcnow()
for watcher in [self.client] + self._watchers + self._sharers: for watcher in [self.client] + self._watchers + self._sharers:
watcher.send(data) watcher.send(data)
elif sock == self.client: elif sock == self.client:
...@@ -170,6 +174,7 @@ class Session: ...@@ -170,6 +174,7 @@ class Session:
"login_from": "ST", "login_from": "ST",
"remote_addr": self.client.addr[0], "remote_addr": self.client.addr[0],
"is_finished": True if self.stop_evt.is_set() else False, "is_finished": True if self.stop_evt.is_set() else False,
"date_last_active": self.date_last_active.strftime("%Y-%m-%d %H:%M:%S") + " +0000",
"date_start": self.date_created.strftime("%Y-%m-%d %H:%M:%S") + " +0000", "date_start": self.date_created.strftime("%Y-%m-%d %H:%M:%S") + " +0000",
"date_end": self.date_end.strftime("%Y-%m-%d %H:%M:%S") + " +0000" if self.date_end else None "date_end": self.date_end.strftime("%Y-%m-%d %H:%M:%S") + " +0000" if self.date_end else None
} }
......
...@@ -3,20 +3,16 @@ ...@@ -3,20 +3,16 @@
# #
import os import os
import logging
import socket import socket
import threading import threading
import paramiko import paramiko
import sys
import time from .utils import ssh_key_gen, get_logger
from .utils import ssh_key_gen
from .interface import SSHInterface from .interface import SSHInterface
from .interactive import InteractiveServer from .interactive import InteractiveServer
from .models import Client, Request from .models import Client, Request
logger = logging.getLogger(__file__) logger = get_logger(__file__)
BACKLOG = 5 BACKLOG = 5
...@@ -90,14 +86,12 @@ class SSHServer: ...@@ -90,14 +86,12 @@ class SSHServer:
def handle_chan(self, chan, request): def handle_chan(self, chan, request):
client = Client(chan, request) client = Client(chan, request)
print(chan)
print(request)
self.app.add_client(client) self.app.add_client(client)
self.dispatch(client) self.dispatch(client)
def dispatch(self, client): def dispatch(self, client):
request_type = client.request.type request_type = client.request.type
if request_type == 'pty': if request_type == 'pty' or request_type == 'x11':
logger.info("Request type `pty`, dispatch to interactive mode") logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact() InteractiveServer(self.app, client).interact()
elif request_type == 'exec': elif request_type == 'exec':
......
# coding: utf-8 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
import weakref import weakref
import logging
logger = logging.getLogger(__file__) from .utils import get_logger
logger = get_logger(__file__)
class TaskHandler: class TaskHandler:
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import hashlib import hashlib
import logging
import re import re
import os import os
import threading import threading
...@@ -85,61 +86,6 @@ def ssh_key_gen(length=2048, type='rsa', password=None, ...@@ -85,61 +86,6 @@ def ssh_key_gen(length=2048, type='rsa', password=None,
raise IOError('These is error when generate ssh key.') raise IOError('These is error when generate ssh key.')
def content_md5(data):
"""计算data的MD5值,经过Base64编码并返回str类型。
返回值可以直接作为HTTP Content-Type头部的值
"""
if isinstance(data, str):
data = hashlib.md5(data.encode('utf-8'))
value = base64.b64encode(data.digest())
return value.decode('utf-8')
_STRPTIME_LOCK = threading.Lock()
_GMT_FORMAT = "%a, %d %b %Y %H:%M:%S GMT"
_ISO8601_FORMAT = "%Y-%m-%dT%H:%M:%S.000Z"
def to_unixtime(time_string, format_string):
with _STRPTIME_LOCK:
return int(calendar.timegm(time.strptime(str(time_string), format_string)))
def http_date(timeval=None):
"""返回符合HTTP标准的GMT时间字符串,用strftime的格式表示就是"%a, %d %b %Y %H:%M:%S GMT"。
但不能使用strftime,因为strftime的结果是和locale相关的。
"""
return formatdate(timeval, usegmt=True)
def http_to_unixtime(time_string):
"""把HTTP Date格式的字符串转换为UNIX时间(自1970年1月1日UTC零点的秒数)。
HTTP Date形如 `Sat, 05 Dec 2015 11:10:29 GMT` 。
"""
return to_unixtime(time_string, _GMT_FORMAT)
def iso8601_to_unixtime(time_string):
"""把ISO8601时间字符串(形如,2012-02-24T06:07:48.000Z)转换为UNIX时间,精确到秒。"""
return to_unixtime(time_string, _ISO8601_FORMAT)
def make_signature(access_key_secret, date=None):
if isinstance(date, bytes):
date = bytes.decode(date)
if isinstance(date, int):
date_gmt = http_date(date)
elif date is None:
date_gmt = http_date(int(time.time()))
else:
date_gmt = date
data = str(access_key_secret) + "\n" + date_gmt
return content_md5(data)
class TtyIOParser(object): class TtyIOParser(object):
def __init__(self, width=80, height=24): def __init__(self, width=80, height=24):
self.screen = pyte.Screen(width, height) self.screen = pyte.Screen(width, height)
...@@ -162,9 +108,12 @@ class TtyIOParser(object): ...@@ -162,9 +108,12 @@ class TtyIOParser(object):
for d in data: for d in data:
self.stream.feed(d) self.stream.feed(d)
try:
for line in self.screen.display: for line in self.screen.display:
if line.strip(): if line.strip():
output.append(line) output.append(line)
except IndexError:
pass
self.screen.reset() self.screen.reset()
return sep.join(output[0:-1]).strip() return sep.join(output[0:-1]).strip()
...@@ -283,10 +232,6 @@ def wrap_with_title(text): ...@@ -283,10 +232,6 @@ def wrap_with_title(text):
return wrap_with_color(text, color='black', background='green') return wrap_with_color(text, color='black', background='green')
def b64encode_as_string(data):
return base64.b64encode(data).decode("utf-8")
def split_string_int(s): def split_string_int(s):
"""Split string or int """Split string or int
...@@ -320,37 +265,6 @@ def sort_assets(assets, order_by='hostname'): ...@@ -320,37 +265,6 @@ def sort_assets(assets, order_by='hostname'):
return assets return assets
class PKey(object):
@classmethod
def from_string(cls, key_string):
try:
pkey = paramiko.RSAKey(file_obj=StringIO(key_string))
return pkey
except paramiko.SSHException:
try:
pkey = paramiko.DSSKey(file_obj=StringIO(key_string))
return pkey
except paramiko.SSHException:
return None
def timestamp_to_datetime_str(ts):
datetime_format = '%Y-%m-%dT%H:%M:%S.%fZ'
dt = datetime.datetime.fromtimestamp(ts, tz=pytz.timezone('UTC'))
return dt.strftime(datetime_format)
class MultiQueue(Queue):
def mget(self, size=1, block=True, timeout=5):
items = []
for i in range(size):
try:
items.append(self.get(block=block, timeout=timeout))
except Empty:
break
return items
def _gettext(): def _gettext():
gettext.bindtextdomain("coco", os.path.join(BASE_DIR, "locale")) gettext.bindtextdomain("coco", os.path.join(BASE_DIR, "locale"))
gettext.textdomain("coco") gettext.textdomain("coco")
...@@ -371,4 +285,21 @@ def compile_message(): ...@@ -371,4 +285,21 @@ def compile_message():
pass pass
def get_logger(file_name):
return logging.getLogger('coco.'+file_name)
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
def len_display(s):
length = 0
for i in s:
if zh_pattern.match(i):
length += 2
else:
length += 1
return length
ugettext = _gettext() ugettext = _gettext()
...@@ -9,10 +9,10 @@ BASE_DIR = os.path.dirname(__file__) ...@@ -9,10 +9,10 @@ BASE_DIR = os.path.dirname(__file__)
class Config: class Config:
""" """
Coco config file Coco config file, coco also load config from server update setting below
""" """
# 项目名称, 会用来向Jumpserver注册, 识别而已, 不能重复 # 项目名称, 会用来向Jumpserver注册, 识别而已, 不能重复
# APP_NAME = "localhost" # NAME = "localhost"
# Jumpserver项目的url, api请求注册会使用 # Jumpserver项目的url, api请求注册会使用
# CORE_HOST = os.environ.get("CORE_HOST") or 'http://127.0.0.1:8080' # CORE_HOST = os.environ.get("CORE_HOST") or 'http://127.0.0.1:8080'
...@@ -49,16 +49,22 @@ class Config: ...@@ -49,16 +49,22 @@ class Config:
# ASSET_LIST_SORT_BY = 'ip' # ASSET_LIST_SORT_BY = 'ip'
# 登录是否支持密码认证 # 登录是否支持密码认证
# SSH_PASSWORD_AUTH = True # PASSWORD_AUTH = True
# 登录是否支持秘钥认证 # 登录是否支持秘钥认证
# SSH_PUBLIC_KEY_AUTH = True # PUBLIC_KEY_AUTH = True
# 和Jumpserver 保持心跳时间间隔 # 和Jumpserver 保持心跳时间间隔
# HEARTBEAT_INTERVAL = 5 # HEARTBEAT_INTERVAL = 5
# Admin的名字,出问题会提示给用户 # Admin的名字,出问题会提示给用户
# ADMINS = '' # ADMINS = ''
COMMAND_STORAGE = {
"TYPE": "server"
}
REPLAY_STORAGE = {
"TYPE": "server"
}
config = Config() config = Config()
asn1crypto==0.23.0 asn1crypto==0.23.0
bcrypt==3.1.4 bcrypt==3.1.4
boto3==1.5.18
botocore==1.8.32
certifi==2017.11.5 certifi==2017.11.5
cffi==1.11.2 cffi==1.11.2
chardet==3.0.4 chardet==3.0.4
...@@ -28,4 +30,5 @@ tornado==4.5.2 ...@@ -28,4 +30,5 @@ tornado==4.5.2
urllib3==1.22 urllib3==1.22
wcwidth==0.1.7 wcwidth==0.1.7
werkzeug==0.12.2 werkzeug==0.12.2
jumpserver-python-sdk==0.0.23 jumpserver-python-sdk==0.0.31
jms-es-sdk
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment