diff --git a/internal/models/user.go b/internal/models/user.go index 318202d..acd9178 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -74,8 +74,15 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo } stmt := `INSERT INTO users (ShortId, Username, Email, IsBanned, HashedPassword, Created) VALUES (?, ?, ?, FALSE, ?, ?)` - result, err := m.DB.Exec(stmt, shortId, username, email, hashedPassword, time.Now().UTC()) + tx, err := m.DB.Begin() if err != nil { + return err + } + result, err := tx.Exec(stmt, shortId, username, email, hashedPassword, time.Now().UTC()) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return err + } if sqliteError, ok := err.(sqlite3.Error); ok { if sqliteError.ExtendedCode == 2067 && strings.Contains(sqliteError.Error(), "Email") { return ErrDuplicateEmail @@ -85,50 +92,91 @@ func (m *UserModel) Insert(shortId uint64, username string, email string, passwo } id, err := result.LastInsertId() if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return err + } return err } - err = m.initializeUserSettings(id, settings) + stmt = `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue) + VALUES (?, ?, ?, ?)` + _, err = tx.Exec(stmt, id, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String()) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return err + } + return err + } + err = tx.Commit() if err != nil { return err } + return nil } -func (m *UserModel) Get(id uint64) (User, error) { +func (m *UserModel) Get(shortId uint64) (User, error) { stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE ShortId = ? AND Deleted IS NULL` - row := m.DB.QueryRow(stmt, id) - var u User - err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created) + tx, err := m.DB.Begin() if err != nil { + return User{}, err + } + row := tx.QueryRow(stmt, shortId) + var u User + err = row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return User{}, err + } if errors.Is(err, sql.ErrNoRows) { return User{}, ErrNoRecord } return User{}, err } - settings, err := m.GetSettings(u.ID) + settings, err := m.getSettings(tx, u.ID) if err != nil { - return u, err + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return User{}, err + } + return User{}, err } u.Settings = settings + err = tx.Commit() + if err != nil { + return User{}, err + } return u, nil } func (m *UserModel) GetById(id int64) (User, error) { stmt := `SELECT Id, ShortId, Username, Email, Created FROM users WHERE Id = ? AND Deleted IS NULL` - row := m.DB.QueryRow(stmt, id) - var u User - err := row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created) + tx, err := m.DB.Begin() if err != nil { + return User{}, err + } + row := tx.QueryRow(stmt, id) + var u User + err = row.Scan(&u.ID, &u.ShortId, &u.Username, &u.Email, &u.Created) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return User{}, err + } if errors.Is(err, sql.ErrNoRows) { return User{}, ErrNoRecord } return User{}, err } - settings, err := m.GetSettings(u.ID) + settings, err := m.getSettings(tx, u.ID) if err != nil { - return u, err + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return User{}, err + } + return User{}, err } u.Settings = settings + err = tx.Commit() + if err != nil { + return User{}, err + } return u, nil } @@ -186,12 +234,12 @@ func (m *UserModel) Exists(id int64) (bool, error) { return exists, err } -func (m *UserModel) GetSettings(userId int64) (UserSettings, error) { +func (m *UserModel) getSettings(tx *sql.Tx, userId int64) (UserSettings, error) { stmt := `SELECT u.SettingId, a.ItemValue, u.UnconstrainedValue FROM user_settings AS u LEFT JOIN allowed_setting_values AS a ON u.AllowedSettingValueId = a.Id WHERE UserId = ?` var settings UserSettings - rows, err := m.DB.Query(stmt, userId) + rows, err := tx.Query(stmt, userId) if err != nil { return settings, err } @@ -214,16 +262,6 @@ func (m *UserModel) GetSettings(userId int64) (UserSettings, error) { return settings, err } -func (m *UserModel) initializeUserSettings(userId int64, settings UserSettings) error { - stmt := `INSERT INTO user_settings (UserId, SettingId, AllowedSettingValueId, UnconstrainedValue) - VALUES (?, ?, ?, ?)` - _, err := m.DB.Exec(stmt, userId, m.Settings[SettingUserTimezone].id, nil, settings.LocalTimezone.String()) - if err != nil { - return err - } - return nil -} - func (m *UserModel) UpdateUserSettings(userId int64, settings UserSettings) error { err := m.UpdateSetting(userId, m.Settings[SettingUserTimezone], settings.LocalTimezone.String()) if err != nil {