mirror of
https://github.com/acepanel/panel.git
synced 2026-02-04 07:57:21 +08:00
refactor: 消除 database_user.go 中的重复代码 (#1149)
This commit is contained in:
@@ -212,9 +212,7 @@ func (s *App) SetRootPassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
defer mysql.Close()
|
||||
if err = mysql.UserPassword("root", req.Password, "localhost"); err != nil {
|
||||
service.Error(w, http.StatusInternalServerError, "%v", err)
|
||||
return
|
||||
|
||||
@@ -256,9 +256,7 @@ func (r *backupRepo) createMySQL(to string, name string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
defer mysql.Close()
|
||||
if exist, _ := mysql.DatabaseExists(name); !exist {
|
||||
return errors.New(r.t.Get("database does not exist: %s", name))
|
||||
}
|
||||
@@ -302,10 +300,8 @@ func (r *backupRepo) createPostgres(to string, name string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
if exist, _ := postgres.DatabaseExist(name); !exist {
|
||||
defer postgres.Close()
|
||||
if exist, _ := postgres.DatabaseExists(name); !exist {
|
||||
return errors.New(r.t.Get("database does not exist: %s", name))
|
||||
}
|
||||
size, err := postgres.DatabaseSize(name)
|
||||
@@ -418,9 +414,7 @@ func (r *backupRepo) restoreMySQL(backup, target string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
defer mysql.Close()
|
||||
if exist, _ := mysql.DatabaseExists(target); !exist {
|
||||
return errors.New(r.t.Get("database does not exist: %s", target))
|
||||
}
|
||||
@@ -460,10 +454,8 @@ func (r *backupRepo) restorePostgres(backup, target string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
if exist, _ := postgres.DatabaseExist(target); !exist {
|
||||
defer postgres.Close()
|
||||
if exist, _ := postgres.DatabaseExists(target); !exist {
|
||||
return errors.New(r.t.Get("database does not exist: %s", target))
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewDatabaseRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServerRep
|
||||
}
|
||||
}
|
||||
|
||||
func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
|
||||
func (r *databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
|
||||
var databaseServer []*biz.DatabaseServer
|
||||
if err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Find(&databaseServer).Error; err != nil {
|
||||
return nil, 0, err
|
||||
@@ -37,61 +37,43 @@ func (r databaseRepo) List(page, limit uint) ([]*biz.Database, int64, error) {
|
||||
|
||||
database := make([]*biz.Database, 0)
|
||||
for _, server := range databaseServer {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err == nil {
|
||||
if databases, err := mysql.Databases(); err == nil {
|
||||
for item := range slices.Values(databases) {
|
||||
database = append(database, &biz.Database{
|
||||
Type: biz.DatabaseTypeMysql,
|
||||
Name: item.Name,
|
||||
Server: server.Name,
|
||||
ServerID: server.ID,
|
||||
Encoding: item.CharSet,
|
||||
})
|
||||
}
|
||||
}
|
||||
_ = mysql.Close()
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err == nil {
|
||||
if databases, err := postgres.Databases(); err == nil {
|
||||
for item := range slices.Values(databases) {
|
||||
database = append(database, &biz.Database{
|
||||
Type: biz.DatabaseTypePostgresql,
|
||||
Name: item.Name,
|
||||
Server: server.Name,
|
||||
ServerID: server.ID,
|
||||
Encoding: item.Encoding,
|
||||
Comment: item.Comment,
|
||||
})
|
||||
}
|
||||
}
|
||||
_ = postgres.Close()
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if databases, err := operator.Databases(); err == nil {
|
||||
for item := range slices.Values(databases) {
|
||||
database = append(database, &biz.Database{
|
||||
Type: server.Type,
|
||||
Name: item.Name,
|
||||
Server: server.Name,
|
||||
ServerID: server.ID,
|
||||
Encoding: item.CharSet,
|
||||
Comment: item.Comment,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
operator.Close()
|
||||
}
|
||||
|
||||
return database[(page-1)*limit:], int64(len(database)), nil
|
||||
}
|
||||
|
||||
func (r databaseRepo) Create(req *request.DatabaseCreate) error {
|
||||
func (r *databaseRepo) Create(req *request.DatabaseCreate) error {
|
||||
server, err := r.server.Get(req.ServerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
if req.CreateUser {
|
||||
if err = r.user.Create(&request.DatabaseUserCreate{
|
||||
ServerID: req.ServerID,
|
||||
@@ -102,22 +84,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = mysql.DatabaseCreate(req.Name); err != nil {
|
||||
if err = operator.DatabaseCreate(req.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Username != "" {
|
||||
if err = mysql.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil {
|
||||
if err = operator.PrivilegesGrant(req.Username, req.Name, req.Host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
if req.CreateUser {
|
||||
if err = r.user.Create(&request.DatabaseUserCreate{
|
||||
ServerID: req.ServerID,
|
||||
@@ -128,15 +103,15 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = postgres.DatabaseCreate(req.Name); err != nil {
|
||||
if err = operator.DatabaseCreate(req.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Username != "" {
|
||||
if err = postgres.PrivilegesGrant(req.Username, req.Name); err != nil {
|
||||
if err = operator.PrivilegesGrant(req.Username, req.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = postgres.DatabaseComment(req.Name, req.Comment); err != nil {
|
||||
if err = operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -144,55 +119,56 @@ func (r databaseRepo) Create(req *request.DatabaseCreate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r databaseRepo) Delete(serverID uint, name string) error {
|
||||
func (r *databaseRepo) Delete(serverID uint, name string) error {
|
||||
server, err := r.server.Get(serverID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
return mysql.DatabaseDrop(name)
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
return postgres.DatabaseDrop(name)
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return operator.DatabaseDrop(name)
|
||||
}
|
||||
|
||||
func (r databaseRepo) Comment(req *request.DatabaseComment) error {
|
||||
func (r *databaseRepo) Comment(req *request.DatabaseComment) error {
|
||||
server, err := r.server.Get(req.ServerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
return errors.New(r.t.Get("mysql not support database comment"))
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
return postgres.DatabaseComment(req.Name, req.Comment)
|
||||
return operator.(*db.Postgres).DatabaseComment(req.Name, req.Comment)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *databaseRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mysql, nil
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewDatabaseServerRepo(t *gotext.Locale, db *gorm.DB, log *slog.Logger) biz.
|
||||
}
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Count() (int64, error) {
|
||||
func (r *databaseServerRepo) Count() (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&biz.DatabaseServer{}).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
@@ -37,7 +37,7 @@ func (r databaseServerRepo) Count() (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) {
|
||||
func (r *databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64, error) {
|
||||
databaseServer := make([]*biz.DatabaseServer, 0)
|
||||
var total int64
|
||||
err := r.db.Model(&biz.DatabaseServer{}).Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&databaseServer).Error
|
||||
@@ -49,7 +49,7 @@ func (r databaseServerRepo) List(page, limit uint) ([]*biz.DatabaseServer, int64
|
||||
return databaseServer, total, err
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
|
||||
func (r *databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
|
||||
databaseServer := new(biz.DatabaseServer)
|
||||
if err := r.db.Where("id = ?", id).First(databaseServer).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -60,7 +60,7 @@ func (r databaseServerRepo) Get(id uint) (*biz.DatabaseServer, error) {
|
||||
return databaseServer, nil
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) {
|
||||
func (r *databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error) {
|
||||
databaseServer := new(biz.DatabaseServer)
|
||||
if err := r.db.Where("name = ?", name).First(databaseServer).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -71,7 +71,7 @@ func (r databaseServerRepo) GetByName(name string) (*biz.DatabaseServer, error)
|
||||
return databaseServer, nil
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
|
||||
func (r *databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
|
||||
databaseServer := &biz.DatabaseServer{
|
||||
Name: req.Name,
|
||||
Type: biz.DatabaseType(req.Type),
|
||||
@@ -89,7 +89,7 @@ func (r databaseServerRepo) Create(req *request.DatabaseServerCreate) error {
|
||||
return r.db.Create(databaseServer).Error
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
|
||||
func (r *databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
|
||||
server, err := r.Get(req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -109,11 +109,11 @@ func (r databaseServerRepo) Update(req *request.DatabaseServerUpdate) error {
|
||||
return r.db.Save(server).Error
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error {
|
||||
func (r *databaseServerRepo) UpdateRemark(req *request.DatabaseServerUpdateRemark) error {
|
||||
return r.db.Model(&biz.DatabaseServer{}).Where("id = ?", req.ID).Update("remark", req.Remark).Error
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Delete(id uint) error {
|
||||
func (r *databaseServerRepo) Delete(id uint) error {
|
||||
if err := r.ClearUsers(id); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,11 +122,11 @@ func (r databaseServerRepo) Delete(id uint) error {
|
||||
}
|
||||
|
||||
// ClearUsers 删除指定服务器的所有用户,只是删除面板记录,不会实际删除
|
||||
func (r databaseServerRepo) ClearUsers(serverID uint) error {
|
||||
func (r *databaseServerRepo) ClearUsers(serverID uint) error {
|
||||
return r.db.Where("server_id = ?", serverID).Delete(&biz.DatabaseUser{}).Error
|
||||
}
|
||||
|
||||
func (r databaseServerRepo) Sync(id uint) error {
|
||||
func (r *databaseServerRepo) Sync(id uint) error {
|
||||
server, err := r.Get(id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -137,16 +137,15 @@ func (r databaseServerRepo) Sync(id uint) error {
|
||||
return err
|
||||
}
|
||||
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer operator.Close()
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
allUsers, err := mysql.Users()
|
||||
allUsers, err := operator.Users()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -166,24 +165,17 @@ func (r databaseServerRepo) Sync(id uint) error {
|
||||
}
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
allUsers, err := postgres.Users()
|
||||
allUsers, err := operator.Users()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for user := range slices.Values(allUsers) {
|
||||
if !slices.ContainsFunc(users, func(a *biz.DatabaseUser) bool {
|
||||
return a.Username == user.Role
|
||||
}) && !slices.Contains([]string{"postgres"}, user.Role) {
|
||||
return a.Username == user.User
|
||||
}) && !slices.Contains([]string{"postgres"}, user.User) {
|
||||
newUser := &biz.DatabaseUser{
|
||||
ServerID: id,
|
||||
Username: user.Role,
|
||||
Username: user.User,
|
||||
Remark: r.t.Get("sync from server %s", server.Name),
|
||||
}
|
||||
r.db.Create(newUser)
|
||||
@@ -195,26 +187,19 @@ func (r databaseServerRepo) Sync(id uint) error {
|
||||
}
|
||||
|
||||
// checkServer 检查服务器连接
|
||||
func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
|
||||
func (r *databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
case biz.DatabaseTypeMysql, biz.DatabaseTypePostgresql:
|
||||
operator, err := r.getOperator(server)
|
||||
if err == nil {
|
||||
_ = mysql.Close()
|
||||
server.Status = biz.DatabaseServerStatusValid
|
||||
return true
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err == nil {
|
||||
_ = postgres.Close()
|
||||
operator.Close()
|
||||
server.Status = biz.DatabaseServerStatusValid
|
||||
return true
|
||||
}
|
||||
case biz.DatabaseTypeRedis:
|
||||
redis, err := db.NewRedis(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err == nil {
|
||||
_ = redis.Close()
|
||||
redis.Close()
|
||||
server.Status = biz.DatabaseServerStatusValid
|
||||
return true
|
||||
}
|
||||
@@ -223,3 +208,22 @@ func (r databaseServerRepo) checkServer(server *biz.DatabaseServer) bool {
|
||||
server.Status = biz.DatabaseServerStatusInvalid
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *databaseServerRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mysql, nil
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewDatabaseUserRepo(t *gotext.Locale, db *gorm.DB, server biz.DatabaseServe
|
||||
}
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) Count() (int64, error) {
|
||||
func (r *databaseUserRepo) Count() (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&biz.DatabaseUser{}).Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
@@ -35,7 +35,7 @@ func (r databaseUserRepo) Count() (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) {
|
||||
func (r *databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, error) {
|
||||
user := make([]*biz.DatabaseUser, 0)
|
||||
var total int64
|
||||
err := r.db.Model(&biz.DatabaseUser{}).Preload("Server").Order("id desc").Count(&total).Offset(int((page - 1) * limit)).Limit(int(limit)).Find(&user).Error
|
||||
@@ -47,7 +47,7 @@ func (r databaseUserRepo) List(page, limit uint) ([]*biz.DatabaseUser, int64, er
|
||||
return user, total, err
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
|
||||
func (r *databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
|
||||
user := new(biz.DatabaseUser)
|
||||
if err := r.db.Preload("Server").Where("id = ?", id).First(user).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -58,74 +58,49 @@ func (r databaseUserRepo) Get(id uint) (*biz.DatabaseUser, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
|
||||
func (r *databaseUserRepo) Create(req *request.DatabaseUserCreate) error {
|
||||
server, err := r.server.Get(req.ServerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user := new(biz.DatabaseUser)
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer operator.Close()
|
||||
|
||||
// 创建用户
|
||||
if err = operator.UserCreate(req.Username, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建数据库并授权
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = operator.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
if err = mysql.UserCreate(req.Username, req.Password, req.Host); err != nil {
|
||||
if err = operator.PrivilegesGrant(req.Username, name); err != nil {
|
||||
return err
|
||||
}
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = mysql.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mysql.PrivilegesGrant(req.Username, name, req.Host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
user = &biz.DatabaseUser{
|
||||
ServerID: req.ServerID,
|
||||
Username: req.Username,
|
||||
Host: req.Host,
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
if err = postgres.UserCreate(req.Username, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = postgres.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = postgres.PrivilegesGrant(req.Username, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
user = &biz.DatabaseUser{
|
||||
ServerID: req.ServerID,
|
||||
Username: req.Username,
|
||||
}
|
||||
}
|
||||
|
||||
user := &biz.DatabaseUser{
|
||||
ServerID: req.ServerID,
|
||||
Username: req.Username,
|
||||
Host: req.Host,
|
||||
Password: req.Password,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
|
||||
if err = r.db.FirstOrInit(user, user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = req.Password
|
||||
user.Remark = req.Remark
|
||||
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
|
||||
func (r *databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
|
||||
user, err := r.Get(req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -136,58 +111,36 @@ func (r databaseUserRepo) Update(req *request.DatabaseUserUpdate) error {
|
||||
return err
|
||||
}
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer operator.Close()
|
||||
|
||||
// 更新密码
|
||||
if req.Password != "" {
|
||||
if err = operator.UserPassword(user.Username, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
if req.Password != "" {
|
||||
if err = mysql.UserPassword(user.Username, req.Password, user.Host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = mysql.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mysql.PrivilegesGrant(user.Username, name, user.Host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
user.Password = req.Password
|
||||
}
|
||||
|
||||
// 创建数据库并授权
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = operator.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
if req.Password != "" {
|
||||
if err = postgres.UserPassword(user.Username, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for name := range slices.Values(req.Privileges) {
|
||||
if err = postgres.DatabaseCreate(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = postgres.PrivilegesGrant(user.Username, name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = operator.PrivilegesGrant(user.Username, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
user.Password = req.Password
|
||||
user.Remark = req.Remark
|
||||
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error {
|
||||
func (r *databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) error {
|
||||
user, err := r.Get(req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -198,7 +151,7 @@ func (r databaseUserRepo) UpdateRemark(req *request.DatabaseUserUpdateRemark) er
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) Delete(id uint) error {
|
||||
func (r *databaseUserRepo) Delete(id uint) error {
|
||||
user, err := r.Get(id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -209,45 +162,31 @@ func (r databaseUserRepo) Delete(id uint) error {
|
||||
return err
|
||||
}
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
_ = mysql.UserDrop(user.Username, user.Host)
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
_ = postgres.UserDrop(user.Username)
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer operator.Close()
|
||||
|
||||
_ = operator.UserDrop(user.Username)
|
||||
|
||||
return r.db.Where("id = ?", id).Delete(&biz.DatabaseUser{}).Error
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
|
||||
func (r *databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
|
||||
server, err := r.server.Get(serverID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator, err := r.getOperator(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer operator.Close()
|
||||
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
users := make([]*biz.DatabaseUser, 0)
|
||||
if err = r.db.Where("server_id = ? AND username IN ?", serverID, names).Find(&users).Error; err != nil {
|
||||
return err
|
||||
@@ -260,57 +199,42 @@ func (r databaseUserRepo) DeleteByNames(serverID uint, names []string) error {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = mysql.UserDrop(name, host)
|
||||
_ = operator.UserDrop(name, host)
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
for name := range slices.Values(names) {
|
||||
_ = postgres.UserDrop(name)
|
||||
_ = operator.UserDrop(name)
|
||||
}
|
||||
}
|
||||
|
||||
return r.db.Where("server_id = ? AND username IN ?", serverID, names).Delete(&biz.DatabaseUser{}).Error
|
||||
}
|
||||
|
||||
func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) {
|
||||
func (r *databaseUserRepo) fillUser(user *biz.DatabaseUser) {
|
||||
server, err := r.server.Get(user.ServerID)
|
||||
if err == nil {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err == nil {
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
privileges, _ := mysql.UserPrivileges(user.Username, user.Host)
|
||||
operator, err := r.getOperator(server)
|
||||
if err == nil {
|
||||
defer operator.Close()
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
privileges, _ := operator.UserPrivileges(user.Username, user.Host)
|
||||
user.Privileges = privileges
|
||||
}
|
||||
if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil {
|
||||
_ = mysql2.Close()
|
||||
user.Status = biz.DatabaseUserStatusValid
|
||||
} else {
|
||||
user.Status = biz.DatabaseUserStatusInvalid
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err == nil {
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
privileges, _ := postgres.UserPrivileges(user.Username)
|
||||
if mysql2, err := db.NewMySQL(user.Username, user.Password, fmt.Sprintf("%s:%d", server.Host, server.Port)); err == nil {
|
||||
mysql2.Close()
|
||||
user.Status = biz.DatabaseUserStatusValid
|
||||
} else {
|
||||
user.Status = biz.DatabaseUserStatusInvalid
|
||||
}
|
||||
case biz.DatabaseTypePostgresql:
|
||||
privileges, _ := operator.UserPrivileges(user.Username)
|
||||
user.Privileges = privileges
|
||||
}
|
||||
if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil {
|
||||
_ = postgres2.Close()
|
||||
user.Status = biz.DatabaseUserStatusValid
|
||||
} else {
|
||||
user.Status = biz.DatabaseUserStatusInvalid
|
||||
if postgres2, err := db.NewPostgres(user.Username, user.Password, server.Host, server.Port); err == nil {
|
||||
postgres2.Close()
|
||||
user.Status = biz.DatabaseUserStatusValid
|
||||
} else {
|
||||
user.Status = biz.DatabaseUserStatusInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,3 +243,22 @@ func (r databaseUserRepo) fillUser(user *biz.DatabaseUser) {
|
||||
user.Privileges = make([]string, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *databaseUserRepo) getOperator(server *biz.DatabaseServer) (db.Operator, error) {
|
||||
switch server.Type {
|
||||
case biz.DatabaseTypeMysql:
|
||||
mysql, err := db.NewMySQL(server.Username, server.Password, fmt.Sprintf("%s:%d", server.Host, server.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mysql, nil
|
||||
case biz.DatabaseTypePostgresql:
|
||||
postgres, err := db.NewPostgres(server.Username, server.Password, server.Host, server.Port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", server.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,9 +145,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
rootPassword, _ := s.settingRepo.Get(biz.SettingKeyMySQLRootPassword)
|
||||
mysql, err := db.NewMySQL("root", rootPassword, "/tmp/mysql.sock", "unix")
|
||||
if err == nil {
|
||||
defer func(mysql *db.MySQL) {
|
||||
_ = mysql.Close()
|
||||
}(mysql)
|
||||
defer mysql.Close()
|
||||
databases, err := mysql.Databases()
|
||||
if err == nil {
|
||||
databaseCount += len(databases)
|
||||
@@ -157,9 +155,7 @@ func (s *HomeService) CountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if postgresqlInstalled {
|
||||
postgres, err := db.NewPostgres("postgres", "", "127.0.0.1", 5432)
|
||||
if err == nil {
|
||||
defer func(postgres *db.Postgres) {
|
||||
_ = postgres.Close()
|
||||
}(postgres)
|
||||
defer postgres.Close()
|
||||
databases, err := postgres.Databases()
|
||||
if err == nil {
|
||||
databaseCount += len(databases)
|
||||
|
||||
29
pkg/db/db.go
Normal file
29
pkg/db/db.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package db
|
||||
|
||||
import "database/sql"
|
||||
|
||||
type Operator interface {
|
||||
Close()
|
||||
Ping() error
|
||||
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
|
||||
DatabaseCreate(name string) error
|
||||
DatabaseDrop(name string) error
|
||||
DatabaseExists(name string) (bool, error)
|
||||
DatabaseSize(name string) (int64, error)
|
||||
|
||||
UserCreate(user, password string, host ...string) error
|
||||
UserDrop(user string, host ...string) error
|
||||
UserPassword(user, password string, host ...string) error
|
||||
UserPrivileges(user string, host ...string) ([]string, error)
|
||||
|
||||
PrivilegesGrant(user, database string, host ...string) error
|
||||
PrivilegesRevoke(user, database string, host ...string) error
|
||||
|
||||
Users() ([]User, error)
|
||||
Databases() ([]Database, error)
|
||||
}
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"slices"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
||||
"github.com/acepanel/panel/pkg/types"
|
||||
)
|
||||
|
||||
type MySQL struct {
|
||||
@@ -18,7 +16,7 @@ type MySQL struct {
|
||||
address string
|
||||
}
|
||||
|
||||
func NewMySQL(username, password, address string, typ ...string) (*MySQL, error) {
|
||||
func NewMySQL(username, password, address string, typ ...string) (Operator, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/", username, password, address)
|
||||
if len(typ) > 0 && typ[0] == "unix" {
|
||||
dsn = fmt.Sprintf("%s:%s@unix(%s)/", username, password, address)
|
||||
@@ -38,8 +36,8 @@ func NewMySQL(username, password, address string, typ ...string) (*MySQL, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *MySQL) Close() error {
|
||||
return r.db.Close()
|
||||
func (r *MySQL) Close() {
|
||||
_ = r.db.Close()
|
||||
}
|
||||
|
||||
func (r *MySQL) Ping() error {
|
||||
@@ -101,44 +99,43 @@ func (r *MySQL) DatabaseSize(name string) (int64, error) {
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (r *MySQL) UserCreate(user, password, host string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host, password))
|
||||
func (r *MySQL) UserCreate(user, password string, host ...string) error {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
_, err := r.Exec(fmt.Sprintf("CREATE USER IF NOT EXISTS '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) UserDrop(user, host string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host))
|
||||
func (r *MySQL) UserDrop(user string, host ...string) error {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%s'", user, host[0]))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) UserPassword(user, password, host string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host, password))
|
||||
func (r *MySQL) UserPassword(user, password string, host ...string) error {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
_, err := r.Exec(fmt.Sprintf("ALTER USER '%s'@'%s' IDENTIFIED BY '%s'", user, host[0], password))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) PrivilegesGrant(user, database, host string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
func (r *MySQL) UserPrivileges(user string, host ...string) ([]string, error) {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
|
||||
func (r *MySQL) PrivilegesRevoke(user, database, host string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) UserPrivileges(user, host string) ([]string, error) {
|
||||
rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host))
|
||||
rows, err := r.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host[0]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(rows *sql.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
defer func(rows *sql.Rows) { _ = rows.Close() }(rows)
|
||||
|
||||
re := regexp.MustCompile(`GRANT\s+ALL PRIVILEGES\s+ON\s+[\x60'"]?([^\s\x60'"]+)[\x60'"]?\.\*\s+TO\s+`)
|
||||
var databases []string
|
||||
@@ -166,16 +163,32 @@ func (r *MySQL) UserPrivileges(user, host string) ([]string, error) {
|
||||
return slices.Compact(databases), nil
|
||||
}
|
||||
|
||||
func (r *MySQL) Users() ([]types.MySQLUser, error) {
|
||||
func (r *MySQL) PrivilegesGrant(user, database string, host ...string) error {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
_, err := r.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s'", database, user, host[0]))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) PrivilegesRevoke(user, database string, host ...string) error {
|
||||
if len(host) == 0 {
|
||||
host = append(host, "%")
|
||||
}
|
||||
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON %s.* FROM '%s'@'%s'", database, user, host[0]))
|
||||
r.flushPrivileges()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MySQL) Users() ([]User, error) {
|
||||
rows, err := r.Query("SELECT user, host FROM mysql.user")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(rows *sql.Rows) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
defer func(rows *sql.Rows) { _ = rows.Close() }(rows)
|
||||
|
||||
var users []types.MySQLUser
|
||||
var users []User
|
||||
for rows.Next() {
|
||||
var user, host string
|
||||
if err := rows.Scan(&user, &host); err != nil {
|
||||
@@ -186,7 +199,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, types.MySQLUser{
|
||||
users = append(users, User{
|
||||
User: user,
|
||||
Host: host,
|
||||
Grants: grants,
|
||||
@@ -196,7 +209,7 @@ func (r *MySQL) Users() ([]types.MySQLUser, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *MySQL) Databases() ([]types.MySQLDatabase, error) {
|
||||
func (r *MySQL) Databases() ([]Database, error) {
|
||||
query := `
|
||||
SELECT
|
||||
SCHEMA_NAME,
|
||||
@@ -214,9 +227,9 @@ func (r *MySQL) Databases() ([]types.MySQLDatabase, error) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var databases []types.MySQLDatabase
|
||||
var databases []Database
|
||||
for rows.Next() {
|
||||
var db types.MySQLDatabase
|
||||
var db Database
|
||||
if err = rows.Scan(&db.Name, &db.CharSet, &db.Collation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/acepanel/panel/pkg/systemctl"
|
||||
"github.com/acepanel/panel/pkg/types"
|
||||
)
|
||||
|
||||
type Postgres struct {
|
||||
@@ -20,7 +19,7 @@ type Postgres struct {
|
||||
port uint
|
||||
}
|
||||
|
||||
func NewPostgres(username, password, address string, port uint) (*Postgres, error) {
|
||||
func NewPostgres(username, password, address string, port uint) (Operator, error) {
|
||||
username = strings.ReplaceAll(username, `'`, `\'`)
|
||||
password = strings.ReplaceAll(password, `'`, `\'`)
|
||||
dsn := fmt.Sprintf(`host=%s port=%d user='%s' password='%s' dbname=postgres sslmode=disable`, address, port, username, password)
|
||||
@@ -46,8 +45,8 @@ func NewPostgres(username, password, address string, port uint) (*Postgres, erro
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Postgres) Close() error {
|
||||
return r.db.Close()
|
||||
func (r *Postgres) Close() {
|
||||
_ = r.db.Close()
|
||||
}
|
||||
|
||||
func (r *Postgres) Ping() error {
|
||||
@@ -72,7 +71,7 @@ func (r *Postgres) Prepare(query string) (*sql.Stmt, error) {
|
||||
|
||||
func (r *Postgres) DatabaseCreate(name string) error {
|
||||
// postgres 不支持 CREATE DATABASE IF NOT EXISTS,但是为了保持与 MySQL 一致,先检查数据库是否存在
|
||||
exist, err := r.DatabaseExist(name)
|
||||
exist, err := r.DatabaseExists(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -88,7 +87,7 @@ func (r *Postgres) DatabaseDrop(name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Postgres) DatabaseExist(name string) (bool, error) {
|
||||
func (r *Postgres) DatabaseExists(name string) (bool, error) {
|
||||
var count int
|
||||
if err := r.QueryRow("SELECT COUNT(*) FROM pg_database WHERE datname = $1", name).Scan(&count); err != nil {
|
||||
return false, err
|
||||
@@ -110,7 +109,7 @@ func (r *Postgres) DatabaseComment(name, comment string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Postgres) UserCreate(user, password string) error {
|
||||
func (r *Postgres) UserCreate(user, password string, host ...string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'", user, password))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -119,7 +118,7 @@ func (r *Postgres) UserCreate(user, password string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Postgres) UserDrop(user string) error {
|
||||
func (r *Postgres) UserDrop(user string, host ...string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("DROP USER IF EXISTS %s", user))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -128,12 +127,12 @@ func (r *Postgres) UserDrop(user string) error {
|
||||
return systemctl.Reload("postgresql")
|
||||
}
|
||||
|
||||
func (r *Postgres) UserPassword(user, password string) error {
|
||||
func (r *Postgres) UserPassword(user, password string, host ...string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("ALTER USER %s WITH PASSWORD '%s'", user, password))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Postgres) UserPrivileges(user string) ([]string, error) {
|
||||
func (r *Postgres) UserPrivileges(user string, host ...string) ([]string, error) {
|
||||
query := `
|
||||
SELECT d.datname
|
||||
FROM pg_catalog.pg_database d
|
||||
@@ -169,7 +168,7 @@ func (r *Postgres) UserPrivileges(user string) ([]string, error) {
|
||||
return databases, nil
|
||||
}
|
||||
|
||||
func (r *Postgres) PrivilegesGrant(user, database string) error {
|
||||
func (r *Postgres) PrivilegesGrant(user, database string, host ...string) error {
|
||||
if _, err := r.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", database, user)); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -180,12 +179,12 @@ func (r *Postgres) PrivilegesGrant(user, database string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Postgres) PrivilegesRevoke(user, database string) error {
|
||||
func (r *Postgres) PrivilegesRevoke(user, database string, host ...string) error {
|
||||
_, err := r.Exec(fmt.Sprintf("REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s", database, user))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Postgres) Users() ([]types.PostgresUser, error) {
|
||||
func (r *Postgres) Users() ([]User, error) {
|
||||
query := `
|
||||
SELECT rolname,
|
||||
rolsuper,
|
||||
@@ -204,11 +203,11 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var users []types.PostgresUser
|
||||
var users []User
|
||||
for rows.Next() {
|
||||
var user types.PostgresUser
|
||||
var user User
|
||||
var super, canCreateRole, canCreateDb, replication, bypassRls bool
|
||||
if err = rows.Scan(&user.Role, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil {
|
||||
if err = rows.Scan(&user.User, &super, &canCreateRole, &canCreateDb, &replication, &bypassRls); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -221,12 +220,12 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
|
||||
}
|
||||
for perm, enabled := range permissions {
|
||||
if enabled {
|
||||
user.Attributes = append(user.Attributes, perm)
|
||||
user.Grants = append(user.Grants, perm)
|
||||
}
|
||||
}
|
||||
|
||||
if len(user.Attributes) == 0 {
|
||||
user.Attributes = append(user.Attributes, "无")
|
||||
if len(user.Grants) == 0 {
|
||||
user.Grants = append(user.Grants, "None")
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
@@ -239,7 +238,7 @@ func (r *Postgres) Users() ([]types.PostgresUser, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *Postgres) Databases() ([]types.PostgresDatabase, error) {
|
||||
func (r *Postgres) Databases() ([]Database, error) {
|
||||
query := `
|
||||
SELECT
|
||||
d.datname,
|
||||
@@ -257,10 +256,10 @@ func (r *Postgres) Databases() ([]types.PostgresDatabase, error) {
|
||||
_ = rows.Close()
|
||||
}(rows)
|
||||
|
||||
var databases []types.PostgresDatabase
|
||||
var databases []Database
|
||||
for rows.Next() {
|
||||
var db types.PostgresDatabase
|
||||
if err := rows.Scan(&db.Name, &db.Owner, &db.Encoding, &db.Comment); err != nil {
|
||||
var db Database
|
||||
if err := rows.Scan(&db.Name, &db.Owner, &db.CharSet, &db.Comment); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if slices.Contains([]string{"template0", "template1", "postgres"}, db.Name) {
|
||||
|
||||
@@ -39,8 +39,8 @@ func NewRedis(username, password, address string) (*Redis, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Redis) Close() error {
|
||||
return r.conn.Close()
|
||||
func (r *Redis) Close() {
|
||||
_ = r.conn.Close()
|
||||
}
|
||||
|
||||
func (r *Redis) Exec(command string, args ...any) (any, error) {
|
||||
|
||||
15
pkg/db/types.go
Normal file
15
pkg/db/types.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package db
|
||||
|
||||
type User struct {
|
||||
User string `json:"user"` // 用户名,PG 里面对应 Role
|
||||
Host string `json:"host"` // 主机,PG 这个字段为空
|
||||
Grants []string `json:"grants"` // 权限列表
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
Name string `json:"name"` // 数据库名
|
||||
Owner string `json:"owner"` // 所有者,MySQL 这个字段为空
|
||||
CharSet string `json:"char_set"` // 字符集,PG 里面对应 Encoding
|
||||
Collation string `json:"collation"` // 校对集,PG 这个字段为空
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package types
|
||||
|
||||
type MySQLUser struct {
|
||||
User string `json:"user"`
|
||||
Host string `json:"host"`
|
||||
Grants []string `json:"grants"`
|
||||
}
|
||||
|
||||
type MySQLDatabase struct {
|
||||
Name string `json:"name"`
|
||||
CharSet string `json:"char_set"`
|
||||
Collation string `json:"collation"`
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package types
|
||||
|
||||
type PostgresUser struct {
|
||||
Role string `json:"role"`
|
||||
Attributes []string `json:"attributes"`
|
||||
}
|
||||
|
||||
type PostgresDatabase struct {
|
||||
Name string `json:"name"`
|
||||
Owner string `json:"owner"`
|
||||
Encoding string `json:"encoding"`
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
Reference in New Issue
Block a user