From 2c85070921ed28801c44dac29879235990372736 Mon Sep 17 00:00:00 2001 From: Catsayer <120027958+Catsayer-Chan@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:43:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=B6=88=E9=99=A4=20database=5Fuse?= =?UTF-8?q?r.go=20=E4=B8=AD=E7=9A=84=E9=87=8D=E5=A4=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=20(#1149)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/apps/mysql/app.go | 4 +- internal/data/backup.go | 20 +-- internal/data/database.go | 144 +++++++---------- internal/data/database_server.go | 86 +++++----- internal/data/database_user.go | 261 ++++++++++++------------------- internal/service/home.go | 8 +- pkg/db/db.go | 29 ++++ pkg/db/mysql.go | 85 +++++----- pkg/db/postgres.go | 45 +++--- pkg/db/redis.go | 4 +- pkg/db/types.go | 15 ++ pkg/types/mysql.go | 13 -- pkg/types/postgres.go | 13 -- 13 files changed, 333 insertions(+), 394 deletions(-) create mode 100644 pkg/db/db.go create mode 100644 pkg/db/types.go delete mode 100644 pkg/types/mysql.go delete mode 100644 pkg/types/postgres.go diff --git a/internal/apps/mysql/app.go b/internal/apps/mysql/app.go index 9d13efbb..9209d109 100644 --- a/internal/apps/mysql/app.go +++ b/internal/apps/mysql/app.go @@ -212,9 +212,7 @@ func (s *App) SetRootPassword(w http.ResponseWriter, r *http.Request) { return } } else { - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) + defer mysql.Close() if err = mysql.UserPassword("root", req.Password, "localhost"); err != nil { service.Error(w, http.StatusInternalServerError, "%v", err) return diff --git a/internal/data/backup.go b/internal/data/backup.go index 679d0754..7c05bb59 100644 --- a/internal/data/backup.go +++ b/internal/data/backup.go @@ -256,9 +256,7 @@ func (r *backupRepo) createMySQL(to string, name string) error { if err != nil { return err } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) + defer mysql.Close() if exist, _ := mysql.DatabaseExists(name); !exist { return errors.New(r.t.Get("database does not exist: %s", name)) } @@ -302,10 +300,8 @@ func (r *backupRepo) createPostgres(to string, name string) error { if err != nil { return err } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - if exist, _ := postgres.DatabaseExist(name); !exist { + defer postgres.Close() + if exist, _ := postgres.DatabaseExists(name); !exist { return errors.New(r.t.Get("database does not exist: %s", name)) } size, err := postgres.DatabaseSize(name) @@ -418,9 +414,7 @@ func (r *backupRepo) restoreMySQL(backup, target string) error { if err != nil { return err } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) + defer mysql.Close() if exist, _ := mysql.DatabaseExists(target); !exist { return errors.New(r.t.Get("database does not exist: %s", target)) } @@ -460,10 +454,8 @@ func (r *backupRepo) restorePostgres(backup, target string) error { if err != nil { return err } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - if exist, _ := postgres.DatabaseExist(target); !exist { + defer postgres.Close() + if exist, _ := postgres.DatabaseExists(target); !exist { return errors.New(r.t.Get("database does not exist: %s", target)) } diff --git a/internal/data/database.go b/internal/data/database.go index 039c2981..8e2dc89c 100644 --- a/internal/data/database.go +++ b/internal/data/database.go @@ -29,7 +29,7 @@ func NewDatabaseRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServerRep } } -func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) { +func (r *databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) { var databaseServer []*biz.DatabaseServer if err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Find(&databaseServer).Error; err != nil { return nil, 0, err @@ -37,61 +37,43 @@ func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) { database := make([]*biz.Database, 0) for _, server := range databaseServer { - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err == nil { - if databases, err := mysql.Databases(); err == nil { - for item := range slices.Values(databases) { - database = append(database, &biz.Database{ - Type: biz.DatabaseTypeMysql, - Name: item.Name, - Server: server.Name, - ServerID: server.ID, - Encoding: item.CharSet, - }) - } - } - _ = mysql.Close() - } - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err == nil { - if databases, err := postgres.Databases(); err == nil { - for item := range slices.Values(databases) { - database = append(database, &biz.Database{ - Type: biz.DatabaseTypePostgresql, - Name: item.Name, - Server: server.Name, - ServerID: server.ID, - Encoding: item.Encoding, - Comment: item.Comment, - }) - } - } - _ = postgres.Close() + operator, err := r.getOperator(server) + if err != nil { + continue + } + + if databases, err := operator.Databases(); err == nil { + for item := range slices.Values(databases) { + database = append(database, &biz.Database{ + Type: server.Type, + Name: item.Name, + Server: server.Name, + ServerID: server.ID, + Encoding: item.CharSet, + Comment: item.Comment, + }) } } + + operator.Close() } return database[(page-1)*limit:], int64(len(database)), nil } -func (r databaseRepo) Create(req *request.DatabaseCreate) error { +func (r *databaseRepo) Create(req *request.DatabaseCreate) error { server, err := r.server.Get(req.ServerID) if err != nil { return err } + operator, err := r.getOperator(server) + if err != nil { + return err + } + switch server.Type { case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { - return err - } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) if req.CreateUser { if err = r.user.Create(&request.DatabaseUserCreate{ ServerID: req.ServerID, @@ -102,22 +84,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error { return err } } - if err = mysql.DatabaseCreate(req.Name); err != nil { + if err = operator.DatabaseCreate(req.Name); err != nil { return err } if req.Username != "" { - if err = mysql.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil { + if err = operator.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil { return err } } case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) if req.CreateUser { if err = r.user.Create(&request.DatabaseUserCreate{ ServerID: req.ServerID, @@ -128,15 +103,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error { return err } } - if err = postgres.DatabaseCreate(req.Name); err != nil { + if err = operator.DatabaseCreate(req.Name); err != nil { return err } if req.Username != "" { - if err = postgres.PrivilegesGrant(req.Username, req.Name); err != nil { + if err = operator.PrivilegesGrant(req.Username, req.Name); err != nil { return err } } - if err = postgres.DatabaseComment(req.Name, req.Comment); err != nil { + if err = operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment); err != nil { return err } } @@ -144,55 +119,56 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error { return nil } -func (r databaseRepo) Delete(serverID uint, name string) error { +func (r *databaseRepo) Delete(serverID uint, name string) error { server, err := r.server.Get(serverID) if err != nil { return err } - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { - return err - } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - return mysql.DatabaseDrop(name) - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - return postgres.DatabaseDrop(name) + operator, err := r.getOperator(server) + if err != nil { + return err } - return nil + return operator.DatabaseDrop(name) } -func (r databaseRepo) Comment(req *request.DatabaseComment) error { +func (r *databaseRepo) Comment(req *request.DatabaseComment) error { server, err := r.server.Get(req.ServerID) if err != nil { return err } + operator, err := r.getOperator(server) + if err != nil { + return err + } + switch server.Type { case biz.DatabaseTypeMysql: return errors.New(r.t.Get("mysql not support database comment")) case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - return postgres.DatabaseComment(req.Name, req.Comment) + return operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment) } return nil } + +func (r *databaseRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) { + switch server.Type { + case biz.DatabaseTypeMysql: + mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) + if err != nil { + return nil, err + } + return mysql, nil + case biz.DatabaseTypePostgresql: + postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) + if err != nil { + return nil, err + } + return postgres, nil + default: + return nil, fmt.Errorf("unsupported database type: %s", server.Type) + } +} diff --git a/internal/data/database_server.go b/internal/data/database_server.go index 91694716..b0ec0464 100644 --- a/internal/data/database_server.go +++ b/internal/data/database_server.go @@ -28,7 +28,7 @@ func NewDatabaseServerRepo(t *gotext.Locale, db *gorm.DB, log *slog.Logger) biz. } } -func (r databaseServerRepo) Count() (int64, error) { +func (r *databaseServerRepo) Count() (int64, error) { var count int64 if err := r.db.Model(&biz.DatabaseServer{}).Count(&count).Error; err != nil { return 0, err @@ -37,7 +37,7 @@ func (r databaseServerRepo) Count() (int64, error) { return count, nil } -func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) { +func (r *databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) { databaseServer := make([]*biz.DatabaseServer, 0) var total int64 err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&databaseServer).Error @@ -49,7 +49,7 @@ func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64 return databaseServer, total, err } -func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) { +func (r *databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) { databaseServer := new(biz.DatabaseServer) if err := r.db.Where("id = ?", id).First(databaseServer).Error; err != nil { return nil, err @@ -60,7 +60,7 @@ func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) { return databaseServer, nil } -func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) { +func (r *databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) { databaseServer := new(biz.DatabaseServer) if err := r.db.Where("name = ?", name).First(databaseServer).Error; err != nil { return nil, err @@ -71,7 +71,7 @@ func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) return databaseServer, nil } -func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error { +func (r *databaseServerRepo) Create(req *request.DatabaseServerCreate) error { databaseServer := &biz.DatabaseServer{ Name: req.Name, Type: biz.DatabaseType(req.Type), @@ -89,7 +89,7 @@ func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error { return r.db.Create(databaseServer).Error } -func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error { +func (r *databaseServerRepo) Update(req *request.DatabaseServerUpdate) error { server, err := r.Get(req.ID) if err != nil { return err @@ -109,11 +109,11 @@ func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error { return r.db.Save(server).Error } -func (r databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error { +func (r *databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error { return r.db.Model(&biz.DatabaseServer{}).Where("id = ?", req.ID).Update("remark", req.Remark).Error } -func (r databaseServerRepo) Delete(id uint) error { +func (r *databaseServerRepo) Delete(id uint) error { if err := r.ClearUsers(id); err != nil { return err } @@ -122,11 +122,11 @@ func (r databaseServerRepo) Delete(id uint) error { } // ClearUsers 删除指定服务器的所有用户,只是删除面板记录,不会实际删除 -func (r databaseServerRepo) ClearUsers(serverID uint) error { +func (r *databaseServerRepo) ClearUsers(serverID uint) error { return r.db.Where("server_id = ?", serverID).Delete(&biz.DatabaseUser{}).Error } -func (r databaseServerRepo) Sync(id uint) error { +func (r *databaseServerRepo) Sync(id uint) error { server, err := r.Get(id) if err != nil { return err @@ -137,16 +137,15 @@ func (r databaseServerRepo) Sync(id uint) error { return err } + operator, err := r.getOperator(server) + if err != nil { + return err + } + defer operator.Close() + switch server.Type { case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { - return err - } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - allUsers, err := mysql.Users() + allUsers, err := operator.Users() if err != nil { return err } @@ -166,24 +165,17 @@ func (r databaseServerRepo) Sync(id uint) error { } } case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - allUsers, err := postgres.Users() + allUsers, err := operator.Users() if err != nil { return err } for user := range slices.Values(allUsers) { if !slices.ContainsFunc(users, func(a *biz.DatabaseUser) bool { - return a.Username == user.Role - }) && !slices.Contains([]string{"postgres"}, user.Role) { + return a.Username == user.User + }) && !slices.Contains([]string{"postgres"}, user.User) { newUser := &biz.DatabaseUser{ ServerID: id, - Username: user.Role, + Username: user.User, Remark: r.t.Get("sync from server %s", server.Name), } r.db.Create(newUser) @@ -195,26 +187,19 @@ func (r databaseServerRepo) Sync(id uint) error { } // checkServer 检查服务器连接 -func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool { +func (r *databaseServerRepo) checkServer(server *biz.DatabaseServer) bool { switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) + case biz.DatabaseTypeMysql, biz.DatabaseTypePostgresql: + operator, err := r.getOperator(server) if err == nil { - _ = mysql.Close() - server.Status = biz.DatabaseServerStatusValid - return true - } - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err == nil { - _ = postgres.Close() + operator.Close() server.Status = biz.DatabaseServerStatusValid return true } case biz.DatabaseTypeRedis: redis, err := db.NewRedis(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) if err == nil { - _ = redis.Close() + redis.Close() server.Status = biz.DatabaseServerStatusValid return true } @@ -223,3 +208,22 @@ func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool { server.Status = biz.DatabaseServerStatusInvalid return false } + +func (r *databaseServerRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) { + switch server.Type { + case biz.DatabaseTypeMysql: + mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) + if err != nil { + return nil, err + } + return mysql, nil + case biz.DatabaseTypePostgresql: + postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) + if err != nil { + return nil, err + } + return postgres, nil + default: + return nil, fmt.Errorf("unsupported database type: %s", server.Type) + } +} diff --git a/internal/data/database_user.go b/internal/data/database_user.go index 269c9cec..94e78349 100644 --- a/internal/data/database_user.go +++ b/internal/data/database_user.go @@ -26,7 +26,7 @@ func NewDatabaseUserRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServe } } -func (r databaseUserRepo) Count() (int64, error) { +func (r *databaseUserRepo) Count() (int64, error) { var count int64 if err := r.db.Model(&biz.DatabaseUser{}).Count(&count).Error; err != nil { return 0, err @@ -35,7 +35,7 @@ func (r databaseUserRepo) Count() (int64, error) { return count, nil } -func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) { +func (r *databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) { user := make([]*biz.DatabaseUser, 0) var total int64 err := r.db.Model(&biz.DatabaseUser{}).Preload("Server").Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&user).Error @@ -47,7 +47,7 @@ func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, er return user, total, err } -func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) { +func (r *databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) { user := new(biz.DatabaseUser) if err := r.db.Preload("Server").Where("id = ?", id).First(user).Error; err != nil { return nil, err @@ -58,74 +58,49 @@ func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) { return user, nil } -func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error { +func (r *databaseUserRepo) Create(req *request.DatabaseUserCreate) error { server, err := r.server.Get(req.ServerID) if err != nil { return err } - user := new(biz.DatabaseUser) - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { + operator, err := r.getOperator(server) + if err != nil { + return err + } + defer operator.Close() + + // 创建用户 + if err = operator.UserCreate(req.Username, req.Password); err != nil { + return err + } + + // 创建数据库并授权 + for name := range slices.Values(req.Privileges) { + if err = operator.DatabaseCreate(name); err != nil { return err } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - if err = mysql.UserCreate(req.Username, req.Password, req.Host); err != nil { + if err = operator.PrivilegesGrant(req.Username, name); err != nil { return err } - for name := range slices.Values(req.Privileges) { - if err = mysql.DatabaseCreate(name); err != nil { - return err - } - if err = mysql.PrivilegesGrant(req.Username, name, req.Host); err != nil { - return err - } - } - user = &biz.DatabaseUser{ - ServerID: req.ServerID, - Username: req.Username, - Host: req.Host, - } - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - if err = postgres.UserCreate(req.Username, req.Password); err != nil { - return err - } - for name := range slices.Values(req.Privileges) { - if err = postgres.DatabaseCreate(name); err != nil { - return err - } - if err = postgres.PrivilegesGrant(req.Username, name); err != nil { - return err - } - } - user = &biz.DatabaseUser{ - ServerID: req.ServerID, - Username: req.Username, - } + } + + user := &biz.DatabaseUser{ + ServerID: req.ServerID, + Username: req.Username, + Host: req.Host, + Password: req.Password, + Remark: req.Remark, } if err = r.db.FirstOrInit(user, user).Error; err != nil { return err } - user.Password = req.Password - user.Remark = req.Remark - return r.db.Save(user).Error } -func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error { +func (r *databaseUserRepo) Update(req *request.DatabaseUserUpdate) error { user, err := r.Get(req.ID) if err != nil { return err @@ -136,58 +111,36 @@ func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error { return err } - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { + operator, err := r.getOperator(server) + if err != nil { + return err + } + defer operator.Close() + + // 更新密码 + if req.Password != "" { + if err = operator.UserPassword(user.Username, req.Password); err != nil { return err } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - if req.Password != "" { - if err = mysql.UserPassword(user.Username, req.Password, user.Host); err != nil { - return err - } - } - for name := range slices.Values(req.Privileges) { - if err = mysql.DatabaseCreate(name); err != nil { - return err - } - if err = mysql.PrivilegesGrant(user.Username, name, user.Host); err != nil { - return err - } - } - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { + user.Password = req.Password + } + + // 创建数据库并授权 + for name := range slices.Values(req.Privileges) { + if err = operator.DatabaseCreate(name); err != nil { return err } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - if req.Password != "" { - if err = postgres.UserPassword(user.Username, req.Password); err != nil { - return err - } - } - for name := range slices.Values(req.Privileges) { - if err = postgres.DatabaseCreate(name); err != nil { - return err - } - if err = postgres.PrivilegesGrant(user.Username, name); err != nil { - return err - } + if err = operator.PrivilegesGrant(user.Username, name); err != nil { + return err } } - user.Password = req.Password user.Remark = req.Remark return r.db.Save(user).Error } -func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error { +func (r *databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error { user, err := r.Get(req.ID) if err != nil { return err @@ -198,7 +151,7 @@ func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) er return r.db.Save(user).Error } -func (r databaseUserRepo) Delete(id uint) error { +func (r *databaseUserRepo) Delete(id uint) error { user, err := r.Get(id) if err != nil { return err @@ -209,45 +162,31 @@ func (r databaseUserRepo) Delete(id uint) error { return err } - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { - return err - } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - _ = mysql.UserDrop(user.Username, user.Host) - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - _ = postgres.UserDrop(user.Username) + operator, err := r.getOperator(server) + if err != nil { + return err } + defer operator.Close() + + _ = operator.UserDrop(user.Username) return r.db.Where("id = ?", id).Delete(&biz.DatabaseUser{}).Error } -func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error { +func (r *databaseUserRepo) DeleteByNames(serverID uint, names []string) error { server, err := r.server.Get(serverID) if err != nil { return err } + operator, err := r.getOperator(server) + if err != nil { + return err + } + defer operator.Close() + switch server.Type { case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err != nil { - return err - } - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) users := make([]*biz.DatabaseUser, 0) if err = r.db.Where("server_id = ? AND username IN ?", serverID, names).Find(&users).Error; err != nil { return err @@ -260,57 +199,42 @@ func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error { break } } - _ = mysql.UserDrop(name, host) + _ = operator.UserDrop(name, host) } case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err != nil { - return err - } - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) for name := range slices.Values(names) { - _ = postgres.UserDrop(name) + _ = operator.UserDrop(name) } } return r.db.Where("server_id = ? AND username IN ?", serverID, names).Delete(&biz.DatabaseUser{}).Error } -func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) { +func (r *databaseUserRepo) fillUser(user *biz.DatabaseUser) { server, err := r.server.Get(user.ServerID) if err == nil { - switch server.Type { - case biz.DatabaseTypeMysql: - mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) - if err == nil { - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) - privileges, _ := mysql.UserPrivileges(user.Username, user.Host) + operator, err := r.getOperator(server) + if err == nil { + defer operator.Close() + switch server.Type { + case biz.DatabaseTypeMysql: + privileges, _ := operator.UserPrivileges(user.Username, user.Host) user.Privileges = privileges - } - if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil { - _ = mysql2.Close() - user.Status = biz.DatabaseUserStatusValid - } else { - user.Status = biz.DatabaseUserStatusInvalid - } - case biz.DatabaseTypePostgresql: - postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) - if err == nil { - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) - privileges, _ := postgres.UserPrivileges(user.Username) + if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil { + mysql2.Close() + user.Status = biz.DatabaseUserStatusValid + } else { + user.Status = biz.DatabaseUserStatusInvalid + } + case biz.DatabaseTypePostgresql: + privileges, _ := operator.UserPrivileges(user.Username) user.Privileges = privileges - } - if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil { - _ = postgres2.Close() - user.Status = biz.DatabaseUserStatusValid - } else { - user.Status = biz.DatabaseUserStatusInvalid + if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil { + postgres2.Close() + user.Status = biz.DatabaseUserStatusValid + } else { + user.Status = biz.DatabaseUserStatusInvalid + } } } } @@ -319,3 +243,22 @@ func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) { user.Privileges = make([]string, 0) } } + +func (r *databaseUserRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) { + switch server.Type { + case biz.DatabaseTypeMysql: + mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)) + if err != nil { + return nil, err + } + return mysql, nil + case biz.DatabaseTypePostgresql: + postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port) + if err != nil { + return nil, err + } + return postgres, nil + default: + return nil, fmt.Errorf("unsupported database type: %s", server.Type) + } +} diff --git a/internal/service/home.go b/internal/service/home.go index ba9d445c..23d053a7 100644 --- a/internal/service/home.go +++ b/internal/service/home.go @@ -145,9 +145,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) { rootPassword, _ := s.settingRepo.Get(biz.SettingKeyMySQLRootPassword) mysql, err := db.NewMySQL("root", rootPassword, "/tmp/mysql.sock", "unix") if err == nil { - defer func(mysql *db.MySQL) { - _ = mysql.Close() - }(mysql) + defer mysql.Close() databases, err := mysql.Databases() if err == nil { databaseCount += len(databases) @@ -157,9 +155,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) { if postgresqlInstalled { postgres, err := db.NewPostgres("postgres", "", "127.0.0.1", 5432) if err == nil { - defer func(postgres *db.Postgres) { - _ = postgres.Close() - }(postgres) + defer postgres.Close() databases, err := postgres.Databases() if err == nil { databaseCount += len(databases) diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 00000000..4bd41d85 --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,29 @@ +package db + +import "database/sql" + +type Operator interface { + Close() + Ping() error + + Query(query string, args ...any) (*sql.Rows, error) + QueryRow(query string, args ...any) *sql.Row + Exec(query string, args ...any) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + + DatabaseCreate(name string) error + DatabaseDrop(name string) error + DatabaseExists(name string) (bool, error) + DatabaseSize(name string) (int64, error) + + UserCreate(user, password string, host ...string) error + UserDrop(user string, host ...string) error + UserPassword(user, password string, host ...string) error + UserPrivileges(user string, host ...string) ([]string, error) + + PrivilegesGrant(user, database string, host ...string) error + PrivilegesRevoke(user, database string, host ...string) error + + Users() ([]User, error) + Databases() ([]Database, error) +} diff --git a/pkg/db/mysql.go b/pkg/db/mysql.go index 3ec73f73..277b8d22 100644 --- a/pkg/db/mysql.go +++ b/pkg/db/mysql.go @@ -7,8 +7,6 @@ import ( "slices" _ "github.com/go-sql-driver/mysql" - - "github.com/acepanel/panel/pkg/types" ) type MySQL struct { @@ -18,7 +16,7 @@ type MySQL struct { address string } -func NewMySQL(username, password, address string, typ ...string) (*MySQL, error) { +func NewMySQL(username, password, address string, typ ...string) (Operator, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s)/", username, password, address) if len(typ) > 0 && typ[0] == "unix" { dsn = fmt.Sprintf("%s:%s@unix(%s)/", username, password, address) @@ -38,8 +36,8 @@ func NewMySQL(username, password, address string, typ ...string) (*MySQL, error) }, nil } -func (r *MySQL) Close() error { - return r.db.Close() +func (r *MySQL) Close() { + _ = r.db.Close() } func (r *MySQL) Ping() error { @@ -101,44 +99,43 @@ func (r *MySQL) DatabaseSize(name string) (int64, error) { return size, err } -func (r *MySQL) UserCreate(user, password, host string) error { - _, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host, password)) +func (r *MySQL) UserCreate(user, password string, host ...string) error { + if len(host) == 0 { + host = append(host, "%") + } + _, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password)) r.flushPrivileges() return err } -func (r *MySQL) UserDrop(user, host string) error { - _, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host)) +func (r *MySQL) UserDrop(user string, host ...string) error { + if len(host) == 0 { + host = append(host, "%") + } + _, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host[0])) r.flushPrivileges() return err } -func (r *MySQL) UserPassword(user, password, host string) error { - _, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host, password)) +func (r *MySQL) UserPassword(user, password string, host ...string) error { + if len(host) == 0 { + host = append(host, "%") + } + _, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password)) r.flushPrivileges() return err } -func (r *MySQL) PrivilegesGrant(user, database, host string) error { - _, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host)) - r.flushPrivileges() - return err -} +func (r *MySQL) UserPrivileges(user string, host ...string) ([]string, error) { + if len(host) == 0 { + host = append(host, "%") + } -func (r *MySQL) PrivilegesRevoke(user, database, host string) error { - _, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host)) - r.flushPrivileges() - return err -} - -func (r *MySQL) UserPrivileges(user, host string) ([]string, error) { - rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host)) + rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host[0])) if err != nil { return nil, err } - defer func(rows *sql.Rows) { - _ = rows.Close() - }(rows) + defer func(rows *sql.Rows) { _ = rows.Close() }(rows) re := regexp.MustCompile(`GRANT\s+ALL PRIVILEGES\s+ON\s+[\x60'"]?([^\s\x60'"]+)[\x60'"]?\.\*\s+TO\s+`) var databases []string @@ -166,16 +163,32 @@ func (r *MySQL) UserPrivileges(user, host string) ([]string, error) { return slices.Compact(databases), nil } -func (r *MySQL) Users() ([]types.MySQLUser, error) { +func (r *MySQL) PrivilegesGrant(user, database string, host ...string) error { + if len(host) == 0 { + host = append(host, "%") + } + _, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host[0])) + r.flushPrivileges() + return err +} + +func (r *MySQL) PrivilegesRevoke(user, database string, host ...string) error { + if len(host) == 0 { + host = append(host, "%") + } + _, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host[0])) + r.flushPrivileges() + return err +} + +func (r *MySQL) Users() ([]User, error) { rows, err := r.Query("SELECT user, host FROM mysql.user") if err != nil { return nil, err } - defer func(rows *sql.Rows) { - _ = rows.Close() - }(rows) + defer func(rows *sql.Rows) { _ = rows.Close() }(rows) - var users []types.MySQLUser + var users []User for rows.Next() { var user, host string if err := rows.Scan(&user, &host); err != nil { @@ -186,7 +199,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) { continue } - users = append(users, types.MySQLUser{ + users = append(users, User{ User: user, Host: host, Grants: grants, @@ -196,7 +209,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) { return users, nil } -func (r *MySQL) Databases() ([]types.MySQLDatabase, error) { +func (r *MySQL) Databases() ([]Database, error) { query := ` SELECT SCHEMA_NAME, @@ -214,9 +227,9 @@ func (r *MySQL) Databases() ([]types.MySQLDatabase, error) { _ = rows.Close() }(rows) - var databases []types.MySQLDatabase + var databases []Database for rows.Next() { - var db types.MySQLDatabase + var db Database if err = rows.Scan(&db.Name, &db.CharSet, &db.Collation); err != nil { return nil, err } diff --git a/pkg/db/postgres.go b/pkg/db/postgres.go index 49a2fe28..9a781647 100644 --- a/pkg/db/postgres.go +++ b/pkg/db/postgres.go @@ -9,7 +9,6 @@ import ( _ "github.com/lib/pq" "github.com/acepanel/panel/pkg/systemctl" - "github.com/acepanel/panel/pkg/types" ) type Postgres struct { @@ -20,7 +19,7 @@ type Postgres struct { port uint } -func NewPostgres(username, password, address string, port uint) (*Postgres, error) { +func NewPostgres(username, password, address string, port uint) (Operator, error) { username = strings.ReplaceAll(username, `'`, `\'`) password = strings.ReplaceAll(password, `'`, `\'`) dsn := fmt.Sprintf(`host=%s port=%d user='%s' password='%s' dbname=postgres sslmode=disable`, address, port, username, password) @@ -46,8 +45,8 @@ func NewPostgres(username, password, address string, port uint) (*Postgres, erro }, nil } -func (r *Postgres) Close() error { - return r.db.Close() +func (r *Postgres) Close() { + _ = r.db.Close() } func (r *Postgres) Ping() error { @@ -72,7 +71,7 @@ func (r *Postgres) Prepare(query string) (*sql.Stmt, error) { func (r *Postgres) DatabaseCreate(name string) error { // postgres 不支持 CREATE DATABASE IF NOT EXISTS,但是为了保持与 MySQL 一致,先检查数据库是否存在 - exist, err := r.DatabaseExist(name) + exist, err := r.DatabaseExists(name) if err != nil { return err } @@ -88,7 +87,7 @@ func (r *Postgres) DatabaseDrop(name string) error { return err } -func (r *Postgres) DatabaseExist(name string) (bool, error) { +func (r *Postgres) DatabaseExists(name string) (bool, error) { var count int if err := r.QueryRow("SELECT COUNT(*) FROM pg_database WHERE datname = $1", name).Scan(&count); err != nil { return false, err @@ -110,7 +109,7 @@ func (r *Postgres) DatabaseComment(name, comment string) error { return err } -func (r *Postgres) UserCreate(user, password string) error { +func (r *Postgres) UserCreate(user, password string, host ...string) error { _, err := r.Exec(fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'", user, password)) if err != nil { return err @@ -119,7 +118,7 @@ func (r *Postgres) UserCreate(user, password string) error { return nil } -func (r *Postgres) UserDrop(user string) error { +func (r *Postgres) UserDrop(user string, host ...string) error { _, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS %s", user)) if err != nil { return err @@ -128,12 +127,12 @@ func (r *Postgres) UserDrop(user string) error { return systemctl.Reload("postgresql") } -func (r *Postgres) UserPassword(user, password string) error { +func (r *Postgres) UserPassword(user, password string, host ...string) error { _, err := r.Exec(fmt.Sprintf("ALTER USER %s WITH PASSWORD '%s'", user, password)) return err } -func (r *Postgres) UserPrivileges(user string) ([]string, error) { +func (r *Postgres) UserPrivileges(user string, host ...string) ([]string, error) { query := ` SELECT d.datname FROM pg_catalog.pg_database d @@ -169,7 +168,7 @@ func (r *Postgres) UserPrivileges(user string) ([]string, error) { return databases, nil } -func (r *Postgres) PrivilegesGrant(user, database string) error { +func (r *Postgres) PrivilegesGrant(user, database string, host ...string) error { if _, err := r.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", database, user)); err != nil { return err } @@ -180,12 +179,12 @@ func (r *Postgres) PrivilegesGrant(user, database string) error { return nil } -func (r *Postgres) PrivilegesRevoke(user, database string) error { +func (r *Postgres) PrivilegesRevoke(user, database string, host ...string) error { _, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s", database, user)) return err } -func (r *Postgres) Users() ([]types.PostgresUser, error) { +func (r *Postgres) Users() ([]User, error) { query := ` SELECT rolname, rolsuper, @@ -204,11 +203,11 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) { _ = rows.Close() }(rows) - var users []types.PostgresUser + var users []User for rows.Next() { - var user types.PostgresUser + var user User var super, canCreateRole, canCreateDb, replication, bypassRls bool - if err = rows.Scan(&user.Role, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil { + if err = rows.Scan(&user.User, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil { return nil, err } @@ -221,12 +220,12 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) { } for perm, enabled := range permissions { if enabled { - user.Attributes = append(user.Attributes, perm) + user.Grants = append(user.Grants, perm) } } - if len(user.Attributes) == 0 { - user.Attributes = append(user.Attributes, "无") + if len(user.Grants) == 0 { + user.Grants = append(user.Grants, "None") } users = append(users, user) @@ -239,7 +238,7 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) { return users, nil } -func (r *Postgres) Databases() ([]types.PostgresDatabase, error) { +func (r *Postgres) Databases() ([]Database, error) { query := ` SELECT d.datname, @@ -257,10 +256,10 @@ func (r *Postgres) Databases() ([]types.PostgresDatabase, error) { _ = rows.Close() }(rows) - var databases []types.PostgresDatabase + var databases []Database for rows.Next() { - var db types.PostgresDatabase - if err := rows.Scan(&db.Name, &db.Owner, &db.Encoding, &db.Comment); err != nil { + var db Database + if err := rows.Scan(&db.Name, &db.Owner, &db.CharSet, &db.Comment); err != nil { return nil, err } if slices.Contains([]string{"template0", "template1", "postgres"}, db.Name) { diff --git a/pkg/db/redis.go b/pkg/db/redis.go index 8a33058e..21087641 100644 --- a/pkg/db/redis.go +++ b/pkg/db/redis.go @@ -39,8 +39,8 @@ func NewRedis(username, password, address string) (*Redis, error) { }, nil } -func (r *Redis) Close() error { - return r.conn.Close() +func (r *Redis) Close() { + _ = r.conn.Close() } func (r *Redis) Exec(command string, args ...any) (any, error) { diff --git a/pkg/db/types.go b/pkg/db/types.go new file mode 100644 index 00000000..85404689 --- /dev/null +++ b/pkg/db/types.go @@ -0,0 +1,15 @@ +package db + +type User struct { + User string `json:"user"` // 用户名,PG 里面对应 Role + Host string `json:"host"` // 主机,PG 这个字段为空 + Grants []string `json:"grants"` // 权限列表 +} + +type Database struct { + Name string `json:"name"` // 数据库名 + Owner string `json:"owner"` // 所有者,MySQL 这个字段为空 + CharSet string `json:"char_set"` // 字符集,PG 里面对应 Encoding + Collation string `json:"collation"` // 校对集,PG 这个字段为空 + Comment string `json:"comment"` +} diff --git a/pkg/types/mysql.go b/pkg/types/mysql.go deleted file mode 100644 index 2e2d7f8b..00000000 --- a/pkg/types/mysql.go +++ /dev/null @@ -1,13 +0,0 @@ -package types - -type MySQLUser struct { - User string `json:"user"` - Host string `json:"host"` - Grants []string `json:"grants"` -} - -type MySQLDatabase struct { - Name string `json:"name"` - CharSet string `json:"char_set"` - Collation string `json:"collation"` -} diff --git a/pkg/types/postgres.go b/pkg/types/postgres.go deleted file mode 100644 index bbd0901f..00000000 --- a/pkg/types/postgres.go +++ /dev/null @@ -1,13 +0,0 @@ -package types - -type PostgresUser struct { - Role string `json:"role"` - Attributes []string `json:"attributes"` -} - -type PostgresDatabase struct { - Name string `json:"name"` - Owner string `json:"owner"` - Encoding string `json:"encoding"` - Comment string `json:"comment"` -}