Unverified Commit 0be79d58 authored by Eric_Lee's avatar Eric_Lee Committed by GitHub

Bugfix (#116)


* [Bugfix] fix reuse ssh connet bug
parent c66c2487
...@@ -21,6 +21,8 @@ type ProxyServer struct { ...@@ -21,6 +21,8 @@ type ProxyServer struct {
User *model.User User *model.User
Asset *model.Asset Asset *model.Asset
SystemUser *model.SystemUser SystemUser *model.SystemUser
cacheSSHClient *srvconn.SSHClient
} }
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置 // getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
...@@ -105,6 +107,7 @@ func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err er ...@@ -105,6 +107,7 @@ func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err er
ReuseConnection: conf.ReuseConnection, ReuseConnection: conf.ReuseConnection,
CloseOnce: new(sync.Once), CloseOnce: new(sync.Once),
} }
srvConn.SetSSHClient(p.cacheSSHClient)
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term) err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
return return
} }
...@@ -129,12 +132,19 @@ func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection, ...@@ -129,12 +132,19 @@ func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection,
// getServerConn 获取获取server连接 // getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) { func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
if p.cacheSSHClient == nil {
done := make(chan struct{}) done := make(chan struct{})
defer func() { defer func() {
utils.IgnoreErrWriteString(p.UserConn, "\r\n") utils.IgnoreErrWriteString(p.UserConn, "\r\n")
close(done) close(done)
}() }()
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second) go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second)
} else {
utils.IgnoreErrWriteString(p.UserConn, utils.WrapperString("You reuse SSH Conn from cache.\r\n", utils.Green))
logger.Infof("Request %s: Reuse connection for SSH. SSH client %p current ref: %d", p.UserConn.ID(),
p.cacheSSHClient, p.cacheSSHClient.RefCount())
}
if p.SystemUser.Protocol == "telnet" { if p.SystemUser.Protocol == "telnet" {
return p.getTelnetConn() return p.getTelnetConn()
} else { } else {
...@@ -193,6 +203,17 @@ func (p *ProxyServer) checkRequiredSystemUserInfo() error { ...@@ -193,6 +203,17 @@ func (p *ProxyServer) checkRequiredSystemUserInfo() error {
logger.Errorf("Get asset %s systemuser username err: %s", p.Asset.Hostname, err) logger.Errorf("Get asset %s systemuser username err: %s", p.Asset.Hostname, err)
return err return err
} }
if config.GetConf().ReuseConnection {
key := srvconn.MakeReuseSSHClientKey(p.User, p.Asset, p.SystemUser)
cacheSSHClient, ok := srvconn.GetClientFromCache(key)
if ok {
p.cacheSSHClient = cacheSSHClient
logger.Infof("Reuse connection for SFTP: %s->%s@%s. SSH client %p current ref: %d",
p.User.Username, p.SystemUser.Username, p.Asset.IP, cacheSSHClient, cacheSSHClient.RefCount())
return nil
}
}
if err := p.getSystemUserAuthOrManualSet(); err != nil { if err := p.getSystemUserAuthOrManualSet(); err != nil {
logger.Errorf("Get asset %s systemuser password/PrivateKey err: %s", p.Asset.Hostname, err) logger.Errorf("Get asset %s systemuser password/PrivateKey err: %s", p.Asset.Hostname, err)
return err return err
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
...@@ -42,7 +43,7 @@ type SSHClient struct { ...@@ -42,7 +43,7 @@ type SSHClient struct {
closed chan struct{} closed chan struct{}
} }
func (s *SSHClient) refCount() int { func (s *SSHClient) RefCount() int {
if s.isClosed() { if s.isClosed() {
return 0 return 0
} }
...@@ -273,34 +274,33 @@ func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Du ...@@ -273,34 +274,33 @@ func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Du
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration, func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration,
useCache bool) (client *SSHClient, err error) { useCache bool) (client *SSHClient, err error) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
switch {
case useCache:
client = getClientFromCache(key)
if client != nil {
if systemUser.Username == "" {
systemUser.Username = client.username
}
logger.Infof("Reuse connection: %s->%s@%s. SSH client %p current ref: %d",
user.Username, client.username, asset.IP, client, client.refCount())
return client, nil
}
}
client, err = newClient(asset, systemUser, timeout) client, err = newClient(asset, systemUser, timeout)
if err == nil && useCache { if err == nil && useCache {
key := MakeReuseSSHClientKey(user, asset, systemUser)
setClientCache(key, client) setClientCache(key, client)
} }
return return
} }
func getClientFromCache(key string) (client *SSHClient) { func searchSSHClientFromCache(prefixKey string) (client *SSHClient, ok bool) {
clientLock.Lock() clientLock.Lock()
defer clientLock.Unlock() defer clientLock.Unlock()
client, ok := sshClients[key] for key, cacheClient := range sshClients {
if !ok { if strings.HasPrefix(key, prefixKey) {
return nil cacheClient.increaseRef()
return cacheClient, true
} }
}
return
}
func GetClientFromCache(key string) (client *SSHClient, ok bool) {
clientLock.Lock()
defer clientLock.Unlock()
client, ok = sshClients[key]
if ok {
client.increaseRef() client.increaseRef()
}
return return
} }
...@@ -317,7 +317,7 @@ func RecycleClient(client *SSHClient) { ...@@ -317,7 +317,7 @@ func RecycleClient(client *SSHClient) {
return return
} }
client.decreaseRef() client.decreaseRef()
if client.refCount() == 0 { if client.RefCount() == 0 {
clientLock.Lock() clientLock.Lock()
delete(sshClients, client.key) delete(sshClients, client.key)
clientLock.Unlock() clientLock.Unlock()
...@@ -328,6 +328,10 @@ func RecycleClient(client *SSHClient) { ...@@ -328,6 +328,10 @@ func RecycleClient(client *SSHClient) {
logger.Infof("Success to close SSH client %p", client) logger.Infof("Success to close SSH client %p", client)
} }
} else { } else {
logger.Debugf("SSH client %p ref -1. current ref: %d", client, client.refCount()) logger.Debugf("SSH client %p ref -1. current ref: %d", client, client.RefCount())
} }
} }
func MakeReuseSSHClientKey(user *model.User, asset *model.Asset, systemUser *model.SystemUser) string {
return fmt.Sprintf("%s_%s_%s_%s", user.ID, asset.ID, systemUser.ID, systemUser.Username)
}
...@@ -589,10 +589,42 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) { ...@@ -589,10 +589,42 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) {
} }
func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (conn *SftpConn, err error) { func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (conn *SftpConn, err error) {
sshClient, err := NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection) var (
sshClient *SSHClient
ok bool
)
if u.ReuseConnection {
key := MakeReuseSSHClientKey(u.User, asset, sysUser)
switch sysUser.Username {
case "":
sshClient, ok = searchSSHClientFromCache(key)
if ok {
sysUser.Username = sshClient.username
}
default:
sshClient, ok = GetClientFromCache(key)
}
if !ok {
sshClient, err = NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil {
logger.Errorf("Get new SSH client err: %s", err)
return
}
} else {
logger.Infof("Reuse connection for SFTP: %s->%s@%s. SSH client %p current ref: %d",
u.User.Username, sshClient.username, asset.IP, sshClient, sshClient.RefCount())
}
} else {
sshClient, err = NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil { if err != nil {
logger.Errorf("Get new SSH client err: %s", err)
return return
} }
}
sftpClient, err := sftp.NewClient(sshClient.client) sftpClient, err := sftp.NewClient(sshClient.client)
if err != nil { if err != nil {
logger.Errorf("SSH client %p start sftp client session err %s", sshClient, err) logger.Errorf("SSH client %p start sftp client session err %s", sshClient, err)
......
...@@ -25,6 +25,12 @@ type ServerSSHConnection struct { ...@@ -25,6 +25,12 @@ type ServerSSHConnection struct {
stdout io.Reader stdout io.Reader
} }
func (sc *ServerSSHConnection) SetSSHClient(client *SSHClient) {
if client != nil {
sc.client = client
}
}
func (sc *ServerSSHConnection) Protocol() string { func (sc *ServerSSHConnection) Protocol() string {
return "ssh" return "ssh"
} }
...@@ -57,11 +63,13 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) { ...@@ -57,11 +63,13 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
} }
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) { func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
if sc.client == nil {
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection) sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection)
if err != nil { if err != nil {
logger.Errorf("New SSH client err: %s", err) logger.Errorf("New SSH client err: %s", err)
return return
} }
}
err = sc.invokeShell(h, w, term) err = sc.invokeShell(h, w, term)
if err != nil { if err != nil {
logger.Errorf("SSH client %p start ssh shell session err %s", sc.client, err) logger.Errorf("SSH client %p start ssh shell session err %s", sc.client, err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment