From 58e6f35585928aacca8bf9789fea4c16c800bd56 Mon Sep 17 00:00:00 2001
From: yequari 
Date: Sun, 29 Jun 2025 17:19:35 -0700
Subject: [PATCH 1/5] implement OIDC login flow
---
 .gitignore                                    |  5 +-
 cmd/web/handlers_user.go                      | 83 ++++++++++++++++
 cmd/web/helpers.go                            | 22 +++++
 cmd/web/main.go                               | 57 ++++++++++-
 cmd/web/routes.go                             |  2 +
 go.mod                                        |  5 +
 go.sum                                        | 21 +++-
 internal/models/mocks/users.go                | 16 +++
 internal/models/user.go                       | 98 +++++++++++++++++++
 .../000006_change_password_field.down.sql     |  4 +
 .../000006_change_password_field.up.sql       |  4 +
 migrations/000007_add_oidc_subject.down.sql   |  1 +
 migrations/000007_add_oidc_subject.up.sql     |  1 +
 ui/views/common.templ                         |  1 +
 ui/views/common_templ.go                      |  8 +-
 ui/views/users.templ                          |  3 +-
 ui/views/users_templ.go                       | 28 +++---
 17 files changed, 333 insertions(+), 26 deletions(-)
 create mode 100644 migrations/000006_change_password_field.down.sql
 create mode 100644 migrations/000006_change_password_field.up.sql
 create mode 100644 migrations/000007_add_oidc_subject.down.sql
 create mode 100644 migrations/000007_add_oidc_subject.up.sql
diff --git a/.gitignore b/.gitignore
index d803d48..319a781 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,4 +31,7 @@ tls/
 test.db.old
 .gitignore
 .nvim/session
-*templ.txt
\ No newline at end of file
+*templ.txt
+
+# env files
+.env*
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 8690065..3c21fde 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -9,6 +9,7 @@ import (
 	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 	"git.32bit.cafe/32bitcafe/guestbook/internal/validator"
 	"git.32bit.cafe/32bitcafe/guestbook/ui/views"
+	"github.com/coreos/go-oidc/v3/oidc"
 )
 
 func (app *application) getUserRegister(w http.ResponseWriter, r *http.Request) {
@@ -92,6 +93,88 @@ func (app *application) postUserLogin(w http.ResponseWriter, r *http.Request) {
 	http.Redirect(w, r, "/", http.StatusSeeOther)
 }
 
+func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
+	state, err := randString(16)
+	if err != nil {
+		app.serverError(w, r, err)
+		return
+	}
+	nonce, err := randString(16)
+	if err != nil {
+		app.serverError(w, r, err)
+		return
+	}
+
+	setCallbackCookie(w, r, "state", state)
+	setCallbackCookie(w, r, "nonce", nonce)
+
+	http.Redirect(w, r, app.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
+}
+
+func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Request) {
+	state, err := r.Cookie("state")
+	if err != nil {
+		app.clientError(w, http.StatusBadRequest)
+		return
+	}
+	if r.URL.Query().Get("state") != state.Value {
+		app.clientError(w, http.StatusBadRequest)
+		return
+	}
+
+	oauth2Token, err := app.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
+	if err != nil {
+		app.logger.Error("Failed to exchange token")
+		app.serverError(w, r, err)
+		return
+	}
+	rawIDToken, ok := oauth2Token.Extra("id_token").(string)
+	if !ok {
+		app.serverError(w, r, errors.New("No id_token field in oauth2 token"))
+		return
+	}
+	idToken, err := app.oauth.verifier.Verify(r.Context(), rawIDToken)
+	if err != nil {
+		app.logger.Error("Failed to verify ID token")
+		app.serverError(w, r, err)
+		return
+	}
+
+	nonce, err := r.Cookie("nonce")
+	if err != nil {
+		app.logger.Error("nonce not found")
+		app.serverError(w, r, err)
+		return
+	}
+	if idToken.Nonce != nonce.Value {
+		app.serverError(w, r, errors.New("nonce did not match"))
+		return
+	}
+
+	oauth2Token.AccessToken = "*REDACTED*"
+
+	var t models.UserIdToken
+	if err := idToken.Claims(&t); err != nil {
+		app.serverError(w, r, err)
+		return
+	}
+
+	err = app.sessionManager.RenewToken(r.Context())
+	if err != nil {
+		app.serverError(w, r, err)
+		return
+	}
+	id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
+	if err != nil {
+		id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
+		if err != nil {
+			app.serverError(w, r, err)
+		}
+	}
+	app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
+	http.Redirect(w, r, "/", http.StatusSeeOther)
+}
+
 func (app *application) postUserLogout(w http.ResponseWriter, r *http.Request) {
 	err := app.sessionManager.RenewToken(r.Context())
 	if err != nil {
diff --git a/cmd/web/helpers.go b/cmd/web/helpers.go
index f86b71b..239aea2 100644
--- a/cmd/web/helpers.go
+++ b/cmd/web/helpers.go
@@ -1,8 +1,11 @@
 package main
 
 import (
+	"crypto/rand"
+	"encoding/base64"
 	"errors"
 	"fmt"
+	"io"
 	"math"
 	"net/http"
 	"net/url"
@@ -140,3 +143,22 @@ func matchOrigin(origin string, u *url.URL) bool {
 	}
 	return true
 }
+
+func randString(nByte int) (string, error) {
+	b := make([]byte, nByte)
+	if _, err := io.ReadFull(rand.Reader, b); err != nil {
+		return "", err
+	}
+	return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) {
+	c := &http.Cookie{
+		Name:     name,
+		Value:    value,
+		MaxAge:   int(time.Hour.Seconds()),
+		Secure:   r.TLS != nil,
+		HttpOnly: true,
+	}
+	http.SetCookie(w, c)
+}
diff --git a/cmd/web/main.go b/cmd/web/main.go
index 5fe6dac..e999e92 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -1,9 +1,12 @@
 package main
 
 import (
+	"context"
 	"crypto/tls"
 	"database/sql"
+	"errors"
 	"flag"
+	"fmt"
 	"log/slog"
 	"net/http"
 	"os"
@@ -14,10 +17,21 @@ import (
 	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 	"github.com/alexedwards/scs/sqlite3store"
 	"github.com/alexedwards/scs/v2"
+	"github.com/coreos/go-oidc/v3/oidc"
 	"github.com/gorilla/schema"
+	"github.com/joho/godotenv"
 	_ "github.com/mattn/go-sqlite3"
+	"golang.org/x/oauth2"
 )
 
+type applicationOauthConfig struct {
+	ctx        context.Context
+	config     oauth2.Config
+	provider   *oidc.Provider
+	oidcConfig *oidc.Config
+	verifier   *oidc.IDTokenVerifier
+}
+
 type application struct {
 	sequence          uint16
 	logger            *slog.Logger
@@ -26,6 +40,7 @@ type application struct {
 	guestbookComments models.GuestbookCommentModelInterface
 	sessionManager    *scs.SessionManager
 	formDecoder       *schema.Decoder
+	oauth             applicationOauthConfig
 	debug             bool
 	timezones         []string
 	rootUrl           string
@@ -35,10 +50,11 @@ func main() {
 	addr := flag.String("addr", ":3000", "HTTP network address")
 	dsn := flag.String("dsn", "guestbook.db", "data source name")
 	debug := flag.Bool("debug", false, "enable debug mode")
-	root := flag.String("root", "localhost:3000", "root URL of application")
+	root := flag.String("root", "https://localhost:3000", "root URL of application")
 	flag.Parse()
 
 	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
+	godotenv.Load(".env.dev")
 
 	db, err := openDB(*dsn)
 	if err != nil {
@@ -67,6 +83,13 @@ func main() {
 		rootUrl:           *root,
 	}
 
+	o, err := setupOauth(app.rootUrl)
+	if err != nil {
+		logger.Error(err.Error())
+		os.Exit(1)
+	}
+	app.oauth = o
+
 	err = app.users.InitializeSettingsMap()
 	if err != nil {
 		logger.Error(err.Error())
@@ -114,6 +137,38 @@ func openDB(dsn string) (*sql.DB, error) {
 	return db, nil
 }
 
+func setupOauth(rootUrl string) (applicationOauthConfig, error) {
+	var c applicationOauthConfig
+	var (
+		oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
+		clientID       = os.Getenv("OAUTH2_CLIENT_ID")
+		clientSecret   = os.Getenv("OAUTH2_CLIENT_SECRET")
+	)
+	if oauth2Provider == "" || clientID == "" || clientSecret == "" {
+		return applicationOauthConfig{}, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+	}
+
+	c.ctx = context.Background()
+	provider, err := oidc.NewProvider(c.ctx, oauth2Provider)
+	if err != nil {
+		return applicationOauthConfig{}, err
+	}
+	c.provider = provider
+	c.oidcConfig = &oidc.Config{
+		ClientID: clientID,
+	}
+	c.verifier = provider.Verifier(c.oidcConfig)
+	c.config = oauth2.Config{
+		ClientID:     clientID,
+		ClientSecret: clientSecret,
+		Endpoint:     provider.Endpoint(),
+		RedirectURL:  fmt.Sprintf("%s/users/login/oidc/callback", rootUrl),
+		Scopes:       []string{oidc.ScopeOpenID, "profile", "email"},
+	}
+	return c, nil
+
+}
+
 func getAvailableTimezones() []string {
 	var zones []string
 	var zoneDirs = []string{
diff --git a/cmd/web/routes.go b/cmd/web/routes.go
index f9e768e..4f32a1b 100644
--- a/cmd/web/routes.go
+++ b/cmd/web/routes.go
@@ -27,6 +27,8 @@ func (app *application) routes() http.Handler {
 	mux.Handle("POST /users/register", dynamic.ThenFunc(app.postUserRegister))
 	mux.Handle("GET /users/login", dynamic.ThenFunc(app.getUserLogin))
 	mux.Handle("POST /users/login", dynamic.ThenFunc(app.postUserLogin))
+	mux.Handle("/users/login/oidc", dynamic.ThenFunc(app.userLoginOIDC))
+	mux.Handle("/users/login/oidc/callback", dynamic.ThenFunc(app.userLoginOIDCCallback))
 	mux.Handle("GET /help", dynamic.ThenFunc(app.notImplemented))
 
 	protected := dynamic.Append(app.requireAuthentication)
diff --git a/go.mod b/go.mod
index 46407e0..dd2b0fc 100644
--- a/go.mod
+++ b/go.mod
@@ -6,9 +6,14 @@ require (
 	github.com/a-h/templ v0.3.833
 	github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c
 	github.com/alexedwards/scs/v2 v2.8.0
+	github.com/coreos/go-oidc/v3 v3.14.1
 	github.com/gorilla/schema v1.4.1
+	github.com/joho/godotenv v1.5.1
 	github.com/justinas/alice v1.2.0
 	github.com/justinas/nosurf v1.1.1
 	github.com/mattn/go-sqlite3 v1.14.24
 	golang.org/x/crypto v0.36.0
+	golang.org/x/oauth2 v0.30.0
 )
+
+require github.com/go-jose/go-jose/v4 v4.0.5 // indirect
diff --git a/go.sum b/go.sum
index 19a290f..f2165b3 100644
--- a/go.sum
+++ b/go.sum
@@ -1,24 +1,35 @@
 github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
 github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
-github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885 h1:+DCxWg/ojncqS+TGAuRUoV7OfG/S4doh0pcpAwEcow0=
-github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885/go.mod h1:Iyk7S76cxGaiEX/mSYmTZzYehp4KfyylcLaV3OnToss=
 github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c h1:0gBCIsmH3+aaWK55APhhY7/Z+uv5IdbMqekI97V9shU=
 github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c/go.mod h1:Iyk7S76cxGaiEX/mSYmTZzYehp4KfyylcLaV3OnToss=
 github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw=
 github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
+github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
+github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
+github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
 github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
 github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
 github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
 github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
 github.com/justinas/nosurf v1.1.1 h1:92Aw44hjSK4MxJeMSyDa7jwuI9GR2J/JCQiaKvXXSlk=
 github.com/justinas/nosurf v1.1.1/go.mod h1:ALpWdSbuNGy2lZWtyXdjkYv4edL23oSEgfBT1gPJ5BQ=
-github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
 github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
 github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
 github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
-golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
 golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
 golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/models/mocks/users.go b/internal/models/mocks/users.go
index 2e5ff45..e0ab8a8 100644
--- a/internal/models/mocks/users.go
+++ b/internal/models/mocks/users.go
@@ -38,6 +38,15 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo
 	}
 }
 
+func (m *UserModel) InsertWithoutPassword(shortId uint64, username string, email string, password string, settings models.UserSettings) (int64, error) {
+	switch email {
+	case "dupe@example.com":
+		return -1, models.ErrDuplicateEmail
+	default:
+		return 2, nil
+	}
+}
+
 func (m *UserModel) Get(shortId uint64) (models.User, error) {
 	switch shortId {
 	case 1:
@@ -67,6 +76,13 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 	return 0, models.ErrInvalidCredentials
 }
 
+func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
+	if email == "test@example.com" {
+		return 1, nil
+	}
+	return 0, models.ErrInvalidCredentials
+}
+
 func (m *UserModel) Exists(id int64) (bool, error) {
 	switch id {
 	case 1:
diff --git a/internal/models/user.go b/internal/models/user.go
index f7c5ab0..12a86e0 100644
--- a/internal/models/user.go
+++ b/internal/models/user.go
@@ -35,13 +35,23 @@ type UserModel struct {
 	Settings map[string]Setting
 }
 
+type UserIdToken struct {
+	Subject       string   `json:"sub"`
+	Email         string   `json:"email"`
+	EmailVerified bool     `json:"email_verified"`
+	Username      string   `json:"preferred_username"`
+	Groups        []string `json:"groups"`
+}
+
 type UserModelInterface interface {
 	InitializeSettingsMap() error
 	Insert(shortId uint64, username string, email string, password string, settings UserSettings) error
+	InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
 	Get(shortId uint64) (User, error)
 	GetById(id int64) (User, error)
 	GetAll() ([]User, error)
 	Authenticate(email, password string) (int64, error)
+	AuthenticateByOIDC(email, subject string) (int64, error)
 	Exists(id int64) (bool, error)
 	UpdateUserSettings(userId int64, settings UserSettings) error
 	UpdateSetting(userId int64, setting Setting, value string) error
@@ -126,6 +136,49 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo
 	return nil
 }
 
+func (m *UserModel) InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error) {
+	stmt := `INSERT INTO users (ShortId, Username, Email, IsBanned, OIDCSubject, Created)
+    VALUES (?, ?, ?, FALSE, ?, ?)`
+	tx, err := m.DB.Begin()
+	if err != nil {
+		return -1, err
+	}
+	result, err := tx.Exec(stmt, shortId, username, email, subject, time.Now().UTC())
+	if err != nil {
+		if rollbackErr := tx.Rollback(); rollbackErr != nil {
+			return -1, err
+		}
+		if sqliteError, ok := err.(sqlite3.Error); ok {
+			if sqliteError.ExtendedCode == 2067 && strings.Contains(sqliteError.Error(), "Email") {
+				return -1, ErrDuplicateEmail
+			}
+		}
+		return -1, err
+	}
+	id, err := result.LastInsertId()
+	if err != nil {
+		if rollbackErr := tx.Rollback(); rollbackErr != nil {
+			return -1, err
+		}
+		return -1, err
+	}
+	stmt = `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue) 
+		VALUES (?, ?, ?, ?)`
+	_, err = tx.Exec(stmt, id, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String())
+	if err != nil {
+		if rollbackErr := tx.Rollback(); rollbackErr != nil {
+			return -1, err
+		}
+		return -1, err
+	}
+	err = tx.Commit()
+	if err != nil {
+		return -1, err
+	}
+
+	return id, nil
+}
+
 func (m *UserModel) Get(shortId uint64) (User, error) {
 	stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE ShortId = ? AND Deleted IS NULL`
 	tx, err := m.DB.Begin()
@@ -239,6 +292,51 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 	return id, nil
 }
 
+func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) {
+	var id int64
+	var s sql.NullString
+	tx, err := m.DB.Begin()
+	if err != nil {
+		return -1, err
+	}
+	stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
+	err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
+	if err != nil {
+		if errors.Is(err, sql.ErrNoRows) {
+			if rollbackErr := tx.Rollback(); rollbackErr != nil {
+				return -1, err
+			}
+			return -1, ErrNoRecord
+		} else {
+			if rollbackErr := tx.Rollback(); rollbackErr != nil {
+				return -1, err
+			}
+			return -1, err
+		}
+	}
+
+	if !s.Valid {
+		stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
+		_, err = tx.Exec(stmt, subject, id)
+		if err != nil {
+			if rollbackErr := tx.Rollback(); rollbackErr != nil {
+				return -1, err
+			}
+			return -1, err
+		}
+	} else if subject != s.String {
+		if rollbackErr := tx.Rollback(); rollbackErr != nil {
+			return -1, ErrInvalidCredentials
+		}
+	}
+
+	err = tx.Commit()
+	if err != nil {
+		return -1, err
+	}
+	return id, nil
+}
+
 func (m *UserModel) Exists(id int64) (bool, error) {
 	var exists bool
 	stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
diff --git a/migrations/000006_change_password_field.down.sql b/migrations/000006_change_password_field.down.sql
new file mode 100644
index 0000000..794b222
--- /dev/null
+++ b/migrations/000006_change_password_field.down.sql
@@ -0,0 +1,4 @@
+ALTER TABLE users RENAME COLUMN HashedPassword TO HashedPasswordOld;
+ALTER TABLE users ADD COLUMN HashedPassword char(60) NOT NULL DEFAULT '0000';
+UPDATE users SET HashedPassword=HashedPasswordOld;
+ALTER TABLE users DROP COLUMN HashedPasswordOld;
diff --git a/migrations/000006_change_password_field.up.sql b/migrations/000006_change_password_field.up.sql
new file mode 100644
index 0000000..e80c883
--- /dev/null
+++ b/migrations/000006_change_password_field.up.sql
@@ -0,0 +1,4 @@
+ALTER TABLE users RENAME COLUMN HashedPassword TO HashedPasswordOld;
+ALTER TABLE users ADD COLUMN HashedPassword char(60) NULL;
+UPDATE users SET HashedPassword=HashedPasswordOld;
+ALTER TABLE users DROP COLUMN HashedPasswordOld;
diff --git a/migrations/000007_add_oidc_subject.down.sql b/migrations/000007_add_oidc_subject.down.sql
new file mode 100644
index 0000000..4a6a7e9
--- /dev/null
+++ b/migrations/000007_add_oidc_subject.down.sql
@@ -0,0 +1 @@
+ALTER TABLE users DROP COLUMN OIDCSubject;
diff --git a/migrations/000007_add_oidc_subject.up.sql b/migrations/000007_add_oidc_subject.up.sql
new file mode 100644
index 0000000..6edefa4
--- /dev/null
+++ b/migrations/000007_add_oidc_subject.up.sql
@@ -0,0 +1 @@
+ALTER TABLE users ADD COLUMN OIDCSubject varchar(255);
diff --git a/ui/views/common.templ b/ui/views/common.templ
index a9111f1..5df9f50 100644
--- a/ui/views/common.templ
+++ b/ui/views/common.templ
@@ -12,6 +12,7 @@ type CommonData struct {
 	CSRFToken       string
 	CurrentUser     *models.User
 	IsHtmx          bool
+	RootUrl         string
 }
 
 func shortIdToSlug(shortId uint64) string {
diff --git a/ui/views/common_templ.go b/ui/views/common_templ.go
index e79039c..6a7d8aa 100644
--- a/ui/views/common_templ.go
+++ b/ui/views/common_templ.go
@@ -102,7 +102,7 @@ func topNav(data CommonData) templ.Component {
 			var templ_7745c5c3_Var3 string
 			templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(data.CurrentUser.Username)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 44, Col: 40}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 45, Col: 40}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
 			if templ_7745c5c3_Err != nil {
@@ -121,7 +121,7 @@ func topNav(data CommonData) templ.Component {
 			var templ_7745c5c3_Var4 string
 			templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(hxHeaders)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 52, Col: 62}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 53, Col: 62}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
 			if templ_7745c5c3_Err != nil {
@@ -202,7 +202,7 @@ func base(title string, data CommonData) templ.Component {
 		var templ_7745c5c3_Var7 string
 		templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(title)
 		if templ_7745c5c3_Err != nil {
-			return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 71, Col: 17}
+			return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 72, Col: 17}
 		}
 		_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
 		if templ_7745c5c3_Err != nil {
@@ -232,7 +232,7 @@ func base(title string, data CommonData) templ.Component {
 			var templ_7745c5c3_Var8 string
 			templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(data.Flash)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 83, Col: 36}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 84, Col: 36}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
 			if templ_7745c5c3_Err != nil {
diff --git a/ui/views/users.templ b/ui/views/users.templ
index 4fb839b..a0d60b3 100644
--- a/ui/views/users.templ
+++ b/ui/views/users.templ
@@ -30,7 +30,8 @@ templ UserLogin(title string, data CommonData, form forms.UserLoginForm) {
 				
 			
 			
 		
 	}
diff --git a/ui/views/users_templ.go b/ui/views/users_templ.go
index c791e2d..391f030 100644
--- a/ui/views/users_templ.go
+++ b/ui/views/users_templ.go
@@ -143,7 +143,7 @@ func UserLogin(title string, data CommonData, form forms.UserLoginForm) templ.Co
 					return templ_7745c5c3_Err
 				}
 			}
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
@@ -197,7 +197,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 			var templ_7745c5c3_Var10 string
 			templ_7745c5c3_Var10, templ_7745c5c3_Err = templ.JoinStringErrs(data.CSRFToken)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 43, Col: 64}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 44, Col: 64}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var10))
 			if templ_7745c5c3_Err != nil {
@@ -220,7 +220,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 				var templ_7745c5c3_Var11 string
 				templ_7745c5c3_Var11, templ_7745c5c3_Err = templ.JoinStringErrs(error)
 				if templ_7745c5c3_Err != nil {
-					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 48, Col: 33}
+					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 49, Col: 33}
 				}
 				_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var11))
 				if templ_7745c5c3_Err != nil {
@@ -238,7 +238,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 			var templ_7745c5c3_Var12 string
 			templ_7745c5c3_Var12, templ_7745c5c3_Err = templ.JoinStringErrs(form.Name)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 50, Col: 70}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 51, Col: 70}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var12))
 			if templ_7745c5c3_Err != nil {
@@ -261,7 +261,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 				var templ_7745c5c3_Var13 string
 				templ_7745c5c3_Var13, templ_7745c5c3_Err = templ.JoinStringErrs(error)
 				if templ_7745c5c3_Err != nil {
-					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 56, Col: 33}
+					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 57, Col: 33}
 				}
 				_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var13))
 				if templ_7745c5c3_Err != nil {
@@ -279,7 +279,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 			var templ_7745c5c3_Var14 string
 			templ_7745c5c3_Var14, templ_7745c5c3_Err = templ.JoinStringErrs(form.Email)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 58, Col: 65}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 59, Col: 65}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var14))
 			if templ_7745c5c3_Err != nil {
@@ -302,7 +302,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 				var templ_7745c5c3_Var15 string
 				templ_7745c5c3_Var15, templ_7745c5c3_Err = templ.JoinStringErrs(error)
 				if templ_7745c5c3_Err != nil {
-					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 64, Col: 33}
+					return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 65, Col: 33}
 				}
 				_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var15))
 				if templ_7745c5c3_Err != nil {
@@ -367,7 +367,7 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
 			var templ_7745c5c3_Var18 string
 			templ_7745c5c3_Var18, templ_7745c5c3_Err = templ.JoinStringErrs(user.Username)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 77, Col: 21}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 21}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var18))
 			if templ_7745c5c3_Err != nil {
@@ -380,7 +380,7 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
 			var templ_7745c5c3_Var19 string
 			templ_7745c5c3_Var19, templ_7745c5c3_Err = templ.JoinStringErrs(user.Email)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 17}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 79, Col: 17}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var19))
 			if templ_7745c5c3_Err != nil {
@@ -441,7 +441,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 			var templ_7745c5c3_Var22 string
 			templ_7745c5c3_Var22, templ_7745c5c3_Err = templ.JoinStringErrs(data.CSRFToken)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 88, Col: 65}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 89, Col: 65}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var22))
 			if templ_7745c5c3_Err != nil {
@@ -460,7 +460,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 					var templ_7745c5c3_Var23 string
 					templ_7745c5c3_Var23, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
 					if templ_7745c5c3_Err != nil {
-						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 93, Col: 25}
+						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 94, Col: 25}
 					}
 					_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var23))
 					if templ_7745c5c3_Err != nil {
@@ -473,7 +473,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 					var templ_7745c5c3_Var24 string
 					templ_7745c5c3_Var24, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
 					if templ_7745c5c3_Err != nil {
-						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 93, Col: 48}
+						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 94, Col: 48}
 					}
 					_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var24))
 					if templ_7745c5c3_Err != nil {
@@ -491,7 +491,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 					var templ_7745c5c3_Var25 string
 					templ_7745c5c3_Var25, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
 					if templ_7745c5c3_Err != nil {
-						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 95, Col: 25}
+						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 96, Col: 25}
 					}
 					_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var25))
 					if templ_7745c5c3_Err != nil {
@@ -504,7 +504,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 					var templ_7745c5c3_Var26 string
 					templ_7745c5c3_Var26, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
 					if templ_7745c5c3_Err != nil {
-						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 95, Col: 32}
+						return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 96, Col: 32}
 					}
 					_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var26))
 					if templ_7745c5c3_Err != nil {
-- 
2.30.2
From c56a445c6a0641485591bb0ed59d97f3cf9ccadb Mon Sep 17 00:00:00 2001
From: yequari 
Date: Sun, 29 Jun 2025 18:12:32 -0700
Subject: [PATCH 2/5] add app configuration
---
 cmd/web/handlers_user.go | 15 +++++--
 cmd/web/helpers.go       | 16 ++++---
 cmd/web/main.go          | 96 +++++++++++++++++++++++++++++-----------
 ui/views/common.templ    | 20 +++++----
 ui/views/common_templ.go | 50 ++++++++++++---------
 ui/views/users.templ     |  4 +-
 ui/views/users_templ.go  | 94 +++++++++++++++++++++------------------
 7 files changed, 186 insertions(+), 109 deletions(-)
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 3c21fde..141d459 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -13,12 +13,18 @@ import (
 )
 
 func (app *application) getUserRegister(w http.ResponseWriter, r *http.Request) {
+	if !app.config.localAuthEnabled {
+		http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+	}
 	form := forms.UserRegistrationForm{}
 	data := app.newCommonData(r)
 	views.UserRegistration("User Registration", data, form).Render(r.Context(), w)
 }
 
 func (app *application) getUserLogin(w http.ResponseWriter, r *http.Request) {
+	if !app.config.localAuthEnabled {
+		http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+	}
 	views.UserLogin("Login", app.newCommonData(r), forms.UserLoginForm{}).Render(r.Context(), w)
 }
 
@@ -94,6 +100,9 @@ func (app *application) postUserLogin(w http.ResponseWriter, r *http.Request) {
 }
 
 func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
+	if !app.config.oauthEnabled {
+		http.Redirect(w, r, "/users/login", http.StatusFound)
+	}
 	state, err := randString(16)
 	if err != nil {
 		app.serverError(w, r, err)
@@ -108,7 +117,7 @@ func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
 	setCallbackCookie(w, r, "state", state)
 	setCallbackCookie(w, r, "nonce", nonce)
 
-	http.Redirect(w, r, app.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
+	http.Redirect(w, r, app.config.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
 }
 
 func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Request) {
@@ -122,7 +131,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
 		return
 	}
 
-	oauth2Token, err := app.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
+	oauth2Token, err := app.config.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
 	if err != nil {
 		app.logger.Error("Failed to exchange token")
 		app.serverError(w, r, err)
@@ -133,7 +142,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
 		app.serverError(w, r, errors.New("No id_token field in oauth2 token"))
 		return
 	}
-	idToken, err := app.oauth.verifier.Verify(r.Context(), rawIDToken)
+	idToken, err := app.config.oauth.verifier.Verify(r.Context(), rawIDToken)
 	if err != nil {
 		app.logger.Error("Failed to verify ID token")
 		app.serverError(w, r, err)
diff --git a/cmd/web/helpers.go b/cmd/web/helpers.go
index 239aea2..8e66c3c 100644
--- a/cmd/web/helpers.go
+++ b/cmd/web/helpers.go
@@ -107,13 +107,15 @@ func (app *application) getCurrentUser(r *http.Request) *models.User {
 
 func (app *application) newCommonData(r *http.Request) views.CommonData {
 	return views.CommonData{
-		CurrentYear:     time.Now().Year(),
-		Flash:           app.sessionManager.PopString(r.Context(), "flash"),
-		IsAuthenticated: app.isAuthenticated(r),
-		CSRFToken:       nosurf.Token(r),
-		CurrentUser:     app.getCurrentUser(r),
-		IsHtmx:          r.Header.Get("Hx-Request") == "true",
-		RootUrl:         app.rootUrl,
+		CurrentYear:      time.Now().Year(),
+		Flash:            app.sessionManager.PopString(r.Context(), "flash"),
+		IsAuthenticated:  app.isAuthenticated(r),
+		CSRFToken:        nosurf.Token(r),
+		CurrentUser:      app.getCurrentUser(r),
+		IsHtmx:           r.Header.Get("Hx-Request") == "true",
+		RootUrl:          app.config.rootUrl,
+		LocalAuthEnabled: app.config.localAuthEnabled,
+		OIDCEnabled:      app.config.oauthEnabled,
 	}
 }
 
diff --git a/cmd/web/main.go b/cmd/web/main.go
index e999e92..7df1d2b 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -9,7 +9,9 @@ import (
 	"fmt"
 	"log/slog"
 	"net/http"
+	"net/url"
 	"os"
+	"strconv"
 	"strings"
 	"time"
 	"unicode"
@@ -32,6 +34,13 @@ type applicationOauthConfig struct {
 	verifier   *oidc.IDTokenVerifier
 }
 
+type applicationConfig struct {
+	oauthEnabled     bool
+	localAuthEnabled bool
+	oauth            applicationOauthConfig
+	rootUrl          string
+}
+
 type application struct {
 	sequence          uint16
 	logger            *slog.Logger
@@ -40,21 +49,24 @@ type application struct {
 	guestbookComments models.GuestbookCommentModelInterface
 	sessionManager    *scs.SessionManager
 	formDecoder       *schema.Decoder
-	oauth             applicationOauthConfig
+	config            applicationConfig
 	debug             bool
 	timezones         []string
-	rootUrl           string
 }
 
 func main() {
 	addr := flag.String("addr", ":3000", "HTTP network address")
 	dsn := flag.String("dsn", "guestbook.db", "data source name")
 	debug := flag.Bool("debug", false, "enable debug mode")
-	root := flag.String("root", "https://localhost:3000", "root URL of application")
 	flag.Parse()
 
 	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
 	godotenv.Load(".env.dev")
+	cfg, err := setupConfig(*addr)
+	if err != nil {
+		logger.Error(err.Error())
+		os.Exit(1)
+	}
 
 	db, err := openDB(*dsn)
 	if err != nil {
@@ -78,18 +90,11 @@ func main() {
 		users:             &models.UserModel{DB: db, Settings: make(map[string]models.Setting)},
 		guestbookComments: &models.GuestbookCommentModel{DB: db},
 		formDecoder:       formDecoder,
+		config:            cfg,
 		debug:             *debug,
 		timezones:         getAvailableTimezones(),
-		rootUrl:           *root,
 	}
 
-	o, err := setupOauth(app.rootUrl)
-	if err != nil {
-		logger.Error(err.Error())
-		os.Exit(1)
-	}
-	app.oauth = o
-
 	err = app.users.InitializeSettingsMap()
 	if err != nil {
 		logger.Error(err.Error())
@@ -137,36 +142,73 @@ func openDB(dsn string) (*sql.DB, error) {
 	return db, nil
 }
 
-func setupOauth(rootUrl string) (applicationOauthConfig, error) {
-	var c applicationOauthConfig
+func setupConfig(addr string) (applicationConfig, error) {
+	var c applicationConfig
+
 	var (
-		oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
-		clientID       = os.Getenv("OAUTH2_CLIENT_ID")
-		clientSecret   = os.Getenv("OAUTH2_CLIENT_SECRET")
+		rootUrl           = os.Getenv("ROOT_URL")
+		oidcEnabled       = os.Getenv("ENABLE_OIDC")
+		localLoginEnabled = os.Getenv("ENABLE_LOCAL_LOGIN")
+		oauth2Provider    = os.Getenv("OAUTH2_PROVIDER")
+		clientID          = os.Getenv("OAUTH2_CLIENT_ID")
+		clientSecret      = os.Getenv("OAUTH2_CLIENT_SECRET")
 	)
-	if oauth2Provider == "" || clientID == "" || clientSecret == "" {
-		return applicationOauthConfig{}, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+	if rootUrl != "" {
+		c.rootUrl = rootUrl
+	} else {
+		u, err := url.Parse(fmt.Sprintf("https://localhost%s", addr))
+		if err != nil {
+			return c, err
+		}
+		c.rootUrl = u.String()
 	}
 
-	c.ctx = context.Background()
-	provider, err := oidc.NewProvider(c.ctx, oauth2Provider)
+	oauthEnabled, err := strconv.ParseBool(oidcEnabled)
 	if err != nil {
-		return applicationOauthConfig{}, err
+		c.oauthEnabled = false
 	}
-	c.provider = provider
-	c.oidcConfig = &oidc.Config{
+	c.oauthEnabled = oauthEnabled
+
+	localAuthEnabled, err := strconv.ParseBool(localLoginEnabled)
+	if err != nil {
+		c.localAuthEnabled = true
+	}
+	c.localAuthEnabled = localAuthEnabled
+
+	if !c.oauthEnabled && !c.localAuthEnabled {
+		return c, errors.New("Either ENABLE_OIDC or ENABLE_LOCAL_LOGIN must be set to true")
+	}
+
+	// if OIDC is disabled, no more configuration needs to be read
+	if !oauthEnabled {
+		return c, nil
+	}
+
+	var o applicationOauthConfig
+	if oauth2Provider == "" || clientID == "" || clientSecret == "" {
+		return c, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+	}
+
+	o.ctx = context.Background()
+	provider, err := oidc.NewProvider(o.ctx, oauth2Provider)
+	if err != nil {
+		return c, err
+	}
+	o.provider = provider
+	o.oidcConfig = &oidc.Config{
 		ClientID: clientID,
 	}
-	c.verifier = provider.Verifier(c.oidcConfig)
-	c.config = oauth2.Config{
+	o.verifier = provider.Verifier(o.oidcConfig)
+	o.config = oauth2.Config{
 		ClientID:     clientID,
 		ClientSecret: clientSecret,
 		Endpoint:     provider.Endpoint(),
-		RedirectURL:  fmt.Sprintf("%s/users/login/oidc/callback", rootUrl),
+		RedirectURL:  fmt.Sprintf("%s/users/login/oidc/callback", c.rootUrl),
 		Scopes:       []string{oidc.ScopeOpenID, "profile", "email"},
 	}
-	return c, nil
 
+	c.oauth = o
+	return c, nil
 }
 
 func getAvailableTimezones() []string {
diff --git a/ui/views/common.templ b/ui/views/common.templ
index 5df9f50..bafe90f 100644
--- a/ui/views/common.templ
+++ b/ui/views/common.templ
@@ -6,13 +6,15 @@ import "fmt"
 import "strings"
 
 type CommonData struct {
-	CurrentYear     int
-	Flash           string
-	IsAuthenticated bool
-	CSRFToken       string
-	CurrentUser     *models.User
-	IsHtmx          bool
-	RootUrl         string
+	CurrentYear      int
+	Flash            string
+	IsAuthenticated  bool
+	CSRFToken        string
+	CurrentUser      *models.User
+	IsHtmx           bool
+	RootUrl          string
+	LocalAuthEnabled bool
+	OIDCEnabled      bool
 }
 
 func shortIdToSlug(shortId uint64) string {
@@ -52,7 +54,9 @@ templ topNav(data CommonData) {
 				Settings | 
 				Logout
 			} else {
-				Create an Account | 
+				if data.LocalAuthEnabled {
+					Create an Account | 
+				}
 				Login
 			}
 		
diff --git a/ui/views/common_templ.go b/ui/views/common_templ.go
index 6a7d8aa..e4131f2 100644
--- a/ui/views/common_templ.go
+++ b/ui/views/common_templ.go
@@ -14,13 +14,15 @@ import "fmt"
 import "strings"
 
 type CommonData struct {
-	CurrentYear     int
-	Flash           string
-	IsAuthenticated bool
-	CSRFToken       string
-	CurrentUser     *models.User
-	IsHtmx          bool
-	RootUrl         string
+	CurrentYear      int
+	Flash            string
+	IsAuthenticated  bool
+	CSRFToken        string
+	CurrentUser      *models.User
+	IsHtmx           bool
+	RootUrl          string
+	LocalAuthEnabled bool
+	OIDCEnabled      bool
 }
 
 func shortIdToSlug(shortId uint64) string {
@@ -102,7 +104,7 @@ func topNav(data CommonData) templ.Component {
 			var templ_7745c5c3_Var3 string
 			templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(data.CurrentUser.Username)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 45, Col: 40}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 47, Col: 40}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
 			if templ_7745c5c3_Err != nil {
@@ -121,7 +123,7 @@ func topNav(data CommonData) templ.Component {
 			var templ_7745c5c3_Var4 string
 			templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(hxHeaders)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 53, Col: 62}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 55, Col: 62}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
 			if templ_7745c5c3_Err != nil {
@@ -132,12 +134,18 @@ func topNav(data CommonData) templ.Component {
 				return templ_7745c5c3_Err
 			}
 		} else {
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account |  Login")
+			if data.LocalAuthEnabled {
+				templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account | ")
+				if templ_7745c5c3_Err != nil {
+					return templ_7745c5c3_Err
+				}
+			}
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, " Login")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
 		}
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
@@ -166,7 +174,7 @@ func commonFooter() templ.Component {
 			templ_7745c5c3_Var5 = templ.NopComponent
 		}
 		ctx = templ.ClearChildren(ctx)
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
@@ -195,20 +203,20 @@ func base(title string, data CommonData) templ.Component {
 			templ_7745c5c3_Var6 = templ.NopComponent
 		}
 		ctx = templ.ClearChildren(ctx)
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
 		var templ_7745c5c3_Var7 string
 		templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(title)
 		if templ_7745c5c3_Err != nil {
-			return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 72, Col: 17}
+			return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 76, Col: 17}
 		}
 		_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, " - webweav.ing")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, " - webweav.ing")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
@@ -220,25 +228,25 @@ func base(title string, data CommonData) templ.Component {
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
 		if data.Flash != "" {
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
 			var templ_7745c5c3_Var8 string
 			templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(data.Flash)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 84, Col: 36}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 88, Col: 36}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "
 ")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
@@ -247,7 +255,7 @@ func base(title string, data CommonData) templ.Component {
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
@@ -255,7 +263,7 @@ func base(title string, data CommonData) templ.Component {
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
-		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
+		templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 17, "")
 		if templ_7745c5c3_Err != nil {
 			return templ_7745c5c3_Err
 		}
diff --git a/ui/views/users.templ b/ui/views/users.templ
index a0d60b3..299d282 100644
--- a/ui/views/users.templ
+++ b/ui/views/users.templ
@@ -31,7 +31,9 @@ templ UserLogin(title string, data CommonData, form forms.UserLoginForm) {
 			
 			
 		
 	}
diff --git a/ui/views/users_templ.go b/ui/views/users_templ.go
index 391f030..ec96f47 100644
--- a/ui/views/users_templ.go
+++ b/ui/views/users_templ.go
@@ -143,7 +143,17 @@ func UserLogin(title string, data CommonData, form forms.UserLoginForm) templ.Co
 					return templ_7745c5c3_Err
 				}
 			}
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
@@ -190,130 +200,130 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
 				}()
 			}
 			ctx = templ.InitializeContext(ctx)
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "User Registration
")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
@@ -360,33 +370,33 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
 				}()
 			}
 			ctx = templ.InitializeContext(ctx)
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 29, "")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
 			var templ_7745c5c3_Var18 string
 			templ_7745c5c3_Var18, templ_7745c5c3_Err = templ.JoinStringErrs(user.Username)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 21}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 80, Col: 21}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var18))
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 30, "
")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "
")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
 			var templ_7745c5c3_Var19 string
 			templ_7745c5c3_Var19, templ_7745c5c3_Err = templ.JoinStringErrs(user.Email)
 			if templ_7745c5c3_Err != nil {
-				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 79, Col: 17}
+				return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 81, Col: 17}
 			}
 			_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var19))
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "
")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 33, "
")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
@@ -434,89 +444,89 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
 				}()
 			}
 			ctx = templ.InitializeContext(ctx)
-			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "User Settings
")
+			templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 42, " ")
 			if templ_7745c5c3_Err != nil {
 				return templ_7745c5c3_Err
 			}
-- 
2.30.2
From db1d4e1ad2fe4f38d2762cf0b05403aaaccea2d6 Mon Sep 17 00:00:00 2001
From: yequari 
Date: Sun, 29 Jun 2025 20:16:37 -0700
Subject: [PATCH 3/5] env file config
---
 cmd/web/main.go | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/cmd/web/main.go b/cmd/web/main.go
index 7df1d2b..b764069 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -58,10 +58,15 @@ func main() {
 	addr := flag.String("addr", ":3000", "HTTP network address")
 	dsn := flag.String("dsn", "guestbook.db", "data source name")
 	debug := flag.Bool("debug", false, "enable debug mode")
+	env := flag.String("env", ".env", ".env file path")
 	flag.Parse()
 
 	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
-	godotenv.Load(".env.dev")
+	err := godotenv.Load(*env)
+	if err != nil {
+		logger.Error(err.Error())
+		os.Exit(1)
+	}
 	cfg, err := setupConfig(*addr)
 	if err != nil {
 		logger.Error(err.Error())
-- 
2.30.2
From f6e332b76abcbfb47f123dfdbf055f2d411aa3ba Mon Sep 17 00:00:00 2001
From: yequari 
Date: Thu, 17 Jul 2025 17:02:35 -0700
Subject: [PATCH 4/5] unit testing for oidc and fix test cases for remote
 comments
---
 cmd/web/handlers_guestbook_test.go          |  54 +++++---
 cmd/web/handlers_user.go                    |  26 +++-
 cmd/web/handlers_user_test.go               | 129 ++++++++++++++++++++
 cmd/web/main.go                             |   7 +-
 cmd/web/testutils_test.go                   |  30 +++++
 internal/auth/auth.go                       |  14 +++
 internal/models/mocks/{users.go => user.go} |  29 +++--
 internal/models/user.go                     |  65 +++++-----
 8 files changed, 289 insertions(+), 65 deletions(-)
 create mode 100644 internal/auth/auth.go
 rename internal/models/mocks/{users.go => user.go} (84%)
diff --git a/cmd/web/handlers_guestbook_test.go b/cmd/web/handlers_guestbook_test.go
index 182f081..12870c8 100644
--- a/cmd/web/handlers_guestbook_test.go
+++ b/cmd/web/handlers_guestbook_test.go
@@ -1,9 +1,12 @@
 package main
 
 import (
+	"bytes"
 	"fmt"
+	"io"
 	"net/http"
 	"net/url"
+	"strings"
 	"testing"
 
 	"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
@@ -150,9 +153,6 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 	ts := newTestServer(t, app.routes())
 	defer ts.Close()
 
-	_, _, body := ts.get(t, fmt.Sprintf("/websites/%s/guestbook", shortIdToSlug(1)))
-	validCSRFToken := extractCSRFToken(t, body)
-
 	const (
 		validAuthorName  = "John Test"
 		validAuthorEmail = "test@example.com"
@@ -166,8 +166,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 		authorEmail string
 		authorSite  string
 		content     string
-		csrfToken   string
 		wantCode    int
+		wantBody    string
 	}{
 		{
 			name:        "Valid input",
@@ -175,8 +175,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 			authorEmail: validAuthorEmail,
 			authorSite:  validAuthorSite,
 			content:     validContent,
-			csrfToken:   validCSRFToken,
-			wantCode:    http.StatusSeeOther,
+			wantCode:    http.StatusOK,
+			wantBody:    "Comment successfully posted",
 		},
 		{
 			name:        "Blank name",
@@ -184,8 +184,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 			authorEmail: validAuthorEmail,
 			authorSite:  validAuthorSite,
 			content:     validContent,
-			csrfToken:   validCSRFToken,
-			wantCode:    http.StatusUnprocessableEntity,
+			wantCode:    http.StatusOK,
+			wantBody:    "An error occurred",
 		},
 		{
 			name:        "Blank email",
@@ -193,8 +193,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 			authorEmail: "",
 			authorSite:  validAuthorSite,
 			content:     validContent,
-			csrfToken:   validCSRFToken,
-			wantCode:    http.StatusSeeOther,
+			wantCode:    http.StatusOK,
+			wantBody:    "Comment successfully posted",
 		},
 		{
 			name:        "Blank site",
@@ -202,8 +202,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 			authorEmail: validAuthorEmail,
 			authorSite:  "",
 			content:     validContent,
-			csrfToken:   validCSRFToken,
-			wantCode:    http.StatusSeeOther,
+			wantCode:    http.StatusOK,
+			wantBody:    "Comment successfully posted",
 		},
 		{
 			name:        "Blank content",
@@ -211,21 +211,39 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
 			authorEmail: validAuthorEmail,
 			authorSite:  validAuthorSite,
 			content:     "",
-			csrfToken:   validCSRFToken,
-			wantCode:    http.StatusUnprocessableEntity,
+			wantCode:    http.StatusOK,
+			wantBody:    "An error occurred",
 		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
+
 			form := url.Values{}
 			form.Add("authorname", tt.authorName)
 			form.Add("authoremail", tt.authorEmail)
 			form.Add("authorsite", tt.authorSite)
 			form.Add("content", tt.content)
-			form.Add("csrf_token", tt.csrfToken)
-			code, _, body := ts.postForm(t, fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)), form)
-			assert.Equal(t, code, tt.wantCode)
-			assert.Equal(t, body, body)
+			r, err := http.NewRequest("POST", ts.URL, strings.NewReader(form.Encode()))
+			if err != nil {
+				t.Fatal(err)
+			}
+			r.URL.Path = fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1))
+			r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+			r.Header.Set("Origin", "http://example.com")
+
+			resp, err := ts.Client().Do(r)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			defer resp.Body.Close()
+			body, err := io.ReadAll(resp.Body)
+			if err != nil {
+				t.Fatal(err)
+			}
+			body = bytes.TrimSpace(body)
+			assert.Equal(t, resp.StatusCode, tt.wantCode)
+			assert.StringContains(t, string(body), tt.wantBody)
 		})
 	}
 }
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 141d459..132f099 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -173,13 +173,35 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
 		app.serverError(w, r, err)
 		return
 	}
-	id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
-	if err != nil {
+
+	// search for user by subject
+	id, err := app.users.GetBySubject(t.Subject)
+	if err != nil && errors.Is(err, models.ErrNoRecord) {
+		// if no user is found, check if they have signed up by email already
+		id, err = app.users.GetByEmail(t.Email)
+		if err == nil {
+			// if user is found by email, update subject to match them in the first step next time
+			err2 := app.users.UpdateSubject(id, t.Subject)
+			if err2 != nil {
+				app.serverError(w, r, err2)
+				return
+			}
+		}
+	} else if err != nil {
+		app.serverError(w, r, err)
+		return
+	}
+	if err != nil && errors.Is(err, models.ErrNoRecord) {
+		// if no user is found by subject or email, create a new user
 		id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
 		if err != nil {
 			app.serverError(w, r, err)
 		}
+	} else if err != nil {
+		app.serverError(w, r, err)
+		return
 	}
+
 	app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
 	http.Redirect(w, r, "/", http.StatusSeeOther)
 }
diff --git a/cmd/web/handlers_user_test.go b/cmd/web/handlers_user_test.go
index 8a97224..32adfe9 100644
--- a/cmd/web/handlers_user_test.go
+++ b/cmd/web/handlers_user_test.go
@@ -1,11 +1,17 @@
 package main
 
 import (
+	"context"
+	"crypto/rsa"
 	"net/http"
 	"net/url"
 	"testing"
+	"time"
 
 	"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
+	"github.com/coreos/go-oidc/v3/oidc"
+	"github.com/coreos/go-oidc/v3/oidc/oidctest"
+	"golang.org/x/oauth2"
 )
 
 func TestUserSignup(t *testing.T) {
@@ -119,3 +125,126 @@ func TestUserSignup(t *testing.T) {
 	}
 
 }
+
+type OAuth2Mock struct {
+	Srv  *testServer
+	Priv *rsa.PrivateKey
+}
+
+func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
+	return ""
+}
+
+func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
+	tkn := oauth2.Token{
+		AccessToken: "AccessToken",
+		Expiry:      time.Now().Add(1 * time.Hour),
+	}
+	m := make(map[string]interface{})
+	var rawClaims = `{
+		"iss": "` + o.Srv.URL + `",
+		"aud": "my-client-id",
+		"sub": "foo",
+		"email": "foo@example.com",
+		"email_verified": true,
+		"nonce": "nonce"
+		}`
+
+	m["id_token"] = oidctest.SignIDToken(o.Priv, "test-key", oidc.RS256, rawClaims)
+	return tkn.WithExtra(m), nil
+}
+
+func (o *OAuth2Mock) Client(ctx context.Context, t *oauth2.Token) *http.Client {
+	return nil
+}
+
+func TestUserOIDCCallback(t *testing.T) {
+	app := newTestApplication(t)
+	ts := newTestServer(t, app.routes())
+
+	priv := newTestKey(t)
+	srv := newTestOIDCServer(t, priv)
+
+	defer srv.Close()
+	defer ts.Close()
+	ctx := context.Background()
+
+	p, err := oidc.NewProvider(ctx, srv.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	cfg := &oidc.Config{
+		ClientID:        "my-client-id",
+		SkipExpiryCheck: true,
+	}
+	v := p.VerifierContext(ctx, cfg)
+	app.config.oauth = applicationOauthConfig{
+		ctx:        context.Background(),
+		oidcConfig: cfg,
+		config: &OAuth2Mock{
+			Srv:  srv,
+			Priv: priv,
+		},
+		provider: p,
+		verifier: v,
+	}
+	app.config.oauthEnabled = true
+
+	const (
+		validSubject = "goodSubject"
+		validUserId  = 1
+		validEmail   = "test@example.com"
+		validState   = "goodState"
+	)
+
+	tests := []struct {
+		name     string
+		subject  string
+		email    string
+		state    string
+		wantCode int
+	}{
+		{
+			name:     "Found Subject",
+			subject:  validSubject,
+			email:    validEmail,
+			state:    validState,
+			wantCode: http.StatusSeeOther,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(*testing.T) {
+			r, err := http.NewRequest("GET", ts.URL, nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			r.URL.Path = "/users/login/oidc/callback"
+			q := r.URL.Query()
+			q.Add("state", tt.state)
+			r.URL.RawQuery = q.Encode()
+			c := &http.Cookie{
+				Name:     "state",
+				Value:    validState,
+				MaxAge:   int(time.Hour.Seconds()),
+				Secure:   r.TLS != nil,
+				HttpOnly: true,
+			}
+			d := &http.Cookie{
+				Name:     "nonce",
+				Value:    "nonce",
+				MaxAge:   int(time.Hour.Seconds()),
+				Secure:   r.TLS != nil,
+				HttpOnly: true,
+			}
+			r.AddCookie(c)
+			r.AddCookie(d)
+			resp, err := ts.Client().Do(r)
+			if err != nil {
+				t.Fatal(err)
+			}
+			assert.Equal(t, resp.StatusCode, tt.wantCode)
+		})
+	}
+
+}
diff --git a/cmd/web/main.go b/cmd/web/main.go
index b764069..b15619b 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -16,6 +16,7 @@ import (
 	"time"
 	"unicode"
 
+	"git.32bit.cafe/32bitcafe/guestbook/internal/auth"
 	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
 	"github.com/alexedwards/scs/sqlite3store"
 	"github.com/alexedwards/scs/v2"
@@ -28,9 +29,9 @@ import (
 
 type applicationOauthConfig struct {
 	ctx        context.Context
-	config     oauth2.Config
-	provider   *oidc.Provider
 	oidcConfig *oidc.Config
+	config     auth.OAuth2ConfigInterface
+	provider   *oidc.Provider
 	verifier   *oidc.IDTokenVerifier
 }
 
@@ -204,7 +205,7 @@ func setupConfig(addr string) (applicationConfig, error) {
 		ClientID: clientID,
 	}
 	o.verifier = provider.Verifier(o.oidcConfig)
-	o.config = oauth2.Config{
+	o.config = &oauth2.Config{
 		ClientID:     clientID,
 		ClientSecret: clientSecret,
 		Endpoint:     provider.Endpoint(),
diff --git a/cmd/web/testutils_test.go b/cmd/web/testutils_test.go
index b5267b6..afab66a 100644
--- a/cmd/web/testutils_test.go
+++ b/cmd/web/testutils_test.go
@@ -2,6 +2,8 @@ package main
 
 import (
 	"bytes"
+	"crypto/rand"
+	"crypto/rsa"
 	"html"
 	"io"
 	"log/slog"
@@ -15,6 +17,8 @@ import (
 
 	"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
 	"github.com/alexedwards/scs/v2"
+	"github.com/coreos/go-oidc/v3/oidc"
+	"github.com/coreos/go-oidc/v3/oidc/oidctest"
 	"github.com/gorilla/schema"
 )
 
@@ -34,9 +38,35 @@ func newTestApplication(t *testing.T) *application {
 		guestbookComments: &mocks.GuestbookCommentModel{},
 		formDecoder:       formDecoder,
 		timezones:         getAvailableTimezones(),
+		config: applicationConfig{
+			localAuthEnabled: true,
+		},
 	}
 }
 
+func newTestKey(t *testing.T) *rsa.PrivateKey {
+	priv, err := rsa.GenerateKey(rand.Reader, 2048)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return priv
+}
+
+func newTestOIDCServer(t *testing.T, priv *rsa.PrivateKey) *testServer {
+	s := &oidctest.Server{
+		PublicKeys: []oidctest.PublicKey{
+			{
+				PublicKey: priv.Public(),
+				KeyID:     "test-key",
+				Algorithm: oidc.ES256,
+			},
+		},
+	}
+	ts := httptest.NewServer(s)
+	s.SetIssuer(ts.URL)
+	return &testServer{ts}
+}
+
 type testServer struct {
 	*httptest.Server
 }
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
new file mode 100644
index 0000000..9704cf4
--- /dev/null
+++ b/internal/auth/auth.go
@@ -0,0 +1,14 @@
+package auth
+
+import (
+	"context"
+	"net/http"
+
+	"golang.org/x/oauth2"
+)
+
+type OAuth2ConfigInterface interface {
+	AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
+	Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
+	Client(ctx context.Context, t *oauth2.Token) *http.Client
+}
diff --git a/internal/models/mocks/users.go b/internal/models/mocks/user.go
similarity index 84%
rename from internal/models/mocks/users.go
rename to internal/models/mocks/user.go
index e0ab8a8..931e841 100644
--- a/internal/models/mocks/users.go
+++ b/internal/models/mocks/user.go
@@ -1,6 +1,7 @@
 package mocks
 
 import (
+	"errors"
 	"time"
 
 	"git.32bit.cafe/32bitcafe/guestbook/internal/models"
@@ -76,13 +77,6 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 	return 0, models.ErrInvalidCredentials
 }
 
-func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
-	if email == "test@example.com" {
-		return 1, nil
-	}
-	return 0, models.ErrInvalidCredentials
-}
-
 func (m *UserModel) Exists(id int64) (bool, error) {
 	switch id {
 	case 1:
@@ -103,3 +97,24 @@ func (m *UserModel) UpdateUserSettings(userId int64, settings models.UserSetting
 func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
 	return nil
 }
+
+func (m *UserModel) GetBySubject(subject string) (int64, error) {
+	if subject == "goodSubject" {
+		return 1, nil
+	}
+	return -1, models.ErrNoRecord
+}
+
+func (m *UserModel) GetByEmail(email string) (int64, error) {
+	if email == "test@example.com" {
+		return 1, nil
+	}
+	return -1, models.ErrNoRecord
+}
+
+func (m *UserModel) UpdateSubject(userId int64, subject string) error {
+	if userId == 1 {
+		return nil
+	}
+	return errors.New("invalid")
+}
diff --git a/internal/models/user.go b/internal/models/user.go
index 12a86e0..13eacc5 100644
--- a/internal/models/user.go
+++ b/internal/models/user.go
@@ -49,12 +49,14 @@ type UserModelInterface interface {
 	InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
 	Get(shortId uint64) (User, error)
 	GetById(id int64) (User, error)
+	GetByEmail(email string) (int64, error)
+	GetBySubject(subject string) (int64, error)
 	GetAll() ([]User, error)
 	Authenticate(email, password string) (int64, error)
-	AuthenticateByOIDC(email, subject string) (int64, error)
 	Exists(id int64) (bool, error)
 	UpdateUserSettings(userId int64, settings UserSettings) error
 	UpdateSetting(userId int64, setting Setting, value string) error
+	UpdateSubject(userId int64, subject string) error
 }
 
 func (m *UserModel) InitializeSettingsMap() error {
@@ -292,51 +294,44 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
 	return id, nil
 }
 
-func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) {
+func (m *UserModel) GetByEmail(email string) (int64, error) {
 	var id int64
-	var s sql.NullString
-	tx, err := m.DB.Begin()
-	if err != nil {
-		return -1, err
-	}
-	stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
-	err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
+	stmt := `SELECT Id FROM users WHERE Email = ?`
+	err := m.DB.QueryRow(stmt, email).Scan(&id)
 	if err != nil {
 		if errors.Is(err, sql.ErrNoRows) {
-			if rollbackErr := tx.Rollback(); rollbackErr != nil {
-				return -1, err
-			}
 			return -1, ErrNoRecord
 		} else {
-			if rollbackErr := tx.Rollback(); rollbackErr != nil {
-				return -1, err
-			}
 			return -1, err
 		}
 	}
-
-	if !s.Valid {
-		stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
-		_, err = tx.Exec(stmt, subject, id)
-		if err != nil {
-			if rollbackErr := tx.Rollback(); rollbackErr != nil {
-				return -1, err
-			}
-			return -1, err
-		}
-	} else if subject != s.String {
-		if rollbackErr := tx.Rollback(); rollbackErr != nil {
-			return -1, ErrInvalidCredentials
-		}
-	}
-
-	err = tx.Commit()
-	if err != nil {
-		return -1, err
-	}
 	return id, nil
 }
 
+func (m *UserModel) GetBySubject(subject string) (int64, error) {
+	var id int64
+	var s sql.NullString
+	stmt := `SELECT Id, OIDCSubject FROM users WHERE OIDCSubject = ?`
+	err := m.DB.QueryRow(stmt, subject).Scan(&id, &s)
+	if err != nil {
+		if errors.Is(err, sql.ErrNoRows) {
+			return -1, ErrNoRecord
+		} else {
+			return -1, err
+		}
+	}
+	return id, nil
+}
+
+func (m *UserModel) UpdateSubject(userId int64, subject string) error {
+	stmt := `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
+	_, err := m.DB.Exec(stmt, subject, userId)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
 func (m *UserModel) Exists(id int64) (bool, error) {
 	var exists bool
 	stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
-- 
2.30.2
From 2759127cf97ae396ea9d6d9b86be0ac592a19efc Mon Sep 17 00:00:00 2001
From: yequari 
Date: Sat, 19 Jul 2025 15:36:50 -0700
Subject: [PATCH 5/5] add oidc unit tests
---
 cmd/web/handlers_user_test.go | 68 ++++++++++++++++++++++++++---------
 internal/models/mocks/user.go |  4 ++-
 2 files changed, 55 insertions(+), 17 deletions(-)
diff --git a/cmd/web/handlers_user_test.go b/cmd/web/handlers_user_test.go
index 32adfe9..4624a20 100644
--- a/cmd/web/handlers_user_test.go
+++ b/cmd/web/handlers_user_test.go
@@ -127,8 +127,10 @@ func TestUserSignup(t *testing.T) {
 }
 
 type OAuth2Mock struct {
-	Srv  *testServer
-	Priv *rsa.PrivateKey
+	Srv     *testServer
+	Priv    *rsa.PrivateKey
+	Subject string
+	Email   string
 }
 
 func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
@@ -140,12 +142,12 @@ func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.A
 		AccessToken: "AccessToken",
 		Expiry:      time.Now().Add(1 * time.Hour),
 	}
-	m := make(map[string]interface{})
+	m := make(map[string]any)
 	var rawClaims = `{
 		"iss": "` + o.Srv.URL + `",
 		"aud": "my-client-id",
-		"sub": "foo",
-		"email": "foo@example.com",
+		"sub": "` + o.Subject + `",
+		"email": "` + o.Email + `",
 		"email_verified": true,
 		"nonce": "nonce"
 		}`
@@ -178,23 +180,26 @@ func TestUserOIDCCallback(t *testing.T) {
 		SkipExpiryCheck: true,
 	}
 	v := p.VerifierContext(ctx, cfg)
+	oMock := &OAuth2Mock{
+		Srv:     srv,
+		Priv:    priv,
+		Subject: "foo",
+	}
 	app.config.oauth = applicationOauthConfig{
 		ctx:        context.Background(),
 		oidcConfig: cfg,
-		config: &OAuth2Mock{
-			Srv:  srv,
-			Priv: priv,
-		},
-		provider: p,
-		verifier: v,
+		config:     oMock,
+		provider:   p,
+		verifier:   v,
 	}
 	app.config.oauthEnabled = true
 
 	const (
-		validSubject = "goodSubject"
-		validUserId  = 1
-		validEmail   = "test@example.com"
-		validState   = "goodState"
+		validSubject   = "goodSubject"
+		unknownSubject = "foo"
+		validUserId    = 1
+		validEmail     = "test@example.com"
+		validState     = "goodState"
 	)
 
 	tests := []struct {
@@ -205,16 +210,46 @@ func TestUserOIDCCallback(t *testing.T) {
 		wantCode int
 	}{
 		{
-			name:     "Found Subject",
+			name:     "By Subject",
 			subject:  validSubject,
+			email:    "",
+			state:    validState,
+			wantCode: http.StatusSeeOther,
+		},
+		{
+			name:     "By Email",
+			subject:  unknownSubject,
 			email:    validEmail,
 			state:    validState,
 			wantCode: http.StatusSeeOther,
 		},
+		{
+			name:     "No User",
+			subject:  unknownSubject,
+			email:    "",
+			state:    validState,
+			wantCode: http.StatusSeeOther,
+		},
+		{
+			name:     "Invalid State",
+			subject:  unknownSubject,
+			email:    validEmail,
+			state:    "",
+			wantCode: http.StatusInternalServerError,
+		},
+		{
+			name:     "Unknown Subject & Email",
+			subject:  unknownSubject,
+			email:    "",
+			state:    validState,
+			wantCode: http.StatusInternalServerError,
+		},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(*testing.T) {
+			oMock.Subject = tt.subject
+			oMock.Email = tt.email
 			r, err := http.NewRequest("GET", ts.URL, nil)
 			if err != nil {
 				t.Fatal(err)
@@ -223,6 +258,7 @@ func TestUserOIDCCallback(t *testing.T) {
 			q := r.URL.Query()
 			q.Add("state", tt.state)
 			r.URL.RawQuery = q.Encode()
+
 			c := &http.Cookie{
 				Name:     "state",
 				Value:    validState,
diff --git a/internal/models/mocks/user.go b/internal/models/mocks/user.go
index 931e841..0ca4d30 100644
--- a/internal/models/mocks/user.go
+++ b/internal/models/mocks/user.go
@@ -101,8 +101,10 @@ func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value st
 func (m *UserModel) GetBySubject(subject string) (int64, error) {
 	if subject == "goodSubject" {
 		return 1, nil
+	} else if subject == "foo" {
+		return -1, models.ErrNoRecord
 	}
-	return -1, models.ErrNoRecord
+	return -1, errors.New("Unexpected Error")
 }
 
 func (m *UserModel) GetByEmail(email string) (int64, error) {
-- 
2.30.2