Commit e7f9565f authored by ibuler's avatar ibuler

[Update] 修改使用client cache

parent 22beb659
......@@ -33,7 +33,7 @@ func (c *Coco) Stop() {
}
func RunForever() {
loadingBoot()
bootstrap()
gracefulStop := make(chan os.Signal)
signal.Notify(gracefulStop, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
app := &Coco{}
......@@ -42,7 +42,7 @@ func RunForever() {
app.Stop()
}
func loadingBoot() {
func bootstrap() {
config.Initial()
logger.Initial()
service.Initial()
......
......@@ -16,13 +16,13 @@ import (
func Initial() {
conf := config.GetConf()
if conf.UploadFailedReplay {
go uploadFailedReplay(conf.RootPath)
go uploadRemainReplay(conf.RootPath)
}
go keepHeartbeat(conf.HeartbeatDuration)
}
func uploadFailedReplay(rootPath string) {
func uploadRemainReplay(rootPath string) {
replayDir := filepath.Join(rootPath, "data", "replays")
err := common.EnsureDirExist(replayDir)
if err != nil {
......@@ -47,11 +47,11 @@ func uploadFailedReplay(rootPath string) {
}
return nil
})
logger.Debug("upload Replay Done")
logger.Debug("Upload remain replay done")
}
func keepHeartbeat(interval int) {
tick := time.Tick(time.Duration(interval) * time.Second)
func keepHeartbeat(interval time.Duration) {
tick := time.Tick(interval * time.Second)
for {
select {
case <-tick:
......@@ -63,6 +63,5 @@ func keepHeartbeat(interval int) {
}
}
}
}
}
......@@ -7,6 +7,7 @@ import (
"os"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
......@@ -20,9 +21,9 @@ type Config struct {
PublicKeyAuth bool `json:"TERMINAL_PUBLIC_KEY_AUTH" yaml:"PUBLIC_KEY_AUTH"`
CommandStorage map[string]interface{} `json:"TERMINAL_COMMAND_STORAGE"`
ReplayStorage map[string]interface{} `json:"TERMINAL_REPLAY_STORAGE" yaml:"REPLAY_STORAGE"`
SessionKeepDuration int `json:"TERMINAL_SESSION_KEEP_DURATION"`
SessionKeepDuration time.Duration `json:"TERMINAL_SESSION_KEEP_DURATION"`
TelnetRegex string `json:"TERMINAL_TELNET_REGEX"`
MaxIdleTime int `json:"SECURITY_MAX_IDLE_TIME"`
MaxIdleTime time.Duration `json:"SECURITY_MAX_IDLE_TIME"`
SftpRoot string `json:"TERMINAL_SFTP_ROOT" yaml:"SFTP_ROOT"`
Name string `yaml:"NAME"`
SecretKey string `yaml:"SECRET_KEY"`
......@@ -30,13 +31,13 @@ type Config struct {
CoreHost string `yaml:"CORE_HOST"`
BootstrapToken string `yaml:"BOOTSTRAP_TOKEN"`
BindHost string `yaml:"BIND_HOST"`
SSHPort int `yaml:"SSHD_PORT"`
HTTPPort int `yaml:"HTTPD_PORT"`
SSHTimeout int `yaml:"SSH_TIMEOUT"`
SSHPort string `yaml:"SSHD_PORT"`
HTTPPort string `yaml:"HTTPD_PORT"`
SSHTimeout time.Duration `yaml:"SSH_TIMEOUT"`
AccessKey string `yaml:"ACCESS_KEY"`
AccessKeyFile string `yaml:"ACCESS_KEY_FILE"`
LogLevel string `yaml:"LOG_LEVEL"`
HeartbeatDuration int `yaml:"HEARTBEAT_INTERVAL"`
HeartbeatDuration time.Duration `yaml:"HEARTBEAT_INTERVAL"`
RootPath string `yaml:"ROOT_PATH"`
Comment string `yaml:"COMMENT"`
Language string `yaml:"LANG"`
......@@ -110,9 +111,9 @@ var Conf = &Config{
CoreHost: "http://localhost:8080",
BootstrapToken: "",
BindHost: "0.0.0.0",
SSHPort: 2222,
SSHPort: "2222",
SSHTimeout: 15,
HTTPPort: 5000,
HTTPPort: "5000",
HeartbeatDuration: 10,
AccessKey: "",
AccessKeyFile: "data/keys/.access_key",
......
......@@ -40,7 +40,10 @@ func (c *assetsCacheContainer) SetValue(key string, value []model.Asset) {
c.mapData[key] = value
}
var userAssetsCached = assetsCacheContainer{mapData: make(map[string][]model.Asset), lock: new(sync.RWMutex)}
var userAssetsCached = assetsCacheContainer{
mapData: make(map[string][]model.Asset),
lock: new(sync.RWMutex),
}
func SessionHandler(sess ssh.Session) {
pty, _, ok := sess.Pty()
......
......@@ -2,7 +2,6 @@ package httpd
import (
"net/http"
"strconv"
"sync"
"github.com/googollee/go-socket.io"
......@@ -36,6 +35,6 @@ func StartHTTPServer() {
http.Handle("/socket.io/", server)
logger.Debug("start HTTP Serving ", conf.HTTPPort)
httpServer = &http.Server{Addr: conf.BindHost + ":" + strconv.Itoa(conf.HTTPPort), Handler: nil}
httpServer = &http.Server{Addr: conf.BindHost + ":" + conf.HTTPPort, Handler: nil}
logger.Fatal(httpServer.ListenAndServe())
}
......@@ -91,6 +91,23 @@ type Asset struct {
OrgName string `json:"org_name"`
}
type Gateway struct {
ID string `json:"id"`
Name string `json:"Name"`
IP string `json:"ip"`
Port int `json:"port"`
Protocol string `json:"protocol"`
Username string `json:"username"`
Password string `json:"password"`
PrivateKey string `json:"private_key"`
}
type Domain struct {
ID string `json:"id"`
Gateways []Gateway `json:"gateways"`
Name string `json:"name"`
}
type Node struct {
Id string `json:"id"`
Key string `json:"key"`
......
......@@ -80,11 +80,11 @@ func (p *Parser) initial() {
p.cmdRecordChan = make(chan [2]string, 1024)
}
func (p *Parser) Parse() {
func (p *Parser) ParseStream() {
defer func() {
close(p.userOutputChan)
close(p.srvOutputChan)
logger.Debug("Parser parse routine done")
logger.Debug("Parser parse stream routine done")
}()
for {
select {
......
package proxy
import (
"cocogo/pkg/srvconn"
"cocogo/pkg/utils"
"fmt"
"regexp"
"strconv"
"strings"
"time"
......@@ -14,6 +11,8 @@ import (
"cocogo/pkg/logger"
"cocogo/pkg/model"
"cocogo/pkg/service"
"cocogo/pkg/srvconn"
"cocogo/pkg/utils"
)
type ProxyServer struct {
......@@ -23,11 +22,8 @@ type ProxyServer struct {
SystemUser *model.SystemUser
}
// getSystemUserAuthOrManualSet 获取系统用户的认证信息或手动设置
func (p *ProxyServer) getSystemUserAuthOrManualSet() {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.Id, p.Asset.Id)
p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey
if p.SystemUser.LoginMode == model.LoginModeManual ||
(p.SystemUser.Password == "" && p.SystemUser.PrivateKey == "") {
term := utils.NewTerminal(p.UserConn, "password: ")
......@@ -35,11 +31,16 @@ func (p *ProxyServer) getSystemUserAuthOrManualSet() {
if err != nil {
logger.Errorf("Get password from user err %s", err.Error())
}
logger.Debug("Get password from user input: ", line)
p.SystemUser.Password = line
logger.Debug("Get password from user input: ", line)
} else {
info := service.GetSystemUserAssetAuthInfo(p.SystemUser.Id, p.Asset.Id)
p.SystemUser.Password = info.Password
p.SystemUser.PrivateKey = info.PrivateKey
}
}
// getSystemUserUsernameIfNeed 获取系统用户用户名,或手动设置
func (p *ProxyServer) getSystemUserUsernameIfNeed() {
if p.SystemUser.Username == "" {
var username string
......@@ -52,69 +53,76 @@ func (p *ProxyServer) getSystemUserUsernameIfNeed() {
}
}
p.SystemUser.Username = username
logger.Info("Get username from user input: ", username)
logger.Debug("Get username from user input: ", username)
}
}
// checkProtocolMatch 检查协议是否匹配
func (p *ProxyServer) checkProtocolMatch() bool {
return p.SystemUser.Protocol == p.Asset.Protocol
}
// checkProtocolIsGraph 检查协议是否是图形化的
func (p *ProxyServer) checkProtocolIsGraph() bool {
switch p.Asset.Protocol {
case "ssh", "telnet":
return true
default:
return false
default:
return true
}
}
// validatePermission 检查是否有权限连接
func (p *ProxyServer) validatePermission() bool {
return true
return service.ValidateUserAssetPermission(
p.User.ID, p.Asset.Id, p.SystemUser.Id, "connect",
)
}
func (p *ProxyServer) getSSHConn() (srvConn *srvconn.ServerSSHConnection, err error) {
proxyConfig := &srvconn.SSHClientConfig{}
sshConfig := srvconn.SSHClientConfig{
Host: p.Asset.Ip,
Port: strconv.Itoa(p.Asset.Port),
User: p.SystemUser.Username,
Password: p.SystemUser.Password,
PrivateKey: p.SystemUser.PrivateKey,
Overtime: config.GetConf().SSHTimeout,
Proxy: proxyConfig,
}
// getSSHConn 获取ssh连接
func (p *ProxyServer) getSSHConn(fromCache ...bool) (srvConn *srvconn.ServerSSHConnection, err error) {
pty := p.UserConn.Pty()
srvConn = &srvconn.ServerSSHConnection{
Name: p.Asset.Hostname,
Creator: p.User.Username,
SSHClientConfig: sshConfig,
User: p.User,
Asset: p.Asset,
SystemUser: p.SystemUser,
Overtime: time.Duration(config.GetConf().SSHTimeout) * time.Second,
}
if len(fromCache) > 0 && fromCache[0] {
err = srvConn.TryConnectFromCache(pty.Window.Height, pty.Window.Width, pty.Term)
} else {
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
}
pty := p.UserConn.Pty()
err = srvConn.Connect(pty.Window.Height, pty.Window.Width, pty.Term)
fmt.Println("Error: ", err)
return
}
// getTelnetConn 获取telnet连接
func (p *ProxyServer) getTelnetConn() (srvConn *srvconn.ServerTelnetConnection, err error) {
conf := config.GetConf()
cusString := conf.TelnetRegex
pattern, _ := regexp.Compile(cusString)
srvConn = &srvconn.ServerTelnetConnection{
Name: p.Asset.Hostname,
Creator: p.User.ID,
Host: p.Asset.Ip,
Port: strconv.Itoa(p.Asset.Port),
User: p.SystemUser.Username,
Password: p.SystemUser.Password,
User: p.User,
Asset: p.Asset,
SystemUser: p.SystemUser,
CustomString: cusString,
CustomSuccessPattern: pattern,
Overtime: conf.SSHTimeout,
Overtime: time.Duration(conf.SSHTimeout) * time.Second,
}
err = srvConn.Connect(0, 0, "")
utils.IgnoreErrWriteString(p.UserConn, "\r\n")
return
}
// getServerConnFromCache 从cache中获取ssh server连接
func (p *ProxyServer) getServerConnFromCache() (srvConn srvconn.ServerConnection, err error) {
if p.SystemUser.Protocol == "ssh" {
srvConn, err = p.getSSHConn(true)
}
return
}
// getServerConn 获取获取server连接
func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err error) {
p.getSystemUserUsernameIfNeed()
p.getSystemUserAuthOrManualSet()
......@@ -123,19 +131,20 @@ func (p *ProxyServer) getServerConn() (srvConn srvconn.ServerConnection, err err
utils.IgnoreErrWriteString(p.UserConn, "\r\n")
close(done)
}()
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout)
go p.sendConnectingMsg(done, config.GetConf().SSHTimeout*time.Second)
if p.Asset.Protocol == "telnet" {
return p.getTelnetConn()
} else {
return p.getSSHConn()
return p.getSSHConn(false)
}
}
func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delaySecond int) {
// sendConnectingMsg 发送连接信息
func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delayDuration time.Duration) {
delay := 0.0
msg := fmt.Sprintf(i18n.T("Connecting to %s@%s %.1f"), p.SystemUser.Username, p.Asset.Ip, delay)
utils.IgnoreErrWriteString(p.UserConn, msg)
for int(delay) < delaySecond {
for int(delay) < int(delayDuration/time.Second) {
select {
case <-done:
return
......@@ -149,6 +158,7 @@ func (p *ProxyServer) sendConnectingMsg(done chan struct{}, delaySecond int) {
}
}
// preCheckRequisite 检查是否满足条件
func (p *ProxyServer) preCheckRequisite() (ok bool) {
if !p.checkProtocolMatch() {
msg := utils.WrapperWarn(i18n.T("System user <%s> and asset <%s> protocol are inconsistent."))
......@@ -156,7 +166,7 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) {
utils.IgnoreErrWriteString(p.UserConn, msg)
return
}
if !p.checkProtocolIsGraph() {
if p.checkProtocolIsGraph() {
msg := i18n.T("Terminal only support protocol ssh/telnet, please use web terminal to access")
msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(p.UserConn, msg)
......@@ -170,11 +180,17 @@ func (p *ProxyServer) preCheckRequisite() (ok bool) {
return true
}
// Proxy 代理
func (p *ProxyServer) Proxy() {
if !p.preCheckRequisite() {
return
}
srvConn, err := p.getServerConn()
// 先从cache中获取srv连接, 如果没有获得,则连接
srvConn, err := p.getServerConnFromCache()
if err != nil || srvConn == nil {
srvConn, err = p.getServerConn()
}
if err != nil {
msg := fmt.Sprintf("Connect asset %s error: %s\n\r", p.Asset.Hostname, err)
utils.IgnoreErrWriteString(p.UserConn, msg)
......@@ -193,8 +209,11 @@ func (p *ProxyServer) Proxy() {
sw.SetFilterRules(cmdRules)
AddSession(sw)
_ = sw.Bridge(p.UserConn, srvConn)
p.finishSession(sw)
RemoveSession(sw)
defer func() {
_ = srvConn.Close()
p.finishSession(sw)
RemoveSession(sw)
}()
}
func (p *ProxyServer) createSession(s *SwitchSession) bool {
......@@ -212,7 +231,7 @@ func (p *ProxyServer) finishSession(s *SwitchSession) {
data := s.MapData()
service.FinishSession(data)
service.FinishReply(s.Id)
logger.Debugf("finish session: %s", s.Id)
logger.Debugf("Finish session: %s", s.Id)
}
func (p *ProxyServer) GetFilterRules() []model.SystemUserFilterRule {
......
......@@ -31,7 +31,7 @@ type SwitchSession struct {
DateActive time.Time
finished bool
MaxIdleTime int
MaxIdleTime time.Duration
cmdRecorder *CommandRecorder
replayRecorder *ReplyRecorder
......@@ -50,9 +50,7 @@ func (s *SwitchSession) Initial() {
s.MaxIdleTime = config.GetConf().MaxIdleTime
s.cmdRecorder = NewCommandRecorder(s.Id)
s.replayRecorder = NewReplyRecord(s.Id)
s.parser = newParser()
s.ctx, s.cancel = context.WithCancel(context.Background())
}
......@@ -75,6 +73,7 @@ func (s *SwitchSession) recordCommand() {
}
}
// generateCommandResult 生成命令结果
func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command {
var input string
var output string
......@@ -104,6 +103,7 @@ func (s *SwitchSession) generateCommandResult(command [2]string) *model.Command
}
}
// postBridge 桥接结束以后执行操作
func (s *SwitchSession) postBridge() {
s.DateEnd = time.Now().UTC().Format("2006-01-02 15:04:05 +0000")
s.finished = true
......@@ -114,26 +114,31 @@ func (s *SwitchSession) postBridge() {
_ = s.srvTran.Close()
}
// SetFilterRules 设置命令过滤规则
func (s *SwitchSession) SetFilterRules(cmdRules []model.SystemUserFilterRule) {
s.parser.SetCMDFilterRules(cmdRules)
}
// Bridge 桥接两个链接
func (s *SwitchSession) Bridge(userConn UserConnection, srvConn srvconn.ServerConnection) (err error) {
winCh := userConn.WinCh()
// 将ReadWriter转换为Channel读写
s.srvTran = NewDirectTransport(s.Id, srvConn)
s.userTran = NewDirectTransport(s.Id, userConn)
defer func() {
logger.Info("session bridge done: ", s.Id)
logger.Info("Session bridge done: ", s.Id)
s.postBridge()
}()
go s.parser.Parse()
// 处理数据流
go s.parser.ParseStream()
// 记录命令
go s.recordCommand()
defer s.postBridge()
for {
select {
// 检测是否超过最大空闲时间
case <-time.After(time.Duration(s.MaxIdleTime) * time.Minute):
case <-time.After(s.MaxIdleTime * time.Minute):
msg := fmt.Sprintf(i18n.T("Connect idle more than %d minutes, disconnect"), s.MaxIdleTime)
msg = utils.WrapperWarn(msg)
utils.IgnoreErrWriteString(s.userTran, "\n\r"+msg)
......
......@@ -16,15 +16,6 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System
return
}
func GetSystemUserAuthInfo(systemUserID string) (info model.SystemUserAuthInfo) {
Url := fmt.Sprintf(SystemUserAuthInfoURL, systemUserID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Error("Get system user auth info failed")
}
return
}
func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilterRule, err error) {
/*[
{
......@@ -76,7 +67,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
}
func GetSystemUser(systemUserID string) (info model.SystemUser) {
Url := fmt.Sprintf(SystemUser, systemUserID)
Url := fmt.Sprintf(SystemUserDetailURL, systemUserID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Errorf("Get system user %s failed", systemUserID)
......@@ -85,16 +76,25 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) {
}
func GetAsset(assetID string) (asset model.Asset) {
Url := fmt.Sprintf(Asset, assetID)
Url := fmt.Sprintf(AssetDetailURL, assetID)
err := authClient.Get(Url, &asset)
if err != nil {
logger.Errorf("Get Asset %s failed", assetID)
logger.Errorf("Get Asset %s failed\n", assetID)
}
return
}
func GetDomainWithGateway(gID string) (domain model.Domain) {
url := fmt.Sprintf(DomainDetailURL, gID)
err := authClient.Get(url, &domain)
if err != nil {
logger.Errorf("Get domain %s failed", gID)
}
return
}
func GetTokenAsset(token string) (tokenUser model.TokenUser) {
Url := fmt.Sprintf(TokenAsset, token)
Url := fmt.Sprintf(TokenAssetUrl, token)
err := authClient.Get(Url, &tokenUser)
if err != nil {
logger.Error("Get Token Asset info failed")
......
package service
const (
UserAuthURL = "/api/users/v1/auth/" // post 验证用户登陆
UserProfileURL = "/api/users/v1/profile/" // 获取当前用户的基本信息
UserListUrl = "/api/users/v1/users/" // 用户列表地址
UserDetailURL = "/api/users/v1/users/%s/" // 获取用户信息
UserAuthOTPURL = "/api/users/v1/otp/auth/" // 验证OTP
AuthMFAURL = "/api/authentication/v1/otp/auth/" // MFA 验证用户信息
UserAuthURL = "/api/users/v1/auth/" // post 验证用户登陆
UserProfileURL = "/api/users/v1/profile/" // 获取当前用户的基本信息
UserListUrl = "/api/users/v1/users/" // 用户列表地址
UserDetailURL = "/api/users/v1/users/%s/" // 获取用户信息
UserAuthOTPURL = "/api/authentication/v1/otp/auth/" // 验证OTP
SystemUserAssetAuthURL = "/api/assets/v1/system-user/%s/asset/%s/auth-info/" // 该系统用户对某资产的授权
SystemUserAuthInfoURL = "/api/assets/v1/system-user/%s/auth-info/" // 该系统用户的授权
SystemUserCmdFilterRules = "/api/assets/v1/system-user/%s/cmd-filter-rules/" // 过滤规则url
SystemUser = "/api/assets/v1/system-user/%s" // 某个系统用户的信息
Asset = "/api/assets/v1/assets/%s/" // 某一个资产信息
TokenAsset = "/api/users/v1/connection-token/?token=%s" // Token name
SystemUserDetailURL = "/api/assets/v1/system-user/%s/" // 某个系统用户的信息
AssetDetailURL = "/api/assets/v1/assets/%s/" // 某一个资产信息
DomainDetailURL = "/api/assets/v1/domain/%s/"
TokenAssetUrl = "/api/users/v1/connection-token/?token=%s" // Token name
TerminalRegisterURL = "/api/terminal/v2/terminal-registrations/" // 注册当前coco
TerminalConfigURL = "/api/terminal/v1/terminal/config/" // 从jumpserver获取coco的配置
......
......@@ -24,10 +24,6 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) (
"login_type": loginType,
}
err = client.Post(UserAuthURL, data, &resp)
if err != nil {
logger.Error(err)
}
return
}
......
package srvconn
import (
"io"
"time"
)
type ServerConnection interface {
io.ReadWriteCloser
Timeout() time.Duration
Protocol() string
SetWinSize(w, h int) error
}
package srvconn
import (
"cocogo/pkg/service"
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"
gossh "golang.org/x/crypto/ssh"
"cocogo/pkg/config"
"cocogo/pkg/logger"
"cocogo/pkg/model"
)
......@@ -18,21 +21,139 @@ var (
clientLock = new(sync.RWMutex)
)
func newClient(user *model.User, asset *model.Asset,
systemUser *model.SystemUser) (client *gossh.Client, err error) {
cfg := SSHClientConfig{
type SSHClientConfig struct {
Host string
Port string
User string
Password string
PrivateKey string
PrivateKeyPath string
Timeout time.Duration
Proxy []*SSHClientConfig
proxyConn gossh.Conn
}
func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
authMethods := make([]gossh.AuthMethod, 0)
if sc.Password != "" {
authMethods = append(authMethods, gossh.Password(sc.Password))
}
if sc.PrivateKeyPath != "" {
if pubkey, err := GetPubKeyFromFile(sc.PrivateKeyPath); err != nil {
err = fmt.Errorf("parse private key from file error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(pubkey))
}
}
if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(sc.PrivateKey)); err != nil {
err = fmt.Errorf("parse private key error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(signer))
}
}
config = &gossh.ClientConfig{
User: sc.User,
Auth: authMethods,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Timeout: sc.Timeout,
}
return config, nil
}
func (sc *SSHClientConfig) DialProxy() (client *gossh.Client, err error) {
for _, p := range sc.Proxy {
client, err = p.Dial()
if err == nil {
return
}
}
return
}
func (sc *SSHClientConfig) Dial() (client *gossh.Client, err error) {
cfg, err := sc.Config()
if err != nil {
return
}
if len(sc.Proxy) > 0 {
logger.Debugf("Dial host proxy first")
proxyClient, err := sc.DialProxy()
if err != nil {
err = errors.New("connect proxy host error 1: " + err.Error())
logger.Error("Connect proxy host error 1: ", err.Error())
return client, err
}
proxySock, err := proxyClient.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port))
if err != nil {
err = errors.New("connect proxy host error 2: " + err.Error())
logger.Error("Connect proxy host error 2: ", err.Error())
return client, err
}
proxyConn, chans, reqs, err := gossh.NewClientConn(proxySock, net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return client, err
}
sc.proxyConn = proxyConn
client = gossh.NewClient(proxyConn, chans, reqs)
} else {
logger.Debugf("Dial host %s:%s", sc.Host, sc.Port)
client, err = gossh.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return
}
}
return client, nil
}
func (sc *SSHClientConfig) String() string {
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port)
}
func newClient(asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
proxyConfigs := make([]*SSHClientConfig, 0)
// 如果有网关则从网关中连接
if asset.Domain != "" {
domain := service.GetDomainWithGateway(asset.Domain)
if domain.ID != "" && len(domain.Gateways) > 1 {
for _, gateway := range domain.Gateways {
proxyConfigs = append(proxyConfigs, &SSHClientConfig{
Host: gateway.IP,
Port: strconv.Itoa(gateway.Port),
User: gateway.Username,
Password: gateway.Password,
PrivateKey: gateway.PrivateKey,
Timeout: timeout,
})
}
}
}
sshConfig := SSHClientConfig{
Host: asset.Ip,
Port: strconv.Itoa(asset.Port),
User: systemUser.Username,
Password: systemUser.Password,
PrivateKey: systemUser.PrivateKey,
Overtime: config.GetConf().SSHTimeout,
Timeout: timeout,
Proxy: proxyConfigs,
}
sshConfig = SSHClientConfig{
Host: "127.0.0.1",
Port: "22",
User: "root",
Password: "redhat",
Proxy: []*SSHClientConfig{
{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"},
},
}
client, err = cfg.Dial()
client, err = sshConfig.Dial()
return
}
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client, err error) {
func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUser, timeout time.Duration) (client *gossh.Client, err error) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.Id, systemUser.Id)
clientLock.RLock()
client, ok := sshClients[key]
......@@ -52,7 +173,7 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return client, nil
}
client, err = newClient(user, asset, systemUser)
client, err = newClient(asset, systemUser, timeout)
if err == nil {
clientLock.Lock()
sshClients[key] = client
......@@ -62,6 +183,18 @@ func NewClient(user *model.User, asset *model.Asset, systemUser *model.SystemUse
return
}
func GetClientFromCache(user *model.User, asset *model.Asset, systemUser *model.SystemUser) (client *gossh.Client) {
key := fmt.Sprintf("%s_%s_%s", user.ID, asset.Id, systemUser.Id)
clientLock.Lock()
defer clientLock.Unlock()
client, ok := sshClients[key]
if !ok {
return
}
clientsRefCounter[client]++
return
}
func RecycleClient(client *gossh.Client) {
clientLock.Lock()
defer clientLock.Unlock()
......
......@@ -10,7 +10,9 @@ var testConnection = SSHClientConfig{
Port: "22",
User: "root",
Password: "redhat",
Proxy: &SSHClientConfig{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"},
Proxy: []*SSHClientConfig{
{Host: "192.168.244.185", Port: "22", User: "root", Password: "redhat"},
},
}
func TestSSHConnection_Config(t *testing.T) {
......
package srvconn
import (
"fmt"
"github.com/pkg/errors"
"io"
"net"
"time"
gossh "golang.org/x/crypto/ssh"
)
type ServerConnection interface {
io.ReadWriteCloser
Timeout() time.Duration
Protocol() string
SetWinSize(w, h int) error
}
type SSHClientConfig struct {
Host string
Port string
User string
Password string
PrivateKey string
PrivateKeyPath string
Overtime int
Proxy *SSHClientConfig
proxyConn gossh.Conn
}
func (sc *SSHClientConfig) Timeout() time.Duration {
if sc.Overtime == 0 {
sc.Overtime = 30
}
return time.Duration(sc.Overtime) * time.Second
}
func (sc *SSHClientConfig) Config() (config *gossh.ClientConfig, err error) {
authMethods := make([]gossh.AuthMethod, 0)
if sc.Password != "" {
authMethods = append(authMethods, gossh.Password(sc.Password))
}
if sc.PrivateKeyPath != "" {
if pubkey, err := GetPubKeyFromFile(sc.PrivateKeyPath); err != nil {
err = fmt.Errorf("parse private key from file error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(pubkey))
}
}
if sc.PrivateKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(sc.PrivateKey)); err != nil {
err = fmt.Errorf("parse private key error: %s", err)
return config, err
} else {
authMethods = append(authMethods, gossh.PublicKeys(signer))
}
}
config = &gossh.ClientConfig{
User: sc.User,
Auth: authMethods,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Timeout: sc.Timeout(),
}
return config, nil
}
func (sc *SSHClientConfig) Dial() (client *gossh.Client, err error) {
cfg, err := sc.Config()
if err != nil {
return
}
if sc.Proxy != nil && sc.Proxy.Host != "" {
proxyClient, err := sc.Proxy.Dial()
if err != nil {
err = errors.New("connect proxy Host error1: " + err.Error())
return client, err
}
proxySock, err := proxyClient.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port))
if err != nil {
err = errors.New("connect proxy Host error2: " + err.Error())
return client, err
}
proxyConn, chans, reqs, err := gossh.NewClientConn(proxySock, net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return client, err
}
sc.proxyConn = proxyConn
client = gossh.NewClient(proxyConn, chans, reqs)
} else {
client, err = gossh.Dial("tcp", net.JoinHostPort(sc.Host, sc.Port), cfg)
if err != nil {
return
}
}
return client, nil
}
func (sc *SSHClientConfig) String() string {
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port)
}
"cocogo/pkg/model"
)
type ServerSSHConnection struct {
SSHClientConfig
Name string
Creator string
client *gossh.Client
session *gossh.Session
stdin io.WriteCloser
stdout io.Reader
closed bool
refCount int
User *model.User
Asset *model.Asset
SystemUser *model.SystemUser
Overtime time.Duration
client *gossh.Client
session *gossh.Session
stdin io.WriteCloser
stdout io.Reader
closed bool
refCount int
connected bool
}
func (sc *ServerSSHConnection) Protocol() string {
return "ssh"
}
func (sc *ServerSSHConnection) String() string {
return fmt.Sprintf("%s@%s:%s", sc.User, sc.Host, sc.Port)
}
func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
sess, err := sc.client.NewSession()
if err != nil {
......@@ -151,14 +57,29 @@ func (sc *ServerSSHConnection) invokeShell(h, w int, term string) (err error) {
}
func (sc *ServerSSHConnection) Connect(h, w int, term string) (err error) {
sc.client, err = sc.Dial()
sc.client, err = NewClient(sc.User, sc.Asset, sc.SystemUser, sc.Timeout())
if err != nil {
return
}
err = sc.invokeShell(h, w, term)
if err != nil {
return
}
sc.connected = true
return nil
}
func (sc *ServerSSHConnection) TryConnectFromCache(h, w int, term string) (err error) {
sc.client = GetClientFromCache(sc.User, sc.Asset, sc.SystemUser)
if sc.client == nil {
return errors.New("No client in cache")
}
err = sc.invokeShell(h, w, term)
if err != nil {
return
}
sc.connected = true
return nil
}
......@@ -174,15 +95,21 @@ func (sc *ServerSSHConnection) Write(p []byte) (n int, err error) {
return sc.stdin.Write(p)
}
func (sc *ServerSSHConnection) Timeout() time.Duration {
if sc.Overtime == 0 {
sc.Overtime = 30 * time.Second
}
return sc.Overtime
}
func (sc *ServerSSHConnection) Close() (err error) {
if sc.closed {
RecycleClient(sc.client)
if sc.closed || !sc.connected {
return
}
err = sc.session.Close()
err = sc.client.Close()
if sc.proxyConn != nil {
err = sc.proxyConn.Close()
if err != nil {
return
}
sc.closed = true
return
}
......@@ -2,12 +2,13 @@ package srvconn
import (
"bytes"
"cocogo/pkg/model"
"errors"
"net"
"regexp"
"strconv"
"time"
"github.com/pkg/errors"
"cocogo/pkg/logger"
)
......@@ -45,13 +46,10 @@ const (
)
type ServerTelnetConnection struct {
Name string
Creator string
Host string
Port string
User string
Password string
Overtime int
User *model.User
Asset *model.Asset
SystemUser *model.SystemUser
Overtime time.Duration
CustomString string
CustomSuccessPattern *regexp.Regexp
......@@ -61,9 +59,9 @@ type ServerTelnetConnection struct {
func (tc *ServerTelnetConnection) Timeout() time.Duration {
if tc.Overtime == 0 {
tc.Overtime = 30
tc.Overtime = 30 * time.Second
}
return time.Duration(tc.Overtime) * time.Second
return tc.Overtime
}
func (tc *ServerTelnetConnection) Protocol() string {
......@@ -120,12 +118,12 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus {
if incorrectPattern.Match(data) {
return AuthFailed
} else if usernamePattern.Match(data) {
_, _ = tc.conn.Write([]byte(tc.User + "\r\n"))
_, _ = tc.conn.Write([]byte(tc.SystemUser.Username + "\r\n"))
logger.Debug("usernamePattern ", tc.User)
return AuthPartial
} else if passwordPattern.Match(data) {
_, _ = tc.conn.Write([]byte(tc.Password + "\r\n"))
logger.Debug("passwordPattern ", tc.Password)
_, _ = tc.conn.Write([]byte(tc.SystemUser.Password + "\r\n"))
logger.Debug("passwordPattern ", tc.SystemUser.Password)
return AuthPartial
} else if successPattern.Match(data) {
return AuthSuccess
......@@ -139,7 +137,9 @@ func (tc *ServerTelnetConnection) login(data []byte) AuthStatus {
}
func (tc *ServerTelnetConnection) Connect(h, w int, term string) (err error) {
conn, err := net.DialTimeout("tcp", net.JoinHostPort(tc.Host, tc.Port), tc.Timeout())
var ip = tc.Asset.Ip
var port = strconv.Itoa(tc.Asset.Port)
conn, err := net.DialTimeout("tcp", net.JoinHostPort(ip, port), tc.Timeout())
if err != nil {
return
}
......
package sshd
import (
"strconv"
"net"
"github.com/gliderlabs/ssh"
......@@ -22,9 +22,10 @@ func StartServer() {
logger.Fatal("Load host key error: ", err)
}
logger.Infof("Start ssh server at %s:%d", conf.BindHost, conf.SSHPort)
addr := net.JoinHostPort(conf.BindHost, conf.SSHPort)
logger.Infof("Start ssh server at %s", addr)
sshServer = &ssh.Server{
Addr: conf.BindHost + ":" + strconv.Itoa(conf.SSHPort),
Addr: addr,
KeyboardInteractiveHandler: auth.CheckMFA,
PasswordHandler: auth.CheckUserPassword,
PublicKeyHandler: auth.CheckUserPublicKey,
......@@ -33,7 +34,7 @@ func StartServer() {
Handler: handler.SessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{},
}
// Set Auth Handler
// Set sftp handler
sshServer.SetSubsystemHandler("sftp", handler.SftpHandler)
logger.Fatal(sshServer.ListenAndServe())
}
......@@ -41,8 +42,7 @@ func StartServer() {
func StopServer() {
err := sshServer.Close()
if err != nil {
logger.Debugf("ssh server close failed: %s", err.Error())
logger.Errorf("SSH server close failed: %s", err.Error())
}
logger.Debug("Close ssh Server")
logger.Debug("Close ssh server")
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment