Commit 08f3f0b2 authored by Eric's avatar Eric

[Update] update user auth

parent b4ed7c7a
...@@ -6,10 +6,10 @@ import ( ...@@ -6,10 +6,10 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service" "github.com/jumpserver/koko/pkg/service"
) )
...@@ -31,28 +31,26 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult) ...@@ -31,28 +31,26 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult)
authMethod = "password" authMethod = "password"
} }
remoteAddr, _, _ := net.SplitHostPort(ctx.RemoteAddr().String()) remoteAddr, _, _ := net.SplitHostPort(ctx.RemoteAddr().String())
userClient := service.NewSessionClient(service.Username(username),
service.Password(password), service.PublicKey(publicKey),
service.RemoteAddr(remoteAddr), service.LoginType("T"))
user, authStatus := userClient.Authenticate(ctx)
resp, err := service.Authenticate(username, password, publicKey, remoteAddr, "T") switch authStatus {
if err != nil { case service.AuthMFARequired:
ctx.SetValue(model.ContextKeyClient, &userClient)
action = actionPartialAccepted
res = ssh.AuthPartiallySuccessful
case service.AuthSuccess:
res = ssh.AuthSuccessful
ctx.SetValue(model.ContextKeyUser, &user)
default:
action = actionFailed action = actionFailed
logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr)
return
}
if resp != nil {
switch resp.User.OTPLevel {
case 0:
res = ssh.AuthSuccessful
case 1, 2:
action = actionPartialAccepted
res = ssh.AuthPartiallySuccessful
default:
}
ctx.SetValue(cctx.ContextKeyUser, resp.User)
ctx.SetValue(cctx.ContextKeySeed, resp.Seed)
ctx.SetValue(cctx.ContextKeyToken, resp.Token)
} }
logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr) logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr)
return res return
} }
func CheckUserPassword(ctx ssh.Context, password string) ssh.AuthResult { func CheckUserPassword(ctx ssh.Context, password string) ssh.AuthResult {
...@@ -90,19 +88,19 @@ func CheckMFA(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) (r ...@@ -90,19 +88,19 @@ func CheckMFA(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) (r
return return
} }
mfaCode := answers[0] mfaCode := answers[0]
seed, ok := ctx.Value(cctx.ContextKeySeed).(string) client, ok := ctx.Value(model.ContextKeyClient).(*service.SessionClient)
if !ok { if !ok {
logger.Error("Mfa Auth failed, may be user password or publickey auth failed") logger.Errorf("User %s Mfa Auth failed: not found session client.", username, )
return return
} }
resp, err := service.CheckUserOTP(seed, mfaCode, remoteAddr, "T") user, authStatus := client.CheckUserOTP(ctx, mfaCode)
if err != nil { switch authStatus {
logger.Error("Mfa Auth failed: ", err) case service.AuthSuccess:
return
}
if resp.Token != "" {
res = ssh.AuthSuccessful res = ssh.AuthSuccessful
return ctx.SetValue(model.ContextKeyUser, &user)
logger.Infof("User %s Mfa Auth success", username)
default:
logger.Errorf("User %s Mfa Auth failed", username)
} }
return return
} }
......
package cctx
import (
"context"
"github.com/gliderlabs/ssh"
"github.com/jumpserver/koko/pkg/model"
)
type contextKey struct {
name string
}
var (
ContextKeyUser = &contextKey{"user"}
ContextKeyAsset = &contextKey{"asset"}
ContextKeySystemUser = &contextKey{"systemUser"}
ContextKeySSHSession = &contextKey{"sshSession"}
ContextKeyLocalAddr = &contextKey{"localAddr"}
ContextKeyRemoteAddr = &contextKey{"RemoteAddr"}
ContextKeySSHCtx = &contextKey{"sshCtx"}
ContextKeySeed = &contextKey{"seed"}
ContextKeyToken = &contextKey{"token"}
)
type Context interface {
context.Context
User() *model.User
Asset() *model.Asset
SystemUser() *model.SystemUser
SSHSession() *ssh.Session
SSHCtx() *ssh.Context
SetValue(key, value interface{})
}
// Context coco内部使用的Context
type CocoContext struct {
context.Context
}
// user 返回当前连接的用户model
func (ctx *CocoContext) User() *model.User {
return ctx.Value(ContextKeyUser).(*model.User)
}
func (ctx *CocoContext) Asset() *model.Asset {
return ctx.Value(ContextKeyAsset).(*model.Asset)
}
func (ctx *CocoContext) SystemUser() *model.SystemUser {
return ctx.Value(ContextKeySystemUser).(*model.SystemUser)
}
func (ctx *CocoContext) SSHSession() *ssh.Session {
return ctx.Value(ContextKeySSHSession).(*ssh.Session)
}
func (ctx *CocoContext) SSHCtx() *ssh.Context {
return ctx.Value(ContextKeySSHCtx).(*ssh.Context)
}
func (ctx *CocoContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}
func applySessionMetadata(ctx *CocoContext, sess ssh.Session) {
ctx.SetValue(ContextKeySSHSession, &sess)
ctx.SetValue(ContextKeySSHCtx, sess.Context())
ctx.SetValue(ContextKeyLocalAddr, sess.LocalAddr())
}
func NewContext(sess ssh.Session) (*CocoContext, context.CancelFunc) {
sshCtx, cancel := context.WithCancel(sess.Context())
ctx := &CocoContext{sshCtx}
applySessionMetadata(ctx, sess)
return ctx, cancel
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"io/ioutil" "io/ioutil"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/cookiejar"
neturl "net/url" neturl "net/url"
"os" "os"
"path/filepath" "path/filepath"
...@@ -38,8 +39,10 @@ type UrlParser interface { ...@@ -38,8 +39,10 @@ type UrlParser interface {
func NewClient(timeout time.Duration, baseHost string) Client { func NewClient(timeout time.Duration, baseHost string) Client {
headers := make(map[string]string) headers := make(map[string]string)
jar, _ := cookiejar.New(nil)
client := http.Client{ client := http.Client{
Timeout: timeout * time.Second, Timeout: timeout * time.Second,
Jar: jar,
} }
return Client{ return Client{
BaseHost: baseHost, BaseHost: baseHost,
...@@ -103,11 +106,10 @@ func (c *Client) parseUrl(url string, params []map[string]string) string { ...@@ -103,11 +106,10 @@ func (c *Client) parseUrl(url string, params []map[string]string) string {
func (c *Client) setAuthHeader(r *http.Request) { func (c *Client) setAuthHeader(r *http.Request) {
if len(c.cookie) != 0 { if len(c.cookie) != 0 {
cookie := make([]string, 0)
for k, v := range c.cookie { for k, v := range c.cookie {
cookie = append(cookie, fmt.Sprintf("%s=%s", k, v)) c := http.Cookie{Name: k, Value: v,}
r.AddCookie(&c)
} }
r.Header.Add("Cookie", strings.Join(cookie, ";"))
} }
if len(c.basicAuth) == 2 { if len(c.basicAuth) == 2 {
r.SetBasicAuth(c.basicAuth[0], c.basicAuth[1]) r.SetBasicAuth(c.basicAuth[0], c.basicAuth[1])
...@@ -162,10 +164,9 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str ...@@ -162,10 +164,9 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if resp.StatusCode >= 400 { if resp.StatusCode >= 500 {
msg := fmt.Sprintf("%s %s failed, get code: %d, %s", req.Method, req.URL, resp.StatusCode, string(body)) msg := fmt.Sprintf("%s %s failed, get code: %d, %s", req.Method, req.URL, resp.StatusCode, body)
err = errors.New(msg) err = errors.New(msg)
return return
} }
...@@ -176,7 +177,7 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str ...@@ -176,7 +177,7 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
return return
} }
// Unmarshal response body to result struct // Unmarshal response body to result struct
if res != nil && resp.StatusCode >= 200 && resp.StatusCode <= 300 { if res != nil && strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
err = json.Unmarshal(body, res) err = json.Unmarshal(body, res)
if err != nil { if err != nil {
msg := fmt.Sprintf("%s %s failed, unmarshal '%s' response failed: %s", req.Method, req.URL, body[:12], err) msg := fmt.Sprintf("%s %s failed, unmarshal '%s' response failed: %s", req.Method, req.URL, body[:12], err)
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/xlab/treeprint" "github.com/xlab/treeprint"
"github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
...@@ -21,13 +20,16 @@ import ( ...@@ -21,13 +20,16 @@ import (
) )
func SessionHandler(sess ssh.Session) { func SessionHandler(sess ssh.Session) {
user, ok := sess.Context().Value(model.ContextKeyUser).(*model.User)
if !ok && user == nil {
logger.Errorf("SSH User %s not found, exit.", sess.User())
return
}
pty, _, ok := sess.Pty() pty, _, ok := sess.Pty()
if ok { if ok {
ctx, cancel := cctx.NewContext(sess) handler := newInteractiveHandler(sess, user)
defer cancel()
handler := newInteractiveHandler(sess, ctx.User())
logger.Infof("Request %s: User %s request pty %s", handler.sess.ID(), sess.User(), pty.Term) logger.Infof("Request %s: User %s request pty %s", handler.sess.ID(), sess.User(), pty.Term)
handler.Dispatch(ctx) handler.Dispatch(sess.Context())
} else { } else {
utils.IgnoreErrWriteString(sess, "No PTY requested.\n") utils.IgnoreErrWriteString(sess, "No PTY requested.\n")
return return
...@@ -66,7 +68,7 @@ type interactiveHandler struct { ...@@ -66,7 +68,7 @@ type interactiveHandler struct {
func (h *interactiveHandler) Initial() { func (h *interactiveHandler) Initial() {
h.assetLoadPolicy = strings.ToLower(config.GetConf().AssetLoadPolicy) h.assetLoadPolicy = strings.ToLower(config.GetConf().AssetLoadPolicy)
h.displayBanner() h.displayBanner()
h.winWatchChan = make(chan bool) h.winWatchChan = make(chan bool, 1)
h.loadDataDone = make(chan struct{}) h.loadDataDone = make(chan struct{})
go h.firstLoadData() go h.firstLoadData()
} }
...@@ -130,7 +132,7 @@ func (h *interactiveHandler) resumeWatchWinSize() { ...@@ -130,7 +132,7 @@ func (h *interactiveHandler) resumeWatchWinSize() {
h.winWatchChan <- true h.winWatchChan <- true
} }
func (h *interactiveHandler) Dispatch(ctx cctx.Context) { func (h *interactiveHandler) Dispatch(ctx context.Context) {
go h.watchWinSizeChange() go h.watchWinSizeChange()
defer logger.Infof("Request %s: User %s stop interactive", h.sess.ID(), h.user.Name) defer logger.Infof("Request %s: User %s stop interactive", h.sess.ID(), h.user.Name)
for { for {
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"github.com/pkg/sftp" "github.com/pkg/sftp"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
"github.com/jumpserver/koko/pkg/service" "github.com/jumpserver/koko/pkg/service"
...@@ -20,10 +19,13 @@ import ( ...@@ -20,10 +19,13 @@ import (
) )
func SftpHandler(sess ssh.Session) { func SftpHandler(sess ssh.Session) {
ctx, cancel := cctx.NewContext(sess) currentUser, ok := sess.Context().Value(model.ContextKeyUser).(*model.User)
defer cancel() if !ok {
logger.Errorf("SFTP User not found, exit.")
return
}
host, _, _ := net.SplitHostPort(sess.RemoteAddr().String()) host, _, _ := net.SplitHostPort(sess.RemoteAddr().String())
userSftp := NewSFTPHandler(ctx.User(), host) userSftp := NewSFTPHandler(currentUser, host)
handlers := sftp.Handlers{ handlers := sftp.Handlers{
FileGet: userSftp, FileGet: userSftp,
FilePut: userSftp, FilePut: userSftp,
...@@ -34,7 +36,7 @@ func SftpHandler(sess ssh.Session) { ...@@ -34,7 +36,7 @@ func SftpHandler(sess ssh.Session) {
logger.Infof("SFTP request %s: Handler start", reqID) logger.Infof("SFTP request %s: Handler start", reqID)
req := sftp.NewRequestServer(sess, handlers) req := sftp.NewRequestServer(sess, handlers)
if err := req.Serve(); err == io.EOF { if err := req.Serve(); err == io.EOF {
logger.Debug("SFTP request %s: Exited session.", reqID) logger.Debugf("SFTP request %s: Exited session.", reqID)
} else if err != nil { } else if err != nil {
logger.Errorf("SFTP request %s: Server completed with error %s", reqID, err) logger.Errorf("SFTP request %s: Server completed with error %s", reqID, err)
} }
...@@ -82,7 +84,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) { ...@@ -82,7 +84,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) {
case "Setstat": case "Setstat":
return return
case "Rename": case "Rename":
logger.Debug("%s=>%s", r.Filepath, r.Target) logger.Debugf("%s=>%s", r.Filepath, r.Target)
return fs.Rename(r.Filepath, r.Target) return fs.Rename(r.Filepath, r.Target)
case "Rmdir": case "Rmdir":
err = fs.RemoveDirectory(r.Filepath) err = fs.RemoveDirectory(r.Filepath)
...@@ -91,7 +93,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) { ...@@ -91,7 +93,7 @@ func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) {
case "Mkdir": case "Mkdir":
err = fs.MkdirAll(r.Filepath) err = fs.MkdirAll(r.Filepath)
case "Symlink": case "Symlink":
logger.Debug("%s=>%s", r.Filepath, r.Target) logger.Debugf("%s=>%s", r.Filepath, r.Target)
err = fs.Symlink(r.Filepath, r.Target) err = fs.Symlink(r.Filepath, r.Target)
default: default:
return return
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"github.com/LeeEirc/elfinder" "github.com/LeeEirc/elfinder"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jumpserver/koko/pkg/cctx"
"github.com/jumpserver/koko/pkg/common" "github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/config" "github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
...@@ -45,8 +44,8 @@ func AuthDecorator(handler http.HandlerFunc) http.HandlerFunc { ...@@ -45,8 +44,8 @@ func AuthDecorator(handler http.HandlerFunc) http.HandlerFunc {
} else { } else {
remoteIP = strings.Split(request.RemoteAddr, ":")[0] remoteIP = strings.Split(request.RemoteAddr, ":")[0]
} }
ctx := context.WithValue(request.Context(), cctx.ContextKeyUser, user) ctx := context.WithValue(request.Context(), model.ContextKeyUser, user)
ctx = context.WithValue(ctx, cctx.ContextKeyRemoteAddr, remoteIP) ctx = context.WithValue(ctx, model.ContextKeyRemoteAddr, remoteIP)
handler(responseWriter, request.WithContext(ctx)) handler(responseWriter, request.WithContext(ctx))
} }
} }
...@@ -66,8 +65,8 @@ func sftpFinder(wr http.ResponseWriter, req *http.Request) { ...@@ -66,8 +65,8 @@ func sftpFinder(wr http.ResponseWriter, req *http.Request) {
func sftpHostConnectorView(wr http.ResponseWriter, req *http.Request) { func sftpHostConnectorView(wr http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req) vars := mux.Vars(req)
hostID := vars["host"] hostID := vars["host"]
user := req.Context().Value(cctx.ContextKeyUser).(*model.User) user := req.Context().Value(model.ContextKeyUser).(*model.User)
remoteIP := req.Context().Value(cctx.ContextKeyRemoteAddr).(string) remoteIP := req.Context().Value(model.ContextKeyRemoteAddr).(string)
switch req.Method { switch req.Method {
case "GET": case "GET":
if err := req.ParseForm(); err != nil { if err := req.ParseForm(); err != nil {
......
package model
type contextKey int64
const (
ContextKeyUser contextKey = iota + 1
ContextKeyRemoteAddr
ContextKeyClient
)
...@@ -18,12 +18,6 @@ package model ...@@ -18,12 +18,6 @@ package model
'date_expired': '2089-03-21 18:18:24 +0800'} 'date_expired': '2089-03-21 18:18:24 +0800'}
*/ */
type AuthResponse struct {
Token string `json:"token"`
Seed string `json:"seed"`
User *User `json:"user"`
}
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
......
...@@ -70,10 +70,10 @@ func (ak *AccessKey) SaveToFile() error { ...@@ -70,10 +70,10 @@ func (ak *AccessKey) SaveToFile() error {
} }
} }
f, err := os.Create(ak.Path) f, err := os.Create(ak.Path)
defer f.Close()
if err != nil { if err != nil {
return err return err
} }
defer f.Close()
_, err = f.WriteString(fmt.Sprintf("%s:%s", ak.ID, ak.Secret)) _, err = f.WriteString(fmt.Sprintf("%s:%s", ak.ID, ak.Secret))
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
......
package service
import (
"sync"
"github.com/jumpserver/koko/pkg/model"
)
type assetsCacheContainer struct {
mapData map[string]model.AssetList
mapETag map[string]string
mu *sync.RWMutex
}
func (c *assetsCacheContainer) Get(key string) (model.AssetList, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapData[key]
return value, ok
}
func (c *assetsCacheContainer) GetETag(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapETag[key]
return value, ok
}
func (c *assetsCacheContainer) SetValue(key string, value model.AssetList) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapData[key] = value
}
func (c *assetsCacheContainer) SetETag(key string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapETag[key] = value
}
type nodesCacheContainer struct {
mapData map[string]model.NodeList
mapETag map[string]string
mu *sync.RWMutex
}
func (c *nodesCacheContainer) Get(key string) (model.NodeList, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapData[key]
return value, ok
}
func (c *nodesCacheContainer) GetETag(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.mapETag[key]
return value, ok
}
func (c *nodesCacheContainer) SetValue(key string, value model.NodeList) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapData[key] = value
}
func (c *nodesCacheContainer) SetETag(key string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
c.mapETag[key] = value
}
package service
const (
ErrLoginConfirmWait = "login_confirm_wait"
ErrLoginConfirmRejected = "login_confirm_rejected"
ErrLoginConfirmRequired = "login_confirm_required"
ErrMFARequired = "mfa_required"
ErrPasswordFailed = "password_failed"
)
...@@ -13,15 +13,12 @@ import ( ...@@ -13,15 +13,12 @@ import (
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
) )
var client = common.NewClient(30, "")
var authClient = common.NewClient(30, "") var authClient = common.NewClient(30, "")
func Initial(ctx context.Context) { func Initial(ctx context.Context) {
cf := config.GetConf() cf := config.GetConf()
keyPath := cf.AccessKeyFile keyPath := cf.AccessKeyFile
client.BaseHost = cf.CoreHost
authClient.BaseHost = cf.CoreHost authClient.BaseHost = cf.CoreHost
client.SetHeader("X-JMS-ORG", "ROOT")
authClient.SetHeader("X-JMS-ORG", "ROOT") authClient.SetHeader("X-JMS-ORG", "ROOT")
if !path.IsAbs(cf.AccessKeyFile) { if !path.IsAbs(cf.AccessKeyFile) {
......
package service
type AuthStatus int64
const (
AuthSuccess AuthStatus = iota + 1
AuthFailed
AuthMFARequired
)
type SessionOption func(*SessionOptions)
func Username(username string) SessionOption {
return func(args *SessionOptions) {
args.Username = username
}
}
func Password(password string) SessionOption {
return func(args *SessionOptions) {
args.Password = password
}
}
func PublicKey(publicKey string) SessionOption {
return func(args *SessionOptions) {
args.PublicKey = publicKey
}
}
func RemoteAddr(remoteAddr string) SessionOption {
return func(args *SessionOptions) {
args.RemoteAddr = remoteAddr
}
}
func LoginType(loginType string) SessionOption {
return func(args *SessionOptions) {
args.LoginType = loginType
}
}
type SessionOptions struct {
Username string
Password string
PublicKey string
RemoteAddr string
LoginType string
}
...@@ -8,9 +8,7 @@ import ( ...@@ -8,9 +8,7 @@ import (
) )
func RegisterTerminal(name, token, comment string) (res model.Terminal) { func RegisterTerminal(name, token, comment string) (res model.Terminal) {
if client.Headers == nil { client := newClient()
client.Headers = make(map[string]string)
}
client.Headers["Authorization"] = fmt.Sprintf("BootstrapToken %s", token) client.Headers["Authorization"] = fmt.Sprintf("BootstrapToken %s", token)
data := map[string]string{"name": name, "comment": comment} data := map[string]string{"name": name, "comment": comment}
_, err := client.Post(TerminalRegisterURL, data, &res) _, err := client.Post(TerminalRegisterURL, data, &res)
......
...@@ -38,3 +38,9 @@ const ( ...@@ -38,3 +38,9 @@ const (
const ( const (
UserAssetSystemUsersURL = "/api/v1/perms/users/%s/assets/%s/system-users/" // 获取用户授权资产的系统用户列表 UserAssetSystemUsersURL = "/api/v1/perms/users/%s/assets/%s/system-users/" // 获取用户授权资产的系统用户列表
) )
// 1.5.5
const (
UserTokenAuthURL = "/api/v1/authentication/tokens/" // 用户登录验证
UserConfirmAuthURL = "/api/v1/authentication/order/auth/"
)
package service package service
import ( import (
"context"
"fmt" "fmt"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/jumpserver/koko/pkg/common"
"github.com/jumpserver/koko/pkg/logger" "github.com/jumpserver/koko/pkg/logger"
"github.com/jumpserver/koko/pkg/model" "github.com/jumpserver/koko/pkg/model"
) )
type AuthResp struct { type AuthResponse struct {
Token string `json:"token"` Err string `json:"error,omitempty"`
Seed string `json:"seed"` Msg string `json:"msg,omitempty"`
User *model.User `json:"user"` Data ResponseData `json:"data,omitempty"`
Username string `json:"username,omitempty"`
Token string `json:"token,omitempty"`
Keyword string `json:"keyword,omitempty"`
DateExpired string `json:"date_expired,omitempty"`
}
type ResponseData struct {
Choices []string `json:"choices,omitempty"`
Url string `json:"url,omitempty"`
}
type AuthOptions struct {
Name string
Url string
}
func NewSessionClient(setters ...SessionOption) SessionClient {
option := &SessionOptions{}
for _, setter := range setters {
setter(option)
}
conn := newClient()
return SessionClient{
option: option,
client: &conn,
authOptions: make(map[string]AuthOptions),
}
}
type SessionClient struct {
option *SessionOptions
client *common.Client
authOptions map[string]AuthOptions
}
func (u *SessionClient) Authenticate(ctx context.Context) (user model.User, authStatus AuthStatus) {
authStatus = AuthFailed
data := map[string]string{
"username": u.option.Username,
"password": u.option.Password,
"public_key": u.option.PublicKey,
"remote_addr": u.option.RemoteAddr,
"login_type": u.option.LoginType,
}
var resp AuthResponse
_, err := u.client.Post(UserTokenAuthURL, data, &resp)
if err != nil {
logger.Errorf("User %s Authenticate err: %s", u.option.Username, err)
return
}
fmt.Printf("%v\n", resp)
if resp.Err != "" {
switch resp.Err {
case ErrLoginConfirmRequired:
if !u.checkConfirm(ctx) {
logger.Errorf("User %s login confirm required err", u.option.Username)
return
}
logger.Infof("User %s login confirm required success", u.option.Username)
authStatus = AuthSuccess
case ErrLoginConfirmWait:
if !u.checkConfirm(ctx) {
logger.Errorf("User %s login confirm Wait check err", u.option.Username)
return
}
logger.Infof("User %s login confirm wait check success", u.option.Username)
authStatus = AuthSuccess
case ErrMFARequired:
for _, item := range resp.Data.Choices {
u.authOptions[item] = AuthOptions{
Name: item,
Url: resp.Data.Url,
}
}
logger.Infof("User %s login need MFA", u.option.Username)
authStatus = AuthMFARequired
}
return
}
if resp.Token != "" {
return user, AuthSuccess
}
return
} }
func Authenticate(username, password, publicKey, remoteAddr, loginType string) (resp *AuthResp, err error) { func (u *SessionClient) CheckUserOTP(ctx context.Context, code string) (user model.User, authStatus AuthStatus) {
var err error
authStatus = AuthFailed
data := map[string]string{ data := map[string]string{
"username": username, "code": code,
"password": password, }
"public_key": publicKey, for name, authData := range u.authOptions {
"remote_addr": remoteAddr, var resp AuthResponse
"login_type": loginType, switch name {
case "opt":
data["type"] = name
}
_, err = u.client.Post(authData.Url, data, &resp)
if err != nil {
return
}
if resp.Err != "" {
return
}
if resp.Msg == "ok" {
return u.Authenticate(ctx)
}
} }
_, err = client.Post(UserAuthURL, data, &resp)
return return
} }
func (u *SessionClient) checkConfirm(ctx context.Context) bool {
doneChan := make(chan bool, 1)
go func() {
var err error
for {
select {
case <-ctx.Done():
doneChan <- false
case <-time.After(5 * time.Second):
var resp AuthResponse
_, err = u.client.Get(UserConfirmAuthURL, &resp)
if err != nil {
logger.Errorf("User %s check confirm err: %s", u.option.Username, err)
doneChan <- false
return
}
if resp.Err != "" {
switch resp.Err {
case ErrLoginConfirmWait:
logger.Infof("User %s wait confirm", u.option.Username)
continue
case ErrLoginConfirmRejected:
default:
}
logger.Infof("User %s confirm rejected %s", u.option.Username, resp.Err)
doneChan <- false
return
}
if resp.Msg == "ok" {
logger.Infof("User %s confirm accepted", u.option.Username)
doneChan <- true
return
}
}
}
}()
return <-doneChan
}
func GetUserDetail(userID string) (user *model.User) { func GetUserDetail(userID string) (user *model.User) {
Url := fmt.Sprintf(UserDetailURL, userID) Url := fmt.Sprintf(UserDetailURL, userID)
_, err := authClient.Get(Url, &user) _, err := authClient.Get(Url, &user)
...@@ -56,20 +199,6 @@ func GetUserByUsername(username string) (user *model.User, err error) { ...@@ -56,20 +199,6 @@ func GetUserByUsername(username string) (user *model.User, err error) {
return return
} }
func CheckUserOTP(seed, code, remoteAddr, loginType string) (resp *AuthResp, err error) {
data := map[string]string{
"seed": seed,
"otp_code": code,
"remote_addr": remoteAddr,
"login_type": loginType,
}
_, err = client.Post(UserAuthOTPURL, data, &resp)
if err != nil {
return
}
return
}
func CheckUserCookie(sessionID, csrfToken string) (user *model.User, err error) { func CheckUserCookie(sessionID, csrfToken string) (user *model.User, err error) {
cli := newClient() cli := newClient()
cli.SetCookie("csrftoken", csrfToken) cli.SetCookie("csrftoken", csrfToken)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment