From e30564081757cda2a9f3390ed572d65a4256a560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Mon, 1 Dec 2025 14:19:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E7=9A=84=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E8=A7=A3=E6=9E=90=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/copilot-instructions.md | 65 --- AGENTS.md | 53 ++ CLAUDE.md | 259 ++++++++++ pkg/nginx/parser.go | 26 +- pkg/nginx/setter.go | 2 +- pkg/webserver/LICENSE | 21 + pkg/webserver/apache/ast.go | 243 +++++++++ pkg/webserver/apache/data.go | 30 ++ pkg/webserver/apache/export.go | 347 +++++++++++++ pkg/webserver/apache/lexer.go | 349 +++++++++++++ pkg/webserver/apache/options.go | 27 + pkg/webserver/apache/parser.go | 429 ++++++++++++++++ pkg/webserver/apache/parser_test.go | 262 ++++++++++ pkg/webserver/apache/vhost.go | 654 ++++++++++++++++++++++++ pkg/webserver/apache/vhost_test.go | 387 ++++++++++++++ pkg/webserver/nginx/data.go | 41 ++ pkg/webserver/nginx/getter.go | 278 ++++++++++ pkg/webserver/nginx/parser.go | 232 +++++++++ pkg/webserver/nginx/parser_test.go | 233 +++++++++ pkg/webserver/nginx/setter.go | 601 ++++++++++++++++++++++ pkg/webserver/nginx/testdata/http.conf | 28 + pkg/webserver/nginx/testdata/https.conf | 36 ++ pkg/webserver/nginx/vhost.go | 491 ++++++++++++++++++ pkg/webserver/nginx/vhost_test.go | 349 +++++++++++++ pkg/webserver/types.go | 9 + pkg/webserver/types/proxy.go | 24 + pkg/webserver/types/redirect.go | 19 + pkg/webserver/types/vhost.go | 136 +++++ pkg/webserver/webserver.go | 21 + 29 files changed, 5583 insertions(+), 69 deletions(-) delete mode 100644 .github/copilot-instructions.md create mode 100644 AGENTS.md create mode 100644 CLAUDE.md create mode 100644 pkg/webserver/LICENSE create mode 100644 pkg/webserver/apache/ast.go create mode 100644 pkg/webserver/apache/data.go create mode 100644 pkg/webserver/apache/export.go create mode 100644 pkg/webserver/apache/lexer.go create mode 100644 pkg/webserver/apache/options.go create mode 100644 pkg/webserver/apache/parser.go create mode 100644 pkg/webserver/apache/parser_test.go create mode 100644 pkg/webserver/apache/vhost.go create mode 100644 pkg/webserver/apache/vhost_test.go create mode 100644 pkg/webserver/nginx/data.go create mode 100644 pkg/webserver/nginx/getter.go create mode 100644 pkg/webserver/nginx/parser.go create mode 100644 pkg/webserver/nginx/parser_test.go create mode 100644 pkg/webserver/nginx/setter.go create mode 100644 pkg/webserver/nginx/testdata/http.conf create mode 100644 pkg/webserver/nginx/testdata/https.conf create mode 100644 pkg/webserver/nginx/vhost.go create mode 100644 pkg/webserver/nginx/vhost_test.go create mode 100644 pkg/webserver/types.go create mode 100644 pkg/webserver/types/proxy.go create mode 100644 pkg/webserver/types/redirect.go create mode 100644 pkg/webserver/types/vhost.go create mode 100644 pkg/webserver/webserver.go diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index 9316a424..00000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,65 +0,0 @@ -你是一名专业的 AI 编程助手,专门使用 Go 语言的 github.com/gofiber/fiber/v3 包和 gorm.io/gorm 包开发项目。 -始终使用最新稳定版本的 go1.24,并熟悉 RESTful API 设计原则、最佳实践和 Go 语言习惯用法。 - -- 严格按照用户要求一字不差地执行。 -- 始终使用简体中文进行回复及编写代码注释。 -- 首先进行逐步思考 - 详细描述你的 API 结构、端点和数据流计划,用伪代码详细写出。 -- 确认计划后,编写代码! -- 为 API 编写正确、最新、无错误、功能完整、安全且高效的 Go 代码。 -- 使用 github.com/gofiber/fiber/v3 包进行 API 开发: - - fiber v3 在 handler 中使用 c fiber.Ctx 而不是 c *fiber.Ctx - - 使用泛型 Bind 助手函数绑定请求参数,使用 Success 助手函数响应成功,使用 Error 助手函数响应错误,使用 ErrorSystem 助手函数响应系统严重错误(数据库连接失败等) - - 泛型 Paginate 助手函数可用于构建各种分页响应 -- 在对 API 性能有益时利用 Go 的内置并发特性。 -- 遵循 RESTful API 设计原则和最佳实践。 -- 包含必要的导入、包声明和任何必需的设置代码。 -- 如需记录日志,使用标准库的 slog 包进行日志记录(注入 *slog.Logger)。 -- 考虑为横切关注点(如日志记录、认证)实现中间件。 -- 在适当时实现速率限制和认证/授权(JWT)。 -- 在 API 实现中不留任何待办事项、占位符或缺失部分。 -- 解释时简明扼要,但为复杂逻辑或 Go 特定习惯用法提供简短注释。 -- 如果对最佳实践或实现细节不确定,请说明而不是猜测。 -- 提供使用 Go 测试包测试 API 端点的建议。 - -## 项目描述 - -本项目是基于 Go 语言的 Fiber 框架和 wire 依赖注入开发的 AcePanel Linux 服务器运维管理面板,目前正在进行 v3 版本重构。 - -v3 版本需要完成以下重构任务: -1. 使用 Fiber v3 替换目前的 go-chi 路由 -2. 全新的项目模块,支持运行 Java/Go/Python 等项目 -3. 网站模块重构,支持多 Web 服务器(Apache/OLS/Kangle) -4. 备份模块重构,需要支持 s3 和 ftp/sftp 备份途径 -5. 计划任务模块重构,支持管理备份任务和自定义脚本任务等 - -## 项目结构 - -├── cmd/ -│ ├── ace/ 面板主程序 -│ └── cli/ 面板命令行工具 -├── internal/ -│ ├── app/ 应用入口 -│ ├── apps/ 面板各子应用的实现 -│ ├── biz/ 业务逻辑的接口和数据库模型定义,类似 DDD 的 domain 层,data 类似 DDD 的 repo,而业务接口在这里定义,使用依赖倒置的原则 -│ ├── bootstrap/ 各个模块的启动引导 -│ ├── data/ 业务数据访问,包含 cache、db 等封装,实现了 biz 的业务接口。我们可能会把 data 与 dao 混淆在一起,data 偏重业务的含义,它所要做的是将领域对象重新拿出来,我们去掉了 DDD 的 infra 层 -│ ├── http/ -│ │ ├── middleware/ 自定义路由中间件 -│ │ ├── request/ 请求结构体 -│ │ └── rule/ 自定义验证规则 -│ ├── job/ 面板后台任务 -│ ├── migration/ 数据库迁移定义 -│ ├── queuejob/ 面板任务队列 -│ ├── route/ 路由定义 -│ └── service/ 实现了路由定义的服务层,类似 DDD 的 application 层,处理 DTO 到 biz 领域实体的转换(DTO -> DO),同时协同各类 biz 交互,但是不应处理复杂逻辑 -├── mocks/ 模拟数据,目前没有使用 -├── pkg/ 工具函数及包 -├── storage/ 数据存储 -└── web/ 前端项目 - -## 开发新需求时的流程 - -1. 在 route/http 中添加新的路由和注入需要的服务 -2. 在 service 中添加新的服务方法,先读取已存在的其他服务方法,以参考它们的实现方式 -3. 在 biz 中添加新的业务逻辑需要的接口等,先读取已存在的其他接口,以参考它们的实现方式 -4. 在 data 中实现 biz 的接口,先读取已存在的其他实现,以参考它们的实现方式 diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..38218e83 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,53 @@ +# AGENTS 指南 + +## 基本要求 +- 所有回复、文档、代码注释必须使用简体中文。 +- 项目处于 v3 重构期(网站/备份/计划任务等模块),保持现有架构和风格,避免随意改动。 + +## 项目概览与分层 +- 技术栈:后端 Go 1.25 + go-chi + GORM + Wire;前端 Vue 3 + Vite + UnoCSS + Naive UI + pnpm(Node 24)。 +- 分层:route -> service -> biz <- data,服务层只做编排/DTO 转换,业务逻辑放在 biz,数据访问在 data。 +- 目录:`cmd/ace`/`cmd/cli` 入口;`internal/app` 配置/启动;`internal/route|service|biz|data|http|apps|bootstrap|migration|job|queuejob` 按职责拆分;`pkg/` 通用库与内嵌资源;`web/` 前端;`mocks/` 为 Mockery 生成的仓库接口;构建后的前端复制到 `pkg/embed/frontend`;多语言在 `pkg/embed/locales` 与 `web/src/locales`。 +- 配置示例:`config.example.yml`;CI 脚本见 `.github/workflows/`。 + +## 开发约束 +- 禁止在本地直接运行主程序,只允许在远程 Linux 服务器运行。 +- 开发前准备:`cp config.example.yml config.yml`;前端开发可复制 `.env.development`(或按需 `.env.production`)为 `.env`,必要时复制 `settings/proxy-config.example.ts` 为 `settings/proxy-config.ts`。 + +## 构建与测试 +- 后端: + ```bash + go test ./... + go build ./cmd/ace + go build ./cmd/cli + # 如需注入版本信息可使用 go build -ldflags 方案,保持 -trimpath/-buildvcs=false 一致 + ``` +- 前端: + ```bash + cd web + pnpm install + pnpm type-check + pnpm lint + pnpm dev + pnpm build # 产物输出 dist 并自动复制到 ../pkg/embed/frontend + ``` +- 后端单元测试仅覆盖 `pkg/` 公共包,`internal/` 无需测试;前端不写单元测试,依赖 TS 类型检查与 ESLint。 + +## 开发流程 +1. 在 `internal/route/` 添加路由,注入所需 service。 +2. 在 `internal/service/` 实现编排:参数校验、DTO 处理,返回 `Success`/`Error`/`ErrorSystem`。 +3. 在 `internal/biz/` 定义接口与领域模型,保持精简,接口由 data 实现。 +4. 在 `internal/data/` 用 GORM/缓存等实现仓库逻辑,遵循依赖倒置。 +5. 更新对应 `wire.go` 并运行 `go generate ./...` 完成依赖注入。 + +## 编码规范 +- Go:使用 `gofmt` 与 `golangci-lint`;导出符号需中文注释;错误返回统一使用 `error` 并添加上下文;避免循环依赖,包名短小;日志使用标准库 `slog`,可用 `samber/lo` 辅助;文件按领域拆分(如 `container_*`)。 +- 前端:TypeScript + Vue SFC(组合式 API);样式使用 UnoCSS/Naive UI 主题;状态集中在 `store/`(Pinia);请求使用 Alova;命名采用帕斯卡组件名;Prettier 2 空格 + ESLint 规则。 + +## 数据与安全 +- 默认数据库 SQLite(`github.com/ncruces/go-sqlite3`),通过 GORM 迁移与访问。 +- 需要关注认证/授权(JWT)、SQL 注入防护、XSS/CSRF 防护、速率限制(`github.com/sethvargo/go-limiter`)。 + +## 提交与 PR +- 提交信息遵循惯例式格式(如 `chore(deps): ...`、`feat: ...`、`fix: ...`),一次提交聚焦单一主题。 +- PR 应包含:变更摘要、关联 Issue/需求、测试命令与结果、前端可视化改动的截图;确保 CI(lint/test/build)在干净环境可复现。 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..85c4306a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,259 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## 项目概述 + +AcePanel 是基于 Go 语言开发的新一代 Linux 服务器运维管理面板。项目采用前后端分离架构: +- 后端:Go 1.25 + go-chi 路由 + GORM + Wire 依赖注入 +- 前端:Vue 3 + Vite + Naive UI + pnpm + +**重要提示:** 项目目前正在进行 v3 版本重构,主要包括重构网站/备份/计划任务模块等。 + +## 语言和编码规范 + +**所有代码注释、文档和回复必须使用简体中文。** + +## 构建和测试 + +### 后端构建 + +构建主程序: +```bash +go build -o ace ./cmd/ace +``` + +构建 CLI 工具: +```bash +go build -o cli ./cmd/cli +``` + +构建时注入版本信息: +```bash +VERSION="1.0.0" +BUILD_TIME="$(date -u '+%F %T UTC')" +COMMIT_HASH="$(git rev-parse --short HEAD)" +GO_VERSION="$(go version | cut -d' ' -f3)" + +LDFLAGS="-s -w --extldflags '-static'" +LDFLAGS="${LDFLAGS} -X 'github.com/acepanel/panel/internal/app.Version=${VERSION}'" +LDFLAGS="${LDFLAGS} -X 'github.com/acepanel/panel/internal/app.BuildTime=${BUILD_TIME}'" +LDFLAGS="${LDFLAGS} -X 'github.com/acepanel/panel/internal/app.CommitHash=${COMMIT_HASH}'" +LDFLAGS="${LDFLAGS} -X 'github.com/acepanel/panel/internal/app.GoVersion=${GO_VERSION}'" + +go build -trimpath -buildvcs=false -ldflags "${LDFLAGS}" -o ace ./cmd/ace +``` + +### 运行测试 + +运行所有测试: +```bash +go test -v ./... +``` + +运行测试并生成覆盖率报告: +```bash +go test -v -coverprofile="coverage.out" ./... +``` + +运行单个测试: +```bash +go test -v -run TestFunctionName ./path/to/package +``` + +### 前端开发 + +进入前端目录: +```bash +cd web +``` + +安装依赖: +```bash +pnpm install +``` + +开发模式(带热重载): +```bash +pnpm dev +``` + +类型检查: +```bash +pnpm type-check +``` + +代码检查: +```bash +pnpm lint +``` + +构建生产版本: +```bash +pnpm build +``` + +## 代码架构 + +项目采用类 DDD 分层架构,依赖关系为:route -> service -> biz <- data + +### 核心目录结构 + +- **`cmd/`**: 程序入口 + - `ace/`: 面板主程序 + - `cli/`: 命令行工具 + +- **`internal/app/`**: 应用入口和配置 + +- **`internal/route/`**: HTTP 路由定义 + - 定义路由规则 + - 注入所需的 service 依赖 + +- **`internal/service/`**: 服务层(类似 DDD 的 application 层) + - 处理 HTTP 请求/响应 + - DTO 到 DO 的转换 + - 协调多个 biz 接口完成业务流程 + - **不应处理复杂业务逻辑** + +- **`internal/biz/`**: 业务逻辑层(类似 DDD 的 domain 层) + - 定义业务接口(Repository 模式) + - 定义领域模型和数据结构 + - 使用依赖倒置原则:biz 定义接口,data 实现接口 + +- **`internal/data/`**: 数据访问层(类似 DDD 的 repository 层) + - 实现 biz 中定义的业务接口 + - 封装数据库、缓存等操作 + - 处理数据持久化逻辑 + +- **`internal/http/`**: HTTP 相关 + - `middleware/`: 自定义中间件 + - `request/`: 请求结构体定义 + - `rule/`: 自定义验证规则 + +- **`internal/apps/`**: 面板子应用实现 + +- **`internal/bootstrap/`**: 各模块启动引导 + +- **`internal/migration/`**: 数据库迁移 + +- **`internal/job/`**: 后台任务 + +- **`internal/queuejob/`**: 任务队列 + +- **`pkg/`**: 工具函数和通用包 + - 包含各种独立的工具模块 + - 可被项目任何部分引用 + +- **`web/`**: Vue 3 前端项目 + +## 开发新功能的标准流程 + +1. **在 `internal/route/` 中添加路由** + - 参考已有路由文件(如 `http.go`) + - 注入需要的 service 依赖 + - 定义路由规则和 handler 映射 + +2. **在 `internal/service/` 中实现服务方法** + - **先阅读已有的类似服务**以了解代码风格 + - 处理请求验证和响应格式化 + - 使用 `Success()` 返回成功响应 + - 使用 `Error()` 返回错误响应 + - 使用 `ErrorSystem()` 返回系统严重错误 + - 调用 biz 层接口完成业务逻辑 + +3. **在 `internal/biz/` 中定义业务接口** + - **先阅读已有的类似接口定义** + - 定义 Repository 接口(如 `WebsiteRepo`) + - 定义领域模型结构体(如 `Website`) + - 保持接口简洁明确 + +4. **在 `internal/data/` 中实现 biz 接口** + - **先阅读已有的类似实现** + - 创建 repo 结构体(如 `websiteRepo`) + - 实现构造函数(如 `NewWebsiteRepo`) + - 实现所有接口方法 + - 处理数据库操作和缓存逻辑 + +5. **使用 Wire 进行依赖注入** + - 在对应的 wire.go 文件中添加 provider + - 运行 `go generate` 生成依赖注入代码 + +## 技术栈特定注意事项 + +### Go 语言规范 + +- 使用 Go 1.25 稳定版本 +- 遵循 Go 标准库和习惯用法 +- 日志使用标准库的 `slog` 包 +- 使用 `github.com/samber/lo` 进行函数式编程辅助 + +### 当前框架 + +- 路由:`github.com/go-chi/chi/v5` +- ORM:`gorm.io/gorm` +- 依赖注入:`github.com/google/wire` +- 配置管理:`github.com/knadh/koanf/v2` +- 验证:`github.com/gookit/validate` + +### 助手函数(service 层) + +在 service 层使用以下助手函数: +- `Success(w, data)`: 返回成功响应 +- `Error(w, statusCode, format, args...)`: 返回错误响应 +- `ErrorSystem(w, format, args...)`: 返回系统严重错误(500) +- `Bind[T](r)`: 绑定请求参数到泛型类型 T +- `Paginate[T](...)`: 构建分页响应 + +### 数据库 + +- 使用 SQLite(`github.com/ncruces/go-sqlite3`) +- 使用 GORM 进行数据库迁移和操作 + +### 安全性 + +- 实现认证/授权(JWT) +- 防止 SQL 注入(使用 GORM 参数化查询) +- 防止 XSS 和 CSRF 攻击 +- 实现速率限制(`github.com/sethvargo/go-limiter`) + +## 代码风格 + +- 所有代码注释必须使用简体中文 +- 遵循 Go 官方代码风格 +- 使用 `gofmt` 格式化代码 +- 复杂逻辑添加注释说明 +- 导出的函数和类型必须有注释 + +## Wire 依赖注入 + +项目使用 Wire 进行依赖注入。当添加新的依赖时: + +1. 在 `cmd/ace/wire.go` 或 `cmd/cli/wire.go` 中添加 provider +2. 运行生成命令: +```bash +go generate ./... +``` + +## 前端开发注意事项 + +- 使用 Vue 3 Composition API +- UI 框架:Naive UI +- 状态管理:Pinia +- HTTP 请求:Alova +- 图标:@iconify/vue +- 终端:xterm.js +- 遵循项目已有的组件结构和编码风格 + +## 配置文件 + +开发时需要准备配置文件: +```bash +cp config.example.yml config.yml +``` + +前端开发配置: +```bash +cd web +cp .env.production .env +cp settings/proxy-config.example.ts settings/proxy-config.ts +``` diff --git a/pkg/nginx/parser.go b/pkg/nginx/parser.go index 4496dcae..04b94cf7 100644 --- a/pkg/nginx/parser.go +++ b/pkg/nginx/parser.go @@ -122,7 +122,7 @@ func (p *Parser) Clear(key string) error { // Set 通过表达式设置配置 // e.g. Set("server.server_name", []directive) -func (p *Parser) Set(key string, directives []*config.Directive) error { +func (p *Parser) Set(key string, directives []*config.Directive, after ...string) error { parts := strings.Split(key, ".") var block *config.Block @@ -144,9 +144,30 @@ func (p *Parser) Set(key string, directives []*config.Directive) error { blockDirective = sub[0] } + iDirectives := make([]config.IDirective, 0, len(directives)) for _, directive := range directives { directive.SetParent(blockDirective) - block.Directives = append(block.Directives, directive) + iDirectives = append(iDirectives, directive) + } + + if len(after) == 0 { + block.Directives = append(block.Directives, iDirectives...) + } else { + insertIndex := -1 + for i, d := range block.Directives { + if d.GetName() == after[0] { + insertIndex = i + 1 + break + } + } + if insertIndex == -1 { + return fmt.Errorf("after directive %s not found", after[0]) + } + + block.Directives = append( + block.Directives[:insertIndex], + append(iDirectives, block.Directives[insertIndex:]...)..., + ) } return nil @@ -157,7 +178,6 @@ func (p *Parser) Sort() { } func (p *Parser) Dump() string { - p.Sort() return dumper.DumpConfig(p.c, dumper.IndentedStyle) } diff --git a/pkg/nginx/setter.go b/pkg/nginx/setter.go index 3e7927a4..7a9dafb2 100644 --- a/pkg/nginx/setter.go +++ b/pkg/nginx/setter.go @@ -222,7 +222,7 @@ func (p *Parser) SetHTTPS(cert, key string) error { Name: "ssl_early_data", Parameters: []config.Parameter{{Value: "on"}}, }, - }) + }, "root") } func (p *Parser) SetHTTPSProtocols(protocols []string) error { diff --git a/pkg/webserver/LICENSE b/pkg/webserver/LICENSE new file mode 100644 index 00000000..b4a1cc19 --- /dev/null +++ b/pkg/webserver/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 AcePanel + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pkg/webserver/apache/ast.go b/pkg/webserver/apache/ast.go new file mode 100644 index 00000000..ca174228 --- /dev/null +++ b/pkg/webserver/apache/ast.go @@ -0,0 +1,243 @@ +package apache + +import ( + "strings" +) + +// Config Apache 配置文件的 AST 根节点 +type Config struct { + Directives []*Directive `json:"directives"` + VirtualHosts []*VirtualHost `json:"virtual_hosts"` + Comments []*Comment `json:"comments"` + Includes []*Include `json:"includes"` +} + +// Directive Apache 指令 +type Directive struct { + Name string `json:"name"` + Args []string `json:"args"` + Line int `json:"line"` + Column int `json:"column"` + Block *Block `json:"block,omitempty"` // 对于有块的指令如 +} + +// VirtualHost 虚拟主机配置 +type VirtualHost struct { + Name string `json:"name"` + Args []string `json:"args"` // 通常是 IP:Port + Line int `json:"line"` + Column int `json:"column"` + Directives []*Directive `json:"directives"` + Comments []*Comment `json:"comments"` // 虚拟主机内的注释 +} + +// Block 配置块,如 , 等 +type Block struct { + Type string `json:"type"` // Directory, Location, Files 等 + Args []string `json:"args"` // 块的参数 + Directives []*Directive `json:"directives"` + Comments []*Comment `json:"comments"` // 块内注释 + Line int `json:"line"` + Column int `json:"column"` +} + +// Comment 注释 +type Comment struct { + Text string `json:"text"` + Line int `json:"line"` + Column int `json:"column"` +} + +// Include 包含其他配置文件的指令 +type Include struct { + Path string `json:"path"` + Line int `json:"line"` + Column int `json:"column"` +} + +// GetDirective 根据名称查找指令 +func (c *Config) GetDirective(name string) *Directive { + for _, dir := range c.Directives { + if strings.EqualFold(dir.Name, name) { + return dir + } + } + return nil +} + +// GetDirectives 根据名称查找所有匹配的指令 +func (c *Config) GetDirectives(name string) []*Directive { + var result []*Directive + for _, dir := range c.Directives { + if strings.EqualFold(dir.Name, name) { + result = append(result, dir) + } + } + return result +} + +// GetVirtualHost 根据参数查找虚拟主机 +func (c *Config) GetVirtualHost(args ...string) *VirtualHost { + for _, vhost := range c.VirtualHosts { + if len(vhost.Args) == len(args) { + match := true + for i, arg := range args { + if vhost.Args[i] != arg { + match = false + break + } + } + if match { + return vhost + } + } + } + return nil +} + +// AddVirtualHost 添加新虚拟主机到配置 +func (c *Config) AddVirtualHost(args ...string) *VirtualHost { + vhost := &VirtualHost{ + Name: "VirtualHost", + Args: args, + Directives: make([]*Directive, 0), + } + c.VirtualHosts = append(c.VirtualHosts, vhost) + return vhost +} + +// AddDirective 为虚拟主机添加指令 +func (v *VirtualHost) AddDirective(name string, args ...string) *Directive { + directive := &Directive{ + Name: name, + Args: args, + } + v.Directives = append(v.Directives, directive) + return directive +} + +// GetDirective 在虚拟主机中根据名称查找指令 +func (v *VirtualHost) GetDirective(name string) *Directive { + for _, dir := range v.Directives { + if strings.EqualFold(dir.Name, name) { + return dir + } + } + return nil +} + +// GetDirectives 在虚拟主机中根据名称查找所有匹配的指令 +func (v *VirtualHost) GetDirectives(name string) []*Directive { + var result []*Directive + for _, dir := range v.Directives { + if strings.EqualFold(dir.Name, name) { + result = append(result, dir) + } + } + return result +} + +// SetDirective 设置指令(如果存在则更新,不存在则添加) +func (v *VirtualHost) SetDirective(name string, args ...string) *Directive { + // 查找现有指令 + for _, dir := range v.Directives { + if strings.EqualFold(dir.Name, name) { + dir.Args = args + return dir + } + } + // 不存在,添加新指令 + return v.AddDirective(name, args...) +} + +// RemoveDirective 删除指令 +func (v *VirtualHost) RemoveDirective(name string) bool { + for i, dir := range v.Directives { + if strings.EqualFold(dir.Name, name) { + v.Directives = append(v.Directives[:i], v.Directives[i+1:]...) + return true + } + } + return false +} + +// RemoveDirectives 删除所有匹配名称的指令 +func (v *VirtualHost) RemoveDirectives(name string) int { + count := 0 + newDirectives := make([]*Directive, 0, len(v.Directives)) + for _, dir := range v.Directives { + if strings.EqualFold(dir.Name, name) { + count++ + } else { + newDirectives = append(newDirectives, dir) + } + } + v.Directives = newDirectives + return count +} + +// HasDirective 检查是否存在指定指令 +func (v *VirtualHost) HasDirective(name string) bool { + return v.GetDirective(name) != nil +} + +// GetDirectiveValue 获取指令的第一个参数值 +func (v *VirtualHost) GetDirectiveValue(name string) string { + dir := v.GetDirective(name) + if dir != nil && len(dir.Args) > 0 { + return dir.Args[0] + } + return "" +} + +// GetDirectiveValues 获取指令的所有参数值 +func (v *VirtualHost) GetDirectiveValues(name string) []string { + dir := v.GetDirective(name) + if dir != nil { + return dir.Args + } + return nil +} + +// AddBlock 添加块指令(如 Directory, Location 等) +func (v *VirtualHost) AddBlock(blockType string, args ...string) *Directive { + block := &Block{ + Type: blockType, + Args: args, + Directives: make([]*Directive, 0), + Comments: make([]*Comment, 0), + } + directive := &Directive{ + Name: blockType, + Args: args, + Block: block, + } + v.Directives = append(v.Directives, directive) + return directive +} + +// GetBlock 获取块指令 +func (v *VirtualHost) GetBlock(blockType string, args ...string) *Block { + for _, dir := range v.Directives { + if dir.Block != nil && strings.EqualFold(dir.Block.Type, blockType) { + // 如果指定了参数,需要匹配 + if len(args) > 0 { + if len(dir.Block.Args) != len(args) { + continue + } + match := true + for i, arg := range args { + if dir.Block.Args[i] != arg { + match = false + break + } + } + if !match { + continue + } + } + return dir.Block + } + } + return nil +} diff --git a/pkg/webserver/apache/data.go b/pkg/webserver/apache/data.go new file mode 100644 index 00000000..0d409b0f --- /dev/null +++ b/pkg/webserver/apache/data.go @@ -0,0 +1,30 @@ +package apache + +// DisableConfName 禁用配置文件名 +const DisableConfName = "00-disable.conf" + +// DisableConfContent 禁用配置内容 +const DisableConfContent = `# 网站已停止 +RewriteEngine on +RewriteRule ^.*$ - [R=503,L] +` + +// DefaultVhostConf 默认配置模板 +const DefaultVhostConf = ` + ServerName localhost + DocumentRoot /opt/ace/sites/default/public + DirectoryIndex index.php index.html + + ErrorLog /opt/ace/sites/default/log/error.log + CustomLog /opt/ace/sites/default/log/access.log combined + + # custom configs + IncludeOptional /opt/ace/sites/default/config/server.d/*.conf + + + Options -Indexes +FollowSymLinks + AllowOverride All + Require all granted + + +` diff --git a/pkg/webserver/apache/export.go b/pkg/webserver/apache/export.go new file mode 100644 index 00000000..255bbd1f --- /dev/null +++ b/pkg/webserver/apache/export.go @@ -0,0 +1,347 @@ +package apache + +import ( + "fmt" + "sort" + "strings" +) + +// ExportOptions 定义导出选项 +type ExportOptions struct { + // IndentStyle 缩进样式:使用空格还是制表符 + IndentStyle string // "spaces" 或 "tabs" + + // IndentSize 缩进大小(仅当IndentStyle为"spaces"时有效) + IndentSize int + + // SortDirectives 是否对指令进行排序 + SortDirectives bool + + // IncludeComments 是否包含注释 + IncludeComments bool + + // PreserveEmptyLines 是否保留空行 + PreserveEmptyLines bool + + // FormatStyle 格式化风格 + FormatStyle string // "compact", "standard", "verbose" +} + +// DefaultExportOptions 返回默认的导出选项 +func DefaultExportOptions() *ExportOptions { + return &ExportOptions{ + IndentStyle: "spaces", + IndentSize: 4, + SortDirectives: false, + IncludeComments: true, + PreserveEmptyLines: true, + FormatStyle: "standard", + } +} + +// Export 导出整个配置为Apache配置文件格式 +func (c *Config) Export() string { + return c.ExportWithOptions(DefaultExportOptions()) +} + +// ExportWithOptions 使用指定选项导出配置 +func (c *Config) ExportWithOptions(options *ExportOptions) string { + var builder strings.Builder + var items []exportItem + + // 收集所有需要导出的项目 + // 添加全局注释 + if options.IncludeComments { + for _, comment := range c.Comments { + items = append(items, exportItem{ + line: comment.Line, + item: comment, + typ: "comment", + }) + } + } + + // 添加全局指令 + for _, directive := range c.Directives { + items = append(items, exportItem{ + line: directive.Line, + item: directive, + typ: "directive", + }) + } + + // 添加虚拟主机 + for _, vhost := range c.VirtualHosts { + items = append(items, exportItem{ + line: vhost.Line, + item: vhost, + typ: "virtualhost", + }) + } + + // 如果不需要保持原始顺序,按行号排序 + if !options.SortDirectives { + sort.Slice(items, func(i, j int) bool { + return items[i].line < items[j].line + }) + } + + // 导出所有项目 + for i, item := range items { + switch item.typ { + case "comment": + comment := item.item.(*Comment) + builder.WriteString(comment.ExportWithOptions(options, 0)) + case "directive": + directive := item.item.(*Directive) + builder.WriteString(directive.ExportWithOptions(options, 0)) + case "virtualhost": + vhost := item.item.(*VirtualHost) + builder.WriteString(vhost.ExportWithOptions(options, 0)) + } + + // 添加换行符 + if options.PreserveEmptyLines && i < len(items)-1 { + // 检查是否需要添加空行 + nextItem := items[i+1] + if shouldAddEmptyLine(item, nextItem, options) { + builder.WriteString("\n") + } + } + } + + return strings.TrimSpace(builder.String()) +} + +// ExportWithOptions 导出指令 +func (d *Directive) ExportWithOptions(options *ExportOptions, indent int) string { + var builder strings.Builder + + // 如果是块指令,使用Block的导出方法 + if d.Block != nil { + return d.Block.ExportWithOptions(options, indent) + } + + // 添加缩进 + builder.WriteString(getIndent(options, indent)) + + // 指令名称 + builder.WriteString(d.Name) + + // 指令参数 + if len(d.Args) > 0 { + builder.WriteString(" ") + for i, arg := range d.Args { + if i > 0 { + builder.WriteString(" ") + } + + // 如果参数包含空格,需要引用 + if strings.Contains(arg, " ") && !strings.HasPrefix(arg, "\"") { + builder.WriteString(fmt.Sprintf(`"%s"`, arg)) + } else { + builder.WriteString(arg) + } + } + } + + builder.WriteString("\n") + return builder.String() +} + +// ExportWithOptions 导出虚拟主机 +func (v *VirtualHost) ExportWithOptions(options *ExportOptions, indent int) string { + var builder strings.Builder + + // 开始标签 + builder.WriteString(getIndent(options, indent)) + builder.WriteString("<") + builder.WriteString(v.Name) + if len(v.Args) > 0 { + builder.WriteString(" ") + builder.WriteString(strings.Join(v.Args, " ")) + } + builder.WriteString(">\n") + + // 收集虚拟主机内的项目 + var items []exportItem + + // 添加注释 + if options.IncludeComments { + for _, comment := range v.Comments { + items = append(items, exportItem{ + line: comment.Line, + item: comment, + typ: "comment", + }) + } + } + + // 添加指令 + for _, directive := range v.Directives { + items = append(items, exportItem{ + line: directive.Line, + item: directive, + typ: "directive", + }) + } + + // 排序 + if !options.SortDirectives { + sort.Slice(items, func(i, j int) bool { + return items[i].line < items[j].line + }) + } + + // 导出虚拟主机内容 + for _, item := range items { + switch item.typ { + case "comment": + comment := item.item.(*Comment) + builder.WriteString(comment.ExportWithOptions(options, indent+1)) + case "directive": + directive := item.item.(*Directive) + builder.WriteString(directive.ExportWithOptions(options, indent+1)) + } + } + + // 结束标签 + builder.WriteString(getIndent(options, indent)) + builder.WriteString("\n") + + return builder.String() +} + +// ExportWithOptions 导出注释 +func (c *Comment) ExportWithOptions(options *ExportOptions, indent int) string { + var builder strings.Builder + + // 添加缩进 + builder.WriteString(getIndent(options, indent)) + + // 注释内容 + builder.WriteString("# ") + builder.WriteString(c.Text) + builder.WriteString("\n") + + return builder.String() +} + +// ExportWithOptions 导出块指令 +func (b *Block) ExportWithOptions(options *ExportOptions, indent int) string { + var builder strings.Builder + + // 开始标签 + builder.WriteString(getIndent(options, indent)) + builder.WriteString("<") + builder.WriteString(b.Type) + if len(b.Args) > 0 { + builder.WriteString(" ") + for i, arg := range b.Args { + if i > 0 { + builder.WriteString(" ") + } + // 如果参数包含空格,需要引用 + if strings.Contains(arg, " ") && !strings.HasPrefix(arg, "\"") { + builder.WriteString(fmt.Sprintf(`"%s"`, arg)) + } else { + builder.WriteString(arg) + } + } + } + builder.WriteString(">\n") + + // 块内指令和注释 + allItems := make([]exportItem, 0, len(b.Directives)+len(b.Comments)) + + // 添加指令 + for _, directive := range b.Directives { + allItems = append(allItems, exportItem{item: directive, line: directive.Line, typ: "directive"}) + } + + // 添加注释 + if options.IncludeComments { + for _, comment := range b.Comments { + allItems = append(allItems, exportItem{item: comment, line: comment.Line, typ: "comment"}) + } + } + + // 按行号排序 + if options.SortDirectives { + sort.Slice(allItems, func(i, j int) bool { + return allItems[i].line < allItems[j].line + }) + } + + // 导出所有项目 + for i, item := range allItems { + switch item.typ { + case "comment": + comment := item.item.(*Comment) + builder.WriteString(comment.ExportWithOptions(options, indent+1)) + case "directive": + directive := item.item.(*Directive) + builder.WriteString(directive.ExportWithOptions(options, indent+1)) + } + + // 添加空行(如果需要) + if options.PreserveEmptyLines && i < len(allItems)-1 { + nextItem := allItems[i+1] + if shouldAddEmptyLine(item, nextItem, options) { + builder.WriteString("\n") + } + } + } + + // 结束标签 + builder.WriteString(getIndent(options, indent)) + builder.WriteString("\n") + + return builder.String() +} + +// exportItem 用于排序的导出项目 +type exportItem struct { + line int + item interface{} + typ string +} + +// getIndent 获取缩进字符串 +func getIndent(options *ExportOptions, level int) string { + if level <= 0 { + return "" + } + + if options.IndentStyle == "tabs" { + return strings.Repeat("\t", level) + } + return strings.Repeat(" ", level*options.IndentSize) +} + +// shouldAddEmptyLine 判断是否应该添加空行 +func shouldAddEmptyLine(current, next exportItem, options *ExportOptions) bool { + if !options.PreserveEmptyLines { + return false + } + + // 根据格式化风格决定 + switch options.FormatStyle { + case "verbose": + return true + case "compact": + return false + case "standard": + // 在不同类型之间添加空行(除了注释) + if current.typ != next.typ && current.typ != "comment" && next.typ != "comment" { + return true + } + return false + } + + return false +} diff --git a/pkg/webserver/apache/lexer.go b/pkg/webserver/apache/lexer.go new file mode 100644 index 00000000..f627c392 --- /dev/null +++ b/pkg/webserver/apache/lexer.go @@ -0,0 +1,349 @@ +package apache + +import ( + "io" + "strings" + "unicode" +) + +// TokenType 表示 token 的类型 +type TokenType int + +const ( + ILLEGAL TokenType = iota + EOF + NEWLINE + COMMENT + DIRECTIVE + STRING + LBRACE // < + RBRACE // > + SLASH // / + COLON // : + SEMICOLON // ; + EQUAL // = + QUOTE // " + VIRTUALHOST + BLOCKDIRECTIVE // Directory, Location 等块指令 +) + +// Token 表示一个词法单元 +type Token struct { + Type TokenType + Value string + Line int + Column int +} + +// Lexer 词法分析器 +type Lexer struct { + current rune + line int + column int + buf []rune + pos int + content string +} + +// NewLexer 创建一个新的词法分析器 +func NewLexer(input io.Reader) (*Lexer, error) { + // 读取全部内容到字符串 + content := new(strings.Builder) + _, err := io.Copy(content, input) + if err != nil { + return nil, err + } + + l := &Lexer{ + line: 1, + column: 0, + content: content.String(), + buf: []rune(content.String()), + pos: -1, + } + + l.readChar() // 初始化第一个字符 + return l, nil +} + +// readChar 读取下一个字符 +func (l *Lexer) readChar() { + l.pos++ + if l.pos >= len(l.buf) { + l.current = 0 // EOF + } else { + l.current = l.buf[l.pos] + } + + l.column++ + if l.current == '\n' { + l.line++ + l.column = 0 + } +} + +// peekChar 预览下一个字符而不移动位置 +func (l *Lexer) peekChar() rune { + if l.pos+1 >= len(l.buf) { + return 0 + } + return l.buf[l.pos+1] +} + +// skipWhitespace 跳过空白字符 +func (l *Lexer) skipWhitespace() { + for l.current == ' ' || l.current == '\t' || l.current == '\r' { + l.readChar() + } +} + +// readString 读取字符串字面量 +func (l *Lexer) readString(delimiter rune) string { + var result strings.Builder + l.readChar() // 跳过开始的引号 + + for l.current != delimiter && l.current != 0 { + if l.current == '\\' { + l.readChar() + if l.current != 0 { + // 保持转义字符的原始形式 + result.WriteRune('\\') + result.WriteRune(l.current) + l.readChar() + } + } else { + result.WriteRune(l.current) + l.readChar() + } + } + + return result.String() +} + +// readIdentifier 读取标识符或指令名 +func (l *Lexer) readIdentifier() string { + var result strings.Builder + + for unicode.IsLetter(l.current) || unicode.IsDigit(l.current) || l.current == '_' || l.current == '-' || l.current == '.' || l.current == ':' || l.current == '/' || l.current == '$' || l.current == '@' || l.current == '%' || l.current == '{' || l.current == '}' || l.current == '?' || l.current == '&' || l.current == '=' || l.current == '+' { + result.WriteRune(l.current) + l.readChar() + } + + return result.String() +} + +// readWord 读取单词(可能包含特殊字符) +func (l *Lexer) readWord() string { + var result strings.Builder + + for l.current != 0 && l.current != ' ' && l.current != '\t' && l.current != '\n' && l.current != '\r' && + l.current != '<' && l.current != '>' && l.current != '"' && l.current != '\'' { + result.WriteRune(l.current) + l.readChar() + } + + return result.String() +} + +// readComment 读取注释 +func (l *Lexer) readComment() string { + var result strings.Builder + l.readChar() // 跳过 # + + // 跳过 # 后面的第一个空格(如果有的话) + if l.current == ' ' { + l.readChar() + } + + for l.current != '\n' && l.current != 0 { + result.WriteRune(l.current) + l.readChar() + } + + return result.String() +} + +// isVirtualHostDirective 检查是否是虚拟主机指令 +func (l *Lexer) isVirtualHostDirective(identifier string) bool { + return strings.EqualFold(identifier, "VirtualHost") +} + +// isBlockDirective 检查是否是块指令 +func (l *Lexer) isBlockDirective(identifier string) bool { + blockDirectives := []string{ + "Directory", "DirectoryMatch", "Location", "LocationMatch", + "Files", "FilesMatch", "Limit", "LimitExcept", "RequireAll", "RequireAny", "RequireNone", + "IfModule", "IfDefine", "IfVersion", "Proxy", + } + + for _, blockDir := range blockDirectives { + if strings.EqualFold(identifier, blockDir) { + return true + } + } + return false +} + +// NextToken 获取下一个 token +func (l *Lexer) NextToken() Token { + var tok Token + + l.skipWhitespace() + + tok.Line = l.line + tok.Column = l.column + + switch l.current { + case '#': + tok.Type = COMMENT + tok.Value = l.readComment() + case '\n': + tok.Type = NEWLINE + tok.Value = "\n" + l.readChar() + case '<': + // 检查是否是虚拟主机或目录块 + l.readChar() // 跳过 < + + // 检查是否是结束标签 + isClosing := false + if l.current == '/' { + isClosing = true + l.readChar() + } + + identifier := l.readIdentifier() + + // 如果无法读取到有效的标识符,这可能是无效语法 + if identifier == "" { + // 将此作为ILLEGAL token处理 + tok.Type = ILLEGAL + tok.Value = "<" + return tok + } + + // 跳过空白字符和参数 + l.skipWhitespace() + var args []string + for l.current != '>' && l.current != 0 { + // 记录当前位置,防止无限循环 + oldPos := l.pos + + if l.current == '"' || l.current == '\'' { + // 保留引号 + quoteChar := l.current + arg := string(quoteChar) + l.readString(l.current) + string(quoteChar) + args = append(args, arg) + l.readChar() // 跳过结束引号 + } else { + arg := l.readWord() + if arg != "" { + args = append(args, arg) + } + } + l.skipWhitespace() + + // 如果位置没有前进,说明遇到了无法处理的字符,退出循环防止死循环 + if l.pos == oldPos { + // 尝试跳过一个字符继续 + if l.current != 0 { + l.readChar() + } + break + } + } + + if l.current == '>' { + l.readChar() // 跳过 > + } + + if l.isVirtualHostDirective(identifier) { + tok.Type = VIRTUALHOST + if isClosing { + tok.Value = "/" + identifier + } else { + tok.Value = identifier + if len(args) > 0 { + tok.Value += " " + strings.Join(args, " ") + } + } + } else if l.isBlockDirective(identifier) { + // 识别为块指令 + tok.Type = BLOCKDIRECTIVE + if isClosing { + tok.Value = "/" + identifier + } else { + tok.Value = identifier + if len(args) > 0 { + tok.Value += " " + strings.Join(args, " ") + } + } + } else { + tok.Type = DIRECTIVE + if isClosing { + tok.Value = "/" + identifier + } else { + tok.Value = identifier + if len(args) > 0 { + tok.Value += " " + strings.Join(args, " ") + } + } + } + case '>': + tok.Type = RBRACE + tok.Value = ">" + l.readChar() + case '"': + tok.Type = STRING + // 保留引号 + tok.Value = `"` + l.readString('"') + `"` + l.readChar() + case '\'': + tok.Type = STRING + // 保留引号 + tok.Value = "'" + l.readString('\'') + "'" + l.readChar() + case 0: + tok.Type = EOF + tok.Value = "" + default: + if unicode.IsLetter(l.current) { + identifier := l.readIdentifier() + tok.Type = DIRECTIVE + tok.Value = identifier + } else { + // 读取其他类型的单词 + word := l.readWord() + if word != "" { + tok.Type = STRING + tok.Value = word + } else { + tok.Type = ILLEGAL + tok.Value = string(l.current) + l.readChar() + } + } + } + + return tok +} + +// PeekToken 预览下一个 token 而不移动位置 +func (l *Lexer) PeekToken() Token { + // 保存当前状态 + savedPos := l.pos + savedLine := l.line + savedColumn := l.column + savedCurrent := l.current + + // 获取下一个 token + token := l.NextToken() + + // 恢复状态 + l.pos = savedPos + l.line = savedLine + l.column = savedColumn + l.current = savedCurrent + + return token +} diff --git a/pkg/webserver/apache/options.go b/pkg/webserver/apache/options.go new file mode 100644 index 00000000..93c752c0 --- /dev/null +++ b/pkg/webserver/apache/options.go @@ -0,0 +1,27 @@ +package apache + +import ( + "os" +) + +// ParseOptions 定义解析器选项 +type ParseOptions struct { + // ProcessIncludes 是否处理Include指令,递归加载包含的文件 + ProcessIncludes bool + + // BaseDir 基础目录,用于解析相对路径的Include文件 + BaseDir string + + // MaxIncludeDepth 最大包含深度,防止无限递归 + MaxIncludeDepth int +} + +// DefaultParseOptions 返回默认的解析选项 +func DefaultParseOptions() *ParseOptions { + wd, _ := os.Getwd() + return &ParseOptions{ + ProcessIncludes: false, // 默认不处理Include + BaseDir: wd, + MaxIncludeDepth: 10, + } +} diff --git a/pkg/webserver/apache/parser.go b/pkg/webserver/apache/parser.go new file mode 100644 index 00000000..ecf71294 --- /dev/null +++ b/pkg/webserver/apache/parser.go @@ -0,0 +1,429 @@ +package apache + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// Parser Apache 配置文件解析器 +type Parser struct { + lexer *Lexer + options *ParseOptions + includeStack []string // 用于检测循环包含 + currentDepth int // 当前包含深度 + currentBaseDir string // 当前文件的基础目录 +} + +// NewParser 创建一个新的 Apache 配置解析器(使用默认选项) +func NewParser(input io.Reader) (*Parser, error) { + return NewParserWithOptions(input, DefaultParseOptions()) +} + +// NewParserWithOptions 创建一个带选项的 Apache 配置解析器 +func NewParserWithOptions(input io.Reader, options *ParseOptions) (*Parser, error) { + lexer, err := NewLexer(input) + if err != nil { + return nil, err + } + + return &Parser{ + lexer: lexer, + options: options, + includeStack: make([]string, 0), + currentDepth: 0, + currentBaseDir: options.BaseDir, + }, nil +} + +// Parse 解析 Apache 配置文件并返回 AST +func (p *Parser) Parse() (*Config, error) { + config := &Config{ + Directives: make([]*Directive, 0), + VirtualHosts: make([]*VirtualHost, 0), + Comments: make([]*Comment, 0), + } + + for { + token := p.lexer.NextToken() + if token.Type == EOF { + break + } + + switch token.Type { + case COMMENT: + comment := &Comment{ + Text: token.Value, + Line: token.Line, + Column: token.Column, + } + config.Comments = append(config.Comments, comment) + + case DIRECTIVE: + directive, err := p.parseDirective(token) + if err != nil { + return nil, fmt.Errorf("error parsing directive: %w", err) + } + + // 检查是否是Include类指令 + if p.options.ProcessIncludes && (strings.EqualFold(directive.Name, "Include") || strings.EqualFold(directive.Name, "IncludeOptional")) { + includeConfig, err := p.processInclude(directive) + if err != nil { + // 对于IncludeOptional,如果文件不存在则忽略错误 + if strings.EqualFold(directive.Name, "IncludeOptional") && os.IsNotExist(err) { + // 仍然记录Include指令,但不合并内容 + include := &Include{ + Path: directive.Args[0], + Line: directive.Line, + Column: directive.Column, + } + config.Includes = append(config.Includes, include) + continue + } + return nil, fmt.Errorf("error processing include file '%s': %w", directive.Args[0], err) + } + + if includeConfig != nil { + // 合并包含的配置到当前配置 + config.Directives = append(config.Directives, includeConfig.Directives...) + config.VirtualHosts = append(config.VirtualHosts, includeConfig.VirtualHosts...) + config.Comments = append(config.Comments, includeConfig.Comments...) + config.Includes = append(config.Includes, includeConfig.Includes...) + } + + // 记录Include指令 + include := &Include{ + Path: directive.Args[0], + Line: directive.Line, + Column: directive.Column, + } + config.Includes = append(config.Includes, include) + } else { + config.Directives = append(config.Directives, directive) + } + + case VIRTUALHOST: + vhost, err := p.parseVirtualHost(token) + if err != nil { + return nil, fmt.Errorf("error parsing virtual host: %w", err) + } + config.VirtualHosts = append(config.VirtualHosts, vhost) + + case BLOCKDIRECTIVE: + block, err := p.parseBlockDirective(token) + if err != nil { + return nil, fmt.Errorf("error parsing block directive: %w", err) + } + // 将块指令作为带Block的Directive添加到配置中 + directive := &Directive{ + Name: block.Type, + Args: block.Args, + Line: block.Line, + Column: block.Column, + Block: block, + } + config.Directives = append(config.Directives, directive) + + case NEWLINE: + // 跳过换行符 + continue + + case ILLEGAL: + return nil, fmt.Errorf("invalid syntax: '%s' at line %d, column %d", token.Value, token.Line, token.Column) + + default: + return nil, fmt.Errorf("unknown token type: %v at line %d, column %d", token.Type, token.Line, token.Column) + } + } + + return config, nil +} + +// parseDirective 解析单个指令 +func (p *Parser) parseDirective(token Token) (*Directive, error) { + directive := &Directive{ + Name: token.Value, + Line: token.Line, + Column: token.Column, + Args: make([]string, 0), + } + + // 读取指令参数 + for { + nextToken := p.lexer.PeekToken() + if nextToken.Type == NEWLINE || nextToken.Type == EOF { + break + } + + argToken := p.lexer.NextToken() + if argToken.Type == STRING || argToken.Type == DIRECTIVE { + directive.Args = append(directive.Args, argToken.Value) + } + } + + return directive, nil +} + +// parseVirtualHost 解析虚拟主机配置 +func (p *Parser) parseVirtualHost(token Token) (*VirtualHost, error) { + // 从 token.Value 中提取虚拟主机名称和参数 + parts := strings.Fields(token.Value) + vhost := &VirtualHost{ + Name: "VirtualHost", + Line: token.Line, + Column: token.Column, + Directives: make([]*Directive, 0), + Comments: make([]*Comment, 0), + } + + // 如果有参数,添加到 Args 中 + if len(parts) > 1 { + vhost.Args = parts[1:] + } + + // 跳过换行符到虚拟主机内容 + for { + nextToken := p.lexer.NextToken() + if nextToken.Type == EOF { + return nil, fmt.Errorf("unexpected end of virtual host") + } + // 检查结束标签(可能是VIRTUALHOST或DIRECTIVE类型) + if (nextToken.Type == VIRTUALHOST && strings.HasPrefix(nextToken.Value, "/VirtualHost")) || + (nextToken.Type == DIRECTIVE && strings.HasPrefix(nextToken.Value, "/VirtualHost")) { + // 遇到结束标签 + break + } + if nextToken.Type == NEWLINE { + continue + } + + if nextToken.Type == DIRECTIVE { + directive, err := p.parseDirective(nextToken) + if err != nil { + return nil, err + } + + // 检查是否是Include类指令 + if p.options.ProcessIncludes && (strings.EqualFold(directive.Name, "Include") || strings.EqualFold(directive.Name, "IncludeOptional")) { + includeConfig, err := p.processInclude(directive) + if err != nil { + // 对于IncludeOptional,如果文件不存在则忽略错误 + if strings.EqualFold(directive.Name, "IncludeOptional") && os.IsNotExist(err) { + continue + } + return nil, fmt.Errorf("error processing include file '%s': %w", directive.Args[0], err) + } + + if includeConfig != nil { + // 合并包含的配置到虚拟主机 + vhost.Directives = append(vhost.Directives, includeConfig.Directives...) + vhost.Comments = append(vhost.Comments, includeConfig.Comments...) + // 注意:虚拟主机内的Include不应该包含其他虚拟主机,但如果包含了也要处理 + if len(includeConfig.VirtualHosts) > 0 { + return nil, fmt.Errorf("include files inside virtual host cannot contain other virtual hosts") + } + } + } else { + vhost.Directives = append(vhost.Directives, directive) + } + } else if nextToken.Type == COMMENT { + // 收集虚拟主机内的注释 + comment := &Comment{ + Text: nextToken.Value, + Line: nextToken.Line, + Column: nextToken.Column, + } + vhost.Comments = append(vhost.Comments, comment) + } + } + + return vhost, nil +} + +// ParseFile 从文件解析 Apache 配置 +func ParseFile(filename string) (*Config, error) { + file, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer func(file *os.File) { _ = file.Close() }(file) + + parser, err := NewParser(file) + if err != nil { + return nil, err + } + + return parser.Parse() +} + +// ParseString 从字符串解析 Apache 配置 +func ParseString(content string) (*Config, error) { + reader := strings.NewReader(content) + parser, err := NewParser(reader) + if err != nil { + return nil, err + } + + return parser.Parse() +} + +// ParseStringWithOptions 从字符串解析 Apache 配置(带选项) +func ParseStringWithOptions(content string, options *ParseOptions) (*Config, error) { + reader := strings.NewReader(content) + parser, err := NewParserWithOptions(reader, options) + if err != nil { + return nil, err + } + + return parser.Parse() +} + +// ParseFileWithOptions 从文件解析 Apache 配置(带选项) +func ParseFileWithOptions(filename string, options *ParseOptions) (*Config, error) { + file, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer func(file *os.File) { _ = file.Close() }(file) + + // 设置基础目录为文件所在目录 + if options.BaseDir == "" { + options.BaseDir = filepath.Dir(filename) + } + + parser, err := NewParserWithOptions(file, options) + if err != nil { + return nil, err + } + + // 设置当前文件路径用于循环检测 + absPath, _ := filepath.Abs(filename) + parser.includeStack = append(parser.includeStack, absPath) + + return parser.Parse() +} + +// processInclude 处理Include指令 +func (p *Parser) processInclude(directive *Directive) (*Config, error) { + if len(directive.Args) == 0 { + return nil, fmt.Errorf("include directive missing file path argument") + } + + // 检查递归深度 + if p.currentDepth >= p.options.MaxIncludeDepth { + return nil, fmt.Errorf("include nesting depth exceeds limit %d", p.options.MaxIncludeDepth) + } + + includePath := directive.Args[0] + + // 解析包含文件的完整路径 + var fullPath string + var err error + + if filepath.IsAbs(includePath) { + fullPath = includePath + } else { + // 相对路径基于当前文件所在目录 + fullPath = filepath.Join(p.currentBaseDir, includePath) + } + + // 获取绝对路径用于循环检测 + absPath, err := filepath.Abs(fullPath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path of file '%s': %w", fullPath, err) + } + + // 检查循环包含 + for _, stackPath := range p.includeStack { + if stackPath == absPath { + return nil, fmt.Errorf("circular include detected: %s", absPath) + } + } + + // 检查文件是否存在 + if _, err := os.Stat(fullPath); err != nil { + return nil, err + } + + // 打开并解析包含的文件 + file, err := os.Open(fullPath) + if err != nil { + return nil, fmt.Errorf("failed to open include file: %w", err) + } + defer func(file *os.File) { _ = file.Close() }(file) + + // 创建新的解析器选项,继承当前选项但更新基础目录 + includeOptions := *p.options + includeOptions.BaseDir = filepath.Dir(fullPath) + + // 创建新的解析器实例 + includeParser, err := NewParserWithOptions(file, &includeOptions) + if err != nil { + return nil, fmt.Errorf("failed to create include file parser: %w", err) + } + + // 设置包含解析器的状态 + includeParser.includeStack = append(p.includeStack, absPath) + includeParser.currentDepth = p.currentDepth + 1 + includeParser.currentBaseDir = filepath.Dir(fullPath) + + // 解析包含的文件 + return includeParser.Parse() +} + +// parseBlockDirective 解析块指令(如Directory, Location等) +func (p *Parser) parseBlockDirective(token Token) (*Block, error) { + // 从token.Value中提取块类型和参数 + parts := strings.Fields(token.Value) + block := &Block{ + Type: parts[0], + Line: token.Line, + Column: token.Column, + Directives: make([]*Directive, 0), + Comments: make([]*Comment, 0), + } + + // 如果有参数,添加到Args中 + if len(parts) > 1 { + block.Args = parts[1:] + } + + // 跳过换行符到块内容 + for { + nextToken := p.lexer.NextToken() + if nextToken.Type == EOF { + return nil, fmt.Errorf("unexpected end of block directive") + } + + // 检查结束标签 + if (nextToken.Type == BLOCKDIRECTIVE && strings.HasPrefix(nextToken.Value, "/"+block.Type)) || + (nextToken.Type == DIRECTIVE && strings.HasPrefix(nextToken.Value, "/"+block.Type)) { + // 遇到结束标签 + break + } + + if nextToken.Type == NEWLINE { + continue + } + + if nextToken.Type == DIRECTIVE { + directive, err := p.parseDirective(nextToken) + if err != nil { + return nil, err + } + block.Directives = append(block.Directives, directive) + } else if nextToken.Type == COMMENT { + // 处理块内注释 + comment := &Comment{ + Text: nextToken.Value, + Line: nextToken.Line, + Column: nextToken.Column, + } + block.Comments = append(block.Comments, comment) + } + } + + return block, nil +} diff --git a/pkg/webserver/apache/parser_test.go b/pkg/webserver/apache/parser_test.go new file mode 100644 index 00000000..b1c9b54f --- /dev/null +++ b/pkg/webserver/apache/parser_test.go @@ -0,0 +1,262 @@ +package apache + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/suite" +) + +type ParserTestSuite struct { + suite.Suite +} + +func TestParserTestSuite(t *testing.T) { + suite.Run(t, &ParserTestSuite{}) +} + +func (s *ParserTestSuite) TestParseSimpleDirective() { + input := "ServerName www.example.com" + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Directives, 1) + + directive := config.Directives[0] + s.Equal("ServerName", directive.Name) + s.Equal([]string{"www.example.com"}, directive.Args) +} + +func (s *ParserTestSuite) TestParseDirectiveWithMultipleArgs() { + input := "Listen 192.168.1.100:80" + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Directives, 1) + + directive := config.Directives[0] + s.Equal("Listen", directive.Name) + s.Equal([]string{"192.168.1.100:80"}, directive.Args) +} + +func (s *ParserTestSuite) TestParseComment() { + input := "# This is a comment\nServerName www.example.com" + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Comments, 1) + s.Len(config.Directives, 1) + + comment := config.Comments[0] + s.Equal("This is a comment", comment.Text) + s.Equal(1, comment.Line) + + directive := config.Directives[0] + s.Equal("ServerName", directive.Name) +} + +func (s *ParserTestSuite) TestParseVirtualHost() { + input := ` + ServerName www.example.com + DocumentRoot /var/www/html +` + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.VirtualHosts, 1) + + vhost := config.VirtualHosts[0] + s.Equal("VirtualHost", vhost.Name) + s.Equal([]string{"*:80"}, vhost.Args) + s.Len(vhost.Directives, 2) + + serverName := vhost.Directives[0] + s.Equal("ServerName", serverName.Name) + s.Equal([]string{"www.example.com"}, serverName.Args) + + docRoot := vhost.Directives[1] + s.Equal("DocumentRoot", docRoot.Name) + s.Equal([]string{"/var/www/html"}, docRoot.Args) +} + +func (s *ParserTestSuite) TestParseComplexConfig() { + input := `# Apache 配置示例 +ServerRoot /etc/apache2 +ServerName www.example.com:80 + +# SSL 配置 +LoadModule ssl_module modules/mod_ssl.so + + + ServerName www.example.com + DocumentRoot /var/www/html + ErrorLog logs/error.log + CustomLog logs/access.log common + + + + ServerName www.example.com + DocumentRoot /var/www/html + SSLEngine on + SSLCertificateFile /path/to/certificate.crt + SSLCertificateKeyFile /path/to/private.key +` + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + // 检查注释 + s.Len(config.Comments, 2) + s.Equal("Apache 配置示例", config.Comments[0].Text) + s.Equal("SSL 配置", config.Comments[1].Text) + + // 检查全局指令 + s.Len(config.Directives, 3) + s.Equal("ServerRoot", config.Directives[0].Name) + s.Equal("ServerName", config.Directives[1].Name) + s.Equal("LoadModule", config.Directives[2].Name) + + // 检查虚拟主机 + s.Len(config.VirtualHosts, 2) + + // HTTP 虚拟主机 + httpVhost := config.VirtualHosts[0] + s.Equal([]string{"*:80"}, httpVhost.Args) + s.Len(httpVhost.Directives, 4) + + // HTTPS 虚拟主机 + httpsVhost := config.VirtualHosts[1] + s.Equal([]string{"*:443"}, httpsVhost.Args) + s.Len(httpsVhost.Directives, 5) + + // 检查 SSL 指令 + sslEngine := httpsVhost.Directives[2] + s.Equal("SSLEngine", sslEngine.Name) + s.Equal([]string{"on"}, sslEngine.Args) +} + +func (s *ParserTestSuite) TestParseQuotedStrings() { + input := `ServerName "www.example.com" +CustomLog "/var/log/apache2/access.log" combined` + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Directives, 2) + + serverName := config.Directives[0] + s.Equal("ServerName", serverName.Name) + s.Equal([]string{"\"www.example.com\""}, serverName.Args) + + customLog := config.Directives[1] + s.Equal("CustomLog", customLog.Name) + s.Equal([]string{"\"/var/log/apache2/access.log\"", "combined"}, customLog.Args) +} + +func (s *ParserTestSuite) TestParseEmptyConfig() { + input := "" + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Directives, 0) + s.Len(config.VirtualHosts, 0) + s.Len(config.Comments, 0) +} + +func (s *ParserTestSuite) TestParseOnlyComments() { + input := `# 这是第一个注释 +# 这是第二个注释` + + config, err := ParseString(input) + s.NoError(err) + s.NotNil(config) + + s.Len(config.Comments, 2) + s.Len(config.Directives, 0) + s.Len(config.VirtualHosts, 0) + + s.Equal("这是第一个注释", config.Comments[0].Text) + s.Equal("这是第二个注释", config.Comments[1].Text) +} + +func (s *ParserTestSuite) TestConfigGetMethods() { + input := `ServerName www.example.com +ServerAdmin admin@example.com +ServerName backup.example.com + + + ServerName www.example.com + DocumentRoot /var/www/html + + + + ServerName www.example.com + DocumentRoot /var/www/secure +` + + config, err := ParseString(input) + s.NoError(err) + + // 测试 GetDirective + serverName := config.GetDirective("ServerName") + s.NotNil(serverName) + s.Equal("ServerName", serverName.Name) + s.Equal([]string{"www.example.com"}, serverName.Args) + + // 测试 GetDirectives + serverNames := config.GetDirectives("ServerName") + s.Len(serverNames, 2) + + // 测试 GetVirtualHost + vhost := config.GetVirtualHost("*:80") + s.NotNil(vhost) + s.Equal([]string{"*:80"}, vhost.Args) + + // 测试虚拟主机中的 GetDirective + vhostServerName := vhost.GetDirective("ServerName") + s.NotNil(vhostServerName) + s.Equal([]string{"www.example.com"}, vhostServerName.Args) +} + +func (s *ParserTestSuite) TestLexerTokens() { + input := `# Comment +ServerName www.example.com + + DocumentRoot "/var/www/html" +` + + lexer, err := NewLexer(strings.NewReader(input)) + s.NoError(err) + + // 测试第一个 token - 注释 + token := lexer.NextToken() + s.Equal(COMMENT, token.Type) + s.Equal("Comment", token.Value) + s.Equal(1, token.Line) + + // 跳过换行 + token = lexer.NextToken() + s.Equal(NEWLINE, token.Type) + + // 测试指令 + token = lexer.NextToken() + s.Equal(DIRECTIVE, token.Type) + s.Equal("ServerName", token.Value) + + // 测试参数 + token = lexer.NextToken() + s.Equal(DIRECTIVE, token.Type) + s.Equal("www.example.com", token.Value) +} diff --git a/pkg/webserver/apache/vhost.go b/pkg/webserver/apache/vhost.go new file mode 100644 index 00000000..898f944f --- /dev/null +++ b/pkg/webserver/apache/vhost.go @@ -0,0 +1,654 @@ +package apache + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "github.com/acepanel/panel/pkg/webserver/types" +) + +// Vhost Apache 虚拟主机实现 +type Vhost struct { + config *Config + vhost *VirtualHost + configDir string // 配置目录 +} + +// NewVhost 创建 Apache 虚拟主机实例 +// configDir: 配置目录路径 +func NewVhost(configDir string) (*Vhost, error) { + v := &Vhost{ + configDir: configDir, + } + + // 加载配置 + var config *Config + var err error + + if v.configDir != "" { + // 从配置目录加载主配置文件 + configFile := filepath.Join(v.configDir, "apache.conf") + if _, statErr := os.Stat(configFile); statErr == nil { + config, err = ParseFile(configFile) + if err != nil { + return nil, fmt.Errorf("failed to parse apache config: %w", err) + } + } + } + + // 如果没有配置文件,使用默认配置 + if config == nil { + config, err = ParseString(DefaultVhostConf) + if err != nil { + return nil, fmt.Errorf("failed to parse default config: %w", err) + } + } + + v.config = config + + // 获取第一个虚拟主机 + if len(config.VirtualHosts) > 0 { + v.vhost = config.VirtualHosts[0] + } else { + // 创建默认虚拟主机 + v.vhost = config.AddVirtualHost("*:80") + } + + return v, nil +} + +// ========== VhostCore 接口实现 ========== + +func (v *Vhost) Enable() bool { + // 检查禁用配置文件是否存在 + disableFile := filepath.Join(v.configDir, "server.d", DisableConfName) + _, err := os.Stat(disableFile) + return os.IsNotExist(err) +} + +func (v *Vhost) SetEnable(enable bool, _ ...string) error { + serverDir := filepath.Join(v.configDir, "server.d") + disableFile := filepath.Join(serverDir, DisableConfName) + + if enable { + // 启用:删除禁用配置文件 + if err := os.Remove(disableFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove disable config: %w", err) + } + return nil + } + + // 禁用:创建禁用配置文件 + // 确保目录存在 + if err := os.MkdirAll(serverDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + // 写入禁用配置 + if err := os.WriteFile(disableFile, []byte(DisableConfContent), 0644); err != nil { + return fmt.Errorf("写入禁用配置失败: %w", err) + } + + return nil +} + +func (v *Vhost) Listen() []types.Listen { + var result []types.Listen + + // Apache 的监听配置通常在 VirtualHost 的参数中 + // 例如: + for _, arg := range v.vhost.Args { + listen := types.Listen{ + Address: arg, + Options: make(map[string]string), + } + + // 检查是否是 HTTPS + if strings.Contains(arg, ":443") || v.HTTPS() { + listen.Protocol = "https" + } else { + listen.Protocol = "http" + } + + result = append(result, listen) + } + + return result +} + +func (v *Vhost) SetListen(listens []types.Listen) error { + var args []string + for _, l := range listens { + args = append(args, l.Address) + } + v.vhost.Args = args + return nil +} + +func (v *Vhost) ServerName() []string { + var names []string + + // 获取 ServerName + serverName := v.vhost.GetDirectiveValue("ServerName") + if serverName != "" { + names = append(names, serverName) + } + + // 获取 ServerAlias(可能有多个值) + aliases := v.vhost.GetDirectives("ServerAlias") + for _, alias := range aliases { + names = append(names, alias.Args...) + } + + return names +} + +func (v *Vhost) SetServerName(serverName []string) error { + if len(serverName) == 0 { + return nil + } + + // 设置主域名 + v.vhost.SetDirective("ServerName", serverName[0]) + + // 删除现有的 ServerAlias + v.vhost.RemoveDirectives("ServerAlias") + + // 设置别名 + if len(serverName) > 1 { + v.vhost.AddDirective("ServerAlias", serverName[1:]...) + } + + return nil +} + +func (v *Vhost) Index() []string { + values := v.vhost.GetDirectiveValues("DirectoryIndex") + if values != nil { + return values + } + return nil +} + +func (v *Vhost) SetIndex(index []string) error { + if len(index) == 0 { + v.vhost.RemoveDirective("DirectoryIndex") + return nil + } + v.vhost.SetDirective("DirectoryIndex", index...) + return nil +} + +func (v *Vhost) Root() string { + return v.vhost.GetDirectiveValue("DocumentRoot") +} + +func (v *Vhost) SetRoot(root string) error { + v.vhost.SetDirective("DocumentRoot", root) + + // 同时更新 Directory 块 + dirBlock := v.vhost.GetBlock("Directory") + if dirBlock != nil { + // 更新现有的 Directory 块路径 + dirBlock.Args = []string{root} + } else { + // 添加新的 Directory 块 + block := v.vhost.AddBlock("Directory", root) + if block.Block != nil { + block.Block.Directives = append(block.Block.Directives, + &Directive{Name: "Options", Args: []string{"-Indexes", "+FollowSymLinks"}}, + &Directive{Name: "AllowOverride", Args: []string{"All"}}, + &Directive{Name: "Require", Args: []string{"all", "granted"}}, + ) + } + } + + return nil +} + +func (v *Vhost) Includes() []types.IncludeFile { + var result []types.IncludeFile + + // 获取所有 Include 和 IncludeOptional 指令 + for _, dir := range v.vhost.GetDirectives("Include") { + if len(dir.Args) > 0 { + result = append(result, types.IncludeFile{ + Path: dir.Args[0], + }) + } + } + for _, dir := range v.vhost.GetDirectives("IncludeOptional") { + if len(dir.Args) > 0 { + result = append(result, types.IncludeFile{ + Path: dir.Args[0], + }) + } + } + + return result +} + +func (v *Vhost) SetIncludes(includes []types.IncludeFile) error { + // 删除现有的 Include 指令 + v.vhost.RemoveDirectives("Include") + v.vhost.RemoveDirectives("IncludeOptional") + + // 添加新的 Include 指令 + for _, inc := range includes { + v.vhost.AddDirective("Include", inc.Path) + } + + return nil +} + +func (v *Vhost) AccessLog() string { + return v.vhost.GetDirectiveValue("CustomLog") +} + +func (v *Vhost) SetAccessLog(accessLog string) error { + v.vhost.SetDirective("CustomLog", accessLog, "combined") + return nil +} + +func (v *Vhost) ErrorLog() string { + return v.vhost.GetDirectiveValue("ErrorLog") +} + +func (v *Vhost) SetErrorLog(errorLog string) error { + v.vhost.SetDirective("ErrorLog", errorLog) + return nil +} + +func (v *Vhost) Save() error { + if v.configDir == "" { + return fmt.Errorf("配置目录为空,无法保存") + } + + configFile := filepath.Join(v.configDir, "apache.conf") + content := v.config.Export() + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + return fmt.Errorf("保存配置文件失败: %w", err) + } + + return nil +} + +func (v *Vhost) Reload() error { + // 重载 Apache 配置 + cmds := []string{ + "/opt/ace/apps/apache/bin/apachectl graceful", + "/usr/sbin/apachectl graceful", + "apachectl graceful", + "systemctl reload apache2", + "systemctl reload httpd", + } + + var lastErr error + for _, cmd := range cmds { + parts := strings.Fields(cmd) + if len(parts) < 1 { + continue + } + + // 检查命令是否存在 + if _, err := os.Stat(parts[0]); err == nil || parts[0] == "systemctl" { + err := exec.Command(parts[0], parts[1:]...).Run() + if err == nil { + return nil + } + lastErr = err + } + } + + if lastErr != nil { + return fmt.Errorf("重载 Apache 配置失败: %w", lastErr) + } + return fmt.Errorf("未找到 apachectl 或 apache2 命令") +} + +func (v *Vhost) Reset() error { + // 重置配置为默认值 + config, err := ParseString(DefaultVhostConf) + if err != nil { + return fmt.Errorf("重置配置失败: %w", err) + } + + v.config = config + if len(config.VirtualHosts) > 0 { + v.vhost = config.VirtualHosts[0] + } + + return nil +} + +// ========== VhostSSL 接口实现 ========== + +func (v *Vhost) HTTPS() bool { + // 检查是否有 SSL 相关配置 + return v.vhost.HasDirective("SSLEngine") && + strings.EqualFold(v.vhost.GetDirectiveValue("SSLEngine"), "on") +} + +func (v *Vhost) SSLConfig() *types.SSLConfig { + if !v.HTTPS() { + return nil + } + + config := &types.SSLConfig{ + Cert: v.vhost.GetDirectiveValue("SSLCertificateFile"), + Key: v.vhost.GetDirectiveValue("SSLCertificateKeyFile"), + } + + // 获取协议 + protocols := v.vhost.GetDirectiveValues("SSLProtocol") + if protocols != nil { + config.Protocols = protocols + } + + // 获取加密套件 + config.Ciphers = v.vhost.GetDirectiveValue("SSLCipherSuite") + + // 检查 HSTS + headers := v.vhost.GetDirectives("Header") + for _, h := range headers { + if len(h.Args) >= 3 && strings.Contains(strings.Join(h.Args, " "), "Strict-Transport-Security") { + config.HSTS = true + break + } + } + + // 检查 OCSP + config.OCSP = strings.EqualFold(v.vhost.GetDirectiveValue("SSLUseStapling"), "on") + + // 检查 HTTP 重定向(通常在 HTTP 虚拟主机中配置) + redirects := v.vhost.GetDirectives("RewriteRule") + for _, r := range redirects { + if len(r.Args) >= 2 && strings.Contains(strings.Join(r.Args, " "), "https://") { + config.HTTPRedirect = true + break + } + } + + return config +} + +func (v *Vhost) SetSSLConfig(cfg *types.SSLConfig) error { + if cfg == nil { + return fmt.Errorf("SSL 配置不能为空") + } + + // 启用 SSL + v.vhost.SetDirective("SSLEngine", "on") + + // 设置证书 + if cfg.Cert != "" { + v.vhost.SetDirective("SSLCertificateFile", cfg.Cert) + } + if cfg.Key != "" { + v.vhost.SetDirective("SSLCertificateKeyFile", cfg.Key) + } + + // 设置协议 + if len(cfg.Protocols) > 0 { + v.vhost.SetDirective("SSLProtocol", cfg.Protocols...) + } else { + v.vhost.SetDirective("SSLProtocol", "all", "-SSLv2", "-SSLv3", "-TLSv1", "-TLSv1.1") + } + + // 设置加密套件 + if cfg.Ciphers != "" { + v.vhost.SetDirective("SSLCipherSuite", cfg.Ciphers) + } else { + v.vhost.SetDirective("SSLCipherSuite", "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384") + } + + // 设置 HSTS + if cfg.HSTS { + // 只移除现有的 HSTS Header,保留其他 Header + newDirectives := make([]*Directive, 0, len(v.vhost.Directives)) + for _, dir := range v.vhost.Directives { + if strings.EqualFold(dir.Name, "Header") { + if len(dir.Args) >= 3 && strings.Contains(strings.Join(dir.Args, " "), "Strict-Transport-Security") { + continue + } + } + newDirectives = append(newDirectives, dir) + } + v.vhost.Directives = newDirectives + v.vhost.AddDirective("Header", "always", "set", "Strict-Transport-Security", `"max-age=31536000"`) + } + + // 设置 OCSP + if cfg.OCSP { + v.vhost.SetDirective("SSLUseStapling", "on") + } + + // 设置 HTTP 重定向(需要 mod_rewrite) + if cfg.HTTPRedirect { + v.vhost.SetDirective("RewriteEngine", "on") + v.vhost.AddDirective("RewriteCond", "%{HTTPS}", "off") + v.vhost.AddDirective("RewriteRule", "^(.*)$", "https://%{HTTP_HOST}%{REQUEST_URI}", "[R=301,L]") + } + + // 更新监听端口为 443 + hasSSLPort := false + for _, arg := range v.vhost.Args { + if strings.Contains(arg, ":443") { + hasSSLPort = true + break + } + } + if !hasSSLPort { + v.vhost.Args = append(v.vhost.Args, "*:443") + } + + return nil +} + +func (v *Vhost) ClearHTTPS() error { + // 移除 SSL 相关指令 + v.vhost.RemoveDirective("SSLEngine") + v.vhost.RemoveDirective("SSLCertificateFile") + v.vhost.RemoveDirective("SSLCertificateKeyFile") + v.vhost.RemoveDirective("SSLProtocol") + v.vhost.RemoveDirective("SSLCipherSuite") + v.vhost.RemoveDirective("SSLUseStapling") + + // 只移除 HSTS 相关的 Header 指令 + newDirectives := make([]*Directive, 0, len(v.vhost.Directives)) + for _, dir := range v.vhost.Directives { + if strings.EqualFold(dir.Name, "Header") { + // 检查是否是 HSTS Header + if len(dir.Args) >= 3 && strings.Contains(strings.Join(dir.Args, " "), "Strict-Transport-Security") { + continue // 跳过 HSTS Header + } + } + newDirectives = append(newDirectives, dir) + } + v.vhost.Directives = newDirectives + + // 移除重定向规则 + v.vhost.RemoveDirective("RewriteEngine") + v.vhost.RemoveDirectives("RewriteCond") + v.vhost.RemoveDirectives("RewriteRule") + + // 更新监听端口,移除 443 + var newArgs []string + for _, arg := range v.vhost.Args { + if !strings.Contains(arg, ":443") { + newArgs = append(newArgs, arg) + } + } + if len(newArgs) == 0 { + newArgs = []string{"*:80"} + } + v.vhost.Args = newArgs + + return nil +} + +// ========== VhostPHP 接口实现 ========== + +func (v *Vhost) PHP() int { + // Apache 通常通过 FilesMatch 块配置 PHP + // 或者通过 SetHandler 指令 + handler := v.vhost.GetDirectiveValue("SetHandler") + if handler != "" && strings.Contains(handler, "php") { + // 尝试从 handler 中提取版本号 + // 例如: proxy:unix:/run/php/php8.4-fpm.sock|fcgi://localhost + if idx := strings.Index(handler, "php"); idx != -1 { + versionStr := "" + for i := idx + 3; i < len(handler); i++ { + c := handler[i] + if (c >= '0' && c <= '9') || c == '.' { + versionStr += string(c) + } else { + break + } + } + if versionStr != "" { + // 转换版本号,如 "8.4" -> 84 + parts := strings.Split(versionStr, ".") + if len(parts) >= 2 { + major, _ := strconv.Atoi(parts[0]) + minor, _ := strconv.Atoi(parts[1]) + return major*10 + minor + } + } + } + // 如果有 PHP 处理器但无法确定版本,返回默认值 + return 1 + } + + // 检查 FilesMatch 块中的 PHP 配置 + for _, dir := range v.vhost.Directives { + if dir.Block != nil && strings.EqualFold(dir.Block.Type, "FilesMatch") { + for _, d := range dir.Block.Directives { + if strings.EqualFold(d.Name, "SetHandler") && len(d.Args) > 0 { + if strings.Contains(d.Args[0], "php") { + return 1 // 有 PHP,但版本未知 + } + } + } + } + } + + return 0 +} + +func (v *Vhost) SetPHP(version int) error { + // 移除现有的 PHP 配置 + v.vhost.RemoveDirective("SetHandler") + + // 移除 FilesMatch 块中的 PHP 配置 + for i := len(v.vhost.Directives) - 1; i >= 0; i-- { + dir := v.vhost.Directives[i] + if dir.Block != nil && strings.EqualFold(dir.Block.Type, "FilesMatch") { + if len(dir.Block.Args) > 0 && strings.Contains(dir.Block.Args[0], "php") { + v.vhost.Directives = append(v.vhost.Directives[:i], v.vhost.Directives[i+1:]...) + } + } + } + + if version == 0 { + return nil // 禁用 PHP + } + + // 添加 PHP-FPM 配置 + major := version / 10 + minor := version % 10 + socketPath := fmt.Sprintf("/run/php/php%d.%d-fpm.sock", major, minor) + + // 添加 FilesMatch 块 + block := v.vhost.AddBlock("FilesMatch", `\.php$`) + if block.Block != nil { + block.Block.Directives = append(block.Block.Directives, + &Directive{ + Name: "SetHandler", + Args: []string{fmt.Sprintf("proxy:unix:%s|fcgi://localhost", socketPath)}, + }, + ) + } + + return nil +} + +// ========== VhostAdvanced 接口实现 ========== + +func (v *Vhost) RateLimit() *types.RateLimit { + // Apache 使用 mod_ratelimit + rate := v.vhost.GetDirectiveValue("SetOutputFilter") + if rate != "RATE_LIMIT" { + return nil + } + + rateLimit := &types.RateLimit{ + Options: make(map[string]string), + } + + // 获取速率限制值 + rateValue := v.vhost.GetDirectiveValue("SetEnv") + if rateValue != "" { + rateLimit.Rate = rateValue + } + + return rateLimit +} + +func (v *Vhost) SetRateLimit(limit *types.RateLimit) error { + if limit == nil { + // 清除限速配置 + v.vhost.RemoveDirective("SetOutputFilter") + v.vhost.RemoveDirectives("SetEnv") + return nil + } + + // 设置 mod_ratelimit + v.vhost.SetDirective("SetOutputFilter", "RATE_LIMIT") + if limit.Rate != "" { + v.vhost.SetDirective("SetEnv", "rate-limit", limit.Rate) + } + + return nil +} + +func (v *Vhost) BasicAuth() map[string]string { + authType := v.vhost.GetDirectiveValue("AuthType") + if authType == "" || !strings.EqualFold(authType, "Basic") { + return nil + } + + return map[string]string{ + "realm": v.vhost.GetDirectiveValue("AuthName"), + "user_file": v.vhost.GetDirectiveValue("AuthUserFile"), + } +} + +func (v *Vhost) SetBasicAuth(auth map[string]string) error { + if auth == nil || len(auth) == 0 { + // 清除基本认证配置 + v.vhost.RemoveDirective("AuthType") + v.vhost.RemoveDirective("AuthName") + v.vhost.RemoveDirective("AuthUserFile") + v.vhost.RemoveDirective("Require") + return nil + } + + realm := auth["realm"] + userFile := auth["user_file"] + + if realm == "" { + realm = "Restricted" + } + + v.vhost.SetDirective("AuthType", "Basic") + v.vhost.SetDirective("AuthName", fmt.Sprintf(`"%s"`, realm)) + v.vhost.SetDirective("AuthUserFile", userFile) + v.vhost.SetDirective("Require", "valid-user") + + return nil +} diff --git a/pkg/webserver/apache/vhost_test.go b/pkg/webserver/apache/vhost_test.go new file mode 100644 index 00000000..c812e516 --- /dev/null +++ b/pkg/webserver/apache/vhost_test.go @@ -0,0 +1,387 @@ +package apache + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/acepanel/panel/pkg/webserver/types" +) + +type VhostTestSuite struct { + suite.Suite + vhost *Vhost + configDir string +} + +func TestVhostTestSuite(t *testing.T) { + suite.Run(t, &VhostTestSuite{}) +} + +func (s *VhostTestSuite) SetupTest() { + // 创建临时配置目录 + configDir, err := os.MkdirTemp("", "apache-test-*") + s.Require().NoError(err) + s.configDir = configDir + + // 创建 server.d 目录 + err = os.MkdirAll(filepath.Join(configDir, "server.d"), 0755) + s.Require().NoError(err) + + vhost, err := NewVhost(configDir) + s.Require().NoError(err) + s.Require().NotNil(vhost) + s.vhost = vhost +} + +func (s *VhostTestSuite) TearDownTest() { + // 清理临时目录 + if s.configDir != "" { + s.NoError(os.RemoveAll(s.configDir)) + } +} + +func (s *VhostTestSuite) TestNewVhost() { + s.Equal(s.configDir, s.vhost.configDir) + s.NotNil(s.vhost.config) + s.NotNil(s.vhost.vhost) +} + +func (s *VhostTestSuite) TestEnable() { + // 默认应该是启用状态(没有 00-disable.conf) + s.True(s.vhost.Enable()) + + // 禁用网站 + err := s.vhost.SetEnable(false) + s.NoError(err) + s.False(s.vhost.Enable()) + + // 验证禁用文件存在 + disableFile := filepath.Join(s.configDir, "server.d", DisableConfName) + _, err = os.Stat(disableFile) + s.NoError(err) + + // 重新启用 + err = s.vhost.SetEnable(true) + s.NoError(err) + s.True(s.vhost.Enable()) + + // 验证禁用文件已删除 + _, err = os.Stat(disableFile) + s.True(os.IsNotExist(err)) +} + +func (s *VhostTestSuite) TestDisableConfigContent() { + // 禁用网站 + err := s.vhost.SetEnable(false) + s.NoError(err) + + // 读取禁用配置内容 + disableFile := filepath.Join(s.configDir, "server.d", DisableConfName) + content, err := os.ReadFile(disableFile) + s.NoError(err) + + // 验证内容包含 503 返回 + s.Contains(string(content), "503") + s.Contains(string(content), "RewriteRule") +} + +func (s *VhostTestSuite) TestServerName() { + names := []string{"example.com", "www.example.com", "api.example.com"} + err := s.vhost.SetServerName(names) + s.NoError(err) + + got := s.vhost.ServerName() + s.Len(got, 3) + s.Equal("example.com", got[0]) + s.Equal("www.example.com", got[1]) + s.Equal("api.example.com", got[2]) +} + +func (s *VhostTestSuite) TestServerNameEmpty() { + err := s.vhost.SetServerName([]string{}) + s.NoError(err) +} + +func (s *VhostTestSuite) TestRoot() { + root := "/var/www/html" + err := s.vhost.SetRoot(root) + s.NoError(err) + s.Equal(root, s.vhost.Root()) +} + +func (s *VhostTestSuite) TestIndex() { + index := []string{"index.html", "index.php", "default.html"} + err := s.vhost.SetIndex(index) + s.NoError(err) + + got := s.vhost.Index() + s.Len(got, 3) + s.Equal(index, got) +} + +func (s *VhostTestSuite) TestIndexEmpty() { + err := s.vhost.SetIndex([]string{}) + s.NoError(err) + s.Nil(s.vhost.Index()) +} + +func (s *VhostTestSuite) TestListen() { + listens := []types.Listen{ + {Address: "*:80", Protocol: "http"}, + {Address: "*:443", Protocol: "https"}, + } + err := s.vhost.SetListen(listens) + s.NoError(err) + + got := s.vhost.Listen() + s.Len(got, 2) + s.Equal("*:80", got[0].Address) + s.Equal("*:443", got[1].Address) +} + +func (s *VhostTestSuite) TestHTTPS() { + s.False(s.vhost.HTTPS()) + s.Nil(s.vhost.SSLConfig()) +} + +func (s *VhostTestSuite) TestSetSSLConfig() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + HSTS: true, + OCSP: true, + } + err := s.vhost.SetSSLConfig(sslConfig) + s.NoError(err) + + s.True(s.vhost.HTTPS()) + + got := s.vhost.SSLConfig() + s.NotNil(got) + s.Equal(sslConfig.Cert, got.Cert) + s.Equal(sslConfig.Key, got.Key) + s.True(got.HSTS) + s.True(got.OCSP) +} + +func (s *VhostTestSuite) TestSetSSLConfigNil() { + err := s.vhost.SetSSLConfig(nil) + s.Error(err) +} + +func (s *VhostTestSuite) TestClearHTTPS() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + HSTS: true, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + s.True(s.vhost.HTTPS()) + + err := s.vhost.ClearHTTPS() + s.NoError(err) + s.False(s.vhost.HTTPS()) +} + +func (s *VhostTestSuite) TestClearHTTPSPreservesOtherHeaders() { + // 添加一个非 HSTS 的 Header + s.vhost.vhost.AddDirective("Header", "set", "X-Custom-Header", "value") + + // 设置 SSL 和 HSTS + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + HSTS: true, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + // 清除 HTTPS + err := s.vhost.ClearHTTPS() + s.NoError(err) + + // 检查自定义 Header 是否保留 + headers := s.vhost.vhost.GetDirectives("Header") + s.NotEmpty(headers) + found := false + for _, h := range headers { + if len(h.Args) >= 2 && h.Args[1] == "X-Custom-Header" { + found = true + break + } + } + s.True(found, "自定义 Header 应该被保留") +} + +func (s *VhostTestSuite) TestPHP() { + s.Equal(0, s.vhost.PHP()) + + err := s.vhost.SetPHP(84) + s.NoError(err) + s.NotEqual(0, s.vhost.PHP()) + + err = s.vhost.SetPHP(0) + s.NoError(err) + s.Equal(0, s.vhost.PHP()) +} + +func (s *VhostTestSuite) TestAccessLog() { + accessLog := "/var/log/apache/access.log" + err := s.vhost.SetAccessLog(accessLog) + s.NoError(err) + s.Equal(accessLog, s.vhost.AccessLog()) +} + +func (s *VhostTestSuite) TestErrorLog() { + errorLog := "/var/log/apache/error.log" + err := s.vhost.SetErrorLog(errorLog) + s.NoError(err) + s.Equal(errorLog, s.vhost.ErrorLog()) +} + +func (s *VhostTestSuite) TestIncludes() { + includes := []types.IncludeFile{ + {Path: "/etc/apache/conf.d/ssl.conf"}, + {Path: "/etc/apache/conf.d/php.conf"}, + } + err := s.vhost.SetIncludes(includes) + s.NoError(err) + + got := s.vhost.Includes() + s.Len(got, 2) + s.Equal(includes[0].Path, got[0].Path) + s.Equal(includes[1].Path, got[1].Path) +} + +func (s *VhostTestSuite) TestBasicAuth() { + s.Nil(s.vhost.BasicAuth()) + + auth := map[string]string{ + "realm": "Test Realm", + "user_file": "/etc/htpasswd", + } + err := s.vhost.SetBasicAuth(auth) + s.NoError(err) + + got := s.vhost.BasicAuth() + s.NotNil(got) + s.Equal(auth["user_file"], got["user_file"]) + + err = s.vhost.SetBasicAuth(nil) + s.NoError(err) + s.Nil(s.vhost.BasicAuth()) +} + +func (s *VhostTestSuite) TestRateLimit() { + s.Nil(s.vhost.RateLimit()) + + limit := &types.RateLimit{ + Rate: "512", + } + err := s.vhost.SetRateLimit(limit) + s.NoError(err) + + got := s.vhost.RateLimit() + s.NotNil(got) + + err = s.vhost.SetRateLimit(nil) + s.NoError(err) + s.Nil(s.vhost.RateLimit()) +} + +func (s *VhostTestSuite) TestReset() { + s.NoError(s.vhost.SetServerName([]string{"modified.com"})) + s.NoError(s.vhost.SetRoot("/modified/path")) + + err := s.vhost.Reset() + s.NoError(err) + + names := s.vhost.ServerName() + s.NotContains(names, "modified.com") +} + +func (s *VhostTestSuite) TestSave() { + s.NoError(s.vhost.SetServerName([]string{"save-test.com"})) + + err := s.vhost.Save() + s.NoError(err) + + // 验证配置文件已保存 + configFile := filepath.Join(s.configDir, "apache.conf") + content, err := os.ReadFile(configFile) + s.NoError(err) + s.Contains(string(content), "save-test.com") +} + +func (s *VhostTestSuite) TestExport() { + s.NoError(s.vhost.SetServerName([]string{"export-test.com"})) + s.NoError(s.vhost.SetRoot("/var/www/export-test")) + + content := s.vhost.config.Export() + s.NotEmpty(content) + s.Contains(content, "export-test.com") + s.Contains(content, "/var/www/export-test") + s.Contains(content, "") +} + +func (s *VhostTestSuite) TestExportWithSSL() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + content := s.vhost.config.Export() + s.Contains(content, "SSLEngine on") + s.Contains(content, "SSLCertificateFile") + s.Contains(content, "SSLCertificateKeyFile") +} + +func (s *VhostTestSuite) TestListenProtocolDetection() { + listens := []types.Listen{ + {Address: "*:443", Protocol: "https"}, + } + s.NoError(s.vhost.SetListen(listens)) + + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + got := s.vhost.Listen() + s.Len(got, 1) + s.Equal("https", got[0].Protocol) +} + +func (s *VhostTestSuite) TestDirectoryBlock() { + root := "/var/www/test-dir" + err := s.vhost.SetRoot(root) + s.NoError(err) + + content := s.vhost.config.Export() + s.Contains(content, "") + s.Contains(content, "") +} + +func (s *VhostTestSuite) TestPHPFilesMatchBlock() { + err := s.vhost.SetPHP(84) + s.NoError(err) + + content := s.vhost.config.Export() + s.Contains(content, " 0 { + realm = realmDir.GetParameters()[0].GetValue() + } + + file := "" + if len(fileDir.GetParameters()) > 0 { + file = fileDir.GetParameters()[0].GetValue() + } + + return realm, file +} diff --git a/pkg/webserver/nginx/parser.go b/pkg/webserver/nginx/parser.go new file mode 100644 index 00000000..67e567a2 --- /dev/null +++ b/pkg/webserver/nginx/parser.go @@ -0,0 +1,232 @@ +package nginx + +import ( + "errors" + "fmt" + "os" + "strings" + + "github.com/tufanbarisyildirim/gonginx/config" + "github.com/tufanbarisyildirim/gonginx/dumper" + "github.com/tufanbarisyildirim/gonginx/parser" +) + +// Parser Nginx vhost 配置解析器 +type Parser struct { + cfg *config.Config + cfgPath string // 配置文件路径 +} + +func NewParser(website ...string) (*Parser, error) { + str := DefaultConf + cfgPath := "" + if len(website) != 0 && website[0] != "" { + cfgPath = fmt.Sprintf("/opt/ace/sites/%s/config/nginx.conf", website[0]) + if cfg, err := os.ReadFile(cfgPath); err == nil { + str = string(cfg) + } else { + return nil, err + } + } + + p := parser.NewStringParser(str, parser.WithSkipIncludeParsingErr(), parser.WithSkipValidDirectivesErr()) + cfg, err := p.Parse() + if err != nil { + return nil, err + } + + return &Parser{cfg: cfg, cfgPath: cfgPath}, nil +} + +// NewParserFromFile 从指定文件路径创建解析器 +func NewParserFromFile(filePath string) (*Parser, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + p := parser.NewStringParser(string(content), parser.WithSkipIncludeParsingErr(), parser.WithSkipValidDirectivesErr()) + cfg, err := p.Parse() + if err != nil { + return nil, fmt.Errorf("failed to parse config file: %w", err) + } + + return &Parser{cfg: cfg, cfgPath: filePath}, nil +} + +func (p *Parser) Config() *config.Config { + return p.cfg +} + +// Find 查找指令,如: Find("server.listen") +func (p *Parser) Find(key string) ([]config.IDirective, error) { + parts := strings.Split(key, ".") + var block *config.Block + var ok bool + block = p.cfg.Block + for i := 0; i < len(parts)-1; i++ { + key = parts[i] + directives := block.FindDirectives(key) + if len(directives) == 0 { + return nil, fmt.Errorf("given key %s not found", key) + } + if len(directives) > 1 { + return nil, errors.New("multiple directives found") + } + block, ok = directives[0].GetBlock().(*config.Block) + if !ok { + return nil, errors.New("block is not *config.Block") + } + } + + var result []config.IDirective + for _, dir := range block.GetDirectives() { + if dir.GetName() == parts[len(parts)-1] { + result = append(result, dir) + } + } + + return result, nil +} + +// FindOne 查找单个指令,如: FindOne("server.server_name") +func (p *Parser) FindOne(key string) (config.IDirective, error) { + directives, err := p.Find(key) + if err != nil { + return nil, err + } + if len(directives) == 0 { + return nil, fmt.Errorf("given key %s not found", key) + } + if len(directives) > 1 { + return nil, fmt.Errorf("multiple directives found for %s", key) + } + + return directives[0], nil +} + +// Clear 移除指令,如: Clear("server.server_name") +func (p *Parser) Clear(key string) error { + parts := strings.Split(key, ".") + last := parts[len(parts)-1] + parts = parts[:len(parts)-1] + + var block *config.Block + var ok bool + block = p.cfg.Block + for i := 0; i < len(parts); i++ { + directives := block.FindDirectives(parts[i]) + if len(directives) == 0 { + return fmt.Errorf("given key %s not found", parts[i]) + } + if len(directives) > 1 { + return fmt.Errorf("multiple directives found for %s", parts[i]) + } + block, ok = directives[0].GetBlock().(*config.Block) + if !ok { + return errors.New("block is not *config.Block") + } + } + + var newDirectives []config.IDirective + for _, directive := range block.GetDirectives() { + if directive.GetName() != last { + newDirectives = append(newDirectives, directive) + } + } + block.Directives = newDirectives + + return nil +} + +// Set 设置指令,如: Set("server.index", []Directive{...}) +func (p *Parser) Set(key string, directives []*config.Directive, after ...string) error { + parts := strings.Split(key, ".") + + var block *config.Block + var blockDirective config.IDirective + var ok bool + block = p.cfg.Block + for i := 0; i < len(parts); i++ { + sub := block.FindDirectives(parts[i]) + if len(sub) == 0 { + return fmt.Errorf("given key %s not found", parts[i]) + } + if len(sub) > 1 { + return fmt.Errorf("multiple directives found for %s", parts[i]) + } + block, ok = sub[0].GetBlock().(*config.Block) + if !ok { + return errors.New("block is not *config.Block") + } + blockDirective = sub[0] + } + + iDirectives := make([]config.IDirective, 0, len(directives)) + for _, directive := range directives { + directive.SetParent(blockDirective) + iDirectives = append(iDirectives, directive) + } + + if len(after) == 0 { + block.Directives = append(block.Directives, iDirectives...) + } else { + insertIndex := -1 + for i, d := range block.Directives { + if d.GetName() == after[0] { + insertIndex = i + 1 + break + } + } + if insertIndex == -1 { + return fmt.Errorf("after directive %s not found", after[0]) + } + + block.Directives = append( + block.Directives[:insertIndex], + append(iDirectives, block.Directives[insertIndex:]...)..., + ) + } + + return nil +} + +// Dump 将指令结构导出为配置内容 +func (p *Parser) Dump() string { + return dumper.DumpConfig(p.cfg, dumper.IndentedStyle) +} + +func (p *Parser) slices2Parameters(slices []string) []config.Parameter { + var parameters []config.Parameter + for _, slice := range slices { + parameters = append(parameters, config.Parameter{Value: slice}) + } + return parameters +} + +func (p *Parser) parameters2Slices(parameters []config.Parameter) []string { + var s []string + for _, parameter := range parameters { + s = append(s, parameter.Value) + } + return s +} + +// Save 保存配置到文件 +func (p *Parser) Save() error { + if p.cfgPath == "" { + return fmt.Errorf("config file path is empty, cannot save") + } + + content := p.Dump() + if err := os.WriteFile(p.cfgPath, []byte(content), 0644); err != nil { + return fmt.Errorf("failed to save config file: %w", err) + } + + return nil +} + +// SetConfigPath 设置配置文件路径 +func (p *Parser) SetConfigPath(path string) { + p.cfgPath = path +} diff --git a/pkg/webserver/nginx/parser_test.go b/pkg/webserver/nginx/parser_test.go new file mode 100644 index 00000000..3a18bd75 --- /dev/null +++ b/pkg/webserver/nginx/parser_test.go @@ -0,0 +1,233 @@ +package nginx + +import ( + "os" + "testing" + + "github.com/stretchr/testify/suite" +) + +type NginxTestSuite struct { + suite.Suite +} + +func TestNginxTestSuite(t *testing.T) { + suite.Run(t, &NginxTestSuite{}) +} + +func (s *NginxTestSuite) TestListen() { + parser, err := NewParser() + s.NoError(err) + listen, err := parser.GetListen() + s.NoError(err) + s.Equal([][]string{{"80"}}, listen) + s.NoError(parser.SetListen([][]string{{"80"}, {"443"}})) + listen, err = parser.GetListen() + s.NoError(err) + s.Equal([][]string{{"80"}, {"443"}}, listen) +} + +func (s *NginxTestSuite) TestServerName() { + parser, err := NewParser() + s.NoError(err) + serverName, err := parser.GetServerName() + s.NoError(err) + s.Equal([]string{"localhost"}, serverName) + s.NoError(parser.SetServerName([]string{"example.com"})) + serverName, err = parser.GetServerName() + s.NoError(err) + s.Equal([]string{"example.com"}, serverName) +} + +func (s *NginxTestSuite) TestIndex() { + parser, err := NewParser() + s.NoError(err) + index, err := parser.GetIndex() + s.NoError(err) + s.Equal([]string{"index.php", "index.html"}, index) + s.NoError(parser.SetIndex([]string{"index.html", "index.php"})) + index, err = parser.GetIndex() + s.NoError(err) + s.Equal([]string{"index.html", "index.php"}, index) +} + +func (s *NginxTestSuite) TestIndexWithComment() { + parser, err := NewParser() + s.NoError(err) + index, comment, err := parser.GetIndexWithComment() + s.NoError(err) + s.Equal([]string{"index.php", "index.html"}, index) + s.Equal([]string(nil), comment) + s.NoError(parser.SetIndexWithComment([]string{"index.html", "index.php"}, []string{"# 测试"})) + index, comment, err = parser.GetIndexWithComment() + s.NoError(err) + s.Equal([]string{"index.html", "index.php"}, index) + s.Equal([]string{"# 测试"}, comment) +} + +func (s *NginxTestSuite) TestRoot() { + parser, err := NewParser() + s.NoError(err) + root, err := parser.GetRoot() + s.NoError(err) + s.Equal("/opt/ace/sites/default/public", root) + s.NoError(parser.SetRoot("/www/wwwroot/test")) + root, err = parser.GetRoot() + s.NoError(err) + s.Equal("/www/wwwroot/test", root) +} + +func (s *NginxTestSuite) TestRootWithComment() { + parser, err := NewParser() + s.NoError(err) + root, comment, err := parser.GetRootWithComment() + s.NoError(err) + s.Equal("/opt/ace/sites/default/public", root) + s.Equal([]string(nil), comment) + s.NoError(parser.SetRootWithComment("/www/wwwroot/test", []string{"# 测试"})) + root, comment, err = parser.GetRootWithComment() + s.NoError(err) + s.Equal("/www/wwwroot/test", root) + s.Equal([]string{"# 测试"}, comment) +} + +func (s *NginxTestSuite) TestIncludes() { + parser, err := NewParser() + s.NoError(err) + includes, comments, err := parser.GetIncludes() + s.NoError(err) + s.Equal([]string{"/opt/ace/sites/default/config/server.d/*.conf"}, includes) + s.Equal([][]string{{"# custom configs"}}, comments) + s.NoError(parser.SetIncludes([]string{"/www/server/vhost/rewrite/default.conf"}, nil)) + includes, comments, err = parser.GetIncludes() + s.NoError(err) + s.Equal([]string{"/www/server/vhost/rewrite/default.conf"}, includes) + s.Equal([][]string{[]string(nil)}, comments) + s.NoError(parser.SetIncludes([]string{"/www/server/vhost/rewrite/test.conf"}, [][]string{{"# 伪静态规则测试"}})) + includes, comments, err = parser.GetIncludes() + s.NoError(err) + s.Equal([]string{"/www/server/vhost/rewrite/test.conf"}, includes) + s.Equal([][]string{{"# 伪静态规则测试"}}, comments) +} + +func (s *NginxTestSuite) TestPHP() { + parser, err := NewParser() + s.NoError(err) + s.Equal(0, parser.GetPHP()) + s.NoError(parser.SetPHP(80)) + s.Equal(80, parser.GetPHP()) + s.NoError(parser.SetPHP(0)) + s.Equal(0, parser.GetPHP()) +} + +func (s *NginxTestSuite) TestHTTP() { + parser, err := NewParser() + s.NoError(err) + expect, err := os.ReadFile("testdata/http.conf") + s.NoError(err) + s.Equal(string(expect), parser.Dump()) +} + +func (s *NginxTestSuite) TestHTTPS() { + parser, err := NewParser() + s.NoError(err) + s.False(parser.GetHTTPS()) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.True(parser.GetHTTPS()) + expect, err := os.ReadFile("testdata/https.conf") + s.NoError(err) + s.Equal(string(expect), parser.Dump()) +} + +func (s *NginxTestSuite) TestHTTPSProtocols() { + parser, err := NewParser() + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.Equal([]string{"TLSv1.2", "TLSv1.3"}, parser.GetHTTPSProtocols()) + s.NoError(parser.SetHTTPSProtocols([]string{"TLSv1.3"})) + s.Equal([]string{"TLSv1.3"}, parser.GetHTTPSProtocols()) +} + +func (s *NginxTestSuite) TestHTTPSCiphers() { + parser, err := NewParser() + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.Equal("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305", parser.GetHTTPSCiphers()) + s.NoError(parser.SetHTTPSCiphers("TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384")) + s.Equal("TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384", parser.GetHTTPSCiphers()) +} + +func (s *NginxTestSuite) TestOCSP() { + parser, err := NewParser() + s.NoError(err) + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.False(parser.GetOCSP()) + s.NoError(parser.SetOCSP(false)) + s.False(parser.GetOCSP()) + s.NoError(parser.SetOCSP(true)) + s.True(parser.GetOCSP()) + s.NoError(parser.SetOCSP(false)) + s.False(parser.GetOCSP()) +} + +func (s *NginxTestSuite) TestHSTS() { + parser, err := NewParser() + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.False(parser.GetHSTS()) + s.NoError(parser.SetHSTS(false)) + s.False(parser.GetHSTS()) + s.NoError(parser.SetHSTS(true)) + s.True(parser.GetHSTS()) + s.NoError(parser.SetHSTS(false)) + s.False(parser.GetHSTS()) +} + +func (s *NginxTestSuite) TestHTTPSRedirect() { + parser, err := NewParser() + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.False(parser.GetHTTPSRedirect()) + s.NoError(parser.SetHTTPSRedirect(false)) + s.False(parser.GetHTTPSRedirect()) + s.NoError(parser.SetHTTPSRedirect(true)) + s.True(parser.GetHTTPSRedirect()) + s.NoError(parser.SetHTTPSRedirect(false)) + s.False(parser.GetHTTPSRedirect()) +} + +func (s *NginxTestSuite) TestAltSvc() { + parser, err := NewParser() + s.NoError(err) + s.NoError(parser.SetHTTPSCert("/www/server/vhost/cert/default.pem", "/www/server/vhost/cert/default.key")) + s.Equal("", parser.GetAltSvc()) + s.NoError(parser.SetAltSvc(`'h3=":$server_port"; ma=2592000'`)) + s.Equal(`'h3=":$server_port"; ma=2592000'`, parser.GetAltSvc()) + s.NoError(parser.SetAltSvc("")) + s.Equal("", parser.GetAltSvc()) +} + +func (s *NginxTestSuite) TestAccessLog() { + parser, err := NewParser() + s.NoError(err) + log, err := parser.GetAccessLog() + s.NoError(err) + s.Equal("/opt/ace/sites/default/log/access.log", log) + s.NoError(parser.SetAccessLog("/www/wwwlogs/access.log")) + log, err = parser.GetAccessLog() + s.NoError(err) + s.Equal("/www/wwwlogs/access.log", log) +} + +func (s *NginxTestSuite) TestErrorLog() { + parser, err := NewParser() + s.NoError(err) + log, err := parser.GetErrorLog() + s.NoError(err) + s.Equal("/opt/ace/sites/default/log/error.log", log) + s.NoError(parser.SetErrorLog("/www/wwwlogs/error.log")) + log, err = parser.GetErrorLog() + s.NoError(err) + s.Equal("/www/wwwlogs/error.log", log) +} diff --git a/pkg/webserver/nginx/setter.go b/pkg/webserver/nginx/setter.go new file mode 100644 index 00000000..305462d3 --- /dev/null +++ b/pkg/webserver/nginx/setter.go @@ -0,0 +1,601 @@ +package nginx + +import ( + "fmt" + "slices" + "strings" + + "github.com/tufanbarisyildirim/gonginx/config" +) + +func (p *Parser) SetListen(listen [][]string) error { + var directives []*config.Directive + for _, l := range listen { + directives = append(directives, &config.Directive{ + Name: "listen", + Parameters: p.slices2Parameters(l), + }) + } + + if err := p.Clear("server.listen"); err != nil { + return err + } + + return p.Set("server", directives) +} + +func (p *Parser) SetServerName(serverName []string) error { + if err := p.Clear("server.server_name"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "server_name", + Parameters: p.slices2Parameters(serverName), + }, + }) +} + +func (p *Parser) SetIndex(index []string) error { + if err := p.Clear("server.index"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "index", + Parameters: p.slices2Parameters(index), + }, + }) +} + +func (p *Parser) SetIndexWithComment(index []string, comment []string) error { + if err := p.Clear("server.index"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "index", + Parameters: p.slices2Parameters(index), + Comment: comment, + }, + }) +} + +func (p *Parser) SetRoot(root string) error { + if err := p.Clear("server.root"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "root", + Parameters: []config.Parameter{{Value: root}}, + }, + }) +} + +func (p *Parser) SetRootWithComment(root string, comment []string) error { + if err := p.Clear("server.root"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "root", + Parameters: []config.Parameter{{Value: root}}, + Comment: comment, + }, + }) +} + +func (p *Parser) SetIncludes(includes []string, comments [][]string) error { + if err := p.Clear("server.include"); err != nil { + return err + } + + var directives []*config.Directive + for i, item := range includes { + var comment []string + if i < len(comments) { + comment = comments[i] + } + directives = append(directives, &config.Directive{ + Name: "include", + Parameters: []config.Parameter{{Value: item}}, + Comment: comment, + }) + } + + return p.Set("server", directives) +} + +func (p *Parser) SetPHP(php int) error { + old, err := p.Find("server.include") + if err != nil { + return err + } + if err = p.Clear("server.include"); err != nil { + return err + } + + var directives []*config.Directive + var foundFlag bool + for _, item := range old { + // 查找enable-php的配置 + if slices.ContainsFunc(p.parameters2Slices(item.GetParameters()), func(s string) bool { + return strings.HasPrefix(s, "enable-php-") && strings.HasSuffix(s, ".conf") + }) { + foundFlag = true + directives = append(directives, &config.Directive{ + Name: item.GetName(), + Parameters: []config.Parameter{{Value: fmt.Sprintf("enable-php-%d.conf", php)}}, + Comment: item.GetComment(), + }) + } else { + // 其余的原样保留 + directives = append(directives, &config.Directive{ + Name: item.GetName(), + Parameters: item.GetParameters(), + Comment: item.GetComment(), + }) + } + } + + // 如果没有找到enable-php的配置,直接添加一个 + if !foundFlag { + directives = append(directives, &config.Directive{ + Name: "include", + Parameters: []config.Parameter{{Value: fmt.Sprintf("enable-php-%d.conf", php)}}, + }) + } + + return p.Set("server", directives) +} + +func (p *Parser) ClearHTTPS() error { + if err := p.Clear("server.ssl_certificate"); err != nil { + return err + } + if err := p.Clear("server.ssl_certificate_key"); err != nil { + return err + } + if err := p.Clear("server.ssl_session_timeout"); err != nil { + return err + } + if err := p.Clear("server.ssl_session_cache"); err != nil { + return err + } + if err := p.Clear("server.ssl_protocols"); err != nil { + return err + } + if err := p.Clear("server.ssl_ciphers"); err != nil { + return err + } + if err := p.Clear("server.ssl_prefer_server_ciphers"); err != nil { + return err + } + if err := p.Clear("server.ssl_early_data"); err != nil { + return err + } + + return nil +} + +func (p *Parser) SetHTTPSCert(cert, key string) error { + if err := p.ClearHTTPS(); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "ssl_certificate", + Parameters: []config.Parameter{{Value: cert}}, + }, + { + Name: "ssl_certificate_key", + Parameters: []config.Parameter{{Value: key}}, + }, + { + Name: "ssl_session_timeout", + Parameters: []config.Parameter{{Value: "1d"}}, + }, + { + Name: "ssl_session_cache", + Parameters: []config.Parameter{{Value: "shared:SSL:10m"}}, + }, + { + Name: "ssl_protocols", + Parameters: []config.Parameter{{Value: "TLSv1.2"}, {Value: "TLSv1.3"}}, + }, + { + Name: "ssl_ciphers", + Parameters: []config.Parameter{{Value: "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305"}}, + }, + { + Name: "ssl_prefer_server_ciphers", + Parameters: []config.Parameter{{Value: "off"}}, + }, + { + Name: "ssl_early_data", + Parameters: []config.Parameter{{Value: "on"}}, + }, + }, "root") +} + +func (p *Parser) SetHTTPSProtocols(protocols []string) error { + if err := p.Clear("server.ssl_protocols"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "ssl_protocols", + Parameters: p.slices2Parameters(protocols), + }, + }) +} + +func (p *Parser) SetHTTPSCiphers(ciphers string) error { + if err := p.Clear("server.ssl_ciphers"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "ssl_ciphers", + Parameters: []config.Parameter{{Value: ciphers}}, + }, + }) +} + +func (p *Parser) SetOCSP(ocsp bool) error { + if err := p.Clear("server.ssl_stapling"); err != nil { + return err + } + if err := p.Clear("server.ssl_stapling_verify"); err != nil { + return err + } + + if ocsp { + return p.Set("server", []*config.Directive{ + { + Name: "ssl_stapling", + Parameters: []config.Parameter{{Value: "on"}}, + }, + { + Name: "ssl_stapling_verify", + Parameters: []config.Parameter{{Value: "on"}}, + }, + }) + } + + return nil +} + +func (p *Parser) SetHSTS(hsts bool) error { + old, err := p.Find("server.add_header") + if err != nil { + return err + } + if err = p.Clear("server.add_header"); err != nil { + return err + } + + var directives []*config.Directive + var foundFlag bool + for _, dir := range old { + if slices.Contains(p.parameters2Slices(dir.GetParameters()), "Strict-Transport-Security") { + foundFlag = true + if hsts { + directives = append(directives, &config.Directive{ + Name: dir.GetName(), + Parameters: []config.Parameter{{Value: "Strict-Transport-Security"}, {Value: "max-age=31536000"}}, + Comment: dir.GetComment(), + }) + } + } else { + directives = append(directives, &config.Directive{ + Name: dir.GetName(), + Parameters: dir.GetParameters(), + Comment: dir.GetComment(), + }) + } + } + + if !foundFlag && hsts { + directives = append(directives, &config.Directive{ + Name: "add_header", + Parameters: []config.Parameter{{Value: "Strict-Transport-Security"}, {Value: "max-age=31536000"}}, + }) + } + + return p.Set("server", directives) +} + +func (p *Parser) SetHTTPSRedirect(httpRedirect bool) error { + // if 重定向 + ifs, err := p.Find("server.if") + if err != nil { + return err + } + if err = p.Clear("server.if"); err != nil { + return err + } + + var directives []*config.Directive + var foundFlag bool + for _, dir := range ifs { // 所有 if + if !httpRedirect { + if len(dir.GetParameters()) == 3 && dir.GetParameters()[0].GetValue() == "($scheme" && dir.GetParameters()[1].GetValue() == "=" && dir.GetParameters()[2].GetValue() == "http)" { + continue + } + } + var ifDirectives []config.IDirective + for _, dir2 := range dir.GetBlock().GetDirectives() { // 每个 if 中所有指令 + if !httpRedirect { + // 不启用http重定向,则判断并移除特定的return指令 + if dir2.GetName() != "return" && !slices.Contains(p.parameters2Slices(dir2.GetParameters()), "https://$host$request_uri") { + ifDirectives = append(ifDirectives, dir2) + } + } else { + // 启用http重定向,需要检查防止重复添加 + if dir2.GetName() == "return" && slices.Contains(p.parameters2Slices(dir2.GetParameters()), "https://$host$request_uri") { + foundFlag = true + } + ifDirectives = append(ifDirectives, dir2) + } + } + // 写回 if 指令 + if block, ok := dir.GetBlock().(*config.Block); ok { + block.Directives = ifDirectives + } + directives = append(directives, &config.Directive{ + Block: dir.GetBlock(), + Name: dir.GetName(), + Parameters: dir.GetParameters(), + Comment: dir.GetComment(), + }) + } + + if !foundFlag && httpRedirect { + ifDir := &config.Directive{ + Name: "if", + Block: &config.Block{}, + Parameters: []config.Parameter{{Value: "($scheme"}, {Value: "="}, {Value: "http)"}}, + } + redirectDir := &config.Directive{ + Name: "return", + Parameters: []config.Parameter{{Value: "308"}, {Value: "https://$host$request_uri"}}, + } + redirectDir.SetParent(ifDir.GetParent()) + ifBlock := ifDir.GetBlock().(*config.Block) + ifBlock.Directives = append(ifBlock.Directives, redirectDir) + directives = append(directives, ifDir) + } + + if err = p.Set("server", directives); err != nil { + return err + } + + // error_page 497 重定向 + directives = nil + errorPages, err := p.Find("server.error_page") + if err != nil { + return err + } + if err = p.Clear("server.error_page"); err != nil { + return err + } + var found497 bool + for _, dir := range errorPages { + if !httpRedirect { + // 不启用https重定向,则判断并移除特定的return指令 + if !slices.Contains(p.parameters2Slices(dir.GetParameters()), "497") && !slices.Contains(p.parameters2Slices(dir.GetParameters()), "https://$host:$server_port$request_uri") { + directives = append(directives, &config.Directive{ + Block: dir.GetBlock(), + Name: dir.GetName(), + Parameters: dir.GetParameters(), + Comment: dir.GetComment(), + }) + } + } else { + // 启用https重定向,需要检查防止重复添加 + if slices.Contains(p.parameters2Slices(dir.GetParameters()), "497") && slices.Contains(p.parameters2Slices(dir.GetParameters()), "https://$host:$server_port$request_uri") { + found497 = true + } + directives = append(directives, &config.Directive{ + Block: dir.GetBlock(), + Name: dir.GetName(), + Parameters: dir.GetParameters(), + Comment: dir.GetComment(), + }) + } + } + + if !found497 && httpRedirect { + directives = append(directives, &config.Directive{ + Name: "error_page", + Parameters: []config.Parameter{{Value: "497"}, {Value: "=308"}, {Value: "https://$host:$server_port$request_uri"}}, + }) + } + + return p.Set("server", directives) +} + +func (p *Parser) SetAltSvc(altSvc string) error { + old, err := p.Find("server.add_header") + if err != nil { + return err + } + if err = p.Clear("server.add_header"); err != nil { + return err + } + + var directives []*config.Directive + var foundFlag bool + for _, dir := range old { + if slices.Contains(p.parameters2Slices(dir.GetParameters()), "Alt-Svc") { + foundFlag = true + if altSvc != "" { // 为空表示要删除 + directives = append(directives, &config.Directive{ + Name: dir.GetName(), + Parameters: []config.Parameter{{Value: "Alt-Svc"}, {Value: altSvc}}, + Comment: dir.GetComment(), + }) + } + } else { + directives = append(directives, &config.Directive{ + Name: dir.GetName(), + Parameters: dir.GetParameters(), + Comment: dir.GetComment(), + }) + } + } + + if !foundFlag && altSvc != "" { + directives = append(directives, &config.Directive{ + Name: "add_header", + Parameters: []config.Parameter{{Value: "Alt-Svc"}, {Value: altSvc}}, + }) + } + + return p.Set("server", directives) +} + +func (p *Parser) SetAccessLog(accessLog string) error { + if err := p.Clear("server.access_log"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "access_log", + Parameters: []config.Parameter{{Value: accessLog}}, + }, + }) +} + +func (p *Parser) SetErrorLog(errorLog string) error { + if err := p.Clear("server.error_log"); err != nil { + return err + } + + return p.Set("server", []*config.Directive{ + { + Name: "error_log", + Parameters: []config.Parameter{{Value: errorLog}}, + }, + }) +} + +// SetReturn 设置 return 指令(用于禁用网站) +func (p *Parser) SetReturn(code, url string) error { + if err := p.Clear("server.return"); err != nil { + // 忽略不存在的错误 + } + + directives := []*config.Directive{ + { + Name: "return", + Parameters: []config.Parameter{{Value: code}, {Value: url}}, + }, + } + + // 在 server 块的最开始插入 return 指令 + // 获取 server 块 + serverDirs := p.cfg.Block.FindDirectives("server") + if len(serverDirs) == 0 { + return fmt.Errorf("server block not found") + } + + serverBlock, ok := serverDirs[0].GetBlock().(*config.Block) + if !ok { + return fmt.Errorf("server block is not *config.Block") + } + + // 设置父节点 + for _, d := range directives { + d.SetParent(serverDirs[0]) + } + + // 在开头插入 + newDirectives := make([]config.IDirective, 0, len(directives)+len(serverBlock.Directives)) + for _, d := range directives { + newDirectives = append(newDirectives, d) + } + newDirectives = append(newDirectives, serverBlock.Directives...) + serverBlock.Directives = newDirectives + + return nil +} + +// SetLimitRate 设置限速配置 +func (p *Parser) SetLimitRate(limitRate string) error { + if err := p.Clear("server.limit_rate"); err != nil { + // 忽略不存在的错误 + } + + if limitRate == "" { + return nil // 清除限速配置 + } + + return p.Set("server", []*config.Directive{ + { + Name: "limit_rate", + Parameters: []config.Parameter{{Value: limitRate}}, + }, + }) +} + +// SetLimitConn 设置并发连接数限制 +func (p *Parser) SetLimitConn(limitConn [][]string) error { + if err := p.Clear("server.limit_conn"); err != nil { + // 忽略不存在的错误 + } + + if len(limitConn) == 0 { + return nil // 清除限流配置 + } + + var directives []*config.Directive + for _, limit := range limitConn { + if len(limit) >= 2 { + directives = append(directives, &config.Directive{ + Name: "limit_conn", + Parameters: p.slices2Parameters(limit), + }) + } + } + + return p.Set("server", directives) +} + +// SetBasicAuth 设置基本认证 +func (p *Parser) SetBasicAuth(realm, userFile string) error { + // 清除现有配置 + if err := p.Clear("server.auth_basic"); err != nil { + // 忽略不存在的错误 + } + if err := p.Clear("server.auth_basic_user_file"); err != nil { + // 忽略不存在的错误 + } + + // 如果 realm 为空,表示禁用基本认证 + if realm == "" || userFile == "" { + return nil + } + + return p.Set("server", []*config.Directive{ + { + Name: "auth_basic", + Parameters: []config.Parameter{{Value: realm}}, + }, + { + Name: "auth_basic_user_file", + Parameters: []config.Parameter{{Value: userFile}}, + }, + }) +} diff --git a/pkg/webserver/nginx/testdata/http.conf b/pkg/webserver/nginx/testdata/http.conf new file mode 100644 index 00000000..fee718ff --- /dev/null +++ b/pkg/webserver/nginx/testdata/http.conf @@ -0,0 +1,28 @@ +include /opt/ace/sites/default/config/http.d/*.conf; +server { + listen 80; + server_name localhost; + index index.php index.html; + root /opt/ace/sites/default/public; + # error page + error_page 404 /404.html; + # custom configs + include /opt/ace/sites/default/config/server.d/*.conf; + # browser cache + location ~ .*\.(bmp|jpg|jpeg|png|gif|svg|ico|tiff|webp|avif|heif|heic|jxl)$ { + expires 30d; + access_log /dev/null; + error_log /dev/null; + } + location ~ .*\.(js|css|ttf|otf|woff|woff2|eot)$ { + expires 6h; + access_log /dev/null; + error_log /dev/null; + } + # deny sensitive files + location ~ ^/(\.user.ini|\.htaccess|\.git|\.svn|\.env) { + return 404; + } + access_log /opt/ace/sites/default/log/access.log; + error_log /opt/ace/sites/default/log/error.log; +} \ No newline at end of file diff --git a/pkg/webserver/nginx/testdata/https.conf b/pkg/webserver/nginx/testdata/https.conf new file mode 100644 index 00000000..c73828ba --- /dev/null +++ b/pkg/webserver/nginx/testdata/https.conf @@ -0,0 +1,36 @@ +include /opt/ace/sites/default/config/http.d/*.conf; +server { + listen 80; + server_name localhost; + index index.php index.html; + root /opt/ace/sites/default/public; + ssl_certificate /www/server/vhost/cert/default.pem; + ssl_certificate_key /www/server/vhost/cert/default.key; + ssl_session_timeout 1d; + ssl_session_cache shared:SSL:10m; + ssl_protocols TLSv1.2 TLSv1.3; + ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305; + ssl_prefer_server_ciphers off; + ssl_early_data on; + # error page + error_page 404 /404.html; + # custom configs + include /opt/ace/sites/default/config/server.d/*.conf; + # browser cache + location ~ .*\.(bmp|jpg|jpeg|png|gif|svg|ico|tiff|webp|avif|heif|heic|jxl)$ { + expires 30d; + access_log /dev/null; + error_log /dev/null; + } + location ~ .*\.(js|css|ttf|otf|woff|woff2|eot)$ { + expires 6h; + access_log /dev/null; + error_log /dev/null; + } + # deny sensitive files + location ~ ^/(\.user.ini|\.htaccess|\.git|\.svn|\.env) { + return 404; + } + access_log /opt/ace/sites/default/log/access.log; + error_log /opt/ace/sites/default/log/error.log; +} \ No newline at end of file diff --git a/pkg/webserver/nginx/vhost.go b/pkg/webserver/nginx/vhost.go new file mode 100644 index 00000000..41e89f07 --- /dev/null +++ b/pkg/webserver/nginx/vhost.go @@ -0,0 +1,491 @@ +package nginx + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/acepanel/panel/pkg/webserver/types" +) + +// Vhost Nginx 虚拟主机实现 +type Vhost struct { + parser *Parser + configDir string // 配置目录 +} + +// NewVhost 创建 Nginx 虚拟主机实例 +// configDir: 配置目录路径 +func NewVhost(configDir string) (*Vhost, error) { + v := &Vhost{ + configDir: configDir, + } + + // 加载配置 + var parser *Parser + var err error + + if v.configDir != "" { + // 从配置目录加载主配置文件 + configFile := filepath.Join(v.configDir, "nginx.conf") + if _, statErr := os.Stat(configFile); statErr == nil { + parser, err = NewParserFromFile(configFile) + if err != nil { + return nil, fmt.Errorf("failed to load nginx config: %w", err) + } + } + } + + // 如果没有配置文件,使用默认配置 + if parser == nil { + // 使用空字符串创建默认配置,而不尝试读取文件 + parser, err = NewParser("") + if err != nil { + return nil, fmt.Errorf("failed to load default config: %w", err) + } + // 如果有 configDir,设置配置文件路径 + if v.configDir != "" { + parser.SetConfigPath(filepath.Join(v.configDir, "nginx.conf")) + } + } + + v.parser = parser + return v, nil +} + +// ========== VhostCore 接口实现 ========== + +func (v *Vhost) Enable() bool { + // 检查禁用配置文件是否存在 + disableFile := filepath.Join(v.configDir, "server.d", DisableConfName) + _, err := os.Stat(disableFile) + return os.IsNotExist(err) +} + +func (v *Vhost) SetEnable(enable bool, _ ...string) error { + serverDir := filepath.Join(v.configDir, "server.d") + disableFile := filepath.Join(serverDir, DisableConfName) + + if enable { + // 启用:删除禁用配置文件 + if err := os.Remove(disableFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove disable config: %w", err) + } + return nil + } + + // 禁用:创建禁用配置文件 + // 确保目录存在 + if err := os.MkdirAll(serverDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + // 写入禁用配置 + if err := os.WriteFile(disableFile, []byte(DisableConfContent), 0644); err != nil { + return fmt.Errorf("failed to write disable config: %w", err) + } + + return nil +} + +func (v *Vhost) Listen() []types.Listen { + listens, err := v.parser.GetListen() + if err != nil { + return nil + } + + var result []types.Listen + for _, l := range listens { + if len(l) == 0 { + continue + } + + listen := types.Listen{ + Address: l[0], + Options: make(map[string]string), + } + + // 解析 Nginx 特有的选项 + for i := 1; i < len(l); i++ { + switch l[i] { + case "ssl": + listen.Protocol = "https" + case "http2": + listen.Protocol = "http2" + case "http3", "quic": + listen.Protocol = "http3" + default: + listen.Options[l[i]] = "true" + } + } + + // 如果没有指定协议,默认为 http + if listen.Protocol == "" { + listen.Protocol = "http" + } + + result = append(result, listen) + } + + return result +} + +func (v *Vhost) SetListen(listens []types.Listen) error { + // 将通用 Listen 转换为 Nginx 格式 + var nginxListens [][]string + for _, l := range listens { + listen := []string{l.Address} + + // 添加协议标识 + switch l.Protocol { + case "https": + listen = append(listen, "ssl") + case "http2": + listen = append(listen, "http2") + case "http3": + listen = append(listen, "http3") + } + + // 添加其他选项 + for k, v := range l.Options { + if v == "true" { + listen = append(listen, k) + } else { + listen = append(listen, fmt.Sprintf("%s=%s", k, v)) + } + } + + nginxListens = append(nginxListens, listen) + } + + return v.parser.SetListen(nginxListens) +} + +func (v *Vhost) ServerName() []string { + names, err := v.parser.GetServerName() + if err != nil { + return nil + } + return names +} + +func (v *Vhost) SetServerName(serverName []string) error { + return v.parser.SetServerName(serverName) +} + +func (v *Vhost) Index() []string { + index, err := v.parser.GetIndex() + if err != nil { + return nil + } + return index +} + +func (v *Vhost) SetIndex(index []string) error { + return v.parser.SetIndex(index) +} + +func (v *Vhost) Root() string { + root, err := v.parser.GetRoot() + if err != nil { + return "" + } + return root +} + +func (v *Vhost) SetRoot(root string) error { + return v.parser.SetRoot(root) +} + +func (v *Vhost) Includes() []types.IncludeFile { + includes, comments, err := v.parser.GetIncludes() + if err != nil { + return nil + } + + var result []types.IncludeFile + for i, inc := range includes { + file := types.IncludeFile{ + Path: inc, + } + if i < len(comments) { + file.Comment = comments[i] + } + result = append(result, file) + } + + return result +} + +func (v *Vhost) SetIncludes(includes []types.IncludeFile) error { + var paths []string + var comments [][]string + + for _, inc := range includes { + paths = append(paths, inc.Path) + comments = append(comments, inc.Comment) + } + + return v.parser.SetIncludes(paths, comments) +} + +func (v *Vhost) AccessLog() string { + log, err := v.parser.GetAccessLog() + if err != nil { + return "" + } + return log +} + +func (v *Vhost) SetAccessLog(accessLog string) error { + return v.parser.SetAccessLog(accessLog) +} + +func (v *Vhost) ErrorLog() string { + log, err := v.parser.GetErrorLog() + if err != nil { + return "" + } + return log +} + +func (v *Vhost) SetErrorLog(errorLog string) error { + return v.parser.SetErrorLog(errorLog) +} + +func (v *Vhost) Save() error { + return v.parser.Save() +} + +func (v *Vhost) Reload() error { + // 重载 Nginx 配置 + // 优先使用 openresty,如果不存在则使用 nginx + cmds := []string{ + "/opt/ace/apps/openresty/bin/openresty -s reload", + "/usr/sbin/nginx -s reload", + "nginx -s reload", + } + + var lastErr error + for _, cmd := range cmds { + parts := strings.Fields(cmd) + if len(parts) < 2 { + continue + } + + // 检查命令是否存在 + if _, err := os.Stat(parts[0]); err == nil { + // 执行重载命令 + err := exec.Command(parts[0], parts[1:]...).Run() + if err == nil { + return nil + } + lastErr = err + } + } + + if lastErr != nil { + return fmt.Errorf("failed to reload nginx config: %w", lastErr) + } + return fmt.Errorf("nginx or openresty command not found") +} + +func (v *Vhost) Reset() error { + // 重置配置为默认值 + parser, err := NewParser("") + if err != nil { + return fmt.Errorf("failed to reset config: %w", err) + } + + // 如果有 configDir,设置配置文件路径 + if v.configDir != "" { + parser.SetConfigPath(filepath.Join(v.configDir, "nginx.conf")) + } + + v.parser = parser + return nil +} + +// ========== VhostSSL 接口实现 ========== + +func (v *Vhost) HTTPS() bool { + return v.parser.GetHTTPS() +} + +func (v *Vhost) SSLConfig() *types.SSLConfig { + if !v.HTTPS() { + return nil + } + + return &types.SSLConfig{ + Protocols: v.parser.GetHTTPSProtocols(), + Ciphers: v.parser.GetHTTPSCiphers(), + HSTS: v.parser.GetHSTS(), + OCSP: v.parser.GetOCSP(), + HTTPRedirect: v.parser.GetHTTPSRedirect(), + AltSvc: v.parser.GetAltSvc(), + } +} + +func (v *Vhost) SetSSLConfig(cfg *types.SSLConfig) error { + if cfg == nil { + return fmt.Errorf("SSL config cannot be nil") + } + + // 设置证书和私钥 + if err := v.parser.SetHTTPSCert(cfg.Cert, cfg.Key); err != nil { + return err + } + + // 设置协议 + if len(cfg.Protocols) > 0 { + if err := v.parser.SetHTTPSProtocols(cfg.Protocols); err != nil { + return err + } + } + + // 设置加密套件 + if cfg.Ciphers != "" { + if err := v.parser.SetHTTPSCiphers(cfg.Ciphers); err != nil { + return err + } + } + + // 设置 HSTS + if err := v.parser.SetHSTS(cfg.HSTS); err != nil { + return err + } + + // 设置 OCSP + if err := v.parser.SetOCSP(cfg.OCSP); err != nil { + return err + } + + // 设置 HTTP 跳转 + if err := v.parser.SetHTTPSRedirect(cfg.HTTPRedirect); err != nil { + return err + } + + // 设置 Alt-Svc + if cfg.AltSvc != "" { + if err := v.parser.SetAltSvc(cfg.AltSvc); err != nil { + return err + } + } + + return nil +} + +func (v *Vhost) ClearHTTPS() error { + return v.parser.ClearHTTPS() +} + +// ========== VhostPHP 接口实现 ========== + +func (v *Vhost) PHP() int { + return v.parser.GetPHP() +} + +func (v *Vhost) SetPHP(version int) error { + // 先移除所有 PHP 相关的 include + includes := v.Includes() + var newIncludes []types.IncludeFile + for _, inc := range includes { + // 过滤掉 enable-php-*.conf + if !strings.HasPrefix(inc.Path, "enable-php-") || !strings.HasSuffix(inc.Path, ".conf") { + newIncludes = append(newIncludes, inc) + } + } + + // 如果版本不为 0,添加新的 PHP include + if version > 0 { + newIncludes = append(newIncludes, types.IncludeFile{ + Path: fmt.Sprintf("enable-php-%d.conf", version), + Comment: []string{fmt.Sprintf("# Enable PHP %d.%d", version/10, version%10)}, + }) + } + + return v.SetIncludes(newIncludes) +} + +// ========== VhostAdvanced 接口实现 ========== + +func (v *Vhost) RateLimit() *types.RateLimit { + rate := v.parser.GetLimitRate() + limitConn := v.parser.GetLimitConn() + + if rate == "" && len(limitConn) == 0 { + return nil + } + + rateLimit := &types.RateLimit{ + Rate: rate, + Options: make(map[string]string), + } + + // 解析 limit_conn 配置 + for _, limit := range limitConn { + if len(limit) >= 2 { + // limit_conn zone connections + // 例如: limit_conn perip 10 + rateLimit.Options[limit[0]] = limit[1] + } + } + + return rateLimit +} + +func (v *Vhost) SetRateLimit(limit *types.RateLimit) error { + if limit == nil { + // 清除限流配置 + if err := v.parser.SetLimitRate(""); err != nil { + return err + } + return v.parser.SetLimitConn(nil) + } + + // 设置限速 + if err := v.parser.SetLimitRate(limit.Rate); err != nil { + return err + } + + // 设置并发连接数限制 + var limitConns [][]string + for zone, connections := range limit.Options { + limitConns = append(limitConns, []string{zone, connections}) + } + + return v.parser.SetLimitConn(limitConns) +} + +func (v *Vhost) BasicAuth() map[string]string { + realm, userFile := v.parser.GetBasicAuth() + if realm == "" || userFile == "" { + return nil + } + + // 返回基本认证配置 + // 注意:这里只返回配置路径,不解析用户文件内容 + return map[string]string{ + "realm": realm, + "user_file": userFile, + } +} + +func (v *Vhost) SetBasicAuth(auth map[string]string) error { + if auth == nil || len(auth) == 0 { + // 清除基本认证配置 + return v.parser.SetBasicAuth("", "") + } + + realm := auth["realm"] + userFile := auth["user_file"] + + if realm == "" { + realm = "Restricted" + } + + return v.parser.SetBasicAuth(realm, userFile) +} diff --git a/pkg/webserver/nginx/vhost_test.go b/pkg/webserver/nginx/vhost_test.go new file mode 100644 index 00000000..7e24d5ba --- /dev/null +++ b/pkg/webserver/nginx/vhost_test.go @@ -0,0 +1,349 @@ +package nginx + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/acepanel/panel/pkg/webserver/types" +) + +type VhostTestSuite struct { + suite.Suite + vhost *Vhost + configDir string +} + +func TestVhostTestSuite(t *testing.T) { + suite.Run(t, &VhostTestSuite{}) +} + +func (s *VhostTestSuite) SetupTest() { + // 创建临时配置目录 + configDir, err := os.MkdirTemp("", "nginx-test-*") + s.Require().NoError(err) + s.configDir = configDir + + // 创建 server.d 目录 + err = os.MkdirAll(filepath.Join(configDir, "server.d"), 0755) + s.Require().NoError(err) + + vhost, err := NewVhost(configDir) + s.Require().NoError(err) + s.Require().NotNil(vhost) + s.vhost = vhost +} + +func (s *VhostTestSuite) TearDownTest() { + // 清理临时目录 + if s.configDir != "" { + s.NoError(os.RemoveAll(s.configDir)) + } +} + +func (s *VhostTestSuite) TestNewVhost() { + s.Equal(s.configDir, s.vhost.configDir) + s.NotNil(s.vhost.parser) +} + +func (s *VhostTestSuite) TestEnable() { + // 默认应该是启用状态(没有 00-disable.conf) + s.True(s.vhost.Enable()) + + // 禁用网站 + s.NoError(s.vhost.SetEnable(false)) + s.False(s.vhost.Enable()) + + // 验证禁用文件存在 + disableFile := filepath.Join(s.configDir, "server.d", DisableConfName) + _, err := os.Stat(disableFile) + s.NoError(err) + + // 重新启用 + s.NoError(s.vhost.SetEnable(true)) + s.True(s.vhost.Enable()) + + // 验证禁用文件已删除 + _, err = os.Stat(disableFile) + s.True(os.IsNotExist(err)) +} + +func (s *VhostTestSuite) TestDisableConfigContent() { + // 禁用网站 + s.NoError(s.vhost.SetEnable(false)) + + // 读取禁用配置内容 + disableFile := filepath.Join(s.configDir, "server.d", DisableConfName) + content, err := os.ReadFile(disableFile) + s.NoError(err) + + // 验证内容包含 503 返回 + s.Contains(string(content), "503") + s.Contains(string(content), "return") +} + +func (s *VhostTestSuite) TestServerName() { + names := []string{"example.com", "www.example.com", "api.example.com"} + s.NoError(s.vhost.SetServerName(names)) + + got := s.vhost.ServerName() + s.Len(got, 3) + s.Equal("example.com", got[0]) + s.Equal("www.example.com", got[1]) + s.Equal("api.example.com", got[2]) +} + +func (s *VhostTestSuite) TestServerNameEmpty() { + s.NoError(s.vhost.SetServerName([]string{})) +} + +func (s *VhostTestSuite) TestRoot() { + root := "/var/www/html" + s.NoError(s.vhost.SetRoot(root)) + s.Equal(root, s.vhost.Root()) +} + +func (s *VhostTestSuite) TestIndex() { + index := []string{"index.html", "index.php", "default.html"} + s.NoError(s.vhost.SetIndex(index)) + + got := s.vhost.Index() + s.Len(got, 3) + s.Equal(index, got) +} + +func (s *VhostTestSuite) TestIndexEmpty() { + s.NoError(s.vhost.SetIndex([]string{})) +} + +func (s *VhostTestSuite) TestListen() { + listens := []types.Listen{ + {Address: "80", Protocol: "http"}, + {Address: "443", Protocol: "https"}, + } + s.NoError(s.vhost.SetListen(listens)) + + got := s.vhost.Listen() + s.Len(got, 2) +} + +func (s *VhostTestSuite) TestListenWithHTTP3() { + listens := []types.Listen{ + {Address: "443", Protocol: "http3"}, + } + s.NoError(s.vhost.SetListen(listens)) + + got := s.vhost.Listen() + s.Len(got, 1) + s.Equal("http3", got[0].Protocol) +} + +func (s *VhostTestSuite) TestHTTPS() { + s.False(s.vhost.HTTPS()) + s.Nil(s.vhost.SSLConfig()) +} + +func (s *VhostTestSuite) TestSetSSLConfig() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + HSTS: true, + OCSP: true, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + s.True(s.vhost.HTTPS()) + + got := s.vhost.SSLConfig() + s.NotNil(got) + s.True(got.HSTS) + s.True(got.OCSP) +} + +func (s *VhostTestSuite) TestSetSSLConfigNil() { + err := s.vhost.SetSSLConfig(nil) + s.Error(err) +} + +func (s *VhostTestSuite) TestClearHTTPS() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + HSTS: true, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + s.True(s.vhost.HTTPS()) + + s.NoError(s.vhost.ClearHTTPS()) + s.False(s.vhost.HTTPS()) +} + +func (s *VhostTestSuite) TestPHP() { + s.Equal(0, s.vhost.PHP()) + + s.NoError(s.vhost.SetPHP(84)) + + // Nginx 的 PHP 实现使用 include 文件 + includes := s.vhost.Includes() + found := false + for _, inc := range includes { + if strings.Contains(inc.Path, "enable-php-84.conf") { + found = true + break + } + } + s.True(found, "PHP include file should exist") + + s.NoError(s.vhost.SetPHP(0)) +} + +func (s *VhostTestSuite) TestAccessLog() { + accessLog := "/var/log/nginx/access.log" + s.NoError(s.vhost.SetAccessLog(accessLog)) + s.Equal(accessLog, s.vhost.AccessLog()) +} + +func (s *VhostTestSuite) TestErrorLog() { + errorLog := "/var/log/nginx/error.log" + s.NoError(s.vhost.SetErrorLog(errorLog)) + s.Equal(errorLog, s.vhost.ErrorLog()) +} + +func (s *VhostTestSuite) TestIncludes() { + includes := []types.IncludeFile{ + {Path: "/etc/nginx/conf.d/ssl.conf"}, + {Path: "/etc/nginx/conf.d/php.conf"}, + } + s.NoError(s.vhost.SetIncludes(includes)) + + got := s.vhost.Includes() + s.Len(got, 2) + s.Equal(includes[0].Path, got[0].Path) + s.Equal(includes[1].Path, got[1].Path) +} + +func (s *VhostTestSuite) TestBasicAuth() { + s.Nil(s.vhost.BasicAuth()) + + auth := map[string]string{ + "realm": "Test Realm", + "user_file": "/etc/nginx/htpasswd", + } + s.NoError(s.vhost.SetBasicAuth(auth)) + + got := s.vhost.BasicAuth() + s.NotNil(got) + s.Equal(auth["user_file"], got["user_file"]) + + s.NoError(s.vhost.SetBasicAuth(nil)) + s.Nil(s.vhost.BasicAuth()) +} + +func (s *VhostTestSuite) TestRateLimit() { + s.Nil(s.vhost.RateLimit()) + + limit := &types.RateLimit{ + Rate: "512k", + Options: map[string]string{ + "perip": "10", + }, + } + s.NoError(s.vhost.SetRateLimit(limit)) + + got := s.vhost.RateLimit() + s.NotNil(got) + s.Equal("512k", got.Rate) + + s.NoError(s.vhost.SetRateLimit(nil)) + s.Nil(s.vhost.RateLimit()) +} + +func (s *VhostTestSuite) TestReset() { + err := s.vhost.SetServerName([]string{"modified.com"}) + s.NoError(err) + err = s.vhost.SetRoot("/modified/path") + s.NoError(err) + + err = s.vhost.Reset() + s.NoError(err) + + names := s.vhost.ServerName() + s.NotContains(names, "modified.com") +} + +func (s *VhostTestSuite) TestSave() { + // 设置配置文件路径 + configFile := filepath.Join(s.configDir, "nginx.conf") + s.vhost.parser.SetConfigPath(configFile) + + s.NoError(s.vhost.SetServerName([]string{"save-test.com"})) + + s.NoError(s.vhost.Save()) + + // 验证配置文件已保存 + content, err := os.ReadFile(configFile) + s.NoError(err) + s.Contains(string(content), "save-test.com") +} + +func (s *VhostTestSuite) TestDump() { + err := s.vhost.SetServerName([]string{"dump-test.com"}) + s.NoError(err) + err = s.vhost.SetRoot("/var/www/dump-test") + s.NoError(err) + + content := s.vhost.parser.Dump() + s.NotEmpty(content) + s.Contains(content, "dump-test.com") + s.Contains(content, "/var/www/dump-test") + s.Contains(content, "server") +} + +func (s *VhostTestSuite) TestDumpWithSSL() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + content := s.vhost.parser.Dump() + s.Contains(content, "ssl_certificate") + s.Contains(content, "ssl_certificate_key") +} + +func (s *VhostTestSuite) TestHTTPSRedirect() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + HTTPRedirect: true, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + got := s.vhost.SSLConfig() + s.NotNil(got) + s.True(got.HTTPRedirect) +} + +func (s *VhostTestSuite) TestAltSvc() { + sslConfig := &types.SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + AltSvc: `h3=":$server_port"; ma=2592000`, + } + s.NoError(s.vhost.SetSSLConfig(sslConfig)) + + got := s.vhost.SSLConfig() + s.NotNil(got) + s.Contains(got.AltSvc, "h3=") +} + +func (s *VhostTestSuite) TestDefaultConfIncludesServerD() { + // 验证默认配置包含 server.d 的 include + s.Contains(DefaultConf, "server.d") + s.Contains(DefaultConf, "include") +} diff --git a/pkg/webserver/types.go b/pkg/webserver/types.go new file mode 100644 index 00000000..d610bcf6 --- /dev/null +++ b/pkg/webserver/types.go @@ -0,0 +1,9 @@ +package webserver + +// Type Web 服务器类型 +type Type string + +const ( + TypeNginx Type = "nginx" + TypeApache Type = "apache" +) diff --git a/pkg/webserver/types/proxy.go b/pkg/webserver/types/proxy.go new file mode 100644 index 00000000..31c62f4d --- /dev/null +++ b/pkg/webserver/types/proxy.go @@ -0,0 +1,24 @@ +package types + +import "time" + +// Proxy 反向代理配置 +type Proxy struct { + AutoRefresh bool // 是否自动刷新解析 + Pass string // 代理地址,如: "http://example.com", "http://backend" + Host string // 代理 Host,如: "example.com" + SNI string // 代理 SNI,如: "example.com" + Cache bool // 是否启用缓存 + Buffering bool // 是否启用缓冲 + Resolver []string // 自定义 DNS 解析器配置,如: ["8.8.8.8", "ipv6=off"] + ResolverTimeout time.Duration // DNS 解析超时时间,如: 5 * time.Second + Replaces map[string]string // 响应内容替换,如: map["/old"] = "/new" +} + +// Upstream 上游服务器配置 +type Upstream struct { + Name string // 上游名称,如: "backend" + Servers map[string]string // 上游服务器及权重,如: map["server1"] = "weight=5" + Algo string // 负载均衡算法,如: "least_conn", "ip_hash" + Keepalive int // 保持连接数,如: 32 +} diff --git a/pkg/webserver/types/redirect.go b/pkg/webserver/types/redirect.go new file mode 100644 index 00000000..afd71823 --- /dev/null +++ b/pkg/webserver/types/redirect.go @@ -0,0 +1,19 @@ +package types + +// RedirectType 重定向类型 +type RedirectType string + +const ( + RedirectType404 RedirectType = "404" // 404 重定向 + RedirectTypeHost RedirectType = "host" // 主机名重定向 + RedirectTypeURL RedirectType = "url" // URL 重定向 +) + +// Redirect 重定向配置 +type Redirect struct { + Type RedirectType // 重定向类型 + From string // 源地址,如: "example.com", "http://example.com", "/old" + To string // 目标地址,如: "https://example.com" + KeepURI bool // 是否保持 URI 不变(即保留请求参数) + StatusCode int // 自定义状态码,如: 301, 302, 307, 308,默认 308 +} diff --git a/pkg/webserver/types/vhost.go b/pkg/webserver/types/vhost.go new file mode 100644 index 00000000..f1f554b9 --- /dev/null +++ b/pkg/webserver/types/vhost.go @@ -0,0 +1,136 @@ +package types + +// VhostType 虚拟主机类型 +type VhostType string + +const ( + VhostTypeStatic VhostType = "static" + VhostTypePHP VhostType = "php" + VhostTypeProxy VhostType = "proxy" +) + +// Vhost 虚拟主机完整接口 +type Vhost interface { + VhostCore + VhostSSL + VhostPHP + VhostAdvanced +} + +// VhostCore 核心接口 +type VhostCore interface { + // Enable 取启用状态 + Enable() bool + // SetEnable 设置启用状态及停止页路径 + SetEnable(enable bool, stopPage ...string) error + + // Listen 取监听配置 + Listen() []Listen + // SetListen 设置监听配置 + SetListen(listen []Listen) error + + // ServerName 取服务器名称,如: ["example.com", "www.example.com"] + ServerName() []string + // SetServerName 设置服务器名称 + SetServerName(serverName []string) error + + // Index 取默认首页,如: ["index.php", "index.html"] + Index() []string + // SetIndex 设置默认首页 + SetIndex(index []string) error + + // Root 取网站根目录,如: "/opt/ace/sites/example/public" + Root() string + // SetRoot 设置网站根目录 + SetRoot(root string) error + + // Includes 取包含的文件配置 + Includes() []IncludeFile + // SetIncludes 设置包含的文件配置 + SetIncludes(includes []IncludeFile) error + + // AccessLog 取访问日志路径,如: "/opt/ace/sites/example/log/access.log" + AccessLog() string + // SetAccessLog 设置访问日志路径 + SetAccessLog(accessLog string) error + + // ErrorLog 取错误日志路径,如: "/opt/ace/sites/example/log/error.log" + ErrorLog() string + // SetErrorLog 设置错误日志路径 + SetErrorLog(errorLog string) error + + // Save 保存配置到文件 + Save() error + // Reload 重载配置(重启或重载服务器) + Reload() error + // Reset 重置配置为默认值 + Reset() error +} + +// VhostSSL SSL/TLS 相关接口 +type VhostSSL interface { + // HTTPS 取 HTTPS 启用状态 + HTTPS() bool + // SSLConfig 取 SSL 配置 + SSLConfig() *SSLConfig + // SetSSLConfig 设置 SSL 配置(自动启用 HTTPS) + SetSSLConfig(cfg *SSLConfig) error + // ClearHTTPS 清除 HTTPS 配置 + ClearHTTPS() error +} + +// VhostPHP PHP 相关接口 +type VhostPHP interface { + // PHP 取 PHP 版本,如: 84, 81, 80, 0 表示未启用 PHP + PHP() int + // SetPHP 设置 PHP 版本 + SetPHP(version int) error +} + +// VhostAdvanced 高级功能接口 +type VhostAdvanced interface { + // RateLimit 取限流限速配置 + RateLimit() *RateLimit + // SetRateLimit 设置限流限速配置 + SetRateLimit(limit *RateLimit) error + + // BasicAuth 取基本认证配置 + BasicAuth() map[string]string + // SetBasicAuth 设置基本认证 + SetBasicAuth(auth map[string]string) error +} + +// Listen 监听配置 +type Listen struct { + Address string // 监听地址,如: "80", "0.0.0.0:80", "[::]:443" + Protocol string // 协议类型,如: "http", "https", "http2", "http3" + Options map[string]string // 服务器特定选项,如: map["default_server"] = "true" +} + +// SSLConfig SSL/TLS 配置 +type SSLConfig struct { + Cert string // 证书路径 + Key string // 私钥路径 + Protocols []string // 支持的协议,如: ["TLSv1.2", "TLSv1.3"] + Ciphers string // 加密套件 + + // 高级选项 + HSTS bool // HTTP 严格传输安全 + OCSP bool // OCSP Stapling + HTTPRedirect bool // HTTP 强制跳转 HTTPS + AltSvc string // Alt-Svc 配置,如: 'h3=":443"; ma=86400' +} + +// RateLimit 限流限速配置 +type RateLimit struct { + Rate string // 速率限制,如: "512k", "10r/s" + Burst int // 突发限制 + Concurrent int // 并发连接数限制 + Options map[string]string // 服务器特定选项 +} + +// IncludeFile 包含文件配置 +type IncludeFile struct { + Path string // 文件路径 + Comment []string // 注释说明 +} diff --git a/pkg/webserver/webserver.go b/pkg/webserver/webserver.go new file mode 100644 index 00000000..72f54657 --- /dev/null +++ b/pkg/webserver/webserver.go @@ -0,0 +1,21 @@ +package webserver + +import ( + "fmt" + + "github.com/acepanel/panel/pkg/webserver/apache" + "github.com/acepanel/panel/pkg/webserver/nginx" + "github.com/acepanel/panel/pkg/webserver/types" +) + +// NewVhost 创建虚拟主机管理实例 +func NewVhost(serverType Type, configDir string) (types.Vhost, error) { + switch serverType { + case TypeNginx: + return nginx.NewVhost(configDir) + case TypeApache: + return apache.NewVhost(configDir) + default: + return nil, fmt.Errorf("unsupported server type: %s", serverType) + } +}