Commit 1a883bef authored by Eric's avatar Eric

[Update] fix config data race and pagination related issues

parent 596eea80
......@@ -57,14 +57,14 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult)
}
func CheckUserPassword(ctx ssh.Context, password string) ssh.AuthResult {
if !config.Conf.PasswordAuth {
if !config.GetConf().PasswordAuth {
return ssh.AuthFailed
}
return checkAuth(ctx, password, "")
}
func CheckUserPublicKey(ctx ssh.Context, key ssh.PublicKey) ssh.AuthResult {
if !config.Conf.PublicKeyAuth {
if !config.GetConf().PublicKeyAuth {
return ssh.AuthFailed
}
b := key.Marshal()
......
......@@ -3,12 +3,12 @@ package common
import "sync"
func NewPagination(data []interface{}, size int) *Pagination {
return &Pagination{
p := &Pagination{
data: data,
pageSize: size,
currentPage: 1,
lock: new(sync.RWMutex),
}
p.SetPageSize(size)
return p
}
type Pagination struct {
......
package common
import (
"fmt"
"strings"
"github.com/olekukonko/tablewriter"
......@@ -131,9 +130,7 @@ func (t *WrapperTable) CalculateColumnsSize() {
if delta == 0 {
break
}
fmt.Println(t.fieldsSize)
}
fmt.Println(canChangeCols)
}
}
......@@ -165,7 +162,6 @@ func (t *WrapperTable) convertDataToSlice() [][]string {
func (t *WrapperTable) Display() string {
t.CalculateColumnsSize()
fmt.Println(t.fieldsSize)
tableString := &strings.Builder{}
table := tablewriter.NewWriter(tableString)
......
......@@ -2,11 +2,13 @@ package config
import (
"encoding/json"
"gopkg.in/yaml.v2"
"io/ioutil"
"log"
"os"
"strings"
"sync"
"gopkg.in/yaml.v2"
)
type Config struct {
......@@ -99,9 +101,10 @@ func (c *Config) Load(filepath string) error {
return err
}
var lock = new(sync.RWMutex)
var name, _ = os.Hostname()
var rootPath, _ = os.Getwd()
var Conf = Config{
var Conf = &Config{
Name: name,
CoreHost: "http://localhost:8080",
BootstrapToken: "",
......@@ -120,3 +123,19 @@ var Conf = Config{
ReplayStorage: map[string]string{"TYPE": "server"},
CommandStorage: map[string]string{"TYPE": "server"},
}
func SetConf(conf *Config) {
lock.Lock()
defer lock.Unlock()
Conf = conf
}
func GetConf() *Config {
lock.RLock()
defer lock.RUnlock()
var conf Config
if confBytes, err := json.Marshal(Conf); err == nil {
_ = json.Unmarshal(confBytes, &conf)
}
return &conf
}
......@@ -64,8 +64,9 @@ type ColorMeta struct {
func displayBanner(sess ssh.Session, user string) {
title := defaultTitle
if config.Conf.HeaderTitle != "" {
title = config.Conf.HeaderTitle
cf := config.GetConf()
if cf.HeaderTitle != "" {
title = cf.HeaderTitle
}
prefix := utils.CharClear + utils.CharTab + utils.CharTab
......
......@@ -32,31 +32,29 @@ func (p *AssetPagination) Initial() {
}
pageSize := p.getPageSize()
p.page = common.NewPagination(pageData, pageSize)
firstPageData := p.page.GetPageData(1)
p.currentData = make([]model.Asset, len(firstPageData))
for i, item := range firstPageData {
p.currentData[i] = item.(model.Asset)
}
}
func (p *AssetPagination) getPageSize() int {
var (
pageSize int
minHeight = 8 // 分页显示的最小高度
)
_, height := p.term.GetSize()
switch config.Conf.AssetListPageSize {
switch config.GetConf().AssetListPageSize {
case "auto":
pageSize = height - 8
pageSize = height - minHeight
case "all":
pageSize = len(p.assets)
default:
if value, err := strconv.Atoi(config.Conf.AssetListPageSize); err == nil {
if value, err := strconv.Atoi(config.GetConf().AssetListPageSize); err == nil {
pageSize = value
} else {
pageSize = height - 8
pageSize = height - minHeight
}
}
if pageSize <= 0 {
......@@ -66,20 +64,22 @@ func (p *AssetPagination) getPageSize() int {
}
func (p *AssetPagination) Start() []model.Asset {
p.term.SetPrompt(": ")
defer p.term.SetPrompt("Opt> ")
for {
// 当前页是第一个,如果当前页数据小于page size,显示所有
if p.page.CurrentPage() == 1 && p.page.GetPageSize() > len(p.currentData) {
// 总数据小于page size,则显示所有资产且退出
if p.page.GetPageSize() >= p.page.TotalCount() {
p.currentData = p.assets
p.displayPageAssets()
return []model.Asset{}
}
p.displayPageAssets()
p.displayTipsInfo()
line, err := p.term.ReadLine()
if err != nil {
return []model.Asset{}
}
pageSize := p.getPageSize()
p.page.SetPageSize(pageSize)
......@@ -91,11 +91,11 @@ func (p *AssetPagination) Start() []model.Asset {
if !p.page.HasPrePage() {
continue
}
tmpData := p.page.GetPrePageData()
if len(p.currentData) != len(tmpData) {
p.currentData = make([]model.Asset, len(tmpData))
prePageData := p.page.GetPrePageData()
if len(p.currentData) != len(prePageData) {
p.currentData = make([]model.Asset, len(prePageData))
}
for i, item := range tmpData {
for i, item := range prePageData {
p.currentData[i] = item.(model.Asset)
}
......@@ -103,11 +103,11 @@ func (p *AssetPagination) Start() []model.Asset {
if !p.page.HasNextPage() {
continue
}
tmpData := p.page.GetNextPageData()
if len(p.currentData) != len(tmpData) {
p.currentData = make([]model.Asset, len(tmpData))
nextPageData := p.page.GetNextPageData()
if len(p.currentData) != len(nextPageData) {
p.currentData = make([]model.Asset, len(nextPageData))
}
for i, item := range tmpData {
for i, item := range nextPageData {
p.currentData[i] = item.(model.Asset)
}
case "b", "q":
......@@ -145,7 +145,6 @@ func (p *AssetPagination) displayPageAssets() {
names[i] = systemUser[i].Name
}
row["systemUsers"] = strings.Join(names, ",")
fmt.Println(row["系统用户"], len(row["系统用户"]))
row["comment"] = j.Comment
data[i] = row
}
......
......@@ -11,7 +11,6 @@ import (
"github.com/gliderlabs/ssh"
"github.com/olekukonko/tablewriter"
"github.com/xlab/treeprint"
"golang.org/x/crypto/ssh/terminal"
"cocogo/pkg/cctx"
"cocogo/pkg/logger"
......@@ -175,17 +174,7 @@ func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) mo
return systemUsers[0]
default:
}
displaySystemUsers := make([]model.SystemUser, 0)
model.SortSystemUserByPriority(systemUsers)
highestPriority := systemUsers[length-1].Priority
displaySystemUsers = append(displaySystemUsers, systemUsers[length-1])
for i := length - 2; i >= 0; i-- {
if highestPriority == systemUsers[i].Priority {
displaySystemUsers = append(displaySystemUsers, systemUsers[i])
}
}
displaySystemUsers := selectHighestPrioritySystemUsers(systemUsers)
if len(displaySystemUsers) == 1 {
return displaySystemUsers[0]
}
......@@ -196,20 +185,21 @@ func (h *interactiveHandler) chooseSystemUser(systemUsers []model.SystemUser) mo
table.Append([]string{strconv.Itoa(i + 1), displaySystemUsers[i].Username})
}
table.SetBorder(false)
count := 0
term := terminal.NewTerminal(h.sess, "num:")
for count < 3 {
h.term.SetPrompt("Select User: ")
defer h.term.SetPrompt("Opt> ")
for count := 0; count < 3; count++ {
table.Render()
line, err := term.ReadLine()
line, err := h.term.ReadLine()
if err != nil {
continue
break
}
line = strings.TrimSpace(line)
if num, err := strconv.Atoi(line); err == nil {
if num > 0 && num <= len(displaySystemUsers) {
return displaySystemUsers[num-1]
}
}
count++
}
return displaySystemUsers[0]
}
......@@ -224,15 +214,14 @@ func (h *interactiveHandler) displayAssetsOrProxy(assets []model.Asset) {
} else {
h.displayAssets(assets)
}
}
func (h *interactiveHandler) displayAssets(assets model.AssetList) {
if len(assets) == 0 {
_, _ = io.WriteString(h.term, "\r\n No Assets\r\n\r")
} else {
h.term.SetPrompt(": ")
pag := NewAssetPagination(h.term, assets)
pag.Initial()
selectOneAssets := pag.Start()
if len(selectOneAssets) == 1 {
systemUser := h.chooseSystemUser(selectOneAssets[0].SystemUsers)
......@@ -240,7 +229,6 @@ func (h *interactiveHandler) displayAssets(assets model.AssetList) {
h.systemUserSelect = &systemUser
h.Proxy(context.TODO())
}
h.term.SetPrompt("Opt> ")
}
}
......
......@@ -10,8 +10,9 @@ import (
)
func init() {
localePath := path.Join(config.Conf.RootPath, "locale")
if strings.HasPrefix(config.Conf.Language, "zh") {
cf := config.GetConf()
localePath := path.Join(cf.RootPath, "locale")
if strings.HasPrefix(cf.Language, "zh") {
gotext.Configure(localePath, "zh_CN", "coco")
} else {
gotext.Configure(localePath, "en_US", "coco")
......
......@@ -25,7 +25,8 @@ func Initial() {
LogFormat: "%time% [%lvl%] %msg%",
TimestampFormat: "2006-01-02 15:04:05",
}
level, ok := logLevels[strings.ToUpper(config.Conf.LogLevel)]
conf := config.GetConf()
level, ok := logLevels[strings.ToUpper(conf.LogLevel)]
if !ok {
level = logrus.InfoLevel
}
......@@ -37,7 +38,7 @@ func Initial() {
logger.SetLevel(level)
// Output to file
logFilePath := path.Join(config.Conf.RootPath, "logs", "coco.log")
logFilePath := path.Join(conf.RootPath, "logs", "coco.log")
logDirPath := path.Dir(logFilePath)
if common.FileExists(logDirPath) {
err := os.MkdirAll(logDirPath, os.ModePerm)
......
......@@ -61,7 +61,7 @@ func (p *ProxyServer) getSSHConn() (srvConn *ServerSSHConnection, err error) {
port: strconv.Itoa(p.Asset.Port),
user: p.SystemUser.Username,
password: p.SystemUser.Password,
timeout: config.Conf.SSHTimeout,
timeout: config.GetConf().SSHTimeout,
}
pty := p.UserConn.Pty()
done := make(chan struct{})
......@@ -90,7 +90,8 @@ func (p *ProxyServer) sendConnectingMsg(done chan struct{}) {
delay := 0.0
msg := fmt.Sprintf(i18n.T("Connecting to %s@%s %.1f"), p.SystemUser.Username, p.Asset.Ip, delay)
utils.IgnoreErrWriteString(p.UserConn, msg)
for int(delay) < config.Conf.SSHTimeout {
cf := config.GetConf()
for int(delay) < cf.SSHTimeout {
select {
case <-done:
return
......
......@@ -14,8 +14,6 @@ import (
"cocogo/pkg/model"
)
var conf = config.Conf
func NewCommandRecorder(sess *SwitchSession) (recorder *CommandRecorder) {
recorder = &CommandRecorder{Session: sess}
recorder.initial()
......@@ -125,7 +123,7 @@ func (r *ReplyRecorder) Record(b []byte) {
func (r *ReplyRecorder) prepare() {
sessionId := r.Session.Id
rootPath := conf.RootPath
rootPath := config.GetConf().RootPath
today := time.Now().UTC().Format("2006-01-02")
gzFileName := sessionId + ".replay.gz"
replayDir := filepath.Join(rootPath, "data", "replays", today)
......
......@@ -17,7 +17,7 @@ type CommandStorage interface {
}
func NewReplayStorage() ReplayStorage {
cf := config.Conf.ReplayStorage
cf := config.GetConf().ReplayStorage
tp, ok := cf["TYPE"]
if !ok {
tp = "server"
......@@ -29,7 +29,7 @@ func NewReplayStorage() ReplayStorage {
}
func NewCommandStorage() CommandStorage {
cf := config.Conf.CommandStorage
cf := config.GetConf().CommandStorage
tp, ok := cf["TYPE"]
if !ok {
tp = "server"
......
......@@ -82,8 +82,9 @@ func (ak *AccessKey) SaveToFile() error {
}
func (ak *AccessKey) Register(times int) error {
name := config.Conf.Name
token := config.Conf.BootstrapToken
cf := config.GetConf()
name := cf.Name
token := cf.BootstrapToken
comment := "Coco"
res := RegisterTerminal(name, token, comment)
......
......@@ -16,16 +16,17 @@ var client = common.NewClient(30, "")
var authClient = common.NewClient(30, "")
func Initial() {
keyPath := config.Conf.AccessKeyFile
client.BaseHost = config.Conf.CoreHost
authClient.BaseHost = config.Conf.CoreHost
cf := config.GetConf()
keyPath := cf.AccessKeyFile
client.BaseHost = cf.CoreHost
authClient.BaseHost = cf.CoreHost
client.SetHeader("X-JMS-ORG", "ROOT")
authClient.SetHeader("X-JMS-ORG", "ROOT")
if !path.IsAbs(config.Conf.AccessKeyFile) {
keyPath = filepath.Join(config.Conf.RootPath, keyPath)
if !path.IsAbs(cf.AccessKeyFile) {
keyPath = filepath.Join(cf.RootPath, keyPath)
}
ak := AccessKey{Value: config.Conf.AccessKey, Path: keyPath}
ak := AccessKey{Value: cf.AccessKey, Path: keyPath}
_ = ak.Load()
authClient.Auth = ak
validateAccessAuth()
......@@ -34,6 +35,7 @@ func Initial() {
}
func validateAccessAuth() {
cf := config.GetConf()
maxTry := 30
count := 0
for {
......@@ -43,7 +45,7 @@ func validateAccessAuth() {
}
if err != nil {
msg := "Connect server error or access key is invalid, remove %s run again"
logger.Errorf(msg, config.Conf.AccessKeyFile)
logger.Errorf(msg, cf.AccessKeyFile)
} else if user.Role != "App" {
logger.Error("Access role is not App, is: ", user.Role)
}
......@@ -76,8 +78,13 @@ func MustLoadServerConfigOnce() {
}
func LoadConfigFromServer() (err error) {
err = authClient.Get(TerminalConfigURL, &config.Conf)
conf := config.GetConf()
err = authClient.Get(TerminalConfigURL, conf)
if err != nil {
return err
}
config.SetConf(conf)
return nil
}
func KeepSyncConfigWithServer() {
......
......@@ -16,11 +16,11 @@ type HostKey struct {
}
func (hk *HostKey) loadHostKeyFromFile(keyPath string) (signer ssh.Signer, err error) {
_, err = os.Stat(conf.HostKeyFile)
_, err = os.Stat(keyPath)
if err != nil {
return
}
buf, err := ioutil.ReadFile(conf.HostKeyFile)
buf, err := ioutil.ReadFile(keyPath)
if err != nil {
return
}
......
......@@ -11,9 +11,8 @@ import (
"cocogo/pkg/logger"
)
var conf = config.Conf
func StartServer() {
conf := config.GetConf()
hostKey := HostKey{Value: conf.HostKey, Path: conf.HostKeyFile}
logger.Debug("Loading host key")
signer, err := hostKey.Load()
......
......@@ -12,12 +12,12 @@ import (
)
var (
conf = config.Conf
httpServer *http.Server
cons = &connections{container: make(map[string]*WebConn), mu: new(sync.RWMutex)}
)
func StartHTTPServer() {
conf := config.GetConf()
server, err := socketio.NewServer(nil)
if err != nil {
logger.Fatal(err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment