add oidc unit tests

This commit is contained in:
yequari 2025-07-19 15:36:50 -07:00
parent f6e332b76a
commit 2759127cf9
2 changed files with 55 additions and 17 deletions

View File

@ -129,6 +129,8 @@ func TestUserSignup(t *testing.T) {
type OAuth2Mock struct { type OAuth2Mock struct {
Srv *testServer Srv *testServer
Priv *rsa.PrivateKey Priv *rsa.PrivateKey
Subject string
Email string
} }
func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
@ -140,12 +142,12 @@ func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.A
AccessToken: "AccessToken", AccessToken: "AccessToken",
Expiry: time.Now().Add(1 * time.Hour), Expiry: time.Now().Add(1 * time.Hour),
} }
m := make(map[string]interface{}) m := make(map[string]any)
var rawClaims = `{ var rawClaims = `{
"iss": "` + o.Srv.URL + `", "iss": "` + o.Srv.URL + `",
"aud": "my-client-id", "aud": "my-client-id",
"sub": "foo", "sub": "` + o.Subject + `",
"email": "foo@example.com", "email": "` + o.Email + `",
"email_verified": true, "email_verified": true,
"nonce": "nonce" "nonce": "nonce"
}` }`
@ -178,13 +180,15 @@ func TestUserOIDCCallback(t *testing.T) {
SkipExpiryCheck: true, SkipExpiryCheck: true,
} }
v := p.VerifierContext(ctx, cfg) v := p.VerifierContext(ctx, cfg)
oMock := &OAuth2Mock{
Srv: srv,
Priv: priv,
Subject: "foo",
}
app.config.oauth = applicationOauthConfig{ app.config.oauth = applicationOauthConfig{
ctx: context.Background(), ctx: context.Background(),
oidcConfig: cfg, oidcConfig: cfg,
config: &OAuth2Mock{ config: oMock,
Srv: srv,
Priv: priv,
},
provider: p, provider: p,
verifier: v, verifier: v,
} }
@ -192,6 +196,7 @@ func TestUserOIDCCallback(t *testing.T) {
const ( const (
validSubject = "goodSubject" validSubject = "goodSubject"
unknownSubject = "foo"
validUserId = 1 validUserId = 1
validEmail = "test@example.com" validEmail = "test@example.com"
validState = "goodState" validState = "goodState"
@ -205,16 +210,46 @@ func TestUserOIDCCallback(t *testing.T) {
wantCode int wantCode int
}{ }{
{ {
name: "Found Subject", name: "By Subject",
subject: validSubject, subject: validSubject,
email: "",
state: validState,
wantCode: http.StatusSeeOther,
},
{
name: "By Email",
subject: unknownSubject,
email: validEmail, email: validEmail,
state: validState, state: validState,
wantCode: http.StatusSeeOther, wantCode: http.StatusSeeOther,
}, },
{
name: "No User",
subject: unknownSubject,
email: "",
state: validState,
wantCode: http.StatusSeeOther,
},
{
name: "Invalid State",
subject: unknownSubject,
email: validEmail,
state: "",
wantCode: http.StatusInternalServerError,
},
{
name: "Unknown Subject & Email",
subject: unknownSubject,
email: "",
state: validState,
wantCode: http.StatusInternalServerError,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(*testing.T) { t.Run(tt.name, func(*testing.T) {
oMock.Subject = tt.subject
oMock.Email = tt.email
r, err := http.NewRequest("GET", ts.URL, nil) r, err := http.NewRequest("GET", ts.URL, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -223,6 +258,7 @@ func TestUserOIDCCallback(t *testing.T) {
q := r.URL.Query() q := r.URL.Query()
q.Add("state", tt.state) q.Add("state", tt.state)
r.URL.RawQuery = q.Encode() r.URL.RawQuery = q.Encode()
c := &http.Cookie{ c := &http.Cookie{
Name: "state", Name: "state",
Value: validState, Value: validState,

View File

@ -101,9 +101,11 @@ func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value st
func (m *UserModel) GetBySubject(subject string) (int64, error) { func (m *UserModel) GetBySubject(subject string) (int64, error) {
if subject == "goodSubject" { if subject == "goodSubject" {
return 1, nil return 1, nil
} } else if subject == "foo" {
return -1, models.ErrNoRecord return -1, models.ErrNoRecord
} }
return -1, errors.New("Unexpected Error")
}
func (m *UserModel) GetByEmail(email string) (int64, error) { func (m *UserModel) GetByEmail(email string) (int64, error) {
if email == "test@example.com" { if email == "test@example.com" {