Commit 4ba06008 authored by Eric's avatar Eric Committed by Eric_Lee

fix cpu bugs and panic

parent 8ba707fa
......@@ -12,7 +12,6 @@ import (
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common"
......@@ -126,7 +125,7 @@ func (fs *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
if err != nil {
return nil, sftp.ErrSshFxPermissionDenied
}
sysUserDir.homeDirpath, err = client.Getwd()
sysUserDir.homeDirPath, err = client.Getwd()
if err != nil {
return nil, err
}
......@@ -193,7 +192,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) {
if err != nil {
return sftp.ErrSshFxPermissionDenied
}
suDir.homeDirpath, err = client.Getwd()
suDir.homeDirPath, err = client.Getwd()
if err != nil {
return err
}
......@@ -275,7 +274,7 @@ func (fs *sftpHandler) Filewrite(r *sftp.Request) (io.WriterAt, error) {
if err != nil {
return nil, sftp.ErrSshFxPermissionDenied
}
suDir.homeDirpath, err = client.Getwd()
suDir.homeDirPath, err = client.Getwd()
if err != nil {
return nil, err
}
......@@ -336,7 +335,7 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
if err != nil {
return nil, sftp.ErrSshFxPermissionDenied
}
suDir.homeDirpath, err = ftpClient.Getwd()
suDir.homeDirPath, err = ftpClient.Getwd()
if err != nil {
return nil, err
}
......@@ -364,12 +363,12 @@ func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
return NewReaderAt(f), err
}
func (fs *sftpHandler) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *gossh.Client, err error) {
func (fs *sftpHandler) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *srvconn.SSHClient, err error) {
sshClient, err = srvconn.NewClient(fs.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second)
if err != nil {
return
}
sftpClient, err = sftp.NewClient(sshClient)
sftpClient, err = sftp.NewClient(sshClient.Client)
if err != nil {
return
}
......@@ -437,9 +436,9 @@ type SysUserDir struct {
rootPath string
systemUser *model.SystemUser
time time.Time
homeDirpath string
homeDirPath string
client *sftp.Client
conn *gossh.Client
conn *srvconn.SSHClient
}
func (su *SysUserDir) Name() string { return su.systemUser.Name }
......@@ -462,7 +461,7 @@ func (su *SysUserDir) ParsePath(path string) string {
var realPath string
switch strings.ToLower(su.rootPath) {
case "home", "~", "":
realPath = strings.ReplaceAll(path, su.prefix, su.homeDirpath)
realPath = strings.ReplaceAll(path, su.prefix, su.homeDirPath)
default:
realPath = strings.ReplaceAll(path, su.prefix, su.rootPath)
}
......
......@@ -2,6 +2,7 @@ package httpd
import (
"io"
"sync"
"github.com/gliderlabs/ssh"
socketio "github.com/googollee/go-socket.io"
......@@ -20,6 +21,7 @@ type Client struct {
Conn socketio.Conn
Closed bool
pty ssh.Pty
lock *sync.RWMutex
}
func (c *Client) WinCh() <-chan ssh.Window {
......@@ -39,6 +41,11 @@ func (c *Client) Read(p []byte) (n int, err error) {
}
func (c *Client) Write(p []byte) (n int, err error) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.Closed {
return
}
data := DataMsg{Data: string(p), Room: c.Uuid}
n = len(p)
c.Conn.Emit("data", data)
......@@ -50,6 +57,8 @@ func (c *Client) Pty() ssh.Pty {
}
func (c *Client) Close() (err error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.Closed {
return
}
......
......@@ -6,6 +6,7 @@ import (
"io"
"net"
"strings"
"sync"
"github.com/gliderlabs/ssh"
socketio "github.com/googollee/go-socket.io"
......@@ -83,11 +84,11 @@ func OnHostHandler(s socketio.Conn, message HostMsg) {
ctx := s.Context().(WebContext)
userR, userW := io.Pipe()
conn := conns.GetWebConn(s.ID())
addr,_,_ := net.SplitHostPort(s.RemoteAddr().String())
addr, _, _ := net.SplitHostPort(s.RemoteAddr().String())
client := &Client{
Uuid: clientID, Cid: conn.Cid, user: conn.User,addr:addr,
Uuid: clientID, Cid: conn.Cid, user: conn.User, addr: addr,
WinChan: make(chan ssh.Window, 100), Conn: s,
UserRead: userR, UserWrite: userW,
UserRead: userR, UserWrite: userW, lock: new(sync.RWMutex),
pty: ssh.Pty{Term: "xterm", Window: win},
}
client.WinChan <- win
......@@ -97,6 +98,7 @@ func OnHostHandler(s socketio.Conn, message HostMsg) {
Asset: &asset, SystemUser: &systemUser,
}
go func() {
defer logger.Debug("web proxy end")
proxySrv.Proxy()
s.Emit("logout", RoomMsg{Room: clientID})
}()
......@@ -148,7 +150,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) {
client := Client{
Uuid: clientID, Cid: conn.Cid, user: conn.User,
WinChan: make(chan ssh.Window, 100), Conn: s,
UserRead: userR, UserWrite: userW,
UserRead: userR, UserWrite: userW, lock: new(sync.RWMutex),
pty: ssh.Pty{Term: "xterm", Window: win},
}
client.WinChan <- win
......@@ -159,6 +161,7 @@ func OnTokenHandler(s socketio.Conn, message TokenMsg) {
Asset: &asset, SystemUser: &systemUser,
}
go func() {
defer logger.Debug("web proxy end")
proxySrv.Proxy()
s.Emit("logout", RoomMsg{Room: clientID})
}()
......
......@@ -10,7 +10,6 @@ import (
"github.com/LeeEirc/elfinder"
"github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config"
......@@ -126,7 +125,7 @@ func (u *UserVolume) Info(path string) (elfinder.FileDir, error) {
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -205,7 +204,7 @@ func (u *UserVolume) List(path string) []elfinder.FileDir {
if err != nil {
return dirs
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return dirs
}
......@@ -284,7 +283,7 @@ func (u *UserVolume) GetFile(path string) (reader io.ReadCloser, err error) {
if err != nil {
return nil, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return nil, err
}
......@@ -355,7 +354,7 @@ func (u *UserVolume) UploadFile(dir, filename string, reader io.Reader) (elfinde
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -462,7 +461,7 @@ func (u *UserVolume) MergeChunk(cid, total int, dirPath, filename string) (elfin
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -567,7 +566,7 @@ func (u *UserVolume) MakeDir(dir, newDirname string) (elfinder.FileDir, error) {
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -640,7 +639,7 @@ func (u *UserVolume) MakeFile(dir, newFilename string) (elfinder.FileDir, error)
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -708,7 +707,7 @@ func (u *UserVolume) Rename(oldNamePath, newName string) (elfinder.FileDir, erro
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -783,7 +782,7 @@ func (u *UserVolume) Remove(path string) error {
if err != nil {
return os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return err
}
......@@ -854,7 +853,7 @@ func (u *UserVolume) Paste(dir, filename, suffix string, reader io.ReadCloser) (
if err != nil {
return rest, os.ErrPermission
}
sysUserVol.homeDirpath, err = sftClient.Getwd()
sysUserVol.homeDirPath, err = sftClient.Getwd()
if err != nil {
return rest, err
}
......@@ -905,12 +904,12 @@ func (u *UserVolume) RootFileDir() elfinder.FileDir {
return resFDir
}
func (u *UserVolume) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *gossh.Client, err error) {
func (u *UserVolume) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) (sftpClient *sftp.Client, sshClient *srvconn.SSHClient, err error) {
sshClient, err = srvconn.NewClient(u.user, asset, sysUser, config.GetConf().SSHTimeout*time.Second)
if err != nil {
return
}
sftpClient, err = sftp.NewClient(sshClient)
sftpClient, err = sftp.NewClient(sshClient.Client)
if err != nil {
return
}
......@@ -973,9 +972,9 @@ type sysUserVolume struct {
rootPath string
systemUser *model.SystemUser
homeDirpath string
homeDirPath string
client *sftp.Client
conn *gossh.Client
conn *srvconn.SSHClient
}
func (su *sysUserVolume) info() elfinder.FileDir {
......@@ -994,7 +993,7 @@ func (su *sysUserVolume) ParsePath(path string) string {
var realPath string
switch strings.ToLower(su.rootPath) {
case "home", "~", "":
realPath = strings.ReplaceAll(path, su.suPath, su.homeDirpath)
realPath = strings.ReplaceAll(path, su.suPath, su.homeDirPath)
default:
realPath = strings.ReplaceAll(path, su.suPath, su.rootPath)
}
......
......@@ -23,7 +23,7 @@ type ProxyServer struct {
}
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
func (p *ProxyServer) getSystemUserAuthOrManualSet() {
func (p *ProxyServer) getSystemUserAuthOrManualSet() error {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.ID, p.Asset.ID)
p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey
......@@ -41,19 +41,24 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() {
line, err := term.ReadPassword(fmt.Sprintf("%s's password: ", p.SystemUser.Username))
if err != nil {
logger.Errorf("Get password from user err %s", err.Error())
return err
}
p.SystemUser.Password = line
logger.Debug("Get password from user input: ", line)
}
return nil
}
// getSystemUserUsernameIfNeed 获取系统用户用户名,或手动设置
func (p *ProxyServer) getSystemUserUsernameIfNeed() {
func (p *ProxyServer) getSystemUserUsernameIfNeed() (err error) {
if p.SystemUser.Username == "" {
var username string
term := utils.NewTerminal(p.UserConn, "username: ")
for {
username, _ = term.ReadLine()
username, err = term.ReadLine()
if err != nil {
return err
}
username = strings.TrimSpace(username)
if username != "" {
break
......@@ -62,6 +67,7 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() {
p.SystemUser.Username = username
logger.Debug("Get username from user input: ", username)
}
return
}
// checkProtocolMatch 检查协议是否匹配
......@@ -131,8 +137,14 @@ func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection
// getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
p.getSystemUserUsernameIfNeed()
p.getSystemUserAuthOrManualSet()
err = p.getSystemUserUsernameIfNeed()
if err != nil {
return
}
err = p.getSystemUserAuthOrManualSet()
if err != nil {
return
}
done := make(chan struct{})
defer func() {
utils.IgnoreErrWriteString(p.UserConn, "\r\n")
......@@ -213,7 +225,6 @@ func (p *ProxyServer) Proxy() {
if err != nil || srvConn == nil {
srvConn, err = p.getServerConn()
}
// 连接后端服务器失败
if err != nil {
p.sendConnectErrorMsg(err)
......@@ -224,9 +235,6 @@ func (p *ProxyServer) Proxy() {
if err != nil {
return
}
defer RemoveSession(sw)
_ = sw.Bridge(p.UserConn, srvConn)
defer func() {
_ = srvConn.Close()
RemoveSession(sw)
}()
}
......@@ -3,7 +3,6 @@ package srvconn
import (
"errors"
"fmt"
"github.com/jumpserver/koko/pkg/service"
"net"
"strconv"
"sync"
......@@ -13,11 +12,12 @@ import (
"github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service"
)
var (
sshClients = make(map[string]*gossh.Client)
clientsRefCounter = make(map[*gossh.Client]int)
sshClients = make(map[string]*SSHClient)
clientsRefCounter = make(map[*SSHClient]int)
clientLock = new(sync.RWMutex)
)
......@@ -31,6 +31,11 @@ var (
"3des-cbc"}
)
type SSHClient struct {
Client *gossh.Client
Username string
}
type SSHClientConfig struct {
Host string `json:"host"`
Port string `json:"port"`
......@@ -58,7 +63,7 @@ func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
}
}
if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey),[]byte(sc.Password)); err != nil {
if signer, err := gossh.ParsePrivateKeyWithPassphrase([]byte(sc.PrivateKey), []byte(sc.Password)); err != nil {
err = fmt.Errorf("parse private key error: %s", err)
return config, err
} else {
......@@ -162,13 +167,16 @@ func MakeConfig(asset *model.Asset, systemUser *model.SystemUser, timeout time.D
return
}
func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *SSHClient, err error) {
sshConfig := MakeConfig(asset, systemUser, timeout)
client, err = sshConfig.Dial()
return
conn, err := sshConfig.Dial()
if err != nil {
return nil, err
}
return &SSHClient{Client: conn, Username: systemUser.Username}, err
}
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *SSHClient, err error) {
client = GetClientFromCache(user, asset, systemUser)
if client != nil {
return client, nil
......@@ -185,7 +193,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return
}
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client) {
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *SSHClient) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.ID, systemUser.ID)
clientLock.Lock()
defer clientLock.Unlock()
......@@ -196,15 +204,14 @@ func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.
var u = user.Username
var ip = asset.IP
var sysName = systemUser.Username
clientsRefCounter[client]++
var counter = clientsRefCounter[client]
logger.Infof("Reuse connection: %s->%s@%s ref: %d", u, sysName, ip, counter)
logger.Infof("Reuse connection: %s->%s@%s ref: %d", u, client.Username, ip, counter)
return
}
func RecycleClient(client *gossh.Client) {
func RecycleClient(client *SSHClient) {
clientLock.RLock()
counter, ok := clientsRefCounter[client]
clientLock.RUnlock()
......@@ -222,7 +229,7 @@ func RecycleClient(client *gossh.Client) {
}
}
func CloseClient(client *gossh.Client) {
func CloseClient(client *SSHClient) {
clientLock.Lock()
defer clientLock.Unlock()
......@@ -237,5 +244,5 @@ func CloseClient(client *gossh.Client) {
if key != "" {
delete(sshClients, key)
}
_ = client.Close()
_ = client.Client.Close()
}
......@@ -16,7 +16,7 @@ type ServerSSHConnection struct {
SystemUser *model.SystemUser
Overtime time.Duration
client *gossh.Client
client *SSHClient
session *gossh.Session
stdin io.WriteCloser
stdout io.Reader
......@@ -28,8 +28,12 @@ func (sc *ServerSSHConnection) Protocol() string {
return "ssh"
}
func (sc *ServerSSHConnection) Username() string {
return sc.client.Username
}
func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
sess, err := sc.client.NewSession()
sess, err := sc.client.Client.NewSession()
if err != nil {
return
}
......@@ -74,6 +78,9 @@ func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err e
if sc.client == nil {
return errors.New("no client in cache")
}
if sc.SystemUser.Username == "" {
sc.SystemUser.Username = sc.client.Username
}
err = sc.invokeShell(h, w, term)
if err != nil {
RecycleClient(sc.client)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment