Commit b9394148 authored by Eric's avatar Eric Committed by Eric_Lee

fix bug: cpu problem

parent 110b75c6
...@@ -34,9 +34,7 @@ func newParser() *Parser { ...@@ -34,9 +34,7 @@ func newParser() *Parser {
// Parse 解析用户输入输出, 拦截过滤用户输入输出 // Parse 解析用户输入输出, 拦截过滤用户输入输出
type Parser struct { type Parser struct {
userInputChan chan []byte
userOutputChan chan []byte userOutputChan chan []byte
srvInputChan chan []byte
srvOutputChan chan []byte srvOutputChan chan []byte
cmdRecordChan chan [2]string cmdRecordChan chan [2]string
...@@ -65,41 +63,44 @@ func (p *Parser) initial() { ...@@ -65,41 +63,44 @@ func (p *Parser) initial() {
p.cmdOutputParser = NewCmdParser() p.cmdOutputParser = NewCmdParser()
p.closed = make(chan struct{}) p.closed = make(chan struct{})
p.userInputChan = make(chan []byte, 1024)
p.userOutputChan = make(chan []byte, 1024)
p.srvInputChan = make(chan []byte, 1024)
p.srvOutputChan = make(chan []byte, 1024)
p.cmdRecordChan = make(chan [2]string, 1024) p.cmdRecordChan = make(chan [2]string, 1024)
} }
// ParseStream 解析数据流 // ParseStream 解析数据流
func (p *Parser) ParseStream() { func (p *Parser) ParseStream(userInChan, srvInChan <-chan []byte) (userOut, srvOut <-chan []byte) {
defer func() {
close(p.userOutputChan) p.userOutputChan = make(chan []byte, 1)
close(p.srvOutputChan) p.srvOutputChan = make(chan []byte, 1)
close(p.cmdRecordChan)
_ = p.cmdOutputParser.Close() go func() {
_ = p.cmdInputParser.Close() defer func() {
logger.Debug("Parser parse stream routine done") close(p.cmdRecordChan)
}() close(p.userOutputChan)
for { close(p.srvOutputChan)
select { _ = p.cmdOutputParser.Close()
case <-p.closed: _ = p.cmdInputParser.Close()
return logger.Debug("Parser parse stream routine done")
case b, ok := <-p.userInputChan: }()
if !ok { for {
return select {
} case <-p.closed:
b = p.ParseUserInput(b)
p.userOutputChan <- b
case b, ok := <-p.srvInputChan:
if !ok {
return return
case b, ok := <-userInChan:
if !ok {
return
}
b = p.ParseUserInput(b)
p.userOutputChan <- b
case b, ok := <-srvInChan:
if !ok {
return
}
b = p.ParseServerOutput(b)
p.srvOutputChan <- b
} }
b = p.ParseServerOutput(b)
p.srvOutputChan <- b
} }
} }()
return p.userOutputChan, p.srvOutputChan
} }
// Todo: parseMultipleInput 依然存在问题 // Todo: parseMultipleInput 依然存在问题
...@@ -170,7 +171,7 @@ func (p *Parser) parseZmodemState(b []byte) { ...@@ -170,7 +171,7 @@ func (p *Parser) parseZmodemState(b []byte) {
if bytes.Contains(b[:24], zmodemEndMark) { if bytes.Contains(b[:24], zmodemEndMark) {
logger.Debug("Zmodem end") logger.Debug("Zmodem end")
p.zmodemState = "" p.zmodemState = ""
} else if bytes.Contains(b[:24], zmodemCancelMark) { } else if bytes.Contains(b, zmodemCancelMark) {
logger.Debug("Zmodem cancel") logger.Debug("Zmodem cancel")
p.zmodemState = "" p.zmodemState = ""
} }
...@@ -200,9 +201,7 @@ func (p *Parser) splitCmdStream(b []byte) { ...@@ -200,9 +201,7 @@ func (p *Parser) splitCmdStream(b []byte) {
p.cmdInputParser.WriteData(b) p.cmdInputParser.WriteData(b)
return return
} }
// outputBuff 最大存储1024, 否则可能撑爆内存 p.cmdOutputParser.WriteData(b)
// 如果最后一个字符不是ascii, 可以截断了某个中文字符的一部分,为了安全继续添加
p.cmdOutputParser.WriteData(b)
} }
// ParseServerOutput 解析服务器输出 // ParseServerOutput 解析服务器输出
...@@ -249,6 +248,4 @@ func (p *Parser) Close() { ...@@ -249,6 +248,4 @@ func (p *Parser) Close() {
close(p.closed) close(p.closed)
} }
close(p.userInputChan)
close(p.srvInputChan)
} }
...@@ -3,6 +3,7 @@ package proxy ...@@ -3,6 +3,7 @@ package proxy
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"strings" "strings"
"time" "time"
...@@ -38,9 +39,6 @@ type SwitchSession struct { ...@@ -38,9 +39,6 @@ type SwitchSession struct {
replayRecorder *ReplyRecorder replayRecorder *ReplyRecorder
parser *Parser parser *Parser
userTran Transport
srvTran Transport
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
...@@ -111,8 +109,6 @@ func (s *SwitchSession) postBridge() { ...@@ -111,8 +109,6 @@ func (s *SwitchSession) postBridge() {
s.parser.Close() s.parser.Close()
s.replayRecorder.End() s.replayRecorder.End()
s.cmdRecorder.End() s.cmdRecorder.End()
_ = s.userTran.Close()
_ = s.srvTran.Close()
} }
// SetFilterRules 设置命令过滤规则 // SetFilterRules 设置命令过滤规则
...@@ -123,26 +119,30 @@ func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) { ...@@ -123,26 +119,30 @@ func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) {
// Bridge 桥接两个链接 // Bridge 桥接两个链接
func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) { func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) {
winCh := userConn.WinCh() winCh := userConn.WinCh()
// 将ReadWriter转换为Channel读写
s.srvTran = NewDirectTransport(s.ID, srvConn)
s.userTran = NewDirectTransport(s.ID, userConn)
defer func() { defer func() {
logger.Info("Session bridge done: ", s.ID) logger.Info("Session bridge done: ", s.ID)
_ = userConn.Close()
_ = srvConn.Close()
s.postBridge() s.postBridge()
}() }()
userInChan := make(chan []byte, 10)
srvInChan := make(chan []byte, 10)
// 处理数据流 // 处理数据流
go s.parser.ParseStream() userOutChan, srvOutChan := s.parser.ParseStream(userInChan, srvInChan)
// 记录命令 // 记录命令
go s.recordCommand() go s.recordCommand()
go LoopRead(userConn, userInChan)
go LoopRead(srvConn, srvInChan)
for { for {
select { select {
// 检测是否超过最大空闲时间 // 检测是否超过最大空闲时间
case <-time.After(s.MaxIdleTime * time.Minute): case <-time.After(s.MaxIdleTime * time.Minute):
msg := fmt.Sprintf(i18n.T("Connect idle more than %d minutes, disconnect"), s.MaxIdleTime) msg := fmt.Sprintf(i18n.T("Connect idle more than %d minutes, disconnect"), s.MaxIdleTime)
msg = utils.WrapperWarn(msg) msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(s.userTran, "\n\r"+msg) utils.IgnoreErrWriteString(userConn, "\n\r"+msg)
return return
// 手动结束 // 手动结束
case <-s.ctx.Done(): case <-s.ctx.Done():
...@@ -151,36 +151,27 @@ func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerCo ...@@ -151,36 +151,27 @@ func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerCo
utils.IgnoreErrWriteString(userConn, "\n\r"+msg) utils.IgnoreErrWriteString(userConn, "\n\r"+msg)
return return
// 监控窗口大小变化 // 监控窗口大小变化
case win := <-winCh: case win, ok := <-winCh:
_ = srvConn.SetWinSize(win.Height, win.Width)
logger.Debugf("Window server change: %d*%d", win.Height, win.Width)
// Server发来数据流入parser中
case p, ok := <-s.srvTran.Chan():
if !ok { if !ok {
return return
} }
s.parser.srvInputChan <- p _ = srvConn.SetWinSize(win.Height, win.Width)
// Server流入parser数据,经处理发给用户 logger.Debugf("Window server change: %d*%d", win.Height, win.Width)
case p, ok := <-s.parser.srvOutputChan: // 经过parse处理的server数据,发给user
case p, ok := <-srvOutChan:
if !ok { if !ok {
return return
} }
nw, _ := s.userTran.Write(p) nw, _ := userConn.Write(p)
if !s.parser.IsInZmodemRecvState() { if !s.parser.IsInZmodemRecvState() {
s.replayRecorder.Record(p[:nw]) s.replayRecorder.Record(p[:nw])
} }
// User发来的数据流流入parser // 经过parse处理的user数据,发给server
case p, ok := <-s.userTran.Chan(): case p, ok := <-userOutChan:
if !ok { if !ok {
return return
} }
s.parser.userInputChan <- p _, err = srvConn.Write(p)
// User发来的数据经parser处理,发给Server
case p, ok := <-s.parser.userOutputChan:
if !ok {
return
}
_, err = s.srvTran.Write(p)
} }
} }
} }
...@@ -203,3 +194,18 @@ func (s *SwitchSession) MapData() map[string]interface{} { ...@@ -203,3 +194,18 @@ func (s *SwitchSession) MapData() map[string]interface{} {
"date_end": dataEnd, "date_end": dataEnd,
} }
} }
func LoopRead(read io.Reader, inChan chan<- []byte) {
defer logger.Debug("loop read end")
for {
buf := make([]byte, 1024)
nr, err := read.Read(buf)
if err != nil {
break
}
if nr > 0 {
inChan <- buf[:nr]
}
}
close(inChan)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment