2
0
mirror of https://github.com/acepanel/panel.git synced 2026-02-04 11:27:17 +08:00

feat: 添加storage包

This commit is contained in:
2025-09-18 03:53:02 +08:00
parent e5d16fdf9b
commit 38edda3650
8 changed files with 1635 additions and 1 deletions

341
pkg/storage/ftp.go Normal file
View File

@@ -0,0 +1,341 @@
package storage
import (
"bytes"
"fmt"
"io"
"mime"
"path/filepath"
"strings"
"time"
"github.com/jlaffaye/ftp"
)
type FTPConfig struct {
Host string // FTP 服务器地址
Port int // FTP 端口,默认 21
Username string // 用户名
Password string // 密码
BasePath string // 基础路径
}
type FTP struct {
config FTPConfig
}
func NewFTP(config FTPConfig) (Storage, error) {
if config.Port == 0 {
config.Port = 21
}
config.BasePath = strings.Trim(config.BasePath, "/")
f := &FTP{
config: config,
}
if err := f.ensureBasePath(); err != nil {
return nil, fmt.Errorf("failed to ensure base path: %w", err)
}
return f, nil
}
// connect 建立 FTP 连接
func (f *FTP) connect() (*ftp.ServerConn, error) {
addr := fmt.Sprintf("%s:%d", f.config.Host, f.config.Port)
conn, err := ftp.Dial(addr)
if err != nil {
return nil, err
}
err = conn.Login(f.config.Username, f.config.Password)
if err != nil {
conn.Quit()
return nil, err
}
return conn, nil
}
// ensureBasePath 确保基础路径存在
func (f *FTP) ensureBasePath() error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
// 递归创建路径
parts := strings.Split(f.config.BasePath, "/")
currentPath := ""
for _, part := range parts {
if part == "" {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
_ = conn.MakeDir(currentPath)
}
return nil
}
// getRemotePath 获取远程路径
func (f *FTP) getRemotePath(path string) string {
path = strings.TrimPrefix(path, "/")
if f.config.BasePath == "" {
return path
}
if path == "" {
return f.config.BasePath
}
return fmt.Sprintf("%s/%s", f.config.BasePath, path)
}
// MakeDirectory 创建目录
func (f *FTP) MakeDirectory(directory string) error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
remotePath := f.getRemotePath(directory)
// 递归创建目录
parts := strings.Split(remotePath, "/")
currentPath := ""
for _, part := range parts {
if part == "" {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
// 尝试创建目录
_ = conn.MakeDir(currentPath)
}
return nil
}
// DeleteDirectory 删除目录
func (f *FTP) DeleteDirectory(directory string) error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
remotePath := f.getRemotePath(directory)
return conn.RemoveDir(remotePath)
}
// Copy 复制文件到新位置
func (f *FTP) Copy(oldFile, newFile string) error {
// FTP 不支持直接复制,需要下载再上传
data, err := f.Get(oldFile)
if err != nil {
return err
}
return f.Put(newFile, string(data))
}
// Delete 删除文件
func (f *FTP) Delete(files ...string) error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
for _, file := range files {
remotePath := f.getRemotePath(file)
if err := conn.Delete(remotePath); err != nil {
return err
}
}
return nil
}
// Exists 检查文件是否存在
func (f *FTP) Exists(file string) bool {
conn, err := f.connect()
if err != nil {
return false
}
defer conn.Quit()
remotePath := f.getRemotePath(file)
_, err = conn.FileSize(remotePath)
return err == nil
}
// Files 获取目录下的所有文件
func (f *FTP) Files(path string) ([]string, error) {
conn, err := f.connect()
if err != nil {
return nil, err
}
defer conn.Quit()
remotePath := f.getRemotePath(path)
entries, err := conn.List(remotePath)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if entry.Type == ftp.EntryTypeFile {
files = append(files, entry.Name)
}
}
return files, nil
}
// Get 读取文件内容
func (f *FTP) Get(file string) ([]byte, error) {
conn, err := f.connect()
if err != nil {
return nil, err
}
defer conn.Quit()
remotePath := f.getRemotePath(file)
resp, err := conn.Retr(remotePath)
if err != nil {
return nil, err
}
defer resp.Close()
return io.ReadAll(resp)
}
// LastModified 获取文件最后修改时间
func (f *FTP) LastModified(file string) (time.Time, error) {
conn, err := f.connect()
if err != nil {
return time.Time{}, err
}
defer conn.Quit()
remotePath := f.getRemotePath(file)
entries, err := conn.List(filepath.Dir(remotePath))
if err != nil {
return time.Time{}, err
}
fileName := filepath.Base(remotePath)
for _, entry := range entries {
if entry.Name == fileName {
return entry.Time, nil
}
}
return time.Time{}, fmt.Errorf("file not found: %s", file)
}
// MimeType 获取文件的 MIME 类型
func (f *FTP) MimeType(file string) (string, error) {
ext := filepath.Ext(file)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
return "application/octet-stream", nil
}
return mimeType, nil
}
// Missing 检查文件是否不存在
func (f *FTP) Missing(file string) bool {
return !f.Exists(file)
}
// Move 移动文件到新位置
func (f *FTP) Move(oldFile, newFile string) error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
oldPath := f.getRemotePath(oldFile)
newPath := f.getRemotePath(newFile)
// 确保目标目录存在
newDir := filepath.Dir(newPath)
if newDir != "." {
f.createDirectoryPath(conn, newDir)
}
return conn.Rename(oldPath, newPath)
}
// createDirectoryPath 递归创建目录路径
func (f *FTP) createDirectoryPath(conn *ftp.ServerConn, path string) {
parts := strings.Split(path, "/")
currentPath := ""
for _, part := range parts {
if part == "" {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
_ = conn.MakeDir(currentPath)
}
}
// Path 获取文件的完整路径
func (f *FTP) Path(file string) string {
return fmt.Sprintf("ftp://%s:%d/%s", f.config.Host, f.config.Port, f.getRemotePath(file))
}
// Put 写入文件内容
func (f *FTP) Put(file, content string) error {
conn, err := f.connect()
if err != nil {
return err
}
defer conn.Quit()
remotePath := f.getRemotePath(file)
// 确保目录存在
remoteDir := filepath.Dir(remotePath)
if remoteDir != "." {
f.createDirectoryPath(conn, remoteDir)
}
return conn.Stor(remotePath, bytes.NewReader([]byte(content)))
}
// Size 获取文件大小
func (f *FTP) Size(file string) (int64, error) {
conn, err := f.connect()
if err != nil {
return 0, err
}
defer conn.Quit()
remotePath := f.getRemotePath(file)
return conn.FileSize(remotePath)
}

175
pkg/storage/local.go Normal file
View File

@@ -0,0 +1,175 @@
package storage
import (
"io"
"mime"
"os"
"path/filepath"
"time"
)
type Local struct {
basePath string
}
func NewLocal(basePath string) Storage {
if basePath == "" {
basePath = "/"
}
return &Local{
basePath: basePath,
}
}
// MakeDirectory 创建目录
func (n *Local) MakeDirectory(directory string) error {
fullPath := n.fullPath(directory)
return os.MkdirAll(fullPath, 0755)
}
// DeleteDirectory 删除目录
func (n *Local) DeleteDirectory(directory string) error {
fullPath := n.fullPath(directory)
return os.RemoveAll(fullPath)
}
// Copy 复制文件到新位置
func (n *Local) Copy(oldFile, newFile string) error {
srcPath := n.fullPath(oldFile)
dstPath := n.fullPath(newFile)
// 确保目标目录存在
if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil {
return err
}
src, err := os.Open(srcPath)
if err != nil {
return err
}
defer func() { _ = src.Close() }()
dst, err := os.Create(dstPath)
if err != nil {
return err
}
defer func() { _ = dst.Close() }()
_, err = io.Copy(dst, src)
return err
}
// Delete 删除文件
func (n *Local) Delete(files ...string) error {
for _, file := range files {
fullPath := n.fullPath(file)
if err := os.Remove(fullPath); err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}
// Exists 检查文件是否存在
func (n *Local) Exists(file string) bool {
fullPath := n.fullPath(file)
_, err := os.Stat(fullPath)
return !os.IsNotExist(err)
}
// Files 获取目录下的所有文件
func (n *Local) Files(path string) ([]string, error) {
fullPath := n.fullPath(path)
entries, err := os.ReadDir(fullPath)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files, nil
}
// Get 读取文件内容
func (n *Local) Get(file string) ([]byte, error) {
fullPath := n.fullPath(file)
return os.ReadFile(fullPath)
}
// LastModified 获取文件最后修改时间
func (n *Local) LastModified(file string) (time.Time, error) {
fullPath := n.fullPath(file)
info, err := os.Stat(fullPath)
if err != nil {
return time.Time{}, err
}
return info.ModTime(), nil
}
// MimeType 获取文件的 MIME 类型
func (n *Local) MimeType(file string) (string, error) {
ext := filepath.Ext(file)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
return "application/octet-stream", nil
}
return mimeType, nil
}
// Missing 检查文件是否不存在
func (n *Local) Missing(file string) bool {
return !n.Exists(file)
}
// Move 移动文件到新位置
func (n *Local) Move(oldFile, newFile string) error {
oldPath := n.fullPath(oldFile)
newPath := n.fullPath(newFile)
// 确保目标目录存在
if err := os.MkdirAll(filepath.Dir(newPath), 0755); err != nil {
return err
}
return os.Rename(oldPath, newPath)
}
// Path 获取文件的完整路径
func (n *Local) Path(file string) string {
return n.fullPath(file)
}
// Put 写入文件内容
func (n *Local) Put(file, content string) error {
fullPath := n.fullPath(file)
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil {
return err
}
return os.WriteFile(fullPath, []byte(content), 0644)
}
// Size 获取文件大小
func (n *Local) Size(file string) (int64, error) {
fullPath := n.fullPath(file)
info, err := os.Stat(fullPath)
if err != nil {
return 0, err
}
return info.Size(), nil
}
// fullPath 获取文件的完整路径
func (n *Local) fullPath(file string) string {
if filepath.IsAbs(file) {
return file
}
return filepath.Join(n.basePath, file)
}

403
pkg/storage/s3.go Normal file
View File

@@ -0,0 +1,403 @@
package storage
import (
"bytes"
"context"
"fmt"
"io"
"mime"
"path/filepath"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
// S3AddressingStyle S3 地址模式
type S3AddressingStyle string
const (
// S3AddressingStylePath Path 模式https://s3.region.amazonaws.com/bucket/key
S3AddressingStylePath S3AddressingStyle = "path"
// S3AddressingStyleVirtualHosted Virtual Hosted 模式https://bucket.s3.region.amazonaws.com/key
S3AddressingStyleVirtualHosted S3AddressingStyle = "virtual-hosted"
)
type S3Config struct {
Region string // AWS 区域
Bucket string // S3 存储桶名称
AccessKeyID string // 访问密钥 ID
SecretAccessKey string // 访问密钥
Endpoint string // 自定义端点(如 MinIO
BasePath string // 基础路径前缀
AddressingStyle S3AddressingStyle // 地址模式
ForcePathStyle bool // 强制使用 Path 模式(兼容旧版本)
}
type S3 struct {
client *s3.Client
config S3Config
}
func NewS3(cfg S3Config) (Storage, error) {
// 设置默认地址模式
if cfg.AddressingStyle == "" {
if cfg.ForcePathStyle {
cfg.AddressingStyle = S3AddressingStylePath
} else {
cfg.AddressingStyle = S3AddressingStyleVirtualHosted
}
}
cfg.BasePath = strings.Trim(cfg.BasePath, "/")
var awsCfg aws.Config
var err error
if cfg.Endpoint != "" {
// 自定义端点(如 MinIO
awsCfg, err = config.LoadDefaultConfig(context.TODO(),
config.WithRegion(cfg.Region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.AccessKeyID, cfg.SecretAccessKey, "")),
config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{
URL: cfg.Endpoint,
SigningRegion: cfg.Region,
}, nil
})),
)
} else {
// 标准 AWS S3
awsCfg, err = config.LoadDefaultConfig(context.TODO(),
config.WithRegion(cfg.Region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.AccessKeyID, cfg.SecretAccessKey, "")),
)
}
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
// 根据地址模式配置客户端
usePathStyle := cfg.AddressingStyle == S3AddressingStylePath || cfg.ForcePathStyle
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
o.UsePathStyle = usePathStyle
})
s := &S3{
client: client,
config: cfg,
}
if s.config.BasePath != "" {
if err := s.ensureBasePath(); err != nil {
return nil, fmt.Errorf("failed to ensure base path: %w", err)
}
}
return s, nil
}
// ensureBasePath 确保基础路径存在
func (s *S3) ensureBasePath() error {
key := s.config.BasePath + "/"
_, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(s.config.BasePath),
Key: aws.String(key),
Body: bytes.NewReader([]byte{}),
})
return err
}
// getKey 获取完整的对象键
func (s *S3) getKey(file string) string {
file = strings.TrimPrefix(file, "/")
if s.config.BasePath == "" {
return file
}
if file == "" {
return s.config.BasePath
}
return fmt.Sprintf("%s/%s", s.config.BasePath, file)
}
// MakeDirectory 创建目录S3中实际创建一个空的目录标记对象
func (s *S3) MakeDirectory(directory string) error {
key := s.getKey(directory)
if !strings.HasSuffix(key, "/") {
key += "/"
}
_, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
Body: bytes.NewReader([]byte{}),
})
return err
}
// DeleteDirectory 删除目录
func (s *S3) DeleteDirectory(directory string) error {
prefix := s.getKey(directory)
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
// 列出所有文件
var objects []types.ObjectIdentifier
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(s.config.Bucket),
Prefix: aws.String(prefix),
})
for paginator.HasMorePages() {
output, err := paginator.NextPage(context.TODO())
if err != nil {
return err
}
for _, obj := range output.Contents {
if obj.Key != nil {
objects = append(objects, types.ObjectIdentifier{
Key: obj.Key,
})
}
}
}
if len(objects) == 0 {
return nil
}
// 批量删除
_, err := s.client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{
Bucket: aws.String(s.config.Bucket),
Delete: &types.Delete{
Objects: objects,
},
})
return err
}
// Copy 复制文件到新位置
func (s *S3) Copy(oldFile, newFile string) error {
sourceKey := s.getKey(oldFile)
destKey := s.getKey(newFile)
_, err := s.client.CopyObject(context.TODO(), &s3.CopyObjectInput{
Bucket: aws.String(s.config.Bucket),
CopySource: aws.String(fmt.Sprintf("%s/%s", s.config.Bucket, sourceKey)),
Key: aws.String(destKey),
})
return err
}
// Delete 删除文件
func (s *S3) Delete(files ...string) error {
if len(files) == 0 {
return nil
}
// 批量删除
var objects []types.ObjectIdentifier
for _, file := range files {
key := s.getKey(file)
objects = append(objects, types.ObjectIdentifier{
Key: aws.String(key),
})
}
_, err := s.client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{
Bucket: aws.String(s.config.Bucket),
Delete: &types.Delete{
Objects: objects,
},
})
return err
}
// Exists 检查文件是否存在
func (s *S3) Exists(file string) bool {
key := s.getKey(file)
_, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
return err == nil
}
// Files 获取目录下的所有文件
func (s *S3) Files(path string) ([]string, error) {
prefix := s.getKey(path)
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
var files []string
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(s.config.Bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
})
for paginator.HasMorePages() {
output, err := paginator.NextPage(context.TODO())
if err != nil {
return nil, err
}
for _, obj := range output.Contents {
if obj.Key != nil && !strings.HasSuffix(*obj.Key, "/") {
fileName := strings.TrimPrefix(*obj.Key, prefix)
if fileName != "" && !strings.Contains(fileName, "/") {
files = append(files, fileName)
}
}
}
}
return files, nil
}
// Get 读取文件内容
func (s *S3) Get(file string) ([]byte, error) {
key := s.getKey(file)
output, err := s.client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return nil, err
}
defer output.Body.Close()
return io.ReadAll(output.Body)
}
// LastModified 获取文件最后修改时间
func (s *S3) LastModified(file string) (time.Time, error) {
key := s.getKey(file)
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return time.Time{}, err
}
if output.LastModified != nil {
return *output.LastModified, nil
}
return time.Time{}, nil
}
// MimeType 获取文件的 MIME 类型
func (s *S3) MimeType(file string) (string, error) {
key := s.getKey(file)
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return "", err
}
if output.ContentType != nil {
return *output.ContentType, nil
}
// 根据文件扩展名推断
ext := filepath.Ext(file)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
return "application/octet-stream", nil
}
return mimeType, nil
}
// Missing 检查文件是否不存在
func (s *S3) Missing(file string) bool {
return !s.Exists(file)
}
// Move 移动文件到新位置
func (s *S3) Move(oldFile, newFile string) error {
// 先复制
if err := s.Copy(oldFile, newFile); err != nil {
return err
}
// 再删除原文件
return s.Delete(oldFile)
}
// Path 获取文件的完整路径
func (s *S3) Path(file string) string {
// 根据地址模式返回不同的 URL 格式
key := s.getKey(file)
if s.config.Endpoint != "" {
// 自定义端点
return fmt.Sprintf("%s/%s/%s", strings.TrimSuffix(s.config.Endpoint, "/"), s.config.Bucket, key)
}
switch s.config.AddressingStyle {
case S3AddressingStyleVirtualHosted:
// Virtual Hosted 模式https://bucket.s3.region.amazonaws.com/key
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", s.config.Bucket, s.config.Region, key)
case S3AddressingStylePath:
// Path 模式https://s3.region.amazonaws.com/bucket/key
return fmt.Sprintf("https://s3.%s.amazonaws.com/%s/%s", s.config.Region, s.config.Bucket, key)
default:
// 默认返回 s3:// 协议格式
return fmt.Sprintf("s3://%s/%s", s.config.Bucket, key)
}
}
// Put 写入文件内容
func (s *S3) Put(file, content string) error {
key := s.getKey(file)
// 推断 MIME 类型
ext := filepath.Ext(file)
contentType := mime.TypeByExtension(ext)
if contentType == "" {
contentType = "application/octet-stream"
}
_, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
Body: bytes.NewReader([]byte(content)),
ContentType: aws.String(contentType),
})
return err
}
// Size 获取文件大小
func (s *S3) Size(file string) (int64, error) {
key := s.getKey(file)
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return 0, err
}
if output.ContentLength != nil {
return *output.ContentLength, nil
}
return 0, nil
}

