diff --git a/internal/apps/phpmyadmin/service.go b/internal/apps/phpmyadmin/service.go index 276fe35d..82456fe8 100644 --- a/internal/apps/phpmyadmin/service.go +++ b/internal/apps/phpmyadmin/service.go @@ -78,8 +78,9 @@ func (s *Service) UpdatePort(w http.ResponseWriter, r *http.Request) { fw := firewall.NewFirewall() err = fw.Port(firewall.FireInfo{ - Port: req.Port, - Protocol: "tcp", + PortStart: req.Port, + PortEnd: req.Port, + Protocol: "tcp", }, firewall.OperationAdd) if err != nil { service.Error(w, http.StatusInternalServerError, "%v", err) diff --git a/internal/apps/pureftpd/service.go b/internal/apps/pureftpd/service.go index 0861cf95..642f16a9 100644 --- a/internal/apps/pureftpd/service.go +++ b/internal/apps/pureftpd/service.go @@ -156,8 +156,9 @@ func (s *Service) UpdatePort(w http.ResponseWriter, r *http.Request) { fw := firewall.NewFirewall() err = fw.Port(firewall.FireInfo{ - Port: req.Port, - Protocol: "tcp", + PortStart: req.Port, + PortEnd: req.Port, + Protocol: "tcp", }, firewall.OperationAdd) if err != nil { service.Error(w, http.StatusInternalServerError, "%v", err) diff --git a/internal/data/setting.go b/internal/data/setting.go index 872169b4..fb7eb3c5 100644 --- a/internal/data/setting.go +++ b/internal/data/setting.go @@ -205,8 +205,9 @@ func (r *settingRepo) UpdatePanelSetting(ctx context.Context, setting *request.P // 放行端口 fw := firewall.NewFirewall() err = fw.Port(firewall.FireInfo{ - Port: uint(config.HTTP.Port), - Protocol: "tcp", + PortStart: uint(config.HTTP.Port), + PortEnd: uint(config.HTTP.Port), + Protocol: "tcp", }, firewall.OperationAdd) if err != nil { return false, err diff --git a/internal/http/request/firewall.go b/internal/http/request/firewall.go index 244ec621..9401ddf5 100644 --- a/internal/http/request/firewall.go +++ b/internal/http/request/firewall.go @@ -5,6 +5,9 @@ type FirewallStatus struct { } type FirewallRule struct { - Port uint `json:"port" validate:"required"` - Protocol string `json:"protocol" validate:"required"` + PortStart uint `json:"port_start" validate:"required,gte=1,lte=65535"` + PortEnd uint `json:"port_end" validate:"required,gte=1,lte=65535"` + Protocols []string `json:"protocols" validate:"min=1,dive,oneof=tcp udp"` + Address string `json:"address"` + Strategy string `json:"strategy" validate:"required,oneof=accept drop"` } diff --git a/internal/service/firewall.go b/internal/service/firewall.go index 3a53347c..bc82aaa0 100644 --- a/internal/service/firewall.go +++ b/internal/service/firewall.go @@ -2,6 +2,7 @@ package service import ( "net/http" + "slices" "github.com/go-rat/chix" @@ -15,7 +16,6 @@ type FirewallService struct { } func NewFirewallService() *FirewallService { - return &FirewallService{ firewall: firewall.NewFirewall(), } @@ -80,9 +80,13 @@ func (s *FirewallService) CreateRule(w http.ResponseWriter, r *http.Request) { return } - if err = s.firewall.Port(firewall.FireInfo{Port: req.Port, Protocol: req.Protocol}, "add"); err != nil { - Error(w, http.StatusInternalServerError, "%v", err) - return + for protocol := range slices.Values(req.Protocols) { + if err = s.firewall.Port(firewall.FireInfo{ + PortStart: req.PortStart, PortEnd: req.PortEnd, Protocol: protocol, Address: req.Address, Strategy: req.Strategy, + }, firewall.OperationAdd); err != nil { + Error(w, http.StatusInternalServerError, "%v", err) + return + } } Success(w, nil) @@ -95,9 +99,13 @@ func (s *FirewallService) DeleteRule(w http.ResponseWriter, r *http.Request) { return } - if err = s.firewall.Port(firewall.FireInfo{Port: req.Port, Protocol: req.Protocol}, "remove"); err != nil { - Error(w, http.StatusInternalServerError, "%v", err) - return + for protocol := range slices.Values(req.Protocols) { + if err = s.firewall.Port(firewall.FireInfo{ + PortStart: req.PortStart, PortEnd: req.PortEnd, Protocol: protocol, Address: req.Address, Strategy: req.Strategy, + }, firewall.OperationRemove); err != nil { + Error(w, http.StatusInternalServerError, "%v", err) + return + } } Success(w, nil) diff --git a/pkg/firewall/consts.go b/pkg/firewall/consts.go index 87c259c4..83e819d8 100644 --- a/pkg/firewall/consts.go +++ b/pkg/firewall/consts.go @@ -1,22 +1,24 @@ package firewall type FireInfo struct { - Family string `json:"family"` // ipv4 ipv6 - Address string `json:"address"` - Port uint `json:"port"` // 1-65535 - Protocol string `json:"protocol"` // tcp udp tcp/udp - Strategy string `json:"strategy"` // accept drop + Family string `json:"family"` // ipv4 ipv6 + Address string `json:"address"` // 源地址或目标地址 + PortStart uint `json:"port_start"` // 1-65535 + PortEnd uint `json:"port_end"` // 1-65535 + Protocol string `json:"protocol"` // tcp udp tcp/udp + Strategy string `json:"strategy"` // accept drop + Direction string `json:"direction"` // in out 入站或出站 +} - Num string `json:"num"` +type FireForwardInfo struct { + Address string `json:"address"` + Port uint `json:"port"` // 1-65535 + Protocol string `json:"protocol"` // tcp udp tcp/udp TargetIP string `json:"targetIP"` TargetPort string `json:"targetPort"` // 1-65535 - - UsedStatus string `json:"usedStatus"` - Description string `json:"description"` } type Forward struct { - Num string `json:"num"` Protocol string `json:"protocol"` Port uint `json:"port"` // 1-65535 TargetIP string `json:"targetIP"` diff --git a/pkg/firewall/firewall.go b/pkg/firewall/firewall.go index 4e403433..6f7794ec 100644 --- a/pkg/firewall/firewall.go +++ b/pkg/firewall/firewall.go @@ -16,8 +16,8 @@ import ( type Operation string var ( - OperationAdd Operation = "add" - OperationDel Operation = "remove" + OperationAdd Operation = "add" + OperationRemove Operation = "remove" ) type Firewall struct { @@ -58,13 +58,21 @@ func (r *Firewall) ListRule() ([]FireInfo, error) { if len(port) == 0 { continue } - var itemPort FireInfo + var item FireInfo if strings.Contains(port, "/") { - itemPort.Port = cast.ToUint(strings.Split(port, "/")[0]) - itemPort.Protocol = strings.Split(port, "/")[1] + ruleItem := strings.Split(port, "/") + portItem := strings.Split(ruleItem[0], "-") + if len(portItem) > 1 { + item.PortStart = cast.ToUint(portItem[0]) + item.PortEnd = cast.ToUint(portItem[1]) + } else { + item.PortStart = cast.ToUint(ruleItem[0]) + item.PortEnd = cast.ToUint(ruleItem[0]) + } + item.Protocol = ruleItem[1] } - itemPort.Strategy = "accept" - data = append(data, itemPort) + item.Strategy = "accept" + data = append(data, item) } }() go func() { @@ -73,7 +81,6 @@ func (r *Firewall) ListRule() ([]FireInfo, error) { if err != nil { return } - data = append(data, rich...) }() @@ -81,13 +88,13 @@ func (r *Firewall) ListRule() ([]FireInfo, error) { return data, nil } -func (r *Firewall) ListForward() ([]FireInfo, error) { +func (r *Firewall) ListForward() ([]FireForwardInfo, error) { out, err := shell.Execf("firewall-cmd --zone=public --list-forward-ports") if err != nil { return nil, err } - var data []FireInfo + var data []FireForwardInfo for _, line := range strings.Split(out, "\n") { line = strings.TrimFunc(line, func(r rune) bool { return r <= 32 @@ -100,7 +107,7 @@ func (r *Firewall) ListForward() ([]FireInfo, error) { if len(match[4]) == 0 { match[4] = "127.0.0.1" } - data = append(data, FireInfo{ + data = append(data, FireForwardInfo{ Port: cast.ToUint(match[1]), Protocol: match[2], TargetIP: match[4], @@ -132,10 +139,17 @@ func (r *Firewall) ListRichRule() ([]FireInfo, error) { return data, nil } -func (r *Firewall) Port(port FireInfo, operation Operation) error { - stdout, err := shell.Execf("firewall-cmd --zone=public --%s-port=%d/%s --permanent", operation, port.Port, port.Protocol) +func (r *Firewall) Port(rule FireInfo, operation Operation) error { + if rule.PortEnd == 0 { + rule.PortEnd = rule.PortStart + } + // 不支持的切换使用rich rules + if rule.Direction != "in" || rule.Family != "ipv4" || rule.Address != "" || rule.Strategy != "accept" { + return r.RichRules(rule, operation) + } + stdout, err := shell.Execf("firewall-cmd --zone=public --%s-port=%d-%d/%s --permanent", operation, rule.PortStart, rule.PortEnd, rule.Protocol) if err != nil { - return fmt.Errorf("%s port %d/%s failed, err: %s", operation, port.Port, port.Protocol, stdout) + return fmt.Errorf("%s port %d-%d/%s failed, err: %s", operation, rule.PortStart, rule.PortEnd, rule.Protocol, stdout) } _, err = shell.Execf("firewall-cmd --reload") @@ -146,22 +160,29 @@ func (r *Firewall) RichRules(rule FireInfo, operation Operation) error { families := strings.Split(rule.Family, "/") // ipv4 ipv6 for _, family := range families { - var ruleStr strings.Builder - ruleStr.WriteString(fmt.Sprintf(`rule family="%s" `, family)) + var ruleBuilder strings.Builder + ruleBuilder.WriteString(fmt.Sprintf(`rule family="%s" `, family)) + if len(rule.Address) != 0 { - ruleStr.WriteString(fmt.Sprintf(`source address="%s" `, rule.Address)) + if rule.Direction == "in" { + ruleBuilder.WriteString(fmt.Sprintf(`source address="%s" `, rule.Address)) + } else if rule.Direction == "out" { + ruleBuilder.WriteString(fmt.Sprintf(`destination address="%s" `, rule.Address)) + } else { + return fmt.Errorf("invalid direction: %s", rule.Direction) + } } - if rule.Port != 0 { - ruleStr.WriteString(fmt.Sprintf(`port port="%d" `, rule.Port)) + if rule.PortStart != 0 && rule.PortEnd != 0 { + ruleBuilder.WriteString(fmt.Sprintf(`port port="%d-%d" `, rule.PortStart, rule.PortEnd)) } if len(rule.Protocol) != 0 { - ruleStr.WriteString(fmt.Sprintf(`protocol="%s" `, rule.Protocol)) + ruleBuilder.WriteString(fmt.Sprintf(`protocol="%s" `, rule.Protocol)) } - ruleStr.WriteString(rule.Strategy) - _, err := shell.Execf("firewall-cmd --zone=public --%s-rich-rule '%s' --permanent", operation, ruleStr.String()) + ruleBuilder.WriteString(rule.Strategy) + _, err := shell.Execf("firewall-cmd --zone=public --%s-rich-rule '%s' --permanent", operation, ruleBuilder.String()) if err != nil { - return fmt.Errorf("%s rich rules (%s) failed, err: %v", operation, ruleStr.String(), err) + return fmt.Errorf("%s rich rules (%s) failed, err: %v", operation, ruleBuilder.String(), err) } } @@ -202,7 +223,14 @@ func (r *Firewall) parseRichRule(line string) (*FireInfo, error) { itemRule.Family = match[1] itemRule.Address = match[2] - itemRule.Port = cast.ToUint(match[3]) + ports := strings.Split(match[3], "-") + if len(ports) > 1 { + itemRule.PortStart = cast.ToUint(ports[0]) + itemRule.PortEnd = cast.ToUint(ports[1]) + } else { + itemRule.PortStart = cast.ToUint(match[3]) + itemRule.PortEnd = cast.ToUint(match[3]) + } itemRule.Protocol = match[4] itemRule.Strategy = match[5] }