diff --git a/cmd/web/handlers_guestbook_test.go b/cmd/web/handlers_guestbook_test.go index 182f081..12870c8 100644 --- a/cmd/web/handlers_guestbook_test.go +++ b/cmd/web/handlers_guestbook_test.go @@ -1,9 +1,12 @@ package main import ( + "bytes" "fmt" + "io" "net/http" "net/url" + "strings" "testing" "git.32bit.cafe/32bitcafe/guestbook/internal/assert" @@ -150,9 +153,6 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { ts := newTestServer(t, app.routes()) defer ts.Close() - _, _, body := ts.get(t, fmt.Sprintf("/websites/%s/guestbook", shortIdToSlug(1))) - validCSRFToken := extractCSRFToken(t, body) - const ( validAuthorName = "John Test" validAuthorEmail = "test@example.com" @@ -166,8 +166,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail string authorSite string content string - csrfToken string wantCode int + wantBody string }{ { name: "Valid input", @@ -175,8 +175,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail: validAuthorEmail, authorSite: validAuthorSite, content: validContent, - csrfToken: validCSRFToken, - wantCode: http.StatusSeeOther, + wantCode: http.StatusOK, + wantBody: "Comment successfully posted", }, { name: "Blank name", @@ -184,8 +184,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail: validAuthorEmail, authorSite: validAuthorSite, content: validContent, - csrfToken: validCSRFToken, - wantCode: http.StatusUnprocessableEntity, + wantCode: http.StatusOK, + wantBody: "An error occurred", }, { name: "Blank email", @@ -193,8 +193,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail: "", authorSite: validAuthorSite, content: validContent, - csrfToken: validCSRFToken, - wantCode: http.StatusSeeOther, + wantCode: http.StatusOK, + wantBody: "Comment successfully posted", }, { name: "Blank site", @@ -202,8 +202,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail: validAuthorEmail, authorSite: "", content: validContent, - csrfToken: validCSRFToken, - wantCode: http.StatusSeeOther, + wantCode: http.StatusOK, + wantBody: "Comment successfully posted", }, { name: "Blank content", @@ -211,21 +211,39 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) { authorEmail: validAuthorEmail, authorSite: validAuthorSite, content: "", - csrfToken: validCSRFToken, - wantCode: http.StatusUnprocessableEntity, + wantCode: http.StatusOK, + wantBody: "An error occurred", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + form := url.Values{} form.Add("authorname", tt.authorName) form.Add("authoremail", tt.authorEmail) form.Add("authorsite", tt.authorSite) form.Add("content", tt.content) - form.Add("csrf_token", tt.csrfToken) - code, _, body := ts.postForm(t, fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)), form) - assert.Equal(t, code, tt.wantCode) - assert.Equal(t, body, body) + r, err := http.NewRequest("POST", ts.URL, strings.NewReader(form.Encode())) + if err != nil { + t.Fatal(err) + } + r.URL.Path = fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.Header.Set("Origin", "http://example.com") + + resp, err := ts.Client().Do(r) + if err != nil { + t.Fatal(err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + body = bytes.TrimSpace(body) + assert.Equal(t, resp.StatusCode, tt.wantCode) + assert.StringContains(t, string(body), tt.wantBody) }) } } diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go index 141d459..132f099 100644 --- a/cmd/web/handlers_user.go +++ b/cmd/web/handlers_user.go @@ -173,13 +173,35 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req app.serverError(w, r, err) return } - id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject) - if err != nil { + + // search for user by subject + id, err := app.users.GetBySubject(t.Subject) + if err != nil && errors.Is(err, models.ErrNoRecord) { + // if no user is found, check if they have signed up by email already + id, err = app.users.GetByEmail(t.Email) + if err == nil { + // if user is found by email, update subject to match them in the first step next time + err2 := app.users.UpdateSubject(id, t.Subject) + if err2 != nil { + app.serverError(w, r, err2) + return + } + } + } else if err != nil { + app.serverError(w, r, err) + return + } + if err != nil && errors.Is(err, models.ErrNoRecord) { + // if no user is found by subject or email, create a new user id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings()) if err != nil { app.serverError(w, r, err) } + } else if err != nil { + app.serverError(w, r, err) + return } + app.sessionManager.Put(r.Context(), "authenticatedUserId", id) http.Redirect(w, r, "/", http.StatusSeeOther) } diff --git a/cmd/web/handlers_user_test.go b/cmd/web/handlers_user_test.go index 8a97224..32adfe9 100644 --- a/cmd/web/handlers_user_test.go +++ b/cmd/web/handlers_user_test.go @@ -1,11 +1,17 @@ package main import ( + "context" + "crypto/rsa" "net/http" "net/url" "testing" + "time" "git.32bit.cafe/32bitcafe/guestbook/internal/assert" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/coreos/go-oidc/v3/oidc/oidctest" + "golang.org/x/oauth2" ) func TestUserSignup(t *testing.T) { @@ -119,3 +125,126 @@ func TestUserSignup(t *testing.T) { } } + +type OAuth2Mock struct { + Srv *testServer + Priv *rsa.PrivateKey +} + +func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + return "" +} + +func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + tkn := oauth2.Token{ + AccessToken: "AccessToken", + Expiry: time.Now().Add(1 * time.Hour), + } + m := make(map[string]interface{}) + var rawClaims = `{ + "iss": "` + o.Srv.URL + `", + "aud": "my-client-id", + "sub": "foo", + "email": "foo@example.com", + "email_verified": true, + "nonce": "nonce" + }` + + m["id_token"] = oidctest.SignIDToken(o.Priv, "test-key", oidc.RS256, rawClaims) + return tkn.WithExtra(m), nil +} + +func (o *OAuth2Mock) Client(ctx context.Context, t *oauth2.Token) *http.Client { + return nil +} + +func TestUserOIDCCallback(t *testing.T) { + app := newTestApplication(t) + ts := newTestServer(t, app.routes()) + + priv := newTestKey(t) + srv := newTestOIDCServer(t, priv) + + defer srv.Close() + defer ts.Close() + ctx := context.Background() + + p, err := oidc.NewProvider(ctx, srv.URL) + if err != nil { + t.Fatal(err) + } + cfg := &oidc.Config{ + ClientID: "my-client-id", + SkipExpiryCheck: true, + } + v := p.VerifierContext(ctx, cfg) + app.config.oauth = applicationOauthConfig{ + ctx: context.Background(), + oidcConfig: cfg, + config: &OAuth2Mock{ + Srv: srv, + Priv: priv, + }, + provider: p, + verifier: v, + } + app.config.oauthEnabled = true + + const ( + validSubject = "goodSubject" + validUserId = 1 + validEmail = "test@example.com" + validState = "goodState" + ) + + tests := []struct { + name string + subject string + email string + state string + wantCode int + }{ + { + name: "Found Subject", + subject: validSubject, + email: validEmail, + state: validState, + wantCode: http.StatusSeeOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(*testing.T) { + r, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + r.URL.Path = "/users/login/oidc/callback" + q := r.URL.Query() + q.Add("state", tt.state) + r.URL.RawQuery = q.Encode() + c := &http.Cookie{ + Name: "state", + Value: validState, + MaxAge: int(time.Hour.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + d := &http.Cookie{ + Name: "nonce", + Value: "nonce", + MaxAge: int(time.Hour.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + r.AddCookie(c) + r.AddCookie(d) + resp, err := ts.Client().Do(r) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, resp.StatusCode, tt.wantCode) + }) + } + +} diff --git a/cmd/web/main.go b/cmd/web/main.go index b764069..b15619b 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -16,6 +16,7 @@ import ( "time" "unicode" + "git.32bit.cafe/32bitcafe/guestbook/internal/auth" "git.32bit.cafe/32bitcafe/guestbook/internal/models" "github.com/alexedwards/scs/sqlite3store" "github.com/alexedwards/scs/v2" @@ -28,9 +29,9 @@ import ( type applicationOauthConfig struct { ctx context.Context - config oauth2.Config - provider *oidc.Provider oidcConfig *oidc.Config + config auth.OAuth2ConfigInterface + provider *oidc.Provider verifier *oidc.IDTokenVerifier } @@ -204,7 +205,7 @@ func setupConfig(addr string) (applicationConfig, error) { ClientID: clientID, } o.verifier = provider.Verifier(o.oidcConfig) - o.config = oauth2.Config{ + o.config = &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, Endpoint: provider.Endpoint(), diff --git a/cmd/web/testutils_test.go b/cmd/web/testutils_test.go index b5267b6..afab66a 100644 --- a/cmd/web/testutils_test.go +++ b/cmd/web/testutils_test.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "crypto/rand" + "crypto/rsa" "html" "io" "log/slog" @@ -15,6 +17,8 @@ import ( "git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks" "github.com/alexedwards/scs/v2" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/coreos/go-oidc/v3/oidc/oidctest" "github.com/gorilla/schema" ) @@ -34,9 +38,35 @@ func newTestApplication(t *testing.T) *application { guestbookComments: &mocks.GuestbookCommentModel{}, formDecoder: formDecoder, timezones: getAvailableTimezones(), + config: applicationConfig{ + localAuthEnabled: true, + }, } } +func newTestKey(t *testing.T) *rsa.PrivateKey { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + return priv +} + +func newTestOIDCServer(t *testing.T, priv *rsa.PrivateKey) *testServer { + s := &oidctest.Server{ + PublicKeys: []oidctest.PublicKey{ + { + PublicKey: priv.Public(), + KeyID: "test-key", + Algorithm: oidc.ES256, + }, + }, + } + ts := httptest.NewServer(s) + s.SetIssuer(ts.URL) + return &testServer{ts} +} + type testServer struct { *httptest.Server } diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..9704cf4 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,14 @@ +package auth + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" +) + +type OAuth2ConfigInterface interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + Client(ctx context.Context, t *oauth2.Token) *http.Client +} diff --git a/internal/models/mocks/users.go b/internal/models/mocks/user.go similarity index 84% rename from internal/models/mocks/users.go rename to internal/models/mocks/user.go index e0ab8a8..931e841 100644 --- a/internal/models/mocks/users.go +++ b/internal/models/mocks/user.go @@ -1,6 +1,7 @@ package mocks import ( + "errors" "time" "git.32bit.cafe/32bitcafe/guestbook/internal/models" @@ -76,13 +77,6 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) { return 0, models.ErrInvalidCredentials } -func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) { - if email == "test@example.com" { - return 1, nil - } - return 0, models.ErrInvalidCredentials -} - func (m *UserModel) Exists(id int64) (bool, error) { switch id { case 1: @@ -103,3 +97,24 @@ func (m *UserModel) UpdateUserSettings(userId int64, settings models.UserSetting func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error { return nil } + +func (m *UserModel) GetBySubject(subject string) (int64, error) { + if subject == "goodSubject" { + return 1, nil + } + return -1, models.ErrNoRecord +} + +func (m *UserModel) GetByEmail(email string) (int64, error) { + if email == "test@example.com" { + return 1, nil + } + return -1, models.ErrNoRecord +} + +func (m *UserModel) UpdateSubject(userId int64, subject string) error { + if userId == 1 { + return nil + } + return errors.New("invalid") +} diff --git a/internal/models/user.go b/internal/models/user.go index 12a86e0..13eacc5 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -49,12 +49,14 @@ type UserModelInterface interface { InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error) Get(shortId uint64) (User, error) GetById(id int64) (User, error) + GetByEmail(email string) (int64, error) + GetBySubject(subject string) (int64, error) GetAll() ([]User, error) Authenticate(email, password string) (int64, error) - AuthenticateByOIDC(email, subject string) (int64, error) Exists(id int64) (bool, error) UpdateUserSettings(userId int64, settings UserSettings) error UpdateSetting(userId int64, setting Setting, value string) error + UpdateSubject(userId int64, subject string) error } func (m *UserModel) InitializeSettingsMap() error { @@ -292,51 +294,44 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) { return id, nil } -func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) { +func (m *UserModel) GetByEmail(email string) (int64, error) { var id int64 - var s sql.NullString - tx, err := m.DB.Begin() - if err != nil { - return -1, err - } - stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?` - err = tx.QueryRow(stmt, email, subject).Scan(&id, &s) + stmt := `SELECT Id FROM users WHERE Email = ?` + err := m.DB.QueryRow(stmt, email).Scan(&id) if err != nil { if errors.Is(err, sql.ErrNoRows) { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return -1, err - } return -1, ErrNoRecord } else { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return -1, err - } return -1, err } } - - if !s.Valid { - stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?` - _, err = tx.Exec(stmt, subject, id) - if err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return -1, err - } - return -1, err - } - } else if subject != s.String { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return -1, ErrInvalidCredentials - } - } - - err = tx.Commit() - if err != nil { - return -1, err - } return id, nil } +func (m *UserModel) GetBySubject(subject string) (int64, error) { + var id int64 + var s sql.NullString + stmt := `SELECT Id, OIDCSubject FROM users WHERE OIDCSubject = ?` + err := m.DB.QueryRow(stmt, subject).Scan(&id, &s) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return -1, ErrNoRecord + } else { + return -1, err + } + } + return id, nil +} + +func (m *UserModel) UpdateSubject(userId int64, subject string) error { + stmt := `UPDATE users SET OIDCSubject = ? WHERE Id = ?` + _, err := m.DB.Exec(stmt, subject, userId) + if err != nil { + return err + } + return nil +} + func (m *UserModel) Exists(id int64) (bool, error) { var exists bool stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`