Unverified Commit 7808031b authored by Eric_Lee's avatar Eric_Lee Committed by GitHub

[Bugfix] Fix leaking goroutines and add more log info msg (#105)

parent 6a311821
...@@ -91,6 +91,7 @@ func (h *interactiveHandler) displayBanner() { ...@@ -91,6 +91,7 @@ func (h *interactiveHandler) displayBanner() {
func (h *interactiveHandler) watchWinSizeChange() { func (h *interactiveHandler) watchWinSizeChange() {
sessChan := h.sess.WinCh() sessChan := h.sess.WinCh()
winChan := sessChan winChan := sessChan
defer logger.Infof("Request %s: Windows change watch close", h.sess.Uuid)
for { for {
select { select {
case <-h.sess.Sess.Context().Done(): case <-h.sess.Sess.Context().Done():
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/pkg/sftp" "github.com/pkg/sftp"
uuid "github.com/satori/go.uuid"
"github.com/jumpserver/koko/pkg/cctx" "github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
...@@ -29,15 +30,17 @@ func SftpHandler(sess ssh.Session) { ...@@ -29,15 +30,17 @@ func SftpHandler(sess ssh.Session) {
FileCmd: userSftp, FileCmd: userSftp,
FileList: userSftp, FileList: userSftp,
} }
reqID := uuid.NewV4().String()
logger.Infof("SFTP request %s: Handler start", reqID)
req := sftp.NewRequestServer(sess, handlers) req := sftp.NewRequestServer(sess, handlers)
if err := req.Serve(); err == io.EOF { if err := req.Serve(); err == io.EOF {
_ = req.Close() logger.Debug("SFTP request %s: Exited session.", reqID)
userSftp.Close()
logger.Info("sftp client exited session.")
} else if err != nil { } else if err != nil {
logger.Error("sftp server completed with error:", err) logger.Errorf("SFTP request %s: Server completed with error %s", reqID, err)
} }
_ = req.Close()
userSftp.Close()
logger.Infof("SFTP request %s: Handler exit.", reqID)
} }
func NewSFTPHandler(user *model.User, addr string) *sftpHandler { func NewSFTPHandler(user *model.User, addr string) *sftpHandler {
......
...@@ -39,7 +39,7 @@ func (w *WrapperSession) readLoop() { ...@@ -39,7 +39,7 @@ func (w *WrapperSession) readLoop() {
} }
} }
_ = w.inWriter.Close() _ = w.inWriter.Close()
logger.Infof("Request %s read loop break", w.Uuid) logger.Infof("Request %s: Read loop break", w.Uuid)
} }
func (w *WrapperSession) Read(p []byte) (int, error) { func (w *WrapperSession) Read(p []byte) (int, error) {
......
...@@ -13,7 +13,7 @@ func OnELFinderConnect(c *neffos.NSConn, msg neffos.Message) error { ...@@ -13,7 +13,7 @@ func OnELFinderConnect(c *neffos.NSConn, msg neffos.Message) error {
return nil return nil
} }
func OnELFinderDisconnect(c *neffos.NSConn, msg neffos.Message) (error) { func OnELFinderDisconnect(c *neffos.NSConn, msg neffos.Message) error {
logger.Infof("Request %s: web folder ws disconnect", c.Conn.ID()) logger.Infof("Request %s: web folder ws disconnect", c.Conn.ID())
removeUserVolume(c.Conn.ID()) removeUserVolume(c.Conn.ID())
return nil return nil
......
...@@ -107,7 +107,7 @@ func StartHTTPServer() { ...@@ -107,7 +107,7 @@ func StartHTTPServer() {
// AuthDecorator(sftpHostConnectorView)).Methods("GET", "POST") // AuthDecorator(sftpHostConnectorView)).Methods("GET", "POST")
addr := net.JoinHostPort(conf.BindHost, conf.HTTPPort) addr := net.JoinHostPort(conf.BindHost, conf.HTTPPort)
logger.Debug("Start HTTP server at ", addr) logger.Info("Start HTTP server at ", addr)
httpServer = &http.Server{Addr: addr, Handler: router} httpServer = &http.Server{Addr: addr, Handler: router}
logger.Fatal(httpServer.ListenAndServe()) logger.Fatal(httpServer.ListenAndServe())
} }
......
...@@ -32,7 +32,7 @@ func (c *Coco) Start() { ...@@ -32,7 +32,7 @@ func (c *Coco) Start() {
func (c *Coco) Stop() { func (c *Coco) Stop() {
sshd.StopServer() sshd.StopServer()
httpd.StopHTTPServer() httpd.StopHTTPServer()
logger.Info("Quit The Coco") logger.Info("Quit The KoKo")
} }
func RunForever() { func RunForever() {
......
...@@ -28,7 +28,7 @@ var ( ...@@ -28,7 +28,7 @@ var (
const ( const (
CommandInputParserName = "Command Input parser" CommandInputParserName = "Command Input parser"
CommandOutputParserName = "Command output parser" CommandOutputParserName = "Command Output parser"
) )
func newParser(sid string) *Parser { func newParser(sid string) *Parser {
...@@ -78,7 +78,7 @@ func (p *Parser) ParseStream(userInChan, srvInChan <-chan []byte) (userOut, srvO ...@@ -78,7 +78,7 @@ func (p *Parser) ParseStream(userInChan, srvInChan <-chan []byte) (userOut, srvO
p.userOutputChan = make(chan []byte, 1) p.userOutputChan = make(chan []byte, 1)
p.srvOutputChan = make(chan []byte, 1) p.srvOutputChan = make(chan []byte, 1)
logger.Infof("Session %s: Parser start", p.id)
go func() { go func() {
defer func() { defer func() {
// 会话结束,结算命令结果 // 会话结束,结算命令结果
...@@ -88,7 +88,7 @@ func (p *Parser) ParseStream(userInChan, srvInChan <-chan []byte) (userOut, srvO ...@@ -88,7 +88,7 @@ func (p *Parser) ParseStream(userInChan, srvInChan <-chan []byte) (userOut, srvO
close(p.srvOutputChan) close(p.srvOutputChan)
_ = p.cmdOutputParser.Close() _ = p.cmdOutputParser.Close()
_ = p.cmdInputParser.Close() _ = p.cmdInputParser.Close()
logger.Infof("Session %s parser routine done", p.id) logger.Infof("Session %s: Parser routine done", p.id)
}() }()
for { for {
select { select {
......
...@@ -215,7 +215,7 @@ func (p *ProxyServer) Proxy() { ...@@ -215,7 +215,7 @@ func (p *ProxyServer) Proxy() {
// 创建Session // 创建Session
sw, err := CreateSession(p) sw, err := CreateSession(p)
if err != nil { if err != nil {
logger.Error("Create session failed.") logger.Errorf("Request %s: Create session failed: %s",p.UserConn.ID(), err.Error())
return return
} }
defer RemoveSession(sw) defer RemoveSession(sw)
......
...@@ -56,8 +56,10 @@ func (c *CommandRecorder) End() { ...@@ -56,8 +56,10 @@ func (c *CommandRecorder) End() {
} }
func (c *CommandRecorder) record() { func (c *CommandRecorder) record() {
cmdList := make([]*model.Command, 0) cmdList := make([]*model.Command, 0, 10)
maxRetry := 0 maxRetry := 0
logger.Infof("Session %s: Command recorder start", c.sessionID)
defer logger.Infof("Session %s: Command recorder close", c.sessionID)
for { for {
select { select {
case <-c.closed: case <-c.closed:
...@@ -66,7 +68,6 @@ func (c *CommandRecorder) record() { ...@@ -66,7 +68,6 @@ func (c *CommandRecorder) record() {
} }
case p, ok := <-c.queue: case p, ok := <-c.queue:
if !ok { if !ok {
logger.Debug("session command recorder close: ", c.sessionID)
return return
} }
cmdList = append(cmdList, p) cmdList = append(cmdList, p)
...@@ -136,7 +137,7 @@ func (r *ReplyRecorder) prepare() { ...@@ -136,7 +137,7 @@ func (r *ReplyRecorder) prepare() {
return return
} }
logger.Debug("Replay file path: ", r.absFilePath) logger.Infof("Session %s: Replay file path: %s",r.SessionID, r.absFilePath)
r.file, err = os.Create(r.absFilePath) r.file, err = os.Create(r.absFilePath)
if err != nil { if err != nil {
logger.Errorf("Create file %s error: %s\n", r.absFilePath, err) logger.Errorf("Create file %s error: %s\n", r.absFilePath, err)
...@@ -151,6 +152,8 @@ func (r *ReplyRecorder) End() { ...@@ -151,6 +152,8 @@ func (r *ReplyRecorder) End() {
} }
func (r *ReplyRecorder) uploadReplay() { func (r *ReplyRecorder) uploadReplay() {
logger.Infof("Session %s: Replay recorder is uploading", r.SessionID)
defer logger.Infof("Session %s: Replay recorder has uploaded", r.SessionID)
if !common.FileExists(r.absFilePath) { if !common.FileExists(r.absFilePath) {
logger.Debug("Replay file not found, passed: ", r.absFilePath) logger.Debug("Replay file not found, passed: ", r.absFilePath)
return return
......
package proxy package proxy
import ( import (
"errors"
"sync" "sync"
"time" "time"
...@@ -17,7 +18,7 @@ var lock = new(sync.RWMutex) ...@@ -17,7 +18,7 @@ var lock = new(sync.RWMutex)
func HandleSessionTask(task model.TerminalTask) { func HandleSessionTask(task model.TerminalTask) {
switch task.Name { switch task.Name {
case "kill_session": case "kill_session":
if ok := KillSession(task.Args); ok{ if ok := KillSession(task.Args); ok {
service.FinishTask(task.ID) service.FinishTask(task.ID)
} }
default: default:
...@@ -68,7 +69,7 @@ func CreateSession(p *ProxyServer) (sw *SwitchSession, err error) { ...@@ -68,7 +69,7 @@ func CreateSession(p *ProxyServer) (sw *SwitchSession, err error) {
msg = utils.WrapperWarn(msg) msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
logger.Error(msg) logger.Error(msg)
return return sw, errors.New("connect api server failed")
} }
// 获取系统用户的过滤规则,并设置 // 获取系统用户的过滤规则,并设置
cmdRules, err := service.GetSystemUserFilterRules(p.SystemUser.ID) cmdRules, err := service.GetSystemUserFilterRules(p.SystemUser.ID)
...@@ -76,6 +77,7 @@ func CreateSession(p *ProxyServer) (sw *SwitchSession, err error) { ...@@ -76,6 +77,7 @@ func CreateSession(p *ProxyServer) (sw *SwitchSession, err error) {
msg = utils.WrapperWarn(msg) msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(p.UserConn, msg) utils.IgnoreErrWriteString(p.UserConn, msg)
logger.Error(msg + err.Error()) logger.Error(msg + err.Error())
return sw, errors.New("connect api server failed")
} }
sw.SetFilterRules(cmdRules) sw.SetFilterRules(cmdRules)
AddSession(sw) AddSession(sw)
...@@ -96,5 +98,5 @@ func postSession(s *SwitchSession) bool { ...@@ -96,5 +98,5 @@ func postSession(s *SwitchSession) bool {
func finishSession(s *SwitchSession) { func finishSession(s *SwitchSession) {
data := s.MapData() data := s.MapData()
service.FinishSession(data) service.FinishSession(data)
logger.Debugf("Finish session: %s", s.ID) logger.Debugf("Session %s has finished", s.ID)
} }
...@@ -30,14 +30,11 @@ type SwitchSession struct { ...@@ -30,14 +30,11 @@ type SwitchSession struct {
DateStart string DateStart string
DateEnd string DateEnd string
DateActive time.Time
finished bool finished bool
MaxIdleTime time.Duration MaxIdleTime time.Duration
cmdRecorder *CommandRecorder cmdRules []model.SystemUserFilterRule
replayRecorder *ReplyRecorder
parser *Parser
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
...@@ -47,9 +44,7 @@ func (s *SwitchSession) Initial() { ...@@ -47,9 +44,7 @@ func (s *SwitchSession) Initial() {
s.ID = uuid.NewV4().String() s.ID = uuid.NewV4().String()
s.DateStart = common.CurrentUTCTime() s.DateStart = common.CurrentUTCTime()
s.MaxIdleTime = config.GetConf().MaxIdleTime s.MaxIdleTime = config.GetConf().MaxIdleTime
s.cmdRecorder = NewCommandRecorder(s.ID) s.cmdRules = make([]model.SystemUserFilterRule, 0)
s.replayRecorder = NewReplyRecord(s.ID)
s.parser = newParser(s.ID)
s.ctx, s.cancel = context.WithCancel(context.Background()) s.ctx, s.cancel = context.WithCancel(context.Background())
} }
...@@ -62,16 +57,18 @@ func (s *SwitchSession) Terminate() { ...@@ -62,16 +57,18 @@ func (s *SwitchSession) Terminate() {
s.cancel() s.cancel()
} }
func (s *SwitchSession) recordCommand() { func (s *SwitchSession) recordCommand(cmdRecordChan chan [2]string) {
logger.Infof("Session %s record command start", s.ID) // 命令记录
for command := range s.parser.cmdRecordChan { cmdRecorder := NewCommandRecorder(s.ID)
for command := range cmdRecordChan {
if command[0] == "" { if command[0] == "" {
continue continue
} }
cmd := s.generateCommandResult(command) cmd := s.generateCommandResult(command)
s.cmdRecorder.Record(cmd) cmdRecorder.Record(cmd)
} }
logger.Infof("Session %s record command stop", s.ID) // 关闭命令记录
cmdRecorder.End()
} }
// generateCommandResult 生成命令结果 // generateCommandResult 生成命令结果
...@@ -108,35 +105,53 @@ func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command ...@@ -108,35 +105,53 @@ func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command
func (s *SwitchSession) postBridge() { func (s *SwitchSession) postBridge() {
s.DateEnd = common.CurrentUTCTime() s.DateEnd = common.CurrentUTCTime()
s.finished = true s.finished = true
s.parser.Close()
s.replayRecorder.End()
s.cmdRecorder.End()
} }
// SetFilterRules 设置命令过滤规则 // SetFilterRules 设置命令过滤规则
func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) { func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) {
s.parser.SetCMDFilterRules(cmdRules) if len(cmdRules) > 0 {
s.cmdRules = cmdRules
}
} }
// Bridge 桥接两个链接 // Bridge 桥接两个链接
func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) { func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) {
winCh := userConn.WinCh() var (
parser *Parser
replayRecorder *ReplyRecorder
userInChan chan []byte
srvInChan chan []byte
)
parser = newParser(s.ID)
replayRecorder = NewReplyRecord(s.ID)
userInChan = make(chan []byte, 10)
srvInChan = make(chan []byte, 10)
// 设置parser的命令过滤规则
parser.SetCMDFilterRules(s.cmdRules)
// 处理数据流
userOutChan, srvOutChan := parser.ParseStream(userInChan, srvInChan)
defer func() { defer func() {
_ = userConn.Close() _ = userConn.Close()
_ = srvConn.Close() _ = srvConn.Close()
// 关闭parser
parser.Close()
// 关闭录像
replayRecorder.End()
s.postBridge() s.postBridge()
}() }()
userInChan := make(chan []byte, 10)
srvInChan := make(chan []byte, 10)
// 处理数据流
userOutChan, srvOutChan := s.parser.ParseStream(userInChan, srvInChan)
// 记录命令 // 记录命令
go s.recordCommand() go s.recordCommand(parser.cmdRecordChan)
go LoopRead(userConn, userInChan) go LoopRead(userConn, userInChan)
go LoopRead(srvConn, srvInChan) go LoopRead(srvConn, srvInChan)
winCh := userConn.WinCh()
for { for {
select { select {
// 检测是否超过最大空闲时间 // 检测是否超过最大空闲时间
...@@ -165,8 +180,8 @@ func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerCo ...@@ -165,8 +180,8 @@ func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerCo
return return
} }
nw, _ := userConn.Write(p) nw, _ := userConn.Write(p)
if !s.parser.IsInZmodemRecvState() { if !parser.IsInZmodemRecvState() {
s.replayRecorder.Record(p[:nw]) replayRecorder.Record(p[:nw])
} }
// 经过parse处理的user数据,发给server // 经过parse处理的user数据,发给server
case p, ok := <-userOutChan: case p, ok := <-userOutChan:
......
...@@ -11,7 +11,7 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System ...@@ -11,7 +11,7 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System
Url := fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID) Url := fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID)
_, err := authClient.Get(Url, &info) _, err := authClient.Get(Url, &info)
if err != nil { if err != nil {
logger.Error("Get system user Asset auth info failed") logger.Error("Get system user %s asset %s auth info failed", systemUserID, assetID)
} }
return return
} }
...@@ -61,7 +61,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt ...@@ -61,7 +61,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
_, err = authClient.Get(Url, &rules) _, err = authClient.Get(Url, &rules)
if err != nil { if err != nil {
logger.Error("Get system user auth info failed") logger.Errorf("Get system user %s filter rule failed", systemUserID)
} }
return return
} }
......
...@@ -104,8 +104,8 @@ func (s *SSHClient) isClosed() bool { ...@@ -104,8 +104,8 @@ func (s *SSHClient) isClosed() bool {
func KeepAlive(c *SSHClient, closed <-chan struct{}, keepInterval time.Duration) { func KeepAlive(c *SSHClient, closed <-chan struct{}, keepInterval time.Duration) {
t := time.NewTicker(keepInterval * time.Second) t := time.NewTicker(keepInterval * time.Second)
defer t.Stop() defer t.Stop()
logger.Debugf("SSH client %p keep alive start", c) logger.Infof("SSH client %p keep alive start", c)
defer logger.Debugf("SSH client %p keep alive stop", c) defer logger.Infof("SSH client %p keep alive stop", c)
for { for {
select { select {
case <-closed: case <-closed:
......
...@@ -583,7 +583,7 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) { ...@@ -583,7 +583,7 @@ func (u *UserSftp) SendFTPLog(dataChan <-chan *model.FTPLog) {
if err == nil { if err == nil {
break break
} }
logger.Errorf("create FTP log err: %s", err.Error()) logger.Errorf("Create FTP log err: %s", err.Error())
} }
} }
} }
...@@ -595,14 +595,19 @@ func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser) ...@@ -595,14 +595,19 @@ func (u *UserSftp) GetSftpClient(asset *model.Asset, sysUser *model.SystemUser)
} }
sftpClient, err := sftp.NewClient(sshClient.client) sftpClient, err := sftp.NewClient(sshClient.client)
if err != nil { if err != nil {
logger.Errorf("SSH client %p start sftp client session err %s", sshClient, err)
RecycleClient(sshClient) RecycleClient(sshClient)
return return nil, err
} }
HomeDirPath, err := sftpClient.Getwd() HomeDirPath, err := sftpClient.Getwd()
if err != nil { if err != nil {
return logger.Errorf("SSH client %p get home dir err %s", sshClient, err)
_ = sftpClient.Close()
RecycleClient(sshClient)
return nil, err
} }
logger.Infof("SSH client %p start sftp client session success", sshClient)
conn = &SftpConn{client: sftpClient, conn: sshClient, HomeDirPath: HomeDirPath} conn = &SftpConn{client: sftpClient, conn: sshClient, HomeDirPath: HomeDirPath}
return conn, err return conn, err
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
) )
...@@ -58,12 +59,16 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) { ...@@ -58,12 +59,16 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) { func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection) sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout(), sc.ReuseConnection)
if err != nil { if err != nil {
logger.Errorf("New SSH client err: %s", err)
return return
} }
err = sc.invokeShell(h, w, term) err = sc.invokeShell(h, w, term)
if err != nil { if err != nil {
logger.Errorf("SSH client %p start ssh shell session err %s", sc.client, err)
RecycleClient(sc.client) RecycleClient(sc.client)
return
} }
logger.Infof("SSH client %p start ssh shell session success", sc.client)
return return
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment