Unverified Commit a125e737 authored by Eric_Lee's avatar Eric_Lee Committed by GitHub

[update] update new core API (#76)

parent 3a79c635
...@@ -14226,7 +14226,8 @@ $.fn.elfindercwd = function(fm, options) { ...@@ -14226,7 +14226,8 @@ $.fn.elfindercwd = function(fm, options) {
selectAll = function() { selectAll = function() {
var phash = fm.cwd().hash; var phash = fm.cwd().hash;
// fix select all display; remove cwd disable status
cwd.find('[id]:not(.'+clSelected+'):not(.elfinder-cwd-parent)').removeClass(clDisabled);
selectCheckbox && selectAllCheckbox.find('input').prop('checked', true); selectCheckbox && selectAllCheckbox.find('input').prop('checked', true);
fm.lazy(function() { fm.lazy(function() {
var files; var files;
......
...@@ -45,6 +45,7 @@ type Config struct { ...@@ -45,6 +45,7 @@ type Config struct {
Language string `yaml:"LANG"` Language string `yaml:"LANG"`
LanguageCode string `yaml:"LANGUAGE_CODE"` // Abandon LanguageCode string `yaml:"LANGUAGE_CODE"` // Abandon
UploadFailedReplay bool `yaml:"UPLOAD_FAILED_REPLAY_ON_START"` UploadFailedReplay bool `yaml:"UPLOAD_FAILED_REPLAY_ON_START"`
LoadPolicy string `yaml:"LOAD_POLICY"` // all, pagination
} }
func (c *Config) EnsureConfigValid() { func (c *Config) EnsureConfigValid() {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/i18n" "github.com/jumpserver/koko/pkg/i18n"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service"
"github.com/jumpserver/koko/pkg/utils" "github.com/jumpserver/koko/pkg/utils"
) )
...@@ -130,8 +131,8 @@ func (p *AssetPagination) Start() []model.Asset { ...@@ -130,8 +131,8 @@ func (p *AssetPagination) Start() []model.Asset {
} }
func (p *AssetPagination) displayPageAssets() { func (p *AssetPagination) displayPageAssets() {
Labels := []string{i18n.T("ID"), i18n.T("hostname"), i18n.T("IP"), i18n.T("systemUsers"), i18n.T("comment")} Labels := []string{i18n.T("ID"), i18n.T("hostname"), i18n.T("IP"), i18n.T("comment")}
fields := []string{"ID", "hostname", "IP", "systemUsers", "comment"} fields := []string{"ID", "hostname", "IP", "comment"}
data := make([]map[string]string, len(p.currentData)) data := make([]map[string]string, len(p.currentData))
for i, j := range p.currentData { for i, j := range p.currentData {
row := make(map[string]string) row := make(map[string]string)
...@@ -139,12 +140,6 @@ func (p *AssetPagination) displayPageAssets() { ...@@ -139,12 +140,6 @@ func (p *AssetPagination) displayPageAssets() {
row["hostname"] = j.Hostname row["hostname"] = j.Hostname
row["IP"] = j.IP row["IP"] = j.IP
systemUser := selectHighestPrioritySystemUsers(j.SystemUsers)
names := make([]string, len(systemUser))
for i := range systemUser {
names[i] = systemUser[i].Name
}
row["systemUsers"] = strings.Join(names, ",")
comments := make([]string, 0) comments := make([]string, 0)
for _, item := range strings.Split(strings.TrimSpace(j.Comment), "\r\n") { for _, item := range strings.Split(strings.TrimSpace(j.Comment), "\r\n") {
if strings.TrimSpace(item) == "" { if strings.TrimSpace(item) == "" {
...@@ -164,11 +159,10 @@ func (p *AssetPagination) displayPageAssets() { ...@@ -164,11 +159,10 @@ func (p *AssetPagination) displayPageAssets() {
Fields: fields, Fields: fields,
Labels: Labels, Labels: Labels,
FieldsSize: map[string][3]int{ FieldsSize: map[string][3]int{
"ID": {0, 0, 4}, "ID": {0, 0, 5},
"hostname": {0, 8, 0}, "hostname": {0, 8, 0},
"IP": {0, 15, 40}, "IP": {0, 15, 40},
"systemUsers": {0, 12, 0}, "comment": {0, 0, 0},
"comment": {0, 0, 0},
}, },
Data: data, Data: data,
TotalSize: w, TotalSize: w,
...@@ -191,3 +185,209 @@ func (p *AssetPagination) displayTipsInfo() { ...@@ -191,3 +185,209 @@ func (p *AssetPagination) displayTipsInfo() {
} }
} }
func NewUserPagination(term *utils.Terminal, uid, search string, proxy bool) *UserAssetPagination {
return &UserAssetPagination{
UserID: uid,
offset: 0,
limit: 0,
search: search,
term: term,
proxy: proxy,
Data: model.AssetsPaginationResponse{},
}
}
type UserAssetPagination struct {
UserID string
offset int
limit int
search string
term *utils.Terminal
proxy bool
Data model.AssetsPaginationResponse
}
func (p *UserAssetPagination) Start() []model.Asset {
p.term.SetPrompt(": ")
defer p.term.SetPrompt("Opt> ")
for {
p.retrieveData()
if p.proxy && p.Data.Total == 1 {
return p.Data.Data
}
// 无上下页,则退出循环
if p.Data.NextURL == "" && p.Data.PreviousURL == "" {
p.displayPageAssets()
return p.Data.Data
}
inLoop:
p.displayPageAssets()
p.displayTipsInfo()
line, err := p.term.ReadLine()
if err != nil {
return p.Data.Data
}
line = strings.TrimSpace(line)
switch len(line) {
case 0, 1:
switch strings.ToLower(line) {
case "p":
if p.Data.PreviousURL == "" {
continue
}
p.offset -= p.limit
case "", "n":
if p.Data.NextURL == "" {
continue
}
p.offset += p.limit
case "b", "q":
return []model.Asset{}
default:
if indexID, err := strconv.Atoi(line); err == nil {
if indexID > 0 && indexID <= len(p.Data.Data) {
return []model.Asset{p.Data.Data[indexID-1]}
}
}
goto inLoop
}
default:
if indexID, err := strconv.Atoi(line); err == nil {
if indexID > 0 && indexID <= len(p.Data.Data) {
return []model.Asset{p.Data.Data[indexID-1]}
}
}
goto inLoop
}
}
}
func (p *UserAssetPagination) displayPageAssets() {
if len(p.Data.Data) == 0 {
_, _ = p.term.Write([]byte(i18n.T("No Assets")))
_, _ = p.term.Write([]byte("\n\r"))
return
}
Labels := []string{i18n.T("ID"), i18n.T("hostname"), i18n.T("IP"), i18n.T("comment")}
fields := []string{"ID", "hostname", "IP", "comment"}
data := make([]map[string]string, len(p.Data.Data))
for i, j := range p.Data.Data {
row := make(map[string]string)
row["ID"] = strconv.Itoa(i + 1)
row["hostname"] = j.Hostname
row["IP"] = j.IP
comments := make([]string, 0)
for _, item := range strings.Split(strings.TrimSpace(j.Comment), "\r\n") {
if strings.TrimSpace(item) == "" {
continue
}
comments = append(comments, strings.ReplaceAll(strings.TrimSpace(item), " ", ","))
}
row["comment"] = strings.Join(comments, "|")
data[i] = row
}
w, _ := p.term.GetSize()
var pageSize int
var totalPage int
var currentPage int
var totalCount int
var currentOffset int
currentOffset = p.offset + len(p.Data.Data)
switch p.limit {
case 0:
pageSize = len(p.Data.Data)
totalCount = pageSize
totalPage = 1
currentPage = 1
default:
pageSize = p.limit
totalCount = p.Data.Total
switch totalCount % pageSize {
case 0:
totalPage = totalCount / pageSize
default:
totalPage = (totalCount / pageSize) + 1
}
switch currentOffset % pageSize {
case 0:
currentPage = currentOffset / pageSize
default:
currentPage = (currentOffset / pageSize) + 1
}
}
caption := fmt.Sprintf(i18n.T("Page: %d, Count: %d, Total Page: %d, Total Count: %d"),
currentPage, pageSize, totalPage, totalCount,
)
caption = utils.WrapperString(caption, utils.Green)
table := common.WrapperTable{
Fields: fields,
Labels: Labels,
FieldsSize: map[string][3]int{
"ID": {0, 0, 5},
"hostname": {0, 8, 0},
"IP": {0, 15, 40},
"comment": {0, 0, 0},
},
Data: data,
TotalSize: w,
Caption: caption,
TruncPolicy: common.TruncMiddle,
}
table.Initial()
_, _ = p.term.Write([]byte(utils.CharClear))
_, _ = p.term.Write([]byte(table.Display()))
}
func (p *UserAssetPagination) displayTipsInfo() {
tips := []string{
i18n.T("\nTips: Enter the asset ID and log directly into the asset.\n"),
i18n.T("\nPage up: P/p Page down: Enter|N/n BACK: b.\n"),
}
for _, tip := range tips {
_, _ = p.term.Write([]byte(tip))
}
}
func (p *UserAssetPagination) retrieveData() {
p.limit = GetPageSize(p.term)
if p.limit == 0 || p.offset < 0 || p.limit >= p.Data.Total {
p.offset = 0
}
p.Data = service.GetUserAssets(p.UserID, p.search, p.limit, p.offset)
}
func GetPageSize(term *utils.Terminal) int {
var (
pageSize int
minHeight = 8 // 分页显示的最小高度
)
_, height := term.GetSize()
conf := config.GetConf()
switch conf.AssetListPageSize {
case "auto":
pageSize = height - minHeight
case "all":
return 0
default:
if value, err := strconv.Atoi(conf.AssetListPageSize); err == nil {
pageSize = value
} else {
pageSize = height - minHeight
}
}
if pageSize <= 0 {
pageSize = 1
}
return pageSize
}
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"io" "io"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
...@@ -40,12 +39,9 @@ func newInteractiveHandler(sess ssh.Session, user *model.User) *interactiveHandl ...@@ -40,12 +39,9 @@ func newInteractiveHandler(sess ssh.Session, user *model.User) *interactiveHandl
wrapperSess := NewWrapperSession(sess) wrapperSess := NewWrapperSession(sess)
term := utils.NewTerminal(wrapperSess, "Opt> ") term := utils.NewTerminal(wrapperSess, "Opt> ")
handler := &interactiveHandler{ handler := &interactiveHandler{
sess: wrapperSess, sess: wrapperSess,
user: user, user: user,
term: term, term: term,
mu: new(sync.RWMutex),
nodeDataLoaded: make(chan struct{}),
assetDataLoaded: make(chan struct{}),
} }
handler.Initial() handler.Initial()
return handler return handler
...@@ -59,42 +55,33 @@ type interactiveHandler struct { ...@@ -59,42 +55,33 @@ type interactiveHandler struct {
assetSelect *model.Asset assetSelect *model.Asset
systemUserSelect *model.SystemUser systemUserSelect *model.SystemUser
assets model.AssetList
searchResult model.AssetList
nodes model.NodeList nodes model.NodeList
mu *sync.RWMutex searchResult []model.Asset
nodeDataLoaded chan struct{}
assetDataLoaded chan struct{} allAssets []model.Asset
search string
offset int
limit int
loadDataDone chan struct{}
loadPolicy string
} }
func (h *interactiveHandler) Initial() { func (h *interactiveHandler) Initial() {
h.loadPolicy = config.GetConf().LoadPolicy
h.displayBanner() h.displayBanner()
h.loadAssetsFromCache()
h.searchResult = make([]model.Asset, 0)
h.winWatchChan = make(chan bool) h.winWatchChan = make(chan bool)
h.loadDataDone = make(chan struct{})
go h.firstLoadData()
} }
func (h *interactiveHandler) loadAssetsFromCache() { func (h *interactiveHandler) firstLoadData() {
if assets, ok := service.GetUserAssetsFromCache(h.user.ID); ok {
h.assets = assets
close(h.assetDataLoaded)
} else {
h.assets = make([]model.Asset, 0)
}
go h.firstLoadAssetAndNodes()
}
func (h *interactiveHandler) firstLoadAssetAndNodes() {
h.loadUserAssets("1")
h.loadUserNodes("1") h.loadUserNodes("1")
logger.Debug("First load assets and nodes done") switch h.loadPolicy {
close(h.nodeDataLoaded) case "all":
select { h.loadAllAssets()
case <-h.assetDataLoaded:
return
default:
close(h.assetDataLoaded)
} }
close(h.loadDataDone)
} }
func (h *interactiveHandler) displayBanner() { func (h *interactiveHandler) displayBanner() {
...@@ -151,16 +138,14 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) { ...@@ -151,16 +138,14 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) {
break break
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
<-h.assetDataLoaded
switch len(line) { switch len(line) {
case 0, 1: case 0, 1:
switch strings.ToLower(line) { switch strings.ToLower(line) {
case "", "p": case "", "p":
h.mu.RLock() // 展示所有的资产
h.displayAssets(h.assets) h.displayAllAssets()
h.mu.RUnlock()
case "g": case "g":
<-h.nodeDataLoaded <-h.loadDataDone
h.displayNodes(h.nodes) h.displayNodes(h.nodes)
case "h": case "h":
h.displayBanner() h.displayBanner()
...@@ -170,8 +155,7 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) { ...@@ -170,8 +155,7 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) {
logger.Info("exit session") logger.Info("exit session")
return return
default: default:
assets := h.searchAsset(line) h.searchAssetOrProxy(line)
h.displayAssetsOrProxy(assets)
} }
default: default:
switch { switch {
...@@ -180,29 +164,36 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) { ...@@ -180,29 +164,36 @@ func (h *interactiveHandler) Dispatch(ctx cctx.Context) {
return return
case strings.Index(line, "/") == 0: case strings.Index(line, "/") == 0:
searchWord := strings.TrimSpace(line[1:]) searchWord := strings.TrimSpace(line[1:])
assets := h.searchAsset(searchWord) h.searchAsset(searchWord)
h.displayAssets(assets)
case strings.Index(line, "g") == 0: case strings.Index(line, "g") == 0:
searchWord := strings.TrimSpace(strings.TrimPrefix(line, "g")) searchWord := strings.TrimSpace(strings.TrimPrefix(line, "g"))
if num, err := strconv.Atoi(searchWord); err == nil { if num, err := strconv.Atoi(searchWord); err == nil {
if num >= 0 { if num >= 0 {
<-h.nodeDataLoaded
assets := h.searchNodeAssets(num) assets := h.searchNodeAssets(num)
h.displayAssets(assets) h.displayAssets(assets)
continue continue
} }
} }
assets := h.searchAsset(line) h.searchAssetOrProxy(line)
h.displayAssetsOrProxy(assets)
default: default:
assets := h.searchAsset(line) h.searchAssetOrProxy(line)
h.displayAssetsOrProxy(assets)
} }
} }
} }
} }
func (h *interactiveHandler) displayAllAssets() {
switch h.loadPolicy {
case "all":
<-h.loadDataDone
h.displayAssets(h.allAssets)
default:
h.searchAsset("")
}
}
func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) model.SystemUser { func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) model.SystemUser {
length := len(systemUsers) length := len(systemUsers)
switch length { switch length {
...@@ -244,19 +235,6 @@ func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) mo ...@@ -244,19 +235,6 @@ func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) mo
return displaySystemUsers[0] return displaySystemUsers[0]
} }
// 当资产的数量为1的时候,就进行代理转化
func (h *interactiveHandler) displayAssetsOrProxy(assets []model.Asset) {
if len(assets) == 1 {
systemUser := h.chooseSystemUser(assets[0].SystemUsers)
h.assetSelect = &assets[0]
h.systemUserSelect = &systemUser
h.Proxy(context.TODO())
} else {
h.displayAssets(assets)
}
}
func (h *interactiveHandler) displayAssets(assets model.AssetList) { func (h *interactiveHandler) displayAssets(assets model.AssetList) {
if len(assets) == 0 { if len(assets) == 0 {
_, _ = io.WriteString(h.term, i18n.T("No Assets")+"\n\r") _, _ = io.WriteString(h.term, i18n.T("No Assets")+"\n\r")
...@@ -265,18 +243,16 @@ func (h *interactiveHandler) displayAssets(assets model.AssetList) { ...@@ -265,18 +243,16 @@ func (h *interactiveHandler) displayAssets(assets model.AssetList) {
pag := NewAssetPagination(h.term, sortedAssets) pag := NewAssetPagination(h.term, sortedAssets)
selectOneAssets := pag.Start() selectOneAssets := pag.Start()
if len(selectOneAssets) == 1 { if len(selectOneAssets) == 1 {
systemUser := h.chooseSystemUser(selectOneAssets[0].SystemUsers) systemUsers := service.GetUserAssetSystemUsers(h.user.ID, selectOneAssets[0].ID)
systemUser := h.chooseSystemUser(systemUsers)
h.assetSelect = &selectOneAssets[0] h.assetSelect = &selectOneAssets[0]
h.systemUserSelect = &systemUser h.systemUserSelect = &systemUser
h.Proxy(context.TODO()) h.Proxy(context.TODO())
} }
if pag.page.PageSize() >= pag.page.TotalCount() { if pag.page.PageSize() >= pag.page.TotalCount() {
h.searchResult = sortedAssets h.searchResult = sortedAssets
} else {
h.searchResult = h.searchResult[:0]
} }
} }
} }
func (h *interactiveHandler) displayNodes(nodes []model.Node) { func (h *interactiveHandler) displayNodes(nodes []model.Node) {
...@@ -294,7 +270,10 @@ func (h *interactiveHandler) displayNodes(nodes []model.Node) { ...@@ -294,7 +270,10 @@ func (h *interactiveHandler) displayNodes(nodes []model.Node) {
} }
func (h *interactiveHandler) refreshAssetsAndNodesData() { func (h *interactiveHandler) refreshAssetsAndNodesData() {
h.loadUserAssets("2") switch h.loadPolicy {
case "all":
h.loadAllAssets()
}
h.loadUserNodes("2") h.loadUserNodes("2")
_, err := io.WriteString(h.term, i18n.T("Refresh done")+"\n\r") _, err := io.WriteString(h.term, i18n.T("Refresh done")+"\n\r")
if err != nil { if err != nil {
...@@ -302,45 +281,75 @@ func (h *interactiveHandler) refreshAssetsAndNodesData() { ...@@ -302,45 +281,75 @@ func (h *interactiveHandler) refreshAssetsAndNodesData() {
} }
} }
func (h *interactiveHandler) loadUserAssets(cachePolicy string) {
assets := service.GetUserAssets(h.user.ID, cachePolicy, "")
h.mu.Lock()
h.assets = assets
h.mu.Unlock()
}
func (h *interactiveHandler) loadUserNodes(cachePolicy string) { func (h *interactiveHandler) loadUserNodes(cachePolicy string) {
h.mu.Lock()
h.nodes = service.GetUserNodes(h.user.ID, cachePolicy) h.nodes = service.GetUserNodes(h.user.ID, cachePolicy)
h.mu.Unlock()
} }
func (h *interactiveHandler) searchAsset(key string) (assets []model.Asset) { func (h *interactiveHandler) loadAllAssets() {
h.allAssets = service.GetUserAllAssets(h.user.ID)
}
func (h *interactiveHandler) searchAsset(key string) {
switch h.loadPolicy {
case "all":
<-h.loadDataDone
var searchData []model.Asset
switch len(h.searchResult) {
case 0:
searchData = h.allAssets
default:
searchData = h.searchResult
}
assets := searchFromLocalAssets(searchData, key)
h.displayAssets(assets)
default:
pag := NewUserPagination(h.term, h.user.ID, key, false)
h.searchResult = pag.Start()
}
}
func (h *interactiveHandler) searchAssetOrProxy(key string) {
if indexNum, err := strconv.Atoi(key); err == nil && len(h.searchResult) > 0 { if indexNum, err := strconv.Atoi(key); err == nil && len(h.searchResult) > 0 {
if indexNum > 0 && indexNum <= len(h.searchResult) { if indexNum > 0 && indexNum <= len(h.searchResult) {
assets = []model.Asset{h.searchResult[indexNum-1]} assetSelect := h.searchResult[indexNum-1]
systemUsers := service.GetUserAssetSystemUsers(h.user.ID, assetSelect.ID)
systemUserSelect := h.chooseSystemUser(systemUsers)
h.systemUserSelect = &systemUserSelect
h.assetSelect = &assetSelect
h.Proxy(context.Background())
return return
} }
} }
var searchData []model.Asset var assets []model.Asset
switch len(h.searchResult) { switch h.loadPolicy {
case 0: case "all":
h.mu.RLock() <-h.loadDataDone
searchData = h.assets var searchData []model.Asset
h.mu.RUnlock() switch len(h.searchResult) {
case 0:
searchData = h.allAssets
default:
searchData = h.searchResult
}
assets = searchFromLocalAssets(searchData, key)
if len(assets) != 1 {
h.displayAssets(assets)
return
}
default: default:
searchData = h.searchResult pag := NewUserPagination(h.term, h.user.ID, key, true)
assets = pag.Start()
} }
key = strings.ToLower(key) if len(assets) == 1 {
for _, assetValue := range searchData { systemUsers := service.GetUserAssetSystemUsers(h.user.ID, assets[0].ID)
contents := []string{strings.ToLower(assetValue.Hostname), systemUserSelect := h.chooseSystemUser(systemUsers)
strings.ToLower(assetValue.IP), strings.ToLower(assetValue.Comment)} h.systemUserSelect = &systemUserSelect
if isSubstring(contents, key) { h.assetSelect = &assets[0]
assets = append(assets, assetValue) h.Proxy(context.Background())
} return
} }
return assets h.searchResult = assets
} }
func (h *interactiveHandler) searchNodeAssets(num int) (assets model.AssetList) { func (h *interactiveHandler) searchNodeAssets(num int) (assets model.AssetList) {
...@@ -422,96 +431,15 @@ func selectHighestPrioritySystemUsers(systemUsers []model.SystemUser) []model.Sy ...@@ -422,96 +431,15 @@ func selectHighestPrioritySystemUsers(systemUsers []model.SystemUser) []model.Sy
return result return result
} }
//func (h *InteractiveHandler) JoinShareRoom(roomID string) { func searchFromLocalAssets(assets model.AssetList, key string) []model.Asset {
//sshConn := userhome.NewSSHConn(h.sess) displayAssets := make([]model.Asset, 0, len(assets))
//ctx, cancelFuc := context.WithCancel(h.sess.Context()) key = strings.ToLower(key)
// for _, assetValue := range assets {
//_, winCh, _ := h.sess.Pty() contents := []string{strings.ToLower(assetValue.Hostname),
//go func() { strings.ToLower(assetValue.IP), strings.ToLower(assetValue.Comment)}
// for { if isSubstring(contents, key) {
// select { displayAssets = append(displayAssets, assetValue)
// case <-ctx.Done(): }
// return }
// case win, ok := <-winCh: return displayAssets
// if !ok { }
// return
// }
// fmt.Println("join term change:", win)
// }
// }
//}()
//proxybak.Manager.JoinShareRoom(roomID, sshConn)
//logger.Info("exit room id:", roomID)
//cancelFuc()
//
//}
// /*
// 1. 创建SSHConn,符合core.Conn接口
// 2. 创建一个session Home
// 3. 创建一个NodeConn,及相关的channel 可以是MemoryChannel 或者是redisChannel
// 4. session Home 与 proxy channel 交换数据
// */
// ptyReq, winChan, _ := i.sess.Pty()
// sshConn := userhome.NewSSHConn(i.sess)
// serverAuth := transport.ServerAuth{
// SessionID: uuid.NewV4().String(),
// IP: asset.IP,
// port: asset.port,
// Username: systemUser.Username,
// password: systemUser.password,
// PublicKey: parsePrivateKey(systemUser.privateKey)}
//
// nodeConn, err := transport.NewNodeConn(i.sess.Context(), serverAuth, ptyReq, winChan)
// if err != nil {
// logger.Error(err)
// return err
// }
// defer func() {
// nodeConn.Close()
// data := map[string]interface{}{
// "id": nodeConn.SessionID,
// "user": i.user.Username,
// "asset": asset.Hostname,
// "org_id": asset.OrgID,
// "system_user": systemUser.Username,
// "login_from": "ST",
// "remote_addr": i.sess.RemoteAddr().String(),
// "is_finished": true,
// "date_start": nodeConn.StartTime.Format("2006-01-02 15:04:05 +0000"),
// "date_end": time.Now().UTC().Format("2006-01-02 15:04:05 +0000"),
// }
// postData, _ := json.Marshal(data)
// appService.FinishSession(nodeConn.SessionID, postData)
// appService.FinishReply(nodeConn.SessionID)
// }()
// data := map[string]interface{}{
// "id": nodeConn.SessionID,
// "user": i.user.Username,
// "asset": asset.Hostname,
// "org_id": asset.OrgID,
// "system_user": systemUser.Username,
// "login_from": "ST",
// "remote_addr": i.sess.RemoteAddr().String(),
// "is_finished": false,
// "date_start": nodeConn.StartTime.Format("2006-01-02 15:04:05 +0000"),
// "date_end": nil,
// }
// postData, err := json.Marshal(data)
//
// if !appService.CreateSession(postData) {
// return err
// }
//
// memChan := transport.NewMemoryAgent(nodeConn)
//
// Home := userhome.NewUserSessionHome(sshConn)
// logger.Info("session Home ID: ", Home.SessionID())
//
// err = proxy.Manager.session(i.sess.Context(), Home, memChan)
// if err != nil {
// logger.Error(err)
// }
// return err
//}
//
...@@ -41,7 +41,7 @@ func SftpHandler(sess ssh.Session) { ...@@ -41,7 +41,7 @@ func SftpHandler(sess ssh.Session) {
} }
func NewSFTPHandler(user *model.User, addr string) *sftpHandler { func NewSFTPHandler(user *model.User, addr string) *sftpHandler {
assets := service.GetUserAssets(user.ID, "1", "") assets := service.GetUserAllAssets(user.ID)
return &sftpHandler{srvconn.NewUserSFTP(user, addr, assets...)} return &sftpHandler{srvconn.NewUserSFTP(user, addr, assets...)}
} }
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service" "github.com/jumpserver/koko/pkg/service"
"github.com/jumpserver/koko/pkg/srvconn" "github.com/jumpserver/koko/pkg/srvconn"
) )
func NewUserVolume(user *model.User, addr, hostId string) *UserVolume { func NewUserVolume(user *model.User, addr, hostId string) *UserVolume {
...@@ -24,9 +23,9 @@ func NewUserVolume(user *model.User, addr, hostId string) *UserVolume { ...@@ -24,9 +23,9 @@ func NewUserVolume(user *model.User, addr, hostId string) *UserVolume {
basePath := "/" basePath := "/"
switch hostId { switch hostId {
case "": case "":
assets = service.GetUserAssets(user.ID, "1", "") assets = service.GetUserAllAssets(user.ID)
default: default:
assets = service.GetUserAssets(user.ID, "1", hostId) assets = service.GetUserAssetByID(user.ID, hostId)
if len(assets) == 1 { if len(assets) == 1 {
homename = assets[0].Hostname homename = assets[0].Hostname
if assets[0].OrgID != "" { if assets[0].OrgID != "" {
...@@ -50,8 +49,8 @@ func NewUserVolume(user *model.User, addr, hostId string) *UserVolume { ...@@ -50,8 +49,8 @@ func NewUserVolume(user *model.User, addr, hostId string) *UserVolume {
type UserVolume struct { type UserVolume struct {
Uuid string Uuid string
*srvconn.UserSftp *srvconn.UserSftp
Homename string Homename string
basePath string basePath string
chunkFilesMap map[int]*sftp.File chunkFilesMap map[int]*sftp.File
lock *sync.Mutex lock *sync.Mutex
...@@ -142,13 +141,13 @@ func (u *UserVolume) GetFile(path string) (reader io.ReadCloser, err error) { ...@@ -142,13 +141,13 @@ func (u *UserVolume) GetFile(path string) (reader io.ReadCloser, err error) {
func (u *UserVolume) UploadFile(dirPath, uploadPath, filename string, reader io.Reader) (elfinder.FileDir, error) { func (u *UserVolume) UploadFile(dirPath, uploadPath, filename string, reader io.Reader) (elfinder.FileDir, error) {
var path string var path string
switch { switch {
case strings.Contains(uploadPath,filename): case strings.Contains(uploadPath, filename):
path = filepath.Join(dirPath, TrimPrefix(uploadPath)) path = filepath.Join(dirPath, TrimPrefix(uploadPath))
default: default:
path = filepath.Join(dirPath, filename) path = filepath.Join(dirPath, filename)
} }
logger.Debug("Volume upload file path: ", path," ", filename, " ",uploadPath) logger.Debug("Volume upload file path: ", path, " ", filename, " ", uploadPath)
var rest elfinder.FileDir var rest elfinder.FileDir
fd, err := u.UserSftp.Create(filepath.Join(u.basePath, path)) fd, err := u.UserSftp.Create(filepath.Join(u.basePath, path))
if err != nil { if err != nil {
...@@ -171,7 +170,7 @@ func (u *UserVolume) UploadChunk(cid int, dirPath, uploadPath, filename string, ...@@ -171,7 +170,7 @@ func (u *UserVolume) UploadChunk(cid int, dirPath, uploadPath, filename string,
u.lock.Unlock() u.lock.Unlock()
if !ok { if !ok {
switch { switch {
case strings.Contains(uploadPath,filename): case strings.Contains(uploadPath, filename):
path = filepath.Join(dirPath, TrimPrefix(uploadPath)) path = filepath.Join(dirPath, TrimPrefix(uploadPath))
case uploadPath != "": case uploadPath != "":
path = filepath.Join(dirPath, TrimPrefix(uploadPath), filename) path = filepath.Join(dirPath, TrimPrefix(uploadPath), filename)
...@@ -204,7 +203,7 @@ func (u *UserVolume) UploadChunk(cid int, dirPath, uploadPath, filename string, ...@@ -204,7 +203,7 @@ func (u *UserVolume) UploadChunk(cid int, dirPath, uploadPath, filename string,
func (u *UserVolume) MergeChunk(cid, total int, dirPath, uploadPath, filename string) (elfinder.FileDir, error) { func (u *UserVolume) MergeChunk(cid, total int, dirPath, uploadPath, filename string) (elfinder.FileDir, error) {
var path string var path string
switch { switch {
case strings.Contains(uploadPath,filename): case strings.Contains(uploadPath, filename):
path = filepath.Join(dirPath, TrimPrefix(uploadPath)) path = filepath.Join(dirPath, TrimPrefix(uploadPath))
case uploadPath != "": case uploadPath != "":
path = filepath.Join(dirPath, TrimPrefix(uploadPath), filename) path = filepath.Join(dirPath, TrimPrefix(uploadPath), filename)
...@@ -340,6 +339,6 @@ func hashPath(id, path string) string { ...@@ -340,6 +339,6 @@ func hashPath(id, path string) string {
return elfinder.CreateHash(id, path) return elfinder.CreateHash(id, path)
} }
func TrimPrefix(path string) string{ func TrimPrefix(path string) string {
return strings.TrimPrefix(path, "/") return strings.TrimPrefix(path, "/")
} }
\ No newline at end of file
package koko package koko
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
...@@ -34,18 +35,20 @@ func (c *Coco) Stop() { ...@@ -34,18 +35,20 @@ func (c *Coco) Stop() {
} }
func RunForever() { func RunForever() {
bootstrap() ctx,cancelFunc := context.WithCancel(context.Background())
bootstrap(ctx)
gracefulStop := make(chan os.Signal) gracefulStop := make(chan os.Signal)
signal.Notify(gracefulStop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) signal.Notify(gracefulStop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
app := &Coco{} app := &Coco{}
app.Start() app.Start()
<-gracefulStop <-gracefulStop
cancelFunc()
app.Stop() app.Stop()
} }
func bootstrap() { func bootstrap(ctx context.Context) {
config.Initial() config.Initial()
logger.Initial() logger.Initial()
service.Initial() service.Initial(ctx)
Initial() Initial()
} }
...@@ -77,29 +77,27 @@ func assetSortByHostName(asset1, asset2 *Asset) bool { ...@@ -77,29 +77,27 @@ func assetSortByHostName(asset1, asset2 *Asset) bool {
type NodeList []Node type NodeList []Node
type AssetsPaginationResponse struct {
Total int `json:"count"`
NextURL string `json:"next"`
PreviousURL string `json:"previous"`
Data []Asset `json:"results"`
}
type Asset struct { type Asset struct {
ID string `json:"id"` ID string `json:"id"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
IP string `json:"ip"` IP string `json:"ip"`
Port int `json:"port"`
SystemUsers []SystemUser `json:"system_users_granted"`
IsActive bool `json:"is_active"`
SystemUsersJoin string `json:"system_users_join"`
Os string `json:"os"` Os string `json:"os"`
Domain string `json:"domain"` Domain string `json:"domain"`
Platform string `json:"platform"` Platform string `json:"platform"`
Comment string `json:"comment"` Comment string `json:"comment"`
Protocol string `json:"protocol"`
Protocols []string `json:"protocols,omitempty"` Protocols []string `json:"protocols,omitempty"`
OrgID string `json:"org_id"` OrgID string `json:"org_id"`
OrgName string `json:"org_name"` OrgName string `json:"org_name"`
} }
func (a *Asset) ProtocolPort(protocol string) int { func (a *Asset) ProtocolPort(protocol string) int {
// 向下兼容
if a.Protocols == nil {
return a.Port
}
for _, item := range a.Protocols { for _, item := range a.Protocols {
if strings.Contains(strings.ToLower(item), strings.ToLower(protocol)) { if strings.Contains(strings.ToLower(item), strings.ToLower(protocol)) {
proAndPort := strings.Split(item, "/") proAndPort := strings.Split(item, "/")
...@@ -123,9 +121,6 @@ func (a *Asset) ProtocolPort(protocol string) int { ...@@ -123,9 +121,6 @@ func (a *Asset) ProtocolPort(protocol string) int {
} }
func (a *Asset) IsSupportProtocol(protocol string) bool { func (a *Asset) IsSupportProtocol(protocol string) bool {
if a.Protocols == nil {
return a.Protocol == protocol
}
for _, item := range a.Protocols { for _, item := range a.Protocols {
if strings.Contains(strings.ToLower(item), strings.ToLower(protocol)) { if strings.Contains(strings.ToLower(item), strings.ToLower(protocol)) {
return true return true
......
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"os" "os"
"path" "path"
...@@ -15,7 +16,7 @@ import ( ...@@ -15,7 +16,7 @@ import (
var client = common.NewClient(30, "") var client = common.NewClient(30, "")
var authClient = common.NewClient(30, "") var authClient = common.NewClient(30, "")
func Initial() { func Initial(ctx context.Context) {
cf := config.GetConf() cf := config.GetConf()
keyPath := cf.AccessKeyFile keyPath := cf.AccessKeyFile
client.BaseHost = cf.CoreHost client.BaseHost = cf.CoreHost
...@@ -31,7 +32,7 @@ func Initial() { ...@@ -31,7 +32,7 @@ func Initial() {
authClient.Auth = ak authClient.Auth = ak
validateAccessAuth() validateAccessAuth()
MustLoadServerConfigOnce() MustLoadServerConfigOnce()
go KeepSyncConfigWithServer() go KeepSyncConfigWithServer(ctx)
} }
func newClient() *common.Client { func newClient() *common.Client {
...@@ -94,12 +95,18 @@ func LoadConfigFromServer() (err error) { ...@@ -94,12 +95,18 @@ func LoadConfigFromServer() (err error) {
return nil return nil
} }
func KeepSyncConfigWithServer() { func KeepSyncConfigWithServer(ctx context.Context) {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for { for {
err := LoadConfigFromServer() select {
if err != nil { case <-ctx.Done():
logger.Warn("Sync config with server error: ", err) logger.Info("Sync config with server exit.")
case <-ticker.C:
err := LoadConfigFromServer()
if err != nil {
logger.Warn("Sync config with server error: ", err)
}
} }
time.Sleep(60 * time.Second)
} }
} }
...@@ -2,60 +2,55 @@ package service ...@@ -2,60 +2,55 @@ package service
import ( import (
"fmt" "fmt"
"sync" "strconv"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
) )
var userAssetsCached = assetsCacheContainer{ func GetUserAssets(userID, search string, pageSize, offset int) (resp model.AssetsPaginationResponse) {
mapData: make(map[string]model.AssetList), if pageSize < 0 {
mapETag: make(map[string]string), pageSize = 0
mu: new(sync.RWMutex),
}
var userNodesCached = nodesCacheContainer{
mapData: make(map[string]model.NodeList),
mapETag: make(map[string]string),
mu: new(sync.RWMutex),
}
func GetUserAssetsFromCache(userID string) (assets model.AssetList, ok bool) {
assets, ok = userAssetsCached.Get(userID)
return
}
func GetUserAssets(userID, cachePolicy, assetId string) (assets model.AssetList) {
if cachePolicy == "" {
cachePolicy = "1"
}
headers := make(map[string]string)
if etag, ok := userAssetsCached.GetETag(userID); ok && cachePolicy == "1" && assetId == "" {
headers["If-None-Match"] = etag
} }
payload := map[string]string{"cache_policy": cachePolicy} params := map[string]string{
if assetId != "" { "search": search,
payload["id"] = assetId "limit": strconv.Itoa(pageSize),
"offset": strconv.Itoa(offset),
} }
Url := fmt.Sprintf(UserAssetsURL, userID)
resp, err := authClient.Get(Url, &assets, payload, headers)
Url := fmt.Sprintf(UserAssetsURL, userID)
var err error
if pageSize > 0 {
_, err = authClient.Get(Url, &resp, params)
} else {
var data model.AssetList
_, err = authClient.Get(Url, &data, params)
resp.Data = data
}
if err != nil { if err != nil {
logger.Error("Get user assets error: ", err) logger.Error("Get user assets error: ", err)
return
} }
if resp.StatusCode == 200 && resp.Header.Get("ETag") != "" { return
newETag := resp.Header.Get("ETag") }
userAssetsCached.SetValue(userID, assets)
userAssetsCached.SetETag(userID, newETag) func GetUserAllAssets(userID string) (assets []model.Asset) {
} else if resp.StatusCode == 304 { Url := fmt.Sprintf(UserAssetsURL, userID)
assets, _ = userAssetsCached.Get(userID) _, err := authClient.Get(Url, &assets)
if err != nil {
logger.Error("Get user all assets error: ", err)
} }
return return
} }
func GetUserNodesFromCache(userID string) (nodes model.NodeList, ok bool) { func GetUserAssetByID(userID, assertID string) (assets []model.Asset) {
nodes, ok = userNodesCached.Get(userID) params := map[string]string{
"id": assertID,
}
Url := fmt.Sprintf(UserAssetsURL, userID)
_, err := authClient.Get(Url, &assets, params)
if err != nil {
logger.Error("Get user asset by ID error: ", err)
}
return return
} }
...@@ -63,21 +58,20 @@ func GetUserNodes(userID, cachePolicy string) (nodes model.NodeList) { ...@@ -63,21 +58,20 @@ func GetUserNodes(userID, cachePolicy string) (nodes model.NodeList) {
if cachePolicy == "" { if cachePolicy == "" {
cachePolicy = "1" cachePolicy = "1"
} }
headers := make(map[string]string)
if etag, ok := userNodesCached.GetETag(userID); ok && cachePolicy == "1" {
headers["If-None-Match"] = etag
}
payload := map[string]string{"cache_policy": cachePolicy} payload := map[string]string{"cache_policy": cachePolicy}
Url := fmt.Sprintf(UserNodesListURL, userID) Url := fmt.Sprintf(UserNodesListURL, userID)
resp, err := authClient.Get(Url, &nodes, payload, headers) _, err := authClient.Get(Url, &nodes, payload)
if err != nil { if err != nil {
logger.Error("Get user nodes error: ", err) logger.Error("Get user nodes error: ", err)
} }
if resp.StatusCode == 200 && resp.Header.Get("ETag") != "" { return
userNodesCached.SetValue(userID, nodes) }
userNodesCached.SetETag(userID, resp.Header.Get("ETag"))
} else if resp.StatusCode == 304 { func GetUserAssetSystemUsers(userID, assetID string) (sysUsers []model.SystemUser) {
nodes, _ = userNodesCached.Get(userID) Url := fmt.Sprintf(UserAssetSystemUsersURL, userID, assetID)
_, err := authClient.Get(Url, &sysUsers)
if err != nil {
logger.Error("Get user asset system users error: ", err)
} }
return return
} }
......
...@@ -32,3 +32,9 @@ const ( ...@@ -32,3 +32,9 @@ const (
UserNodeAssetsListURL = "/api/perms/v1/users/%s/nodes/%s/assets/" UserNodeAssetsListURL = "/api/perms/v1/users/%s/nodes/%s/assets/"
ValidateUserAssetPermissionURL = "/api/perms/v1/asset-permissions/user/validate/" //0不使用缓存 1 使用缓存 2 刷新缓存 ValidateUserAssetPermissionURL = "/api/perms/v1/asset-permissions/user/validate/" //0不使用缓存 1 使用缓存 2 刷新缓存
) )
// 1.5.3
const (
UserAssetSystemUsersURL = "/api/v1/perms/users/%s/assets/%s/system-users/" // 获取用户授权资产的系统用户列表
)
...@@ -80,6 +80,7 @@ func (u *UserSftp) ReadDir(path string) (res []os.FileInfo, err error) { ...@@ -80,6 +80,7 @@ func (u *UserSftp) ReadDir(path string) (res []os.FileInfo, err error) {
} }
return return
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return res, sftp.ErrSshFxNoSuchFile return res, sftp.ErrSshFxNoSuchFile
...@@ -120,6 +121,7 @@ func (u *UserSftp) Stat(path string) (res os.FileInfo, err error) { ...@@ -120,6 +121,7 @@ func (u *UserSftp) Stat(path string) (res os.FileInfo, err error) {
res = NewFakeFile(req.host, true) res = NewFakeFile(req.host, true)
return return
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return res, sftp.ErrSshFxNoSuchFile return res, sftp.ErrSshFxNoSuchFile
...@@ -148,7 +150,7 @@ func (u *UserSftp) ReadLink(path string) (res string, err error) { ...@@ -148,7 +150,7 @@ func (u *UserSftp) ReadLink(path string) (res string, err error) {
if req.su == "" { if req.su == "" {
return res, sftp.ErrSshFxPermissionDenied return res, sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return res, sftp.ErrSshFxNoSuchFile return res, sftp.ErrSshFxNoSuchFile
...@@ -175,6 +177,7 @@ func (u *UserSftp) RemoveDirectory(path string) error { ...@@ -175,6 +177,7 @@ func (u *UserSftp) RemoveDirectory(path string) error {
if req.su == "" { if req.su == "" {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return sftp.ErrSshFxNoSuchFile return sftp.ErrSshFxNoSuchFile
...@@ -236,7 +239,7 @@ func (u *UserSftp) Remove(path string) error { ...@@ -236,7 +239,7 @@ func (u *UserSftp) Remove(path string) error {
if req.su == "" { if req.su == "" {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return sftp.ErrSshFxNoSuchFile return sftp.ErrSshFxNoSuchFile
...@@ -273,6 +276,7 @@ func (u *UserSftp) MkdirAll(path string) error { ...@@ -273,6 +276,7 @@ func (u *UserSftp) MkdirAll(path string) error {
if req.su == "" { if req.su == "" {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return sftp.ErrSshFxNoSuchFile return sftp.ErrSshFxNoSuchFile
...@@ -309,6 +313,7 @@ func (u *UserSftp) Rename(oldNamePath, newNamePath string) error { ...@@ -309,6 +313,7 @@ func (u *UserSftp) Rename(oldNamePath, newNamePath string) error {
if !ok { if !ok {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req1.su] su, ok := host.suMaps[req1.su]
if !ok { if !ok {
return sftp.ErrSshFxNoSuchFile return sftp.ErrSshFxNoSuchFile
...@@ -346,6 +351,7 @@ func (u *UserSftp) Symlink(oldNamePath, newNamePath string) error { ...@@ -346,6 +351,7 @@ func (u *UserSftp) Symlink(oldNamePath, newNamePath string) error {
if !ok { if !ok {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req1.su] su, ok := host.suMaps[req1.su]
if !ok { if !ok {
return sftp.ErrSshFxNoSuchFile return sftp.ErrSshFxNoSuchFile
...@@ -383,7 +389,7 @@ func (u *UserSftp) Create(path string) (*sftp.File, error) { ...@@ -383,7 +389,7 @@ func (u *UserSftp) Create(path string) (*sftp.File, error) {
if req.su == "" { if req.su == "" {
return nil, sftp.ErrSshFxPermissionDenied return nil, sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return nil, sftp.ErrSshFxNoSuchFile return nil, sftp.ErrSshFxNoSuchFile
...@@ -420,6 +426,7 @@ func (u *UserSftp) Open(path string) (*sftp.File, error) { ...@@ -420,6 +426,7 @@ func (u *UserSftp) Open(path string) (*sftp.File, error) {
if req.su == "" { if req.su == "" {
return nil, sftp.ErrSshFxPermissionDenied return nil, sftp.ErrSshFxPermissionDenied
} }
host.loadSystemUsers(u.User.ID)
su, ok := host.suMaps[req.su] su, ok := host.suMaps[req.su]
if !ok { if !ok {
return nil, sftp.ErrSshFxNoSuchFile return nil, sftp.ErrSshFxNoSuchFile
...@@ -506,6 +513,7 @@ func (u *UserSftp) GetSFTPAndRealPath(req requestMessage) (conn *SftpConn, realP ...@@ -506,6 +513,7 @@ func (u *UserSftp) GetSFTPAndRealPath(req requestMessage) (conn *SftpConn, realP
func (u *UserSftp) HostHasUniqueSu(hostKey string) (string, bool) { func (u *UserSftp) HostHasUniqueSu(hostKey string) (string, bool) {
if host, ok := u.hosts[hostKey]; ok { if host, ok := u.hosts[hostKey]; ok {
host.loadSystemUsers(u.User.ID)
return host.HasUniqueSu() return host.HasUniqueSu()
} }
return "", false return "", false
...@@ -616,13 +624,7 @@ type requestMessage struct { ...@@ -616,13 +624,7 @@ type requestMessage struct {
} }
func NewHostnameDir(asset *model.Asset) *HostnameDir { func NewHostnameDir(asset *model.Asset) *HostnameDir {
sus := make(map[string]*model.SystemUser) h := HostnameDir{asset: asset}
for i := 0; i < len(asset.SystemUsers); i++ {
if asset.SystemUsers[i].Protocol == "ssh" {
sus[asset.SystemUsers[i].Name] = &asset.SystemUsers[i]
}
}
h := HostnameDir{asset: asset, suMaps: sus}
return &h return &h
} }
...@@ -631,6 +633,19 @@ type HostnameDir struct { ...@@ -631,6 +633,19 @@ type HostnameDir struct {
suMaps map[string]*model.SystemUser suMaps map[string]*model.SystemUser
} }
func (h *HostnameDir) loadSystemUsers(userID string) {
if h.suMaps == nil {
sus := make(map[string]*model.SystemUser)
SystemUsers := service.GetUserAssetSystemUsers(userID, h.asset.ID)
for i := 0; i < len(SystemUsers); i++ {
if SystemUsers[i].Protocol == "ssh" {
sus[SystemUsers[i].Name] = &SystemUsers[i]
}
}
h.suMaps = sus
}
}
func (h *HostnameDir) HasUniqueSu() (string, bool) { func (h *HostnameDir) HasUniqueSu() (string, bool) {
sus := h.GetSystemUsers() sus := h.GetSystemUsers()
if len(sus) == 1 { if len(sus) == 1 {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment