From 03c0b191fe16d8576d46a41c48ad8b168f4abb3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Mon, 26 May 2025 00:34:24 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81cidr=EF=BC=8Cclose=20?= =?UTF-8?q?#764?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/data/user_token.go | 23 +++++++++++++++++++++-- internal/http/middleware/entrance.go | 27 ++++++++++++++++++++++++--- internal/http/request/setting.go | 2 +- internal/http/request/user_token.go | 4 ++-- internal/http/rule/ip_cidr.go | 22 ++++++++++++++++++++++ internal/http/rule/rule.go | 2 ++ 6 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 internal/http/rule/ip_cidr.go diff --git a/internal/data/user_token.go b/internal/data/user_token.go index 4272c600..42140b8f 100644 --- a/internal/data/user_token.go +++ b/internal/data/user_token.go @@ -11,7 +11,7 @@ import ( "io" "net" "net/http" - "slices" + "strings" "time" "github.com/go-rat/utils/str" @@ -144,7 +144,26 @@ func (r userTokenRepo) ValidateReq(req *http.Request) (uint, error) { if err != nil { ip = req.RemoteAddr } - if !slices.Contains(userToken.IPs, ip) { + allowed := false + requestIP := net.ParseIP(ip) + if requestIP != nil { + for _, allowedIP := range userToken.IPs { + if strings.Contains(allowedIP, "/") { + // CIDR + if _, ipNet, err := net.ParseCIDR(allowedIP); err == nil && ipNet.Contains(requestIP) { + allowed = true + break + } + } else { + // IP + if allowedIP == ip { + allowed = true + break + } + } + } + } + if !allowed { return 0, errors.New(r.t.Get("invalid request ip: %s", ip)) } } diff --git a/internal/http/middleware/entrance.go b/internal/http/middleware/entrance.go index 847e8cea..3558bdc6 100644 --- a/internal/http/middleware/entrance.go +++ b/internal/http/middleware/entrance.go @@ -47,9 +47,30 @@ func Entrance(t *gotext.Locale, conf *koanf.Koanf, session *sessions.Manager) fu if err != nil { ip = r.RemoteAddr } - if len(conf.Strings("http.bind_ip")) > 0 && !slices.Contains(conf.Strings("http.bind_ip"), ip) { - Abort(w, http.StatusTeapot, t.Get("invalid request ip: %s", ip)) - return + if len(conf.Strings("http.bind_ip")) > 0 { + allowed := false + requestIP := net.ParseIP(ip) + if requestIP != nil { + for _, allowedIP := range conf.Strings("http.bind_ip") { + if strings.Contains(allowedIP, "/") { + // CIDR + if _, ipNet, err := net.ParseCIDR(allowedIP); err == nil && ipNet.Contains(requestIP) { + allowed = true + break + } + } else { + // IP + if allowedIP == ip { + allowed = true + break + } + } + } + } + if !allowed { + Abort(w, http.StatusTeapot, t.Get("invalid request ip: %s", ip)) + return + } } if len(conf.Strings("http.bind_ua")) > 0 && !slices.Contains(conf.Strings("http.bind_ua"), r.UserAgent()) { Abort(w, http.StatusTeapot, t.Get("invalid request user agent: %s", r.UserAgent())) diff --git a/internal/http/request/setting.go b/internal/http/request/setting.go index cf735f6f..5b1282c7 100644 --- a/internal/http/request/setting.go +++ b/internal/http/request/setting.go @@ -25,7 +25,7 @@ type SettingPanel struct { func (r *SettingPanel) Rules(_ *http.Request) map[string]string { return map[string]string{ "BindDomain.*": "required", - "BindIP.*": "required|ip", + "BindIP.*": "required|ipcidr", "BindUA.*": "required", } } diff --git a/internal/http/request/user_token.go b/internal/http/request/user_token.go index f741ce78..53ccb688 100644 --- a/internal/http/request/user_token.go +++ b/internal/http/request/user_token.go @@ -15,7 +15,7 @@ type UserTokenCreate struct { func (r *UserTokenCreate) Rules(_ *http.Request) map[string]string { return map[string]string{ - "IPs.*": "required|ip", + "IPs.*": "required|ipcidr", } } @@ -27,6 +27,6 @@ type UserTokenUpdate struct { func (r *UserTokenUpdate) Rules(_ *http.Request) map[string]string { return map[string]string{ - "IPs.*": "required|ip", + "IPs.*": "required|ipcidr", } } diff --git a/internal/http/rule/ip_cidr.go b/internal/http/rule/ip_cidr.go new file mode 100644 index 00000000..b001fb57 --- /dev/null +++ b/internal/http/rule/ip_cidr.go @@ -0,0 +1,22 @@ +package rule + +import "net" + +// IPCIDR 验证一个值是否是一个有效的 IP 或 CIDR 格式 +type IPCIDR struct{} + +func NewIPCIDR() *IPCIDR { + return &IPCIDR{} +} + +func (r *IPCIDR) Passes(val any, options ...any) bool { + if str, ok := val.(string); ok { + if ip := net.ParseIP(str); ip != nil { + return true // 是有效的 IP + } + if _, _, err := net.ParseCIDR(str); err == nil { + return true // 是有效的 CIDR + } + } + return false // 既不是 IP 也不是 CIDR +} diff --git a/internal/http/rule/rule.go b/internal/http/rule/rule.go index 07e02fe0..65b6d652 100644 --- a/internal/http/rule/rule.go +++ b/internal/http/rule/rule.go @@ -11,11 +11,13 @@ func GlobalRules(db *gorm.DB) { "notExists": NewNotExists(db).Passes, "password": NewPassword().Passes, "cron": NewCron().Passes, + "ipcidr": NewIPCIDR().Passes, }) validate.AddGlobalMessages(map[string]string{ "exists": "{field} 不存在", "notExists": "{field} 已存在", "password": "密码不满足要求(8-20位,至少包含字母、数字、特殊字符中的两种)", "cron": "Cron 表达式不合法", + "ipcidr": "IP 或 CIDR 格式不合法", }) }