From c56a445c6a0641485591bb0ed59d97f3cf9ccadb Mon Sep 17 00:00:00 2001
From: yequari
Date: Sun, 29 Jun 2025 18:12:32 -0700
Subject: [PATCH] add app configuration
---
cmd/web/handlers_user.go | 15 +++++--
cmd/web/helpers.go | 16 ++++---
cmd/web/main.go | 96 +++++++++++++++++++++++++++++-----------
ui/views/common.templ | 20 +++++----
ui/views/common_templ.go | 50 ++++++++++++---------
ui/views/users.templ | 4 +-
ui/views/users_templ.go | 94 +++++++++++++++++++++------------------
7 files changed, 186 insertions(+), 109 deletions(-)
diff --git a/cmd/web/handlers_user.go b/cmd/web/handlers_user.go
index 3c21fde..141d459 100644
--- a/cmd/web/handlers_user.go
+++ b/cmd/web/handlers_user.go
@@ -13,12 +13,18 @@ import (
)
func (app *application) getUserRegister(w http.ResponseWriter, r *http.Request) {
+ if !app.config.localAuthEnabled {
+ http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+ }
form := forms.UserRegistrationForm{}
data := app.newCommonData(r)
views.UserRegistration("User Registration", data, form).Render(r.Context(), w)
}
func (app *application) getUserLogin(w http.ResponseWriter, r *http.Request) {
+ if !app.config.localAuthEnabled {
+ http.Redirect(w, r, "/users/login/oidc", http.StatusFound)
+ }
views.UserLogin("Login", app.newCommonData(r), forms.UserLoginForm{}).Render(r.Context(), w)
}
@@ -94,6 +100,9 @@ func (app *application) postUserLogin(w http.ResponseWriter, r *http.Request) {
}
func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
+ if !app.config.oauthEnabled {
+ http.Redirect(w, r, "/users/login", http.StatusFound)
+ }
state, err := randString(16)
if err != nil {
app.serverError(w, r, err)
@@ -108,7 +117,7 @@ func (app *application) userLoginOIDC(w http.ResponseWriter, r *http.Request) {
setCallbackCookie(w, r, "state", state)
setCallbackCookie(w, r, "nonce", nonce)
- http.Redirect(w, r, app.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
+ http.Redirect(w, r, app.config.oauth.config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound)
}
func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Request) {
@@ -122,7 +131,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
return
}
- oauth2Token, err := app.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
+ oauth2Token, err := app.config.oauth.config.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
app.logger.Error("Failed to exchange token")
app.serverError(w, r, err)
@@ -133,7 +142,7 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
app.serverError(w, r, errors.New("No id_token field in oauth2 token"))
return
}
- idToken, err := app.oauth.verifier.Verify(r.Context(), rawIDToken)
+ idToken, err := app.config.oauth.verifier.Verify(r.Context(), rawIDToken)
if err != nil {
app.logger.Error("Failed to verify ID token")
app.serverError(w, r, err)
diff --git a/cmd/web/helpers.go b/cmd/web/helpers.go
index 239aea2..8e66c3c 100644
--- a/cmd/web/helpers.go
+++ b/cmd/web/helpers.go
@@ -107,13 +107,15 @@ func (app *application) getCurrentUser(r *http.Request) *models.User {
func (app *application) newCommonData(r *http.Request) views.CommonData {
return views.CommonData{
- CurrentYear: time.Now().Year(),
- Flash: app.sessionManager.PopString(r.Context(), "flash"),
- IsAuthenticated: app.isAuthenticated(r),
- CSRFToken: nosurf.Token(r),
- CurrentUser: app.getCurrentUser(r),
- IsHtmx: r.Header.Get("Hx-Request") == "true",
- RootUrl: app.rootUrl,
+ CurrentYear: time.Now().Year(),
+ Flash: app.sessionManager.PopString(r.Context(), "flash"),
+ IsAuthenticated: app.isAuthenticated(r),
+ CSRFToken: nosurf.Token(r),
+ CurrentUser: app.getCurrentUser(r),
+ IsHtmx: r.Header.Get("Hx-Request") == "true",
+ RootUrl: app.config.rootUrl,
+ LocalAuthEnabled: app.config.localAuthEnabled,
+ OIDCEnabled: app.config.oauthEnabled,
}
}
diff --git a/cmd/web/main.go b/cmd/web/main.go
index e999e92..7df1d2b 100644
--- a/cmd/web/main.go
+++ b/cmd/web/main.go
@@ -9,7 +9,9 @@ import (
"fmt"
"log/slog"
"net/http"
+ "net/url"
"os"
+ "strconv"
"strings"
"time"
"unicode"
@@ -32,6 +34,13 @@ type applicationOauthConfig struct {
verifier *oidc.IDTokenVerifier
}
+type applicationConfig struct {
+ oauthEnabled bool
+ localAuthEnabled bool
+ oauth applicationOauthConfig
+ rootUrl string
+}
+
type application struct {
sequence uint16
logger *slog.Logger
@@ -40,21 +49,24 @@ type application struct {
guestbookComments models.GuestbookCommentModelInterface
sessionManager *scs.SessionManager
formDecoder *schema.Decoder
- oauth applicationOauthConfig
+ config applicationConfig
debug bool
timezones []string
- rootUrl string
}
func main() {
addr := flag.String("addr", ":3000", "HTTP network address")
dsn := flag.String("dsn", "guestbook.db", "data source name")
debug := flag.Bool("debug", false, "enable debug mode")
- root := flag.String("root", "https://localhost:3000", "root URL of application")
flag.Parse()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
godotenv.Load(".env.dev")
+ cfg, err := setupConfig(*addr)
+ if err != nil {
+ logger.Error(err.Error())
+ os.Exit(1)
+ }
db, err := openDB(*dsn)
if err != nil {
@@ -78,18 +90,11 @@ func main() {
users: &models.UserModel{DB: db, Settings: make(map[string]models.Setting)},
guestbookComments: &models.GuestbookCommentModel{DB: db},
formDecoder: formDecoder,
+ config: cfg,
debug: *debug,
timezones: getAvailableTimezones(),
- rootUrl: *root,
}
- o, err := setupOauth(app.rootUrl)
- if err != nil {
- logger.Error(err.Error())
- os.Exit(1)
- }
- app.oauth = o
-
err = app.users.InitializeSettingsMap()
if err != nil {
logger.Error(err.Error())
@@ -137,36 +142,73 @@ func openDB(dsn string) (*sql.DB, error) {
return db, nil
}
-func setupOauth(rootUrl string) (applicationOauthConfig, error) {
- var c applicationOauthConfig
+func setupConfig(addr string) (applicationConfig, error) {
+ var c applicationConfig
+
var (
- oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
- clientID = os.Getenv("OAUTH2_CLIENT_ID")
- clientSecret = os.Getenv("OAUTH2_CLIENT_SECRET")
+ rootUrl = os.Getenv("ROOT_URL")
+ oidcEnabled = os.Getenv("ENABLE_OIDC")
+ localLoginEnabled = os.Getenv("ENABLE_LOCAL_LOGIN")
+ oauth2Provider = os.Getenv("OAUTH2_PROVIDER")
+ clientID = os.Getenv("OAUTH2_CLIENT_ID")
+ clientSecret = os.Getenv("OAUTH2_CLIENT_SECRET")
)
- if oauth2Provider == "" || clientID == "" || clientSecret == "" {
- return applicationOauthConfig{}, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+ if rootUrl != "" {
+ c.rootUrl = rootUrl
+ } else {
+ u, err := url.Parse(fmt.Sprintf("https://localhost%s", addr))
+ if err != nil {
+ return c, err
+ }
+ c.rootUrl = u.String()
}
- c.ctx = context.Background()
- provider, err := oidc.NewProvider(c.ctx, oauth2Provider)
+ oauthEnabled, err := strconv.ParseBool(oidcEnabled)
if err != nil {
- return applicationOauthConfig{}, err
+ c.oauthEnabled = false
}
- c.provider = provider
- c.oidcConfig = &oidc.Config{
+ c.oauthEnabled = oauthEnabled
+
+ localAuthEnabled, err := strconv.ParseBool(localLoginEnabled)
+ if err != nil {
+ c.localAuthEnabled = true
+ }
+ c.localAuthEnabled = localAuthEnabled
+
+ if !c.oauthEnabled && !c.localAuthEnabled {
+ return c, errors.New("Either ENABLE_OIDC or ENABLE_LOCAL_LOGIN must be set to true")
+ }
+
+ // if OIDC is disabled, no more configuration needs to be read
+ if !oauthEnabled {
+ return c, nil
+ }
+
+ var o applicationOauthConfig
+ if oauth2Provider == "" || clientID == "" || clientSecret == "" {
+ return c, errors.New("OAUTH2_PROVIDER, OAUTH2_CLIENT_ID, and OAUTH2_CLIENT_SECRET must be specified as environment variables.")
+ }
+
+ o.ctx = context.Background()
+ provider, err := oidc.NewProvider(o.ctx, oauth2Provider)
+ if err != nil {
+ return c, err
+ }
+ o.provider = provider
+ o.oidcConfig = &oidc.Config{
ClientID: clientID,
}
- c.verifier = provider.Verifier(c.oidcConfig)
- c.config = oauth2.Config{
+ o.verifier = provider.Verifier(o.oidcConfig)
+ o.config = oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: provider.Endpoint(),
- RedirectURL: fmt.Sprintf("%s/users/login/oidc/callback", rootUrl),
+ RedirectURL: fmt.Sprintf("%s/users/login/oidc/callback", c.rootUrl),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
- return c, nil
+ c.oauth = o
+ return c, nil
}
func getAvailableTimezones() []string {
diff --git a/ui/views/common.templ b/ui/views/common.templ
index 5df9f50..bafe90f 100644
--- a/ui/views/common.templ
+++ b/ui/views/common.templ
@@ -6,13 +6,15 @@ import "fmt"
import "strings"
type CommonData struct {
- CurrentYear int
- Flash string
- IsAuthenticated bool
- CSRFToken string
- CurrentUser *models.User
- IsHtmx bool
- RootUrl string
+ CurrentYear int
+ Flash string
+ IsAuthenticated bool
+ CSRFToken string
+ CurrentUser *models.User
+ IsHtmx bool
+ RootUrl string
+ LocalAuthEnabled bool
+ OIDCEnabled bool
}
func shortIdToSlug(shortId uint64) string {
@@ -52,7 +54,9 @@ templ topNav(data CommonData) {
Settings |
Logout
} else {
- Create an Account |
+ if data.LocalAuthEnabled {
+ Create an Account |
+ }
Login
}
diff --git a/ui/views/common_templ.go b/ui/views/common_templ.go
index 6a7d8aa..e4131f2 100644
--- a/ui/views/common_templ.go
+++ b/ui/views/common_templ.go
@@ -14,13 +14,15 @@ import "fmt"
import "strings"
type CommonData struct {
- CurrentYear int
- Flash string
- IsAuthenticated bool
- CSRFToken string
- CurrentUser *models.User
- IsHtmx bool
- RootUrl string
+ CurrentYear int
+ Flash string
+ IsAuthenticated bool
+ CSRFToken string
+ CurrentUser *models.User
+ IsHtmx bool
+ RootUrl string
+ LocalAuthEnabled bool
+ OIDCEnabled bool
}
func shortIdToSlug(shortId uint64) string {
@@ -102,7 +104,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var3 string
templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(data.CurrentUser.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 45, Col: 40}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 47, Col: 40}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
if templ_7745c5c3_Err != nil {
@@ -121,7 +123,7 @@ func topNav(data CommonData) templ.Component {
var templ_7745c5c3_Var4 string
templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(hxHeaders)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 53, Col: 62}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 55, Col: 62}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
if templ_7745c5c3_Err != nil {
@@ -132,12 +134,18 @@ func topNav(data CommonData) templ.Component {
return templ_7745c5c3_Err
}
} else {
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account | Login")
+ if data.LocalAuthEnabled {
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "Create an Account | ")
+ if templ_7745c5c3_Err != nil {
+ return templ_7745c5c3_Err
+ }
+ }
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, " Login")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -166,7 +174,7 @@ func commonFooter() templ.Component {
templ_7745c5c3_Var5 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -195,20 +203,20 @@ func base(title string, data CommonData) templ.Component {
templ_7745c5c3_Var6 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var7 string
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(title)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 72, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 76, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, " - webweav.ing")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, " - webweav.ing")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -220,25 +228,25 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
if data.Flash != "" {
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var8 string
templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(data.Flash)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 84, Col: 36}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/common.templ`, Line: 88, Col: 36}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -247,7 +255,7 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -255,7 +263,7 @@ func base(title string, data CommonData) templ.Component {
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 17, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
diff --git a/ui/views/users.templ b/ui/views/users.templ
index a0d60b3..299d282 100644
--- a/ui/views/users.templ
+++ b/ui/views/users.templ
@@ -31,7 +31,9 @@ templ UserLogin(title string, data CommonData, form forms.UserLoginForm) {
}
diff --git a/ui/views/users_templ.go b/ui/views/users_templ.go
index 391f030..ec96f47 100644
--- a/ui/views/users_templ.go
+++ b/ui/views/users_templ.go
@@ -143,7 +143,17 @@ func UserLogin(title string, data CommonData, form forms.UserLoginForm) templ.Co
return templ_7745c5c3_Err
}
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -190,130 +200,130 @@ func UserRegistration(title string, data CommonData, form forms.UserRegistration
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "User Registration
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -360,33 +370,33 @@ func UserProfile(title string, data CommonData, user models.User) templ.Componen
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 29, "")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var18 string
templ_7745c5c3_Var18, templ_7745c5c3_Err = templ.JoinStringErrs(user.Username)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 78, Col: 21}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 80, Col: 21}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var18))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 30, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var19 string
templ_7745c5c3_Var19, templ_7745c5c3_Err = templ.JoinStringErrs(user.Email)
if templ_7745c5c3_Err != nil {
- return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 79, Col: 17}
+ return templ.Error{Err: templ_7745c5c3_Err, FileName: `ui/views/users.templ`, Line: 81, Col: 17}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var19))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 31, "
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 33, "
")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -434,89 +444,89 @@ func UserSettingsView(data CommonData, timezones []string) templ.Component {
}()
}
ctx = templ.InitializeContext(ctx)
- templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 32, "User Settings
")
+ templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 42, " ")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}