Commit ebb30424 authored by ibuler's avatar ibuler

Use process except thread

parent 216163f4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
import logging
import os
BASE_DIR = os.path.dirname(os.path.abspath(__name__))
class Config:
SSH_HOST = ''
SSH_PORT = 2200
LOG_LEVEL = 'INFO'
LOG_DIR = os.path.join(BASE_DIR, 'logs')
LOG_FILENAME = 'ssh_server.log'
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'verbose': {
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
},
'main': {
'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
},
'simple': {
'format': '%(levelname)s %(message)s'
},
},
'handlers': {
'null': {
'level': 'DEBUG',
'class': 'logging.NullHandler',
},
'console': {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'main',
'stream': 'ext://sys.stdout',
},
'file': {
'level': 'DEBUG',
'class': 'logging.FileHandler',
'formatter': 'main',
'mode': 'a',
'filename': os.path.join(LOG_DIR, LOG_FILENAME),
},
},
'loggers': {
'jumpserver': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
},
'jumpserver.web_ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
},
'jumpserver.ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
}
}
}
def __init__(self):
pass
def __getattr__(self, item):
return None
class DevelopmentConfig(Config):
pass
class ProductionConfig(Config):
pass
class TestingConfig(Config):
pass
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig,
}
env = 'default'
if __name__ == '__main__':
pass
...@@ -7,20 +7,14 @@ import os ...@@ -7,20 +7,14 @@ import os
BASE_DIR = os.path.dirname(os.path.abspath(__name__)) BASE_DIR = os.path.dirname(os.path.abspath(__name__))
LOG_LEVEL_CHOICES = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL
}
class Config: class Config:
LOG_LEVEL = '' LOG_LEVEL = 'INFO'
LOG_DIR = os.path.join(BASE_DIR, 'logs') LOG_DIR = os.path.join(BASE_DIR, 'logs')
LOGGING = { LOGGING = {
'version': 1, 'version': 1,
'disable_existing_loggers': False,
'formatters': { 'formatters': {
'verbose': { 'verbose': {
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
...@@ -47,35 +41,23 @@ class Config: ...@@ -47,35 +41,23 @@ class Config:
'level': 'DEBUG', 'level': 'DEBUG',
'class': 'logging.FileHandler', 'class': 'logging.FileHandler',
'formatter': 'main', 'formatter': 'main',
'filename': os.path.join(PROJECT_DIR, 'logs', 'jumpserver.log') 'filename': LOG_DIR,
}, },
}, },
'loggers': { 'loggers': {
'django': {
'handlers': ['null'],
'propagate': False,
'level': LOG_LEVEL,
},
'django.request': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
'propagate': False,
},
'django.server': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
'propagate': False,
},
'jumpserver': { 'jumpserver': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL, 'level': LOG_LEVEL,
}, },
'jumpserver.users.api': { 'jumpserver.web_ssh_server': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL, 'level': LOG_LEVEL,
}, },
'jumpserver.users.view': { 'jumpserver.ssh_server': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL, 'level': LOG_LEVEL,
} }
} }
...@@ -88,6 +70,27 @@ class Config: ...@@ -88,6 +70,27 @@ class Config:
return None return None
class DevelopmentConfig(Config):
pass
class ProductionConfig(Config):
pass
class TestingConfig(Config):
pass
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig,
}
env = 'default'
if __name__ == '__main__': if __name__ == '__main__':
pass pass
...@@ -7,6 +7,7 @@ import base64 ...@@ -7,6 +7,7 @@ import base64
from binascii import hexlify from binascii import hexlify
import sys import sys
import threading import threading
from multiprocessing import process
import traceback import traceback
import tty import tty
import termios import termios
...@@ -31,17 +32,21 @@ except IndexError: ...@@ -31,17 +32,21 @@ except IndexError:
pass pass
from django.conf import settings from django.conf import settings
from common.utils import get_logger
from users.utils import ssh_key_gen, check_user_is_valid from users.utils import ssh_key_gen, check_user_is_valid
from utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
class SSHServerInterface(paramiko.ServerInterface): class SSHServerInterface(paramiko.ServerInterface):
host_key_path = os.path.join(BASE_DIR, 'host_rsa_key') host_key_path = os.path.join(BASE_DIR, 'host_rsa_key')
channel_pools = []
def __init__(self): def __init__(self, client, addr):
self.event = threading.Event() self.event = threading.Event()
self.client = client
self.addr = addr
self.user = None self.user = None
@classmethod @classmethod
...@@ -70,19 +75,35 @@ class SSHServerInterface(paramiko.ServerInterface): ...@@ -70,19 +75,35 @@ class SSHServerInterface(paramiko.ServerInterface):
def check_auth_password(self, username, password): def check_auth_password(self, username, password):
self.user = check_user_is_valid(username=username, password=password) self.user = check_user_is_valid(username=username, password=password)
if self.user: if self.user:
logger.info('User: %s password auth passed' % username) logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
else: else:
logger.warning('User: %s password auth failed' % username) logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, public_key): def check_auth_publickey(self, username, public_key):
self.user = check_user_is_valid(username=username, public_key=public_key) self.user = check_user_is_valid(username=username, public_key=public_key)
if self.user: if self.user:
logger.info('User: %s public key auth passed' % username) logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
else: else:
logger.warning('User: %s public key auth failed' % username) logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
...@@ -95,12 +116,20 @@ class SSHServerInterface(paramiko.ServerInterface): ...@@ -95,12 +116,20 @@ class SSHServerInterface(paramiko.ServerInterface):
def check_channel_shell_request(self, channel): def check_channel_shell_request(self, channel):
self.event.set() self.event.set()
self.__class__.channel_pools.append(channel)
return True return True
def check_channel_pty_request(self, channel, term, width, height, pixelwidth, def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
pixelheight, modes): pixelheight, modes):
return True return True
def check_channel_window_change_request(self, channel, width, height, pixelwidth, pixelheight):
logger.info('Change window size %s * %s' % (width, height))
logger.info('Change length %s ' % len(self.__class__.channel_pools))
# for channel in self.__class__.channel_pools:
# channel.send("Hello world")
return True
class SSHServer: class SSHServer:
def __init__(self, host='127.0.0.1', port=2200): def __init__(self, host='127.0.0.1', port=2200):
...@@ -110,18 +139,22 @@ class SSHServer: ...@@ -110,18 +139,22 @@ class SSHServer:
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((self.host, self.port)) self.sock.bind((self.host, self.port))
self.server_ssh = None self.server_ssh = None
self.server_chan = None self.server_channel = None
self.client_channel = None
def connect(self): def connect(self):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(hostname='127.0.0.1', port=22, username='root', password='redhat') ssh.connect(hostname='127.0.0.1', port=22, username='root', password='redhat')
self.server_ssh = ssh self.server_ssh = ssh
self.server_chan = channel = ssh.invoke_shell(term='xterm') self.server_channel = channel = ssh.invoke_shell(term='xterm')
return channel return channel
def handle_ssh_request(self, client, addr): def handle_ssh_request(self, client, addr):
logger.info("Get connection from " + str(addr)) logger.info("Get connection from %(host)s:%(port)s" % {
'host': addr[0],
'port': addr[1],
})
try: try:
transport = paramiko.Transport(client, gss_kex=False) transport = paramiko.Transport(client, gss_kex=False)
transport.set_gss_host(socket.getfqdn("")) transport.set_gss_host(socket.getfqdn(""))
...@@ -132,70 +165,63 @@ class SSHServer: ...@@ -132,70 +165,63 @@ class SSHServer:
raise raise
transport.add_server_key(SSHServerInterface.get_host_key()) transport.add_server_key(SSHServerInterface.get_host_key())
ssh_interface = SSHServerInterface() ssh_interface = SSHServerInterface(client, addr)
try: try:
transport.start_server(server=ssh_interface) transport.start_server(server=ssh_interface)
except paramiko.SSHException: except paramiko.SSHException:
print('*** SSH negotiation failed.') print('*** SSH negotiation failed.')
return return
channel = transport.accept(20) self.client_channel = client_channel = transport.accept(20)
if channel is None: if client_channel is None:
print('*** No channel.') print('*** No channel.')
return return
print('Authenticated!') print('Authenticated!')
channel.settimeout(100) client_channel.settimeout(100)
channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
channel.send('Happy birthday to Robot Dave!\r\n\r\n') client_channel.send('Happy birthday to Robot Dave!\r\n\r\n')
server_channel = self.connect() server_channel = self.connect()
if not ssh_interface.event.is_set(): if not ssh_interface.event.is_set():
print('*** Client never asked for a shell.') print('*** Client never asked for a shell.')
return return
server_data = []
input_mode = True
while True: while True:
r, w, e = select.select([server_channel, channel], [], []) r, w, x = select.select([client_channel, server_channel], [], [])
if channel in r: if client_channel in r:
recv_data = channel.recv(1024).decode('utf8') data_client = client_channel.recv(1024)
# print("From client: " + repr(recv_data)) logger.info(data_client)
if len(recv_data) == 0: if len(data_client) == 0:
break break
server_channel.send(recv_data) # client_channel.send(data_client)
server_channel.send(data_client)
if server_channel in r: if server_channel in r:
recv_data = server_channel.recv(1024).decode('utf8') data_server = server_channel.recv(1024)
# print("From server: " + repr(recv_data)) if len(data_server) == 0:
if len(recv_data) == 0:
break break
channel.send(recv_data) client_channel.send(data_server)
if len(recv_data) > 20:
server_data.append('...') # if len(recv_data) > 20:
else: # server_data.append('...')
server_data.append(recv_data) # else:
try: # server_data.append(recv_data)
if repr(server_data[-2]) == u'\r\n': # try:
result = server_data.pop() # if repr(server_data[-2]) == u'\r\n':
server_data.pop() # result = server_data.pop()
command = ''.join(server_data) # server_data.pop()
server_data = [] # command = ''.join(server_data)
print(">>> Command: %s" % command) # server_data = []
print(result) # except IndexError:
except IndexError: # pass
pass
print(server_data) except Exception:
client_channel.close()
except Exception as e: server_channel.close()
print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e)) logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1'))
traceback.print_exc()
try:
transport.close()
except:
pass
sys.exit(1)
def listen(self): def listen(self):
self.sock.listen(5) self.sock.listen(5)
...@@ -204,7 +230,9 @@ class SSHServer: ...@@ -204,7 +230,9 @@ class SSHServer:
try: try:
client, addr = self.sock.accept() client, addr = self.sock.accept()
print('Listening for connection ...') print('Listening for connection ...')
t = threading.Thread(target=self.handle_ssh_request, args=(client, addr)) # t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
t = process.Process(target=self.handle_ssh_request, args=(client, addr))
t.daemon = True t.daemon = True
t.start() t.start()
except Exception as e: except Exception as e:
......
...@@ -3,6 +3,15 @@ ...@@ -3,6 +3,15 @@
# #
import logging import logging
from logging.config import dictConfig
from ssh_config import config, env
CONFIG_SSH_SERVER = config.get(env)
def get_logger(name):
dictConfig(CONFIG_SSH_SERVER.LOGGING)
return logging.getLogger('jumpserver.%s' % name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment