Commit e9f3361d authored by Eric's avatar Eric

use sync.map replace some native map container

parent 8f5d95aa
package main package main
import "cocogo/pkg/sshd" import (
"cocogo/pkg/auth"
"cocogo/pkg/config"
"cocogo/pkg/sshd"
)
func main() { var (
conf *config.Config
appService *auth.Service
)
func init() {
configFile := "config.yml"
conf = config.LoadFromYaml(configFile)
appService = auth.NewAuthService(conf)
appService.LoadAccessKey()
appService.EnsureValidAuth()
appService.LoadTerminalConfig()
sshd.Initial(conf, appService)
}
func main() {
sshd.StartServer() sshd.StartServer()
} }
...@@ -8,7 +8,10 @@ require ( ...@@ -8,7 +8,10 @@ require (
github.com/gliderlabs/ssh v0.1.3 github.com/gliderlabs/ssh v0.1.3
github.com/kr/pty v1.1.4 // indirect github.com/kr/pty v1.1.4 // indirect
github.com/mattn/go-runewidth v0.0.4 github.com/mattn/go-runewidth v0.0.4
github.com/olekukonko/tablewriter v0.0.1
github.com/satori/go.uuid v1.2.0 github.com/satori/go.uuid v1.2.0
github.com/sirupsen/logrus v1.4.0 github.com/sirupsen/logrus v1.4.0
github.com/xlab/treeprint v0.0.0-20181112141820-a009c3971eca
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576 golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576
gopkg.in/yaml.v2 v2.2.2
) )
package asset
import "golang.org/x/crypto/ssh"
/*
{
"id": "060ba6be-a01d-41ef-b366-384b8a012274",
"hostname": "docker_test",
"ip": "127.0.0.1",
"port": 32768,
"system_users_granted": [
{
"id": "fbd39f8c-fa3e-4c2b-948e-ce1e0380b4f9",
"name": "docker_root",
"username": "root",
"priority": 20,
"protocol": "ssh",
"comment": "screencast",
"login_mode": "auto"
}
],
"is_active": true,
"system_users_join": "root",
"os": null,
"domain": null,
"platform": "Linux",
"comment": "screencast",
"protocol": "ssh",
"org_id": "",
"org_name": "DEFAULT"
}
*/
type Node struct {
IP string `json:"ip"`
Port string `json:"port"`
UserName string `json:"username"`
PassWord string `json:"password"`
PublicKey ssh.Signer
}
package auth package auth
import "github.com/gliderlabs/ssh" import "fmt"
type Service struct { type accessAuth struct {
accessKey string
accessSecret string
} }
var ( func (a accessAuth) Signature(date string) string {
service = new(Service) return fmt.Sprintf("Sign %s:%s", a.accessKey, MakeSignature(a.accessSecret, date))
)
func NewService() *Service {
return service
}
func (s *Service) SSHPassword(ctx ssh.Context, password string) bool {
ctx.SessionID()
Username := "softwareuser1"
Password := "123456"
if ctx.User() == Username && password == Password {
return true
}
return false
} }
This diff is collapsed.
package auth
import (
"crypto/md5"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
func HTTPGMTDate() string {
GmtDateLayout := "Mon, 02 Jan 2006 15:04:05 GMT"
return time.Now().UTC().Format(GmtDateLayout)
}
func MakeSignature(key, date string) string {
s := strings.Join([]string{key, date}, "\n")
return Base64Encode(MD5Encode([]byte(s)))
}
func Base64Encode(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
func MD5Encode(b []byte) string {
return fmt.Sprintf("%x", md5.Sum(b))
}
func MakeSureDirExit(filePath string) {
dirPath := filepath.Dir(filePath)
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err = os.Mkdir(dirPath, os.ModePerm)
if err != nil {
log.Info("could not create dir path:", dirPath)
os.Exit(1)
}
log.Info("create dir path:", dirPath)
return
}
log.Info("dir path exits:", dirPath)
}
package auth
const (
TerminalRegisterUrl = "/api/terminal/v2/terminal-registrations/" // 注册当前coco
TerminalConfigUrl = "/api/terminal/v1/terminal/config/" // 从jumpserver获取coco的配置
UserAuthUrl = "/api/users/v1/auth/" // post 验证用户登陆
UserProfileUrl = "/api/users/v1/profile/" // 获取当前用户的基本信息
UserAssetsUrl = "/api/perms/v1/user/%s/assets/" //获取用户授权的所有资产
UserNodesAssetsUrl = "/api/perms/v1/user/%s/nodes-assets/" // 获取用户授权的所有节点信息 节点分组
SystemUserAssetAuthUrl = "/api/assets/v1/system-user/%s/asset/%s/auth-info/" // 该系统用户对某资产的授权
SystemUserAuthUrl = "/api/assets/v1/system-user/%s/auth-info/" // 该系统用户的授权
ValidateUserAssetPermission = "/api/perms/v1/asset-permission/user/validate/" //0不使用缓存 1 使用缓存 2 刷新缓存
)
/*
/api/assets/v1/system-user/%s/asset/%s/auth-info/
/api/assets/v1/system-user/fbd39f8c-fa3e-4c2b-948e-ce1e0380b4f9/cmd-filter-rules/
*/
package config
import (
"fmt"
"io/ioutil"
"os"
"gopkg.in/yaml.v2"
)
func LoadFromYaml(filepath string) *Config {
c := createDefaultConfig()
body, err := ioutil.ReadFile(filepath)
if err != nil {
os.Exit(1)
}
e := yaml.Unmarshal(body, &c)
if e != nil {
fmt.Println("load yaml err")
os.Exit(1)
}
return &c
}
func createDefaultConfig() Config {
name, _ := os.Hostname()
rootPath, _ := os.Getwd()
return Config{
Name: name,
CoreHost: "http://localhost:8080",
BootstrapToken: "",
BindHost: "0.0.0.0",
SshPort: 2222,
HTTPPort: 5000,
CustomerAccessKey: "",
AccessKeyFile: "data/keys/.access_key",
LogLevel: "DEBUG",
RootPath: rootPath,
Comment: "Coco",
TermConfig: &TerminalConfig{},
}
}
type Config struct {
Name string `yaml:"NAME"`
CoreHost string `yaml:"CORE_HOST"`
BootstrapToken string `yaml:"BOOTSTRAP_TOKEN"`
BindHost string `yaml:"BIND_HOST"`
SshPort int `yaml:"SSHD_PORT"`
HTTPPort int `yaml:"HTTPD_PORT"`
CustomerAccessKey string `yaml:"ACCESS_KEY"`
AccessKeyFile string `yaml:"ACCESS_KEY_FILE"`
LogLevel string `yaml:"LOG_LEVEL"`
HeartBeat int `yaml:"HEARTBEAT_INTERVAL"`
RootPath string
Comment string
TermConfig *TerminalConfig
}
type TerminalConfig struct {
AssetListPageSize string `json:"TERMINAL_ASSET_LIST_PAGE_SIZE"`
AssetListSortBy string `json:"TERMINAL_ASSET_LIST_SORT_BY"`
CommandStorage Storage `json:"TERMINAL_COMMAND_STORAGE"`
HeaderTitle string `json:"TERMINAL_HEADER_TITLE"`
HeartBeatInterval int `json:"TERMINAL_HEARTBEAT_INTERVAL"`
HostKey string `json:"TERMINAL_HOST_KEY"`
PasswordAuth bool `json:"TERMINAL_PASSWORD_AUTH"`
PublicKeyAuth bool `json:"TERMINAL_PUBLIC_KEY_AUTH"`
RePlayStorage Storage `json:"TERMINAL_REPLAY_STORAGE"`
SessionKeepDuration int `json:"TERMINAL_SESSION_KEEP_DURATION"`
TelnetRegex string `json:"TERMINAL_TELNET_REGEX"`
SecurityMaxIdleTime int `json:"SECURITY_MAX_IDLE_TIME"`
}
type Storage struct {
Type string `json:"TYPE"`
}
...@@ -5,81 +5,98 @@ import ( ...@@ -5,81 +5,98 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/gliderlabs/ssh"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
) )
func NewNodeConn(c *gossh.Client, s *gossh.Session, useS Conn) (*NodeConn, error) { type Conn interface {
ptyReq, winCh, _ := useS.Pty() SessionID() string
err := s.RequestPty(ptyReq.Term, ptyReq.Window.Height, ptyReq.Window.Width, gossh.TerminalModes{})
User() string
UUID() uuid.UUID
Pty() (ssh.Pty, <-chan ssh.Window, bool)
Context() context.Context
io.Reader
io.WriteCloser
}
type ServerAuth struct {
IP string
Port int
UserName string
Password string
PublicKey gossh.Signer
}
func CreateNodeSession(authInfo ServerAuth) (c *gossh.Client, s *gossh.Session, err error) {
config := &gossh.ClientConfig{
User: authInfo.UserName,
Auth: []gossh.AuthMethod{
gossh.Password(authInfo.Password),
gossh.PublicKeys(authInfo.PublicKey),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}
client, err := gossh.Dial("tcp", fmt.Sprintf("%s:%d", authInfo.IP, authInfo.Port), config)
if err != nil { if err != nil {
return nil, err log.Error(err)
return c, s, err
} }
nodeStdin, err := s.StdinPipe() s, err = client.NewSession()
if err != nil { if err != nil {
return nil, err log.Error(err)
return c, s, err
} }
nodeStdout, err := s.StdoutPipe()
return client, s, nil
}
func NewNodeConn(authInfo ServerAuth, userS Conn) (*NodeConn, error) {
c, s, err := CreateNodeSession(authInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = s.Shell() ptyReq, winCh, _ := userS.Pty()
err = s.RequestPty(ptyReq.Term, ptyReq.Window.Height, ptyReq.Window.Width, gossh.TerminalModes{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx, cancel := context.WithCancel(useS.Context()) nodeStdin, err := s.StdinPipe()
Out, In := io.Pipe()
go func() {
for {
select {
case <-ctx.Done():
log.Info("NewNodeConn goroutine closed")
err = s.Close()
err = c.Close()
if err != io.EOF && err != nil {
log.Info(" sess Close():", err)
}
return
case win, ok := <-winCh:
if !ok {
return
}
err = s.WindowChange(win.Height, win.Width)
if err != nil { if err != nil {
log.Info("windowChange err: ", win) return nil, err
return
}
log.Info("windowChange: ", win)
}
} }
}() nodeStdout, err := s.StdoutPipe()
go func() {
nr, err := io.Copy(In, nodeStdout)
if err != nil { if err != nil {
log.Info("io copy err:", err) return nil, err
} }
err = In.Close() err = s.Shell()
if err != nil { if err != nil {
log.Info("io copy c.Close():", err) return nil, err
} }
cancel() ctx, cancelFunc := context.WithCancel(userS.Context())
log.Info("io copy int:", nr)
}()
nConn := &NodeConn{ nConn := &NodeConn{
uuid: uuid.NewV4(), uuid: uuid.NewV4(),
client: c, client: c,
conn: s, conn: s,
ctx: ctx,
ctxCancelFunc: cancelFunc,
stdin: nodeStdin, stdin: nodeStdin,
stdout: nodeStdout, stdout: nodeStdout,
cusOut: Out,
cusIn: In,
tParser: NewTerminalParser(), tParser: NewTerminalParser(),
inChan: make(chan []byte),
outChan: make(chan []byte),
} }
go nConn.windowChangeHandler(winCh)
return nConn, nil return nConn, nil
} }
...@@ -90,43 +107,80 @@ type NodeConn struct { ...@@ -90,43 +107,80 @@ type NodeConn struct {
conn *gossh.Session conn *gossh.Session
stdin io.Writer stdin io.Writer
stdout io.Reader stdout io.Reader
cusIn io.WriteCloser
cusOut io.ReadCloser
tParser *TerminalParser tParser *TerminalParser
currentCommandInput string currentCommandInput string
currentCommandResult string currentCommandResult string
rulerFilters []RuleFilter rulerFilters []RuleFilter
specialCommands []SpecialRuler specialCommands []SpecialRuler
inSpecialStatus bool inSpecialStatus bool
ctx context.Context
ctxCancelFunc context.CancelFunc
inChan chan []byte
outChan chan []byte
} }
func (n *NodeConn) UUID() uuid.UUID { func (n *NodeConn) UUID() uuid.UUID {
return n.uuid return n.uuid
} }
func (n *NodeConn) Read(b []byte) (nr int, err error) { func (n *NodeConn) Wait() error {
nr, err = n.cusOut.Read(b) return n.conn.Wait()
}
if n.tParser.Started && nr > 0 {
n.FilterSpecialCommand(b[:nr])
func (n *NodeConn) FilterSpecialCommand(b []byte) {
for _, specialCommand := range n.specialCommands {
if matched := specialCommand.MatchRule(b); matched {
switch { switch {
case n.inSpecialStatus: case specialCommand.EnterStatus():
// 进入特殊命令状态, n.inSpecialStatus = true
case n.tParser.InputStatus: case specialCommand.ExitStatus():
n.tParser.CmdInputBuf.Write(b[:nr]) n.inSpecialStatus = false
case n.tParser.OutputStatus:
n.tParser.CmdOutputBuf.Write(b[:nr]) }
default: }
}
}
func (n *NodeConn) FilterWhiteBlackRule(cmd string) bool {
for _, rule := range n.rulerFilters {
if rule.Match(cmd) {
return rule.BlockCommand()
}
} }
return false
}
func (n *NodeConn) windowChangeHandler(winCH <-chan ssh.Window) {
for {
select {
case <-n.ctx.Done():
log.Info("windowChangeHandler done")
return
case win, ok := <-winCH:
if !ok {
return
}
err := n.conn.WindowChange(win.Height, win.Width)
if err != nil {
log.Error("windowChange err: ", win)
return
}
log.Info("windowChange: ", win)
}
} }
return nr, err
} }
func (n *NodeConn) Write(b []byte) (nw int, err error) { func (n *NodeConn) handleRequest(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case buf, ok := <-n.inChan:
if !ok {
return
}
n.tParser.Once.Do( n.tParser.Once.Do(
func() { func() {
n.tParser.Started = true n.tParser.Started = true
...@@ -136,22 +190,19 @@ func (n *NodeConn) Write(b []byte) (nw int, err error) { ...@@ -136,22 +190,19 @@ func (n *NodeConn) Write(b []byte) (nw int, err error) {
case n.inSpecialStatus: case n.inSpecialStatus:
// 特殊的命令 vim 或者 rz // 特殊的命令 vim 或者 rz
case n.tParser.IsEnterKey(b): case n.tParser.IsEnterKey(buf):
n.currentCommandInput = n.tParser.ParseCommandInput() n.currentCommandInput = n.tParser.ParseCommandInput()
if n.FilterWhiteBlackRule(n.currentCommandInput) { if n.FilterWhiteBlackRule(n.currentCommandInput) {
msg := fmt.Sprintf("\r\n cmd '%s' is forbidden \r\n", n.currentCommandInput) msg := fmt.Sprintf("\r\n cmd '%s' is forbidden \r\n", n.currentCommandInput)
nw, err = n.cusIn.Write([]byte(msg)) n.outChan <- []byte(msg)
if err != nil {
return nw, err
}
ctrU := []byte{21, 13} // 清除行并换行 ctrU := []byte{21, 13} // 清除行并换行
nw, err = n.stdin.Write(ctrU) _, err := n.stdin.Write(ctrU)
if err != nil { if err != nil {
return nw, err log.Error(err)
} }
n.tParser.InputStatus = false n.tParser.InputStatus = false
n.tParser.OutputStatus = false n.tParser.OutputStatus = false
return len(b), nil continue
} }
n.tParser.InputStatus = false n.tParser.InputStatus = false
n.tParser.OutputStatus = true n.tParser.OutputStatus = true
...@@ -167,36 +218,57 @@ func (n *NodeConn) Write(b []byte) (nw int, err error) { ...@@ -167,36 +218,57 @@ func (n *NodeConn) Write(b []byte) (nw int, err error) {
} }
n.tParser.InputStatus = true n.tParser.InputStatus = true
} }
return n.stdin.Write(b)
}
func (n *NodeConn) Close() error { _, _ = n.stdin.Write(buf)
return n.cusOut.Close()
}
func (n *NodeConn) Wait() error { }
return n.conn.Wait() }
} }
func (n *NodeConn) FilterSpecialCommand(b []byte) { func (n *NodeConn) handleResponse(ctx context.Context) {
for _, specialCommand := range n.specialCommands { buf := make([]byte, maxBufferSize)
if matched := specialCommand.MatchRule(b); matched { defer close(n.outChan)
for {
nr, err := n.stdout.Read(buf)
if err != nil {
return
}
if n.tParser.Started && nr > 0 {
n.FilterSpecialCommand(buf[:nr])
switch { switch {
case specialCommand.EnterStatus(): case n.inSpecialStatus:
n.inSpecialStatus = true // 进入特殊命令状态,
case specialCommand.ExitStatus(): case n.tParser.InputStatus:
n.inSpecialStatus = false n.tParser.CmdInputBuf.Write(buf[:nr])
case n.tParser.OutputStatus:
n.tParser.CmdOutputBuf.Write(buf[:nr])
default:
}
} }
select {
case <-ctx.Done():
return
default:
copyBuf := make([]byte, len(buf[:nr]))
copy(copyBuf, buf[:nr])
n.outChan <- copyBuf
} }
} }
} }
func (n *NodeConn) FilterWhiteBlackRule(cmd string) bool { func (n *NodeConn) Close() {
for _, rule := range n.rulerFilters {
if rule.Match(cmd) { select {
return rule.BlockCommand() case <-n.ctx.Done():
} return
default:
_ = n.conn.Close()
_ = n.client.Close()
n.ctxCancelFunc()
} }
return false
} }
...@@ -15,10 +15,11 @@ type ProxyChannel interface { ...@@ -15,10 +15,11 @@ type ProxyChannel interface {
Wait() error Wait() error
} }
func NewMemoryChannel(n *NodeConn) *memoryChannel { func NewMemoryChannel(nConn *NodeConn, useS Conn) *memoryChannel {
m := &memoryChannel{ m := &memoryChannel{
uuid: n.UUID(), uuid: nConn.UUID(),
conn: n, conn: nConn,
} }
return m return m
} }
...@@ -33,58 +34,13 @@ func (m *memoryChannel) UUID() string { ...@@ -33,58 +34,13 @@ func (m *memoryChannel) UUID() string {
} }
func (m *memoryChannel) SendResponseChannel(ctx context.Context) <-chan []byte { func (m *memoryChannel) SendResponseChannel(ctx context.Context) <-chan []byte {
// 传入context, 可以从外层进行取消 go m.conn.handleResponse(ctx)
sendChannel := make(chan []byte) return m.conn.outChan
go func() {
defer close(sendChannel)
resp := make([]byte, maxBufferSize)
for {
nr, e := m.conn.Read(resp)
if e != nil {
log.Info("m.conn.Read(resp) err: ", e)
break
}
select {
case <-ctx.Done():
return
default:
sendChannel <- resp[:nr]
}
}
}()
return sendChannel
} }
func (m *memoryChannel) ReceiveRequestChannel(ctx context.Context) chan<- []byte { func (m *memoryChannel) ReceiveRequestChannel(ctx context.Context) chan<- []byte {
// 传入context, 可以从外层进行取消 go m.conn.handleRequest(ctx)
receiveChannel := make(chan []byte) return m.conn.inChan
go func() {
defer m.conn.Close()
for {
select {
case <-ctx.Done():
log.Info("ReceiveRequestChannel ctx done")
return
case reqBuf, ok := <-receiveChannel:
if !ok {
return
}
nw, e := m.conn.Write(reqBuf)
if e != nil && nw != len(reqBuf) {
return
}
}
}
}()
return receiveChannel
} }
func (m *memoryChannel) Wait() error { func (m *memoryChannel) Wait() error {
......
...@@ -5,95 +5,76 @@ import ( ...@@ -5,95 +5,76 @@ import (
"sync" "sync"
) )
type room struct { var Manager = &manager{
sessionID string container: new(sync.Map),
uHome SessionHome
pChan ProxyChannel
} }
var Manager = &manager{container: map[string]room{}}
type manager struct { type manager struct {
container map[string]room container *sync.Map
sync.RWMutex
} }
func (m *manager) add(uHome SessionHome, pChan ProxyChannel) { func (m *manager) add(uHome SessionHome) {
m.Lock() m.container.Store(uHome.SessionID(), uHome)
m.container[uHome.SessionID()] = room{
sessionID: uHome.SessionID(),
uHome: uHome,
pChan: pChan,
}
m.Unlock()
} }
func (m *manager) delete(roomID string) { func (m *manager) delete(roomID string) {
m.Lock() m.container.Delete(roomID)
delete(m.container, roomID)
m.Unlock()
} }
func (m *manager) search(roomID string) (SessionHome, bool) { func (m *manager) search(roomID string) (SessionHome, bool) {
m.RLock() if uHome, ok := m.container.Load(roomID); ok {
defer m.RUnlock() return uHome.(SessionHome), ok
if room, ok := m.container[roomID]; ok {
return room.uHome, ok
} }
return nil, false return nil, false
} }
func JoinShareRoom(roomID string, uConn Conn) { func (m *manager) JoinShareRoom(roomID string, uConn Conn) {
if userHome, ok := Manager.search(roomID); ok { if userHome, ok := m.search(roomID); ok {
userHome.AddConnection(uConn) userHome.AddConnection(uConn)
} }
} }
func ExitShareRoom(roomID string, uConn Conn) { func (m *manager) ExitShareRoom(roomID string, uConn Conn) {
if userHome, ok := Manager.search(roomID); ok { if userHome, ok := m.search(roomID); ok {
userHome.RemoveConnection(uConn) userHome.RemoveConnection(uConn)
} }
} }
func Switch(ctx context.Context, userHome SessionHome, pChannel ProxyChannel) error { func (m *manager) Switch(ctx context.Context, userHome SessionHome, pChannel ProxyChannel) error {
Manager.add(userHome, pChannel) m.add(userHome)
defer Manager.delete(userHome.SessionID()) defer m.delete(userHome.SessionID())
subCtx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup subCtx, cancelFunc := context.WithCancel(ctx)
wg.Add(2) userSendRequestStream := userHome.SendRequestChannel(subCtx)
go func(ctx context.Context, wg *sync.WaitGroup) { userReceiveStream := userHome.ReceiveResponseChannel(subCtx)
defer wg.Done() nodeRequestChan := pChannel.ReceiveRequestChannel(subCtx)
nodeSendResponseStream := pChannel.SendResponseChannel(subCtx)
userSendRequestStream := userHome.SendRequestChannel(ctx)
nodeRequestChan := pChannel.ReceiveRequestChannel(ctx) for userSendRequestStream != nil || nodeSendResponseStream != nil {
select {
for reqFromUser := range userSendRequestStream { case buf1, ok := <-userSendRequestStream:
nodeRequestChan <- reqFromUser if !ok {
log.Warn("userSendRequestStream close")
userSendRequestStream = nil
continue
} }
log.Info("userSendRequestStream close") nodeRequestChan <- buf1
close(nodeRequestChan) case buf2, ok := <-nodeSendResponseStream:
if !ok {
}(subCtx, &wg) log.Warn("nodeSendResponseStream close")
nodeSendResponseStream = nil
go func(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
userReceiveStream := userHome.ReceiveResponseChannel(ctx)
nodeSendResponseStream := pChannel.SendResponseChannel(ctx)
for resFromNode := range nodeSendResponseStream {
userReceiveStream <- resFromNode
}
log.Info("nodeSendResponseStream close")
close(userReceiveStream) close(userReceiveStream)
}(subCtx, &wg) cancelFunc()
err := pChannel.Wait() continue
if err != nil { }
log.Info("pChannel err:", err) userReceiveStream <- buf2
case <-ctx.Done():
return nil
}
} }
cancel()
wg.Wait()
log.Info("switch end") log.Info("switch end")
return err return nil
} }
...@@ -2,28 +2,9 @@ package core ...@@ -2,28 +2,9 @@ package core
import ( import (
"context" "context"
"io"
"sync" "sync"
"github.com/gliderlabs/ssh"
uuid "github.com/satori/go.uuid"
) )
type Conn interface {
SessionID() string
User() string
UUID() uuid.UUID
Pty() (ssh.Pty, <-chan ssh.Window, bool)
Context() context.Context
io.Reader
io.WriteCloser
}
type SessionHome interface { type SessionHome interface {
SessionID() string SessionID() string
AddConnection(c Conn) AddConnection(c Conn)
...@@ -33,21 +14,22 @@ type SessionHome interface { ...@@ -33,21 +14,22 @@ type SessionHome interface {
} }
func NewUserSessionHome(con Conn) *userSessionHome { func NewUserSessionHome(con Conn) *userSessionHome {
return &userSessionHome{ uHome := &userSessionHome{
readStream: make(chan []byte), readStream: make(chan []byte),
mainConn: con, mainConn: con,
connMap: map[string]Conn{con.UUID().String(): con}, connMap: new(sync.Map),
cancelMap: map[string]context.CancelFunc{}, cancelMap: new(sync.Map),
} }
uHome.connMap.Store(con.SessionID(), con)
return uHome
} }
type userSessionHome struct { type userSessionHome struct {
readStream chan []byte readStream chan []byte
mainConn Conn mainConn Conn
connMap map[string]Conn connMap *sync.Map
cancelMap map[string]context.CancelFunc cancelMap *sync.Map
sync.RWMutex
} }
func (r *userSessionHome) SessionID() string { func (r *userSessionHome) SessionID() string {
...@@ -57,9 +39,9 @@ func (r *userSessionHome) SessionID() string { ...@@ -57,9 +39,9 @@ func (r *userSessionHome) SessionID() string {
func (r *userSessionHome) AddConnection(c Conn) { func (r *userSessionHome) AddConnection(c Conn) {
key := c.SessionID() key := c.SessionID()
if _, ok := r.connMap[key]; !ok { if _, ok := r.connMap.Load(key); !ok {
log.Info("add connection ", c) log.Info("add connection ", c)
r.connMap[key] = c r.connMap.Store(key, c)
} else { } else {
log.Info("already add connection") log.Info("already add connection")
return return
...@@ -68,7 +50,7 @@ func (r *userSessionHome) AddConnection(c Conn) { ...@@ -68,7 +50,7 @@ func (r *userSessionHome) AddConnection(c Conn) {
log.Info("add conn session room: ", r.SessionID()) log.Info("add conn session room: ", r.SessionID())
ctx, cancelFunc := context.WithCancel(r.mainConn.Context()) ctx, cancelFunc := context.WithCancel(r.mainConn.Context())
r.cancelMap[key] = cancelFunc r.cancelMap.Store(key, cancelFunc)
defer r.RemoveConnection(c) defer r.RemoveConnection(c)
...@@ -82,10 +64,12 @@ func (r *userSessionHome) AddConnection(c Conn) { ...@@ -82,10 +64,12 @@ func (r *userSessionHome) AddConnection(c Conn) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Info("conn ctx done") log.Info(" user conn ctx done")
return return
default: default:
r.readStream <- buf[:nr] copyBuf := make([]byte, nr)
copy(copyBuf, buf[:nr])
r.readStream <- copyBuf
} }
...@@ -94,13 +78,13 @@ func (r *userSessionHome) AddConnection(c Conn) { ...@@ -94,13 +78,13 @@ func (r *userSessionHome) AddConnection(c Conn) {
} }
func (r *userSessionHome) RemoveConnection(c Conn) { func (r *userSessionHome) RemoveConnection(c Conn) {
r.Lock()
defer r.Unlock()
key := c.SessionID() key := c.SessionID()
if _, ok := r.connMap[key]; ok { if cancelFunc, ok := r.cancelMap.Load(key); ok {
delete(r.connMap, key) cancelFunc.(context.CancelFunc)()
delete(r.cancelMap, key)
} }
r.connMap.Delete(key)
} }
func (r *userSessionHome) SendRequestChannel(ctx context.Context) <-chan []byte { func (r *userSessionHome) SendRequestChannel(ctx context.Context) <-chan []byte {
...@@ -118,7 +102,9 @@ func (r *userSessionHome) SendRequestChannel(ctx context.Context) <-chan []byte ...@@ -118,7 +102,9 @@ func (r *userSessionHome) SendRequestChannel(ctx context.Context) <-chan []byte
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
r.readStream <- buf[:nr] var respCopy []byte
respCopy = append(respCopy, buf[:nr]...)
r.readStream <- respCopy
} }
} }
...@@ -132,11 +118,10 @@ func (r *userSessionHome) ReceiveResponseChannel(ctx context.Context) chan<- []b ...@@ -132,11 +118,10 @@ func (r *userSessionHome) ReceiveResponseChannel(ctx context.Context) chan<- []b
writeStream := make(chan []byte) writeStream := make(chan []byte)
go func() { go func() {
defer func() { defer func() {
r.RLock() r.cancelMap.Range(func(key, cancelFunc interface{}) bool {
for _, cancel := range r.cancelMap { cancelFunc.(context.CancelFunc)()
cancel() return true
} })
r.RUnlock()
}() }()
for { for {
...@@ -147,13 +132,15 @@ func (r *userSessionHome) ReceiveResponseChannel(ctx context.Context) chan<- []b ...@@ -147,13 +132,15 @@ func (r *userSessionHome) ReceiveResponseChannel(ctx context.Context) chan<- []b
if !ok { if !ok {
return return
} }
for _, c := range r.connMap { r.connMap.Range(func(key, connItem interface{}) bool {
nw, err := c.Write(buf) nw, err := connItem.(Conn).Write(buf)
if err != nil || nw != len(buf) { if err != nil || nw != len(buf) {
log.Error("Write Conn err", c) log.Error("Write Conn err", connItem)
r.cancelMap[c.SessionID()]() r.RemoveConnection(connItem.(Conn))
}
} }
return true
})
} }
} }
......
...@@ -23,17 +23,17 @@ type Rule struct { ...@@ -23,17 +23,17 @@ type Rule struct {
action bool action bool
} }
func (w *Rule) Match(s string) bool { func (r *Rule) Match(s string) bool {
switch w.ruleType { switch r.ruleType {
case "command": case "command":
for _, content := range w.contents { for _, content := range r.contents {
if content == s { if content == s {
return true return true
} }
} }
return false return false
default: default:
for _, content := range w.contents { for _, content := range r.contents {
if matched, _ := regexp.MatchString(content, s); matched { if matched, _ := regexp.MatchString(content, s); matched {
return true return true
} }
...@@ -43,6 +43,6 @@ func (w *Rule) Match(s string) bool { ...@@ -43,6 +43,6 @@ func (w *Rule) Match(s string) bool {
} }
func (w *Rule) BlockCommand() bool { func (r *Rule) BlockCommand() bool {
return w.action return r.action
} }
package model
/*
{
"id": "135ce78d-c4fe-44ca-9be3-c86581cb4365",
"hostname": "coco2",
"ip": "127.0.0.1",
"port": 32769,
"system_users_granted": [{
"id": "fbd39f8c-fa3e-4c2b-948e-ce1e0380b4f9",
"name": "docker_root",
"username": "root",
"priority": 19,
"protocol": "ssh",
"comment": "screencast",
"login_mode": "auto"
}],
"is_active": true,
"system_users_join": "root",
"os": null,
"domain": null,
"platform": "Linux",
"comment": "",
"protocol": "ssh",
"org_id": "",
"org_name": "DEFAULT"
}
*/
type Asset struct {
Id string `json:"id"`
Hostname string `json:"hostname"`
Ip string `json:"ip"`
Port int `json:"port"`
SystemUsers []SystemUser `json:"system_users_granted"`
IsActive bool `json:"is_active"`
SystemUsersJoin string `json:"system_users_join"`
Os string `json:"os"`
Domain string `json:"domain"`
Platform string `json:"platform"`
Comment string `json:"comment"`
Protocol string `json:"protocol"`
OrgID string `json:"org_id"`
OrgName string `json:"org_name"`
}
package model
import (
"sort"
"strconv"
"strings"
)
type AssetNode struct {
Id string `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Value string `json:"value"`
Parent string `json:"parent"`
AssetsGranted []Asset `json:"assets_granted"`
AssetsAmount int `json:"assets_amount"`
OrgId string `json:"org_id"`
}
type nodeSortBy func(node1, node2 *AssetNode) bool
func (by nodeSortBy) Sort(assetNodes []AssetNode) {
nodeSorter := &AssetNodeSorter{
assetNodes: assetNodes,
sortBy: by,
}
sort.Sort(nodeSorter)
}
type AssetNodeSorter struct {
assetNodes []AssetNode
sortBy func(node1, node2 *AssetNode) bool
}
func (a *AssetNodeSorter) Len() int {
return len(a.assetNodes)
}
func (a *AssetNodeSorter) Swap(i, j int) {
a.assetNodes[i], a.assetNodes[j] = a.assetNodes[j], a.assetNodes[i]
}
func (a *AssetNodeSorter) Less(i, j int) bool {
return a.sortBy(&a.assetNodes[i], &a.assetNodes[j])
}
/*
key的排列顺序:
1 1:3 1:3:0 1:4 1:5 1:8
*/
func keySort(node1, node2 *AssetNode) bool {
node1Keys := strings.Split(node1.Key, ":")
node2Keys := strings.Split(node2.Key, ":")
for i := 0; i < len(node1Keys); i++ {
if i >= len(node2Keys) {
return false
}
node1num, _ := strconv.Atoi(node1Keys[i])
node2num, _ := strconv.Atoi(node2Keys[i])
if node1num == node2num {
continue
} else if node1num-node2num > 0 {
return false
} else {
return true
}
}
return true
}
func SortAssetNodesByKey(assetNodes []AssetNode) {
nodeSortBy(keySort).Sort(assetNodes)
}
package model
import (
"sort"
)
/*
{"id": "fbd39f8c-fa3e-4c2b-948e-ce1e0380b4f9",
"name": "docker_root",
"username": "root",
"priority": 19,
"protocol": "ssh",
"comment": "screencast",
"login_mode": "auto"}
*/
type SystemUser struct {
Id string `json:"id"`
Name string `json:"name"`
UserName string `json:"username"`
Priority int `json:"priority"`
Protocol string `json:"protocol"`
Comment string `json:"comment"`
LoginMode string `json:"login_mode"`
}
type SystemUserAuthInfo struct {
Id string `json:"id"`
Name string `json:"name"`
UserName string `json:"username"`
Protocol string `json:"protocol"`
LoginMode string `json:"login_mode"`
Password string `json:"password"`
PrivateKey string `json:"private_key"`
}
type systemUserSortBy func(user1, user2 *SystemUser) bool
func (by systemUserSortBy) Sort(users []SystemUser) {
nodeSorter := &systemUserSorter{
users: users,
sortBy: by,
}
sort.Sort(nodeSorter)
}
type systemUserSorter struct {
users []SystemUser
sortBy func(user1, user2 *SystemUser) bool
}
func (s *systemUserSorter) Len() int {
return len(s.users)
}
func (s *systemUserSorter) Swap(i, j int) {
s.users[i], s.users[j] = s.users[j], s.users[i]
}
func (s *systemUserSorter) Less(i, j int) bool {
return s.sortBy(&s.users[i], &s.users[j])
}
func systemUserPrioritySort(use1, user2 *SystemUser) bool {
return use1.Priority <= user2.Priority
}
func SortSystemUserByPriority(users []SystemUser) {
systemUserSortBy(systemUserPrioritySort).Sort(users)
}
package model
/*
{'id': '1f8e54a8-d99d-4074-b35d-45264adb4e34',
'name': 'EricdeMBP.lan',
'username': 'EricdeMBP.lan',
'email': 'EricdeMBP.lan@serviceaccount.local',
'groups': [],
'groups_display': '',
'role': 'App','role_display': '应用程序',
'avatar_url': '/static/img/avatar/user.png',
'wechat': '','phone': None, 'otp_level': 0,
'comment': '', 'source': 'local',
'source_display': 'Local',
'is_valid': True, 'is_expired': False,
'is_active': True, 'created_by': '',
'is_first_login': True, 'date_password_last_updated': '2019-04-08 18:18:24 +0800',
'date_expired': '2089-03-21 18:18:24 +0800'}
*/
type User struct {
Id string `json:"id"`
Name string `json:"name"`
UserName string `json:"username"`
Email string `json:"email"`
Role string `json:"role"`
IsValid bool `json:"is_valid"`
IsActive bool `json:"is_active"`
}
package sshd
const (
GreenColorCode = "\033[32m"
ColorEnd = "\033[0m"
Tab = "\t"
EndLine = "\r\n\r"
)
const (
AssetsMapKey = "AssetsMapKey"
AssetNodesMapKey = "AssetNodesKey"
)
This diff is collapsed.
package sshd
type CommandData struct {
Input string `json:"input"`
Output string `json:"output"`
Timestamp int64 `json:"timestamp"`
}
...@@ -2,41 +2,85 @@ package sshd ...@@ -2,41 +2,85 @@ package sshd
import ( import (
"cocogo/pkg/auth" "cocogo/pkg/auth"
"cocogo/pkg/config"
"cocogo/pkg/model"
"io"
"strconv" "strconv"
"sync" "sync"
"text/template" "text/template"
"github.com/sirupsen/logrus" "golang.org/x/crypto/ssh/terminal"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/sirupsen/logrus"
) )
var ( var (
SSHPort int conf *config.Config
SSHKeyPath string appService *auth.Service
log *logrus.Logger serverSig ssh.Signer
displayTemplate *template.Template displayTemplate *template.Template
authService *auth.Service log *logrus.Logger
sessionContainer sync.Map
Cached sync.Map
) )
func init() { func Initial(config *config.Config, service *auth.Service) {
log = logrus.New()
displayTemplate = template.Must(template.New("display").Parse(welcomeTemplate)) displayTemplate = template.Must(template.New("display").Parse(welcomeTemplate))
SSHPort = 2333 conf = config
SSHKeyPath = "data/host_rsa_key" appService = service
authService = auth.NewService() serverSig = parsePrivateKey(config.TermConfig.HostKey)
log = logrus.New()
if level, err := logrus.ParseLevel(config.LogLevel); err != nil {
log.SetLevel(logrus.InfoLevel)
} else {
log.SetLevel(level)
}
} }
func StartServer() { func StartServer() {
serverSig := getPrivateKey(SSHKeyPath)
ser := ssh.Server{ ser := ssh.Server{
Addr: "0.0.0.0:" + strconv.Itoa(SSHPort), Addr: conf.BindHost + ":" + strconv.Itoa(conf.SshPort),
PasswordHandler: authService.SSHPassword, PasswordHandler: appService.CheckSSHPassword,
PublicKeyHandler: appService.CheckSSHPublicKey,
HostSigners: []ssh.Signer{serverSig}, HostSigners: []ssh.Signer{serverSig},
Version: "coco-v1.4", Version: "coco-v1.4",
Handler: InteractiveHandler, Handler: connectHandler,
} }
log.Fatal(ser.ListenAndServe()) log.Fatal(ser.ListenAndServe())
} }
func connectHandler(sess ssh.Session) {
_, _, ptyOk := sess.Pty()
if ptyOk {
user, ok := sess.Context().Value("LoginUser").(model.User)
if !ok {
log.Info("Get current User failed")
return
}
userInteractive := &sshInteractive{
sess: sess,
term: terminal.NewTerminal(sess, "Opt>"),
user: user,
helpInfo: HelpInfo{UserName: sess.User(),
ColorCode: GreenColorCode,
ColorEnd: ColorEnd,
Tab: Tab,
EndLine: EndLine}}
log.Info("accept one session")
userInteractive.displayHelpInfo()
userInteractive.StartDispatch()
} else {
_, err := io.WriteString(sess, "No PTY requested.\n")
if err != nil {
return
}
}
}
...@@ -20,6 +20,7 @@ func (s *SSHConn) SessionID() string { ...@@ -20,6 +20,7 @@ func (s *SSHConn) SessionID() string {
func (s *SSHConn) User() string { func (s *SSHConn) User() string {
return s.conn.User() return s.conn.User()
} }
func (s *SSHConn) UUID() uuid.UUID { func (s *SSHConn) UUID() uuid.UUID {
return s.uuid return s.uuid
} }
......
package sshd package sshd
import ( import (
"io/ioutil"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
) )
func getPrivateKey(keyPath string) gossh.Signer { func parsePrivateKey(privateKey string) gossh.Signer {
privateBytes, err := ioutil.ReadFile(keyPath) private, err := gossh.ParsePrivateKey([]byte(privateKey))
if err != nil {
log.Fatal("Failed to load private key: ", err)
}
private, err := gossh.ParsePrivateKey(privateBytes)
if err != nil { if err != nil {
log.Fatal("Failed to parse private key: ", err) log.Info("Failed to parse private key: ", err)
} }
return private return private
} }
......
package sshd
const welcomeTemplate = `
{{.UserName}} Welcome to use Jumpserver open source fortress system{{.EndLine}}
{{.Tab}}1) Enter {{.ColorCode}}ID{{.ColorEnd}} directly login or enter {{.ColorCode}}part IP, Hostname, Comment{{.ColorEnd}} to search login(if unique). {{.EndLine}}
{{.Tab}}2) Enter {{.ColorCode}}/{{.ColorEnd}} + {{.ColorCode}}IP, Hostname{{.ColorEnd}} or {{.ColorCode}}Comment{{.ColorEnd}} search, such as: /ip. {{.EndLine}}
{{.Tab}}3) Enter {{.ColorCode}}p{{.ColorEnd}} to display the host you have permission.{{.EndLine}}
{{.Tab}}4) Enter {{.ColorCode}}g{{.ColorEnd}} to display the node that you have permission.{{.EndLine}}
{{.Tab}}5) Enter {{.ColorCode}}g{{.ColorEnd}} + {{.ColorCode}}NodeID{{.ColorEnd}} to display the host under the node, such as g1. {{.EndLine}}
{{.Tab}}6) Enter {{.ColorCode}}s{{.ColorEnd}} Chinese-english switch.{{.EndLine}}
{{.Tab}}7) Enter {{.ColorCode}}h{{.ColorEnd}} help.{{.EndLine}}
{{.Tab}}8) Enter {{.ColorCode}}r{{.ColorEnd}} to refresh your assets and nodes.{{.EndLine}}
{{.Tab}}0) Enter {{.ColorCode}}q{{.ColorEnd}} exit.{{.EndLine}}
`
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment