fixes #14
This commit is contained in:
@ -3,13 +3,13 @@ package websocket
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminLock(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -33,8 +33,8 @@ func (h *MessageHandler) adminLock(id string, session *session.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminUnlock(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminUnlock(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -58,8 +58,8 @@ func (h *MessageHandler) adminUnlock(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminControl(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminControl(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -73,7 +73,7 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_CONTROL,
|
||||
ID: id,
|
||||
Target: host.ID,
|
||||
Target: host.ID(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_CONTROL)
|
||||
return err
|
||||
@ -92,8 +92,8 @@ func (h *MessageHandler) adminControl(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminRelease(id string, session *session.Session) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminRelease(id string, session types.Session) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -107,7 +107,7 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_RELEASE,
|
||||
ID: id,
|
||||
Target: host.ID,
|
||||
Target: host.ID(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_RELEASE)
|
||||
return err
|
||||
@ -126,8 +126,8 @@ func (h *MessageHandler) adminRelease(id string, session *session.Session) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminGive(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminGive(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -154,8 +154,8 @@ func (h *MessageHandler) adminGive(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminMute(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminMute(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -166,17 +166,17 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
target.Muted = true
|
||||
target.SetMuted(true)
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_MUTE,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||
@ -186,8 +186,8 @@ func (h *MessageHandler) adminMute(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminUnmute(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminUnmute(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -198,12 +198,12 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
||||
return nil
|
||||
}
|
||||
|
||||
target.Muted = false
|
||||
target.SetMuted(false)
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_UNMUTE,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_UNMUTE)
|
||||
@ -213,8 +213,8 @@ func (h *MessageHandler) adminUnmute(id string, session *session.Session, payloa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminKick(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminKick(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -225,22 +225,19 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := target.Kick(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "You have been kicked",
|
||||
}); err != nil {
|
||||
if err := target.Kick("You have been kicked"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_KICK,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, []string{payload.ID}); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_KICK)
|
||||
@ -250,8 +247,8 @@ func (h *MessageHandler) adminKick(id string, session *session.Session, payload
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) adminBan(id string, session *session.Session, payload *message.Admin) error {
|
||||
if !session.Admin {
|
||||
func (h *MessageHandler) adminBan(id string, session types.Session, payload *message.Admin) error {
|
||||
if !session.Admin() {
|
||||
h.logger.Debug().Msg("user not admin")
|
||||
return nil
|
||||
}
|
||||
@ -262,12 +259,12 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Admin {
|
||||
if target.Admin() {
|
||||
h.logger.Debug().Msg("target is an admin, baling")
|
||||
return nil
|
||||
}
|
||||
|
||||
remote := target.RemoteAddr()
|
||||
remote := target.Address()
|
||||
if remote == nil {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
return nil
|
||||
@ -283,17 +280,14 @@ func (h *MessageHandler) adminBan(id string, session *session.Session, payload *
|
||||
|
||||
h.banned[address[0]] = true
|
||||
|
||||
if err := target.Kick(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "You have been banned",
|
||||
}); err != nil {
|
||||
if err := target.Kick("You have been banned"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.sessions.Brodcast(
|
||||
message.AdminTarget{
|
||||
Event: event.ADMIN_BAN,
|
||||
Target: target.ID,
|
||||
Target: target.ID(),
|
||||
ID: id,
|
||||
}, []string{payload.ID}); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.ADMIN_BAN)
|
||||
|
@ -1,13 +1,13 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) chat(id string, session *session.Session, payload *message.ChatRecieve) error {
|
||||
if session.Muted {
|
||||
func (h *MessageHandler) chat(id string, session types.Session, payload *message.ChatRecieve) error {
|
||||
if session.Muted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -23,8 +23,8 @@ func (h *MessageHandler) chat(id string, session *session.Session, payload *mess
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) chatEmote(id string, session *session.Session, payload *message.EmoteRecieve) error {
|
||||
if session.Muted {
|
||||
func (h *MessageHandler) chatEmote(id string, session types.Session, payload *message.EmoteRecieve) error {
|
||||
if session.Muted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) controlRelease(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) controlRelease(id string, session types.Session) error {
|
||||
|
||||
// check if session is host
|
||||
if !h.sessions.IsHost(id) {
|
||||
@ -31,7 +31,7 @@ func (h *MessageHandler) controlRelease(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) controlRequest(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) controlRequest(id string, session types.Session) error {
|
||||
// check for host
|
||||
if !h.sessions.HasHost() {
|
||||
// set host
|
||||
@ -57,7 +57,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
// tell session there is a host
|
||||
if err := session.Send(message.Control{
|
||||
Event: event.CONTROL_REQUEST,
|
||||
ID: host.ID,
|
||||
ID: host.ID(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Err(err).Str("id", id).Msgf("sending event %s has failed", event.CONTROL_REQUEST)
|
||||
return err
|
||||
@ -68,7 +68,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
Event: event.CONTROL_REQUESTING,
|
||||
ID: id,
|
||||
}); err != nil {
|
||||
h.logger.Warn().Err(err).Str("id", host.ID).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
||||
h.logger.Warn().Err(err).Str("id", host.ID()).Msgf("sending event %s has failed", event.CONTROL_REQUESTING)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -76,7 +76,7 @@ func (h *MessageHandler) controlRequest(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) controlGive(id string, session *session.Session, payload *message.Control) error {
|
||||
func (h *MessageHandler) controlGive(id string, session types.Session, payload *message.Control) error {
|
||||
// check if session is host
|
||||
if !h.sessions.IsHost(id) {
|
||||
h.logger.Debug().Str("id", id).Msg("is not the host")
|
||||
|
@ -1,235 +1,142 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
"n.eko.moe/neko/internal/webrtc"
|
||||
)
|
||||
|
||||
func New(sessions *session.SessionManager, webrtc *webrtc.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
conf: conf,
|
||||
sessions: sessions,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
handler: &MessageHandler{
|
||||
logger: logger.With().Str("subsystem", "handler").Logger(),
|
||||
sessions: sessions,
|
||||
webrtc: webrtc,
|
||||
banned: make(map[string]bool),
|
||||
locked: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
const pingPeriod = 60 * time.Second
|
||||
|
||||
type WebSocketHandler struct {
|
||||
type MessageHandler struct {
|
||||
logger zerolog.Logger
|
||||
upgrader websocket.Upgrader
|
||||
handler *MessageHandler
|
||||
conf *config.WebSocket
|
||||
sessions *session.SessionManager
|
||||
shutdown chan bool
|
||||
sessions types.SessionManager
|
||||
webrtc types.WebRTCManager
|
||||
banned map[string]bool
|
||||
locked bool
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Start() error {
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ws.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ws.shutdown:
|
||||
return
|
||||
}
|
||||
func (h *MessageHandler) Connected(id string, socket *WebSocket) (bool, string, error) {
|
||||
address := socket.Address()
|
||||
if address == nil {
|
||||
h.logger.Debug().Msg("no remote address, baling")
|
||||
} else {
|
||||
ok, banned := h.banned[*address]
|
||||
if ok && banned {
|
||||
h.logger.Debug().Str("address", *address).Msg("banned")
|
||||
return false, "This IP has been banned", nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ws.sessions.OnCreated(func(id string, session *session.Session) {
|
||||
if err := ws.handler.SessionCreated(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session created")
|
||||
}
|
||||
})
|
||||
if h.locked {
|
||||
h.logger.Debug().Msg("server locked")
|
||||
return false, "Server is currently locked", nil
|
||||
}
|
||||
|
||||
ws.sessions.OnConnected(func(id string, session *session.Session) {
|
||||
if err := ws.handler.SessionConnected(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session connected")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnDestroy(func(id string) {
|
||||
if err := ws.handler.SessionDestroyed(id); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Shutdown() error {
|
||||
ws.shutdown <- true
|
||||
return nil
|
||||
func (h *MessageHandler) Disconnected(id string) error {
|
||||
return h.sessions.Destroy(id)
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||
|
||||
socket, err := ws.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||
return err
|
||||
}
|
||||
|
||||
id, admin, err := ws.authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
||||
|
||||
if err = socket.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "invalid password",
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = socket.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, reason, err := ws.handler.SocketConnected(id, socket)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("connection failed")
|
||||
func (h *MessageHandler) Message(id string, raw []byte) error {
|
||||
header := message.Message{}
|
||||
if err := json.Unmarshal(raw, &header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := h.sessions.Get(id)
|
||||
if !ok {
|
||||
if err = socket.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: reason,
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = socket.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
errors.Errorf("unknown session id %s", id)
|
||||
}
|
||||
|
||||
ws.sessions.New(id, admin, socket)
|
||||
switch header.Event {
|
||||
// Signal Events
|
||||
case event.SIGNAL_PROVIDE:
|
||||
payload := &message.Signal{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.createPeer(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
// Identity Events
|
||||
case event.IDENTITY_DETAILS:
|
||||
payload := &message.IdentityDetails{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.identityDetails(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", socket.RemoteAddr().String()).
|
||||
Msg("new connection created")
|
||||
// Control Events
|
||||
case event.CONTROL_RELEASE:
|
||||
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_REQUEST:
|
||||
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_GIVE:
|
||||
payload := &message.Control{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.controlGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
defer func() {
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", socket.RemoteAddr().String()).
|
||||
Msg("session ended")
|
||||
}()
|
||||
// Chat Events
|
||||
case event.CHAT_MESSAGE:
|
||||
payload := &message.ChatRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chat(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.CHAT_EMOTE:
|
||||
payload := &message.EmoteRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chatEmote(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
ws.handle(socket, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.AdminPassword {
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.Password {
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) handle(socket *websocket.Conn, id string) {
|
||||
bytes := make(chan []byte)
|
||||
cancel := make(chan struct{})
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
ws.logger.Debug().Str("address", socket.RemoteAddr().String()).Msg("handle socket ending")
|
||||
ws.handler.SocketDisconnected(id)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, raw, err := socket.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
ws.logger.Warn().Err(err).Msg("read message error")
|
||||
} else {
|
||||
ws.logger.Debug().Err(err).Msg("read message error")
|
||||
}
|
||||
close(cancel)
|
||||
break
|
||||
}
|
||||
bytes <- raw
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case raw := <-bytes:
|
||||
ws.logger.Debug().
|
||||
Str("session", id).
|
||||
Str("raw", string(raw)).
|
||||
Msg("recieved message from client")
|
||||
if err := ws.handler.Message(id, raw); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("message handler has failed")
|
||||
}
|
||||
case <-cancel:
|
||||
return
|
||||
case _ = <-ticker.C:
|
||||
if err := socket.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Admin Events
|
||||
case event.ADMIN_LOCK:
|
||||
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_UNLOCK:
|
||||
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_CONTROL:
|
||||
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_RELEASE:
|
||||
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_GIVE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_BAN:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminBan(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_KICK:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminKick(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_MUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminMute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_UNMUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminUnmute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
default:
|
||||
return errors.Errorf("unknown message event %s", header.Event)
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,34 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) identityDetails(id string, session *session.Session, payload *message.IdentityDetails) error {
|
||||
if _, err := h.sessions.SetName(id, payload.Username); err != nil {
|
||||
func (h *MessageHandler) identityDetails(id string, session types.Session, payload *message.IdentityDetails) error {
|
||||
if err := session.SetName(payload.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) createPeer(id string, session types.Session, payload *message.Signal) error {
|
||||
sdp, peer, err := h.webrtc.CreatePeer(id, payload.SDP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.SetPeer(peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.Send(message.Signal{
|
||||
Event: event.SIGNAL_ANSWER,
|
||||
SDP: sdp,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,149 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
"n.eko.moe/neko/internal/webrtc"
|
||||
)
|
||||
|
||||
type MessageHandler struct {
|
||||
logger zerolog.Logger
|
||||
sessions *session.SessionManager
|
||||
webrtc *webrtc.WebRTCManager
|
||||
banned map[string]bool
|
||||
locked bool
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SocketConnected(id string, socket *websocket.Conn) (bool, string, error) {
|
||||
remote := socket.RemoteAddr().String()
|
||||
if remote != "" {
|
||||
address := strings.SplitN(remote, ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
h.logger.Debug().Str("address", remote).Msg("no remote address, baling")
|
||||
} else {
|
||||
|
||||
ok, banned := h.banned[address[0]]
|
||||
if ok && banned {
|
||||
h.logger.Debug().Str("address", remote).Msg("banned")
|
||||
return false, "This IP has been banned", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.locked {
|
||||
h.logger.Debug().Str("address", remote).Msg("locked")
|
||||
return false, "Server is currently locked", nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SocketDisconnected(id string) error {
|
||||
return h.sessions.Destroy(id)
|
||||
}
|
||||
|
||||
func (h *MessageHandler) Message(id string, raw []byte) error {
|
||||
header := message.Message{}
|
||||
if err := json.Unmarshal(raw, &header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := h.sessions.Get(id)
|
||||
if !ok {
|
||||
errors.Errorf("unknown session id %s", id)
|
||||
}
|
||||
|
||||
switch header.Event {
|
||||
// Signal Events
|
||||
case event.SIGNAL_PROVIDE:
|
||||
payload := message.Signal{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(&payload, raw, func() error {
|
||||
return h.webrtc.CreatePeer(id, payload.SDP)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Identity Events
|
||||
case event.IDENTITY_DETAILS:
|
||||
payload := &message.IdentityDetails{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.identityDetails(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Control Events
|
||||
case event.CONTROL_RELEASE:
|
||||
return errors.Wrapf(h.controlRelease(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_REQUEST:
|
||||
return errors.Wrapf(h.controlRequest(id, session), "%s failed", header.Event)
|
||||
case event.CONTROL_GIVE:
|
||||
payload := &message.Control{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.controlGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Chat Events
|
||||
case event.CHAT_MESSAGE:
|
||||
payload := &message.ChatRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chat(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.CHAT_EMOTE:
|
||||
payload := &message.EmoteRecieve{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.chatEmote(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
|
||||
// Admin Events
|
||||
case event.ADMIN_LOCK:
|
||||
return errors.Wrapf(h.adminLock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_UNLOCK:
|
||||
return errors.Wrapf(h.adminUnlock(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_CONTROL:
|
||||
return errors.Wrapf(h.adminControl(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_RELEASE:
|
||||
return errors.Wrapf(h.adminRelease(id, session), "%s failed", header.Event)
|
||||
case event.ADMIN_GIVE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminGive(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_BAN:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminBan(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_KICK:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminKick(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_MUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminMute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
case event.ADMIN_UNMUTE:
|
||||
payload := &message.Admin{}
|
||||
return errors.Wrapf(
|
||||
utils.Unmarshal(payload, raw, func() error {
|
||||
return h.adminUnmute(id, session, payload)
|
||||
}), "%s failed", header.Event)
|
||||
default:
|
||||
return errors.Errorf("unknown message event %s", header.Event)
|
||||
}
|
||||
}
|
@ -1,12 +1,12 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"n.eko.moe/neko/internal/event"
|
||||
"n.eko.moe/neko/internal/message"
|
||||
"n.eko.moe/neko/internal/session"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
)
|
||||
|
||||
func (h *MessageHandler) SessionCreated(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) SessionCreated(id string, session types.Session) error {
|
||||
if err := session.Send(message.Identity{
|
||||
Event: event.IDENTITY_PROVIDE,
|
||||
ID: id,
|
||||
@ -17,11 +17,11 @@ func (h *MessageHandler) SessionCreated(id string, session *session.Session) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MessageHandler) SessionConnected(id string, session *session.Session) error {
|
||||
func (h *MessageHandler) SessionConnected(id string, session types.Session) error {
|
||||
// send list of members to session
|
||||
if err := session.Send(message.MembersList{
|
||||
Event: event.MEMBER_LIST,
|
||||
Memebers: h.sessions.GetConnected(),
|
||||
Memebers: h.sessions.Members(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.MEMBER_LIST)
|
||||
return err
|
||||
@ -32,7 +32,7 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
||||
if ok {
|
||||
if err := session.Send(message.Control{
|
||||
Event: event.CONTROL_LOCKED,
|
||||
ID: host.ID,
|
||||
ID: host.ID(),
|
||||
}); err != nil {
|
||||
h.logger.Warn().Str("id", id).Err(err).Msgf("sending event %s has failed", event.CONTROL_LOCKED)
|
||||
return err
|
||||
@ -42,8 +42,8 @@ func (h *MessageHandler) SessionConnected(id string, session *session.Session) e
|
||||
// let everyone know there is a new session
|
||||
if err := h.sessions.Brodcast(
|
||||
message.Member{
|
||||
Event: event.MEMBER_CONNECTED,
|
||||
Session: session,
|
||||
Event: event.MEMBER_CONNECTED,
|
||||
Member: session.Member(),
|
||||
}, nil); err != nil {
|
||||
h.logger.Warn().Err(err).Msgf("brodcasting event %s has failed", event.CONTROL_RELEASE)
|
||||
return err
|
||||
|
37
server/internal/websocket/socket.go
Normal file
37
server/internal/websocket/socket.go
Normal file
@ -0,0 +1,37 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WebSocket struct {
|
||||
id string
|
||||
connection *websocket.Conn
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Address() *string {
|
||||
remote := socket.connection.RemoteAddr()
|
||||
address := strings.SplitN(remote.String(), ":", -1)
|
||||
if len(address[0]) < 1 {
|
||||
return nil
|
||||
}
|
||||
return &address[0]
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Send(v interface{}) error {
|
||||
if socket.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return socket.connection.WriteJSON(v)
|
||||
}
|
||||
|
||||
func (socket *WebSocket) Destroy() error {
|
||||
if socket.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return socket.connection.Close()
|
||||
}
|
239
server/internal/websocket/websocket.go
Normal file
239
server/internal/websocket/websocket.go
Normal file
@ -0,0 +1,239 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"n.eko.moe/neko/internal/config"
|
||||
"n.eko.moe/neko/internal/types"
|
||||
"n.eko.moe/neko/internal/types/event"
|
||||
"n.eko.moe/neko/internal/types/message"
|
||||
"n.eko.moe/neko/internal/utils"
|
||||
)
|
||||
|
||||
func New(sessions types.SessionManager, webrtc types.WebRTCManager, conf *config.WebSocket) *WebSocketHandler {
|
||||
logger := log.With().Str("module", "websocket").Logger()
|
||||
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
conf: conf,
|
||||
sessions: sessions,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
handler: &MessageHandler{
|
||||
logger: logger.With().Str("subsystem", "handler").Logger(),
|
||||
sessions: sessions,
|
||||
webrtc: webrtc,
|
||||
banned: make(map[string]bool),
|
||||
locked: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
const pingPeriod = 60 * time.Second
|
||||
|
||||
type WebSocketHandler struct {
|
||||
logger zerolog.Logger
|
||||
upgrader websocket.Upgrader
|
||||
sessions types.SessionManager
|
||||
conf *config.WebSocket
|
||||
handler *MessageHandler
|
||||
shutdown chan bool
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Start() error {
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ws.logger.Info().Msg("shutdown")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ws.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ws.sessions.OnCreated(func(id string, session types.Session) {
|
||||
if err := ws.handler.SessionCreated(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session created with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session created")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnConnected(func(id string, session types.Session) {
|
||||
if err := ws.handler.SessionConnected(id, session); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session connected with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session connected")
|
||||
}
|
||||
})
|
||||
|
||||
ws.sessions.OnDestroy(func(id string) {
|
||||
if err := ws.handler.SessionDestroyed(id); err != nil {
|
||||
ws.logger.Warn().Str("id", id).Err(err).Msg("session destroyed with and error")
|
||||
} else {
|
||||
ws.logger.Debug().Str("id", id).Msg("session destroyed")
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Shutdown() error {
|
||||
ws.shutdown <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) Upgrade(w http.ResponseWriter, r *http.Request) error {
|
||||
ws.logger.Debug().Msg("attempting to upgrade connection")
|
||||
|
||||
connection, err := ws.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to upgrade connection")
|
||||
return err
|
||||
}
|
||||
|
||||
id, admin, err := ws.authenticate(r)
|
||||
if err != nil {
|
||||
ws.logger.Warn().Err(err).Msg("authenticatetion failed")
|
||||
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: "invalid password",
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
socket := &WebSocket{
|
||||
id: id,
|
||||
connection: connection,
|
||||
}
|
||||
|
||||
ok, reason, err := ws.handler.Connected(id, socket)
|
||||
if err != nil {
|
||||
ws.logger.Error().Err(err).Msg("connection failed")
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err = connection.WriteJSON(message.Disconnect{
|
||||
Event: event.SYSTEM_DISCONNECT,
|
||||
Message: reason,
|
||||
}); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("failed to send disconnect")
|
||||
}
|
||||
|
||||
if err = connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ws.sessions.New(id, admin, socket)
|
||||
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("new connection created")
|
||||
|
||||
defer func() {
|
||||
ws.logger.
|
||||
Debug().
|
||||
Str("session", id).
|
||||
Str("address", connection.RemoteAddr().String()).
|
||||
Msg("session ended")
|
||||
}()
|
||||
|
||||
ws.handle(connection, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) authenticate(r *http.Request) (string, bool, error) {
|
||||
id, err := utils.NewUID(32)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
passwords, ok := r.URL.Query()["password"]
|
||||
if !ok || len(passwords[0]) < 1 {
|
||||
return "", false, fmt.Errorf("no password provided")
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.AdminPassword {
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
if passwords[0] == ws.conf.Password {
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
return "", false, fmt.Errorf("invalid password: %s", passwords[0])
|
||||
}
|
||||
|
||||
func (ws *WebSocketHandler) handle(connection *websocket.Conn, id string) {
|
||||
bytes := make(chan []byte)
|
||||
cancel := make(chan struct{})
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
ws.logger.Debug().Str("address", connection.RemoteAddr().String()).Msg("handle socket ending")
|
||||
ws.handler.Disconnected(id)
|
||||
}()
|
||||
|
||||
for {
|
||||
_, raw, err := connection.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
ws.logger.Warn().Err(err).Msg("read message error")
|
||||
} else {
|
||||
ws.logger.Debug().Err(err).Msg("read message error")
|
||||
}
|
||||
close(cancel)
|
||||
break
|
||||
}
|
||||
bytes <- raw
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case raw := <-bytes:
|
||||
ws.logger.Debug().
|
||||
Str("session", id).
|
||||
Str("raw", string(raw)).
|
||||
Msg("recieved message from client")
|
||||
if err := ws.handler.Message(id, raw); err != nil {
|
||||
ws.logger.Error().Err(err).Msg("message handler has failed")
|
||||
}
|
||||
case <-cancel:
|
||||
return
|
||||
case _ = <-ticker.C:
|
||||
if err := connection.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user