341
pkg/storage/sftp.go Normal file
View File

@@ -0,0 +1,341 @@
package storage
import (
"bytes"
"fmt"
"io"
"mime"
"os"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
type SFTPConfig struct {
Host string // SFTP 服务器地址
Port int // SFTP 端口,默认 22
Username string // 用户名
Password string // 密码
PrivateKey string // SSH 私钥路径或内容
BasePath string // 基础路径
Timeout time.Duration // 连接超时时间
}
type SFTP struct {
config SFTPConfig
}
func NewSFTP(config SFTPConfig) (Storage, error) {
if config.Port == 0 {
config.Port = 22
}
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
config.BasePath = strings.Trim(config.BasePath, "/")
s := &SFTP{
config: config,
}
if err := s.ensureBasePath(); err != nil {
return nil, fmt.Errorf("failed to ensure base path: %w", err)
}
return s, nil
}
// connect 建立 SFTP 连接
func (s *SFTP) connect() (*sftp.Client, func(), error) {
var auth []ssh.AuthMethod
// 密码认证
if s.config.Password != "" {
auth = append(auth, ssh.Password(s.config.Password))
}
// 私钥认证
if s.config.PrivateKey != "" {
var signer ssh.Signer
var err error
if _, statErr := os.Stat(s.config.PrivateKey); statErr == nil {
// 私钥文件路径
keyBytes, err := os.ReadFile(s.config.PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to read private key file: %w", err)
}
signer, err = ssh.ParsePrivateKey(keyBytes)
} else {
// 私钥内容
signer, err = ssh.ParsePrivateKey([]byte(s.config.PrivateKey))
}
if err != nil {
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
}
auth = append(auth, ssh.PublicKeys(signer))
}
clientConfig := &ssh.ClientConfig{
User: s.config.Username,
Auth: auth,
Timeout: s.config.Timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
sshClient, err := ssh.Dial("tcp", addr, clientConfig)
if err != nil {
return nil, nil, err
}
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return nil, nil, err
}
cleanup := func() {
_ = sftpClient.Close()
_ = sshClient.Close()
}
return sftpClient, cleanup, nil
}
// ensureBasePath 确保基础路径存在
func (s *SFTP) ensureBasePath() error {
if s.config.BasePath == "" {
return nil
}
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
return client.MkdirAll(s.config.BasePath)
}
// getRemotePath 获取远程路径
func (s *SFTP) getRemotePath(path string) string {
path = strings.TrimPrefix(path, "/")
if s.config.BasePath == "" {
return path
}
if path == "" {
return s.config.BasePath
}
return filepath.Join(s.config.BasePath, path)
}
// MakeDirectory 创建目录
func (s *SFTP) MakeDirectory(directory string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
remotePath := s.getRemotePath(directory)
return client.MkdirAll(remotePath)
}
// DeleteDirectory 删除目录
func (s *SFTP) DeleteDirectory(directory string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
remotePath := s.getRemotePath(directory)
return client.RemoveDirectory(remotePath)
}
// Copy 复制文件到新位置
func (s *SFTP) Copy(oldFile, newFile string) error {
// SFTP 不支持直接复制,需要读取再写入
data, err := s.Get(oldFile)
if err != nil {
return err
}
return s.Put(newFile, string(data))
}
// Delete 删除文件
func (s *SFTP) Delete(files ...string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
for _, file := range files {
remotePath := s.getRemotePath(file)
if err := client.Remove(remotePath); err != nil {
return err
}
}
return nil
}
// Exists 检查文件是否存在
func (s *SFTP) Exists(file string) bool {
client, cleanup, err := s.connect()
if err != nil {
return false
}
defer cleanup()
remotePath := s.getRemotePath(file)
_, err = client.Stat(remotePath)
return err == nil
}
// Files 获取目录下的所有文件
func (s *SFTP) Files(path string) ([]string, error) {
client, cleanup, err := s.connect()
if err != nil {
return nil, err
}
defer cleanup()
remotePath := s.getRemotePath(path)
entries, err := client.ReadDir(remotePath)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files, nil
}
// Get 读取文件内容
func (s *SFTP) Get(file string) ([]byte, error) {
client, cleanup, err := s.connect()
if err != nil {
return nil, err
}
defer cleanup()
remotePath := s.getRemotePath(file)
remoteFile, err := client.Open(remotePath)
if err != nil {
return nil, err
}
defer func() { _ = remoteFile.Close() }()
return io.ReadAll(remoteFile)
}
// LastModified 获取文件最后修改时间
func (s *SFTP) LastModified(file string) (time.Time, error) {
client, cleanup, err := s.connect()
if err != nil {
return time.Time{}, err
}
defer cleanup()
remotePath := s.getRemotePath(file)
stat, err := client.Stat(remotePath)
if err != nil {
return time.Time{}, err
}
return stat.ModTime(), nil
}
// MimeType 获取文件的 MIME 类型
func (s *SFTP) MimeType(file string) (string, error) {
ext := filepath.Ext(file)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
return "application/octet-stream", nil
}
return mimeType, nil
}
// Missing 检查文件是否不存在
func (s *SFTP) Missing(file string) bool {
return !s.Exists(file)
}
// Move 移动文件到新位置
func (s *SFTP) Move(oldFile, newFile string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
oldPath := s.getRemotePath(oldFile)
newPath := s.getRemotePath(newFile)
// 确保目标目录存在
newDir := filepath.Dir(newPath)
if newDir != "." {
_ = client.MkdirAll(newDir)
}
return client.Rename(oldPath, newPath)
}
// Path 获取文件的完整路径
func (s *SFTP) Path(file string) string {
return fmt.Sprintf("sftp://%s:%d/%s", s.config.Host, s.config.Port, s.getRemotePath(file))
}
// Put 写入文件内容
func (s *SFTP) Put(file, content string) error {
client, cleanup, err := s.connect()
if err != nil {
return err
}
defer cleanup()
remotePath := s.getRemotePath(file)
// 确保目录存在
remoteDir := filepath.Dir(remotePath)
if remoteDir != "." {
_ = client.MkdirAll(remoteDir)
}
remoteFile, err := client.Create(remotePath)
if err != nil {
return err
}
defer func() { _ = remoteFile.Close() }()
_, err = io.Copy(remoteFile, bytes.NewReader([]byte(content)))
return err
}
// Size 获取文件大小
func (s *SFTP) Size(file string) (int64, error) {
client, cleanup, err := s.connect()
if err != nil {
return 0, err
}
defer cleanup()
remotePath := s.getRemotePath(file)
stat, err := client.Stat(remotePath)
if err != nil {
return 0, err
}
return stat.Size(), nil
}

