2
0
mirror of https://github.com/acepanel/panel.git synced 2026-02-04 07:57:21 +08:00

refactor: 消除 database_user.go 中的重复代码 (#1149)

This commit is contained in:
Catsayer
2025-12-04 14:43:02 +08:00
committed by GitHub
parent aa181ababc
commit 2c85070921
13 changed files with 333 additions and 394 deletions

View File

@@ -212,9 +212,7 @@ func (s *App) SetRootPassword(w http.ResponseWriter, r *http.Request) {
return
}
} else {
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
defer mysql.Close()
if err = mysql.UserPassword("root", req.Password, "localhost"); err != nil {
service.Error(w, http.StatusInternalServerError, "%v", err)
return

View File

@@ -256,9 +256,7 @@ func (r *backupRepo) createMySQL(to string, name string) error {
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
defer mysql.Close()
if exist, _ := mysql.DatabaseExists(name); !exist {
return errors.New(r.t.Get("database does not exist: %s", name))
}
@@ -302,10 +300,8 @@ func (r *backupRepo) createPostgres(to string, name string) error {
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
if exist, _ := postgres.DatabaseExist(name); !exist {
defer postgres.Close()
if exist, _ := postgres.DatabaseExists(name); !exist {
return errors.New(r.t.Get("database does not exist: %s", name))
}
size, err := postgres.DatabaseSize(name)
@@ -418,9 +414,7 @@ func (r *backupRepo) restoreMySQL(backup, target string) error {
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
defer mysql.Close()
if exist, _ := mysql.DatabaseExists(target); !exist {
return errors.New(r.t.Get("database does not exist: %s", target))
}
@@ -460,10 +454,8 @@ func (r *backupRepo) restorePostgres(backup, target string) error {
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
if exist, _ := postgres.DatabaseExist(target); !exist {
defer postgres.Close()
if exist, _ := postgres.DatabaseExists(target); !exist {
return errors.New(r.t.Get("database does not exist: %s", target))
}

View File

@@ -29,7 +29,7 @@ func NewDatabaseRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServerRep
}
}
func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
func (r *databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
var databaseServer []*biz.DatabaseServer
if err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Find(&databaseServer).Error; err != nil {
return nil, 0, err
@@ -37,61 +37,43 @@ func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
database := make([]*biz.Database, 0)
for _, server := range databaseServer {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err == nil {
if databases, err := mysql.Databases(); err == nil {
for item := range slices.Values(databases) {
database = append(database, &biz.Database{
Type: biz.DatabaseTypeMysql,
Name: item.Name,
Server: server.Name,
ServerID: server.ID,
Encoding: item.CharSet,
})
}
}
_ = mysql.Close()
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err == nil {
if databases, err := postgres.Databases(); err == nil {
for item := range slices.Values(databases) {
database = append(database, &biz.Database{
Type: biz.DatabaseTypePostgresql,
Name: item.Name,
Server: server.Name,
ServerID: server.ID,
Encoding: item.Encoding,
Comment: item.Comment,
})
}
}
_ = postgres.Close()
operator, err := r.getOperator(server)
if err != nil {
continue
}
if databases, err := operator.Databases(); err == nil {
for item := range slices.Values(databases) {
database = append(database, &biz.Database{
Type: server.Type,
Name: item.Name,
Server: server.Name,
ServerID: server.ID,
Encoding: item.CharSet,
Comment: item.Comment,
})
}
}
operator.Close()
}
return database[(page-1)*limit:], int64(len(database)), nil
}
func (r databaseRepo) Create(req *request.DatabaseCreate) error {
func (r *databaseRepo) Create(req *request.DatabaseCreate) error {
server, err := r.server.Get(req.ServerID)
if err != nil {
return err
}
operator, err := r.getOperator(server)
if err != nil {
return err
}
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
if req.CreateUser {
if err = r.user.Create(&request.DatabaseUserCreate{
ServerID: req.ServerID,
@@ -102,22 +84,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
return err
}
}
if err = mysql.DatabaseCreate(req.Name); err != nil {
if err = operator.DatabaseCreate(req.Name); err != nil {
return err
}
if req.Username != "" {
if err = mysql.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil {
if err = operator.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil {
return err
}
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
if req.CreateUser {
if err = r.user.Create(&request.DatabaseUserCreate{
ServerID: req.ServerID,
@@ -128,15 +103,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
return err
}
}
if err = postgres.DatabaseCreate(req.Name); err != nil {
if err = operator.DatabaseCreate(req.Name); err != nil {
return err
}
if req.Username != "" {
if err = postgres.PrivilegesGrant(req.Username, req.Name); err != nil {
if err = operator.PrivilegesGrant(req.Username, req.Name); err != nil {
return err
}
}
if err = postgres.DatabaseComment(req.Name, req.Comment); err != nil {
if err = operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment); err != nil {
return err
}
}
@@ -144,55 +119,56 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
return nil
}
func (r databaseRepo) Delete(serverID uint, name string) error {
func (r *databaseRepo) Delete(serverID uint, name string) error {
server, err := r.server.Get(serverID)
if err != nil {
return err
}
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
return mysql.DatabaseDrop(name)
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
return postgres.DatabaseDrop(name)
operator, err := r.getOperator(server)
if err != nil {
return err
}
return nil
return operator.DatabaseDrop(name)
}
func (r databaseRepo) Comment(req *request.DatabaseComment) error {
func (r *databaseRepo) Comment(req *request.DatabaseComment) error {
server, err := r.server.Get(req.ServerID)
if err != nil {
return err
}
operator, err := r.getOperator(server)
if err != nil {
return err
}
switch server.Type {
case biz.DatabaseTypeMysql:
return errors.New(r.t.Get("mysql not support database comment"))
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
return postgres.DatabaseComment(req.Name, req.Comment)
return operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment)
}
return nil
}
func (r *databaseRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return nil, err
}
return mysql, nil
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return nil, err
}
return postgres, nil
default:
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
}
}

View File

@@ -28,7 +28,7 @@ func NewDatabaseServerRepo(t *gotext.Locale, db *gorm.DB, log *slog.Logger) biz.
}
}
func (r databaseServerRepo) Count() (int64, error) {
func (r *databaseServerRepo) Count() (int64, error) {
var count int64
if err := r.db.Model(&biz.DatabaseServer{}).Count(&count).Error; err != nil {
return 0, err
@@ -37,7 +37,7 @@ func (r databaseServerRepo) Count() (int64, error) {
return count, nil
}
func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) {
func (r *databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) {
databaseServer := make([]*biz.DatabaseServer, 0)
var total int64
err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&databaseServer).Error
@@ -49,7 +49,7 @@ func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64
return databaseServer, total, err
}
func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
func (r *databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
databaseServer := new(biz.DatabaseServer)
if err := r.db.Where("id = ?", id).First(databaseServer).Error; err != nil {
return nil, err
@@ -60,7 +60,7 @@ func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
return databaseServer, nil
}
func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) {
func (r *databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) {
databaseServer := new(biz.DatabaseServer)
if err := r.db.Where("name = ?", name).First(databaseServer).Error; err != nil {
return nil, err
@@ -71,7 +71,7 @@ func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error)
return databaseServer, nil
}
func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
func (r *databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
databaseServer := &biz.DatabaseServer{
Name: req.Name,
Type: biz.DatabaseType(req.Type),
@@ -89,7 +89,7 @@ func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
return r.db.Create(databaseServer).Error
}
func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
func (r *databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
server, err := r.Get(req.ID)
if err != nil {
return err
@@ -109,11 +109,11 @@ func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
return r.db.Save(server).Error
}
func (r databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error {
func (r *databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error {
return r.db.Model(&biz.DatabaseServer{}).Where("id = ?", req.ID).Update("remark", req.Remark).Error
}
func (r databaseServerRepo) Delete(id uint) error {
func (r *databaseServerRepo) Delete(id uint) error {
if err := r.ClearUsers(id); err != nil {
return err
}
@@ -122,11 +122,11 @@ func (r databaseServerRepo) Delete(id uint) error {
}
// ClearUsers 删除指定服务器的所有用户,只是删除面板记录,不会实际删除
func (r databaseServerRepo) ClearUsers(serverID uint) error {
func (r *databaseServerRepo) ClearUsers(serverID uint) error {
return r.db.Where("server_id = ?", serverID).Delete(&biz.DatabaseUser{}).Error
}
func (r databaseServerRepo) Sync(id uint) error {
func (r *databaseServerRepo) Sync(id uint) error {
server, err := r.Get(id)
if err != nil {
return err
@@ -137,16 +137,15 @@ func (r databaseServerRepo) Sync(id uint) error {
return err
}
operator, err := r.getOperator(server)
if err != nil {
return err
}
defer operator.Close()
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
allUsers, err := mysql.Users()
allUsers, err := operator.Users()
if err != nil {
return err
}
@@ -166,24 +165,17 @@ func (r databaseServerRepo) Sync(id uint) error {
}
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
allUsers, err := postgres.Users()
allUsers, err := operator.Users()
if err != nil {
return err
}
for user := range slices.Values(allUsers) {
if !slices.ContainsFunc(users, func(a *biz.DatabaseUser) bool {
return a.Username == user.Role
}) && !slices.Contains([]string{"postgres"}, user.Role) {
return a.Username == user.User
}) && !slices.Contains([]string{"postgres"}, user.User) {
newUser := &biz.DatabaseUser{
ServerID: id,
Username: user.Role,
Username: user.User,
Remark: r.t.Get("sync from server %s", server.Name),
}
r.db.Create(newUser)
@@ -195,26 +187,19 @@ func (r databaseServerRepo) Sync(id uint) error {
}
// checkServer 检查服务器连接
func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
func (r *databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
case biz.DatabaseTypeMysql, biz.DatabaseTypePostgresql:
operator, err := r.getOperator(server)
if err == nil {
_ = mysql.Close()
server.Status = biz.DatabaseServerStatusValid
return true
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err == nil {
_ = postgres.Close()
operator.Close()
server.Status = biz.DatabaseServerStatusValid
return true
}
case biz.DatabaseTypeRedis:
redis, err := db.NewRedis(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err == nil {
_ = redis.Close()
redis.Close()
server.Status = biz.DatabaseServerStatusValid
return true
}
@@ -223,3 +208,22 @@ func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
server.Status = biz.DatabaseServerStatusInvalid
return false
}
func (r *databaseServerRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return nil, err
}
return mysql, nil
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return nil, err
}
return postgres, nil
default:
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
}
}

View File

@@ -26,7 +26,7 @@ func NewDatabaseUserRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServe
}
}
func (r databaseUserRepo) Count() (int64, error) {
func (r *databaseUserRepo) Count() (int64, error) {
var count int64
if err := r.db.Model(&biz.DatabaseUser{}).Count(&count).Error; err != nil {
return 0, err
@@ -35,7 +35,7 @@ func (r databaseUserRepo) Count() (int64, error) {
return count, nil
}
func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) {
func (r *databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) {
user := make([]*biz.DatabaseUser, 0)
var total int64
err := r.db.Model(&biz.DatabaseUser{}).Preload("Server").Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&user).Error
@@ -47,7 +47,7 @@ func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, er
return user, total, err
}
func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
func (r *databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
user := new(biz.DatabaseUser)
if err := r.db.Preload("Server").Where("id = ?", id).First(user).Error; err != nil {
return nil, err
@@ -58,74 +58,49 @@ func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
return user, nil
}
func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
func (r *databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
server, err := r.server.Get(req.ServerID)
if err != nil {
return err
}
user := new(biz.DatabaseUser)
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
operator, err := r.getOperator(server)
if err != nil {
return err
}
defer operator.Close()
// 创建用户
if err = operator.UserCreate(req.Username, req.Password); err != nil {
return err
}
// 创建数据库并授权
for name := range slices.Values(req.Privileges) {
if err = operator.DatabaseCreate(name); err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
if err = mysql.UserCreate(req.Username, req.Password, req.Host); err != nil {
if err = operator.PrivilegesGrant(req.Username, name); err != nil {
return err
}
for name := range slices.Values(req.Privileges) {
if err = mysql.DatabaseCreate(name); err != nil {
return err
}
if err = mysql.PrivilegesGrant(req.Username, name, req.Host); err != nil {
return err
}
}
user = &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
Host: req.Host,
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
if err = postgres.UserCreate(req.Username, req.Password); err != nil {
return err
}
for name := range slices.Values(req.Privileges) {
if err = postgres.DatabaseCreate(name); err != nil {
return err
}
if err = postgres.PrivilegesGrant(req.Username, name); err != nil {
return err
}
}
user = &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
}
}
user := &biz.DatabaseUser{
ServerID: req.ServerID,
Username: req.Username,
Host: req.Host,
Password: req.Password,
Remark: req.Remark,
}
if err = r.db.FirstOrInit(user, user).Error; err != nil {
return err
}
user.Password = req.Password
user.Remark = req.Remark
return r.db.Save(user).Error
}
func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
func (r *databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
user, err := r.Get(req.ID)
if err != nil {
return err
@@ -136,58 +111,36 @@ func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
return err
}
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
operator, err := r.getOperator(server)
if err != nil {
return err
}
defer operator.Close()
// 更新密码
if req.Password != "" {
if err = operator.UserPassword(user.Username, req.Password); err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
if req.Password != "" {
if err = mysql.UserPassword(user.Username, req.Password, user.Host); err != nil {
return err
}
}
for name := range slices.Values(req.Privileges) {
if err = mysql.DatabaseCreate(name); err != nil {
return err
}
if err = mysql.PrivilegesGrant(user.Username, name, user.Host); err != nil {
return err
}
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
user.Password = req.Password
}
// 创建数据库并授权
for name := range slices.Values(req.Privileges) {
if err = operator.DatabaseCreate(name); err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
if req.Password != "" {
if err = postgres.UserPassword(user.Username, req.Password); err != nil {
return err
}
}
for name := range slices.Values(req.Privileges) {
if err = postgres.DatabaseCreate(name); err != nil {
return err
}
if err = postgres.PrivilegesGrant(user.Username, name); err != nil {
return err
}
if err = operator.PrivilegesGrant(user.Username, name); err != nil {
return err
}
}
user.Password = req.Password
user.Remark = req.Remark
return r.db.Save(user).Error
}
func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error {
func (r *databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error {
user, err := r.Get(req.ID)
if err != nil {
return err
@@ -198,7 +151,7 @@ func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) er
return r.db.Save(user).Error
}
func (r databaseUserRepo) Delete(id uint) error {
func (r *databaseUserRepo) Delete(id uint) error {
user, err := r.Get(id)
if err != nil {
return err
@@ -209,45 +162,31 @@ func (r databaseUserRepo) Delete(id uint) error {
return err
}
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
_ = mysql.UserDrop(user.Username, user.Host)
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
_ = postgres.UserDrop(user.Username)
operator, err := r.getOperator(server)
if err != nil {
return err
}
defer operator.Close()
_ = operator.UserDrop(user.Username)
return r.db.Where("id = ?", id).Delete(&biz.DatabaseUser{}).Error
}
func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
func (r *databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
server, err := r.server.Get(serverID)
if err != nil {
return err
}
operator, err := r.getOperator(server)
if err != nil {
return err
}
defer operator.Close()
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return err
}
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
users := make([]*biz.DatabaseUser, 0)
if err = r.db.Where("server_id = ? AND username IN ?", serverID, names).Find(&users).Error; err != nil {
return err
@@ -260,57 +199,42 @@ func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
break
}
}
_ = mysql.UserDrop(name, host)
_ = operator.UserDrop(name, host)
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return err
}
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
for name := range slices.Values(names) {
_ = postgres.UserDrop(name)
_ = operator.UserDrop(name)
}
}
return r.db.Where("server_id = ? AND username IN ?", serverID, names).Delete(&biz.DatabaseUser{}).Error
}
func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) {
func (r *databaseUserRepo) fillUser(user *biz.DatabaseUser) {
server, err := r.server.Get(user.ServerID)
if err == nil {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err == nil {
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
privileges, _ := mysql.UserPrivileges(user.Username, user.Host)
operator, err := r.getOperator(server)
if err == nil {
defer operator.Close()
switch server.Type {
case biz.DatabaseTypeMysql:
privileges, _ := operator.UserPrivileges(user.Username, user.Host)
user.Privileges = privileges
}
if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil {
_ = mysql2.Close()
user.Status = biz.DatabaseUserStatusValid
} else {
user.Status = biz.DatabaseUserStatusInvalid
}
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err == nil {
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
privileges, _ := postgres.UserPrivileges(user.Username)
if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil {
mysql2.Close()
user.Status = biz.DatabaseUserStatusValid
} else {
user.Status = biz.DatabaseUserStatusInvalid
}
case biz.DatabaseTypePostgresql:
privileges, _ := operator.UserPrivileges(user.Username)
user.Privileges = privileges
}
if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil {
_ = postgres2.Close()
user.Status = biz.DatabaseUserStatusValid
} else {
user.Status = biz.DatabaseUserStatusInvalid
if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil {
postgres2.Close()
user.Status = biz.DatabaseUserStatusValid
} else {
user.Status = biz.DatabaseUserStatusInvalid
}
}
}
}
@@ -319,3 +243,22 @@ func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) {
user.Privileges = make([]string, 0)
}
}
func (r *databaseUserRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
switch server.Type {
case biz.DatabaseTypeMysql:
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
if err != nil {
return nil, err
}
return mysql, nil
case biz.DatabaseTypePostgresql:
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
if err != nil {
return nil, err
}
return postgres, nil
default:
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
}
}

View File

@@ -145,9 +145,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) {
rootPassword, _ := s.settingRepo.Get(biz.SettingKeyMySQLRootPassword)
mysql, err := db.NewMySQL("root", rootPassword, "/tmp/mysql.sock", "unix")
if err == nil {
defer func(mysql *db.MySQL) {
_ = mysql.Close()
}(mysql)
defer mysql.Close()
databases, err := mysql.Databases()
if err == nil {
databaseCount += len(databases)
@@ -157,9 +155,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) {
if postgresqlInstalled {
postgres, err := db.NewPostgres("postgres", "", "127.0.0.1", 5432)
if err == nil {
defer func(postgres *db.Postgres) {
_ = postgres.Close()
}(postgres)
defer postgres.Close()
databases, err := postgres.Databases()
if err == nil {
databaseCount += len(databases)

29
pkg/db/db.go Normal file
View File

@@ -0,0 +1,29 @@
package db
import "database/sql"
type Operator interface {
Close()
Ping() error
Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
DatabaseCreate(name string) error
DatabaseDrop(name string) error
DatabaseExists(name string) (bool, error)
DatabaseSize(name string) (int64, error)
UserCreate(user, password string, host ...string) error
UserDrop(user string, host ...string) error
UserPassword(user, password string, host ...string) error
UserPrivileges(user string, host ...string) ([]string, error)
PrivilegesGrant(user, database string, host ...string) error
PrivilegesRevoke(user, database string, host ...string) error
Users() ([]User, error)
Databases() ([]Database, error)
}

View File

@@ -7,8 +7,6 @@ import (
"slices"
_ "github.com/go-sql-driver/mysql"
"github.com/acepanel/panel/pkg/types"
)
type MySQL struct {
@@ -18,7 +16,7 @@ type MySQL struct {
address string
}
func NewMySQL(username, password, address string, typ ...string) (*MySQL, error) {
func NewMySQL(username, password, address string, typ ...string) (Operator, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/", username, password, address)
if len(typ) > 0 && typ[0] == "unix" {
dsn = fmt.Sprintf("%s:%s@unix(%s)/", username, password, address)
@@ -38,8 +36,8 @@ func NewMySQL(username, password, address string, typ ...string) (*MySQL, error)
}, nil
}
func (r *MySQL) Close() error {
return r.db.Close()
func (r *MySQL) Close() {
_ = r.db.Close()
}
func (r *MySQL) Ping() error {
@@ -101,44 +99,43 @@ func (r *MySQL) DatabaseSize(name string) (int64, error) {
return size, err
}
func (r *MySQL) UserCreate(user, password, host string) error {
_, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host, password))
func (r *MySQL) UserCreate(user, password string, host ...string) error {
if len(host) == 0 {
host = append(host, "%")
}
_, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password))
r.flushPrivileges()
return err
}
func (r *MySQL) UserDrop(user, host string) error {
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host))
func (r *MySQL) UserDrop(user string, host ...string) error {
if len(host) == 0 {
host = append(host, "%")
}
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host[0]))
r.flushPrivileges()
return err
}
func (r *MySQL) UserPassword(user, password, host string) error {
_, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host, password))
func (r *MySQL) UserPassword(user, password string, host ...string) error {
if len(host) == 0 {
host = append(host, "%")
}
_, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password))
r.flushPrivileges()
return err
}
func (r *MySQL) PrivilegesGrant(user, database, host string) error {
_, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host))
r.flushPrivileges()
return err
}
func (r *MySQL) UserPrivileges(user string, host ...string) ([]string, error) {
if len(host) == 0 {
host = append(host, "%")
}
func (r *MySQL) PrivilegesRevoke(user, database, host string) error {
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host))
r.flushPrivileges()
return err
}
func (r *MySQL) UserPrivileges(user, host string) ([]string, error) {
rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host))
rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host[0]))
if err != nil {
return nil, err
}
defer func(rows *sql.Rows) {
_ = rows.Close()
}(rows)
defer func(rows *sql.Rows) { _ = rows.Close() }(rows)
re := regexp.MustCompile(`GRANT\s+ALL PRIVILEGES\s+ON\s+[\x60'"]?([^\s\x60'"]+)[\x60'"]?\.\*\s+TO\s+`)
var databases []string
@@ -166,16 +163,32 @@ func (r *MySQL) UserPrivileges(user, host string) ([]string, error) {
return slices.Compact(databases), nil
}
func (r *MySQL) Users() ([]types.MySQLUser, error) {
func (r *MySQL) PrivilegesGrant(user, database string, host ...string) error {
if len(host) == 0 {
host = append(host, "%")
}
_, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host[0]))
r.flushPrivileges()
return err
}
func (r *MySQL) PrivilegesRevoke(user, database string, host ...string) error {
if len(host) == 0 {
host = append(host, "%")
}
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host[0]))
r.flushPrivileges()
return err
}
func (r *MySQL) Users() ([]User, error) {
rows, err := r.Query("SELECT user, host FROM mysql.user")
if err != nil {
return nil, err
}
defer func(rows *sql.Rows) {
_ = rows.Close()
}(rows)
defer func(rows *sql.Rows) { _ = rows.Close() }(rows)
var users []types.MySQLUser
var users []User
for rows.Next() {
var user, host string
if err := rows.Scan(&user, &host); err != nil {
@@ -186,7 +199,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) {
continue
}
users = append(users, types.MySQLUser{
users = append(users, User{
User: user,
Host: host,
Grants: grants,
@@ -196,7 +209,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) {
return users, nil
}
func (r *MySQL) Databases() ([]types.MySQLDatabase, error) {
func (r *MySQL) Databases() ([]Database, error) {
query := `
SELECT
SCHEMA_NAME,
@@ -214,9 +227,9 @@ func (r *MySQL) Databases() ([]types.MySQLDatabase, error) {
_ = rows.Close()
}(rows)
var databases []types.MySQLDatabase
var databases []Database
for rows.Next() {
var db types.MySQLDatabase
var db Database
if err = rows.Scan(&db.Name, &db.CharSet, &db.Collation); err != nil {
return nil, err
}

View File

@@ -9,7 +9,6 @@ import (
_ "github.com/lib/pq"
"github.com/acepanel/panel/pkg/systemctl"
"github.com/acepanel/panel/pkg/types"
)
type Postgres struct {
@@ -20,7 +19,7 @@ type Postgres struct {
port uint
}
func NewPostgres(username, password, address string, port uint) (*Postgres, error) {
func NewPostgres(username, password, address string, port uint) (Operator, error) {
username = strings.ReplaceAll(username, `'`, `\'`)
password = strings.ReplaceAll(password, `'`, `\'`)
dsn := fmt.Sprintf(`host=%s port=%d user='%s' password='%s' dbname=postgres sslmode=disable`, address, port, username, password)
@@ -46,8 +45,8 @@ func NewPostgres(username, password, address string, port uint) (*Postgres, erro
}, nil
}
func (r *Postgres) Close() error {
return r.db.Close()
func (r *Postgres) Close() {
_ = r.db.Close()
}
func (r *Postgres) Ping() error {
@@ -72,7 +71,7 @@ func (r *Postgres) Prepare(query string) (*sql.Stmt, error) {
func (r *Postgres) DatabaseCreate(name string) error {
// postgres 不支持 CREATE DATABASE IF NOT EXISTS但是为了保持与 MySQL 一致,先检查数据库是否存在
exist, err := r.DatabaseExist(name)
exist, err := r.DatabaseExists(name)
if err != nil {
return err
}
@@ -88,7 +87,7 @@ func (r *Postgres) DatabaseDrop(name string) error {
return err
}
func (r *Postgres) DatabaseExist(name string) (bool, error) {
func (r *Postgres) DatabaseExists(name string) (bool, error) {
var count int
if err := r.QueryRow("SELECT COUNT(*) FROM pg_database WHERE datname = $1", name).Scan(&count); err != nil {
return false, err
@@ -110,7 +109,7 @@ func (r *Postgres) DatabaseComment(name, comment string) error {
return err
}
func (r *Postgres) UserCreate(user, password string) error {
func (r *Postgres) UserCreate(user, password string, host ...string) error {
_, err := r.Exec(fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'", user, password))
if err != nil {
return err
@@ -119,7 +118,7 @@ func (r *Postgres) UserCreate(user, password string) error {
return nil
}
func (r *Postgres) UserDrop(user string) error {
func (r *Postgres) UserDrop(user string, host ...string) error {
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS %s", user))
if err != nil {
return err
@@ -128,12 +127,12 @@ func (r *Postgres) UserDrop(user string) error {
return systemctl.Reload("postgresql")
}
func (r *Postgres) UserPassword(user, password string) error {
func (r *Postgres) UserPassword(user, password string, host ...string) error {
_, err := r.Exec(fmt.Sprintf("ALTER USER %s WITH PASSWORD '%s'", user, password))
return err
}
func (r *Postgres) UserPrivileges(user string) ([]string, error) {
func (r *Postgres) UserPrivileges(user string, host ...string) ([]string, error) {
query := `
SELECT d.datname
FROM pg_catalog.pg_database d
@@ -169,7 +168,7 @@ func (r *Postgres) UserPrivileges(user string) ([]string, error) {
return databases, nil
}
func (r *Postgres) PrivilegesGrant(user, database string) error {
func (r *Postgres) PrivilegesGrant(user, database string, host ...string) error {
if _, err := r.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", database, user)); err != nil {
return err
}
@@ -180,12 +179,12 @@ func (r *Postgres) PrivilegesGrant(user, database string) error {
return nil
}
func (r *Postgres) PrivilegesRevoke(user, database string) error {
func (r *Postgres) PrivilegesRevoke(user, database string, host ...string) error {
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s", database, user))
return err
}
func (r *Postgres) Users() ([]types.PostgresUser, error) {
func (r *Postgres) Users() ([]User, error) {
query := `
SELECT rolname,
rolsuper,
@@ -204,11 +203,11 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
_ = rows.Close()
}(rows)
var users []types.PostgresUser
var users []User
for rows.Next() {
var user types.PostgresUser
var user User
var super, canCreateRole, canCreateDb, replication, bypassRls bool
if err = rows.Scan(&user.Role, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil {
if err = rows.Scan(&user.User, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil {
return nil, err
}
@@ -221,12 +220,12 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
}
for perm, enabled := range permissions {
if enabled {
user.Attributes = append(user.Attributes, perm)
user.Grants = append(user.Grants, perm)
}
}
if len(user.Attributes) == 0 {
user.Attributes = append(user.Attributes, "")
if len(user.Grants) == 0 {
user.Grants = append(user.Grants, "None")
}
users = append(users, user)
@@ -239,7 +238,7 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
return users, nil
}
func (r *Postgres) Databases() ([]types.PostgresDatabase, error) {
func (r *Postgres) Databases() ([]Database, error) {
query := `
SELECT
d.datname,
@@ -257,10 +256,10 @@ func (r *Postgres) Databases() ([]types.PostgresDatabase, error) {
_ = rows.Close()
}(rows)
var databases []types.PostgresDatabase
var databases []Database
for rows.Next() {
var db types.PostgresDatabase
if err := rows.Scan(&db.Name, &db.Owner, &db.Encoding, &db.Comment); err != nil {
var db Database
if err := rows.Scan(&db.Name, &db.Owner, &db.CharSet, &db.Comment); err != nil {
return nil, err
}
if slices.Contains([]string{"template0", "template1", "postgres"}, db.Name) {

View File

@@ -39,8 +39,8 @@ func NewRedis(username, password, address string) (*Redis, error) {
}, nil
}
func (r *Redis) Close() error {
return r.conn.Close()
func (r *Redis) Close() {
_ = r.conn.Close()
}
func (r *Redis) Exec(command string, args ...any) (any, error) {

15
pkg/db/types.go Normal file
View File

@@ -0,0 +1,15 @@
package db
type User struct {
User string `json:"user"` // 用户名PG 里面对应 Role
Host string `json:"host"` // 主机PG 这个字段为空
Grants []string `json:"grants"` // 权限列表
}
type Database struct {
Name string `json:"name"` // 数据库名
Owner string `json:"owner"` // 所有者MySQL 这个字段为空
CharSet string `json:"char_set"` // 字符集PG 里面对应 Encoding
Collation string `json:"collation"` // 校对集PG 这个字段为空
Comment string `json:"comment"`
}

View File

@@ -1,13 +0,0 @@
package types
type MySQLUser struct {
User string `json:"user"`
Host string `json:"host"`
Grants []string `json:"grants"`
}
type MySQLDatabase struct {
Name string `json:"name"`
CharSet string `json:"char_set"`
Collation string `json:"collation"`
}

View File

@@ -1,13 +0,0 @@
package types
type PostgresUser struct {
Role string `json:"role"`
Attributes []string `json:"attributes"`
}
type PostgresDatabase struct {
Name string `json:"name"`
Owner string `json:"owner"`
Encoding string `json:"encoding"`
Comment string `json:"comment"`
}