diff --git a/internal/data/user_token.go b/internal/data/user_token.go index 083c8cce..4272c600 100644 --- a/internal/data/user_token.go +++ b/internal/data/user_token.go @@ -42,9 +42,10 @@ func (r userTokenRepo) List(userID, page, limit uint) ([]*biz.UserToken, int64, } func (r userTokenRepo) Create(userID uint, ips []string, expired time.Time) (*biz.UserToken, error) { + token := str.Random(32) userToken := &biz.UserToken{ UserID: userID, - Token: str.Random(32), + Token: token, IPs: ips, ExpiredAt: expired, } @@ -52,6 +53,8 @@ func (r userTokenRepo) Create(userID uint, ips []string, expired time.Time) (*bi return nil, err } + userToken.Token = token // 返回的值是加密的,这里覆盖为原始值 + return userToken, nil } @@ -98,11 +101,17 @@ func (r userTokenRepo) ValidateReq(req *http.Request) (uint, error) { return 0, errors.New(r.t.Get("invalid header: %v", err)) } if algorithm != "HMAC-SHA256" { - return 0, errors.New(r.t.Get("invalid algorithm, must be HMAC-SHA256")) + return 0, errors.New(r.t.Get("invalid signature")) } // 获取用户令牌 - userToken, _ := r.Get(id) // 不应报错,防止猜测令牌ID + userToken, err := r.Get(id) + if err != nil { + return 0, errors.New(r.t.Get("invalid signature")) // 不应返回原始报错,防止猜测令牌ID + } + if userToken.ExpiredAt.Before(time.Now()) { + return 0, errors.New(r.t.Get("token expired")) + } // 步骤一:构造规范化请求 body, err := io.ReadAll(req.Body) @@ -121,12 +130,12 @@ func (r userTokenRepo) ValidateReq(req *http.Request) (uint, error) { // 步骤四:验证签名 if subtle.ConstantTimeCompare([]byte(signature), []byte(validSignature)) != 1 { - return 0, errors.New(r.t.Get("invalid api signature")) + return 0, errors.New(r.t.Get("invalid signature")) } // 步骤五:验证时间戳 if timestamp == 0 || timestamp < (time.Now().Unix()-300) { - return 0, errors.New(r.t.Get("api signature expired")) + return 0, errors.New(r.t.Get("signature expired")) } // 步骤六:验证IP diff --git a/internal/http/middleware/entrance.go b/internal/http/middleware/entrance.go index 3c9cebac..f7f909a2 100644 --- a/internal/http/middleware/entrance.go +++ b/internal/http/middleware/entrance.go @@ -1,6 +1,7 @@ package middleware import ( + "github.com/go-chi/chi/v5" "net" "net/http" "slices" @@ -30,7 +31,13 @@ func Entrance(t *gotext.Locale, conf *koanf.Koanf, session *sessions.Manager) fu return } - entrance := conf.String("http.entrance") + entrance := strings.TrimSuffix(conf.String("http.entrance"), "/") + if entrance == "" { + entrance = "/" + } + if !strings.HasPrefix(entrance, "/") { + entrance = "/" + entrance + } // 情况一:设置了绑定域名、IP、UA,且请求不符合要求,返回错误 host, _, err := net.SplitHostPort(r.Host) @@ -75,7 +82,7 @@ func Entrance(t *gotext.Locale, conf *koanf.Koanf, session *sessions.Manager) fu } // 情况二:请求路径与入口路径相同,标记通过验证并重定向到登录页面 - if strings.TrimSuffix(r.URL.Path, "/") == strings.TrimSuffix(entrance, "/") { + if strings.TrimSuffix(r.URL.Path, "/") == entrance { sess.Put("verify_entrance", true) render := chix.NewRender(w, r) defer render.Release() @@ -85,7 +92,13 @@ func Entrance(t *gotext.Locale, conf *koanf.Koanf, session *sessions.Manager) fu // 情况三:通过APIKey+入口路径访问,重写请求路径并跳过验证 if strings.HasPrefix(r.URL.Path, entrance) && r.Header.Get("Authorization") != "" { - r.URL.Path = strings.TrimPrefix(r.URL.Path, entrance) + // 只在设置了入口路径的情况下,才进行重写 + if entrance != "/" { + if rctx := chi.RouteContext(r.Context()); rctx != nil { + rctx.RoutePath = strings.TrimPrefix(rctx.RoutePath, entrance) + } + r.URL.Path = strings.TrimPrefix(r.URL.Path, entrance) + } next.ServeHTTP(w, r) return }