Commit 5b05f664 authored by ibuler's avatar ibuler

[Update] 继续支持多进程

parent 01e4eb72
# -*- coding: utf-8 -*-
#
from werkzeug.local import LocalProxy
from functools import partial
stack = {}
__db_sessions = []
def _find(name):
if stack.get(name):
return stack[name]
else:
raise ValueError("Not found in stack: {}".format(name))
current_app = LocalProxy(partial(_find, 'current_app'))
...@@ -5,7 +5,7 @@ import uuid ...@@ -5,7 +5,7 @@ import uuid
import socket import socket
from .service import app_service from .service import app_service
from .struct import SizedList, SelectEvent from .utils import SizedList, SelectEvent
from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \ from .utils import wrap_with_line_feed as wr, wrap_with_warning as warning, \
ugettext as _ ugettext as _
from . import char, utils from . import char, utils
......
...@@ -14,7 +14,7 @@ import jms_storage ...@@ -14,7 +14,7 @@ import jms_storage
from .config import config from .config import config
from .utils import get_logger, Singleton from .utils import get_logger, Singleton
from .struct import MemoryQueue from .utils import MemoryQueue
from .service import app_service from .service import app_service
logger = get_logger(__file__) logger = get_logger(__file__)
......
...@@ -13,7 +13,7 @@ except ImportError: ...@@ -13,7 +13,7 @@ except ImportError:
from .utils import get_logger, wrap_with_warning as warn, \ from .utils import get_logger, wrap_with_warning as warn, \
wrap_with_line_feed as wr, ugettext as _, ignore_error wrap_with_line_feed as wr, ugettext as _, ignore_error
from .service import app_service from .service import app_service
from .struct import SelectEvent from .utils import SelectEvent
from .recorder import get_recorder from .recorder import get_recorder
BUF_SIZE = 1024 BUF_SIZE = 1024
......
...@@ -6,6 +6,9 @@ import os ...@@ -6,6 +6,9 @@ import os
import socket import socket
import threading import threading
import time import time
import random
import multiprocessing
from multiprocessing.reduction import recv_handle, send_handle, DupFd
import paramiko import paramiko
...@@ -17,13 +20,14 @@ from coco.sftp import SFTPServer ...@@ -17,13 +20,14 @@ from coco.sftp import SFTPServer
from coco.config import config from coco.config import config
logger = get_logger(__file__) logger = get_logger(__file__)
current_socks = []
BACKLOG = 5 BACKLOG = 5
class SSHServer: class SSHServer:
def __init__(self): def __init__(self):
self.stop_evt = threading.Event() self.stop_evt = multiprocessing.Event()
self.workers = [] self.workers = []
self.pipe = None self.pipe = None
...@@ -40,7 +44,8 @@ class SSHServer: ...@@ -40,7 +44,8 @@ class SSHServer:
with open(key_path, 'w') as f: with open(key_path, 'w') as f:
f.write(ssh_key) f.write(ssh_key)
def run(self): def start_master(self, in_p, out_p, workers):
in_p.close()
host = config["BIND_HOST"] host = config["BIND_HOST"]
port = config["SSHD_PORT"] port = config["SSHD_PORT"]
print('Starting ssh server at {}:{}'.format(host, port)) print('Starting ssh server at {}:{}'.format(host, port))
...@@ -48,17 +53,54 @@ class SSHServer: ...@@ -48,17 +53,54 @@ class SSHServer:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
sock.bind((host, port)) sock.bind((host, port))
sock.listen(BACKLOG) sock.listen(BACKLOG)
while not self.stop_evt.is_set(): while True:
try: try:
client, addr = sock.accept() client, addr = sock.accept()
t = threading.Thread(target=self.handle_connection, args=(client, addr)) worker = random.choice(workers)
t.daemon = True send_handle(out_p, client.fileno(), worker.pid)
t.start()
except IndexError as e: except IndexError as e:
logger.error("Start SSH server error: {}".format(e)) logger.error("Start SSH server error: {}".format(e))
def start_worker(self, in_p, out_p):
out_p.close()
while True:
fd = recv_handle(in_p)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd)
# print("Recv sock: {}".format(sock))
addr = sock.getpeername()
thread = threading.Thread(
target=self.handle_connection, args=(sock, addr)
)
thread.daemon = True
thread.start()
def start_workers(self, in_p, out_p):
workers = []
for i in range(4):
worker = multiprocessing.Process(target=self.start_worker, args=(in_p, out_p))
worker.daemon = True
workers.append(worker)
worker.start()
return workers
def run(self):
c1, c2 = multiprocessing.Pipe()
self.pipe = (c1, c2)
workers = self.start_workers(c1, c2)
self.workers = workers
server_p = multiprocessing.Process(
target=self.start_master, args=(c1, c2, workers), name='master'
)
server_p.start()
server_p.join()
c1.close()
c2.close()
print("Exit")
def handle_connection(self, sock, addr): def handle_connection(self, sock, addr):
logger.debug("Handle new connection from: {}".format(addr)) logger.debug("Handle new connection from: {}".format(addr))
time.sleep(4)
print("Sock is closed: {} 2".format(sock._closed))
transport = paramiko.Transport(sock, gss_kex=False) transport = paramiko.Transport(sock, gss_kex=False)
try: try:
transport.load_server_moduli() transport.load_server_moduli()
...@@ -71,6 +113,7 @@ class SSHServer: ...@@ -71,6 +113,7 @@ class SSHServer:
) )
connection = Connection.new_connection(addr=addr, sock=sock) connection = Connection.new_connection(addr=addr, sock=sock)
server = SSHInterface(connection) server = SSHInterface(connection)
print("Sock is closed: {} 3".format(transport.sock._closed))
try: try:
transport.start_server(server=server) transport.start_server(server=server)
while transport.is_active(): while transport.is_active():
...@@ -94,9 +137,9 @@ class SSHServer: ...@@ -94,9 +137,9 @@ class SSHServer:
transport.close() transport.close()
except paramiko.SSHException as e: except paramiko.SSHException as e:
logger.warning("SSH negotiation failed: {}".format(e)) logger.warning("SSH negotiation failed: {}".format(e))
except EOFError as e: except IndexError as e:
logger.warning("Handle connection EOF Error: {}".format(e)) logger.warning("Handle connection EOF Error: {}".format(e))
except Exception as e: except SyntaxError as e:
logger.error("Unexpect error occur on handle connection: {}".format(e)) logger.error("Unexpect error occur on handle connection: {}".format(e))
finally: finally:
Connection.remove_connection(connection.id) Connection.remove_connection(connection.id)
...@@ -129,8 +172,3 @@ class SSHServer: ...@@ -129,8 +172,3 @@ class SSHServer:
def shutdown(self): def shutdown(self):
self.stop_evt.set() self.stop_evt.set()
if __name__ == '__main__':
ssh_server = SSHServer()
ssh_server.run()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
import queue
import socket
class MultiQueueMixin:
def mget(self, size=1, block=True, timeout=5):
items = []
for i in range(size):
try:
items.append(self.get(block=block, timeout=timeout))
except queue.Empty:
break
return items
def mput(self, data_set):
for i in data_set:
self.put(i)
class MemoryQueue(MultiQueueMixin, queue.Queue, object):
pass
class SizedList(list):
def __init__(self, maxsize=0):
self.maxsize = maxsize
self.size = 0
super(list, self).__init__()
def append(self, b):
if self.maxsize == 0 or self.size < self.maxsize:
super(SizedList, self).append(b)
self.size += len(b)
def clean(self):
self.size = 0
del self[:]
class SelectEvent:
def __init__(self):
self.p1, self.p2 = socket.socketpair()
def set(self):
self.p2.send(b'0')
def fileno(self):
return self.p1.fileno()
def __getattr__(self, item):
return getattr(self.p1, item)
...@@ -8,6 +8,8 @@ import logging ...@@ -8,6 +8,8 @@ import logging
import re import re
import os import os
import gettext import gettext
import queue
import socket
from io import StringIO from io import StringIO
from binascii import hexlify from binascii import hexlify
from werkzeug.local import Local, LocalProxy from werkzeug.local import Local, LocalProxy
...@@ -464,4 +466,53 @@ def ignore_error(func): ...@@ -464,4 +466,53 @@ def ignore_error(func):
return wrapper return wrapper
class MultiQueueMixin:
def mget(self, size=1, block=True, timeout=5):
items = []
for i in range(size):
try:
items.append(self.get(block=block, timeout=timeout))
except queue.Empty:
break
return items
def mput(self, data_set):
for i in data_set:
self.put(i)
class MemoryQueue(MultiQueueMixin, queue.Queue, object):
pass
class SizedList(list):
def __init__(self, maxsize=0):
self.maxsize = maxsize
self.size = 0
super(list, self).__init__()
def append(self, b):
if self.maxsize == 0 or self.size < self.maxsize:
super(SizedList, self).append(b)
self.size += len(b)
def clean(self):
self.size = 0
del self[:]
class SelectEvent:
def __init__(self):
self.p1, self.p2 = socket.socketpair()
def set(self):
self.p2.send(b'0')
def fileno(self):
return self.p1.fileno()
def __getattr__(self, item):
return getattr(self.p1, item)
ugettext = LocalProxy(partial(_find, 'LANGUAGE_CODE')) ugettext = LocalProxy(partial(_find, 'LANGUAGE_CODE'))
# -*- coding: utf-8 -*-
#
from coco.sshd import SSHServer
ssh = SSHServer()
if __name__ == '__main__':
ssh.run()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment