Unverified Commit 0be79d58 authored by Eric_Lee's avatar Eric_Lee Committed by GitHub

Bugfix (#116)


* [Bugfix] fix reuse ssh connet bug
parent c66c2487
......@@ -21,6 +21,8 @@ type ProxyServer struct {
User *model.User
Asset *model.Asset
SystemUser *model.SystemUser
cacheSSHClient *srvconn.SSHClient
}
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
......@@ -105,6 +107,7 @@ func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err er
ReuseConnection: conf.ReuseConnection,
CloseOnce: new(sync.Once),
}
srvConn.SetSSHClient(p.cacheSSHClient)
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
return
}
......@@ -129,12 +132,19 @@ func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection,
// getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
done := make(chan struct{})
defer func() {
utils.IgnoreErrWriteString(p.UserConn, "\r\n")
close(done)
}()
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second)
if p.cacheSSHClient == nil {
done := make(chan struct{})
defer func() {
utils.IgnoreErrWriteString(p.UserConn, "\r\n")
close(done)
}()
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second)
} else {
utils.IgnoreErrWriteString(p.UserConn, utils.WrapperString("You reuse SSH Conn from cache.\r\n", utils.Green))
logger.Infof("Request %s: Reuse connection for SSH. SSH client %p current ref: %d", p.UserConn.ID(),
p.cacheSSHClient, p.cacheSSHClient.RefCount())
}
if p.SystemUser.Protocol == "telnet" {
return p.getTelnetConn()
} else {
......@@ -193,6 +203,17 @@ func (p *ProxyServer) checkRequiredSystemUserInfo() error {
logger.Errorf("Get asset %s systemuser username err: %s", p.Asset.Hostname, err)
return err
}
if config.GetConf().ReuseConnection {
key := srvconn.MakeReuseSSHClientKey(p.User, p.Asset, p.SystemUser)
cacheSSHClient, ok := srvconn.GetClientFromCache(key)
if ok {
p.cacheSSHClient = cacheSSHClient
logger.Infof("Reuse connection for SFTP: %s->%s@%s. SSH client %p current ref: %d",
p.User.Username, p.SystemUser.Username, p.Asset.IP, cacheSSHClient, cacheSSHClient.RefCount())
return nil
}
}
if err := p.getSystemUserAuthOrManualSet(); err != nil {
logger.Errorf("Get asset %s systemuser password/PrivateKey err: %s", p.Asset.Hostname, err)
return err
......
......@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
......@@ -42,7 +43,7 @@ type SSHClient struct {
closed chan struct{}
}
func (s *SSHClient) refCount() int {
func (s *SSHClient) RefCount() int {
if s.isClosed() {
return 0
}
......@@ -273,34 +274,33 @@ func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Du
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration,
useCache bool) (client *SSHClient, err error) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
switch {
case useCache:
client = getClientFromCache(key)
if client != nil {
if systemUser.Username == "" {
systemUser.Username = client.username
}
logger.Infof("Reuse connection: %s->%s@%s. SSH client %p current ref: %d",
user.Username, client.username, asset.IP, client, client.refCount())
return client, nil
}
}
client, err = newClient(asset, systemUser, timeout)
if err == nil && useCache {
key := MakeReuseSSHClientKey(user, asset, systemUser)
setClientCache(key, client)
}
return
}
func getClientFromCache(key string) (client *SSHClient) {
func searchSSHClientFromCache(prefixKey string) (client *SSHClient, ok bool) {
clientLock.Lock()
defer clientLock.Unlock()
client, ok := sshClients[key]
if !ok {
return nil
for key, cacheClient := range sshClients {
if strings.HasPrefix(key, prefixKey) {
cacheClient.increaseRef()
return cacheClient, true
}
}
return
}
func GetClientFromCache(key string) (client *SSHClient, ok bool) {
clientLock.Lock()
defer clientLock.Unlock()
client, ok = sshClients[key]
if ok {
client.increaseRef()
}
client.increaseRef()
return
}
......@@ -317,7 +317,7 @@ func RecycleClient(client *SSHClient) {
return
}
client.decreaseRef()
if client.refCount() == 0 {
if client.RefCount() == 0 {
clientLock.Lock()
delete(sshClients, client.key)
clientLock.Unlock()
......@@ -328,6 +328,10 @@ func RecycleClient(client *SSHClient) {
logger.Infof("Success to close SSH client %p", client)
}
} else {
logger.Debugf("SSH client %p ref -1. current ref: %d", client, client.refCount())
logger.Debugf("SSH client %p ref -1. current ref: %d", client, client.RefCount())
}
}
func MakeReuseSSHClientKey(user *model.User, asset *model.Asset, systemUser *model.SystemUser) string {
return fmt.Sprintf("%s_%s_%s_%s", user.ID, asset.ID, systemUser.ID, systemUser.Username)
}
......@@ -589,10 +589,42 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) {
}
func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (conn *SftpConn, err error) {
sshClient, err := NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil {
return
var (
sshClient *SSHClient
ok bool
)
if u.ReuseConnection {
key := MakeReuseSSHClientKey(u.User, asset, sysUser)
switch sysUser.Username {
case "":
sshClient, ok = searchSSHClientFromCache(key)
if ok {
sysUser.Username = sshClient.username
}
default:
sshClient, ok = GetClientFromCache(key)
}
if !ok {
sshClient, err = NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil {
logger.Errorf("Get new SSH client err: %s", err)
return
}
} else {
logger.Infof("Reuse connection for SFTP: %s->%s@%s. SSH client %p current ref: %d",
u.User.Username, sshClient.username, asset.IP, sshClient, sshClient.RefCount())
}
} else {
sshClient, err = NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil {
logger.Errorf("Get new SSH client err: %s", err)
return
}
}
sftpClient, err := sftp.NewClient(sshClient.client)
if err != nil {
logger.Errorf("SSH client %p start sftp client session err %s", sshClient, err)
......
......@@ -25,6 +25,12 @@ type ServerSSHConnection struct {
stdout io.Reader
}
func (sc *ServerSSHConnection) SetSSHClient(client *SSHClient) {
if client != nil {
sc.client = client
}
}
func (sc *ServerSSHConnection) Protocol() string {
return "ssh"
}
......@@ -57,10 +63,12 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
}
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection)
if err != nil {
logger.Errorf("New SSH client err: %s", err)
return
if sc.client == nil {
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection)
if err != nil {
logger.Errorf("New SSH client err: %s", err)
return
}
}
err = sc.invokeShell(h, w, term)
if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment