Commit 6d96b5db authored by ibuler's avatar ibuler

[Update] 修改org mixin

parent e8ebc941
...@@ -206,7 +206,7 @@ class Node(OrgModelMixin): ...@@ -206,7 +206,7 @@ class Node(OrgModelMixin):
return self.get_all_assets().valid() return self.get_all_assets().valid()
def is_default_node(self): def is_default_node(self):
return self.is_root() and self.key == '0' return self.is_root() and self.key == '1'
def is_root(self): def is_root(self):
if self.key.isdigit(): if self.key.isdigit():
......
...@@ -17,7 +17,7 @@ class AdminUserSerializer(BulkOrgResourceModelSerializer): ...@@ -17,7 +17,7 @@ class AdminUserSerializer(BulkOrgResourceModelSerializer):
""" """
class Meta: class Meta:
list_serializer_class = AdaptedBulkListSerializer # list_serializer_class = AdaptedBulkListSerializer
model = AdminUser model = AdminUser
fields = [ fields = [
'id', 'name', 'username', 'password', 'private_key', 'public_key', 'id', 'name', 'username', 'password', 'private_key', 'public_key',
......
...@@ -112,7 +112,6 @@ def to_dict(data): ...@@ -112,7 +112,6 @@ def to_dict(data):
@register.filter @register.filter
def sort(data): def sort(data):
print(data)
return sorted(data) return sorted(data)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from werkzeug.local import Local
from django.db import models from django.db import models
from django.conf import settings
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.shortcuts import redirect, get_object_or_404 from django.shortcuts import redirect, get_object_or_404
from django.forms import ModelForm from django.forms import ModelForm
from django.http.response import HttpResponseForbidden from django.http.response import HttpResponseForbidden
from django.core.exceptions import ValidationError
from rest_framework import serializers from rest_framework import serializers
from rest_framework.validators import UniqueTogetherValidator from rest_framework.validators import UniqueTogetherValidator
...@@ -16,12 +13,12 @@ from common.utils import get_logger ...@@ -16,12 +13,12 @@ from common.utils import get_logger
from common.validators import ProjectUniqueValidator from common.validators import ProjectUniqueValidator
from common.mixins import BulkSerializerMixin from common.mixins import BulkSerializerMixin
from .utils import ( from .utils import (
current_org, set_current_org, set_to_root_org, get_current_org_id set_current_org, set_to_root_org, get_current_org, current_org,
get_current_org_id_for_serializer,
) )
from .models import Organization from .models import Organization
logger = get_logger(__file__) logger = get_logger(__file__)
local = Local()
__all__ = [ __all__ = [
'OrgManager', 'OrgViewGenericMixin', 'OrgModelMixin', 'OrgModelForm', 'OrgManager', 'OrgViewGenericMixin', 'OrgModelMixin', 'OrgModelForm',
...@@ -29,43 +26,28 @@ __all__ = [ ...@@ -29,43 +26,28 @@ __all__ = [
'OrgMembershipModelViewSetMixin', 'OrgResourceSerializerMixin', 'OrgMembershipModelViewSetMixin', 'OrgResourceSerializerMixin',
'BulkOrgResourceSerializerMixin', 'BulkOrgResourceModelSerializer', 'BulkOrgResourceSerializerMixin', 'BulkOrgResourceModelSerializer',
] ]
debug = settings.DEBUG
class OrgManager(models.Manager): class OrgManager(models.Manager):
def get_queryset(self): def get_queryset(self):
queryset = super(OrgManager, self).get_queryset() queryset = super(OrgManager, self).get_queryset()
kwargs = {} kwargs = {}
if not current_org: _current_org = get_current_org()
if _current_org is None:
kwargs['id'] = None kwargs['id'] = None
elif current_org.is_real(): elif _current_org.is_real():
kwargs['org_id'] = current_org.id kwargs['org_id'] = _current_org.id
elif current_org.is_default(): elif _current_org.is_default():
queryset = queryset.filter(org_id="") queryset = queryset.filter(org_id="")
queryset = queryset.filter(**kwargs)
return queryset
def filter_by_fullname(self, fullname, field=None): queryset = queryset.filter(**kwargs)
ori_org = current_org
value, org = self.model.split_fullname(fullname)
set_current_org(org)
if not field:
if hasattr(self.model, 'name'):
field = 'name'
elif hasattr(self.model, 'hostname'):
field = 'hostname'
queryset = self.get_queryset().filter(**{field: value})
set_current_org(ori_org)
return queryset return queryset
def get_object_by_fullname(self, fullname, field=None):
queryset = self.filter_by_fullname(fullname, field=field)
if len(queryset) == 1:
return queryset[0]
return None
def all(self): def all(self):
if not current_org: _current_org = get_current_org()
if _current_org is None:
msg = 'You can `objects.set_current_org(org).all()` then run it' msg = 'You can `objects.set_current_org(org).all()` then run it'
return self return self
else: else:
...@@ -73,35 +55,23 @@ class OrgManager(models.Manager): ...@@ -73,35 +55,23 @@ class OrgManager(models.Manager):
def set_current_org(self, org): def set_current_org(self, org):
if isinstance(org, str): if isinstance(org, str):
org = Organization.objects.get(name=org) org = Organization.get_instance(org)
set_current_org(org) set_current_org(org)
return self return self
class OrgModelMixin(models.Model): class OrgModelMixin(models.Model):
org_id = models.CharField(max_length=36, blank=True, default='', verbose_name=_("Organization"), db_index=True) org_id = models.CharField(max_length=36, blank=True, default='',
verbose_name=_("Organization"), db_index=True)
objects = OrgManager() objects = OrgManager()
sep = '@' sep = '@'
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if current_org and current_org.is_real(): if current_org is not None and current_org.is_real():
self.org_id = current_org.id self.org_id = current_org.id
return super().save(*args, **kwargs) return super().save(*args, **kwargs)
@classmethod
def split_fullname(cls, fullname, sep=None):
if not sep:
sep = cls.sep
index = fullname.rfind(sep)
if index == -1:
value = fullname
org = Organization.default()
else:
value = fullname[:index]
org = Organization.get_instance(fullname[index + 1:])
return value, org
@property @property
def org(self): def org(self):
from orgs.models import Organization from orgs.models import Organization
...@@ -126,41 +96,19 @@ class OrgModelMixin(models.Model): ...@@ -126,41 +96,19 @@ class OrgModelMixin(models.Model):
else: else:
return name return name
def validate_unique(self, exclude=None):
"""
Check unique constraints on the model and raise ValidationError if any
failed.
"""
self.org_id = current_org.id if current_org.is_real() else ''
if exclude and 'org_id' in exclude:
exclude.remove('org_id')
unique_checks, date_checks = self._get_unique_checks(exclude=exclude)
errors = self._perform_unique_checks(unique_checks)
date_errors = self._perform_date_checks(date_checks)
for k, v in date_errors.items():
errors.setdefault(k, []).extend(v)
if errors:
raise ValidationError(errors)
class Meta: class Meta:
abstract = True abstract = True
class OrgViewGenericMixin: class OrgViewGenericMixin:
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
if not current_org: if current_org is None:
return redirect('orgs:switch-a-org') return redirect('orgs:switch-a-org')
if not current_org.can_admin_by(request.user): if not current_org.can_admin_by(request.user):
print("{} cannot admin {}".format(request.user, current_org))
if request.user.is_org_admin: if request.user.is_org_admin:
return redirect('orgs:switch-a-org') return redirect('orgs:switch-a-org')
return HttpResponseForbidden() return HttpResponseForbidden()
else:
print(current_org.can_admin_by(request.user))
return super().dispatch(request, *args, **kwargs) return super().dispatch(request, *args, **kwargs)
...@@ -216,7 +164,7 @@ class OrgResourceSerializerMixin(serializers.Serializer): ...@@ -216,7 +164,7 @@ class OrgResourceSerializerMixin(serializers.Serializer):
由于HiddenField字段不可读,API获取资产信息时获取不到org_id, 由于HiddenField字段不可读,API获取资产信息时获取不到org_id,
但是coco需要资产的org_id字段,所以修改为CharField类型 但是coco需要资产的org_id字段,所以修改为CharField类型
""" """
org_id = serializers.ReadOnlyField(default=get_current_org_id, label=_("Organization")) org_id = serializers.ReadOnlyField(default=get_current_org_id_for_serializer, label=_("Organization"))
org_name = serializers.ReadOnlyField(label=_("Org name")) org_name = serializers.ReadOnlyField(label=_("Org name"))
def get_validators(self): def get_validators(self):
...@@ -236,7 +184,7 @@ class OrgResourceSerializerMixin(serializers.Serializer): ...@@ -236,7 +184,7 @@ class OrgResourceSerializerMixin(serializers.Serializer):
return fields return fields
class BulkOrgResourceSerializerMixin(BulkSerializerMixin, OrgResourceSerializerMixin): class BulkOrgResourceSerializerMixin(OrgResourceSerializerMixin, BulkSerializerMixin):
pass pass
......
import uuid import uuid
from django.db import models from django.db import models
from django.core.cache import cache
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.utils import is_uuid from common.utils import is_uuid
...@@ -16,9 +15,12 @@ class Organization(models.Model): ...@@ -16,9 +15,12 @@ class Organization(models.Model):
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created')) date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment')) comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
orgs = None
CACHE_PREFIX = 'JMS_ORG_{}' CACHE_PREFIX = 'JMS_ORG_{}'
ROOT_ID_NAME = 'ROOT' ROOT_ID = '00000000-0000-0000-0000-000000000000'
DEFAULT_ID_NAME = 'DEFAULT' ROOT_NAME = 'ROOT'
DEFAULT_ID = 'DEFAULT'
DEFAULT_NAME = 'DEFAULT'
class Meta: class Meta:
verbose_name = _("Organization") verbose_name = _("Organization")
...@@ -27,33 +29,30 @@ class Organization(models.Model): ...@@ -27,33 +29,30 @@ class Organization(models.Model):
return self.name return self.name
def set_to_cache(self): def set_to_cache(self):
key_id = self.CACHE_PREFIX.format(self.id) if self.__class__.orgs is None:
key_name = self.CACHE_PREFIX.format(self.name) self.__class__.orgs = {}
cache.set(key_id, self, 3600) self.__class__.orgs[str(self.id)] = self
cache.set(key_name, self, 3600)
def expire_cache(self): def expire_cache(self):
key_id = self.CACHE_PREFIX.format(self.id) self.__class__.orgs.pop(str(self.id), None)
key_name = self.CACHE_PREFIX.format(self.name)
cache.delete(key_id)
cache.delete(key_name)
@classmethod @classmethod
def get_instance_from_cache(cls, oid): def get_instance_from_cache(cls, oid):
key = cls.CACHE_PREFIX.format(oid) if not cls.orgs or not isinstance(cls.orgs, dict):
return cache.get(key, None) return None
return cls.orgs.get(str(oid))
@classmethod @classmethod
def get_instance(cls, id_or_name, default=True): def get_instance(cls, id_or_name, default=False):
cached = cls.get_instance_from_cache(id_or_name) cached = cls.get_instance_from_cache(id_or_name)
if cached: if cached:
return cached return cached
if not id_or_name: if id_or_name is None:
return cls.default() if default else None return cls.default() if default else None
elif id_or_name == cls.DEFAULT_ID_NAME: elif id_or_name in [cls.DEFAULT_ID, cls.DEFAULT_NAME, '']:
return cls.default() return cls.default()
elif id_or_name == cls.ROOT_ID_NAME: elif id_or_name in [cls.ROOT_ID, cls.ROOT_NAME]:
return cls.root() return cls.root()
try: try:
...@@ -89,7 +88,7 @@ class Organization(models.Model): ...@@ -89,7 +88,7 @@ class Organization(models.Model):
return False return False
def is_real(self): def is_real(self):
return len(str(self.id)) == 36 return self.id not in (self.DEFAULT_NAME, self.ROOT_ID)
@classmethod @classmethod
def get_user_admin_orgs(cls, user): def get_user_admin_orgs(cls, user):
...@@ -105,20 +104,20 @@ class Organization(models.Model): ...@@ -105,20 +104,20 @@ class Organization(models.Model):
@classmethod @classmethod
def default(cls): def default(cls):
return cls(id=cls.DEFAULT_ID_NAME, name=cls.DEFAULT_ID_NAME) return cls(id=cls.DEFAULT_ID, name=cls.DEFAULT_NAME)
@classmethod @classmethod
def root(cls): def root(cls):
return cls(id=cls.ROOT_ID_NAME, name=cls.ROOT_ID_NAME) return cls(id=cls.ROOT_ID, name=cls.ROOT_NAME)
def is_root(self): def is_root(self):
if self.id is self.ROOT_ID_NAME: if self.id is self.ROOT_ID:
return True return True
else: else:
return False return False
def is_default(self): def is_default(self):
if self.id is self.DEFAULT_ID_NAME: if self.id is self.DEFAULT_ID:
return True return True
else: else:
return False return False
......
...@@ -10,12 +10,14 @@ def get_org_from_request(request): ...@@ -10,12 +10,14 @@ def get_org_from_request(request):
oid = request.session.get("oid") oid = request.session.get("oid")
if not oid: if not oid:
oid = request.META.get("HTTP_X_JMS_ORG") oid = request.META.get("HTTP_X_JMS_ORG")
if not oid:
oid = Organization.DEFAULT_ID
org = Organization.get_instance(oid) org = Organization.get_instance(oid)
return org return org
def set_current_org(org): def set_current_org(org):
setattr(thread_local, 'current_org', org.id) setattr(thread_local, 'current_org_id', org.id)
def set_to_default_org(): def set_to_default_org():
...@@ -31,13 +33,22 @@ def _find(attr): ...@@ -31,13 +33,22 @@ def _find(attr):
def get_current_org(): def get_current_org():
org_id = _find('current_org') org_id = get_current_org_id()
if org_id is None:
return None
org = Organization.get_instance(org_id) org = Organization.get_instance(org_id)
return org return org
def get_current_org_id(): def get_current_org_id():
org_id = _find('current_org') org_id = _find('current_org_id')
return org_id
def get_current_org_id_for_serializer():
org_id = get_current_org_id()
if org_id == Organization.DEFAULT_ID:
org_id = ''
return org_id return org_id
......
...@@ -13,7 +13,8 @@ class SwitchOrgView(DetailView): ...@@ -13,7 +13,8 @@ class SwitchOrgView(DetailView):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
pk = kwargs.get('pk') pk = kwargs.get('pk')
self.object = Organization.get_instance(pk) self.object = Organization.get_instance(pk)
request.session['oid'] = self.object.id.__str__() oid = str(self.object.id)
request.session['oid'] = oid
host = request.get_host() host = request.get_host()
referer = request.META.get('HTTP_REFERER') referer = request.META.get('HTTP_REFERER')
if referer.find(host) != -1: if referer.find(host) != -1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment