mirror of
https://github.com/acepanel/panel.git
synced 2026-02-04 06:47:20 +08:00
182 lines
3.7 KiB
Go
182 lines
3.7 KiB
Go
package data
|
||
|
||
import (
|
||
"errors"
|
||
"image"
|
||
|
||
"github.com/go-rat/utils/hash"
|
||
"github.com/leonelquinteros/gotext"
|
||
"github.com/pquerna/otp"
|
||
"github.com/pquerna/otp/totp"
|
||
"github.com/spf13/cast"
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/tnb-labs/panel/internal/biz"
|
||
)
|
||
|
||
type userRepo struct {
|
||
t *gotext.Locale
|
||
db *gorm.DB
|
||
hasher hash.Hasher
|
||
}
|
||
|
||
func NewUserRepo(t *gotext.Locale, db *gorm.DB) biz.UserRepo {
|
||
return &userRepo{
|
||
t: t,
|
||
db: db,
|
||
hasher: hash.NewArgon2id(),
|
||
}
|
||
}
|
||
|
||
func (r *userRepo) List(page, limit uint) ([]*biz.User, int64, error) {
|
||
users := make([]*biz.User, 0)
|
||
var total int64
|
||
err := r.db.Model(&biz.User{}).Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&users).Error
|
||
return users, total, err
|
||
}
|
||
|
||
func (r *userRepo) Get(id uint) (*biz.User, error) {
|
||
user := new(biz.User)
|
||
if err := r.db.First(user, id).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
func (r *userRepo) Create(username, password string) (*biz.User, error) {
|
||
value, err := r.hasher.Make(password)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
user := &biz.User{
|
||
Username: username,
|
||
Password: value,
|
||
}
|
||
if err = r.db.Create(user).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
func (r *userRepo) UpdatePassword(id uint, password string) error {
|
||
value, err := r.hasher.Make(password)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
user, err := r.Get(id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
user.Password = value
|
||
return r.db.Save(user).Error
|
||
}
|
||
|
||
func (r *userRepo) UpdateEmail(id uint, email string) error {
|
||
user, err := r.Get(id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
user.Email = email
|
||
return r.db.Save(user).Error
|
||
}
|
||
|
||
func (r *userRepo) Delete(id uint) error {
|
||
if id == 1 {
|
||
return errors.New(r.t.Get("please don't do this"))
|
||
}
|
||
|
||
user := new(biz.User)
|
||
if err := r.db.First(user, id).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return r.db.Delete(user).Error
|
||
}
|
||
|
||
func (r *userRepo) CheckPassword(username, password string) (*biz.User, error) {
|
||
user := new(biz.User)
|
||
if err := r.db.Where("username = ?", username).First(user).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New(r.t.Get("username or password error"))
|
||
} else {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
if !r.hasher.Check(password, user.Password) {
|
||
return nil, errors.New(r.t.Get("username or password error"))
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
func (r *userRepo) IsTwoFA(username string) (bool, error) {
|
||
user := new(biz.User)
|
||
if err := r.db.Where("username = ?", username).First(user).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return false, errors.New(r.t.Get("username or password error"))
|
||
} else {
|
||
return false, err
|
||
}
|
||
}
|
||
|
||
return user.TwoFA != "", nil
|
||
}
|
||
|
||
func (r *userRepo) GenerateTwoFA(id uint) (image.Image, string, string, error) {
|
||
key, err := totp.Generate(totp.GenerateOpts{
|
||
Issuer: "RatPanel",
|
||
AccountName: cast.ToString(id),
|
||
SecretSize: 32,
|
||
Algorithm: otp.AlgorithmSHA256,
|
||
})
|
||
if err != nil {
|
||
return nil, "", "", err
|
||
}
|
||
|
||
img, err := key.Image(200, 200)
|
||
if err != nil {
|
||
return nil, "", "", err
|
||
}
|
||
|
||
return img, key.URL(), key.Secret(), nil
|
||
}
|
||
|
||
func (r *userRepo) UpdateTwoFA(id uint, code, secret string) error {
|
||
user, err := r.Get(id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 保存前先验证一次,防止错误开启
|
||
if secret != "" && !totp.Validate(code, secret) {
|
||
return errors.New(r.t.Get("invalid 2fa code"))
|
||
}
|
||
|
||
user.TwoFA = secret
|
||
return r.db.Save(user).Error
|
||
}
|
||
|
||
func (r *userRepo) CheckTwoFA(id uint, code string) (bool, error) {
|
||
user, err := r.Get(id)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
if user.TwoFA == "" {
|
||
return true, nil // 未开启2FA,无需验证
|
||
}
|
||
|
||
if !totp.Validate(code, user.TwoFA) {
|
||
return false, errors.New(r.t.Get("invalid 2fa code"))
|
||
}
|
||
|
||
return true, nil
|
||
}
|