Commit e69d14c5 authored by Eric's avatar Eric

[Update] add tips for confirmation

parent 29121de5
...@@ -2,6 +2,7 @@ package auth ...@@ -2,6 +2,7 @@ package auth
import ( import (
"net" "net"
"strings"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
...@@ -16,6 +17,9 @@ import ( ...@@ -16,6 +17,9 @@ import (
var mfaInstruction = "Please enter 6 digits." var mfaInstruction = "Please enter 6 digits."
var mfaQuestion = "[MFA auth]: " var mfaQuestion = "[MFA auth]: "
var confirmInstruction = "Please wait for your admin to confirm."
var confirmQuestion = "[YES or NO]: "
const ( const (
actionAccepted = "Accepted" actionAccepted = "Accepted"
actionFailed = "Failed" actionFailed = "Failed"
...@@ -34,20 +38,24 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult) ...@@ -34,20 +38,24 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult)
userClient, ok := ctx.Value(model.ContextKeyClient).(*service.SessionClient) userClient, ok := ctx.Value(model.ContextKeyClient).(*service.SessionClient)
if !ok { if !ok {
sessionClient := service.NewSessionClient(service.Username(username), sessionClient := service.NewSessionClient(service.Username(username),
service.Password(password), service.PublicKey(publicKey),
service.RemoteAddr(remoteAddr), service.LoginType("T")) service.RemoteAddr(remoteAddr), service.LoginType("T"))
userClient = &sessionClient userClient = &sessionClient
ctx.SetValue(model.ContextKeyClient, userClient)
} }
userClient.SetOption(service.Password(password), service.PublicKey(publicKey))
user, authStatus := userClient.Authenticate(ctx) user, authStatus := userClient.Authenticate(ctx)
switch authStatus { switch authStatus {
case service.AuthMFARequired: case service.AuthMFARequired:
ctx.SetValue(model.ContextKeyClient, &userClient)
action = actionPartialAccepted action = actionPartialAccepted
res = ssh.AuthPartiallySuccessful res = ssh.AuthPartiallySuccessful
case service.AuthSuccess: case service.AuthSuccess:
res = ssh.AuthSuccessful res = ssh.AuthSuccessful
ctx.SetValue(model.ContextKeyUser, &user) ctx.SetValue(model.ContextKeyUser, &user)
case service.AuthConfirmRequired:
required := true
ctx.SetValue(model.ContextKeyConfirmRequired, &required)
action = actionPartialAccepted
res = ssh.AuthPartiallySuccessful
default: default:
action = actionFailed action = actionFailed
} }
...@@ -73,37 +81,64 @@ func CheckUserPublicKey(ctx ssh.Context, key ssh.PublicKey) ssh.AuthResult { ...@@ -73,37 +81,64 @@ func CheckUserPublicKey(ctx ssh.Context, key ssh.PublicKey) ssh.AuthResult {
} }
func CheckMFA(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) (res ssh.AuthResult) { func CheckMFA(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) (res ssh.AuthResult) {
if value, ok := ctx.Value(model.ContextKeyConfirmFailed).(*bool); ok && *value {
return ssh.AuthFailed
}
username := ctx.User() username := ctx.User()
remoteAddr, _, _ := net.SplitHostPort(ctx.RemoteAddr().String()) remoteAddr, _, _ := net.SplitHostPort(ctx.RemoteAddr().String())
res = ssh.AuthFailed res = ssh.AuthFailed
defer func() {
authMethod := "MFA" var confirmAction bool
if res == ssh.AuthSuccessful { instruction := mfaInstruction
action := actionAccepted question := mfaQuestion
logger.Infof("%s %s for %s from %s", action, authMethod, username, remoteAddr)
} else { client, ok := ctx.Value(model.ContextKeyClient).(*service.SessionClient)
action := actionFailed if !ok {
logger.Errorf("%s %s for %s from %s", action, authMethod, username, remoteAddr) logger.Errorf("User %s Mfa Auth failed: not found session client.", username, )
return
} }
}() value, ok := ctx.Value(model.ContextKeyConfirmRequired).(*bool)
answers, err := challenger(username, mfaInstruction, []string{mfaQuestion}, []bool{true}) if ok && *value {
confirmAction = true
instruction = confirmInstruction
question = confirmQuestion
}
answers, err := challenger(username, instruction, []string{question}, []bool{true})
if err != nil || len(answers) != 1 { if err != nil || len(answers) != 1 {
return return
} }
mfaCode := answers[0] if confirmAction {
client, ok := ctx.Value(model.ContextKeyClient).(*service.SessionClient) switch strings.TrimSpace(strings.ToLower(answers[0])) {
if !ok { case "yes", "y", "":
logger.Errorf("User %s Mfa Auth failed: not found session client.", username, ) user, authStatus := client.CheckConfirm(ctx)
switch authStatus {
case service.AuthSuccess:
res = ssh.AuthSuccessful
ctx.SetValue(model.ContextKeyUser, &user)
return return
} }
default:
client.CancelConfirm()
}
failed := true
ctx.SetValue(model.ContextKeyConfirmFailed, &failed)
return
}
mfaCode := answers[0]
user, authStatus := client.CheckUserOTP(ctx, mfaCode) user, authStatus := client.CheckUserOTP(ctx, mfaCode)
switch authStatus { switch authStatus {
case service.AuthSuccess: case service.AuthSuccess:
res = ssh.AuthSuccessful res = ssh.AuthSuccessful
ctx.SetValue(model.ContextKeyUser, &user) ctx.SetValue(model.ContextKeyUser, &user)
logger.Infof("User %s Mfa Auth success", username) logger.Infof("%s MFA for %s from %s", actionAccepted, username, remoteAddr)
case service.AuthConfirmRequired:
res = ssh.AuthPartiallySuccessful
required := true
ctx.SetValue(model.ContextKeyConfirmRequired, &required)
logger.Infof("%s MFA for %s from %s", actionPartialAccepted, username, remoteAddr)
default: default:
logger.Errorf("User %s Mfa Auth failed", username) logger.Errorf("%s MFA for %s from %s", actionFailed, username, remoteAddr)
} }
return return
} }
......
...@@ -6,4 +6,6 @@ const ( ...@@ -6,4 +6,6 @@ const (
ContextKeyUser contextKey = iota + 1 ContextKeyUser contextKey = iota + 1
ContextKeyRemoteAddr ContextKeyRemoteAddr
ContextKeyClient ContextKeyClient
ContextKeyConfirmRequired
ContextKeyConfirmFailed
) )
...@@ -6,6 +6,7 @@ const ( ...@@ -6,6 +6,7 @@ const (
AuthSuccess AuthStatus = iota + 1 AuthSuccess AuthStatus = iota + 1
AuthFailed AuthFailed
AuthMFARequired AuthMFARequired
AuthConfirmRequired
) )
type SessionOption func(*SessionOptions) type SessionOption func(*SessionOptions)
......
...@@ -55,6 +55,12 @@ type SessionClient struct { ...@@ -55,6 +55,12 @@ type SessionClient struct {
authOptions map[string]AuthOptions authOptions map[string]AuthOptions
} }
func (u *SessionClient) SetOption(setters ...SessionOption) {
for _, setter := range setters {
setter(u.option)
}
}
func (u *SessionClient) Authenticate(ctx context.Context) (user model.User, authStatus AuthStatus) { func (u *SessionClient) Authenticate(ctx context.Context) (user model.User, authStatus AuthStatus) {
authStatus = AuthFailed authStatus = AuthFailed
data := map[string]string{ data := map[string]string{
...@@ -73,12 +79,8 @@ func (u *SessionClient) Authenticate(ctx context.Context) (user model.User, auth ...@@ -73,12 +79,8 @@ func (u *SessionClient) Authenticate(ctx context.Context) (user model.User, auth
if resp.Err != "" { if resp.Err != "" {
switch resp.Err { switch resp.Err {
case ErrLoginConfirmWait: case ErrLoginConfirmWait:
if !u.checkConfirm(ctx) { logger.Infof("User %s login need confirmation", u.option.Username)
logger.Errorf("User %s login confirm required err", u.option.Username) authStatus = AuthConfirmRequired
return
}
logger.Infof("User %s login confirm required success", u.option.Username)
return u.Authenticate(ctx)
case ErrMFARequired: case ErrMFARequired:
for _, item := range resp.Data.Choices { for _, item := range resp.Data.Choices {
u.authOptions[item] = AuthOptions{ u.authOptions[item] = AuthOptions{
...@@ -129,16 +131,14 @@ func (u *SessionClient) CheckUserOTP(ctx context.Context, code string) (user mod ...@@ -129,16 +131,14 @@ func (u *SessionClient) CheckUserOTP(ctx context.Context, code string) (user mod
return return
} }
func (u *SessionClient) checkConfirm(ctx context.Context) (ok bool) { func (u *SessionClient) CheckConfirm(ctx context.Context) (user model.User, authStatus AuthStatus) {
var err error var err error
for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
_, err = u.client.Delete(UserConfirmAuthURL, nil) logger.Errorf("User %s exit and cancel confirmation", u.option.Username)
if err != nil { u.CancelConfirm()
logger.Errorf("User %s cancel confirmation err: %s", u.option.Username, err)
return return
}
logger.Infof("User %s cancel confirm request", u.option.Username)
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
var resp authResponse var resp authResponse
_, err = u.client.Get(UserConfirmAuthURL, &resp) _, err = u.client.Get(UserConfirmAuthURL, &resp)
...@@ -150,7 +150,7 @@ func (u *SessionClient) checkConfirm(ctx context.Context) (ok bool) { ...@@ -150,7 +150,7 @@ func (u *SessionClient) checkConfirm(ctx context.Context) (ok bool) {
switch resp.Err { switch resp.Err {
case ErrLoginConfirmWait: case ErrLoginConfirmWait:
logger.Infof("User %s still wait confirm", u.option.Username) logger.Infof("User %s still wait confirm", u.option.Username)
return u.checkConfirm(ctx) continue
case ErrLoginConfirmRejected: case ErrLoginConfirmRejected:
logger.Infof("User %s confirmation was rejected by admin", u.option.Username) logger.Infof("User %s confirmation was rejected by admin", u.option.Username)
default: default:
...@@ -160,10 +160,19 @@ func (u *SessionClient) checkConfirm(ctx context.Context) (ok bool) { ...@@ -160,10 +160,19 @@ func (u *SessionClient) checkConfirm(ctx context.Context) (ok bool) {
} }
if resp.Msg == "ok" { if resp.Msg == "ok" {
logger.Infof("User %s confirmation was accepted", u.option.Username) logger.Infof("User %s confirmation was accepted", u.option.Username)
return true return u.Authenticate(ctx)
}
} }
} }
}
func (u *SessionClient) CancelConfirm() {
_, err := u.client.Delete(UserConfirmAuthURL, nil)
if err != nil {
logger.Errorf("Cancel User %s confirmation err: %s", u.option.Username, err)
return return
}
logger.Infof("Cancel User %s confirmation success", u.option.Username)
} }
func GetUserDetail(userID string) (user *model.User) { func GetUserDetail(userID string) (user *model.User) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment