Commit 4d2015df authored by 广宏伟's avatar 广宏伟

Merged in dev (pull request #56)

Dev
parents 27dd5afb 17d01a8e
......@@ -18,7 +18,7 @@ from .httpd import HttpServer
from .logger import create_logger
from .tasks import TaskHandler
from .recorder import get_command_recorder_class, ServerReplayRecorder
from .utils import get_logger
from .utils import get_logger, register_app, register_service
__version__ = '1.3.0'
......@@ -67,6 +67,7 @@ class Coco:
self.replay_recorder_class = None
self.command_recorder_class = None
self._task_handler = None
register_app(self)
@property
def name(self):
......@@ -79,24 +80,25 @@ class Coco:
def service(self):
if self._service is None:
self._service = AppService(self)
register_service(self._service)
return self._service
@property
def sshd(self):
if self._sshd is None:
self._sshd = SSHServer(self)
self._sshd = SSHServer()
return self._sshd
@property
def httpd(self):
if self._httpd is None:
self._httpd = HttpServer(self)
self._httpd = HttpServer()
return self._httpd
@property
def task_handler(self):
if self._task_handler is None:
self._task_handler = TaskHandler(self)
self._task_handler = TaskHandler()
return self._task_handler
def make_logger(self):
......@@ -114,11 +116,10 @@ class Coco:
self.command_recorder_class = get_command_recorder_class(self.config)
def new_command_recorder(self):
recorder = self.command_recorder_class(self)
return recorder
return self.command_recorder_class()
def new_replay_recorder(self):
return self.replay_recorder_class(self)
return self.replay_recorder_class()
def bootstrap(self):
self.make_logger()
......
# -*- coding: utf-8 -*-
#
import weakref
import os
import socket
import paramiko
from paramiko.ssh_exception import SSHException
from .ctx import app_service
from .utils import get_logger, get_private_key_fingerprint
logger = get_logger(__file__)
......@@ -15,20 +15,13 @@ TIMEOUT = 10
class SSHConnection:
def __init__(self, app):
self._app = weakref.ref(app)
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user):
"""
获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
password, private_key = \
self.app.service.get_system_user_auth_info(system_user)
app_service.get_system_user_auth_info(system_user)
system_user.password = password
system_user.private_key = private_key
......@@ -97,7 +90,7 @@ class SSHConnection:
def get_proxy_sock(self, asset):
sock = None
domain = self.app.service.get_domain_detail_with_gateway(
domain = app_service.get_domain_detail_with_gateway(
asset.domain
)
if not domain.has_ssh_gateway():
......
# -*- coding: utf-8 -*-
#
from werkzeug.local import LocalProxy
from functools import partial
stack = {}
def _find(name):
if stack.get(name):
return stack[name]
else:
raise ValueError("Not found in stack: {}".format(name))
current_app = LocalProxy(partial(_find, 'app'))
app_service = LocalProxy(partial(_find, 'service'))
# current_app = []
# current_service = []
This diff is collapsed.
This diff is collapsed.
......@@ -4,9 +4,9 @@
import paramiko
import threading
import weakref
from .utils import get_logger
from .ctx import current_app, app_service
logger = get_logger(__file__)
......@@ -19,22 +19,13 @@ class SSHInterface(paramiko.ServerInterface):
https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py
"""
def __init__(self, app, request):
self._app = weakref.ref(app)
self._request = weakref.ref(request)
def __init__(self, request):
self.request = request
self.event = threading.Event()
self.auth_valid = False
self.otp_auth = False
self.info = None
@property
def app(self):
return self._app()
@property
def request(self):
return self._request()
def check_auth_interactive(self, username, submethods):
logger.info("Check auth interactive: %s %s" % (username, submethods))
instructions = 'Please enter 6 digits.'
......@@ -55,7 +46,7 @@ class SSHInterface(paramiko.ServerInterface):
if not seed:
return paramiko.AUTH_FAILED
is_valid = self.app.service.authenticate_otp(seed, otp_code)
is_valid = app_service.authenticate_otp(seed, otp_code)
if is_valid:
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
......@@ -67,9 +58,9 @@ class SSHInterface(paramiko.ServerInterface):
supported = []
if self.otp_auth:
return 'keyboard-interactive'
if self.app.config["PASSWORD_AUTH"]:
if current_app.config["PASSWORD_AUTH"]:
supported.append("password")
if self.app.config["PUBLIC_KEY_AUTH"]:
if current_app.config["PUBLIC_KEY_AUTH"]:
supported.append("publickey")
return ",".join(supported)
......@@ -100,7 +91,7 @@ class SSHInterface(paramiko.ServerInterface):
return paramiko.AUTH_SUCCESSFUL
def validate_auth(self, username, password="", public_key=""):
info = self.app.service.authenticate(
info = app_service.authenticate(
username, password=password, public_key=public_key,
remote_addr=self.request.remote_ip
)
......
......@@ -218,7 +218,7 @@ class WSProxy:
```
"""
def __init__(self, ws, child, room, connection):
def __init__(self, ws, child, room_id):
"""
:param ws: websocket instance or handler, have write_message method
:param child: sock child pair
......@@ -226,9 +226,8 @@ class WSProxy:
self.ws = ws
self.child = child
self.stop_event = threading.Event()
self.room = room
self.room_id = room_id
self.auto_forward()
self.connection = connection
def send(self, msg):
"""
......@@ -252,7 +251,9 @@ class WSProxy:
if len(data) == 0:
self.close()
data = data.decode(errors="ignore")
self.ws.emit("data", {'data': data, 'room': self.connection}, room=self.room)
print("Send data: {}".format(data))
self.ws.emit("data", {'data': data, 'room': self.room_id},
room=self.room_id)
if len(data) == BUF_SIZE:
time.sleep(0.1)
......@@ -265,7 +266,6 @@ class WSProxy:
self.stop_event.set()
self.child.shutdown(1)
self.child.close()
self.ws.logout(self.connection)
logger.debug("Proxy {} closed".format(self))
......
......@@ -4,13 +4,13 @@
import threading
import time
import weakref
from paramiko.ssh_exception import SSHException
from .session import Session
from .models import Server
from .connection import SSHConnection
from .ctx import current_app, app_service
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
get_logger, net_input
......@@ -21,24 +21,19 @@ BUF_SIZE = 4096
class ProxyServer:
def __init__(self, app, client):
self._app = weakref.ref(app)
def __init__(self, client):
self.client = client
self.server = None
self.connecting = True
self.stop_event = threading.Event()
@property
def app(self):
return self._app()
def get_system_user_auth(self, system_user):
"""
获取系统用户的认证信息,密码或秘钥
:return: system user have full info
"""
password, private_key = \
self.app.service.get_system_user_auth_info(system_user)
app_service.get_system_user_auth_info(system_user)
if not password and not private_key:
prompt = "{}'s password: ".format(system_user.username)
password = net_input(self.client, prompt=prompt, sensitive=True)
......@@ -51,26 +46,26 @@ class ProxyServer:
self.server = self.get_server_conn(asset, system_user)
if self.server is None:
return
command_recorder = self.app.new_command_recorder()
replay_recorder = self.app.new_replay_recorder()
command_recorder = current_app.new_command_recorder()
replay_recorder = current_app.new_replay_recorder()
session = Session(
self.client, self.server,
command_recorder=command_recorder,
replay_recorder=replay_recorder,
)
self.app.add_session(session)
current_app.add_session(session)
self.watch_win_size_change_async()
session.bridge()
self.stop_event.set()
self.end_watch_win_size_change()
self.app.remove_session(session)
current_app.remove_session(session)
def validate_permission(self, asset, system_user):
"""
验证用户是否有连接改资产的权限
:return: True or False
"""
return self.app.service.validate_user_asset_permission(
return app_service.validate_user_asset_permission(
self.client.user.id, asset.id, system_user.id
)
......@@ -90,13 +85,14 @@ class ProxyServer:
pass
def get_ssh_server_conn(self, asset, system_user):
ssh = SSHConnection(self.app)
request = self.client.request
term = request.meta.get('term', 'xterm')
width = request.meta.get('width', 80)
height = request.meta.get('height', 24)
chan, msg = ssh.get_channel(asset, system_user, term=term,
width=width, height=height)
ssh = SSHConnection()
chan, msg = ssh.get_channel(
asset, system_user, term=term, width=width, height=height
)
if not chan:
self.client.send(warning(wr(msg, before=1, after=0)))
server = None
......@@ -130,9 +126,11 @@ class ProxyServer:
def send_connecting_message(self, asset, system_user):
def func():
delay = 0.0
self.client.send('Connecting to {}@{} {:.1f}'.format(system_user, asset, delay))
self.client.send('Connecting to {}@{} {:.1f}'.format(
system_user, asset, delay)
)
while self.connecting and delay < TIMEOUT:
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode('utf-8'))
self.client.send('\x08\x08\x08{:.1f}'.format(delay).encode())
time.sleep(0.1)
delay += 0.1
thread = threading.Thread(target=func)
......
......@@ -8,33 +8,19 @@ import time
import os
import gzip
import json
import shutil
import jms_storage
from .utils import get_logger
from .utils import get_logger, Singleton
from .alignment import MemoryQueue
from .ctx import current_app, app_service
logger = get_logger(__file__)
BUF_SIZE = 1024
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
class ReplayRecorder(metaclass=abc.ABCMeta):
def __init__(self, app, session=None):
self.app = app
def __init__(self, session=None):
self.session = session
@abc.abstractmethod
......@@ -61,8 +47,7 @@ class ReplayRecorder(metaclass=abc.ABCMeta):
class CommandRecorder:
def __init__(self, app, session=None):
self.app = app
def __init__(self, session=None):
self.session = session
def record(self, data):
......@@ -92,8 +77,8 @@ class ServerReplayRecorder(ReplayRecorder):
time_start = None
storage = None
def __init__(self, app):
super().__init__(app)
def __init__(self):
super().__init__()
self.file = None
self.file_path = None
......@@ -114,8 +99,8 @@ class ServerReplayRecorder(ReplayRecorder):
def session_start(self, session_id):
self.time_start = time.time()
filename = session_id+'.replay.gz'
self.file_path = os.path.join(self.app.config['LOG_DIR'], filename)
filename = session_id + '.replay.gz'
self.file_path = os.path.join(current_app.config['LOG_DIR'], filename)
self.file = gzip.open(self.file_path, 'at')
self.file.write('{')
......@@ -128,11 +113,11 @@ class ServerReplayRecorder(ReplayRecorder):
logger.error("Failed to push {}'s {}".format(session_id, "record"))
def upload_replay(self, session_id):
configs = self.app.service.load_config_from_server()
configs = app_service.load_config_from_server()
logger.debug("upload_replay print config: {}".format(configs))
self.storage = jms_storage.init(configs["REPLAY_STORAGE"])
if not self.storage:
self.storage = jms_storage.jms(self.app.service)
self.storage = jms_storage.jms(app_service)
if self.push_file(3, session_id):
os.unlink(self.file_path)
return True
......@@ -151,7 +136,7 @@ class ServerReplayRecorder(ReplayRecorder):
else:
msg = "Failed push session {}'s replay log to storage".format(session_id)
logger.error(msg)
self.storage = jms_storage.jms(self.app.service)
self.storage = jms_storage.jms(app_service)
return self.push_file(3, session_id)
if self.push_to_storage(session_id):
......@@ -167,7 +152,7 @@ class ServerReplayRecorder(ReplayRecorder):
logger.error("Failed finished session {}'s replay".format(session_id))
return False
if self.app.service.finish_replay(session_id):
if app_service.finish_replay(session_id):
logger.info("Success finish session {}'s replay ".format(session_id))
return True
else:
......@@ -180,8 +165,8 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
timeout = 5
no = 0
def __init__(self, app):
super().__init__(app)
def __init__(self):
super().__init__()
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_server_async()
......@@ -204,7 +189,7 @@ class ServerCommandRecorder(CommandRecorder, metaclass=Singleton):
if not data_set:
continue
logger.debug("Send {} commands to server".format(len(data_set)))
ok = self.app.service.push_session_command(data_set)
ok = app_service.push_session_command(data_set)
if not ok:
self.queue.mput(data_set)
......@@ -228,13 +213,15 @@ class ESCommandRecorder(CommandRecorder, metaclass=Singleton):
no = 0
default_hosts = ["http://localhost"]
def __init__(self, app):
super().__init__(app)
def __init__(self):
super().__init__()
self.queue = MemoryQueue()
self.stop_evt = threading.Event()
self.push_to_es_async()
self.__class__.no += 1
self.store = jms_storage.ESStore(app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts))
self.store = jms_storage.ESStore(
current_app.config["COMMAND_STORAGE"].get("HOSTS", self.default_hosts)
)
if not self.store.ping():
raise AssertionError("ESCommand storage init error")
......
......@@ -40,7 +40,7 @@ class Session:
"""
logger.info("Session add watcher: {} -> {} ".format(self.id, watcher))
if not silent:
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode("utf-8"))
watcher.send("Welcome to watch session {}\r\n".format(self.id).encode())
self.sel.register(watcher, selectors.EVENT_READ)
self._watchers.append(watcher)
......
......@@ -34,7 +34,7 @@ class SFTPServer(paramiko.SFTPServerInterface):
self._sftp[host] = sftp
return sftp
else:
raise OSError("Can not connect asset sftp server")
raise OSError("Can not connect asset sftp server: {}".format(msg))
else:
return self._sftp[host]
......
......@@ -12,6 +12,7 @@ from .interface import SSHInterface
from .interactive import InteractiveServer
from .models import Client, Request
from .sftp import SFTPServer
from .ctx import current_app
logger = get_logger(__file__)
BACKLOG = 5
......@@ -19,38 +20,41 @@ BACKLOG = 5
class SSHServer:
def __init__(self, app):
self.app = app
def __init__(self):
self.stop_evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.host_key_path = os.path.join(self.app.root_path, 'keys', 'host_rsa_key')
self.workers = []
self.pipe = None
@property
def host_key(self):
if not os.path.isfile(self.host_key_path):
self.gen_host_key()
return paramiko.RSAKey(filename=self.host_key_path)
host_key_path = os.path.join(current_app.root_path, 'keys', 'host_rsa_key')
if not os.path.isfile(host_key_path):
self.gen_host_key(host_key_path)
return paramiko.RSAKey(filename=host_key_path)
def gen_host_key(self):
@staticmethod
def gen_host_key(key_path):
ssh_key, _ = ssh_key_gen()
with open(self.host_key_path, 'w') as f:
with open(key_path, 'w') as f:
f.write(ssh_key)
def run(self):
host = self.app.config["BIND_HOST"]
port = self.app.config["SSHD_PORT"]
host = current_app.config["BIND_HOST"]
port = current_app.config["SSHD_PORT"]
print('Starting ssh server at {}:{}'.format(host, port))
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((host, port))
self.sock.listen(BACKLOG)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(BACKLOG)
while not self.stop_evt.is_set():
try:
sock, addr = self.sock.accept()
logger.info("Get ssh request from {}: {}".format(addr[0], addr[1]))
thread = threading.Thread(target=self.handle_connection, args=(sock, addr))
client, addr = sock.accept()
logger.info("Get ssh request from {}: {}".format(*addr))
thread = threading.Thread(target=self.handle_connection,
args=(client, addr))
thread.daemon = True
thread.start()
except Exception as e:
except IndexError as e:
logger.error("Start SSH server error: {}".format(e))
def handle_connection(self, sock, addr):
......@@ -65,7 +69,7 @@ class SSHServer:
'sftp', paramiko.SFTPServer, SFTPServer
)
request = Request(addr)
server = SSHInterface(self.app, request)
server = SSHInterface(request)
try:
transport.start_server(server=server)
except paramiko.SSHException:
......@@ -96,7 +100,7 @@ class SSHServer:
def handle_chan(self, chan, request):
client = Client(chan, request)
self.app.add_client(client)
current_app.add_client(client)
self.dispatch(client)
def dispatch(self, client):
......@@ -104,7 +108,7 @@ class SSHServer:
request_type = set(client.request.type)
if supported & request_type:
logger.info("Request type `pty`, dispatch to interactive mode")
InteractiveServer(self.app, client).interact()
InteractiveServer(client).interact()
elif 'subsystem' in request_type:
pass
else:
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
import weakref
from .ctx import current_app, app_service
from .utils import get_logger
logger = get_logger(__file__)
class TaskHandler:
routes = None
def __init__(self, app):
self._app = weakref.ref(app)
def init(self):
self.routes = {
'kill_session': self.handle_kill_session
}
@property
def app(self):
return self._app()
def handle_kill_session(self, task):
@staticmethod
def handle_kill_session(task):
logger.info("Handle kill session task: {}".format(task.args))
session_id = task.args
session = None
for s in self.app.sessions:
for s in current_app.sessions:
if s.id == session_id:
session = s
break
if session:
session.terminate()
self.app.service.finish_task(task.id)
app_service.finish_task(task.id)
def handle(self, task):
if task.name == "kill_session":
self.handle_kill_session(task)
else:
logger.error("No handler for this task: {}".format(task.name))
func = self.routes.get(task.name)
return func(task)
......@@ -15,10 +15,24 @@ import paramiko
import pyte
from . import char
from .ctx import stack
BASE_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
def ssh_key_string_to_obj(text, password=None):
key = None
try:
......@@ -280,27 +294,14 @@ def get_logger(file_name):
return logging.getLogger('coco.'+file_name)
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
def len_display(s):
length = 0
for i in s:
if zh_pattern.match(i):
length += 2
else:
length += 1
return length
def net_input(client, prompt='Opt> ', sensitive=False):
def net_input(client, prompt='Opt> ', sensitive=False, before=0, after=0):
"""实现了一个ssh input, 提示用户输入, 获取并返回
:return user input string
"""
input_data = []
parser = TtyIOParser()
client.send(wrap_with_line_feed(prompt, before=0, after=0))
client.send(wrap_with_line_feed(prompt, before=before, after=after))
while True:
data = client.recv(10)
......@@ -355,4 +356,67 @@ def net_input(client, prompt='Opt> ', sensitive=False):
input_data.append(data)
def register_app(app):
stack['app'] = app
def register_service(service):
stack['service'] = service
zh_pattern = re.compile(r'[\u4e00-\u9fa5]')
def find_chinese(s):
return zh_pattern.findall(s)
def align_with_zh(s, length, addin=' '):
if not isinstance(s, str):
s = str(s)
zh_len = len(find_chinese(s))
padding = length - (len(s) - zh_len) - zh_len*2
padding_content = ''
if padding > 0:
padding_content = addin*padding
return s + padding_content
def format_with_zh(size_list, *args):
data = []
for length, s in zip(size_list, args):
data.append(align_with_zh(s, length))
return ' '.join(data)
def size_of_str_with_zh(s):
if isinstance(s, int):
s = str(s)
try:
chinese = find_chinese(s)
except TypeError:
print(type(s))
raise
return len(s) + len(chinese)
def item_max_length(_iter, maxi=None, mini=None, key=None):
if key:
_iter = [key(i) for i in _iter]
length = [size_of_str_with_zh(s) for s in _iter]
if maxi:
length.append(maxi)
length = max(length)
if mini and length < mini:
length = mini
return length
def int_length(i):
return len(str(i))
ugettext = _gettext()
......@@ -28,7 +28,7 @@ psutil==5.4.1
pyasn1==0.4.2
pycparser==2.18
PyNaCl==1.2.1
pyte==0.7.0
pyte==0.8.0
python-dateutil==2.6.1
python-engineio==2.1.0
python-gssapi==0.6.4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment