diff --git a/internal/biz/database_user.go b/internal/biz/database_user.go index 8f87c0b4..b80ddd7e 100644 --- a/internal/biz/database_user.go +++ b/internal/biz/database_user.go @@ -17,16 +17,16 @@ const ( ) type DatabaseUser struct { - ID uint `gorm:"primaryKey" json:"id"` - ServerID uint `gorm:"not null" json:"server_id"` - Username string `gorm:"not null" json:"username"` - Password string `gorm:"not null" json:"password"` - Host string `gorm:"not null" json:"host"` // 仅 mysql - Status DatabaseUserStatus `gorm:"-:all" json:"status"` // 仅显示 - Privileges map[string][]string `gorm:"-:all" json:"privileges"` // 仅显示 - Remark string `gorm:"not null" json:"remark"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uint `gorm:"primaryKey" json:"id"` + ServerID uint `gorm:"not null" json:"server_id"` + Username string `gorm:"not null" json:"username"` + Password string `gorm:"not null" json:"password"` + Host string `gorm:"not null" json:"host"` // 仅 mysql + Status DatabaseUserStatus `gorm:"-:all" json:"status"` // 仅显示 + Privileges []string `gorm:"-:all" json:"privileges"` // 仅显示 + Remark string `gorm:"not null" json:"remark"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` Server *DatabaseServer `gorm:"foreignKey:ServerID;references:ID" json:"server"` } @@ -59,5 +59,6 @@ type DatabaseUserRepo interface { Update(req *request.DatabaseUserUpdate) error UpdateRemark(req *request.DatabaseUserUpdateRemark) error Delete(id uint) error + DeleteByNames(serverID uint, names []string) error DeleteByServerID(serverID uint) error } diff --git a/internal/data/database.go b/internal/data/database.go index 0733e417..3bf28694 100644 --- a/internal/data/database.go +++ b/internal/data/database.go @@ -78,7 +78,12 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error { return err } if req.CreateUser { - if err = mysql.UserCreate(req.Username, req.Password, req.Host); err != nil { + if err = NewDatabaseUserRepo().Create(&request.DatabaseUserCreate{ + ServerID: req.ServerID, + Username: req.Username, + Password: req.Password, + Host: req.Host, + }); err != nil { return err } } @@ -96,7 +101,12 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error { return err } if req.CreateUser { - if err = postgres.UserCreate(req.Username, req.Password); err != nil { + if err = NewDatabaseUserRepo().Create(&request.DatabaseUserCreate{ + ServerID: req.ServerID, + Username: req.Username, + Password: req.Password, + Host: req.Host, + }); err != nil { return err } } diff --git a/internal/data/database_user.go b/internal/data/database_user.go index e58d03e4..c0b732f6 100644 --- a/internal/data/database_user.go +++ b/internal/data/database_user.go @@ -56,6 +56,7 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error { return err } + user := new(biz.DatabaseUser) switch server.Type { case biz.DatabaseTypeMysql: mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) @@ -70,6 +71,11 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error { return err } } + user = &biz.DatabaseUser{ + ServerID: req.ServerID, + Username: req.Username, + Host: req.Host, + } case biz.DatabaseTypePostgresql: postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) if err != nil { @@ -83,12 +89,10 @@ func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error { return err } } - } - - user := &biz.DatabaseUser{ - ServerID: req.ServerID, - Username: req.Username, - Host: req.Host, + user = &biz.DatabaseUser{ + ServerID: req.ServerID, + Username: req.Username, + } } if err = app.Orm.FirstOrInit(user, user).Error; err != nil { @@ -191,6 +195,46 @@ func (r databaseUserRepo) Delete(id uint) error { return app.Orm.Where("id = ?", id).Delete(&biz.DatabaseUser{}).Error } +func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error { + server, err := NewDatabaseServerRepo().Get(serverID) + if err != nil { + return err + } + + switch server.Type { + case biz.DatabaseTypeMysql: + mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) + if err != nil { + return err + } + users := make([]*biz.DatabaseUser, 0) + if err = app.Orm.Where("server_id = ? AND username IN ?", serverID, names).Find(&users).Error; err != nil { + return err + } + for name := range slices.Values(names) { + host := "localhost" + for u := range slices.Values(users) { + if u.Username == name { + host = u.Host + break + } + } + _ = mysql.UserDrop(name, host) + } + case biz.DatabaseTypePostgresql: + postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) + if err != nil { + return err + } + for name := range slices.Values(names) { + _ = postgres.UserDrop(name) + } + } + + return app.Orm.Where("server_id = ? AND username IN ?", serverID, names).Delete(&biz.DatabaseUser{}).Error +} + +// DeleteByServerID 删除指定服务器的所有用户,只是删除面板记录,不会实际删除 func (r databaseUserRepo) DeleteByServerID(serverID uint) error { return app.Orm.Where("server_id = ?", serverID).Delete(&biz.DatabaseUser{}).Error } @@ -225,6 +269,6 @@ func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) { } // 初始化,防止 nil if user.Privileges == nil { - user.Privileges = make(map[string][]string) + user.Privileges = make([]string, 0) } } diff --git a/internal/data/website.go b/internal/data/website.go index 458c06ee..7acd45f3 100644 --- a/internal/data/website.go +++ b/internal/data/website.go @@ -19,7 +19,6 @@ import ( "github.com/TheTNB/panel/internal/http/request" "github.com/TheTNB/panel/pkg/acme" "github.com/TheTNB/panel/pkg/cert" - "github.com/TheTNB/panel/pkg/db" "github.com/TheTNB/panel/pkg/io" "github.com/TheTNB/panel/pkg/nginx" "github.com/TheTNB/panel/pkg/punycode" @@ -507,17 +506,14 @@ func (r *websiteRepo) Delete(req *request.WebsiteDelete) error { _ = io.Remove(website.Path) } if req.DB { - rootPassword, err := NewSettingRepo().Get(biz.SettingKeyMySQLRootPassword) - if err != nil { - return err + repo := NewDatabaseServerRepo() + if mysql, err := repo.GetByName("local_mysql"); err == nil { + _ = NewDatabaseUserRepo().DeleteByNames(mysql.ID, []string{website.Name}) + _ = NewDatabaseRepo().Delete(mysql.ID, website.Name) } - if mysql, err := db.NewMySQL("root", rootPassword, "/tmp/mysql.sock", "unix"); err == nil { - _ = mysql.UserDrop(website.Name, "localhost") - _ = mysql.DatabaseDrop(website.Name) - } - if postgres, err := db.NewPostgres("postgres", "", "127.0.0.1", 5432); err == nil { - _ = postgres.UserDrop(website.Name) - _ = postgres.DatabaseDrop(website.Name) + if postgres, err := repo.GetByName("local_postgresql"); err == nil { + _ = NewDatabaseUserRepo().DeleteByNames(postgres.ID, []string{website.Name}) + _ = NewDatabaseRepo().Delete(postgres.ID, website.Name) } } diff --git a/internal/route/http.go b/internal/route/http.go index 297c16a4..58bbf4c6 100644 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -74,6 +74,7 @@ func Http(r chi.Router) { server := service.NewDatabaseServerService() r.Get("/", server.List) r.Post("/", server.Create) + r.Get("/{id}", server.Get) r.Put("/{id}", server.Update) r.Put("/{id}/remark", server.UpdateRemark) r.Delete("/{id}", server.Delete) @@ -84,6 +85,7 @@ func Http(r chi.Router) { user := service.NewDatabaseUserService() r.Get("/", user.List) r.Post("/", user.Create) + r.Get("/{id}", user.Get) r.Put("/{id}", user.Update) r.Put("/{id}/remark", user.UpdateRemark) r.Delete("/{id}", user.Delete) diff --git a/internal/service/database_server.go b/internal/service/database_server.go index 0869ccdf..321ae4ef 100644 --- a/internal/service/database_server.go +++ b/internal/service/database_server.go @@ -54,6 +54,22 @@ func (s *DatabaseServer) Create(w http.ResponseWriter, r *http.Request) { Success(w, nil) } +func (s *DatabaseServer) Get(w http.ResponseWriter, r *http.Request) { + req, err := Bind[request.ID](r) + if err != nil { + Error(w, http.StatusUnprocessableEntity, "%v", err) + return + } + + server, err := s.databaseServerRepo.Get(req.ID) + if err != nil { + Error(w, http.StatusInternalServerError, "%v", err) + return + } + + Success(w, server) +} + func (s *DatabaseServer) Update(w http.ResponseWriter, r *http.Request) { req, err := Bind[request.DatabaseServerUpdate](r) if err != nil { diff --git a/internal/service/database_user.go b/internal/service/database_user.go index 1ec01d03..c3665ad7 100644 --- a/internal/service/database_user.go +++ b/internal/service/database_user.go @@ -54,6 +54,22 @@ func (s *DatabaseUser) Create(w http.ResponseWriter, r *http.Request) { Success(w, nil) } +func (s *DatabaseUser) Get(w http.ResponseWriter, r *http.Request) { + req, err := Bind[request.ID](r) + if err != nil { + Error(w, http.StatusUnprocessableEntity, "%v", err) + return + } + + user, err := s.databaseUserRepo.Get(req.ID) + if err != nil { + Error(w, http.StatusInternalServerError, "%v", err) + return + } + + Success(w, user) +} + func (s *DatabaseUser) Update(w http.ResponseWriter, r *http.Request) { req, err := Bind[request.DatabaseUserUpdate](r) if err != nil { diff --git a/pkg/db/mysql.go b/pkg/db/mysql.go index a83ed98b..3f07f334 100644 --- a/pkg/db/mysql.go +++ b/pkg/db/mysql.go @@ -3,9 +3,9 @@ package db import ( "database/sql" "fmt" - "strings" - _ "github.com/go-sql-driver/mysql" + "regexp" + "slices" "github.com/TheTNB/panel/pkg/types" ) @@ -122,52 +122,43 @@ func (r *MySQL) PrivilegesGrant(user, database, host string) error { return err } -func (r *MySQL) UserPrivileges(user, host string) (map[string][]string, error) { +func (r *MySQL) PrivilegesRevoke(user, database, host string) error { + _, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host)) + r.flushPrivileges() + return err +} + +func (r *MySQL) UserPrivileges(user, host string) ([]string, error) { rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host)) if err != nil { return nil, err } defer rows.Close() - privileges := make(map[string][]string) + re := regexp.MustCompile(`GRANT\s+ALL PRIVILEGES\s+ON\s+[\x60'"]?([^\s\x60'"]+)[\x60'"]?\.\*\s+TO\s+`) + var databases []string for rows.Next() { var grant string if err = rows.Scan(&grant); err != nil { return nil, err } - if !strings.HasPrefix(grant, "GRANT ") { - continue + + // 使用正则表达式匹配 + matches := re.FindStringSubmatch(grant) + if len(matches) == 2 { + dbName := matches[1] + if dbName != "*" { + databases = append(databases, dbName) + } } - - parts := strings.Split(grant, " ON ") - if len(parts) < 2 { - continue - } - - privList := strings.TrimPrefix(parts[0], "GRANT ") - privs := strings.Split(privList, ", ") - - dbPart := strings.Split(parts[1], " TO")[0] - // *.* 表示全局权限 - if dbPart == "*.*" { - dbPart = "*" - } - - dbPart = strings.Trim(dbPart, "`") - privileges[dbPart] = append(privileges[dbPart], privs...) } if err = rows.Err(); err != nil { return nil, err } - return privileges, nil -} - -func (r *MySQL) PrivilegesRevoke(user, database, host string) error { - _, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host)) - r.flushPrivileges() - return err + slices.Sort(databases) + return slices.Compact(databases), nil } func (r *MySQL) Users() ([]types.MySQLUser, error) { diff --git a/pkg/db/postgres.go b/pkg/db/postgres.go index e0248098..7af3e7d1 100644 --- a/pkg/db/postgres.go +++ b/pkg/db/postgres.go @@ -3,10 +3,8 @@ package db import ( "database/sql" "fmt" - "slices" - "strings" - _ "github.com/lib/pq" + "slices" "github.com/TheTNB/panel/pkg/systemctl" "github.com/TheTNB/panel/pkg/types" @@ -123,14 +121,16 @@ func (r *Postgres) UserPassword(user, password string) error { return err } -func (r *Postgres) UserPrivileges(user string) (map[string][]string, error) { +func (r *Postgres) UserPrivileges(user string) ([]string, error) { query := ` - SELECT - table_catalog as database_name, - string_agg(DISTINCT privilege_type, ',') as privileges - FROM information_schema.role_database_privileges - WHERE grantee = $1 - GROUP BY table_catalog` + SELECT d.datname + FROM pg_catalog.pg_database d + JOIN pg_catalog.pg_roles r ON d.datdba = r.oid + WHERE r.rolname = $1 + AND d.datistemplate = false + AND d.datname NOT IN ('template0', 'template1', 'postgres') + ORDER BY d.datname; + ` rows, err := r.Query(query, user) if err != nil { @@ -138,22 +138,21 @@ func (r *Postgres) UserPrivileges(user string) (map[string][]string, error) { } defer rows.Close() - privileges := make(map[string][]string) + var databases []string for rows.Next() { - var db, privilegeStr string - if err = rows.Scan(&db, &privilegeStr); err != nil { + var dbName string + if err = rows.Scan(&dbName); err != nil { return nil, err } - - privileges[db] = strings.Split(privilegeStr, ",") + databases = append(databases, dbName) } if err = rows.Err(); err != nil { return nil, err } - return privileges, nil + return databases, nil } func (r *Postgres) PrivilegesGrant(user, database string) error { diff --git a/web/src/api/panel/database/index.ts b/web/src/api/panel/database/index.ts index 1c153be0..5362bb3d 100644 --- a/web/src/api/panel/database/index.ts +++ b/web/src/api/panel/database/index.ts @@ -15,6 +15,8 @@ export default { http.Get('/databaseServer', { params: { page, limit } }), // 创建数据库服务器 serverCreate: (data: any) => http.Post('/databaseServer', data), + // 获取数据库服务器 + serverGet: (id: number) => http.Get(`/databaseServer/${id}`), // 更新数据库服务器 serverUpdate: (id: number, data: any) => http.Put(`/databaseServer/${id}`, data), // 删除数据库服务器 @@ -28,6 +30,8 @@ export default { userList: (page: number, limit: number) => http.Get('/databaseUser', { params: { page, limit } }), // 创建数据库用户 userCreate: (data: any) => http.Post('/databaseUser', data), + // 获取数据库用户 + userGet: (id: number) => http.Get(`/databaseUser/${id}`), // 更新数据库用户 userUpdate: (id: number, data: any) => http.Put(`/databaseUser/${id}`, data), // 删除数据库用户 diff --git a/web/src/views/database/CreateDatabaseModal.vue b/web/src/views/database/CreateDatabaseModal.vue index 0d431741..d4be69dc 100644 --- a/web/src/views/database/CreateDatabaseModal.vue +++ b/web/src/views/database/CreateDatabaseModal.vue @@ -28,16 +28,21 @@ const handleCreate = () => { }) } -onMounted(() => { - database.serverList(1, 10000).then((data: any) => { - for (const server of data.items) { - servers.value.push({ - label: server.name, - value: server.id +watch( + () => show.value, + (value) => { + if (value) { + database.serverList(1, 10000).then((data: any) => { + for (const server of data.items) { + servers.value.push({ + label: server.name, + value: server.id + }) + } }) } - }) -}) + } +)