Commit e7f9565f authored by ibuler's avatar ibuler

[Update] 修改使用client cache

parent 22beb659
...@@ -33,7 +33,7 @@ func (c *Coco) Stop() { ...@@ -33,7 +33,7 @@ func (c *Coco) Stop() {
} }
func RunForever() { func RunForever() {
loadingBoot() bootstrap()
gracefulStop := make(chan os.Signal) gracefulStop := make(chan os.Signal)
signal.Notify(gracefulStop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) signal.Notify(gracefulStop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
app := &Coco{} app := &Coco{}
...@@ -42,7 +42,7 @@ func RunForever() { ...@@ -42,7 +42,7 @@ func RunForever() {
app.Stop() app.Stop()
} }
func loadingBoot() { func bootstrap() {
config.Initial() config.Initial()
logger.Initial() logger.Initial()
service.Initial() service.Initial()
......
...@@ -16,13 +16,13 @@ import ( ...@@ -16,13 +16,13 @@ import (
func Initial() { func Initial() {
conf := config.GetConf() conf := config.GetConf()
if conf.UploadFailedReplay { if conf.UploadFailedReplay {
go uploadFailedReplay(conf.RootPath) go uploadRemainReplay(conf.RootPath)
} }
go keepHeartbeat(conf.HeartbeatDuration) go keepHeartbeat(conf.HeartbeatDuration)
} }
func uploadFailedReplay(rootPath string) { func uploadRemainReplay(rootPath string) {
replayDir := filepath.Join(rootPath, "data", "replays") replayDir := filepath.Join(rootPath, "data", "replays")
err := common.EnsureDirExist(replayDir) err := common.EnsureDirExist(replayDir)
if err != nil { if err != nil {
...@@ -47,11 +47,11 @@ func uploadFailedReplay(rootPath string) { ...@@ -47,11 +47,11 @@ func uploadFailedReplay(rootPath string) {
} }
return nil return nil
}) })
logger.Debug("upload Replay Done") logger.Debug("Upload remain replay done")
} }
func keepHeartbeat(interval int) { func keepHeartbeat(interval time.Duration) {
tick := time.Tick(time.Duration(interval) * time.Second) tick := time.Tick(interval * time.Second)
for { for {
select { select {
case <-tick: case <-tick:
...@@ -63,6 +63,5 @@ func keepHeartbeat(interval int) { ...@@ -63,6 +63,5 @@ func keepHeartbeat(interval int) {
} }
} }
} }
} }
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"time"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
...@@ -20,9 +21,9 @@ type Config struct { ...@@ -20,9 +21,9 @@ type Config struct {
PublicKeyAuth bool `json:"TERMINAL_PUBLIC_KEY_AUTH" yaml:"PUBLIC_KEY_AUTH"` PublicKeyAuth bool `json:"TERMINAL_PUBLIC_KEY_AUTH" yaml:"PUBLIC_KEY_AUTH"`
CommandStorage map[string]interface{} `json:"TERMINAL_COMMAND_STORAGE"` CommandStorage map[string]interface{} `json:"TERMINAL_COMMAND_STORAGE"`
ReplayStorage map[string]interface{} `json:"TERMINAL_REPLAY_STORAGE" yaml:"REPLAY_STORAGE"` ReplayStorage map[string]interface{} `json:"TERMINAL_REPLAY_STORAGE" yaml:"REPLAY_STORAGE"`
SessionKeepDuration int `json:"TERMINAL_SESSION_KEEP_DURATION"` SessionKeepDuration time.Duration `json:"TERMINAL_SESSION_KEEP_DURATION"`
TelnetRegex string `json:"TERMINAL_TELNET_REGEX"` TelnetRegex string `json:"TERMINAL_TELNET_REGEX"`
MaxIdleTime int `json:"SECURITY_MAX_IDLE_TIME"` MaxIdleTime time.Duration `json:"SECURITY_MAX_IDLE_TIME"`
SftpRoot string `json:"TERMINAL_SFTP_ROOT" yaml:"SFTP_ROOT"` SftpRoot string `json:"TERMINAL_SFTP_ROOT" yaml:"SFTP_ROOT"`
Name string `yaml:"NAME"` Name string `yaml:"NAME"`
SecretKey string `yaml:"SECRET_KEY"` SecretKey string `yaml:"SECRET_KEY"`
...@@ -30,13 +31,13 @@ type Config struct { ...@@ -30,13 +31,13 @@ type Config struct {
CoreHost string `yaml:"CORE_HOST"` CoreHost string `yaml:"CORE_HOST"`
BootstrapToken string `yaml:"BOOTSTRAP_TOKEN"` BootstrapToken string `yaml:"BOOTSTRAP_TOKEN"`
BindHost string `yaml:"BIND_HOST"` BindHost string `yaml:"BIND_HOST"`
SSHPort int `yaml:"SSHD_PORT"` SSHPort string `yaml:"SSHD_PORT"`
HTTPPort int `yaml:"HTTPD_PORT"` HTTPPort string `yaml:"HTTPD_PORT"`
SSHTimeout int `yaml:"SSH_TIMEOUT"` SSHTimeout time.Duration `yaml:"SSH_TIMEOUT"`
AccessKey string `yaml:"ACCESS_KEY"` AccessKey string `yaml:"ACCESS_KEY"`
AccessKeyFile string `yaml:"ACCESS_KEY_FILE"` AccessKeyFile string `yaml:"ACCESS_KEY_FILE"`
LogLevel string `yaml:"LOG_LEVEL"` LogLevel string `yaml:"LOG_LEVEL"`
HeartbeatDuration int `yaml:"HEARTBEAT_INTERVAL"` HeartbeatDuration time.Duration `yaml:"HEARTBEAT_INTERVAL"`
RootPath string `yaml:"ROOT_PATH"` RootPath string `yaml:"ROOT_PATH"`
Comment string `yaml:"COMMENT"` Comment string `yaml:"COMMENT"`
Language string `yaml:"LANG"` Language string `yaml:"LANG"`
...@@ -110,9 +111,9 @@ var Conf = &Config{ ...@@ -110,9 +111,9 @@ var Conf = &Config{
CoreHost: "http://localhost:8080", CoreHost: "http://localhost:8080",
BootstrapToken: "", BootstrapToken: "",
BindHost: "0.0.0.0", BindHost: "0.0.0.0",
SSHPort: 2222, SSHPort: "2222",
SSHTimeout: 15, SSHTimeout: 15,
HTTPPort: 5000, HTTPPort: "5000",
HeartbeatDuration: 10, HeartbeatDuration: 10,
AccessKey: "", AccessKey: "",
AccessKeyFile: "data/keys/.access_key", AccessKeyFile: "data/keys/.access_key",
......
...@@ -40,7 +40,10 @@ func (c *assetsCacheContainer) SetValue(key string, value []model.Asset) { ...@@ -40,7 +40,10 @@ func (c *assetsCacheContainer) SetValue(key string, value []model.Asset) {
c.mapData[key] = value c.mapData[key] = value
} }
var userAssetsCached = assetsCacheContainer{mapData: make(map[string][]model.Asset), lock: new(sync.RWMutex)} var userAssetsCached = assetsCacheContainer{
mapData: make(map[string][]model.Asset),
lock: new(sync.RWMutex),
}
func SessionHandler(sess ssh.Session) { func SessionHandler(sess ssh.Session) {
pty, _, ok := sess.Pty() pty, _, ok := sess.Pty()
......
...@@ -2,7 +2,6 @@ package httpd ...@@ -2,7 +2,6 @@ package httpd
import ( import (
"net/http" "net/http"
"strconv"
"sync" "sync"
"github.com/googollee/go-socket.io" "github.com/googollee/go-socket.io"
...@@ -36,6 +35,6 @@ func StartHTTPServer() { ...@@ -36,6 +35,6 @@ func StartHTTPServer() {
http.Handle("/socket.io/", server) http.Handle("/socket.io/", server)
logger.Debug("start HTTP Serving ", conf.HTTPPort) logger.Debug("start HTTP Serving ", conf.HTTPPort)
httpServer = &http.Server{Addr: conf.BindHost + ":" + strconv.Itoa(conf.HTTPPort), Handler: nil} httpServer = &http.Server{Addr: conf.BindHost + ":" + conf.HTTPPort, Handler: nil}
logger.Fatal(httpServer.ListenAndServe()) logger.Fatal(httpServer.ListenAndServe())
} }
...@@ -91,6 +91,23 @@ type Asset struct { ...@@ -91,6 +91,23 @@ type Asset struct {
OrgName string `json:"org_name"` OrgName string `json:"org_name"`
} }
type Gateway struct {
ID string `json:"id"`
Name string `json:"Name"`
IP string `json:"ip"`
Port int `json:"port"`
Protocol string `json:"protocol"`
Username string `json:"username"`
Password string `json:"password"`
PrivateKey string `json:"private_key"`
}
type Domain struct {
ID string `json:"id"`
Gateways []Gateway `json:"gateways"`
Name string `json:"name"`
}
type Node struct { type Node struct {
Id string `json:"id"` Id string `json:"id"`
Key string `json:"key"` Key string `json:"key"`
......
...@@ -80,11 +80,11 @@ func (p *Parser) initial() { ...@@ -80,11 +80,11 @@ func (p *Parser) initial() {
p.cmdRecordChan = make(chan [2]string, 1024) p.cmdRecordChan = make(chan [2]string, 1024)
} }
func (p *Parser) Parse() { func (p *Parser) ParseStream() {
defer func() { defer func() {
close(p.userOutputChan) close(p.userOutputChan)
close(p.srvOutputChan) close(p.srvOutputChan)
logger.Debug("Parser parse routine done") logger.Debug("Parser parse stream routine done")
}() }()
for { for {
select { select {
......
package proxy package proxy
import ( import (
"cocogo/pkg/srvconn"
"cocogo/pkg/utils"
"fmt" "fmt"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
...@@ -14,6 +11,8 @@ import ( ...@@ -14,6 +11,8 @@ import (
"cocogo/pkg/logger" "cocogo/pkg/logger"
"cocogo/pkg/model" "cocogo/pkg/model"
"cocogo/pkg/service" "cocogo/pkg/service"
"cocogo/pkg/srvconn"
"cocogo/pkg/utils"
) )
type ProxyServer struct { type ProxyServer struct {
...@@ -23,11 +22,8 @@ type ProxyServer struct { ...@@ -23,11 +22,8 @@ type ProxyServer struct {
SystemUser *model.SystemUser SystemUser *model.SystemUser
} }
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
func (p *ProxyServer) getSystemUserAuthOrManualSet() { func (p *ProxyServer) getSystemUserAuthOrManualSet() {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.Id, p.Asset.Id)
p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey
if p.SystemUser.LoginMode == model.LoginModeManual || if p.SystemUser.LoginMode == model.LoginModeManual ||
(p.SystemUser.Password == "" && p.SystemUser.PrivateKey == "") { (p.SystemUser.Password == "" && p.SystemUser.PrivateKey == "") {
term := utils.NewTerminal(p.UserConn, "password: ") term := utils.NewTerminal(p.UserConn, "password: ")
...@@ -35,11 +31,16 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() { ...@@ -35,11 +31,16 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() {
if err != nil { if err != nil {
logger.Errorf("Get password from user err %s", err.Error()) logger.Errorf("Get password from user err %s", err.Error())
} }
logger.Debug("Get password from user input: ", line)
p.SystemUser.Password = line p.SystemUser.Password = line
logger.Debug("Get password from user input: ", line)
} else {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.Id, p.Asset.Id)
p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey
} }
} }
// getSystemUserUsernameIfNeed 获取系统用户用户名,或手动设置
func (p *ProxyServer) getSystemUserUsernameIfNeed() { func (p *ProxyServer) getSystemUserUsernameIfNeed() {
if p.SystemUser.Username == "" { if p.SystemUser.Username == "" {
var username string var username string
...@@ -52,69 +53,76 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() { ...@@ -52,69 +53,76 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() {
} }
} }
p.SystemUser.Username = username p.SystemUser.Username = username
logger.Info("Get username from user input: ", username) logger.Debug("Get username from user input: ", username)
} }
} }
// checkProtocolMatch 检查协议是否匹配
func (p *ProxyServer) checkProtocolMatch() bool { func (p *ProxyServer) checkProtocolMatch() bool {
return p.SystemUser.Protocol == p.Asset.Protocol return p.SystemUser.Protocol == p.Asset.Protocol
} }
// checkProtocolIsGraph 检查协议是否是图形化的
func (p *ProxyServer) checkProtocolIsGraph() bool { func (p *ProxyServer) checkProtocolIsGraph() bool {
switch p.Asset.Protocol { switch p.Asset.Protocol {
case "ssh", "telnet": case "ssh", "telnet":
return true
default:
return false return false
default:
return true
} }
} }
// validatePermission 检查是否有权限连接
func (p *ProxyServer) validatePermission() bool { func (p *ProxyServer) validatePermission() bool {
return true return service.ValidateUserAssetPermission(
p.User.ID, p.Asset.Id, p.SystemUser.Id, "connect",
)
} }
func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err error) { // getSSHConn 获取ssh连接
proxyConfig := &srvconn.SSHClientConfig{} func (p *ProxyServer) getSSHConn(fromCache ...bool) (srvConn *srvconn.ServerSSHConnection, err error) {
sshConfig := srvconn.SSHClientConfig{ pty := p.UserConn.Pty()
Host: p.Asset.Ip,
Port: strconv.Itoa(p.Asset.Port),
User: p.SystemUser.Username,
Password: p.SystemUser.Password,
PrivateKey: p.SystemUser.PrivateKey,
Overtime: config.GetConf().SSHTimeout,
Proxy: proxyConfig,
}
srvConn = &srvconn.ServerSSHConnection{ srvConn = &srvconn.ServerSSHConnection{
Name: p.Asset.Hostname, User: p.User,
Creator: p.User.Username, Asset: p.Asset,
SSHClientConfig: sshConfig, SystemUser: p.SystemUser,
Overtime: time.Duration(config.GetConf().SSHTimeout) * time.Second,
} }
pty := p.UserConn.Pty() if len(fromCache) > 0 && fromCache[0] {
err = srvConn.TryConnectFromCache(pty.Window.Height, pty.Window.Width, pty.Term)
} else {
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term) err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
fmt.Println("Error: ", err) }
return return
} }
// getTelnetConn 获取telnet连接
func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection, err error) { func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection, err error) {
conf := config.GetConf() conf := config.GetConf()
cusString := conf.TelnetRegex cusString := conf.TelnetRegex
pattern, _ := regexp.Compile(cusString) pattern, _ := regexp.Compile(cusString)
srvConn = &srvconn.ServerTelnetConnection{ srvConn = &srvconn.ServerTelnetConnection{
Name: p.Asset.Hostname, User: p.User,
Creator: p.User.ID, Asset: p.Asset,
Host: p.Asset.Ip, SystemUser: p.SystemUser,
Port: strconv.Itoa(p.Asset.Port),
User: p.SystemUser.Username,
Password: p.SystemUser.Password,
CustomString: cusString, CustomString: cusString,
CustomSuccessPattern: pattern, CustomSuccessPattern: pattern,
Overtime: conf.SSHTimeout, Overtime: time.Duration(conf.SSHTimeout) * time.Second,
} }
err = srvConn.Connect(0, 0, "") err = srvConn.Connect(0, 0, "")
utils.IgnoreErrWriteString(p.UserConn, "\r\n") utils.IgnoreErrWriteString(p.UserConn, "\r\n")
return return
} }
// getServerConnFromCache 从cache中获取ssh server连接
func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection, err error) {
if p.SystemUser.Protocol == "ssh" {
srvConn, err = p.getSSHConn(true)
}
return
}
// getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) { func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
p.getSystemUserUsernameIfNeed() p.getSystemUserUsernameIfNeed()
p.getSystemUserAuthOrManualSet() p.getSystemUserAuthOrManualSet()
...@@ -123,19 +131,20 @@ func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err err ...@@ -123,19 +131,20 @@ func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err err
utils.IgnoreErrWriteString(p.UserConn, "\r\n") utils.IgnoreErrWriteString(p.UserConn, "\r\n")
close(done) close(done)
}() }()
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout) go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second)
if p.Asset.Protocol == "telnet" { if p.Asset.Protocol == "telnet" {
return p.getTelnetConn() return p.getTelnetConn()
} else { } else {
return p.getSSHConn() return p.getSSHConn(false)
} }
} }
func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delaySecond int) { // sendConnectingMsg 发送连接信息
func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delayDuration time.Duration) {
delay := 0.0 delay := 0.0
msg := fmt.Sprintf(i18n.T("Connecting to %s@%s %.1f"), p.SystemUser.Username, p.Asset.Ip, delay) msg := fmt.Sprintf(i18n.T("Connecting to %s@%s %.1f"), p.SystemUser.Username, p.Asset.Ip, delay)
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
for int(delay) < delaySecond { for int(delay) < int(delayDuration/time.Second) {
select { select {
case <-done: case <-done:
return return
...@@ -149,6 +158,7 @@ func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delaySecond int) { ...@@ -149,6 +158,7 @@ func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delaySecond int) {
} }
} }
// preCheckRequisite 检查是否满足条件
func (p *ProxyServer) preCheckRequisite() (ok bool) { func (p *ProxyServer) preCheckRequisite() (ok bool) {
if !p.checkProtocolMatch() { if !p.checkProtocolMatch() {
msg := utils.WrapperWarn(i18n.T("System user <%s> and asset <%s> protocol are inconsistent.")) msg := utils.WrapperWarn(i18n.T("System user <%s> and asset <%s> protocol are inconsistent."))
...@@ -156,7 +166,7 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) { ...@@ -156,7 +166,7 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) {
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
return return
} }
if !p.checkProtocolIsGraph() { if p.checkProtocolIsGraph() {
msg := i18n.T("Terminal only support protocol ssh/telnet, please use web terminal to access") msg := i18n.T("Terminal only support protocol ssh/telnet, please use web terminal to access")
msg = utils.WrapperWarn(msg) msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
...@@ -170,11 +180,17 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) { ...@@ -170,11 +180,17 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) {
return true return true
} }
// Proxy 代理
func (p *ProxyServer) Proxy() { func (p *ProxyServer) Proxy() {
if !p.preCheckRequisite() { if !p.preCheckRequisite() {
return return
} }
srvConn, err := p.getServerConn() // 先从cache中获取srv连接, 如果没有获得,则连接
srvConn, err := p.getServerConnFromCache()
if err != nil || srvConn == nil {
srvConn, err = p.getServerConn()
}
if err != nil { if err != nil {
msg := fmt.Sprintf("Connect asset %s error: %s\n\r", p.Asset.Hostname, err) msg := fmt.Sprintf("Connect asset %s error: %s\n\r", p.Asset.Hostname, err)
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
...@@ -193,8 +209,11 @@ func (p *ProxyServer) Proxy() { ...@@ -193,8 +209,11 @@ func (p *ProxyServer) Proxy() {
sw.SetFilterRules(cmdRules) sw.SetFilterRules(cmdRules)
AddSession(sw) AddSession(sw)
_ = sw.Bridge(p.UserConn, srvConn) _ = sw.Bridge(p.UserConn, srvConn)
defer func() {
_ = srvConn.Close()
p.finishSession(sw) p.finishSession(sw)
RemoveSession(sw) RemoveSession(sw)
}()
} }
func (p *ProxyServer) createSession(s *SwitchSession) bool { func (p *ProxyServer) createSession(s *SwitchSession) bool {
...@@ -212,7 +231,7 @@ func (p *ProxyServer) finishSession(s *SwitchSession) { ...@@ -212,7 +231,7 @@ func (p *ProxyServer) finishSession(s *SwitchSession) {
data := s.MapData() data := s.MapData()
service.FinishSession(data) service.FinishSession(data)
service.FinishReply(s.Id) service.FinishReply(s.Id)
logger.Debugf("finish session: %s", s.Id) logger.Debugf("Finish session: %s", s.Id)
} }
func (p *ProxyServer) GetFilterRules() []model.SystemUserFilterRule { func (p *ProxyServer) GetFilterRules() []model.SystemUserFilterRule {
......
...@@ -31,7 +31,7 @@ type SwitchSession struct { ...@@ -31,7 +31,7 @@ type SwitchSession struct {
DateActive time.Time DateActive time.Time
finished bool finished bool
MaxIdleTime int MaxIdleTime time.Duration
cmdRecorder *CommandRecorder cmdRecorder *CommandRecorder
replayRecorder *ReplyRecorder replayRecorder *ReplyRecorder
...@@ -50,9 +50,7 @@ func (s *SwitchSession) Initial() { ...@@ -50,9 +50,7 @@ func (s *SwitchSession) Initial() {
s.MaxIdleTime = config.GetConf().MaxIdleTime s.MaxIdleTime = config.GetConf().MaxIdleTime
s.cmdRecorder = NewCommandRecorder(s.Id) s.cmdRecorder = NewCommandRecorder(s.Id)
s.replayRecorder = NewReplyRecord(s.Id) s.replayRecorder = NewReplyRecord(s.Id)
s.parser = newParser() s.parser = newParser()
s.ctx, s.cancel = context.WithCancel(context.Background()) s.ctx, s.cancel = context.WithCancel(context.Background())
} }
...@@ -75,6 +73,7 @@ func (s *SwitchSession) recordCommand() { ...@@ -75,6 +73,7 @@ func (s *SwitchSession) recordCommand() {
} }
} }
// generateCommandResult 生成命令结果
func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command { func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command {
var input string var input string
var output string var output string
...@@ -104,6 +103,7 @@ func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command ...@@ -104,6 +103,7 @@ func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command
} }
} }
// postBridge 桥接结束以后执行操作
func (s *SwitchSession) postBridge() { func (s *SwitchSession) postBridge() {
s.DateEnd = time.Now().UTC().Format("2006-01-02 15:04:05 +0000") s.DateEnd = time.Now().UTC().Format("2006-01-02 15:04:05 +0000")
s.finished = true s.finished = true
...@@ -114,26 +114,31 @@ func (s *SwitchSession) postBridge() { ...@@ -114,26 +114,31 @@ func (s *SwitchSession) postBridge() {
_ = s.srvTran.Close() _ = s.srvTran.Close()
} }
// SetFilterRules 设置命令过滤规则
func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) { func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) {
s.parser.SetCMDFilterRules(cmdRules) s.parser.SetCMDFilterRules(cmdRules)
} }
// Bridge 桥接两个链接
func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) { func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) {
winCh := userConn.WinCh() winCh := userConn.WinCh()
// 将ReadWriter转换为Channel读写
s.srvTran = NewDirectTransport(s.Id, srvConn) s.srvTran = NewDirectTransport(s.Id, srvConn)
s.userTran = NewDirectTransport(s.Id, userConn) s.userTran = NewDirectTransport(s.Id, userConn)
defer func() { defer func() {
logger.Info("session bridge done: ", s.Id) logger.Info("Session bridge done: ", s.Id)
s.postBridge()
}() }()
go s.parser.Parse() // 处理数据流
go s.parser.ParseStream()
// 记录命令
go s.recordCommand() go s.recordCommand()
defer s.postBridge()
for { for {
select { select {
// 检测是否超过最大空闲时间 // 检测是否超过最大空闲时间
case <-time.After(time.Duration(s.MaxIdleTime) * time.Minute): case <-time.After(s.MaxIdleTime * time.Minute):
msg := fmt.Sprintf(i18n.T("Connect idle more than %d minutes, disconnect"), s.MaxIdleTime) msg := fmt.Sprintf(i18n.T("Connect idle more than %d minutes, disconnect"), s.MaxIdleTime)
msg = utils.WrapperWarn(msg) msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(s.userTran, "\n\r"+msg) utils.IgnoreErrWriteString(s.userTran, "\n\r"+msg)
......
...@@ -16,15 +16,6 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System ...@@ -16,15 +16,6 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System
return return
} }
func GetSystemUserAuthInfo(systemUserID string) (info model.SystemUserAuthInfo) {
Url := fmt.Sprintf(SystemUserAuthInfoURL, systemUserID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Error("Get system user auth info failed")
}
return
}
func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilterRule, err error) { func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilterRule, err error) {
/*[ /*[
{ {
...@@ -76,7 +67,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt ...@@ -76,7 +67,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
} }
func GetSystemUser(systemUserID string) (info model.SystemUser) { func GetSystemUser(systemUserID string) (info model.SystemUser) {
Url := fmt.Sprintf(SystemUser, systemUserID) Url := fmt.Sprintf(SystemUserDetailURL, systemUserID)
err := authClient.Get(Url, &info) err := authClient.Get(Url, &info)
if err != nil { if err != nil {
logger.Errorf("Get system user %s failed", systemUserID) logger.Errorf("Get system user %s failed", systemUserID)
...@@ -85,16 +76,25 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) { ...@@ -85,16 +76,25 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) {
} }
func GetAsset(assetID string) (asset model.Asset) { func GetAsset(assetID string) (asset model.Asset) {
Url := fmt.Sprintf(Asset, assetID) Url := fmt.Sprintf(AssetDetailURL, assetID)
err := authClient.Get(Url, &asset) err := authClient.Get(Url, &asset)
if err != nil { if err != nil {
logger.Errorf("Get Asset %s failed", assetID) logger.Errorf("Get Asset %s failed\n", assetID)
}
return
}
func GetDomainWithGateway(gID string) (domain model.Domain) {
url := fmt.Sprintf(DomainDetailURL, gID)
err := authClient.Get(url, &domain)
if err != nil {
logger.Errorf("Get domain %s failed", gID)
} }
return return
} }
func GetTokenAsset(token string) (tokenUser model.TokenUser) { func GetTokenAsset(token string) (tokenUser model.TokenUser) {
Url := fmt.Sprintf(TokenAsset, token) Url := fmt.Sprintf(TokenAssetUrl, token)
err := authClient.Get(Url, &tokenUser) err := authClient.Get(Url, &tokenUser)
if err != nil { if err != nil {
logger.Error("Get Token Asset info failed") logger.Error("Get Token Asset info failed")
......
...@@ -5,16 +5,14 @@ const ( ...@@ -5,16 +5,14 @@ const (
UserProfileURL = "/api/users/v1/profile/" // 获取当前用户的基本信息 UserProfileURL = "/api/users/v1/profile/" // 获取当前用户的基本信息
UserListUrl = "/api/users/v1/users/" // 用户列表地址 UserListUrl = "/api/users/v1/users/" // 用户列表地址
UserDetailURL = "/api/users/v1/users/%s/" // 获取用户信息 UserDetailURL = "/api/users/v1/users/%s/" // 获取用户信息
UserAuthOTPURL = "/api/users/v1/otp/auth/" // 验证OTP UserAuthOTPURL = "/api/authentication/v1/otp/auth/" // 验证OTP
AuthMFAURL = "/api/authentication/v1/otp/auth/" // MFA 验证用户信息
SystemUserAssetAuthURL = "/api/assets/v1/system-user/%s/asset/%s/auth-info/" // 该系统用户对某资产的授权 SystemUserAssetAuthURL = "/api/assets/v1/system-user/%s/asset/%s/auth-info/" // 该系统用户对某资产的授权
SystemUserAuthInfoURL = "/api/assets/v1/system-user/%s/auth-info/" // 该系统用户的授权
SystemUserCmdFilterRules = "/api/assets/v1/system-user/%s/cmd-filter-rules/" // 过滤规则url SystemUserCmdFilterRules = "/api/assets/v1/system-user/%s/cmd-filter-rules/" // 过滤规则url
SystemUser = "/api/assets/v1/system-user/%s" // 某个系统用户的信息 SystemUserDetailURL = "/api/assets/v1/system-user/%s/" // 某个系统用户的信息
Asset = "/api/assets/v1/assets/%s/" // 某一个资产信息 AssetDetailURL = "/api/assets/v1/assets/%s/" // 某一个资产信息
TokenAsset = "/api/users/v1/connection-token/?token=%s" // Token name DomainDetailURL = "/api/assets/v1/domain/%s/"
TokenAssetUrl = "/api/users/v1/connection-token/?token=%s" // Token name
TerminalRegisterURL = "/api/terminal/v2/terminal-registrations/" // 注册当前coco TerminalRegisterURL = "/api/terminal/v2/terminal-registrations/" // 注册当前coco
TerminalConfigURL = "/api/terminal/v1/terminal/config/" // 从jumpserver获取coco的配置 TerminalConfigURL = "/api/terminal/v1/terminal/config/" // 从jumpserver获取coco的配置
......
...@@ -24,10 +24,6 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) ( ...@@ -24,10 +24,6 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) (
"login_type": loginType, "login_type": loginType,
} }
err = client.Post(UserAuthURL, data, &resp) err = client.Post(UserAuthURL, data, &resp)
if err != nil {
logger.Error(err)
}
return return
} }
......
package srvconn
import (
"io"
"time"
)
type ServerConnection interface {
io.ReadWriteCloser
Timeout() time.Duration
Protocol() string
SetWinSize(w, h int) error
}
package srvconn package srvconn
import ( import (
"cocogo/pkg/service"
"errors"
"fmt" "fmt"
"net"
"strconv" "strconv"
"sync" "sync"
"time"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"cocogo/pkg/config"
"cocogo/pkg/logger" "cocogo/pkg/logger"
"cocogo/pkg/model" "cocogo/pkg/model"
) )
...@@ -18,21 +21,139 @@ var ( ...@@ -18,21 +21,139 @@ var (
clientLock = new(sync.RWMutex) clientLock = new(sync.RWMutex)
) )
func newClient(user *model.User, asset *model.Asset, type SSHClientConfig struct {
systemUser *model.SystemUser) (client *gossh.Client, err error) { Host string
cfg := SSHClientConfig{ Port string
User string
Password string
PrivateKey string
PrivateKeyPath string
Timeout time.Duration
Proxy []*SSHClientConfig
proxyConn gossh.Conn
}
func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
authMethods := make([]gossh.AuthMethod, 0)
if sc.Password != "" {
authMethods = append(authMethods, gossh.Password(sc.Password))
}
if sc.PrivateKeyPath != "" {
if pubkey, err := GetPubKeyFromFile(sc.PrivateKeyPath); err != nil {
err = fmt.Errorf("parse private key from file error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(pubkey))
}
}
if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(sc.PrivateKey)); err != nil {
err = fmt.Errorf("parse private key error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(signer))
}
}
config = &gossh.ClientConfig{
User: sc.User,
Auth: authMethods,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Timeout: sc.Timeout,
}
return config, nil
}
func (sc *SSHClientConfig) DialProxy() (client *gossh.Client, err error) {
for _, p := range sc.Proxy {
client, err = p.Dial()
if err == nil {
return
}
}
return
}
func (sc *SSHClientConfig) Dial() (client *gossh.Client, err error) {
cfg, err := sc.Config()
if err != nil {
return
}
if len(sc.Proxy) > 0 {
logger.Debugf("Dial host proxy first")
proxyClient, err := sc.DialProxy()
if err != nil {
err = errors.New("connect proxy host error 1: " + err.Error())
logger.Error("Connect proxy host error 1: ", err.Error())
return client, err
}
proxySock, err := proxyClient.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port))
if err != nil {
err = errors.New("connect proxy host error 2: " + err.Error())
logger.Error("Connect proxy host error 2: ", err.Error())
return client, err
}
proxyConn, chans, reqs, err := gossh.NewClientConn(proxySock, net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return client, err
}
sc.proxyConn = proxyConn
client = gossh.NewClient(proxyConn, chans, reqs)
} else {
logger.Debugf("Dial host %s:%s", sc.Host, sc.Port)
client, err = gossh.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return
}
}
return client, nil
}
func (sc *SSHClientConfig) String() string {
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port)
}
func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
proxyConfigs := make([]*SSHClientConfig, 0)
// 如果有网关则从网关中连接
if asset.Domain != "" {
domain := service.GetDomainWithGateway(asset.Domain)
if domain.ID != "" && len(domain.Gateways) > 1 {
for _, gateway := range domain.Gateways {
proxyConfigs = append(proxyConfigs, &SSHClientConfig{
Host: gateway.IP,
Port: strconv.Itoa(gateway.Port),
User: gateway.Username,
Password: gateway.Password,
PrivateKey: gateway.PrivateKey,
Timeout: timeout,
})
}
}
}
sshConfig := SSHClientConfig{
Host: asset.Ip, Host: asset.Ip,
Port: strconv.Itoa(asset.Port), Port: strconv.Itoa(asset.Port),
User: systemUser.Username, User: systemUser.Username,
Password: systemUser.Password, Password: systemUser.Password,
PrivateKey: systemUser.PrivateKey, PrivateKey: systemUser.PrivateKey,
Overtime: config.GetConf().SSHTimeout, Timeout: timeout,
Proxy: proxyConfigs,
}
sshConfig = SSHClientConfig{
Host: "127.0.0.1",
Port: "22",
User: "root",
Password: "redhat",
Proxy: []*SSHClientConfig{
{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"},
},
} }
client, err = cfg.Dial() client, err = sshConfig.Dial()
return return
} }
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client, err error) { func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.Id, systemUser.Id) key := fmt.Sprintf("%s_%s_%s", user.ID, asset.Id, systemUser.Id)
clientLock.RLock() clientLock.RLock()
client, ok := sshClients[key] client, ok := sshClients[key]
...@@ -52,7 +173,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse ...@@ -52,7 +173,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return client, nil return client, nil
} }
client, err = newClient(user, asset, systemUser) client, err = newClient(asset, systemUser, timeout)
if err == nil { if err == nil {
clientLock.Lock() clientLock.Lock()
sshClients[key] = client sshClients[key] = client
...@@ -62,6 +183,18 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse ...@@ -62,6 +183,18 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return return
} }
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.Id, systemUser.Id)
clientLock.Lock()
defer clientLock.Unlock()
client, ok := sshClients[key]
if !ok {
return
}
clientsRefCounter[client]++
return
}
func RecycleClient(client *gossh.Client) { func RecycleClient(client *gossh.Client) {
clientLock.Lock() clientLock.Lock()
defer clientLock.Unlock() defer clientLock.Unlock()
......
...@@ -10,7 +10,9 @@ var testConnection = SSHClientConfig{ ...@@ -10,7 +10,9 @@ var testConnection = SSHClientConfig{
Port: "22", Port: "22",
User: "root", User: "root",
Password: "redhat", Password: "redhat",
Proxy: &SSHClientConfig{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"}, Proxy: []*SSHClientConfig{
{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"},
},
} }
func TestSSHConnection_Config(t *testing.T) { func TestSSHConnection_Config(t *testing.T) {
......
package srvconn package srvconn
import ( import (
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"io" "io"
"net"
"time" "time"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
)
type ServerConnection interface {
io.ReadWriteCloser
Timeout() time.Duration
Protocol() string
SetWinSize(w, h int) error
}
type SSHClientConfig struct {
Host string
Port string
User string
Password string
PrivateKey string
PrivateKeyPath string
Overtime int
Proxy *SSHClientConfig
proxyConn gossh.Conn
}
func (sc *SSHClientConfig) Timeout() time.Duration {
if sc.Overtime == 0 {
sc.Overtime = 30
}
return time.Duration(sc.Overtime) * time.Second
}
func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
authMethods := make([]gossh.AuthMethod, 0)
if sc.Password != "" {
authMethods = append(authMethods, gossh.Password(sc.Password))
}
if sc.PrivateKeyPath != "" {
if pubkey, err := GetPubKeyFromFile(sc.PrivateKeyPath); err != nil {
err = fmt.Errorf("parse private key from file error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(pubkey))
}
}
if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(sc.PrivateKey)); err != nil {
err = fmt.Errorf("parse private key error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(signer))
}
}
config = &gossh.ClientConfig{
User: sc.User,
Auth: authMethods,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Timeout: sc.Timeout(),
}
return config, nil
}
func (sc *SSHClientConfig) Dial() (client *gossh.Client, err error) {
cfg, err := sc.Config()
if err != nil {
return
}
if sc.Proxy != nil && sc.Proxy.Host != "" {
proxyClient, err := sc.Proxy.Dial()
if err != nil {
err = errors.New("connect proxy Host error1: " + err.Error())
return client, err
}
proxySock, err := proxyClient.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port))
if err != nil {
err = errors.New("connect proxy Host error2: " + err.Error())
return client, err
}
proxyConn, chans, reqs, err := gossh.NewClientConn(proxySock, net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return client, err
}
sc.proxyConn = proxyConn
client = gossh.NewClient(proxyConn, chans, reqs)
} else {
client, err = gossh.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return
}
}
return client, nil
}
func (sc *SSHClientConfig) String() string { "cocogo/pkg/model"
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port) )
}
type ServerSSHConnection struct { type ServerSSHConnection struct {
SSHClientConfig User *model.User
Name string Asset *model.Asset
Creator string SystemUser *model.SystemUser
Overtime time.Duration
client *gossh.Client client *gossh.Client
session *gossh.Session session *gossh.Session
...@@ -113,16 +22,13 @@ type ServerSSHConnection struct { ...@@ -113,16 +22,13 @@ type ServerSSHConnection struct {
stdout io.Reader stdout io.Reader
closed bool closed bool
refCount int refCount int
connected bool
} }
func (sc *ServerSSHConnection) Protocol() string { func (sc *ServerSSHConnection) Protocol() string {
return "ssh" return "ssh"
} }
func (sc *ServerSSHConnection) String() string {
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port)
}
func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) { func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
sess, err := sc.client.NewSession() sess, err := sc.client.NewSession()
if err != nil { if err != nil {
...@@ -151,14 +57,29 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) { ...@@ -151,14 +57,29 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
} }
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) { func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
sc.client, err = sc.Dial() sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout())
if err != nil {
return
}
err = sc.invokeShell(h, w, term)
if err != nil { if err != nil {
return return
} }
sc.connected = true
return nil
}
func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err error) {
sc.client = GetClientFromCache(sc.User, sc.Asset, sc.SystemUser)
if sc.client == nil {
return errors.New("No client in cache")
}
err = sc.invokeShell(h, w, term) err = sc.invokeShell(h, w, term)
if err != nil { if err != nil {
return return
} }
sc.connected = true
return nil return nil
} }
...@@ -174,15 +95,21 @@ func (sc *ServerSSHConnection) Write(p []byte) (n int, err error) { ...@@ -174,15 +95,21 @@ func (sc *ServerSSHConnection) Write(p []byte) (n int, err error) {
return sc.stdin.Write(p) return sc.stdin.Write(p)
} }
func (sc *ServerSSHConnection) Timeout() time.Duration {
if sc.Overtime == 0 {
sc.Overtime = 30 * time.Second
}
return sc.Overtime
}
func (sc *ServerSSHConnection) Close() (err error) { func (sc *ServerSSHConnection) Close() (err error) {
if sc.closed { RecycleClient(sc.client)
if sc.closed || !sc.connected {
return return
} }
err = sc.session.Close() err = sc.session.Close()
err = sc.client.Close() if err != nil {
if sc.proxyConn != nil { return
err = sc.proxyConn.Close()
} }
sc.closed = true
return return
} }
...@@ -2,12 +2,13 @@ package srvconn ...@@ -2,12 +2,13 @@ package srvconn
import ( import (
"bytes" "bytes"
"cocogo/pkg/model"
"errors"
"net" "net"
"regexp" "regexp"
"strconv"
"time" "time"
"github.com/pkg/errors"
"cocogo/pkg/logger" "cocogo/pkg/logger"
) )
...@@ -45,13 +46,10 @@ const ( ...@@ -45,13 +46,10 @@ const (
) )
type ServerTelnetConnection struct { type ServerTelnetConnection struct {
Name string User *model.User
Creator string Asset *model.Asset
Host string SystemUser *model.SystemUser
Port string Overtime time.Duration
User string
Password string
Overtime int
CustomString string CustomString string
CustomSuccessPattern *regexp.Regexp CustomSuccessPattern *regexp.Regexp
...@@ -61,9 +59,9 @@ type ServerTelnetConnection struct { ...@@ -61,9 +59,9 @@ type ServerTelnetConnection struct {
func (tc *ServerTelnetConnection) Timeout() time.Duration { func (tc *ServerTelnetConnection) Timeout() time.Duration {
if tc.Overtime == 0 { if tc.Overtime == 0 {
tc.Overtime = 30 tc.Overtime = 30 * time.Second
} }
return time.Duration(tc.Overtime) * time.Second return tc.Overtime
} }
func (tc *ServerTelnetConnection) Protocol() string { func (tc *ServerTelnetConnection) Protocol() string {
...@@ -120,12 +118,12 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus { ...@@ -120,12 +118,12 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus {
if incorrectPattern.Match(data) { if incorrectPattern.Match(data) {
return AuthFailed return AuthFailed
} else if usernamePattern.Match(data) { } else if usernamePattern.Match(data) {
_, _ = tc.conn.Write([]byte(tc.User + "\r\n")) _, _ = tc.conn.Write([]byte(tc.SystemUser.Username + "\r\n"))
logger.Debug("usernamePattern ", tc.User) logger.Debug("usernamePattern ", tc.User)
return AuthPartial return AuthPartial
} else if passwordPattern.Match(data) { } else if passwordPattern.Match(data) {
_, _ = tc.conn.Write([]byte(tc.Password + "\r\n")) _, _ = tc.conn.Write([]byte(tc.SystemUser.Password + "\r\n"))
logger.Debug("passwordPattern ", tc.Password) logger.Debug("passwordPattern ", tc.SystemUser.Password)
return AuthPartial return AuthPartial
} else if successPattern.Match(data) { } else if successPattern.Match(data) {
return AuthSuccess return AuthSuccess
...@@ -139,7 +137,9 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus { ...@@ -139,7 +137,9 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus {
} }
func (tc *ServerTelnetConnection) Connect(h, w int, term string) (err error) { func (tc *ServerTelnetConnection) Connect(h, w int, term string) (err error) {
conn, err := net.DialTimeout("tcp", net.JoinHostPort(tc.Host, tc.Port), tc.Timeout()) var ip = tc.Asset.Ip
var port = strconv.Itoa(tc.Asset.Port)
conn, err := net.DialTimeout("tcp", net.JoinHostPort(ip, port), tc.Timeout())
if err != nil { if err != nil {
return return
} }
......
package sshd package sshd
import ( import (
"strconv" "net"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
...@@ -22,9 +22,10 @@ func StartServer() { ...@@ -22,9 +22,10 @@ func StartServer() {
logger.Fatal("Load host key error: ", err) logger.Fatal("Load host key error: ", err)
} }
logger.Infof("Start ssh server at %s:%d", conf.BindHost, conf.SSHPort) addr := net.JoinHostPort(conf.BindHost, conf.SSHPort)
logger.Infof("Start ssh server at %s", addr)
sshServer = &ssh.Server{ sshServer = &ssh.Server{
Addr: conf.BindHost + ":" + strconv.Itoa(conf.SSHPort), Addr: addr,
KeyboardInteractiveHandler: auth.CheckMFA, KeyboardInteractiveHandler: auth.CheckMFA,
PasswordHandler: auth.CheckUserPassword, PasswordHandler: auth.CheckUserPassword,
PublicKeyHandler: auth.CheckUserPublicKey, PublicKeyHandler: auth.CheckUserPublicKey,
...@@ -33,7 +34,7 @@ func StartServer() { ...@@ -33,7 +34,7 @@ func StartServer() {
Handler: handler.SessionHandler, Handler: handler.SessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{}, SubsystemHandlers: map[string]ssh.SubsystemHandler{},
} }
// Set Auth Handler // Set sftp handler
sshServer.SetSubsystemHandler("sftp", handler.SftpHandler) sshServer.SetSubsystemHandler("sftp", handler.SftpHandler)
logger.Fatal(sshServer.ListenAndServe()) logger.Fatal(sshServer.ListenAndServe())
} }
...@@ -41,8 +42,7 @@ func StartServer() { ...@@ -41,8 +42,7 @@ func StartServer() {
func StopServer() { func StopServer() {
err := sshServer.Close() err := sshServer.Close()
if err != nil { if err != nil {
logger.Debugf("ssh server close failed: %s", err.Error()) logger.Errorf("SSH server close failed: %s", err.Error())
} }
logger.Debug("Close ssh Server") logger.Debug("Close ssh server")
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment