Commit 7ad60ede authored by ibuler's avatar ibuler

[Update] 修改token

parent 3b4db38f
......@@ -205,7 +205,6 @@ class AssetsAmountMixin:
获取节点下所有资产数量速度太慢,所以需要重写,使用cache等方案
:return:
"""
return 0
if self._assets_amount is not None:
return self._assets_amount
cache_key = self._assets_amount_cache_key.format(self.key)
......
# Generated by Django 2.1.7 on 2019-07-26 09:53
from django.db import migrations, models
def migrate_loginlog_reason_to_str(apps, schema_editor):
db_alias = schema_editor.connection.alias
reason_map = {
"0": "",
"1": 'Username/password check failed',
"2": 'MFA authentication failed',
"3": "Username does not exist",
"4": "Password expired",
}
model = apps.get_model("audits", "UserLoginLog")
for k, v in reason_map.items():
model.objects.using(db_alias).filter(reason=k).update(reason=v)
class Migration(migrations.Migration):
dependencies = [
('audits', '0005_auto_20190228_1715'),
]
operations = [
migrations.AlterField(
model_name='userloginlog',
name='reason',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Reason'),
),
migrations.RunPython(migrate_loginlog_reason_to_str),
]
......@@ -72,20 +72,6 @@ class UserLoginLog(models.Model):
(MFA_UNKNOWN, _('-')),
)
REASON_NOTHING = 0
REASON_PASSWORD = 1
REASON_MFA = 2
REASON_NOT_EXIST = 3
REASON_PASSWORD_EXPIRED = 4
REASON_CHOICE = (
(REASON_NOTHING, _('-')),
(REASON_PASSWORD, _('Username/password check failed')),
(REASON_MFA, _('MFA authentication failed')),
(REASON_NOT_EXIST, _("Username does not exist")),
(REASON_PASSWORD_EXPIRED, _("Password expired")),
)
STATUS_CHOICE = (
(True, _('Success')),
(False, _('Failed'))
......@@ -97,7 +83,7 @@ class UserLoginLog(models.Model):
city = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('Login city'))
user_agent = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('User agent'))
mfa = models.SmallIntegerField(default=MFA_UNKNOWN, choices=MFA_CHOICE, verbose_name=_('MFA'))
reason = models.SmallIntegerField(default=0, choices=REASON_CHOICE, verbose_name=_('Reason'))
reason = models.CharField(default='', max_length=128, blank=True, verbose_name=_('Reason'))
status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status'))
datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login'))
......
......@@ -72,7 +72,7 @@
<td class="text-center">{{ login_log.ip }}</td>
<td class="text-center">{{ login_log.city }}</td>
<td class="text-center">{{ login_log.get_mfa_display }}</td>
<td class="text-center">{{ login_log.get_reason_display }}</td>
<td class="text-center">{% trans login_log.reason %}</td>
<td class="text-center">{{ login_log.get_status_display }}</td>
<td class="text-center">{{ login_log.datetime }}</td>
</tr>
......
......@@ -2,3 +2,5 @@
#
from .auth import *
from .token import *
from .mfa import *
......@@ -22,9 +22,11 @@ from users.models import User
from assets.models import Asset, SystemUser
from audits.models import UserLoginLog as LoginLog
from users.utils import (
check_user_valid, check_otp_code, increase_login_failed_count,
check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count
)
from .. import const
from ..utils import check_user_valid
from ..serializers import OtpVerifySerializer
from ..signals import post_auth_success, post_auth_failed
......@@ -53,27 +55,15 @@ class UserAuthApi(RootOrgViewMixin, APIView):
user, msg = self.check_user_valid(request)
if not user:
username = request.data.get('username', '')
exist = User.objects.filter(username=username).first()
reason = LoginLog.REASON_PASSWORD if exist else LoginLog.REASON_NOT_EXIST
self.send_auth_signal(success=False, username=username, reason=reason)
self.send_auth_signal(success=False, username=username, reason=msg)
increase_login_failed_count(username, ip)
return Response({'msg': msg}, status=401)
if user.password_has_expired:
self.send_auth_signal(
success=False, username=username,
reason=LoginLog.REASON_PASSWORD_EXPIRED
)
msg = _("The user {} password has expired, please update.".format(
user.username))
logger.info(msg)
return Response({'msg': msg}, status=401)
if not user.otp_enabled:
self.send_auth_signal(success=True, user=user)
# 登陆成功,清除原来的缓存计数
clean_failed_count(username, ip)
token = user.create_bearer_token(request)
token, expired_at = user.create_bearer_token(request)
return Response(
{'token': token, 'user': self.serializer_class(user).data}
)
......@@ -167,10 +157,10 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView):
status=401
)
if not check_otp_code(user.otp_secret_key, otp_code):
self.send_auth_signal(success=False, username=user.username, reason=LoginLog.REASON_MFA)
self.send_auth_signal(success=False, username=user.username, reason=const.mfa_failed)
return Response({'msg': _('MFA certification failed')}, status=401)
self.send_auth_signal(success=True, user=user)
token = user.create_bearer_token(request)
token, expired_at = user.create_bearer_token(request)
data = {'token': token, 'user': self.serializer_class(user).data}
return Response(data)
......
# -*- coding: utf-8 -*-
#
from rest_framework.permissions import AllowAny
from rest_framework.generics import CreateAPIView
from .. import serializers
class MFAChallengeApi(CreateAPIView):
permission_classes = (AllowAny,)
serializer_class = serializers.MFAChallengeSerializer
# -*- coding: utf-8 -*-
#
import uuid
from django.core.cache import cache
from django.utils.translation import ugettext as _
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView
from rest_framework.views import APIView
from drf_yasg.utils import swagger_auto_schema
from common.utils import get_request_ip, get_logger
from users.utils import (
check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count
)
from ..utils import check_user_valid
from ..signals import post_auth_success, post_auth_failed
from .. import serializers
logger = get_logger(__name__)
__all__ = ['TokenCreateApi']
class AuthFailedError(Exception):
def __init__(self, msg, reason=None):
self.msg = msg
self.reason = reason
class MFARequiredError(Exception):
pass
class TokenCreateApi(CreateAPIView):
permission_classes = (AllowAny,)
serializer_class = serializers.BearerTokenSerializer
@staticmethod
def check_is_block(username, ip):
if is_block_login(username, ip):
msg = _("Log in frequently and try again later")
logger.warn(msg + ': ' + username + ':' + ip)
raise AuthFailedError(msg)
def check_user_valid(self):
request = self.request
username = request.data.get('username', '')
password = request.data.get('password', '')
public_key = request.data.get('public_key', '')
user, msg = check_user_valid(
username=username, password=password,
public_key=public_key
)
if not user:
raise AuthFailedError(msg)
return user
def create(self, request, *args, **kwargs):
username = self.request.data.get('username')
ip = self.request.data.get('remote_addr', None)
ip = ip or get_request_ip(self.request)
user = None
try:
self.check_is_block(username, ip)
user = self.check_user_valid()
if user.otp_enabled:
raise MFARequiredError()
self.send_auth_signal(success=True, user=user)
clean_failed_count(username, ip)
return super().create(request, *args, **kwargs)
except AuthFailedError as e:
increase_login_failed_count(username, ip)
self.send_auth_signal(success=False, user=user, username=username, reason=str(e))
return Response({'msg': str(e)}, status=401)
except MFARequiredError:
msg = _("MFA required")
seed = uuid.uuid4().hex
cache.set(seed, user.username, 300)
resp = {'msg': msg, "choices": ["otp"], "req": seed}
return Response(resp, status=300)
def send_auth_signal(self, success=True, user=None, username='', reason=''):
if success:
post_auth_success.send(
sender=self.__class__, user=user, request=self.request
)
else:
post_auth_failed.send(
sender=self.__class__, username=username,
request=self.request, reason=reason
)
# -*- coding: utf-8 -*-
#
from django.utils.translation import ugettext_lazy as _
password_failed = _('Username/password check failed')
mfa_failed = _('MFA authentication failed')
user_not_exist = _("Username does not exist")
password_expired = _("Password expired")
user_invalid = _('Disabled or expired')
# -*- coding: utf-8 -*-
#
from django.core.cache import cache
from rest_framework import serializers
from users.models import User
......@@ -8,6 +9,7 @@ from .models import AccessKey
__all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer',
]
......@@ -23,24 +25,26 @@ class OtpVerifySerializer(serializers.Serializer):
code = serializers.CharField(max_length=6, min_length=6)
class BearerTokenSerializer(serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField(allow_blank=True, write_only=True)
public_key = serializers.CharField(allow_blank=True, write_only=True)
class BearerTokenMixin(serializers.Serializer):
token = serializers.CharField(read_only=True)
keyword = serializers.SerializerMethodField()
date_expired = serializers.DateTimeField(read_only=True)
@staticmethod
def get_keyword(obj):
return 'Bearer'
def create(self, validated_data):
username = validated_data["username"]
def create_response(self, username):
request = self.context.get("request")
user = User.objects.get(username=username)
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise serializers.ValidationError("username %s not exist" % username)
token, date_expired = user.create_bearer_token(request)
instance = {
"username": validated_data.get(username),
"token": user.create_bearer_token(request),
"username": username,
"token": token,
"date_expired": date_expired,
}
return instance
......@@ -48,3 +52,38 @@ class BearerTokenSerializer(serializers.Serializer):
pass
class BearerTokenSerializer(BearerTokenMixin, serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField(write_only=True, allow_null=True,
required=False)
public_key = serializers.CharField(write_only=True, allow_null=True,
required=False)
def create(self, validated_data):
username = validated_data.get("username")
return self.create_response(username)
class MFAChallengeSerializer(BearerTokenMixin, serializers.Serializer):
req = serializers.CharField(write_only=True)
auth_type = serializers.CharField(write_only=True)
code = serializers.CharField(write_only=True)
def validate_req(self, attr):
username = cache.get(attr)
if not username:
raise serializers.ValidationError("Not valid, may be expired")
self.context["username"] = username
def validate_code(self, code):
username = self.context["username"]
user = User.objects.get(username=username)
ok = user.check_otp(code)
if not ok:
msg = "Otp code not valid, may be expired"
raise serializers.ValidationError(msg)
def create(self, validated_data):
username = self.context["username"]
return self.create_response(username)
......@@ -13,6 +13,8 @@ app_name = 'authentication'
urlpatterns = [
# path('token/', api.UserToken.as_view(), name='user-token'),
path('auth/', api.UserAuthApi.as_view(), name='user-auth'),
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'),
path('connection-token/',
api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('otp/auth/', api.UserOtpAuthApi.as_view(), name='user-otp-auth'),
......
# -*- coding: utf-8 -*-
#
from django.utils.translation import ugettext as _
from common.utils import get_ip_city, validate_ip
from django.contrib.auth import authenticate
from common.utils import get_ip_city, get_object_or_none, validate_ip
from users.models import User
from . import const
def write_login_log(*args, **kwargs):
......@@ -16,3 +20,36 @@ def write_login_log(*args, **kwargs):
kwargs.update({'ip': ip, 'city': city})
UserLoginLog.objects.create(**kwargs)
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email', None)
username = kwargs.pop('username', None)
if username:
user = get_object_or_none(User, username=username)
elif email:
user = get_object_or_none(User, email=email)
else:
user = None
if user is None:
return None, const.user_not_exist
elif not user.is_valid:
return None, const.user_invalid
elif user.password_has_expired:
return None, const.password_expired
if password and authenticate(username=username, password=password):
return user, ''
if public_key and user.public_key:
public_key_saved = user.public_key.split()
if len(public_key_saved) == 1:
if public_key == public_key_saved[0]:
return user, ''
elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]:
return user, ''
return None, const.password_failed
......@@ -202,11 +202,11 @@ class GrantAssetsMixin(LabelFilterMixin):
data.append(asset)
return data
def get_serializer(self, data=None, many=True):
if data is None:
data = []
data = self.get_serializer_queryset(data)
return super().get_serializer(data=data, many=many)
def get_serializer(self, assets_items=None, many=True):
if assets_items is None:
assets_items = []
assets_items = self.get_serializer_queryset(assets_items)
return super().get_serializer(assets_items, many=many)
def filter_queryset_by_id(self, assets_items):
i = self.request.query_params.get("id")
......
......@@ -140,11 +140,11 @@ class UserGrantedNodesApi(UserPermissionCacheMixin, NodesWithUngroupMixin, ListA
_nodes.append(node)
return _nodes
def get_serializer(self, data=None, many=True):
if data is None:
data = []
nodes = self.get_nodes(data)
return super().get_serializer(data=nodes, many=True)
def get_serializer(self, nodes_items=None, many=True):
if nodes_items is None:
nodes_items = []
nodes = self.get_nodes(nodes_items)
return super().get_serializer(nodes, many=True)
def get_queryset(self):
user = self.get_object()
......@@ -267,11 +267,11 @@ class UserGrantedNodesWithAssetsApi(UserPermissionCacheMixin, NodesWithUngroupMi
queryset.append(node)
return queryset
def get_serializer(self, data=None, many=True):
if data is None:
data = []
queryset = self.get_serializer_queryset(data)
return super().get_serializer(data=queryset, many=many)
def get_serializer(self, nodes_items=None, many=True):
if nodes_items is None:
nodes_items = []
queryset = self.get_serializer_queryset(nodes_items)
return super().get_serializer(queryset, many=many)
def get_queryset(self):
user = self.get_object()
......@@ -298,10 +298,10 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserGrantedNodesWithAssetsApi):
assets_only_fields = ParserNode.assets_only_fields
system_users_only_fields = ParserNode.system_users_only_fields
def get_serializer(self, data=None, many=True):
if data is None:
data = []
_queryset = super().get_serializer_queryset(data)
def get_serializer(self, nodes_items=None, many=True):
if nodes_items is None:
nodes_items = []
_queryset = super().get_serializer_queryset(nodes_items)
queryset = []
for node in _queryset:
......
......@@ -102,13 +102,14 @@ class PermAssetsAmountUtil(PermStackUtilMixin):
self.debug("出栈: {} 栈顶: {}".format(
_node['key'], self.stack.top['key'] if self.stack.top else None)
)
_node["assets_amount"] = len(_node["all_assets"] | _node["assets"])
_node["all_assets"] = _node["all_assets"] | _node["assets"]
_node["assets_amount"] = len(_node["all_assets"])
self._nodes[_node.pop("key")] = _node
if not self.stack.top:
return
self.stack.top["all_assets"]\
.update(_node["all_assets"] | _node["assets"])
.update(_node["all_assets"])
def compute_nodes_assets_amount(self, nodes_with_assets):
self.stack = Stack()
......
......@@ -232,7 +232,8 @@ class TokenMixin:
token = uuid.uuid4().hex
cache.set(token, self.id, expiration)
cache.set('%s_%s' % (self.id, remote_addr), token, expiration)
return token
date_expired = timezone.now() + timezone.timedelta(seconds=expiration)
return token, date_expired
def refresh_bearer_token(self, token):
pass
......
......@@ -7,18 +7,14 @@ import pyotp
import base64
import logging
import ipaddress
from django.http import Http404
from django.conf import settings
from django.contrib.auth.mixins import UserPassesTestMixin
from django.contrib.auth import authenticate
from django.utils.translation import ugettext as _
from django.core.cache import cache
from datetime import datetime
from common.tasks import send_mail_async
from common.utils import reverse, get_object_or_none, get_ip_city
from .models import User
from common.utils import reverse
logger = logging.getLogger('jumpserver')
......@@ -177,37 +173,6 @@ def send_reset_ssh_key_mail(user):
send_mail_async.delay(subject, message, recipient_list, html_message=message)
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email', None)
username = kwargs.pop('username', None)
if username:
user = get_object_or_none(User, username=username)
elif email:
user = get_object_or_none(User, email=email)
else:
user = None
if user is None:
return None, _('User not exist')
elif not user.is_valid:
return None, _('Disabled or expired')
if password and authenticate(username=username, password=password):
return user, ''
if public_key and user.public_key:
public_key_saved = user.public_key.split()
if len(public_key_saved) == 1:
if public_key == public_key_saved[0]:
return user, ''
elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]:
return user, ''
return None, _('Password or SSH public key invalid')
def get_user_or_tmp_user(request):
user = request.user
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment