Unverified Commit 59927ffc authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #2889 from jumpserver/perf

[Update] 优化协议
parents 303cf41b aac5eed9
...@@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _ ...@@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _
from django import forms from django import forms
from orgs.mixins import OrgModelForm from orgs.mixins import OrgModelForm
from assets.models import SystemUser, Protocol from assets.models import SystemUser
from ..models import RemoteApp from ..models import RemoteApp
from .. import const from .. import const
...@@ -88,9 +88,7 @@ class RemoteAppCreateUpdateForm(RemoteAppTypeForms, OrgModelForm): ...@@ -88,9 +88,7 @@ class RemoteAppCreateUpdateForm(RemoteAppTypeForms, OrgModelForm):
# 过滤RDP资产和系统用户 # 过滤RDP资产和系统用户
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
field_asset = self.fields['asset'] field_asset = self.fields['asset']
field_asset.queryset = field_asset.queryset.filter( field_asset.queryset = field_asset.queryset.has_protocol('rdp')
protocols__name=Protocol.PROTOCOL_RDP
)
field_system_user = self.fields['system_user'] field_system_user = self.fields['system_user']
field_system_user.queryset = field_system_user.queryset.filter( field_system_user.queryset = field_system_user.queryset.filter(
protocol=SystemUser.PROTOCOL_RDP protocol=SystemUser.PROTOCOL_RDP
......
...@@ -37,7 +37,7 @@ __all__ = [ ...@@ -37,7 +37,7 @@ __all__ = [
] ]
class AssetViewSet(LabelFilter, ApiMessageMixin, OrgBulkModelViewSet): class AssetViewSet(LabelFilter, OrgBulkModelViewSet):
""" """
API endpoint that allows Asset to be viewed or edited. API endpoint that allows Asset to be viewed or edited.
""" """
......
...@@ -130,7 +130,7 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): ...@@ -130,7 +130,7 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
include_assets = self.request.query_params.get('assets', '0') == '1' include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets: if not include_assets:
return queryset return queryset
assets = self.node.get_assets().prefetch_related("protocols").only( assets = self.node.get_assets().only(
"id", "hostname", "ip", 'platform', "os", "org_id", "id", "hostname", "ip", 'platform', "os", "org_id",
) )
for asset in assets: for asset in assets:
......
...@@ -6,33 +6,27 @@ from django.utils.translation import gettext_lazy as _ ...@@ -6,33 +6,27 @@ from django.utils.translation import gettext_lazy as _
from common.utils import get_logger from common.utils import get_logger
from orgs.mixins import OrgModelForm from orgs.mixins import OrgModelForm
from ..models import Asset, Protocol, Node from ..models import Asset, Node
logger = get_logger(__file__) logger = get_logger(__file__)
__all__ = [ __all__ = [
'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm', 'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm', 'ProtocolForm',
'ProtocolForm'
] ]
class ProtocolForm(forms.ModelForm): class ProtocolForm(forms.Form):
class Meta: name = forms.ChoiceField(
model = Protocol choices=Asset.PROTOCOL_CHOICES, label=_("Name"), initial='ssh',
fields = ['name', 'port'] widget=forms.Select(attrs={'class': 'form-control protocol-name'})
widgets = { )
'name': forms.Select(attrs={ port = forms.IntegerField(
'class': 'form-control protocol-name' max_value=65534, min_value=1, label=_("Port"), initial=22,
}), widget=forms.TextInput(attrs={'class': 'form-control protocol-port'})
'port': forms.TextInput(attrs={ )
'class': 'form-control protocol-port'
}),
}
class AssetCreateForm(OrgModelForm): class AssetCreateForm(OrgModelForm):
PROTOCOL_CHOICES = Protocol.PROTOCOL_CHOICES
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if not self.data: if not self.data:
......
...@@ -3,14 +3,6 @@ ...@@ -3,14 +3,6 @@
from django.db import migrations from django.db import migrations
def migrate_assets_protocol(apps, schema_editor):
asset_model = apps.get_model("assets", "Asset")
db_alias = schema_editor.connection.alias
assets = asset_model.objects.using(db_alias).all()
for asset in assets:
asset.protocols.create(name=asset.protocol, port=asset.port)
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
...@@ -18,5 +10,4 @@ class Migration(migrations.Migration): ...@@ -18,5 +10,4 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(migrate_assets_protocol),
] ]
# Generated by Django 2.1.7 on 2019-07-05 05:48
from django.db import migrations
from django.db.models import F
from django.db.models import CharField, Value as V
from django.db.models.functions import Concat
def migrate_assets_protocol(apps, schema_editor):
asset_model = apps.get_model("assets", "Asset")
db_alias = schema_editor.connection.alias
assets = asset_model.objects.using(db_alias).all().annotate(
protocols_new=Concat(
'protocol', V('/'), 'port',
output_field=CharField(),
),
)
assets.update(protocols=F('protocols_new'))
class Migration(migrations.Migration):
dependencies = [
('assets', '0033_auto_20190624_2108'),
]
operations = [
migrations.RemoveField(
model_name='asset',
name='protocols',
),
migrations.AddField(
model_name='asset',
name='protocols',
field=CharField(blank=True, default='ssh/22', max_length=128, verbose_name='Protocols'),
),
migrations.RunPython(migrate_assets_protocol),
migrations.DeleteModel(name='Protocol'),
]
...@@ -6,16 +6,16 @@ import uuid ...@@ -6,16 +6,16 @@ import uuid
import logging import logging
import random import random
from functools import reduce from functools import reduce
from collections import OrderedDict
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from .user import AdminUser, SystemUser from .user import AdminUser, SystemUser
from .utils import Connectivity from .utils import Connectivity
from orgs.mixins import OrgModelMixin, OrgManager from orgs.mixins import OrgModelMixin, OrgManager
__all__ = ['Asset', 'Protocol'] __all__ = ['Asset']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,8 +45,12 @@ class AssetQuerySet(models.QuerySet): ...@@ -45,8 +45,12 @@ class AssetQuerySet(models.QuerySet):
def valid(self): def valid(self):
return self.active() return self.active()
def has_protocol(self, name):
return self.filter(protocols__contains=name)
class Protocol(models.Model): class ProtocolsMixin:
protocols = ''
PROTOCOL_SSH = 'ssh' PROTOCOL_SSH = 'ssh'
PROTOCOL_RDP = 'rdp' PROTOCOL_RDP = 'rdp'
PROTOCOL_TELNET = 'telnet' PROTOCOL_TELNET = 'telnet'
...@@ -57,19 +61,42 @@ class Protocol(models.Model): ...@@ -57,19 +61,42 @@ class Protocol(models.Model):
(PROTOCOL_TELNET, 'telnet (beta)'), (PROTOCOL_TELNET, 'telnet (beta)'),
(PROTOCOL_VNC, 'vnc'), (PROTOCOL_VNC, 'vnc'),
) )
PORT_VALIDATORS = [MaxValueValidator(65535), MinValueValidator(1)]
id = models.UUIDField(default=uuid.uuid4, primary_key=True) @property
name = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, def protocols_as_list(self):
default=PROTOCOL_SSH, verbose_name=_("Name")) if not self.protocols:
port = models.IntegerField(default=22, verbose_name=_("Port"), return []
validators=PORT_VALIDATORS) return self.protocols.split(' ')
def __str__(self): @property
return "{}/{}".format(self.name, self.port) def protocols_as_dict(self):
d = OrderedDict()
protocols = self.protocols_as_list
for i in protocols:
if '/' not in i:
continue
name, port = i.split('/')[:2]
if not all([name, port]):
continue
d[name] = int(port)
return d
@property
def protocols_as_json(self):
return [
{"name": name, "port": port}
for name, port in self.protocols_as_dict.items()
]
def has_protocol(self, name):
return name in self.protocols_as_dict
@property
def ssh_port(self):
return self.protocols_as_dict.get("ssh", 22)
class Asset(OrgModelMixin): class Asset(ProtocolsMixin, OrgModelMixin):
# Important # Important
PLATFORM_CHOICES = ( PLATFORM_CHOICES = (
('Linux', 'Linux'), ('Linux', 'Linux'),
...@@ -84,12 +111,12 @@ class Asset(OrgModelMixin): ...@@ -84,12 +111,12 @@ class Asset(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True) ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True)
hostname = models.CharField(max_length=128, verbose_name=_('Hostname')) hostname = models.CharField(max_length=128, verbose_name=_('Hostname'))
protocol = models.CharField(max_length=128, default=Protocol.PROTOCOL_SSH, protocol = models.CharField(max_length=128, default=ProtocolsMixin.PROTOCOL_SSH,
choices=Protocol.PROTOCOL_CHOICES, choices=ProtocolsMixin.PROTOCOL_CHOICES,
verbose_name=_('Protocol')) verbose_name=_('Protocol'))
port = models.IntegerField(default=22, verbose_name=_('Port')) port = models.IntegerField(default=22, verbose_name=_('Port'))
protocols = models.ManyToManyField('Protocol', verbose_name=_("Protocol")) protocols = models.CharField(max_length=128, default='ssh/22', blank=True, verbose_name=_("Protocols"))
platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES, default='Linux', verbose_name=_('Platform')) platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES, default='Linux', verbose_name=_('Platform'))
domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL) domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL)
nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")) nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes"))
...@@ -136,41 +163,9 @@ class Asset(OrgModelMixin): ...@@ -136,41 +163,9 @@ class Asset(OrgModelMixin):
warning = '' warning = ''
if not self.is_active: if not self.is_active:
warning += ' inactive' warning += ' inactive'
else: if warning:
return True, '' return False, warning
return False, warning return True, warning
@property
def protocols_name(self):
names = []
for protocol in self.protocols.all():
names.append(protocol.name)
return names
def has_protocol(self, name):
return name in self.protocols_name
def get_protocol_by_name(self, name):
for i in self.protocols.all():
if i.name.lower() == name.lower():
return i
return None
@property
def protocol_ssh(self):
return self.get_protocol_by_name("ssh")
@property
def protocol_rdp(self):
return self.get_protocol_by_name("rdp")
@property
def ssh_port(self):
if self.protocol_ssh:
port = self.protocol_ssh.port
else:
port = 22
return port
def is_windows(self): def is_windows(self):
if self.platform in ("Windows", "Windows2016"): if self.platform in ("Windows", "Windows2016"):
...@@ -278,10 +273,7 @@ class Asset(OrgModelMixin): ...@@ -278,10 +273,7 @@ class Asset(OrgModelMixin):
'id': self.id, 'id': self.id,
'hostname': self.hostname, 'hostname': self.hostname,
'ip': self.ip, 'ip': self.ip,
'protocols': [ 'protocols': self.protocols_as_list,
{"name": p.name, "port": p.port}
for p in self.protocols.all()
],
'platform': self.platform, 'platform': self.platform,
} }
} }
...@@ -314,7 +306,7 @@ class Asset(OrgModelMixin): ...@@ -314,7 +306,7 @@ class Asset(OrgModelMixin):
created_by='Fake') created_by='Fake')
try: try:
asset.save() asset.save()
asset.protocols.create(name="ssh", port=22) asset.protocols = 'ssh/22'
if nodes and len(nodes) > 3: if nodes and len(nodes) > 3:
_nodes = random.sample(nodes, 3) _nodes = random.sample(nodes, 3)
else: else:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from rest_framework import serializers from rest_framework import serializers
from rest_framework.validators import ValidationError
from django.db.models import Prefetch from django.db.models import Prefetch
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import BulkOrgResourceModelSerializer from orgs.mixins import BulkOrgResourceModelSerializer
from common.serializers import AdaptedBulkListSerializer from common.serializers import AdaptedBulkListSerializer
from ..models import Asset, Protocol, Node, Label from ..models import Asset, Node, Label
from .base import ConnectivitySerializer from .base import ConnectivitySerializer
__all__ = [ __all__ = [
'AssetSerializer', 'AssetSimpleSerializer', 'AssetSerializer', 'AssetSimpleSerializer',
'ProtocolSerializer', 'ProtocolsRelatedField', 'ProtocolsField',
] ]
class ProtocolSerializer(serializers.ModelSerializer): class ProtocolField(serializers.RegexField):
class Meta: protocols = '|'.join(dict(Asset.PROTOCOL_CHOICES).keys())
model = Protocol default_error_messages = {
fields = ["name", "port"] 'invalid': _('Protocol format should {}/{}'.format(protocols, '1-65535'))
}
regex = r'^(%s)/(\d{1,5})$' % protocols
def __init__(self, *args, **kwargs):
super().__init__(self.regex, **kwargs)
def validate_duplicate_protocols(values):
errors = []
names = []
for value in values:
if not value or '/' not in value:
continue
name = value.split('/')[0]
if name in names:
errors.append(_("Protocol duplicate: {}").format(name))
names.append(name)
errors.append('')
if any(errors):
raise serializers.ValidationError(errors)
class ProtocolsRelatedField(serializers.RelatedField):
def to_representation(self, value):
return str(value)
def to_internal_value(self, data): class ProtocolsField(serializers.ListField):
if isinstance(data, dict): default_validators = [validate_duplicate_protocols]
return data
if '/' not in data: def __init__(self, *args, **kwargs):
raise ValidationError("protocol not contain /: {}".format(data)) kwargs['child'] = ProtocolField()
v = data.split("/") kwargs['allow_null'] = True
if len(v) != 2: kwargs['allow_empty'] = True
raise ValidationError("protocol format should be name/port: {}".format(data)) kwargs['min_length'] = 1
name, port = v kwargs['max_length'] = 4
cleaned_data = {"name": name, "port": port} super().__init__(*args, **kwargs)
return cleaned_data
def to_representation(self, value):
if not value:
return []
return value.split(' ')
class AssetSerializer(BulkOrgResourceModelSerializer): class AssetSerializer(BulkOrgResourceModelSerializer):
protocols = ProtocolsRelatedField( protocols = ProtocolsField(label=_('Protocols'), required=False)
many=True, queryset=Protocol.objects.all(), label=_("Protocols")
)
connectivity = ConnectivitySerializer(read_only=True, label=_("Connectivity")) connectivity = ConnectivitySerializer(read_only=True, label=_("Connectivity"))
""" """
...@@ -79,66 +97,32 @@ class AssetSerializer(BulkOrgResourceModelSerializer): ...@@ -79,66 +97,32 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
queryset = queryset.prefetch_related( queryset = queryset.prefetch_related(
Prefetch('nodes', queryset=Node.objects.all().only('id')), Prefetch('nodes', queryset=Node.objects.all().only('id')),
Prefetch('labels', queryset=Label.objects.all().only('id')), Prefetch('labels', queryset=Label.objects.all().only('id')),
'protocols'
).select_related('admin_user', 'domain') ).select_related('admin_user', 'domain')
return queryset return queryset
@staticmethod def compatible_with_old_protocol(self, validated_data):
def validate_protocols(attr):
protocols_serializer = ProtocolSerializer(data=attr, many=True)
protocols_serializer.is_valid(raise_exception=True)
protocols_name = [i.get("name", "ssh") for i in attr]
errors = [{} for i in protocols_name]
for i, name in enumerate(protocols_name):
if name in protocols_name[:i]:
errors[i] = {"name": _("Protocol duplicate: {}").format(name)}
if any(errors):
raise ValidationError(errors)
return attr
def create(self, validated_data):
protocols_data = validated_data.pop("protocols", []) protocols_data = validated_data.pop("protocols", [])
# 兼容老的api # 兼容老的api
protocol = validated_data.get("protocol") name = validated_data.get("protocol")
port = validated_data.get("port") port = validated_data.get("port")
if not protocols_data and protocol and port: if not protocols_data and name and port:
protocols_data = [{"name": protocol, "port": port}] protocols_data.insert(0, '/'.join([name, str(port)]))
elif not name and not port and protocols_data:
if not protocol and not port and protocols_data: protocol = protocols_data[0].split('/')
validated_data["protocol"] = protocols_data[0]["name"] validated_data["protocol"] = protocol[0]
validated_data["port"] = protocols_data[0]["port"] validated_data["port"] = int(protocol[1])
if validated_data:
validated_data["protocols"] = ' '.join(protocols_data)
protocols_serializer = ProtocolSerializer(data=protocols_data, many=True) def create(self, validated_data):
protocols_serializer.is_valid(raise_exception=True) self.compatible_with_old_protocol(validated_data)
protocols = protocols_serializer.save()
instance = super().create(validated_data) instance = super().create(validated_data)
instance.protocols.set(protocols)
return instance return instance
def update(self, instance, validated_data): def update(self, instance, validated_data):
protocols_data = validated_data.pop("protocols", []) self.compatible_with_old_protocol(validated_data)
return super().update(instance, validated_data)
# 兼容老的api
protocol = validated_data.get("protocol")
port = validated_data.get("port")
if not protocols_data and protocol and port:
protocols_data = [{"name": protocol, "port": port}]
if not protocol and not port and protocols_data:
validated_data["protocol"] = protocols_data[0]["name"]
validated_data["port"] = protocols_data[0]["port"]
protocols = None
if protocols_data:
protocols_serializer = ProtocolSerializer(data=protocols_data, many=True)
protocols_serializer.is_valid(raise_exception=True)
protocols = protocols_serializer.save()
instance = super().update(instance, validated_data)
if protocols:
instance.protocols.all().delete()
instance.protocols.set(protocols)
return instance
class AssetSimpleSerializer(serializers.ModelSerializer): class AssetSimpleSerializer(serializers.ModelSerializer):
......
...@@ -70,11 +70,7 @@ ...@@ -70,11 +70,7 @@
</tr> </tr>
<tr> <tr>
<td>{% trans 'Protocol' %}</td> <td>{% trans 'Protocol' %}</td>
<td> <td>{{ asset.protocols }}</td>
{% for protocol in asset.protocols.all %}
<b>{{ protocol }}</b>
{% endfor %}
</td>
</tr> </tr>
<tr> <tr>
<td>{% trans 'Admin user' %}:</td> <td>{% trans 'Admin user' %}:</td>
......
# coding:utf-8 # coding:utf-8
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
import csv
import json
import uuid
import codecs
import chardet
from io import StringIO
from django.db import transaction
from django.contrib import messages from django.contrib import messages
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.views.generic import TemplateView, ListView, View from django.views.generic import TemplateView, ListView
from django.views.generic.edit import CreateView, DeleteView, FormView, UpdateView from django.views.generic.edit import FormMixin
from django.views.generic.edit import CreateView, DeleteView, UpdateView
from django.urls import reverse_lazy from django.urls import reverse_lazy
from django.views.generic.detail import DetailView from django.views.generic.detail import DetailView
from django.http import HttpResponse, JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.core.cache import cache from django.core.cache import cache
from django.utils import timezone
from django.shortcuts import redirect from django.shortcuts import redirect
from django.contrib.messages.views import SuccessMessageMixin from django.contrib.messages.views import SuccessMessageMixin
from django.forms.formsets import formset_factory from django.forms.formsets import formset_factory
from common.mixins import JSONResponseMixin
from common.utils import get_object_or_none, get_logger from common.utils import get_object_or_none, get_logger
from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser
from common.const import ( from common.const import (
create_success_msg, update_success_msg, KEY_CACHE_RESOURCES_ID create_success_msg, update_success_msg, KEY_CACHE_RESOURCES_ID
) )
from .. import forms from .. import forms
from ..models import Asset, AdminUser, SystemUser, Label, Node, Domain from ..models import Asset, SystemUser, Label, Node
__all__ = [ __all__ = [
...@@ -87,7 +75,7 @@ class UserAssetListView(PermissionsMixin, TemplateView): ...@@ -87,7 +75,7 @@ class UserAssetListView(PermissionsMixin, TemplateView):
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): class AssetCreateView(PermissionsMixin, FormMixin, TemplateView):
model = Asset model = Asset
form_class = forms.AssetCreateForm form_class = forms.AssetCreateForm
template_name = 'assets/asset_create.html' template_name = 'assets/asset_create.html'
...@@ -112,16 +100,6 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): ...@@ -112,16 +100,6 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView):
formset = ProtocolFormset() formset = ProtocolFormset()
return formset return formset
def form_valid(self, form):
formset = self.get_protocol_formset()
valid = formset.is_valid()
if not valid:
return self.form_invalid(form)
protocols = formset.save()
instance = super().form_valid(form)
instance.protocols.set(protocols)
return instance
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
formset = self.get_protocol_formset() formset = self.get_protocol_formset()
context = { context = {
...@@ -132,8 +110,32 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): ...@@ -132,8 +110,32 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView):
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
def get_success_message(self, cleaned_data):
return create_success_msg % ({"name": cleaned_data["hostname"]}) class AssetUpdateView(PermissionsMixin, UpdateView):
model = Asset
form_class = forms.AssetUpdateForm
template_name = 'assets/asset_update.html'
success_url = reverse_lazy('assets:asset-list')
permission_classes = [IsOrgAdmin]
def get_protocol_formset(self):
ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5)
if self.request.method == "POST":
formset = ProtocolFormset(self.request.POST)
else:
initial_data = self.object.protocols_as_json
formset = ProtocolFormset(initial=initial_data)
return formset
def get_context_data(self, **kwargs):
formset = self.get_protocol_formset()
context = {
'app': _('Assets'),
'action': _('Update asset'),
'formset': formset,
}
kwargs.update(context)
return super().get_context_data(**kwargs)
class AssetBulkUpdateView(PermissionsMixin, ListView): class AssetBulkUpdateView(PermissionsMixin, ListView):
...@@ -177,36 +179,6 @@ class AssetBulkUpdateView(PermissionsMixin, ListView): ...@@ -177,36 +179,6 @@ class AssetBulkUpdateView(PermissionsMixin, ListView):
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
class AssetUpdateView(PermissionsMixin, SuccessMessageMixin, UpdateView):
model = Asset
form_class = forms.AssetUpdateForm
template_name = 'assets/asset_update.html'
success_url = reverse_lazy('assets:asset-list')
permission_classes = [IsOrgAdmin]
def get_protocol_formset(self):
ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5)
if self.request.method == "POST":
formset = ProtocolFormset(self.request.POST)
else:
initial_data = [{"name": p.name, "port": p.port} for p in self.object.protocols.all()]
formset = ProtocolFormset(initial=initial_data)
return formset
def get_context_data(self, **kwargs):
formset = self.get_protocol_formset()
context = {
'app': _('Assets'),
'action': _('Update asset'),
'formset': formset,
}
kwargs.update(context)
return super().get_context_data(**kwargs)
def get_success_message(self, cleaned_data):
return update_success_msg % ({"name": cleaned_data["hostname"]})
class AssetDeleteView(PermissionsMixin, DeleteView): class AssetDeleteView(PermissionsMixin, DeleteView):
model = Asset model = Asset
template_name = 'delete_confirm.html' template_name = 'delete_confirm.html'
...@@ -222,7 +194,7 @@ class AssetDetailView(PermissionsMixin, DetailView): ...@@ -222,7 +194,7 @@ class AssetDetailView(PermissionsMixin, DetailView):
def get_queryset(self): def get_queryset(self):
return super().get_queryset().prefetch_related( return super().get_queryset().prefetch_related(
"nodes", "labels", "protocols" "nodes", "labels",
).select_related('admin_user', 'domain') ).select_related('admin_user', 'domain')
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
......
...@@ -135,9 +135,7 @@ function getSelectedAssetsNode() { ...@@ -135,9 +135,7 @@ function getSelectedAssetsNode() {
var assetsNode = []; var assetsNode = [];
nodes.forEach(function (node) { nodes.forEach(function (node) {
if (node.meta.type === 'asset' && !node.isHidden) { if (node.meta.type === 'asset' && !node.isHidden) {
var protocols = $.map(node.meta.asset.protocols, function (v) { var protocols = node.meta.asset.protocols;
return v.name
});
if (assetsNodeId.indexOf(node.id) === -1 && protocols.indexOf("ssh") > -1) { if (assetsNodeId.indexOf(node.id) === -1 && protocols.indexOf("ssh") > -1) {
assetsNodeId.push(node.id); assetsNodeId.push(node.id);
assetsNode.push(node) assetsNode.push(node)
......
...@@ -126,7 +126,7 @@ class GenerateTree: ...@@ -126,7 +126,7 @@ class GenerateTree:
for asset, system_users in assets.items(): for asset, system_users in assets.items():
self.add_asset(asset, system_users) self.add_asset(asset, system_users)
#@timeit # #@timeit
def add_asset(self, asset, system_users=None): def add_asset(self, asset, system_users=None):
nodes = asset.nodes.all() nodes = asset.nodes.all()
nodes = self.node_util.get_nodes_by_queryset(nodes) nodes = self.node_util.get_nodes_by_queryset(nodes)
...@@ -493,12 +493,13 @@ class AssetPermissionUtil(AssetPermissionCacheMixin): ...@@ -493,12 +493,13 @@ class AssetPermissionUtil(AssetPermissionCacheMixin):
pattern.add(r'^{0}$|^{0}:'.format(node.key)) pattern.add(r'^{0}$|^{0}:'.format(node.key))
pattern = '|'.join(list(pattern)) pattern = '|'.join(list(pattern))
if pattern: if pattern:
assets = Asset.objects.filter(nodes__key__regex=pattern)\ assets = Asset.objects.filter(nodes__key__regex=pattern) \
.prefetch_related('nodes', "protocols")\ .prefetch_related('nodes')\
.only(*self.assets_only)\ .only(*self.assets_only)\
.distinct() .distinct()
else: else:
assets = [] assets = []
assets = list(assets)
self.tree.add_assets_without_system_users(assets) self.tree.add_assets_without_system_users(assets)
assets = self.tree.get_assets() assets = self.tree.get_assets()
self._assets = assets self._assets = assets
...@@ -598,7 +599,7 @@ def parse_asset_to_tree_node(node, asset, system_users): ...@@ -598,7 +599,7 @@ def parse_asset_to_tree_node(node, asset, system_users):
'id': asset.id, 'id': asset.id,
'hostname': asset.hostname, 'hostname': asset.hostname,
'ip': asset.ip, 'ip': asset.ip,
'protocols': [str(p) for p in asset.protocols.all()], 'protocols': asset.protocols_as_list,
'platform': asset.platform, 'platform': asset.platform,
'domain': None if not asset.domain else asset.domain.id, 'domain': None if not asset.domain else asset.domain.id,
'is_active': asset.is_active, 'is_active': asset.is_active,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment