mirror of
https://github.com/acepanel/panel.git
synced 2026-02-04 09:13:49 +08:00
feat: 重写s3客户端
This commit is contained in:
@@ -1,18 +1,13 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/rhnvrm/simples3"
|
||||
)
|
||||
|
||||
// S3AddressingStyle S3 地址模式
|
||||
@@ -31,56 +26,57 @@ type S3Config struct {
|
||||
AccessKeyID string // 访问密钥 ID
|
||||
SecretAccessKey string // 访问密钥
|
||||
Endpoint string // 自定义端点
|
||||
Scheme string // 协议 http 或 https
|
||||
BasePath string // 基础路径前缀
|
||||
AddressingStyle S3AddressingStyle // 地址模式
|
||||
}
|
||||
|
||||
type S3 struct {
|
||||
client *s3.Client
|
||||
client *simples3.S3
|
||||
config S3Config
|
||||
bucket string // bucket 用于 API 调用
|
||||
}
|
||||
|
||||
func NewS3(cfg S3Config) (Storage, error) {
|
||||
if cfg.AddressingStyle == "" {
|
||||
cfg.AddressingStyle = S3AddressingStyleVirtualHosted
|
||||
}
|
||||
if cfg.Scheme == "" {
|
||||
cfg.Scheme = "https"
|
||||
}
|
||||
|
||||
cfg.BasePath = strings.Trim(cfg.BasePath, "/")
|
||||
|
||||
var awsCfg aws.Config
|
||||
var err error
|
||||
client := simples3.New(cfg.Region, cfg.AccessKeyID, cfg.SecretAccessKey)
|
||||
|
||||
awsCfg, err = config.LoadDefaultConfig(context.TODO(),
|
||||
config.WithRegion(cfg.Region),
|
||||
config.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
config.WithRequestChecksumCalculation(aws.RequestChecksumCalculationWhenRequired),
|
||||
config.WithResponseChecksumValidation(aws.ResponseChecksumValidationWhenRequired),
|
||||
config.WithRetryMaxAttempts(10),
|
||||
)
|
||||
// bucket 用于 API 调用
|
||||
// Virtual Hosted Style 时 bucket 已在 endpoint 中,API 调用时传空
|
||||
// Path Style 时 bucket 需要作为路径的一部分
|
||||
bucket := cfg.Bucket
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
var client *s3.Client
|
||||
if cfg.Endpoint != "" {
|
||||
// 自定义端点
|
||||
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||||
})
|
||||
// 自定义 Endpoint
|
||||
if cfg.AddressingStyle == S3AddressingStyleVirtualHosted {
|
||||
// Virtual Hosted Style: https://{bucket}.{endpoint}/{key}
|
||||
client.SetEndpoint(fmt.Sprintf("%s://%s.%s", cfg.Scheme, cfg.Bucket, cfg.Endpoint))
|
||||
bucket = ""
|
||||
} else {
|
||||
// Path Style: https://{endpoint}/{bucket}/{key}
|
||||
client.SetEndpoint(fmt.Sprintf("%s://%s", cfg.Scheme, cfg.Endpoint))
|
||||
}
|
||||
} else {
|
||||
// 标准 AWS S3
|
||||
client = s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
o.UsePathStyle = cfg.AddressingStyle == S3AddressingStylePath
|
||||
})
|
||||
// AWS S3
|
||||
if cfg.AddressingStyle == S3AddressingStyleVirtualHosted {
|
||||
// Virtual Hosted Style: https://{bucket}.s3.{region}.amazonaws.com/{key}
|
||||
client.SetEndpoint(fmt.Sprintf("https://%s.s3.%s.amazonaws.com", cfg.Bucket, cfg.Region))
|
||||
bucket = ""
|
||||
}
|
||||
}
|
||||
|
||||
return &S3{
|
||||
client: client,
|
||||
config: cfg,
|
||||
bucket: bucket,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -91,42 +87,27 @@ func (s *S3) Delete(files ...string) error {
|
||||
}
|
||||
|
||||
// 批量删除
|
||||
var objects []types.ObjectIdentifier
|
||||
var objects []string
|
||||
for _, file := range files {
|
||||
key := s.getKey(file)
|
||||
objects = append(objects, types.ObjectIdentifier{
|
||||
Key: aws.String(key),
|
||||
})
|
||||
objects = append(objects, key)
|
||||
}
|
||||
|
||||
_, err := s.client.DeleteObjects(context.TODO(), &s3.DeleteObjectsInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Delete: &types.Delete{
|
||||
Objects: objects,
|
||||
},
|
||||
_, err := s.client.DeleteObjects(simples3.DeleteObjectsInput{
|
||||
Bucket: s.bucket,
|
||||
Objects: objects,
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
waiter := s3.NewObjectNotExistsWaiter(s.client)
|
||||
for _, file := range files {
|
||||
key := s.getKey(file)
|
||||
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
}, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Exists 检查文件是否存在
|
||||
func (s *S3) Exists(file string) bool {
|
||||
key := s.getKey(file)
|
||||
_, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
_, err := s.client.FileDetails(simples3.DetailsInput{
|
||||
Bucket: s.bucket,
|
||||
ObjectKey: key,
|
||||
})
|
||||
return err == nil
|
||||
}
|
||||
@@ -134,18 +115,29 @@ func (s *S3) Exists(file string) bool {
|
||||
// LastModified 获取文件最后修改时间
|
||||
func (s *S3) LastModified(file string) (time.Time, error) {
|
||||
key := s.getKey(file)
|
||||
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
output, err := s.client.FileDetails(simples3.DetailsInput{
|
||||
Bucket: s.bucket,
|
||||
ObjectKey: key,
|
||||
})
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if output.LastModified != nil {
|
||||
return *output.LastModified, nil
|
||||
if output.LastModified == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
return time.Time{}, nil
|
||||
|
||||
// 解析 HTTP 日期格式
|
||||
t, err := time.Parse(time.RFC1123, output.LastModified)
|
||||
if err != nil {
|
||||
// 尝试其他格式
|
||||
t, err = time.Parse(time.RFC1123Z, output.LastModified)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse LastModified: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// List 列出目录下的所有文件
|
||||
@@ -156,31 +148,29 @@ func (s *S3) List(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Prefix: aws.String(prefix),
|
||||
Delimiter: aws.String("/"),
|
||||
seq, finish := s.client.ListAll(simples3.ListInput{
|
||||
Bucket: s.bucket,
|
||||
Prefix: prefix,
|
||||
Delimiter: "/",
|
||||
})
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(context.TODO())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for obj := range seq {
|
||||
key := obj.Key
|
||||
// 跳过目录本身
|
||||
if key == prefix {
|
||||
continue
|
||||
}
|
||||
for _, obj := range page.Contents {
|
||||
key := aws.ToString(obj.Key)
|
||||
// 跳过目录本身
|
||||
if key == prefix {
|
||||
continue
|
||||
}
|
||||
// 提取文件名
|
||||
name := strings.TrimPrefix(key, prefix)
|
||||
if name != "" && !strings.Contains(name, "/") {
|
||||
files = append(files, name)
|
||||
}
|
||||
// 提取文件名
|
||||
name := strings.TrimPrefix(key, prefix)
|
||||
if name != "" && !strings.Contains(name, "/") {
|
||||
files = append(files, name)
|
||||
}
|
||||
}
|
||||
|
||||
if err := finish(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
@@ -188,27 +178,13 @@ func (s *S3) List(path string) ([]string, error) {
|
||||
func (s *S3) Put(file string, content io.Reader) error {
|
||||
key := s.getKey(file)
|
||||
|
||||
// For S3-compatible providers, disable automatic checksum calculation on the Uploader.
|
||||
// The S3 client's RequestChecksumCalculation setting only affects single-part uploads.
|
||||
// Multipart uploads via the Uploader require this separate setting (added in s3/manager v1.20.0).
|
||||
// See: https://github.com/aws/aws-sdk-go-v2/issues/3007
|
||||
uploader := manager.NewUploader(s.client, func(u *manager.Uploader) {
|
||||
u.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
_, err := s.client.FileUploadMultipart(simples3.MultipartUploadInput{
|
||||
Bucket: s.bucket,
|
||||
ObjectKey: key,
|
||||
ContentType: "application/octet-stream",
|
||||
Body: content,
|
||||
Concurrency: 5,
|
||||
})
|
||||
_, err := uploader.Upload(context.TODO(), &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
Body: content,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
waiter := s3.NewObjectExistsWaiter(s.client)
|
||||
err = waiter.Wait(context.TODO(), &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
}, 30*time.Second)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -216,15 +192,20 @@ func (s *S3) Put(file string, content io.Reader) error {
|
||||
// Size 获取文件大小
|
||||
func (s *S3) Size(file string) (int64, error) {
|
||||
key := s.getKey(file)
|
||||
output, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.config.Bucket),
|
||||
Key: aws.String(key),
|
||||
output, err := s.client.FileDetails(simples3.DetailsInput{
|
||||
Bucket: s.bucket,
|
||||
ObjectKey: key,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return aws.ToInt64(output.ContentLength), nil
|
||||
size, err := strconv.ParseInt(output.ContentLength, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse ContentLength: %w", err)
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// getKey 获取完整的对象键
|
||||
|
||||
@@ -9,6 +9,7 @@ type BackupStorageInfo struct {
|
||||
Style string `json:"style"` // virtual-hosted, path
|
||||
Region string `json:"region"` // 地区
|
||||
Endpoint string `json:"endpoint"` // 端点
|
||||
Scheme string `json:"scheme"` // http, https
|
||||
Bucket string `json:"bucket"` // 存储桶
|
||||
|
||||
// SFTP / WebDAV
|
||||
|
||||
Reference in New Issue
Block a user