36
pkg/storage/types.go Normal file
View File

@@ -0,0 +1,36 @@
package storage
import (
"time"
)
type Storage interface {
// MakeDirectory creates a directory.
MakeDirectory(directory string) error
// DeleteDirectory deletes the given directory.
DeleteDirectory(directory string) error
// Copy the given file to a new location.
Copy(oldFile, newFile string) error
// Delete deletes the given file(s).
Delete(file ...string) error
// Exists determines if a file exists.
Exists(file string) bool
// Files gets all the files from the given directory.
Files(path string) ([]string, error)
// Get gets the contents of a file.
Get(file string) ([]byte, error)
// LastModified gets the file's last modified time.
LastModified(file string) (time.Time, error)
// MimeType gets the file's mime type.
MimeType(file string) (string, error)
// Missing determines if a file is missing.
Missing(file string) bool
// Move a file to a new location.
Move(oldFile, newFile string) error
// Path gets the full path for the file.
Path(file string) string
// Put writes the contents of a file.
Put(file, content string) error
// Size gets the file size of a given file.
Size(file string) (int64, error)
}

212
pkg/storage/webdav.go Normal file
View File

@@ -0,0 +1,212 @@
package storage
import (
"bytes"
"fmt"
"io"
"mime"
"path/filepath"
"strings"
"time"
"github.com/studio-b12/gowebdav"
)
type WebDavConfig struct {
URL string // WebDAV 服务器 URL
Username string // 用户名
Password string // 密码
BasePath string // 基础路径
Timeout time.Duration // 连接超时时间
}
type WebDav struct {
client *gowebdav.Client
config WebDavConfig
}
func NewWebDav(config WebDavConfig) (Storage, error) {
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
config.BasePath = strings.Trim(config.BasePath, "/")
client := gowebdav.NewClient(config.URL, config.Username, config.Password)
client.SetTimeout(config.Timeout)
w := &WebDav{
client: client,
config: config,
}
if err := w.ensureBasePath(); err != nil {
return nil, fmt.Errorf("failed to ensure base path: %w", err)
}
return w, nil
}
// ensureBasePath 确保基础路径存在
func (w *WebDav) ensureBasePath() error {
if w.config.BasePath == "" {
return nil
}
return w.client.MkdirAll(w.config.BasePath, 0755)
}
// getRemotePath 获取远程路径
func (w *WebDav) getRemotePath(path string) string {
path = strings.TrimPrefix(path, "/")
if w.config.BasePath == "" {
return path
}
if path == "" {
return w.config.BasePath
}
return filepath.Join(w.config.BasePath, path)
}
// MakeDirectory 创建目录
func (w *WebDav) MakeDirectory(directory string) error {
remotePath := w.getRemotePath(directory)
return w.client.MkdirAll(remotePath, 0755)
}
// DeleteDirectory 删除目录
func (w *WebDav) DeleteDirectory(directory string) error {
remotePath := w.getRemotePath(directory)
return w.client.RemoveAll(remotePath)
}
// Copy 复制文件到新位置
func (w *WebDav) Copy(oldFile, newFile string) error {
oldPath := w.getRemotePath(oldFile)
newPath := w.getRemotePath(newFile)
// 确保目标目录存在
newDir := filepath.Dir(newPath)
if newDir != "." {
_ = w.client.MkdirAll(newDir, 0755)
}
return w.client.Copy(oldPath, newPath, false)
}
// Delete 删除文件
func (w *WebDav) Delete(files ...string) error {
for _, file := range files {
remotePath := w.getRemotePath(file)
if err := w.client.Remove(remotePath); err != nil {
return err
}
}
return nil
}
// Exists 检查文件是否存在
func (w *WebDav) Exists(file string) bool {
remotePath := w.getRemotePath(file)
_, err := w.client.Stat(remotePath)
return err == nil
}
// Files 获取目录下的所有文件
func (w *WebDav) Files(path string) ([]string, error) {
remotePath := w.getRemotePath(path)
entries, err := w.client.ReadDir(remotePath)
if err != nil {
return nil, err
}
var files []string
for _, entry := range entries {
if !entry.IsDir() {
files = append(files, entry.Name())
}
}
return files, nil
}
// Get 读取文件内容
func (w *WebDav) Get(file string) ([]byte, error) {
remotePath := w.getRemotePath(file)
reader, err := w.client.ReadStream(remotePath)
if err != nil {
return nil, err
}
defer func() { _ = reader.Close() }()
return io.ReadAll(reader)
}
// LastModified 获取文件最后修改时间
func (w *WebDav) LastModified(file string) (time.Time, error) {
remotePath := w.getRemotePath(file)
stat, err := w.client.Stat(remotePath)
if err != nil {
return time.Time{}, err
}
return stat.ModTime(), nil
}
// MimeType 获取文件的 MIME 类型
func (w *WebDav) MimeType(file string) (string, error) {
ext := filepath.Ext(file)
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
return "application/octet-stream", nil
}
return mimeType, nil
}
// Missing 检查文件是否不存在
func (w *WebDav) Missing(file string) bool {
return !w.Exists(file)
}
// Move 移动文件到新位置
func (w *WebDav) Move(oldFile, newFile string) error {
oldPath := w.getRemotePath(oldFile)
newPath := w.getRemotePath(newFile)
// 确保目标目录存在
newDir := filepath.Dir(newPath)
if newDir != "." {
_ = w.client.MkdirAll(newDir, 0755)
}
return w.client.Rename(oldPath, newPath, false)
}
// Path 获取文件的完整路径
func (w *WebDav) Path(file string) string {
remotePath := w.getRemotePath(file)
return fmt.Sprintf("%s/%s", strings.TrimSuffix(w.config.URL, "/"), remotePath)
}
// Put 写入文件内容
func (w *WebDav) Put(file, content string) error {
remotePath := w.getRemotePath(file)
// 确保目录存在
remoteDir := filepath.Dir(remotePath)
if remoteDir != "." {
_ = w.client.MkdirAll(remoteDir, 0755)
}
return w.client.WriteStream(remotePath, bytes.NewReader([]byte(content)), 0644)
}
// Size 获取文件大小
func (w *WebDav) Size(file string) (int64, error) {
remotePath := w.getRemotePath(file)
stat, err := w.client.Stat(remotePath)
if err != nil {
return 0, err
}
return stat.Size(), nil
}