Commit 1a0ff422 authored by ibuler's avatar ibuler

[Update] 优化树结构

parent 6d96b5db
...@@ -148,6 +148,7 @@ class AssetUserTestConnectiveApi(generics.RetrieveAPIView): ...@@ -148,6 +148,7 @@ class AssetUserTestConnectiveApi(generics.RetrieveAPIView):
Test asset users connective Test asset users connective
""" """
permission_classes = (IsOrgAdminOrAppUser,) permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.TaskIDSerializer
def get_asset_users(self): def get_asset_users(self):
username = self.request.GET.get('username') username = self.request.GET.get('username')
......
...@@ -26,6 +26,7 @@ from ..hands import IsOrgAdmin ...@@ -26,6 +26,7 @@ from ..hands import IsOrgAdmin
from ..models import Node from ..models import Node
from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util
from .. import serializers from .. import serializers
from ..utils import NodeUtil
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -79,12 +80,10 @@ class NodeListAsTreeApi(generics.ListAPIView): ...@@ -79,12 +80,10 @@ class NodeListAsTreeApi(generics.ListAPIView):
serializer_class = TreeNodeSerializer serializer_class = TreeNodeSerializer
def get_queryset(self): def get_queryset(self):
queryset = [node.as_tree_node() for node in Node.objects.all()] queryset = Node.objects.all()
return queryset util = NodeUtil()
nodes = util.get_nodes_by_queryset(queryset)
def filter_queryset(self, queryset): queryset = [node.as_tree_node() for node in nodes]
if self.request.query_params.get('refresh', '0') == '1':
queryset = self.refresh_nodes(queryset)
return queryset return queryset
@staticmethod @staticmethod
...@@ -114,15 +113,11 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): ...@@ -114,15 +113,11 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
def get_queryset(self): def get_queryset(self):
node_key = self.request.query_params.get('key') node_key = self.request.query_params.get('key')
if node_key: util = NodeUtil()
self.node = Node.objects.get(key=node_key) if not node_key:
queryset = self.node.get_children(with_self=False) node_key = Node.root().key
else: self.node = util.get_node_by_key(node_key)
self.is_root = True queryset = self.node.get_children(with_self=True)
self.node = Node.root()
queryset = list(self.node.get_children(with_self=True))
nodes_invalid = Node.objects.exclude(key__startswith=self.node.key)
queryset.extend(list(nodes_invalid))
queryset = [node.as_tree_node() for node in queryset] queryset = [node.as_tree_node() for node in queryset]
queryset = sorted(queryset) queryset = sorted(queryset)
return queryset return queryset
......
...@@ -46,12 +46,6 @@ class AssetQuerySet(models.QuerySet): ...@@ -46,12 +46,6 @@ class AssetQuerySet(models.QuerySet):
return self.active() return self.active()
class AssetManager(OrgManager):
def get_queryset(self):
queryset = super().get_queryset().prefetch_related("nodes", "protocols")
return queryset
class Protocol(models.Model): class Protocol(models.Model):
PROTOCOL_SSH = 'ssh' PROTOCOL_SSH = 'ssh'
PROTOCOL_RDP = 'rdp' PROTOCOL_RDP = 'rdp'
...@@ -131,7 +125,7 @@ class Asset(OrgModelMixin): ...@@ -131,7 +125,7 @@ class Asset(OrgModelMixin):
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created')) date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment')) comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)() objects = OrgManager.from_queryset(AssetQuerySet)()
def __str__(self): def __str__(self):
return '{0.hostname}({0.ip})'.format(self) return '{0.hostname}({0.ip})'.format(self)
...@@ -300,15 +294,20 @@ class Asset(OrgModelMixin): ...@@ -300,15 +294,20 @@ class Asset(OrgModelMixin):
@classmethod @classmethod
def generate_fake(cls, count=100): def generate_fake(cls, count=100):
from random import seed, choice from random import seed, choice
import forgery_py
from django.db import IntegrityError from django.db import IntegrityError
from .node import Node from .node import Node
from orgs.utils import get_current_org
from orgs.models import Organization
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(Node.objects.all()) nodes = list(Node.objects.all())
seed() seed()
for i in range(count): for i in range(count):
ip = [str(i) for i in random.sample(range(255), 4)] ip = [str(i) for i in random.sample(range(255), 4)]
asset = cls(ip='.'.join(ip), asset = cls(ip='.'.join(ip),
hostname=forgery_py.internet.user_name(True), hostname='.'.join(ip),
admin_user=choice(AdminUser.objects.all()), admin_user=choice(AdminUser.objects.all()),
created_by='Fake') created_by='Fake')
try: try:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid import uuid
import re
from django.db import models, transaction from django.db import models, transaction
from django.db.models import Q from django.db.models import Q
...@@ -15,54 +16,185 @@ from orgs.models import Organization ...@@ -15,54 +16,185 @@ from orgs.models import Organization
__all__ = ['Node'] __all__ = ['Node']
class Node(OrgModelMixin): class FamilyMixin:
id = models.UUIDField(default=uuid.uuid4, primary_key=True) _parents = None
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1' _children = None
value = models.CharField(max_length=128, verbose_name=_("Value")) _all_children = None
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
is_node = True is_node = True
_assets_amount = None
_full_value_cache_key = '_NODE_VALUE_{}'
_assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}'
class Meta: @property
verbose_name = _("Node") def children(self):
ordering = ['key'] if self._children:
return self._children
pattern = r'^{0}:[0-9]+$'.format(self.key)
return Node.objects.filter(key__regex=pattern)
def __str__(self): @children.setter
return self.full_value def children(self, value):
self._children = value
def __eq__(self, other): @property
if not other: def all_children(self):
return False if self._all_children:
return self.id == other.id return self._all_children
pattern = r'^{0}:'.format(self.key)
return Node.objects.filter(
key__regex=pattern
)
def __gt__(self, other): def get_children(self, with_self=False):
if self.is_root() and not other.is_root(): children = list(self.children)
return True if with_self:
elif not self.is_root() and other.is_root(): children.append(self)
return False return children
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
self_parent_key = self_key[:-1]
other_parent_key = other_key[:-1]
if self_parent_key == other_parent_key: def get_all_children(self, with_self=False):
return self.name > other.name children = self.all_children
if len(self_parent_key) < len(other_parent_key): if with_self:
return True children = list(children)
elif len(self_parent_key) > len(other_parent_key): children.append(self)
return False return children
return self_key > other_key
def __lt__(self, other): @property
return not self.__gt__(other) def parents(self):
if self._parents:
return self._parents
ancestor_keys = self.get_ancestor_keys()
ancestor = Node.objects.filter(
key__in=ancestor_keys
).order_by('key')
return ancestor
@parents.setter
def parents(self, value):
self._parents = value
def get_ancestor(self, with_self=False):
parents = self.parents
if with_self:
parents = list(parents)
parents.append(self)
return parents
@property @property
def name(self): def parent(self):
return self.value if self._parents:
return self._parents[0]
if self.is_root():
return self
try:
parent = Node.objects.get(key=self.parent_key)
return parent
except Node.DoesNotExist:
return Node.root()
@parent.setter
def parent(self, parent):
if not self.is_node:
self.key = parent.key + ':fake'
return
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.key = parent.get_next_child_key()
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
def get_sibling(self, with_self=False):
key = ':'.join(self.key.split(':')[:-1])
pattern = r'^{}:[0-9]+$'.format(key)
sibling = Node.objects.filter(
key__regex=pattern.format(self.key)
)
if not with_self:
sibling = sibling.exclude(key=self.key)
return sibling
def get_family(self):
ancestor = self.get_ancestor()
children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)]
def get_ancestor_keys(self, with_self=False):
parent_keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def is_children(self, other):
pattern = re.compile(r'^{0}:[0-9]+$'.format(self.key))
return pattern.match(other.key)
def is_parent(self, other):
pattern = re.compile(r'^{0}:[0-9]+$'.format(other.key))
return pattern.match(self.key)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
@property
def parents_keys(self, with_self=False):
keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
keys.append(':'.join(key_list))
key_list.pop()
return keys
class FullValueMixin:
_full_value_cache_key = '_NODE_VALUE_{}'
_full_value = ''
key = ''
@property
def full_value(self):
if self._full_value:
return self._full_value
key = self._full_value_cache_key.format(self.key)
cached = cache.get(key)
if cached:
return cached
if self.is_root():
return self.value
parent_full_value = self.parent.full_value
value = parent_full_value + ' / ' + self.value
self.full_value = value
return value
@full_value.setter
def full_value(self, value):
self._full_value = value
key = self._full_value_cache_key.format(self.key)
cache.set(key, value, 3600*24)
def expire_full_value(self):
key = self._full_value_cache_key.format(self.key)
cache.delete_pattern(key+'*')
@classmethod
def expire_nodes_full_value(cls, nodes=None):
key = cls._full_value_cache_key.format('*')
cache.delete_pattern(key+'*')
from ..utils import NodeUtil
util = NodeUtil()
util.set_full_value()
class AssetsAmountMixin:
_assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}'
_assets_amount = None
key = ''
@property @property
def assets_amount(self): def assets_amount(self):
...@@ -77,53 +209,77 @@ class Node(OrgModelMixin): ...@@ -77,53 +209,77 @@ class Node(OrgModelMixin):
if cached is not None: if cached is not None:
return cached return cached
assets_amount = self.get_all_assets().count() assets_amount = self.get_all_assets().count()
cache.set(cache_key, assets_amount, 3600) self.assets_amount = assets_amount
return assets_amount return assets_amount
@assets_amount.setter @assets_amount.setter
def assets_amount(self, value): def assets_amount(self, value):
self._assets_amount = value self._assets_amount = value
cache_key = self._assets_amount_cache_key.format(self.key)
cache.set(cache_key, value, 3600 * 24)
def expire_assets_amount(self): def expire_assets_amount(self):
ancestor_keys = self.get_ancestor_keys(with_self=True) ancestor_keys = self.get_ancestor_keys(with_self=True)
cache_keys = [self._assets_amount_cache_key.format(k) for k in ancestor_keys] cache_keys = [self._assets_amount_cache_key.format(k) for k in
ancestor_keys]
cache.delete_many(cache_keys) cache.delete_many(cache_keys)
@classmethod @classmethod
def expire_nodes_assets_amount(cls, nodes=None): def expire_nodes_assets_amount(cls, nodes=None):
if nodes: from ..utils import NodeUtil
for node in nodes:
node.expire_assets_amount()
return
key = cls._assets_amount_cache_key.format('*') key = cls._assets_amount_cache_key.format('*')
cache.delete_pattern(key) cache.delete_pattern(key)
util = NodeUtil(with_assets_amount=True)
util.set_assets_amount()
@property
def full_value(self):
key = self._full_value_cache_key.format(self.key)
cached = cache.get(key)
if cached:
return cached
if self.is_root():
return self.value
parent_full_value = self.parent.full_value
value = parent_full_value + ' / ' + self.value
key = self._full_value_cache_key.format(self.key)
cache.set(key, value, 3600)
return value
def expire_full_value(self): class Node(OrgModelMixin, FamilyMixin, FullValueMixin, AssetsAmountMixin):
key = self._full_value_cache_key.format(self.key) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
cache.delete_pattern(key+'*') key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
@classmethod is_node = True
def expire_nodes_full_value(cls, nodes=None): _parents = None
if nodes:
for node in nodes: class Meta:
node.expire_full_value() verbose_name = _("Node")
return ordering = ['key']
key = cls._full_value_cache_key.format('*')
cache.delete_pattern(key+'*') def __str__(self):
return self.full_value
def __eq__(self, other):
if not other:
return False
return self.id == other.id
def __gt__(self, other):
# if self.is_root() and not other.is_root():
# return False
# elif not self.is_root() and other.is_root():
# return True
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
self_parent_key = self_key[:-1]
other_parent_key = other_key[:-1]
if self_parent_key and other_parent_key and \
self_parent_key == other_parent_key:
return self.value > other.value
# if len(self_parent_key) < len(other_parent_key):
# return True
# elif len(self_parent_key) > len(other_parent_key):
# return False
return self_key > other_key
def __lt__(self, other):
return not self.__gt__(other)
@property
def name(self):
return self.value
@property @property
def level(self): def level(self):
...@@ -152,33 +308,6 @@ class Node(OrgModelMixin): ...@@ -152,33 +308,6 @@ class Node(OrgModelMixin):
child = self.__class__.objects.create(id=_id, key=child_key, value=value) child = self.__class__.objects.create(id=_id, key=child_key, value=value)
return child return child
def get_children(self, with_self=False):
pattern = r'^{0}$|^{0}:[0-9]+$' if with_self else r'^{0}:[0-9]+$'
return self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
def get_all_children(self, with_self=False):
pattern = r'^{0}$|^{0}:' if with_self else r'^{0}:'
return self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
def get_sibling(self, with_self=False):
key = ':'.join(self.key.split(':')[:-1])
pattern = r'^{}:[0-9]+$'.format(key)
sibling = self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
if not with_self:
sibling = sibling.exclude(key=self.key)
return sibling
def get_family(self):
ancestor = self.get_ancestor()
children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)]
def get_assets(self): def get_assets(self):
from .asset import Asset from .asset import Asset
if self.is_default_node(): if self.is_default_node():
...@@ -214,52 +343,6 @@ class Node(OrgModelMixin): ...@@ -214,52 +343,6 @@ class Node(OrgModelMixin):
else: else:
return False return False
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
@property
def parent(self):
if self.is_root():
return self
try:
parent = self.__class__.objects.get(key=self.parent_key)
return parent
except Node.DoesNotExist:
return self.__class__.root()
@parent.setter
def parent(self, parent):
if not self.is_node:
self.key = parent.key + ':fake'
return
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.key = parent.get_next_child_key()
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
def get_ancestor_keys(self, with_self=False):
parent_keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def get_ancestor(self, with_self=False):
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
ancestor = self.__class__.objects.filter(
key__in=ancestor_keys
).order_by('key')
return ancestor
@classmethod @classmethod
def create_root_node(cls): def create_root_node(cls):
# 如果使用current_org 在set_current_org时会死循环 # 如果使用current_org 在set_current_org时会死循环
...@@ -310,9 +393,19 @@ class Node(OrgModelMixin): ...@@ -310,9 +393,19 @@ class Node(OrgModelMixin):
tree_node = TreeNode(**data) tree_node = TreeNode(**data)
return tree_node return tree_node
@classmethod
def get_queryset(cls):
from ..utils import NodeUtil
util = NodeUtil()
return util.nodes
@classmethod @classmethod
def generate_fake(cls, count=100): def generate_fake(cls, count=100):
import random import random
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
for i in range(count): for i in range(count):
node = random.choice(cls.objects.all()) node = random.choice(cls.objects.all())
node.create_child('Node {}'.format(i)) node.create_child('Node {}'.format(i))
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
# #
from django.utils.translation import ugettext_lazy as _ from django.db.models import Prefetch
from django.core.cache import cache
from django.utils import timezone
from common.utils import get_object_or_none from common.utils import get_object_or_none, get_logger
from .models import SystemUser, Label from common.struct import Stack
from .models import SystemUser, Label, Node, Asset
def get_assets_by_id_list(id_list): logger = get_logger(__file__)
return Asset.objects.filter(id__in=id_list).filter(is_active=True)
def get_system_users_by_id_list(id_list):
return SystemUser.objects.filter(id__in=id_list)
def get_system_user_by_name(name): def get_system_user_by_name(name):
...@@ -47,4 +41,154 @@ class LabelFilter: ...@@ -47,4 +41,154 @@ class LabelFilter:
return queryset return queryset
class NodeUtil:
def __init__(self, with_assets_amount=False, debug=False):
self.stack = Stack()
self._nodes = {}
self.with_assets_amount = with_assets_amount
self._debug = debug
self.init()
@staticmethod
def sorted_by(node):
return [int(i) for i in node.key.split(':')]
def get_all_nodes(self):
all_nodes = Node.objects.all()
if self.with_assets_amount:
all_nodes = all_nodes.prefetch_related(
Prefetch('assets', queryset=Asset.objects.all().only('id'))
)
for node in all_nodes:
node._assets = set(node.assets.all())
all_nodes = sorted(all_nodes, key=self.sorted_by)
guarder = Node(key='', value='Guarder')
guarder._assets = []
all_nodes.append(guarder)
return all_nodes
def push_to_stack(self, node):
# 入栈之前检查
# 如果栈是空的,证明是一颗树的根部
if self.stack.is_empty():
node._full_value = node.value
node._parents = []
else:
# 如果不是根节点,
# 该节点的祖先应该是父节点的祖先加上父节点
# 该节点的名字是父节点的名字+自己的名字
node._parents = [self.stack.top] + self.stack.top._parents
node._full_value = ' / '.join(
[self.stack.top._full_value, node.value]
)
node._children = []
node._all_children = []
self.debug("入栈: {}".format(node.key))
self.stack.push(node)
# 出栈
def pop_from_stack(self):
_node = self.stack.pop()
self.debug("出栈: {} 栈顶: {}".format(_node.key, self.stack.top.key if self.stack.top else None))
self._nodes[_node.key] = _node
if not self.stack.top:
return
if self.with_assets_amount:
self.stack.top._assets.update(_node._assets)
_node._assets_amount = len(_node._assets)
delattr(_node, '_assets')
self.stack.top._children.append(_node)
self.stack.top._all_children.extend([_node] + _node._children)
def init(self):
all_nodes = self.get_all_nodes()
for node in all_nodes:
self.debug("准备: {} 栈顶: {}".format(node.key, self.stack.top.key if self.stack.top else None))
# 入栈之前检查,该节点是不是栈顶节点的子节点
# 如果不是,则栈顶出栈
while self.stack.top and not self.stack.top.is_children(node):
self.pop_from_stack()
self.push_to_stack(node)
# 出栈最后一个
self.debug("剩余: {}".format(', '.join([n.key for n in self.stack])))
def get_nodes_by_queryset(self, queryset):
nodes = []
for n in queryset:
node = self._nodes.get(n.key)
if not node:
continue
nodes.append(nodes)
return [self]
def get_node_by_key(self, key):
return self._nodes.get(key)
def debug(self, msg):
self._debug and logger.debug(msg)
def set_assets_amount(self):
for node in self._nodes.values():
node.assets_amount = node._assets_amount
def set_full_value(self):
for node in self._nodes.values():
node.full_value = node._full_value
@property
def nodes(self):
return list(self._nodes.values())
# 使用给定节点生成一颗树
# 找到他们的祖先节点
# 可选找到他们的子孙节点
def get_family(self, nodes, with_children=False):
tree_nodes = set()
for n in nodes:
node = self.get_node_by_key(n.key)
if not node:
continue
tree_nodes.update(node._parents)
tree_nodes.add(node)
if with_children:
tree_nodes.update(node._children)
for n in tree_nodes:
delattr(n, '_children')
delattr(n, '_parents')
return list(tree_nodes)
def test_node_tree():
tree = NodeUtil()
for node in tree._nodes.values():
print("Check {}".format(node.key))
children_wanted = node.get_all_children().count()
children = len(node._children)
if children != children_wanted:
print("{} children not equal: {} != {}".format(node.key, children, children_wanted))
assets_amount_wanted = node.get_all_assets().count()
if node._assets_amount != assets_amount_wanted:
print("{} assets amount not equal: {} != {}".format(
node.key, node._assets_amount, assets_amount_wanted)
)
full_value_wanted = node.full_value
if node._full_value != full_value_wanted:
print("{} full value not equal: {} != {}".format(
node.key, node._full_value, full_value_wanted)
)
parents_wanted = node.get_ancestor().count()
parents = len(node._parents)
if parents != parents_wanted:
print("{} parents count not equal: {} != {}".format(
node.key, parents, parents_wanted)
)
# -*- coding: utf-8 -*-
#
class Stack(list):
def is_empty(self):
return len(self) == 0
@property
def top(self):
if self.is_empty():
return None
return self[-1]
@property
def bottom(self):
if self.is_empty():
return None
return self[0]
def size(self):
return len(self)
def push(self, item):
self.append(item)
...@@ -7,7 +7,7 @@ from django.conf.urls.static import static ...@@ -7,7 +7,7 @@ from django.conf.urls.static import static
from django.conf.urls.i18n import i18n_patterns from django.conf.urls.i18n import i18n_patterns
from django.views.i18n import JavaScriptCatalog from django.views.i18n import JavaScriptCatalog
from .views import IndexView, LunaView, I18NView from .views import IndexView, LunaView, I18NView, HealthCheckView
from .swagger import get_swagger_view from .swagger import get_swagger_view
api_v1 = [ api_v1 = [
...@@ -63,6 +63,7 @@ urlpatterns = [ ...@@ -63,6 +63,7 @@ urlpatterns = [
path('', IndexView.as_view(), name='index'), path('', IndexView.as_view(), name='index'),
path('', include(api_v2_patterns)), path('', include(api_v2_patterns)),
path('', include(api_v1_patterns)), path('', include(api_v1_patterns)),
path('api/health/', HealthCheckView.as_view(), name="health"),
path('luna/', LunaView.as_view(), name='luna-view'), path('luna/', LunaView.as_view(), name='luna-view'),
path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'), path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'),
path('settings/', include('settings.urls.view_urls', namespace='settings')), path('settings/', include('settings.urls.view_urls', namespace='settings')),
......
import datetime import datetime
import re import re
import time
from django.http import HttpResponse, HttpResponseRedirect from django.http import HttpResponse, HttpResponseRedirect
from django.conf import settings from django.conf import settings
...@@ -9,6 +10,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -9,6 +10,7 @@ from django.utils.translation import ugettext_lazy as _
from django.db.models import Count from django.db.models import Count
from django.shortcuts import redirect from django.shortcuts import redirect
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.http import HttpResponse from django.http import HttpResponse
from django.utils.encoding import iri_to_uri from django.utils.encoding import iri_to_uri
...@@ -222,3 +224,10 @@ def redirect_format_api(request, *args, **kwargs): ...@@ -222,3 +224,10 @@ def redirect_format_api(request, *args, **kwargs):
return HttpResponseTemporaryRedirect(_path) return HttpResponseTemporaryRedirect(_path)
else: else:
return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404) return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404)
class HealthCheckView(APIView):
permission_classes = ()
def get(self, request):
return Response({"status": 1, "time": int(time.time())})
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
from .ansible.inventory import BaseInventory from .ansible.inventory import BaseInventory
from assets.utils import get_assets_by_id_list, get_system_user_by_id
from common.utils import get_logger from common.utils import get_logger
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import traceback
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.shortcuts import redirect, get_object_or_404 from django.shortcuts import redirect, get_object_or_404
...@@ -33,8 +33,8 @@ class OrgManager(models.Manager): ...@@ -33,8 +33,8 @@ class OrgManager(models.Manager):
def get_queryset(self): def get_queryset(self):
queryset = super(OrgManager, self).get_queryset() queryset = super(OrgManager, self).get_queryset()
kwargs = {} kwargs = {}
_current_org = get_current_org()
_current_org = get_current_org()
if _current_org is None: if _current_org is None:
kwargs['id'] = None kwargs['id'] = None
elif _current_org.is_real(): elif _current_org.is_real():
...@@ -42,12 +42,17 @@ class OrgManager(models.Manager): ...@@ -42,12 +42,17 @@ class OrgManager(models.Manager):
elif _current_org.is_default(): elif _current_org.is_default():
queryset = queryset.filter(org_id="") queryset = queryset.filter(org_id="")
# lines = traceback.format_stack()
# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>")
# for line in lines[-10:-5]:
# print(line)
# print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
queryset = queryset.filter(**kwargs) queryset = queryset.filter(**kwargs)
return queryset return queryset
def all(self): def all(self):
_current_org = get_current_org() if not current_org:
if _current_org is None:
msg = 'You can `objects.set_current_org(org).all()` then run it' msg = 'You can `objects.set_current_org(org).all()` then run it'
return self return self
else: else:
......
...@@ -258,7 +258,9 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserPermissionCacheMixin, ListAPIView) ...@@ -258,7 +258,9 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserPermissionCacheMixin, ListAPIView)
util.filter_permissions( util.filter_permissions(
system_users=self.system_user_id system_users=self.system_user_id
) )
print("111111111111")
nodes = util.get_nodes_with_assets() nodes = util.get_nodes_with_assets()
print("22222222222222")
for node, assets in nodes.items(): for node, assets in nodes.items():
data = parse_node_to_tree_node(node) data = parse_node_to_tree_node(node)
queryset.append(data) queryset.append(data)
......
...@@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelForm from orgs.mixins import OrgModelForm
from orgs.utils import current_org from orgs.utils import current_org
from perms.models import AssetPermission from perms.models import AssetPermission
from assets.models import Asset from assets.models import Asset, Node
__all__ = [ __all__ = [
'AssetPermissionForm', 'AssetPermissionForm',
......
...@@ -4,6 +4,7 @@ import uuid ...@@ -4,6 +4,7 @@ import uuid
from collections import defaultdict from collections import defaultdict
import json import json
from hashlib import md5 from hashlib import md5
import time
from django.utils import timezone from django.utils import timezone
from django.db.models import Q from django.db.models import Q
...@@ -17,6 +18,7 @@ from common.tree import TreeNode ...@@ -17,6 +18,7 @@ from common.tree import TreeNode
from .. import const from .. import const
from ..models import AssetPermission, Action from ..models import AssetPermission, Action
from ..hands import Node from ..hands import Node
from assets.utils import NodeUtil
logger = get_logger(__file__) logger = get_logger(__file__)
...@@ -35,9 +37,8 @@ class GenerateTree: ...@@ -35,9 +37,8 @@ class GenerateTree:
"asset_instance": set("system_user") "asset_instance": set("system_user")
} }
""" """
self.__all_nodes = list(Node.objects.all()) self.node_util = NodeUtil()
self.nodes = defaultdict(dict) self.nodes = defaultdict(dict)
self.direct_nodes = []
self._root_node = None self._root_node = None
self._ungroup_node = None self._ungroup_node = None
...@@ -48,10 +49,8 @@ class GenerateTree: ...@@ -48,10 +49,8 @@ class GenerateTree:
all_nodes = self.nodes.keys() all_nodes = self.nodes.keys()
# 如果没有授权节点,就放到默认的根节点下 # 如果没有授权节点,就放到默认的根节点下
if not all_nodes: if not all_nodes:
root_node = Node.root() return None
self.add_node(root_node) root_node = min(all_nodes)
else:
root_node = max(all_nodes)
self._root_node = root_node self._root_node = root_node
return root_node return root_node
...@@ -60,7 +59,10 @@ class GenerateTree: ...@@ -60,7 +59,10 @@ class GenerateTree:
if self._ungroup_node: if self._ungroup_node:
return self._ungroup_node return self._ungroup_node
node_id = const.UNGROUPED_NODE_ID node_id = const.UNGROUPED_NODE_ID
node_key = self.root_node.get_next_child_key() if self.root_node:
node_key = self.root_node.get_next_child_key()
else:
node_key = '0:0'
node_value = _("Default") node_value = _("Default")
node = Node(id=node_id, key=node_key, value=node_value) node = Node(id=node_id, key=node_key, value=node_value)
self.add_node(node) self.add_node(node)
...@@ -69,11 +71,11 @@ class GenerateTree: ...@@ -69,11 +71,11 @@ class GenerateTree:
def add_asset(self, asset, system_users): def add_asset(self, asset, system_users):
nodes = asset.nodes.all() nodes = asset.nodes.all()
in_nodes = set(self.direct_nodes) & set(nodes) for node in nodes:
for node in in_nodes: if node in self.nodes:
self.nodes[node][asset].update(system_users) self.nodes[node][asset].update(system_users)
if not in_nodes: else:
self.nodes[self.ungrouped_node][asset].update(system_users) self.nodes[self.ungrouped_node][asset].update(system_users)
def get_nodes(self): def get_nodes(self):
for node in self.nodes: for node in self.nodes:
...@@ -84,26 +86,14 @@ class GenerateTree: ...@@ -84,26 +86,14 @@ class GenerateTree:
node.assets_amount = len(assets) node.assets_amount = len(assets)
return self.nodes return self.nodes
# 添加节点时,追溯到根节点
def add_node(self, node): def add_node(self, node):
if node in self.nodes: self.nodes[node] = defaultdict(set)
return
else:
self.nodes[node] = defaultdict(set)
if node.is_root():
return
for n in self.__all_nodes:
if n.key == node.parent_key:
self.add_node(n)
break
# 添加树节点 # 添加树节点
def add_nodes(self, nodes): def add_nodes(self, nodes):
for node in nodes: need_nodes = self.node_util.get_family(nodes, with_children=True)
for node in need_nodes:
self.add_node(node) self.add_node(node)
self.add_nodes(node.get_all_children(with_self=False))
# 如果是直接授权的节点,则放到direct_nodes中
self.direct_nodes.append(node)
def get_user_permissions(user, include_group=True): def get_user_permissions(user, include_group=True):
...@@ -140,35 +130,28 @@ def get_system_user_permissions(system_user): ...@@ -140,35 +130,28 @@ def get_system_user_permissions(system_user):
) )
class AssetPermissionUtil: def timeit(func):
get_permissions_map = { def wrapper(*args, **kwargs):
"User": get_user_permissions, logger.debug("Start call: {}".format(func.__name__))
"UserGroup": get_user_group_permissions, now = time.time()
"Asset": get_asset_permissions, result = func(*args, **kwargs)
"Node": get_node_permissions, using = time.time() - now
"SystemUser": get_system_user_permissions, logger.debug("Call {} end, using: {:.2}".format(func.__name__, using))
} return result
return wrapper
class AssetGranted:
def __init__(self):
self.system_users = {}
class AssetPermissionCacheMixin:
CACHE_KEY_PREFIX = '_ASSET_PERM_CACHE_' CACHE_KEY_PREFIX = '_ASSET_PERM_CACHE_'
CACHE_META_KEY_PREFIX = '_ASSET_PERM_META_KEY_' CACHE_META_KEY_PREFIX = '_ASSET_PERM_META_KEY_'
CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME
CACHE_POLICY_MAP = (('0', 'never'), ('1', 'using'), ('2', 'refresh')) CACHE_POLICY_MAP = (('0', 'never'), ('1', 'using'), ('2', 'refresh'))
def __init__(self, obj, cache_policy='0'):
self.object = obj
self.obj_id = str(obj.id)
self._permissions = None
self._permissions_id = None # 标记_permission的唯一值
self._assets = None
self._filter_id = 'None' # 当通过filter更改 permission是标记
self.cache_policy = cache_policy
self.tree = GenerateTree()
self.change_org_if_need()
@staticmethod
def change_org_if_need():
set_to_root_org()
@classmethod @classmethod
def is_not_using_cache(cls, cache_policy): def is_not_using_cache(cls, cache_policy):
return cls.CACHE_TIME == 0 or cache_policy in cls.CACHE_POLICY_MAP[0] return cls.CACHE_TIME == 0 or cache_policy in cls.CACHE_POLICY_MAP[0]
...@@ -190,94 +173,7 @@ class AssetPermissionUtil: ...@@ -190,94 +173,7 @@ class AssetPermissionUtil:
def _is_refresh_cache(self): def _is_refresh_cache(self):
return self.is_refresh_cache(self.cache_policy) return self.is_refresh_cache(self.cache_policy)
@property @timeit
def permissions(self):
if self._permissions:
return self._permissions
object_cls = self.object.__class__.__name__
func = self.get_permissions_map[object_cls]
permissions = func(self.object)
self._permissions = permissions
return permissions
def filter_permissions(self, **filters):
filters_json = json.dumps(filters, sort_keys=True)
self._permissions = self.permissions.filter(**filters)
self._filter_id = md5(filters_json.encode()).hexdigest()
@staticmethod
def _structured_system_user(system_users, actions):
"""
结构化系统用户
:param system_users:
:param actions:
:return: {system_user1: {'actions': set(), }, }
"""
_attr = {'actions': set(actions)}
_system_users = {system_user: _attr for system_user in system_users}
return _system_users
def get_nodes_direct(self):
"""
返回用户/组授权规则直接关联的节点
:return: {asset1: {system_user1: {'actions': set()},}}
"""
nodes = defaultdict(dict)
permissions = self.permissions.prefetch_related('nodes', 'system_users')
for perm in permissions:
actions = perm.actions.all()
self.tree.add_nodes(perm.nodes.all())
for node in perm.nodes.all():
system_users = perm.system_users.all()
system_users = self._structured_system_user(system_users, actions)
nodes[node].update(system_users)
return nodes
def get_assets_direct(self):
"""
返回用户授权规则直接关联的资产
:return: {asset1: {system_user1: {'actions': set()},}}
"""
assets = defaultdict(dict)
permissions = self.permissions.prefetch_related('assets', 'system_users')
for perm in permissions:
actions = perm.actions.all()
for asset in perm.assets.all().valid().prefetch_related('nodes'):
system_users = perm.system_users.filter(protocol__in=asset.protocols_name)
system_users = self._structured_system_user(system_users, actions)
assets[asset].update(system_users)
return assets
def get_assets_without_cache(self):
"""
:return: {asset1: set(system_user1,)}
"""
if self._assets:
return self._assets
assets = self.get_assets_direct()
nodes = self.get_nodes_direct()
for node, system_users in nodes.items():
_assets = node.get_all_assets().valid().prefetch_related('nodes')
for asset in _assets:
for system_user, attr_dict in system_users.items():
if not asset.has_protocol(system_user.protocol):
continue
if system_user in assets[asset]:
actions = assets[asset][system_user]['actions']
attr_dict['actions'].update(actions)
system_users.update({system_user: attr_dict})
assets[asset].update(system_users)
__assets = defaultdict(set)
for asset, system_users in assets.items():
for system_user, attr_dict in system_users.items():
setattr(system_user, 'actions', attr_dict['actions'])
__assets[asset] = set(system_users.keys())
self._assets = __assets
return self._assets
def get_cache_key(self, resource): def get_cache_key(self, resource):
cache_key = self.CACHE_KEY_PREFIX + '{obj_id}_{filter_id}_{resource}' cache_key = self.CACHE_KEY_PREFIX + '{obj_id}_{filter_id}_{resource}'
return cache_key.format( return cache_key.format(
...@@ -301,27 +197,6 @@ class AssetPermissionUtil: ...@@ -301,27 +197,6 @@ class AssetPermissionUtil:
cached = cache.get(self.asset_key) cached = cache.get(self.asset_key)
return cached return cached
def get_assets(self):
if self._is_not_using_cache():
return self.get_assets_from_cache()
elif self._is_refresh_cache():
self.expire_cache()
return self.get_assets_from_cache()
else:
self.expire_cache()
return self.get_assets_without_cache()
def get_nodes_with_assets_without_cache(self):
"""
返回节点并且包含资产
{"node": {"assets": set("system_user")}}
:return:
"""
assets = self.get_assets_without_cache()
for asset, system_users in assets.items():
self.tree.add_asset(asset, system_users)
return self.tree.get_nodes()
def get_nodes_with_assets_from_cache(self): def get_nodes_with_assets_from_cache(self):
cached = cache.get(self.node_key) cached = cache.get(self.node_key)
if not cached: if not cached:
...@@ -338,13 +213,6 @@ class AssetPermissionUtil: ...@@ -338,13 +213,6 @@ class AssetPermissionUtil:
else: else:
return self.get_nodes_with_assets_without_cache() return self.get_nodes_with_assets_without_cache()
def get_system_user_without_cache(self):
system_users = set()
permissions = self.permissions.prefetch_related('system_users')
for perm in permissions:
system_users.update(perm.system_users.all())
return system_users
def get_system_user_from_cache(self): def get_system_user_from_cache(self):
cached = cache.get(self.system_key) cached = cache.get(self.system_key)
if not cached: if not cached:
...@@ -418,6 +286,152 @@ class AssetPermissionUtil: ...@@ -418,6 +286,152 @@ class AssetPermissionUtil:
cache.delete_pattern(key) cache.delete_pattern(key)
class AssetPermissionUtil(AssetPermissionCacheMixin):
get_permissions_map = {
"User": get_user_permissions,
"UserGroup": get_user_group_permissions,
"Asset": get_asset_permissions,
"Node": get_node_permissions,
"SystemUser": get_system_user_permissions,
}
def __init__(self, obj, cache_policy='0'):
self.object = obj
self.obj_id = str(obj.id)
self._permissions = None
self._permissions_id = None # 标记_permission的唯一值
self._assets = None
self._filter_id = 'None' # 当通过filter更改 permission是标记
self.cache_policy = cache_policy
self.tree = GenerateTree()
self.change_org_if_need()
self.nodes = None
@staticmethod
def change_org_if_need():
set_to_root_org()
@property
def permissions(self):
if self._permissions:
return self._permissions
object_cls = self.object.__class__.__name__
func = self.get_permissions_map[object_cls]
permissions = func(self.object)
self._permissions = permissions
return permissions
@timeit
def filter_permissions(self, **filters):
filters_json = json.dumps(filters, sort_keys=True)
self._permissions = self.permissions.filter(**filters)
self._filter_id = md5(filters_json.encode()).hexdigest()
@staticmethod
@timeit
def _structured_system_user(system_users, actions):
"""
结构化系统用户
:param system_users:
:param actions:
:return: {system_user1: {'actions': set(), }, }
"""
_attr = {'actions': set(actions)}
_system_users = {system_user: _attr for system_user in system_users}
return _system_users
@timeit
def get_nodes_direct(self):
"""
返回用户/组授权规则直接关联的节点
:return: {asset1: {system_user1: {'actions': set()},}}
"""
nodes = defaultdict(dict)
permissions = self.permissions.prefetch_related('nodes', 'system_users', 'actions')
for perm in permissions:
actions = perm.actions.all()
for node in perm.nodes.all():
system_users = perm.system_users.all()
system_users = self._structured_system_user(system_users, actions)
nodes[node].update(system_users)
self.tree.add_nodes(nodes.keys())
# 替换成优化过的node
nodes = {self.tree.node_util.get_node_by_key(k.key): v for k, v in nodes.items()}
return nodes
@timeit
def get_assets_direct(self):
"""
返回用户授权规则直接关联的资产
:return: {asset1: {system_user1: {'actions': set()},}}
"""
assets = defaultdict(dict)
permissions = self.permissions.prefetch_related('assets', 'system_users')
for perm in permissions:
actions = perm.actions.all()
for asset in perm.assets.all().valid().prefetch_related('nodes'):
system_users = perm.system_users.filter(protocol__in=asset.protocols_name)
system_users = self._structured_system_user(system_users, actions)
assets[asset].update(system_users)
return assets
@timeit
def get_assets_without_cache(self):
"""
:return: {asset1: set(system_user1,)}
"""
if self._assets:
return self._assets
assets = self.get_assets_direct()
nodes = self.get_nodes_direct()
# for node, system_users in nodes.items():
# print(9999, node)
# _assets = node.get_all_valid_assets()
# print(".......... end .......")
# for asset in _assets:
# print(">>asset")
# for system_user, attr_dict in system_users.items():
# print(">>>system user")
# if not asset.has_protocol(system_user.protocol):
# continue
# if system_user in assets[asset]:
# actions = assets[asset][system_user]['actions']
# attr_dict['actions'].update(actions)
# system_users.update({system_user: attr_dict})
# print("<<<system user")
# print("<<<asset")
# assets[asset].update(system_users)
# print(">>>>>>")
#
__assets = defaultdict(set)
for asset, system_users in assets.items():
for system_user, attr_dict in system_users.items():
setattr(system_user, 'actions', attr_dict['actions'])
__assets[asset] = set(system_users.keys())
self._assets = __assets
return self._assets
@timeit
def get_nodes_with_assets_without_cache(self):
"""
返回节点并且包含资产
{"node": {"assets": set("system_user")}}
:return:
"""
assets = self.get_assets_without_cache()
for asset, system_users in assets.items():
self.tree.add_asset(asset, system_users)
return self.tree.get_nodes()
def get_system_user_without_cache(self):
system_users = set()
permissions = self.permissions.prefetch_related('system_users')
for perm in permissions:
system_users.update(perm.system_users.all())
return system_users
def is_obj_attr_has(obj, val, attrs=("hostname", "ip", "comment")): def is_obj_attr_has(obj, val, attrs=("hostname", "ip", "comment")):
if not attrs: if not attrs:
vals = [val for val in obj.__dict__.values() if isinstance(val, (str, int))] vals = [val for val in obj.__dict__.values() if isinstance(val, (str, int))]
......
...@@ -242,22 +242,3 @@ class CommandStorageDeleteAPI(APIView): ...@@ -242,22 +242,3 @@ class CommandStorageDeleteAPI(APIView):
storage_name = str(request.data.get('name')) storage_name = str(request.data.get('name'))
Setting.delete_storage('TERMINAL_COMMAND_STORAGE', storage_name) Setting.delete_storage('TERMINAL_COMMAND_STORAGE', storage_name)
return Response({"msg": _('Delete succeed')}, status=200) return Response({"msg": _('Delete succeed')}, status=200)
class DjangoSettingsAPI(APIView):
def get(self, request):
if not settings.DEBUG:
return Response("Not in debug mode")
data = {}
for i in [settings, getattr(settings, '_wrapped')]:
if not i:
continue
for k, v in i.__dict__.items():
if k and k.isupper():
try:
json.dumps(v)
data[k] = v
except (json.JSONDecodeError, TypeError):
data[k] = str(v)
return Response(data)
\ No newline at end of file
...@@ -15,5 +15,4 @@ urlpatterns = [ ...@@ -15,5 +15,4 @@ urlpatterns = [
path('terminal/replay-storage/delete/', api.ReplayStorageDeleteAPI.as_view(), name='replay-storage-delete'), path('terminal/replay-storage/delete/', api.ReplayStorageDeleteAPI.as_view(), name='replay-storage-delete'),
path('terminal/command-storage/create/', api.CommandStorageCreateAPI.as_view(), name='command-storage-create'), path('terminal/command-storage/create/', api.CommandStorageCreateAPI.as_view(), name='command-storage-create'),
path('terminal/command-storage/delete/', api.CommandStorageDeleteAPI.as_view(), name='command-storage-delete'), path('terminal/command-storage/delete/', api.CommandStorageDeleteAPI.as_view(), name='command-storage-delete'),
path('django-settings/', api.DjangoSettingsAPI.as_view(), name='django-settings'),
] ]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment