mirror of
https://github.com/acepanel/panel.git
synced 2026-02-04 12:40:25 +08:00
241 lines
6.1 KiB
Go
241 lines
6.1 KiB
Go
package storage
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/aws/aws-sdk-go-v2/aws"
|
||
"github.com/aws/aws-sdk-go-v2/config"
|
||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||
)
|
||
|
||
// S3AddressingStyle S3 地址模式
|
||
type S3AddressingStyle string
|
||
|
||
const (
|
||
// S3AddressingStylePath Path 模式:https://s3.region.amazonaws.com/bucket/key
|
||
S3AddressingStylePath S3AddressingStyle = "path"
|
||
// S3AddressingStyleVirtualHosted Virtual Hosted 模式:https://bucket.s3.region.amazonaws.com/key
|
||
S3AddressingStyleVirtualHosted S3AddressingStyle = "virtual-hosted"
|
||
)
|
||
|
||
type S3Config struct {
|
||
Region string // AWS 区域
|
||
Bucket string // S3 存储桶名称
|
||
AccessKeyID string // 访问密钥 ID
|
||
SecretAccessKey string // 访问密钥
|
||
Endpoint string // 自定义端点
|
||
BasePath string // 基础路径前缀
|
||
AddressingStyle S3AddressingStyle // 地址模式
|
||
}
|
||
|
||
type S3 struct {
|
||
client *s3.Client
|
||
config S3Config
|
||
}
|
||
|
||
func NewS3(cfg S3Config) (Storage, error) {
|
||
if cfg.AddressingStyle == "" {
|
||
cfg.AddressingStyle = S3AddressingStyleVirtualHosted
|
||
}
|
||
|
||
cfg.BasePath = strings.Trim(cfg.BasePath, "/")
|
||
|
||
var awsCfg aws.Config
|
||
var err error
|
||
|
||
awsCfg, err = config.LoadDefaultConfig(context.TODO(),
|
||
config.WithRegion(cfg.Region),
|
||
config.WithCredentialsProvider(
|
||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||
),
|
||
config.WithRequestChecksumCalculation(aws.RequestChecksumCalculationWhenRequired),
|
||
config.WithResponseChecksumValidation(aws.ResponseChecksumValidationWhenRequired),
|
||
config.WithRetryMaxAttempts(10),
|
||
)
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||
}
|
||
|
||
var client *s3.Client
|
||
if cfg.Endpoint != "" {
|
||
// 自定义端点
|
||
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
|
||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||
})
|
||
} else {
|
||
// 标准 AWS S3
|
||
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
|
||
})
|
||
}
|
||
|
||
return &S3{
|
||
client: client,
|
||
config: cfg,
|
||
}, nil
|
||
}
|
||
|
||
// Delete 删除文件
|
||
func (s *S3) Delete(files ...string) error {
|
||
if len(files) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 批量删除
|
||
var objects []types.ObjectIdentifier
|
||
for _, file := range files {
|
||
key := s.getKey(file)
|
||
objects = append(objects, types.ObjectIdentifier{
|
||
Key: aws.String(key),
|
||
})
|
||
}
|
||
|
||
_, err := s.client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Delete: &types.Delete{
|
||
Objects: objects,
|
||
},
|
||
})
|
||
|
||
waiter := s3.NewObjectNotExistsWaiter(s.client)
|
||
for _, file := range files {
|
||
key := s.getKey(file)
|
||
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
}, 30*time.Second)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
// Exists 检查文件是否存在
|
||
func (s *S3) Exists(file string) bool {
|
||
key := s.getKey(file)
|
||
_, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
})
|
||
return err == nil
|
||
}
|
||
|
||
// LastModified 获取文件最后修改时间
|
||
func (s *S3) LastModified(file string) (time.Time, error) {
|
||
key := s.getKey(file)
|
||
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
})
|
||
if err != nil {
|
||
return time.Time{}, err
|
||
}
|
||
|
||
if output.LastModified != nil {
|
||
return *output.LastModified, nil
|
||
}
|
||
return time.Time{}, nil
|
||
}
|
||
|
||
// List 列出目录下的所有文件
|
||
func (s *S3) List(path string) ([]string, error) {
|
||
prefix := s.getKey(path)
|
||
if prefix != "" && !strings.HasSuffix(prefix, "/") {
|
||
prefix += "/"
|
||
}
|
||
|
||
var files []string
|
||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Prefix: aws.String(prefix),
|
||
Delimiter: aws.String("/"),
|
||
})
|
||
|
||
for paginator.HasMorePages() {
|
||
page, err := paginator.NextPage(context.TODO())
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, obj := range page.Contents {
|
||
key := aws.ToString(obj.Key)
|
||
// 跳过目录本身
|
||
if key == prefix {
|
||
continue
|
||
}
|
||
// 提取文件名
|
||
name := strings.TrimPrefix(key, prefix)
|
||
if name != "" && !strings.Contains(name, "/") {
|
||
files = append(files, name)
|
||
}
|
||
}
|
||
}
|
||
|
||
return files, nil
|
||
}
|
||
|
||
// Put 写入文件内容
|
||
func (s *S3) Put(file string, content io.Reader) error {
|
||
key := s.getKey(file)
|
||
|
||
// For S3-compatible providers, disable automatic checksum calculation on the Uploader.
|
||
// The S3 client's RequestChecksumCalculation setting only affects single-part uploads.
|
||
// Multipart uploads via the Uploader require this separate setting (added in s3/manager v1.20.0).
|
||
// See: https://github.com/aws/aws-sdk-go-v2/issues/3007
|
||
uploader := manager.NewUploader(s.client, func(u *manager.Uploader) {
|
||
u.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||
})
|
||
_, err := uploader.Upload(context.TODO(), &s3.PutObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
Body: content,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
waiter := s3.NewObjectExistsWaiter(s.client)
|
||
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
}, 30*time.Second)
|
||
|
||
return err
|
||
}
|
||
|
||
// Size 获取文件大小
|
||
func (s *S3) Size(file string) (int64, error) {
|
||
key := s.getKey(file)
|
||
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||
Bucket: aws.String(s.config.Bucket),
|
||
Key: aws.String(key),
|
||
})
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return aws.ToInt64(output.ContentLength), nil
|
||
}
|
||
|
||
// getKey 获取完整的对象键
|
||
func (s *S3) getKey(file string) string {
|
||
file = strings.TrimPrefix(file, "/")
|
||
if s.config.BasePath == "" {
|
||
return file
|
||
}
|
||
if file == "" {
|
||
return s.config.BasePath
|
||
}
|
||
return fmt.Sprintf("%s/%s", s.config.BasePath, file)
|
||
}
|