Commit 63743fc2 authored by Eric's avatar Eric

make ssh reuse connection configurable

parent c28bcc29
......@@ -26,6 +26,7 @@ type Config struct {
MaxIdleTime time.Duration `json:"SECURITY_MAX_IDLE_TIME"`
SftpRoot string `json:"TERMINAL_SFTP_ROOT" yaml:"SFTP_ROOT"`
ShowHiddenFile bool `yaml:"SFTP_SHOW_HIDDEN_FILE"`
ReuseConnection bool `yaml:"REUSE_CONNECTION"`
Name string `yaml:"NAME"`
SecretKey string `yaml:"SECRET_KEY"`
HostKeyFile string `yaml:"HOST_KEY_FILE"`
......@@ -131,6 +132,7 @@ var Conf = &Config{
UploadFailedReplay: true,
SftpRoot: "/tmp",
ShowHiddenFile: false,
ReuseConnection: true,
}
func SetConf(conf *Config) {
......
......@@ -4,6 +4,7 @@ import (
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/jumpserver/koko/pkg/config"
......@@ -93,19 +94,18 @@ func (p *ProxyServer) validatePermission() bool {
}
// getSSHConn 获取ssh连接
func (p *ProxyServer) getSSHConn(fromCache ...bool) (srvConn *srvconn.ServerSSHConnection, err error) {
func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err error) {
pty := p.UserConn.Pty()
conf := config.GetConf()
srvConn = &srvconn.ServerSSHConnection{
User: p.User,
Asset: p.Asset,
SystemUser: p.SystemUser,
Overtime: time.Duration(config.GetConf().SSHTimeout) * time.Second,
Overtime: conf.SSHTimeout * time.Second,
ReuseConnection: conf.ReuseConnection,
CloseOnce: new(sync.Once),
}
if len(fromCache) > 0 && fromCache[0] {
err = srvConn.TryConnectFromCache(pty.Window.Height, pty.Window.Width, pty.Term)
} else {
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
}
return
}
......@@ -127,14 +127,6 @@ func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection,
return
}
// getServerConnFromCache 从cache中获取ssh server连接
func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection, err error) {
if p.SystemUser.Protocol == "ssh" {
srvConn, err = p.getSSHConn(true)
}
return
}
// getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
err = p.getSystemUserUsernameIfNeed()
......@@ -154,7 +146,7 @@ func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err err
if p.SystemUser.Protocol == "telnet" {
return p.getTelnetConn()
} else {
return p.getSSHConn(false)
return p.getSSHConn()
}
}
......@@ -220,11 +212,7 @@ func (p *ProxyServer) Proxy() {
if !p.preCheckRequisite() {
return
}
// 先从cache中获取srv连接, 如果没有获得,则连接
srvConn, err := p.getServerConnFromCache()
if err != nil || srvConn == nil {
srvConn, err = p.getServerConn()
}
srvConn, err := p.getServerConn()
// 连接后端服务器失败
if err != nil {
p.sendConnectErrorMsg(err)
......
......@@ -17,7 +17,6 @@ import (
var (
sshClients = make(map[string]*SSHClient)
clientsRefCounter = make(map[*SSHClient]int)
clientLock = new(sync.RWMutex)
)
......@@ -32,8 +31,43 @@ var (
)
type SSHClient struct {
Client *gossh.Client
Username string
client *gossh.Client
username string
ref int
key string
mu *sync.RWMutex
}
func (s *SSHClient) refCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ref
}
func (s *SSHClient) increaseRef() {
s.mu.Lock()
defer s.mu.Unlock()
s.ref++
}
func (s *SSHClient) decreaseRef() {
s.mu.Lock()
defer s.mu.Unlock()
s.ref--
}
func (s *SSHClient) NewSession() (*gossh.Session, error) {
return s.client.NewSession()
}
func (s *SSHClient) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ref > 1 {
return nil
}
return s.client.Close()
}
type SSHClientConfig struct {
......@@ -150,7 +184,7 @@ func MakeConfig(asset *model.Asset, systemUser *model.SystemUser, timeout time.D
}
}
}
if systemUser.Password == "" && systemUser.PrivateKey == "" && systemUser.LoginMode != model.LoginModeManual{
if systemUser.Password == "" && systemUser.PrivateKey == "" && systemUser.LoginMode != model.LoginModeManual {
info := service.GetSystemUserAssetAuthInfo(systemUser.ID, asset.ID)
systemUser.Password = info.Password
systemUser.PrivateKey = info.PrivateKey
......@@ -173,78 +207,67 @@ func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Du
if err != nil {
return nil, err
}
return &SSHClient{Client: conn, Username: systemUser.Username}, err
return &SSHClient{client: conn, username: systemUser.Username, mu: new(sync.RWMutex)}, err
}
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *SSHClient, err error) {
client = GetClientFromCache(user, asset, systemUser)
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration,
useCache bool) (client *SSHClient, err error) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
switch {
case useCache:
client = getClientFromCache(key)
if client != nil {
if systemUser.Username == "" {
systemUser.Username = client.username
}
logger.Infof("Reuse connection: %s->%s@%s ref: %d",
user.Username, client.username, asset.IP, client.refCount())
return client, nil
}
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
}
client, err = newClient(asset, systemUser, timeout)
if err == nil {
clientLock.Lock()
sshClients[key] = client
clientsRefCounter[client] = 1
clientLock.Unlock()
if err == nil && useCache {
setClientCache(key, client)
}
return
}
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *SSHClient) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
func getClientFromCache(key string) (client *SSHClient) {
clientLock.Lock()
defer clientLock.Unlock()
client, ok := sshClients[key]
if !ok {
return
}
if systemUser.Username == "" {
systemUser.Username = client.Username
return nil
}
var u = user.Username
var ip = asset.IP
clientsRefCounter[client]++
var counter = clientsRefCounter[client]
logger.Infof("Reuse connection: %s->%s@%s ref: %d", u, client.Username, ip, counter)
client.increaseRef()
return
}
func RecycleClient(client *SSHClient) {
clientLock.RLock()
counter, ok := clientsRefCounter[client]
clientLock.RUnlock()
if ok {
if counter == 1 {
logger.Debug("Recycle client: close it")
CloseClient(client)
} else {
func setClientCache(key string, client *SSHClient) {
clientLock.Lock()
clientsRefCounter[client]--
sshClients[key] = client
client.increaseRef()
client.key = key
clientLock.Unlock()
logger.Debugf("Recycle client: ref -1: %d", clientsRefCounter[client])
}
}
}
func CloseClient(client *SSHClient) {
clientLock.Lock()
defer clientLock.Unlock()
delete(clientsRefCounter, client)
var key string
for k, v := range sshClients {
if v == client {
key = k
break
func RecycleClient(client *SSHClient) {
// 0, 1: delete Cache, close client.
// default: client ref decrease.
if client == nil {
return
}
switch client.refCount() {
case 0, 1:
clientLock.Lock()
delete(sshClients, client.key)
clientLock.Unlock()
err := client.Close()
if err != nil {
logger.Info("Failed to close client err: ", err.Error())
}
if key != "" {
delete(sshClients, key)
default:
client.decreaseRef()
}
_ = client.Client.Close()
}
......@@ -32,6 +32,8 @@ type UserSftp struct {
RootPath string
ShowHidden bool
ReuseConnection bool
Overtime time.Duration
hosts map[string]*HostnameDir // key hostname or hostname.orgName
sftpClients map[string]*SftpConn // key %s@%s suName hostName
......@@ -42,6 +44,8 @@ func (u *UserSftp) initial(assets []model.Asset) {
conf := config.GetConf()
u.RootPath = conf.SftpRoot
u.ShowHidden = conf.ShowHiddenFile
u.ReuseConnection = conf.ReuseConnection
u.Overtime = conf.SSHTimeout * time.Second
u.hosts = make(map[string]*HostnameDir)
u.sftpClients = make(map[string]*SftpConn)
u.LogChan = make(chan *model.FTPLog, 10)
......@@ -92,9 +96,9 @@ func (u *UserSftp) ReadDir(path string) (res []os.FileInfo, err error) {
res, err = conn.client.ReadDir(realPath)
if !u.ShowHidden {
noHiddenFiles := make([]os.FileInfo, 0, len(res))
for i:=0; i<len(res);i++ {
for i := 0; i < len(res); i++ {
if !strings.HasPrefix(res[i].Name(), ".") {
noHiddenFiles = append(noHiddenFiles,res[i])
noHiddenFiles = append(noHiddenFiles, res[i])
}
}
return noHiddenFiles, err
......@@ -577,11 +581,11 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) {
}
func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (conn *SftpConn, err error) {
sshClient, err := NewClient(u.User, asset, sysUser, config.GetConf().SSHTimeout*time.Second)
sshClient, err := NewClient(u.User, asset, sysUser, u.Overtime, u.ReuseConnection)
if err != nil {
return
}
sftpClient, err := sftp.NewClient(sshClient.Client)
sftpClient, err := sftp.NewClient(sshClient.client)
if err != nil {
return
}
......
package srvconn
import (
"errors"
"io"
"sync"
"time"
gossh "golang.org/x/crypto/ssh"
......@@ -15,25 +15,21 @@ type ServerSSHConnection struct {
Asset *model.Asset
SystemUser *model.SystemUser
Overtime time.Duration
CloseOnce *sync.Once
ReuseConnection bool
client *SSHClient
session *gossh.Session
stdin io.WriteCloser
stdout io.Reader
closed bool
connected bool
}
func (sc *ServerSSHConnection) Protocol() string {
return "ssh"
}
func (sc *ServerSSHConnection) Username() string {
return sc.client.Username
}
func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
sess, err := sc.client.Client.NewSession()
sess, err := sc.client.NewSession()
if err != nil {
return
}
......@@ -60,32 +56,15 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
}
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout())
if err != nil {
return
}
err = sc.invokeShell(h, w, term)
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection)
if err != nil {
return
}
sc.connected = true
return nil
}
func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err error) {
sc.client = GetClientFromCache(sc.User, sc.Asset, sc.SystemUser)
if sc.client == nil {
return errors.New("no client in cache")
}
err = sc.invokeShell(h, w, term)
if err != nil {
RecycleClient(sc.client)
return
}
sc.connected = true
return nil
return
}
func (sc *ServerSSHConnection) SetWinSize(h, w int) error {
......@@ -108,13 +87,9 @@ func (sc *ServerSSHConnection) Timeout() time.Duration {
}
func (sc *ServerSSHConnection) Close() (err error) {
sc.CloseOnce.Do(func() {
RecycleClient(sc.client)
if sc.closed || !sc.connected {
return
}
err = sc.session.Close()
if err != nil {
return
}
return
})
return sc.session.Close()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment