Commit 4c443066 authored by BaiJiangJie's avatar BaiJiangJie

[Merge] Merge branch dev of github.com:jumpserver/jumpserver into github_dev

parents ee35ca36 482d1bb2
...@@ -11,8 +11,7 @@ from django.db.models import Q ...@@ -11,8 +11,7 @@ from django.db.models import Q
from common.mixins import IDInFilterMixin from common.mixins import IDInFilterMixin
from common.utils import get_logger from common.utils import get_logger
from ..hands import IsSuperUser, IsValidUser, IsSuperUserOrAppUser, \ from ..hands import IsSuperUser, IsValidUser, IsSuperUserOrAppUser
NodePermissionUtil
from ..models import Asset, SystemUser, AdminUser, Node from ..models import Asset, SystemUser, AdminUser, Node
from .. import serializers from .. import serializers
from ..tasks import update_asset_hardware_info_manual, \ from ..tasks import update_asset_hardware_info_manual, \
...@@ -22,7 +21,7 @@ from ..utils import LabelFilter ...@@ -22,7 +21,7 @@ from ..utils import LabelFilter
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
'AssetViewSet', 'UserAssetListView', 'AssetListUpdateApi', 'AssetViewSet', 'AssetListUpdateApi',
'AssetRefreshHardwareApi', 'AssetAdminUserTestApi' 'AssetRefreshHardwareApi', 'AssetAdminUserTestApi'
] ]
...@@ -71,19 +70,6 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet): ...@@ -71,19 +70,6 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet):
return queryset return queryset
class UserAssetListView(generics.ListAPIView):
queryset = Asset.objects.all()
serializer_class = serializers.AssetSerializer
permission_classes = (IsValidUser,)
def get_queryset(self):
assets_granted = NodePermissionUtil.get_user_assets(self.request.user).keys()
queryset = self.queryset.filter(
id__in=[asset.id for asset in assets_granted]
)
return queryset
class AssetListUpdateApi(IDInFilterMixin, ListBulkCreateUpdateDestroyAPIView): class AssetListUpdateApi(IDInFilterMixin, ListBulkCreateUpdateDestroyAPIView):
""" """
Asset bulk update api Asset bulk update api
......
...@@ -31,7 +31,7 @@ from .. import serializers ...@@ -31,7 +31,7 @@ from .. import serializers
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
'NodeViewSet', 'NodeChildrenApi', 'NodeViewSet', 'NodeChildrenApi',
'NodeAssetsApi', 'NodeWithAssetsApi', 'NodeAssetsApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'NodeAddAssetsApi', 'NodeRemoveAssetsApi',
'NodeReplaceAssetsApi', 'NodeReplaceAssetsApi',
'NodeAddChildrenApi', 'RefreshNodeHardwareInfoApi', 'NodeAddChildrenApi', 'RefreshNodeHardwareInfoApi',
...@@ -42,14 +42,7 @@ __all__ = [ ...@@ -42,14 +42,7 @@ __all__ = [
class NodeViewSet(BulkModelViewSet): class NodeViewSet(BulkModelViewSet):
queryset = Node.objects.all() queryset = Node.objects.all()
permission_classes = (IsSuperUser,) permission_classes = (IsSuperUser,)
# serializer_class = serializers.NodeSerializer serializer_class = serializers.NodeSerializer
def get_serializer_class(self):
show_current_asset = self.request.query_params.get('show_current_asset')
if show_current_asset:
return serializers.NodeCurrentSerializer
else:
return serializers.NodeSerializer
def perform_create(self, serializer): def perform_create(self, serializer):
child_key = Node.root().get_next_child_key() child_key = Node.root().get_next_child_key()
...@@ -57,32 +50,32 @@ class NodeViewSet(BulkModelViewSet): ...@@ -57,32 +50,32 @@ class NodeViewSet(BulkModelViewSet):
serializer.save() serializer.save()
class NodeWithAssetsApi(generics.ListAPIView): # class NodeWithAssetsApi(generics.ListAPIView):
permission_classes = (IsSuperUser,) # permission_classes = (IsSuperUser,)
serializers = serializers.NodeSerializer # serializers = serializers.NodeSerializer
#
def get_node(self): # def get_node(self):
pk = self.kwargs.get('pk') or self.request.query_params.get('node') # pk = self.kwargs.get('pk') or self.request.query_params.get('node')
if not pk: # if not pk:
node = Node.root() # node = Node.root()
else: # else:
node = get_object_or_404(Node, pk) # node = get_object_or_404(Node, pk)
return node # return node
#
def get_queryset(self): # def get_queryset(self):
queryset = [] # queryset = []
node = self.get_node() # node = self.get_node()
children = node.get_children() # children = node.get_children()
assets = node.get_assets() # assets = node.get_assets()
queryset.extend(list(children)) # queryset.extend(list(children))
#
for asset in assets: # for asset in assets:
node = Node() # node = Node()
node.id = asset.id # node.id = asset.id
node.parent = node.id # node.parent = node.id
node.value = asset.hostname # node.value = asset.hostname
queryset.append(node) # queryset.append(node)
return queryset # return queryset
class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView): class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView):
...@@ -147,9 +140,9 @@ class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView): ...@@ -147,9 +140,9 @@ class NodeChildrenApi(mixins.ListModelMixin, generics.CreateAPIView):
for asset in assets: for asset in assets:
node_fake = Node() node_fake = Node()
node_fake.id = asset.id node_fake.id = asset.id
node_fake.parent = node
node_fake.value = asset.hostname
node_fake.is_node = False node_fake.is_node = False
node_fake.parent_id = node.id
node_fake.value = asset.hostname
queryset.append(node_fake) queryset.append(node_fake)
queryset = sorted(queryset, key=lambda x: x.is_node, reverse=True) queryset = sorted(queryset, key=lambda x: x.is_node, reverse=True)
return queryset return queryset
...@@ -185,7 +178,7 @@ class NodeAddChildrenApi(generics.UpdateAPIView): ...@@ -185,7 +178,7 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
for node in children: for node in children:
if not node: if not node:
continue continue
node.set_parent(instance) node.parent = instance
return Response("OK") return Response("OK")
......
...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
from common.mixins import AdminUserRequiredMixin from common.mixins import AdminUserRequiredMixin
from common.permissions import IsAppUser, IsSuperUser, IsValidUser, IsSuperUserOrAppUser from common.permissions import IsAppUser, IsSuperUser, IsValidUser, IsSuperUserOrAppUser
from users.models import User, UserGroup from users.models import User, UserGroup
from perms.utils import NodePermissionUtil
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import uuid import uuid
import logging import logging
import random import random
from functools import reduce
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
...@@ -149,22 +150,15 @@ class Asset(models.Model): ...@@ -149,22 +150,15 @@ class Asset(models.Model):
nodes = self.nodes.all() or [Node.root()] nodes = self.nodes.all() or [Node.root()]
return nodes return nodes
@property def get_all_nodes(self, flat=False):
def nodes_cache_key(self): nodes = []
key = "NODES_OF_{}".format(str(self.id)) for node in self.get_nodes():
return key _nodes = node.get_ancestor(with_self=True)
_nodes.append(_nodes)
def get_nodes_or_cache(self): if flat:
cached = cache.get(self.nodes_cache_key) nodes = list(reduce(lambda x, y: set(x) | set(y), nodes))
if cached is not None:
return cached
nodes = list(self.get_nodes())
cache.set(self.nodes_cache_key, nodes, 3600)
return nodes return nodes
def expire_nodes_cache(self):
cache.delete(self.nodes_cache_key)
@property @property
def hardware_info(self): def hardware_info(self):
if self.cpu_count: if self.cpu_count:
......
...@@ -5,7 +5,7 @@ import uuid ...@@ -5,7 +5,7 @@ import uuid
from django.db import models, transaction from django.db import models, transaction
from django.db.models import Q from django.db.models import Q
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.utils import with_cache
__all__ = ['Node'] __all__ = ['Node']
...@@ -22,32 +22,36 @@ class Node(models.Model): ...@@ -22,32 +22,36 @@ class Node(models.Model):
def __str__(self): def __str__(self):
return self.full_value return self.full_value
def __eq__(self, other):
return self.key == other.key
def __gt__(self, other):
if self.is_root():
return True
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
if len(self_key) < len(other_key):
return True
elif len(self_key) > len(other_key):
return False
else:
return self_key[-1] < other_key[-1]
@property @property
def name(self): def name(self):
return self.value return self.value
@property @property
def full_value(self): def full_value(self):
ancestor = [a.value for a in self.ancestor] ancestor = [a.value for a in self.get_ancestor(with_self=True)]
if self.is_root(): if self.is_root():
return self.value return self.value
ancestor.append(self.value)
return ' / '.join(ancestor) return ' / '.join(ancestor)
@property @property
def level(self): def level(self):
return len(self.key.split(':')) return len(self.key.split(':'))
def set_parent(self, instance):
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.parent = instance
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
def get_next_child_key(self): def get_next_child_key(self):
mark = self.child_mark mark = self.child_mark
self.child_mark += 1 self.child_mark += 1
...@@ -55,32 +59,35 @@ class Node(models.Model): ...@@ -55,32 +59,35 @@ class Node(models.Model):
return "{}:{}".format(self.key, mark) return "{}:{}".format(self.key, mark)
def create_child(self, value): def create_child(self, value):
child_key = self.get_next_child_key() with transaction.atomic():
child = self.__class__.objects.create(key=child_key, value=value) child_key = self.get_next_child_key()
return child child = self.__class__.objects.create(key=child_key, value=value)
return child
def get_children(self):
return self.__class__.objects.filter(
key__regex=r'^{}:[0-9]+$'.format(self.key)
)
def get_children_with_self(self): def get_children(self, with_self=False):
pattern = r'^{0}$|^{}:[0-9]+$' if with_self else r'^{}:[0-9]+$'
return self.__class__.objects.filter( return self.__class__.objects.filter(
key__regex=r'^{0}$|^{0}:[0-9]+$'.format(self.key) key__regex=pattern.format(self.key)
) )
def get_all_children(self): def get_all_children(self, with_self=False):
pattern = r'^{0}$|^{0}:' if with_self else r'^{0}'
return self.__class__.objects.filter( return self.__class__.objects.filter(
key__startswith='{}:'.format(self.key) key__regex=pattern.format(self.key)
) )
def get_all_children_with_self(self): def get_sibling(self, with_self=False):
return self.__class__.objects.filter( key = ':'.join(self.key.split(':')[:-1])
key__regex=r'^{0}$|^{0}:'.format(self.key) pattern = r'^{}:[0-9]+$'.format(key)
sibling = self.__class__.objects.filter(
key__regex=pattern.format(self.key)
) )
if not with_self:
sibling = sibling.exclude(key=self.key)
return sibling
def get_family(self): def get_family(self):
ancestor = self.ancestor ancestor = self.get_ancestor()
children = self.get_all_children() children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)] return [*tuple(ancestor), self, *tuple(children)]
...@@ -91,7 +98,7 @@ class Node(models.Model): ...@@ -91,7 +98,7 @@ class Node(models.Model):
Q(nodes__id=self.id) | Q(nodes__isnull=True) Q(nodes__id=self.id) | Q(nodes__isnull=True)
) )
else: else:
assets = Asset.objects.filter(nodes__id=self.id) assets = self.assets.all()
return assets return assets
def get_valid_assets(self): def get_valid_assets(self):
...@@ -102,8 +109,8 @@ class Node(models.Model): ...@@ -102,8 +109,8 @@ class Node(models.Model):
if self.is_root(): if self.is_root():
assets = Asset.objects.all() assets = Asset.objects.all()
else: else:
nodes = self.get_all_children_with_self() pattern = r'^{0}$|^{0}:'.format(self.key)
assets = Asset.objects.filter(nodes__in=nodes).distinct() assets = Asset.objects.filter(nodes__key__regex=pattern)
return assets return assets
def get_all_valid_assets(self): def get_all_valid_assets(self):
...@@ -125,26 +132,33 @@ class Node(models.Model): ...@@ -125,26 +132,33 @@ class Node(models.Model):
@parent.setter @parent.setter
def parent(self, parent): def parent(self, parent):
self.key = parent.get_next_child_key() if self.is_node:
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.key = parent.get_next_child_key()
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
else:
self.key = parent.key+':fake'
@property def get_ancestor(self, with_self=False):
def ancestor(self):
if self.is_root(): if self.is_root():
ancestor = self.__class__.objects.filter(key='0') ancestor = self.__class__.objects.filter(key='0')
else: return ancestor
_key = self.key.split(':')
ancestor_keys = [] _key = self.key.split(':')
for i in range(len(_key)-1): if not with_self:
_key.pop() _key.pop()
ancestor_keys.append(':'.join(_key)) ancestor_keys = []
ancestor = self.__class__.objects.filter(key__in=ancestor_keys) for i in range(len(_key)):
ancestor = list(ancestor) ancestor_keys.append(':'.join(_key))
return ancestor _key.pop()
ancestor = self.__class__.objects.filter(
@property key__in=ancestor_keys
def ancestor_with_self(self): ).order_by('key')
ancestor = list(self.ancestor)
ancestor.insert(0, self)
return ancestor return ancestor
@classmethod @classmethod
...@@ -152,4 +166,6 @@ class Node(models.Model): ...@@ -152,4 +166,6 @@ class Node(models.Model):
obj, created = cls.objects.get_or_create( obj, created = cls.objects.get_or_create(
key='0', defaults={"key": '0', 'value': "ROOT"} key='0', defaults={"key": '0', 'value': "ROOT"}
) )
print(obj)
return obj return obj
...@@ -16,8 +16,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -16,8 +16,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
""" """
资产的数据结构 资产的数据结构
""" """
nodes = serializers.SerializerMethodField()
class Meta: class Meta:
model = Asset model = Asset
list_serializer_class = BulkListSerializer list_serializer_class = BulkListSerializer
...@@ -31,10 +29,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -31,10 +29,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
]) ])
return fields return fields
@staticmethod
def get_nodes(obj):
return [n.id for n in obj.get_nodes_or_cache()]
class AssetGrantedSerializer(serializers.ModelSerializer): class AssetGrantedSerializer(serializers.ModelSerializer):
""" """
......
...@@ -9,7 +9,7 @@ from .asset import AssetGrantedSerializer ...@@ -9,7 +9,7 @@ from .asset import AssetGrantedSerializer
__all__ = [ __all__ = [
'NodeSerializer', "NodeGrantedSerializer", "NodeAddChildrenSerializer", 'NodeSerializer', "NodeGrantedSerializer", "NodeAddChildrenSerializer",
"NodeAssetsSerializer", "NodeCurrentSerializer", "NodeAssetsSerializer",
] ]
...@@ -64,11 +64,11 @@ class NodeSerializer(serializers.ModelSerializer): ...@@ -64,11 +64,11 @@ class NodeSerializer(serializers.ModelSerializer):
@staticmethod @staticmethod
def get_parent(obj): def get_parent(obj):
return obj.parent.id return obj.parent.id if obj.is_node else obj.parent_id
@staticmethod @staticmethod
def get_assets_amount(obj): def get_assets_amount(obj):
return obj.get_all_assets().count() return obj.get_all_assets().count() if obj.is_node else 0
def get_fields(self): def get_fields(self):
fields = super().get_fields() fields = super().get_fields()
...@@ -77,12 +77,6 @@ class NodeSerializer(serializers.ModelSerializer): ...@@ -77,12 +77,6 @@ class NodeSerializer(serializers.ModelSerializer):
return fields return fields
class NodeCurrentSerializer(NodeSerializer):
@staticmethod
def get_assets_amount(obj):
return obj.get_assets().count()
class NodeAssetsSerializer(serializers.ModelSerializer): class NodeAssetsSerializer(serializers.ModelSerializer):
assets = serializers.PrimaryKeyRelatedField(many=True, queryset=Asset.objects.all()) assets = serializers.PrimaryKeyRelatedField(many=True, queryset=Asset.objects.all())
......
...@@ -64,7 +64,6 @@ def on_system_user_assets_change(sender, instance=None, **kwargs): ...@@ -64,7 +64,6 @@ def on_system_user_assets_change(sender, instance=None, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_node_changed(sender, instance=None, **kwargs): def on_asset_node_changed(sender, instance=None, **kwargs):
if isinstance(instance, Asset): if isinstance(instance, Asset):
instance.expire_nodes_cache()
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
logger.debug("Asset node change signal received") logger.debug("Asset node change signal received")
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
...@@ -81,10 +80,6 @@ def on_asset_node_changed(sender, instance=None, **kwargs): ...@@ -81,10 +80,6 @@ def on_asset_node_changed(sender, instance=None, **kwargs):
def on_node_assets_changed(sender, instance=None, **kwargs): def on_node_assets_changed(sender, instance=None, **kwargs):
if isinstance(instance, Node): if isinstance(instance, Node):
assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
# 清理资产节点缓存
for asset in assets:
asset.expire_nodes_cache()
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
logger.debug("Node assets change signal received") logger.debug("Node assets change signal received")
# 重新关联系统用户和资产的关系 # 重新关联系统用户和资产的关系
......
...@@ -95,7 +95,7 @@ function initTree2() { ...@@ -95,7 +95,7 @@ function initTree2() {
}; };
var zNodes = []; var zNodes = [];
$.get("{% url 'api-assets:node-list' %}?show_current_asset=1", function(data, status){ $.get("{% url 'api-assets:node-list' %}", function(data, status){
$.each(data, function (index, value) { $.each(data, function (index, value) {
value["pId"] = value["parent"]; value["pId"] = value["parent"];
{#value["open"] = true;#} {#value["open"] = true;#}
......
...@@ -399,8 +399,7 @@ function initTree() { ...@@ -399,8 +399,7 @@ function initTree() {
}; };
var zNodes = []; var zNodes = [];
var query_params = {'show_current_asset': getCookie('show_current_asset')}; $.get("{% url 'api-assets:node-list' %}", function(data, status){
$.get("{% url 'api-assets:node-list' %}", query_params, function(data, status){
$.each(data, function (index, value) { $.each(data, function (index, value) {
value["pId"] = value["parent"]; value["pId"] = value["parent"];
if (value["key"] === "0") { if (value["key"] === "0") {
...@@ -436,7 +435,7 @@ $(document).ready(function(){ ...@@ -436,7 +435,7 @@ $(document).ready(function(){
initTable(); initTable();
initTree(); initTree();
if(getCookie('show_current_asset') === 'yes'){ if(getCookie('show_current_asset') === '1'){
$('#show_all_asset').css('display', 'inline-block'); $('#show_all_asset').css('display', 'inline-block');
} }
else{ else{
...@@ -564,7 +563,7 @@ $(document).ready(function(){ ...@@ -564,7 +563,7 @@ $(document).ready(function(){
hideRMenu(); hideRMenu();
$(this).css('display', 'none'); $(this).css('display', 'none');
$('#show_all_asset').css('display', 'inline-block'); $('#show_all_asset').css('display', 'inline-block');
setCookie('show_current_asset', 'yes'); setCookie('show_current_asset', '1');
location.reload(); location.reload();
}) })
.on('click', '.btn-show-all-asset', function(){ .on('click', '.btn-show-all-asset', function(){
......
...@@ -23,8 +23,6 @@ urlpatterns = [ ...@@ -23,8 +23,6 @@ urlpatterns = [
api.AssetRefreshHardwareApi.as_view(), name='asset-refresh'), api.AssetRefreshHardwareApi.as_view(), name='asset-refresh'),
url(r'^v1/assets/(?P<pk>[0-9a-zA-Z\-]{36})/alive/$', url(r'^v1/assets/(?P<pk>[0-9a-zA-Z\-]{36})/alive/$',
api.AssetAdminUserTestApi.as_view(), name='asset-alive-test'), api.AssetAdminUserTestApi.as_view(), name='asset-alive-test'),
url(r'^v1/assets/user-assets/$',
api.UserAssetListView.as_view(), name='user-asset-list'),
url(r'^v1/admin-user/(?P<pk>[0-9a-zA-Z\-]{36})/nodes/$', url(r'^v1/admin-user/(?P<pk>[0-9a-zA-Z\-]{36})/nodes/$',
api.ReplaceNodesAdminUserApi.as_view(), name='replace-nodes-admin-user'), api.ReplaceNodesAdminUserApi.as_view(), name='replace-nodes-admin-user'),
url(r'^v1/admin-user/(?P<pk>[0-9a-zA-Z\-]{36})/auth/$', url(r'^v1/admin-user/(?P<pk>[0-9a-zA-Z\-]{36})/auth/$',
...@@ -35,17 +33,26 @@ urlpatterns = [ ...@@ -35,17 +33,26 @@ urlpatterns = [
api.SystemUserPushApi.as_view(), name='system-user-push'), api.SystemUserPushApi.as_view(), name='system-user-push'),
url(r'^v1/system-user/(?P<pk>[0-9a-zA-Z\-]{36})/connective/$', url(r'^v1/system-user/(?P<pk>[0-9a-zA-Z\-]{36})/connective/$',
api.SystemUserTestConnectiveApi.as_view(), name='system-user-connective'), api.SystemUserTestConnectiveApi.as_view(), name='system-user-connective'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/children/$', api.NodeChildrenApi.as_view(), name='node-children'), url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/children/$',
api.NodeChildrenApi.as_view(), name='node-children'),
url(r'^v1/nodes/children/$', api.NodeChildrenApi.as_view(), name='node-children-2'), url(r'^v1/nodes/children/$', api.NodeChildrenApi.as_view(), name='node-children-2'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/children/add/$', api.NodeAddChildrenApi.as_view(), name='node-add-children'), url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/children/add/$',
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/$', api.NodeAssetsApi.as_view(), name='node-assets'), api.NodeAddChildrenApi.as_view(), name='node-add-children'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/add/$', api.NodeAddAssetsApi.as_view(), name='node-add-assets'), url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/$',
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/replace/$', api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'), api.NodeAssetsApi.as_view(), name='node-assets'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/remove/$', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'), url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/add/$',
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/refresh-hardware-info/$', api.RefreshNodeHardwareInfoApi.as_view(), name='node-refresh-hardware-info'), api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/test-connective/$', api.TestNodeConnectiveApi.as_view(), name='node-test-connective'), url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/replace/$',
api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/assets/remove/$',
api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/refresh-hardware-info/$',
api.RefreshNodeHardwareInfoApi.as_view(), name='node-refresh-hardware-info'),
url(r'^v1/nodes/(?P<pk>[0-9a-zA-Z\-]{36})/test-connective/$',
api.TestNodeConnectiveApi.as_view(), name='node-test-connective'),
url(r'^v1/gateway/(?P<pk>[0-9a-zA-Z\-]{36})/test-connective/$', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), url(r'^v1/gateway/(?P<pk>[0-9a-zA-Z\-]{36})/test-connective/$',
api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),
] ]
urlpatterns += router.urls urlpatterns += router.urls
......
...@@ -21,23 +21,13 @@ class MailTestingAPI(APIView): ...@@ -21,23 +21,13 @@ class MailTestingAPI(APIView):
serializer = self.serializer_class(data=request.data) serializer = self.serializer_class(data=request.data)
if serializer.is_valid(): if serializer.is_valid():
email_host_user = serializer.validated_data["EMAIL_HOST_USER"] email_host_user = serializer.validated_data["EMAIL_HOST_USER"]
kwargs = { for k, v in serializer.validated_data.items():
"host": serializer.validated_data["EMAIL_HOST"], if k.startswith('EMAIL'):
"port": serializer.validated_data["EMAIL_PORT"], setattr(settings, k, v)
"username": serializer.validated_data["EMAIL_HOST_USER"],
"password": serializer.validated_data["EMAIL_HOST_PASSWORD"],
"use_ssl": serializer.validated_data["EMAIL_USE_SSL"],
"use_tls": serializer.validated_data["EMAIL_USE_TLS"]
}
connection = get_connection(timeout=5, **kwargs)
try: try:
connection.open() subject = "Test"
except Exception as e: message = "Test smtp setting"
return Response({"error": str(e)}, status=401) send_mail(subject, message, email_host_user, [email_host_user])
try:
send_mail("Test", "Test smtp setting", email_host_user,
[email_host_user], connection=connection)
except Exception as e: except Exception as e:
return Response({"error": str(e)}, status=401) return Response({"error": str(e)}, status=401)
......
...@@ -2,6 +2,7 @@ from django.core.mail import send_mail ...@@ -2,6 +2,7 @@ from django.core.mail import send_mail
from django.conf import settings from django.conf import settings
from celery import shared_task from celery import shared_task
from .utils import get_logger from .utils import get_logger
from .models import Setting
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -21,6 +22,10 @@ def send_mail_async(*args, **kwargs): ...@@ -21,6 +22,10 @@ def send_mail_async(*args, **kwargs):
Example: Example:
send_mail_sync.delay(subject, message, recipient_list, fail_silently=False, html_message=None) send_mail_sync.delay(subject, message, recipient_list, fail_silently=False, html_message=None)
""" """
configs = Setting.objects.filter(name__startswith='EMAIL')
for config in configs:
setattr(settings, config.name, config.cleaned_value)
if len(args) == 3: if len(args) == 3:
args = list(args) args = list(args)
args[0] = settings.EMAIL_SUBJECT_PREFIX + args[0] args[0] = settings.EMAIL_SUBJECT_PREFIX + args[0]
......
...@@ -16,6 +16,7 @@ import calendar ...@@ -16,6 +16,7 @@ import calendar
import threading import threading
from io import StringIO from io import StringIO
import uuid import uuid
from functools import wraps
import paramiko import paramiko
import sshpubkeys import sshpubkeys
...@@ -395,3 +396,17 @@ class TeeObj: ...@@ -395,3 +396,17 @@ class TeeObj:
def close(self): def close(self):
self.file_obj.close() self.file_obj.close()
def with_cache(func):
cache = {}
key = "_{}.{}".format(func.__module__, func.__name__)
@wraps(func)
def wrapper(*args, **kwargs):
cached = cache.get(key)
if cached:
return cached
res = func(*args, **kwargs)
cache[key] = res
return res
return wrapper
...@@ -41,11 +41,11 @@ class AssetPermissionViewSet(viewsets.ModelViewSet): ...@@ -41,11 +41,11 @@ class AssetPermissionViewSet(viewsets.ModelViewSet):
asset = get_object_or_404(Asset, pk=asset_id) asset = get_object_or_404(Asset, pk=asset_id)
permissions = set(queryset.filter(assets=asset)) permissions = set(queryset.filter(assets=asset))
for node in asset.nodes.all(): for node in asset.nodes.all():
inherit_nodes.update(set(node.ancestor_with_self)) inherit_nodes.update(set(node.get_ancestor(with_self=True)))
elif node_id: elif node_id:
node = get_object_or_404(Node, pk=node_id) node = get_object_or_404(Node, pk=node_id)
permissions = set(queryset.filter(nodes=node)) permissions = set(queryset.filter(nodes=node))
inherit_nodes = node.ancestor inherit_nodes = node.get_ancestor()
for n in inherit_nodes: for n in inherit_nodes:
_permissions = queryset.filter(nodes=n) _permissions = queryset.filter(nodes=n)
...@@ -70,7 +70,8 @@ class UserGrantedAssetsApi(ListAPIView): ...@@ -70,7 +70,8 @@ class UserGrantedAssetsApi(ListAPIView):
else: else:
user = self.request.user user = self.request.user
for k, v in AssetPermissionUtil.get_user_assets(user).items(): util = AssetPermissionUtil(user)
for k, v in util.get_assets().items():
if k.is_unixlike(): if k.is_unixlike():
system_users_granted = [s for s in v if s.protocol == 'ssh'] system_users_granted = [s for s in v if s.protocol == 'ssh']
else: else:
...@@ -95,7 +96,8 @@ class UserGrantedNodesApi(ListAPIView): ...@@ -95,7 +96,8 @@ class UserGrantedNodesApi(ListAPIView):
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
else: else:
user = self.request.user user = self.request.user
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) util = AssetPermissionUtil(user)
nodes = util.get_nodes_with_assets()
return nodes.keys() return nodes.keys()
def get_permissions(self): def get_permissions(self):
...@@ -116,7 +118,8 @@ class UserGrantedNodesWithAssetsApi(ListAPIView): ...@@ -116,7 +118,8 @@ class UserGrantedNodesWithAssetsApi(ListAPIView):
else: else:
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) util = AssetPermissionUtil(user)
nodes = util.get_nodes_with_assets()
for node, _assets in nodes.items(): for node, _assets in nodes.items():
assets = _assets.keys() assets = _assets.keys()
for k, v in _assets.items(): for k, v in _assets.items():
...@@ -147,8 +150,9 @@ class UserGrantedNodeAssetsApi(ListAPIView): ...@@ -147,8 +150,9 @@ class UserGrantedNodeAssetsApi(ListAPIView):
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
else: else:
user = self.request.user user = self.request.user
util = AssetPermissionUtil(user)
node = get_object_or_404(Node, id=node_id) node = get_object_or_404(Node, id=node_id)
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) nodes = util.get_nodes_with_assets()
assets = nodes.get(node, []) assets = nodes.get(node, [])
for asset, system_users in assets.items(): for asset, system_users in assets.items():
asset.system_users_granted = system_users asset.system_users_granted = system_users
...@@ -172,7 +176,8 @@ class UserGroupGrantedAssetsApi(ListAPIView): ...@@ -172,7 +176,8 @@ class UserGroupGrantedAssetsApi(ListAPIView):
return queryset return queryset
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
assets = AssetPermissionUtil.get_user_group_assets(user_group) util = AssetPermissionUtil(user_group)
assets = util.get_assets()
for k, v in assets.items(): for k, v in assets.items():
k.system_users_granted = v k.system_users_granted = v
queryset.append(k) queryset.append(k)
...@@ -189,7 +194,8 @@ class UserGroupGrantedNodesApi(ListAPIView): ...@@ -189,7 +194,8 @@ class UserGroupGrantedNodesApi(ListAPIView):
if group_id: if group_id:
group = get_object_or_404(UserGroup, id=group_id) group = get_object_or_404(UserGroup, id=group_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(group) util = AssetPermissionUtil(group)
nodes = util.get_nodes_with_assets()
return nodes.keys() return nodes.keys()
return queryset return queryset
...@@ -206,7 +212,8 @@ class UserGroupGrantedNodesWithAssetsApi(ListAPIView): ...@@ -206,7 +212,8 @@ class UserGroupGrantedNodesWithAssetsApi(ListAPIView):
return queryset return queryset
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) util = AssetPermissionUtil(user_group)
nodes = util.get_nodes_with_assets()
for node, _assets in nodes.items(): for node, _assets in nodes.items():
assets = _assets.keys() assets = _assets.keys()
for asset, system_users in _assets.items(): for asset, system_users in _assets.items():
...@@ -226,7 +233,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView): ...@@ -226,7 +233,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
node = get_object_or_404(Node, id=node_id) node = get_object_or_404(Node, id=node_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) util = AssetPermissionUtil(user_group)
nodes = util.get_nodes_with_assets()
assets = nodes.get(node, []) assets = nodes.get(node, [])
for asset, system_users in assets.items(): for asset, system_users in assets.items():
asset.system_users_granted = system_users asset.system_users_granted = system_users
...@@ -246,7 +254,8 @@ class ValidateUserAssetPermissionView(APIView): ...@@ -246,7 +254,8 @@ class ValidateUserAssetPermissionView(APIView):
asset = get_object_or_404(Asset, id=asset_id) asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id) system_user = get_object_or_404(SystemUser, id=system_id)
assets_granted = AssetPermissionUtil.get_user_assets(user) util = AssetPermissionUtil(user)
assets_granted = util.get_assets()
if system_user in assets_granted.get(asset, []): if system_user in assets_granted.get(asset, []):
return Response({'msg': True}, status=200) return Response({'msg': True}, status=200)
else: else:
......
This diff is collapsed.
#!/bin/bash #!/bin/bash
if grep -q 'source ~/.autoenv/activate.sh' ~/.bashrc; then if grep -q 'source /opt/autoenv/activate.sh' ~/.bashrc; then
echo -e "\033[31m 正在自动载入 python 环境 \033[0m" echo -e "\033[31m 正在自动载入 python 环境 \033[0m"
else else
echo -e "\033[31m 不支持自动升级,请参考 http://docs.jumpserver.org/zh/docs/upgrade.html 手动升级 \033[0m" echo -e "\033[31m 不支持自动升级,请参考 http://docs.jumpserver.org/zh/docs/upgrade.html 手动升级 \033[0m"
...@@ -40,5 +40,6 @@ git pull && pip install -r requirements/requirements.txt && cd utils && sh make_ ...@@ -40,5 +40,6 @@ git pull && pip install -r requirements/requirements.txt && cd utils && sh make_
cd .. && ./jms start all -d cd .. && ./jms start all -d
echo -e "\033[31m 请检查jumpserver是否启动成功 \033[0m" echo -e "\033[31m 请检查jumpserver是否启动成功 \033[0m"
echo -e "\033[31m 备份文件存放于$jumpserver_backup目录 \033[0m" echo -e "\033[31m 备份文件存放于$jumpserver_backup目录 \033[0m"
stty erase ^?
exit 0 exit 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment