package srvconn

import (
	"bytes"
	"errors"
	"fmt"
	"os"
	"os/exec"
	"os/user"
	"strconv"
	"sync"
	"syscall"
	"time"

	"github.com/creack/pty"

	"github.com/jumpserver/koko/pkg/logger"
)

const (
	mysqlPrompt = "Enter password: "
)

func NewMysqlServer(ops ...SqlOption) *ServerMysqlConnection {
	args := &SqlOptions{
		Username: os.Getenv("USER"),
		Password: os.Getenv("PASSWORD"),
		Host:     "127.0.0.1",
		Port:     3306,
		DBName:   "",
	}
	for _, setter := range ops {
		setter(args)
	}
	return &ServerMysqlConnection{options: args, onceClose: new(sync.Once)}
}

type ServerMysqlConnection struct {
	options   *SqlOptions
	ptyFD     *os.File
	onceClose *sync.Once
	cmd       *exec.Cmd
}

func (dbconn *ServerMysqlConnection) Connect() (err error) {
	dbconn.cmd = exec.Command("mysql", dbconn.options.CommandArgs()...)
	nobody, err := user.Lookup("nobody")
	if err != nil {
		logger.Errorf("lookup nobody user err: %s", err)
		return errors.New("nobody user does not exist")
	}
	dbconn.cmd.SysProcAttr = &syscall.SysProcAttr{}
	uid, _ := strconv.Atoi(nobody.Uid)
	gid, _ := strconv.Atoi(nobody.Gid)
	dbconn.cmd.SysProcAttr.Credential = &syscall.Credential{Uid: uint32(uid), Gid: uint32(gid)}
	dbconn.ptyFD, err = pty.Start(dbconn.cmd)
	if err != nil {
		logger.Errorf("pty start err: %s", err)
		return fmt.Errorf("start local pty err: %s", err)
	}
	prompt := [len(mysqlPrompt)]byte{}
	nr, err := dbconn.ptyFD.Read(prompt[:])
	if err != nil {
		_ = dbconn.ptyFD.Close()
		_ = dbconn.cmd.Process.Kill()
		logger.Errorf("read mysql pty local fd err: %s", err)
		return fmt.Errorf("mysql conn err: %s", err)
	}
	if !bytes.Equal(prompt[:nr], []byte(mysqlPrompt)) {
		_ = dbconn.cmd.Process.Kill()
		_ = dbconn.ptyFD.Close()
		logger.Errorf("mysql login prompt characters did not match: %s", prompt[:nr])
		return errors.New("failed login mysql")
	}
	// 输入密码, 登录mysql
	_, err = dbconn.ptyFD.Write([]byte(dbconn.options.Password + "\r\n"))
	if err != nil {
		_ = dbconn.ptyFD.Close()
		_ = dbconn.cmd.Process.Kill()
		logger.Errorf("mysql local pty write err: %s", err)
		return fmt.Errorf("mysql conn err: %s", err)
	}
	logger.Infof("Connect mysql database %s success ", dbconn.options.Host)
	return
}

func (dbconn *ServerMysqlConnection) Read(p []byte) (int, error) {
	if dbconn.ptyFD == nil {
		return 0, fmt.Errorf("not connect init")
	}
	return dbconn.ptyFD.Read(p)
}

func (dbconn *ServerMysqlConnection) Write(p []byte) (int, error) {
	if dbconn.ptyFD == nil {
		return 0, fmt.Errorf("not connect init")
	}
	return dbconn.ptyFD.Write(p)
}

func (dbconn *ServerMysqlConnection) SetWinSize(w, h int) error {
	if dbconn.ptyFD == nil {
		return fmt.Errorf("not connect init")
	}
	win := pty.Winsize{
		Rows: uint16(h),
		Cols: uint16(w),
	}
	return pty.Setsize(dbconn.ptyFD, &win)
}

func (dbconn *ServerMysqlConnection) Close() (err error) {
	dbconn.onceClose.Do(func() {
		if dbconn.ptyFD == nil {
			return
		}
		_ = dbconn.ptyFD.Close()
		err = dbconn.cmd.Process.Kill()
	})
	return
}

func (dbconn *ServerMysqlConnection) Timeout() time.Duration {
	return time.Duration(10) * time.Second
}

func (dbconn *ServerMysqlConnection) Protocol() string {
	return "mysql"
}

type SqlOptions struct {
	Username string
	Password string
	DBName   string
	Host     string
	Port     int
}

func (opts *SqlOptions) CommandArgs() []string {
	return []string{
		fmt.Sprintf("--user=%s", opts.Username),
		fmt.Sprintf("--host=%s", opts.Host),
		fmt.Sprintf("--port=%d", opts.Port),
		"--password",
		opts.DBName,
	}
}

type SqlOption func(*SqlOptions)

func SqlUsername(username string) SqlOption {
	return func(args *SqlOptions) {
		args.Username = username
	}
}

func SqlPassword(password string) SqlOption {
	return func(args *SqlOptions) {
		args.Password = password
	}
}

func SqlDBName(dbName string) SqlOption {
	return func(args *SqlOptions) {
		args.DBName = dbName
	}
}

func SqlHost(host string) SqlOption {
	return func(args *SqlOptions) {
		args.Host = host
	}
}

func SqlPort(port int) SqlOption {
	return func(args *SqlOptions) {
		args.Port = port
	}
}