Commit 7ad60ede authored by ibuler's avatar ibuler

[Update] 修改token

parent 3b4db38f
...@@ -205,7 +205,6 @@ class AssetsAmountMixin: ...@@ -205,7 +205,6 @@ class AssetsAmountMixin:
获取节点下所有资产数量速度太慢,所以需要重写,使用cache等方案 获取节点下所有资产数量速度太慢,所以需要重写,使用cache等方案
:return: :return:
""" """
return 0
if self._assets_amount is not None: if self._assets_amount is not None:
return self._assets_amount return self._assets_amount
cache_key = self._assets_amount_cache_key.format(self.key) cache_key = self._assets_amount_cache_key.format(self.key)
......
# Generated by Django 2.1.7 on 2019-07-26 09:53
from django.db import migrations, models
def migrate_loginlog_reason_to_str(apps, schema_editor):
db_alias = schema_editor.connection.alias
reason_map = {
"0": "",
"1": 'Username/password check failed',
"2": 'MFA authentication failed',
"3": "Username does not exist",
"4": "Password expired",
}
model = apps.get_model("audits", "UserLoginLog")
for k, v in reason_map.items():
model.objects.using(db_alias).filter(reason=k).update(reason=v)
class Migration(migrations.Migration):
dependencies = [
('audits', '0005_auto_20190228_1715'),
]
operations = [
migrations.AlterField(
model_name='userloginlog',
name='reason',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Reason'),
),
migrations.RunPython(migrate_loginlog_reason_to_str),
]
...@@ -72,20 +72,6 @@ class UserLoginLog(models.Model): ...@@ -72,20 +72,6 @@ class UserLoginLog(models.Model):
(MFA_UNKNOWN, _('-')), (MFA_UNKNOWN, _('-')),
) )
REASON_NOTHING = 0
REASON_PASSWORD = 1
REASON_MFA = 2
REASON_NOT_EXIST = 3
REASON_PASSWORD_EXPIRED = 4
REASON_CHOICE = (
(REASON_NOTHING, _('-')),
(REASON_PASSWORD, _('Username/password check failed')),
(REASON_MFA, _('MFA authentication failed')),
(REASON_NOT_EXIST, _("Username does not exist")),
(REASON_PASSWORD_EXPIRED, _("Password expired")),
)
STATUS_CHOICE = ( STATUS_CHOICE = (
(True, _('Success')), (True, _('Success')),
(False, _('Failed')) (False, _('Failed'))
...@@ -97,7 +83,7 @@ class UserLoginLog(models.Model): ...@@ -97,7 +83,7 @@ class UserLoginLog(models.Model):
city = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('Login city')) city = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('Login city'))
user_agent = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('User agent')) user_agent = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('User agent'))
mfa = models.SmallIntegerField(default=MFA_UNKNOWN, choices=MFA_CHOICE, verbose_name=_('MFA')) mfa = models.SmallIntegerField(default=MFA_UNKNOWN, choices=MFA_CHOICE, verbose_name=_('MFA'))
reason = models.SmallIntegerField(default=0, choices=REASON_CHOICE, verbose_name=_('Reason')) reason = models.CharField(default='', max_length=128, blank=True, verbose_name=_('Reason'))
status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status')) status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status'))
datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login')) datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login'))
......
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
<td class="text-center">{{ login_log.ip }}</td> <td class="text-center">{{ login_log.ip }}</td>
<td class="text-center">{{ login_log.city }}</td> <td class="text-center">{{ login_log.city }}</td>
<td class="text-center">{{ login_log.get_mfa_display }}</td> <td class="text-center">{{ login_log.get_mfa_display }}</td>
<td class="text-center">{{ login_log.get_reason_display }}</td> <td class="text-center">{% trans login_log.reason %}</td>
<td class="text-center">{{ login_log.get_status_display }}</td> <td class="text-center">{{ login_log.get_status_display }}</td>
<td class="text-center">{{ login_log.datetime }}</td> <td class="text-center">{{ login_log.datetime }}</td>
</tr> </tr>
......
...@@ -2,3 +2,5 @@ ...@@ -2,3 +2,5 @@
# #
from .auth import * from .auth import *
from .token import *
from .mfa import *
...@@ -22,9 +22,11 @@ from users.models import User ...@@ -22,9 +22,11 @@ from users.models import User
from assets.models import Asset, SystemUser from assets.models import Asset, SystemUser
from audits.models import UserLoginLog as LoginLog from audits.models import UserLoginLog as LoginLog
from users.utils import ( from users.utils import (
check_user_valid, check_otp_code, increase_login_failed_count, check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count is_block_login, clean_failed_count
) )
from .. import const
from ..utils import check_user_valid
from ..serializers import OtpVerifySerializer from ..serializers import OtpVerifySerializer
from ..signals import post_auth_success, post_auth_failed from ..signals import post_auth_success, post_auth_failed
...@@ -53,27 +55,15 @@ class UserAuthApi(RootOrgViewMixin, APIView): ...@@ -53,27 +55,15 @@ class UserAuthApi(RootOrgViewMixin, APIView):
user, msg = self.check_user_valid(request) user, msg = self.check_user_valid(request)
if not user: if not user:
username = request.data.get('username', '') username = request.data.get('username', '')
exist = User.objects.filter(username=username).first() self.send_auth_signal(success=False, username=username, reason=msg)
reason = LoginLog.REASON_PASSWORD if exist else LoginLog.REASON_NOT_EXIST
self.send_auth_signal(success=False, username=username, reason=reason)
increase_login_failed_count(username, ip) increase_login_failed_count(username, ip)
return Response({'msg': msg}, status=401) return Response({'msg': msg}, status=401)
if user.password_has_expired:
self.send_auth_signal(
success=False, username=username,
reason=LoginLog.REASON_PASSWORD_EXPIRED
)
msg = _("The user {} password has expired, please update.".format(
user.username))
logger.info(msg)
return Response({'msg': msg}, status=401)
if not user.otp_enabled: if not user.otp_enabled:
self.send_auth_signal(success=True, user=user) self.send_auth_signal(success=True, user=user)
# 登陆成功,清除原来的缓存计数 # 登陆成功,清除原来的缓存计数
clean_failed_count(username, ip) clean_failed_count(username, ip)
token = user.create_bearer_token(request) token, expired_at = user.create_bearer_token(request)
return Response( return Response(
{'token': token, 'user': self.serializer_class(user).data} {'token': token, 'user': self.serializer_class(user).data}
) )
...@@ -167,10 +157,10 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView): ...@@ -167,10 +157,10 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView):
status=401 status=401
) )
if not check_otp_code(user.otp_secret_key, otp_code): if not check_otp_code(user.otp_secret_key, otp_code):
self.send_auth_signal(success=False, username=user.username, reason=LoginLog.REASON_MFA) self.send_auth_signal(success=False, username=user.username, reason=const.mfa_failed)
return Response({'msg': _('MFA certification failed')}, status=401) return Response({'msg': _('MFA certification failed')}, status=401)
self.send_auth_signal(success=True, user=user) self.send_auth_signal(success=True, user=user)
token = user.create_bearer_token(request) token, expired_at = user.create_bearer_token(request)
data = {'token': token, 'user': self.serializer_class(user).data} data = {'token': token, 'user': self.serializer_class(user).data}
return Response(data) return Response(data)
......
# -*- coding: utf-8 -*-
#
from rest_framework.permissions import AllowAny
from rest_framework.generics import CreateAPIView
from .. import serializers
class MFAChallengeApi(CreateAPIView):
permission_classes = (AllowAny,)
serializer_class = serializers.MFAChallengeSerializer
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
from django.core.cache import cache
from django.utils.translation import ugettext as _
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.generics import CreateAPIView from rest_framework.generics import CreateAPIView
from rest_framework.views import APIView from drf_yasg.utils import swagger_auto_schema
from common.utils import get_request_ip, get_logger
from users.utils import (
check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count
)
from ..utils import check_user_valid
from ..signals import post_auth_success, post_auth_failed
from .. import serializers from .. import serializers
logger = get_logger(__name__)
__all__ = ['TokenCreateApi']
class AuthFailedError(Exception):
def __init__(self, msg, reason=None):
self.msg = msg
self.reason = reason
class MFARequiredError(Exception):
pass
class TokenCreateApi(CreateAPIView): class TokenCreateApi(CreateAPIView):
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
serializer_class = serializers.BearerTokenSerializer serializer_class = serializers.BearerTokenSerializer
@staticmethod
def check_is_block(username, ip):
if is_block_login(username, ip):
msg = _("Log in frequently and try again later")
logger.warn(msg + ': ' + username + ':' + ip)
raise AuthFailedError(msg)
def check_user_valid(self):
request = self.request
username = request.data.get('username', '')
password = request.data.get('password', '')
public_key = request.data.get('public_key', '')
user, msg = check_user_valid(
username=username, password=password,
public_key=public_key
)
if not user:
raise AuthFailedError(msg)
return user
def create(self, request, *args, **kwargs):
username = self.request.data.get('username')
ip = self.request.data.get('remote_addr', None)
ip = ip or get_request_ip(self.request)
user = None
try:
self.check_is_block(username, ip)
user = self.check_user_valid()
if user.otp_enabled:
raise MFARequiredError()
self.send_auth_signal(success=True, user=user)
clean_failed_count(username, ip)
return super().create(request, *args, **kwargs)
except AuthFailedError as e:
increase_login_failed_count(username, ip)
self.send_auth_signal(success=False, user=user, username=username, reason=str(e))
return Response({'msg': str(e)}, status=401)
except MFARequiredError:
msg = _("MFA required")
seed = uuid.uuid4().hex
cache.set(seed, user.username, 300)
resp = {'msg': msg, "choices": ["otp"], "req": seed}
return Response(resp, status=300)
def send_auth_signal(self, success=True, user=None, username='', reason=''):
if success:
post_auth_success.send(
sender=self.__class__, user=user, request=self.request
)
else:
post_auth_failed.send(
sender=self.__class__, username=username,
request=self.request, reason=reason
)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext_lazy as _
password_failed = _('Username/password check failed')
mfa_failed = _('MFA authentication failed')
user_not_exist = _("Username does not exist")
password_expired = _("Password expired")
user_invalid = _('Disabled or expired')
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.core.cache import cache
from rest_framework import serializers from rest_framework import serializers
from users.models import User from users.models import User
...@@ -8,6 +9,7 @@ from .models import AccessKey ...@@ -8,6 +9,7 @@ from .models import AccessKey
__all__ = [ __all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer', 'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer',
] ]
...@@ -23,24 +25,26 @@ class OtpVerifySerializer(serializers.Serializer): ...@@ -23,24 +25,26 @@ class OtpVerifySerializer(serializers.Serializer):
code = serializers.CharField(max_length=6, min_length=6) code = serializers.CharField(max_length=6, min_length=6)
class BearerTokenSerializer(serializers.Serializer): class BearerTokenMixin(serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField(allow_blank=True, write_only=True)
public_key = serializers.CharField(allow_blank=True, write_only=True)
token = serializers.CharField(read_only=True) token = serializers.CharField(read_only=True)
keyword = serializers.SerializerMethodField() keyword = serializers.SerializerMethodField()
date_expired = serializers.DateTimeField(read_only=True)
@staticmethod @staticmethod
def get_keyword(obj): def get_keyword(obj):
return 'Bearer' return 'Bearer'
def create(self, validated_data): def create_response(self, username):
username = validated_data["username"]
request = self.context.get("request") request = self.context.get("request")
user = User.objects.get(username=username) try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise serializers.ValidationError("username %s not exist" % username)
token, date_expired = user.create_bearer_token(request)
instance = { instance = {
"username": validated_data.get(username), "username": username,
"token": user.create_bearer_token(request), "token": token,
"date_expired": date_expired,
} }
return instance return instance
...@@ -48,3 +52,38 @@ class BearerTokenSerializer(serializers.Serializer): ...@@ -48,3 +52,38 @@ class BearerTokenSerializer(serializers.Serializer):
pass pass
class BearerTokenSerializer(BearerTokenMixin, serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField(write_only=True, allow_null=True,
required=False)
public_key = serializers.CharField(write_only=True, allow_null=True,
required=False)
def create(self, validated_data):
username = validated_data.get("username")
return self.create_response(username)
class MFAChallengeSerializer(BearerTokenMixin, serializers.Serializer):
req = serializers.CharField(write_only=True)
auth_type = serializers.CharField(write_only=True)
code = serializers.CharField(write_only=True)
def validate_req(self, attr):
username = cache.get(attr)
if not username:
raise serializers.ValidationError("Not valid, may be expired")
self.context["username"] = username
def validate_code(self, code):
username = self.context["username"]
user = User.objects.get(username=username)
ok = user.check_otp(code)
if not ok:
msg = "Otp code not valid, may be expired"
raise serializers.ValidationError(msg)
def create(self, validated_data):
username = self.context["username"]
return self.create_response(username)
...@@ -13,6 +13,8 @@ app_name = 'authentication' ...@@ -13,6 +13,8 @@ app_name = 'authentication'
urlpatterns = [ urlpatterns = [
# path('token/', api.UserToken.as_view(), name='user-token'), # path('token/', api.UserToken.as_view(), name='user-token'),
path('auth/', api.UserAuthApi.as_view(), name='user-auth'), path('auth/', api.UserAuthApi.as_view(), name='user-auth'),
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'),
path('connection-token/', path('connection-token/',
api.UserConnectionTokenApi.as_view(), name='connection-token'), api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('otp/auth/', api.UserOtpAuthApi.as_view(), name='user-otp-auth'), path('otp/auth/', api.UserOtpAuthApi.as_view(), name='user-otp-auth'),
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from common.utils import get_ip_city, validate_ip from django.contrib.auth import authenticate
from common.utils import get_ip_city, get_object_or_none, validate_ip
from users.models import User
from . import const
def write_login_log(*args, **kwargs): def write_login_log(*args, **kwargs):
...@@ -16,3 +20,36 @@ def write_login_log(*args, **kwargs): ...@@ -16,3 +20,36 @@ def write_login_log(*args, **kwargs):
kwargs.update({'ip': ip, 'city': city}) kwargs.update({'ip': ip, 'city': city})
UserLoginLog.objects.create(**kwargs) UserLoginLog.objects.create(**kwargs)
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email', None)
username = kwargs.pop('username', None)
if username:
user = get_object_or_none(User, username=username)
elif email:
user = get_object_or_none(User, email=email)
else:
user = None
if user is None:
return None, const.user_not_exist
elif not user.is_valid:
return None, const.user_invalid
elif user.password_has_expired:
return None, const.password_expired
if password and authenticate(username=username, password=password):
return user, ''
if public_key and user.public_key:
public_key_saved = user.public_key.split()
if len(public_key_saved) == 1:
if public_key == public_key_saved[0]:
return user, ''
elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]:
return user, ''
return None, const.password_failed
...@@ -202,11 +202,11 @@ class GrantAssetsMixin(LabelFilterMixin): ...@@ -202,11 +202,11 @@ class GrantAssetsMixin(LabelFilterMixin):
data.append(asset) data.append(asset)
return data return data
def get_serializer(self, data=None, many=True): def get_serializer(self, assets_items=None, many=True):
if data is None: if assets_items is None:
data = [] assets_items = []
data = self.get_serializer_queryset(data) assets_items = self.get_serializer_queryset(assets_items)
return super().get_serializer(data=data, many=many) return super().get_serializer(assets_items, many=many)
def filter_queryset_by_id(self, assets_items): def filter_queryset_by_id(self, assets_items):
i = self.request.query_params.get("id") i = self.request.query_params.get("id")
......
...@@ -140,11 +140,11 @@ class UserGrantedNodesApi(UserPermissionCacheMixin, NodesWithUngroupMixin, ListA ...@@ -140,11 +140,11 @@ class UserGrantedNodesApi(UserPermissionCacheMixin, NodesWithUngroupMixin, ListA
_nodes.append(node) _nodes.append(node)
return _nodes return _nodes
def get_serializer(self, data=None, many=True): def get_serializer(self, nodes_items=None, many=True):
if data is None: if nodes_items is None:
data = [] nodes_items = []
nodes = self.get_nodes(data) nodes = self.get_nodes(nodes_items)
return super().get_serializer(data=nodes, many=True) return super().get_serializer(nodes, many=True)
def get_queryset(self): def get_queryset(self):
user = self.get_object() user = self.get_object()
...@@ -267,11 +267,11 @@ class UserGrantedNodesWithAssetsApi(UserPermissionCacheMixin, NodesWithUngroupMi ...@@ -267,11 +267,11 @@ class UserGrantedNodesWithAssetsApi(UserPermissionCacheMixin, NodesWithUngroupMi
queryset.append(node) queryset.append(node)
return queryset return queryset
def get_serializer(self, data=None, many=True): def get_serializer(self, nodes_items=None, many=True):
if data is None: if nodes_items is None:
data = [] nodes_items = []
queryset = self.get_serializer_queryset(data) queryset = self.get_serializer_queryset(nodes_items)
return super().get_serializer(data=queryset, many=many) return super().get_serializer(queryset, many=many)
def get_queryset(self): def get_queryset(self):
user = self.get_object() user = self.get_object()
...@@ -298,10 +298,10 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserGrantedNodesWithAssetsApi): ...@@ -298,10 +298,10 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserGrantedNodesWithAssetsApi):
assets_only_fields = ParserNode.assets_only_fields assets_only_fields = ParserNode.assets_only_fields
system_users_only_fields = ParserNode.system_users_only_fields system_users_only_fields = ParserNode.system_users_only_fields
def get_serializer(self, data=None, many=True): def get_serializer(self, nodes_items=None, many=True):
if data is None: if nodes_items is None:
data = [] nodes_items = []
_queryset = super().get_serializer_queryset(data) _queryset = super().get_serializer_queryset(nodes_items)
queryset = [] queryset = []
for node in _queryset: for node in _queryset:
......
...@@ -102,13 +102,14 @@ class PermAssetsAmountUtil(PermStackUtilMixin): ...@@ -102,13 +102,14 @@ class PermAssetsAmountUtil(PermStackUtilMixin):
self.debug("出栈: {} 栈顶: {}".format( self.debug("出栈: {} 栈顶: {}".format(
_node['key'], self.stack.top['key'] if self.stack.top else None) _node['key'], self.stack.top['key'] if self.stack.top else None)
) )
_node["assets_amount"] = len(_node["all_assets"] | _node["assets"]) _node["all_assets"] = _node["all_assets"] | _node["assets"]
_node["assets_amount"] = len(_node["all_assets"])
self._nodes[_node.pop("key")] = _node self._nodes[_node.pop("key")] = _node
if not self.stack.top: if not self.stack.top:
return return
self.stack.top["all_assets"]\ self.stack.top["all_assets"]\
.update(_node["all_assets"] | _node["assets"]) .update(_node["all_assets"])
def compute_nodes_assets_amount(self, nodes_with_assets): def compute_nodes_assets_amount(self, nodes_with_assets):
self.stack = Stack() self.stack = Stack()
......
...@@ -232,7 +232,8 @@ class TokenMixin: ...@@ -232,7 +232,8 @@ class TokenMixin:
token = uuid.uuid4().hex token = uuid.uuid4().hex
cache.set(token, self.id, expiration) cache.set(token, self.id, expiration)
cache.set('%s_%s' % (self.id, remote_addr), token, expiration) cache.set('%s_%s' % (self.id, remote_addr), token, expiration)
return token date_expired = timezone.now() + timezone.timedelta(seconds=expiration)
return token, date_expired
def refresh_bearer_token(self, token): def refresh_bearer_token(self, token):
pass pass
......
...@@ -7,18 +7,14 @@ import pyotp ...@@ -7,18 +7,14 @@ import pyotp
import base64 import base64
import logging import logging
import ipaddress
from django.http import Http404 from django.http import Http404
from django.conf import settings from django.conf import settings
from django.contrib.auth.mixins import UserPassesTestMixin
from django.contrib.auth import authenticate
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.core.cache import cache from django.core.cache import cache
from datetime import datetime from datetime import datetime
from common.tasks import send_mail_async from common.tasks import send_mail_async
from common.utils import reverse, get_object_or_none, get_ip_city from common.utils import reverse
from .models import User
logger = logging.getLogger('jumpserver') logger = logging.getLogger('jumpserver')
...@@ -177,37 +173,6 @@ def send_reset_ssh_key_mail(user): ...@@ -177,37 +173,6 @@ def send_reset_ssh_key_mail(user):
send_mail_async.delay(subject, message, recipient_list, html_message=message) send_mail_async.delay(subject, message, recipient_list, html_message=message)
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email', None)
username = kwargs.pop('username', None)
if username:
user = get_object_or_none(User, username=username)
elif email:
user = get_object_or_none(User, email=email)
else:
user = None
if user is None:
return None, _('User not exist')
elif not user.is_valid:
return None, _('Disabled or expired')
if password and authenticate(username=username, password=password):
return user, ''
if public_key and user.public_key:
public_key_saved = user.public_key.split()
if len(public_key_saved) == 1:
if public_key == public_key_saved[0]:
return user, ''
elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]:
return user, ''
return None, _('Password or SSH public key invalid')
def get_user_or_tmp_user(request): def get_user_or_tmp_user(request):
user = request.user user = request.user
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment