2
0
mirror of https://github.com/acepanel/panel.git synced 2026-02-04 20:48:42 +08:00
Files
panel/pkg/storage/s3.go
2026-01-20 02:18:13 +08:00

241 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"context"
"fmt"
"io"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
// S3AddressingStyle S3 地址模式
type S3AddressingStyle string
const (
// S3AddressingStylePath Path 模式https://s3.region.amazonaws.com/bucket/key
S3AddressingStylePath S3AddressingStyle = "path"
// S3AddressingStyleVirtualHosted Virtual Hosted 模式https://bucket.s3.region.amazonaws.com/key
S3AddressingStyleVirtualHosted S3AddressingStyle = "virtual-hosted"
)
type S3Config struct {
Region string // AWS 区域
Bucket string // S3 存储桶名称
AccessKeyID string // 访问密钥 ID
SecretAccessKey string // 访问密钥
Endpoint string // 自定义端点
BasePath string // 基础路径前缀
AddressingStyle S3AddressingStyle // 地址模式
}
type S3 struct {
client *s3.Client
config S3Config
}
func NewS3(cfg S3Config) (Storage, error) {
if cfg.AddressingStyle == "" {
cfg.AddressingStyle = S3AddressingStyleVirtualHosted
}
cfg.BasePath = strings.Trim(cfg.BasePath, "/")
var awsCfg aws.Config
var err error
awsCfg, err = config.LoadDefaultConfig(context.TODO(),
config.WithRegion(cfg.Region),
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
config.WithRequestChecksumCalculation(aws.RequestChecksumCalculationWhenRequired),
config.WithResponseChecksumValidation(aws.ResponseChecksumValidationWhenRequired),
config.WithRetryMaxAttempts(10),
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
var client *s3.Client
if cfg.Endpoint != "" {
// 自定义端点
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
o.BaseEndpoint = aws.String(cfg.Endpoint)
})
} else {
// 标准 AWS S3
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
})
}
return &S3{
client: client,
config: cfg,
}, nil
}
// Delete 删除文件
func (s *S3) Delete(files ...string) error {
if len(files) == 0 {
return nil
}
// 批量删除
var objects []types.ObjectIdentifier
for _, file := range files {
key := s.getKey(file)
objects = append(objects, types.ObjectIdentifier{
Key: aws.String(key),
})
}
_, err := s.client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{
Bucket: aws.String(s.config.Bucket),
Delete: &types.Delete{
Objects: objects,
},
})
waiter := s3.NewObjectNotExistsWaiter(s.client)
for _, file := range files {
key := s.getKey(file)
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
}, 30*time.Second)
if err != nil {
return err
}
}
return err
}
// Exists 检查文件是否存在
func (s *S3) Exists(file string) bool {
key := s.getKey(file)
_, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
return err == nil
}
// LastModified 获取文件最后修改时间
func (s *S3) LastModified(file string) (time.Time, error) {
key := s.getKey(file)
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return time.Time{}, err
}
if output.LastModified != nil {
return *output.LastModified, nil
}
return time.Time{}, nil
}
// List 列出目录下的所有文件
func (s *S3) List(path string) ([]string, error) {
prefix := s.getKey(path)
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
var files []string
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: aws.String(s.config.Bucket),
Prefix: aws.String(prefix),
Delimiter: aws.String("/"),
})
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return nil, err
}
for _, obj := range page.Contents {
key := aws.ToString(obj.Key)
// 跳过目录本身
if key == prefix {
continue
}
// 提取文件名
name := strings.TrimPrefix(key, prefix)
if name != "" && !strings.Contains(name, "/") {
files = append(files, name)
}
}
}
return files, nil
}
// Put 写入文件内容
func (s *S3) Put(file string, content io.Reader) error {
key := s.getKey(file)
// For S3-compatible providers, disable automatic checksum calculation on the Uploader.
// The S3 client's RequestChecksumCalculation setting only affects single-part uploads.
// Multipart uploads via the Uploader require this separate setting (added in s3/manager v1.20.0).
// See: https://github.com/aws/aws-sdk-go-v2/issues/3007
uploader := manager.NewUploader(s.client, func(u *manager.Uploader) {
u.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
_, err := uploader.Upload(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
Body: content,
})
if err != nil {
return err
}
waiter := s3.NewObjectExistsWaiter(s.client)
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
}, 30*time.Second)
return err
}
// Size 获取文件大小
func (s *S3) Size(file string) (int64, error) {
key := s.getKey(file)
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
Bucket: aws.String(s.config.Bucket),
Key: aws.String(key),
})
if err != nil {
return 0, err
}
return aws.ToInt64(output.ContentLength), nil
}
// getKey 获取完整的对象键
func (s *S3) getKey(file string) string {
file = strings.TrimPrefix(file, "/")
if s.config.BasePath == "" {
return file
}
if file == "" {
return s.config.BasePath
}
return fmt.Sprintf("%s/%s", s.config.BasePath, file)
}