260 lines
6.7 KiB
Go

package models
import (
"database/sql"
"errors"
"strings"
"time"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
)
type UserSettings struct {
LocalTimezone *time.Location
}
const (
SettingUserTimezone = "local_timezone"
)
type User struct {
ID int64
ShortId uint64
Username string
Email string
Deleted bool
IsBanned bool
HashedPassword []byte
Created time.Time
Settings UserSettings
}
type UserModel struct {
DB *sql.DB
Settings map[string]Setting
}
func (m *UserModel) InitializeSettingsMap() error {
if m.Settings == nil {
m.Settings = make(map[string]Setting)
}
stmt := `SELECT settings.Id, settings.Description, Constrained, d.Id, d.Description, g.Id, g.Description, MinValue, MaxValue
FROM settings
LEFT JOIN setting_data_types d ON settings.DataType = d.Id
LEFT JOIN setting_groups g ON settings.SettingGroup = g.Id
WHERE SettingGroup = (SELECT Id FROM setting_groups WHERE Description = 'user' LIMIT 1)`
result, err := m.DB.Query(stmt)
if err != nil {
return err
}
for result.Next() {
var s Setting
var mn sql.NullString
var mx sql.NullString
err := result.Scan(&s.id, &s.description, &s.constrained, &s.dataType.id, &s.dataType.description, &s.settingGroup.id, &s.settingGroup.description, &mn, &mx)
if mn.Valid {
s.minValue = mn.String
}
if mx.Valid {
s.maxValue = mx.String
}
if err != nil {
return err
}
m.Settings[s.description] = s
}
return nil
}
func (m *UserModel) Insert(shortId uint64, username string, email string, password string, settings UserSettings) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
if err != nil {
return err
}
stmt := `INSERT INTO users (ShortId, Username, Email, IsBanned, HashedPassword, Created)
VALUES (?, ?, ?, FALSE, ?, ?)`
result, err := m.DB.Exec(stmt, shortId, username, email, hashedPassword, time.Now().UTC())
if err != nil {
if sqliteError, ok := err.(sqlite3.Error); ok {
if sqliteError.ExtendedCode == 2067 && strings.Contains(sqliteError.Error(), "Email") {
return ErrDuplicateEmail
}
}
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
err = m.initializeUserSettings(id, settings)
if err != nil {
return err
}
return nil
}
func (m *UserModel) Get(id uint64) (User, error) {
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE ShortId = ? AND Deleted IS NULL`
row := m.DB.QueryRow(stmt, id)
var u User
err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return User{}, ErrNoRecord
}
return User{}, err
}
settings, err := m.GetSettings(u.ID)
if err != nil {
return u, err
}
u.Settings = settings
return u, nil
}
func (m *UserModel) GetById(id int64) (User, error) {
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE Id = ? AND Deleted IS NULL`
row := m.DB.QueryRow(stmt, id)
var u User
err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return User{}, ErrNoRecord
}
return User{}, err
}
settings, err := m.GetSettings(u.ID)
if err != nil {
return u, err
}
u.Settings = settings
return u, nil
}
func (m *UserModel) GetAll() ([]User, error) {
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE DELETED IS NULL`
rows, err := m.DB.Query(stmt)
if err != nil {
return nil, err
}
var users []User
for rows.Next() {
var u User
err = rows.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
if err != nil {
return nil, err
}
users = append(users, u)
}
if err = rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (m *UserModel) Authenticate(email, password string) (int64, error) {
var id int64
var hashedPassword []byte
stmt := `SELECT Id, HashedPassword FROM users WHERE Email = ?`
err := m.DB.QueryRow(stmt, email).Scan(&id, &hashedPassword)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, ErrInvalidCredentials
} else {
return 0, err
}
}
err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
if err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return 0, ErrInvalidCredentials
} else {
return 0, err
}
}
return id, nil
}
func (m *UserModel) Exists(id int64) (bool, error) {
var exists bool
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
err := m.DB.QueryRow(stmt, id).Scan(&exists)
return exists, err
}
func (m *UserModel) GetSettings(userId int64) (UserSettings, error) {
stmt := `SELECT u.SettingId, a.ItemValue, u.UnconstrainedValue FROM user_settings AS u
LEFT JOIN allowed_setting_values AS a ON u.AllowedSettingValueId = a.Id
WHERE UserId = ?`
var settings UserSettings
rows, err := m.DB.Query(stmt, userId)
if err != nil {
return settings, err
}
for rows.Next() {
var id int
var itemValue sql.NullString
var unconstrainedValue sql.NullString
err = rows.Scan(&id, &itemValue, &unconstrainedValue)
if err != nil {
return settings, err
}
switch id {
case m.Settings[SettingUserTimezone].id:
settings.LocalTimezone, err = time.LoadLocation(unconstrainedValue.String)
if err != nil {
panic(err)
}
}
}
return settings, err
}
func (m *UserModel) initializeUserSettings(userId int64, settings UserSettings) error {
stmt := `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue)
VALUES (?, ?, ?, ?)`
_, err := m.DB.Exec(stmt, userId, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String())
if err != nil {
return err
}
return nil
}
func (m *UserModel) UpdateUserSettings(userId int64, settings UserSettings) error {
err := m.UpdateSetting(userId, m.Settings[SettingUserTimezone], settings.LocalTimezone.String())
if err != nil {
return err
}
return nil
}
func (m *UserModel) UpdateSetting(userId int64, setting Setting, value string) error {
valid := setting.Validate(value)
if !valid {
return ErrInvalidSettingValue
}
stmt := `UPDATE user_settings SET
AllowedSettingValueId=IFNULL(
(SELECT Id FROM allowed_setting_values WHERE SettingId = user_settings.SettingId AND ItemValue = ?), AllowedSettingValueId
),
UnconstrainedValue=(SELECT ? FROM settings WHERE settings.Id = user_settings.SettingId AND settings.Constrained=0)
WHERE userId = ?
AND SettingId = (SELECT Id from Settings WHERE Description=?);`
result, err := m.DB.Exec(stmt, value, value, userId, setting.description)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return ErrInvalidSettingValue
}
return nil
}