260 lines
6.7 KiB
Go
260 lines
6.7 KiB
Go
package models
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mattn/go-sqlite3"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type UserSettings struct {
|
|
LocalTimezone *time.Location
|
|
}
|
|
|
|
const (
|
|
SettingUserTimezone = "local_timezone"
|
|
)
|
|
|
|
type User struct {
|
|
ID int64
|
|
ShortId uint64
|
|
Username string
|
|
Email string
|
|
Deleted bool
|
|
IsBanned bool
|
|
HashedPassword []byte
|
|
Created time.Time
|
|
Settings UserSettings
|
|
}
|
|
|
|
type UserModel struct {
|
|
DB *sql.DB
|
|
Settings map[string]Setting
|
|
}
|
|
|
|
func (m *UserModel) InitializeSettingsMap() error {
|
|
if m.Settings == nil {
|
|
m.Settings = make(map[string]Setting)
|
|
}
|
|
stmt := `SELECT settings.Id, settings.Description, Constrained, d.Id, d.Description, g.Id, g.Description, MinValue, MaxValue
|
|
FROM settings
|
|
LEFT JOIN setting_data_types d ON settings.DataType = d.Id
|
|
LEFT JOIN setting_groups g ON settings.SettingGroup = g.Id
|
|
WHERE SettingGroup = (SELECT Id FROM setting_groups WHERE Description = 'user' LIMIT 1)`
|
|
result, err := m.DB.Query(stmt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for result.Next() {
|
|
var s Setting
|
|
var mn sql.NullString
|
|
var mx sql.NullString
|
|
err := result.Scan(&s.id, &s.description, &s.constrained, &s.dataType.id, &s.dataType.description, &s.settingGroup.id, &s.settingGroup.description, &mn, &mx)
|
|
if mn.Valid {
|
|
s.minValue = mn.String
|
|
}
|
|
if mx.Valid {
|
|
s.maxValue = mx.String
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.Settings[s.description] = s
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *UserModel) Insert(shortId uint64, username string, email string, password string, settings UserSettings) error {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stmt := `INSERT INTO users (ShortId, Username, Email, IsBanned, HashedPassword, Created)
|
|
VALUES (?, ?, ?, FALSE, ?, ?)`
|
|
result, err := m.DB.Exec(stmt, shortId, username, email, hashedPassword, time.Now().UTC())
|
|
if err != nil {
|
|
if sqliteError, ok := err.(sqlite3.Error); ok {
|
|
if sqliteError.ExtendedCode == 2067 && strings.Contains(sqliteError.Error(), "Email") {
|
|
return ErrDuplicateEmail
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = m.initializeUserSettings(id, settings)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *UserModel) Get(id uint64) (User, error) {
|
|
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE ShortId = ? AND Deleted IS NULL`
|
|
row := m.DB.QueryRow(stmt, id)
|
|
var u User
|
|
err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return User{}, ErrNoRecord
|
|
}
|
|
return User{}, err
|
|
}
|
|
settings, err := m.GetSettings(u.ID)
|
|
if err != nil {
|
|
return u, err
|
|
}
|
|
u.Settings = settings
|
|
return u, nil
|
|
}
|
|
|
|
func (m *UserModel) GetById(id int64) (User, error) {
|
|
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE Id = ? AND Deleted IS NULL`
|
|
row := m.DB.QueryRow(stmt, id)
|
|
var u User
|
|
err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return User{}, ErrNoRecord
|
|
}
|
|
return User{}, err
|
|
}
|
|
settings, err := m.GetSettings(u.ID)
|
|
if err != nil {
|
|
return u, err
|
|
}
|
|
u.Settings = settings
|
|
return u, nil
|
|
}
|
|
|
|
func (m *UserModel) GetAll() ([]User, error) {
|
|
stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE DELETED IS NULL`
|
|
rows, err := m.DB.Query(stmt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var users []User
|
|
for rows.Next() {
|
|
var u User
|
|
err = rows.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (m *UserModel) Authenticate(email, password string) (int64, error) {
|
|
var id int64
|
|
var hashedPassword []byte
|
|
|
|
stmt := `SELECT Id, HashedPassword FROM users WHERE Email = ?`
|
|
err := m.DB.QueryRow(stmt, email).Scan(&id, &hashedPassword)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, ErrInvalidCredentials
|
|
} else {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(password))
|
|
if err != nil {
|
|
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
|
return 0, ErrInvalidCredentials
|
|
} else {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return id, nil
|
|
}
|
|
|
|
func (m *UserModel) Exists(id int64) (bool, error) {
|
|
var exists bool
|
|
stmt := `SELECT EXISTS(SELECT true FROM users WHERE Id = ? AND DELETED IS NULL)`
|
|
err := m.DB.QueryRow(stmt, id).Scan(&exists)
|
|
return exists, err
|
|
}
|
|
|
|
func (m *UserModel) GetSettings(userId int64) (UserSettings, error) {
|
|
stmt := `SELECT u.SettingId, a.ItemValue, u.UnconstrainedValue FROM user_settings AS u
|
|
LEFT JOIN allowed_setting_values AS a ON u.AllowedSettingValueId = a.Id
|
|
WHERE UserId = ?`
|
|
var settings UserSettings
|
|
rows, err := m.DB.Query(stmt, userId)
|
|
if err != nil {
|
|
return settings, err
|
|
}
|
|
for rows.Next() {
|
|
var id int
|
|
var itemValue sql.NullString
|
|
var unconstrainedValue sql.NullString
|
|
err = rows.Scan(&id, &itemValue, &unconstrainedValue)
|
|
if err != nil {
|
|
return settings, err
|
|
}
|
|
switch id {
|
|
case m.Settings[SettingUserTimezone].id:
|
|
settings.LocalTimezone, err = time.LoadLocation(unconstrainedValue.String)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
return settings, err
|
|
}
|
|
|
|
func (m *UserModel) initializeUserSettings(userId int64, settings UserSettings) error {
|
|
stmt := `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue)
|
|
VALUES (?, ?, ?, ?)`
|
|
_, err := m.DB.Exec(stmt, userId, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *UserModel) UpdateUserSettings(userId int64, settings UserSettings) error {
|
|
err := m.UpdateSetting(userId, m.Settings[SettingUserTimezone], settings.LocalTimezone.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *UserModel) UpdateSetting(userId int64, setting Setting, value string) error {
|
|
valid := setting.Validate(value)
|
|
if !valid {
|
|
return ErrInvalidSettingValue
|
|
}
|
|
stmt := `UPDATE user_settings SET
|
|
AllowedSettingValueId=IFNULL(
|
|
(SELECT Id FROM allowed_setting_values WHERE SettingId = user_settings.SettingId AND ItemValue = ?), AllowedSettingValueId
|
|
),
|
|
UnconstrainedValue=(SELECT ? FROM settings WHERE settings.Id = user_settings.SettingId AND settings.Constrained=0)
|
|
WHERE userId = ?
|
|
AND SettingId = (SELECT Id from Settings WHERE Description=?);`
|
|
result, err := m.DB.Exec(stmt, value, value, userId, setting.description)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows != 1 {
|
|
return ErrInvalidSettingValue
|
|
}
|
|
return nil
|
|
}
|