98 lines
2.3 KiB
Go
98 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"html"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.32bit.cafe/32bitcafe/guestbook/internal/models/mocks"
|
|
"github.com/alexedwards/scs/v2"
|
|
"github.com/gorilla/schema"
|
|
)
|
|
|
|
func newTestApplication(t *testing.T) *application {
|
|
formDecoder := schema.NewDecoder()
|
|
formDecoder.IgnoreUnknownKeys(true)
|
|
|
|
sessionManager := scs.New()
|
|
sessionManager.Lifetime = 12 * time.Hour
|
|
sessionManager.Cookie.Secure = true
|
|
|
|
return &application{
|
|
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
sessionManager: sessionManager,
|
|
websites: &mocks.WebsiteModel{},
|
|
guestbooks: &mocks.GuestbookModel{},
|
|
users: &mocks.UserModel{},
|
|
guestbookComments: &mocks.GuestbookCommentModel{},
|
|
formDecoder: formDecoder,
|
|
timezones: getAvailableTimezones(),
|
|
}
|
|
}
|
|
|
|
type testServer struct {
|
|
*httptest.Server
|
|
}
|
|
|
|
func newTestServer(t *testing.T, h http.Handler) *testServer {
|
|
ts := httptest.NewTLSServer(h)
|
|
jar, err := cookiejar.New(nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ts.Client().Jar = jar
|
|
ts.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
}
|
|
return &testServer{ts}
|
|
}
|
|
|
|
func (ts *testServer) get(t *testing.T, urlPath string) (int, http.Header, string) {
|
|
rs, err := ts.Client().Get(ts.URL + urlPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
defer rs.Body.Close()
|
|
body, err := io.ReadAll(rs.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
body = bytes.TrimSpace(body)
|
|
|
|
return rs.StatusCode, rs.Header, string(body)
|
|
}
|
|
|
|
func (ts *testServer) postForm(t *testing.T, urlPath string, form url.Values) (int, http.Header, string) {
|
|
rs, err := ts.Client().PostForm(ts.URL+urlPath, form)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
defer rs.Body.Close()
|
|
body, err := io.ReadAll(rs.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
body = bytes.TrimSpace(body)
|
|
return rs.StatusCode, rs.Header, string(body)
|
|
}
|
|
|
|
var csrfTokenRX = regexp.MustCompile(`<input type="hidden" name="csrf_token" value="(.+?)">`)
|
|
|
|
func extractCSRFToken(t *testing.T, body string) string {
|
|
matches := csrfTokenRX.FindStringSubmatch(body)
|
|
if len(matches) < 2 {
|
|
t.Fatal("no csrf token found in body")
|
|
}
|
|
return html.UnescapeString(matches[1])
|
|
}
|