Commit f2e5c268 authored by ibuler's avatar ibuler

[Feature] 添加ssh interface

parent ecf43f40
import os import os
import time import time
import threading import threading
from queue import Queue
from .config import Config from .config import Config
from .sshd import SSHServer from .sshd import SSHServer
...@@ -30,16 +31,24 @@ class Coco: ...@@ -30,16 +31,24 @@ class Coco:
'SSH_PASSWORD_AUTH': True, 'SSH_PASSWORD_AUTH': True,
'SSH_PUBLIC_KEY_AUTH': True, 'SSH_PUBLIC_KEY_AUTH': True,
'HEARTBEAT_INTERVAL': 5, 'HEARTBEAT_INTERVAL': 5,
'MAX_CONNECTIONS': 500,
} }
def __init__(self, name=None): def __init__(self, name=None, root_path=None):
self.config = self.config_class(BASE_DIR, defaults=self.default_config) self.config = self.config_class(BASE_DIR, defaults=self.default_config)
self.sessions = [] self.sessions = []
self.connections = []
self.sshd = SSHServer(self)
self.ws = None
if name: if name:
self.name = name self.name = name
else: else:
self.name = self.config['NAME'] self.name = self.config['NAME']
if root_path is None:
self.make_logger() self.make_logger()
def make_logger(self): def make_logger(self):
...@@ -61,20 +70,31 @@ class Coco: ...@@ -61,20 +70,31 @@ class Coco:
'host': self.config['BIND_HOST'], 'port': self.config['WS_PORT']}) 'host': self.config['BIND_HOST'], 'port': self.config['WS_PORT']})
print('Quit the server with CONTROL-C.') print('Quit the server with CONTROL-C.')
exit_queue = Queue()
try: try:
self.run_sshd() if self.config["SSHD_PORT"] != 0:
self.run_ws() self.run_sshd()
if self.config['WS_PORT'] != 0:
self.run_ws()
if exit_queue.get():
self.shutdown()
except KeyboardInterrupt: except KeyboardInterrupt:
self.shutdown() self.shutdown()
def run_sshd(self): def run_sshd(self):
thread = threading.Thread(target=SSHServer(self).run, args=()) thread = threading.Thread(target=self.sshd.run, args=())
thread.daemon = True
thread.start()
def run_ws(self): def run_ws(self):
pass pass
def shutdown(self): def shutdown(self):
pass print("Grace shutdown the server")
self.sshd.shutdown()
def monitor_session(self): def monitor_session(self):
pass pass
......
#!coding: utf-8
BACKSPACE_CHAR = {b'\x08': b'\x08\x1b[K', b'\x7f': b'\x08\x1b[K'}
ENTER_CHAR = [b'\r', b'\n', b'\r\n']
UNSUPPORTED_CHAR = {b'\x15': 'Ctrl-U', b'\x0c': 'Ctrl-L', b'\x05': 'Ctrl-E'}
CLEAR_CHAR = b'\x1b[H\x1b[2J'
BELL_CHAR = b'\x07'
#!coding: utf-8
import socket
from . import char
class InteractiveServer:
def __init__(self, app, request, chan):
self.app = app
self.request = request
self.client = chan
def display_banner(self):
self.client.send(char.CLEAR_CHAR)
banner = u"""\n\033[1;32m %s, 欢迎使用Jumpserver开源跳板机系统 \033[0m\r\n\r
1) 输入 \033[32mID\033[0m 直接登录 或 输入\033[32m部分 IP,主机名,备注\033[0m 进行搜索登录(如果唯一).\r
2) 输入 \033[32m/\033[0m + \033[32mIP, 主机名 or 备注 \033[0m搜索. 如: /ip\r
3) 输入 \033[32mP/p\033[0m 显示您有权限的主机.\r
4) 输入 \033[32mG/g\033[0m 显示您有权限的主机组.\r
5) 输入 \033[32mG/g\033[0m\033[0m + \033[32m组ID\033[0m 显示该组下主机. 如: g1\r
6) 输入 \033[32mE/e\033[0m 批量执行命令.(未完成)\r
7) 输入 \033[32mU/u\033[0m 批量上传文件.(未完成)\r
8) 输入 \033[32mD/d\033[0m 批量下载文件.(未完成)\r
9) 输入 \033[32mH/h\033[0m 帮助.\r
0) 输入 \033[32mQ/q\033[0m 退出.\r\n""" % self.request.user
self.client.send(banner)
def get_input(self, prompt='Opt> '):
pass
def dispatch(self):
pass
def run(self):
self.display_banner()
while True:
try:
self.dispatch()
except socket.error:
break
self.close()
def close(self):
pass
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import logging import logging
import paramiko import paramiko
import threading
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
...@@ -15,12 +16,13 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -15,12 +16,13 @@ class SSHInterface(paramiko.ServerInterface):
https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py https://github.com/paramiko/paramiko/blob/master/demos/demo_server.py
""" """
def __init__(self, app, *args, *kwargs): def __init__(self, app, request):
self.app = app self.app = app
self.request = request
self.event = threading.Event()
def check_auth_interactive(self, username, submethods): def check_auth_interactive(self, username, submethods):
""" """
:param username: the username of the authenticating client :param username: the username of the authenticating client
:param submethods: a comma-separated list of methods preferred :param submethods: a comma-separated list of methods preferred
by the client (usually empty) by the client (usually empty)
...@@ -32,6 +34,7 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -32,6 +34,7 @@ class SSHInterface(paramiko.ServerInterface):
def check_auth_interactive_response(self, responses): def check_auth_interactive_response(self, responses):
logger.info("Check auth interactive response: %s " % responses) logger.info("Check auth interactive response: %s " % responses)
# TODO:MFA Auth
pass pass
def enable_auth_gssapi(self): def enable_auth_gssapi(self):
...@@ -52,32 +55,52 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -52,32 +55,52 @@ class SSHInterface(paramiko.ServerInterface):
def validate_auth(self, username, password="", key=""): def validate_auth(self, username, password="", key=""):
# Todo: Implement it # Todo: Implement it
self.request.user = "guang"
return True return True
def check_channel_direct_tcpip_request(self, chanid, origin, destination): def check_channel_direct_tcpip_request(self, chanid, origin, destination):
logger.info("Check channel direct tcpip request: %d %s %s" % logger.debug("Check channel direct tcpip request: %d %s %s" %
(chanid, origin, destination)) (chanid, origin, destination))
self.request.type = 'direct-tcpip'
self.request.meta = {
'chanid': chanid, 'origin': origin,
'destination': destination,
}
self.event.set()
return 0 return 0
def check_channel_env_request(self, channel, name, value): def check_channel_env_request(self, channel, name, value):
logger.info("Check channel env request: %s, %s, %s" % logger.debug("Check channel env request: %s, %s, %s" %
(channel, name, value)) (channel, name, value))
return False return False
def check_channel_exec_request(self, channel, command): def check_channel_exec_request(self, channel, command):
logger.info("Check channel exec request: %s `%s`" % logger.debug("Check channel exec request: %s `%s`" %
(channel, command)) (channel, command))
self.request.type = 'exec'
self.request.meta = {'channel': channel, 'command': command}
self.event.set()
return False return False
def check_channel_forward_agent_request(self, channel): def check_channel_forward_agent_request(self, channel):
logger.info("Check channel forward agent request: %s" % channel) logger.debug("Check channel forward agent request: %s" % channel)
self.request.type = "forward-agent"
self.request.meta = {'channel': channel}
self.event.set()
return False return False
def check_channel_pty_request( def check_channel_pty_request(
self, channel, term, width, height, self, channel, term, width, height,
pixelwidth, pixelheight, modes): pixelwidth, pixelheight, modes):
logger.info("Check channel pty request: %s %s %s %s %s %s %s" % logger.debug("Check channel pty request: %s %s %s %s %s %s" %
(channel, term, width, height, pixelwidth,pixelheight, modes)) (channel, term, width, height, pixelwidth, pixelheight))
self.request.type = 'pty'
self.request.meta = {
'channel': channel, 'term': term, 'width': width,
'height': height, 'pixelwidth': pixelwidth,
'pixelheight': pixelheight, 'models': modes,
}
self.event.set()
return True return True
def check_channel_request(self, kind, chanid): def check_channel_request(self, kind, chanid):
...@@ -90,23 +113,39 @@ class SSHInterface(paramiko.ServerInterface): ...@@ -90,23 +113,39 @@ class SSHInterface(paramiko.ServerInterface):
def check_channel_subsystem_request(self, channel, name): def check_channel_subsystem_request(self, channel, name):
logger.info("Check channel subsystem request: %s %s" % (channel, name)) logger.info("Check channel subsystem request: %s %s" % (channel, name))
self.request.type = 'subsystem'
self.request.meta = {'channel': channel, 'name': name}
self.event.set()
return False return False
def check_channel_window_change_request( def check_channel_window_change_request(self, channel, width, height,
self, channel, width, height, pixelwidth, pixelheight): pixelwidth, pixelheight):
# Todo: implement window size change self.request.meta['width'] = width
self.request.meta['height'] = height
self.request.meta['pixelwidth'] = pixelwidth
self.request.meta['pixelheight'] = pixelheight
self.request.change_size_event.set()
return True return True
def check_channel_x11_request( def check_channel_x11_request(self, channel, single_connection,
self, channel, single_connection, auth_protocol, auth_cookie, auth_protocol, auth_cookie, screen_number):
screen_number):
logger.info("Check channel x11 request %s %s %s %s %s" % logger.info("Check channel x11 request %s %s %s %s %s" %
(channel, single_connection, auth_protocol, (channel, single_connection, auth_protocol,
auth_cookie, screen_number)) auth_cookie, screen_number))
self.request.type = 'x11'
self.request.meta = {
'channel': channel, 'single_connection': single_connection,
'auth_protocol': auth_protocol, 'auth_cookie': auth_cookie,
'screen_number': screen_number,
}
self.event.set()
return False return False
def check_port_forward_request(self, address, port): def check_port_forward_request(self, address, port):
logger.info("Check channel subsystem request: %s %s" % (address, port)) logger.info("Check channel port forward request: %s %s" % (address, port))
self.request.type = 'port-forward'
self.request.meta = {'address': address, 'port': port}
self.event.set()
return False return False
def get_banner(self): def get_banner(self):
......
#! coding: utf-8 #! coding: utf-8
import os
import logging import logging
import socket import socket
import threading
import paramiko
import sys
from .utils import ssh_key_gen
from .interface import SSHInterface
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
BACKLOG = 5 BACKLOG = 5
class SSHServer: class Request:
def __init__(self, client, addr):
self.type = ""
self.meta = {}
self.client = client
self.addr = addr
self.user = None
self.change_size_event = threading.Event()
self.win_size = {}
class SSHServer:
def __init__(self, app=None): def __init__(self, app=None):
self.app = app self.app = app
self.stop_event = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.host_key_path = os.path.join(self.app.root_path, 'keys', 'host_rsa_key')
self.host_key = self.get_host_key()
def run(self): def listen(self):
host = self.app.config["BIND_HOST"] host = self.app.config["BIND_HOST"]
port = self.app.config["SSHD_PORT"] port = self.app.config["SSHD_PORT"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) print('Starting shh server at %(host)s:%(port)s' %
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(BACKLOG)
print('Starting ssh server at %(host)s:%(port)s' %
{"host": host, "port": port}) {"host": host, "port": port})
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((host, port))
self.sock.listen(BACKLOG)
def get_host_key(self):
if not os.path.isfile(self.host_key_path):
self.gen_host_key()
return paramiko.RSAKey(filename=self.host_key_path)
def gen_host_key(self):
ssh_key, _ = ssh_key_gen()
with open(self.host_key_path, 'w') as f:
f.write(ssh_key)
def run(self):
self.listen()
max_conn_num = self.app['MAX_CONNECTIONS']
while not self.stop_event.is_set():
try:
client, addr = self.sock.accept()
logger.info("Get ssh request from %s: %s" % (addr[0], addr[1]))
if len(self.app.connections) >= max_conn_num:
client.close()
logger.warning("Arrive max connection number %s, "
"reject new request %s:%s" %
(max_conn_num, addr[0], addr[1]))
else:
self.app.connections.append((client, addr))
thread = threading.Thread(target=self.handle, args=(client, addr))
thread.daemon = True
thread.start()
except Exception as e:
logger.error("SSH server error: %s" % e)
def handle(self, client, addr):
transport = paramiko.Transport(client, gss_kex=False)
try:
transport.load_server_moduli()
except IOError:
logger.warning("Failed load moduli -- gex will be unsupported")
transport.add_server_key(self.host_key)
request = Request(client, addr)
server = SSHInterface(self.app, request)
try:
transport.start_server(server=server)
except paramiko.SSHException:
logger.warning("SSH negotiation failed.")
sys.exit(1)
chan = transport.accept(10)
if chan is None:
logger.warning("No ssh channel get")
sys.exit(1)
server.event.wait(5)
if not server.event.is_set():
logger.warning("Client not request a valid request")
sys.exit(2)
self.dispatch(request, chan)
def dispatch(self, request, channel):
if request.type == 'pty':
pass
elif request.type == 'exec':
pass
elif request.type == 'subsystem':
pass
else:
channel.send("Not support request type: %s" % request.type)
def shutdown(self): def shutdown(self):
pass self.stop_event.set()
#!coding: utf-8
import os
import paramiko
from io import StringIO
def ssh_key_string_to_obj(text):
key_f = StringIO(text)
key = None
try:
key = paramiko.RSAKey.from_private_key(key_f)
except paramiko.SSHException:
pass
try:
key = paramiko.DSSKey.from_private_key(key_f)
except paramiko.SSHException:
pass
return key
def ssh_pubkey_gen(private_key=None, username='jumpserver', hostname='localhost'):
if isinstance(private_key, str):
private_key = ssh_key_string_to_obj(private_key)
if not isinstance(private_key, (paramiko.RSAKey, paramiko.DSSKey)):
raise IOError('Invalid private key')
public_key = "%(key_type)s %(key_content)s %(username)s@%(hostname)s" % {
'key_type': private_key.get_name(),
'key_content': private_key.get_base64(),
'username': username,
'hostname': hostname,
}
return public_key
def ssh_key_gen(length=2048, type='rsa', password=None,
username='jumpserver', hostname=None):
"""Generate user ssh private and public key
Use paramiko RSAKey generate it.
:return private key str and public key str
"""
if hostname is None:
hostname = os.uname()[1]
f = StringIO()
try:
if type == 'rsa':
private_key_obj = paramiko.RSAKey.generate(length)
elif type == 'dsa':
private_key_obj = paramiko.DSSKey.generate(length)
else:
raise IOError('SSH private key must be `rsa` or `dsa`')
private_key_obj.write_private_key(f, password=password)
private_key = f.getvalue()
public_key = ssh_pubkey_gen(private_key_obj, username=username, hostname=hostname)
return private_key, public_key
except IOError:
raise IOError('These is error when generate ssh key.')
\ No newline at end of file
...@@ -17,6 +17,9 @@ APP_NAME = "coco" ...@@ -17,6 +17,9 @@ APP_NAME = "coco"
# 监听的WS端口号,默认5000 # 监听的WS端口号,默认5000
# WS_PORT = 5000 # WS_PORT = 5000
# 最大连接线程数
# MAX_CONNECTIONS = 500
# 是否开启DEBUG # 是否开启DEBUG
# DEBUG = True # DEBUG = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment