diff --git a/internal/member/manager.go b/internal/member/manager.go index 6315bb95..fd502dea 100644 --- a/internal/member/manager.go +++ b/internal/member/manager.go @@ -41,86 +41,91 @@ func New(sessions types.SessionManager, config *config.Member) *MemberManagerCtx } type MemberManagerCtx struct { - logger zerolog.Logger - sessions types.SessionManager - config *config.Member - mu sync.Mutex - provider types.MemberProvider + logger zerolog.Logger + sessions types.SessionManager + config *config.Member + providerMu sync.Mutex + provider types.MemberProvider + sessionMu sync.Mutex } func (manager *MemberManagerCtx) Connect() error { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.Connect() } func (manager *MemberManagerCtx) Disconnect() error { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.Disconnect() } func (manager *MemberManagerCtx) Authenticate(username string, password string) (string, types.MemberProfile, error) { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.Authenticate(username, password) } func (manager *MemberManagerCtx) Insert(username string, password string, profile types.MemberProfile) (string, error) { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.Insert(username, password, profile) } func (manager *MemberManagerCtx) Select(id string) (types.MemberProfile, error) { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.Select(id) } func (manager *MemberManagerCtx) SelectAll(limit int, offset int) (map[string]types.MemberProfile, error) { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.SelectAll(limit, offset) } func (manager *MemberManagerCtx) UpdateProfile(id string, profile types.MemberProfile) error { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() // update corresponding session, if exists + manager.sessionMu.Lock() if _, ok := manager.sessions.Get(id); ok { if err := manager.sessions.Update(id, profile); err != nil { manager.logger.Err(err).Msg("error while updating session") } } + manager.sessionMu.Unlock() return manager.provider.UpdateProfile(id, profile) } func (manager *MemberManagerCtx) UpdatePassword(id string, password string) error { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() return manager.provider.UpdatePassword(id, password) } func (manager *MemberManagerCtx) Delete(id string) error { - manager.mu.Lock() - defer manager.mu.Unlock() + manager.providerMu.Lock() + defer manager.providerMu.Unlock() // destroy corresponding session, if exists + manager.sessionMu.Lock() if _, ok := manager.sessions.Get(id); ok { if err := manager.sessions.Delete(id); err != nil { manager.logger.Err(err).Msg("error while deleting session") } } + manager.sessionMu.Unlock() return manager.provider.Delete(id) } @@ -130,15 +135,31 @@ func (manager *MemberManagerCtx) Delete(id string) error { // func (manager *MemberManagerCtx) Login(username string, password string) (types.Session, string, error) { + manager.sessionMu.Lock() + defer manager.sessionMu.Unlock() + id, profile, err := manager.provider.Authenticate(username, password) if err != nil { return nil, "", err } + session, ok := manager.sessions.Get(id) + if ok { + if session.State().IsConnected { + return nil, "", fmt.Errorf("session is already connected") + } + + // delete existing session + manager.sessions.Delete(id) + } + return manager.sessions.Create(id, profile) } func (manager *MemberManagerCtx) Logout(id string) error { + manager.sessionMu.Lock() + defer manager.sessionMu.Unlock() + if _, ok := manager.sessions.Get(id); !ok { return fmt.Errorf("session not found") }