2
0
mirror of https://github.com/acepanel/panel.git synced 2026-02-04 06:47:20 +08:00
Files
panel/pkg/storage/sftp.go
2026-01-22 02:54:44 +08:00

221 lines
4.7 KiB
Go

package storage
import (
"fmt"
"io"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
type SFTPConfig struct {
Host string // SFTP 服务器地址
Port int // SFTP 端口,默认 22
Username string // 用户名
Password string // 密码
PrivateKey string // SSH 私钥
BasePath string // 基础路径
Timeout time.Duration // 连接超时时间
}
type SFTP struct {
config SFTPConfig
}
func NewSFTP(config SFTPConfig) (Storage, error) {
if config.Port == 0 {
config.Port = 22
}
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
config.BasePath = strings.TrimSuffix(config.BasePath, "/")
if config.Username == "" || (config.Password == "" && config.PrivateKey == "") {
return nil, fmt.Errorf("username and either password or private key must be provided")
}
return &SFTP{config: config}, nil
}
// connect 建立 SFTP 连接,返回 client 和 cleanup 函数
func (s *SFTP) connect() (*sftp.Client, func(), error) {
var auth []ssh.AuthMethod
// 密码认证
if s.config.Password != "" {
auth = append(auth, ssh.Password(s.config.Password))
}
// 私钥认证
if s.config.PrivateKey != "" {
signer, err := ssh.ParsePrivateKey([]byte(s.config.PrivateKey))
if err != nil {
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
}
auth = append(auth, ssh.PublicKeys(signer))
}
clientConfig := &ssh.ClientConfig{
User: s.config.Username,
Auth: auth,
Timeout: s.config.Timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
sshClient, err := ssh.Dial("tcp", addr, clientConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SSH server: %w", err)
}
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err)
}
cleanup := func() {
_ = sftpClient.Close()
_ = sshClient.Close()
}
return sftpClient, cleanup, nil
}
// Delete 删除文件
func (s *SFTP) Delete(files ...string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
for _, file := range files {
remotePath := s.getRemotePath(file)
if err = client.Remove(remotePath); err != nil {
return err
}
}
return nil
}
// Exists 检查文件是否存在
func (s *SFTP) Exists(file string) bool {
client, cleanup, err := s.connect()
if err != nil {
return false
}
defer cleanup()
remotePath := s.getRemotePath(file)
_, err = client.Stat(remotePath)
return err == nil
}
// LastModified 获取文件最后修改时间
func (s *SFTP) LastModified(file string) (time.Time, error) {
client, cleanup, err := s.connect()
if err != nil {
return time.Time{}, err
}
defer cleanup()
remotePath := s.getRemotePath(file)
stat, err := client.Stat(remotePath)
if err != nil {
return time.Time{}, err
}
return stat.ModTime(), nil
}
// List 列出目录下的所有文件
func (s *SFTP) List(path string) ([]string, error) {
client, cleanup, err := s.connect()
if err != nil {
return nil, err
}
defer cleanup()
// 确保基础路径存在
if s.config.BasePath != "" {
_ = client.MkdirAll(s.config.BasePath)
}
remotePath := s.getRemotePath(path)
entries, err := client.ReadDir(remotePath)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files, nil
}
// Put 写入文件内容
func (s *SFTP) Put(file string, content io.Reader) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
remotePath := s.getRemotePath(file)
// 确保目录存在
remoteDir := filepath.Dir(remotePath)
if remoteDir != "." {
_ = client.MkdirAll(remoteDir)
}
// 确保基础路径存在
if s.config.BasePath != "" {
_ = client.MkdirAll(s.config.BasePath)
}
remoteFile, err := client.Create(remotePath)
if err != nil {
return err
}
defer func() { _ = remoteFile.Close() }()
_, err = io.Copy(remoteFile, content)
return err
}
// Size 获取文件大小
func (s *SFTP) Size(file string) (int64, error) {
client, cleanup, err := s.connect()
if err != nil {
return 0, err
}
defer cleanup()
remotePath := s.getRemotePath(file)
stat, err := client.Stat(remotePath)
if err != nil {
return 0, err
}
return stat.Size(), nil
}
// getRemotePath 获取远程路径
func (s *SFTP) getRemotePath(path string) string {
path = strings.TrimPrefix(path, "/")
if s.config.BasePath == "" {
return path
}
if path == "" {
return s.config.BasePath
}
return filepath.Join(s.config.BasePath, path)
}