package apache import ( "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/suite" "github.com/acepanel/panel/pkg/webserver/types" ) type VhostTestSuite struct { suite.Suite vhost *PHPVhost configDir string } func TestVhostTestSuite(t *testing.T) { suite.Run(t, &VhostTestSuite{}) } func (s *VhostTestSuite) SetupTest() { // 创建临时配置目录 configDir, err := os.MkdirTemp("", "apache-test-*") s.Require().NoError(err) s.configDir = configDir // 创建 vhost 目录 err = os.MkdirAll(filepath.Join(configDir, "vhost"), 0755) s.Require().NoError(err) vhost, err := NewPHPVhost(configDir) s.Require().NoError(err) s.Require().NotNil(vhost) s.vhost = vhost } func (s *VhostTestSuite) TearDownTest() { // 清理临时目录 if s.configDir != "" { s.NoError(os.RemoveAll(s.configDir)) } } func (s *VhostTestSuite) TestNewVhost() { s.Equal(s.configDir, s.vhost.configDir) s.NotNil(s.vhost.config) s.NotNil(s.vhost.vhost) } func (s *VhostTestSuite) TestEnable() { // 默认应该是启用状态(没有 000-disable.conf) s.True(s.vhost.Enable()) // 禁用网站 s.NoError(s.vhost.SetEnable(false)) s.False(s.vhost.Enable()) // 验证禁用文件存在 disableFile := filepath.Join(s.configDir, "vhost", DisableConfName) _, err := os.Stat(disableFile) s.NoError(err) // 重新启用 s.NoError(s.vhost.SetEnable(true)) s.True(s.vhost.Enable()) // 验证禁用文件已删除 _, err = os.Stat(disableFile) s.True(os.IsNotExist(err)) } func (s *VhostTestSuite) TestDisableConfigContent() { // 禁用网站 s.NoError(s.vhost.SetEnable(false)) // 读取禁用配置内容 disableFile := filepath.Join(s.configDir, "vhost", DisableConfName) content, err := os.ReadFile(disableFile) s.NoError(err) // 验证内容包含 503 返回 s.Contains(string(content), "503") s.Contains(string(content), "RewriteRule") } func (s *VhostTestSuite) TestServerName() { names := []string{"example.com", "www.example.com", "api.example.com"} s.NoError(s.vhost.SetServerName(names)) got := s.vhost.ServerName() s.Len(got, 3) s.Equal("example.com", got[0]) s.Equal("www.example.com", got[1]) s.Equal("api.example.com", got[2]) } func (s *VhostTestSuite) TestServerNameEmpty() { s.NoError(s.vhost.SetServerName([]string{})) } func (s *VhostTestSuite) TestRoot() { root := "/var/www/html" s.NoError(s.vhost.SetRoot(root)) s.Equal(root, s.vhost.Root()) } func (s *VhostTestSuite) TestIndex() { index := []string{"index.html", "index.php", "default.html"} s.NoError(s.vhost.SetIndex(index)) got := s.vhost.Index() s.Len(got, 3) s.Equal(index, got) } func (s *VhostTestSuite) TestIndexEmpty() { s.NoError(s.vhost.SetIndex([]string{})) s.Nil(s.vhost.Index()) } func (s *VhostTestSuite) TestListen() { listens := []types.Listen{ {Address: "*:80"}, {Address: "*:443"}, } s.NoError(s.vhost.SetListen(listens)) got := s.vhost.Listen() s.Len(got, 2) s.Equal("*:80", got[0].Address) s.Equal("*:443", got[1].Address) } func (s *VhostTestSuite) TestSSL() { s.False(s.vhost.SSL()) s.Nil(s.vhost.SSLConfig()) } func (s *VhostTestSuite) TestSetSSLConfig() { sslConfig := &types.SSLConfig{ Cert: "/etc/ssl/cert.pem", Key: "/etc/ssl/key.pem", Protocols: []string{"TLSv1.2", "TLSv1.3"}, HSTS: true, OCSP: true, } s.NoError(s.vhost.SetSSLConfig(sslConfig)) s.True(s.vhost.SSL()) got := s.vhost.SSLConfig() s.NotNil(got) s.Equal(sslConfig.Cert, got.Cert) s.Equal(sslConfig.Key, got.Key) s.True(got.HSTS) s.True(got.OCSP) } func (s *VhostTestSuite) TestSetSSLConfigNil() { s.Error(s.vhost.SetSSLConfig(nil)) } func (s *VhostTestSuite) TestClearSSL() { sslConfig := &types.SSLConfig{ Cert: "/etc/ssl/cert.pem", Key: "/etc/ssl/key.pem", HSTS: true, } s.NoError(s.vhost.SetSSLConfig(sslConfig)) s.True(s.vhost.SSL()) s.NoError(s.vhost.ClearSSL()) s.False(s.vhost.SSL()) } func (s *VhostTestSuite) TestClearHTTPSPreservesOtherHeaders() { // 添加一个非 HSTS 的 Header s.vhost.vhost.AddDirective("Header", "set", "X-Custom-Header", "value") // 设置 SSL 和 HSTS sslConfig := &types.SSLConfig{ Cert: "/etc/ssl/cert.pem", Key: "/etc/ssl/key.pem", HSTS: true, } s.NoError(s.vhost.SetSSLConfig(sslConfig)) // 清除 HTTPS s.NoError(s.vhost.ClearSSL()) // 检查自定义 Header 是否保留 headers := s.vhost.vhost.GetDirectives("Header") s.NotEmpty(headers) found := false for _, h := range headers { if len(h.Args) >= 2 && h.Args[1] == "X-Custom-Header" { found = true break } } s.True(found, "自定义 Header 应该被保留") } func (s *VhostTestSuite) TestPHP() { s.Equal(uint(0), s.vhost.PHP()) s.NoError(s.vhost.SetPHP(84)) s.NotEqual(uint(0), s.vhost.PHP()) s.NoError(s.vhost.SetPHP(0)) s.Equal(uint(0), s.vhost.PHP()) } func (s *VhostTestSuite) TestAccessLog() { accessLog := "/var/log/apache/access.log" s.NoError(s.vhost.SetAccessLog(accessLog)) s.Equal(accessLog, s.vhost.AccessLog()) } func (s *VhostTestSuite) TestErrorLog() { errorLog := "/var/log/apache/error.log" s.NoError(s.vhost.SetErrorLog(errorLog)) s.Equal(errorLog, s.vhost.ErrorLog()) } func (s *VhostTestSuite) TestIncludes() { includes := []types.IncludeFile{ {Path: "/etc/apache/conf.d/ssl.conf"}, {Path: "/etc/apache/conf.d/php.conf"}, } s.NoError(s.vhost.SetIncludes(includes)) got := s.vhost.Includes() s.Len(got, 2) s.Equal(includes[0].Path, got[0].Path) s.Equal(includes[1].Path, got[1].Path) } func (s *VhostTestSuite) TestBasicAuth() { s.Nil(s.vhost.BasicAuth()) auth := map[string]string{ "realm": "Test Realm", "user_file": "/etc/htpasswd", } s.NoError(s.vhost.SetBasicAuth(auth)) got := s.vhost.BasicAuth() s.NotNil(got) s.Equal(auth["user_file"], got["user_file"]) s.NoError(s.vhost.ClearBasicAuth()) s.Nil(s.vhost.BasicAuth()) } func (s *VhostTestSuite) TestRateLimit() { s.Nil(s.vhost.RateLimit()) limit := &types.RateLimit{ Rate: "512", } s.NoError(s.vhost.SetRateLimit(limit)) got := s.vhost.RateLimit() s.NotNil(got) s.NoError(s.vhost.ClearRateLimit()) s.Nil(s.vhost.RateLimit()) } func (s *VhostTestSuite) TestReset() { s.NoError(s.vhost.SetServerName([]string{"modified.com"})) s.NoError(s.vhost.SetRoot("/modified/path")) s.NoError(s.vhost.Reset()) names := s.vhost.ServerName() s.NotContains(names, "modified.com") } func (s *VhostTestSuite) TestSave() { s.NoError(s.vhost.SetServerName([]string{"save-test.com"})) s.NoError(s.vhost.Save()) // 验证配置文件已保存 configFile := filepath.Join(s.configDir, "apache.conf") content, err := os.ReadFile(configFile) s.NoError(err) s.Contains(string(content), "save-test.com") } func (s *VhostTestSuite) TestExport() { s.NoError(s.vhost.SetServerName([]string{"export-test.com"})) s.NoError(s.vhost.SetRoot("/var/www/export-test")) content := s.vhost.config.Export() s.NotEmpty(content) s.Contains(content, "export-test.com") s.Contains(content, "/var/www/export-test") s.Contains(content, "") } func (s *VhostTestSuite) TestExportWithSSL() { sslConfig := &types.SSLConfig{ Cert: "/etc/ssl/cert.pem", Key: "/etc/ssl/key.pem", Protocols: []string{"TLSv1.2", "TLSv1.3"}, } s.NoError(s.vhost.SetSSLConfig(sslConfig)) content := s.vhost.config.Export() s.Contains(content, "SSLEngine on") s.Contains(content, "SSLCertificateFile") s.Contains(content, "SSLCertificateKeyFile") } func (s *VhostTestSuite) TestListenProtocolDetection() { listens := []types.Listen{ {Address: "*:443"}, } s.NoError(s.vhost.SetListen(listens)) sslConfig := &types.SSLConfig{ Cert: "/etc/ssl/cert.pem", Key: "/etc/ssl/key.pem", } s.NoError(s.vhost.SetSSLConfig(sslConfig)) got := s.vhost.Listen() s.Len(got, 1) s.Equal("*:443", got[0].Address) } func (s *VhostTestSuite) TestDirectoryBlock() { root := "/var/www/test-dir" s.NoError(s.vhost.SetRoot(root)) content := s.vhost.config.Export() s.Contains(content, "") s.Contains(content, "") } func (s *VhostTestSuite) TestPHPFilesMatchBlock() { s.NoError(s.vhost.SetPHP(84)) content := s.vhost.config.Export() s.Contains(content, "