Unverified Commit 164f48e1 authored by 老广's avatar 老广 Committed by GitHub

Dev beta (#3048)

* [Update] 统一url地址

* [Update] 修改api

* [Update] 使用规范的签名

* [Update] 修改url

* [Update] 修改swagger

* [Update] 添加serializer class避免报错

* [Update] 修改token

* [Update] 支持api key

* [Update] 支持生成api key

* [Update] 修改api重定向

* [Update] 修改翻译

* [Update] 添加说明文档

* [Update] 修复浏览器关闭后session不失效的问题

* [Update] 修改一些内容

* [Update] 修改 jms脚本

* [Update] 修改重定向

* [Update] 修改搜索trim

* [Update] 修改搜索trim

* [Update] 添加sys log

* [Bugfix] 修改登陆错误

* [Update] 优化User操作private_token的接口 (#3091)

* [Update] 优化User操作private_token的接口

* [Update] 优化User操作private_token的接口 2

* [Bugfix] 解决授权了一个节点,当移动节点后,被移动的节点下的资产会放到未分组节点下的问题

* [Update] 升级jquery

* [Update] 默认使用page

* [Update] 修改使用Orgmodel view set

* [Update] 支持 nv的硬盘 https://github.com/jumpserver/jumpserver/issues/1804

* [UPdate] 解决命令执行宽度问题

* [Update] 优化节点

* [Update] 修改nodes过多时创建比较麻烦

* [Update] 修改导入

* [Update] 节点获取更新

* [Update] 修改nodes

* [Update] nodes显示full value

* [Update] 统一使用nodes select2 函数

* [Update] 修改磁盘大小小数

* [Update] 修改 Node service

* [Update] 优化授权节点

* [Update] 修改 node permission

* [Update] 修改asset permission

* [Stash]

* [Update] 修改node assets api

* [Update] 修改tree service,支持资产数量

* [Update] 修改暂时完成

* [Update] 修改一些bug
parent fe6f7bcf
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
from rest_framework import generics from rest_framework import generics
from rest_framework.pagination import LimitOffsetPagination
from rest_framework_bulk import BulkModelViewSet
from orgs.mixins.api import OrgBulkModelViewSet
from ..hands import IsOrgAdmin, IsAppUser from ..hands import IsOrgAdmin, IsAppUser
from ..models import RemoteApp from ..models import RemoteApp
from ..serializers import RemoteAppSerializer, RemoteAppConnectionInfoSerializer from ..serializers import RemoteAppSerializer, RemoteAppConnectionInfoSerializer
...@@ -16,13 +15,12 @@ __all__ = [ ...@@ -16,13 +15,12 @@ __all__ = [
] ]
class RemoteAppViewSet(BulkModelViewSet): class RemoteAppViewSet(OrgBulkModelViewSet):
filter_fields = ('name',) filter_fields = ('name',)
search_fields = filter_fields search_fields = filter_fields
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
queryset = RemoteApp.objects.all() queryset = RemoteApp.objects.all()
serializer_class = RemoteAppSerializer serializer_class = RemoteAppSerializer
pagination_class = LimitOffsetPagination
class RemoteAppConnectionInfoApi(generics.RetrieveAPIView): class RemoteAppConnectionInfoApi(generics.RetrieveAPIView):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django import forms from django import forms
from orgs.mixins import OrgModelForm from orgs.mixins.forms import OrgModelForm
from assets.models import SystemUser from assets.models import SystemUser
from ..models import RemoteApp from ..models import RemoteApp
......
...@@ -5,7 +5,7 @@ import uuid ...@@ -5,7 +5,7 @@ import uuid
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
from common.fields.model import EncryptJsonDictTextField from common.fields.model import EncryptJsonDictTextField
from .. import const from .. import const
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from rest_framework import serializers from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .. import const from .. import const
from ..models import RemoteApp from ..models import RemoteApp
......
# coding:utf-8 # coding:utf-8
# #
from django.urls import path from django.urls import path, re_path
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from common import api as capi
from .. import api from .. import api
app_name = 'applications' app_name = 'applications'
router = BulkRouter() router = BulkRouter()
router.register(r'remote-app', api.RemoteAppViewSet, 'remote-app') router.register(r'remote-apps', api.RemoteAppViewSet, 'remote-app')
urlpatterns = [ urlpatterns = [
path('remote-apps/<uuid:pk>/connection-info/', path('remote-apps/<uuid:pk>/connection-info/',
api.RemoteAppConnectionInfoApi.as_view(), api.RemoteAppConnectionInfoApi.as_view(),
name='remote-app-connection-info') name='remote-app-connection-info')
] ]
old_version_urlpatterns = [
re_path('(?P<resource>remote-app)/.*', capi.redirect_plural_name_api)
]
urlpatterns += router.urls urlpatterns += router.urls + old_version_urlpatterns
...@@ -17,8 +17,7 @@ from django.db import transaction ...@@ -17,8 +17,7 @@ from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from rest_framework import generics from rest_framework import generics
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from rest_framework.pagination import LimitOffsetPagination
from common.mixins import IDInCacheFilterMixin from common.mixins import IDInCacheFilterMixin
from common.utils import get_logger from common.utils import get_logger
...@@ -36,7 +35,7 @@ __all__ = [ ...@@ -36,7 +35,7 @@ __all__ = [
] ]
class AdminUserViewSet(IDInCacheFilterMixin, BulkModelViewSet): class AdminUserViewSet(OrgBulkModelViewSet):
""" """
Admin user api set, for add,delete,update,list,retrieve resource Admin user api set, for add,delete,update,list,retrieve resource
""" """
...@@ -46,11 +45,6 @@ class AdminUserViewSet(IDInCacheFilterMixin, BulkModelViewSet): ...@@ -46,11 +45,6 @@ class AdminUserViewSet(IDInCacheFilterMixin, BulkModelViewSet):
queryset = AdminUser.objects.all() queryset = AdminUser.objects.all()
serializer_class = serializers.AdminUserSerializer serializer_class = serializers.AdminUserSerializer
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
pagination_class = LimitOffsetPagination
def get_queryset(self):
queryset = super().get_queryset().all()
return queryset
class AdminUserAuthApi(generics.UpdateAPIView): class AdminUserAuthApi(generics.UpdateAPIView):
...@@ -98,7 +92,6 @@ class AdminUserTestConnectiveApi(generics.RetrieveAPIView): ...@@ -98,7 +92,6 @@ class AdminUserTestConnectiveApi(generics.RetrieveAPIView):
class AdminUserAssetsListView(generics.ListAPIView): class AdminUserAssetsListView(generics.ListAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.AssetSimpleSerializer serializer_class = serializers.AssetSimpleSerializer
pagination_class = LimitOffsetPagination
filter_fields = ("hostname", "ip") filter_fields = ("hostname", "ip")
http_method_names = ['get'] http_method_names = ['get']
search_fields = filter_fields search_fields = filter_fields
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
import random import random
from rest_framework import generics from rest_framework import generics
from rest_framework.views import APIView
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet
from rest_framework_bulk import ListBulkCreateUpdateDestroyAPIView
from rest_framework.pagination import LimitOffsetPagination
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.urls import reverse_lazy
from django.core.cache import cache
from django.db.models import Q from django.db.models import Q
from common.mixins import IDInCacheFilterMixin, ApiMessageMixin
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
from orgs.mixins import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..const import CACHE_KEY_ASSET_BULK_UPDATE_ID_PREFIX
from ..models import Asset, AdminUser, Node from ..models import Asset, AdminUser, Node
from .. import serializers from .. import serializers
from ..tasks import update_asset_hardware_info_manual, \ from ..tasks import update_asset_hardware_info_manual, \
...@@ -31,9 +21,9 @@ from ..utils import LabelFilter ...@@ -31,9 +21,9 @@ from ..utils import LabelFilter
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
'AssetViewSet', 'AssetListUpdateApi', 'AssetViewSet',
'AssetRefreshHardwareApi', 'AssetAdminUserTestApi', 'AssetRefreshHardwareApi', 'AssetAdminUserTestApi',
'AssetGatewayApi', 'AssetBulkUpdateSelectAPI' 'AssetGatewayApi',
] ]
...@@ -46,7 +36,6 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet): ...@@ -46,7 +36,6 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet):
ordering_fields = ("hostname", "ip", "port", "cpu_cores") ordering_fields = ("hostname", "ip", "port", "cpu_cores")
queryset = Asset.objects.all() queryset = Asset.objects.all()
serializer_class = serializers.AssetSerializer serializer_class = serializers.AssetSerializer
pagination_class = LimitOffsetPagination
permission_classes = (IsOrgAdminOrAppUser,) permission_classes = (IsOrgAdminOrAppUser,)
success_message = _("%(hostname)s was %(action)s successfully") success_message = _("%(hostname)s was %(action)s successfully")
...@@ -73,19 +62,21 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet): ...@@ -73,19 +62,21 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet):
node = get_object_or_404(Node, id=node_id) node = get_object_or_404(Node, id=node_id)
show_current_asset = self.request.query_params.get("show_current_asset") in ('1', 'true') show_current_asset = self.request.query_params.get("show_current_asset") in ('1', 'true')
# 当前节点是顶层节点, 并且仅显示直接资产
if node.is_root() and show_current_asset: if node.is_root() and show_current_asset:
queryset = queryset.filter( queryset = queryset.filter(
Q(nodes=node_id) | Q(nodes__isnull=True) Q(nodes=node_id) | Q(nodes__isnull=True)
) ).distinct()
# 当前节点是顶层节点,显示所有资产
elif node.is_root() and not show_current_asset: elif node.is_root() and not show_current_asset:
pass return queryset
# 当前节点不是鼎城节点,只显示直接资产
elif not node.is_root() and show_current_asset: elif not node.is_root() and show_current_asset:
queryset = queryset.filter(nodes=node) queryset = queryset.filter(nodes=node)
else: else:
queryset = queryset.filter( children = node.get_all_children(with_self=True)
nodes__key__regex='^{}(:[0-9]+)*$'.format(node.key), queryset = queryset.filter(nodes__in=children).distinct()
) return queryset
return queryset.distinct()
def filter_admin_user_id(self, queryset): def filter_admin_user_id(self, queryset):
admin_user_id = self.request.query_params.get('admin_user_id') admin_user_id = self.request.query_params.get('admin_user_id')
...@@ -102,30 +93,6 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet): ...@@ -102,30 +93,6 @@ class AssetViewSet(LabelFilter, OrgBulkModelViewSet):
return queryset return queryset
class AssetListUpdateApi(IDInCacheFilterMixin, ListBulkCreateUpdateDestroyAPIView):
"""
Asset bulk update api
"""
queryset = Asset.objects.all()
serializer_class = serializers.AssetSerializer
permission_classes = (IsOrgAdmin,)
class AssetBulkUpdateSelectAPI(APIView):
permission_classes = (IsOrgAdmin,)
def post(self, request, *args, **kwargs):
assets_id = request.data.get('assets_id', '')
if assets_id:
spm = uuid.uuid4().hex
key = CACHE_KEY_ASSET_BULK_UPDATE_ID_PREFIX.format(spm)
cache.set(key, assets_id, 300)
url = reverse_lazy('assets:asset-bulk-update') + '?spm=%s' % spm
return Response({'url': url})
error = _('Please select assets that need to be updated')
return Response({'error': error}, status=400)
class AssetRefreshHardwareApi(generics.RetrieveAPIView): class AssetRefreshHardwareApi(generics.RetrieveAPIView):
""" """
Refresh asset hardware info Refresh asset hardware info
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# #
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework import viewsets, status, generics from rest_framework import generics
from rest_framework.pagination import LimitOffsetPagination
from rest_framework import filters from rest_framework import filters
from rest_framework_bulk import BulkModelViewSet from rest_framework_bulk import BulkModelViewSet
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.http import Http404
from common.permissions import IsOrgAdminOrAppUser, NeedMFAVerify from common.permissions import IsOrgAdminOrAppUser, NeedMFAVerify
from common.utils import get_object_or_none, get_logger from common.utils import get_object_or_none, get_logger
...@@ -53,7 +53,6 @@ class AssetUserSearchBackend(filters.BaseFilterBackend): ...@@ -53,7 +53,6 @@ class AssetUserSearchBackend(filters.BaseFilterBackend):
class AssetUserViewSet(IDInCacheFilterMixin, BulkModelViewSet): class AssetUserViewSet(IDInCacheFilterMixin, BulkModelViewSet):
pagination_class = LimitOffsetPagination
serializer_class = serializers.AssetUserSerializer serializer_class = serializers.AssetUserSerializer
permission_classes = [IsOrgAdminOrAppUser] permission_classes = [IsOrgAdminOrAppUser]
http_method_names = ['get', 'post'] http_method_names = ['get', 'post']
...@@ -67,6 +66,9 @@ class AssetUserViewSet(IDInCacheFilterMixin, BulkModelViewSet): ...@@ -67,6 +66,9 @@ class AssetUserViewSet(IDInCacheFilterMixin, BulkModelViewSet):
AssetUserFilterBackend, AssetUserSearchBackend, AssetUserFilterBackend, AssetUserSearchBackend,
) )
def allow_bulk_destroy(self, qs, filtered):
return False
def get_queryset(self): def get_queryset(self):
# 尽可能先返回更少的数据 # 尽可能先返回更少的数据
username = self.request.GET.get('username') username = self.request.GET.get('username')
...@@ -115,14 +117,6 @@ class AssetUserAuthInfoApi(generics.RetrieveAPIView): ...@@ -115,14 +117,6 @@ class AssetUserAuthInfoApi(generics.RetrieveAPIView):
serializer_class = serializers.AssetUserAuthInfoSerializer serializer_class = serializers.AssetUserAuthInfoSerializer
permission_classes = [IsOrgAdminOrAppUser, NeedMFAVerify] permission_classes = [IsOrgAdminOrAppUser, NeedMFAVerify]
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
status_code = status.HTTP_200_OK
if not instance:
status_code = status.HTTP_400_BAD_REQUEST
return Response(serializer.data, status=status_code)
def get_object(self): def get_object(self):
query_params = self.request.query_params query_params = self.request.query_params
username = query_params.get('username') username = query_params.get('username')
...@@ -133,8 +127,7 @@ class AssetUserAuthInfoApi(generics.RetrieveAPIView): ...@@ -133,8 +127,7 @@ class AssetUserAuthInfoApi(generics.RetrieveAPIView):
manger = AssetUserManager() manger = AssetUserManager()
instance = manger.get(username, asset, prefer=prefer) instance = manger.get(username, asset, prefer=prefer)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) raise Http404("Not found")
return None
else: else:
return instance return instance
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from rest_framework_bulk import BulkModelViewSet
from rest_framework.pagination import LimitOffsetPagination
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from orgs.mixins.api import OrgBulkModelViewSet
from ..hands import IsOrgAdmin from ..hands import IsOrgAdmin
from ..models import CommandFilter, CommandFilterRule from ..models import CommandFilter, CommandFilterRule
from .. import serializers from .. import serializers
...@@ -13,21 +12,19 @@ from .. import serializers ...@@ -13,21 +12,19 @@ from .. import serializers
__all__ = ['CommandFilterViewSet', 'CommandFilterRuleViewSet'] __all__ = ['CommandFilterViewSet', 'CommandFilterRuleViewSet']
class CommandFilterViewSet(BulkModelViewSet): class CommandFilterViewSet(OrgBulkModelViewSet):
filter_fields = ("name",) filter_fields = ("name",)
search_fields = filter_fields search_fields = filter_fields
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
queryset = CommandFilter.objects.all() queryset = CommandFilter.objects.all()
serializer_class = serializers.CommandFilterSerializer serializer_class = serializers.CommandFilterSerializer
pagination_class = LimitOffsetPagination
class CommandFilterRuleViewSet(BulkModelViewSet): class CommandFilterRuleViewSet(OrgBulkModelViewSet):
filter_fields = ("content",) filter_fields = ("content",)
search_fields = filter_fields search_fields = filter_fields
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.CommandFilterRuleSerializer serializer_class = serializers.CommandFilterRuleSerializer
pagination_class = LimitOffsetPagination
def get_queryset(self): def get_queryset(self):
fpk = self.kwargs.get('filter_pk') fpk = self.kwargs.get('filter_pk')
......
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
from rest_framework_bulk import BulkModelViewSet
from rest_framework.views import APIView, Response from rest_framework.views import APIView, Response
from rest_framework.pagination import LimitOffsetPagination
from django.views.generic.detail import SingleObjectMixin from django.views.generic.detail import SingleObjectMixin
from common.utils import get_logger from common.utils import get_logger
from common.permissions import IsOrgAdmin, IsAppUser, IsOrgAdminOrAppUser from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
from orgs.mixins.api import OrgBulkModelViewSet
from ..models import Domain, Gateway from ..models import Domain, Gateway
from .. import serializers from .. import serializers
...@@ -16,11 +14,10 @@ logger = get_logger(__file__) ...@@ -16,11 +14,10 @@ logger = get_logger(__file__)
__all__ = ['DomainViewSet', 'GatewayViewSet', "GatewayTestConnectionApi"] __all__ = ['DomainViewSet', 'GatewayViewSet', "GatewayTestConnectionApi"]
class DomainViewSet(BulkModelViewSet): class DomainViewSet(OrgBulkModelViewSet):
queryset = Domain.objects.all() queryset = Domain.objects.all()
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.DomainSerializer serializer_class = serializers.DomainSerializer
pagination_class = LimitOffsetPagination
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset().all() queryset = super().get_queryset().all()
...@@ -37,13 +34,12 @@ class DomainViewSet(BulkModelViewSet): ...@@ -37,13 +34,12 @@ class DomainViewSet(BulkModelViewSet):
return super().get_permissions() return super().get_permissions()
class GatewayViewSet(BulkModelViewSet): class GatewayViewSet(OrgBulkModelViewSet):
filter_fields = ("domain__name", "name", "username", "ip", "domain") filter_fields = ("domain__name", "name", "username", "ip", "domain")
search_fields = filter_fields search_fields = filter_fields
queryset = Gateway.objects.all() queryset = Gateway.objects.all()
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.GatewaySerializer serializer_class = serializers.GatewaySerializer
pagination_class = LimitOffsetPagination
class GatewayTestConnectionApi(SingleObjectMixin, APIView): class GatewayTestConnectionApi(SingleObjectMixin, APIView):
......
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from rest_framework.pagination import LimitOffsetPagination
from django.db.models import Count from django.db.models import Count
from common.utils import get_logger from common.utils import get_logger
from orgs.mixins import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..hands import IsOrgAdmin from ..hands import IsOrgAdmin
from ..models import Label from ..models import Label
from .. import serializers from .. import serializers
...@@ -32,7 +31,6 @@ class LabelViewSet(OrgBulkModelViewSet): ...@@ -32,7 +31,6 @@ class LabelViewSet(OrgBulkModelViewSet):
search_fields = filter_fields search_fields = filter_fields
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.LabelSerializer serializer_class = serializers.LabelSerializer
pagination_class = LimitOffsetPagination
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
if request.query_params.get("distinct"): if request.query_params.get("distinct"):
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from rest_framework import generics, mixins, viewsets import time
from rest_framework import generics
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.response import Response from rest_framework.response import Response
...@@ -22,11 +24,11 @@ from django.shortcuts import get_object_or_404 ...@@ -22,11 +24,11 @@ from django.shortcuts import get_object_or_404
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer from common.tree import TreeNodeSerializer
from orgs.mixins.api import OrgModelViewSet
from ..hands import IsOrgAdmin from ..hands import IsOrgAdmin
from ..models import Node from ..models import Node
from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util
from .. import serializers from .. import serializers
from ..utils import NodeUtil
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -39,29 +41,25 @@ __all__ = [ ...@@ -39,29 +41,25 @@ __all__ = [
] ]
class NodeViewSet(viewsets.ModelViewSet): class NodeViewSet(OrgModelViewSet):
filter_fields = ('value', 'key', ) filter_fields = ('value', 'key', 'id')
search_fields = filter_fields search_fields = ('value', )
queryset = Node.objects.all() queryset = Node.objects.all()
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer serializer_class = serializers.NodeSerializer
# 仅支持根节点指直接创建,子节点下的节点需要通过children接口创建
def perform_create(self, serializer): def perform_create(self, serializer):
child_key = Node.root().get_next_child_key() child_key = Node.root().get_next_child_key()
serializer.validated_data["key"] = child_key serializer.validated_data["key"] = child_key
serializer.save() serializer.save()
def update(self, request, *args, **kwargs): def perform_update(self, serializer):
node = self.get_object() node = self.get_object()
if node.is_root(): if node.is_root() and node.value != serializer.validated_data['value']:
node_value = node.value msg = _("You can't update the root node name")
post_value = request.data.get('value') raise ValidationError({"error": msg})
if node_value != post_value: return super().perform_update(serializer)
return Response(
{"msg": _("You can't update the root node name")},
status=400
)
return super().update(request, *args, **kwargs)
class NodeListAsTreeApi(generics.ListAPIView): class NodeListAsTreeApi(generics.ListAPIView):
...@@ -79,21 +77,72 @@ class NodeListAsTreeApi(generics.ListAPIView): ...@@ -79,21 +77,72 @@ class NodeListAsTreeApi(generics.ListAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = TreeNodeSerializer serializer_class = TreeNodeSerializer
@staticmethod
def to_tree_queryset(queryset):
queryset = [node.as_tree_node() for node in queryset]
return queryset
def get_queryset(self): def get_queryset(self):
queryset = Node.objects.all() queryset = Node.objects.all()
util = NodeUtil()
nodes = util.get_nodes_by_queryset(queryset)
queryset = [node.as_tree_node() for node in nodes]
return queryset return queryset
@staticmethod def filter_queryset(self, queryset):
def refresh_nodes(queryset): queryset = super().filter_queryset(queryset)
Node.expire_nodes_assets_amount() queryset = self.to_tree_queryset(queryset)
Node.expire_nodes_full_value() return queryset
class NodeChildrenApi(generics.ListCreateAPIView):
queryset = Node.objects.all()
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer
instance = None
def initial(self, request, *args, **kwargs):
self.instance = self.get_object()
return super().initial(request, *args, **kwargs)
def perform_create(self, serializer):
data = serializer.validated_data
_id = data.get("id")
value = data.get("value")
if not value:
value = self.instance.get_next_child_preset_name()
node = self.instance.create_child(value=value, _id=_id)
# 避免查询 full value
node._full_value = node.value
serializer.instance = node
def get_object(self):
pk = self.kwargs.get('pk') or self.request.query_params.get('id')
key = self.request.query_params.get("key")
if not pk and not key:
node = Node.root()
return node
if pk:
node = get_object_or_404(Node, pk=pk)
else:
node = get_object_or_404(Node, key=key)
return node
def get_queryset(self):
query_all = self.request.query_params.get("all", "0") == "all"
if not self.instance:
return Node.objects.none()
if self.instance.is_root():
with_self = True
else:
with_self = False
if query_all:
queryset = self.instance.get_all_children(with_self=with_self)
else:
queryset = self.instance.get_children(with_self=with_self)
return queryset return queryset
class NodeChildrenAsTreeApi(generics.ListAPIView): class NodeChildrenAsTreeApi(NodeChildrenApi):
""" """
节点子节点作为树返回, 节点子节点作为树返回,
[ [
...@@ -106,39 +155,26 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): ...@@ -106,39 +155,26 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
] ]
""" """
permission_classes = (IsOrgAdmin,)
serializer_class = TreeNodeSerializer serializer_class = TreeNodeSerializer
node = None http_method_names = ['get']
is_root = False
def get_queryset(self): def get_queryset(self):
self.check_need_refresh_nodes() queryset = super().get_queryset()
node_key = self.request.query_params.get('key')
util = NodeUtil()
# 是否包含自己
with_self = False
if not node_key:
node_key = Node.root().key
with_self = True
self.node = util.get_node_by_key(node_key)
queryset = self.node.get_children(with_self=with_self)
queryset = [node.as_tree_node() for node in queryset] queryset = [node.as_tree_node() for node in queryset]
queryset = self.add_assets_if_need(queryset)
queryset = sorted(queryset) queryset = sorted(queryset)
return queryset return queryset
def filter_assets(self, queryset): def add_assets_if_need(self, queryset):
include_assets = self.request.query_params.get('assets', '0') == '1' include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets: if not include_assets:
return queryset return queryset
assets = self.node.get_assets().only( assets = self.instance.get_assets().only(
"id", "hostname", "ip", 'platform', "os", "org_id", "protocols", "id", "hostname", "ip", 'platform', "os",
"org_id", "protocols",
) )
for asset in assets: for asset in assets:
queryset.append(asset.as_tree_node(self.node)) queryset.append(asset.as_tree_node(self.instance))
return queryset
def filter_queryset(self, queryset):
queryset = self.filter_assets(queryset)
return queryset return queryset
def check_need_refresh_nodes(self): def check_need_refresh_nodes(self):
...@@ -146,59 +182,6 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): ...@@ -146,59 +182,6 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
Node.refresh_nodes() Node.refresh_nodes()
class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView):
queryset = Node.objects.all()
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer
instance = None
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
def post(self, request, *args, **kwargs):
instance = self.get_object()
if not request.data.get("value"):
request.data["value"] = instance.get_next_child_preset_name()
return super().post(request, *args, **kwargs)
def create(self, request, *args, **kwargs):
instance = self.get_object()
value = request.data.get("value")
_id = request.data.get('id') or None
values = [child.value for child in instance.get_children()]
if value in values:
raise ValidationError(
'The same level node name cannot be the same'
)
node = instance.create_child(value=value, _id=_id)
return Response(self.serializer_class(instance=node).data, status=201)
def get_object(self):
pk = self.kwargs.get('pk') or self.request.query_params.get('id')
if not pk:
node = Node.root()
else:
node = get_object_or_404(Node, pk=pk)
return node
def get_queryset(self):
queryset = []
query_all = self.request.query_params.get("all")
node = self.get_object()
if node is None:
node = Node.root()
node.assets__count = node.get_all_assets().count()
queryset.append(node)
if query_all:
children = node.get_all_children()
else:
children = node.get_children()
queryset.extend(list(children))
return queryset
class NodeAssetsApi(generics.ListAPIView): class NodeAssetsApi(generics.ListAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.AssetSerializer serializer_class = serializers.AssetSerializer
......
...@@ -16,18 +16,17 @@ ...@@ -16,18 +16,17 @@
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from rest_framework import generics from rest_framework import generics
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet
from rest_framework.pagination import LimitOffsetPagination
from common.serializers import CeleryTaskSerializer
from common.utils import get_logger from common.utils import get_logger
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
from common.mixins import IDInCacheFilterMixin from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import OrgBulkModelViewSet
from ..models import SystemUser, Asset from ..models import SystemUser, Asset
from .. import serializers from .. import serializers
from ..tasks import push_system_user_to_assets_manual, \ from ..tasks import (
test_system_user_connectivity_manual, push_system_user_a_asset_manual, \ push_system_user_to_assets_manual, test_system_user_connectivity_manual,
test_system_user_connectivity_a_asset push_system_user_a_asset_manual, test_system_user_connectivity_a_asset,
)
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -49,7 +48,6 @@ class SystemUserViewSet(OrgBulkModelViewSet): ...@@ -49,7 +48,6 @@ class SystemUserViewSet(OrgBulkModelViewSet):
queryset = SystemUser.objects.all() queryset = SystemUser.objects.all()
serializer_class = serializers.SystemUserSerializer serializer_class = serializers.SystemUserSerializer
permission_classes = (IsOrgAdminOrAppUser,) permission_classes = (IsOrgAdminOrAppUser,)
pagination_class = LimitOffsetPagination
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset().all() queryset = super().get_queryset().all()
...@@ -92,6 +90,7 @@ class SystemUserPushApi(generics.RetrieveAPIView): ...@@ -92,6 +90,7 @@ class SystemUserPushApi(generics.RetrieveAPIView):
""" """
queryset = SystemUser.objects.all() queryset = SystemUser.objects.all()
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = CeleryTaskSerializer
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
system_user = self.get_object() system_user = self.get_object()
...@@ -108,6 +107,7 @@ class SystemUserTestConnectiveApi(generics.RetrieveAPIView): ...@@ -108,6 +107,7 @@ class SystemUserTestConnectiveApi(generics.RetrieveAPIView):
""" """
queryset = SystemUser.objects.all() queryset = SystemUser.objects.all()
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = CeleryTaskSerializer
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
system_user = self.get_object() system_user = self.get_object()
...@@ -118,7 +118,6 @@ class SystemUserTestConnectiveApi(generics.RetrieveAPIView): ...@@ -118,7 +118,6 @@ class SystemUserTestConnectiveApi(generics.RetrieveAPIView):
class SystemUserAssetsListView(generics.ListAPIView): class SystemUserAssetsListView(generics.ListAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.AssetSimpleSerializer serializer_class = serializers.AssetSimpleSerializer
pagination_class = LimitOffsetPagination
filter_fields = ("hostname", "ip") filter_fields = ("hostname", "ip")
http_method_names = ['get'] http_method_names = ['get']
search_fields = filter_fields search_fields = filter_fields
......
...@@ -4,7 +4,7 @@ from django import forms ...@@ -4,7 +4,7 @@ from django import forms
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.utils import get_logger from common.utils import get_logger
from orgs.mixins import OrgModelForm from orgs.mixins.forms import OrgModelForm
from ..models import Asset, Node from ..models import Asset, Node
...@@ -29,9 +29,14 @@ class ProtocolForm(forms.Form): ...@@ -29,9 +29,14 @@ class ProtocolForm(forms.Form):
class AssetCreateForm(OrgModelForm): class AssetCreateForm(OrgModelForm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.data:
return
nodes_field = self.fields['nodes'] nodes_field = self.fields['nodes']
nodes_field.choices = ((n.id, n.full_value) for n in if self.instance:
Node.get_queryset()) nodes_field.choices = ((n.id, n.full_value) for n in
self.instance.nodes.all())
else:
nodes_field.choices = []
class Meta: class Meta:
model = Asset model = Asset
...@@ -42,7 +47,7 @@ class AssetCreateForm(OrgModelForm): ...@@ -42,7 +47,7 @@ class AssetCreateForm(OrgModelForm):
] ]
widgets = { widgets = {
'nodes': forms.SelectMultiple(attrs={ 'nodes': forms.SelectMultiple(attrs={
'class': 'select2', 'data-placeholder': _('Nodes') 'class': 'nodes-select2', 'data-placeholder': _('Nodes')
}), }),
'admin_user': forms.Select(attrs={ 'admin_user': forms.Select(attrs={
'class': 'select2', 'data-placeholder': _('Admin user') 'class': 'select2', 'data-placeholder': _('Admin user')
...@@ -68,6 +73,17 @@ class AssetCreateForm(OrgModelForm): ...@@ -68,6 +73,17 @@ class AssetCreateForm(OrgModelForm):
class AssetUpdateForm(OrgModelForm): class AssetUpdateForm(OrgModelForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.data:
return
nodes_field = self.fields['nodes']
if self.instance:
nodes_field.choices = ((n.id, n.full_value) for n in
self.instance.nodes.all())
else:
nodes_field.choices = []
class Meta: class Meta:
model = Asset model = Asset
fields = [ fields = [
...@@ -77,7 +93,7 @@ class AssetUpdateForm(OrgModelForm): ...@@ -77,7 +93,7 @@ class AssetUpdateForm(OrgModelForm):
] ]
widgets = { widgets = {
'nodes': forms.SelectMultiple(attrs={ 'nodes': forms.SelectMultiple(attrs={
'class': 'select2', 'data-placeholder': _('Node') 'class': 'nodes-select2', 'data-placeholder': _('Node')
}), }),
'admin_user': forms.Select(attrs={ 'admin_user': forms.Select(attrs={
'class': 'select2', 'data-placeholder': _('Admin user') 'class': 'select2', 'data-placeholder': _('Admin user')
......
...@@ -5,7 +5,7 @@ from django.core.exceptions import ValidationError ...@@ -5,7 +5,7 @@ from django.core.exceptions import ValidationError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
import re import re
from orgs.mixins import OrgModelForm from orgs.mixins.forms import OrgModelForm
from ..models import CommandFilter, CommandFilterRule from ..models import CommandFilter, CommandFilterRule
__all__ = ['CommandFilterForm', 'CommandFilterRuleForm'] __all__ = ['CommandFilterForm', 'CommandFilterRuleForm']
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from django import forms from django import forms
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from orgs.mixins import OrgModelForm from orgs.mixins.forms import OrgModelForm
from ..models import Domain, Asset, Gateway from ..models import Domain, Asset, Gateway
from .user import PasswordAndKeyAuthForm from .user import PasswordAndKeyAuthForm
......
...@@ -4,7 +4,7 @@ from django import forms ...@@ -4,7 +4,7 @@ from django import forms
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.utils import validate_ssh_private_key, ssh_pubkey_gen, get_logger from common.utils import validate_ssh_private_key, ssh_pubkey_gen, get_logger
from orgs.mixins import OrgModelForm from orgs.mixins.forms import OrgModelForm
from ..models import AdminUser, SystemUser from ..models import AdminUser, SystemUser
logger = get_logger(__file__) logger = get_logger(__file__)
......
# Generated by Django 2.1.7 on 2019-07-24 12:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0036_auto_20190716_1535'),
]
operations = [
migrations.AlterField(
model_name='adminuser',
name='_become_pass',
field=models.CharField(blank=True, default='', max_length=128),
),
]
...@@ -13,7 +13,7 @@ from django.db import models ...@@ -13,7 +13,7 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from .utils import Connectivity from .utils import Connectivity
from orgs.mixins import OrgModelMixin, OrgManager from orgs.mixins.models import OrgModelMixin, OrgManager
__all__ = ['Asset', 'ProtocolsMixin'] __all__ = ['Asset', 'ProtocolsMixin']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -345,7 +345,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin): ...@@ -345,7 +345,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
else: else:
_nodes = [Node.default_node()] _nodes = [Node.default_node()]
asset.nodes.set(_nodes) asset.nodes.set(_nodes)
asset.system_users = [choice(SystemUser.objects.all()) for i in range(3)]
logger.debug('Generate fake asset : %s' % asset.ip) logger.debug('Generate fake asset : %s' % asset.ip)
except IntegrityError: except IntegrityError:
print('Error continue') print('Error continue')
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgManager from orgs.mixins.models import OrgManager
from .base import AssetUser from .base import AssetUser
__all__ = ['AuthBook'] __all__ = ['AuthBook']
......
...@@ -15,7 +15,7 @@ from common.utils import ( ...@@ -15,7 +15,7 @@ from common.utils import (
) )
from common.validators import alphanumeric from common.validators import alphanumeric
from common import fields from common import fields
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
from .utils import private_key_validator, Connectivity from .utils import private_key_validator, Connectivity
signer = get_signer() signer = get_signer()
......
...@@ -7,7 +7,7 @@ from django.db import models ...@@ -7,7 +7,7 @@ from django.db import models
from django.core.validators import MinValueValidator, MaxValueValidator from django.core.validators import MinValueValidator, MaxValueValidator
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
__all__ = [ __all__ = [
......
...@@ -9,7 +9,7 @@ import paramiko ...@@ -9,7 +9,7 @@ import paramiko
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
from .base import AssetUser from .base import AssetUser
__all__ = ['Domain', 'Gateway'] __all__ = ['Domain', 'Gateway']
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import uuid import uuid
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
class Label(OrgModelMixin): class Label(OrgModelMixin):
......
This diff is collapsed.
...@@ -31,7 +31,7 @@ class AdminUser(AssetUser): ...@@ -31,7 +31,7 @@ class AdminUser(AssetUser):
become = models.BooleanField(default=True) become = models.BooleanField(default=True)
become_method = models.CharField(choices=BECOME_METHOD_CHOICES, default='sudo', max_length=4) become_method = models.CharField(choices=BECOME_METHOD_CHOICES, default='sudo', max_length=4)
become_user = models.CharField(default='root', max_length=64) become_user = models.CharField(default='root', max_length=64)
_become_pass = models.CharField(default='', max_length=128) _become_pass = models.CharField(default='', blank=True, max_length=128)
CONNECTIVITY_CACHE_KEY = '_ADMIN_USER_CONNECTIVE_{}' CONNECTIVITY_CACHE_KEY = '_ADMIN_USER_CONNECTIVE_{}'
_prefer = "admin_user" _prefer = "admin_user"
......
...@@ -6,7 +6,7 @@ from rest_framework import serializers ...@@ -6,7 +6,7 @@ from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from ..models import Node, AdminUser from ..models import Node, AdminUser
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import AuthSerializer, AuthSerializerMixin from .base import AuthSerializer, AuthSerializerMixin
......
...@@ -4,7 +4,7 @@ from rest_framework import serializers ...@@ -4,7 +4,7 @@ from rest_framework import serializers
from django.db.models import Prefetch from django.db.models import Prefetch
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from ..models import Asset, Node, Label from ..models import Asset, Node, Label
from .base import ConnectivitySerializer from .base import ConnectivitySerializer
......
...@@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _ ...@@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _
from rest_framework import serializers from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import AuthBook, Asset from ..models import AuthBook, Asset
from ..backends import AssetUserManager from ..backends import AssetUserManager
from .base import ConnectivitySerializer, AuthSerializerMixin from .base import ConnectivitySerializer, AuthSerializerMixin
......
...@@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _
from common.fields import ChoiceDisplayField from common.fields import ChoiceDisplayField
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from ..models import CommandFilter, CommandFilterRule, SystemUser from ..models import CommandFilter, CommandFilterRule, SystemUser
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
class CommandFilterSerializer(BulkOrgResourceModelSerializer): class CommandFilterSerializer(BulkOrgResourceModelSerializer):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from rest_framework import serializers from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import Domain, Gateway from ..models import Domain, Gateway
from .base import AuthSerializerMixin from .base import AuthSerializerMixin
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from rest_framework import serializers from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import Label from ..models import Label
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from rest_framework import serializers from rest_framework import serializers
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import Asset, Node from ..models import Asset, Node
...@@ -13,22 +13,21 @@ __all__ = [ ...@@ -13,22 +13,21 @@ __all__ = [
class NodeSerializer(BulkOrgResourceModelSerializer): class NodeSerializer(BulkOrgResourceModelSerializer):
assets_amount = serializers.IntegerField(read_only=True)
name = serializers.ReadOnlyField(source='value') name = serializers.ReadOnlyField(source='value')
value = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("value"))
class Meta: class Meta:
model = Node model = Node
only_fields = ['id', 'key', 'value', 'org_id'] only_fields = ['id', 'key', 'value', 'org_id']
fields = only_fields + ['name', 'assets_amount'] fields = only_fields + ['name', 'full_value']
read_only_fields = [ read_only_fields = ['key', 'org_id']
'key', 'name', 'assets_amount', 'org_id',
]
def validate_value(self, data): def validate_value(self, data):
instance = self.instance if self.instance else Node.root() if not self.instance and not data:
children = instance.parent.get_children() return data
children_values = [node.value for node in children if node != instance] instance = self.instance
if data in children_values: siblings = instance.get_siblings()
if siblings.filter(value=data):
raise serializers.ValidationError( raise serializers.ValidationError(
_('The same level node name cannot be the same') _('The same level node name cannot be the same')
) )
......
...@@ -4,7 +4,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -4,7 +4,7 @@ from django.utils.translation import ugettext_lazy as _
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from common.utils import ssh_pubkey_gen from common.utils import ssh_pubkey_gen
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import SystemUser from ..models import SystemUser
from .base import AuthSerializer, AuthSerializerMixin from .base import AuthSerializer, AuthSerializerMixin
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from collections import defaultdict from collections import defaultdict
from django.db.models.signals import post_save, m2m_changed, post_delete from django.db.models.signals import post_save, m2m_changed, pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from common.utils import get_logger from common.utils import get_logger
...@@ -38,15 +38,13 @@ def on_asset_created_or_update(sender, instance=None, created=False, **kwargs): ...@@ -38,15 +38,13 @@ def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
test_asset_conn_on_created(instance) test_asset_conn_on_created(instance)
# 过期节点资产数量 # 过期节点资产数量
nodes = instance.nodes.all() Node.refresh_nodes()
Node.expire_nodes_assets_amount(nodes)
@receiver(post_delete, sender=Asset, dispatch_uid="my_unique_identifier") @receiver(pre_delete, sender=Asset, dispatch_uid="my_unique_identifier")
def on_asset_delete(sender, instance=None, **kwargs): def on_asset_delete(sender, instance=None, **kwargs):
# 过期节点资产数量 # 过期节点资产数量
nodes = instance.nodes.all() Node.refresh_nodes()
Node.expire_nodes_assets_amount(nodes)
@receiver(post_save, sender=SystemUser, dispatch_uid="my_unique_identifier") @receiver(post_save, sender=SystemUser, dispatch_uid="my_unique_identifier")
...@@ -80,19 +78,18 @@ def on_asset_node_changed(sender, instance=None, **kwargs): ...@@ -80,19 +78,18 @@ def on_asset_node_changed(sender, instance=None, **kwargs):
logger.debug("Asset nodes change signal received") logger.debug("Asset nodes change signal received")
Asset.expire_all_nodes_keys_cache() Asset.expire_all_nodes_keys_cache()
if isinstance(instance, Asset): if isinstance(instance, Asset):
if kwargs['action'] == 'pre_remove': # nodes = []
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) # if kwargs['action'] == 'pre_remove':
Node.expire_nodes_assets_amount(nodes) # nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
Node.expire_nodes_assets_amount(nodes)
system_users_assets = defaultdict(set) system_users_assets = defaultdict(set)
system_users = SystemUser.objects.filter(nodes__in=nodes) system_users = SystemUser.objects.filter(nodes__in=nodes)
# 清理节点缓存
for system_user in system_users: for system_user in system_users:
system_users_assets[system_user].update({instance}) system_users_assets[system_user].update({instance})
for system_user, assets in system_users_assets.items(): for system_user, assets in system_users_assets.items():
system_user.assets.add(*tuple(assets)) system_user.assets.add(*tuple(assets))
Node.refresh_nodes()
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
...@@ -100,7 +97,6 @@ def on_node_assets_changed(sender, instance=None, **kwargs): ...@@ -100,7 +97,6 @@ def on_node_assets_changed(sender, instance=None, **kwargs):
if isinstance(instance, Node): if isinstance(instance, Node):
logger.debug("Node assets change signal {} received".format(instance)) logger.debug("Node assets change signal {} received".format(instance))
# 当节点和资产关系发生改变时,过期资产数量缓存 # 当节点和资产关系发生改变时,过期资产数量缓存
instance.expire_assets_amount()
assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
# 重新关联系统用户和资产的关系 # 重新关联系统用户和资产的关系
...@@ -112,7 +108,7 @@ def on_node_assets_changed(sender, instance=None, **kwargs): ...@@ -112,7 +108,7 @@ def on_node_assets_changed(sender, instance=None, **kwargs):
@receiver(post_save, sender=Node) @receiver(post_save, sender=Node)
def on_node_update_or_created(sender, instance=None, created=False, **kwargs): def on_node_update_or_created(sender, instance=None, created=False, **kwargs):
if instance and not created: if instance and not created:
instance.expire_full_value() Node.refresh_nodes()
@receiver(post_save, sender=AuthBook) @receiver(post_save, sender=AuthBook)
......
...@@ -24,7 +24,7 @@ FORKS = 10 ...@@ -24,7 +24,7 @@ FORKS = 10
TIMEOUT = 60 TIMEOUT = 60
logger = get_logger(__file__) logger = get_logger(__file__)
CACHE_MAX_TIME = 60*60*2 CACHE_MAX_TIME = 60*60*2
disk_pattern = re.compile(r'^hd|sd|xvd|vd') disk_pattern = re.compile(r'^hd|sd|xvd|vd|nv')
PERIOD_TASK = os.environ.get("PERIOD_TASK", "on") PERIOD_TASK = os.environ.get("PERIOD_TASK", "on")
...@@ -62,7 +62,7 @@ def clean_hosts_by_protocol(system_user, assets): ...@@ -62,7 +62,7 @@ def clean_hosts_by_protocol(system_user, assets):
return hosts return hosts
@shared_task @shared_task(queue="ansible")
def set_assets_hardware_info(assets, result, **kwargs): def set_assets_hardware_info(assets, result, **kwargs):
""" """
Using ops task run result, to update asset info Using ops task run result, to update asset info
...@@ -106,7 +106,7 @@ def set_assets_hardware_info(assets, result, **kwargs): ...@@ -106,7 +106,7 @@ def set_assets_hardware_info(assets, result, **kwargs):
for dev, dev_info in info.get('ansible_devices', {}).items(): for dev, dev_info in info.get('ansible_devices', {}).items():
if disk_pattern.match(dev) and dev_info['removable'] == '0': if disk_pattern.match(dev) and dev_info['removable'] == '0':
disk_info[dev] = dev_info['size'] disk_info[dev] = dev_info['size']
___disk_total = '%s %s' % sum_capacity(disk_info.values()) ___disk_total = '%.1f %s' % sum_capacity(disk_info.values())
___disk_info = json.dumps(disk_info) ___disk_info = json.dumps(disk_info)
# ___platform = info.get('ansible_system', 'Unknown') # ___platform = info.get('ansible_system', 'Unknown')
...@@ -148,7 +148,7 @@ def update_assets_hardware_info_util(assets, task_name=None): ...@@ -148,7 +148,7 @@ def update_assets_hardware_info_util(assets, task_name=None):
return result return result
@shared_task @shared_task(queue="ansible")
def update_asset_hardware_info_manual(asset): def update_asset_hardware_info_manual(asset):
task_name = _("Update asset hardware info: {}").format(asset.hostname) task_name = _("Update asset hardware info: {}").format(asset.hostname)
update_assets_hardware_info_util( update_assets_hardware_info_util(
...@@ -156,7 +156,7 @@ def update_asset_hardware_info_manual(asset): ...@@ -156,7 +156,7 @@ def update_asset_hardware_info_manual(asset):
) )
@shared_task @shared_task(queue="ansible")
def update_assets_hardware_info_period(): def update_assets_hardware_info_period():
""" """
Update asset hardware period task Update asset hardware period task
...@@ -170,7 +170,7 @@ def update_assets_hardware_info_period(): ...@@ -170,7 +170,7 @@ def update_assets_hardware_info_period():
## ADMIN USER CONNECTIVE ## ## ADMIN USER CONNECTIVE ##
@shared_task @shared_task(queue="ansible")
def test_asset_connectivity_util(assets, task_name=None): def test_asset_connectivity_util(assets, task_name=None):
from ops.utils import update_or_create_ansible_task from ops.utils import update_or_create_ansible_task
...@@ -227,7 +227,7 @@ def test_asset_connectivity_util(assets, task_name=None): ...@@ -227,7 +227,7 @@ def test_asset_connectivity_util(assets, task_name=None):
return results_summary return results_summary
@shared_task @shared_task(queue="ansible")
def test_asset_connectivity_manual(asset): def test_asset_connectivity_manual(asset):
task_name = _("Test assets connectivity: {}").format(asset) task_name = _("Test assets connectivity: {}").format(asset)
summary = test_asset_connectivity_util([asset], task_name=task_name) summary = test_asset_connectivity_util([asset], task_name=task_name)
...@@ -238,7 +238,7 @@ def test_asset_connectivity_manual(asset): ...@@ -238,7 +238,7 @@ def test_asset_connectivity_manual(asset):
return True, "" return True, ""
@shared_task @shared_task(queue="ansible")
def test_admin_user_connectivity_util(admin_user, task_name): def test_admin_user_connectivity_util(admin_user, task_name):
""" """
Test asset admin user can connect or not. Using ansible api do that Test asset admin user can connect or not. Using ansible api do that
...@@ -254,7 +254,7 @@ def test_admin_user_connectivity_util(admin_user, task_name): ...@@ -254,7 +254,7 @@ def test_admin_user_connectivity_util(admin_user, task_name):
return summary return summary
@shared_task @shared_task(queue="ansible")
@register_as_period_task(interval=3600) @register_as_period_task(interval=3600)
def test_admin_user_connectivity_period(): def test_admin_user_connectivity_period():
""" """
...@@ -276,7 +276,7 @@ def test_admin_user_connectivity_period(): ...@@ -276,7 +276,7 @@ def test_admin_user_connectivity_period():
cache.set(key, 1, 60*40) cache.set(key, 1, 60*40)
@shared_task @shared_task(queue="ansible")
def test_admin_user_connectivity_manual(admin_user): def test_admin_user_connectivity_manual(admin_user):
task_name = _("Test admin user connectivity: {}").format(admin_user.name) task_name = _("Test admin user connectivity: {}").format(admin_user.name)
test_admin_user_connectivity_util(admin_user, task_name) test_admin_user_connectivity_util(admin_user, task_name)
...@@ -286,7 +286,7 @@ def test_admin_user_connectivity_manual(admin_user): ...@@ -286,7 +286,7 @@ def test_admin_user_connectivity_manual(admin_user):
## System user connective ## ## System user connective ##
@shared_task @shared_task(queue="ansible")
def test_system_user_connectivity_util(system_user, assets, task_name): def test_system_user_connectivity_util(system_user, assets, task_name):
""" """
Test system cant connect his assets or not. Test system cant connect his assets or not.
...@@ -344,14 +344,14 @@ def test_system_user_connectivity_util(system_user, assets, task_name): ...@@ -344,14 +344,14 @@ def test_system_user_connectivity_util(system_user, assets, task_name):
return results_summary return results_summary
@shared_task @shared_task(queue="ansible")
def test_system_user_connectivity_manual(system_user): def test_system_user_connectivity_manual(system_user):
task_name = _("Test system user connectivity: {}").format(system_user) task_name = _("Test system user connectivity: {}").format(system_user)
assets = system_user.get_all_assets() assets = system_user.get_all_assets()
return test_system_user_connectivity_util(system_user, assets, task_name) return test_system_user_connectivity_util(system_user, assets, task_name)
@shared_task @shared_task(queue="ansible")
def test_system_user_connectivity_a_asset(system_user, asset): def test_system_user_connectivity_a_asset(system_user, asset):
task_name = _("Test system user connectivity: {} => {}").format( task_name = _("Test system user connectivity: {} => {}").format(
system_user, asset system_user, asset
...@@ -359,7 +359,7 @@ def test_system_user_connectivity_a_asset(system_user, asset): ...@@ -359,7 +359,7 @@ def test_system_user_connectivity_a_asset(system_user, asset):
return test_system_user_connectivity_util(system_user, [asset], task_name) return test_system_user_connectivity_util(system_user, [asset], task_name)
@shared_task @shared_task(queue="ansible")
def test_system_user_connectivity_period(): def test_system_user_connectivity_period():
if PERIOD_TASK != "on": if PERIOD_TASK != "on":
logger.debug("Period task disabled, test system user connectivity pass") logger.debug("Period task disabled, test system user connectivity pass")
...@@ -483,7 +483,7 @@ def get_push_system_user_tasks(host, system_user): ...@@ -483,7 +483,7 @@ def get_push_system_user_tasks(host, system_user):
return tasks return tasks
@shared_task @shared_task(queue="ansible")
def push_system_user_util(system_user, assets, task_name): def push_system_user_util(system_user, assets, task_name):
from ops.utils import update_or_create_ansible_task from ops.utils import update_or_create_ansible_task
if not system_user.is_need_push(): if not system_user.is_need_push():
...@@ -519,14 +519,14 @@ def push_system_user_util(system_user, assets, task_name): ...@@ -519,14 +519,14 @@ def push_system_user_util(system_user, assets, task_name):
task.run() task.run()
@shared_task @shared_task(queue="ansible")
def push_system_user_to_assets_manual(system_user): def push_system_user_to_assets_manual(system_user):
assets = system_user.get_all_assets() assets = system_user.get_all_assets()
task_name = _("Push system users to assets: {}").format(system_user.name) task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name=task_name) return push_system_user_util(system_user, assets, task_name=task_name)
@shared_task @shared_task(queue="ansible")
def push_system_user_a_asset_manual(system_user, asset): def push_system_user_a_asset_manual(system_user, asset):
task_name = _("Push system users to asset: {} => {}").format( task_name = _("Push system users to asset: {} => {}").format(
system_user.name, asset system_user.name, asset
...@@ -534,7 +534,7 @@ def push_system_user_a_asset_manual(system_user, asset): ...@@ -534,7 +534,7 @@ def push_system_user_a_asset_manual(system_user, asset):
return push_system_user_util(system_user, [asset], task_name=task_name) return push_system_user_util(system_user, [asset], task_name=task_name)
@shared_task @shared_task(queue="ansible")
def push_system_user_to_assets(system_user, assets): def push_system_user_to_assets(system_user, assets):
task_name = _("Push system users to assets: {}").format(system_user.name) task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name) return push_system_user_util(system_user, assets, task_name)
...@@ -569,7 +569,7 @@ def get_test_asset_user_connectivity_tasks(asset): ...@@ -569,7 +569,7 @@ def get_test_asset_user_connectivity_tasks(asset):
return tasks return tasks
@shared_task @shared_task(queue="ansible")
def test_asset_user_connectivity_util(asset_user, task_name, run_as_admin=False): def test_asset_user_connectivity_util(asset_user, task_name, run_as_admin=False):
""" """
:param asset_user: <AuthBook>对象 :param asset_user: <AuthBook>对象
...@@ -602,7 +602,7 @@ def test_asset_user_connectivity_util(asset_user, task_name, run_as_admin=False) ...@@ -602,7 +602,7 @@ def test_asset_user_connectivity_util(asset_user, task_name, run_as_admin=False)
asset_user.set_connectivity(summary) asset_user.set_connectivity(summary)
@shared_task @shared_task(queue="ansible")
def test_asset_users_connectivity_manual(asset_users, run_as_admin=False): def test_asset_users_connectivity_manual(asset_users, run_as_admin=False):
""" """
:param asset_users: <AuthBook>对象 :param asset_users: <AuthBook>对象
......
...@@ -236,7 +236,8 @@ function onBodyMouseDown(event){ ...@@ -236,7 +236,8 @@ function onBodyMouseDown(event){
} }
function onRename(event, treeId, treeNode, isCancel){ function onRename(event, treeId, treeNode, isCancel){
var url = "{% url 'api-assets:node-detail' pk=DEFAULT_PK %}".replace("{{ DEFAULT_PK }}", current_node_id); var url = "{% url 'api-assets:node-detail' pk=DEFAULT_PK %}"
.replace("{{ DEFAULT_PK }}", current_node_id);
var data = {"value": treeNode.name}; var data = {"value": treeNode.name};
if (isCancel){ if (isCancel){
return return
...@@ -247,10 +248,13 @@ function onRename(event, treeId, treeNode, isCancel){ ...@@ -247,10 +248,13 @@ function onRename(event, treeId, treeNode, isCancel){
method: "PATCH", method: "PATCH",
success_message: "{% trans 'Rename success' %}", success_message: "{% trans 'Rename success' %}",
success: function () { success: function () {
treeNode.name = treeNode.name + ' (' + treeNode.meta.node.assets_amount + ')'; var assets_amount = treeNode.meta.node.assets_amount;
if (!assets_amount) {
assets_amount = 0;
}
treeNode.name = treeNode.name + ' (' + assets_amount + ')';
zTree.updateNode(treeNode); zTree.updateNode(treeNode);
console.log("Success: " + treeNode.name) },
}
}) })
} }
......
...@@ -88,9 +88,9 @@ ...@@ -88,9 +88,9 @@
<form> <form>
<tr> <tr>
<td colspan="2" class="no-borders"> <td colspan="2" class="no-borders">
<select data-placeholder="{% trans 'Select nodes' %}" id="nodes_selected" class="select2" style="width: 100%" multiple="" tabindex="4"> <select data-placeholder="{% trans 'Select nodes' %}" id="nodes_selected" class="nodes-select2" style="width: 100%" multiple="" tabindex="4">
{% for node in nodes %} {% for node in nodes %}
<option value="{{ node.id }}" id="opt_{{ node.id }}" >{{ node }}</option> <option value="{{ node.id }}" id="opt_{{ node.id }}" >{{ node.full_value }}</option>
{% endfor %} {% endfor %}
</select> </select>
</td> </td>
...@@ -140,7 +140,8 @@ function replaceNodeAssetsAdminUser(nodes) { ...@@ -140,7 +140,8 @@ function replaceNodeAssetsAdminUser(nodes) {
jumpserver.nodes_selected = {}; jumpserver.nodes_selected = {};
$(document).ready(function () { $(document).ready(function () {
$('.select2').select2() var url = "{% url 'api-assets:node-list' %}";
nodesSelect2Init(".nodes-select2", url)
.on('select2:select', function(evt) { .on('select2:select', function(evt) {
var data = evt.params.data; var data = evt.params.data;
jumpserver.nodes_selected[data.id] = data.text; jumpserver.nodes_selected[data.id] = data.text;
......
...@@ -110,6 +110,8 @@ $(document).ready(function () { ...@@ -110,6 +110,8 @@ $(document).ready(function () {
$('.select2').select2({ $('.select2').select2({
allowClear: true allowClear: true
}); });
var url = "{% url 'api-assets:node-list' %}";
nodesSelect2Init(".nodes-select2", url);
$(".labels").select2({ $(".labels").select2({
allowClear: true, allowClear: true,
templateSelection: format templateSelection: format
......
...@@ -195,10 +195,7 @@ ...@@ -195,10 +195,7 @@
<form> <form>
<tr> <tr>
<td colspan="2" class="no-borders"> <td colspan="2" class="no-borders">
<select data-placeholder="{% trans 'Nodes' %}" id="groups_selected" class="select2 groups" style="width: 100%" multiple="" tabindex="4"> <select data-placeholder="{% trans 'Nodes' %}" id="groups_selected" class="nodes-select2 groups" style="width: 100%" multiple="" tabindex="4">
{% for node in nodes_remain %}
<option value="{{ node.id }}" id="opt_{{ node.id }}" >{{ node }}</option>
{% endfor %}
</select> </select>
</td> </td>
</tr> </tr>
...@@ -211,7 +208,7 @@ ...@@ -211,7 +208,7 @@
{% for node in asset.nodes.all %} {% for node in asset.nodes.all %}
<tr> <tr>
<td ><b class="bdg_node" data-gid={{ node.id }}>{{ node }}</b></td> <td ><b class="bdg_node" data-gid={{ node.id }}>{{ node.full_value }}</b></td>
<td> <td>
<button class="btn btn-danger pull-right btn-xs btn-leave-node" type="button"><i class="fa fa-minus"></i></button> <button class="btn btn-danger pull-right btn-xs btn-leave-node" type="button"><i class="fa fa-minus"></i></button>
</td> </td>
...@@ -291,7 +288,9 @@ function refreshAssetHardware() { ...@@ -291,7 +288,9 @@ function refreshAssetHardware() {
$(document).ready(function () { $(document).ready(function () {
$('.select2.groups').select2().on('select2:select', function(evt) { var url = "{% url 'api-assets:node-list' %}";
nodesSelect2Init(".nodes-select2", url)
.on('select2:select', function(evt) {
var data = evt.params.data; var data = evt.params.data;
jumpserver.nodes_selected[data.id] = data.text; jumpserver.nodes_selected[data.id] = data.text;
}).on('select2:unselect', function(evt) { }).on('select2:unselect', function(evt) {
......
...@@ -442,9 +442,10 @@ $(document).ready(function(){ ...@@ -442,9 +442,10 @@ $(document).ready(function(){
var success = function () { var success = function () {
asset_table.ajax.reload() asset_table.ajax.reload()
}; };
var url = "{% url 'api-assets:node-remove-assets' pk=DEFAULT_PK %}".replace("{{ DEFAULT_PK }}", current_node_id);
requestApi({ requestApi({
'url': '/api/assets/v1/nodes/' + current_node_id + '/assets/remove/', 'url': url,
'method': 'PUT', 'method': 'PUT',
'body': JSON.stringify(data), 'body': JSON.stringify(data),
'success': success 'success': success
......
...@@ -88,10 +88,7 @@ ...@@ -88,10 +88,7 @@
<form> <form>
<tr> <tr>
<td colspan="2" class="no-borders"> <td colspan="2" class="no-borders">
<select data-placeholder="{% trans 'Add to node' %}" id="node_selected" class="select2" style="width: 100%" multiple="" tabindex="4"> <select data-placeholder="{% trans 'Add to node' %}" id="node_selected" class="nodes-select2" style="width: 100%" multiple="" tabindex="4">
{% for node in nodes_remain %}
<option value="{{ node.id }}" id="opt_{{ node.id }}" >{{ node }}</option>
{% endfor %}
</select> </select>
</td> </td>
</tr> </tr>
...@@ -104,7 +101,7 @@ ...@@ -104,7 +101,7 @@
{% for node in system_user.nodes.all|sort %} {% for node in system_user.nodes.all|sort %}
<tr> <tr>
<td ><b class="bdg_node" data-gid={{ node.id }}>{{ node }}</b></td> <td ><b class="bdg_node" data-gid={{ node.id }}>{{ node.full_value }}</b></td>
<td> <td>
<button class="btn btn-danger pull-right btn-xs btn-remove-from-node" type="button"><i class="fa fa-minus"></i></button> <button class="btn btn-danger pull-right btn-xs btn-remove-from-node" type="button"><i class="fa fa-minus"></i></button>
</td> </td>
...@@ -156,6 +153,8 @@ jumpserver.nodes_selected = {}; ...@@ -156,6 +153,8 @@ jumpserver.nodes_selected = {};
$(document).ready(function () { $(document).ready(function () {
$('.select2').select2() $('.select2').select2()
var url = "{% url 'api-assets:node-list' %}";
nodesSelect2Init(".nodes-select2", url)
.on('select2:select', function(evt) { .on('select2:select', function(evt) {
var data = evt.params.data; var data = evt.params.data;
jumpserver.nodes_selected[data.id] = data.text; jumpserver.nodes_selected[data.id] = data.text;
......
...@@ -21,9 +21,10 @@ ...@@ -21,9 +21,10 @@
{% block custom_foot_js %} {% block custom_foot_js %}
<script> <script>
var treeUrl = "{% url 'api-perms:my-nodes-as-tree' %}?&cache_policy=1"; var treeUrl = "{% url 'api-perms:my-nodes-children-as-tree' %}?&cache_policy=1";
var assetTableUrl = "{% url 'api-perms:my-assets' %}?cache_policy=1"; var assetTableUrl = "{% url 'api-perms:my-assets' %}?cache_policy=1";
var selectUrl = '{% url "api-perms:my-node-assets" node_id=DEFAULT_PK %}?cache_policy=1&all=1'; var selectUrl = '{% url "api-perms:my-node-assets" node_id=DEFAULT_PK %}?cache_policy=1&all=1';
var systemUsersUrl = "{% url 'api-perms:my-asset-system-users' asset_id=DEFAULT_PK %}";
var showAssetHref = false; // Need input default true var showAssetHref = false; // Need input default true
var actions = { var actions = {
targets: 4, createdCell: function (td, cellData) { targets: 4, createdCell: function (td, cellData) {
......
# coding:utf-8 # coding:utf-8
from django.urls import path from django.urls import path, re_path
from rest_framework_nested import routers from rest_framework_nested import routers
# from rest_framework.routers import DefaultRouter # from rest_framework.routers import DefaultRouter
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from common import api as capi
from .. import api from .. import api
app_name = 'assets' app_name = 'assets'
router = BulkRouter() router = BulkRouter()
router.register(r'assets', api.AssetViewSet, 'asset') router.register(r'assets', api.AssetViewSet, 'asset')
router.register(r'admin-user', api.AdminUserViewSet, 'admin-user') router.register(r'admin-users', api.AdminUserViewSet, 'admin-user')
router.register(r'system-user', api.SystemUserViewSet, 'system-user') router.register(r'system-users', api.SystemUserViewSet, 'system-user')
router.register(r'labels', api.LabelViewSet, 'label') router.register(r'labels', api.LabelViewSet, 'label')
router.register(r'nodes', api.NodeViewSet, 'node') router.register(r'nodes', api.NodeViewSet, 'node')
router.register(r'domain', api.DomainViewSet, 'domain') router.register(r'domains', api.DomainViewSet, 'domain')
router.register(r'gateway', api.GatewayViewSet, 'gateway') router.register(r'gateways', api.GatewayViewSet, 'gateway')
router.register(r'cmd-filter', api.CommandFilterViewSet, 'cmd-filter') router.register(r'cmd-filters', api.CommandFilterViewSet, 'cmd-filter')
router.register(r'asset-user', api.AssetUserViewSet, 'asset-user') router.register(r'asset-users', api.AssetUserViewSet, 'asset-user')
router.register(r'asset-user-info', api.AssetUserExportViewSet, 'asset-user-info') router.register(r'asset-users-info', api.AssetUserExportViewSet, 'asset-user-info')
cmd_filter_router = routers.NestedDefaultRouter(router, r'cmd-filter', lookup='filter') cmd_filter_router = routers.NestedDefaultRouter(router, r'cmd-filters', lookup='filter')
cmd_filter_router.register(r'rules', api.CommandFilterRuleViewSet, 'cmd-filter-rule') cmd_filter_router.register(r'rules', api.CommandFilterRuleViewSet, 'cmd-filter-rule')
urlpatterns = [ urlpatterns = [
path('assets-bulk/', api.AssetListUpdateApi.as_view(), name='asset-bulk-update'),
path('asset/update/select/',
api.AssetBulkUpdateSelectAPI.as_view(), name='asset-bulk-update-select'),
path('assets/<uuid:pk>/refresh/', path('assets/<uuid:pk>/refresh/',
api.AssetRefreshHardwareApi.as_view(), name='asset-refresh'), api.AssetRefreshHardwareApi.as_view(), name='asset-refresh'),
path('assets/<uuid:pk>/alive/', path('assets/<uuid:pk>/alive/',
...@@ -35,36 +34,36 @@ urlpatterns = [ ...@@ -35,36 +34,36 @@ urlpatterns = [
path('assets/<uuid:pk>/gateway/', path('assets/<uuid:pk>/gateway/',
api.AssetGatewayApi.as_view(), name='asset-gateway'), api.AssetGatewayApi.as_view(), name='asset-gateway'),
path('asset-user/auth-info/', path('asset-users/auth-info/',
api.AssetUserAuthInfoApi.as_view(), name='asset-user-auth-info'), api.AssetUserAuthInfoApi.as_view(), name='asset-user-auth-info'),
path('asset-user/test-connective/', path('asset-users/test-connective/',
api.AssetUserTestConnectiveApi.as_view(), name='asset-user-connective'), api.AssetUserTestConnectiveApi.as_view(), name='asset-user-connective'),
path('admin-user/<uuid:pk>/nodes/', path('admin-users/<uuid:pk>/nodes/',
api.ReplaceNodesAdminUserApi.as_view(), name='replace-nodes-admin-user'), api.ReplaceNodesAdminUserApi.as_view(), name='replace-nodes-admin-user'),
path('admin-user/<uuid:pk>/auth/', path('admin-users/<uuid:pk>/auth/',
api.AdminUserAuthApi.as_view(), name='admin-user-auth'), api.AdminUserAuthApi.as_view(), name='admin-user-auth'),
path('admin-user/<uuid:pk>/connective/', path('admin-users/<uuid:pk>/connective/',
api.AdminUserTestConnectiveApi.as_view(), name='admin-user-connective'), api.AdminUserTestConnectiveApi.as_view(), name='admin-user-connective'),
path('admin-user/<uuid:pk>/assets/', path('admin-users/<uuid:pk>/assets/',
api.AdminUserAssetsListView.as_view(), name='admin-user-assets'), api.AdminUserAssetsListView.as_view(), name='admin-user-assets'),
path('system-user/<uuid:pk>/auth-info/', path('system-users/<uuid:pk>/auth-info/',
api.SystemUserAuthInfoApi.as_view(), name='system-user-auth-info'), api.SystemUserAuthInfoApi.as_view(), name='system-user-auth-info'),
path('system-user/<uuid:pk>/asset/<uuid:aid>/auth-info/', path('system-users/<uuid:pk>/asset/<uuid:aid>/auth-info/',
api.SystemUserAssetAuthInfoApi.as_view(), name='system-user-asset-auth-info'), api.SystemUserAssetAuthInfoApi.as_view(), name='system-user-asset-auth-info'),
path('system-user/<uuid:pk>/assets/', path('system-users/<uuid:pk>/assets/',
api.SystemUserAssetsListView.as_view(), name='system-user-assets'), api.SystemUserAssetsListView.as_view(), name='system-user-assets'),
path('system-user/<uuid:pk>/push/', path('system-users/<uuid:pk>/push/',
api.SystemUserPushApi.as_view(), name='system-user-push'), api.SystemUserPushApi.as_view(), name='system-user-push'),
path('system-user/<uuid:pk>/asset/<uuid:aid>/push/', path('system-users/<uuid:pk>/asset/<uuid:aid>/push/',
api.SystemUserPushToAssetApi.as_view(), name='system-user-push-to-asset'), api.SystemUserPushToAssetApi.as_view(), name='system-user-push-to-asset'),
path('system-user/<uuid:pk>/asset/<uuid:aid>/test/', path('system-users/<uuid:pk>/asset/<uuid:aid>/test/',
api.SystemUserTestAssetConnectivityApi.as_view(), name='system-user-test-to-asset'), api.SystemUserTestAssetConnectivityApi.as_view(), name='system-user-test-to-asset'),
path('system-user/<uuid:pk>/connective/', path('system-users/<uuid:pk>/connective/',
api.SystemUserTestConnectiveApi.as_view(), name='system-user-connective'), api.SystemUserTestConnectiveApi.as_view(), name='system-user-connective'),
path('system-user/<uuid:pk>/cmd-filter-rules/', path('system-users/<uuid:pk>/cmd-filter-rules/',
api.SystemUserCommandFilterRuleListApi.as_view(), name='system-user-cmd-filter-rule-list'), api.SystemUserCommandFilterRuleListApi.as_view(), name='system-user-cmd-filter-rule-list'),
path('nodes/tree/', api.NodeListAsTreeApi.as_view(), name='node-tree'), path('nodes/tree/', api.NodeListAsTreeApi.as_view(), name='node-tree'),
...@@ -89,10 +88,14 @@ urlpatterns = [ ...@@ -89,10 +88,14 @@ urlpatterns = [
path('nodes/refresh-assets-amount/', path('nodes/refresh-assets-amount/',
api.RefreshAssetsAmount.as_view(), name='refresh-assets-amount'), api.RefreshAssetsAmount.as_view(), name='refresh-assets-amount'),
path('gateway/<uuid:pk>/test-connective/', path('gateways/<uuid:pk>/test-connective/',
api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),
] ]
urlpatterns += router.urls + cmd_filter_router.urls old_version_urlpatterns = [
re_path('(?P<resource>admin-user|system-user|domain|gateway|cmd-filter|asset-user)/.*', capi.redirect_plural_name_api)
]
urlpatterns += router.urls + cmd_filter_router.urls + old_version_urlpatterns
This diff is collapsed.
...@@ -83,7 +83,6 @@ class AdminUserDetailView(PermissionsMixin, DetailView): ...@@ -83,7 +83,6 @@ class AdminUserDetailView(PermissionsMixin, DetailView):
context = { context = {
'app': _('Assets'), 'app': _('Assets'),
'action': _('Admin user detail'), 'action': _('Admin user detail'),
'nodes': Node.get_queryset(),
} }
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
......
...@@ -16,8 +16,7 @@ from common.utils import get_object_or_none, get_logger ...@@ -16,8 +16,7 @@ from common.utils import get_object_or_none, get_logger
from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser
from common.const import KEY_CACHE_RESOURCES_ID from common.const import KEY_CACHE_RESOURCES_ID
from .. import forms from .. import forms
from ..utils import NodeUtil from ..models import Asset, Label, Node
from ..models import Asset, SystemUser, Label, Node
__all__ = [ __all__ = [
...@@ -196,13 +195,9 @@ class AssetDetailView(PermissionsMixin, DetailView): ...@@ -196,13 +195,9 @@ class AssetDetailView(PermissionsMixin, DetailView):
).select_related('admin_user', 'domain') ).select_related('admin_user', 'domain')
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
nodes_remain = Node.objects.exclude(assets=self.object).only('key')
util = NodeUtil()
nodes_remain = util.get_nodes_by_queryset(nodes_remain)
context = { context = {
'app': _('Assets'), 'app': _('Assets'),
'action': _('Asset detail'), 'action': _('Asset detail'),
'nodes_remain': nodes_remain,
} }
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
...@@ -98,14 +98,9 @@ class SystemUserAssetView(PermissionsMixin, DetailView): ...@@ -98,14 +98,9 @@ class SystemUserAssetView(PermissionsMixin, DetailView):
permission_classes = [IsOrgAdmin] permission_classes = [IsOrgAdmin]
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
from ..utils import NodeUtil
nodes_remain = Node.objects.exclude(systemuser=self.object)
util = NodeUtil()
nodes_remain = util.get_nodes_by_queryset(nodes_remain)
context = { context = {
'app': _('assets'), 'app': _('assets'),
'action': _('System user asset'), 'action': _('System user asset'),
'nodes_remain': nodes_remain
} }
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
from django.apps import AppConfig from django.apps import AppConfig
from django.conf import settings
from django.db.models.signals import post_save
class AuditsConfig(AppConfig): class AuditsConfig(AppConfig):
...@@ -6,3 +8,5 @@ class AuditsConfig(AppConfig): ...@@ -6,3 +8,5 @@ class AuditsConfig(AppConfig):
def ready(self): def ready(self):
from . import signals_handler from . import signals_handler
if settings.SYSLOG_ENABLE:
post_save.connect(signals_handler.on_audits_log_create)
# Generated by Django 2.1.7 on 2019-07-26 09:53
from django.db import migrations, models
def migrate_loginlog_reason_to_str(apps, schema_editor):
db_alias = schema_editor.connection.alias
reason_map = {
"0": "",
"1": 'Username/password check failed',
"2": 'MFA authentication failed',
"3": "Username does not exist",
"4": "Password expired",
}
model = apps.get_model("audits", "UserLoginLog")
for k, v in reason_map.items():
model.objects.using(db_alias).filter(reason=k).update(reason=v)
class Migration(migrations.Migration):
dependencies = [
('audits', '0005_auto_20190228_1715'),
]
operations = [
migrations.AlterField(
model_name='userloginlog',
name='reason',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Reason'),
),
migrations.RunPython(migrate_loginlog_reason_to_str),
]
...@@ -5,7 +5,7 @@ from django.db.models import Q ...@@ -5,7 +5,7 @@ from django.db.models import Q
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils import timezone from django.utils import timezone
from orgs.mixins import OrgModelMixin from orgs.mixins.models import OrgModelMixin
__all__ = [ __all__ = [
'FTPLog', 'OperateLog', 'PasswordChangeLog', 'UserLoginLog', 'FTPLog', 'OperateLog', 'PasswordChangeLog', 'UserLoginLog',
...@@ -72,20 +72,6 @@ class UserLoginLog(models.Model): ...@@ -72,20 +72,6 @@ class UserLoginLog(models.Model):
(MFA_UNKNOWN, _('-')), (MFA_UNKNOWN, _('-')),
) )
REASON_NOTHING = 0
REASON_PASSWORD = 1
REASON_MFA = 2
REASON_NOT_EXIST = 3
REASON_PASSWORD_EXPIRED = 4
REASON_CHOICE = (
(REASON_NOTHING, _('-')),
(REASON_PASSWORD, _('Username/password check failed')),
(REASON_MFA, _('MFA authentication failed')),
(REASON_NOT_EXIST, _("Username does not exist")),
(REASON_PASSWORD_EXPIRED, _("Password expired")),
)
STATUS_CHOICE = ( STATUS_CHOICE = (
(True, _('Success')), (True, _('Success')),
(False, _('Failed')) (False, _('Failed'))
...@@ -97,7 +83,7 @@ class UserLoginLog(models.Model): ...@@ -97,7 +83,7 @@ class UserLoginLog(models.Model):
city = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('Login city')) city = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('Login city'))
user_agent = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('User agent')) user_agent = models.CharField(max_length=254, blank=True, null=True, verbose_name=_('User agent'))
mfa = models.SmallIntegerField(default=MFA_UNKNOWN, choices=MFA_CHOICE, verbose_name=_('MFA')) mfa = models.SmallIntegerField(default=MFA_UNKNOWN, choices=MFA_CHOICE, verbose_name=_('MFA'))
reason = models.SmallIntegerField(default=0, choices=REASON_CHOICE, verbose_name=_('Reason')) reason = models.CharField(default='', max_length=128, blank=True, verbose_name=_('Reason'))
status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status')) status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status'))
datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login')) datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login'))
......
...@@ -3,11 +3,36 @@ ...@@ -3,11 +3,36 @@
from rest_framework import serializers from rest_framework import serializers
from .models import FTPLog from terminal.models import Session
from . import models
class FTPLogSerializer(serializers.ModelSerializer): class FTPLogSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = FTPLog model = models.FTPLog
fields = '__all__'
class LoginLogSerializer(serializers.ModelSerializer):
class Meta:
model = models.UserLoginLog
fields = '__all__'
class OperateLogSerializer(serializers.ModelSerializer):
class Meta:
model = models.OperateLog
fields = '__all__'
class PasswordChangeLogSerializer(serializers.ModelSerializer):
class Meta:
model = models.PasswordChangeLog
fields = '__all__'
class SessionAuditSerializer(serializers.ModelSerializer):
class Meta:
model = Session
fields = '__all__' fields = '__all__'
...@@ -4,13 +4,18 @@ ...@@ -4,13 +4,18 @@
from django.db.models.signals import post_save, post_delete from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.db import transaction from django.db import transaction
from rest_framework.renderers import JSONRenderer
from jumpserver.utils import current_request from jumpserver.utils import current_request
from common.utils import get_request_ip, get_logger from common.utils import get_request_ip, get_logger, get_syslogger
from users.models import User from users.models import User
from .models import OperateLog, PasswordChangeLog from terminal.models import Session
from . import models
from . import serializers
logger = get_logger(__name__) logger = get_logger(__name__)
sys_logger = get_syslogger("audits")
json_render = JSONRenderer()
MODELS_NEED_RECORD = ( MODELS_NEED_RECORD = (
...@@ -36,7 +41,7 @@ def create_operate_log(action, sender, resource): ...@@ -36,7 +41,7 @@ def create_operate_log(action, sender, resource):
} }
with transaction.atomic(): with transaction.atomic():
try: try:
OperateLog.objects.create(**data) models.OperateLog.objects.create(**data)
except Exception as e: except Exception as e:
logger.error("Create operate log error: {}".format(e)) logger.error("Create operate log error: {}".format(e))
...@@ -44,15 +49,15 @@ def create_operate_log(action, sender, resource): ...@@ -44,15 +49,15 @@ def create_operate_log(action, sender, resource):
@receiver(post_save, dispatch_uid="my_unique_identifier") @receiver(post_save, dispatch_uid="my_unique_identifier")
def on_object_created_or_update(sender, instance=None, created=False, **kwargs): def on_object_created_or_update(sender, instance=None, created=False, **kwargs):
if created: if created:
action = OperateLog.ACTION_CREATE action = models.OperateLog.ACTION_CREATE
else: else:
action = OperateLog.ACTION_UPDATE action = models.OperateLog.ACTION_UPDATE
create_operate_log(action, sender, instance) create_operate_log(action, sender, instance)
@receiver(post_delete, dispatch_uid="my_unique_identifier") @receiver(post_delete, dispatch_uid="my_unique_identifier")
def on_object_delete(sender, instance=None, **kwargs): def on_object_delete(sender, instance=None, **kwargs):
create_operate_log(OperateLog.ACTION_DELETE, sender, instance) create_operate_log(models.OperateLog.ACTION_DELETE, sender, instance)
@receiver(post_save, sender=User, dispatch_uid="my_unique_identifier") @receiver(post_save, sender=User, dispatch_uid="my_unique_identifier")
...@@ -61,7 +66,32 @@ def on_user_change_password(sender, instance=None, **kwargs): ...@@ -61,7 +66,32 @@ def on_user_change_password(sender, instance=None, **kwargs):
if not current_request or not current_request.user.is_authenticated: if not current_request or not current_request.user.is_authenticated:
return return
with transaction.atomic(): with transaction.atomic():
PasswordChangeLog.objects.create( models.PasswordChangeLog.objects.create(
user=instance, change_by=current_request.user, user=instance, change_by=current_request.user,
remote_addr=get_request_ip(current_request), remote_addr=get_request_ip(current_request),
) )
def on_audits_log_create(sender, instance=None, **kwargs):
if sender == models.UserLoginLog:
category = "login_log"
serializer = serializers.LoginLogSerializer
elif sender == models.FTPLog:
serializer = serializers.FTPLogSerializer
category = "ftp_log"
elif sender == models.OperateLog:
category = "operation_log"
serializer = serializers.OperateLogSerializer
elif sender == models.PasswordChangeLog:
category = "password_change_log"
serializer = serializers.PasswordChangeLogSerializer
elif sender == Session:
category = "host_session_log"
serializer = serializers.SessionAuditSerializer
else:
return
s = serializer(instance=instance)
data = json_render.render(s.data).decode(errors='ignore')
msg = "{} - {}".format(category, data)
sys_logger.info(msg)
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
<td class="text-center">{{ login_log.ip }}</td> <td class="text-center">{{ login_log.ip }}</td>
<td class="text-center">{{ login_log.city }}</td> <td class="text-center">{{ login_log.city }}</td>
<td class="text-center">{{ login_log.get_mfa_display }}</td> <td class="text-center">{{ login_log.get_mfa_display }}</td>
<td class="text-center">{{ login_log.get_reason_display }}</td> <td class="text-center">{% trans login_log.reason %}</td>
<td class="text-center">{{ login_log.get_status_display }}</td> <td class="text-center">{{ login_log.get_status_display }}</td>
<td class="text-center">{{ login_log.datetime }}</td> <td class="text-center">{{ login_log.datetime }}</td>
</tr> </tr>
......
...@@ -2,3 +2,6 @@ ...@@ -2,3 +2,6 @@
# #
from .auth import * from .auth import *
from .token import *
from .mfa import *
from .access_key import *
# -*- coding: utf-8 -*-
#
from rest_framework.viewsets import ModelViewSet
from common.permissions import IsValidUser
from .. import serializers
class AccessKeyViewSet(ModelViewSet):
permission_classes = (IsValidUser,)
serializer_class = serializers.AccessKeySerializer
search_fields = ['^id', '^secret']
def get_queryset(self):
return self.request.user.access_keys.all()
def perform_create(self, serializer):
user = self.request.user
user.create_access_key()
...@@ -16,15 +16,17 @@ from rest_framework.views import APIView ...@@ -16,15 +16,17 @@ from rest_framework.views import APIView
from common.utils import get_logger, get_request_ip from common.utils import get_logger, get_request_ip
from common.permissions import IsOrgAdminOrAppUser, IsValidUser from common.permissions import IsOrgAdminOrAppUser, IsValidUser
from orgs.mixins import RootOrgViewMixin from orgs.mixins.api import RootOrgViewMixin
from users.serializers import UserSerializer from users.serializers import UserSerializer
from users.models import User from users.models import User
from assets.models import Asset, SystemUser from assets.models import Asset, SystemUser
from audits.models import UserLoginLog as LoginLog from audits.models import UserLoginLog as LoginLog
from users.utils import ( from users.utils import (
check_user_valid, check_otp_code, increase_login_failed_count, check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count is_block_login, clean_failed_count
) )
from .. import const
from ..utils import check_user_valid
from ..serializers import OtpVerifySerializer from ..serializers import OtpVerifySerializer
from ..signals import post_auth_success, post_auth_failed from ..signals import post_auth_success, post_auth_failed
...@@ -53,27 +55,15 @@ class UserAuthApi(RootOrgViewMixin, APIView): ...@@ -53,27 +55,15 @@ class UserAuthApi(RootOrgViewMixin, APIView):
user, msg = self.check_user_valid(request) user, msg = self.check_user_valid(request)
if not user: if not user:
username = request.data.get('username', '') username = request.data.get('username', '')
exist = User.objects.filter(username=username).first() self.send_auth_signal(success=False, username=username, reason=msg)
reason = LoginLog.REASON_PASSWORD if exist else LoginLog.REASON_NOT_EXIST
self.send_auth_signal(success=False, username=username, reason=reason)
increase_login_failed_count(username, ip) increase_login_failed_count(username, ip)
return Response({'msg': msg}, status=401) return Response({'msg': msg}, status=401)
if user.password_has_expired:
self.send_auth_signal(
success=False, username=username,
reason=LoginLog.REASON_PASSWORD_EXPIRED
)
msg = _("The user {} password has expired, please update.".format(
user.username))
logger.info(msg)
return Response({'msg': msg}, status=401)
if not user.otp_enabled: if not user.otp_enabled:
self.send_auth_signal(success=True, user=user) self.send_auth_signal(success=True, user=user)
# 登陆成功,清除原来的缓存计数 # 登陆成功,清除原来的缓存计数
clean_failed_count(username, ip) clean_failed_count(username, ip)
token = user.create_bearer_token(request) token, expired_at = user.create_bearer_token(request)
return Response( return Response(
{'token': token, 'user': self.serializer_class(user).data} {'token': token, 'user': self.serializer_class(user).data}
) )
...@@ -167,10 +157,10 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView): ...@@ -167,10 +157,10 @@ class UserOtpAuthApi(RootOrgViewMixin, APIView):
status=401 status=401
) )
if not check_otp_code(user.otp_secret_key, otp_code): if not check_otp_code(user.otp_secret_key, otp_code):
self.send_auth_signal(success=False, username=user.username, reason=LoginLog.REASON_MFA) self.send_auth_signal(success=False, username=user.username, reason=const.mfa_failed)
return Response({'msg': _('MFA certification failed')}, status=401) return Response({'msg': _('MFA certification failed')}, status=401)
self.send_auth_signal(success=True, user=user) self.send_auth_signal(success=True, user=user)
token = user.create_bearer_token(request) token, expired_at = user.create_bearer_token(request)
data = {'token': token, 'user': self.serializer_class(user).data} data = {'token': token, 'user': self.serializer_class(user).data}
return Response(data) return Response(data)
......
# -*- coding: utf-8 -*-
#
from rest_framework.permissions import AllowAny
from rest_framework.generics import CreateAPIView
from .. import serializers
class MFAChallengeApi(CreateAPIView):
permission_classes = (AllowAny,)
serializer_class = serializers.MFAChallengeSerializer
# -*- coding: utf-8 -*-
#
import uuid
from django.core.cache import cache
from django.utils.translation import ugettext as _
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView
from drf_yasg.utils import swagger_auto_schema
from common.utils import get_request_ip, get_logger
from users.utils import (
check_otp_code, increase_login_failed_count,
is_block_login, clean_failed_count
)
from ..utils import check_user_valid
from ..signals import post_auth_success, post_auth_failed
from .. import serializers
logger = get_logger(__name__)
__all__ = ['TokenCreateApi']
class AuthFailedError(Exception):
def __init__(self, msg, reason=None):
self.msg = msg
self.reason = reason
class MFARequiredError(Exception):
pass
class TokenCreateApi(CreateAPIView):
permission_classes = (AllowAny,)
serializer_class = serializers.BearerTokenSerializer
@staticmethod
def check_is_block(username, ip):
if is_block_login(username, ip):
msg = _("Log in frequently and try again later")
logger.warn(msg + ': ' + username + ':' + ip)
raise AuthFailedError(msg)
def check_user_valid(self):
request = self.request
username = request.data.get('username', '')
password = request.data.get('password', '')
public_key = request.data.get('public_key', '')
user, msg = check_user_valid(
username=username, password=password,
public_key=public_key
)
if not user:
raise AuthFailedError(msg)
return user
def create(self, request, *args, **kwargs):
username = self.request.data.get('username')
ip = self.request.data.get('remote_addr', None)
ip = ip or get_request_ip(self.request)
user = None
try:
self.check_is_block(username, ip)
user = self.check_user_valid()
if user.otp_enabled:
raise MFARequiredError()
self.send_auth_signal(success=True, user=user)
clean_failed_count(username, ip)
return super().create(request, *args, **kwargs)
except AuthFailedError as e:
increase_login_failed_count(username, ip)
self.send_auth_signal(success=False, user=user, username=username, reason=str(e))
return Response({'msg': str(e)}, status=401)
except MFARequiredError:
msg = _("MFA required")
seed = uuid.uuid4().hex
cache.set(seed, user.username, 300)
resp = {'msg': msg, "choices": ["otp"], "req": seed}
return Response(resp, status=300)
def send_auth_signal(self, success=True, user=None, username='', reason=''):
if success:
post_auth_success.send(
sender=self.__class__, user=user, request=self.request
)
else:
post_auth_failed.send(
sender=self.__class__, username=username,
request=self.request, reason=reason
)
...@@ -11,6 +11,7 @@ from django.utils.six import text_type ...@@ -11,6 +11,7 @@ from django.utils.six import text_type
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import authentication, exceptions from rest_framework import authentication, exceptions
from common.auth import signature
from rest_framework.authentication import CSRFCheck from rest_framework.authentication import CSRFCheck
from common.utils import get_object_or_none, make_signature, http_to_unixtime from common.utils import get_object_or_none, make_signature, http_to_unixtime
...@@ -108,8 +109,8 @@ class AccessKeyAuthentication(authentication.BaseAuthentication): ...@@ -108,8 +109,8 @@ class AccessKeyAuthentication(authentication.BaseAuthentication):
class AccessTokenAuthentication(authentication.BaseAuthentication): class AccessTokenAuthentication(authentication.BaseAuthentication):
keyword = 'Bearer' keyword = 'Bearer'
model = get_user_model()
expiration = settings.TOKEN_EXPIRATION or 3600 expiration = settings.TOKEN_EXPIRATION or 3600
model = get_user_model()
def authenticate(self, request): def authenticate(self, request):
auth = authentication.get_authorization_header(request).split() auth = authentication.get_authorization_header(request).split()
...@@ -133,8 +134,9 @@ class AccessTokenAuthentication(authentication.BaseAuthentication): ...@@ -133,8 +134,9 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
return self.authenticate_credentials(token) return self.authenticate_credentials(token)
def authenticate_credentials(self, token): def authenticate_credentials(self, token):
model = get_user_model()
user_id = cache.get(token) user_id = cache.get(token)
user = get_object_or_none(self.model, id=user_id) user = get_object_or_none(model, id=user_id)
if not user: if not user:
msg = _('Invalid token or cache refreshed.') msg = _('Invalid token or cache refreshed.')
...@@ -167,3 +169,25 @@ class SessionAuthentication(authentication.SessionAuthentication): ...@@ -167,3 +169,25 @@ class SessionAuthentication(authentication.SessionAuthentication):
# CSRF passed with authenticated user # CSRF passed with authenticated user
return user, None return user, None
class SignatureAuthentication(signature.SignatureAuthentication):
# The HTTP header used to pass the consumer key ID.
# A method to fetch (User instance, user_secret_string) from the
# consumer key ID, or None in case it is not found. Algorithm
# will be what the client has sent, in the case that both RSA
# and HMAC are supported at your site (and also for expansion).
model = get_user_model()
def fetch_user_data(self, key_id, algorithm="hmac-sha256"):
# ...
# example implementation:
try:
key = AccessKey.objects.get(id=key_id)
if not key.is_active:
return None, None
user, secret = key.user, str(key.secret)
return user, secret
except AccessKey.DoesNotExist:
return None, None
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext_lazy as _
password_failed = _('Username/password check failed')
mfa_failed = _('MFA authentication failed')
user_not_exist = _("Username does not exist")
password_expired = _("Password expired")
user_invalid = _('Disabled or expired')
# Generated by Django 2.1.7 on 2019-07-29 06:23
import datetime
from django.db import migrations, models
from django.utils.timezone import utc
class Migration(migrations.Migration):
dependencies = [
('authentication', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='accesskey',
name='date_created',
field=models.DateTimeField(auto_now_add=True, default=datetime.datetime(2019, 7, 29, 6, 23, 54, 115123, tzinfo=utc)),
preserve_default=False,
),
migrations.AddField(
model_name='accesskey',
name='is_active',
field=models.BooleanField(default=True, verbose_name='Active'),
),
]
...@@ -12,6 +12,8 @@ class AccessKey(models.Model): ...@@ -12,6 +12,8 @@ class AccessKey(models.Model):
default=uuid.uuid4, editable=False) default=uuid.uuid4, editable=False)
user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name='User', user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name='User',
on_delete=models.CASCADE, related_name='access_keys') on_delete=models.CASCADE, related_name='access_keys')
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
date_created = models.DateTimeField(auto_now_add=True)
def get_id(self): def get_id(self):
return str(self.id) return str(self.id)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.core.cache import cache
from rest_framework import serializers from rest_framework import serializers
from users.models import User
from .models import AccessKey from .models import AccessKey
__all__ = ['AccessKeySerializer'] __all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer',
]
class AccessKeySerializer(serializers.ModelSerializer): class AccessKeySerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AccessKey model = AccessKey
fields = ['id', 'secret'] fields = ['id', 'secret', 'is_active', 'date_created']
read_only_fields = ['id', 'secret'] read_only_fields = ['id', 'secret', 'date_created']
class OtpVerifySerializer(serializers.Serializer): class OtpVerifySerializer(serializers.Serializer):
code = serializers.CharField(max_length=6, min_length=6) code = serializers.CharField(max_length=6, min_length=6)
class BearerTokenMixin(serializers.Serializer):
token = serializers.CharField(read_only=True)
keyword = serializers.SerializerMethodField()
date_expired = serializers.DateTimeField(read_only=True)
@staticmethod
def get_keyword(obj):
return 'Bearer'
def create_response(self, username):
request = self.context.get("request")
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise serializers.ValidationError("username %s not exist" % username)
token, date_expired = user.create_bearer_token(request)
instance = {
"username": username,
"token": token,
"date_expired": date_expired,
}
return instance
def update(self, instance, validated_data):
pass
class BearerTokenSerializer(BearerTokenMixin, serializers.Serializer):
username = serializers.CharField()
password = serializers.CharField(write_only=True, allow_null=True,
required=False)
public_key = serializers.CharField(write_only=True, allow_null=True,
required=False)
def create(self, validated_data):
username = validated_data.get("username")
return self.create_response(username)
class MFAChallengeSerializer(BearerTokenMixin, serializers.Serializer):
req = serializers.CharField(write_only=True)
auth_type = serializers.CharField(write_only=True)
code = serializers.CharField(write_only=True)
def validate_req(self, attr):
username = cache.get(attr)
if not username:
raise serializers.ValidationError("Not valid, may be expired")
self.context["username"] = username
def validate_code(self, code):
username = self.context["username"]
user = User.objects.get(username=username)
ok = user.check_otp(code)
if not ok:
msg = "Otp code not valid, may be expired"
raise serializers.ValidationError(msg)
def create(self, validated_data):
username = self.context["username"]
return self.create_response(username)
{% extends '_modal.html' %}
{% load i18n %}
{% load static %}
{% block modal_id %}access_key_modal{% endblock %}
{% block modal_class %}modal-lg{% endblock %}
{% block modal_title%}{% trans "API key list" %}{% endblock %}
{% block modal_body %}
<style>
.inmodal .modal-body {
background: #fff;
}
#access_key_list_table_wrapper {
padding-top: 10px;
}
</style>
<div class="alert alert-info help-message">
{% trans 'Using api key sign api header, every requests header difference'%}, <a href="https://tools.ietf.org/html/draft-cavage-http-signatures-08">{% trans 'docs' %} </a>
</div>
<table class="table table-striped table-bordered table-hover " id="access_key_list_table" style="padding-top: 10px">
<thead>
<tr>
<th class="text-center">
<input type="checkbox" id="check_all" class="ipt_check_all" >
</th>
<th class="text-center">{% trans 'ID' %}</th>
<th class="text-center">{% trans 'Secret' %}</th>
<th class="text-center">{% trans 'Active' %}</th>
<th class="text-center">{% trans 'Date' %}</th>
<th class="text-center">{% trans 'Action' %}</th>
</tr>
</thead>
<tbody>
</tbody>
</table>
<div id="uc" hidden>
<button class="btn btn-primary btn-sm" id="create-btn" href="#"> {% trans "Create" %} </button>
</div>
<script>
var table = null;
function initTable() {
var options = {
ele: $('#access_key_list_table'),
columnDefs: [
{targets: 2, createdCell: function (td, cellData) {
var btn = '<button class="btn btn-primary btn-xs btn-secret" data-secret="SECRET">{% trans 'Show' %}</button>';
btn = btn.replace("SECRET", cellData);
$(td).html(btn)
}},
{targets: 3, createdCell: function (td, cellData) {
if (cellData) {
$(td).html('<i class="fa fa-check text-navy"></i>')
} else {
$(td).html('<i class="fa fa-times text-danger"></i>')
}
}},
{targets: 4, createdCell: function (td, cellData) {
var date = toSafeLocalDateStr(cellData);
$(td).html(date)
}},
{targets: 5, createdCell: function (td, cellData, rowData) {
var btn = '';
var btn_del = '<a class="btn btn-xs btn-danger m-l-xs btn-del" data-id="ID">{% trans "Delete" %}</a>';
var btn_inactive = '<a class="btn btn-xs btn-info m-l-xs btn-inactive" data-id="ID">{% trans "Disable" %}</a>';
var btn_active = '<a class="btn btn-xs btn-primary m-l-xs btn-active" data-id="ID">{% trans "Enable" %}</a>';
btn += btn_del;
if (rowData.is_active) {
btn += btn_inactive
} else {
btn += btn_active
}
btn = btn.replaceAll("ID", cellData);
$(td).html(btn);
}}
],
ajax_url: '{% url "api-auth:access-key-list" %}',
columns: [
{data: "id"},
{data: "id"},
{data: "secret"},
{data: "is_active"},
{data: "date_created"},
{data: "id", orderable: false}
],
uc_html: $('#uc').html()
};
table = jumpserver.initServerSideDataTable(options);
}
$(document).ready(function () {
}).on("show.bs.modal", "#access_key_modal", function () {
if (!table) {
initTable()
}
}).on("click", "#create-btn", function () {
var url = "{% url "api-auth:access-key-list" %}";
var data = {
url: url,
method: 'POST',
success: function () {
table.ajax.reload();
}
};
requestApi(data)
}).on("click", ".btn-secret", function () {
var $this = $(this);
$this.parent().html($this.data("secret"))
}).on("click", ".btn-del", function () {
var url = "{% url "api-auth:access-key-detail" pk=DEFAULT_PK %}";
url = url.replace("{{ DEFAULT_PK }}", $(this).data("id")) ;
objectDelete($(this), $(this).data("id"), url);
}).on("click", ".btn-active", function () {
var url = "{% url "api-auth:access-key-detail" pk=DEFAULT_PK %}";
url = url.replace("{{ DEFAULT_PK }}", $(this).data("id")) ;
var data = {
url: url,
body: JSON.stringify({"is_active": true}),
method: "PATCH",
success: function () {
table.ajax.reload();
}
};
requestApi(data)
}).on("click", ".btn-inactive", function () {
var url = "{% url "api-auth:access-key-detail" pk=DEFAULT_PK %}";
url = url.replace("{{ DEFAULT_PK }}", $(this).data("id")) ;
var data = {
url: url,
body: JSON.stringify({"is_active": false}),
method: "PATCH",
success: function () {
table.ajax.reload();
}
};
requestApi(data)
})
</script>
{% endblock %}
{% block modal_button %}
<button data-dismiss="modal" class="btn btn-white close_btn2" type="button">{% trans "Close" %}</button>
{% endblock %}
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
<link href="{% static 'css/login-style.css' %}" rel="stylesheet"> <link href="{% static 'css/login-style.css' %}" rel="stylesheet">
<!-- scripts --> <!-- scripts -->
<script src="{% static 'js/jquery-2.1.1.js' %}"></script> <script src="{% static 'js/jquery-3.1.1.min.js' %}"></script>
<script src="{% static 'js/plugins/sweetalert/sweetalert.min.js' %}"></script> <script src="{% static 'js/plugins/sweetalert/sweetalert.min.js' %}"></script>
<script src="{% static 'js/bootstrap.min.js' %}"></script> <script src="{% static 'js/bootstrap.min.js' %}"></script>
<script src="{% static 'js/plugins/datatables/datatables.min.js' %}"></script> <script src="{% static 'js/plugins/datatables/datatables.min.js' %}"></script>
......
...@@ -4,18 +4,27 @@ ...@@ -4,18 +4,27 @@
from __future__ import absolute_import from __future__ import absolute_import
from django.urls import path from django.urls import path
from rest_framework.routers import DefaultRouter
from .. import api from .. import api
router = DefaultRouter()
router.register('access-keys', api.AccessKeyViewSet, 'access-key')
app_name = 'authentication' app_name = 'authentication'
urlpatterns = [ urlpatterns = [
# path('token/', api.UserToken.as_view(), name='user-token'), # path('token/', api.UserToken.as_view(), name='user-token'),
path('auth/', api.UserAuthApi.as_view(), name='user-auth'), path('auth/', api.UserAuthApi.as_view(), name='user-auth'),
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'),
path('connection-token/', path('connection-token/',
api.UserConnectionTokenApi.as_view(), name='connection-token'), api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('otp/auth/', api.UserOtpAuthApi.as_view(), name='user-otp-auth'), path('otp/auth/', api.UserOtpAuthApi.as_view(), name='user-otp-auth'),
path('otp/verify/', api.UserOtpVerifyApi.as_view(), name='user-otp-verify'), path('otp/verify/', api.UserOtpVerifyApi.as_view(), name='user-otp-verify'),
] ]
urlpatterns += router.urls
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from common.utils import get_ip_city, validate_ip from django.contrib.auth import authenticate
from common.utils import get_ip_city, get_object_or_none, validate_ip
from users.models import User
from . import const
def write_login_log(*args, **kwargs): def write_login_log(*args, **kwargs):
...@@ -16,3 +20,36 @@ def write_login_log(*args, **kwargs): ...@@ -16,3 +20,36 @@ def write_login_log(*args, **kwargs):
kwargs.update({'ip': ip, 'city': city}) kwargs.update({'ip': ip, 'city': city})
UserLoginLog.objects.create(**kwargs) UserLoginLog.objects.create(**kwargs)
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email', None)
username = kwargs.pop('username', None)
if username:
user = get_object_or_none(User, username=username)
elif email:
user = get_object_or_none(User, email=email)
else:
user = None
if user is None:
return None, const.user_not_exist
elif not user.is_valid:
return None, const.user_invalid
elif user.password_has_expired:
return None, const.password_expired
if password and authenticate(username=username, password=password):
return user, ''
if public_key and user.public_key:
public_key_saved = user.public_key.split()
if len(public_key_saved) == 1:
if public_key == public_key_saved[0]:
return user, ''
elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]:
return user, ''
return None, const.password_failed
...@@ -26,6 +26,7 @@ from users.utils import ( ...@@ -26,6 +26,7 @@ from users.utils import (
) )
from ..signals import post_auth_success, post_auth_failed from ..signals import post_auth_success, post_auth_failed
from .. import forms from .. import forms
from .. import const
__all__ = [ __all__ = [
...@@ -81,7 +82,7 @@ class UserLoginView(FormView): ...@@ -81,7 +82,7 @@ class UserLoginView(FormView):
user = form.get_user() user = form.get_user()
# user password expired # user password expired
if user.password_has_expired: if user.password_has_expired:
reason = LoginLog.REASON_PASSWORD_EXPIRED reason = const.password_expired
self.send_auth_signal(success=False, username=user.username, reason=reason) self.send_auth_signal(success=False, username=user.username, reason=reason)
return self.render_to_response(self.get_context_data(password_expired=True)) return self.render_to_response(self.get_context_data(password_expired=True))
...@@ -96,7 +97,7 @@ class UserLoginView(FormView): ...@@ -96,7 +97,7 @@ class UserLoginView(FormView):
# write login failed log # write login failed log
username = form.cleaned_data.get('username') username = form.cleaned_data.get('username')
exist = User.objects.filter(username=username).first() exist = User.objects.filter(username=username).first()
reason = LoginLog.REASON_PASSWORD if exist else LoginLog.REASON_NOT_EXIST reason = const.password_failed if exist else const.user_not_exist
# limit user login failed count # limit user login failed count
ip = get_request_ip(self.request) ip = get_request_ip(self.request)
increase_login_failed_count(username, ip) increase_login_failed_count(username, ip)
...@@ -167,7 +168,7 @@ class UserLoginOtpView(FormView): ...@@ -167,7 +168,7 @@ class UserLoginOtpView(FormView):
else: else:
self.send_auth_signal( self.send_auth_signal(
success=False, username=user.username, success=False, username=user.username,
reason=LoginLog.REASON_MFA reason=const.mfa_failed
) )
form.add_error( form.add_error(
'otp_code', _('MFA code invalid, or ntp sync server time') 'otp_code', _('MFA code invalid, or ntp sync server time')
......
# -*- coding: utf-8 -*-
#
from rest_framework import authentication
from rest_framework import exceptions
from httpsig import HeaderVerifier, utils
"""
Reusing failure exceptions serves several purposes:
1. Lack of useful information regarding the failure inhibits attackers
from learning about valid keyIDs or other forms of information leakage.
Using the same actual object for any failure makes preventing such
leakage through mistakenly-distinct error messages less likely.
2. In an API scenario, the object is created once and raised many times
rather than generated on every failure, which could lead to higher loads
or memory usage in high-volume attack scenarios.
"""
FAILED = exceptions.AuthenticationFailed('Invalid signature.')
class SignatureAuthentication(authentication.BaseAuthentication):
"""
DRF authentication class for HTTP Signature support.
You must subclass this class in your own project and implement the
`fetch_user_data(self, keyId, algorithm)` method, returning a tuple of
the User object and a bytes object containing the user's secret. Note
that key_id and algorithm are DIRTY as they are supplied by the client
and so must be verified in your subclass!
You may set the following class properties in your subclass to configure
authentication for your particular use case:
:param www_authenticate_realm: Default: "api"
:param required_headers: Default: ["(request-target)", "date"]
"""
www_authenticate_realm = "api"
required_headers = ["(request-target)", "date"]
def fetch_user_data(self, key_id, algorithm=None):
"""Retuns a tuple (User, secret) or (None, None)."""
raise NotImplementedError()
def authenticate_header(self, request):
"""
DRF sends this for unauthenticated responses if we're the primary
authenticator.
"""
h = " ".join(self.required_headers)
return 'Signature realm="%s",headers="%s"' % (
self.www_authenticate_realm, h)
def authenticate(self, request):
"""
Perform the actual authentication.
Note that the exception raised is always the same. This is so that we
don't leak information about in/valid keyIds and other such useful
things.
"""
auth_header = authentication.get_authorization_header(request)
if not auth_header or len(auth_header) == 0:
return None
method, fields = utils.parse_authorization_header(auth_header)
# Ignore foreign Authorization headers.
if method.lower() != 'signature':
return None
# Verify basic header structure.
if len(fields) == 0:
raise FAILED
# Ensure all required fields were included.
if len({"keyid", "algorithm", "signature"} - set(fields.keys())) > 0:
raise FAILED
# Fetch the secret associated with the keyid
user, secret = self.fetch_user_data(
fields["keyid"],
algorithm=fields["algorithm"]
)
if not (user and secret):
raise FAILED
# Gather all request headers and translate them as stated in the Django docs:
# https://docs.djangoproject.com/en/1.6/ref/request-response/#django.http.HttpRequest.META
headers = {}
for key in request.META.keys():
if key.startswith("HTTP_") or \
key in ("CONTENT_TYPE", "CONTENT_LENGTH"):
header = key[5:].lower().replace('_', '-')
headers[header] = request.META[key]
# Verify headers
hs = HeaderVerifier(
headers,
secret,
required_headers=self.required_headers,
method=request.method.lower(),
path=request.get_full_path()
)
# All of that just to get to this.
if not hs.verify():
raise FAILED
return user, fields["keyid"]
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .mixins import BulkListSerializerMixin
from rest_framework_bulk.serializers import BulkListSerializer from rest_framework_bulk.serializers import BulkListSerializer
from rest_framework import serializers
from .mixins import BulkListSerializerMixin
class AdaptedBulkListSerializer(BulkListSerializerMixin, BulkListSerializer): class AdaptedBulkListSerializer(BulkListSerializerMixin, BulkListSerializer):
pass pass
class CeleryTaskSerializer(serializers.Serializer):
task = serializers.CharField(read_only=True)
...@@ -29,6 +29,6 @@ def send_mail_async(*args, **kwargs): ...@@ -29,6 +29,6 @@ def send_mail_async(*args, **kwargs):
args = tuple(args) args = tuple(args)
try: try:
send_mail(*args, **kwargs) return send_mail(*args, **kwargs)
except Exception as e: except Exception as e:
logger.error("Sending mail error: {}".format(e)) logger.error("Sending mail error: {}".format(e))
...@@ -31,6 +31,10 @@ def get_logger(name=None): ...@@ -31,6 +31,10 @@ def get_logger(name=None):
return logging.getLogger('jumpserver.%s' % name) return logging.getLogger('jumpserver.%s' % name)
def get_syslogger(name=None):
return logging.getLogger('jms.%s' % name)
def timesince(dt, since='', default="just now"): def timesince(dt, since='', default="just now"):
""" """
Returns string representing "time since" e.g. Returns string representing "time since" e.g.
......
...@@ -379,6 +379,8 @@ defaults = { ...@@ -379,6 +379,8 @@ defaults = {
'ASSETS_PERM_CACHE_TIME': 3600*24, 'ASSETS_PERM_CACHE_TIME': 3600*24,
'SECURITY_MFA_VERIFY_TTL': 3600, 'SECURITY_MFA_VERIFY_TTL': 3600,
'ASSETS_PERM_CACHE_ENABLE': False, 'ASSETS_PERM_CACHE_ENABLE': False,
'SYSLOG_ADDR': '', # '192.168.0.1:514'
'SYSLOG_FACILITY': 'user',
'PERM_SINGLE_ASSET_TO_UNGROUP_NODE': False, 'PERM_SINGLE_ASSET_TO_UNGROUP_NODE': False,
} }
......
...@@ -217,6 +217,9 @@ LOGGING = { ...@@ -217,6 +217,9 @@ LOGGING = {
'simple': { 'simple': {
'format': '%(levelname)s %(message)s' 'format': '%(levelname)s %(message)s'
}, },
'syslog': {
'format': 'jumpserver: %(message)s'
},
'msg': { 'msg': {
'format': '%(message)s' 'format': '%(message)s'
} }
...@@ -249,19 +252,10 @@ LOGGING = { ...@@ -249,19 +252,10 @@ LOGGING = {
'backupCount': 7, 'backupCount': 7,
'filename': ANSIBLE_LOG_FILE, 'filename': ANSIBLE_LOG_FILE,
}, },
'gunicorn_file': { 'syslog': {
'encoding': 'utf8', 'level': 'INFO',
'level': 'DEBUG', 'class': 'logging.NullHandler',
'class': 'logging.handlers.RotatingFileHandler', 'formatter': 'syslog'
'formatter': 'msg',
'maxBytes': 1024*1024*100,
'backupCount': 2,
'filename': GUNICORN_LOG_FILE,
},
'gunicorn_console': {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'msg'
}, },
}, },
'loggers': { 'loggers': {
...@@ -271,25 +265,17 @@ LOGGING = { ...@@ -271,25 +265,17 @@ LOGGING = {
'level': LOG_LEVEL, 'level': LOG_LEVEL,
}, },
'django.request': { 'django.request': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file', 'syslog'],
'level': LOG_LEVEL, 'level': LOG_LEVEL,
'propagate': False, 'propagate': False,
}, },
'django.server': { 'django.server': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file', 'syslog'],
'level': LOG_LEVEL, 'level': LOG_LEVEL,
'propagate': False, 'propagate': False,
}, },
'jumpserver': { 'jumpserver': {
'handlers': ['console', 'file'], 'handlers': ['console', 'file', 'syslog'],
'level': LOG_LEVEL,
},
'jumpserver.users.api': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL,
},
'jumpserver.users.view': {
'handlers': ['console', 'file'],
'level': LOG_LEVEL, 'level': LOG_LEVEL,
}, },
'ops.ansible_api': { 'ops.ansible_api': {
...@@ -300,17 +286,28 @@ LOGGING = { ...@@ -300,17 +286,28 @@ LOGGING = {
'handlers': ['console', 'file'], 'handlers': ['console', 'file'],
'level': "INFO", 'level': "INFO",
}, },
'gunicorn': { 'jms_audits': {
'handlers': ['gunicorn_console', 'gunicorn_file'], 'handlers': ['syslog'],
'level': 'INFO', 'level': 'INFO'
}, },
# 'django.db': { 'django.db': {
# 'handlers': ['console', 'file'], 'handlers': ['console', 'file'],
# 'level': 'DEBUG' 'level': 'DEBUG'
# } }
} }
} }
SYSLOG_ENABLE = False
if CONFIG.SYSLOG_ADDR != '' and len(CONFIG.SYSLOG_ADDR.split(':')) == 2:
host, port = CONFIG.SYSLOG_ADDR.split(':')
SYSLOG_ENABLE = True
LOGGING['handlers']['syslog'].update({
'class': 'logging.handlers.SysLogHandler',
'facility': CONFIG.SYSLOG_FACILITY,
'address': (host, int(port)),
})
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/1.10/topics/i18n/ # https://docs.djangoproject.com/en/1.10/topics/i18n/
# LANGUAGE_CODE = 'en' # LANGUAGE_CODE = 'en'
...@@ -391,6 +388,7 @@ REST_FRAMEWORK = { ...@@ -391,6 +388,7 @@ REST_FRAMEWORK = {
'authentication.backends.api.AccessKeyAuthentication', 'authentication.backends.api.AccessKeyAuthentication',
'authentication.backends.api.AccessTokenAuthentication', 'authentication.backends.api.AccessTokenAuthentication',
'authentication.backends.api.PrivateTokenAuthentication', 'authentication.backends.api.PrivateTokenAuthentication',
'authentication.backends.api.SignatureAuthentication',
'authentication.backends.api.SessionAuthentication', 'authentication.backends.api.SessionAuthentication',
), ),
'DEFAULT_FILTER_BACKENDS': ( 'DEFAULT_FILTER_BACKENDS': (
...@@ -403,7 +401,7 @@ REST_FRAMEWORK = { ...@@ -403,7 +401,7 @@ REST_FRAMEWORK = {
'SEARCH_PARAM': "search", 'SEARCH_PARAM': "search",
'DATETIME_FORMAT': '%Y-%m-%d %H:%M:%S %z', 'DATETIME_FORMAT': '%Y-%m-%d %H:%M:%S %z',
'DATETIME_INPUT_FORMATS': ['iso-8601', '%Y-%m-%d %H:%M:%S %z'], 'DATETIME_INPUT_FORMATS': ['iso-8601', '%Y-%m-%d %H:%M:%S %z'],
# 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination', 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination',
# 'PAGE_SIZE': 15 # 'PAGE_SIZE': 15
} }
...@@ -601,9 +599,12 @@ USER_GUIDE_URL = "" ...@@ -601,9 +599,12 @@ USER_GUIDE_URL = ""
SWAGGER_SETTINGS = { SWAGGER_SETTINGS = {
'DEFAULT_AUTO_SCHEMA_CLASS': 'jumpserver.swagger.CustomSwaggerAutoSchema', 'DEFAULT_AUTO_SCHEMA_CLASS': 'jumpserver.swagger.CustomSwaggerAutoSchema',
'USE_SESSION_AUTH': True,
'SECURITY_DEFINITIONS': { 'SECURITY_DEFINITIONS': {
'basic': { 'Bearer': {
'type': 'basic' 'type': 'apiKey',
'name': 'Authorization',
'in': 'header'
} }
}, },
} }
......
...@@ -7,13 +7,44 @@ from drf_yasg import openapi ...@@ -7,13 +7,44 @@ from drf_yasg import openapi
class CustomSwaggerAutoSchema(SwaggerAutoSchema): class CustomSwaggerAutoSchema(SwaggerAutoSchema):
def get_tags(self, operation_keys): def get_tags(self, operation_keys):
if len(operation_keys) > 2 and operation_keys[1].startswith('v'): if len(operation_keys) > 2:
return [operation_keys[2]] return [operation_keys[0] + '_' + operation_keys[1]]
return super().get_tags(operation_keys) return super().get_tags(operation_keys)
def get_operation_id(self, operation_keys):
action = ''
dump_keys = [k for k in operation_keys]
if hasattr(self.view, 'action'):
action = self.view.action
if action == "bulk_destroy":
action = "bulk_delete"
if dump_keys[-2] == "children":
if self.path.find('id') < 0:
dump_keys.insert(-2, "root")
if dump_keys[0] == "perms" and dump_keys[1] == "users":
if self.path.find('{id}') < 0:
dump_keys.insert(2, "my")
if action.replace('bulk_', '') == dump_keys[-1]:
dump_keys[-1] = action
return super().get_operation_id(tuple(dump_keys))
def get_operation(self, operation_keys):
operation = super().get_operation(operation_keys)
operation.summary = operation.operation_id
return operation
def get_swagger_view(version='v1'): def get_swagger_view(version='v1'):
from .urls import api_v1_patterns, api_v2_patterns from .urls import api_v1, api_v2
from django.urls import path, include
api_v1_patterns = [
path('api/v1/', include(api_v1))
]
api_v2_patterns = [
path('api/v2/', include(api_v2))
]
if version == "v2": if version == "v2":
patterns = api_v2_patterns patterns = api_v2_patterns
else: else:
......
...@@ -7,26 +7,26 @@ from django.conf.urls.static import static ...@@ -7,26 +7,26 @@ from django.conf.urls.static import static
from django.conf.urls.i18n import i18n_patterns from django.conf.urls.i18n import i18n_patterns
from django.views.i18n import JavaScriptCatalog from django.views.i18n import JavaScriptCatalog
from .views import IndexView, LunaView, I18NView, HealthCheckView from .views import IndexView, LunaView, I18NView, HealthCheckView, redirect_format_api
from .swagger import get_swagger_view from .swagger import get_swagger_view
api_v1 = [ api_v1 = [
path('users/v1/', include('users.urls.api_urls', namespace='api-users')), path('users/', include('users.urls.api_urls', namespace='api-users')),
path('assets/v1/', include('assets.urls.api_urls', namespace='api-assets')), path('assets/', include('assets.urls.api_urls', namespace='api-assets')),
path('perms/v1/', include('perms.urls.api_urls', namespace='api-perms')), path('perms/', include('perms.urls.api_urls', namespace='api-perms')),
path('terminal/v1/', include('terminal.urls.api_urls', namespace='api-terminal')), path('terminal/', include('terminal.urls.api_urls', namespace='api-terminal')),
path('ops/v1/', include('ops.urls.api_urls', namespace='api-ops')), path('ops/', include('ops.urls.api_urls', namespace='api-ops')),
path('audits/v1/', include('audits.urls.api_urls', namespace='api-audits')), path('audits/', include('audits.urls.api_urls', namespace='api-audits')),
path('orgs/v1/', include('orgs.urls.api_urls', namespace='api-orgs')), path('orgs/', include('orgs.urls.api_urls', namespace='api-orgs')),
path('settings/v1/', include('settings.urls.api_urls', namespace='api-settings')), path('settings/', include('settings.urls.api_urls', namespace='api-settings')),
path('authentication/v1/', include('authentication.urls.api_urls', namespace='api-auth')), path('authentication/', include('authentication.urls.api_urls', namespace='api-auth')),
path('common/v1/', include('common.urls.api_urls', namespace='api-common')), path('common/', include('common.urls.api_urls', namespace='api-common')),
path('applications/v1/', include('applications.urls.api_urls', namespace='api-applications')), path('applications/', include('applications.urls.api_urls', namespace='api-applications')),
] ]
api_v2 = [ api_v2 = [
path('terminal/v2/', include('terminal.urls.api_urls_v2', namespace='api-terminal-v2')), path('terminal/', include('terminal.urls.api_urls_v2', namespace='api-terminal-v2')),
path('users/v2/', include('users.urls.api_urls_v2', namespace='api-users-v2')), path('users/', include('users.urls.api_urls_v2', namespace='api-users-v2')),
] ]
...@@ -48,30 +48,23 @@ if settings.XPACK_ENABLED: ...@@ -48,30 +48,23 @@ if settings.XPACK_ENABLED:
path('xpack/', include('xpack.urls.view_urls', namespace='xpack')) path('xpack/', include('xpack.urls.view_urls', namespace='xpack'))
) )
api_v1.append( api_v1.append(
path('xpack/v1/', include('xpack.urls.api_urls', namespace='api-xpack')) path('xpack/', include('xpack.urls.api_urls', namespace='api-xpack'))
) )
js_i18n_patterns = i18n_patterns( js_i18n_patterns = i18n_patterns(
path('jsi18n/', JavaScriptCatalog.as_view(), name='javascript-catalog'), path('jsi18n/', JavaScriptCatalog.as_view(), name='javascript-catalog'),
) )
api_v1_patterns = [
path('api/', include(api_v1))
]
api_v2_patterns = [
path('api/', include(api_v2))
]
urlpatterns = [ urlpatterns = [
path('', IndexView.as_view(), name='index'), path('', IndexView.as_view(), name='index'),
path('', include(api_v2_patterns)), path('api/v1/', include(api_v1)),
path('', include(api_v1_patterns)), path('api/v2/', include(api_v2)),
re_path('api/(?P<app>\w+)/(?P<version>v\d)/.*', redirect_format_api),
path('api/health/', HealthCheckView.as_view(), name="health"), path('api/health/', HealthCheckView.as_view(), name="health"),
path('luna/', LunaView.as_view(), name='luna-view'), path('luna/', LunaView.as_view(), name='luna-view'),
path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'), path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'),
path('settings/', include('settings.urls.view_urls', namespace='settings')), path('settings/', include('settings.urls.view_urls', namespace='settings')),
# path('api/v2/', include(api_v2_patterns)),
# External apps url # External apps url
path('captcha/', include('captcha.urls')), path('captcha/', include('captcha.urls')),
......
...@@ -2,14 +2,13 @@ import datetime ...@@ -2,14 +2,13 @@ import datetime
import re import re
import time import time
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect, JsonResponse
from django.conf import settings from django.conf import settings
from django.views.generic import TemplateView, View from django.views.generic import TemplateView, View
from django.utils import timezone from django.utils import timezone
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.db.models import Count from django.db.models import Count
from django.shortcuts import redirect from django.shortcuts import redirect
from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.http import HttpResponse from django.http import HttpResponse
...@@ -208,7 +207,7 @@ class I18NView(View): ...@@ -208,7 +207,7 @@ class I18NView(View):
return response return response
api_url_pattern = re.compile(r'^/api/(?P<version>\w+)/(?P<app>\w+)/(?P<extra>.*)$') api_url_pattern = re.compile(r'^/api/(?P<app>\w+)/(?P<version>v\d)/(?P<extra>.*)$')
@csrf_exempt @csrf_exempt
...@@ -216,18 +215,16 @@ def redirect_format_api(request, *args, **kwargs): ...@@ -216,18 +215,16 @@ def redirect_format_api(request, *args, **kwargs):
_path, query = request.path, request.GET.urlencode() _path, query = request.path, request.GET.urlencode()
matched = api_url_pattern.match(_path) matched = api_url_pattern.match(_path)
if matched: if matched:
version, app, extra = matched.groups() kwargs = matched.groupdict()
_path = '/api/{app}/{version}/{extra}?{query}'.format(**{ kwargs["query"] = query
"app": app, "version": version, "extra": extra, _path = '/api/{version}/{app}/{extra}?{query}'.format(**kwargs).rstrip("?")
"query": query
})
return HttpResponseTemporaryRedirect(_path) return HttpResponseTemporaryRedirect(_path)
else: else:
return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404) return JsonResponse({"msg": "Redirect url failed: {}".format(_path)}, status=404)
class HealthCheckView(APIView): class HealthCheckView(APIView):
permission_classes = () permission_classes = ()
def get(self, request): def get(self, request):
return Response({"status": 1, "time": int(time.time())}) return JsonResponse({"status": 1, "time": int(time.time())})
This diff is collapsed.
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import datetime import datetime
import json import json
import os
from collections import defaultdict from collections import defaultdict
from ansible import constants as C from ansible import constants as C
...@@ -41,7 +42,11 @@ class CallbackMixin: ...@@ -41,7 +42,11 @@ class CallbackMixin:
super().__init__() super().__init__()
if display: if display:
self._display = display self._display = display
cols = os.environ.get("TERM_COLS", None)
self._display.columns = 79 self._display.columns = 79
if cols and cols.isdigit():
self._display.columns = int(cols) - 1
def display(self, msg): def display(self, msg):
self._display.display(msg) self._display.display(msg)
......
...@@ -6,6 +6,7 @@ from rest_framework import viewsets, generics ...@@ -6,6 +6,7 @@ from rest_framework import viewsets, generics
from rest_framework.views import Response from rest_framework.views import Response
from common.permissions import IsOrgAdmin from common.permissions import IsOrgAdmin
from common.serializers import CeleryTaskSerializer
from orgs.utils import current_org from orgs.utils import current_org
from ..models import Task, AdHoc, AdHocRunHistory from ..models import Task, AdHoc, AdHocRunHistory
from ..serializers import TaskSerializer, AdHocSerializer, \ from ..serializers import TaskSerializer, AdHocSerializer, \
...@@ -33,7 +34,7 @@ class TaskViewSet(viewsets.ModelViewSet): ...@@ -33,7 +34,7 @@ class TaskViewSet(viewsets.ModelViewSet):
class TaskRun(generics.RetrieveAPIView): class TaskRun(generics.RetrieveAPIView):
queryset = Task.objects.all() queryset = Task.objects.all()
# serializer_class = TaskViewSet serializer_class = CeleryTaskSerializer
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.utils.translation import ugettext as _
from django.conf import settings from django.conf import settings
from orgs.mixins import RootOrgViewMixin from orgs.mixins.api import RootOrgViewMixin
from common.permissions import IsValidUser from common.permissions import IsValidUser
from perms.utils import AssetPermissionUtilV2
from ..models import CommandExecution from ..models import CommandExecution
from ..serializers import CommandExecutionSerializer from ..serializers import CommandExecutionSerializer
from ..tasks import run_command_execution from ..tasks import run_command_execution
...@@ -20,15 +23,33 @@ class CommandExecutionViewSet(RootOrgViewMixin, viewsets.ModelViewSet): ...@@ -20,15 +23,33 @@ class CommandExecutionViewSet(RootOrgViewMixin, viewsets.ModelViewSet):
user_id=str(self.request.user.id) user_id=str(self.request.user.id)
) )
def check_hosts(self, serializer):
data = serializer.validated_data
assets = data["hosts"]
system_user = data["run_as"]
util = AssetPermissionUtilV2(self.request.user)
util.filter_permissions(system_users=system_user.id)
permed_assets = util.get_assets().filter(id__in=[a.id for a in assets])
unpermed_assets = set(assets) - set(permed_assets)
if unpermed_assets:
msg = _("Not has host {} permission").format(
[str(a.id) for a in unpermed_assets]
)
raise ValidationError({"hosts": msg})
def check_permissions(self, request): def check_permissions(self, request):
if not settings.SECURITY_COMMAND_EXECUTION and request.user.is_common_user: if not settings.SECURITY_COMMAND_EXECUTION and request.user.is_common_user:
return self.permission_denied(request, "Command execution disabled") return self.permission_denied(request, "Command execution disabled")
return super().check_permissions(request) return super().check_permissions(request)
def perform_create(self, serializer): def perform_create(self, serializer):
self.check_hosts(serializer)
instance = serializer.save() instance = serializer.save()
instance.user = self.request.user instance.user = self.request.user
instance.save() instance.save()
cols = self.request.query_params.get("cols", '80')
rows = self.request.query_params.get("rows", '24')
transaction.on_commit(lambda: run_command_execution.apply_async( transaction.on_commit(lambda: run_command_execution.apply_async(
args=(instance.id,), task_id=str(instance.id) args=(instance.id,), kwargs={"cols": cols, "rows": rows},
task_id=str(instance.id)
)) ))
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import os import os
from kombu import Exchange, Queue
from celery import Celery from celery import Celery
# set the default Django settings module for the 'celery' program. # set the default Django settings module for the 'celery' program.
...@@ -15,6 +16,14 @@ configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')} ...@@ -15,6 +16,14 @@ configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
# Using a string here means the worker will not have to # Using a string here means the worker will not have to
# pickle the object when using Windows. # pickle the object when using Windows.
# app.config_from_object('django.conf:settings', namespace='CELERY') # app.config_from_object('django.conf:settings', namespace='CELERY')
configs["CELERY_QUEUES"] = [
Queue("celery", Exchange("celery"), routing_key="celery"),
Queue("ansible", Exchange("ansible"), routing_key="ansible"),
]
configs["CELERY_ROUTES"] = {
"ops.tasks.run_ansible_task": {'exchange': 'ansible', 'routing_key': 'ansible'},
}
app.namespace = 'CELERY' app.namespace = 'CELERY'
app.conf.update(configs) app.conf.update(configs)
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS]) app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
...@@ -30,8 +30,6 @@ class JMSBaseInventory(BaseInventory): ...@@ -30,8 +30,6 @@ class JMSBaseInventory(BaseInventory):
info.update(asset.get_auth_info()) info.update(asset.get_auth_info())
if asset.is_unixlike(): if asset.is_unixlike():
info["become"] = asset.admin_user.become_info info["become"] = asset.admin_user.become_info
for node in asset.nodes.all():
info["groups"].append(node.value)
if asset.is_windows(): if asset.is_windows():
info["vars"].update({ info["vars"].update({
"ansible_connection": "ssh", "ansible_connection": "ssh",
...@@ -45,7 +43,6 @@ class JMSBaseInventory(BaseInventory): ...@@ -45,7 +43,6 @@ class JMSBaseInventory(BaseInventory):
info["vars"].update({ info["vars"].update({
"domain": asset.domain.name, "domain": asset.domain.name,
}) })
info["groups"].append("domain_"+asset.domain.name)
return info return info
@staticmethod @staticmethod
......
# Generated by Django 2.1.7 on 2019-07-24 12:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('ops', '0006_auto_20190318_1023'),
]
operations = [
migrations.AlterField(
model_name='adhoc',
name='_become',
field=models.CharField(blank=True, default='', max_length=1024, verbose_name='Become'),
),
migrations.AlterField(
model_name='adhoc',
name='created_by',
field=models.CharField(blank=True, default='', max_length=64, null=True, verbose_name='Create by'),
),
migrations.AlterField(
model_name='adhoc',
name='run_as',
field=models.CharField(blank=True, default='', max_length=64, null=True, verbose_name='Username'),
),
]
...@@ -161,9 +161,9 @@ class AdHoc(models.Model): ...@@ -161,9 +161,9 @@ class AdHoc(models.Model):
_hosts = models.TextField(blank=True, verbose_name=_('Hosts')) # ['hostname1', 'hostname2'] _hosts = models.TextField(blank=True, verbose_name=_('Hosts')) # ['hostname1', 'hostname2']
hosts = models.ManyToManyField('assets.Asset', verbose_name=_("Host")) hosts = models.ManyToManyField('assets.Asset', verbose_name=_("Host"))
run_as_admin = models.BooleanField(default=False, verbose_name=_('Run as admin')) run_as_admin = models.BooleanField(default=False, verbose_name=_('Run as admin'))
run_as = models.CharField(max_length=64, default='', null=True, verbose_name=_('Username')) run_as = models.CharField(max_length=64, default='', blank=True, null=True, verbose_name=_('Username'))
_become = models.CharField(max_length=1024, default='', verbose_name=_("Become")) _become = models.CharField(max_length=1024, default='', blank=True, verbose_name=_("Become"))
created_by = models.CharField(max_length=64, default='', null=True, verbose_name=_('Create by')) created_by = models.CharField(max_length=64, default='', blank=True, null=True, verbose_name=_('Create by'))
date_created = models.DateTimeField(auto_now_add=True, db_index=True) date_created = models.DateTimeField(auto_now_add=True, db_index=True)
@property @property
......
...@@ -23,7 +23,7 @@ def rerun_task(): ...@@ -23,7 +23,7 @@ def rerun_task():
pass pass
@shared_task @shared_task(queue="ansible")
def run_ansible_task(tid, callback=None, **kwargs): def run_ansible_task(tid, callback=None, **kwargs):
""" """
:param tid: is the tasks serialized data :param tid: is the tasks serialized data
...@@ -45,6 +45,10 @@ def run_command_execution(cid, **kwargs): ...@@ -45,6 +45,10 @@ def run_command_execution(cid, **kwargs):
execution = get_object_or_none(CommandExecution, id=cid) execution = get_object_or_none(CommandExecution, id=cid)
if execution: if execution:
try: try:
os.environ.update({
"TERM_ROWS": kwargs.get("rows", ""),
"TERM_COLS": kwargs.get("cols", ""),
})
execution.run() execution.run()
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded:
logger.error("Run time out") logger.error("Run time out")
...@@ -98,7 +102,7 @@ def create_or_update_registered_periodic_tasks(): ...@@ -98,7 +102,7 @@ def create_or_update_registered_periodic_tasks():
create_or_update_celery_periodic_tasks(task) create_or_update_celery_periodic_tasks(task)
@shared_task @shared_task(queue="ansible")
def hello(name, callback=None): def hello(name, callback=None):
import time import time
time.sleep(10) time.sleep(10)
...@@ -109,7 +113,9 @@ def hello(name, callback=None): ...@@ -109,7 +113,9 @@ def hello(name, callback=None):
# @after_app_shutdown_clean_periodic # @after_app_shutdown_clean_periodic
# @register_as_period_task(interval=30) # @register_as_period_task(interval=30)
def hello123(): def hello123():
p = subprocess.Popen('ls /tmp', shell=True)
print("{} Hello world".format(datetime.datetime.now().strftime("%H:%M:%S"))) print("{} Hello world".format(datetime.datetime.now().strftime("%H:%M:%S")))
return None
@shared_task @shared_task
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
{% load i18n %} {% load i18n %}
<head> <head>
<title>{% trans 'Task log' %}</title> <title>{% trans 'Task log' %}</title>
<script src="{% static 'js/jquery-2.1.1.js' %}"></script> <script src="{% static 'js/jquery-3.1.1.min.js' %}"></script>
<script src="{% static 'js/plugins/xterm/xterm.js' %}"></script> <script src="{% static 'js/plugins/xterm/xterm.js' %}"></script>
<link rel="stylesheet" href="{% static 'js/plugins/xterm/xterm.css' %}" /> <link rel="stylesheet" href="{% static 'js/plugins/xterm/xterm.css' %}" />
<style> <style>
......
...@@ -83,9 +83,50 @@ ...@@ -83,9 +83,50 @@
var zTree, show = 0; var zTree, show = 0;
var systemUserId = null; var systemUserId = null;
var url = null; var url = null;
var treeUrl = "{% url 'api-perms:my-nodes-assets-as-tree' %}?cache_policy=1"; var treeUrl = "{% url 'api-perms:my-nodes-children-with-assets-as-tree' %}?cache_policy=1";
function proposeGeometry(term) {
if (!term.element.parentElement) {
return null;
}
var parentElementStyle = window.getComputedStyle(term.element.parentElement);
var parentElementHeight = parseInt(parentElementStyle.getPropertyValue('height'));
var parentElementWidth = Math.max(0, parseInt(parentElementStyle.getPropertyValue('width')));
var elementStyle = window.getComputedStyle(term.element);
var elementPadding = {
top: parseInt(elementStyle.getPropertyValue('padding-top')),
bottom: parseInt(elementStyle.getPropertyValue('padding-bottom')),
right: parseInt(elementStyle.getPropertyValue('padding-right')),
left: parseInt(elementStyle.getPropertyValue('padding-left'))
};
var elementPaddingVer = elementPadding.top + elementPadding.bottom;
var elementPaddingHor = elementPadding.right + elementPadding.left;
var availableHeight = parentElementHeight - elementPaddingVer;
var availableWidth = parentElementWidth - elementPaddingHor - term._core.viewport.scrollBarWidth;
var geometry = {
cols: Math.floor(availableWidth / term._core.renderer.dimensions.actualCellWidth),
rows: Math.floor(availableHeight / term._core.renderer.dimensions.actualCellHeight)
};
return geometry;
}
function fit(term) {
var geometry = proposeGeometry(term);
if (geometry) {
if (term.rows !== geometry.rows || term.cols !== geometry.cols) {
term._core.renderer.clear();
term.resize(geometry.cols, geometry.rows);
}
}
}
function initTree() { function initTree() {
if (systemUserId) {
url = treeUrl + '&system_user=' + systemUserId
}
else{
url = treeUrl
}
var setting = { var setting = {
check: { check: {
enable: true enable: true
...@@ -99,6 +140,12 @@ function initTree() { ...@@ -99,6 +140,12 @@ function initTree() {
enable: true enable: true
} }
}, },
async: {
enable: true,
url: url,
autoParam: ["id=key", "name=n", "level=lv"],
type: 'get'
},
edit: { edit: {
enable: true, enable: true,
showRemoveBtn: false, showRemoveBtn: false,
...@@ -112,12 +159,7 @@ function initTree() { ...@@ -112,12 +159,7 @@ function initTree() {
onCheck: onCheck onCheck: onCheck
} }
}; };
if (systemUserId) {
url = treeUrl + '&system_user=' + systemUserId
}
else{
url = treeUrl
}
$.get(url, function(data, status){ $.get(url, function(data, status){
$.fn.zTree.init($("#assetTree"), setting, data); $.fn.zTree.init($("#assetTree"), setting, data);
...@@ -183,6 +225,7 @@ function initResultTerminal() { ...@@ -183,6 +225,7 @@ function initResultTerminal() {
screenKeys: false, screenKeys: false,
fontFamily: 'monaco, Consolas, "Lucida Console", monospace', fontFamily: 'monaco, Consolas, "Lucida Console", monospace',
fontSize: 14, fontSize: 14,
lineHeight: 1,
rightClickSelectsWord: true, rightClickSelectsWord: true,
disableStdin: true, disableStdin: true,
theme: { theme: {
...@@ -190,7 +233,9 @@ function initResultTerminal() { ...@@ -190,7 +233,9 @@ function initResultTerminal() {
} }
}); });
term.open(document.getElementById('term')); term.open(document.getElementById('term'));
term.write("{% trans 'Select the left asset, select the running system user, execute command in batch' %}" + "\r\n") var msg = "{% trans 'Select the left asset, select the running system user, execute command in batch' %}" + "\r\n";
fit(term);
term.write(msg)
} }
function wrapperError(msg) { function wrapperError(msg) {
...@@ -201,7 +246,8 @@ function execute() { ...@@ -201,7 +246,8 @@ function execute() {
if (!term) { if (!term) {
initResultTerminal() initResultTerminal()
} }
var url = '{% url "api-ops:command-execution-list" %}'; var size = 'rows=' + term.rows + '&cols=' + term.cols;
var url = '{% url "api-ops:command-execution-list" %}?' + size;
var run_as = systemUserId; var run_as = systemUserId;
var command = editor.getValue(); var command = editor.getValue();
var hosts = getSelectedAssetsNode().map(function (node) { var hosts = getSelectedAssetsNode().map(function (node) {
......
...@@ -65,9 +65,9 @@ class CommandExecutionStartView(PermissionsMixin, TemplateView): ...@@ -65,9 +65,9 @@ class CommandExecutionStartView(PermissionsMixin, TemplateView):
return super().get_permissions() return super().get_permissions()
def get_user_system_users(self): def get_user_system_users(self):
from perms.utils import AssetPermissionUtil from perms.utils import AssetPermissionUtilV2
user = self.request.user user = self.request.user
util = AssetPermissionUtil(user) util = AssetPermissionUtilV2(user)
system_users = [s for s in util.get_system_users() if s.protocol == 'ssh'] system_users = [s for s in util.get_system_users() if s.protocol == 'ssh']
return system_users return system_users
......
...@@ -14,7 +14,7 @@ from assets.models import Asset, Domain, AdminUser, SystemUser, Label ...@@ -14,7 +14,7 @@ from assets.models import Asset, Domain, AdminUser, SystemUser, Label
from perms.models import AssetPermission from perms.models import AssetPermission
from orgs.utils import current_org from orgs.utils import current_org
from common.utils import get_logger from common.utils import get_logger
from .mixins import OrgMembershipModelViewSetMixin from .mixins.api import OrgMembershipModelViewSetMixin
logger = get_logger(__file__) logger = get_logger(__file__)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .models import * # from .models import *
from .serializers import * # from .serializers import *
from .forms import * # from .forms import *
from .api import * # from .api import *
...@@ -34,6 +34,9 @@ class OrgBulkModelViewSet(IDInCacheFilterMixin, BulkModelViewSet): ...@@ -34,6 +34,9 @@ class OrgBulkModelViewSet(IDInCacheFilterMixin, BulkModelViewSet):
queryset = self.serializer_class.setup_eager_loading(queryset) queryset = self.serializer_class.setup_eager_loading(queryset)
return queryset return queryset
def allow_bulk_destroy(self, qs, filtered):
return False
class OrgMembershipModelViewSetMixin: class OrgMembershipModelViewSetMixin:
org = None org = None
......
...@@ -8,7 +8,7 @@ from perms.models import AssetPermission ...@@ -8,7 +8,7 @@ from perms.models import AssetPermission
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from .utils import set_current_org, get_current_org from .utils import set_current_org, get_current_org
from .models import Organization from .models import Organization
from .mixins import OrgMembershipSerializerMixin from .mixins.serializers import OrgMembershipSerializerMixin
class OrgSerializer(ModelSerializer): class OrgSerializer(ModelSerializer):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.urls import path from django.urls import re_path
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from common import api as capi
from .. import api from .. import api
...@@ -10,20 +12,18 @@ app_name = 'orgs' ...@@ -10,20 +12,18 @@ app_name = 'orgs'
router = DefaultRouter() router = DefaultRouter()
# 将会删除 # 将会删除
router.register(r'org/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/admins',
api.OrgMembershipAdminsViewSet, 'membership-admins')
router.register(r'org/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/users',
api.OrgMembershipUsersViewSet, 'membership-users'),
# 替换为这个
router.register(r'orgs/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/admins', router.register(r'orgs/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/admins',
api.OrgMembershipAdminsViewSet, 'membership-admins-2') api.OrgMembershipAdminsViewSet, 'membership-admins')
router.register(r'orgs/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/users', router.register(r'orgs/(?P<org_id>[0-9a-zA-Z\-]{36})/membership/users',
api.OrgMembershipUsersViewSet, 'membership-users-2'), api.OrgMembershipUsersViewSet, 'membership-users'),
router.register(r'orgs', api.OrgViewSet, 'org') router.register(r'orgs', api.OrgViewSet, 'org')
old_version_urlpatterns = [
re_path('(?P<resource>org)/.*', capi.redirect_plural_name_api)
]
urlpatterns = [ urlpatterns = [
] ]
urlpatterns += router.urls urlpatterns += router.urls + old_version_urlpatterns
...@@ -7,7 +7,6 @@ from rest_framework.views import Response ...@@ -7,7 +7,6 @@ from rest_framework.views import Response
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView from rest_framework.generics import RetrieveUpdateAPIView, ListAPIView
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.pagination import LimitOffsetPagination
from common.permissions import IsOrgAdmin from common.permissions import IsOrgAdmin
from common.utils import get_object_or_none from common.utils import get_object_or_none
...@@ -31,7 +30,6 @@ class AssetPermissionViewSet(viewsets.ModelViewSet): ...@@ -31,7 +30,6 @@ class AssetPermissionViewSet(viewsets.ModelViewSet):
""" """
queryset = AssetPermission.objects.all() queryset = AssetPermission.objects.all()
serializer_class = serializers.AssetPermissionCreateUpdateSerializer serializer_class = serializers.AssetPermissionCreateUpdateSerializer
pagination_class = LimitOffsetPagination
filter_fields = ['name'] filter_fields = ['name']
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
...@@ -247,7 +245,6 @@ class AssetPermissionAddAssetApi(RetrieveUpdateAPIView): ...@@ -247,7 +245,6 @@ class AssetPermissionAddAssetApi(RetrieveUpdateAPIView):
class AssetPermissionAssetsApi(ListAPIView): class AssetPermissionAssetsApi(ListAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
pagination_class = LimitOffsetPagination
serializer_class = serializers.AssetPermissionAssetsSerializer serializer_class = serializers.AssetPermissionAssetsSerializer
filter_fields = ("hostname", "ip") filter_fields = ("hostname", "ip")
search_fields = filter_fields search_fields = filter_fields
......
...@@ -12,9 +12,6 @@ from django.views.decorators.http import condition ...@@ -12,9 +12,6 @@ from django.views.decorators.http import condition
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from common.utils import get_logger from common.utils import get_logger
from assets.utils import LabelFilterMixin from assets.utils import LabelFilterMixin
from ..utils import (
AssetPermissionUtil
)
from .. import const from .. import const
from ..hands import Asset, Node, SystemUser from ..hands import Asset, Node, SystemUser
from .. import serializers from .. import serializers
...@@ -24,119 +21,120 @@ logger = get_logger(__name__) ...@@ -24,119 +21,120 @@ logger = get_logger(__name__)
__all__ = ['UserPermissionCacheMixin', 'GrantAssetsMixin', 'NodesWithUngroupMixin'] __all__ = ['UserPermissionCacheMixin', 'GrantAssetsMixin', 'NodesWithUngroupMixin']
def get_etag(request, *args, **kwargs): # def get_etag(request, *args, **kwargs):
cache_policy = request.GET.get("cache_policy") # cache_policy = request.GET.get("cache_policy")
if cache_policy != '1': # if cache_policy != '1':
return None # return None
if not UserPermissionCacheMixin.CACHE_ENABLE: # if not UserPermissionCacheMixin.CACHE_ENABLE:
return None # return None
view = request.parser_context.get("view") # view = request.parser_context.get("view")
if not view: # if not view:
return None # return None
etag = view.get_meta_cache_id() # etag = view.get_meta_cache_id()
return etag # return etag
class UserPermissionCacheMixin: class UserPermissionCacheMixin:
cache_policy = '0' pass
RESP_CACHE_KEY = '_PERMISSION_RESPONSE_CACHE_V2_{}' # cache_policy = '0'
CACHE_ENABLE = settings.ASSETS_PERM_CACHE_ENABLE # RESP_CACHE_KEY = '_PERMISSION_RESPONSE_CACHE_V2_{}'
CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME # CACHE_ENABLE = settings.ASSETS_PERM_CACHE_ENABLE
_object = None # CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME
# _object = None
def get_object(self): #
return None # def get_object(self):
# return None
# 内部使用可控制缓存 #
def _get_object(self): # # 内部使用可控制缓存
if not self._object: # def _get_object(self):
self._object = self.get_object() # if not self._object:
return self._object # self._object = self.get_object()
# return self._object
def get_object_id(self): #
obj = self._get_object() # def get_object_id(self):
if obj: # obj = self._get_object()
return str(obj.id) # if obj:
return None # return str(obj.id)
# return None
def get_request_md5(self): #
path = self.request.path # def get_request_md5(self):
query = {k: v for k, v in self.request.GET.items()} # path = self.request.path
query.pop("_", None) # query = {k: v for k, v in self.request.GET.items()}
query = "&".join(["{}={}".format(k, v) for k, v in query.items()]) # query.pop("_", None)
full_path = "{}?{}".format(path, query) # query = "&".join(["{}={}".format(k, v) for k, v in query.items()])
return md5(full_path.encode()).hexdigest() # full_path = "{}?{}".format(path, query)
# return md5(full_path.encode()).hexdigest()
def get_meta_cache_id(self): #
obj = self._get_object() # def get_meta_cache_id(self):
util = AssetPermissionUtil(obj, cache_policy=self.cache_policy) # obj = self._get_object()
meta_cache_id = util.cache_meta.get('id') # util = AssetPermissionUtil(obj, cache_policy=self.cache_policy)
return meta_cache_id # meta_cache_id = util.cache_meta.get('id')
# return meta_cache_id
def get_response_cache_id(self): #
obj_id = self.get_object_id() # def get_response_cache_id(self):
request_md5 = self.get_request_md5() # obj_id = self.get_object_id()
meta_cache_id = self.get_meta_cache_id() # request_md5 = self.get_request_md5()
resp_cache_id = '{}_{}_{}'.format(obj_id, request_md5, meta_cache_id) # meta_cache_id = self.get_meta_cache_id()
return resp_cache_id # resp_cache_id = '{}_{}_{}'.format(obj_id, request_md5, meta_cache_id)
# return resp_cache_id
def get_response_from_cache(self): #
# 没有数据缓冲 # def get_response_from_cache(self):
meta_cache_id = self.get_meta_cache_id() # # 没有数据缓冲
if not meta_cache_id: # meta_cache_id = self.get_meta_cache_id()
logger.debug("Not get meta id: {}".format(meta_cache_id)) # if not meta_cache_id:
return None # logger.debug("Not get meta id: {}".format(meta_cache_id))
# 从响应缓冲里获取响应 # return None
key = self.get_response_key() # # 从响应缓冲里获取响应
data = cache.get(key) # key = self.get_response_key()
if not data: # data = cache.get(key)
logger.debug("Not get response from cache: {}".format(key)) # if not data:
return None # logger.debug("Not get response from cache: {}".format(key))
logger.debug("Get user permission from cache: {}".format(self.get_object())) # return None
response = Response(data) # logger.debug("Get user permission from cache: {}".format(self.get_object()))
return response # response = Response(data)
# return response
def expire_response_cache(self): #
obj_id = self.get_object_id() # def expire_response_cache(self):
expire_cache_id = '{}_{}'.format(obj_id, '*') # obj_id = self.get_object_id()
key = self.RESP_CACHE_KEY.format(expire_cache_id) # expire_cache_id = '{}_{}'.format(obj_id, '*')
cache.delete_pattern(key) # key = self.RESP_CACHE_KEY.format(expire_cache_id)
# cache.delete_pattern(key)
def get_response_key(self): #
resp_cache_id = self.get_response_cache_id() # def get_response_key(self):
key = self.RESP_CACHE_KEY.format(resp_cache_id) # resp_cache_id = self.get_response_cache_id()
return key # key = self.RESP_CACHE_KEY.format(resp_cache_id)
# return key
def set_response_to_cache(self, response): #
key = self.get_response_key() # def set_response_to_cache(self, response):
cache.set(key, response.data, self.CACHE_TIME) # key = self.get_response_key()
logger.debug("Set response to cache: {}".format(key)) # cache.set(key, response.data, self.CACHE_TIME)
# logger.debug("Set response to cache: {}".format(key))
@method_decorator(condition(etag_func=get_etag)) #
def get(self, request, *args, **kwargs): # @method_decorator(condition(etag_func=get_etag))
if not self.CACHE_ENABLE: # def get(self, request, *args, **kwargs):
self.cache_policy = '0' # if not self.CACHE_ENABLE:
else: # self.cache_policy = '0'
self.cache_policy = request.GET.get('cache_policy', '0') # else:
# self.cache_policy = request.GET.get('cache_policy', '0')
obj = self._get_object() #
if obj is None: # obj = self._get_object()
logger.debug("Not get response from cache: obj is none") # if obj is None:
return super().get(request, *args, **kwargs) # logger.debug("Not get response from cache: obj is none")
# return super().get(request, *args, **kwargs)
if AssetPermissionUtil.is_not_using_cache(self.cache_policy): #
logger.debug("Not get resp from cache: {}".format(self.cache_policy)) # if AssetPermissionUtil.is_not_using_cache(self.cache_policy):
return super().get(request, *args, **kwargs) # logger.debug("Not get resp from cache: {}".format(self.cache_policy))
elif AssetPermissionUtil.is_refresh_cache(self.cache_policy): # return super().get(request, *args, **kwargs)
logger.debug("Not get resp from cache: {}".format(self.cache_policy)) # elif AssetPermissionUtil.is_refresh_cache(self.cache_policy):
self.expire_response_cache() # logger.debug("Not get resp from cache: {}".format(self.cache_policy))
# self.expire_response_cache()
logger.debug("Try get response from cache") #
resp = self.get_response_from_cache() # logger.debug("Try get response from cache")
if not resp: # resp = self.get_response_from_cache()
resp = super().get(request, *args, **kwargs) # if not resp:
self.set_response_to_cache(resp) # resp = super().get(request, *args, **kwargs)
return resp # self.set_response_to_cache(resp)
# return resp
class NodesWithUngroupMixin: class NodesWithUngroupMixin:
...@@ -202,9 +200,11 @@ class GrantAssetsMixin(LabelFilterMixin): ...@@ -202,9 +200,11 @@ class GrantAssetsMixin(LabelFilterMixin):
data.append(asset) data.append(asset)
return data return data
def get_serializer(self, queryset_list, many=True): def get_serializer(self, assets_items=None, many=True):
data = self.get_serializer_queryset(queryset_list) if assets_items is None:
return super().get_serializer(data, many=True) assets_items = []
assets_items = self.get_serializer_queryset(assets_items)
return super().get_serializer(assets_items, many=many)
def filter_queryset_by_id(self, assets_items): def filter_queryset_by_id(self, assets_items):
i = self.request.query_params.get("id") i = self.request.query_params.get("id")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment