From 3f4a7225b1e6e2735f91f5de0898e6357a7ea411 Mon Sep 17 00:00:00 2001
From: ibuler <ibuler@qq.com>
Date: Tue, 30 Dec 2014 22:48:39 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9bug=EF=BC=8C=E5=A2=9E?=
 =?UTF-8?q?=E5=8A=A0ssh=5Fkey=E5=92=8C=E6=99=AE=E9=80=9A=E5=AF=86=E7=A0=81?=
 =?UTF-8?q?=E7=99=BB=E9=99=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 connect.py       | 114 ++++++++++++++++++++++++++++++++---------------
 jasset/models.py |  26 ++++++-----
 juser/models.py  |   4 +-
 3 files changed, 98 insertions(+), 46 deletions(-)

diff --git a/connect.py b/connect.py
index ee6cff31..d2002675 100755
--- a/connect.py
+++ b/connect.py
@@ -57,6 +57,16 @@ def alert_print(string):
     sys.exit()
 
 
+class ServerError(Exception):
+    def __init__(self, error):
+        self.error = error
+
+    def __str__(self):
+        return self.error
+
+    __repr__ = __str__
+
+
 class PyCrypt(object):
     """It's used to encrypt and decrypt password."""
     def __init__(self, key):
@@ -114,21 +124,25 @@ def posix_shell(chan, username, host):
     today_connect_log_dir = os.path.join(connect_log_dir, today)
     log_filename = '%s_%s_%s.log' % (username, host, date_now)
     log_file_path = os.path.join(today_connect_log_dir, log_filename)
-    user = User.objects.get(username=username)
-    asset = Asset.objects.get(ip=host)
-    pid = os.getpid()
 
+    try:
+        user = User.objects.get(username=username)
+        asset = Asset.objects.get(ip=host)
+    except ObjectDoesNotExist:
+        raise ServerError('user %s or asset %s does not exist.' % (username, host))
+
+    pid = os.getpid()
     if not os.path.isdir(today_connect_log_dir):
         try:
             os.makedirs(today_connect_log_dir)
             os.chmod(today_connect_log_dir, 0777)
         except OSError:
-            alert_print('Create %s failed, Please modify %s permission.' % (today_connect_log_dir, connect_log_dir))
+            raise ServerError('Create %s failed, Please modify %s permission.' % (today_connect_log_dir, connect_log_dir))
 
     try:
         log_file = open(log_file_path, 'a')
     except IOError:
-        alert_print('Create logfile failed, Please modify %s permission.' % today_connect_log_dir)
+        raise ServerError('Create logfile failed, Please modify %s permission.' % today_connect_log_dir)
 
     log = Log(user=user, asset=asset, log_path=log_file_path, start_time=timestamp_start, pid=pid)
     log.save()
@@ -177,7 +191,7 @@ def get_user_host(username):
     try:
         user = User.objects.get(username=username)
     except ObjectDoesNotExist:
-        return {'Error': ['0', "Username \033[1;31m%s\033[0m doesn't exist on Jumpserver." % username]}, ['Error']
+        raise ServerError("Username \033[1;31m%s\033[0m doesn't exist on Jumpserver." % username)
     else:
         perm_all = user.permission_set.all()
         for perm in perm_all:
@@ -194,30 +208,58 @@ def get_connect_item(username, ip):
         asset = Asset.objects.get(ip=ip)
         port = asset.port
     except ObjectDoesNotExist:
-        red_print("Host %s isn't exist." % ip)
-        return
-
-    user = User.objects.get(username=username)
-    if asset.ldap_enable:
-        ldap_pwd = cryptor.decrypt(user.ldap_pwd)
-        return username, ldap_pwd, ip, port
-    elif asset.ssh_key_enable:
-        ssh_key_pwd = cryptor.decrypt(user.ssh_key_pwd)
-        return username, ssh_key_pwd, ip, port
-    else:
-        perms = asset.permission_set.all()
-        perm = perms[0]
+        raise ServerError("Host %s does not exist." % ip)
+
+    try:
+        user = User.objects.get(username=username)
+    except ObjectDoesNotExist:
+        raise ServerError('User %s does not exist.' % username)
+
+    if asset.login_type == 'L':
+        try:
+            ldap_pwd = cryptor.decrypt(user.ldap_pwd)
+        except TypeError:
+            raise ServerError('Decrypt %s ldap password error.' % username)
+        return 'L', username, ldap_pwd, ip, port
+    elif asset.login_type == 'S':
+        try:
+            ssh_key_pwd = cryptor.decrypt(user.ssh_key_pwd2)
+        except TypeError:
+            raise ServerError('Decrypt %s ssh key password error.' % username)
+        return 'S', username, ssh_key_pwd, ip, port
+    elif asset.login_type == 'P':
+        try:
+            ssh_pwd = cryptor.decrypt(user.ssh_pwd)
+        except TypeError:
+            raise ServerError('Decrypt %s ssh password error.' % username)
+        return 'P', username, ssh_pwd, ip, port
+    elif asset.login_type == 'M':
+        perms = asset.permission_set.filter(user=user)
+        try:
+            perm = perms[0]
+        except IndexError:
+            raise ServerError('Permission %s to %s does not exist.' % (username, ip))
 
         if perm.role == 'SU':
+            username_super = asset.username_super
             try:
-                return asset.username_super, cryptor.decrypt(asset.password_super), ip, port
+                password_super = cryptor.decrypt(asset.password_super)
             except TypeError:
-                red_print('User %s password error to decrypt.' % username)
-        else:
+                raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_super, ip))
+            return 'M', username_super, password_super, ip, port
+
+        elif perm.role == 'CU':
+            username_common = asset.username_common
             try:
