Unverified Commit ffb4a262 authored by BaiJiangJie's avatar BaiJiangJie Committed by GitHub

Merge pull request #3266 from jumpserver/bugfix

[Update] 修改正则表达式
parents 11f710b3 5055a9f3
...@@ -45,7 +45,7 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend): ...@@ -45,7 +45,7 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
@staticmethod @staticmethod
def perform_query(pattern, queryset): def perform_query(pattern, queryset):
return queryset.filter(nodes__key__regex=pattern) return queryset.filter(nodes__key__regex=pattern).distinct()
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
node, has_query_arg = self.get_query_node(request) node, has_query_arg = self.get_query_node(request)
......
...@@ -112,7 +112,7 @@ class NodesRelationMixin: ...@@ -112,7 +112,7 @@ class NodesRelationMixin:
def get_all_nodes(self, flat=False): def get_all_nodes(self, flat=False):
nodes = [] nodes = []
for node in self.get_nodes(): for node in self.get_nodes():
_nodes = node.get_ancestor(with_self=True) _nodes = node.get_ancestors(with_self=True)
nodes.append(_nodes) nodes.append(_nodes)
if flat: if flat:
nodes = list(reduce(lambda x, y: set(x) | set(y), nodes)) nodes = list(reduce(lambda x, y: set(x) | set(y), nodes))
......
...@@ -10,7 +10,7 @@ from django.utils.translation import ugettext_lazy as _ ...@@ -10,7 +10,7 @@ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext from django.utils.translation import ugettext
from django.core.cache import cache from django.core.cache import cache
from common.utils import get_logger, timeit from common.utils import get_logger, timeit, lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import set_current_org, get_current_org, tmp_to_org from orgs.utils import set_current_org, get_current_org, tmp_to_org
from orgs.models import Organization from orgs.models import Organization
...@@ -74,10 +74,6 @@ class TreeMixin: ...@@ -74,10 +74,6 @@ class TreeMixin:
t = time.time() t = time.time()
cache.set(key, t, ttl) cache.set(key, t, ttl)
@property
def _tree(self):
return self.__class__.tree()
@staticmethod @staticmethod
def refresh_user_tree_cache(): def refresh_user_tree_cache():
""" """
...@@ -108,43 +104,105 @@ class FamilyMixin: ...@@ -108,43 +104,105 @@ class FamilyMixin:
nodes_keys_clean.append(key) nodes_keys_clean.append(key)
return nodes_keys_clean return nodes_keys_clean
@property @classmethod
def children(self): def get_node_all_children_key_pattern(cls, key, with_self=True):
return self.get_children(with_self=False) pattern = r'^{0}:'.format(key)
if with_self:
@property pattern += r'|^{0}$'.format(key)
def all_children(self): return pattern
return self.get_all_children(with_self=False)
def get_children_key_pattern(self, with_self=False): @classmethod
pattern = r'^{0}:[0-9]+$'.format(self.key) def get_node_children_key_pattern(cls, key, with_self=True):
pattern = r'^{0}:[0-9]+$'.format(key)
if with_self: if with_self:
pattern += r'|^{0}$'.format(self.key) pattern += r'|^{0}$'.format(key)
return pattern return pattern
def get_children_key_pattern(self, with_self=False):
return self.get_node_children_key_pattern(self.key, with_self=with_self)
def get_all_children_pattern(self, with_self=False):
return self.get_node_all_children_key_pattern(self.key, with_self=with_self)
def is_children(self, other):
children_pattern = other.get_children_key_pattern(with_self=False)
return re.match(children_pattern, self.key)
def get_children(self, with_self=False): def get_children(self, with_self=False):
pattern = self.get_children_key_pattern(with_self=with_self) pattern = self.get_children_key_pattern(with_self=with_self)
return Node.objects.filter(key__regex=pattern) return Node.objects.filter(key__regex=pattern)
def get_all_children_pattern(self, with_self=False):
pattern = r'^{0}:'.format(self.key)
if with_self:
pattern += r'|^{0}$'.format(self.key)
return pattern
def get_all_children(self, with_self=False): def get_all_children(self, with_self=False):
pattern = self.get_all_children_pattern(with_self=with_self) pattern = self.get_all_children_pattern(with_self=with_self)
children = Node.objects.filter(key__regex=pattern) children = Node.objects.filter(key__regex=pattern)
return children return children
@property @property
def parents(self): def children(self):
return self.get_ancestor(with_self=False) return self.get_children(with_self=False)
@property
def all_children(self):
return self.get_all_children(with_self=False)
def create_child(self, value, _id=None):
with transaction.atomic():
child_key = self.get_next_child_key()
child = self.__class__.objects.create(
id=_id, key=child_key, value=value
)
return child
def get_ancestor(self, with_self=False): def get_next_child_key(self):
mark = self.child_mark
self.child_mark += 1
self.save()
return "{}:{}".format(self.key, mark)
def get_next_child_preset_name(self):
name = ugettext("New node")
values = [
child.value[child.value.rfind(' '):]
for child in self.get_children()
if child.value.startswith(name)
]
values = [int(value) for value in values if value.strip().isdigit()]
count = max(values) + 1 if values else 1
return '{} {}'.format(name, count)
# Parents
@classmethod
def get_node_ancestor_keys(cls, key, with_self=False):
parent_keys = []
key_list = key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def get_ancestor_keys(self, with_self=False):
return self.get_node_ancestor_keys(
self.key, with_self=with_self
)
@property
def ancestors(self):
return self.get_ancestors(with_self=False)
def get_ancestors(self, with_self=False):
ancestor_keys = self.get_ancestor_keys(with_self=with_self) ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys) return self.__class__.objects.filter(key__in=ancestor_keys)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
def is_parent(self, other):
return other.is_children(self)
@property @property
def parent(self): def parent(self):
if self.is_org_root(): if self.is_org_root():
...@@ -177,103 +235,33 @@ class FamilyMixin: ...@@ -177,103 +235,33 @@ class FamilyMixin:
return sibling return sibling
def get_family(self): def get_family(self):
ancestor = self.get_ancestor() ancestors = self.get_ancestors()
children = self.get_all_children() children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)] return [*tuple(ancestors), self, *tuple(children)]
@classmethod
def get_nodes_ancestor_keys_by_key(cls, key, with_self=False):
parent_keys = []
key_list = key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def get_ancestor_keys(self, with_self=False):
return self.__class__.get_nodes_ancestor_keys_by_key(
self.key, with_self=with_self
)
def is_children(self, other):
pattern = r'^{0}:[0-9]+$'.format(self.key)
return re.match(pattern, other.key)
def is_parent(self, other):
return other.is_children(self)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
@property
def parents_keys(self, with_self=False):
keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
keys.append(':'.join(key_list))
key_list.pop()
return keys
def get_next_child_key(self):
mark = self.child_mark
self.child_mark += 1
self.save()
return "{}:{}".format(self.key, mark)
def get_next_child_preset_name(self):
name = ugettext("New node")
values = [
child.value[child.value.rfind(' '):]
for child in self.get_children()
if child.value.startswith(name)
]
values = [int(value) for value in values if value.strip().isdigit()]
count = max(values) + 1 if values else 1
return '{} {}'.format(name, count)
def create_child(self, value, _id=None):
with transaction.atomic():
child_key = self.get_next_child_key()
child = self.__class__.objects.create(
id=_id, key=child_key, value=value
)
return child
class FullValueMixin: class FullValueMixin:
_full_value = None
key = '' key = ''
@property @lazyproperty
def full_value(self): def full_value(self):
if self.is_org_root(): if self.is_org_root():
return self.value return self.value
if self._full_value is not None: value = self.tree().get_node_full_tag(self.key)
return self._full_value
value = self._tree.get_node_full_tag(self.key)
return value return value
class NodeAssetsMixin: class NodeAssetsMixin:
_assets_amount = None
key = '' key = ''
id = None id = None
@property @lazyproperty
def assets_amount(self): def assets_amount(self):
""" """
获取节点下所有资产数量速度太慢,所以需要重写,使用cache等方案 获取节点下所有资产数量速度太慢,所以需要重写,使用cache等方案
:return: :return:
""" """
if self._assets_amount is not None: amount = self.tree().assets_amount(self.key)
return self._assets_amount
amount = self._tree.assets_amount(self.key)
return amount return amount
def get_all_assets(self): def get_all_assets(self):
...@@ -298,13 +286,35 @@ class NodeAssetsMixin: ...@@ -298,13 +286,35 @@ class NodeAssetsMixin:
return self.get_all_assets().valid() return self.get_all_assets().valid()
@classmethod @classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None): def _get_nodes_all_assets(cls, nodes_keys):
"""
当节点比较多的时候,这种正则方式性能差极了
:param nodes_keys:
:return:
"""
from .asset import Asset from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
nodes_children_pattern = set()
for key in nodes_keys:
children_pattern = cls.get_node_all_children_key_pattern(key)
nodes_children_pattern.add(children_pattern)
pattern = '|'.join(nodes_children_pattern)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
@classmethod
def get_nodes_all_assets_ids(cls, nodes_keys):
nodes_keys = cls.clean_children_keys(nodes_keys) nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = set() assets_ids = set()
for key in nodes_keys: for key in nodes_keys:
node_assets_ids = cls.tree().all_assets(key) node_assets_ids = cls.tree().all_assets(key)
assets_ids.update(set(node_assets_ids)) assets_ids.update(set(node_assets_ids))
return assets_ids
@classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None):
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = cls.get_nodes_all_assets_ids(nodes_keys)
if extra_assets_ids: if extra_assets_ids:
assets_ids.update(set(extra_assets_ids)) assets_ids.update(set(extra_assets_ids))
return Asset.objects.filter(id__in=assets_ids) return Asset.objects.filter(id__in=assets_ids)
......
...@@ -148,18 +148,12 @@ class SystemUser(AssetUser): ...@@ -148,18 +148,12 @@ class SystemUser(AssetUser):
return True, None return True, None
def get_all_assets(self): def get_all_assets(self):
from .node import Node from assets.models import Node
args = [Q(systemuser=self)]
pattern = set()
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
nodes_keys = Node.clean_children_keys(nodes_keys) assets_ids = set(self.assets.all().values_list('id', flat=True))
for key in nodes_keys: nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys)
pattern.add(r'^{0}$|^{0}:'.format(key)) assets_ids.update(nodes_assets_ids)
pattern = '|'.join(list(pattern)) assets = Asset.objects.filter(id__in=assets_ids)
if pattern:
args.append(Q(nodes__key__regex=pattern))
args = reduce(lambda x, y: x | y, args)
assets = Asset.objects.filter(args).distinct()
return assets return assets
class Meta: class Meta:
......
...@@ -70,10 +70,14 @@ class TreeService(Tree): ...@@ -70,10 +70,14 @@ class TreeService(Tree):
continue continue
self.nodes_assets_map[key].add(asset_id) self.nodes_assets_map[key].add(asset_id)
def all_children(self, nid, with_self=True, deep=False): def all_children_ids(self, nid, with_self=True):
children_ids = self.expand_tree(nid) children_ids = self.expand_tree(nid)
if not with_self: if not with_self:
next(children_ids) next(children_ids)
return list(children_ids)
def all_children(self, nid, with_self=True, deep=False):
children_ids = self.all_children_ids(nid, with_self=with_self)
return [self.get_node(i, deep=deep) for i in children_ids] return [self.get_node(i, deep=deep) for i in children_ids]
def ancestors(self, nid, with_self=False, deep=False): def ancestors(self, nid, with_self=False, deep=False):
......
...@@ -75,7 +75,7 @@ class AssetPermissionViewSet(viewsets.ModelViewSet): ...@@ -75,7 +75,7 @@ class AssetPermissionViewSet(viewsets.ModelViewSet):
return queryset return queryset
if not node: if not node:
return queryset.none() return queryset.none()
nodes = node.get_ancestor(with_self=True) nodes = node.get_ancestors(with_self=True)
queryset = queryset.filter(nodes__in=nodes) queryset = queryset.filter(nodes__in=nodes)
return queryset return queryset
...@@ -99,11 +99,11 @@ class AssetPermissionViewSet(viewsets.ModelViewSet): ...@@ -99,11 +99,11 @@ class AssetPermissionViewSet(viewsets.ModelViewSet):
for key in inherit_nodes_keys: for key in inherit_nodes_keys:
if key is None: if key is None:
continue continue
ancestor_keys = Node.get_nodes_ancestor_keys_by_key(key, with_self=True) ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True)
inherit_all_nodes.update(ancestor_keys) inherit_all_nodes.update(ancestor_keys)
queryset = queryset.filter( queryset = queryset.filter(
Q(assets__in=assets) | Q(nodes__key__in=inherit_all_nodes) Q(assets__in=assets) | Q(nodes__key__in=inherit_all_nodes)
) ).distinct()
return queryset return queryset
def filter_user(self, queryset): def filter_user(self, queryset):
......
...@@ -96,17 +96,12 @@ class AssetPermission(BasePermission): ...@@ -96,17 +96,12 @@ class AssetPermission(BasePermission):
) )
def get_all_assets(self): def get_all_assets(self):
args = [Q(granted_by_permissions=self)] from assets.models import Node
pattern = set()
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
nodes_keys = Node.clean_children_keys(nodes_keys) assets_ids = set(self.assets.all().values_list('id', flat=True))
for key in nodes_keys: nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys)
pattern.add(r'^{0}$|^{0}:'.format(key)) assets_ids.update(nodes_assets_ids)
pattern = '|'.join(list(pattern)) assets = Asset.objects.filter(id__in=assets_ids)
if pattern:
args.append(Q(nodes__key__regex=pattern))
args = reduce(lambda x, y: x | y, args)
assets = Asset.objects.filter(args).distinct()
return assets return assets
@classmethod @classmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment