diff --git a/cmd/web/handlers_user_test.go b/cmd/web/handlers_user_test.go index 32adfe9..4624a20 100644 --- a/cmd/web/handlers_user_test.go +++ b/cmd/web/handlers_user_test.go @@ -127,8 +127,10 @@ func TestUserSignup(t *testing.T) { } type OAuth2Mock struct { - Srv *testServer - Priv *rsa.PrivateKey + Srv *testServer + Priv *rsa.PrivateKey + Subject string + Email string } func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { @@ -140,12 +142,12 @@ func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.A AccessToken: "AccessToken", Expiry: time.Now().Add(1 * time.Hour), } - m := make(map[string]interface{}) + m := make(map[string]any) var rawClaims = `{ "iss": "` + o.Srv.URL + `", "aud": "my-client-id", - "sub": "foo", - "email": "foo@example.com", + "sub": "` + o.Subject + `", + "email": "` + o.Email + `", "email_verified": true, "nonce": "nonce" }` @@ -178,23 +180,26 @@ func TestUserOIDCCallback(t *testing.T) { SkipExpiryCheck: true, } v := p.VerifierContext(ctx, cfg) + oMock := &OAuth2Mock{ + Srv: srv, + Priv: priv, + Subject: "foo", + } app.config.oauth = applicationOauthConfig{ ctx: context.Background(), oidcConfig: cfg, - config: &OAuth2Mock{ - Srv: srv, - Priv: priv, - }, - provider: p, - verifier: v, + config: oMock, + provider: p, + verifier: v, } app.config.oauthEnabled = true const ( - validSubject = "goodSubject" - validUserId = 1 - validEmail = "test@example.com" - validState = "goodState" + validSubject = "goodSubject" + unknownSubject = "foo" + validUserId = 1 + validEmail = "test@example.com" + validState = "goodState" ) tests := []struct { @@ -205,16 +210,46 @@ func TestUserOIDCCallback(t *testing.T) { wantCode int }{ { - name: "Found Subject", + name: "By Subject", subject: validSubject, + email: "", + state: validState, + wantCode: http.StatusSeeOther, + }, + { + name: "By Email", + subject: unknownSubject, email: validEmail, state: validState, wantCode: http.StatusSeeOther, }, + { + name: "No User", + subject: unknownSubject, + email: "", + state: validState, + wantCode: http.StatusSeeOther, + }, + { + name: "Invalid State", + subject: unknownSubject, + email: validEmail, + state: "", + wantCode: http.StatusInternalServerError, + }, + { + name: "Unknown Subject & Email", + subject: unknownSubject, + email: "", + state: validState, + wantCode: http.StatusInternalServerError, + }, } for _, tt := range tests { t.Run(tt.name, func(*testing.T) { + oMock.Subject = tt.subject + oMock.Email = tt.email r, err := http.NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) @@ -223,6 +258,7 @@ func TestUserOIDCCallback(t *testing.T) { q := r.URL.Query() q.Add("state", tt.state) r.URL.RawQuery = q.Encode() + c := &http.Cookie{ Name: "state", Value: validState, diff --git a/internal/models/mocks/user.go b/internal/models/mocks/user.go index 931e841..0ca4d30 100644 --- a/internal/models/mocks/user.go +++ b/internal/models/mocks/user.go @@ -101,8 +101,10 @@ func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value st func (m *UserModel) GetBySubject(subject string) (int64, error) { if subject == "goodSubject" { return 1, nil + } else if subject == "foo" { + return -1, models.ErrNoRecord } - return -1, models.ErrNoRecord + return -1, errors.New("Unexpected Error") } func (m *UserModel) GetByEmail(email string) (int64, error) {