unit testing for oidc and fix test cases for remote comments
This commit is contained in:
parent
db1d4e1ad2
commit
f6e332b76a
@ -1,9 +1,12 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
|
"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
|
||||||
@ -150,9 +153,6 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
ts := newTestServer(t, app.routes())
|
ts := newTestServer(t, app.routes())
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
_, _, body := ts.get(t, fmt.Sprintf("/websites/%s/guestbook", shortIdToSlug(1)))
|
|
||||||
validCSRFToken := extractCSRFToken(t, body)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
validAuthorName = "John Test"
|
validAuthorName = "John Test"
|
||||||
validAuthorEmail = "test@example.com"
|
validAuthorEmail = "test@example.com"
|
||||||
@ -166,8 +166,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail string
|
authorEmail string
|
||||||
authorSite string
|
authorSite string
|
||||||
content string
|
content string
|
||||||
csrfToken string
|
|
||||||
wantCode int
|
wantCode int
|
||||||
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Valid input",
|
name: "Valid input",
|
||||||
@ -175,8 +175,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail: validAuthorEmail,
|
authorEmail: validAuthorEmail,
|
||||||
authorSite: validAuthorSite,
|
authorSite: validAuthorSite,
|
||||||
content: validContent,
|
content: validContent,
|
||||||
csrfToken: validCSRFToken,
|
wantCode: http.StatusOK,
|
||||||
wantCode: http.StatusSeeOther,
|
wantBody: "Comment successfully posted",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Blank name",
|
name: "Blank name",
|
||||||
@ -184,8 +184,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail: validAuthorEmail,
|
authorEmail: validAuthorEmail,
|
||||||
authorSite: validAuthorSite,
|
authorSite: validAuthorSite,
|
||||||
content: validContent,
|
content: validContent,
|
||||||
csrfToken: validCSRFToken,
|
wantCode: http.StatusOK,
|
||||||
wantCode: http.StatusUnprocessableEntity,
|
wantBody: "An error occurred",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Blank email",
|
name: "Blank email",
|
||||||
@ -193,8 +193,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail: "",
|
authorEmail: "",
|
||||||
authorSite: validAuthorSite,
|
authorSite: validAuthorSite,
|
||||||
content: validContent,
|
content: validContent,
|
||||||
csrfToken: validCSRFToken,
|
wantCode: http.StatusOK,
|
||||||
wantCode: http.StatusSeeOther,
|
wantBody: "Comment successfully posted",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Blank site",
|
name: "Blank site",
|
||||||
@ -202,8 +202,8 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail: validAuthorEmail,
|
authorEmail: validAuthorEmail,
|
||||||
authorSite: "",
|
authorSite: "",
|
||||||
content: validContent,
|
content: validContent,
|
||||||
csrfToken: validCSRFToken,
|
wantCode: http.StatusOK,
|
||||||
wantCode: http.StatusSeeOther,
|
wantBody: "Comment successfully posted",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Blank content",
|
name: "Blank content",
|
||||||
@ -211,21 +211,39 @@ func TestPostGuestbookCommentCreateRemote(t *testing.T) {
|
|||||||
authorEmail: validAuthorEmail,
|
authorEmail: validAuthorEmail,
|
||||||
authorSite: validAuthorSite,
|
authorSite: validAuthorSite,
|
||||||
content: "",
|
content: "",
|
||||||
csrfToken: validCSRFToken,
|
wantCode: http.StatusOK,
|
||||||
wantCode: http.StatusUnprocessableEntity,
|
wantBody: "An error occurred",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Add("authorname", tt.authorName)
|
form.Add("authorname", tt.authorName)
|
||||||
form.Add("authoremail", tt.authorEmail)
|
form.Add("authoremail", tt.authorEmail)
|
||||||
form.Add("authorsite", tt.authorSite)
|
form.Add("authorsite", tt.authorSite)
|
||||||
form.Add("content", tt.content)
|
form.Add("content", tt.content)
|
||||||
form.Add("csrf_token", tt.csrfToken)
|
r, err := http.NewRequest("POST", ts.URL, strings.NewReader(form.Encode()))
|
||||||
code, _, body := ts.postForm(t, fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1)), form)
|
if err != nil {
|
||||||
assert.Equal(t, code, tt.wantCode)
|
t.Fatal(err)
|
||||||
assert.Equal(t, body, body)
|
}
|
||||||
|
r.URL.Path = fmt.Sprintf("/websites/%s/guestbook/comments/create/remote", shortIdToSlug(1))
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
r.Header.Set("Origin", "http://example.com")
|
||||||
|
|
||||||
|
resp, err := ts.Client().Do(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
body = bytes.TrimSpace(body)
|
||||||
|
assert.Equal(t, resp.StatusCode, tt.wantCode)
|
||||||
|
assert.StringContains(t, string(body), tt.wantBody)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -173,13 +173,35 @@ func (app *application) userLoginOIDCCallback(w http.ResponseWriter, r *http.Req
|
|||||||
app.serverError(w, r, err)
|
app.serverError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id, err := app.users.AuthenticateByOIDC(t.Email, t.Subject)
|
|
||||||
if err != nil {
|
// search for user by subject
|
||||||
|
id, err := app.users.GetBySubject(t.Subject)
|
||||||
|
if err != nil && errors.Is(err, models.ErrNoRecord) {
|
||||||
|
// if no user is found, check if they have signed up by email already
|
||||||
|
id, err = app.users.GetByEmail(t.Email)
|
||||||
|
if err == nil {
|
||||||
|
// if user is found by email, update subject to match them in the first step next time
|
||||||
|
err2 := app.users.UpdateSubject(id, t.Subject)
|
||||||
|
if err2 != nil {
|
||||||
|
app.serverError(w, r, err2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
app.serverError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil && errors.Is(err, models.ErrNoRecord) {
|
||||||
|
// if no user is found by subject or email, create a new user
|
||||||
id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
|
id, err = app.users.InsertWithoutPassword(app.createShortId(), t.Username, t.Email, t.Subject, DefaultUserSettings())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.serverError(w, r, err)
|
app.serverError(w, r, err)
|
||||||
}
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
app.serverError(w, r, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
|
app.sessionManager.Put(r.Context(), "authenticatedUserId", id)
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
|
"git.32bit.cafe/32bitcafe/guestbook/internal/assert"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc/oidctest"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUserSignup(t *testing.T) {
|
func TestUserSignup(t *testing.T) {
|
||||||
@ -119,3 +125,126 @@ func TestUserSignup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OAuth2Mock struct {
|
||||||
|
Srv *testServer
|
||||||
|
Priv *rsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OAuth2Mock) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OAuth2Mock) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||||
|
tkn := oauth2.Token{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
Expiry: time.Now().Add(1 * time.Hour),
|
||||||
|
}
|
||||||
|
m := make(map[string]interface{})
|
||||||
|
var rawClaims = `{
|
||||||
|
"iss": "` + o.Srv.URL + `",
|
||||||
|
"aud": "my-client-id",
|
||||||
|
"sub": "foo",
|
||||||
|
"email": "foo@example.com",
|
||||||
|
"email_verified": true,
|
||||||
|
"nonce": "nonce"
|
||||||
|
}`
|
||||||
|
|
||||||
|
m["id_token"] = oidctest.SignIDToken(o.Priv, "test-key", oidc.RS256, rawClaims)
|
||||||
|
return tkn.WithExtra(m), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OAuth2Mock) Client(ctx context.Context, t *oauth2.Token) *http.Client {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserOIDCCallback(t *testing.T) {
|
||||||
|
app := newTestApplication(t)
|
||||||
|
ts := newTestServer(t, app.routes())
|
||||||
|
|
||||||
|
priv := newTestKey(t)
|
||||||
|
srv := newTestOIDCServer(t, priv)
|
||||||
|
|
||||||
|
defer srv.Close()
|
||||||
|
defer ts.Close()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
p, err := oidc.NewProvider(ctx, srv.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cfg := &oidc.Config{
|
||||||
|
ClientID: "my-client-id",
|
||||||
|
SkipExpiryCheck: true,
|
||||||
|
}
|
||||||
|
v := p.VerifierContext(ctx, cfg)
|
||||||
|
app.config.oauth = applicationOauthConfig{
|
||||||
|
ctx: context.Background(),
|
||||||
|
oidcConfig: cfg,
|
||||||
|
config: &OAuth2Mock{
|
||||||
|
Srv: srv,
|
||||||
|
Priv: priv,
|
||||||
|
},
|
||||||
|
provider: p,
|
||||||
|
verifier: v,
|
||||||
|
}
|
||||||
|
app.config.oauthEnabled = true
|
||||||
|
|
||||||
|
const (
|
||||||
|
validSubject = "goodSubject"
|
||||||
|
validUserId = 1
|
||||||
|
validEmail = "test@example.com"
|
||||||
|
validState = "goodState"
|
||||||
|
)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
subject string
|
||||||
|
email string
|
||||||
|
state string
|
||||||
|
wantCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Found Subject",
|
||||||
|
subject: validSubject,
|
||||||
|
email: validEmail,
|
||||||
|
state: validState,
|
||||||
|
wantCode: http.StatusSeeOther,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(*testing.T) {
|
||||||
|
r, err := http.NewRequest("GET", ts.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
r.URL.Path = "/users/login/oidc/callback"
|
||||||
|
q := r.URL.Query()
|
||||||
|
q.Add("state", tt.state)
|
||||||
|
r.URL.RawQuery = q.Encode()
|
||||||
|
c := &http.Cookie{
|
||||||
|
Name: "state",
|
||||||
|
Value: validState,
|
||||||
|
MaxAge: int(time.Hour.Seconds()),
|
||||||
|
Secure: r.TLS != nil,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
d := &http.Cookie{
|
||||||
|
Name: "nonce",
|
||||||
|
Value: "nonce",
|
||||||
|
MaxAge: int(time.Hour.Seconds()),
|
||||||
|
Secure: r.TLS != nil,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
r.AddCookie(c)
|
||||||
|
r.AddCookie(d)
|
||||||
|
resp, err := ts.Client().Do(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, resp.StatusCode, tt.wantCode)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"git.32bit.cafe/32bitcafe/guestbook/internal/auth"
|
||||||
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
|
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
|
||||||
"github.com/alexedwards/scs/sqlite3store"
|
"github.com/alexedwards/scs/sqlite3store"
|
||||||
"github.com/alexedwards/scs/v2"
|
"github.com/alexedwards/scs/v2"
|
||||||
@ -28,9 +29,9 @@ import (
|
|||||||
|
|
||||||
type applicationOauthConfig struct {
|
type applicationOauthConfig struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config oauth2.Config
|
|
||||||
provider *oidc.Provider
|
|
||||||
oidcConfig *oidc.Config
|
oidcConfig *oidc.Config
|
||||||
|
config auth.OAuth2ConfigInterface
|
||||||
|
provider *oidc.Provider
|
||||||
verifier *oidc.IDTokenVerifier
|
verifier *oidc.IDTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,7 +205,7 @@ func setupConfig(addr string) (applicationConfig, error) {
|
|||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
}
|
}
|
||||||
o.verifier = provider.Verifier(o.oidcConfig)
|
o.verifier = provider.Verifier(o.oidcConfig)
|
||||||
o.config = oauth2.Config{
|
o.config = &oauth2.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
Endpoint: provider.Endpoint(),
|
Endpoint: provider.Endpoint(),
|
||||||
|
@ -2,6 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"html"
|
"html"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -15,6 +17,8 @@ import (
|
|||||||
|
|
||||||
"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
|
"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
|
||||||
"github.com/alexedwards/scs/v2"
|
"github.com/alexedwards/scs/v2"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc/oidctest"
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -34,9 +38,35 @@ func newTestApplication(t *testing.T) *application {
|
|||||||
guestbookComments: &mocks.GuestbookCommentModel{},
|
guestbookComments: &mocks.GuestbookCommentModel{},
|
||||||
formDecoder: formDecoder,
|
formDecoder: formDecoder,
|
||||||
timezones: getAvailableTimezones(),
|
timezones: getAvailableTimezones(),
|
||||||
|
config: applicationConfig{
|
||||||
|
localAuthEnabled: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestKey(t *testing.T) *rsa.PrivateKey {
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return priv
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestOIDCServer(t *testing.T, priv *rsa.PrivateKey) *testServer {
|
||||||
|
s := &oidctest.Server{
|
||||||
|
PublicKeys: []oidctest.PublicKey{
|
||||||
|
{
|
||||||
|
PublicKey: priv.Public(),
|
||||||
|
KeyID: "test-key",
|
||||||
|
Algorithm: oidc.ES256,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ts := httptest.NewServer(s)
|
||||||
|
s.SetIssuer(ts.URL)
|
||||||
|
return &testServer{ts}
|
||||||
|
}
|
||||||
|
|
||||||
type testServer struct {
|
type testServer struct {
|
||||||
*httptest.Server
|
*httptest.Server
|
||||||
}
|
}
|
||||||
|
14
internal/auth/auth.go
Normal file
14
internal/auth/auth.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuth2ConfigInterface interface {
|
||||||
|
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||||
|
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||||
|
Client(ctx context.Context, t *oauth2.Token) *http.Client
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
|
"git.32bit.cafe/32bitcafe/guestbook/internal/models"
|
||||||
@ -76,13 +77,6 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
|
|||||||
return 0, models.ErrInvalidCredentials
|
return 0, models.ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UserModel) AuthenticateByOIDC(email, subject string) (int64, error) {
|
|
||||||
if email == "test@example.com" {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
return 0, models.ErrInvalidCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *UserModel) Exists(id int64) (bool, error) {
|
func (m *UserModel) Exists(id int64) (bool, error) {
|
||||||
switch id {
|
switch id {
|
||||||
case 1:
|
case 1:
|
||||||
@ -103,3 +97,24 @@ func (m *UserModel) UpdateUserSettings(userId int64, settings models.UserSetting
|
|||||||
func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
|
func (m *UserModel) UpdateSetting(userId int64, setting models.Setting, value string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *UserModel) GetBySubject(subject string) (int64, error) {
|
||||||
|
if subject == "goodSubject" {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
return -1, models.ErrNoRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UserModel) GetByEmail(email string) (int64, error) {
|
||||||
|
if email == "test@example.com" {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
return -1, models.ErrNoRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UserModel) UpdateSubject(userId int64, subject string) error {
|
||||||
|
if userId == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("invalid")
|
||||||
|
}
|
@ -49,12 +49,14 @@ type UserModelInterface interface {
|
|||||||
InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
|
InsertWithoutPassword(shortId uint64, username string, email string, subject string, settings UserSettings) (int64, error)
|
||||||
Get(shortId uint64) (User, error)
|
Get(shortId uint64) (User, error)
|
||||||
GetById(id int64) (User, error)
|
GetById(id int64) (User, error)
|
||||||
|
GetByEmail(email string) (int64, error)
|
||||||
|
GetBySubject(subject string) (int64, error)
|
||||||
GetAll() ([]User, error)
|
GetAll() ([]User, error)
|
||||||
Authenticate(email, password string) (int64, error)
|
Authenticate(email, password string) (int64, error)
|
||||||
AuthenticateByOIDC(email, subject string) (int64, error)
|
|
||||||
Exists(id int64) (bool, error)
|
Exists(id int64) (bool, error)
|
||||||
UpdateUserSettings(userId int64, settings UserSettings) error
|
UpdateUserSettings(userId int64, settings UserSettings) error
|
||||||
UpdateSetting(userId int64, setting Setting, value string) error
|
UpdateSetting(userId int64, setting Setting, value string) error
|
||||||
|
UpdateSubject(userId int64, subject string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UserModel) InitializeSettingsMap() error {
|
func (m *UserModel) InitializeSettingsMap() error {
|
||||||
@ -292,51 +294,44 @@ func (m *UserModel) Authenticate(email, password string) (int64, error) {
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UserModel) AuthenticateByOIDC(email string, subject string) (int64, error) {
|
func (m *UserModel) GetByEmail(email string) (int64, error) {
|
||||||
var id int64
|
var id int64
|
||||||
var s sql.NullString
|
stmt := `SELECT Id FROM users WHERE Email = ?`
|
||||||
tx, err := m.DB.Begin()
|
err := m.DB.QueryRow(stmt, email).Scan(&id)
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
stmt := `SELECT Id, OIDCSubject FROM users WHERE Email = ?`
|
|
||||||
err = tx.QueryRow(stmt, email, subject).Scan(&id, &s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return -1, ErrNoRecord
|
return -1, ErrNoRecord
|
||||||
} else {
|
} else {
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.Valid {
|
|
||||||
stmt = `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
|
|
||||||
_, err = tx.Exec(stmt, subject, id)
|
|
||||||
if err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
} else if subject != s.String {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return -1, ErrInvalidCredentials
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
}
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *UserModel) GetBySubject(subject string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
var s sql.NullString
|
||||||
|
stmt := `SELECT Id, OIDCSubject FROM users WHERE OIDCSubject = ?`
|
||||||
|
err := m.DB.QueryRow(stmt, subject).Scan(&id, &s)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return -1, ErrNoRecord
|
||||||
|
} else {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UserModel) UpdateSubject(userId int64, subject string) error {
|
||||||
|
stmt := `UPDATE users SET OIDCSubject = ? WHERE Id = ?`
|
||||||
|
_, err := m.DB.Exec(stmt, subject, userId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *UserModel) Exists(id int64) (bool, error) {
|
func (m *UserModel) Exists(id int64) (bool, error) {
|
||||||
var exists bool
|
var exists bool
|
||||||
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
|
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
|
||||||
|
Loading…
x
Reference in New Issue
Block a user