diff --git a/go.mod b/go.mod index f3e63111..428c1c7c 100644 --- a/go.mod +++ b/go.mod @@ -8,10 +8,10 @@ tool go.mau.fi/util/cmd/maubuild require ( github.com/beeper/poly1305 v0.0.0-20250815183548-d4eede7bbf3c + github.com/coder/websocket v1.8.14 github.com/gabriel-vasile/mimetype v1.4.13 github.com/google/go-querystring v1.2.0 github.com/google/uuid v1.6.0 - github.com/gorilla/websocket v1.5.0 github.com/mattn/go-colorable v0.1.14 github.com/rs/zerolog v1.35.1 github.com/tidwall/gjson v1.19.0 @@ -31,7 +31,6 @@ require ( require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/beeper/argo-go v1.1.2 // indirect - github.com/coder/websocket v1.8.14 // indirect github.com/coreos/go-systemd/v22 v22.7.0 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/lib/pq v1.12.3 // indirect diff --git a/go.sum b/go.sum index 55daa575..5d0010b5 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,6 @@ github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfh github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/pkg/messagix/dgw/dgwsocket.go b/pkg/messagix/dgw/dgwsocket.go index dcb4a79d..d7cc0c5c 100644 --- a/pkg/messagix/dgw/dgwsocket.go +++ b/pkg/messagix/dgw/dgwsocket.go @@ -14,8 +14,8 @@ import ( "sync/atomic" "time" + "github.com/coder/websocket" "github.com/google/uuid" - "github.com/gorilla/websocket" "github.com/rs/zerolog" "go.mau.fi/util/exsync" @@ -27,7 +27,7 @@ import ( // import cycle type MessagixClient interface { IsAuthenticated() bool - GetDialer() *websocket.Dialer + GetDialer() *websocket.DialOptions GetLogger() *zerolog.Logger GetCookies() *cookies.Cookies GetEndpoint(string) string @@ -56,11 +56,11 @@ func (s *Socket) CanConnect() error { } func (s *Socket) Connect(ctx context.Context) (err error) { - dialer := s.client.GetDialer() - headers := s.getConnHeaders() + opts := s.client.GetDialer() + opts.HTTPHeader = s.getConnHeaders() socketURL := s.getConnURL() - conn, resp, err := dialer.DialContext(ctx, socketURL, headers) + conn, resp, err := websocket.Dial(ctx, socketURL, opts) if err != nil { statusCode := 999 if resp != nil { @@ -68,6 +68,7 @@ func (s *Socket) Connect(ctx context.Context) (err error) { } return fmt.Errorf("DGW: %w: %w (status code %d)", socket.ErrDial, err, statusCode) } + conn.SetReadLimit(-1) s.conn = conn err = s.readLoop(ctx, conn) @@ -81,6 +82,7 @@ func (s *Socket) Connect(ctx context.Context) (err error) { const AckTimeout = 5 * time.Second const PingInterval = 15 * time.Second const PongTimeout = 30 * time.Second // from web client +const WriteTimeout = 20 * time.Second type XDTData struct { Data struct { @@ -130,7 +132,7 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { errorOnce.Do(func() { s.err.Store(&err) close(done) - closeErr := conn.Close() + closeErr := conn.CloseNow() if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { s.client.GetLogger().Debug().Err(closeErr).Msg("Error closing DGW connection after " + err.Error()) } @@ -150,16 +152,16 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { defer wg.Done() defer close(incoming) for { - msgtype, data, err := conn.ReadMessage() + msgtype, data, err := conn.Read(ctx) if err != nil { fatalError(fmt.Errorf("reading message: %w", err)) return } for len(data) > 0 { switch msgtype { - case websocket.TextMessage: + case websocket.MessageText: s.client.GetLogger().Warn().Bytes("bytes", data).Msg("Unexpected non-binary DGW message, dropping") - case websocket.BinaryMessage: + case websocket.MessageBinary: frame := CheckFrameType(data) data, err = frame.Unmarshal(data) if err != nil { @@ -187,7 +189,9 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { s.client.GetLogger().Warn().Any("frame", frame).Msg("Failed to marshal outbound frame, dropping") continue } - err = conn.WriteMessage(websocket.BinaryMessage, b) + writeCtx, cancel := context.WithTimeout(ctx, WriteTimeout) + err = conn.Write(writeCtx, websocket.MessageBinary, b) + cancel() if err != nil { fatalError(fmt.Errorf("writing message: %w", err)) } @@ -394,8 +398,7 @@ func (s *Socket) getConnURL() string { func (s *Socket) Disconnect() { if s != nil && s.conn != nil { - _ = s.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(3*time.Second)) - _ = s.conn.Close() + _ = s.conn.Close(websocket.StatusNormalClosure, "") } } diff --git a/pkg/messagix/socket.go b/pkg/messagix/socket.go index dac6ec5f..d31039c0 100644 --- a/pkg/messagix/socket.go +++ b/pkg/messagix/socket.go @@ -15,7 +15,7 @@ import ( "sync/atomic" "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "go.mau.fi/util/exhttp" "go.mau.fi/util/ptr" @@ -93,24 +93,29 @@ func (s *Socket) CanConnect() error { return nil } -func (c *Client) GetDialer() *websocket.Dialer { - dialer := websocket.Dialer{HandshakeTimeout: HandshakeTimeout} +func (c *Client) GetDialer() *websocket.DialOptions { + transport := &http.Transport{} if c.httpProxy != nil { - dialer.Proxy = c.httpProxy + transport.Proxy = c.httpProxy } else if c.socksProxy != nil { - dialer.NetDial = c.socksProxy.Dial - - contextDialer, ok := c.socksProxy.(proxy.ContextDialer) - if ok { - dialer.NetDialContext = contextDialer.DialContext + if contextDialer, ok := c.socksProxy.(proxy.ContextDialer); ok { + transport.DialContext = contextDialer.DialContext + } else { + //nolint:staticcheck // fallback path: socksProxy doesn't implement ContextDialer + transport.Dial = c.socksProxy.Dial } } if DisableTLSVerification { - dialer.TLSClientConfig = &tls.Config{ + transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } } - return &dialer + return &websocket.DialOptions{ + HTTPClient: &http.Client{ + Timeout: HandshakeTimeout, + Transport: transport, + }, + } } func (s *Socket) Connect(ctx context.Context) error { @@ -119,15 +124,17 @@ func (s *Socket) Connect(ctx context.Context) error { return err } - headers := s.getConnHeaders() brokerUrl := s.BuildBrokerURL() - dialer := s.client.GetDialer() + opts := s.client.GetDialer() + opts.HTTPHeader = s.getConnHeaders() s.client.Logger.Debug().Str("broker", brokerUrl).Msg("Dialing socket") - conn, _, err := dialer.DialContext(ctx, brokerUrl, headers) + conn, _, err := websocket.Dial(ctx, brokerUrl, opts) if err != nil { return fmt.Errorf("%w: %w", socket.ErrDial, err) } + // Disable read size limit; MQTT messages can exceed the default 32KiB. + conn.SetReadLimit(-1) s.conn = conn err = s.sendConnectPacket() @@ -173,8 +180,7 @@ func (s *Socket) Disconnect() { (*fn)() } if s.conn != nil { - _ = s.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(3*time.Second)) - _ = s.conn.Close() + _ = s.conn.Close(websocket.StatusNormalClosure, "") } } @@ -188,17 +194,11 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed cleanly"))) closedCleanly.Store(true) })) - conn.SetCloseHandler(func(code int, text string) error { - closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed by server: %d %s", code, text))) - closedCleanly.Store(true) - s.client.Logger.Info().Int("code", code).Str("text", text).Msg("Websocket closed by server") - return nil - }) pongTimeoutTimer := time.NewTimer(PongTimeout) defer pongTimeoutTimer.Stop() wsQueue := make(chan any, 32) closeDueToError := func(reason string) { - err := conn.Close() + err := conn.CloseNow() if err != nil && !errors.Is(err, net.ErrClosed) { s.client.Logger.Debug().Err(err).Msg("Error closing connection after " + reason) } @@ -299,12 +299,19 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { }() zerolog.Ctx(ctx).Debug().Msg("Connection established, starting read loop") for { - messageType, p, err := conn.ReadMessage() + messageType, p, err := conn.Read(ctx) if err != nil { - closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("failed to read message: %w", err))) - if !closedCleanly.Load() { - s.client.Logger.Err(err).Msg("Error reading message from socket") - closeDueToError("reading message failed") + var ce websocket.CloseError + if errors.As(err, &ce) { + closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed by server: %d %s", ce.Code, ce.Reason))) + closedCleanly.Store(true) + s.client.Logger.Info().Int("code", int(ce.Code)).Str("text", ce.Reason).Msg("Websocket closed by server") + } else { + closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("failed to read message: %w", err))) + if !closedCleanly.Load() { + s.client.Logger.Err(err).Msg("Error reading message from socket") + closeDueToError("reading message failed") + } } // Hacky sleep to give the ready handler time to run and set the best available error time.Sleep(100 * time.Millisecond) @@ -314,9 +321,9 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error { pongTimeoutTimer.Reset(PongTimeout) switch messageType { - case websocket.TextMessage: + case websocket.MessageText: s.client.Logger.Warn().Bytes("bytes", p).Msg("Unexpected text message in websocket") - case websocket.BinaryMessage: + case websocket.MessageBinary: handleBinaryMessage(p) } } @@ -329,12 +336,11 @@ func (s *Socket) sendData(data []byte) error { if conn == nil { return fmt.Errorf("not connected") } - if err := conn.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil { - return fmt.Errorf("failed to set write deadline: %w", err) - } - err := conn.WriteMessage(websocket.BinaryMessage, data) + ctx, cancel := context.WithTimeout(context.Background(), WriteTimeout) + defer cancel() + err := conn.Write(ctx, websocket.MessageBinary, data) if exhttp.IsNetworkError(err) { - closeErr := conn.Close() + closeErr := conn.CloseNow() if closeErr != nil && !errors.Is(err, net.ErrClosed) { s.client.Logger.Debug().Err(closeErr).Msg("Error closing connection after network error") return errors.Join(err, closeErr)