Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ tool go.mau.fi/util/cmd/maubuild

require (
github.com/beeper/poly1305 v0.0.0-20250815183548-d4eede7bbf3c
github.com/coder/websocket v1.8.14
github.com/gabriel-vasile/mimetype v1.4.13
github.com/google/go-querystring v1.2.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/mattn/go-colorable v0.1.14
github.com/rs/zerolog v1.35.1
github.com/tidwall/gjson v1.19.0
Expand All @@ -31,7 +31,6 @@ require (
require (
filippo.io/edwards25519 v1.2.0 // indirect
github.com/beeper/argo-go v1.1.2 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/coreos/go-systemd/v22 v22.7.0 // indirect
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/lib/pq v1.12.3 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfh
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
Expand Down
27 changes: 15 additions & 12 deletions pkg/messagix/dgw/dgwsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
"sync/atomic"
"time"

"github.com/coder/websocket"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"go.mau.fi/util/exsync"

Expand All @@ -27,7 +27,7 @@ import (
// import cycle
type MessagixClient interface {
IsAuthenticated() bool
GetDialer() *websocket.Dialer
GetDialer() *websocket.DialOptions
GetLogger() *zerolog.Logger
GetCookies() *cookies.Cookies
GetEndpoint(string) string
Expand Down Expand Up @@ -56,18 +56,19 @@ func (s *Socket) CanConnect() error {
}

func (s *Socket) Connect(ctx context.Context) (err error) {
dialer := s.client.GetDialer()
headers := s.getConnHeaders()
opts := s.client.GetDialer()
opts.HTTPHeader = s.getConnHeaders()
socketURL := s.getConnURL()

conn, resp, err := dialer.DialContext(ctx, socketURL, headers)
conn, resp, err := websocket.Dial(ctx, socketURL, opts)
if err != nil {
statusCode := 999
if resp != nil {
statusCode = resp.StatusCode
}
return fmt.Errorf("DGW: %w: %w (status code %d)", socket.ErrDial, err, statusCode)
}
conn.SetReadLimit(-1)
s.conn = conn

err = s.readLoop(ctx, conn)
Expand All @@ -81,6 +82,7 @@ func (s *Socket) Connect(ctx context.Context) (err error) {
const AckTimeout = 5 * time.Second
const PingInterval = 15 * time.Second
const PongTimeout = 30 * time.Second // from web client
const WriteTimeout = 20 * time.Second

type XDTData struct {
Data struct {
Expand Down Expand Up @@ -130,7 +132,7 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error {
errorOnce.Do(func() {
s.err.Store(&err)
close(done)
closeErr := conn.Close()
closeErr := conn.CloseNow()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
s.client.GetLogger().Debug().Err(closeErr).Msg("Error closing DGW connection after " + err.Error())
}
Expand All @@ -150,16 +152,16 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error {
defer wg.Done()
defer close(incoming)
for {
msgtype, data, err := conn.ReadMessage()
msgtype, data, err := conn.Read(ctx)
if err != nil {
fatalError(fmt.Errorf("reading message: %w", err))
return
}
for len(data) > 0 {
switch msgtype {
case websocket.TextMessage:
case websocket.MessageText:
s.client.GetLogger().Warn().Bytes("bytes", data).Msg("Unexpected non-binary DGW message, dropping")
case websocket.BinaryMessage:
case websocket.MessageBinary:
frame := CheckFrameType(data)
data, err = frame.Unmarshal(data)
if err != nil {
Expand Down Expand Up @@ -187,7 +189,9 @@ func (s *Socket) readLoop(ctx context.Context, conn *websocket.Conn) error {
s.client.GetLogger().Warn().Any("frame", frame).Msg("Failed to marshal outbound frame, dropping")
continue
}
err = conn.WriteMessage(websocket.BinaryMessage, b)
writeCtx, cancel := context.WithTimeout(ctx, WriteTimeout)
err = conn.Write(writeCtx, websocket.MessageBinary, b)
cancel()
if err != nil {
fatalError(fmt.Errorf("writing message: %w", err))
}
Expand Down Expand Up @@ -394,8 +398,7 @@ func (s *Socket) getConnURL() string {

func (s *Socket) Disconnect() {
if s != nil && s.conn != nil {
_ = s.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(3*time.Second))
_ = s.conn.Close()
_ = s.conn.Close(websocket.StatusNormalClosure, "")
}
}

Expand Down
76 changes: 41 additions & 35 deletions pkg/messagix/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"sync/atomic"
"time"

"github.com/gorilla/websocket"
"github.com/coder/websocket"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/ptr"
Expand Down Expand Up @@ -93,24 +93,29 @@
return nil
}

func (c *Client) GetDialer() *websocket.Dialer {
dialer := websocket.Dialer{HandshakeTimeout: HandshakeTimeout}
func (c *Client) GetDialer() *websocket.DialOptions {
transport := &http.Transport{}
if c.httpProxy != nil {
dialer.Proxy = c.httpProxy
transport.Proxy = c.httpProxy
} else if c.socksProxy != nil {
dialer.NetDial = c.socksProxy.Dial

contextDialer, ok := c.socksProxy.(proxy.ContextDialer)
if ok {
dialer.NetDialContext = contextDialer.DialContext
if contextDialer, ok := c.socksProxy.(proxy.ContextDialer); ok {
transport.DialContext = contextDialer.DialContext
} else {
//nolint:staticcheck // fallback path: socksProxy doesn't implement ContextDialer
transport.Dial = c.socksProxy.Dial

Check failure on line 105 in pkg/messagix/socket.go

View workflow job for this annotation

GitHub Actions / Lint (old)

transport.Dial has been deprecated since Go 1.7: Use DialContext instead, which allows the transport to cancel dials as soon as they are no longer needed. If both are set, DialContext takes priority. (SA1019)

Check failure on line 105 in pkg/messagix/socket.go

View workflow job for this annotation

GitHub Actions / Lint (latest)

transport.Dial has been deprecated since Go 1.7: Use DialContext instead, which allows the transport to cancel dials as soon as they are no longer needed. If both are set, DialContext takes priority. (SA1019)

Check failure on line 105 in pkg/messagix/socket.go

View workflow job for this annotation

GitHub Actions / Lint (latest)

transport.Dial has been deprecated since Go 1.7: Use DialContext instead, which allows the transport to cancel dials as soon as they are no longer needed. If both are set, DialContext takes priority. (SA1019)

Check failure on line 105 in pkg/messagix/socket.go

View workflow job for this annotation

GitHub Actions / Lint (old)

transport.Dial has been deprecated since Go 1.7: Use DialContext instead, which allows the transport to cancel dials as soon as they are no longer needed. If both are set, DialContext takes priority. (SA1019)
}
}
if DisableTLSVerification {
dialer.TLSClientConfig = &tls.Config{
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
return &dialer
return &websocket.DialOptions{
HTTPClient: &http.Client{
Timeout: HandshakeTimeout,
Transport: transport,
},
}
}

func (s *Socket) Connect(ctx context.Context) error {
Expand All @@ -119,15 +124,17 @@
return err
}

headers := s.getConnHeaders()
brokerUrl := s.BuildBrokerURL()

dialer := s.client.GetDialer()
opts := s.client.GetDialer()
opts.HTTPHeader = s.getConnHeaders()
s.client.Logger.Debug().Str("broker", brokerUrl).Msg("Dialing socket")
conn, _, err := dialer.DialContext(ctx, brokerUrl, headers)
conn, _, err := websocket.Dial(ctx, brokerUrl, opts)
if err != nil {
return fmt.Errorf("%w: %w", socket.ErrDial, err)
}
// Disable read size limit; MQTT messages can exceed the default 32KiB.
conn.SetReadLimit(-1)

s.conn = conn
err = s.sendConnectPacket()
Expand Down Expand Up @@ -173,8 +180,7 @@
(*fn)()
}
if s.conn != nil {
_ = s.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(3*time.Second))
_ = s.conn.Close()
_ = s.conn.Close(websocket.StatusNormalClosure, "")
}
}

Expand All @@ -188,17 +194,11 @@
closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed cleanly")))
closedCleanly.Store(true)
}))
conn.SetCloseHandler(func(code int, text string) error {
closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed by server: %d %s", code, text)))
closedCleanly.Store(true)
s.client.Logger.Info().Int("code", code).Str("text", text).Msg("Websocket closed by server")
return nil
})
pongTimeoutTimer := time.NewTimer(PongTimeout)
defer pongTimeoutTimer.Stop()
wsQueue := make(chan any, 32)
closeDueToError := func(reason string) {
err := conn.Close()
err := conn.CloseNow()
if err != nil && !errors.Is(err, net.ErrClosed) {
s.client.Logger.Debug().Err(err).Msg("Error closing connection after " + reason)
}
Expand Down Expand Up @@ -299,12 +299,19 @@
}()
zerolog.Ctx(ctx).Debug().Msg("Connection established, starting read loop")
for {
messageType, p, err := conn.ReadMessage()
messageType, p, err := conn.Read(ctx)
if err != nil {
closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("failed to read message: %w", err)))
if !closedCleanly.Load() {
s.client.Logger.Err(err).Msg("Error reading message from socket")
closeDueToError("reading message failed")
var ce websocket.CloseError
if errors.As(err, &ce) {
closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("closed by server: %d %s", ce.Code, ce.Reason)))
closedCleanly.Store(true)
s.client.Logger.Info().Int("code", int(ce.Code)).Str("text", ce.Reason).Msg("Websocket closed by server")
} else {
closeErr.CompareAndSwap(nil, ptr.Ptr(fmt.Errorf("failed to read message: %w", err)))
if !closedCleanly.Load() {
s.client.Logger.Err(err).Msg("Error reading message from socket")
closeDueToError("reading message failed")
}
}
// Hacky sleep to give the ready handler time to run and set the best available error
time.Sleep(100 * time.Millisecond)
Expand All @@ -314,9 +321,9 @@
pongTimeoutTimer.Reset(PongTimeout)

switch messageType {
case websocket.TextMessage:
case websocket.MessageText:
s.client.Logger.Warn().Bytes("bytes", p).Msg("Unexpected text message in websocket")
case websocket.BinaryMessage:
case websocket.MessageBinary:
handleBinaryMessage(p)
}
}
Expand All @@ -329,12 +336,11 @@
if conn == nil {
return fmt.Errorf("not connected")
}
if err := conn.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
err := conn.WriteMessage(websocket.BinaryMessage, data)
ctx, cancel := context.WithTimeout(context.Background(), WriteTimeout)
defer cancel()
err := conn.Write(ctx, websocket.MessageBinary, data)
if exhttp.IsNetworkError(err) {
closeErr := conn.Close()
closeErr := conn.CloseNow()
if closeErr != nil && !errors.Is(err, net.ErrClosed) {
s.client.Logger.Debug().Err(closeErr).Msg("Error closing connection after network error")
return errors.Join(err, closeErr)
Expand Down
Loading