unit testing for oidc and fix test cases for remote comments
This commit is contained in:
		
							parent
							
								
									db1d4e1ad2
								
							
						
					
					
						commit
						f6e332b76a
					
				@ -1,9 +1,12 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
 | 
				
			||||||
@ -150,9 +153,6 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
	ts := newTestServer(t, app.routes())
 | 
						ts := newTestServer(t, app.routes())
 | 
				
			||||||
	defer ts.Close()
 | 
						defer ts.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, body := ts.get(t, fmt.Sprintf("/websites/%s/guestbook", shortIdToSlug(1)))
 | 
					 | 
				
			||||||
	validCSRFToken := extractCSRFToken(t, body)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	const (
 | 
						const (
 | 
				
			||||||
		validAuthorName  = "John Test"
 | 
							validAuthorName  = "John Test"
 | 
				
			||||||
		validAuthorEmail = "test@example.com"
 | 
							validAuthorEmail = "test@example.com"
 | 
				
			||||||
@ -166,8 +166,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
		authorEmail string
 | 
							authorEmail string
 | 
				
			||||||
		authorSite  string
 | 
							authorSite  string
 | 
				
			||||||
		content     string
 | 
							content     string
 | 
				
			||||||
		csrfToken   string
 | 
					 | 
				
			||||||
		wantCode    int
 | 
							wantCode    int
 | 
				
			||||||
 | 
							wantBody    string
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:        "Valid input",
 | 
								name:        "Valid input",
 | 
				
			||||||
@ -175,8 +175,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
			authorEmail: validAuthorEmail,
 | 
								authorEmail: validAuthorEmail,
 | 
				
			||||||
			authorSite:  validAuthorSite,
 | 
								authorSite:  validAuthorSite,
 | 
				
			||||||
			content:     validContent,
 | 
								content:     validContent,
 | 
				
			||||||
			csrfToken:   validCSRFToken,
 | 
								wantCode:    http.StatusOK,
 | 
				
			||||||
			wantCode:    http.StatusSeeOther,
 | 
								wantBody:    "Comment successfully posted",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:        "Blank name",
 | 
								name:        "Blank name",
 | 
				
			||||||
@ -184,8 +184,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
			authorEmail: validAuthorEmail,
 | 
								authorEmail: validAuthorEmail,
 | 
				
			||||||
			authorSite:  validAuthorSite,
 | 
								authorSite:  validAuthorSite,
 | 
				
			||||||
			content:     validContent,
 | 
								content:     validContent,
 | 
				
			||||||
			csrfToken:   validCSRFToken,
 | 
								wantCode:    http.StatusOK,
 | 
				
			||||||
			wantCode:    http.StatusUnprocessableEntity,
 | 
								wantBody:    "An error occurred",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:        "Blank email",
 | 
								name:        "Blank email",
 | 
				
			||||||
@ -193,8 +193,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
			authorEmail: "",
 | 
								authorEmail: "",
 | 
				
			||||||
			authorSite:  validAuthorSite,
 | 
								authorSite:  validAuthorSite,
 | 
				
			||||||
			content:     validContent,
 | 
								content:     validContent,
 | 
				
			||||||
			csrfToken:   validCSRFToken,
 | 
								wantCode:    http.StatusOK,
 | 
				
			||||||
			wantCode:    http.StatusSeeOther,
 | 
								wantBody:    "Comment successfully posted",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:        "Blank site",
 | 
								name:        "Blank site",
 | 
				
			||||||
@ -202,8 +202,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
			authorEmail: validAuthorEmail,
 | 
								authorEmail: validAuthorEmail,
 | 
				
			||||||
			authorSite:  "",
 | 
								authorSite:  "",
 | 
				
			||||||
			content:     validContent,
 | 
								content:     validContent,
 | 
				
			||||||
			csrfToken:   validCSRFToken,
 | 
								wantCode:    http.StatusOK,
 | 
				
			||||||
			wantCode:    http.StatusSeeOther,
 | 
								wantBody:    "Comment successfully posted",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:        "Blank content",
 | 
								name:        "Blank content",
 | 
				
			||||||
@ -211,21 +211,39 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 | 
				
			|||||||
			authorEmail: validAuthorEmail,
 | 
								authorEmail: validAuthorEmail,
 | 
				
			||||||
			authorSite:  validAuthorSite,
 | 
								authorSite:  validAuthorSite,
 | 
				
			||||||
			content:     "",
 | 
								content:     "",
 | 
				
			||||||
			csrfToken:   validCSRFToken,
 | 
								wantCode:    http.StatusOK,
 | 
				
			||||||
			wantCode:    http.StatusUnprocessableEntity,
 | 
								wantBody:    "An error occurred",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			form := url.Values{}
 | 
								form := url.Values{}
 | 
				
			||||||
			form.Add("authorname", tt.authorName)
 | 
								form.Add("authorname", tt.authorName)
 | 
				
			||||||
			form.Add("authoremail", tt.authorEmail)
 | 
								form.Add("authoremail", tt.authorEmail)
 | 
				
			||||||
			form.Add("authorsite", tt.authorSite)
 | 
								form.Add("authorsite", tt.authorSite)
 | 
				
			||||||
			form.Add("content", tt.content)
 | 
								form.Add("content", tt.content)
 | 
				
			||||||
			form.Add("csrf_token", tt.csrfToken)
 | 
								r, err := http.NewRequest("POST", ts.URL, strings.NewReader(form.Encode()))
 | 
				
			||||||
			code, _, body := ts.postForm(t, fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)), form)
 | 
								if err != nil {
 | 
				
			||||||
			assert.Equal(t, code, tt.wantCode)
 | 
									t.Fatal(err)
 | 
				
			||||||
			assert.Equal(t, body, body)
 | 
								}
 | 
				
			||||||
 | 
								r.URL.Path = fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1))
 | 
				
			||||||
 | 
								r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
								r.Header.Set("Origin", "http://example.com")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								resp, err := ts.Client().Do(r)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								defer resp.Body.Close()
 | 
				
			||||||
 | 
								body, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								body = bytes.TrimSpace(body)
 | 
				
			||||||
 | 
								assert.Equal(t, resp.StatusCode, tt.wantCode)
 | 
				
			||||||
 | 
								assert.StringContains(t, string(body), tt.wantBody)
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -173,13 +173,35 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
 | 
				
			|||||||
		app.serverError(w, r, err)
 | 
							app.serverError(w, r, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						// search for user by subject
 | 
				
			||||||
 | 
						id, err := app.users.GetBySubject(t.Subject)
 | 
				
			||||||
 | 
						if err != nil && errors.Is(err, models.ErrNoRecord) {
 | 
				
			||||||
 | 
							// if no user is found, check if they have signed up by email already
 | 
				
			||||||
 | 
							id, err = app.users.GetByEmail(t.Email)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								// if user is found by email, update subject to match them in the first step next time
 | 
				
			||||||
 | 
								err2 := app.users.UpdateSubject(id, t.Subject)
 | 
				
			||||||
 | 
								if err2 != nil {
 | 
				
			||||||
 | 
									app.serverError(w, r, err2)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if err != nil {
 | 
				
			||||||
 | 
							app.serverError(w, r, err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil && errors.Is(err, models.ErrNoRecord) {
 | 
				
			||||||
 | 
							// if no user is found by subject or email, create a new user
 | 
				
			||||||
		id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
 | 
							id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			app.serverError(w, r, err)
 | 
								app.serverError(w, r, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						} else if err != nil {
 | 
				
			||||||
 | 
							app.serverError(w, r, err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
 | 
						app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
 | 
				
			||||||
	http.Redirect(w, r, "/", http.StatusSeeOther)
 | 
						http.Redirect(w, r, "/", http.StatusSeeOther)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,17 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"crypto/rsa"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
 | 
				
			||||||
 | 
						"github.com/coreos/go-oidc/v3/oidc"
 | 
				
			||||||
 | 
						"github.com/coreos/go-oidc/v3/oidc/oidctest"
 | 
				
			||||||
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestUserSignup(t *testing.T) {
 | 
					func TestUserSignup(t *testing.T) {
 | 
				
			||||||
@ -119,3 +125,126 @@ func TestUserSignup(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OAuth2Mock struct {
 | 
				
			||||||
 | 
						Srv  *testServer
 | 
				
			||||||
 | 
						Priv *rsa.PrivateKey
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
 | 
				
			||||||
 | 
						tkn := oauth2.Token{
 | 
				
			||||||
 | 
							AccessToken: "AccessToken",
 | 
				
			||||||
 | 
							Expiry:      time.Now().Add(1 * time.Hour),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m := make(map[string]interface{})
 | 
				
			||||||
 | 
						var rawClaims = `{
 | 
				
			||||||
 | 
							"iss": "` + o.Srv.URL + `",
 | 
				
			||||||
 | 
							"aud": "my-client-id",
 | 
				
			||||||
 | 
							"sub": "foo",
 | 
				
			||||||
 | 
							"email": "foo@example.com",
 | 
				
			||||||
 | 
							"email_verified": true,
 | 
				
			||||||
 | 
							"nonce": "nonce"
 | 
				
			||||||
 | 
							}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m["id_token"] = oidctest.SignIDToken(o.Priv, "test-key", oidc.RS256, rawClaims)
 | 
				
			||||||
 | 
						return tkn.WithExtra(m), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (o *OAuth2Mock) Client(ctx context.Context, t *oauth2.Token) *http.Client {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUserOIDCCallback(t *testing.T) {
 | 
				
			||||||
 | 
						app := newTestApplication(t)
 | 
				
			||||||
 | 
						ts := newTestServer(t, app.routes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						priv := newTestKey(t)
 | 
				
			||||||
 | 
						srv := newTestOIDCServer(t, priv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer srv.Close()
 | 
				
			||||||
 | 
						defer ts.Close()
 | 
				
			||||||
 | 
						ctx := context.Background()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p, err := oidc.NewProvider(ctx, srv.URL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						cfg := &oidc.Config{
 | 
				
			||||||
 | 
							ClientID:        "my-client-id",
 | 
				
			||||||
 | 
							SkipExpiryCheck: true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						v := p.VerifierContext(ctx, cfg)
 | 
				
			||||||
 | 
						app.config.oauth = applicationOauthConfig{
 | 
				
			||||||
 | 
							ctx:        context.Background(),
 | 
				
			||||||
 | 
							oidcConfig: cfg,
 | 
				
			||||||
 | 
							config: &OAuth2Mock{
 | 
				
			||||||
 | 
								Srv:  srv,
 | 
				
			||||||
 | 
								Priv: priv,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							provider: p,
 | 
				
			||||||
 | 
							verifier: v,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						app.config.oauthEnabled = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						const (
 | 
				
			||||||
 | 
							validSubject = "goodSubject"
 | 
				
			||||||
 | 
							validUserId  = 1
 | 
				
			||||||
 | 
							validEmail   = "test@example.com"
 | 
				
			||||||
 | 
							validState   = "goodState"
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name     string
 | 
				
			||||||
 | 
							subject  string
 | 
				
			||||||
 | 
							email    string
 | 
				
			||||||
 | 
							state    string
 | 
				
			||||||
 | 
							wantCode int
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:     "Found Subject",
 | 
				
			||||||
 | 
								subject:  validSubject,
 | 
				
			||||||
 | 
								email:    validEmail,
 | 
				
			||||||
 | 
								state:    validState,
 | 
				
			||||||
 | 
								wantCode: http.StatusSeeOther,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(*testing.T) {
 | 
				
			||||||
 | 
								r, err := http.NewRequest("GET", ts.URL, nil)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								r.URL.Path = "/users/login/oidc/callback"
 | 
				
			||||||
 | 
								q := r.URL.Query()
 | 
				
			||||||
 | 
								q.Add("state", tt.state)
 | 
				
			||||||
 | 
								r.URL.RawQuery = q.Encode()
 | 
				
			||||||
 | 
								c := &http.Cookie{
 | 
				
			||||||
 | 
									Name:     "state",
 | 
				
			||||||
 | 
									Value:    validState,
 | 
				
			||||||
 | 
									MaxAge:   int(time.Hour.Seconds()),
 | 
				
			||||||
 | 
									Secure:   r.TLS != nil,
 | 
				
			||||||
 | 
									HttpOnly: true,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								d := &http.Cookie{
 | 
				
			||||||
 | 
									Name:     "nonce",
 | 
				
			||||||
 | 
									Value:    "nonce",
 | 
				
			||||||
 | 
									MaxAge:   int(time.Hour.Seconds()),
 | 
				
			||||||
 | 
									Secure:   r.TLS != nil,
 | 
				
			||||||
 | 
									HttpOnly: true,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								r.AddCookie(c)
 | 
				
			||||||
 | 
								r.AddCookie(d)
 | 
				
			||||||
 | 
								resp, err := ts.Client().Do(r)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								assert.Equal(t, resp.StatusCode, tt.wantCode)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,7 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
	"unicode"
 | 
						"unicode"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/auth"
 | 
				
			||||||
	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 | 
				
			||||||
	"github.com/alexedwards/scs/sqlite3store"
 | 
						"github.com/alexedwards/scs/sqlite3store"
 | 
				
			||||||
	"github.com/alexedwards/scs/v2"
 | 
						"github.com/alexedwards/scs/v2"
 | 
				
			||||||
@ -28,9 +29,9 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type applicationOauthConfig struct {
 | 
					type applicationOauthConfig struct {
 | 
				
			||||||
	ctx        context.Context
 | 
						ctx        context.Context
 | 
				
			||||||
	config     oauth2.Config
 | 
					 | 
				
			||||||
	provider   *oidc.Provider
 | 
					 | 
				
			||||||
	oidcConfig *oidc.Config
 | 
						oidcConfig *oidc.Config
 | 
				
			||||||
 | 
						config     auth.OAuth2ConfigInterface
 | 
				
			||||||
 | 
						provider   *oidc.Provider
 | 
				
			||||||
	verifier   *oidc.IDTokenVerifier
 | 
						verifier   *oidc.IDTokenVerifier
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -204,7 +205,7 @@ func setupConfig(addr string) (applicationConfig, error) {
 | 
				
			|||||||
		ClientID: clientID,
 | 
							ClientID: clientID,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	o.verifier = provider.Verifier(o.oidcConfig)
 | 
						o.verifier = provider.Verifier(o.oidcConfig)
 | 
				
			||||||
	o.config = oauth2.Config{
 | 
						o.config = &oauth2.Config{
 | 
				
			||||||
		ClientID:     clientID,
 | 
							ClientID:     clientID,
 | 
				
			||||||
		ClientSecret: clientSecret,
 | 
							ClientSecret: clientSecret,
 | 
				
			||||||
		Endpoint:     provider.Endpoint(),
 | 
							Endpoint:     provider.Endpoint(),
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,8 @@ package main
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/rsa"
 | 
				
			||||||
	"html"
 | 
						"html"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"log/slog"
 | 
						"log/slog"
 | 
				
			||||||
@ -15,6 +17,8 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
 | 
				
			||||||
	"github.com/alexedwards/scs/v2"
 | 
						"github.com/alexedwards/scs/v2"
 | 
				
			||||||
 | 
						"github.com/coreos/go-oidc/v3/oidc"
 | 
				
			||||||
 | 
						"github.com/coreos/go-oidc/v3/oidc/oidctest"
 | 
				
			||||||
	"github.com/gorilla/schema"
 | 
						"github.com/gorilla/schema"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -34,9 +38,35 @@ func newTestApplication(t *testing.T) *application {
 | 
				
			|||||||
		guestbookComments: &mocks.GuestbookCommentModel{},
 | 
							guestbookComments: &mocks.GuestbookCommentModel{},
 | 
				
			||||||
		formDecoder:       formDecoder,
 | 
							formDecoder:       formDecoder,
 | 
				
			||||||
		timezones:         getAvailableTimezones(),
 | 
							timezones:         getAvailableTimezones(),
 | 
				
			||||||
 | 
							config: applicationConfig{
 | 
				
			||||||
 | 
								localAuthEnabled: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestKey(t *testing.T) *rsa.PrivateKey {
 | 
				
			||||||
 | 
						priv, err := rsa.GenerateKey(rand.Reader, 2048)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return priv
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestOIDCServer(t *testing.T, priv *rsa.PrivateKey) *testServer {
 | 
				
			||||||
 | 
						s := &oidctest.Server{
 | 
				
			||||||
 | 
							PublicKeys: []oidctest.PublicKey{
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									PublicKey: priv.Public(),
 | 
				
			||||||
 | 
									KeyID:     "test-key",
 | 
				
			||||||
 | 
									Algorithm: oidc.ES256,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ts := httptest.NewServer(s)
 | 
				
			||||||
 | 
						s.SetIssuer(ts.URL)
 | 
				
			||||||
 | 
						return &testServer{ts}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type testServer struct {
 | 
					type testServer struct {
 | 
				
			||||||
	*httptest.Server
 | 
						*httptest.Server
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								internal/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								internal/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					package auth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OAuth2ConfigInterface interface {
 | 
				
			||||||
 | 
						AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
 | 
				
			||||||
 | 
						Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
 | 
				
			||||||
 | 
						Client(ctx context.Context, t *oauth2.Token) *http.Client
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package mocks
 | 
					package mocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 | 
						"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 | 
				
			||||||
@ -76,13 +77,6 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 | 
				
			|||||||
	return 0, models.ErrInvalidCredentials
 | 
						return 0, models.ErrInvalidCredentials
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
 | 
					 | 
				
			||||||
	if email == "test@example.com" {
 | 
					 | 
				
			||||||
		return 1, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return 0, models.ErrInvalidCredentials
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m *UserModel) Exists(id int64) (bool, error) {
 | 
					func (m *UserModel) Exists(id int64) (bool, error) {
 | 
				
			||||||
	switch id {
 | 
						switch id {
 | 
				
			||||||
	case 1:
 | 
						case 1:
 | 
				
			||||||
@ -103,3 +97,24 @@ func (m *UserModel) UpdateUserSettings(userId int64, settings models.UserSetting
 | 
				
			|||||||
func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
 | 
					func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *UserModel) GetBySubject(subject string) (int64, error) {
 | 
				
			||||||
 | 
						if subject == "goodSubject" {
 | 
				
			||||||
 | 
							return 1, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return -1, models.ErrNoRecord
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *UserModel) GetByEmail(email string) (int64, error) {
 | 
				
			||||||
 | 
						if email == "test@example.com" {
 | 
				
			||||||
 | 
							return 1, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return -1, models.ErrNoRecord
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *UserModel) UpdateSubject(userId int64, subject string) error {
 | 
				
			||||||
 | 
						if userId == 1 {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return errors.New("invalid")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -49,12 +49,14 @@ type UserModelInterface interface {
 | 
				
			|||||||
	InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
 | 
						InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
 | 
				
			||||||
	Get(shortId uint64) (User, error)
 | 
						Get(shortId uint64) (User, error)
 | 
				
			||||||
	GetById(id int64) (User, error)
 | 
						GetById(id int64) (User, error)
 | 
				
			||||||
 | 
						GetByEmail(email string) (int64, error)
 | 
				
			||||||
 | 
						GetBySubject(subject string) (int64, error)
 | 
				
			||||||
	GetAll() ([]User, error)
 | 
						GetAll() ([]User, error)
 | 
				
			||||||
	Authenticate(email, password string) (int64, error)
 | 
						Authenticate(email, password string) (int64, error)
 | 
				
			||||||
	AuthenticateByOIDC(email, subject string) (int64, error)
 | 
					 | 
				
			||||||
	Exists(id int64) (bool, error)
 | 
						Exists(id int64) (bool, error)
 | 
				
			||||||
	UpdateUserSettings(userId int64, settings UserSettings) error
 | 
						UpdateUserSettings(userId int64, settings UserSettings) error
 | 
				
			||||||
	UpdateSetting(userId int64, setting Setting, value string) error
 | 
						UpdateSetting(userId int64, setting Setting, value string) error
 | 
				
			||||||
 | 
						UpdateSubject(userId int64, subject string) error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *UserModel) InitializeSettingsMap() error {
 | 
					func (m *UserModel) InitializeSettingsMap() error {
 | 
				
			||||||
@ -292,51 +294,44 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 | 
				
			|||||||
	return id, nil
 | 
						return id, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) {
 | 
					func (m *UserModel) GetByEmail(email string) (int64, error) {
 | 
				
			||||||
	var id int64
 | 
						var id int64
 | 
				
			||||||
	var s sql.NullString
 | 
						stmt := `SELECT Id FROM users WHERE Email = ?`
 | 
				
			||||||
	tx, err := m.DB.Begin()
 | 
						err := m.DB.QueryRow(stmt, email).Scan(&id)
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return -1, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
 | 
					 | 
				
			||||||
	err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		if errors.Is(err, sql.ErrNoRows) {
 | 
							if errors.Is(err, sql.ErrNoRows) {
 | 
				
			||||||
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
					 | 
				
			||||||
				return -1, err
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return -1, ErrNoRecord
 | 
								return -1, ErrNoRecord
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
					 | 
				
			||||||
			return -1, err
 | 
								return -1, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
			return -1, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !s.Valid {
 | 
					 | 
				
			||||||
		stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
 | 
					 | 
				
			||||||
		_, err = tx.Exec(stmt, subject, id)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
					 | 
				
			||||||
				return -1, err
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return -1, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else if subject != s.String {
 | 
					 | 
				
			||||||
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
 | 
					 | 
				
			||||||
			return -1, ErrInvalidCredentials
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = tx.Commit()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return -1, err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return id, nil
 | 
						return id, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *UserModel) GetBySubject(subject string) (int64, error) {
 | 
				
			||||||
 | 
						var id int64
 | 
				
			||||||
 | 
						var s sql.NullString
 | 
				
			||||||
 | 
						stmt := `SELECT Id, OIDCSubject FROM users WHERE OIDCSubject = ?`
 | 
				
			||||||
 | 
						err := m.DB.QueryRow(stmt, subject).Scan(&id, &s)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if errors.Is(err, sql.ErrNoRows) {
 | 
				
			||||||
 | 
								return -1, ErrNoRecord
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return -1, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return id, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *UserModel) UpdateSubject(userId int64, subject string) error {
 | 
				
			||||||
 | 
						stmt := `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
 | 
				
			||||||
 | 
						_, err := m.DB.Exec(stmt, subject, userId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *UserModel) Exists(id int64) (bool, error) {
 | 
					func (m *UserModel) Exists(id int64) (bool, error) {
 | 
				
			||||||
	var exists bool
 | 
						var exists bool
 | 
				
			||||||
	stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
 | 
						stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user