Unverified Commit 3c0549f4 authored by 老广's avatar 老广 Committed by GitHub

Merge pull request #60 from jumpserver/dev

Dev
parents accd0e67 d2e789de
...@@ -16,6 +16,7 @@ require ( ...@@ -16,6 +16,7 @@ require (
github.com/gliderlabs/ssh v0.2.3-0.20190711180243-866d0ddf7991 github.com/gliderlabs/ssh v0.2.3-0.20190711180243-866d0ddf7991
github.com/go-playground/form v3.1.4+incompatible // indirect github.com/go-playground/form v3.1.4+incompatible // indirect
github.com/gorilla/mux v1.7.2 github.com/gorilla/mux v1.7.2
github.com/gorilla/websocket v1.4.0
github.com/jarcoal/httpmock v1.0.4 github.com/jarcoal/httpmock v1.0.4
github.com/kataras/neffos v0.0.7 github.com/kataras/neffos v0.0.7
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
......
...@@ -9,13 +9,12 @@ import ( ...@@ -9,13 +9,12 @@ import (
"github.com/jumpserver/koko/pkg/cctx" "github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/i18n"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/service" "github.com/jumpserver/koko/pkg/service"
) )
var mfaInstruction = i18n.T("Please enter 6 digits.") var mfaInstruction = "Please enter 6 digits."
var mfaQuestion = i18n.T("[MFA auth]: ") var mfaQuestion = "[MFA auth]: "
const ( const (
actionAccepted = "Accepted" actionAccepted = "Accepted"
...@@ -33,7 +32,7 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult) ...@@ -33,7 +32,7 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult)
} }
remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0] remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0]
resp, err := service.Authenticate(username, password, publicKey, remoteAddr, "T") resp, err := service.Authenticate(username, password, publicKey, remoteAddr, "ST")
if err != nil { if err != nil {
action = actionFailed action = actionFailed
logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr) logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr)
......
...@@ -76,7 +76,7 @@ func (c *Client) marshalData(data interface{}) (reader io.Reader, error error) { ...@@ -76,7 +76,7 @@ func (c *Client) marshalData(data interface{}) (reader io.Reader, error error) {
} }
func (c *Client) parseUrlQuery(url string, params []map[string]string) string { func (c *Client) parseUrlQuery(url string, params []map[string]string) string {
if len(params) != 1 { if len(params) < 1 {
return url return url
} }
var query []string var query []string
...@@ -119,7 +119,7 @@ func (c *Client) setAuthHeader(r *http.Request) { ...@@ -119,7 +119,7 @@ func (c *Client) setAuthHeader(r *http.Request) {
} }
} }
func (c *Client) SetReqHeaders(req *http.Request) { func (c *Client) SetReqHeaders(req *http.Request, params []map[string]string) {
if len(c.Headers) != 0 { if len(c.Headers) != 0 {
for k, v := range c.Headers { for k, v := range c.Headers {
req.Header.Set(k, v) req.Header.Set(k, v)
...@@ -128,8 +128,13 @@ func (c *Client) SetReqHeaders(req *http.Request) { ...@@ -128,8 +128,13 @@ func (c *Client) SetReqHeaders(req *http.Request) {
if req.Header.Get("Content-Type") == "" { if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
req.Header.Set("user-Agent", "koko-client") req.Header.Set("User-Agent", "koko-client")
c.setAuthHeader(req) c.setAuthHeader(req)
if len(params) >= 2 {
for k, v := range params[1] {
req.Header.Set(k, v)
}
}
} }
func (c *Client) NewRequest(method, url string, body interface{}, params []map[string]string) (req *http.Request, err error) { func (c *Client) NewRequest(method, url string, body interface{}, params []map[string]string) (req *http.Request, err error) {
...@@ -139,16 +144,16 @@ func (c *Client) NewRequest(method, url string, body interface{}, params []map[s ...@@ -139,16 +144,16 @@ func (c *Client) NewRequest(method, url string, body interface{}, params []map[s
return return
} }
req, err = http.NewRequest(method, url, reader) req, err = http.NewRequest(method, url, reader)
c.SetReqHeaders(req) c.SetReqHeaders(req, params)
return req, err return req, err
} }
// Do wrapper http.Client Do() for using auth and error handle // Do wrapper http.Client Do() for using auth and error handle
// params: // params:
// 1. query string if set {"name": "ibuler"} // 1. query string if set {"name": "ibuler"}
func (c *Client) Do(method, url string, data, res interface{}, params ...map[string]string) (err error) { func (c *Client) Do(method, url string, data, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
req, err := c.NewRequest(method, url, data, params) req, err := c.NewRequest(method, url, data, params)
resp, err := c.http.Do(req) resp, err = c.http.Do(req)
if err != nil { if err != nil {
return return
} }
...@@ -167,7 +172,7 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str ...@@ -167,7 +172,7 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
return return
} }
// Unmarshal response body to result struct // Unmarshal response body to result struct
if res != nil { if res != nil && resp.StatusCode >= 200 && resp.StatusCode < 300 {
err = json.Unmarshal(body, res) err = json.Unmarshal(body, res)
if err != nil { if err != nil {
msg := fmt.Sprintf("%s %s failed, unmarshal '%s' response failed: %s", req.Method, req.URL, body[:12], err) msg := fmt.Sprintf("%s %s failed, unmarshal '%s' response failed: %s", req.Method, req.URL, body[:12], err)
...@@ -178,23 +183,23 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str ...@@ -178,23 +183,23 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
return return
} }
func (c *Client) Get(url string, res interface{}, params ...map[string]string) (err error) { func (c *Client) Get(url string, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
return c.Do("GET", url, nil, res, params...) return c.Do("GET", url, nil, res, params...)
} }
func (c *Client) Post(url string, data interface{}, res interface{}, params ...map[string]string) (err error) { func (c *Client) Post(url string, data interface{}, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
return c.Do("POST", url, data, res, params...) return c.Do("POST", url, data, res, params...)
} }
func (c *Client) Delete(url string, res interface{}, params ...map[string]string) (err error) { func (c *Client) Delete(url string, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
return c.Do("DELETE", url, nil, res, params...) return c.Do("DELETE", url, nil, res, params...)
} }
func (c *Client) Put(url string, data interface{}, res interface{}, params ...map[string]string) (err error) { func (c *Client) Put(url string, data interface{}, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
return c.Do("PUT", url, data, res, params...) return c.Do("PUT", url, data, res, params...)
} }
func (c *Client) Patch(url string, data interface{}, res interface{}, params ...map[string]string) (err error) { func (c *Client) Patch(url string, data interface{}, res interface{}, params ...map[string]string) (resp *http.Response, err error) {
return c.Do("PATCH", url, data, res, params...) return c.Do("PATCH", url, data, res, params...)
} }
...@@ -248,7 +253,7 @@ func (c *Client) UploadFile(url string, gFile string, res interface{}, params .. ...@@ -248,7 +253,7 @@ func (c *Client) UploadFile(url string, gFile string, res interface{}, params ..
url = c.parseUrl(url, params) url = c.parseUrl(url, params)
req, err := http.NewRequest("POST", url, buf) req, err := http.NewRequest("POST", url, buf)
req.Header.Set("Content-Type", bodyWriter.FormDataContentType()) req.Header.Set("Content-Type", bodyWriter.FormDataContentType())
c.SetReqHeaders(req) c.SetReqHeaders(req, params)
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
return return
......
...@@ -22,29 +22,6 @@ import ( ...@@ -22,29 +22,6 @@ import (
"github.com/jumpserver/koko/pkg/utils" "github.com/jumpserver/koko/pkg/utils"
) )
type assetsCacheContainer struct {
mapData map[string][]model.Asset
lock *sync.RWMutex
}
func (c *assetsCacheContainer) Get(key string) ([]model.Asset, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
value, ok := c.mapData[key]
return value, ok
}
func (c *assetsCacheContainer) SetValue(key string, value []model.Asset) {
c.lock.Lock()
defer c.lock.Unlock()
c.mapData[key] = value
}
var userAssetsCached = assetsCacheContainer{
mapData: make(map[string][]model.Asset),
lock: new(sync.RWMutex),
}
func SessionHandler(sess ssh.Session) { func SessionHandler(sess ssh.Session) {
pty, _, ok := sess.Pty() pty, _, ok := sess.Pty()
if ok { if ok {
...@@ -98,7 +75,7 @@ func (h *interactiveHandler) Initial() { ...@@ -98,7 +75,7 @@ func (h *interactiveHandler) Initial() {
} }
func (h *interactiveHandler) loadAssetsFromCache() { func (h *interactiveHandler) loadAssetsFromCache() {
if assets, ok := userAssetsCached.Get(h.user.ID); ok { if assets, ok := service.GetUserAssetsFromCache(h.user.ID); ok {
h.assets = assets h.assets = assets
close(h.assetDataLoaded) close(h.assetDataLoaded)
} else { } else {
...@@ -109,8 +86,8 @@ func (h *interactiveHandler) loadAssetsFromCache() { ...@@ -109,8 +86,8 @@ func (h *interactiveHandler) loadAssetsFromCache() {
func (h *interactiveHandler) firstLoadAssetAndNodes() { func (h *interactiveHandler) firstLoadAssetAndNodes() {
h.loadUserAssets("1") h.loadUserAssets("1")
h.loadUserAssetNodes("1") h.loadUserNodes("1")
logger.Debug("first Load Asset And Nodes done") logger.Debug("First load assets and nodes done")
close(h.nodeDataLoaded) close(h.nodeDataLoaded)
select { select {
case <-h.assetDataLoaded: case <-h.assetDataLoaded:
...@@ -198,6 +175,9 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) { ...@@ -198,6 +175,9 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) {
} }
default: default:
switch { switch {
case line == "exit", line == "quit":
logger.Info("exit session")
return
case strings.Index(line, "/") == 0: case strings.Index(line, "/") == 0:
searchWord := strings.TrimSpace(line[1:]) searchWord := strings.TrimSpace(line[1:])
assets := h.searchAsset(searchWord) assets := h.searchAsset(searchWord)
...@@ -315,7 +295,7 @@ func (h *interactiveHandler) displayNodes(nodes []model.Node) { ...@@ -315,7 +295,7 @@ func (h *interactiveHandler) displayNodes(nodes []model.Node) {
func (h *interactiveHandler) refreshAssetsAndNodesData() { func (h *interactiveHandler) refreshAssetsAndNodesData() {
h.loadUserAssets("2") h.loadUserAssets("2")
h.loadUserAssetNodes("2") h.loadUserNodes("2")
_, err := io.WriteString(h.term, i18n.T("Refresh done")+"\n\r") _, err := io.WriteString(h.term, i18n.T("Refresh done")+"\n\r")
if err != nil { if err != nil {
logger.Error("refresh Assets Nodes err:", err) logger.Error("refresh Assets Nodes err:", err)
...@@ -324,14 +304,15 @@ func (h *interactiveHandler) refreshAssetsAndNodesData() { ...@@ -324,14 +304,15 @@ func (h *interactiveHandler) refreshAssetsAndNodesData() {
func (h *interactiveHandler) loadUserAssets(cachePolicy string) { func (h *interactiveHandler) loadUserAssets(cachePolicy string) {
assets := service.GetUserAssets(h.user.ID, cachePolicy, "") assets := service.GetUserAssets(h.user.ID, cachePolicy, "")
userAssetsCached.SetValue(h.user.ID, assets)
h.mu.Lock() h.mu.Lock()
h.assets = assets h.assets = assets
h.mu.Unlock() h.mu.Unlock()
} }
func (h *interactiveHandler) loadUserAssetNodes(cachePolicy string) { func (h *interactiveHandler) loadUserNodes(cachePolicy string) {
h.mu.Lock()
h.nodes = service.GetUserNodes(h.user.ID, cachePolicy) h.nodes = service.GetUserNodes(h.user.ID, cachePolicy)
h.mu.Unlock()
} }
func (h *interactiveHandler) searchAsset(key string) (assets []model.Asset) { func (h *interactiveHandler) searchAsset(key string) (assets []model.Asset) {
...@@ -359,23 +340,16 @@ func (h *interactiveHandler) searchAsset(key string) (assets []model.Asset) { ...@@ -359,23 +340,16 @@ func (h *interactiveHandler) searchAsset(key string) (assets []model.Asset) {
assets = append(assets, assetValue) assets = append(assets, assetValue)
} }
} }
// assetsData, _ := Cached.Load(h.user.ID)
// for _, assetValue := range assetsData.([]model.Asset) {
// if isSubstring([]string{assetValue.IP, assetValue.Hostname, assetValue.Comment}, key) {
// assets = append(assets, assetValue)
// }
// }
return assets return assets
} }
func (h *interactiveHandler) searchNodeAssets(num int) (assets []model.Asset) { func (h *interactiveHandler) searchNodeAssets(num int) (assets model.AssetList) {
if num > len(h.nodes) || num == 0 { if num > len(h.nodes) || num == 0 {
return assets return assets
} }
return h.nodes[num-1].AssetsGranted node := h.nodes[num-1]
assets = service.GetUserNodeAssets(h.user.ID, node.ID, "1")
return
} }
func (h *interactiveHandler) Proxy(ctx context.Context) { func (h *interactiveHandler) Proxy(ctx context.Context) {
...@@ -396,7 +370,7 @@ func ConstructAssetNodeTree(assetNodes []model.Node) treeprint.Tree { ...@@ -396,7 +370,7 @@ func ConstructAssetNodeTree(assetNodes []model.Node) treeprint.Tree {
tree := treeprint.New() tree := treeprint.New()
for i := 0; i < len(assetNodes); i++ { for i := 0; i < len(assetNodes); i++ {
r := strings.LastIndex(assetNodes[i].Key, ":") r := strings.LastIndex(assetNodes[i].Key, ":")
if r < 0 { if _, ok := treeMap[assetNodes[i].Key[:r]]; r < 0 || !ok {
subtree := tree.AddBranch(fmt.Sprintf("%s.%s(%s)", subtree := tree.AddBranch(fmt.Sprintf("%s.%s(%s)",
strconv.Itoa(i+1), assetNodes[i].Name, strconv.Itoa(i+1), assetNodes[i].Name,
strconv.Itoa(assetNodes[i].AssetsAmount))) strconv.Itoa(assetNodes[i].AssetsAmount)))
...@@ -409,7 +383,6 @@ func ConstructAssetNodeTree(assetNodes []model.Node) treeprint.Tree { ...@@ -409,7 +383,6 @@ func ConstructAssetNodeTree(assetNodes []model.Node) treeprint.Tree {
strconv.Itoa(assetNodes[i].AssetsAmount))) strconv.Itoa(assetNodes[i].AssetsAmount)))
treeMap[assetNodes[i].Key] = nodeTree treeMap[assetNodes[i].Key] = nodeTree
} }
} }
return tree return tree
} }
......
...@@ -76,7 +76,6 @@ func (u *UserVolume) Info(path string) (elfinder.FileDir, error) { ...@@ -76,7 +76,6 @@ func (u *UserVolume) Info(path string) (elfinder.FileDir, error) {
} }
if filename == "." { if filename == "." {
filename = originFileInfo.Name() filename = originFileInfo.Name()
fmt.Println("askldkasdlala")
} }
rest.Name = filename rest.Name = filename
rest.Hash = hashPath(u.Uuid, filepath.Join(dirPath, filename)) rest.Hash = hashPath(u.Uuid, filepath.Join(dirPath, filename))
......
package httpd
import (
"net"
"net/http"
"sync"
"time"
"github.com/kataras/neffos"
gorilla "github.com/gorilla/websocket"
)
// DefaultUpgrader is a gorilla/websocket Upgrader with all fields set to the default values.
var DefaultUpgrader = Upgrader(gorilla.Upgrader{})
// Upgrader is a `neffos.Upgrader` type for the gorilla/websocket subprotocol implementation.
// Should be used on `New` to construct the neffos server.
func Upgrader(upgrader gorilla.Upgrader) neffos.Upgrader {
return func(w http.ResponseWriter, r *http.Request) (neffos.Socket, error) {
header := w.Header()
header.Set("Access-Control-Allow-Origin", "*")
underline, err := upgrader.Upgrade(w, r, header)
if err != nil {
return nil, err
}
return newSocket(underline, r, false), nil
}
}
// Socket completes the `neffos.Socket` interface,
// it describes the underline websocket connection.
type Socket struct {
UnderlyingConn *gorilla.Conn
request *http.Request
client bool
mu sync.Mutex
}
func newSocket(underline *gorilla.Conn, request *http.Request, client bool) *Socket {
return &Socket{
UnderlyingConn: underline,
request: request,
client: client,
}
}
// NetConn returns the underline net connection.
func (s *Socket) NetConn() net.Conn {
return s.UnderlyingConn.UnderlyingConn()
}
// Request returns the http request value.
func (s *Socket) Request() *http.Request {
return s.request
}
// ReadData reads binary or text messages from the remote connection.
func (s *Socket) ReadData(timeout time.Duration) ([]byte, error) {
for {
if timeout > 0 {
s.UnderlyingConn.SetReadDeadline(time.Now().Add(timeout))
}
opCode, data, err := s.UnderlyingConn.ReadMessage()
if err != nil {
return nil, err
}
if opCode != gorilla.BinaryMessage && opCode != gorilla.TextMessage {
// if gorilla.IsUnexpectedCloseError(err, gorilla.CloseGoingAway) ...
continue
}
return data, err
}
}
// WriteBinary sends a binary message to the remote connection.
func (s *Socket) WriteBinary(body []byte, timeout time.Duration) error {
return s.write(body, gorilla.BinaryMessage, timeout)
}
// WriteText sends a text message to the remote connection.
func (s *Socket) WriteText(body []byte, timeout time.Duration) error {
return s.write(body, gorilla.TextMessage, timeout)
}
func (s *Socket) write(body []byte, opCode int, timeout time.Duration) error {
if timeout > 0 {
s.UnderlyingConn.SetWriteDeadline(time.Now().Add(timeout))
}
s.mu.Lock()
err := s.UnderlyingConn.WriteMessage(opCode, body)
s.mu.Unlock()
return err
}
...@@ -150,14 +150,13 @@ type Domain struct { ...@@ -150,14 +150,13 @@ type Domain struct {
} }
type Node struct { type Node struct {
ID string `json:"id"` ID string `json:"id"`
Key string `json:"key"` Key string `json:"key"`
Name string `json:"name"` Name string `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Parent string `json:"parent"` Parent string `json:"parent"`
AssetsGranted []Asset `json:"assets_granted"` AssetsAmount int `json:"assets_amount"`
AssetsAmount int `json:"assets_amount"` OrgID string `json:"org_id"`
OrgID string `json:"org_id"`
} }
type nodeSortBy func(node1, node2 *Node) bool type nodeSortBy func(node1, node2 *Node) bool
...@@ -207,7 +206,6 @@ func keySort(node1, node2 *Node) bool { ...@@ -207,7 +206,6 @@ func keySort(node1, node2 *Node) bool {
} else { } else {
return true return true
} }
} }
return true return true
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.SystemUserAuthInfo) { func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.SystemUserAuthInfo) {
Url := fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID) Url := fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID)
err := authClient.Get(Url, &info) _, err := authClient.Get(Url, &info)
if err != nil { if err != nil {
logger.Error("Get system user Asset auth info failed") logger.Error("Get system user Asset auth info failed")
} }
...@@ -59,7 +59,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt ...@@ -59,7 +59,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
]`*/ ]`*/
Url := fmt.Sprintf(SystemUserCmdFilterRulesListURL, systemUserID) Url := fmt.Sprintf(SystemUserCmdFilterRulesListURL, systemUserID)
err = authClient.Get(Url, &rules) _, err = authClient.Get(Url, &rules)
if err != nil { if err != nil {
logger.Error("Get system user auth info failed") logger.Error("Get system user auth info failed")
} }
...@@ -68,7 +68,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt ...@@ -68,7 +68,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
func GetSystemUser(systemUserID string) (info model.SystemUser) { func GetSystemUser(systemUserID string) (info model.SystemUser) {
Url := fmt.Sprintf(SystemUserDetailURL, systemUserID) Url := fmt.Sprintf(SystemUserDetailURL, systemUserID)
err := authClient.Get(Url, &info) _, err := authClient.Get(Url, &info)
if err != nil { if err != nil {
logger.Errorf("Get system user %s failed", systemUserID) logger.Errorf("Get system user %s failed", systemUserID)
} }
...@@ -77,7 +77,7 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) { ...@@ -77,7 +77,7 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) {
func GetAsset(assetID string) (asset model.Asset) { func GetAsset(assetID string) (asset model.Asset) {
Url := fmt.Sprintf(AssetDetailURL, assetID) Url := fmt.Sprintf(AssetDetailURL, assetID)
err := authClient.Get(Url, &asset) _, err := authClient.Get(Url, &asset)
if err != nil { if err != nil {
logger.Errorf("Get Asset %s failed\n", assetID) logger.Errorf("Get Asset %s failed\n", assetID)
} }
...@@ -86,7 +86,7 @@ func GetAsset(assetID string) (asset model.Asset) { ...@@ -86,7 +86,7 @@ func GetAsset(assetID string) (asset model.Asset) {
func GetDomainWithGateway(gID string) (domain model.Domain) { func GetDomainWithGateway(gID string) (domain model.Domain) {
url := fmt.Sprintf(DomainDetailURL, gID) url := fmt.Sprintf(DomainDetailURL, gID)
err := authClient.Get(url, &domain) _, err := authClient.Get(url, &domain)
if err != nil { if err != nil {
logger.Errorf("Get domain %s failed: %s", gID, err) logger.Errorf("Get domain %s failed: %s", gID, err)
} }
...@@ -95,7 +95,7 @@ func GetDomainWithGateway(gID string) (domain model.Domain) { ...@@ -95,7 +95,7 @@ func GetDomainWithGateway(gID string) (domain model.Domain) {
func GetTokenAsset(token string) (tokenUser model.TokenUser) { func GetTokenAsset(token string) (tokenUser model.TokenUser) {
Url := fmt.Sprintf(TokenAssetURL, token) Url := fmt.Sprintf(TokenAssetURL, token)
err := authClient.Get(Url, &tokenUser) _, err := authClient.Get(Url, &tokenUser)
if err != nil { if err != nil {
logger.Error("Get Token Asset info failed: ", err) logger.Error("Get Token Asset info failed: ", err)
} }
......
package service
import (
"sync"
"github.com/jumpserver/koko/pkg/model"
)
type assetsCacheContainer struct {
mapData map[string]model.AssetList
mapETag map[string]string
mu *sync.RWMutex
}
func (c *assetsCacheContainer) Get(key string) (model.AssetList, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapData[key]
return value, ok
}
func (c *assetsCacheContainer) GetETag(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapETag[key]
return value, ok
}
func (c *assetsCacheContainer) SetValue(key string, value model.AssetList) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapData[key] = value
}
func (c *assetsCacheContainer) SetETag(key string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapETag[key] = value
}
type nodesCacheContainer struct {
mapData map[string]model.NodeList
mapETag map[string]string
mu *sync.RWMutex
}
func (c *nodesCacheContainer) Get(key string) (model.NodeList, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapData[key]
return value, ok
}
func (c *nodesCacheContainer) GetETag(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapETag[key]
return value, ok
}
func (c *nodesCacheContainer) SetValue(key string, value model.NodeList) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapData[key] = value
}
func (c *nodesCacheContainer) SetETag(key string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapETag[key] = value
}
...@@ -66,7 +66,7 @@ func validateAccessAuth() { ...@@ -66,7 +66,7 @@ func validateAccessAuth() {
func MustLoadServerConfigOnce() { func MustLoadServerConfigOnce() {
var data map[string]interface{} var data map[string]interface{}
err := authClient.Get(TerminalConfigURL, &data) _, err := authClient.Get(TerminalConfigURL, &data)
if err != nil { if err != nil {
logger.Error("Load config from server error: ", err) logger.Error("Load config from server error: ", err)
return return
...@@ -86,7 +86,7 @@ func MustLoadServerConfigOnce() { ...@@ -86,7 +86,7 @@ func MustLoadServerConfigOnce() {
func LoadConfigFromServer() (err error) { func LoadConfigFromServer() (err error) {
conf := config.GetConf() conf := config.GetConf()
err = authClient.Get(TerminalConfigURL, conf) _, err = authClient.Get(TerminalConfigURL, conf)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -2,37 +2,97 @@ package service ...@@ -2,37 +2,97 @@ package service
import ( import (
"fmt" "fmt"
"sync"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
) )
var userAssetsCached = assetsCacheContainer{
mapData: make(map[string]model.AssetList),
mapETag: make(map[string]string),
mu: new(sync.RWMutex),
}
var userNodesCached = nodesCacheContainer{
mapData: make(map[string]model.NodeList),
mapETag: make(map[string]string),
mu: new(sync.RWMutex),
}
func GetUserAssetsFromCache(userID string) (assets model.AssetList, ok bool) {
assets, ok = userAssetsCached.Get(userID)
return
}
func GetUserAssets(userID, cachePolicy, assetId string) (assets model.AssetList) { func GetUserAssets(userID, cachePolicy, assetId string) (assets model.AssetList) {
if cachePolicy == "" { if cachePolicy == "" {
cachePolicy = "1" cachePolicy = "1"
} }
headers := make(map[string]string)
if etag, ok := userAssetsCached.GetETag(userID); ok && cachePolicy == "1" && assetId == "" {
headers["If-None-Match"] = etag
}
payload := map[string]string{"cache_policy": cachePolicy} payload := map[string]string{"cache_policy": cachePolicy}
if assetId != "" { if assetId != "" {
payload["id"] = assetId payload["id"] = assetId
} }
Url := fmt.Sprintf(UserAssetsURL, userID) Url := fmt.Sprintf(UserAssetsURL, userID)
err := authClient.Get(Url, &assets, payload) resp, err := authClient.Get(Url, &assets, payload, headers)
if err != nil { if err != nil {
logger.Error("Get user assets error: ", err) logger.Error("Get user assets error: ", err)
return
}
if resp.StatusCode == 200 && resp.Header.Get("ETag") != "" {
newETag := resp.Header.Get("ETag")
userAssetsCached.SetValue(userID, assets)
userAssetsCached.SetETag(userID, newETag)
} else if resp.StatusCode == 304 {
assets, _ = userAssetsCached.Get(userID)
} }
return return
} }
func GetUserNodesFromCache(userID string) (nodes model.NodeList, ok bool) {
nodes, ok = userNodesCached.Get(userID)
return
}
func GetUserNodes(userID, cachePolicy string) (nodes model.NodeList) { func GetUserNodes(userID, cachePolicy string) (nodes model.NodeList) {
if cachePolicy == "" { if cachePolicy == "" {
cachePolicy = "1" cachePolicy = "1"
} }
headers := make(map[string]string)
if etag, ok := userNodesCached.GetETag(userID); ok && cachePolicy == "1" {
headers["If-None-Match"] = etag
}
payload := map[string]string{"cache_policy": cachePolicy} payload := map[string]string{"cache_policy": cachePolicy}
Url := fmt.Sprintf(UserNodesAssetsURL, userID) Url := fmt.Sprintf(UserNodesListURL, userID)
err := authClient.Get(Url, &nodes, payload) resp, err := authClient.Get(Url, &nodes, payload, headers)
if err != nil { if err != nil {
logger.Error("Get user nodes error: ", err) logger.Error("Get user nodes error: ", err)
} }
if resp.StatusCode == 200 && resp.Header.Get("ETag") != "" {
userNodesCached.SetValue(userID, nodes)
userNodesCached.SetETag(userID, resp.Header.Get("ETag"))
} else if resp.StatusCode == 304 {
nodes, _ = userNodesCached.Get(userID)
}
return
}
func GetUserNodeAssets(userID, nodeID, cachePolicy string) (assets model.AssetList) {
if cachePolicy == "" {
cachePolicy = "1"
}
payload := map[string]string{"cache_policy": cachePolicy, "all": "1"}
Url := fmt.Sprintf(UserNodeAssetsListURL, userID, nodeID)
_, err := authClient.Get(Url, &assets, payload)
if err != nil {
logger.Error("Get user node assets error: ", err)
return
}
return return
} }
...@@ -48,7 +108,7 @@ func ValidateUserAssetPermission(userID, assetID, systemUserID, action string) b ...@@ -48,7 +108,7 @@ func ValidateUserAssetPermission(userID, assetID, systemUserID, action string) b
var res struct { var res struct {
Msg bool `json:"msg"` Msg bool `json:"msg"`
} }
err := authClient.Get(Url, &res, payload) _, err := authClient.Get(Url, &res, payload)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
......
...@@ -13,7 +13,7 @@ func RegisterTerminal(name, token, comment string) (res model.Terminal) { ...@@ -13,7 +13,7 @@ func RegisterTerminal(name, token, comment string) (res model.Terminal) {
} }
client.Headers["Authorization"] = fmt.Sprintf("BootstrapToken %s", token) client.Headers["Authorization"] = fmt.Sprintf("BootstrapToken %s", token)
data := map[string]string{"name": name, "comment": comment} data := map[string]string{"name": name, "comment": comment}
err := client.Post(TerminalRegisterURL, data, &res) _, err := client.Post(TerminalRegisterURL, data, &res)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
...@@ -25,7 +25,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) { ...@@ -25,7 +25,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) {
data := map[string][]string{ data := map[string][]string{
"sessions": sIds, "sessions": sIds,
} }
err := authClient.Post(TerminalHeartBeatURL, data, &res) _, err := authClient.Post(TerminalHeartBeatURL, data, &res)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
...@@ -34,7 +34,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) { ...@@ -34,7 +34,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) {
func CreateSession(data map[string]interface{}) bool { func CreateSession(data map[string]interface{}) bool {
var res map[string]interface{} var res map[string]interface{}
err := authClient.Post(SessionListURL, data, &res) _, err := authClient.Post(SessionListURL, data, &res)
if err == nil { if err == nil {
return true return true
} }
...@@ -51,7 +51,7 @@ func FinishSession(data map[string]interface{}) { ...@@ -51,7 +51,7 @@ func FinishSession(data map[string]interface{}) {
"date_end": data["date_end"], "date_end": data["date_end"],
} }
Url := fmt.Sprintf(SessionDetailURL, sid) Url := fmt.Sprintf(SessionDetailURL, sid)
err := authClient.Patch(Url, payload, &res) _, err := authClient.Patch(Url, payload, &res)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
...@@ -63,7 +63,7 @@ func FinishReply(sid string) bool { ...@@ -63,7 +63,7 @@ func FinishReply(sid string) bool {
var res map[string]interface{} var res map[string]interface{}
data := map[string]bool{"has_replay": true} data := map[string]bool{"has_replay": true}
Url := fmt.Sprintf(SessionDetailURL, sid) Url := fmt.Sprintf(SessionDetailURL, sid)
err := authClient.Patch(Url, data, &res) _, err := authClient.Patch(Url, data, &res)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return false return false
...@@ -75,7 +75,7 @@ func FinishTask(tid string) bool { ...@@ -75,7 +75,7 @@ func FinishTask(tid string) bool {
var res map[string]interface{} var res map[string]interface{}
data := map[string]bool{"is_finished": true} data := map[string]bool{"is_finished": true}
Url := fmt.Sprintf(FinishTaskURL, tid) Url := fmt.Sprintf(FinishTaskURL, tid)
err := authClient.Patch(Url, data, &res) _, err := authClient.Patch(Url, data, &res)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return false return false
...@@ -94,7 +94,7 @@ func PushSessionReplay(sessionID, gZipFile string) (err error) { ...@@ -94,7 +94,7 @@ func PushSessionReplay(sessionID, gZipFile string) (err error) {
} }
func PushSessionCommand(commands []*model.Command) (err error) { func PushSessionCommand(commands []*model.Command) (err error) {
err = authClient.Post(SessionCommandURL, commands, nil) _, err = authClient.Post(SessionCommandURL, commands, nil)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
...@@ -102,7 +102,7 @@ func PushSessionCommand(commands []*model.Command) (err error) { ...@@ -102,7 +102,7 @@ func PushSessionCommand(commands []*model.Command) (err error) {
} }
func PushFTPLog(data *model.FTPLog) (err error) { func PushFTPLog(data *model.FTPLog) (err error) {
err = authClient.Post(FTPLogListURL, data, nil) _, err = authClient.Post(FTPLogListURL, data, nil)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
......
...@@ -26,7 +26,9 @@ const ( ...@@ -26,7 +26,9 @@ const (
FTPLogListURL = "/api/audits/v1/ftp-log/" // 上传 ftp日志 FTPLogListURL = "/api/audits/v1/ftp-log/" // 上传 ftp日志
UserAssetsURL = "/api/perms/v1/users/%s/assets/" //获取用户授权的所有资产 UserAssetsURL = "/api/perms/v1/users/%s/assets/" //获取用户授权的所有资产
UserNodesAssetsURL = "/api/perms/v1/users/%s/nodes-assets/" // 获取用户授权的所有节点信息 节点分组 UserNodesAssetsURL = "/api/perms/v1/users/%s/nodes-assets/" // 获取用户授权的所有节点信息 节点分组
UserNodesListURL = "/api/perms/v1/users/%s/nodes/"
UserNodeAssetsListURL = "/api/perms/v1/users/%s/nodes/%s/assets/"
ValidateUserAssetPermissionURL = "/api/perms/v1/asset-permissions/user/validate/" //0不使用缓存 1 使用缓存 2 刷新缓存 ValidateUserAssetPermissionURL = "/api/perms/v1/asset-permissions/user/validate/" //0不使用缓存 1 使用缓存 2 刷新缓存
) )
...@@ -23,13 +23,13 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) ( ...@@ -23,13 +23,13 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) (
"remote_addr": remoteAddr, "remote_addr": remoteAddr,
"login_type": loginType, "login_type": loginType,
} }
err = client.Post(UserAuthURL, data, &resp) _, err = client.Post(UserAuthURL, data, &resp)
return return
} }
func GetUserDetail(userID string) (user *model.User) { func GetUserDetail(userID string) (user *model.User) {
Url := fmt.Sprintf(UserDetailURL, userID) Url := fmt.Sprintf(UserDetailURL, userID)
err := authClient.Get(Url, &user) _, err := authClient.Get(Url, &user)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
...@@ -37,14 +37,14 @@ func GetUserDetail(userID string) (user *model.User) { ...@@ -37,14 +37,14 @@ func GetUserDetail(userID string) (user *model.User) {
} }
func GetProfile() (user *model.User, err error) { func GetProfile() (user *model.User, err error) {
err = authClient.Get(UserProfileURL, &user) _, err = authClient.Get(UserProfileURL, &user)
return user, err return user, err
} }
func GetUserByUsername(username string) (user *model.User, err error) { func GetUserByUsername(username string) (user *model.User, err error) {
var users []*model.User var users []*model.User
payload := map[string]string{"username": username} payload := map[string]string{"username": username}
err = authClient.Get(UserListURL, &users, payload) _, err = authClient.Get(UserListURL, &users, payload)
if err != nil { if err != nil {
return return
} }
...@@ -61,7 +61,7 @@ func CheckUserOTP(seed, code string) (resp *AuthResp, err error) { ...@@ -61,7 +61,7 @@ func CheckUserOTP(seed, code string) (resp *AuthResp, err error) {
"seed": seed, "seed": seed,
"otp_code": code, "otp_code": code,
} }
err = client.Post(UserAuthOTPURL, data, &resp) _, err = client.Post(UserAuthOTPURL, data, &resp)
if err != nil { if err != nil {
return return
} }
...@@ -72,6 +72,6 @@ func CheckUserCookie(sessionID, csrfToken string) (user *model.User, err error) ...@@ -72,6 +72,6 @@ func CheckUserCookie(sessionID, csrfToken string) (user *model.User, err error)
cli := newClient() cli := newClient()
cli.SetCookie("csrftoken", csrfToken) cli.SetCookie("csrftoken", csrfToken)
cli.SetCookie("sessionid", sessionID) cli.SetCookie("sessionid", sessionID)
err = cli.Get(UserProfileURL, &user) _, err = cli.Get(UserProfileURL, &user)
return return
} }
...@@ -65,7 +65,7 @@ func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) { ...@@ -65,7 +65,7 @@ func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
if sc.PrivateKey != "" { if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey), []byte(sc.Password)); err != nil { if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey), []byte(sc.Password)); err != nil {
err = fmt.Errorf("parse private key error: %s", err) err = fmt.Errorf("parse private key error: %s", err)
return config, err logger.Error(err.Error())
} else { } else {
authMethods = append(authMethods, gossh.PublicKeys(signer)) authMethods = append(authMethods, gossh.PublicKeys(signer))
} }
...@@ -201,7 +201,7 @@ func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model. ...@@ -201,7 +201,7 @@ func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.
if !ok { if !ok {
return return
} }
if systemUser.Username == ""{ if systemUser.Username == "" {
systemUser.Username = client.Username systemUser.Username = client.Username
} }
var u = user.Username var u = user.Username
......
...@@ -186,12 +186,12 @@ func (u *UserSftp) RemoveDirectory(path string) error { ...@@ -186,12 +186,12 @@ func (u *UserSftp) RemoveDirectory(path string) error {
} }
err := u.removeDirectoryAll(conn.client, realPath) err := u.removeDirectoryAll(conn.client, realPath)
filename := realPath filename := realPath
isSucess := false isSuccess := false
operate := model.OperateRemoveDir operate := model.OperateRemoveDir
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return err return err
} }
...@@ -247,12 +247,12 @@ func (u *UserSftp) Remove(path string) error { ...@@ -247,12 +247,12 @@ func (u *UserSftp) Remove(path string) error {
} }
err := conn.client.Remove(realPath) err := conn.client.Remove(realPath)
filename := realPath filename := realPath
isSucess := false isSuccess := false
operate := model.OperateDelete operate := model.OperateDelete
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return err return err
} }
...@@ -283,12 +283,12 @@ func (u *UserSftp) MkdirAll(path string) error { ...@@ -283,12 +283,12 @@ func (u *UserSftp) MkdirAll(path string) error {
err := conn.client.MkdirAll(realPath) err := conn.client.MkdirAll(realPath)
filename := realPath filename := realPath
isSucess := false isSuccess := false
operate := model.OperateMkdir operate := model.OperateMkdir
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return err return err
} }
...@@ -320,12 +320,12 @@ func (u *UserSftp) Rename(oldNamePath, newNamePath string) error { ...@@ -320,12 +320,12 @@ func (u *UserSftp) Rename(oldNamePath, newNamePath string) error {
err := conn1.client.Rename(oldRealPath, newRealPath) err := conn1.client.Rename(oldRealPath, newRealPath)
filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath) filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath)
isSucess := false isSuccess := false
operate := model.OperateRename operate := model.OperateRename
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return err return err
} }
...@@ -357,12 +357,12 @@ func (u *UserSftp) Symlink(oldNamePath, newNamePath string) error { ...@@ -357,12 +357,12 @@ func (u *UserSftp) Symlink(oldNamePath, newNamePath string) error {
err := conn1.client.Symlink(oldRealPath, newRealPath) err := conn1.client.Symlink(oldRealPath, newRealPath)
filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath) filename := fmt.Sprintf("%s=>%s", oldRealPath, newRealPath)
isSucess := false isSuccess := false
operate := model.OperateSymlink operate := model.OperateSymlink
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return err return err
} }
...@@ -393,12 +393,12 @@ func (u *UserSftp) Create(path string) (*sftp.File, error) { ...@@ -393,12 +393,12 @@ func (u *UserSftp) Create(path string) (*sftp.File, error) {
} }
sf, err := conn.client.Create(realPath) sf, err := conn.client.Create(realPath)
filename := realPath filename := realPath
isSucess := false isSuccess := false
operate := model.OperateUpload operate := model.OperateUpload
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return sf, err return sf, err
} }
...@@ -427,12 +427,12 @@ func (u *UserSftp) Open(path string) (*sftp.File, error) { ...@@ -427,12 +427,12 @@ func (u *UserSftp) Open(path string) (*sftp.File, error) {
} }
sf, err := conn.client.Open(realPath) sf, err := conn.client.Open(realPath)
filename := realPath filename := realPath
isSucess := false isSuccess := false
operate := model.OperateDownaload operate := model.OperateDownaload
if err == nil { if err == nil {
isSucess = true isSuccess = true
} }
u.CreateFTPLog(host.asset, su, operate, filename, isSucess) u.CreateFTPLog(host.asset, su, operate, filename, isSuccess)
return sf, err return sf, err
} }
...@@ -593,6 +593,9 @@ func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) ...@@ -593,6 +593,9 @@ func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser)
func (u *UserSftp) Close() { func (u *UserSftp) Close() {
for _, client := range u.sftpClients { for _, client := range u.sftpClients {
if client == nil {
continue
}
client.Close() client.Close()
} }
close(u.LogChan) close(u.LogChan)
...@@ -645,6 +648,9 @@ type SftpConn struct { ...@@ -645,6 +648,9 @@ type SftpConn struct {
} }
func (s *SftpConn) Close() { func (s *SftpConn) Close() {
if s.client == nil {
return
}
_ = s.client.Close() _ = s.client.Close()
RecycleClient(s.conn) RecycleClient(s.conn)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment