diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go new file mode 100644 index 00000000..94f242ad --- /dev/null +++ b/pkg/auth/auth_test.go @@ -0,0 +1,356 @@ +package auth + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "gitlab.com/demodesk/neko/server/internal/config" + "gitlab.com/demodesk/neko/server/internal/session" + "gitlab.com/demodesk/neko/server/pkg/types" +) + +var i = 0 +var sessionManager = session.New(&config.Session{}) + +func rWithSession(profile types.MemberProfile) (*http.Request, types.Session, error) { + i++ + r := &http.Request{} + session, _, err := sessionManager.Create(fmt.Sprintf("id-%d", i), profile) + ctx := SetSession(r, session) + r = r.WithContext(ctx) + return r, session, err +} + +func TestSessionCtx(t *testing.T) { + r, session, err := rWithSession(types.MemberProfile{}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + sess, ok := GetSession(r) + if !ok { + t.Errorf("session not found") + return + } + + if !reflect.DeepEqual(sess, session) { + t.Errorf("sessions not equal") + return + } +} + +func TestAdminsOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{IsAdmin: false}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + r2, _, err := rWithSession(types.MemberProfile{IsAdmin: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + tests := []struct { + name string + r *http.Request + wantErr bool + }{ + { + name: "is not admin", + r: r1, + wantErr: true, + }, + { + name: "is admin", + r: r2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := AdminsOnly(nil, tt.r) + if (err != nil) != tt.wantErr { + t.Errorf("AdminsOnly() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestHostsOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{CanHost: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + r2, session, err := rWithSession(types.MemberProfile{CanHost: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + // r2 is host + sessionManager.SetHost(session) + + r3, _, err := rWithSession(types.MemberProfile{CanHost: false}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + tests := []struct { + name string + r *http.Request + wantErr bool + }{ + { + name: "is not hosting", + r: r1, + wantErr: true, + }, + { + name: "is hosting", + r: r2, + wantErr: false, + }, + { + name: "cannot host", + r: r3, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := HostsOnly(nil, tt.r) + if (err != nil) != tt.wantErr { + t.Errorf("HostsOnly() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestCanWatchOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{CanWatch: false}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + r2, _, err := rWithSession(types.MemberProfile{CanWatch: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + tests := []struct { + name string + r *http.Request + wantErr bool + }{ + { + name: "can not watch", + r: r1, + wantErr: true, + }, + { + name: "can watch", + r: r2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CanWatchOnly(nil, tt.r) + if (err != nil) != tt.wantErr { + t.Errorf("CanWatchOnly() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestCanHostOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{CanHost: false}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + r2, _, err := rWithSession(types.MemberProfile{CanHost: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + tests := []struct { + name string + r *http.Request + wantErr bool + privateMode bool + }{ + { + name: "can not host", + r: r1, + wantErr: true, + }, + { + name: "can host", + r: r2, + wantErr: false, + }, + { + name: "private mode enabled: can not host", + r: r1, + wantErr: true, + privateMode: true, + }, + { + name: "private mode enabled: can host", + r: r2, + wantErr: true, + privateMode: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings := sessionManager.Settings() + settings.PrivateMode = tt.privateMode + sessionManager.UpdateSettings(settings) + + _, err := CanHostOnly(nil, tt.r) + if (err != nil) != tt.wantErr { + t.Errorf("CanHostOnly() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestCanAccessClipboardOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{CanAccessClipboard: false}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + r2, _, err := rWithSession(types.MemberProfile{CanAccessClipboard: true}) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + tests := []struct { + name string + r *http.Request + wantErr bool + }{ + { + name: "can not access clipboard", + r: r1, + wantErr: true, + }, + { + name: "can access clipboard", + r: r2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CanAccessClipboardOnly(nil, tt.r) + if (err != nil) != tt.wantErr { + t.Errorf("CanAccessClipboardOnly() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestPluginsGenericOnly(t *testing.T) { + r1, _, err := rWithSession(types.MemberProfile{ + Plugins: map[string]any{ + "foo.bar": 1, + }, + }) + if err != nil { + t.Errorf("could not create session %s", err.Error()) + return + } + + t.Run("test if exists", func(t *testing.T) { + key := "foo.bar" + val := 1 + wantErr := false + + handler := PluginsGenericOnly(key, val) + _, err := handler(nil, r1) + if (err != nil) != wantErr { + t.Errorf("PluginsGenericOnly(%q, %v) error = %v, wantErr %v", key, val, err, wantErr) + return + } + }) + + t.Run("test when gets different value", func(t *testing.T) { + key := "foo.bar" + val := 2 + wantErr := true + + handler := PluginsGenericOnly(key, val) + _, err := handler(nil, r1) + if (err != nil) != wantErr { + t.Errorf("PluginsGenericOnly(%q, %v) error = %v, wantErr %v", key, val, err, wantErr) + return + } + }) + + t.Run("test when gets different type", func(t *testing.T) { + key := "foo.bar" + val := "1" + wantErr := true + + handler := PluginsGenericOnly(key, val) + _, err := handler(nil, r1) + if (err != nil) != wantErr { + t.Errorf("PluginsGenericOnly(%q, %v) error = %v, wantErr %v", key, val, err, wantErr) + return + } + }) + + t.Run("test if does not exists", func(t *testing.T) { + key := "foo.bar_not_extist" + val := 1 + wantErr := true + + handler := PluginsGenericOnly(key, val) + _, err := handler(nil, r1) + if (err != nil) != wantErr { + t.Errorf("PluginsGenericOnly(%q, %v) error = %v, wantErr %v", key, val, err, wantErr) + return + } + }) + + t.Run("test if session does not exists", func(t *testing.T) { + key := "foo.bar_not_extist" + val := 1 + wantErr := true + + handler := PluginsGenericOnly(key, val) + _, err := handler(nil, &http.Request{}) + if (err != nil) != wantErr { + t.Errorf("PluginsGenericOnly(%q, %v) error = %v, wantErr %v", key, val, err, wantErr) + return + } + }) +}