From c60bba465d1c3f745a0d605e07671cd7539421df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Sun, 23 Nov 2025 01:49:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=8D=A2websocket=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/ace/wire_gen.go | 2 +- go.mod | 4 +-- go.sum | 8 ++--- internal/service/ws.go | 79 +++++++++++++++++++----------------------- pkg/ssh/turn.go | 25 ++++++------- 5 files changed, 53 insertions(+), 65 deletions(-) diff --git a/cmd/ace/wire_gen.go b/cmd/ace/wire_gen.go index ee487039..c46ceafa 100644 --- a/cmd/ace/wire_gen.go +++ b/cmd/ace/wire_gen.go @@ -143,7 +143,7 @@ func initWeb() (*app.Web, error) { supervisorApp := supervisor.NewApp(locale) loader := bootstrap.NewLoader(codeserverApp, dockerApp, fail2banApp, frpApp, giteaApp, memcachedApp, minioApp, mysqlApp, nginxApp, php74App, php80App, php81App, php82App, php83App, php84App, phpmyadminApp, podmanApp, postgresqlApp, pureftpdApp, redisApp, rsyncApp, s3fsApp, supervisorApp) http := route.NewHttp(koanf, userService, userTokenService, dashboardService, taskService, websiteService, databaseService, databaseServerService, databaseUserService, backupService, certService, certDNSService, certAccountService, appService, cronService, processService, safeService, firewallService, sshService, containerService, containerComposeService, containerNetworkService, containerImageService, containerVolumeService, fileService, monitorService, settingService, systemctlService, toolboxSystemService, toolboxBenchmarkService, loader) - wsService := service.NewWsService(locale, koanf, sshRepo) + wsService := service.NewWsService(locale, koanf, logger, sshRepo) ws := route.NewWs(wsService) mux, err := bootstrap.NewRouter(locale, middlewares, http, ws) if err != nil { diff --git a/go.mod b/go.mod index 5f7b4303..190563ea 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.92.0 github.com/bddjr/hlfhr v1.4.0 github.com/beevik/ntp v1.5.0 + github.com/coder/websocket v1.8.12 github.com/creack/pty v1.1.24 github.com/expr-lang/expr v1.17.6 github.com/go-chi/chi/v5 v5.2.3 @@ -20,7 +21,6 @@ require ( github.com/google/wire v0.7.0 github.com/gookit/color v1.6.0 github.com/gookit/validate v1.5.6 - github.com/gorilla/websocket v1.5.3 github.com/hashicorp/go-version v1.7.0 github.com/jlaffaye/ftp v0.2.0 github.com/knadh/koanf/parsers/yaml v1.1.0 @@ -40,7 +40,7 @@ require ( github.com/libdns/westcn v1.0.2 github.com/libtnb/chix v1.3.2 github.com/libtnb/gormstore v1.1.1 - github.com/libtnb/sessions v1.2.1 + github.com/libtnb/sessions v1.2.2-0.20251122173530-a4002b1c459d github.com/libtnb/utils v1.2.1 github.com/mholt/acmez/v3 v3.1.4 github.com/moby/moby/api v1.52.0 diff --git a/go.sum b/go.sum index f37af105..b841a3b6 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVk github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -183,8 +185,6 @@ github.com/gookit/validate v1.5.6 h1:D6vbSZzreuKYpeeXm5FDDEJy3K5E4lcWsQE4saSMZbU github.com/gookit/validate v1.5.6/go.mod h1:WYEHndRNepIIkM+6CtgEX9MQ9ToIQRhXxmz5oLHF/fc= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= @@ -280,8 +280,8 @@ github.com/libtnb/gormstore v1.1.1 h1:FG/3P4PuWM6/vB4weVJ31meiSaoeXns1NQlP66quKe github.com/libtnb/gormstore v1.1.1/go.mod h1:8A5QzeZxi1MpSmjUVsHTDAL6KnU84feIXMutFLPawwA= github.com/libtnb/securecookie v1.2.0 h1:2uc0PBDm0foeSTrcZ9QTX1IEjf6kFEwfgEYSIXQSKrA= github.com/libtnb/securecookie v1.2.0/go.mod h1:ja+wNGnQzYqcqXQnJWu6icsaWi5JEBwNEMJ2ReTVDxA= -github.com/libtnb/sessions v1.2.1 h1:O9gkEIeZuqyaxopXrUJcGxlNxmNfRBI8BOK43yLJXDI= -github.com/libtnb/sessions v1.2.1/go.mod h1:45Bn9d6PseDINLIM1QaJrlCMbzSZ0NWpDbWkdrKJKw0= +github.com/libtnb/sessions v1.2.2-0.20251122173530-a4002b1c459d h1:PIS6RcMg03UlAkLuif8go4G5fv1x6xFZBK7koBwNd4c= +github.com/libtnb/sessions v1.2.2-0.20251122173530-a4002b1c459d/go.mod h1:qw+FWtBtrPDYCf6MfX0Lk5EhTArpvT72z5Ei4RUMTRg= github.com/libtnb/utils v1.2.1 h1:LJmReRREnpqfHyy9PZtNgBh3ZaIGct81b8ZaAsolMkM= github.com/libtnb/utils v1.2.1/go.mod h1:o6LEDeC42PXI21uLWdWJWTVYvR9BtAZfzzTGJVQoQiU= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= diff --git a/internal/service/ws.go b/internal/service/ws.go index 0caa7e33..953c973c 100644 --- a/internal/service/ws.go +++ b/internal/service/ws.go @@ -3,10 +3,11 @@ package service import ( "bufio" "context" + "log/slog" "net/http" "sync" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/knadh/koanf/v2" "github.com/leonelquinteros/gotext" stdssh "golang.org/x/crypto/ssh" @@ -20,13 +21,15 @@ import ( type WsService struct { t *gotext.Locale conf *koanf.Koanf + log *slog.Logger sshRepo biz.SSHRepo } -func NewWsService(t *gotext.Locale, conf *koanf.Koanf, ssh biz.SSHRepo) *WsService { +func NewWsService(t *gotext.Locale, conf *koanf.Koanf, log *slog.Logger, ssh biz.SSHRepo) *WsService { return &WsService{ t: t, conf: conf, + log: log, sshRepo: ssh, } } @@ -45,32 +48,28 @@ func (s *WsService) Session(w http.ResponseWriter, r *http.Request) { ws, err := s.upgrade(w, r) if err != nil { - ErrorSystem(w) + s.log.Warn("[Websocket] upgrade session ws error", slog.Any("error", err)) return } - defer func(ws *websocket.Conn) { - _ = ws.Close() - }(ws) + defer func(ws *websocket.Conn) { _ = ws.CloseNow() }(ws) client, err := ssh.NewSSHClient(info.Config) if err != nil { - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error())) + _ = ws.Close(websocket.StatusNormalClosure, err.Error()) return } - defer func(client *stdssh.Client) { - _ = client.Close() - }(client) - - turn, err := ssh.NewTurn(ws, client) - if err != nil { - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error())) - return - } - defer func(turn *ssh.Turn) { - _ = turn.Close() - }(turn) + defer func(client *stdssh.Client) { _ = client.Close() }(client) ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + turn, err := ssh.NewTurn(ctx, ws, client) + if err != nil { + _ = ws.Close(websocket.StatusNormalClosure, err.Error()) + return + } + defer func(turn *ssh.Turn) { _ = turn.Close() }(turn) + wg := sync.WaitGroup{} wg.Add(2) @@ -84,31 +83,29 @@ func (s *WsService) Session(w http.ResponseWriter, r *http.Request) { }() wg.Wait() - cancel() } func (s *WsService) Exec(w http.ResponseWriter, r *http.Request) { ws, err := s.upgrade(w, r) if err != nil { - ErrorSystem(w) + s.log.Warn("[Websocket] upgrade exec ws error", slog.Any("error", err)) return } - defer func(ws *websocket.Conn) { - _ = ws.Close() - }(ws) + defer func(ws *websocket.Conn) { _ = ws.CloseNow() }(ws) // 第一条消息是命令 - _, cmd, err := ws.ReadMessage() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, cmd, err := ws.Read(ctx) if err != nil { - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, s.t.Get("failed to read command: %v", err))) + _ = ws.Close(websocket.StatusNormalClosure, s.t.Get("failed to read command: %v", err)) return } - ctx, cancel := context.WithCancel(context.Background()) out, err := shell.ExecfWithPipe(ctx, string(cmd)) if err != nil { - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, s.t.Get("failed to run command: %v", err))) - cancel() + _ = ws.Close(websocket.StatusNormalClosure, s.t.Get("failed to run command: %v", err)) return } @@ -116,38 +113,34 @@ func (s *WsService) Exec(w http.ResponseWriter, r *http.Request) { scanner := bufio.NewScanner(out) for scanner.Scan() { line := scanner.Text() - _ = ws.WriteMessage(websocket.TextMessage, []byte(line)) + _ = ws.Write(ctx, websocket.MessageText, []byte(line)) } if err = scanner.Err(); err != nil { - _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, s.t.Get("failed to read command output: %v", err))) + _ = ws.Close(websocket.StatusNormalClosure, s.t.Get("failed to read command output: %v", err)) } }() - s.readLoop(ws) - cancel() + s.readLoop(ctx, ws) } func (s *WsService) upgrade(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { - upGrader := websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, + opts := &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionContextTakeover, } // debug 模式下不校验 origin,方便 vite 代理调试 if s.conf.Bool("app.debug") { - upGrader.CheckOrigin = func(r *http.Request) bool { - return true - } + opts.InsecureSkipVerify = true } - return upGrader.Upgrade(w, r, nil) + return websocket.Accept(w, r, opts) } // readLoop 阻塞直到客户端关闭连接 -func (s *WsService) readLoop(c *websocket.Conn) { +func (s *WsService) readLoop(ctx context.Context, c *websocket.Conn) { for { - if _, _, err := c.NextReader(); err != nil { - _ = c.Close() + if _, _, err := c.Read(ctx); err != nil { + _ = c.CloseNow() break } } diff --git a/pkg/ssh/turn.go b/pkg/ssh/turn.go index 28854673..33f407b6 100644 --- a/pkg/ssh/turn.go +++ b/pkg/ssh/turn.go @@ -7,7 +7,7 @@ import ( "fmt" "io" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "golang.org/x/crypto/ssh" ) @@ -18,12 +18,13 @@ type MessageResize struct { } type Turn struct { + ctx context.Context stdin io.WriteCloser session *ssh.Session ws *websocket.Conn } -func NewTurn(ws *websocket.Conn, client *ssh.Client) (*Turn, error) { +func NewTurn(ctx context.Context, ws *websocket.Conn, client *ssh.Client) (*Turn, error) { sess, err := client.NewSession() if err != nil { return nil, err @@ -34,7 +35,7 @@ func NewTurn(ws *websocket.Conn, client *ssh.Client) (*Turn, error) { return nil, err } - turn := &Turn{stdin: stdin, session: sess, ws: ws} + turn := &Turn{ctx: ctx, stdin: stdin, session: sess, ws: ws} sess.Stdout = turn sess.Stderr = turn @@ -54,33 +55,27 @@ func NewTurn(ws *websocket.Conn, client *ssh.Client) (*Turn, error) { } func (t *Turn) Write(p []byte) (n int, err error) { - writer, err := t.ws.NextWriter(websocket.TextMessage) - if err != nil { + if err = t.ws.Write(t.ctx, websocket.MessageText, p); err != nil { return 0, err } - defer func(writer io.WriteCloser) { - _ = writer.Close() - }(writer) - - return writer.Write(p) + return len(p), nil } func (t *Turn) Close() error { if t.session != nil { _ = t.session.Close() } - - return t.ws.Close() + return t.ws.CloseNow() } -func (t *Turn) Handle(context context.Context) error { +func (t *Turn) Handle(ctx context.Context) error { var resize MessageResize for { select { - case <-context.Done(): + case <-ctx.Done(): return errors.New("ssh context done exit") default: - _, data, err := t.ws.ReadMessage() + _, data, err := t.ws.Read(ctx) if err != nil { // 通常是客户端关闭连接 return fmt.Errorf("reading ws message err: %v", err)