mirror of
https://github.com/acepanel/panel.git
synced 2026-02-04 04:22:33 +08:00
fix: basic auth问题
This commit is contained in:
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/leonelquinteros/gotext"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/acepanel/panel/internal/app"
|
||||
@@ -217,7 +216,7 @@ func (r *websiteRepo) Get(id uint) (*types.WebsiteSetting, error) {
|
||||
// 高级设置(限流限速、真实 IP、基本认证)
|
||||
setting.RateLimit = vhost.RateLimit()
|
||||
setting.RealIP = vhost.RealIP()
|
||||
// 读取基本认证用户列表(从 htpasswd 文件)
|
||||
// 读取基本认证用户列表
|
||||
setting.BasicAuth = r.readBasicAuthUsers(website.Name)
|
||||
|
||||
// 自定义配置
|
||||
@@ -697,7 +696,7 @@ func (r *websiteRepo) Update(ctx context.Context, req *request.WebsiteUpdate) er
|
||||
}
|
||||
// 基本认证创建 htpasswd 文件
|
||||
if len(req.BasicAuth) > 0 {
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", website.Name, "config", "htpasswd")
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", website.Name, "htpasswd")
|
||||
if err = r.writeBasicAuthUsers(htpasswdPath, req.BasicAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -706,7 +705,7 @@ func (r *websiteRepo) Update(ctx context.Context, req *request.WebsiteUpdate) er
|
||||
}
|
||||
} else {
|
||||
// 清除基本认证配置和 htpasswd 文件
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", website.Name, "config", "htpasswd")
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", website.Name, "htpasswd")
|
||||
_ = io.Remove(htpasswdPath)
|
||||
if err = vhost.ClearBasicAuth(); err != nil {
|
||||
return err
|
||||
@@ -1158,7 +1157,7 @@ func (r *websiteRepo) reloadWebServer() error {
|
||||
|
||||
// readBasicAuthUsers 读取 htpasswd 文件中的用户列表
|
||||
func (r *websiteRepo) readBasicAuthUsers(siteName string) map[string]string {
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", siteName, "config", "htpasswd")
|
||||
htpasswdPath := filepath.Join(app.Root, "sites", siteName, "htpasswd")
|
||||
if !io.Exists(htpasswdPath) {
|
||||
return nil
|
||||
}
|
||||
@@ -1176,11 +1175,10 @@ func (r *websiteRepo) readBasicAuthUsers(siteName string) map[string]string {
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
// htpasswd 格式: username:encrypted_password
|
||||
// htpasswd 格式: username:password
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
// 返回空密码,前端显示为占位符
|
||||
users[parts[0]] = ""
|
||||
users[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1192,54 +1190,17 @@ func (r *websiteRepo) readBasicAuthUsers(siteName string) map[string]string {
|
||||
|
||||
// writeBasicAuthUsers 将用户凭证写入 htpasswd 文件
|
||||
func (r *websiteRepo) writeBasicAuthUsers(htpasswdPath string, users map[string]string) error {
|
||||
// 读取现有用户密码
|
||||
existingUsers := make(map[string]string)
|
||||
if io.Exists(htpasswdPath) {
|
||||
file, err := os.Open(htpasswdPath)
|
||||
if err == nil {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
existingUsers[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
_ = file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for username, password := range users {
|
||||
if username == "" {
|
||||
if username == "" || password == "" {
|
||||
continue
|
||||
}
|
||||
var hashedPassword string
|
||||
if password == "" {
|
||||
// 密码为空,保留现有密码
|
||||
if existing, ok := existingUsers[username]; ok {
|
||||
hashedPassword = existing
|
||||
} else {
|
||||
// 新用户但没有密码,跳过
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// 新密码,使用 bcrypt 加密
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password for user %s: %w", username, err)
|
||||
}
|
||||
hashedPassword = string(hash)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%s:%s", username, hashedPassword))
|
||||
lines = append(lines, fmt.Sprintf("%s:%s", username, password))
|
||||
}
|
||||
|
||||
content := strings.Join(lines, "\n")
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
return io.Write(htpasswdPath, content, 0600)
|
||||
return io.Write(htpasswdPath, content, 0644) // 必须 0644,Nginx 在运行中以 www 用户读取
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user