Commit 83c0ed36 authored by ibuler's avatar ibuler

Update sdk

parent d9d0c4ba
...@@ -7,6 +7,8 @@ from .config import Config ...@@ -7,6 +7,8 @@ from .config import Config
from .sshd import SSHServer from .sshd import SSHServer
from .ws import WSServer from .ws import WSServer
from .logging import create_logger from .logging import create_logger
from .sdk import AppService
from .auth import AppAccessKey
__version__ = '0.4.0' __version__ = '0.4.0'
...@@ -46,6 +48,7 @@ class Coco: ...@@ -46,6 +48,7 @@ class Coco:
self.name = name self.name = name
self.lock = threading.Lock() self.lock = threading.Lock()
self.stop_evt = threading.Event() self.stop_evt = threading.Event()
self.service = None
if name is None: if name is None:
self.name = self.config['NAME'] self.name = self.config['NAME']
...@@ -62,6 +65,7 @@ class Coco: ...@@ -62,6 +65,7 @@ class Coco:
def prepare(self): def prepare(self):
self.sshd = SSHServer(self) self.sshd = SSHServer(self)
self.ws = WSServer(self) self.ws = WSServer(self)
self.initial_service()
def heartbeat(self): def heartbeat(self):
pass pass
...@@ -117,6 +121,9 @@ class Coco: ...@@ -117,6 +121,9 @@ class Coco:
except: except:
pass pass
def initial_service(self):
self.service = AppService(self)
def monitor_session(self): def monitor_session(self):
pass pass
...@@ -3,14 +3,26 @@ ...@@ -3,14 +3,26 @@
# #
import os import os
import six
import logging import logging
from io import IOBase import time
from . import utils from . import utils
from .exception import LoadAccessKeyError from .exception import LoadAccessKeyError
def make_signature(access_key_secret, date=None):
if isinstance(date, bytes):
date = date.decode("utf-8")
if isinstance(date, int):
date_gmt = utils.http_date(date)
elif date is None:
date_gmt = utils.http_date(int(time.time()))
else:
date_gmt = date
data = str(access_key_secret) + "\n" + date_gmt
return utils.content_md5(data)
class AccessKeyAuth(object): class AccessKeyAuth(object):
def __init__(self, access_key_id, access_key_secret): def __init__(self, access_key_id, access_key_secret):
self.id = access_key_id self.id = access_key_id
...@@ -48,7 +60,8 @@ class SessionAuth(object): ...@@ -48,7 +60,8 @@ class SessionAuth(object):
class Auth(object): class Auth(object):
def __init__(self, token=None, access_key_id=None, access_key_secret=None, def __init__(self, token=None, access_key_id=None,
access_key_secret=None,
session_id=None, csrf_token=None): session_id=None, csrf_token=None):
if token is not None: if token is not None:
...@@ -58,7 +71,7 @@ class Auth(object): ...@@ -58,7 +71,7 @@ class Auth(object):
elif session_id and csrf_token: elif session_id and csrf_token:
self.instance = SessionAuth(session_id, csrf_token) self.instance = SessionAuth(session_id, csrf_token)
else: else:
raise OSError('Need token or access_key_id, access_key_secret ' raise SyntaxError('Need token or access_key_id, access_key_secret '
'or session_id, csrf_token') 'or session_id, csrf_token')
def sign_request(self, req): def sign_request(self, req):
...@@ -70,22 +83,27 @@ class AccessKey(object): ...@@ -70,22 +83,27 @@ class AccessKey(object):
self.id = id self.id = id
self.secret = secret self.secret = secret
def clean(self, value, delimiter=':', silent=False): @staticmethod
def clean(value, delimiter=':', silent=False):
try: try:
self.id, self.secret = value.split(delimiter) id, secret = value.split(delimiter)
except (AttributeError, ValueError) as e: except (AttributeError, ValueError) as e:
if not silent: if not silent:
raise LoadAccessKeyError(e) raise LoadAccessKeyError(e)
return '', ''
else: else:
return ':'.join([self.id, self.secret]) return id, secret
def load_from_env(self, env, delimiter=':', silent=False): @classmethod
def load_from_env(cls, env, **kwargs):
value = os.environ.get(env) value = os.environ.get(env)
return self.clean(value, delimiter, silent) id, secret = cls.clean(value, **kwargs)
return cls(id, secret)
def load_from_f(self, f, delimiter=':', silent=False): @classmethod
def load_from_f(cls, f, **kwargs):
value = '' value = ''
if isinstance(f, six.string_types) and os.path.isfile(f): if isinstance(f, str) and os.path.isfile(f):
f = open(f) f = open(f)
if hasattr(f, 'read'): if hasattr(f, 'read'):
for line in f: for line in f:
...@@ -93,10 +111,11 @@ class AccessKey(object): ...@@ -93,10 +111,11 @@ class AccessKey(object):
value = line.strip() value = line.strip()
break break
f.close() f.close()
return self.clean(value, delimiter, silent) id, secret = cls.clean(value, **kwargs)
return cls(id, secret)
def save_to_f(self, f, silent=False): def save_to_f(self, f, silent=False):
if isinstance(f, six.string_types): if isinstance(f, str):
f = open(f, 'w') f = open(f, 'w')
try: try:
f.write(str('{0}:{1}'.format(self.id, self.secret))) f.write(str('{0}:{1}'.format(self.id, self.secret)))
...@@ -113,7 +132,6 @@ class AccessKey(object): ...@@ -113,7 +132,6 @@ class AccessKey(object):
def __str__(self): def __str__(self):
return '{0}:{1}'.format(self.id, self.secret) return '{0}:{1}'.format(self.id, self.secret)
__repr__ = __str__ __repr__ = __str__
...@@ -133,7 +151,7 @@ class ServiceAccessKey(AccessKey): ...@@ -133,7 +151,7 @@ class ServiceAccessKey(AccessKey):
default_key_store = os.path.join(os.environ.get('HOME', ''), '.access_key') default_key_store = os.path.join(os.environ.get('HOME', ''), '.access_key')
def __init__(self, id=None, secret=None, config=None): def __init__(self, id=None, secret=None, config=None):
super(ServiceAccessKey, self).__init__(id=id, secret=secret) super().__init__(id=id, secret=secret)
self.config = config or {} self.config = config or {}
self._key_store = None self._key_store = None
self._key_env = None self._key_env = None
......
...@@ -263,20 +263,4 @@ class Config(dict): ...@@ -263,20 +263,4 @@ class Config(dict):
return '<%s %s>' % (self.__class__.__name__, dict.__repr__(self)) return '<%s %s>' % (self.__class__.__name__, dict.__repr__(self))
API_URL_MAPPING = {
'terminal-register': '/api/applications/v1/terminal/register/',
'terminal-heatbeat': '/api/applications/v1/terminal/heatbeat/',
'send-proxy-log': '/api/audits/v1/proxy-log/receive/',
'finish-proxy-log': '/api/audits/v1/proxy-log/%s/',
'send-command-log': '/api/audits/v1/command-log/',
'send-record-log': '/api/audits/v1/record-log/',
'user-auth': '/api/users/v1/auth/',
'user-assets': '/api/perms/v1/user/%s/assets/',
'user-asset-groups': '/api/perms/v1/user/%s/asset-groups/',
'user-asset-groups-assets': '/api/perms/v1/user/my/asset-groups-assets/',
'assets-of-group': '/api/perms/v1/user/my/asset-group/%s/assets/',
'my-profile': '/api/users/v1/profile/',
'system-user-auth-info': '/api/assets/v1/system-user/%s/auth-info/',
'validate-user-asset-permission':
'/api/perms/v1/asset-permission/user/validate/',
}
...@@ -12,3 +12,6 @@ class LoadAccessKeyError(Exception): ...@@ -12,3 +12,6 @@ class LoadAccessKeyError(Exception):
class RequestError(Exception): class RequestError(Exception):
pass pass
class ResponseError(Exception):
pass
...@@ -19,7 +19,7 @@ class ProxyServer: ...@@ -19,7 +19,7 @@ class ProxyServer:
def __init__(self, app, client): def __init__(self, app, client):
self.app = app self.app = app
self.client = client self.client = client
self.request = client.request self.request = client.do
self.server = None self.server = None
self.connecting = True self.connecting = True
......
...@@ -18,7 +18,7 @@ class InteractiveServer: ...@@ -18,7 +18,7 @@ class InteractiveServer:
def __init__(self, app, client): def __init__(self, app, client):
self.app = app self.app = app
self.client = client self.client = client
self.request = client.request self.request = client.do
def display_banner(self): def display_banner(self):
self.client.send(char.CLEAR_CHAR) self.client.send(char.CLEAR_CHAR)
......
...@@ -11,15 +11,32 @@ import requests ...@@ -11,15 +11,32 @@ import requests
from requests.structures import CaseInsensitiveDict from requests.structures import CaseInsensitiveDict
from cachetools import cached, TTLCache from cachetools import cached, TTLCache
from .auth import Auth, ServiceAccessKey from .auth import Auth, ServiceAccessKey, AccessKey
from .utils import sort_assets, PKey, to_dotmap, timestamp_to_datetime_str from .utils import sort_assets, PKey, timestamp_to_datetime_str
from .exception import RequestError, LoadAccessKeyError from .exception import RequestError, LoadAccessKeyError, ResponseError
from .config import API_URL_MAPPING
_USER_AGENT = 'jms-sdk-py' _USER_AGENT = 'jms-sdk-py'
CACHED_TTL = os.environ.get('CACHED_TTL', 30) CACHED_TTL = os.environ.get('CACHED_TTL', 30)
API_URL_MAPPING = {
'terminal-register': '/api/applications/v1/terminal/register/',
'terminal-heatbeat': '/api/applications/v1/terminal/heatbeat/',
'send-proxy-log': '/api/audits/v1/proxy-log/receive/',
'finish-proxy-log': '/api/audits/v1/proxy-log/%s/',
'send-command-log': '/api/audits/v1/command-log/',
'send-record-log': '/api/audits/v1/record-log/',
'user-auth': '/api/users/v1/auth/',
'user-assets': '/api/perms/v1/user/%s/assets/',
'user-asset-groups': '/api/perms/v1/user/%s/asset-groups/',
'user-asset-groups-assets': '/api/perms/v1/user/my/asset-groups-assets/',
'assets-of-group': '/api/perms/v1/user/my/asset-group/%s/assets/',
'my-profile': '/api/users/v1/profile/',
'system-user-auth-info': '/api/assets/v1/system-user/%s/auth-info/',
'validate-user-asset-permission':
'/api/perms/v1/asset-permission/user/validate/',
}
class FakeResponse(object): class FakeResponse(object):
def __init__(self): def __init__(self):
...@@ -31,118 +48,96 @@ class FakeResponse(object): ...@@ -31,118 +48,96 @@ class FakeResponse(object):
class Request(object): class Request(object):
func_mapping = { methods = {
'get': requests.get, 'get': requests.get,
'post': requests.post, 'post': requests.post,
'patch': requests.patch, 'patch': requests.patch,
'put': requests.put, 'put': requests.put,
} }
def __init__(self, url, method='get', data=None, params=None, headers=None, def __init__(self, url, method='get', data=None, params=None,
content_type='application/json', app_name=''): headers=None, content_type='application/json'):
self.url = url self.url = url
self.method = method self.method = method
self.params = params or {} self.params = params or {}
self.result = None
if not isinstance(headers, dict): if not isinstance(headers, dict):
headers = {} headers = {}
self.headers = CaseInsensitiveDict(headers) self.headers = CaseInsensitiveDict(headers)
self.headers['Content-Type'] = content_type self.headers['Content-Type'] = content_type
if data is None: if data is None:
data = {} data = {}
self.data = json.dumps(data) self.data = json.dumps(data)
if 'User-Agent' not in self.headers: def do(self):
if app_name: result = self.methods.get(self.method)(
self.headers['User-Agent'] = _USER_AGENT + '/' + app_name
else:
self.headers['User-Agent'] = _USER_AGENT
def request(self):
self.result = self.func_mapping.get(self.method)(
url=self.url, headers=self.headers, url=self.url, headers=self.headers,
data=self.data, data=self.data, params=self.params)
params=self.params) return result
print(self.headers)
return self.result
class ApiRequest(object): class AppRequest(object):
api_url_mapping = API_URL_MAPPING
def __init__(self, app_name, endpoint, auth=None): def __init__(self, endpoint, auth=None):
self.app_name = app_name
self._auth = auth self._auth = auth
self.req = None
self.endpoint = endpoint self.endpoint = endpoint
@staticmethod @staticmethod
def parse_result(result): def clean_result(resp):
if resp.status_code >= 400:
return ResponseError("Response code is {0.code}: {0.text}".format(resp))
try: try:
content = result.json() result = resp.json()
except ValueError: except json.JSONDecodeError:
content = {'error': 'We only support json response'} return RequestError("Response json couldn't be decode: {0.text}".format(resp))
logging.warning(result.content) else:
logging.warning(content) return result
except AttributeError:
content = {'error': 'Request error'} def do(self, api_name=None, pk=None, method='get', use_auth=True,
return result, content
def request(self, api_name=None, pk=None, method='get', use_auth=True,
data=None, params=None, content_type='application/json'): data=None, params=None, content_type='application/json'):
if api_name in self.api_url_mapping: if api_name in API_URL_MAPPING:
path = self.api_url_mapping.get(api_name) path = API_URL_MAPPING.get(api_name)
if pk and '%s' in path: if pk and '%s' in path:
path = path % pk path = path % pk
else: else:
path = '/' path = '/'
url = self.endpoint.rstrip('/') + path url = self.endpoint.rstrip('/') + path
print(url) req = Request(url, method=method, data=data,
self.req = req = Request(url, method=method, data=data, params=params, content_type=content_type)
params=params, content_type=content_type,
app_name=self.app_name)
if use_auth: if use_auth:
if not self._auth: if not self._auth:
raise RequestError('Authentication required') raise RequestError('Authentication required')
else: else:
self._auth.sign_request(req) self._auth.sign_request(req)
try: try:
result = req.request() resp = req.do()
if result.status_code > 500: except (requests.ConnectionError, requests.ConnectTimeout) as e:
logging.warning('Server internal error') return RequestError("Connect endpoint: {} {}".format(self.endpoint, e))
except (requests.ConnectionError, requests.ConnectTimeout):
result = FakeResponse() return self.clean_result(resp)
logging.warning('Connect endpoint: {} error'.format(self.endpoint))
return self.parse_result(result)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
kwargs['method'] = 'get' kwargs['method'] = 'get'
print("+"* 10) return self.do(*args, **kwargs)
print(*args)
print("+"* 10)
# print(**kwargs)
print("+"* 10)
return self.request(*args, **kwargs)
def post(self, *args, **kwargs): def post(self, *args, **kwargs):
kwargs['method'] = 'post' kwargs['method'] = 'post'
return self.request(*args, **kwargs) return self.do(*args, **kwargs)
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
kwargs['method'] = 'put' kwargs['method'] = 'put'
return self.request(*args, **kwargs) return self.do(*args, **kwargs)
def patch(self, *args, **kwargs): def patch(self, *args, **kwargs):
kwargs['method'] = 'patch' kwargs['method'] = 'patch'
return self.request(*args, **kwargs) return self.do(*args, **kwargs)
class AppService(ApiRequest): class AppService:
"""使用该类和Jumpserver api进行通信,将terminal用到的常用api进行了封装, """使用该类和Jumpserver api进行通信,将terminal用到的常用api进行了封装,
直接调用方法即可. 直接调用方法即可.
from jms import AppService from jms import AppService
...@@ -172,14 +167,50 @@ class AppService(ApiRequest): ...@@ -172,14 +167,50 @@ class AppService(ApiRequest):
""" """
access_key_class = ServiceAccessKey access_key_class = ServiceAccessKey
def __init__(self, app_name, endpoint, auth=None, config=None): def __init__(self, app):
super(AppService, self).__init__(app_name, endpoint, auth=auth) self.app = app
self.config = config # super(AppService, self).__init__(app_name, endpoint, auth=auth)
self.access_key = self.access_key_class(config=config) # self.config = config
self.user = None # self.access_key = self.access_key_class(config=config)
self.token = None self.access_key = None
self.session_id = None
self.csrf_token = None def load_access_key(self):
# Must be get access key if not register it
self.access_key = ServiceAccessKey(self).load()
if self.access_key is None:
self.register_and_wait_for_accept()
self.save_key_to_store()
def register_and_wait_for_accept(self):
"""注册Terminal, 通常第一次启动需要向Jumpserver注册
content: {
'terminal': {'id': 1, 'name': 'terminal name', ...},
'user': {
'username': 'same as terminal name',
'name': 'same as username',
},
'access_key_id': 'ACCESS KEY ID',
'access_key_secret': 'ACCESS KEY SECRET',
}
"""
r, content = self.post('terminal-register',
data={'name': self.app.name},
use_auth=False)
if r.status_code == 201:
logging.info('Your can save access_key: %s somewhere '
'or set it in config' % content['access_key_id'])
return True, to_dotmap(content)
elif r.status_code == 200:
logging.error('Terminal {} exist already, register failed'
.format(self.app_name))
else:
logging.error('Register terminal {} failed'.format(self.app_name))
return False, None
def save_key_so_store(self):
pass
def auth(self, access_key_id=None, access_key_secret=None): def auth(self, access_key_id=None, access_key_secret=None):
"""App认证, 请求api需要签名header """App认证, 请求api需要签名header
...@@ -205,32 +236,7 @@ class AppService(ApiRequest): ...@@ -205,32 +236,7 @@ class AppService(ApiRequest):
else: else:
raise LoadAccessKeyError('Load access key all failed, auth ignore') raise LoadAccessKeyError('Load access key all failed, auth ignore')
def register_terminal(self):
"""注册Terminal, 通常第一次启动需要向Jumpserver注册
content: {
'terminal': {'id': 1, 'name': 'terminal name', ...},
'user': {
'username': 'same as terminal name',
'name': 'same as username',
},
'access_key_id': 'ACCESS KEY ID',
'access_key_secret': 'ACCESS KEY SECRET',
}
"""
r, content = self.post('terminal-register',
data={'name': self.app_name},
use_auth=False)
if r.status_code == 201:
logging.info('Your can save access_key: %s somewhere '
'or set it in config' % content['access_key_id'])
return True, to_dotmap(content)
elif r.status_code == 200:
logging.error('Terminal {} exist already, register failed'
.format(self.app_name))
else:
logging.error('Register terminal {} failed'.format(self.app_name))
return False, None
def terminal_heatbeat(self): def terminal_heatbeat(self):
"""和Jumpserver维持心跳, 当Terminal断线后,jumpserver可以知晓 """和Jumpserver维持心跳, 当Terminal断线后,jumpserver可以知晓
......
...@@ -86,7 +86,7 @@ class SSHServer: ...@@ -86,7 +86,7 @@ class SSHServer:
self.dispatch(client) self.dispatch(client)
def dispatch(self, client): def dispatch(self, client):
request_type = client.request.type request_type = client.do.type
if request_type == 'pty': if request_type == 'pty':
InteractiveServer(self.app, client).activate() InteractiveServer(self.app, client).activate()
elif request_type == 'exec': elif request_type == 'exec':
......
...@@ -4,6 +4,7 @@ from __future__ import unicode_literals ...@@ -4,6 +4,7 @@ from __future__ import unicode_literals
import hashlib import hashlib
import re import re
import os
import threading import threading
import base64 import base64
import calendar import calendar
...@@ -11,12 +12,11 @@ import time ...@@ -11,12 +12,11 @@ import time
import datetime import datetime
from io import StringIO from io import StringIO
import paramiko
import pyte import pyte
import pytz import pytz
from email.utils import formatdate from email.utils import formatdate
import paramiko
from dotmap import DotMap
try: try:
...@@ -24,9 +24,6 @@ try: ...@@ -24,9 +24,6 @@ try:
except ImportError: except ImportError:
from queue import Queue, Empty from queue import Queue, Empty
from .compat import to_string, to_bytes
def ssh_key_string_to_obj(text): def ssh_key_string_to_obj(text):
key_f = StringIO(text) key_f = StringIO(text)
...@@ -258,21 +255,6 @@ def b64encode_as_string(data): ...@@ -258,21 +255,6 @@ def b64encode_as_string(data):
return to_string(base64.b64encode(data)) return to_string(base64.b64encode(data))
def make_signature(access_key_secret, date=None):
if isinstance(date, bytes):
date = date.decode("utf-8")
if isinstance(date, int):
date_gmt = http_date(date)
elif date is None:
date_gmt = http_date(int(time.time()))
else:
date_gmt = date
data = str(access_key_secret) + "\n" + date_gmt
return content_md5(data)
def split_string_int(s): def split_string_int(s):
"""Split string or int """Split string or int
...@@ -335,17 +317,6 @@ def timestamp_to_datetime_str(ts): ...@@ -335,17 +317,6 @@ def timestamp_to_datetime_str(ts):
return dt.strftime(datetime_format) return dt.strftime(datetime_format)
def to_dotmap(data):
"""将接受dict转换为DotMap"""
if isinstance(data, dict):
data = DotMap(data)
elif isinstance(data, list):
data = [DotMap(d) for d in data]
else:
raise ValueError('Dict or list type required...')
return data
class MultiQueue(Queue): class MultiQueue(Queue):
def mget(self, size=1, block=True, timeout=5): def mget(self, size=1, block=True, timeout=5):
items = [] items = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment