Unverified Commit 59927ffc authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #2889 from jumpserver/perf

[Update] 优化协议
parents 303cf41b aac5eed9
......@@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _
from django import forms
from orgs.mixins import OrgModelForm
from assets.models import SystemUser, Protocol
from assets.models import SystemUser
from ..models import RemoteApp
from .. import const
......@@ -88,9 +88,7 @@ class RemoteAppCreateUpdateForm(RemoteAppTypeForms, OrgModelForm):
# 过滤RDP资产和系统用户
super().__init__(*args, **kwargs)
field_asset = self.fields['asset']
field_asset.queryset = field_asset.queryset.filter(
protocols__name=Protocol.PROTOCOL_RDP
)
field_asset.queryset = field_asset.queryset.has_protocol('rdp')
field_system_user = self.fields['system_user']
field_system_user.queryset = field_system_user.queryset.filter(
protocol=SystemUser.PROTOCOL_RDP
......
......@@ -37,7 +37,7 @@ __all__ = [
]
class AssetViewSet(LabelFilter, ApiMessageMixin, OrgBulkModelViewSet):
class AssetViewSet(LabelFilter, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
......
......@@ -130,7 +130,7 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets:
return queryset
assets = self.node.get_assets().prefetch_related("protocols").only(
assets = self.node.get_assets().only(
"id", "hostname", "ip", 'platform', "os", "org_id",
)
for asset in assets:
......
......@@ -6,33 +6,27 @@ from django.utils.translation import gettext_lazy as _
from common.utils import get_logger
from orgs.mixins import OrgModelForm
from ..models import Asset, Protocol, Node
from ..models import Asset, Node
logger = get_logger(__file__)
__all__ = [
'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm',
'ProtocolForm'
'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm', 'ProtocolForm',
]
class ProtocolForm(forms.ModelForm):
class Meta:
model = Protocol
fields = ['name', 'port']
widgets = {
'name': forms.Select(attrs={
'class': 'form-control protocol-name'
}),
'port': forms.TextInput(attrs={
'class': 'form-control protocol-port'
}),
}
class ProtocolForm(forms.Form):
name = forms.ChoiceField(
choices=Asset.PROTOCOL_CHOICES, label=_("Name"), initial='ssh',
widget=forms.Select(attrs={'class': 'form-control protocol-name'})
)
port = forms.IntegerField(
max_value=65534, min_value=1, label=_("Port"), initial=22,
widget=forms.TextInput(attrs={'class': 'form-control protocol-port'})
)
class AssetCreateForm(OrgModelForm):
PROTOCOL_CHOICES = Protocol.PROTOCOL_CHOICES
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.data:
......
......@@ -3,14 +3,6 @@
from django.db import migrations
def migrate_assets_protocol(apps, schema_editor):
asset_model = apps.get_model("assets", "Asset")
db_alias = schema_editor.connection.alias
assets = asset_model.objects.using(db_alias).all()
for asset in assets:
asset.protocols.create(name=asset.protocol, port=asset.port)
class Migration(migrations.Migration):
dependencies = [
......@@ -18,5 +10,4 @@ class Migration(migrations.Migration):
]
operations = [
migrations.RunPython(migrate_assets_protocol),
]
# Generated by Django 2.1.7 on 2019-07-05 05:48
from django.db import migrations
from django.db.models import F
from django.db.models import CharField, Value as V
from django.db.models.functions import Concat
def migrate_assets_protocol(apps, schema_editor):
asset_model = apps.get_model("assets", "Asset")
db_alias = schema_editor.connection.alias
assets = asset_model.objects.using(db_alias).all().annotate(
protocols_new=Concat(
'protocol', V('/'), 'port',
output_field=CharField(),
),
)
assets.update(protocols=F('protocols_new'))
class Migration(migrations.Migration):
dependencies = [
('assets', '0033_auto_20190624_2108'),
]
operations = [
migrations.RemoveField(
model_name='asset',
name='protocols',
),
migrations.AddField(
model_name='asset',
name='protocols',
field=CharField(blank=True, default='ssh/22', max_length=128, verbose_name='Protocols'),
),
migrations.RunPython(migrate_assets_protocol),
migrations.DeleteModel(name='Protocol'),
]
......@@ -6,16 +6,16 @@ import uuid
import logging
import random
from functools import reduce
from collections import OrderedDict
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from .user import AdminUser, SystemUser
from .utils import Connectivity
from orgs.mixins import OrgModelMixin, OrgManager
__all__ = ['Asset', 'Protocol']
__all__ = ['Asset']
logger = logging.getLogger(__name__)
......@@ -45,8 +45,12 @@ class AssetQuerySet(models.QuerySet):
def valid(self):
return self.active()
def has_protocol(self, name):
return self.filter(protocols__contains=name)
class Protocol(models.Model):
class ProtocolsMixin:
protocols = ''
PROTOCOL_SSH = 'ssh'
PROTOCOL_RDP = 'rdp'
PROTOCOL_TELNET = 'telnet'
......@@ -57,19 +61,42 @@ class Protocol(models.Model):
(PROTOCOL_TELNET, 'telnet (beta)'),
(PROTOCOL_VNC, 'vnc'),
)
PORT_VALIDATORS = [MaxValueValidator(65535), MinValueValidator(1)]
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=16, choices=PROTOCOL_CHOICES,
default=PROTOCOL_SSH, verbose_name=_("Name"))
port = models.IntegerField(default=22, verbose_name=_("Port"),
validators=PORT_VALIDATORS)
@property
def protocols_as_list(self):
if not self.protocols:
return []
return self.protocols.split(' ')
def __str__(self):
return "{}/{}".format(self.name, self.port)
@property
def protocols_as_dict(self):
d = OrderedDict()
protocols = self.protocols_as_list
for i in protocols:
if '/' not in i:
continue
name, port = i.split('/')[:2]
if not all([name, port]):
continue
d[name] = int(port)
return d
@property
def protocols_as_json(self):
return [
{"name": name, "port": port}
for name, port in self.protocols_as_dict.items()
]
def has_protocol(self, name):
return name in self.protocols_as_dict
@property
def ssh_port(self):
return self.protocols_as_dict.get("ssh", 22)
class Asset(OrgModelMixin):
class Asset(ProtocolsMixin, OrgModelMixin):
# Important
PLATFORM_CHOICES = (
('Linux', 'Linux'),
......@@ -84,12 +111,12 @@ class Asset(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True)
hostname = models.CharField(max_length=128, verbose_name=_('Hostname'))
protocol = models.CharField(max_length=128, default=Protocol.PROTOCOL_SSH,
choices=Protocol.PROTOCOL_CHOICES,
protocol = models.CharField(max_length=128, default=ProtocolsMixin.PROTOCOL_SSH,
choices=ProtocolsMixin.PROTOCOL_CHOICES,
verbose_name=_('Protocol'))
port = models.IntegerField(default=22, verbose_name=_('Port'))
protocols = models.ManyToManyField('Protocol', verbose_name=_("Protocol"))
protocols = models.CharField(max_length=128, default='ssh/22', blank=True, verbose_name=_("Protocols"))
platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES, default='Linux', verbose_name=_('Platform'))
domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL)
nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes"))
......@@ -136,41 +163,9 @@ class Asset(OrgModelMixin):
warning = ''
if not self.is_active:
warning += ' inactive'
else:
return True, ''
return False, warning
@property
def protocols_name(self):
names = []
for protocol in self.protocols.all():
names.append(protocol.name)
return names
def has_protocol(self, name):
return name in self.protocols_name
def get_protocol_by_name(self, name):
for i in self.protocols.all():
if i.name.lower() == name.lower():
return i
return None
@property
def protocol_ssh(self):
return self.get_protocol_by_name("ssh")
@property
def protocol_rdp(self):
return self.get_protocol_by_name("rdp")
@property
def ssh_port(self):
if self.protocol_ssh:
port = self.protocol_ssh.port
else:
port = 22
return port
if warning:
return False, warning
return True, warning
def is_windows(self):
if self.platform in ("Windows", "Windows2016"):
......@@ -278,10 +273,7 @@ class Asset(OrgModelMixin):
'id': self.id,
'hostname': self.hostname,
'ip': self.ip,
'protocols': [
{"name": p.name, "port": p.port}
for p in self.protocols.all()
],
'protocols': self.protocols_as_list,
'platform': self.platform,
}
}
......@@ -314,7 +306,7 @@ class Asset(OrgModelMixin):
created_by='Fake')
try:
asset.save()
asset.protocols.create(name="ssh", port=22)
asset.protocols = 'ssh/22'
if nodes and len(nodes) > 3:
_nodes = random.sample(nodes, 3)
else:
......
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from rest_framework.validators import ValidationError
from django.db.models import Prefetch
from django.utils.translation import ugettext_lazy as _
from orgs.mixins import BulkOrgResourceModelSerializer
from common.serializers import AdaptedBulkListSerializer
from ..models import Asset, Protocol, Node, Label
from ..models import Asset, Node, Label
from .base import ConnectivitySerializer
__all__ = [
'AssetSerializer', 'AssetSimpleSerializer',
'ProtocolSerializer', 'ProtocolsRelatedField',
'ProtocolsField',
]
class ProtocolSerializer(serializers.ModelSerializer):
class Meta:
model = Protocol
fields = ["name", "port"]
class ProtocolField(serializers.RegexField):
protocols = '|'.join(dict(Asset.PROTOCOL_CHOICES).keys())
default_error_messages = {
'invalid': _('Protocol format should {}/{}'.format(protocols, '1-65535'))
}
regex = r'^(%s)/(\d{1,5})$' % protocols
def __init__(self, *args, **kwargs):
super().__init__(self.regex, **kwargs)
def validate_duplicate_protocols(values):
errors = []
names = []
for value in values:
if not value or '/' not in value:
continue
name = value.split('/')[0]
if name in names:
errors.append(_("Protocol duplicate: {}").format(name))
names.append(name)
errors.append('')
if any(errors):
raise serializers.ValidationError(errors)
class ProtocolsRelatedField(serializers.RelatedField):
def to_representation(self, value):
return str(value)
def to_internal_value(self, data):
if isinstance(data, dict):
return data
if '/' not in data:
raise ValidationError("protocol not contain /: {}".format(data))
v = data.split("/")
if len(v) != 2:
raise ValidationError("protocol format should be name/port: {}".format(data))
name, port = v
cleaned_data = {"name": name, "port": port}
return cleaned_data
class ProtocolsField(serializers.ListField):
default_validators = [validate_duplicate_protocols]
def __init__(self, *args, **kwargs):
kwargs['child'] = ProtocolField()
kwargs['allow_null'] = True
kwargs['allow_empty'] = True
kwargs['min_length'] = 1
kwargs['max_length'] = 4
super().__init__(*args, **kwargs)
def to_representation(self, value):
if not value:
return []
return value.split(' ')
class AssetSerializer(BulkOrgResourceModelSerializer):
protocols = ProtocolsRelatedField(
many=True, queryset=Protocol.objects.all(), label=_("Protocols")
)
protocols = ProtocolsField(label=_('Protocols'), required=False)
connectivity = ConnectivitySerializer(read_only=True, label=_("Connectivity"))
"""
......@@ -79,66 +97,32 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
queryset = queryset.prefetch_related(
Prefetch('nodes', queryset=Node.objects.all().only('id')),
Prefetch('labels', queryset=Label.objects.all().only('id')),
'protocols'
).select_related('admin_user', 'domain')
return queryset
@staticmethod
def validate_protocols(attr):
protocols_serializer = ProtocolSerializer(data=attr, many=True)
protocols_serializer.is_valid(raise_exception=True)
protocols_name = [i.get("name", "ssh") for i in attr]
errors = [{} for i in protocols_name]
for i, name in enumerate(protocols_name):
if name in protocols_name[:i]:
errors[i] = {"name": _("Protocol duplicate: {}").format(name)}
if any(errors):
raise ValidationError(errors)
return attr
def create(self, validated_data):
def compatible_with_old_protocol(self, validated_data):
protocols_data = validated_data.pop("protocols", [])
# 兼容老的api
protocol = validated_data.get("protocol")
name = validated_data.get("protocol")
port = validated_data.get("port")
if not protocols_data and protocol and port:
protocols_data = [{"name": protocol, "port": port}]
if not protocol and not port and protocols_data:
validated_data["protocol"] = protocols_data[0]["name"]
validated_data["port"] = protocols_data[0]["port"]
if not protocols_data and name and port:
protocols_data.insert(0, '/'.join([name, str(port)]))
elif not name and not port and protocols_data:
protocol = protocols_data[0].split('/')
validated_data["protocol"] = protocol[0]
validated_data["port"] = int(protocol[1])
if validated_data:
validated_data["protocols"] = ' '.join(protocols_data)
protocols_serializer = ProtocolSerializer(data=protocols_data, many=True)
protocols_serializer.is_valid(raise_exception=True)
protocols = protocols_serializer.save()
def create(self, validated_data):
self.compatible_with_old_protocol(validated_data)
instance = super().create(validated_data)
instance.protocols.set(protocols)
return instance
def update(self, instance, validated_data):
protocols_data = validated_data.pop("protocols", [])
# 兼容老的api
protocol = validated_data.get("protocol")
port = validated_data.get("port")
if not protocols_data and protocol and port:
protocols_data = [{"name": protocol, "port": port}]
if not protocol and not port and protocols_data:
validated_data["protocol"] = protocols_data[0]["name"]
validated_data["port"] = protocols_data[0]["port"]
protocols = None
if protocols_data:
protocols_serializer = ProtocolSerializer(data=protocols_data, many=True)
protocols_serializer.is_valid(raise_exception=True)
protocols = protocols_serializer.save()
instance = super().update(instance, validated_data)
if protocols:
instance.protocols.all().delete()
instance.protocols.set(protocols)
return instance
self.compatible_with_old_protocol(validated_data)
return super().update(instance, validated_data)
class AssetSimpleSerializer(serializers.ModelSerializer):
......
......@@ -70,11 +70,7 @@
</tr>
<tr>
<td>{% trans 'Protocol' %}</td>
<td>
{% for protocol in asset.protocols.all %}
<b>{{ protocol }}</b>
{% endfor %}
</td>
<td>{{ asset.protocols }}</td>
</tr>
<tr>
<td>{% trans 'Admin user' %}:</td>
......
# coding:utf-8
from __future__ import absolute_import, unicode_literals
import csv
import json
import uuid
import codecs
import chardet
from io import StringIO
from django.db import transaction
from django.contrib import messages
from django.utils.translation import ugettext_lazy as _
from django.views.generic import TemplateView, ListView, View
from django.views.generic.edit import CreateView, DeleteView, FormView, UpdateView
from django.views.generic import TemplateView, ListView
from django.views.generic.edit import FormMixin
from django.views.generic.edit import CreateView, DeleteView, UpdateView
from django.urls import reverse_lazy
from django.views.generic.detail import DetailView
from django.http import HttpResponse, JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.core.cache import cache
from django.utils import timezone
from django.shortcuts import redirect
from django.contrib.messages.views import SuccessMessageMixin
from django.forms.formsets import formset_factory
from common.mixins import JSONResponseMixin
from common.utils import get_object_or_none, get_logger
from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser
from common.const import (
create_success_msg, update_success_msg, KEY_CACHE_RESOURCES_ID
)
from .. import forms
from ..models import Asset, AdminUser, SystemUser, Label, Node, Domain
from ..models import Asset, SystemUser, Label, Node
__all__ = [
......@@ -87,7 +75,7 @@ class UserAssetListView(PermissionsMixin, TemplateView):
return super().get_context_data(**kwargs)
class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView):
class AssetCreateView(PermissionsMixin, FormMixin, TemplateView):
model = Asset
form_class = forms.AssetCreateForm
template_name = 'assets/asset_create.html'
......@@ -112,16 +100,6 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView):
formset = ProtocolFormset()
return formset
def form_valid(self, form):
formset = self.get_protocol_formset()
valid = formset.is_valid()
if not valid:
return self.form_invalid(form)
protocols = formset.save()
instance = super().form_valid(form)
instance.protocols.set(protocols)
return instance
def get_context_data(self, **kwargs):
formset = self.get_protocol_formset()
context = {
......@@ -132,8 +110,32 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView):
kwargs.update(context)
return super().get_context_data(**kwargs)
def get_success_message(self, cleaned_data):
return create_success_msg % ({"name": cleaned_data["hostname"]})
class AssetUpdateView(PermissionsMixin, UpdateView):
model = Asset
form_class = forms.AssetUpdateForm
template_name = 'assets/asset_update.html'
success_url = reverse_lazy('assets:asset-list')
permission_classes = [IsOrgAdmin]
def get_protocol_formset(self):
ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5)
if self.request.method == "POST":
formset = ProtocolFormset(self.request.POST)
else:
initial_data = self.object.protocols_as_json
formset = ProtocolFormset(initial=initial_data)
return formset
def get_context_data(self, **kwargs):
formset = self.get_protocol_formset()
context = {
'app': _('Assets'),
'action': _('Update asset'),
'formset': formset,
}
kwargs.update(context)
return super().get_context_data(**kwargs)
class AssetBulkUpdateView(PermissionsMixin, ListView):
......@@ -177,36 +179,6 @@ class AssetBulkUpdateView(PermissionsMixin, ListView):
return super().get_context_data(**kwargs)
class AssetUpdateView(PermissionsMixin, SuccessMessageMixin, UpdateView):
model = Asset
form_class = forms.AssetUpdateForm
template_name = 'assets/asset_update.html'
success_url = reverse_lazy('assets:asset-list')
permission_classes = [IsOrgAdmin]
def get_protocol_formset(self):
ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5)
if self.request.method == "POST":
formset = ProtocolFormset(self.request.POST)
else:
initial_data = [{"name": p.name, "port": p.port} for p in self.object.protocols.all()]
formset = ProtocolFormset(initial=initial_data)
return formset
def get_context_data(self, **kwargs):
formset = self.get_protocol_formset()
context = {
'app': _('Assets'),
'action': _('Update asset'),
'formset': formset,
}
kwargs.update(context)
return super().get_context_data(**kwargs)
def get_success_message(self, cleaned_data):
return update_success_msg % ({"name": cleaned_data["hostname"]})
class AssetDeleteView(PermissionsMixin, DeleteView):
model = Asset
template_name = 'delete_confirm.html'
......@@ -222,7 +194,7 @@ class AssetDetailView(PermissionsMixin, DetailView):
def get_queryset(self):
return super().get_queryset().prefetch_related(
"nodes", "labels", "protocols"
"nodes", "labels",
).select_related('admin_user', 'domain')
def get_context_data(self, **kwargs):
......
......@@ -135,9 +135,7 @@ function getSelectedAssetsNode() {
var assetsNode = [];
nodes.forEach(function (node) {
if (node.meta.type === 'asset' && !node.isHidden) {
var protocols = $.map(node.meta.asset.protocols, function (v) {
return v.name
});
var protocols = node.meta.asset.protocols;
if (assetsNodeId.indexOf(node.id) === -1 && protocols.indexOf("ssh") > -1) {
assetsNodeId.push(node.id);
assetsNode.push(node)
......
......@@ -126,7 +126,7 @@ class GenerateTree:
for asset, system_users in assets.items():
self.add_asset(asset, system_users)
#@timeit
# #@timeit
def add_asset(self, asset, system_users=None):
nodes = asset.nodes.all()
nodes = self.node_util.get_nodes_by_queryset(nodes)
......@@ -493,12 +493,13 @@ class AssetPermissionUtil(AssetPermissionCacheMixin):
pattern.add(r'^{0}$|^{0}:'.format(node.key))
pattern = '|'.join(list(pattern))
if pattern:
assets = Asset.objects.filter(nodes__key__regex=pattern)\
.prefetch_related('nodes', "protocols")\
assets = Asset.objects.filter(nodes__key__regex=pattern) \
.prefetch_related('nodes')\
.only(*self.assets_only)\
.distinct()
else:
assets = []
assets = list(assets)
self.tree.add_assets_without_system_users(assets)
assets = self.tree.get_assets()
self._assets = assets
......@@ -598,7 +599,7 @@ def parse_asset_to_tree_node(node, asset, system_users):
'id': asset.id,
'hostname': asset.hostname,
'ip': asset.ip,
'protocols': [str(p) for p in asset.protocols.all()],
'protocols': asset.protocols_as_list,
'platform': asset.platform,
'domain': None if not asset.domain else asset.domain.id,
'is_active': asset.is_active,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment