diff --git a/server/internal/config/websocket.go b/server/internal/config/websocket.go index 91050185..653411e4 100644 --- a/server/internal/config/websocket.go +++ b/server/internal/config/websocket.go @@ -12,6 +12,10 @@ type WebSocket struct { Locks []string ControlProtection bool + + FileTransfer bool + UnprivFileTransfer bool + FileTransferPath string } func (WebSocket) Init(cmd *cobra.Command) error { @@ -40,6 +44,21 @@ func (WebSocket) Init(cmd *cobra.Command) error { return err } + cmd.PersistentFlags().Bool("file_transfer", false, "allow file transfer for admins") + if err := viper.BindPFlag("file_transfer", cmd.PersistentFlags().Lookup("file_transfer")); err != nil { + return err + } + + cmd.PersistentFlags().Bool("unpriv_file_transfer", false, "allow file transfer for non admins") + if err := viper.BindPFlag("unpriv_file_transfer", cmd.PersistentFlags().Lookup("unpriv_file_transfer")); err != nil { + return err + } + + cmd.PersistentFlags().String("file_transfer_path", "/home/neko/Downloads", "path to use for file transfer") + if err := viper.BindPFlag("file_transfer_path", cmd.PersistentFlags().Lookup("file_transfer_path")); err != nil { + return err + } + return nil } @@ -50,4 +69,8 @@ func (s *WebSocket) Set() { s.Locks = viper.GetStringSlice("locks") s.ControlProtection = viper.GetBool("control_protection") + + s.FileTransfer = viper.GetBool("file_transfer") + s.UnprivFileTransfer = viper.GetBool("unpriv_file_transfer") + s.FileTransferPath = viper.GetString("file_transfer_path") } diff --git a/server/internal/http/http.go b/server/internal/http/http.go index 06b08638..4b8989c4 100644 --- a/server/internal/http/http.go +++ b/server/internal/http/http.go @@ -3,9 +3,11 @@ package http import ( "context" "encoding/json" + "fmt" "image/jpeg" "net/http" "os" + "regexp" "strconv" "github.com/go-chi/chi" @@ -17,6 +19,8 @@ import ( "m1k1o/neko/internal/types" ) +const FILE_UPLOAD_BUF_SIZE = 65000 + type Server struct { logger zerolog.Logger router *chi.Mux @@ -99,6 +103,99 @@ func New(conf *config.Server, webSocketHandler types.WebSocketHandler, desktop t } }) + router.Get("/file", func(w http.ResponseWriter, r *http.Request) { + password := r.URL.Query().Get("pwd") + isAuthorized, err := webSocketHandler.CanTransferFiles(password) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + + if !isAuthorized { + http.Error(w, "bad authorization", http.StatusUnauthorized) + return + } + + filename := r.URL.Query().Get("filename") + badChars, _ := regexp.MatchString(`(?m)\.\.(?:\/|$)`, filename) + if filename == "" || badChars { + http.Error(w, "bad filename", http.StatusBadRequest) + return + } + + path := webSocketHandler.MakeFilePath(filename) + f, err := os.Open(path) + if err != nil { + http.Error(w, "not found or unable to open", http.StatusNotFound) + return + } + defer f.Close() + fileinfo, err := f.Stat() + if err != nil { + http.Error(w, "unable to stat file", http.StatusInternalServerError) + return + } + + buffer := make([]byte, fileinfo.Size()) + _, err = f.Read(buffer) + if err != nil { + http.Error(w, "error reading file", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) + w.Write(buffer) + }) + + router.Post("/file", func(w http.ResponseWriter, r *http.Request) { + password := r.URL.Query().Get("pwd") + isAuthorized, err := webSocketHandler.CanTransferFiles(password) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + + if !isAuthorized { + http.Error(w, "bad authorization", http.StatusUnauthorized) + return + } + + r.ParseMultipartForm(32 << 20) + buffer := make([]byte, FILE_UPLOAD_BUF_SIZE) + for _, formheader := range r.MultipartForm.File["files"] { + formfile, err := formheader.Open() + if err != nil { + logger.Warn().Err(err).Msg("failed to open formdata file") + http.Error(w, "error writing file", http.StatusInternalServerError) + return + } + f, err := os.OpenFile(webSocketHandler.MakeFilePath(formheader.Filename), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + http.Error(w, "unable to open file for writing", http.StatusInternalServerError) + return + } + + var copied int64 = 0 + for copied < formheader.Size { + var limit int64 = int64(len(buffer)) + if limit > formheader.Size-copied { + limit = formheader.Size - copied + } + bytesRead, err := formfile.ReadAt(buffer[:limit], copied) + if err != nil { + logger.Warn().Err(err).Msg("failed copying file in upload") + http.Error(w, "error writing file", http.StatusInternalServerError) + return + } + f.Write(buffer[:bytesRead]) + copied += int64(bytesRead) + } + + formfile.Close() + f.Close() + } + }) + router.Get("/health", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("true")) }) diff --git a/server/internal/types/websocket.go b/server/internal/types/websocket.go index 968bbb8b..374d4f9e 100644 --- a/server/internal/types/websocket.go +++ b/server/internal/types/websocket.go @@ -34,4 +34,6 @@ type WebSocketHandler interface { Stats() Stats IsLocked(resource string) bool IsAdmin(password string) (bool, error) + CanTransferFiles(password string) (bool, error) + MakeFilePath(filename string) string } diff --git a/server/internal/websocket/websocket.go b/server/internal/websocket/websocket.go index a4ec5729..00e8edae 100644 --- a/server/internal/websocket/websocket.go +++ b/server/internal/websocket/websocket.go @@ -3,6 +3,7 @@ package websocket import ( "fmt" "net/http" + "os" "sync" "sync/atomic" "time" @@ -33,6 +34,14 @@ func New(sessions types.SessionManager, desktop types.DesktopManager, capture ty logger.Info().Msgf("control locked on behalf of control protection") } + if conf.FileTransferPath[len(conf.FileTransferPath)-1] != '/' { + conf.FileTransferPath += "/" + } + err := os.Mkdir(conf.FileTransferPath, 0755) + if err != nil && !os.IsExist(err) { + logger.Panic().Err(err).Msg("unable to create file transfer directory") + } + // apply default locks for _, lock := range conf.Locks { state.Lock(lock, "") // empty session ID @@ -314,6 +323,22 @@ func (ws *WebSocketHandler) IsAdmin(password string) (bool, error) { return false, fmt.Errorf("invalid password") } +func (ws *WebSocketHandler) CanTransferFiles(password string) (bool, error) { + if !ws.conf.FileTransfer { + return false, nil + } + + if !ws.conf.UnprivFileTransfer { + return ws.IsAdmin(password) + } + + return password == ws.conf.Password, nil +} + +func (ws *WebSocketHandler) MakeFilePath(filename string) string { + return fmt.Sprintf("%s%s", ws.conf.FileTransferPath, filename) +} + func (ws *WebSocketHandler) authenticate(r *http.Request) (bool, error) { passwords, ok := r.URL.Query()["password"] if !ok || len(passwords[0]) < 1 {