Commit a9d15381 authored by ibuler's avatar ibuler

[Update] 修改一些方法

parent f8ff223f
...@@ -152,29 +152,13 @@ class Asset(models.Model): ...@@ -152,29 +152,13 @@ class Asset(models.Model):
def get_all_nodes(self, flat=False): def get_all_nodes(self, flat=False):
nodes = [] nodes = []
for node in self.get_nodes_or_cache(): for node in self.get_nodes():
_nodes = node.get_ancestor(with_self=True) _nodes = node.get_ancestor(with_self=True)
_nodes.append(_nodes) _nodes.append(_nodes)
if flat: if flat:
nodes = list(reduce(lambda x, y: set(x) | set(y), nodes)) nodes = list(reduce(lambda x, y: set(x) | set(y), nodes))
return nodes return nodes
@property
def nodes_cache_key(self):
key = "NODES_OF_{}".format(str(self.id))
return key
def get_nodes_or_cache(self):
cached = cache.get(self.nodes_cache_key)
if cached is not None:
return cached
nodes = list(self.get_nodes())
cache.set(self.nodes_cache_key, nodes, 3600)
return nodes
def expire_nodes_cache(self):
cache.delete(self.nodes_cache_key)
@property @property
def hardware_info(self): def hardware_info(self):
if self.cpu_count: if self.cpu_count:
......
...@@ -22,6 +22,21 @@ class Node(models.Model): ...@@ -22,6 +22,21 @@ class Node(models.Model):
def __str__(self): def __str__(self):
return self.full_value return self.full_value
def __eq__(self, other):
return self.key == other.key
def __gt__(self, other):
if self.is_root():
return True
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
if len(self_key) < len(other_key):
return True
elif len(self_key) > len(other_key):
return False
else:
return self_key[-1] < other_key[-1]
@property @property
def name(self): def name(self):
return self.value return self.value
...@@ -93,7 +108,7 @@ class Node(models.Model): ...@@ -93,7 +108,7 @@ class Node(models.Model):
Q(nodes__id=self.id) | Q(nodes__isnull=True) Q(nodes__id=self.id) | Q(nodes__isnull=True)
) )
else: else:
assets = Asset.objects.filter(nodes__id=self.id) assets = self.assets.all()
return assets return assets
def get_valid_assets(self): def get_valid_assets(self):
...@@ -104,8 +119,8 @@ class Node(models.Model): ...@@ -104,8 +119,8 @@ class Node(models.Model):
if self.is_root(): if self.is_root():
assets = Asset.objects.all() assets = Asset.objects.all()
else: else:
nodes = self.get_all_children(with_self=True) pattern = r'^{0}$|^{0}:'.format(self.key)
assets = Asset.objects.filter(nodes__in=nodes).distinct() assets = Asset.objects.filter(nodes__key__regex=pattern)
return assets return assets
def get_all_valid_assets(self): def get_all_valid_assets(self):
...@@ -154,10 +169,3 @@ class Node(models.Model): ...@@ -154,10 +169,3 @@ class Node(models.Model):
return obj return obj
class Tree:
def __init__(self, root):
self.root = root
self.nodes = []
def add_node(self, node):
pass
...@@ -16,8 +16,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -16,8 +16,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
""" """
资产的数据结构 资产的数据结构
""" """
nodes = serializers.SerializerMethodField()
class Meta: class Meta:
model = Asset model = Asset
list_serializer_class = BulkListSerializer list_serializer_class = BulkListSerializer
...@@ -31,10 +29,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -31,10 +29,6 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
]) ])
return fields return fields
@staticmethod
def get_nodes(obj):
return [n.id for n in obj.get_nodes_or_cache()]
class AssetGrantedSerializer(serializers.ModelSerializer): class AssetGrantedSerializer(serializers.ModelSerializer):
""" """
......
...@@ -68,7 +68,10 @@ class NodeSerializer(serializers.ModelSerializer): ...@@ -68,7 +68,10 @@ class NodeSerializer(serializers.ModelSerializer):
@staticmethod @staticmethod
def get_assets_amount(obj): def get_assets_amount(obj):
return obj.get_all_assets().count() if obj.is_node:
return obj.get_all_assets().count()
else:
return 0
def get_fields(self): def get_fields(self):
fields = super().get_fields() fields = super().get_fields()
......
...@@ -64,7 +64,6 @@ def on_system_user_assets_change(sender, instance=None, **kwargs): ...@@ -64,7 +64,6 @@ def on_system_user_assets_change(sender, instance=None, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_node_changed(sender, instance=None, **kwargs): def on_asset_node_changed(sender, instance=None, **kwargs):
if isinstance(instance, Asset): if isinstance(instance, Asset):
instance.expire_nodes_cache()
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
logger.debug("Asset node change signal received") logger.debug("Asset node change signal received")
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
...@@ -81,10 +80,6 @@ def on_asset_node_changed(sender, instance=None, **kwargs): ...@@ -81,10 +80,6 @@ def on_asset_node_changed(sender, instance=None, **kwargs):
def on_node_assets_changed(sender, instance=None, **kwargs): def on_node_assets_changed(sender, instance=None, **kwargs):
if isinstance(instance, Node): if isinstance(instance, Node):
assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
# 清理资产节点缓存
for asset in assets:
asset.expire_nodes_cache()
if kwargs['action'] == 'post_add': if kwargs['action'] == 'post_add':
logger.debug("Node assets change signal received") logger.debug("Node assets change signal received")
# 重新关联系统用户和资产的关系 # 重新关联系统用户和资产的关系
......
...@@ -70,7 +70,8 @@ class UserGrantedAssetsApi(ListAPIView): ...@@ -70,7 +70,8 @@ class UserGrantedAssetsApi(ListAPIView):
else: else:
user = self.request.user user = self.request.user
for k, v in AssetPermissionUtil.get_user_assets(user).items(): util = AssetPermissionUtil(user)
for k, v in util.get_assets().items():
if k.is_unixlike(): if k.is_unixlike():
system_users_granted = [s for s in v if s.protocol == 'ssh'] system_users_granted = [s for s in v if s.protocol == 'ssh']
else: else:
...@@ -95,7 +96,8 @@ class UserGrantedNodesApi(ListAPIView): ...@@ -95,7 +96,8 @@ class UserGrantedNodesApi(ListAPIView):
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
else: else:
user = self.request.user user = self.request.user
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) util = AssetPermissionUtil(user)
nodes = util.get_nodes_with_assets()
return nodes.keys() return nodes.keys()
def get_permissions(self): def get_permissions(self):
...@@ -116,7 +118,8 @@ class UserGrantedNodesWithAssetsApi(ListAPIView): ...@@ -116,7 +118,8 @@ class UserGrantedNodesWithAssetsApi(ListAPIView):
else: else:
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) util = AssetPermissionUtil(user)
nodes = util.get_nodes_with_assets()
for node, _assets in nodes.items(): for node, _assets in nodes.items():
assets = _assets.keys() assets = _assets.keys()
for k, v in _assets.items(): for k, v in _assets.items():
...@@ -147,8 +150,9 @@ class UserGrantedNodeAssetsApi(ListAPIView): ...@@ -147,8 +150,9 @@ class UserGrantedNodeAssetsApi(ListAPIView):
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
else: else:
user = self.request.user user = self.request.user
util = AssetPermissionUtil(user)
node = get_object_or_404(Node, id=node_id) node = get_object_or_404(Node, id=node_id)
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) nodes = util.get_nodes_with_assets()
assets = nodes.get(node, []) assets = nodes.get(node, [])
for asset, system_users in assets.items(): for asset, system_users in assets.items():
asset.system_users_granted = system_users asset.system_users_granted = system_users
...@@ -172,7 +176,8 @@ class UserGroupGrantedAssetsApi(ListAPIView): ...@@ -172,7 +176,8 @@ class UserGroupGrantedAssetsApi(ListAPIView):
return queryset return queryset
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
assets = AssetPermissionUtil.get_user_group_assets(user_group) util = AssetPermissionUtil(user_group)
assets = util.get_assets()
for k, v in assets.items(): for k, v in assets.items():
k.system_users_granted = v k.system_users_granted = v
queryset.append(k) queryset.append(k)
...@@ -189,7 +194,8 @@ class UserGroupGrantedNodesApi(ListAPIView): ...@@ -189,7 +194,8 @@ class UserGroupGrantedNodesApi(ListAPIView):
if group_id: if group_id:
group = get_object_or_404(UserGroup, id=group_id) group = get_object_or_404(UserGroup, id=group_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(group) util = AssetPermissionUtil(group)
nodes = util.get_nodes_with_assets()
return nodes.keys() return nodes.keys()
return queryset return queryset
...@@ -206,7 +212,8 @@ class UserGroupGrantedNodesWithAssetsApi(ListAPIView): ...@@ -206,7 +212,8 @@ class UserGroupGrantedNodesWithAssetsApi(ListAPIView):
return queryset return queryset
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) util = AssetPermissionUtil(user_group)
nodes = util.get_nodes_with_assets()
for node, _assets in nodes.items(): for node, _assets in nodes.items():
assets = _assets.keys() assets = _assets.keys()
for asset, system_users in _assets.items(): for asset, system_users in _assets.items():
...@@ -226,7 +233,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView): ...@@ -226,7 +233,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
user_group = get_object_or_404(UserGroup, id=user_group_id) user_group = get_object_or_404(UserGroup, id=user_group_id)
node = get_object_or_404(Node, id=node_id) node = get_object_or_404(Node, id=node_id)
nodes = AssetPermissionUtil.get_user_group_nodes_with_assets(user_group) util = AssetPermissionUtil(user_group)
nodes = util.get_nodes_with_assets()
assets = nodes.get(node, []) assets = nodes.get(node, [])
for asset, system_users in assets.items(): for asset, system_users in assets.items():
asset.system_users_granted = system_users asset.system_users_granted = system_users
...@@ -246,7 +254,8 @@ class ValidateUserAssetPermissionView(APIView): ...@@ -246,7 +254,8 @@ class ValidateUserAssetPermissionView(APIView):
asset = get_object_or_404(Asset, id=asset_id) asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id) system_user = get_object_or_404(SystemUser, id=system_id)
assets_granted = AssetPermissionUtil.get_user_assets(user) util = AssetPermissionUtil(user)
assets_granted = util.get_assets()
if system_user in assets_granted.get(asset, []): if system_user in assets_granted.get(asset, []):
return Response({'msg': True}, status=200) return Response({'msg': True}, status=200)
else: else:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment