From 58e6f35585928aacca8bf9789fea4c16c800bd56 Mon Sep 17 00:00:00 2001
From: yequari
Date: Sun, 29 Jun 2025 17:19:35 -0700
Subject: [PATCH 1/3] implement OIDC login flow
---
.gitignore | 5 +-
cmd/web/handlers_user.go | 83 ++++++++++++++++
cmd/web/helpers.go | 22 +++++
cmd/web/main.go | 57 ++++++++++-
cmd/web/routes.go | 2 +
go.mod | 5 +
go.sum | 21 +++-
internal/models/mocks/users.go | 16 +++
internal/models/user.go | 98 +++++++++++++++++++
.../000006_change_password_field.down.sql | 4 +
.../000006_change_password_field.up.sql | 4 +
migrations/000007_add_oidc_subject.down.sql | 1 +
migrations/000007_add_oidc_subject.up.sql | 1 +
ui/views/common.templ | 1 +
ui/views/common_templ.go | 8 +-
ui/views/users.templ | 3 +-
ui/views/users_templ.go | 28 +++---
17 files changed, 333 insertions(+), 26 deletions(-)
create mode 100644 migrations/000006_change_password_field.down.sql
create mode 100644 migrations/000006_change_password_field.up.sql
create mode 100644 migrations/000007_add_oidc_subject.down.sql
create mode 100644 migrations/000007_add_oidc_subject.up.sql
diff --git a/.gitignore b/.gitignore
index d803d48..319a781 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,4 +31,7 @@ tls/
test.db.old
.gitignore
.nvim/session
-*templ.txt
\ No newline at end of file
+*templ.txt
+
+# env files
+.env*
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 8690065..3c21fde 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -9,6 +9,7 @@ import (
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
"git.32bit.cafe/32bitcafe/guestbook/internal/validator"
"git.32bit.cafe/32bitcafe/guestbook/ui/views"
+ "github.com/coreos/go-oidc/v3/oidc"
)
func (app *application) getUserRegister(w http.ResponseWriter, r *http.Request) {
@@ -92,6 +93,88 @@ func (app *application) postUserLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusSeeOther)
}
+func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
+ state, err := randString(16)
+ if err != nil {
+ app.serverError(w, r, err)
+ return
+ }
+ nonce, err := randString(16)
+ if err != nil {
+ app.serverError(w, r, err)
+ return
+ }
+
+ setCallbackCookie(w, r, "state", state)
+ setCallbackCookie(w, r, "nonce", nonce)
+
+ http.Redirect(w, r, app.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
+}
+
+func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Request) {
+ state, err := r.Cookie("state")
+ if err != nil {
+ app.clientError(w, http.StatusBadRequest)
+ return
+ }
+ if r.URL.Query().Get("state") != state.Value {
+ app.clientError(w, http.StatusBadRequest)
+ return
+ }
+
+ oauth2Token, err := app.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
+ if err != nil {
+ app.logger.Error("Failed to exchange token")
+ app.serverError(w, r, err)
+ return
+ }
+ rawIDToken, ok := oauth2Token.Extra("id_token").(string)
+ if !ok {
+ app.serverError(w, r, errors.New("No id_token field in oauth2 token"))
+ return
+ }
+ idToken, err := app.oauth.verifier.Verify(r.Context(), rawIDToken)
+ if err != nil {
+ app.logger.Error("Failed to verify ID token")
+ app.serverError(w, r, err)
+ return
+ }
+
+ nonce, err := r.Cookie("nonce")
+ if err != nil {
+ app.logger.Error("nonce not found")
+ app.serverError(w, r, err)
+ return
+ }
+ if idToken.Nonce != nonce.Value {
+ app.serverError(w, r, errors.New("nonce did not match"))
+ return
+ }
+
+ oauth2Token.AccessToken = "*REDACTED*"
+
+ var t models.UserIdToken
+ if err := idToken.Claims(&t); err != nil {
+ app.serverError(w, r, err)
+ return
+ }
+
+ err = app.sessionManager.RenewToken(r.Context())
+ if err != nil {
+ app.serverError(w, r, err)
+ return
+ }
+ id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
+ if err != nil {
+ id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
+ if err != nil {
+ app.serverError(w, r, err)
+ }
+ }
+ app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+}
+
func (app *application) postUserLogout(w http.ResponseWriter, r *http.Request) {
err := app.sessionManager.RenewToken(r.Context())
if err != nil {
diff --git a/cmd/web/helpers.go b/cmd/web/helpers.go
index f86b71b..239aea2 100644
--- a/cmd/web/helpers.go
+++ b/cmd/web/helpers.go
@@ -1,8 +1,11 @@
package main
import (
+ "crypto/rand"
+ "encoding/base64"
"errors"
"fmt"
+ "io"
"math"
"net/http"
"net/url"
@@ -140,3 +143,22 @@ func matchOrigin(origin string, u *url.URL) bool {
}
return true
}
+
+func randString(nByte int) (string, error) {
+ b := make([]byte, nByte)
+ if _, err := io.ReadFull(rand.Reader, b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) {
+ c := &http.Cookie{
+ Name: name,
+ Value: value,
+ MaxAge: int(time.Hour.Seconds()),
+ Secure: r.TLS != nil,
+ HttpOnly: true,
+ }
+ http.SetCookie(w, c)
+}
diff --git a/cmd/web/main.go b/cmd/web/main.go
index 5fe6dac..e999e92 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -1,9 +1,12 @@
package main
import (
+ "context"
"crypto/tls"
"database/sql"
+ "errors"
"flag"
+ "fmt"
"log/slog"
"net/http"
"os"
@@ -14,10 +17,21 @@ import (
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
"github.com/alexedwards/scs/sqlite3store"
"github.com/alexedwards/scs/v2"
+ "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/schema"
+ "github.com/joho/godotenv"
_ "github.com/mattn/go-sqlite3"
+ "golang.org/x/oauth2"
)
+type applicationOauthConfig struct {
+ ctx context.Context
+ config oauth2.Config
+ provider *oidc.Provider
+ oidcConfig *oidc.Config
+ verifier *oidc.IDTokenVerifier
+}
+
type application struct {
sequence uint16
logger *slog.Logger
@@ -26,6 +40,7 @@ type application struct {
guestbookComments models.GuestbookCommentModelInterface
sessionManager *scs.SessionManager
formDecoder *schema.Decoder
+ oauth applicationOauthConfig
debug bool
timezones []string
rootUrl string
@@ -35,10 +50,11 @@ func main() {
addr := flag.String("addr", ":3000", "HTTP network address")
dsn := flag.String("dsn", "guestbook.db", "data source name")
debug := flag.Bool("debug", false, "enable debug mode")
- root := flag.String("root", "localhost:3000", "root URL of application")
+ root := flag.String("root", "https://localhost:3000", "root URL of application")
flag.Parse()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
+ godotenv.Load(".env.dev")
db, err := openDB(*dsn)
if err != nil {
@@ -67,6 +83,13 @@ func main() {
rootUrl: *root,
}
+ o, err := setupOauth(app.rootUrl)
+ if err != nil {
+ logger.Error(err.Error())
+ os.Exit(1)
+ }
+ app.oauth = o
+
err = app.users.InitializeSettingsMap()
if err != nil {
logger.Error(err.Error())
@@ -114,6 +137,38 @@ func openDB(dsn string) (*sql.DB, error) {
return db, nil
}
+func setupOauth(rootUrl string) (applicationOauthConfig, error) {
+ var c applicationOauthConfig
+ var (
+ oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
+ clientID = os.Getenv("OAUTH2_CLIENT_ID")
+ clientSecret = os.Getenv("OAUTH2_CLIENT_SECRET")
+ )
+ if oauth2Provider == "" || clientID == "" || clientSecret == "" {
+ return applicationOauthConfig{}, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+ }
+
+ c.ctx = context.Background()
+ provider, err := oidc.NewProvider(c.ctx, oauth2Provider)
+ if err != nil {
+ return applicationOauthConfig{}, err
+ }
+ c.provider = provider
+ c.oidcConfig = &oidc.Config{
+ ClientID: clientID,
+ }
+ c.verifier = provider.Verifier(c.oidcConfig)
+ c.config = oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Endpoint: provider.Endpoint(),
+ RedirectURL: fmt.Sprintf("%s/users/login/oidc/callback", rootUrl),
+ Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+ }
+ return c, nil
+
+}
+
func getAvailableTimezones() []string {
var zones []string
var zoneDirs = []string{
diff --git a/cmd/web/routes.go b/cmd/web/routes.go
index f9e768e..4f32a1b 100644
--- a/cmd/web/routes.go
+++ b/cmd/web/routes.go
@@ -27,6 +27,8 @@ func (app *application) routes() http.Handler {
mux.Handle("POST /users/register", dynamic.ThenFunc(app.postUserRegister))
mux.Handle("GET /users/login", dynamic.ThenFunc(app.getUserLogin))
mux.Handle("POST /users/login", dynamic.ThenFunc(app.postUserLogin))
+ mux.Handle("/users/login/oidc", dynamic.ThenFunc(app.userLoginOIDC))
+ mux.Handle("/users/login/oidc/callback", dynamic.ThenFunc(app.userLoginOIDCCallback))
mux.Handle("GET /help", dynamic.ThenFunc(app.notImplemented))
protected := dynamic.Append(app.requireAuthentication)
diff --git a/go.mod b/go.mod
index 46407e0..dd2b0fc 100644
--- a/go.mod
+++ b/go.mod
@@ -6,9 +6,14 @@ require (
github.com/a-h/templ v0.3.833
github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c
github.com/alexedwards/scs/v2 v2.8.0
+ github.com/coreos/go-oidc/v3 v3.14.1
github.com/gorilla/schema v1.4.1
+ github.com/joho/godotenv v1.5.1
github.com/justinas/alice v1.2.0
github.com/justinas/nosurf v1.1.1
github.com/mattn/go-sqlite3 v1.14.24
golang.org/x/crypto v0.36.0
+ golang.org/x/oauth2 v0.30.0
)
+
+require github.com/go-jose/go-jose/v4 v4.0.5 // indirect
diff --git a/go.sum b/go.sum
index 19a290f..f2165b3 100644
--- a/go.sum
+++ b/go.sum
@@ -1,24 +1,35 @@
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
-github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885 h1:+DCxWg/ojncqS+TGAuRUoV7OfG/S4doh0pcpAwEcow0=
-github.com/alexedwards/scs/sqlite3store v0.0.0-20240316134038-7e11d57e8885/go.mod h1:Iyk7S76cxGaiEX/mSYmTZzYehp4KfyylcLaV3OnToss=
github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c h1:0gBCIsmH3+aaWK55APhhY7/Z+uv5IdbMqekI97V9shU=
github.com/alexedwards/scs/sqlite3store v0.0.0-20250212122300-421ef1d8611c/go.mod h1:Iyk7S76cxGaiEX/mSYmTZzYehp4KfyylcLaV3OnToss=
github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw=
github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
+github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
+github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
+github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
github.com/justinas/nosurf v1.1.1 h1:92Aw44hjSK4MxJeMSyDa7jwuI9GR2J/JCQiaKvXXSlk=
github.com/justinas/nosurf v1.1.1/go.mod h1:ALpWdSbuNGy2lZWtyXdjkYv4edL23oSEgfBT1gPJ5BQ=
-github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
-golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/models/mocks/users.go b/internal/models/mocks/users.go
index 2e5ff45..e0ab8a8 100644
--- a/internal/models/mocks/users.go
+++ b/internal/models/mocks/users.go
@@ -38,6 +38,15 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo
}
}
+func (m *UserModel) InsertWithoutPassword(shortId uint64, username string, email string, password string, settings models.UserSettings) (int64, error) {
+ switch email {
+ case "dupe@example.com":
+ return -1, models.ErrDuplicateEmail
+ default:
+ return 2, nil
+ }
+}
+
func (m *UserModel) Get(shortId uint64) (models.User, error) {
switch shortId {
case 1:
@@ -67,6 +76,13 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
return 0, models.ErrInvalidCredentials
}
+func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
+ if email == "test@example.com" {
+ return 1, nil
+ }
+ return 0, models.ErrInvalidCredentials
+}
+
func (m *UserModel) Exists(id int64) (bool, error) {
switch id {
case 1:
diff --git a/internal/models/user.go b/internal/models/user.go
index f7c5ab0..12a86e0 100644
--- a/internal/models/user.go
+++ b/internal/models/user.go
@@ -35,13 +35,23 @@ type UserModel struct {
Settings map[string]Setting
}
+type UserIdToken struct {
+ Subject string `json:"sub"`
+ Email string `json:"email"`
+ EmailVerified bool `json:"email_verified"`
+ Username string `json:"preferred_username"`
+ Groups []string `json:"groups"`
+}
+
type UserModelInterface interface {
InitializeSettingsMap() error
Insert(shortId uint64, username string, email string, password string, settings UserSettings) error
+ InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
Get(shortId uint64) (User, error)
GetById(id int64) (User, error)
GetAll() ([]User, error)
Authenticate(email, password string) (int64, error)
+ AuthenticateByOIDC(email, subject string) (int64, error)
Exists(id int64) (bool, error)
UpdateUserSettings(userId int64, settings UserSettings) error
UpdateSetting(userId int64, setting Setting, value string) error
@@ -126,6 +136,49 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo
return nil
}
+func (m *UserModel) InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error) {
+ stmt := `INSERT INTO users (ShortId, Username, Email, IsBanned, OIDCSubject, Created)
+ VALUES (?, ?, ?, FALSE, ?, ?)`
+ tx, err := m.DB.Begin()
+ if err != nil {
+ return -1, err
+ }
+ result, err := tx.Exec(stmt, shortId, username, email, subject, time.Now().UTC())
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ if sqliteError, ok := err.(sqlite3.Error); ok {
+ if sqliteError.ExtendedCode == 2067 && strings.Contains(sqliteError.Error(), "Email") {
+ return -1, ErrDuplicateEmail
+ }
+ }
+ return -1, err
+ }
+ id, err := result.LastInsertId()
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ return -1, err
+ }
+ stmt = `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue)
+ VALUES (?, ?, ?, ?)`
+ _, err = tx.Exec(stmt, id, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String())
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ return -1, err
+ }
+ err = tx.Commit()
+ if err != nil {
+ return -1, err
+ }
+
+ return id, nil
+}
+
func (m *UserModel) Get(shortId uint64) (User, error) {
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE ShortId = ? AND Deleted IS NULL`
tx, err := m.DB.Begin()
@@ -239,6 +292,51 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
return id, nil
}
+func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) {
+ var id int64
+ var s sql.NullString
+ tx, err := m.DB.Begin()
+ if err != nil {
+ return -1, err
+ }
+ stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
+ err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ return -1, ErrNoRecord
+ } else {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ return -1, err
+ }
+ }
+
+ if !s.Valid {
+ stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
+ _, err = tx.Exec(stmt, subject, id)
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, err
+ }
+ return -1, err
+ }
+ } else if subject != s.String {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return -1, ErrInvalidCredentials
+ }
+ }
+
+ err = tx.Commit()
+ if err != nil {
+ return -1, err
+ }
+ return id, nil
+}
+
func (m *UserModel) Exists(id int64) (bool, error) {
var exists bool
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
diff --git a/migrations/000006_change_password_field.down.sql b/migrations/000006_change_password_field.down.sql
new file mode 100644
index 0000000..794b222
--- /dev/null
+++ b/migrations/000006_change_password_field.down.sql
@@ -0,0 +1,4 @@
+ALTER TABLE users RENAME COLUMN HashedPassword TO HashedPasswordOld;
+ALTER TABLE users ADD COLUMN HashedPassword char(60) NOT NULL DEFAULT '0000';
+UPDATE users SET HashedPassword=HashedPasswordOld;
+ALTER TABLE users DROP COLUMN HashedPasswordOld;
diff --git a/migrations/000006_change_password_field.up.sql b/migrations/000006_change_password_field.up.sql
new file mode 100644
index 0000000..e80c883
--- /dev/null
+++ b/migrations/000006_change_password_field.up.sql
@@ -0,0 +1,4 @@
+ALTER TABLE users RENAME COLUMN HashedPassword TO HashedPasswordOld;
+ALTER TABLE users ADD COLUMN HashedPassword char(60) NULL;
+UPDATE users SET HashedPassword=HashedPasswordOld;
+ALTER TABLE users DROP COLUMN HashedPasswordOld;
diff --git a/migrations/000007_add_oidc_subject.down.sql b/migrations/000007_add_oidc_subject.down.sql
new file mode 100644
index 0000000..4a6a7e9
--- /dev/null
+++ b/migrations/000007_add_oidc_subject.down.sql
@@ -0,0 +1 @@
+ALTER TABLE users DROP COLUMN OIDCSubject;
diff --git a/migrations/000007_add_oidc_subject.up.sql b/migrations/000007_add_oidc_subject.up.sql
new file mode 100644
index 0000000..6edefa4
--- /dev/null
+++ b/migrations/000007_add_oidc_subject.up.sql
@@ -0,0 +1 @@
+ALTER TABLE users ADD COLUMN OIDCSubject varchar(255);
diff --git a/ui/views/common.templ b/ui/views/common.templ
index a9111f1..5df9f50 100644
--- a/ui/views/common.templ
+++ b/ui/views/common.templ
@@ -12,6 +12,7 @@ type CommonData struct {
CSRFToken string
CurrentUser *models.User
IsHtmx bool
+ RootUrl string
}
func shortIdToSlug(shortId uint64) string {
diff --git a/ui/views/common_templ.go b/ui/views/common_templ.go
index e79039c..6a7d8aa 100644
--- a/ui/views/common_templ.go
+++ b/ui/views/common_templ.go
@@ -102,7 +102,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var3 string
templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(data.CurrentUser.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 44, Col: 40}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 45, Col: 40}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
if templ_7745c5c3_Err != nil {
@@ -121,7 +121,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var4 string
templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(hxHeaders)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 52, Col: 62}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 53, Col: 62}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
if templ_7745c5c3_Err != nil {
@@ -202,7 +202,7 @@ func base(title string, data CommonData) templ.Component {
var templ_7745c5c3_Var7 string
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(title)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 71, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 72, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
if templ_7745c5c3_Err != nil {
@@ -232,7 +232,7 @@ func base(title string, data CommonData) templ.Component {
var templ_7745c5c3_Var8 string
templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(data.Flash)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 83, Col: 36}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 84, Col: 36}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
if templ_7745c5c3_Err != nil {
diff --git a/ui/views/users.templ b/ui/views/users.templ
index 4fb839b..a0d60b3 100644
--- a/ui/views/users.templ
+++ b/ui/views/users.templ
@@ -30,7 +30,8 @@ templ UserLogin(title string, data CommonData, form forms.UserLoginForm) {
}
diff --git a/ui/views/users_templ.go b/ui/views/users_templ.go
index c791e2d..391f030 100644
--- a/ui/views/users_templ.go
+++ b/ui/views/users_templ.go
@@ -143,7 +143,7 @@ func UserLogin(title string, data CommonData, form forms.UserLoginForm) templ.Co
return templ_7745c5c3_Err
}
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -197,7 +197,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var10 string
templ_7745c5c3_Var10, templ_7745c5c3_Err = templ.JoinStringErrs(data.CSRFToken)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 43, Col: 64}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 44, Col: 64}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var10))
if templ_7745c5c3_Err != nil {
@@ -220,7 +220,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var11 string
templ_7745c5c3_Var11, templ_7745c5c3_Err = templ.JoinStringErrs(error)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 48, Col: 33}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 49, Col: 33}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var11))
if templ_7745c5c3_Err != nil {
@@ -238,7 +238,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var12 string
templ_7745c5c3_Var12, templ_7745c5c3_Err = templ.JoinStringErrs(form.Name)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 50, Col: 70}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 51, Col: 70}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var12))
if templ_7745c5c3_Err != nil {
@@ -261,7 +261,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var13 string
templ_7745c5c3_Var13, templ_7745c5c3_Err = templ.JoinStringErrs(error)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 56, Col: 33}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 57, Col: 33}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var13))
if templ_7745c5c3_Err != nil {
@@ -279,7 +279,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var14 string
templ_7745c5c3_Var14, templ_7745c5c3_Err = templ.JoinStringErrs(form.Email)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 58, Col: 65}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 59, Col: 65}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var14))
if templ_7745c5c3_Err != nil {
@@ -302,7 +302,7 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
var templ_7745c5c3_Var15 string
templ_7745c5c3_Var15, templ_7745c5c3_Err = templ.JoinStringErrs(error)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 64, Col: 33}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 65, Col: 33}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var15))
if templ_7745c5c3_Err != nil {
@@ -367,7 +367,7 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
var templ_7745c5c3_Var18 string
templ_7745c5c3_Var18, templ_7745c5c3_Err = templ.JoinStringErrs(user.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 77, Col: 21}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 21}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var18))
if templ_7745c5c3_Err != nil {
@@ -380,7 +380,7 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
var templ_7745c5c3_Var19 string
templ_7745c5c3_Var19, templ_7745c5c3_Err = templ.JoinStringErrs(user.Email)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 79, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var19))
if templ_7745c5c3_Err != nil {
@@ -441,7 +441,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
var templ_7745c5c3_Var22 string
templ_7745c5c3_Var22, templ_7745c5c3_Err = templ.JoinStringErrs(data.CSRFToken)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 88, Col: 65}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 89, Col: 65}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var22))
if templ_7745c5c3_Err != nil {
@@ -460,7 +460,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
var templ_7745c5c3_Var23 string
templ_7745c5c3_Var23, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 93, Col: 25}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 94, Col: 25}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var23))
if templ_7745c5c3_Err != nil {
@@ -473,7 +473,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
var templ_7745c5c3_Var24 string
templ_7745c5c3_Var24, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 93, Col: 48}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 94, Col: 48}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var24))
if templ_7745c5c3_Err != nil {
@@ -491,7 +491,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
var templ_7745c5c3_Var25 string
templ_7745c5c3_Var25, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 95, Col: 25}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 96, Col: 25}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var25))
if templ_7745c5c3_Err != nil {
@@ -504,7 +504,7 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
var templ_7745c5c3_Var26 string
templ_7745c5c3_Var26, templ_7745c5c3_Err = templ.JoinStringErrs(tz)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 95, Col: 32}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 96, Col: 32}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var26))
if templ_7745c5c3_Err != nil {
--
2.30.2
From c56a445c6a0641485591bb0ed59d97f3cf9ccadb Mon Sep 17 00:00:00 2001
From: yequari
Date: Sun, 29 Jun 2025 18:12:32 -0700
Subject: [PATCH 2/3] add app configuration
---
cmd/web/handlers_user.go | 15 +++++--
cmd/web/helpers.go | 16 ++++---
cmd/web/main.go | 96 +++++++++++++++++++++++++++++-----------
ui/views/common.templ | 20 +++++----
ui/views/common_templ.go | 50 ++++++++++++---------
ui/views/users.templ | 4 +-
ui/views/users_templ.go | 94 +++++++++++++++++++++------------------
7 files changed, 186 insertions(+), 109 deletions(-)
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 3c21fde..141d459 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -13,12 +13,18 @@ import (
)
func (app *application) getUserRegister(w http.ResponseWriter, r *http.Request) {
+ if !app.config.localAuthEnabled {
+ http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+ }
form := forms.UserRegistrationForm{}
data := app.newCommonData(r)
views.UserRegistration("User Registration", data, form).Render(r.Context(), w)
}
func (app *application) getUserLogin(w http.ResponseWriter, r *http.Request) {
+ if !app.config.localAuthEnabled {
+ http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+ }
views.UserLogin("Login", app.newCommonData(r), forms.UserLoginForm{}).Render(r.Context(), w)
}
@@ -94,6 +100,9 @@ func (app *application) postUserLogin(w http.ResponseWriter, r *http.Request) {
}
func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
+ if !app.config.oauthEnabled {
+ http.Redirect(w, r, "/users/login", http.StatusFound)
+ }
state, err := randString(16)
if err != nil {
app.serverError(w, r, err)
@@ -108,7 +117,7 @@ func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
setCallbackCookie(w, r, "state", state)
setCallbackCookie(w, r, "nonce", nonce)
- http.Redirect(w, r, app.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
+ http.Redirect(w, r, app.config.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
}
func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Request) {
@@ -122,7 +131,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
return
}
- oauth2Token, err := app.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
+ oauth2Token, err := app.config.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
app.logger.Error("Failed to exchange token")
app.serverError(w, r, err)
@@ -133,7 +142,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
app.serverError(w, r, errors.New("No id_token field in oauth2 token"))
return
}
- idToken, err := app.oauth.verifier.Verify(r.Context(), rawIDToken)
+ idToken, err := app.config.oauth.verifier.Verify(r.Context(), rawIDToken)
if err != nil {
app.logger.Error("Failed to verify ID token")
app.serverError(w, r, err)
diff --git a/cmd/web/helpers.go b/cmd/web/helpers.go
index 239aea2..8e66c3c 100644
--- a/cmd/web/helpers.go
+++ b/cmd/web/helpers.go
@@ -107,13 +107,15 @@ func (app *application) getCurrentUser(r *http.Request) *models.User {
func (app *application) newCommonData(r *http.Request) views.CommonData {
return views.CommonData{
- CurrentYear: time.Now().Year(),
- Flash: app.sessionManager.PopString(r.Context(), "flash"),
- IsAuthenticated: app.isAuthenticated(r),
- CSRFToken: nosurf.Token(r),
- CurrentUser: app.getCurrentUser(r),
- IsHtmx: r.Header.Get("Hx-Request") == "true",
- RootUrl: app.rootUrl,
+ CurrentYear: time.Now().Year(),
+ Flash: app.sessionManager.PopString(r.Context(), "flash"),
+ IsAuthenticated: app.isAuthenticated(r),
+ CSRFToken: nosurf.Token(r),
+ CurrentUser: app.getCurrentUser(r),
+ IsHtmx: r.Header.Get("Hx-Request") == "true",
+ RootUrl: app.config.rootUrl,
+ LocalAuthEnabled: app.config.localAuthEnabled,
+ OIDCEnabled: app.config.oauthEnabled,
}
}
diff --git a/cmd/web/main.go b/cmd/web/main.go
index e999e92..7df1d2b 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -9,7 +9,9 @@ import (
"fmt"
"log/slog"
"net/http"
+ "net/url"
"os"
+ "strconv"
"strings"
"time"
"unicode"
@@ -32,6 +34,13 @@ type applicationOauthConfig struct {
verifier *oidc.IDTokenVerifier
}
+type applicationConfig struct {
+ oauthEnabled bool
+ localAuthEnabled bool
+ oauth applicationOauthConfig
+ rootUrl string
+}
+
type application struct {
sequence uint16
logger *slog.Logger
@@ -40,21 +49,24 @@ type application struct {
guestbookComments models.GuestbookCommentModelInterface
sessionManager *scs.SessionManager
formDecoder *schema.Decoder
- oauth applicationOauthConfig
+ config applicationConfig
debug bool
timezones []string
- rootUrl string
}
func main() {
addr := flag.String("addr", ":3000", "HTTP network address")
dsn := flag.String("dsn", "guestbook.db", "data source name")
debug := flag.Bool("debug", false, "enable debug mode")
- root := flag.String("root", "https://localhost:3000", "root URL of application")
flag.Parse()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
godotenv.Load(".env.dev")
+ cfg, err := setupConfig(*addr)
+ if err != nil {
+ logger.Error(err.Error())
+ os.Exit(1)
+ }
db, err := openDB(*dsn)
if err != nil {
@@ -78,18 +90,11 @@ func main() {
users: &models.UserModel{DB: db, Settings: make(map[string]models.Setting)},
guestbookComments: &models.GuestbookCommentModel{DB: db},
formDecoder: formDecoder,
+ config: cfg,
debug: *debug,
timezones: getAvailableTimezones(),
- rootUrl: *root,
}
- o, err := setupOauth(app.rootUrl)
- if err != nil {
- logger.Error(err.Error())
- os.Exit(1)
- }
- app.oauth = o
-
err = app.users.InitializeSettingsMap()
if err != nil {
logger.Error(err.Error())
@@ -137,36 +142,73 @@ func openDB(dsn string) (*sql.DB, error) {
return db, nil
}
-func setupOauth(rootUrl string) (applicationOauthConfig, error) {
- var c applicationOauthConfig
+func setupConfig(addr string) (applicationConfig, error) {
+ var c applicationConfig
+
var (
- oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
- clientID = os.Getenv("OAUTH2_CLIENT_ID")
- clientSecret = os.Getenv("OAUTH2_CLIENT_SECRET")
+ rootUrl = os.Getenv("ROOT_URL")
+ oidcEnabled = os.Getenv("ENABLE_OIDC")
+ localLoginEnabled = os.Getenv("ENABLE_LOCAL_LOGIN")
+ oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
+ clientID = os.Getenv("OAUTH2_CLIENT_ID")
+ clientSecret = os.Getenv("OAUTH2_CLIENT_SECRET")
)
- if oauth2Provider == "" || clientID == "" || clientSecret == "" {
- return applicationOauthConfig{}, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+ if rootUrl != "" {
+ c.rootUrl = rootUrl
+ } else {
+ u, err := url.Parse(fmt.Sprintf("https://localhost%s", addr))
+ if err != nil {
+ return c, err
+ }
+ c.rootUrl = u.String()
}
- c.ctx = context.Background()
- provider, err := oidc.NewProvider(c.ctx, oauth2Provider)
+ oauthEnabled, err := strconv.ParseBool(oidcEnabled)
if err != nil {
- return applicationOauthConfig{}, err
+ c.oauthEnabled = false
}
- c.provider = provider
- c.oidcConfig = &oidc.Config{
+ c.oauthEnabled = oauthEnabled
+
+ localAuthEnabled, err := strconv.ParseBool(localLoginEnabled)
+ if err != nil {
+ c.localAuthEnabled = true
+ }
+ c.localAuthEnabled = localAuthEnabled
+
+ if !c.oauthEnabled && !c.localAuthEnabled {
+ return c, errors.New("Either ENABLE_OIDC or ENABLE_LOCAL_LOGIN must be set to true")
+ }
+
+ // if OIDC is disabled, no more configuration needs to be read
+ if !oauthEnabled {
+ return c, nil
+ }
+
+ var o applicationOauthConfig
+ if oauth2Provider == "" || clientID == "" || clientSecret == "" {
+ return c, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+ }
+
+ o.ctx = context.Background()
+ provider, err := oidc.NewProvider(o.ctx, oauth2Provider)
+ if err != nil {
+ return c, err
+ }
+ o.provider = provider
+ o.oidcConfig = &oidc.Config{
ClientID: clientID,
}
- c.verifier = provider.Verifier(c.oidcConfig)
- c.config = oauth2.Config{
+ o.verifier = provider.Verifier(o.oidcConfig)
+ o.config = oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: provider.Endpoint(),
- RedirectURL: fmt.Sprintf("%s/users/login/oidc/callback", rootUrl),
+ RedirectURL: fmt.Sprintf("%s/users/login/oidc/callback", c.rootUrl),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
- return c, nil
+ c.oauth = o
+ return c, nil
}
func getAvailableTimezones() []string {
diff --git a/ui/views/common.templ b/ui/views/common.templ
index 5df9f50..bafe90f 100644
--- a/ui/views/common.templ
+++ b/ui/views/common.templ
@@ -6,13 +6,15 @@ import "fmt"
import "strings"
type CommonData struct {
- CurrentYear int
- Flash string
- IsAuthenticated bool
- CSRFToken string
- CurrentUser *models.User
- IsHtmx bool
- RootUrl string
+ CurrentYear int
+ Flash string
+ IsAuthenticated bool
+ CSRFToken string
+ CurrentUser *models.User
+ IsHtmx bool
+ RootUrl string
+ LocalAuthEnabled bool
+ OIDCEnabled bool
}
func shortIdToSlug(shortId uint64) string {
@@ -52,7 +54,9 @@ templ topNav(data CommonData) {
Settings |
Logout
} else {
- Create an Account |
+ if data.LocalAuthEnabled {
+ Create an Account |
+ }
Login
}
diff --git a/ui/views/common_templ.go b/ui/views/common_templ.go
index 6a7d8aa..e4131f2 100644
--- a/ui/views/common_templ.go
+++ b/ui/views/common_templ.go
@@ -14,13 +14,15 @@ import "fmt"
import "strings"
type CommonData struct {
- CurrentYear int
- Flash string
- IsAuthenticated bool
- CSRFToken string
- CurrentUser *models.User
- IsHtmx bool
- RootUrl string
+ CurrentYear int
+ Flash string
+ IsAuthenticated bool
+ CSRFToken string
+ CurrentUser *models.User
+ IsHtmx bool
+ RootUrl string
+ LocalAuthEnabled bool
+ OIDCEnabled bool
}
func shortIdToSlug(shortId uint64) string {
@@ -102,7 +104,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var3 string
templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(data.CurrentUser.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 45, Col: 40}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 47, Col: 40}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
if templ_7745c5c3_Err != nil {
@@ -121,7 +123,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var4 string
templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(hxHeaders)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 53, Col: 62}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 55, Col: 62}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
if templ_7745c5c3_Err != nil {
@@ -132,12 +134,18 @@ func topNav(data CommonData) templ.Component {
return templ_7745c5c3_Err
}
} else {
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account | Login")
+ if data.LocalAuthEnabled {
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account | ")
+ if templ_7745c5c3_Err != nil {
+ return templ_7745c5c3_Err
+ }
+ }
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, " Login")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -166,7 +174,7 @@ func commonFooter() templ.Component {
templ_7745c5c3_Var5 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -195,20 +203,20 @@ func base(title string, data CommonData) templ.Component {
templ_7745c5c3_Var6 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var7 string
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(title)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 72, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 76, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, " - webweav.ing")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, " - webweav.ing")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -220,25 +228,25 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
if data.Flash != "" {
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var8 string
templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(data.Flash)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 84, Col: 36}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 88, Col: 36}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -247,7 +255,7 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -255,7 +263,7 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 17, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
diff --git a/ui/views/users.templ b/ui/views/users.templ
index a0d60b3..299d282 100644
--- a/ui/views/users.templ
+++ b/ui/views/users.templ
@@ -31,7 +31,9 @@ templ UserLogin(title string, data CommonData, form forms.UserLoginForm) {
}
diff --git a/ui/views/users_templ.go b/ui/views/users_templ.go
index 391f030..ec96f47 100644
--- a/ui/views/users_templ.go
+++ b/ui/views/users_templ.go
@@ -143,7 +143,17 @@ func UserLogin(title string, data CommonData, form forms.UserLoginForm) templ.Co
return templ_7745c5c3_Err
}
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -190,130 +200,130 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "User Registration
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -360,33 +370,33 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 29, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var18 string
templ_7745c5c3_Var18, templ_7745c5c3_Err = templ.JoinStringErrs(user.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 21}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 80, Col: 21}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var18))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 30, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var19 string
templ_7745c5c3_Var19, templ_7745c5c3_Err = templ.JoinStringErrs(user.Email)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 79, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 81, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var19))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 33, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -434,89 +444,89 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "User Settings
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 42, " ")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
--
2.30.2
From db1d4e1ad2fe4f38d2762cf0b05403aaaccea2d6 Mon Sep 17 00:00:00 2001
From: yequari
Date: Sun, 29 Jun 2025 20:16:37 -0700
Subject: [PATCH 3/3] env file config
---
cmd/web/main.go | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/cmd/web/main.go b/cmd/web/main.go
index 7df1d2b..b764069 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -58,10 +58,15 @@ func main() {
addr := flag.String("addr", ":3000", "HTTP network address")
dsn := flag.String("dsn", "guestbook.db", "data source name")
debug := flag.Bool("debug", false, "enable debug mode")
+ env := flag.String("env", ".env", ".env file path")
flag.Parse()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
- godotenv.Load(".env.dev")
+ err := godotenv.Load(*env)
+ if err != nil {
+ logger.Error(err.Error())
+ os.Exit(1)
+ }
cfg, err := setupConfig(*addr)
if err != nil {
logger.Error(err.Error())
--
2.30.2