Unverified Commit 9cfcadc2 authored by 老广's avatar 老广 Committed by GitHub

服务账号注册机制更改 (#2079)

* [Update] 服务账号注册

* [Update] 修改settings配置

* [Update] 修改settings

* [Update] 整理terminal api

* [Update] 修改terminal api

* [Update] 修改terminal注册机制
parent 363985ee
import json import json
import ldap
from django.db import models from django.db import models
from django.core.cache import cache from django.core.cache import cache
from django.db.utils import ProgrammingError, OperationalError from django.db.utils import ProgrammingError, OperationalError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.conf import settings from django.conf import settings
from django_auth_ldap.config import LDAPSearch, LDAPSearchUnion
from .utils import get_signer from .utils import get_signer
......
...@@ -5,6 +5,7 @@ from rest_framework import permissions ...@@ -5,6 +5,7 @@ from rest_framework import permissions
from django.contrib.auth.mixins import UserPassesTestMixin from django.contrib.auth.mixins import UserPassesTestMixin
from django.shortcuts import redirect from django.shortcuts import redirect
from django.http.response import HttpResponseForbidden from django.http.response import HttpResponseForbidden
from django.conf import settings
from orgs.utils import current_org from orgs.utils import current_org
...@@ -96,3 +97,12 @@ class SuperUserRequiredMixin(UserPassesTestMixin): ...@@ -96,3 +97,12 @@ class SuperUserRequiredMixin(UserPassesTestMixin):
def test_func(self): def test_func(self):
if self.request.user.is_authenticated and self.request.user.is_superuser: if self.request.user.is_authenticated and self.request.user.is_superuser:
return True return True
class WithBootstrapToken(permissions.BasePermission):
def has_permission(self, request, view):
authorization = request.META.get('HTTP_AUTHORIZATION', '')
if not authorization:
return False
request_bootstrap_token = authorization.split()[-1]
return settings.BOOTSTRAP_TOKEN == request_bootstrap_token
...@@ -30,6 +30,9 @@ CONFIG = load_user_config() ...@@ -30,6 +30,9 @@ CONFIG = load_user_config()
# SECURITY WARNING: keep the secret key used in production secret! # SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = CONFIG.SECRET_KEY SECRET_KEY = CONFIG.SECRET_KEY
# SECURITY WARNING: keep the token secret, remove it if all coco, guacamole ok
BOOTSTRAP_TOKEN = CONFIG.BOOTSTRAP_TOKEN
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = CONFIG.DEBUG DEBUG = CONFIG.DEBUG
...@@ -499,9 +502,11 @@ USER_GUIDE_URL = "" ...@@ -499,9 +502,11 @@ USER_GUIDE_URL = ""
SWAGGER_SETTINGS = { SWAGGER_SETTINGS = {
'DEFAULT_AUTO_SCHEMA_CLASS': 'jumpserver.swagger.CustomSwaggerAutoSchema',
'SECURITY_DEFINITIONS': { 'SECURITY_DEFINITIONS': {
'basic': { 'basic': {
'type': 'basic' 'type': 'basic'
} }
}, },
} }
from drf_yasg.inspectors import SwaggerAutoSchema
from rest_framework import permissions
from drf_yasg.views import get_schema_view
from drf_yasg import openapi
class CustomSwaggerAutoSchema(SwaggerAutoSchema):
def get_tags(self, operation_keys):
if len(operation_keys) > 2 and operation_keys[1].startswith('v'):
return [operation_keys[2]]
return super().get_tags(operation_keys)
def get_swagger_view(version='v1'):
from .urls import api_v1_patterns, api_v2_patterns
if version == "v2":
patterns = api_v2_patterns
else:
patterns = api_v1_patterns
schema_view = get_schema_view(
openapi.Info(
title="Jumpserver API Docs",
default_version=version,
description="Jumpserver Restful api docs",
terms_of_service="https://www.jumpserver.org",
contact=openapi.Contact(email="support@fit2cloud.com"),
license=openapi.License(name="GPLv2 License"),
),
public=True,
patterns=patterns,
permission_classes=(permissions.AllowAny,),
)
return schema_view
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import os
from django.urls import path, include, re_path from django.urls import path, include, re_path
from django.conf import settings from django.conf import settings
from django.conf.urls.static import static from django.conf.urls.static import static
from django.conf.urls.i18n import i18n_patterns from django.conf.urls.i18n import i18n_patterns
from django.views.i18n import JavaScriptCatalog from django.views.i18n import JavaScriptCatalog
from rest_framework.response import Response
from django.views.decorators.csrf import csrf_exempt
from django.http import HttpResponse
from django.utils.encoding import iri_to_uri
from rest_framework import permissions
from drf_yasg.views import get_schema_view
from drf_yasg import openapi
from .views import IndexView, LunaView, I18NView from .views import IndexView, LunaView, I18NView
from .swagger import get_swagger_view
api_v1_patterns = [
path('api/', include([
path('users/v1/', include('users.urls.api_urls', namespace='api-users')),
path('assets/v1/', include('assets.urls.api_urls', namespace='api-assets')),
path('perms/v1/', include('perms.urls.api_urls', namespace='api-perms')),
path('terminal/v1/', include('terminal.urls.api_urls', namespace='api-terminal')),
path('ops/v1/', include('ops.urls.api_urls', namespace='api-ops')),
path('audits/v1/', include('audits.urls.api_urls', namespace='api-audits')),
path('orgs/v1/', include('orgs.urls.api_urls', namespace='api-orgs')),
path('common/v1/', include('common.urls.api_urls', namespace='api-common')),
]))
]
schema_view = get_schema_view( api_v2_patterns = [
openapi.Info( path('api/', include([
title="Jumpserver API Docs", path('terminal/v2/', include('terminal.urls.api_urls_v2', namespace='api-terminal-v2')),
default_version='v1', path('users/v2/', include('users.urls.api_urls_v2', namespace='api-users-v2')),
description="Jumpserver Restful api docs", ]))
terms_of_service="https://www.jumpserver.org",
contact=openapi.Contact(email="support@fit2cloud.com"),
license=openapi.License(name="GPLv2 License"),
),
public=True,
permission_classes=(permissions.AllowAny,),
)
api_url_pattern = re.compile(r'^/api/(?P<version>\w+)/(?P<app>\w+)/(?P<extra>.*)$')
class HttpResponseTemporaryRedirect(HttpResponse):
status_code = 307
def __init__(self, redirect_to):
HttpResponse.__init__(self)
self['Location'] = iri_to_uri(redirect_to)
@csrf_exempt
def redirect_format_api(request, *args, **kwargs):
_path, query = request.path, request.GET.urlencode()
matched = api_url_pattern.match(_path)
if matched:
version, app, extra = matched.groups()
_path = '/api/{app}/{version}/{extra}?{query}'.format(**{
"app": app, "version": version, "extra": extra,
"query": query
})
return HttpResponseTemporaryRedirect(_path)
else:
return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404)
v1_api_patterns = [
path('users/v1/', include('users.urls.api_urls', namespace='api-users')),
path('assets/v1/', include('assets.urls.api_urls', namespace='api-assets')),
path('perms/v1/', include('perms.urls.api_urls', namespace='api-perms')),
path('terminal/v1/', include('terminal.urls.api_urls', namespace='api-terminal')),
path('ops/v1/', include('ops.urls.api_urls', namespace='api-ops')),
path('audits/v1/', include('audits.urls.api_urls', namespace='api-audits')),
path('orgs/v1/', include('orgs.urls.api_urls', namespace='api-orgs')),
path('common/v1/', include('common.urls.api_urls', namespace='api-common')),
] ]
app_view_patterns = [ app_view_patterns = [
...@@ -78,6 +42,7 @@ app_view_patterns = [ ...@@ -78,6 +42,7 @@ app_view_patterns = [
path('auth/', include('authentication.urls.view_urls'), name='auth'), path('auth/', include('authentication.urls.view_urls'), name='auth'),
] ]
if settings.XPACK_ENABLED: if settings.XPACK_ENABLED:
app_view_patterns.append(path('xpack/', include('xpack.urls', namespace='xpack'))) app_view_patterns.append(path('xpack/', include('xpack.urls', namespace='xpack')))
...@@ -87,12 +52,13 @@ js_i18n_patterns = i18n_patterns( ...@@ -87,12 +52,13 @@ js_i18n_patterns = i18n_patterns(
urlpatterns = [ urlpatterns = [
path('', IndexView.as_view(), name='index'), path('', IndexView.as_view(), name='index'),
path('', include(api_v2_patterns)),
path('', include(api_v1_patterns)),
path('luna/', LunaView.as_view(), name='luna-error'), path('luna/', LunaView.as_view(), name='luna-error'),
path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'), path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'),
path('settings/', include('common.urls.view_urls', namespace='settings')), path('settings/', include('common.urls.view_urls', namespace='settings')),
path('common/', include('common.urls.view_urls', namespace='common')), path('common/', include('common.urls.view_urls', namespace='common')),
path('api/v1/', redirect_format_api), # path('api/v2/', include(api_v2_patterns)),
path('api/', include(v1_api_patterns)),
# External apps url # External apps url
path('captcha/', include('captcha.urls')), path('captcha/', include('captcha.urls')),
...@@ -104,7 +70,13 @@ urlpatterns += js_i18n_patterns ...@@ -104,7 +70,13 @@ urlpatterns += js_i18n_patterns
if settings.DEBUG: if settings.DEBUG:
urlpatterns += [ urlpatterns += [
re_path('swagger(?P<format>\.json|\.yaml)$', schema_view.without_ui(cache_timeout=None), name='schema-json'), re_path('^swagger(?P<format>\.json|\.yaml)$',
path('docs/', schema_view.with_ui('swagger', cache_timeout=None), name="docs"), get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
path('redoc/', schema_view.with_ui('redoc', cache_timeout=None), name='redoc'), path('docs/', get_swagger_view().with_ui('swagger', cache_timeout=1), name="docs"),
path('redoc/', get_swagger_view().with_ui('redoc', cache_timeout=1), name='redoc'),
re_path('^v2/swagger(?P<format>\.json|\.yaml)$',
get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
path('docs/v2/', get_swagger_view("v2").with_ui('swagger', cache_timeout=1), name="docs"),
path('redoc/v2/', get_swagger_view("v2").with_ui('redoc', cache_timeout=1), name='redoc'),
] ]
import datetime import datetime
import re
from django.http import HttpResponse, HttpResponseRedirect from django.http import HttpResponse, HttpResponseRedirect
from django.conf import settings from django.conf import settings
...@@ -8,6 +9,10 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -8,6 +9,10 @@ from django.utils.translation import ugettext_lazy as _
from django.db.models import Count from django.db.models import Count
from django.shortcuts import redirect from django.shortcuts import redirect
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from rest_framework.response import Response
from django.views.decorators.csrf import csrf_exempt
from django.http import HttpResponse
from django.utils.encoding import iri_to_uri
from users.models import User from users.models import User
from assets.models import Asset from assets.models import Asset
...@@ -188,3 +193,29 @@ class I18NView(View): ...@@ -188,3 +193,29 @@ class I18NView(View):
response = HttpResponseRedirect(referer_url) response = HttpResponseRedirect(referer_url)
response.set_cookie(settings.LANGUAGE_COOKIE_NAME, lang) response.set_cookie(settings.LANGUAGE_COOKIE_NAME, lang)
return response return response
api_url_pattern = re.compile(r'^/api/(?P<version>\w+)/(?P<app>\w+)/(?P<extra>.*)$')
class HttpResponseTemporaryRedirect(HttpResponse):
status_code = 307
def __init__(self, redirect_to):
HttpResponse.__init__(self)
self['Location'] = iri_to_uri(redirect_to)
@csrf_exempt
def redirect_format_api(request, *args, **kwargs):
_path, query = request.path, request.GET.urlencode()
matched = api_url_pattern.match(_path)
if matched:
version, app, extra = matched.groups()
_path = '/api/{app}/{version}/{extra}?{query}'.format(**{
"app": app, "version": version, "extra": extra,
"query": query
})
return HttpResponseTemporaryRedirect(_path)
else:
return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404)
...@@ -19,13 +19,13 @@ class TaskViewSet(viewsets.ModelViewSet): ...@@ -19,13 +19,13 @@ class TaskViewSet(viewsets.ModelViewSet):
queryset = Task.objects.all() queryset = Task.objects.all()
serializer_class = TaskSerializer serializer_class = TaskSerializer
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
label = None # label = None
help_text = '' # help_text = ''
class TaskRun(generics.RetrieveAPIView): class TaskRun(generics.RetrieveAPIView):
queryset = Task.objects.all() queryset = Task.objects.all()
serializer_class = TaskViewSet # serializer_class = TaskViewSet
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
......
# -*- coding: utf-8 -*-
#
# -*- coding: utf-8 -*-
#
from .terminal import *
from .session import *
from .task import *
# -*- coding: utf-8 -*-
#
import logging
from rest_framework.views import APIView, Response
from rest_framework_bulk import BulkModelViewSet
from common.utils import get_object_or_none
from common.permissions import IsOrgAdminOrAppUser
from ...models import Session, Task
from ...serializers import v1 as serializers
__all__ = ['TaskViewSet', 'KillSessionAPI']
logger = logging.getLogger(__file__)
class TaskViewSet(BulkModelViewSet):
queryset = Task.objects.all()
serializer_class = serializers.TaskSerializer
permission_classes = (IsOrgAdminOrAppUser,)
class KillSessionAPI(APIView):
permission_classes = (IsOrgAdminOrAppUser,)
model = Task
def post(self, request, *args, **kwargs):
validated_session = []
for session_id in request.data:
session = get_object_or_none(Session, id=session_id)
if session and not session.is_finished:
validated_session.append(session_id)
self.model.objects.create(
name="kill_session", args=session.id,
terminal=session.terminal,
)
return Response({"ok": validated_session})
# -*- coding: utf-8 -*-
#
from collections import OrderedDict
import logging
import uuid
from django.core.cache import cache
from django.shortcuts import get_object_or_404, redirect
from django.utils import timezone
from rest_framework import viewsets
from rest_framework.views import APIView, Response
from rest_framework.permissions import AllowAny
from common.utils import get_object_or_none
from common.permissions import IsAppUser, IsOrgAdminOrAppUser, IsSuperUser
from ...models import Terminal, Status, Session
from ...serializers import v1 as serializers
__all__ = [
'TerminalViewSet', 'TerminalTokenApi', 'StatusViewSet', 'TerminalConfig',
]
logger = logging.getLogger(__file__)
class TerminalViewSet(viewsets.ModelViewSet):
queryset = Terminal.objects.filter(is_deleted=False)
serializer_class = serializers.TerminalSerializer
permission_classes = (IsSuperUser,)
def create(self, request, *args, **kwargs):
name = request.data.get('name')
remote_ip = request.META.get('REMOTE_ADDR')
x_real_ip = request.META.get('X-Real-IP')
remote_addr = x_real_ip or remote_ip
terminal = get_object_or_none(Terminal, name=name, is_deleted=False)
if terminal:
msg = 'Terminal name %s already used' % name
return Response({'msg': msg}, status=409)
serializer = self.serializer_class(data={
'name': name, 'remote_addr': remote_addr
})
if serializer.is_valid():
terminal = serializer.save()
# App should use id, token get access key, if accepted
token = uuid.uuid4().hex
cache.set(token, str(terminal.id), 3600)
data = {"id": str(terminal.id), "token": token, "msg": "Need accept"}
return Response(data, status=201)
else:
data = serializer.errors
logger.error("Register terminal error: {}".format(data))
return Response(data, status=400)
def get_permissions(self):
if self.action == "create":
self.permission_classes = (AllowAny,)
return super().get_permissions()
class TerminalTokenApi(APIView):
permission_classes = (AllowAny,)
queryset = Terminal.objects.filter(is_deleted=False)
def get(self, request, *args, **kwargs):
try:
terminal = self.queryset.get(id=kwargs.get('terminal'))
except Terminal.DoesNotExist:
terminal = None
token = request.query_params.get("token")
if terminal is None:
return Response('May be reject by administrator', status=401)
if token is None or cache.get(token, "") != str(terminal.id):
return Response('Token is not valid', status=401)
if not terminal.is_accepted:
return Response("Terminal was not accepted yet", status=400)
if not terminal.user or not terminal.user.access_key.all():
return Response("No access key generate", status=401)
access_key = terminal.user.access_key.first()
data = OrderedDict()
data['access_key'] = {'id': access_key.id, 'secret': access_key.secret}
return Response(data, status=200)
class StatusViewSet(viewsets.ModelViewSet):
queryset = Status.objects.all()
serializer_class = serializers.StatusSerializer
permission_classes = (IsOrgAdminOrAppUser,)
session_serializer_class = serializers.SessionSerializer
task_serializer_class = serializers.TaskSerializer
def create(self, request, *args, **kwargs):
from_gua = self.request.query_params.get("from_guacamole", None)
if not from_gua:
self.handle_sessions()
super().create(request, *args, **kwargs)
tasks = self.request.user.terminal.task_set.filter(is_finished=False)
serializer = self.task_serializer_class(tasks, many=True)
return Response(serializer.data, status=201)
def handle_sessions(self):
sessions_active = []
for session_data in self.request.data.get("sessions", []):
self.create_or_update_session(session_data)
if not session_data["is_finished"]:
sessions_active.append(session_data["id"])
sessions_in_db_active = Session.objects.filter(
is_finished=False,
terminal=self.request.user.terminal.id
)
for session in sessions_in_db_active:
if str(session.id) not in sessions_active:
session.is_finished = True
session.date_end = timezone.now()
session.save()
def create_or_update_session(self, session_data):
session_data["terminal"] = self.request.user.terminal.id
_id = session_data["id"]
session = get_object_or_none(Session, id=_id)
if session:
serializer = serializers.SessionSerializer(
data=session_data, instance=session
)
else:
serializer = serializers.SessionSerializer(data=session_data)
if serializer.is_valid():
session = serializer.save()
return session
else:
msg = "session data is not valid {}: {}".format(
serializer.errors, str(serializer.data)
)
logger.error(msg)
return None
def get_queryset(self):
terminal_id = self.kwargs.get("terminal", None)
if terminal_id:
terminal = get_object_or_404(Terminal, id=terminal_id)
self.queryset = terminal.status_set.all()
return self.queryset
def perform_create(self, serializer):
serializer.validated_data["terminal"] = self.request.user.terminal
return super().perform_create(serializer)
def get_permissions(self):
if self.action == "create":
self.permission_classes = (IsAppUser,)
return super().get_permissions()
class TerminalConfig(APIView):
permission_classes = (IsAppUser,)
def get(self, request):
user = request.user
terminal = user.terminal
configs = terminal.config
return Response(configs, status=200)
\ No newline at end of file
# -*- coding: utf-8 -*-
#
from .terminal import *
# -*- coding: utf-8 -*-
#
from rest_framework import viewsets
from common.permissions import IsSuperUser, WithBootstrapToken
from ...models import Terminal
from ...serializers import v2 as serializers
__all__ = ['TerminalViewSet', 'TerminalRegistrationViewSet']
class TerminalViewSet(viewsets.ModelViewSet):
queryset = Terminal.objects.filter(is_deleted=False)
serializer_class = serializers.TerminalSerializer
permission_classes = [IsSuperUser]
class TerminalRegistrationViewSet(viewsets.ModelViewSet):
queryset = Terminal.objects.filter(is_deleted=False)
serializer_class = serializers.TerminalRegistrationSerializer
permission_classes = [WithBootstrapToken]
http_method_names = ['post']
...@@ -16,7 +16,7 @@ from .backends.command.models import AbstractSessionCommand ...@@ -16,7 +16,7 @@ from .backends.command.models import AbstractSessionCommand
class Terminal(models.Model): class Terminal(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=32, verbose_name=_('Name')) name = models.CharField(max_length=32, verbose_name=_('Name'))
remote_addr = models.CharField(max_length=128, verbose_name=_('Remote Address')) remote_addr = models.CharField(max_length=128, blank=True, verbose_name=_('Remote Address'))
ssh_port = models.IntegerField(verbose_name=_('SSH Port'), default=2222) ssh_port = models.IntegerField(verbose_name=_('SSH Port'), default=2222)
http_port = models.IntegerField(verbose_name=_('HTTP Port'), default=5000) http_port = models.IntegerField(verbose_name=_('HTTP Port'), default=5000)
command_storage = models.CharField(max_length=128, verbose_name=_("Command storage"), default='default') command_storage = models.CharField(max_length=128, verbose_name=_("Command storage"), default='default')
...@@ -68,6 +68,10 @@ class Terminal(models.Model): ...@@ -68,6 +68,10 @@ class Terminal(models.Model):
}) })
return configs return configs
@property
def service_account(self):
return self.user
def create_app_user(self): def create_app_user(self):
random = uuid.uuid4().hex[:6] random = uuid.uuid4().hex[:6]
user, access_key = User.create_app_user( user, access_key = User.create_app_user(
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.core.cache import cache from django.core.cache import cache
from django.utils import timezone
from rest_framework import serializers from rest_framework import serializers
from rest_framework_bulk.serializers import BulkListSerializer from rest_framework_bulk.serializers import BulkListSerializer
from common.mixins import BulkSerializerMixin from common.mixins import BulkSerializerMixin
from .models import Terminal, Status, Session, Task from ..models import Terminal, Status, Session, Task
from .backends import get_multi_command_storage from ..backends import get_multi_command_storage
class TerminalSerializer(serializers.ModelSerializer): class TerminalSerializer(serializers.ModelSerializer):
...@@ -33,6 +31,8 @@ class TerminalSerializer(serializers.ModelSerializer): ...@@ -33,6 +31,8 @@ class TerminalSerializer(serializers.ModelSerializer):
return cache.get(key) return cache.get(key)
class SessionSerializer(BulkSerializerMixin, serializers.ModelSerializer): class SessionSerializer(BulkSerializerMixin, serializers.ModelSerializer):
command_amount = serializers.SerializerMethodField() command_amount = serializers.SerializerMethodField()
command_store = get_multi_command_storage() command_store = get_multi_command_storage()
...@@ -71,3 +71,4 @@ class TaskSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -71,3 +71,4 @@ class TaskSerializer(BulkSerializerMixin, serializers.ModelSerializer):
class ReplaySerializer(serializers.Serializer): class ReplaySerializer(serializers.Serializer):
file = serializers.FileField() file = serializers.FileField()
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from common.utils import get_request_ip
from users.serializers.v2 import ServiceAccountRegistrationSerializer
from ..models import Terminal
__all__ = ['TerminalSerializer', 'TerminalRegistrationSerializer']
class TerminalSerializer(serializers.ModelSerializer):
class Meta:
model = Terminal
fields = [
'id', 'name', 'remote_addr', 'comment',
]
read_only_fields = ['id', 'remote_addr']
class TerminalRegistrationSerializer(serializers.ModelSerializer):
service_account = ServiceAccountRegistrationSerializer(read_only=True)
service_account_serializer = None
class Meta:
model = Terminal
fields = [
'id', 'name', 'remote_addr', 'comment', 'service_account'
]
read_only_fields = ['id', 'remote_addr', 'service_account']
def validate(self, attrs):
self.service_account_serializer = ServiceAccountRegistrationSerializer(data=attrs)
self.service_account_serializer.is_valid(raise_exception=True)
return attrs
def create(self, validated_data):
request = self.context.get('request')
sa = self.service_account_serializer.save()
instance = super().create(validated_data)
instance.is_accepted = True
instance.user = sa
instance.remote_addr = get_request_ip(request)
instance.save()
return instance
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.urls import path from django.urls import path, include
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from .. import api from ..api import v1 as api
app_name = 'terminal' app_name = 'terminal'
...@@ -20,7 +20,7 @@ router.register(r'status', api.StatusViewSet, 'status') ...@@ -20,7 +20,7 @@ router.register(r'status', api.StatusViewSet, 'status')
urlpatterns = [ urlpatterns = [
path('sessions/<uuid:pk>/replay/', path('sessions/<uuid:pk>/replay/',
api.SessionReplayV2ViewSet.as_view({'get': 'retrieve', 'post': 'create'}), api.SessionReplayViewSet.as_view({'get': 'retrieve', 'post': 'create'}),
name='session-replay'), name='session-replay'),
path('tasks/kill-session/', api.KillSessionAPI.as_view(), name='kill-session'), path('tasks/kill-session/', api.KillSessionAPI.as_view(), name='kill-session'),
path('terminal/<uuid:terminal>/access-key/', api.TerminalTokenApi.as_view(), path('terminal/<uuid:terminal>/access-key/', api.TerminalTokenApi.as_view(),
...@@ -33,3 +33,6 @@ urlpatterns = [ ...@@ -33,3 +33,6 @@ urlpatterns = [
] ]
urlpatterns += router.urls urlpatterns += router.urls
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
from django.urls import path
from rest_framework_bulk.routes import BulkRouter
from ..api import v2 as api
app_name = 'terminal'
router = BulkRouter()
router.register(r'terminal', api.TerminalViewSet, 'terminal')
router.register(r'terminal-registrations', api.TerminalRegistrationViewSet, 'terminal-registration')
urlpatterns = [
]
urlpatterns += router.urls
...@@ -17,8 +17,8 @@ from orgs.mixins import RootOrgViewMixin ...@@ -17,8 +17,8 @@ from orgs.mixins import RootOrgViewMixin
from ..serializers import UserSerializer from ..serializers import UserSerializer
from ..tasks import write_login_log_async from ..tasks import write_login_log_async
from ..models import User, LoginLog from ..models import User, LoginLog
from ..utils import check_user_valid, generate_token, \ from ..utils import check_user_valid, check_otp_code, \
check_otp_code, increase_login_failed_count, is_block_login, \ increase_login_failed_count, is_block_login, \
clean_failed_count clean_failed_count
from ..hands import Asset, SystemUser from ..hands import Asset, SystemUser
...@@ -79,7 +79,7 @@ class UserAuthApi(RootOrgViewMixin, APIView): ...@@ -79,7 +79,7 @@ class UserAuthApi(RootOrgViewMixin, APIView):
self.write_login_log(request, data) self.write_login_log(request, data)
# 登陆成功,清除原来的缓存计数 # 登陆成功,清除原来的缓存计数
clean_failed_count(username, ip) clean_failed_count(username, ip)
token = generate_token(request, user) token = user.create_bearer_token(request)
return Response( return Response(
{'token': token, 'user': self.serializer_class(user).data} {'token': token, 'user': self.serializer_class(user).data}
) )
...@@ -123,7 +123,6 @@ class UserAuthApi(RootOrgViewMixin, APIView): ...@@ -123,7 +123,6 @@ class UserAuthApi(RootOrgViewMixin, APIView):
'user_agent': user_agent, 'user_agent': user_agent,
} }
data.update(tmp_data) data.update(tmp_data)
write_login_log_async.delay(**data) write_login_log_async.delay(**data)
...@@ -185,7 +184,7 @@ class UserToken(APIView): ...@@ -185,7 +184,7 @@ class UserToken(APIView):
user = request.user user = request.user
msg = None msg = None
if user: if user:
token = generate_token(request, user) token = user.create_bearer_token(request)
return Response({'Token': token, 'Keyword': 'Bearer'}, status=200) return Response({'Token': token, 'Keyword': 'Bearer'}, status=200)
else: else:
return Response({'error': msg}, status=406) return Response({'error': msg}, status=406)
...@@ -223,7 +222,7 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView): ...@@ -223,7 +222,7 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView):
'status': True 'status': True
} }
self.write_login_log(request, data) self.write_login_log(request, data)
token = generate_token(request, user) token = user.create_bearer_token(request)
return Response( return Response(
{ {
'token': token, 'token': token,
......
...@@ -11,13 +11,14 @@ from rest_framework.permissions import IsAuthenticated ...@@ -11,13 +11,14 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework_bulk import BulkModelViewSet from rest_framework_bulk import BulkModelViewSet
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from common.permissions import IsOrgAdmin, IsCurrentUserOrReadOnly, \
IsOrgAdminOrAppUser
from common.mixins import IDInFilterMixin
from common.utils import get_logger
from orgs.utils import current_org
from ..serializers import UserSerializer, UserPKUpdateSerializer, \ from ..serializers import UserSerializer, UserPKUpdateSerializer, \
UserUpdateGroupSerializer, ChangeUserPasswordSerializer UserUpdateGroupSerializer, ChangeUserPasswordSerializer
from ..models import User from ..models import User
from orgs.utils import current_org
from common.permissions import IsOrgAdmin, IsCurrentUserOrReadOnly, IsOrgAdminOrAppUser
from common.mixins import IDInFilterMixin
from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -31,15 +32,16 @@ __all__ = [ ...@@ -31,15 +32,16 @@ __all__ = [
class UserViewSet(IDInFilterMixin, BulkModelViewSet): class UserViewSet(IDInFilterMixin, BulkModelViewSet):
filter_fields = ('username', 'email', 'name', 'id') filter_fields = ('username', 'email', 'name', 'id')
search_fields = filter_fields search_fields = filter_fields
queryset = User.objects.exclude(role="App") queryset = User.objects.all()
serializer_class = UserSerializer serializer_class = UserSerializer
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
pagination_class = LimitOffsetPagination pagination_class = LimitOffsetPagination
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
org_users = current_org.get_org_users() if current_org.is_real():
queryset = queryset.filter(id__in=org_users) org_users = current_org.get_org_users()
queryset = queryset.filter(id__in=org_users)
return queryset return queryset
def get_permissions(self): def get_permissions(self):
......
# -*- coding: utf-8 -*-
#
from .user import *
# -*- coding: utf-8 -*-
#
from rest_framework import viewsets
from common.permissions import WithBootstrapToken
from ...serializers import v2 as serializers
class ServiceAccountRegistrationViewSet(viewsets.ModelViewSet):
serializer_class = serializers.ServiceAccountRegistrationSerializer
permission_classes = (WithBootstrapToken,)
http_method_names = ['post']
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import base64
import uuid import uuid
import hashlib
import time import time
from django.core.cache import cache from django.core.cache import cache
...@@ -12,11 +10,10 @@ from django.utils.translation import ugettext as _ ...@@ -12,11 +10,10 @@ from django.utils.translation import ugettext as _
from django.utils.six import text_type from django.utils.six import text_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import authentication, exceptions, permissions from rest_framework import authentication, exceptions
from rest_framework.authentication import CSRFCheck from rest_framework.authentication import CSRFCheck
from common.utils import get_object_or_none, make_signature, http_to_unixtime from common.utils import get_object_or_none, make_signature, http_to_unixtime
from .utils import refresh_token
from .models import User, AccessKey, PrivateToken from .models import User, AccessKey, PrivateToken
...@@ -144,7 +141,6 @@ class AccessTokenAuthentication(authentication.BaseAuthentication): ...@@ -144,7 +141,6 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
if not user: if not user:
msg = _('Invalid token or cache refreshed.') msg = _('Invalid token or cache refreshed.')
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
refresh_token(token, user)
return user, None return user, None
......
...@@ -17,7 +17,7 @@ class AccessKey(models.Model): ...@@ -17,7 +17,7 @@ class AccessKey(models.Model):
secret = models.UUIDField(verbose_name='AccessKeySecret', secret = models.UUIDField(verbose_name='AccessKeySecret',
default=uuid.uuid4, editable=False) default=uuid.uuid4, editable=False)
user = models.ForeignKey(User, verbose_name='User', user = models.ForeignKey(User, verbose_name='User',
on_delete=models.CASCADE, related_name='access_key') on_delete=models.CASCADE, related_name='access_keys')
def get_id(self): def get_id(self):
return str(self.id) return str(self.id)
...@@ -25,6 +25,9 @@ class AccessKey(models.Model): ...@@ -25,6 +25,9 @@ class AccessKey(models.Model):
def get_secret(self): def get_secret(self):
return str(self.secret) return str(self.secret)
def get_full_value(self):
return '{}:{}'.format(self.id, self.secret)
def __str__(self): def __str__(self):
return str(self.id) return str(self.id)
......
...@@ -2,19 +2,20 @@ ...@@ -2,19 +2,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid import uuid
import base64
from collections import OrderedDict from collections import OrderedDict
from django.conf import settings from django.conf import settings
from django.contrib.auth.hashers import make_password from django.contrib.auth.hashers import make_password
from django.contrib.auth.models import AbstractUser, UserManager from django.contrib.auth.models import AbstractUser
from django.core import signing from django.core import signing
from django.core.cache import cache
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils import timezone from django.utils import timezone
from django.shortcuts import reverse from django.shortcuts import reverse
from common.utils import get_signer, date_expired_default from common.utils import get_signer, date_expired_default
from orgs.mixins import OrgManager
from orgs.utils import current_org from orgs.utils import current_org
...@@ -274,15 +275,38 @@ class User(AbstractUser): ...@@ -274,15 +275,38 @@ class User(AbstractUser):
token = PrivateToken.objects.create(user=self) token = PrivateToken.objects.create(user=self)
return token.key return token.key
def refresh_private_token(self):
from .authentication import PrivateToken
PrivateToken.objects.filter(user=self).delete()
return PrivateToken.objects.create(user=self)
def create_bearer_token(self, request=None):
expiration = settings.TOKEN_EXPIRATION or 3600
if request:
remote_addr = request.META.get('REMOTE_ADDR', '')
else:
remote_addr = '0.0.0.0'
if not isinstance(remote_addr, bytes):
remote_addr = remote_addr.encode("utf-8")
remote_addr = base64.b16encode(remote_addr) # .replace(b'=', '')
token = cache.get('%s_%s' % (self.id, remote_addr))
if not token:
token = uuid.uuid4().hex
cache.set(token, self.id, expiration)
cache.set('%s_%s' % (self.id, remote_addr), token, expiration)
return token
def refresh_bearer_token(self, token):
pass
def create_access_key(self): def create_access_key(self):
from . import AccessKey from . import AccessKey
access_key = AccessKey.objects.create(user=self) access_key = AccessKey.objects.create(user=self)
return access_key return access_key
def refresh_private_token(self): @property
from .authentication import PrivateToken def access_key(self):
PrivateToken.objects.filter(user=self).delete() return self.access_keys.first()
return PrivateToken.objects.create(user=self)
def is_member_of(self, user_group): def is_member_of(self, user_group):
if user_group in self.groups.all(): if user_group in self.groups.all():
...@@ -345,7 +369,8 @@ class User(AbstractUser): ...@@ -345,7 +369,8 @@ class User(AbstractUser):
'phone': self.phone, 'phone': self.phone,
'otp_level': self.otp_level, 'otp_level': self.otp_level,
'comment': self.comment, 'comment': self.comment,
'date_expired': self.date_expired.strftime('%Y-%m-%d %H:%M:%S') if self.date_expired is not None else None 'date_expired': self.date_expired.strftime('%Y-%m-%d %H:%M:%S') \
if self.date_expired is not None else None
}) })
@classmethod @classmethod
......
# -*- coding: utf-8 -*-
#
from .v1 import *
\ No newline at end of file
...@@ -7,14 +7,16 @@ from rest_framework_bulk import BulkListSerializer ...@@ -7,14 +7,16 @@ from rest_framework_bulk import BulkListSerializer
from common.utils import get_signer, validate_ssh_public_key from common.utils import get_signer, validate_ssh_public_key
from common.mixins import BulkSerializerMixin from common.mixins import BulkSerializerMixin
from .models import User, UserGroup from ..models import User, UserGroup
signer = get_signer() signer = get_signer()
class UserSerializer(BulkSerializerMixin, serializers.ModelSerializer): class UserSerializer(BulkSerializerMixin, serializers.ModelSerializer):
groups_display = serializers.SerializerMethodField() groups_display = serializers.SerializerMethodField()
groups = serializers.PrimaryKeyRelatedField(many=True, queryset = UserGroup.objects.all(), required=False) groups = serializers.PrimaryKeyRelatedField(
many=True, queryset=UserGroup.objects.all(), required=False
)
class Meta: class Meta:
model = User model = User
......
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from ..models import User, AccessKey
class AccessKeySerializer(serializers.ModelSerializer):
class Meta:
model = AccessKey
fields = ['id', 'secret']
read_only_fields = ['id', 'secret']
class ServiceAccountRegistrationSerializer(serializers.ModelSerializer):
access_key = AccessKeySerializer(read_only=True)
class Meta:
model = User
fields = ['id', 'name', 'access_key']
read_only_fields = ['id', 'access_key']
def get_username(self):
return self.initial_data.get('name')
def get_email(self):
name = self.initial_data.get('name')
return '{}@serviceaccount.local'.format(name)
def validate_name(self, name):
email = self.get_email()
username = self.get_username()
if User.objects.filter(email=email) or \
User.objects.filter(username=username):
raise serializers.ValidationError('name not unique', code='unique')
return name
def create(self, validated_data):
validated_data['email'] = self.get_email()
validated_data['username'] = self.get_username()
validated_data['role'] = User.ROLE_APP
instance = super().create(validated_data)
instance.create_access_key()
return instance
...@@ -15,8 +15,7 @@ router.register(r'groups', api.UserGroupViewSet, 'user-group') ...@@ -15,8 +15,7 @@ router.register(r'groups', api.UserGroupViewSet, 'user-group')
urlpatterns = [ urlpatterns = [
# path(r'', api.UserListView.as_view()), # path('token/', api.UserToken.as_view(), name='user-token'),
path('token/', api.UserToken.as_view(), name='user-token'),
path('connection-token/', api.UserConnectionTokenApi.as_view(), name='connection-token'), path('connection-token/', api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('profile/', api.UserProfileApi.as_view(), name='user-profile'), path('profile/', api.UserProfileApi.as_view(), name='user-profile'),
path('auth/', api.UserAuthApi.as_view(), name='user-auth'), path('auth/', api.UserAuthApi.as_view(), name='user-auth'),
...@@ -31,5 +30,6 @@ urlpatterns = [ ...@@ -31,5 +30,6 @@ urlpatterns = [
path('users/<uuid:pk>/groups/', api.UserUpdateGroupApi.as_view(), name='user-update-group'), path('users/<uuid:pk>/groups/', api.UserUpdateGroupApi.as_view(), name='user-update-group'),
path('groups/<uuid:pk>/users/', api.UserGroupUpdateUserApi.as_view(), name='user-group-update-user'), path('groups/<uuid:pk>/users/', api.UserGroupUpdateUserApi.as_view(), name='user-group-update-user'),
] ]
urlpatterns += router.urls urlpatterns += router.urls
#!/usr/bin/env python
# ~*~ coding: utf-8 ~*~
#
from __future__ import absolute_import
from django.urls import path, include
from rest_framework_bulk.routes import BulkRouter
from ..api import v2 as api
app_name = 'users'
router = BulkRouter()
router.register(r'service-account-registrations',
api.ServiceAccountRegistrationViewSet,
'service-account-registration')
urlpatterns = [
# path('token/', api.UserToken.as_view(), name='user-token'),
]
urlpatterns += router.urls
...@@ -202,24 +202,6 @@ def check_user_valid(**kwargs): ...@@ -202,24 +202,6 @@ def check_user_valid(**kwargs):
return None, _('Password or SSH public key invalid') return None, _('Password or SSH public key invalid')
def refresh_token(token, user, expiration=settings.TOKEN_EXPIRATION or 3600):
cache.set(token, user.id, expiration)
def generate_token(request, user):
expiration = settings.TOKEN_EXPIRATION or 3600
remote_addr = request.META.get('REMOTE_ADDR', '')
if not isinstance(remote_addr, bytes):
remote_addr = remote_addr.encode("utf-8")
remote_addr = base64.b16encode(remote_addr) # .replace(b'=', '')
token = cache.get('%s_%s' % (user.id, remote_addr))
if not token:
token = uuid.uuid4().hex
cache.set(token, user.id, expiration)
cache.set('%s_%s' % (user.id, remote_addr), token, expiration)
return token
def validate_ip(ip): def validate_ip(ip):
try: try:
ipaddress.ip_address(ip) ipaddress.ip_address(ip)
......
...@@ -76,3 +76,4 @@ aliyun-python-sdk-core-v3==2.9.1 ...@@ -76,3 +76,4 @@ aliyun-python-sdk-core-v3==2.9.1
aliyun-python-sdk-ecs==4.10.1 aliyun-python-sdk-ecs==4.10.1
python-keycloak==0.13.3 python-keycloak==0.13.3
python-keycloak-client==0.1.3 python-keycloak-client==0.1.3
rest_condition==1.0.3
# 名称 用户名 密码
test123 testq12 test123123123
#!/usr/bin/env python
import requests
import sys
admin_username = 'admin'
admin_password = 'admin'
domain_url = 'http://localhost:8080'
class UserCreation:
headers = {}
def __init__(self, username, password, domain):
self.username = username
self.password = password
self.domain = domain
def auth(self):
url = "{}/api/users/v1/token/".format(self.domain)
data = {"username": self.username, "password": self.password}
resp = requests.post(url, data=data)
if resp.status_code == 200:
data = resp.json()
self.headers.update({
'Authorization': '{} {}'.format(data['Keyword'], data['Token'])
})
else:
print("用户名 或 密码 或 地址 不对")
sys.exit(2)
def get_user_detail(self, name, url):
resp = requests.get(url, headers=self.headers)
if resp.status_code == 200:
data = resp.json()
if len(data) < 1:
return None
for d in data:
if d['name'] == name:
return d
return None
return None
def get_system_user_detail(self, name):
url = '{}/api/assets/v1/system-user/?name={}'.format(self.domain, name)
return self.get_user_detail(name, url)
def create_system_user(self, info):
system_user = self.get_system_user_detail(info.get('name'))
if system_user:
return system_user
url = '{}/api/assets/v1/system-user/'.format(self.domain)
resp = requests.post(url, data=info, headers=self.headers, json=False)
if resp.status_code == 201:
return resp.json()
else:
print("创建系统用户失败: {} {}".format(info['name'], resp.content))
return None
def set_system_user_auth(self, system_user, info):
url = '{}/api/assets/v1/system-user/{}/auth-info/'.format(
self.domain, system_user['id']
)
data = {'password': info.get('password')}
resp = requests.patch(url, data=data, headers=self.headers)
if resp.status_code > 300:
print("设置系统用户密码失败: {} {}".format(
system_user.get('name'), resp.content.decode()
))
else:
return True
def get_admin_user_detail(self, name):
url = '{}/api/assets/v1/admin-user/?name={}'.format(self.domain, name)
return self.get_user_detail(name, url)
def create_admin_user(self, info):
admin_user = self.get_admin_user_detail(info.get('name'))
if admin_user:
return admin_user
url = '{}/api/assets/v1/admin-user/'.format(self.domain)
resp = requests.post(url, data=info, headers=self.headers, json=False)
if resp.status_code == 201:
return resp.json()
else:
print("创建管理用户失败: {} {}".format(info['name'], resp.content.decode()))
return None
def set_admin_user_auth(self, admin_user, info):
url = '{}/api/assets/v1/admin-user/{}/auth/'.format(
self.domain, admin_user['id']
)
data = {'password': info.get('password')}
resp = requests.patch(url, data=data, headers=self.headers)
if resp.status_code > 300:
print("设置管理用户密码失败: {} {}".format(
admin_user.get('name'), resp.content.decode()
))
else:
return True
def create_system_users(self):
print("#"*10, " 开始创建系统用户 ", "#"*10)
users = []
f = open('system_users.txt')
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
name, username, password, protocol, auto_push = line.split()[:5]
info = {
"name": name,
"username": username,
"password": password,
"protocol": protocol,
"auto_push": bool(int(auto_push)),
"login_mode": "auto"
}
users.append(info)
for i, info in enumerate(users, start=1):
system_user = self.create_system_user(info)
if system_user and self.set_system_user_auth(system_user, info):
print("[{}] 创建系统用户成功: {}".format(i, system_user['name']))
def create_admin_users(self):
print("\n", "#"*10, " 开始创建管理用户 ", "#"*10)
users = []
f = open('admin_users.txt')
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
name, username, password = line.split()[:3]
info = {
"name": name,
"username": username,
"password": password,
}
users.append(info)
for i, info in enumerate(users, start=1):
admin_user = self.create_admin_user(info)
if admin_user and self.set_admin_user_auth(admin_user, info):
print("[{}] 创建管理用户成功: {}".format(i, admin_user['name']))
def main():
api = UserCreation(username=admin_username,
password=admin_password,
domain=domain_url)
api.auth()
api.create_system_users()
api.create_admin_users()
if __name__ == '__main__':
main()
# 名称 用户名 密码 协议[ssh,rdp] 自动推送[0不推送,1自动推送]
test123 test123 test123123123 ssh 0
test1323 test123 test123123123 ssh 0
1. 安装依赖包
$ pip install requests
2. 设置账号密码和地址
$ vim bulk_create_user.py # 设置为正确的值
admin_username = 'admin'
admin_password = 'admin'
domain_url = 'http://localhost:8081'
3. 配置需要添加的系统用户
$ vim system_users.txt
# 名称 用户名 密码
test123 testq12 test123123123
3. 配置需要添加的系统用户
$ vim system_users.txt
# 名称 用户名 密码 协议[ssh,rdp] 自动推送[0不推送,1自动推送]
test123 test123 test123123123 ssh 0
4. 运行
$ python bulk_create_user.py
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment