session id to session.

This commit is contained in:
Miroslav Šedivý 2020-11-14 16:03:12 +01:00
parent fdf5839547
commit e5eaf5c60c
3 changed files with 12 additions and 18 deletions

View File

@ -42,7 +42,7 @@ type MessageHandlerCtx struct {
locked bool locked bool
} }
func (h *MessageHandlerCtx) Connected(id string, socket types.WebSocket) (bool, string) { func (h *MessageHandlerCtx) Connected(session types.Session, socket types.WebSocket) (bool, string) {
address := socket.Address() address := socket.Address()
if address != "" { if address != "" {
ok, banned := h.banned[address] ok, banned := h.banned[address]
@ -54,12 +54,9 @@ func (h *MessageHandlerCtx) Connected(id string, socket types.WebSocket) (bool,
h.logger.Debug().Msg("no remote address") h.logger.Debug().Msg("no remote address")
} }
if h.locked { if h.locked && !session.Admin(){
session, ok := h.sessions.Get(id) h.logger.Debug().Msg("server locked")
if !ok || !session.Admin() { return false, "locked"
h.logger.Debug().Msg("server locked")
return false, "locked"
}
} }
return true, "" return true, ""
@ -74,17 +71,12 @@ func (h *MessageHandlerCtx) Disconnected(id string) error {
return h.sessions.Destroy(id) return h.sessions.Destroy(id)
} }
func (h *MessageHandlerCtx) Message(id string, raw []byte) error { func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) error {
header := message.Message{} header := message.Message{}
if err := json.Unmarshal(raw, &header); err != nil { if err := json.Unmarshal(raw, &header); err != nil {
return err return err
} }
session, ok := h.sessions.Get(id)
if !ok {
return errors.Errorf("unknown session id %s", id)
}
switch header.Event { switch header.Event {
// Signal Events // Signal Events
case event.SIGNAL_ANSWER: case event.SIGNAL_ANSWER:

View File

@ -149,13 +149,13 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e
// } // }
socket := &WebSocketCtx{ socket := &WebSocketCtx{
id: session.ID(), session: session,
ws: ws, ws: ws,
address: ip, address: ip,
connection: connection, connection: connection,
} }
ok, reason := ws.handler.Connected(session.ID(), socket) ok, reason := ws.handler.Connected(session, socket)
if !ok { if !ok {
// TODO: Refactor // TODO: Refactor
if err = connection.WriteJSON(message.Disconnect{ if err = connection.WriteJSON(message.Disconnect{
@ -226,7 +226,7 @@ func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types.
Str("raw", string(raw)). Str("raw", string(raw)).
Msg("received message from client") Msg("received message from client")
if err := ws.handler.Message(session.ID(), raw); err != nil { if err := ws.handler.Message(session, raw); err != nil {
ws.logger.Error().Err(err).Msg("message handler has failed") ws.logger.Error().Err(err).Msg("message handler has failed")
} }
case <-cancel: case <-cancel:

View File

@ -6,10 +6,12 @@ import (
"sync" "sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"demodesk/neko/internal/types"
) )
type WebSocketCtx struct { type WebSocketCtx struct {
id string session types.Session
address string address string
ws *WebSocketManagerCtx ws *WebSocketManagerCtx
connection *websocket.Conn connection *websocket.Conn
@ -40,7 +42,7 @@ func (socket *WebSocketCtx) Send(v interface{}) error {
} }
socket.ws.logger.Debug(). socket.ws.logger.Debug().
Str("session", socket.id). Str("session", socket.session.ID()).
Str("address", socket.connection.RemoteAddr().String()). Str("address", socket.connection.RemoteAddr().String()).
Str("raw", string(raw)). Str("raw", string(raw)).
Msg("sending message to client") Msg("sending message to client")