Commit 334493e0 authored by Eric's avatar Eric

fix conflicts

parents c489f6c4 b9a7b49b
......@@ -11,11 +11,11 @@
[[projects]]
branch = "dev"
digest = "1:6385b015971978bc1cc76802e9bc7b29984c104d35fe941ef311d0e758641a2b"
digest = "1:b728156c8c642481f6b30c46bfd5f6117f7e3c8495b1f0c47795e2db958c6729"
name = "github.com/gliderlabs/ssh"
packages = ["."]
pruneopts = "UT"
revision = "624e955a0ae6c0e82fe0095ad5276051483ae67d"
revision = "1c00c8e8b607f53b40cdcb98816d682b0d007d12"
source = "github.com/ibuler/ssh"
[[projects]]
......
package auth
import (
"cocogo/pkg/cctx"
"cocogo/pkg/i18n"
"strings"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
"cocogo/pkg/cctx"
"cocogo/pkg/common"
"cocogo/pkg/logger"
"cocogo/pkg/service"
)
var mfaInstruction = i18n.T("Please enter 6 digits.")
var mfaQuestion = i18n.T("[MFA auth]: ")
var contentKeyMFASeed = "MFASeed"
func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult) {
username := ctx.User()
remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0]
......@@ -24,7 +30,12 @@ func checkAuth(ctx ssh.Context, password, publicKey string) (res ssh.AuthResult)
}
if err != nil {
action = "Failed"
res = ssh.AuthFailed
} else if resp.Seed != "" && resp.Token == "" {
ctx.SetValue(contentKeyMFASeed, resp.Seed)
res = ssh.AuthPartiallySuccessful
} else {
res = ssh.AuthSuccessful
}
if resp != nil {
switch resp.User.IsMFA {
......@@ -57,17 +68,41 @@ func CheckUserPublicKey(ctx ssh.Context, key ssh.PublicKey) ssh.AuthResult {
}
func CheckMFA(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) ssh.AuthResult {
answers, err := challenger(ctx.User(), "Please enter 6 digits.", []string{"[MFA auth]: "}, []bool{true})
username := ctx.User()
answers, err := challenger(username, mfaInstruction, []string{mfaQuestion}, []bool{true})
if err != nil {
return ssh.AuthFailed
}
seed := ctx.Value(cctx.ContextKeySeed).(string)
code := answers[0]
res, err := service.AuthenticateMFA(seed, code, "T")
if err != nil || res != nil {
if len(answers) != 0 {
return ssh.AuthFailed
}
mfaCode := answers[0]
seed, ok := ctx.Value(contentKeyMFASeed).(string)
if !ok {
logger.Error("Mfa Auth failed, may be user password or publickey auth failed")
return ssh.AuthFailed
}
resp, err := service.CheckUserOTP(seed, mfaCode)
//ok := checkAuth(ctx, "admin", "")
if err != nil {
logger.Error("Mfa Auth failed: ", err)
return ssh.AuthFailed
}
if resp.Token != "" {
return ssh.AuthSuccessful
}
return ssh.AuthFailed
}
func CheckUserNeedMFA(ctx ssh.Context) (methods []string) {
username := ctx.User()
user, err := service.GetUserByUsername(username)
if err != nil {
return
}
if user.OTPLevel > 0 {
return []string{"keyboard-interactive"}
}
return
}
......@@ -68,28 +68,32 @@ func (c *Client) marshalData(data interface{}) (reader io.Reader, error error) {
return
}
func (c *Client) ParseUrlQuery(url string, query map[string]string) string {
var paramSlice []string
for k, v := range query {
paramSlice = append(paramSlice, fmt.Sprintf("%s=%s", k, v))
func (c *Client) parseUrlQuery(url string, params []map[string]string) string {
if len(params) != 1 {
return url
}
var query []string
for k, v := range params[0] {
query = append(query, fmt.Sprintf("%s=%s", k, v))
}
param := strings.Join(paramSlice, "&")
param := strings.Join(query, "&")
if strings.Contains(url, "?") {
url += "&" + param
} else {
url += "?" + param
}
if c.BaseHost != "" {
url = strings.TrimRight(c.BaseHost, "/") + url
}
return url
}
func (c *Client) ParseUrl(url string) string {
func (c *Client) ParseUrl(url string, params []map[string]string) string {
url = c.parseUrlQuery(url, params)
if c.BaseHost != "" {
url = strings.TrimRight(c.BaseHost, "/") + url
}
return url
}
func (c *Client) SetAuthHeader(r *http.Request, params ...map[string]string) {
func (c *Client) setAuthHeader(r *http.Request) {
if len(c.cookie) != 0 {
cookie := make([]string, 0)
for k, v := range c.cookie {
......@@ -108,7 +112,7 @@ func (c *Client) SetAuthHeader(r *http.Request, params ...map[string]string) {
}
}
func (c *Client) SetReqHeaders(req *http.Request, params ...map[string]string) {
func (c *Client) SetReqHeaders(req *http.Request) {
if len(c.Headers) != 0 {
for k, v := range c.Headers {
req.Header.Set(k, v)
......@@ -118,11 +122,11 @@ func (c *Client) SetReqHeaders(req *http.Request, params ...map[string]string) {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("user-Agent", "coco-client")
c.SetAuthHeader(req)
c.setAuthHeader(req)
}
func (c *Client) NewRequest(method, url string, body interface{}) (req *http.Request, err error) {
url = c.ParseUrl(url)
func (c *Client) NewRequest(method, url string, body interface{}, params []map[string]string) (req *http.Request, err error) {
url = c.ParseUrl(url, params)
reader, err := c.marshalData(body)
if err != nil {
return
......@@ -136,7 +140,7 @@ func (c *Client) NewRequest(method, url string, body interface{}) (req *http.Req
// params:
// 1. query string if set {"name": "ibuler"}
func (c *Client) Do(method, url string, data, res interface{}, params ...map[string]string) (err error) {
req, err := c.NewRequest(method, url, data)
req, err := c.NewRequest(method, url, data, params)
resp, err := c.http.Do(req)
if err != nil {
return
......@@ -154,7 +158,7 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
if res != nil {
err = json.Unmarshal(body, res)
if err != nil {
msg := fmt.Sprintf("%s %s failed, unmarshal `%s` response failed", req.Method, req.URL, string(body)[:50])
msg := fmt.Sprintf("%s %s failed, unmarshal '%s' response failed: %s", req.Method, req.URL, body, err)
err = errors.New(msg)
return
}
......@@ -162,24 +166,24 @@ func (c *Client) Do(method, url string, data, res interface{}, params ...map[str
return
}
func (c *Client) Get(url string, res interface{}) (err error) {
return c.Do("GET", url, nil, res)
func (c *Client) Get(url string, res interface{}, params ...map[string]string) (err error) {
return c.Do("GET", url, nil, res, params...)
}
func (c *Client) Post(url string, data interface{}, res interface{}) (err error) {
return c.Do("POST", url, data, res)
func (c *Client) Post(url string, data interface{}, res interface{}, params ...map[string]string) (err error) {
return c.Do("POST", url, data, res, params...)
}
func (c *Client) Delete(url string, res interface{}) (err error) {
return c.Do("DELETE", url, nil, res)
func (c *Client) Delete(url string, res interface{}, params ...map[string]string) (err error) {
return c.Do("DELETE", url, nil, res, params...)
}
func (c *Client) Put(url string, data interface{}, res interface{}) (err error) {
return c.Do("PUT", url, data, res)
func (c *Client) Put(url string, data interface{}, res interface{}, params ...map[string]string) (err error) {
return c.Do("PUT", url, data, res, params...)
}
func (c *Client) Patch(url string, data interface{}, res interface{}) (err error) {
return c.Do("PATCH", url, data, res)
func (c *Client) Patch(url string, data interface{}, res interface{}, params ...map[string]string) (err error) {
return c.Do("PATCH", url, data, res, params...)
}
func (c *Client) PostForm(url string, data interface{}, res interface{}) (err error) {
......
......@@ -12,7 +12,8 @@ import (
const (
username = "admin"
password = "admin"
usersUrl = "http://localhost/api/v1/users"
baseHost = "http://localhost"
usersUrl = "/api/v1/users"
)
type User struct {
......@@ -26,7 +27,7 @@ var user = User{ID: 2, Name: "Jumpserver", Age: 5}
var userDeleteUrl = fmt.Sprintf("%s/%d", usersUrl, user.ID)
func TestClient_Do(t *testing.T) {
c := NewClient(10)
c := NewClient(10, "")
err := c.Do("GET", usersUrl, nil, nil)
if err == nil {
t.Error("Failed Do(), want get err but not")
......@@ -43,7 +44,7 @@ func TestClient_Do(t *testing.T) {
}
func TestClient_Get(t *testing.T) {
c := NewClient(10)
c := NewClient(10, baseHost)
err := c.Get(usersUrl, nil)
if err == nil {
t.Errorf("Failed Get(%s): want get err but not", usersUrl)
......@@ -56,7 +57,7 @@ func TestClient_Get(t *testing.T) {
}
func TestClient_Post(t *testing.T) {
c := NewClient(10)
c := NewClient(10, baseHost)
var userCreated User
err := c.Post(usersUrl, user, &userCreated)
if err != nil {
......@@ -68,7 +69,7 @@ func TestClient_Post(t *testing.T) {
}
func TestClient_Put(t *testing.T) {
c := NewClient(10)
c := NewClient(10, "")
var userUpdated User
err := c.Put(usersUrl, user, &userUpdated)
if err != nil {
......@@ -80,7 +81,7 @@ func TestClient_Put(t *testing.T) {
}
func TestClient_Delete(t *testing.T) {
c := NewClient(10)
c := NewClient(10, baseHost)
c.SetBasicAuth(username, password)
err := c.Delete(userDeleteUrl, nil)
if err != nil {
......
......@@ -26,10 +26,10 @@ type AuthResponse struct {
type User struct {
Id string `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
OTPLevel int `json:"otp_level"`
Role string `json:"role"`
IsValid bool `json:"is_valid"`
IsActive bool `json:"is_active"`
......
......@@ -8,7 +8,7 @@ import (
)
func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.SystemUserAuthInfo) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID), nil)
Url := fmt.Sprintf(SystemUserAssetAuthURL, systemUserID, assetID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Error("Get system user Asset auth info failed")
......@@ -17,8 +17,7 @@ func GetSystemUserAssetAuthInfo(systemUserID, assetID string) (info model.System
}
func GetSystemUserAuthInfo(systemUserID string) (info model.SystemUserAuthInfo) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(SystemUserAuthInfoURL, systemUserID), nil)
Url := fmt.Sprintf(SystemUserAuthInfoURL, systemUserID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Error("Get system user auth info failed")
......@@ -67,7 +66,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
"filter": "de7693ca-75d5-4639-986b-44ed390260a0"
}
]`*/
Url := authClient.ParseUrlQuery(fmt.Sprintf(SystemUserCmdFilterRules, systemUserID), nil)
Url := fmt.Sprintf(SystemUserCmdFilterRules, systemUserID)
err = authClient.Get(Url, &rules)
if err != nil {
......@@ -77,7 +76,7 @@ func GetSystemUserFilterRules(systemUserID string) (rules []model.SystemUserFilt
}
func GetSystemUser(systemUserID string) (info model.SystemUser) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(SystemUser, systemUserID), nil)
Url := fmt.Sprintf(SystemUser, systemUserID)
err := authClient.Get(Url, &info)
if err != nil {
logger.Errorf("Get system user %s failed", systemUserID)
......@@ -86,7 +85,7 @@ func GetSystemUser(systemUserID string) (info model.SystemUser) {
}
func GetAsset(assetID string) (asset model.Asset) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(Asset, assetID), nil)
Url := fmt.Sprintf(Asset, assetID)
err := authClient.Get(Url, &asset)
if err != nil {
logger.Errorf("Get Asset %s failed", assetID)
......@@ -95,7 +94,7 @@ func GetAsset(assetID string) (asset model.Asset) {
}
func GetTokenAsset(token string) (tokenUser model.TokenUser) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(TokenAsset, token), nil)
Url := fmt.Sprintf(TokenAsset, token)
err := authClient.Get(Url, &tokenUser)
if err != nil {
logger.Error("Get Token Asset info failed")
......
......@@ -29,18 +29,23 @@ func Initial() {
}
func ValidateAccessAuth() {
maxTry := 30
count := 0
for count < 100 {
user := getTerminalProfile()
if user.Id != "" {
for count < maxTry {
user, err := GetProfile()
if err == nil && user.Role == "App" {
break
}
msg := `Connect server error or access key is invalid,
remove %s run again`
if err != nil {
msg := `Connect server error or access key is invalid, remove %s run again`
logger.Errorf(msg, config.Conf.AccessKeyFile)
}
if user.Role != "App" {
logger.Error("Access role is not App, is: ", user.Role)
}
time.Sleep(3 * time.Second)
count++
if count >= 3 {
if count >= maxTry {
os.Exit(1)
}
}
......
......@@ -11,9 +11,9 @@ func GetUserAssets(userId, cachePolicy string) (assets model.AssetList) {
if cachePolicy == "" {
cachePolicy = "0"
}
params := map[string]string{"cache_policy": cachePolicy}
Url := authClient.ParseUrlQuery(fmt.Sprintf(UserAssetsURL, userId), params)
err := authClient.Get(Url, &assets)
payload := map[string]string{"cache_policy": cachePolicy}
Url := fmt.Sprintf(UserAssetsURL, userId)
err := authClient.Get(Url, &assets, payload)
if err != nil {
logger.Error("GetUserAssets---err")
}
......@@ -24,9 +24,9 @@ func GetUserNodes(userId, cachePolicy string) (nodes model.NodeList) {
if cachePolicy == "" {
cachePolicy = "0"
}
params := map[string]string{"cache_policy": cachePolicy}
Url := authClient.ParseUrlQuery(fmt.Sprintf(UserNodesAssetsURL, userId), params)
err := authClient.Get(Url, &nodes)
payload := map[string]string{"cache_policy": cachePolicy}
Url := fmt.Sprintf(UserNodesAssetsURL, userId)
err := authClient.Get(Url, &nodes, payload)
if err != nil {
logger.Error("GetUserNodes err")
}
......@@ -34,17 +34,17 @@ func GetUserNodes(userId, cachePolicy string) (nodes model.NodeList) {
}
func ValidateUserAssetPermission(userId, assetId, systemUserId, action string) bool {
params := map[string]string{
payload := map[string]string{
"user_id": userId,
"asset_id": assetId,
"system_user_id": systemUserId,
"action_name": action,
}
Url := authClient.ParseUrlQuery(ValidateUserAssetPermissionURL, params)
Url := ValidateUserAssetPermissionURL
var res struct {
Msg bool `json:"msg"`
}
err := authClient.Get(Url, &res)
err := authClient.Get(Url, &res, payload)
if err != nil {
logger.Error(err)
......
......@@ -13,18 +13,7 @@ func RegisterTerminal(name, token, comment string) (res model.Terminal) {
}
client.Headers["Authorization"] = fmt.Sprintf("BootstrapToken %s", token)
data := map[string]string{"name": name, "comment": comment}
Url := client.ParseUrlQuery(TerminalRegisterURL, nil)
err := client.Post(Url, data, &res)
if err != nil {
logger.Error(err)
}
return
}
func getTerminalProfile() (user model.User) {
Url := authClient.ParseUrlQuery(UserProfileURL, nil)
err := authClient.Get(Url, &user)
err := client.Post(TerminalRegisterURL, data, &res)
if err != nil {
logger.Error(err)
}
......@@ -36,8 +25,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) {
data := map[string][]string{
"sessions": sIds,
}
Url := authClient.ParseUrlQuery(TerminalHeartBeatURL, nil)
err := authClient.Post(Url, data, &res)
err := authClient.Post(TerminalHeartBeatURL, data, &res)
if err != nil {
logger.Error(err)
}
......@@ -46,8 +34,7 @@ func TerminalHeartBeat(sIds []string) (res []model.TerminalTask) {
func CreateSession(data map[string]interface{}) bool {
var res map[string]interface{}
Url := authClient.ParseUrlQuery(SessionListURL, nil)
err := authClient.Post(Url, data, &res)
err := authClient.Post(SessionListURL, data, &res)
if err == nil {
return true
}
......@@ -61,7 +48,7 @@ func FinishSession(sid, dataEnd string) {
"is_finished": true,
"date_end": dataEnd,
}
Url := authClient.ParseUrlQuery(fmt.Sprintf(SessionDetailURL, sid), nil)
Url := fmt.Sprintf(SessionDetailURL, sid)
err := authClient.Patch(Url, data, &res)
if err != nil {
logger.Error(err)
......@@ -71,7 +58,7 @@ func FinishSession(sid, dataEnd string) {
func FinishReply(sid string) bool {
var res map[string]interface{}
data := map[string]bool{"has_replay": true}
Url := authClient.ParseUrlQuery(fmt.Sprintf(SessionDetailURL, sid), nil)
Url := fmt.Sprintf(SessionDetailURL, sid)
err := authClient.Patch(Url, data, &res)
if err != nil {
logger.Error(err)
......@@ -83,7 +70,7 @@ func FinishReply(sid string) bool {
func FinishTask(tid string) bool {
var res map[string]interface{}
data := map[string]bool{"is_finished": true}
Url := authClient.ParseUrlQuery(fmt.Sprintf(FinishTaskURL, tid), nil)
Url := fmt.Sprintf(FinishTaskURL, tid)
err := authClient.Patch(Url, data, res)
if err != nil {
logger.Error(err)
......@@ -93,8 +80,7 @@ func FinishTask(tid string) bool {
}
func LoadConfigFromServer() (res model.TerminalConf) {
Url := authClient.ParseUrlQuery(TerminalConfigURL, nil)
err := authClient.Get(Url, &res)
err := authClient.Get(TerminalConfigURL, &res)
if err != nil {
logger.Error(err)
}
......
......@@ -4,6 +4,7 @@ const (
UserAuthURL = "/api/users/v1/auth/" // post 验证用户登陆
UserProfileURL = "/api/users/v1/profile/" // 获取当前用户的基本信息
UserUserURL = "/api/users/v1/users/%s/" // 获取用户信息
UserAuthOTPURL = "/api/users/v1/otp/auth/" // 验证OTP
AuthMFAURL = "/api/authentication/v1/otp/auth/" // MFA 验证用户信息
......
......@@ -3,11 +3,19 @@ package service
import (
"fmt"
"github.com/pkg/errors"
"cocogo/pkg/logger"
"cocogo/pkg/model"
)
func Authenticate(username, password, publicKey, remoteAddr, loginType string) (resp *model.AuthResponse, err error) {
type AuthResp struct {
Token string `json:"token"`
Seed string `json:"seed"`
User *model.User `json:"user"`
}
func Authenticate(username, password, publicKey, remoteAddr, loginType string) (resp *AuthResp, err error) {
data := map[string]string{
"username": username,
"password": password,
......@@ -15,8 +23,9 @@ func Authenticate(username, password, publicKey, remoteAddr, loginType string) (
"remote_addr": remoteAddr,
"login_type": loginType,
}
Url := client.ParseUrlQuery(UserAuthURL, nil)
err = client.Post(Url, data, resp)
Url := client.ParseUrl(UserAuthURL, nil)
err = client.Post(Url, data, &resp)
if err != nil {
logger.Error(err)
}
......@@ -39,7 +48,7 @@ func AuthenticateMFA(seed, code, loginType string) (resp *model.AuthResponse, er
"login_type": loginType,
}
Url := client.ParseUrlQuery(AuthMFAURL, nil)
Url := client.ParseUrl(AuthMFAURL, nil)
err = client.Post(Url, data, resp)
if err != nil {
logger.Error(err)
......@@ -49,19 +58,50 @@ func AuthenticateMFA(seed, code, loginType string) (resp *model.AuthResponse, er
}
func GetUserProfile(userId string) (user *model.User) {
Url := authClient.ParseUrlQuery(fmt.Sprintf(UserUserURL, userId), nil)
err := authClient.Get(Url, &user)
Url := fmt.Sprintf(UserUserURL, userId)
err := authClient.Get(Url, user)
if err != nil {
logger.Error(err)
}
return
}
func GetProfile() (user *model.User, err error) {
err = authClient.Get(UserProfileURL, &user)
return
}
func GetUserByUsername(username string) (user *model.User, err error) {
var users []*model.User
payload := map[string]string{"username": username}
err = authClient.Get(UserUserURL, &users, payload)
if err != nil {
return
}
if len(users) != 1 {
err = errors.New(fmt.Sprintf("Not found user by username: %s", username))
} else {
user = users[0]
}
return
}
func CheckUserOTP(seed, code string) (resp *AuthResp, err error) {
data := map[string]string{
"seed": seed,
"otp_code": code,
}
err = client.Post(UserAuthOTPURL, data, resp)
if err != nil {
return
}
return
}
func CheckUserCookie(sessionId, csrfToken string) (user *model.User) {
client.SetCookie("csrftoken", csrfToken)
client.SetCookie("sessionid", sessionId)
Url := client.ParseUrlQuery(UserProfileURL, nil)
err := client.Get(Url, &user)
err := client.Get(UserProfileURL, &user)
if err != nil {
logger.Error(err)
}
......
......@@ -6,7 +6,6 @@ import (
"time"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
"cocogo/pkg/auth"
"cocogo/pkg/config"
......@@ -16,24 +15,6 @@ import (
const version = "v1.4.0"
func defaultConfig(ctx ssh.Context) (conf *gossh.ServerConfig) {
conf = new(gossh.ServerConfig)
conf.AuthLogCallback = func(conn gossh.ConnMetadata, method string, err error) {
fmt.Println(err)
fmt.Println(method)
result := "failed"
if err == nil {
result = "success"
}
logger.Debugf("%s use AuthMethod %s %s\n", conn.User(), method, result)
}
conf.NextAuthMethodsCallback = func(conn gossh.ConnMetadata) (methods []string) {
fmt.Println("Username: ", conn.User())
return []string{"keyboard-interactive"}
}
return conf
}
var (
conf = config.Conf
)
......@@ -55,7 +36,7 @@ func StartServer() {
PasswordHandler: auth.CheckUserPassword,
PublicKeyHandler: auth.CheckUserPublicKey,
KeyboardInteractiveHandler: auth.CheckMFA,
DefaultServerConfigCallback: defaultConfig,
NextAuthMethodsHandler: auth.CheckUserNeedMFA,
HostSigners: []ssh.Signer{signer},
Handler: handler.SessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{},
......
package cocogo
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment