unit testing for oidc and fix test cases for remote comments

This commit is contained in:
yequari 2025-07-17 17:02:35 -07:00
parent db1d4e1ad2
commit f6e332b76a
8 changed files with 289 additions and 65 deletions

View File

@ -1,9 +1,12 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"testing" "testing"
"git.32bit.cafe/32bitcafe/guestbook/internal/assert" "git.32bit.cafe/32bitcafe/guestbook/internal/assert"
@ -150,9 +153,6 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
ts := newTestServer(t, app.routes()) ts := newTestServer(t, app.routes())
defer ts.Close() defer ts.Close()
_, _, body := ts.get(t, fmt.Sprintf("/websites/%s/guestbook", shortIdToSlug(1)))
validCSRFToken := extractCSRFToken(t, body)
const ( const (
validAuthorName = "John Test" validAuthorName = "John Test"
validAuthorEmail = "test@example.com" validAuthorEmail = "test@example.com"
@ -166,8 +166,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail string authorEmail string
authorSite string authorSite string
content string content string
csrfToken string
wantCode int wantCode int
wantBody string
}{ }{
{ {
name: "Valid input", name: "Valid input",
@ -175,8 +175,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail: validAuthorEmail, authorEmail: validAuthorEmail,
authorSite: validAuthorSite, authorSite: validAuthorSite,
content: validContent, content: validContent,
csrfToken: validCSRFToken, wantCode: http.StatusOK,
wantCode: http.StatusSeeOther, wantBody: "Comment successfully posted",
}, },
{ {
name: "Blank name", name: "Blank name",
@ -184,8 +184,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail: validAuthorEmail, authorEmail: validAuthorEmail,
authorSite: validAuthorSite, authorSite: validAuthorSite,
content: validContent, content: validContent,
csrfToken: validCSRFToken, wantCode: http.StatusOK,
wantCode: http.StatusUnprocessableEntity, wantBody: "An error occurred",
}, },
{ {
name: "Blank email", name: "Blank email",
@ -193,8 +193,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail: "", authorEmail: "",
authorSite: validAuthorSite, authorSite: validAuthorSite,
content: validContent, content: validContent,
csrfToken: validCSRFToken, wantCode: http.StatusOK,
wantCode: http.StatusSeeOther, wantBody: "Comment successfully posted",
}, },
{ {
name: "Blank site", name: "Blank site",
@ -202,8 +202,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail: validAuthorEmail, authorEmail: validAuthorEmail,
authorSite: "", authorSite: "",
content: validContent, content: validContent,
csrfToken: validCSRFToken, wantCode: http.StatusOK,
wantCode: http.StatusSeeOther, wantBody: "Comment successfully posted",
}, },
{ {
name: "Blank content", name: "Blank content",
@ -211,21 +211,39 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
authorEmail: validAuthorEmail, authorEmail: validAuthorEmail,
authorSite: validAuthorSite, authorSite: validAuthorSite,
content: "", content: "",
csrfToken: validCSRFToken, wantCode: http.StatusOK,
wantCode: http.StatusUnprocessableEntity, wantBody: "An error occurred",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
form := url.Values{} form := url.Values{}
form.Add("authorname", tt.authorName) form.Add("authorname", tt.authorName)
form.Add("authoremail", tt.authorEmail) form.Add("authoremail", tt.authorEmail)
form.Add("authorsite", tt.authorSite) form.Add("authorsite", tt.authorSite)
form.Add("content", tt.content) form.Add("content", tt.content)
form.Add("csrf_token", tt.csrfToken) r, err := http.NewRequest("POST", ts.URL, strings.NewReader(form.Encode()))
code, _, body := ts.postForm(t, fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)), form) if err != nil {
assert.Equal(t, code, tt.wantCode) t.Fatal(err)
assert.Equal(t, body, body) }
r.URL.Path = fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.Header.Set("Origin", "http://example.com")
resp, err := ts.Client().Do(r)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
body = bytes.TrimSpace(body)
assert.Equal(t, resp.StatusCode, tt.wantCode)
assert.StringContains(t, string(body), tt.wantBody)
}) })
} }
} }

View File

@ -173,13 +173,35 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
app.serverError(w, r, err) app.serverError(w, r, err)
return return
} }
id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
if err != nil { // search for user by subject
id, err := app.users.GetBySubject(t.Subject)
if err != nil && errors.Is(err, models.ErrNoRecord) {
// if no user is found, check if they have signed up by email already
id, err = app.users.GetByEmail(t.Email)
if err == nil {
// if user is found by email, update subject to match them in the first step next time
err2 := app.users.UpdateSubject(id, t.Subject)
if err2 != nil {
app.serverError(w, r, err2)
return
}
}
} else if err != nil {
app.serverError(w, r, err)
return
}
if err != nil && errors.Is(err, models.ErrNoRecord) {
// if no user is found by subject or email, create a new user
id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings()) id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
if err != nil { if err != nil {
app.serverError(w, r, err) app.serverError(w, r, err)
} }
} else if err != nil {
app.serverError(w, r, err)
return
} }
app.sessionManager.Put(r.Context(), "authenticatedUserId", id) app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)
} }

View File

@ -1,11 +1,17 @@
package main package main
import ( import (
"context"
"crypto/rsa"
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
"time"
"git.32bit.cafe/32bitcafe/guestbook/internal/assert" "git.32bit.cafe/32bitcafe/guestbook/internal/assert"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-oidc/v3/oidc/oidctest"
"golang.org/x/oauth2"
) )
func TestUserSignup(t *testing.T) { func TestUserSignup(t *testing.T) {
@ -119,3 +125,126 @@ func TestUserSignup(t *testing.T) {
} }
} }
type OAuth2Mock struct {
Srv *testServer
Priv *rsa.PrivateKey
}
func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return ""
}
func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
tkn := oauth2.Token{
AccessToken: "AccessToken",
Expiry: time.Now().Add(1 * time.Hour),
}
m := make(map[string]interface{})
var rawClaims = `{
"iss": "` + o.Srv.URL + `",
"aud": "my-client-id",
"sub": "foo",
"email": "foo@example.com",
"email_verified": true,
"nonce": "nonce"
}`
m["id_token"] = oidctest.SignIDToken(o.Priv, "test-key", oidc.RS256, rawClaims)
return tkn.WithExtra(m), nil
}
func (o *OAuth2Mock) Client(ctx context.Context, t *oauth2.Token) *http.Client {
return nil
}
func TestUserOIDCCallback(t *testing.T) {
app := newTestApplication(t)
ts := newTestServer(t, app.routes())
priv := newTestKey(t)
srv := newTestOIDCServer(t, priv)
defer srv.Close()
defer ts.Close()
ctx := context.Background()
p, err := oidc.NewProvider(ctx, srv.URL)
if err != nil {
t.Fatal(err)
}
cfg := &oidc.Config{
ClientID: "my-client-id",
SkipExpiryCheck: true,
}
v := p.VerifierContext(ctx, cfg)
app.config.oauth = applicationOauthConfig{
ctx: context.Background(),
oidcConfig: cfg,
config: &OAuth2Mock{
Srv: srv,
Priv: priv,
},
provider: p,
verifier: v,
}
app.config.oauthEnabled = true
const (
validSubject = "goodSubject"
validUserId = 1
validEmail = "test@example.com"
validState = "goodState"
)
tests := []struct {
name string
subject string
email string
state string
wantCode int
}{
{
name: "Found Subject",
subject: validSubject,
email: validEmail,
state: validState,
wantCode: http.StatusSeeOther,
},
}
for _, tt := range tests {
t.Run(tt.name, func(*testing.T) {
r, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
r.URL.Path = "/users/login/oidc/callback"
q := r.URL.Query()
q.Add("state", tt.state)
r.URL.RawQuery = q.Encode()
c := &http.Cookie{
Name: "state",
Value: validState,
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,
HttpOnly: true,
}
d := &http.Cookie{
Name: "nonce",
Value: "nonce",
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,
HttpOnly: true,
}
r.AddCookie(c)
r.AddCookie(d)
resp, err := ts.Client().Do(r)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, resp.StatusCode, tt.wantCode)
})
}
}

View File

@ -16,6 +16,7 @@ import (
"time" "time"
"unicode" "unicode"
"git.32bit.cafe/32bitcafe/guestbook/internal/auth"
"git.32bit.cafe/32bitcafe/guestbook/internal/models" "git.32bit.cafe/32bitcafe/guestbook/internal/models"
"github.com/alexedwards/scs/sqlite3store" "github.com/alexedwards/scs/sqlite3store"
"github.com/alexedwards/scs/v2" "github.com/alexedwards/scs/v2"
@ -28,9 +29,9 @@ import (
type applicationOauthConfig struct { type applicationOauthConfig struct {
ctx context.Context ctx context.Context
config oauth2.Config
provider *oidc.Provider
oidcConfig *oidc.Config oidcConfig *oidc.Config
config auth.OAuth2ConfigInterface
provider *oidc.Provider
verifier *oidc.IDTokenVerifier verifier *oidc.IDTokenVerifier
} }
@ -204,7 +205,7 @@ func setupConfig(addr string) (applicationConfig, error) {
ClientID: clientID, ClientID: clientID,
} }
o.verifier = provider.Verifier(o.oidcConfig) o.verifier = provider.Verifier(o.oidcConfig)
o.config = oauth2.Config{ o.config = &oauth2.Config{
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
Endpoint: provider.Endpoint(), Endpoint: provider.Endpoint(),

View File

@ -2,6 +2,8 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand"
"crypto/rsa"
"html" "html"
"io" "io"
"log/slog" "log/slog"
@ -15,6 +17,8 @@ import (
"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks" "git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
"github.com/alexedwards/scs/v2" "github.com/alexedwards/scs/v2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-oidc/v3/oidc/oidctest"
"github.com/gorilla/schema" "github.com/gorilla/schema"
) )
@ -34,9 +38,35 @@ func newTestApplication(t *testing.T) *application {
guestbookComments: &mocks.GuestbookCommentModel{}, guestbookComments: &mocks.GuestbookCommentModel{},
formDecoder: formDecoder, formDecoder: formDecoder,
timezones: getAvailableTimezones(), timezones: getAvailableTimezones(),
config: applicationConfig{
localAuthEnabled: true,
},
} }
} }
func newTestKey(t *testing.T) *rsa.PrivateKey {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
return priv
}
func newTestOIDCServer(t *testing.T, priv *rsa.PrivateKey) *testServer {
s := &oidctest.Server{
PublicKeys: []oidctest.PublicKey{
{
PublicKey: priv.Public(),
KeyID: "test-key",
Algorithm: oidc.ES256,
},
},
}
ts := httptest.NewServer(s)
s.SetIssuer(ts.URL)
return &testServer{ts}
}
type testServer struct { type testServer struct {
*httptest.Server *httptest.Server
} }

14
internal/auth/auth.go Normal file
View File

@ -0,0 +1,14 @@
package auth
import (
"context"
"net/http"
"golang.org/x/oauth2"
)
type OAuth2ConfigInterface interface {
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
Client(ctx context.Context, t *oauth2.Token) *http.Client
}

View File

@ -1,6 +1,7 @@
package mocks package mocks
import ( import (
"errors"
"time" "time"
"git.32bit.cafe/32bitcafe/guestbook/internal/models" "git.32bit.cafe/32bitcafe/guestbook/internal/models"
@ -76,13 +77,6 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
return 0, models.ErrInvalidCredentials return 0, models.ErrInvalidCredentials
} }
func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
if email == "test@example.com" {
return 1, nil
}
return 0, models.ErrInvalidCredentials
}
func (m *UserModel) Exists(id int64) (bool, error) { func (m *UserModel) Exists(id int64) (bool, error) {
switch id { switch id {
case 1: case 1:
@ -103,3 +97,24 @@ func (m *UserModel) UpdateUserSettings(userId int64, settings models.UserSetting
func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error { func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
return nil return nil
} }
func (m *UserModel) GetBySubject(subject string) (int64, error) {
if subject == "goodSubject" {
return 1, nil
}
return -1, models.ErrNoRecord
}
func (m *UserModel) GetByEmail(email string) (int64, error) {
if email == "test@example.com" {
return 1, nil
}
return -1, models.ErrNoRecord
}
func (m *UserModel) UpdateSubject(userId int64, subject string) error {
if userId == 1 {
return nil
}
return errors.New("invalid")
}

View File

@ -49,12 +49,14 @@ type UserModelInterface interface {
InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error) InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
Get(shortId uint64) (User, error) Get(shortId uint64) (User, error)
GetById(id int64) (User, error) GetById(id int64) (User, error)
GetByEmail(email string) (int64, error)
GetBySubject(subject string) (int64, error)
GetAll() ([]User, error) GetAll() ([]User, error)
Authenticate(email, password string) (int64, error) Authenticate(email, password string) (int64, error)
AuthenticateByOIDC(email, subject string) (int64, error)
Exists(id int64) (bool, error) Exists(id int64) (bool, error)
UpdateUserSettings(userId int64, settings UserSettings) error UpdateUserSettings(userId int64, settings UserSettings) error
UpdateSetting(userId int64, setting Setting, value string) error UpdateSetting(userId int64, setting Setting, value string) error
UpdateSubject(userId int64, subject string) error
} }
func (m *UserModel) InitializeSettingsMap() error { func (m *UserModel) InitializeSettingsMap() error {
@ -292,51 +294,44 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
return id, nil return id, nil
} }
func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) { func (m *UserModel) GetByEmail(email string) (int64, error) {
var id int64 var id int64
var s sql.NullString stmt := `SELECT Id FROM users WHERE Email = ?`
tx, err := m.DB.Begin() err := m.DB.QueryRow(stmt, email).Scan(&id)
if err != nil {
return -1, err
}
stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return -1, err
}
return -1, ErrNoRecord return -1, ErrNoRecord
} else { } else {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return -1, err
}
return -1, err return -1, err
} }
} }
if !s.Valid {
stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
_, err = tx.Exec(stmt, subject, id)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return -1, err
}
return -1, err
}
} else if subject != s.String {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return -1, ErrInvalidCredentials
}
}
err = tx.Commit()
if err != nil {
return -1, err
}
return id, nil return id, nil
} }
func (m *UserModel) GetBySubject(subject string) (int64, error) {
var id int64
var s sql.NullString
stmt := `SELECT Id, OIDCSubject FROM users WHERE OIDCSubject = ?`
err := m.DB.QueryRow(stmt, subject).Scan(&id, &s)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return -1, ErrNoRecord
} else {
return -1, err
}
}
return id, nil
}
func (m *UserModel) UpdateSubject(userId int64, subject string) error {
stmt := `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
_, err := m.DB.Exec(stmt, subject, userId)
if err != nil {
return err
}
return nil
}
func (m *UserModel) Exists(id int64) (bool, error) { func (m *UserModel) Exists(id int64) (bool, error) {
var exists bool var exists bool
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)` stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`