Commit 4b741955 authored by ibuler's avatar ibuler

Update ssh_server to some class

parent 2c64b784
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
__version__ = '0.3.3'
import sys import sys
import os import os
import base64 import base64
import time
from binascii import hexlify from binascii import hexlify
import sys import sys
import threading import threading
...@@ -19,7 +23,6 @@ import select ...@@ -19,7 +23,6 @@ import select
import errno import errno
import paramiko import paramiko
import django import django
from paramiko.py3compat import b, u, decodebytes
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) BASE_DIR = os.path.abspath(os.path.dirname(__file__))
APP_DIR = os.path.join(os.path.dirname(BASE_DIR), 'apps') APP_DIR = os.path.join(os.path.dirname(BASE_DIR), 'apps')
...@@ -33,13 +36,13 @@ except IndexError: ...@@ -33,13 +36,13 @@ except IndexError:
from django.conf import settings from django.conf import settings
from users.utils import ssh_key_gen, check_user_is_valid from users.utils import ssh_key_gen, check_user_is_valid
from utils import get_logger from utils import get_logger, SSHServerException
logger = get_logger(__name__) logger = get_logger(__name__)
class SSHServerInterface(paramiko.ServerInterface): class SSHServer(paramiko.ServerInterface):
host_key_path = os.path.join(BASE_DIR, 'host_rsa_key') host_key_path = os.path.join(BASE_DIR, 'host_rsa_key')
channel_pools = [] channel_pools = []
...@@ -47,7 +50,10 @@ class SSHServerInterface(paramiko.ServerInterface): ...@@ -47,7 +50,10 @@ class SSHServerInterface(paramiko.ServerInterface):
self.event = threading.Event() self.event = threading.Event()
self.client = client self.client = client
self.addr = addr self.addr = addr
self.username = None
self.user = None self.user = None
self.channel_width = None
self.channel_height = None
@classmethod @classmethod
def host_key(cls): def host_key(cls):
...@@ -73,34 +79,37 @@ class SSHServerInterface(paramiko.ServerInterface): ...@@ -73,34 +79,37 @@ class SSHServerInterface(paramiko.ServerInterface):
return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
def check_auth_password(self, username, password): def check_auth_password(self, username, password):
self.user = check_user_is_valid(username=username, password=password) self.user = user = check_user_is_valid(username=username, password=password)
self.username = username = user.username
if self.user: if self.user:
logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % { logger.info('Accepted password for %(username)s from %(host)s port %(port)s ' % {
'user': username, 'username': username,
'host': self.addr[0], 'host': self.addr[0],
'port': self.addr[1], 'port': self.addr[1],
}) })
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
else: else:
logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % { logger.info('Authentication password failed for %(username)s from %(host)s port %(port)s ' % {
'user': username, 'username': username,
'host': self.addr[0], 'host': self.addr[0],
'port': self.addr[1], 'port': self.addr[1],
}) })
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, public_key): def check_auth_publickey(self, username, public_key):
self.user = check_user_is_valid(username=username, public_key=public_key) self.user = user = check_user_is_valid(username=username, public_key=public_key)
self.username = username = user.username
if self.user: if self.user:
logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % { logger.info('Accepted public key for %(username)s from %(host)s port %(port)s ' % {
'user': username, 'username': username,
'host': self.addr[0], 'host': self.addr[0],
'port': self.addr[1], 'port': self.addr[1],
}) })
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
else: else:
logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % { logger.info('Authentication public key failed for %(username)s from %(host)s port %(port)s ' % {
'user': username, 'username': username,
'host': self.addr[0], 'host': self.addr[0],
'port': self.addr[1], 'port': self.addr[1],
}) })
...@@ -135,12 +144,18 @@ class BackendServer: ...@@ -135,12 +144,18 @@ class BackendServer:
self.ssh = None self.ssh = None
self.channel = None self.channel = None
def connect(self, term='xterm', width=80, height=24): def connect(self, term='xterm', width=80, height=24, timeout=10):
self.ssh = ssh = paramiko.SSHClient() self.ssh = ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(hostname=self.host, port=self.port, username=self.username, password=self.host_password, ssh.connect(hostname=self.host, port=self.port, username=self.username, password=self.host_password,
pkey=self.host_private_key, look_for_keys=False, allow_agent=True, compress=True) pkey=self.host_private_key, look_for_keys=False, allow_agent=True, compress=True, timeout=timeout)
self.channel = channel = ssh.invoke_shell(term=term, width=width, height=height) self.channel = channel = ssh.invoke_shell(term=term, width=width, height=height)
logger.info('Connect %(username)s@%(host)s:%(port)s successfully' % {
'username': self.username,
'host': self.host,
'port': self.port,
})
channel.settimeout(100)
return channel return channel
@property @property
...@@ -149,90 +164,108 @@ class BackendServer: ...@@ -149,90 +164,108 @@ class BackendServer:
@property @property
def host_private_key(self): def host_private_key(self):
return 'redhat' return None
class Navigation: class Navigation:
def __init__(self, username): def __init__(self, username, client_channel):
self.username = username self.username = username
self.client_channel = client_channel
def display_banner(self):
client_channel = self.client_channel
client_channel.send('\r\n\r\n\t\tWelcome to use Jumpserver open source system !\r\n\r\n')
client_channel.send('If use find some bug please contact us <ibuler@qq.com>\r\n')
# client_channel.send(self.username)
def display(self): def display(self):
self.display_banner()
def return_to_connect(self):
pass pass
class SSHServer: class JumpServer:
def __init__(self, host='127.0.0.1', port=2200): def __init__(self):
self.host = host self.listen_host = '0.0.0.0'
self.port = port self.listen_port = 2222
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.username = None
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.backend_host = None
self.sock.bind((self.host, self.port)) self.backend_port = None
self.server_ssh = None self.backend_username = None
self.server_channel = None self.backend_channel = None
self.client_channel = None self.client_channel = None
self.sock = None
def invoke_with_backend(self): def display_navigation(self, username, client_channel):
pass nav = Navigation(username, client_channel)
nav.display()
return '127.0.0.1', 22, 'root'
def display_navigation(self): def get_client_channel(self, client, addr):
pass transport = paramiko.Transport(client, gss_kex=False)
transport.set_gss_host(socket.getfqdn(""))
try:
transport.load_server_moduli()
except:
logger.warning('Failed to load moduli -- gex will be unsupported.')
raise
def make_client_channel(self): transport.add_server_key(SSHServer.get_host_key())
pass ssh_server = SSHServer(client, addr)
self.username = ssh_server.username
try:
transport.start_server(server=ssh_server)
except paramiko.SSHException:
logger.warning('SSH negotiation failed.')
self.client_channel = client_channel = transport.accept(20)
if client_channel is None:
logger.warning('No channel get.')
raise SSHServerException('No channel get.')
if not ssh_server.event.is_set():
logger.warning('Client never asked for a shell.')
raise SSHServerException('Client never asked for a shell.')
return client_channel
def get_backend_channel(self, host, port, username):
backend_server = BackendServer(host, port, username)
self.backend_channel = backend_channel = backend_server.connect()
if not backend_channel:
logger.warning('Connect %(username)s@%(host)s:%(port)s failed' % {
'username': username,
'host': host,
'port': port,
})
return backend_channel
def handle_ssh_request(self, client, addr): def handle_ssh_request(self, client, addr):
logger.info("Get connection from %(host)s:%(port)s" % { logger.info("Get ssh request from %(host)s:%(port)s" % {
'host': addr[0], 'host': addr[0],
'port': addr[1], 'port': addr[1],
}) })
try: try:
transport = paramiko.Transport(client, gss_kex=False) client_channel = self.get_client_channel(client, addr)
transport.set_gss_host(socket.getfqdn("")) host, port, username = self.display_navigation(self.username, client_channel)
try: backend_channel = self.get_backend_channel(host, port, username)
transport.load_server_moduli()
except:
logger.warning('(Failed to load moduli -- gex will be unsupported.)')
raise
transport.add_server_key(SSHServerInterface.get_host_key()) print(client_channel.get_id(), backend_channel.get_id())
ssh_interface = SSHServerInterface(client, addr)
try:
transport.start_server(server=ssh_interface)
except paramiko.SSHException:
print('*** SSH negotiation failed.')
return
self.client_channel = client_channel = transport.accept(20)
# self.client_channel = client_channel = transport.open_session()
# client_channel.get_pty(term='xterm')
if client_channel is None:
print('*** No channel.')
return
print('Authenticated!')
client_channel.settimeout(100)
client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
client_channel.send('Happy birthday to Robot Dave!\r\n\r\n')
server_channel = self.connect()
if not ssh_interface.event.is_set():
print('*** Client never asked for a shell.')
return
while True: while True:
r, w, x = select.select([client_channel, server_channel], [], []) r, w, x = select.select([client_channel, backend_channel], [], [])
if client_channel in r: if client_channel in r:
data_client = client_channel.recv(1024) data_client = client_channel.recv(1024)
logger.info(data_client) logger.info(data_client)
if len(data_client) == 0: if len(data_client) == 0:
break break
# client_channel.send(data_client) backend_channel.send(data_client)
server_channel.send(data_client)
if server_channel in r: if backend_channel in r:
data_server = server_channel.recv(1024) data_server = backend_channel.recv(1024)
if len(data_server) == 0: if len(data_server) == 0:
break break
client_channel.send(data_server) client_channel.send(data_server)
...@@ -250,30 +283,35 @@ class SSHServer: ...@@ -250,30 +283,35 @@ class SSHServer:
# except IndexError: # except IndexError:
# pass # pass
except Exception: except IndexError:
logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1')) logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1'))
sys.exit(100) sys.exit(100)
def listen(self): def listen(self):
self.sock.listen(5) self.sock = sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print('Start ssh server %(host)s:%(port)s' % {'host': self.host, 'port': self.port}) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.listen_host, self.listen_port))
sock.listen(5)
print(time.ctime())
print('Jumpserver version %s, more see https://www.jumpserver.org' % __version__)
print('Starting ssh server at %(host)s:%(port)s' % {'host': self.listen_host, 'port': self.listen_port})
print('Quit the server with CONTROL-C.')
while True: while True:
try: try:
client, addr = self.sock.accept() client, addr = self.sock.accept()
print('Listening for connection ...')
# t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
t = process.Process(target=self.handle_ssh_request, args=(client, addr)) t = process.Process(target=self.handle_ssh_request, args=(client, addr))
t.daemon = True t.daemon = True
t.start() t.start()
except Exception as e: except Exception as e:
print('*** Bind failed: ' + str(e)) logger.error('Bind failed: ' + str(e))
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':
server = SSHServer(host='', port=2200) server = JumpServer()
try: try:
server.listen() server.listen()
except KeyboardInterrupt: except KeyboardInterrupt:
......
...@@ -15,3 +15,5 @@ def get_logger(name): ...@@ -15,3 +15,5 @@ def get_logger(name):
return logging.getLogger('jumpserver.%s' % name) return logging.getLogger('jumpserver.%s' % name)
class SSHServerException(Exception):
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment