Commit 70da177e authored by ibuler's avatar ibuler

Update api

parent 92d854b9
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
from rest_framework import viewsets, generics, mixins from rest_framework import viewsets, generics, mixins
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework_bulk import BulkListSerializer, BulkSerializerMixin, ListBulkCreateUpdateDestroyAPIView from rest_framework_bulk import BulkListSerializer, BulkSerializerMixin, ListBulkCreateUpdateDestroyAPIView
...@@ -90,34 +91,19 @@ class AssetListUpdateApi(IDInFilterMixin, ListBulkCreateUpdateDestroyAPIView): ...@@ -90,34 +91,19 @@ class AssetListUpdateApi(IDInFilterMixin, ListBulkCreateUpdateDestroyAPIView):
permission_classes = (IsSuperUser,) permission_classes = (IsSuperUser,)
class SystemUserAuthApi(APIView): class SystemUserAuthInfoApi(generics.RetrieveAPIView):
queryset = SystemUser.objects.all()
permission_classes = (IsSuperUserOrAppUser,) permission_classes = (IsSuperUserOrAppUser,)
def get(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
system_user_id = request.query_params.get('system_user_id', -1) system_user = self.get_object()
system_user_username = request.query_params.get('system_user_username', '') data = {
system_user = get_object_or_none(SystemUser, id=system_user_id, username=system_user_username)
if system_user:
if system_user.password:
password = signer.sign(system_user.password)
else:
password = signer.sign('')
if system_user.private_key:
private_key = signer.sign(system_user.private_key)
else:
private_key = signer.sign(None)
response = {
'id': system_user.id, 'id': system_user.id,
'password': password, 'name': system_user.name,
'private_key': private_key, 'username': system_user.username,
'password': system_user.password,
'private_key': system_user.private_key,
'auth_method': system_user.auth_method,
} }
return Response(data)
return Response(response)
else:
return Response({'msg': 'error system user id or username'}, status=401)
...@@ -263,7 +263,7 @@ class SystemUserForm(forms.ModelForm): ...@@ -263,7 +263,7 @@ class SystemUserForm(forms.ModelForm):
class Meta: class Meta:
model = SystemUser model = SystemUser
fields = [ fields = [
'name', 'username', 'protocol', 'auto_generate_key', 'password', 'private_key_file', 'as_default', 'name', 'username', 'protocol', 'auto_generate_key', 'password', 'private_key_file', 'auth_method',
'auto_push', 'auto_update', 'sudo', 'comment', 'shell', 'home', 'uid', 'auto_push', 'auto_update', 'sudo', 'comment', 'shell', 'home', 'uid',
] ]
widgets = { widgets = {
...@@ -273,8 +273,8 @@ class SystemUserForm(forms.ModelForm): ...@@ -273,8 +273,8 @@ class SystemUserForm(forms.ModelForm):
help_texts = { help_texts = {
'name': '* required', 'name': '* required',
'username': '* required', 'username': '* required',
'auth_push': 'Auto push system user to asset', 'auto_push': 'Auto push system user to asset',
'auth_update': 'Auto update system user ssh key', 'auto_update': 'Auto update system user ssh key',
} }
......
...@@ -95,13 +95,18 @@ class SystemUser(models.Model): ...@@ -95,13 +95,18 @@ class SystemUser(models.Model):
PROTOCOL_CHOICES = ( PROTOCOL_CHOICES = (
('ssh', 'ssh'), ('ssh', 'ssh'),
) )
AUTH_METHOD_CHOICES = (
('P', 'Password'),
('K', 'Public key'),
)
name = models.CharField(max_length=128, unique=True, verbose_name=_('Name')) name = models.CharField(max_length=128, unique=True, verbose_name=_('Name'))
username = models.CharField(max_length=16, verbose_name=_('Username')) username = models.CharField(max_length=16, verbose_name=_('Username'))
_password = models.CharField(max_length=256, blank=True, verbose_name=_('Password')) _password = models.CharField(max_length=256, blank=True, verbose_name=_('Password'))
protocol = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, default='ssh', verbose_name=_('Protocol')) protocol = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, default='ssh', verbose_name=_('Protocol'))
_private_key = models.CharField(max_length=4096, blank=True, verbose_name=_('SSH private key')) _private_key = models.CharField(max_length=4096, blank=True, verbose_name=_('SSH private key'))
_public_key = models.CharField(max_length=4096, blank=True, verbose_name=_('SSH public key')) _public_key = models.CharField(max_length=4096, blank=True, verbose_name=_('SSH public key'))
as_default = models.BooleanField(default=False, verbose_name=_('As default')) auth_method = models.CharField(choices=AUTH_METHOD_CHOICES, default='K',
max_length=1, verbose_name=_('Auth method'))
auto_push = models.BooleanField(default=True, verbose_name=_('Auto push')) auto_push = models.BooleanField(default=True, verbose_name=_('Auto push'))
auto_update = models.BooleanField(default=True, verbose_name=_('Auto update pass/key')) auto_update = models.BooleanField(default=True, verbose_name=_('Auto update pass/key'))
sudo = models.TextField(max_length=4096, default='/user/bin/whoami', verbose_name=_('Sudo')) sudo = models.TextField(max_length=4096, default='/user/bin/whoami', verbose_name=_('Sudo'))
......
...@@ -17,6 +17,7 @@ class AssetGroupSerializer(serializers.ModelSerializer): ...@@ -17,6 +17,7 @@ class AssetGroupSerializer(serializers.ModelSerializer):
def get_assets_amount(obj): def get_assets_amount(obj):
return obj.assets.count() return obj.assets.count()
class AssetUpdateGroupSerializer(serializers.ModelSerializer): class AssetUpdateGroupSerializer(serializers.ModelSerializer):
groups = serializers.PrimaryKeyRelatedField(many=True, queryset=AssetGroup.objects.all()) groups = serializers.PrimaryKeyRelatedField(many=True, queryset=AssetGroup.objects.all())
...@@ -24,6 +25,7 @@ class AssetUpdateGroupSerializer(serializers.ModelSerializer): ...@@ -24,6 +25,7 @@ class AssetUpdateGroupSerializer(serializers.ModelSerializer):
model = Asset model = Asset
fields = ['id', 'groups'] fields = ['id', 'groups']
class AssetUpdateSystemUserSerializer(serializers.ModelSerializer): class AssetUpdateSystemUserSerializer(serializers.ModelSerializer):
system_users = serializers.PrimaryKeyRelatedField(many=True, queryset=SystemUser.objects.all()) system_users = serializers.PrimaryKeyRelatedField(many=True, queryset=SystemUser.objects.all())
...@@ -31,6 +33,7 @@ class AssetUpdateSystemUserSerializer(serializers.ModelSerializer): ...@@ -31,6 +33,7 @@ class AssetUpdateSystemUserSerializer(serializers.ModelSerializer):
model = Asset model = Asset
fields = ['id', 'system_users'] fields = ['id', 'system_users']
class AdminUserSerializer(serializers.ModelSerializer): class AdminUserSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AdminUser model = AdminUser
...@@ -52,6 +55,12 @@ class SystemUserSerializer(serializers.ModelSerializer): ...@@ -52,6 +55,12 @@ class SystemUserSerializer(serializers.ModelSerializer):
return fields return fields
class SystemUserSimpleSerializer(serializers.ModelSerializer):
class Meta:
model = SystemUser
fields = ('id', 'name', 'username')
class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
# system_users = SystemUserSerializer(many=True, read_only=True) # system_users = SystemUserSerializer(many=True, read_only=True)
# admin_user = AdminUserSerializer(many=False, read_only=True) # admin_user = AdminUserSerializer(many=False, read_only=True)
...@@ -75,7 +84,7 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer): ...@@ -75,7 +84,7 @@ class AssetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
class AssetGrantedSerializer(serializers.ModelSerializer): class AssetGrantedSerializer(serializers.ModelSerializer):
system_users = SystemUserSerializer(many=True, read_only=True) system_users = SystemUserSimpleSerializer(many=True, read_only=True)
is_inherited = serializers.SerializerMethodField() is_inherited = serializers.SerializerMethodField()
system_users_join = serializers.SerializerMethodField() system_users_join = serializers.SerializerMethodField()
......
...@@ -16,11 +16,12 @@ router.register(r'v1/system-user', api.SystemUserViewSet, 'system-user') ...@@ -16,11 +16,12 @@ router.register(r'v1/system-user', api.SystemUserViewSet, 'system-user')
urlpatterns = [ urlpatterns = [
url(r'^v1/assets_bulk/$', api.AssetListUpdateApi.as_view(), name='asset-bulk-update'), url(r'^v1/assets_bulk/$', api.AssetListUpdateApi.as_view(), name='asset-bulk-update'),
# url(r'^v1/idc/(?P<pk>[0-9]+)/assets/$', api.IDCAssetsApi.as_view(), name='api-idc-assets'), # url(r'^v1/idc/(?P<pk>[0-9]+)/assets/$', api.IDCAssetsApi.as_view(), name='api-idc-assets'),
url(r'^v1/system-user/auth/', api.SystemUserAuthApi.as_view(), name='system-user-auth'), url(r'^v1/system-user/(?P<pk>[0-9]+)/auth-info/', api.SystemUserAuthInfoApi.as_view(),
name='system-user-auth-info'),
url(r'^v1/assets/(?P<pk>\d+)/groups/$', url(r'^v1/assets/(?P<pk>\d+)/groups/$',
api.AssetUpdateGroupApi.as_view(), name='asset-update-group'), api.AssetUpdateGroupApi.as_view(), name='asset-update-group'),
url(r'^v1/assets/(?P<pk>\d+)/system-users/$', url(r'^v1/assets/(?P<pk>\d+)/system-users/$',
api.SystemUserUpdateApi.as_view(), name='asset-update-systemusers'), api.SystemUserUpdateApi.as_view(), name='asset-update-system-users'),
] ]
urlpatterns += router.urls urlpatterns += router.urls
......
...@@ -18,7 +18,6 @@ class LoginLog(models.Model): ...@@ -18,7 +18,6 @@ class LoginLog(models.Model):
username = models.CharField(max_length=20, verbose_name=_('Username')) username = models.CharField(max_length=20, verbose_name=_('Username'))
name = models.CharField(max_length=20, blank=True, verbose_name=_('Name')) name = models.CharField(max_length=20, blank=True, verbose_name=_('Name'))
login_type = models.CharField(choices=LOGIN_TYPE_CHOICE, max_length=2, verbose_name=_('Login type')) login_type = models.CharField(choices=LOGIN_TYPE_CHOICE, max_length=2, verbose_name=_('Login type'))
terminal = models.CharField(max_length=32, verbose_name=_('Terminal'))
login_ip = models.GenericIPAddressField(verbose_name=_('Login ip')) login_ip = models.GenericIPAddressField(verbose_name=_('Login ip'))
login_city = models.CharField(max_length=100, blank=True, null=True, verbose_name=_('Login city')) login_city = models.CharField(max_length=100, blank=True, null=True, verbose_name=_('Login city'))
user_agent = models.CharField(max_length=100, blank=True, null=True, verbose_name=_('User agent')) user_agent = models.CharField(max_length=100, blank=True, null=True, verbose_name=_('User agent'))
......
...@@ -17,25 +17,28 @@ def validate_ip(ip): ...@@ -17,25 +17,28 @@ def validate_ip(ip):
return False return False
def write_login_log(username, name='', login_type='W', def write_login_log(username, name='', login_type='',
terminal='', login_ip='', user_agent=''): login_ip='', user_agent=''):
if not (login_ip and validate_ip(login_ip)): if not (login_ip and validate_ip(login_ip)):
login_ip = '0.0.0.0' login_ip = '0.0.0.0'
if not name: if not name:
name = username name = username
login_city = get_ip_city(login_ip) login_city = get_ip_city(login_ip)
LoginLog.objects.create(username=username, name=name, login_type=login_type, login_ip=login_ip, LoginLog.objects.create(username=username, name=name, login_type=login_type,
terminal=terminal, login_city=login_city, user_agent=user_agent) login_ip=login_ip, login_city=login_city, user_agent=user_agent)
def get_ip_city(ip, timeout=10): def get_ip_city(ip, timeout=10):
# Taobao ip api: http://ip.taobao.com//service/getIpInfo.php?ip=8.8.8.8 # Taobao ip api: http://ip.taobao.com//service/getIpInfo.php?ip=8.8.8.8
# Sina ip api: http://int.dpool.sina.com.cn/iplookup/iplookup.php?ip=8.8.8.8&format=js # Sina ip api: http://int.dpool.sina.com.cn/iplookup/iplookup.php?ip=8.8.8.8&format=json
url = 'http://ip.taobao.com/service/getIpInfo.php?ip=' + ip url = 'http://int.dpool.sina.com.cn/iplookup/iplookup.php?ip=%s&format=json' % ip
try:
r = requests.get(url, timeout=timeout) r = requests.get(url, timeout=timeout)
except requests.Timeout:
r = None
city = 'Unknown' city = 'Unknown'
if r.status_code == 200: if r and r.status_code == 200:
try: try:
data = r.json() data = r.json()
if data['code'] == 0: if data['code'] == 0:
......
...@@ -100,7 +100,7 @@ class UserToken(APIView): ...@@ -100,7 +100,7 @@ class UserToken(APIView):
user, msg = check_user_valid(username=username, email=email, user, msg = check_user_valid(username=username, email=email,
password=password, public_key=public_key) password=password, public_key=public_key)
if user: if user:
token = generate_token(request) token = generate_token(request, user)
return Response({'Token': token, 'key': 'Bearer'}, status=200) return Response({'Token': token, 'key': 'Bearer'}, status=200)
else: else:
return Response({'error': msg}, status=406) return Response({'error': msg}, status=406)
...@@ -114,28 +114,22 @@ class UserProfile(APIView): ...@@ -114,28 +114,22 @@ class UserProfile(APIView):
class UserAuthApi(APIView): class UserAuthApi(APIView):
permission_classes = () permission_classes = (AllowAny,)
expiration = settings.CONFIG.TOKEN_EXPIRATION or 3600
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
username = request.data.get('username', '') username = request.data.get('username', '')
password = request.data.get('password', '') password = request.data.get('password', '')
public_key = request.data.get('public_key', '') public_key = request.data.get('public_key', '')
remote_addr = request.data.get('remote_addr', '') login_type = request.data.get('login_type', '')
terminal = request.data.get('applications', '') login_ip = request.META.get('REMOTE_ADDR', '')
login_type = request.data.get('login_type', 'T') user_agent = request.data.get('HTTP_USER_AGENT', '')
user = check_user_valid(username=username, password=password, public_key=public_key)
user, msg = check_user_valid(username=username, password=password, public_key=public_key)
if user: if user:
token = cache.get('%s_%s' % (user.id, remote_addr)) token = generate_token(request, user)
if not token: write_login_log_async.delay(user.username, name=user.name, user_agent=user_agent,
token = generate_token(request) login_ip=login_ip, login_type=login_type)
return Response({'token': token, 'user': user.to_json()})
cache.set(token, user.id, self.expiration)
cache.set('%s_%s' % (user.id, remote_addr), token, self.expiration)
write_login_log_async.delay(user.username, name=user.name, terminal=terminal,
login_ip=remote_addr, login_type=login_type)
return Response({'token': token, 'id': user.id, 'username': user.username,
'name': user.name, 'is_active': user.is_active})
else: else:
return Response({'msg': 'Invalid password or public key or user is not active or expired'}, status=401) return Response({'msg': msg}, status=401)
...@@ -43,7 +43,6 @@ class AccessKeyAuthentication(authentication.BaseAuthentication): ...@@ -43,7 +43,6 @@ class AccessKeyAuthentication(authentication.BaseAuthentication):
msg = _('Invalid signature header. Signature string should not contain spaces.') msg = _('Invalid signature header. Signature string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
try: try:
sign = auth[1].decode().split(':') sign = auth[1].decode().split(':')
if len(sign) != 2: if len(sign) != 2:
...@@ -58,7 +57,8 @@ class AccessKeyAuthentication(authentication.BaseAuthentication): ...@@ -58,7 +57,8 @@ class AccessKeyAuthentication(authentication.BaseAuthentication):
return self.authenticate_credentials(request, access_key_id, request_signature) return self.authenticate_credentials(request, access_key_id, request_signature)
def authenticate_credentials(self, request, access_key_id, request_signature): @staticmethod
def authenticate_credentials(request, access_key_id, request_signature):
access_key = get_object_or_none(AccessKey, id=access_key_id) access_key = get_object_or_none(AccessKey, id=access_key_id)
request_date = get_request_date_header(request) request_date = get_request_date_header(request)
if access_key is None or not access_key.user: if access_key is None or not access_key.user:
...@@ -109,7 +109,8 @@ class AccessTokenAuthentication(authentication.BaseAuthentication): ...@@ -109,7 +109,8 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token) return self.authenticate_credentials(token)
def authenticate_credentials(self, token): @staticmethod
def authenticate_credentials(token):
user_id = cache.get(token) user_id = cache.get(token)
user = get_object_or_none(User, id=user_id) user = get_object_or_none(User, id=user_id)
......
...@@ -17,6 +17,7 @@ router.register(r'v1/user-groups', api.UserGroupViewSet, 'user-group') ...@@ -17,6 +17,7 @@ router.register(r'v1/user-groups', api.UserGroupViewSet, 'user-group')
urlpatterns = [ urlpatterns = [
url(r'^v1/token/$', api.UserToken.as_view(), name='user-token'), url(r'^v1/token/$', api.UserToken.as_view(), name='user-token'),
url(r'^v1/profile/$', api.UserProfile.as_view(), name='user-profile'), url(r'^v1/profile/$', api.UserProfile.as_view(), name='user-profile'),
url(r'^v1/auth/$', api.UserAuthApi.as_view(), name='user-auth'),
url(r'^v1/users/(?P<pk>\d+)/password/reset/$', api.UserResetPasswordApi.as_view(), name='user-reset-password'), url(r'^v1/users/(?P<pk>\d+)/password/reset/$', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
url(r'^v1/users/(?P<pk>\d+)/public-key/reset/$', api.UserResetPKApi.as_view(), name='user-public-key-reset'), url(r'^v1/users/(?P<pk>\d+)/public-key/reset/$', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
url(r'^v1/users/(?P<pk>\d+)/public-key/update/$', api.UserUpdatePKApi.as_view(), name='user-public-key-update'), url(r'^v1/users/(?P<pk>\d+)/public-key/update/$', api.UserUpdatePKApi.as_view(), name='user-public-key-update'),
......
...@@ -180,8 +180,8 @@ def send_reset_ssh_key_mail(user): ...@@ -180,8 +180,8 @@ def send_reset_ssh_key_mail(user):
def check_user_valid(**kwargs): def check_user_valid(**kwargs):
password = kwargs.pop('password', None) password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None) public_key = kwargs.pop('public_key', None)
email = kwargs.pop('email') email = kwargs.pop('email', None)
username = kwargs.pop('username') username = kwargs.pop('username', None)
if username: if username:
user = get_object_or_none(User, username=username) user = get_object_or_none(User, username=username)
...@@ -206,24 +206,23 @@ def check_user_valid(**kwargs): ...@@ -206,24 +206,23 @@ def check_user_valid(**kwargs):
elif len(public_key_saved) > 1: elif len(public_key_saved) > 1:
if public_key == public_key_saved[1]: if public_key == public_key_saved[1]:
return user, '' return user, ''
return None, _('Passowrd or SSH public key invalid') return None, _('Password or SSH public key invalid')
def refresh_token(token, user): def refresh_token(token, user, expiration=3600):
expiration = settings.CONFIG.TOKEN_EXPIRATION or 3600
cache.set(token, user.id, expiration) cache.set(token, user.id, expiration)
def generate_token(request): def generate_token(request, user):
expiration = settings.CONFIG.TOKEN_EXPIRATION or 3600 expiration = settings.CONFIG.TOKEN_EXPIRATION or 3600
remote_addr = request.META.get('REMOTE_ADDR', '') remote_addr = request.META.get('REMOTE_ADDR', '')
remote_addr = base64.b16encode(remote_addr).replace('=', '') remote_addr = base64.b16encode(remote_addr).replace('=', '')
token = cache.get('%s_%s' % (request.user.id, remote_addr)) token = cache.get('%s_%s' % (user.id, remote_addr))
if not token: if not token:
token = uuid.uuid4().get_hex() token = uuid.uuid4().get_hex()
print('Set cache: %s' % token) print('Set cache: %s' % token)
cache.set(token, request.user.id, expiration) cache.set(token, user.id, expiration)
cache.set('%s_%s' % (request.user.id, remote_addr), token, expiration) cache.set('%s_%s' % (user.id, remote_addr), token, expiration)
return token return token
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment