diff --git a/cmd/web/wire_gen.go b/cmd/web/wire_gen.go index fb932fa5..e384fde5 100644 --- a/cmd/web/wire_gen.go +++ b/cmd/web/wire_gen.go @@ -70,10 +70,10 @@ func initWeb() (*app.Web, error) { queue := bootstrap.NewQueue() taskRepo := data.NewTaskRepo(locale, db, logger, queue) appRepo := data.NewAppRepo(locale, db, cacheRepo, taskRepo) - middlewares := middleware.NewMiddlewares(koanf, logger, manager, appRepo) + userTokenRepo := data.NewUserTokenRepo(locale, db) + middlewares := middleware.NewMiddlewares(koanf, logger, manager, appRepo, userTokenRepo) userRepo := data.NewUserRepo(locale, db) userService := service.NewUserService(locale, koanf, manager, userRepo) - userTokenRepo := data.NewUserTokenRepo(locale, db) userTokenService := service.NewUserTokenService(locale, userTokenRepo) databaseServerRepo := data.NewDatabaseServerRepo(locale, db, logger) databaseUserRepo := data.NewDatabaseUserRepo(locale, db, databaseServerRepo) diff --git a/go.sum b/go.sum index 3669d7aa..bb5a64a4 100644 --- a/go.sum +++ b/go.sum @@ -121,7 +121,6 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI= github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= @@ -404,8 +403,6 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -515,8 +512,6 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= -golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= -golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/biz/user.go b/internal/biz/user.go index 0191b455..8ddb43c0 100644 --- a/internal/biz/user.go +++ b/internal/biz/user.go @@ -16,6 +16,8 @@ type User struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` + + Tokens []*UserToken `gorm:"foreignKey:UserID" json:"-"` } type UserRepo interface { diff --git a/internal/biz/user_token.go b/internal/biz/user_token.go index 7b3c7615..733e8eb5 100644 --- a/internal/biz/user_token.go +++ b/internal/biz/user_token.go @@ -1,10 +1,13 @@ package biz import ( + "net/http" "time" - "github.com/go-rat/utils/hash" + "github.com/go-rat/utils/crypt" "gorm.io/gorm" + + "github.com/tnb-labs/panel/internal/app" ) type UserToken struct { @@ -18,14 +21,30 @@ type UserToken struct { } func (r *UserToken) BeforeSave(tx *gorm.DB) error { - hasher := hash.NewArgon2id() - var err error - - r.Token, err = hasher.Make(r.Token) + crypter, err := crypt.NewXChacha20Poly1305([]byte(app.Key)) if err != nil { return err } + r.Token, err = crypter.Encrypt([]byte(r.Token)) + if err != nil { + return err + } + + return nil +} + +func (r *UserToken) AfterFind(tx *gorm.DB) error { + crypter, err := crypt.NewXChacha20Poly1305([]byte(app.Key)) + if err != nil { + return err + } + + token, err := crypter.Decrypt(r.Token) + if err == nil { + r.Token = string(token) + } + return nil } @@ -35,4 +54,5 @@ type UserTokenRepo interface { Get(id uint) (*UserToken, error) Delete(id uint) error Update(id uint, ips []string, expired time.Time) (*UserToken, error) + ValidateReq(req *http.Request) (uint, error) } diff --git a/internal/data/user.go b/internal/data/user.go index 287388bf..defe618a 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -88,16 +88,25 @@ func (r *userRepo) UpdateEmail(id uint, email string) error { } func (r *userRepo) Delete(id uint) error { - if id == 1 { + var count int64 + if err := r.db.Model(&biz.User{}).Count(&count).Error; err != nil { + return err + } + if count <= 1 { return errors.New(r.t.Get("please don't do this")) } user := new(biz.User) - if err := r.db.First(user, id).Error; err != nil { + if err := r.db.Preload("Tokens").First(user, id).Error; err != nil { return err } - return r.db.Delete(user).Error + return r.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&user).Association("Tokens").Delete(); err != nil { + return err + } + return tx.Delete(&user).Error + }) } func (r *userRepo) CheckPassword(username, password string) (*biz.User, error) { diff --git a/internal/data/user_token.go b/internal/data/user_token.go index 4fddd369..41ae2bff 100644 --- a/internal/data/user_token.go +++ b/internal/data/user_token.go @@ -1,27 +1,36 @@ package data import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/http" + "slices" "time" - "github.com/go-rat/utils/hash" "github.com/go-rat/utils/str" "github.com/leonelquinteros/gotext" + "github.com/spf13/cast" "gorm.io/gorm" "github.com/tnb-labs/panel/internal/biz" ) type userTokenRepo struct { - t *gotext.Locale - db *gorm.DB - hasher hash.Hasher + t *gotext.Locale + db *gorm.DB } func NewUserTokenRepo(t *gotext.Locale, db *gorm.DB) biz.UserTokenRepo { return &userTokenRepo{ - t: t, - db: db, - hasher: hash.NewArgon2id(), + t: t, + db: db, } } @@ -33,24 +42,16 @@ func (r userTokenRepo) List(userID, page, limit uint) ([]*biz.UserToken, int64, } func (r userTokenRepo) Create(userID uint, ips []string, expired time.Time) (*biz.UserToken, error) { - token := str.Random(32) - hashedToken, err := r.hasher.Make(token) - if err != nil { - return nil, err - } - userToken := &biz.UserToken{ UserID: userID, - Token: hashedToken, + Token: str.Random(32), IPs: ips, ExpiredAt: expired, } - if err = r.db.Create(userToken).Error; err != nil { + if err := r.db.Create(userToken).Error; err != nil { return nil, err } - userToken.Token = token - return userToken, nil } @@ -87,3 +88,63 @@ func (r userTokenRepo) Update(id uint, ips []string, expired time.Time) (*biz.Us return userToken, nil } + +func (r userTokenRepo) ValidateReq(req *http.Request) (uint, error) { + // Authorization: HMAC-SHA256 Credential=, Signature= + var algorithm string + var id uint + var signature string + if _, err := fmt.Sscanf(req.Header.Get("Authorization"), "%s Credential=%d, Signature=%s", &algorithm, &id, &signature); err != nil { + return 0, errors.New(r.t.Get("invalid Authorization header: %v", err)) + } + if algorithm != "HMAC-SHA256" { + return 0, errors.New(r.t.Get("invalid Authorization algorithm, must be HMAC-SHA256")) + } + + // 获取用户令牌 + userToken, _ := r.Get(id) // 不应报错,防止猜测令牌ID + + // 步骤一:构造规范化请求 + body, err := io.ReadAll(req.Body) + if err != nil { + return 0, err + } + req.Body = io.NopCloser(bytes.NewReader(body)) + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s", req.Method, req.URL.Path, req.URL.Query().Encode(), str.SHA256(string(body))) + + // 步骤二:构造待签名字符串 + timestamp := cast.ToInt64(req.Header.Get("X-Timestamp")) + stringToSign := fmt.Sprintf("%s\n%d\n%s", "HMAC-SHA256", cast.ToInt64(timestamp), str.SHA256(canonicalRequest)) + + // 步骤三:计算签名 + validSignature := r.hmacsha256(stringToSign, userToken.Token) + + // 步骤四:验证签名 + if subtle.ConstantTimeCompare([]byte(signature), []byte(validSignature)) != 1 { + return 0, errors.New(r.t.Get("invalid api signature")) + } + + // 步骤五:验证时间戳 + if timestamp == 0 || timestamp < (time.Now().Unix()-300) { + return 0, errors.New(r.t.Get("api signature expired")) + } + + // 步骤六:验证IP + if len(userToken.IPs) > 0 { + ip, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + ip = req.RemoteAddr + } + if !slices.Contains(userToken.IPs, ip) { + return 0, errors.New(r.t.Get("invalid request ip: %s", ip)) + } + } + + return userToken.UserID, nil +} + +func (r userTokenRepo) hmacsha256(data string, secret string) string { + h := hmac.New(sha256.New, []byte(secret)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/internal/http/middleware/entrance.go b/internal/http/middleware/entrance.go index f54b669a..3c9cebac 100644 --- a/internal/http/middleware/entrance.go +++ b/internal/http/middleware/entrance.go @@ -60,7 +60,7 @@ func Entrance(t *gotext.Locale, conf *koanf.Koanf, session *sessions.Manager) fu defer render.Release() render.Status(http.StatusTeapot) render.JSON(chix.M{ - "message": t.Get("invalid request ip: %s", r.RemoteAddr), + "message": t.Get("invalid request ip: %s", ip), }) return } diff --git a/internal/http/middleware/middleware.go b/internal/http/middleware/middleware.go index 241c0960..e5f20466 100644 --- a/internal/http/middleware/middleware.go +++ b/internal/http/middleware/middleware.go @@ -19,18 +19,20 @@ import ( var ProviderSet = wire.NewSet(NewMiddlewares) type Middlewares struct { - conf *koanf.Koanf - log *slog.Logger - session *sessions.Manager - app biz.AppRepo + conf *koanf.Koanf + log *slog.Logger + session *sessions.Manager + app biz.AppRepo + userToken biz.UserTokenRepo } -func NewMiddlewares(conf *koanf.Koanf, log *slog.Logger, session *sessions.Manager, app biz.AppRepo) *Middlewares { +func NewMiddlewares(conf *koanf.Koanf, log *slog.Logger, session *sessions.Manager, app biz.AppRepo, userToken biz.UserTokenRepo) *Middlewares { return &Middlewares{ - conf: conf, - log: log, - session: session, - app: app, + conf: conf, + log: log, + session: session, + app: app, + userToken: userToken, } } @@ -49,7 +51,7 @@ func (r *Middlewares) Globals(t *gotext.Locale, mux *chi.Mux) []func(http.Handle middleware.Recoverer, Status(t), Entrance(t, r.conf, r.session), - MustLogin(t, r.session), + MustLogin(t, r.session, r.userToken), MustInstall(t, r.app), } } diff --git a/internal/http/middleware/must_login.go b/internal/http/middleware/must_login.go index 46093272..e723b2c2 100644 --- a/internal/http/middleware/must_login.go +++ b/internal/http/middleware/must_login.go @@ -1,29 +1,24 @@ package middleware import ( - "bytes" "context" - "crypto/hmac" "crypto/sha256" - "crypto/subtle" - "encoding/hex" "fmt" - "io" "net" "net/http" "slices" "strings" - "time" "github.com/go-rat/chix" "github.com/go-rat/sessions" - "github.com/go-rat/utils/str" "github.com/leonelquinteros/gotext" "github.com/spf13/cast" + + "github.com/tnb-labs/panel/internal/biz" ) // MustLogin 确保已登录 -func MustLogin(t *gotext.Locale, session *sessions.Manager) func(next http.Handler) http.Handler { +func MustLogin(t *gotext.Locale, session *sessions.Manager, userToken biz.UserTokenRepo) func(next http.Handler) http.Handler { // 白名单 whiteList := []string{ "/api/user/key", @@ -54,51 +49,16 @@ func MustLogin(t *gotext.Locale, session *sessions.Manager) func(next http.Handl userID := uint(0) if r.Header.Get("Authorization") != "" { - signature := strings.TrimPrefix(r.Header.Get("Authorization"), "HMAC-SHA256 ") - - // 步骤一:构造规范化请求 - body, err := io.ReadAll(r.Body) - if err != nil { + // API 请求验证 + if userID, err = userToken.ValidateReq(r); err != nil { render := chix.NewRender(w) defer render.Release() - render.Status(http.StatusInternalServerError) + render.Status(http.StatusUnauthorized) render.JSON(chix.M{ "message": err.Error(), }) return } - r.Body = io.NopCloser(bytes.NewReader(body)) - canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s", r.Method, r.URL.Path, r.URL.Query().Encode(), str.SHA256(string(body))) - - // 步骤二:构造待签名字符串 - stringToSign := fmt.Sprintf("%s\n%d\n%s", "HMAC-SHA256", cast.ToInt64(r.Header.Get("X-Timestamp")), str.SHA256(canonicalRequest)) - - // 步骤三:计算签名 - validSignature := hmacsha256(stringToSign, cast.ToString(sess.Get("api_secret"))) - - // 步骤四:验证签名 - if subtle.ConstantTimeCompare([]byte(signature), []byte(validSignature)) != 1 { - render := chix.NewRender(w) - defer render.Release() - render.Status(http.StatusUnauthorized) - render.JSON(chix.M{ - "message": t.Get("invalid api signature"), - }) - return - } - timestamp := cast.ToInt64(r.Header.Get("X-Timestamp")) - if timestamp == 0 || timestamp < (time.Now().Unix()-60) { - render := chix.NewRender(w) - defer render.Release() - render.Status(http.StatusUnauthorized) - render.JSON(chix.M{ - "message": t.Get("api signature expired"), - }) - return - } - - // 步骤五:验证通过 - userID = 1 } else { if sess.Missing("user_id") { render := chix.NewRender(w) @@ -144,9 +104,3 @@ func MustLogin(t *gotext.Locale, session *sessions.Manager) func(next http.Handl }) } } - -func hmacsha256(data string, secret string) string { - h := hmac.New(sha256.New, []byte(secret)) - h.Write([]byte(data)) - return hex.EncodeToString(h.Sum(nil)) -}