From a178bede873a527bd22c30026b18e8a433a28b71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Fri, 19 Apr 2024 20:22:43 +0200 Subject: [PATCH] add filetransfer plugin. --- internal/plugins/filetransfer/config.go | 41 +++ internal/plugins/filetransfer/manager.go | 312 +++++++++++++++++++++++ internal/plugins/filetransfer/plugin.go | 35 +++ internal/plugins/filetransfer/types.go | 26 ++ internal/plugins/filetransfer/utils.go | 32 +++ internal/plugins/manager.go | 4 + 6 files changed, 450 insertions(+) create mode 100644 internal/plugins/filetransfer/config.go create mode 100644 internal/plugins/filetransfer/manager.go create mode 100644 internal/plugins/filetransfer/plugin.go create mode 100644 internal/plugins/filetransfer/types.go create mode 100644 internal/plugins/filetransfer/utils.go diff --git a/internal/plugins/filetransfer/config.go b/internal/plugins/filetransfer/config.go new file mode 100644 index 00000000..593d31fb --- /dev/null +++ b/internal/plugins/filetransfer/config.go @@ -0,0 +1,41 @@ +package filetransfer + +import ( + "path/filepath" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type Config struct { + Enabled bool + RootDir string + RefreshInterval time.Duration +} + +func (Config) Init(cmd *cobra.Command) error { + cmd.PersistentFlags().Bool("filetransfer.enabled", false, "whether file transfer is enabled") + if err := viper.BindPFlag("filetransfer.enabled", cmd.PersistentFlags().Lookup("filetransfer.enabled")); err != nil { + return err + } + + cmd.PersistentFlags().String("filetransfer.dir", "/home/neko/Downloads", "root directory for file transfer") + if err := viper.BindPFlag("filetransfer.dir", cmd.PersistentFlags().Lookup("filetransfer.dir")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("filetransfer.refresh_interval", 30*time.Second, "interval to refresh file list") + if err := viper.BindPFlag("filetransfer.refresh_interval", cmd.PersistentFlags().Lookup("filetransfer.refresh_interval")); err != nil { + return err + } + + return nil +} + +func (s *Config) Set() { + s.Enabled = viper.GetBool("filetransfer.enabled") + rootDir := viper.GetString("filetransfer.dir") + s.RootDir = filepath.Clean(rootDir) + s.RefreshInterval = viper.GetDuration("filetransfer.refresh_interval") +} diff --git a/internal/plugins/filetransfer/manager.go b/internal/plugins/filetransfer/manager.go new file mode 100644 index 00000000..34fa7ee0 --- /dev/null +++ b/internal/plugins/filetransfer/manager.go @@ -0,0 +1,312 @@ +package filetransfer + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "sync" + "time" + + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/utils" + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +const MULTIPART_FORM_MAX_MEMORY = 32 << 20 + +func NewManager( + sessions types.SessionManager, + config *Config, +) *Manager { + logger := log.With().Str("module", "filetransfer").Logger() + + return &Manager{ + logger: logger, + config: config, + sessions: sessions, + shutdown: make(chan struct{}), + } +} + +type Manager struct { + logger zerolog.Logger + config *Config + sessions types.SessionManager + shutdown chan struct{} + mu sync.RWMutex + fileList []Item +} + +func (m *Manager) isEnabledForSession(session types.Session) bool { + canTransfer := true + + profile, ok := session.Profile().Plugins["filetransfer"] + // by default, allow file transfer if the plugin config is not present + if ok { + canTransfer, ok = profile.(bool) + // if the plugin is present but not a boolean, allow file transfer + if !ok { + canTransfer = true + } + } + + return m.config.Enabled && canTransfer + // TODO: when locking is implemented + // && (session.Profile().IsAdmin || !h.state.IsLocked("file_transfer")) +} + +func (m *Manager) refresh() (error, bool) { + // if file transfer is disabled, return immediately without refreshing + if !m.config.Enabled { + return nil, false + } + + files, err := ListFiles(m.config.RootDir) + if err != nil { + return err, false + } + + m.mu.Lock() + defer m.mu.Unlock() + + // check if file list has changed (todo: use hash instead of comparing all fields) + changed := false + if len(files) == len(m.fileList) { + for i, file := range files { + if file.Name != m.fileList[i].Name || file.Size != m.fileList[i].Size { + changed = true + break + } + } + } else { + changed = true + } + + m.fileList = files + return nil, changed +} + +func (m *Manager) broadcastUpdate() { + m.mu.RLock() + defer m.mu.RUnlock() + + m.sessions.Broadcast(FILETRANSFER_UPDATE, Message{ + Enabled: m.config.Enabled, + RootDir: m.config.RootDir, + Files: m.fileList, + }) +} + +func (m *Manager) Start() error { + // send init message once a user connects + m.sessions.OnConnected(func(session types.Session) { + isEnabled := m.isEnabledForSession(session) + + // get file list + m.mu.RLock() + fileList := m.fileList + m.mu.RUnlock() + + // send init message + session.Send(FILETRANSFER_UPDATE, Message{ + Enabled: isEnabled, + RootDir: m.config.RootDir, + Files: fileList, + }) + }) + + // if file transfer is disabled, return immediately without starting the watcher + if !m.config.Enabled { + return nil + } + + if _, err := os.Stat(m.config.RootDir); os.IsNotExist(err) { + err = os.Mkdir(m.config.RootDir, os.ModePerm) + m.logger.Err(err).Msg("creating file transfer directory") + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("unable to start file transfer dir watcher: %w", err) + } + + go func() { + defer watcher.Close() + + // periodically refresh file list + ticker := time.NewTicker(m.config.RefreshInterval) + defer ticker.Stop() + + for { + select { + case <-m.shutdown: + m.logger.Info().Msg("shutting down file transfer manager") + return + case <-ticker.C: + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + if changed { + m.broadcastUpdate() + } + case e, ok := <-watcher.Events: + if !ok { + m.logger.Info().Msg("file transfer dir watcher closed") + return + } + + if e.Has(fsnotify.Create) || e.Has(fsnotify.Remove) || e.Has(fsnotify.Rename) { + m.logger.Debug().Str("event", e.String()).Msg("file transfer dir watcher event") + + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + + if changed { + m.broadcastUpdate() + } + } + case err := <-watcher.Errors: + m.logger.Err(err).Msg("error in file transfer dir watcher") + } + } + }() + + if err := watcher.Add(m.config.RootDir); err != nil { + return fmt.Errorf("unable to watch file transfer dir: %w", err) + } + + return nil +} + +func (m *Manager) Shutdown() error { + close(m.shutdown) + return nil +} + +func (m *Manager) Route(r types.Router) { + r.With(auth.AdminsOnly).Get("/", m.downloadFileHandler) + r.With(auth.AdminsOnly).Post("/", m.uploadFileHandler) +} + +func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool { + switch msg.Event { + case FILETRANSFER_UPDATE: + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + + if changed { + m.broadcastUpdate() + } else { + // get file list + m.mu.RLock() + fileList := m.fileList + m.mu.RUnlock() + + // send update message to this client only + session.Send(FILETRANSFER_UPDATE, Message{ + Enabled: m.config.Enabled, + RootDir: m.config.RootDir, + Files: fileList, + }) + } + return true + } + return false +} + +func (m *Manager) downloadFileHandler(w http.ResponseWriter, r *http.Request) error { + session, ok := auth.GetSession(r) + if !ok { + return utils.HttpUnauthorized("session not found") + } + + enabled := m.isEnabledForSession(session) + if !enabled { + return utils.HttpForbidden("file transfer is disabled") + } + + filename := r.URL.Query().Get("filename") + badChars, err := regexp.MatchString(`(?m)\.\.(?:\/|$)`, filename) + if filename == "" || badChars || err != nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("bad filename") + } + + // ensure filename is clean and only contains the basename + filename = filepath.Clean(filename) + filename = filepath.Base(filename) + filePath := filepath.Join(m.config.RootDir, filename) + + http.ServeFile(w, r, filePath) + return nil +} + +func (m *Manager) uploadFileHandler(w http.ResponseWriter, r *http.Request) error { + session, ok := auth.GetSession(r) + if !ok { + return utils.HttpUnauthorized("session not found") + } + + enabled := m.isEnabledForSession(session) + if !enabled { + return utils.HttpForbidden("file transfer is disabled") + } + + err := r.ParseMultipartForm(MULTIPART_FORM_MAX_MEMORY) + if err != nil || r.MultipartForm == nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("error parsing form") + } + + defer func() { + err = r.MultipartForm.RemoveAll() + if err != nil { + m.logger.Warn().Err(err).Msg("failed to clean up multipart form") + } + }() + + for _, formheader := range r.MultipartForm.File["files"] { + // ensure filename is clean and only contains the basename + filename := filepath.Clean(formheader.Filename) + filename = filepath.Base(filename) + filePath := filepath.Join(m.config.RootDir, filename) + + formfile, err := formheader.Open() + if err != nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("error opening formdata file") + } + defer formfile.Close() + + f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error opening file for writing") + } + defer f.Close() + + _, err = io.Copy(f, formfile) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error writing file") + } + } + + return nil +} diff --git a/internal/plugins/filetransfer/plugin.go b/internal/plugins/filetransfer/plugin.go new file mode 100644 index 00000000..a98672a1 --- /dev/null +++ b/internal/plugins/filetransfer/plugin.go @@ -0,0 +1,35 @@ +package filetransfer + +import ( + "github.com/demodesk/neko/pkg/types" +) + +type Plugin struct { + config *Config + manager *Manager +} + +func NewPlugin() *Plugin { + return &Plugin{ + config: &Config{}, + } +} + +func (p *Plugin) Name() string { + return PluginName +} + +func (p *Plugin) Config() types.PluginConfig { + return p.config +} + +func (p *Plugin) Start(m types.PluginManagers) error { + p.manager = NewManager(m.SessionManager, p.config) + m.ApiManager.AddRouter("/filetransfer", p.manager.Route) + m.WebSocketManager.AddHandler(p.manager.WebSocketHandler) + return p.manager.Start() +} + +func (p *Plugin) Shutdown() error { + return p.manager.Shutdown() +} diff --git a/internal/plugins/filetransfer/types.go b/internal/plugins/filetransfer/types.go new file mode 100644 index 00000000..994bbe6d --- /dev/null +++ b/internal/plugins/filetransfer/types.go @@ -0,0 +1,26 @@ +package filetransfer + +const PluginName = "filetransfer" + +const ( + FILETRANSFER_UPDATE = "filetransfer/update" +) + +type Message struct { + Enabled bool `json:"enabled"` + RootDir string `json:"root_dir"` + Files []Item `json:"files"` +} + +type ItemType string + +const ( + ItemTypeFile ItemType = "file" + ItemTypeDir ItemType = "dir" +) + +type Item struct { + Name string `json:"name"` + Type ItemType `json:"type"` + Size int64 `json:"size,omitempty"` +} diff --git a/internal/plugins/filetransfer/utils.go b/internal/plugins/filetransfer/utils.go new file mode 100644 index 00000000..c4c828ae --- /dev/null +++ b/internal/plugins/filetransfer/utils.go @@ -0,0 +1,32 @@ +package filetransfer + +import "os" + +func ListFiles(path string) ([]Item, error) { + items, err := os.ReadDir(path) + if err != nil { + return nil, err + } + + out := make([]Item, len(items)) + for i, item := range items { + var itemType ItemType + var size int64 = 0 + if item.IsDir() { + itemType = ItemTypeDir + } else { + itemType = ItemTypeFile + info, err := item.Info() + if err == nil { + size = info.Size() + } + } + out[i] = Item{ + Name: item.Name(), + Type: itemType, + Size: size, + } + } + + return out, nil +} diff --git a/internal/plugins/manager.go b/internal/plugins/manager.go index 79d8cc28..908b9aa6 100644 --- a/internal/plugins/manager.go +++ b/internal/plugins/manager.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/cobra" "github.com/demodesk/neko/internal/config" + "github.com/demodesk/neko/internal/plugins/filetransfer" "github.com/demodesk/neko/pkg/types" ) @@ -42,6 +43,9 @@ func New(config *config.Plugins) *ManagerCtx { manager.logger.Info().Msgf("loading finished, total %d plugins", manager.plugins.len()) } + // add built-in plugins + manager.plugins.addPlugin(filetransfer.NewPlugin()) + return manager }