• Eric_Lee's avatar
    Dev (#66) · 9ab4ea69
    Eric_Lee authored
    * fix bug; user auth login type
    
    * fix addr bugs
    
    * fix download many files bug
    Unverified
    9ab4ea69
sftp.go 4.81 KB
package handler

import (
	"io"
	"net"
	"os"
	"sync"
	"syscall"
	"time"

	"github.com/gliderlabs/ssh"
	"github.com/pkg/sftp"

	"github.com/jumpserver/koko/pkg/cctx"
	"github.com/jumpserver/koko/pkg/logger"
	"github.com/jumpserver/koko/pkg/model"
	"github.com/jumpserver/koko/pkg/service"
	"github.com/jumpserver/koko/pkg/srvconn"
)

func SftpHandler(sess ssh.Session) {
	ctx, cancel := cctx.NewContext(sess)
	defer cancel()
	host, _, _ := net.SplitHostPort(sess.RemoteAddr().String())
	userSftp := NewSFTPHandler(ctx.User(), host)
	handlers := sftp.Handlers{
		FileGet:  userSftp,
		FilePut:  userSftp,
		FileCmd:  userSftp,
		FileList: userSftp,
	}

	req := sftp.NewRequestServer(sess, handlers)
	if err := req.Serve(); err == io.EOF {
		_ = req.Close()
		userSftp.Close()
		logger.Info("sftp client exited session.")
	} else if err != nil {
		logger.Error("sftp server completed with error:", err)
	}
}

func NewSFTPHandler(user *model.User, addr string) *sftpHandler {
	assets := service.GetUserAssets(user.ID, "1", "")
	return &sftpHandler{srvconn.NewUserSFTP(user, addr, assets...)}
}

type sftpHandler struct {
	*srvconn.UserSftp
}

func (fs *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
	switch r.Method {
	case "List":
		logger.Debug("List method: ", r.Filepath)
		res, err := fs.ReadDir(r.Filepath)
		fileInfos := make(listerat, 0, len(res))
		for i := 0; i < len(res); i++ {
			fileInfos = append(fileInfos, &wrapperSFTPFileInfo{f: res[i]})
		}
		return fileInfos, err
	case "Stat":
		logger.Debug("stat method: ", r.Filepath)
		fsInfo, err := fs.Stat(r.Filepath)
		return listerat([]os.FileInfo{fsInfo}), err
	case "Readlink":
		logger.Debug("Readlink method", r.Filepath)
		filename, err := fs.ReadLink(r.Filepath)
		fsInfo := srvconn.NewFakeSymFile(filename)
		return listerat([]os.FileInfo{&wrapperSFTPFileInfo{f: fsInfo}}), err
	}
	return nil, sftp.ErrSshFxOpUnsupported
}

func (fs *sftpHandler) Filecmd(r *sftp.Request) (err error) {
	logger.Debug("File cmd: ", r.Filepath)

	switch r.Method {
	case "Setstat":
		return
	case "Rename":
		logger.Debug("%s=>%s", r.Filepath, r.Target)
		return fs.Rename(r.Filepath, r.Target)
	case "Rmdir":
		err = fs.RemoveDirectory(r.Filepath)
	case "Remove":
		err = fs.Remove(r.Filepath)
	case "Mkdir":
		err = fs.MkdirAll(r.Filepath)
	case "Symlink":
		logger.Debug("%s=>%s", r.Filepath, r.Target)
		err = fs.Symlink(r.Filepath, r.Target)
	default:
		return
	}
	return
}

func (fs *sftpHandler) Filewrite(r *sftp.Request) (io.WriterAt, error) {
	logger.Debug("File write: ", r.Filepath)
	f, err := fs.Create(r.Filepath)
	return NewWriterAt(f), err
}

func (fs *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
	logger.Debug("File read: ", r.Filepath)
	f, err := fs.Open(r.Filepath)
	if err != nil {
		return nil, err
	}
	fi, err := f.Stat()
	if err != nil {
		_ = f.Close()
		return nil, err
	}
	return NewReaderAt(f, fi), err
}

func (fs *sftpHandler) Close() {
	fs.UserSftp.Close()
}

type listerat []os.FileInfo

func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
	var n int
	if offset >= int64(len(f)) {
		return 0, io.EOF
	}
	n = copy(ls, f[offset:])
	if n < len(ls) {
		return n, io.EOF
	}
	return n, nil
}

func NewWriterAt(f *sftp.File) io.WriterAt {
	return &clientReadWritAt{f: f, mu: new(sync.RWMutex)}
}

func NewReaderAt(f *sftp.File, fi os.FileInfo) io.ReaderAt {
	return &clientReadWritAt{f: f, mu: new(sync.RWMutex), fi: fi}
}

type clientReadWritAt struct {
	f        *sftp.File
	mu       *sync.RWMutex
	fi       os.FileInfo
	firstErr error
}

func (c *clientReadWritAt) WriteAt(p []byte, off int64) (n int, err error) {
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.firstErr != nil {
		return 0, c.firstErr
	}
	_, _ = c.f.Seek(off, 0)
	nw, err := c.f.Write(p)
	if err != nil {
		c.firstErr = err
		_ = c.f.Close()
	}
	return nw, err
}

func (c *clientReadWritAt) ReadAt(p []byte, off int64) (n int, err error) {
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.firstErr != nil {
		return 0, c.firstErr
	}
	if off >= c.fi.Size() {
		return 0, io.EOF
	}
	_, _ = c.f.Seek(off, 0)
	nr, err := c.f.Read(p)
	if err != nil {
		c.firstErr = err
		_ = c.f.Close()
	}
	return nr, err
}

type wrapperSFTPFileInfo struct {
	f os.FileInfo
}

func (w *wrapperSFTPFileInfo) Name() string {
	return w.f.Name()
}
func (w *wrapperSFTPFileInfo) Size() int64 { return w.f.Size() }
func (w *wrapperSFTPFileInfo) Mode() os.FileMode {
	return w.f.Mode()
}
func (w *wrapperSFTPFileInfo) ModTime() time.Time { return w.f.ModTime() }
func (w *wrapperSFTPFileInfo) IsDir() bool        { return w.f.IsDir() }
func (w *wrapperSFTPFileInfo) Sys() interface{} {
	if statInfo, ok := w.f.Sys().(*sftp.FileStat); ok {
		return &syscall.Stat_t{Uid: statInfo.UID, Gid: statInfo.GID}
	}
	if statInfo, ok := w.f.Sys().(*syscall.Stat_t); ok {
		return &syscall.Stat_t{Uid: statInfo.Uid, Gid: statInfo.Gid}
	}
	return &syscall.Stat_t{Uid: 0, Gid: 0}
}