-                return asset.username_common, cryptor.decrypt(asset.password_common), ip, port
+                password_common = asset.password_common
             except TypeError:
-                red_print('User %s password error to decrypt.' % username)
+                raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_common, ip))
+            return username_common, password_common, ip, port
+
+        else:
+            raise ServerError('Perm in %s for %s map role is not in ["SU", "CU"].' % (ip, username))
+    else:
+        raise ServerError('Login type is not in ["L", "S", "P", "M"]')
 
 
 def verify_connect(username, part_ip):
@@ -233,12 +275,8 @@ def verify_connect(username, part_ip):
     elif len(ip_matched) < 1:
         red_print('No Permission or No host.')
     else:
-        try:
-            username, password, host, port = get_connect_item(username, ip_matched[0])
-        except (ObjectDoesNotExist, IndexError):
-            red_print('Get get_connect_item Error.')
-        else:
-            connect(username, password, host, port, LOGIN_NAME)
+        login_type, username, password, host, port = get_connect_item(username, ip_matched[0])
+        connect(username, password, host, port, LOGIN_NAME, login_type=login_type)
 
 
 def print_prompt():
@@ -257,7 +295,7 @@ def print_user_host(username):
         print '[%s] %s -- %s' % (hosts_attr[ip][0], ip, hosts_attr[ip][1])
 
 
-def connect(username, password, host, port, login_name):
+def connect(username, password, host, port, login_name, login_type='L'):
     """
     Connect server.
     """
@@ -275,11 +313,14 @@ def connect(username, password, host, port, login_name):
     ssh.load_system_host_keys()
     ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
     try:
-        ssh.connect(host, port=port, username=username, password=password, key_filename=key_filename, compress=True)
+        if login_type == 'L':
+            ssh.connect(host, port=port, username=username, password=password, key_filename=key_filename, compress=True)
+        else:
+            ssh.connect(host, port=port, username=username, password=password, compress=True)
     except paramiko.ssh_exception.AuthenticationException:
-        alert_print('Host Password Error, Please Correct it.')
+        raise ServerError('Authentication Error.')
     except socket.error:
-        alert_print('Connect SSH Socket Port Error, Please Correct it.')
+        raise ServerError('Connect SSH Socket Port Error, Please Correct it.')
 
     # Make a channel and set windows size
     global channel
@@ -320,6 +361,9 @@ if __name__ == '__main__':
             elif option in ['Q', 'q']:
                 sys.exit()
             else:
-                verify_connect(LOGIN_NAME, option)
+                try:
+                    verify_connect(LOGIN_NAME, option)
+                except ServerError, e:
+                    red_print(e)
     except IndexError:
         pass
\ No newline at end of file
diff --git a/jasset/models.py b/jasset/models.py
index 084e0368..253fcbf7 100644
--- a/jasset/models.py
+++ b/jasset/models.py
@@ -3,25 +3,31 @@ from juser.models import Group, User
 
 
 class IDC(models.Model):
-    name = models.CharField(max_length=80, unique=True)
-    comment = models.CharField(max_length=100, blank=True, null=True)
+    name = models.CharField(max_length=40, unique=True)
+    comment = models.CharField(max_length=80, blank=True, null=True)
 
     def __unicode__(self):
         return self.name
 
 
 class Asset(models.Model):
+    LOGIN_TYPE_CHOICES = (
+        ('L', 'LDAP'),
+        ('S', 'SSH_KEY'),
+        ('P', 'PASSWORD'),
+        ('M', 'MAP'),
+    )
     ip = models.IPAddressField(unique=True)
-    port = models.SmallIntegerField(max_length=40)
+    port = models.SmallIntegerField(max_length=5)
     idc = models.ForeignKey(IDC)
     group = models.ManyToManyField(Group)
-    ldap_enable = models.BooleanField(default=True)
-    ssh_key_enable = models.BooleanField(default=False)
-    username_common = models.CharField(max_length=80, blank=True, null=True)
-    password_common = models.CharField(max_length=160, blank=True, null=True)
-    username_super = models.CharField(max_length=80, blank=True, null=True)
-    password_super = models.CharField(max_length=160, blank=True, null=True)
-    date_added = models.IntegerField(max_length=80)
+    login_type = models.CharField(max_length=1, choices=LOGIN_TYPE_CHOICES, default='L')
+    username_common = models.CharField(max_length=20, blank=True, null=True)
+    password_common = models.CharField(max_length=80, blank=True, null=True)
+    username_super = models.CharField(max_length=20, blank=True, null=True)
+    password_super = models.CharField(max_length=80, blank=True, null=True)
+    date_added = models.IntegerField(max_length=12)
+    is_active = models.BooleanField(default=True)
     comment = models.CharField(max_length=100, blank=True, null=True)
 
     def __unicode__(self):
diff --git a/juser/models.py b/juser/models.py
index 26cc6c5d..3e2c9a52 100644
--- a/juser/models.py
+++ b/juser/models.py
@@ -22,7 +22,9 @@ class User(models.Model):
     role = models.CharField(max_length=2, choices=USER_ROLE_CHOICES, default='CU')
     group = models.ManyToManyField(Group)
     ldap_pwd = models.CharField(max_length=100)
-    ssh_key_pwd = models.CharField(max_length=100)
+    ssh_key_pwd1 = models.CharField(max_length=100)
+    ssh_key_pwd2 = models.CharField(max_length=100)
+    ssh_pwd = models.CharField(max_length=100)
     is_active = models.BooleanField(default=True)
     last_login = models.IntegerField(default=0)
     date_joined = models.IntegerField()
-- 
2.18.0