Commit ebb30424 authored by ibuler's avatar ibuler

Use process except thread

parent 216163f4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
import logging
import os
BASE_DIR = os.path.dirname(os.path.abspath(__name__))
class Config:
SSH_HOST = ''
SSH_PORT = 2200
LOG_LEVEL = 'INFO'
LOG_DIR = os.path.join(BASE_DIR, 'logs')
LOG_FILENAME = 'ssh_server.log'
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'verbose': {
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
},
'main': {
'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
},
'simple': {
'format': '%(levelname)s %(message)s'
},
},
'handlers': {
'null': {
'level': 'DEBUG',
'class': 'logging.NullHandler',
},
'console': {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'main',
'stream': 'ext://sys.stdout',
},
'file': {
'level': 'DEBUG',
'class': 'logging.FileHandler',
'formatter': 'main',
'mode': 'a',
'filename': os.path.join(LOG_DIR, LOG_FILENAME),
},
},
'loggers': {
'jumpserver': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
},
'jumpserver.web_ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
},
'jumpserver.ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
'propagate': True,
}
}
}
def __init__(self):
pass
def __getattr__(self, item):
return None
class DevelopmentConfig(Config):
pass
class ProductionConfig(Config):
pass
class TestingConfig(Config):
pass
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig,
}
env = 'default'
if __name__ == '__main__':
pass
......@@ -7,20 +7,14 @@ import os
BASE_DIR = os.path.dirname(os.path.abspath(__name__))
LOG_LEVEL_CHOICES = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL
}
class Config:
LOG_LEVEL = ''
LOG_LEVEL = 'INFO'
LOG_DIR = os.path.join(BASE_DIR, 'logs')
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'verbose': {
'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
......@@ -47,35 +41,23 @@ class Config:
'level': 'DEBUG',
'class': 'logging.FileHandler',
'formatter': 'main',
'filename': os.path.join(PROJECT_DIR, 'logs', 'jumpserver.log')
'filename': LOG_DIR,
},
},
'loggers': {
'django': {
'handlers': ['null'],
'propagate': False,
'level': LOG_LEVEL,
},
'django.request': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
'propagate': False,
},
'django.server': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
'propagate': False,
},
'jumpserver': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
},
'jumpserver.users.api': {
'jumpserver.web_ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
},
'jumpserver.users.view': {
'jumpserver.ssh_server': {
'handlers': ['console', 'file'],
# 'level': LOG_LEVEL_CHOICES.get(LOG_LEVEL, None) or LOG_LEVEL_CHOICES.get('info')
'level': LOG_LEVEL,
}
}
......@@ -88,6 +70,27 @@ class Config:
return None
class DevelopmentConfig(Config):
pass
class ProductionConfig(Config):
pass
class TestingConfig(Config):
pass
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig,
}
env = 'default'
if __name__ == '__main__':
pass
......@@ -7,6 +7,7 @@ import base64
from binascii import hexlify
import sys
import threading
from multiprocessing import process
import traceback
import tty
import termios
......@@ -31,17 +32,21 @@ except IndexError:
pass
from django.conf import settings
from common.utils import get_logger
from users.utils import ssh_key_gen, check_user_is_valid
from utils import get_logger
logger = get_logger(__name__)
class SSHServerInterface(paramiko.ServerInterface):
host_key_path = os.path.join(BASE_DIR, 'host_rsa_key')
channel_pools = []
def __init__(self):
def __init__(self, client, addr):
self.event = threading.Event()
self.client = client
self.addr = addr
self.user = None
@classmethod
......@@ -70,19 +75,35 @@ class SSHServerInterface(paramiko.ServerInterface):
def check_auth_password(self, username, password):
self.user = check_user_is_valid(username=username, password=password)
if self.user:
logger.info('User: %s password auth passed' % username)
logger.info('Accepted password for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_SUCCESSFUL
else:
logger.warning('User: %s password auth failed' % username)
logger.info('Authentication password failed for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, public_key):
self.user = check_user_is_valid(username=username, public_key=public_key)
if self.user:
logger.info('User: %s public key auth passed' % username)
logger.info('Accepted public key for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_SUCCESSFUL
else:
logger.warning('User: %s public key auth failed' % username)
logger.info('Authentication public key failed for %(user)s from %(host)s port %(port)s ' % {
'user': username,
'host': self.addr[0],
'port': self.addr[1],
})
return paramiko.AUTH_FAILED
def get_allowed_auths(self, username):
......@@ -95,12 +116,20 @@ class SSHServerInterface(paramiko.ServerInterface):
def check_channel_shell_request(self, channel):
self.event.set()
self.__class__.channel_pools.append(channel)
return True
def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
pixelheight, modes):
return True
def check_channel_window_change_request(self, channel, width, height, pixelwidth, pixelheight):
logger.info('Change window size %s * %s' % (width, height))
logger.info('Change length %s ' % len(self.__class__.channel_pools))
# for channel in self.__class__.channel_pools:
# channel.send("Hello world")
return True
class SSHServer:
def __init__(self, host='127.0.0.1', port=2200):
......@@ -110,18 +139,22 @@ class SSHServer:
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((self.host, self.port))
self.server_ssh = None
self.server_chan = None
self.server_channel = None
self.client_channel = None
def connect(self):
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(hostname='127.0.0.1', port=22, username='root', password='redhat')
self.server_ssh = ssh
self.server_chan = channel = ssh.invoke_shell(term='xterm')
self.server_channel = channel = ssh.invoke_shell(term='xterm')
return channel
def handle_ssh_request(self, client, addr):
logger.info("Get connection from " + str(addr))
logger.info("Get connection from %(host)s:%(port)s" % {
'host': addr[0],
'port': addr[1],
})
try:
transport = paramiko.Transport(client, gss_kex=False)
transport.set_gss_host(socket.getfqdn(""))
......@@ -132,70 +165,63 @@ class SSHServer:
raise
transport.add_server_key(SSHServerInterface.get_host_key())
ssh_interface = SSHServerInterface()
ssh_interface = SSHServerInterface(client, addr)
try:
transport.start_server(server=ssh_interface)
except paramiko.SSHException:
print('*** SSH negotiation failed.')
return
channel = transport.accept(20)
if channel is None:
self.client_channel = client_channel = transport.accept(20)
if client_channel is None:
print('*** No channel.')
return
print('Authenticated!')
channel.settimeout(100)
client_channel.settimeout(100)
channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
channel.send('Happy birthday to Robot Dave!\r\n\r\n')
client_channel.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
client_channel.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
client_channel.send('Happy birthday to Robot Dave!\r\n\r\n')
server_channel = self.connect()
if not ssh_interface.event.is_set():
print('*** Client never asked for a shell.')
return
server_data = []
input_mode = True
while True:
r, w, e = select.select([server_channel, channel], [], [])
r, w, x = select.select([client_channel, server_channel], [], [])
if channel in r:
recv_data = channel.recv(1024).decode('utf8')
# print("From client: " + repr(recv_data))
if len(recv_data) == 0:
if client_channel in r:
data_client = client_channel.recv(1024)
logger.info(data_client)
if len(data_client) == 0:
break
server_channel.send(recv_data)
# client_channel.send(data_client)
server_channel.send(data_client)
if server_channel in r:
recv_data = server_channel.recv(1024).decode('utf8')
# print("From server: " + repr(recv_data))
if len(recv_data) == 0:
data_server = server_channel.recv(1024)
if len(data_server) == 0:
break
channel.send(recv_data)
if len(recv_data) > 20:
server_data.append('...')
else:
server_data.append(recv_data)
try:
if repr(server_data[-2]) == u'\r\n':
result = server_data.pop()
server_data.pop()
command = ''.join(server_data)
server_data = []
print(">>> Command: %s" % command)
print(result)
except IndexError:
pass
print(server_data)
except Exception as e:
print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc()
try:
transport.close()
except:
pass
sys.exit(1)
client_channel.send(data_server)
# if len(recv_data) > 20:
# server_data.append('...')
# else:
# server_data.append(recv_data)
# try:
# if repr(server_data[-2]) == u'\r\n':
# result = server_data.pop()
# server_data.pop()
# command = ''.join(server_data)
# server_data = []
# except IndexError:
# pass
except Exception:
client_channel.close()
server_channel.close()
logger.info('Close with server %s from %s' % ('127.0.0.1', '127.0.0.1'))
def listen(self):
self.sock.listen(5)
......@@ -204,7 +230,9 @@ class SSHServer:
try:
client, addr = self.sock.accept()
print('Listening for connection ...')
t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
# t = threading.Thread(target=self.handle_ssh_request, args=(client, addr))
t = process.Process(target=self.handle_ssh_request, args=(client, addr))
t.daemon = True
t.start()
except Exception as e:
......
......@@ -3,6 +3,15 @@
#
import logging
from logging.config import dictConfig
from ssh_config import config, env
CONFIG_SSH_SERVER = config.get(env)
def get_logger(name):
dictConfig(CONFIG_SSH_SERVER.LOGGING)
return logging.getLogger('jumpserver.%s' % name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment