Commit e9d0104a authored by BaiJiangJie's avatar BaiJiangJie

Merge branch 'dev' of github.com:jumpserver/jumpserver into fix_asset_count_dev

parents 7c694c68 757a31a5
...@@ -40,7 +40,9 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet): ...@@ -40,7 +40,9 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet):
permission_classes = (IsSuperUserOrAppUser,) permission_classes = (IsSuperUserOrAppUser,)
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()\
.prefetch_related('labels', 'nodes')\
.select_related('admin_user')
admin_user_id = self.request.query_params.get('admin_user_id') admin_user_id = self.request.query_params.get('admin_user_id')
node_id = self.request.query_params.get("node_id") node_id = self.request.query_params.get("node_id")
show_current_asset = self.request.query_params.get("show_current_asset") show_current_asset = self.request.query_params.get("show_current_asset")
...@@ -66,7 +68,6 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet): ...@@ -66,7 +68,6 @@ class AssetViewSet(IDInFilterMixin, LabelFilter, BulkModelViewSet):
queryset = queryset.filter( queryset = queryset.filter(
nodes__key__regex='^{}(:[0-9]+)*$'.format(node.key), nodes__key__regex='^{}(:[0-9]+)*$'.format(node.key),
).distinct() ).distinct()
return queryset return queryset
......
...@@ -59,42 +59,70 @@ class Asset(models.Model): ...@@ -59,42 +59,70 @@ class Asset(models.Model):
('Other', 'Other'), ('Other', 'Other'),
) )
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
ip = models.GenericIPAddressField(max_length=32, verbose_name=_('IP'), db_index=True) ip = models.GenericIPAddressField(max_length=32, verbose_name=_('IP'),
hostname = models.CharField(max_length=128, unique=True, verbose_name=_('Hostname')) db_index=True)
hostname = models.CharField(max_length=128, unique=True,
verbose_name=_('Hostname'))
port = models.IntegerField(default=22, verbose_name=_('Port')) port = models.IntegerField(default=22, verbose_name=_('Port'))
platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES, default='Linux', verbose_name=_('Platform')) platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES,
domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL) default='Linux', verbose_name=_('Platform'))
nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")) domain = models.ForeignKey("assets.Domain", null=True, blank=True,
related_name='assets', verbose_name=_("Domain"),
on_delete=models.SET_NULL)
nodes = models.ManyToManyField('assets.Node', default=default_node,
related_name='assets',
verbose_name=_("Nodes"))
is_active = models.BooleanField(default=True, verbose_name=_('Is active')) is_active = models.BooleanField(default=True, verbose_name=_('Is active'))
# Auth # Auth
admin_user = models.ForeignKey('assets.AdminUser', on_delete=models.PROTECT, null=True, verbose_name=_("Admin user")) admin_user = models.ForeignKey('assets.AdminUser', on_delete=models.PROTECT,
null=True, verbose_name=_("Admin user"))
# Some information # Some information
public_ip = models.GenericIPAddressField(max_length=32, blank=True, null=True, verbose_name=_('Public IP')) public_ip = models.GenericIPAddressField(max_length=32, blank=True,
number = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Asset number')) null=True,
verbose_name=_('Public IP'))
number = models.CharField(max_length=32, null=True, blank=True,
verbose_name=_('Asset number'))
# Collect # Collect
vendor = models.CharField(max_length=64, null=True, blank=True, verbose_name=_('Vendor')) vendor = models.CharField(max_length=64, null=True, blank=True,
model = models.CharField(max_length=54, null=True, blank=True, verbose_name=_('Model')) verbose_name=_('Vendor'))
sn = models.CharField(max_length=128, null=True, blank=True, verbose_name=_('Serial number')) model = models.CharField(max_length=54, null=True, blank=True,
verbose_name=_('Model'))
cpu_model = models.CharField(max_length=64, null=True, blank=True, verbose_name=_('CPU model')) sn = models.CharField(max_length=128, null=True, blank=True,
verbose_name=_('Serial number'))
cpu_model = models.CharField(max_length=64, null=True, blank=True,
verbose_name=_('CPU model'))
cpu_count = models.IntegerField(null=True, verbose_name=_('CPU count')) cpu_count = models.IntegerField(null=True, verbose_name=_('CPU count'))
cpu_cores = models.IntegerField(null=True, verbose_name=_('CPU cores')) cpu_cores = models.IntegerField(null=True, verbose_name=_('CPU cores'))
memory = models.CharField(max_length=64, null=True, blank=True, verbose_name=_('Memory')) memory = models.CharField(max_length=64, null=True, blank=True,
disk_total = models.CharField(max_length=1024, null=True, blank=True, verbose_name=_('Disk total')) verbose_name=_('Memory'))
disk_info = models.CharField(max_length=1024, null=True, blank=True, verbose_name=_('Disk info')) disk_total = models.CharField(max_length=1024, null=True, blank=True,
verbose_name=_('Disk total'))
os = models.CharField(max_length=128, null=True, blank=True, verbose_name=_('OS')) disk_info = models.CharField(max_length=1024, null=True, blank=True,
os_version = models.CharField(max_length=16, null=True, blank=True, verbose_name=_('OS version')) verbose_name=_('Disk info'))
os_arch = models.CharField(max_length=16, blank=True, null=True, verbose_name=_('OS arch'))
hostname_raw = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Hostname raw')) os = models.CharField(max_length=128, null=True, blank=True,
verbose_name=_('OS'))
labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels")) os_version = models.CharField(max_length=16, null=True, blank=True,
created_by = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Created by')) verbose_name=_('OS version'))
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created')) os_arch = models.CharField(max_length=16, blank=True, null=True,
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment')) verbose_name=_('OS arch'))
hostname_raw = models.CharField(max_length=128, blank=True, null=True,
verbose_name=_('Hostname raw'))
labels = models.ManyToManyField('assets.Label', blank=True,
related_name='assets',
verbose_name=_("Labels"))
created_by = models.CharField(max_length=32, null=True, blank=True,
verbose_name=_('Created by'))
date_created = models.DateTimeField(auto_now_add=True, null=True,
blank=True,
verbose_name=_('Date created'))
comment = models.TextField(max_length=128, default='', blank=True,
verbose_name=_('Comment'))
objects = AssetManager() objects = AssetManager()
...@@ -121,6 +149,22 @@ class Asset(models.Model): ...@@ -121,6 +149,22 @@ class Asset(models.Model):
nodes = self.nodes.all() or [Node.root()] nodes = self.nodes.all() or [Node.root()]
return nodes return nodes
@property
def nodes_cache_key(self):
key = "NODES_OF_{}".format(str(self.id))
return key
def get_nodes_or_cache(self):
cached = cache.get(self.nodes_cache_key)
if cached is not None:
return cached
nodes = list(self.get_nodes())
cache.set(self.nodes_cache_key, nodes, 3600)
return nodes
def expire_nodes_cache(self):
cache.delete(self.nodes_cache_key)
@property @property
def hardware_info(self): def hardware_info(self):
if self.cpu_count: if self.cpu_count:
......
...@@ -13,9 +13,6 @@ __all__ = ['Node'] ...@@ -13,9 +13,6 @@ __all__ = ['Node']
class Node(models.Model): class Node(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1' key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
# value = models.CharField(
# max_length=128, unique=True, verbose_name=_("Value")
# )
value = models.CharField(max_length=128, verbose_name=_("Value")) value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0) child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True) date_create = models.DateTimeField(auto_now_add=True)
...@@ -31,10 +28,11 @@ class Node(models.Model): ...@@ -31,10 +28,11 @@ class Node(models.Model):
@property @property
def full_value(self): def full_value(self):
if self == self.__class__.root(): ancestor = [a.value for a in self.ancestor]
if self.is_root():
return self.value return self.value
else: ancestor.append(self.value)
return '{} / {}'.format(self.parent.full_value, self.value) return ' / '.join(ancestor)
@property @property
def level(self): def level(self):
...@@ -108,7 +106,6 @@ class Node(models.Model): ...@@ -108,7 +106,6 @@ class Node(models.Model):
def parent(self): def parent(self):
if self.key == "0" or not self.key.startswith("0"): if self.key == "0" or not self.key.startswith("0"):
return self.__class__.root() return self.__class__.root()
parent_key = ":".join(self.key.split(":")[:-1]) parent_key = ":".join(self.key.split(":")[:-1])
try: try:
parent = self.__class__.objects.get(key=parent_key) parent = self.__class__.objects.get(key=parent_key)
...@@ -132,16 +129,17 @@ class Node(models.Model): ...@@ -132,16 +129,17 @@ class Node(models.Model):
@property @property
def ancestor(self): def ancestor(self):
if self.is_root():
ancestor = self.__class__.objects.filter(key='0')
else:
_key = self.key.split(':') _key = self.key.split(':')
ancestor_keys = [] ancestor_keys = []
if self.is_root():
return [self.__class__.root()]
for i in range(len(_key)-1): for i in range(len(_key)-1):
_key.pop() _key.pop()
ancestor_keys.append(':'.join(_key)) ancestor_keys.append(':'.join(_key))
return self.__class__.objects.filter(key__in=ancestor_keys) ancestor = self.__class__.objects.filter(key__in=ancestor_keys)
ancestor = list(ancestor)
return ancestor
@property @property
def ancestor_with_self(self): def ancestor_with_self(self):
......
...@@ -12,34 +12,11 @@ __all__ = [ ...@@ -12,34 +12,11 @@ __all__ = [
] ]
class NodeTMPSerializer(serializers.ModelSerializer):
parent = serializers.SerializerMethodField()
assets_amount = serializers.SerializerMethodField()
class Meta:
model = Node
fields = ['id', 'key', 'value', 'parent', 'assets_amount', 'is_node']
list_serializer_class = BulkListSerializer
@staticmethod
def get_parent(obj):
return obj.parent.id
@staticmethod
def get_assets_amount(obj):
return obj.get_all_assets().count()
def get_fields(self):
fields = super().get_fields()
field = fields["key"]
field.required = False
return fields
class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
""" """
资产的数据结构 资产的数据结构
""" """
nodes = serializers.SerializerMethodField()
class Meta: class Meta:
model = Asset model = Asset
...@@ -54,6 +31,10 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -54,6 +31,10 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
]) ])
return fields return fields
@staticmethod
def get_nodes(obj):
return [n.id for n in obj.get_nodes_or_cache()]
class AssetGrantedSerializer(serializers.ModelSerializer): class AssetGrantedSerializer(serializers.ModelSerializer):
""" """
......
...@@ -63,11 +63,14 @@ def on_system_user_assets_change(sender, instance=None, **kwargs): ...@@ -63,11 +63,14 @@ def on_system_user_assets_change(sender, instance=None, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_node_changed(sender, instance=None, **kwargs): def on_asset_node_changed(sender, instance=None, **kwargs):
if isinstance(instance, Asset) and kwargs['action'] == 'post_add': if isinstance(instance, Asset):
instance.expire_nodes_cache()
if kwargs['action'] == 'post_add':
logger.debug("Asset node change signal received") logger.debug("Asset node change signal received")
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
system_users_assets = defaultdict(set) system_users_assets = defaultdict(set)
system_users = SystemUser.objects.filter(nodes__in=nodes) system_users = SystemUser.objects.filter(nodes__in=nodes)
# 清理节点缓存
for system_user in system_users: for system_user in system_users:
system_users_assets[system_user].update({instance}) system_users_assets[system_user].update({instance})
for system_user, assets in system_users_assets.items(): for system_user, assets in system_users_assets.items():
...@@ -76,9 +79,15 @@ def on_asset_node_changed(sender, instance=None, **kwargs): ...@@ -76,9 +79,15 @@ def on_asset_node_changed(sender, instance=None, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_assets_changed(sender, instance=None, **kwargs): def on_node_assets_changed(sender, instance=None, **kwargs):
if isinstance(instance, Node) and kwargs['action'] == 'post_add': if isinstance(instance, Node):
logger.debug("Node assets change signal received")
assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
# 清理资产节点缓存
for asset in assets:
asset.expire_nodes_cache()
if kwargs['action'] == 'post_add':
logger.debug("Node assets change signal received")
# 重新关联系统用户和资产的关系
system_users = SystemUser.objects.filter(nodes=instance) system_users = SystemUser.objects.filter(nodes=instance)
for system_user in system_users: for system_user in system_users:
system_user.assets.add(*tuple(assets)) system_user.assets.add(*tuple(assets))
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
# #
import os
import paramiko import paramiko
from paramiko.ssh_exception import SSHException
from common.utils import get_object_or_none from common.utils import get_object_or_none
from .models import Asset, SystemUser, Label from .models import Asset, SystemUser, Label
...@@ -49,22 +50,23 @@ def test_gateway_connectability(gateway): ...@@ -49,22 +50,23 @@ def test_gateway_connectability(gateway):
""" """
client = paramiko.SSHClient() client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
proxy = paramiko.SSHClient()
proxy_command = [ proxy.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
"ssh", "{}@{}".format(gateway.username, gateway.ip), proxy.set_missing_host_key_policy(paramiko.AutoAddPolicy())
"-p", str(gateway.port), "-W", "127.0.0.1:{}".format(gateway.port),
]
if gateway.password:
proxy_command.insert(0, "sshpass -p '{}'".format(gateway.password))
if gateway.private_key:
proxy_command.append("-i {}".format(gateway.private_key_file))
try: try:
sock = paramiko.ProxyCommand(" ".join(proxy_command)) proxy.connect(gateway.ip, username=gateway.username,
except paramiko.ProxyCommandFailure as e: password=gateway.password,
pkey=gateway.private_key_obj)
except(paramiko.AuthenticationException,
paramiko.BadAuthenticationType,
SSHException) as e:
return False, str(e) return False, str(e)
sock = proxy.get_transport().open_channel(
'direct-tcpip', ('127.0.0.1', gateway.port), ('127.0.0.1', 0)
)
try: try:
client.connect("127.0.0.1", port=gateway.port, client.connect("127.0.0.1", port=gateway.port,
username=gateway.username, username=gateway.username,
......
...@@ -147,13 +147,8 @@ class UserGrantedNodeAssetsApi(ListAPIView): ...@@ -147,13 +147,8 @@ class UserGrantedNodeAssetsApi(ListAPIView):
user = get_object_or_404(User, id=user_id) user = get_object_or_404(User, id=user_id)
else: else:
user = self.request.user user = self.request.user
node = get_object_or_404(Node, id=node_id)
nodes = AssetPermissionUtil.get_user_nodes_with_assets(user) nodes = AssetPermissionUtil.get_user_nodes_with_assets(user)
node = get_object_or_none(Node, id=node_id)
if not node:
unnode = [node for node in nodes if node.name == 'Unnode']
node = unnode[0] if unnode else None
assets = nodes.get(node, []) assets = nodes.get(node, [])
for asset, system_users in assets.items(): for asset, system_users in assets.items():
asset.system_users_granted = system_users asset.system_users_granted = system_users
......
...@@ -15,7 +15,7 @@ logger = get_logger(__file__) ...@@ -15,7 +15,7 @@ logger = get_logger(__file__)
class Tree: class Tree:
def __init__(self): def __init__(self):
self.__all_nodes = list(Node.objects.all()) self.__all_nodes = list(Node.objects.all().prefetch_related('assets'))
self.__node_asset_map = defaultdict(set) self.__node_asset_map = defaultdict(set)
self.nodes = defaultdict(dict) self.nodes = defaultdict(dict)
self.root = Node.root() self.root = Node.root()
...@@ -134,7 +134,7 @@ class AssetPermissionUtil: ...@@ -134,7 +134,7 @@ class AssetPermissionUtil:
_assets = cls.get_user_group_assets(group) _assets = cls.get_user_group_assets(group)
tree = Tree() tree = Tree()
for asset, _system_users in _assets.items(): for asset, _system_users in _assets.items():
_nodes = asset.get_nodes() _nodes = asset.get_nodes_or_cache()
tree.add_nodes(_nodes) tree.add_nodes(_nodes)
for node in _nodes: for node in _nodes:
tree.nodes[node][asset].update(_system_users) tree.nodes[node][asset].update(_system_users)
......
...@@ -123,6 +123,7 @@ def start_gunicorn(): ...@@ -123,6 +123,7 @@ def start_gunicorn():
'gunicorn', 'jumpserver.wsgi', 'gunicorn', 'jumpserver.wsgi',
'-b', bind, '-b', bind,
'-w', str(WORKERS), '-w', str(WORKERS),
'-k', 'eventlet',
'--access-logformat', log_format, '--access-logformat', log_format,
'-p', pid_file, '-p', pid_file,
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment