Commit 4ba06008 authored by Eric's avatar Eric Committed by Eric_Lee

fix cpu bugs and panic

parent 8ba707fa
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/pkg/sftp" "github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/cctx" "github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
...@@ -126,7 +125,7 @@ func (fs *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) { ...@@ -126,7 +125,7 @@ func (fs *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
if err != nil { if err != nil {
return nil, sftp.ErrSshFxPermissionDenied return nil, sftp.ErrSshFxPermissionDenied
} }
sysUserDir.homeDirpath, err = client.Getwd() sysUserDir.homeDirPath, err = client.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -193,7 +192,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) { ...@@ -193,7 +192,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) {
if err != nil { if err != nil {
return sftp.ErrSshFxPermissionDenied return sftp.ErrSshFxPermissionDenied
} }
suDir.homeDirpath, err = client.Getwd() suDir.homeDirPath, err = client.Getwd()
if err != nil { if err != nil {
return err return err
} }
...@@ -275,7 +274,7 @@ func (fs *sftpHandler) Filewrite(r *sftp.Request) (io.WriterAt, error) { ...@@ -275,7 +274,7 @@ func (fs *sftpHandler) Filewrite(r *sftp.Request) (io.WriterAt, error) {
if err != nil { if err != nil {
return nil, sftp.ErrSshFxPermissionDenied return nil, sftp.ErrSshFxPermissionDenied
} }
suDir.homeDirpath, err = client.Getwd() suDir.homeDirPath, err = client.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -336,7 +335,7 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) { ...@@ -336,7 +335,7 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
if err != nil { if err != nil {
return nil, sftp.ErrSshFxPermissionDenied return nil, sftp.ErrSshFxPermissionDenied
} }
suDir.homeDirpath, err = ftpClient.Getwd() suDir.homeDirPath, err = ftpClient.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -364,12 +363,12 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) { ...@@ -364,12 +363,12 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
return NewReaderAt(f), err return NewReaderAt(f), err
} }
func (fs *sftpHandler) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *gossh.Client, err error) { func (fs *sftpHandler) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *srvconn.SSHClient, err error) {
sshClient, err = srvconn.NewClient(fs.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second) sshClient, err = srvconn.NewClient(fs.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second)
if err != nil { if err != nil {
return return
} }
sftpClient, err = sftp.NewClient(sshClient) sftpClient, err = sftp.NewClient(sshClient.Client)
if err != nil { if err != nil {
return return
} }
...@@ -437,9 +436,9 @@ type SysUserDir struct { ...@@ -437,9 +436,9 @@ type SysUserDir struct {
rootPath string rootPath string
systemUser *model.SystemUser systemUser *model.SystemUser
time time.Time time time.Time
homeDirpath string homeDirPath string
client *sftp.Client client *sftp.Client
conn *gossh.Client conn *srvconn.SSHClient
} }
func (su *SysUserDir) Name() string { return su.systemUser.Name } func (su *SysUserDir) Name() string { return su.systemUser.Name }
...@@ -462,7 +461,7 @@ func (su *SysUserDir) ParsePath(path string) string { ...@@ -462,7 +461,7 @@ func (su *SysUserDir) ParsePath(path string) string {
var realPath string var realPath string
switch strings.ToLower(su.rootPath) { switch strings.ToLower(su.rootPath) {
case "home", "~", "": case "home", "~", "":
realPath = strings.ReplaceAll(path, su.prefix, su.homeDirpath) realPath = strings.ReplaceAll(path, su.prefix, su.homeDirPath)
default: default:
realPath = strings.ReplaceAll(path, su.prefix, su.rootPath) realPath = strings.ReplaceAll(path, su.prefix, su.rootPath)
} }
......
...@@ -2,6 +2,7 @@ package httpd ...@@ -2,6 +2,7 @@ package httpd
import ( import (
"io" "io"
"sync"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
socketio "github.com/googollee/go-socket.io" socketio "github.com/googollee/go-socket.io"
...@@ -20,6 +21,7 @@ type Client struct { ...@@ -20,6 +21,7 @@ type Client struct {
Conn socketio.Conn Conn socketio.Conn
Closed bool Closed bool
pty ssh.Pty pty ssh.Pty
lock *sync.RWMutex
} }
func (c *Client) WinCh() <-chan ssh.Window { func (c *Client) WinCh() <-chan ssh.Window {
...@@ -39,6 +41,11 @@ func (c *Client) Read(p []byte) (n int, err error) { ...@@ -39,6 +41,11 @@ func (c *Client) Read(p []byte) (n int, err error) {
} }
func (c *Client) Write(p []byte) (n int, err error) { func (c *Client) Write(p []byte) (n int, err error) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.Closed {
return
}
data := DataMsg{Data: string(p), Room: c.Uuid} data := DataMsg{Data: string(p), Room: c.Uuid}
n = len(p) n = len(p)
c.Conn.Emit("data", data) c.Conn.Emit("data", data)
...@@ -50,6 +57,8 @@ func (c *Client) Pty() ssh.Pty { ...@@ -50,6 +57,8 @@ func (c *Client) Pty() ssh.Pty {
} }
func (c *Client) Close() (err error) { func (c *Client) Close() (err error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.Closed { if c.Closed {
return return
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
socketio "github.com/googollee/go-socket.io" socketio "github.com/googollee/go-socket.io"
...@@ -83,11 +84,11 @@ func OnHostHandler(s socketio.Conn, message HostMsg) { ...@@ -83,11 +84,11 @@ func OnHostHandler(s socketio.Conn, message HostMsg) {
ctx := s.Context().(WebContext) ctx := s.Context().(WebContext)
userR, userW := io.Pipe() userR, userW := io.Pipe()
conn := conns.GetWebConn(s.ID()) conn := conns.GetWebConn(s.ID())
addr,_,_ := net.SplitHostPort(s.RemoteAddr().String()) addr, _, _ := net.SplitHostPort(s.RemoteAddr().String())
client := &Client{ client := &Client{
Uuid: clientID, Cid: conn.Cid, user: conn.User,addr:addr, Uuid: clientID, Cid: conn.Cid, user: conn.User, addr: addr,
WinChan: make(chan ssh.Window, 100), Conn: s, WinChan: make(chan ssh.Window, 100), Conn: s,
UserRead: userR, UserWrite: userW, UserRead: userR, UserWrite: userW, lock: new(sync.RWMutex),
pty: ssh.Pty{Term: "xterm", Window: win}, pty: ssh.Pty{Term: "xterm", Window: win},
} }
client.WinChan <- win client.WinChan <- win
...@@ -97,6 +98,7 @@ func OnHostHandler(s socketio.Conn, message HostMsg) { ...@@ -97,6 +98,7 @@ func OnHostHandler(s socketio.Conn, message HostMsg) {
Asset: &asset, SystemUser: &systemUser, Asset: &asset, SystemUser: &systemUser,
} }
go func() { go func() {
defer logger.Debug("web proxy end")
proxySrv.Proxy() proxySrv.Proxy()
s.Emit("logout", RoomMsg{Room: clientID}) s.Emit("logout", RoomMsg{Room: clientID})
}() }()
...@@ -148,7 +150,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) { ...@@ -148,7 +150,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) {
client := Client{ client := Client{
Uuid: clientID, Cid: conn.Cid, user: conn.User, Uuid: clientID, Cid: conn.Cid, user: conn.User,
WinChan: make(chan ssh.Window, 100), Conn: s, WinChan: make(chan ssh.Window, 100), Conn: s,
UserRead: userR, UserWrite: userW, UserRead: userR, UserWrite: userW, lock: new(sync.RWMutex),
pty: ssh.Pty{Term: "xterm", Window: win}, pty: ssh.Pty{Term: "xterm", Window: win},
} }
client.WinChan <- win client.WinChan <- win
...@@ -159,6 +161,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) { ...@@ -159,6 +161,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) {
Asset: &asset, SystemUser: &systemUser, Asset: &asset, SystemUser: &systemUser,
} }
go func() { go func() {
defer logger.Debug("web proxy end")
proxySrv.Proxy() proxySrv.Proxy()
s.Emit("logout", RoomMsg{Room: clientID}) s.Emit("logout", RoomMsg{Room: clientID})
}() }()
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"github.com/LeeEirc/elfinder" "github.com/LeeEirc/elfinder"
"github.com/pkg/sftp" "github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
...@@ -126,7 +125,7 @@ func (u *UserVolume) Info(path string) (elfinder.FileDir, error) { ...@@ -126,7 +125,7 @@ func (u *UserVolume) Info(path string) (elfinder.FileDir, error) {
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -205,7 +204,7 @@ func (u *UserVolume) List(path string) []elfinder.FileDir { ...@@ -205,7 +204,7 @@ func (u *UserVolume) List(path string) []elfinder.FileDir {
if err != nil { if err != nil {
return dirs return dirs
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return dirs return dirs
} }
...@@ -284,7 +283,7 @@ func (u *UserVolume) GetFile(path string) (reader io.ReadCloser, err error) { ...@@ -284,7 +283,7 @@ func (u *UserVolume) GetFile(path string) (reader io.ReadCloser, err error) {
if err != nil { if err != nil {
return nil, os.ErrPermission return nil, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -355,7 +354,7 @@ func (u *UserVolume) UploadFile(dir, filename string, reader io.Reader) (elfinde ...@@ -355,7 +354,7 @@ func (u *UserVolume) UploadFile(dir, filename string, reader io.Reader) (elfinde
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -462,7 +461,7 @@ func (u *UserVolume) MergeChunk(cid, total int, dirPath, filename string) (elfin ...@@ -462,7 +461,7 @@ func (u *UserVolume) MergeChunk(cid, total int, dirPath, filename string) (elfin
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -567,7 +566,7 @@ func (u *UserVolume) MakeDir(dir, newDirname string) (elfinder.FileDir, error) { ...@@ -567,7 +566,7 @@ func (u *UserVolume) MakeDir(dir, newDirname string) (elfinder.FileDir, error) {
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -640,7 +639,7 @@ func (u *UserVolume) MakeFile(dir, newFilename string) (elfinder.FileDir, error) ...@@ -640,7 +639,7 @@ func (u *UserVolume) MakeFile(dir, newFilename string) (elfinder.FileDir, error)
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -708,7 +707,7 @@ func (u *UserVolume) Rename(oldNamePath, newName string) (elfinder.FileDir, erro ...@@ -708,7 +707,7 @@ func (u *UserVolume) Rename(oldNamePath, newName string) (elfinder.FileDir, erro
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -783,7 +782,7 @@ func (u *UserVolume) Remove(path string) error { ...@@ -783,7 +782,7 @@ func (u *UserVolume) Remove(path string) error {
if err != nil { if err != nil {
return os.ErrPermission return os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return err return err
} }
...@@ -854,7 +853,7 @@ func (u *UserVolume) Paste(dir, filename, suffix string, reader io.ReadCloser) ( ...@@ -854,7 +853,7 @@ func (u *UserVolume) Paste(dir, filename, suffix string, reader io.ReadCloser) (
if err != nil { if err != nil {
return rest, os.ErrPermission return rest, os.ErrPermission
} }
sysUserVol.homeDirpath, err = sftClient.Getwd() sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil { if err != nil {
return rest, err return rest, err
} }
...@@ -905,12 +904,12 @@ func (u *UserVolume) RootFileDir() elfinder.FileDir { ...@@ -905,12 +904,12 @@ func (u *UserVolume) RootFileDir() elfinder.FileDir {
return resFDir return resFDir
} }
func (u *UserVolume) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *gossh.Client, err error) { func (u *UserVolume) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *srvconn.SSHClient, err error) {
sshClient, err = srvconn.NewClient(u.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second) sshClient, err = srvconn.NewClient(u.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second)
if err != nil { if err != nil {
return return
} }
sftpClient, err = sftp.NewClient(sshClient) sftpClient, err = sftp.NewClient(sshClient.Client)
if err != nil { if err != nil {
return return
} }
...@@ -973,9 +972,9 @@ type sysUserVolume struct { ...@@ -973,9 +972,9 @@ type sysUserVolume struct {
rootPath string rootPath string
systemUser *model.SystemUser systemUser *model.SystemUser
homeDirpath string homeDirPath string
client *sftp.Client client *sftp.Client
conn *gossh.Client conn *srvconn.SSHClient
} }
func (su *sysUserVolume) info() elfinder.FileDir { func (su *sysUserVolume) info() elfinder.FileDir {
...@@ -994,7 +993,7 @@ func (su *sysUserVolume) ParsePath(path string) string { ...@@ -994,7 +993,7 @@ func (su *sysUserVolume) ParsePath(path string) string {
var realPath string var realPath string
switch strings.ToLower(su.rootPath) { switch strings.ToLower(su.rootPath) {
case "home", "~", "": case "home", "~", "":
realPath = strings.ReplaceAll(path, su.suPath, su.homeDirpath) realPath = strings.ReplaceAll(path, su.suPath, su.homeDirPath)
default: default:
realPath = strings.ReplaceAll(path, su.suPath, su.rootPath) realPath = strings.ReplaceAll(path, su.suPath, su.rootPath)
} }
......
...@@ -23,7 +23,7 @@ type ProxyServer struct { ...@@ -23,7 +23,7 @@ type ProxyServer struct {
} }
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置 // getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
func (p *ProxyServer) getSystemUserAuthOrManualSet() { func (p *ProxyServer) getSystemUserAuthOrManualSet() error {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.ID, p.Asset.ID) info := service.GetSystemUserAssetAuthInfo(p.SystemUser.ID, p.Asset.ID)
p.SystemUser.Password = info.Password p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey p.SystemUser.PrivateKey = info.PrivateKey
...@@ -41,19 +41,24 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() { ...@@ -41,19 +41,24 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() {
line, err := term.ReadPassword(fmt.Sprintf("%s's password: ", p.SystemUser.Username)) line, err := term.ReadPassword(fmt.Sprintf("%s's password: ", p.SystemUser.Username))
if err != nil { if err != nil {
logger.Errorf("Get password from user err %s", err.Error()) logger.Errorf("Get password from user err %s", err.Error())
return err
} }
p.SystemUser.Password = line p.SystemUser.Password = line
logger.Debug("Get password from user input: ", line) logger.Debug("Get password from user input: ", line)
} }
return nil
} }
// getSystemUserUsernameIfNeed 获取系统用户用户名,或手动设置 // getSystemUserUsernameIfNeed 获取系统用户用户名,或手动设置
func (p *ProxyServer) getSystemUserUsernameIfNeed() { func (p *ProxyServer) getSystemUserUsernameIfNeed() (err error) {
if p.SystemUser.Username == "" { if p.SystemUser.Username == "" {
var username string var username string
term := utils.NewTerminal(p.UserConn, "username: ") term := utils.NewTerminal(p.UserConn, "username: ")
for { for {
username, _ = term.ReadLine() username, err = term.ReadLine()
if err != nil {
return err
}
username = strings.TrimSpace(username) username = strings.TrimSpace(username)
if username != "" { if username != "" {
break break
...@@ -62,6 +67,7 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() { ...@@ -62,6 +67,7 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() {
p.SystemUser.Username = username p.SystemUser.Username = username
logger.Debug("Get username from user input: ", username) logger.Debug("Get username from user input: ", username)
} }
return
} }
// checkProtocolMatch 检查协议是否匹配 // checkProtocolMatch 检查协议是否匹配
...@@ -131,8 +137,14 @@ func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection ...@@ -131,8 +137,14 @@ func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection
// getServerConn 获取获取server连接 // getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) { func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
p.getSystemUserUsernameIfNeed() err = p.getSystemUserUsernameIfNeed()
p.getSystemUserAuthOrManualSet() if err != nil {
return
}
err = p.getSystemUserAuthOrManualSet()
if err != nil {
return
}
done := make(chan struct{}) done := make(chan struct{})
defer func() { defer func() {
utils.IgnoreErrWriteString(p.UserConn, "\r\n") utils.IgnoreErrWriteString(p.UserConn, "\r\n")
...@@ -213,7 +225,6 @@ func (p *ProxyServer) Proxy() { ...@@ -213,7 +225,6 @@ func (p *ProxyServer) Proxy() {
if err != nil || srvConn == nil { if err != nil || srvConn == nil {
srvConn, err = p.getServerConn() srvConn, err = p.getServerConn()
} }
// 连接后端服务器失败 // 连接后端服务器失败
if err != nil { if err != nil {
p.sendConnectErrorMsg(err) p.sendConnectErrorMsg(err)
...@@ -224,9 +235,6 @@ func (p *ProxyServer) Proxy() { ...@@ -224,9 +235,6 @@ func (p *ProxyServer) Proxy() {
if err != nil { if err != nil {
return return
} }
defer RemoveSession(sw)
_ = sw.Bridge(p.UserConn, srvConn) _ = sw.Bridge(p.UserConn, srvConn)
defer func() {
_ = srvConn.Close()
RemoveSession(sw)
}()
} }
...@@ -3,7 +3,6 @@ package srvconn ...@@ -3,7 +3,6 @@ package srvconn
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/jumpserver/koko/pkg/service"
"net" "net"
"strconv" "strconv"
"sync" "sync"
...@@ -13,11 +12,12 @@ import ( ...@@ -13,11 +12,12 @@ import (
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service"
) )
var ( var (
sshClients = make(map[string]*gossh.Client) sshClients = make(map[string]*SSHClient)
clientsRefCounter = make(map[*gossh.Client]int) clientsRefCounter = make(map[*SSHClient]int)
clientLock = new(sync.RWMutex) clientLock = new(sync.RWMutex)
) )
...@@ -31,6 +31,11 @@ var ( ...@@ -31,6 +31,11 @@ var (
"3des-cbc"} "3des-cbc"}
) )
type SSHClient struct {
Client *gossh.Client
Username string
}
type SSHClientConfig struct { type SSHClientConfig struct {
Host string `json:"host"` Host string `json:"host"`
Port string `json:"port"` Port string `json:"port"`
...@@ -58,7 +63,7 @@ func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) { ...@@ -58,7 +63,7 @@ func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
} }
} }
if sc.PrivateKey != "" { if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey),[]byte(sc.Password)); err != nil { if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey), []byte(sc.Password)); err != nil {
err = fmt.Errorf("parse private key error: %s", err) err = fmt.Errorf("parse private key error: %s", err)
return config, err return config, err
} else { } else {
...@@ -162,13 +167,16 @@ func MakeConfig(asset *model.Asset, systemUser *model.SystemUser, timeout time.D ...@@ -162,13 +167,16 @@ func MakeConfig(asset *model.Asset, systemUser *model.SystemUser, timeout time.D
return return
} }
func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) { func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *SSHClient, err error) {
sshConfig := MakeConfig(asset, systemUser, timeout) sshConfig := MakeConfig(asset, systemUser, timeout)
client, err = sshConfig.Dial() conn, err := sshConfig.Dial()
return if err != nil {
return nil, err
}
return &SSHClient{Client: conn, Username: systemUser.Username}, err
} }
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) { func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *SSHClient, err error) {
client = GetClientFromCache(user, asset, systemUser) client = GetClientFromCache(user, asset, systemUser)
if client != nil { if client != nil {
return client, nil return client, nil
...@@ -185,7 +193,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse ...@@ -185,7 +193,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return return
} }
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client) { func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *SSHClient) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID) key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
clientLock.Lock() clientLock.Lock()
defer clientLock.Unlock() defer clientLock.Unlock()
...@@ -196,15 +204,14 @@ func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model. ...@@ -196,15 +204,14 @@ func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.
var u = user.Username var u = user.Username
var ip = asset.IP var ip = asset.IP
var sysName = systemUser.Username
clientsRefCounter[client]++ clientsRefCounter[client]++
var counter = clientsRefCounter[client] var counter = clientsRefCounter[client]
logger.Infof("Reuse connection: %s->%s@%s ref: %d", u, sysName, ip, counter) logger.Infof("Reuse connection: %s->%s@%s ref: %d", u, client.Username, ip, counter)
return return
} }
func RecycleClient(client *gossh.Client) { func RecycleClient(client *SSHClient) {
clientLock.RLock() clientLock.RLock()
counter, ok := clientsRefCounter[client] counter, ok := clientsRefCounter[client]
clientLock.RUnlock() clientLock.RUnlock()
...@@ -222,7 +229,7 @@ func RecycleClient(client *gossh.Client) { ...@@ -222,7 +229,7 @@ func RecycleClient(client *gossh.Client) {
} }
} }
func CloseClient(client *gossh.Client) { func CloseClient(client *SSHClient) {
clientLock.Lock() clientLock.Lock()
defer clientLock.Unlock() defer clientLock.Unlock()
...@@ -237,5 +244,5 @@ func CloseClient(client *gossh.Client) { ...@@ -237,5 +244,5 @@ func CloseClient(client *gossh.Client) {
if key != "" { if key != "" {
delete(sshClients, key) delete(sshClients, key)
} }
_ = client.Close() _ = client.Client.Close()
} }
...@@ -16,7 +16,7 @@ type ServerSSHConnection struct { ...@@ -16,7 +16,7 @@ type ServerSSHConnection struct {
SystemUser *model.SystemUser SystemUser *model.SystemUser
Overtime time.Duration Overtime time.Duration
client *gossh.Client client *SSHClient
session *gossh.Session session *gossh.Session
stdin io.WriteCloser stdin io.WriteCloser
stdout io.Reader stdout io.Reader
...@@ -28,8 +28,12 @@ func (sc *ServerSSHConnection) Protocol() string { ...@@ -28,8 +28,12 @@ func (sc *ServerSSHConnection) Protocol() string {
return "ssh" return "ssh"
} }
func (sc *ServerSSHConnection) Username() string {
return sc.client.Username
}
func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) { func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
sess, err := sc.client.NewSession() sess, err := sc.client.Client.NewSession()
if err != nil { if err != nil {
return return
} }
...@@ -74,6 +78,9 @@ func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err e ...@@ -74,6 +78,9 @@ func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err e
if sc.client == nil { if sc.client == nil {
return errors.New("no client in cache") return errors.New("no client in cache")
} }
if sc.SystemUser.Username == "" {
sc.SystemUser.Username = sc.client.Username
}
err = sc.invokeShell(h, w, term) err = sc.invokeShell(h, w, term)
if err != nil { if err != nil {
RecycleClient(sc.client) RecycleClient(sc.client)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment