2
0
mirror of https://github.com/acepanel/panel.git synced 2026-02-04 09:13:49 +08:00
Files
panel/internal/http/middleware/must_login.go
2024-12-03 03:46:28 +08:00

84 lines
1.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"fmt"
"net"
"net/http"
"slices"
"strings"
"github.com/go-rat/chix"
"github.com/spf13/cast"
"golang.org/x/crypto/sha3"
"github.com/TheTNB/panel/internal/app"
)
// MustLogin 确保已登录
func MustLogin(next http.Handler) http.Handler {
// 白名单
whiteList := []string{
"/api/user/key",
"/api/user/login",
"/api/user/logout",
"/api/user/isLogin",
"/api/dashboard/panel",
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sess, err := app.Session.GetSession(r)
if err != nil {
render := chix.NewRender(w)
render.Status(http.StatusInternalServerError)
render.JSON(chix.M{
"message": err.Error(),
})
}
// 对白名单和非 API 请求放行
if slices.Contains(whiteList, r.URL.Path) || !strings.HasPrefix(r.URL.Path, "/api") {
next.ServeHTTP(w, r)
return
}
if sess.Missing("user_id") {
render := chix.NewRender(w)
render.Status(http.StatusUnauthorized)
render.JSON(chix.M{
"message": "会话已过期,请重新登录",
})
return
}
userID := cast.ToUint(sess.Get("user_id"))
if userID == 0 {
render := chix.NewRender(w)
render.Status(http.StatusUnauthorized)
render.JSON(chix.M{
"message": "会话无效,请重新登录",
})
return
}
safeLogin := cast.ToBool(sess.Get("safe_login"))
if safeLogin {
safeClientHash := cast.ToString(sess.Get("safe_client"))
ip, _, _ := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
ua := r.Header.Get("User-Agent")
clientHash := fmt.Sprintf("%x", sha3.Sum256([]byte(ip+"|"+ua)))
if safeClientHash != clientHash || safeClientHash == "" {
render := chix.NewRender(w)
render.Status(http.StatusUnauthorized)
render.JSON(chix.M{
"message": "客户端IP/UA变化请重新登录",
})
return
}
}
r = r.WithContext(context.WithValue(r.Context(), "user_id", userID)) // nolint:staticcheck
next.ServeHTTP(w, r)
})
}