diff --git a/destination.go b/destination.go index e5a1b2d..131d012 100644 --- a/destination.go +++ b/destination.go @@ -20,7 +20,7 @@ import ( ) var ( - ErrDestinationClosed = fmt.Errorf("destination closed: %w", errEndpointClosed) + ErrDestinationClosed = fmt.Errorf("destination closed: %w: %w", errEndpointClosed, net.ErrClosed) errDestinationConnUpdated = errors.New("destination connection updated") errDestinationConnRemoved = errors.New("destination connection removed") ) @@ -71,7 +71,7 @@ func (d *Destination) AcceptContext(ctx context.Context) (net.Conn, error) { return nil, ctx.Err() case conn, ok := <-d.acceptCh: if !ok { - return nil, fmt.Errorf("destination %s is closed: %w", d.cfg.Endpoint, net.ErrClosed) + return nil, fmt.Errorf("destination %s is closed: %w", d.cfg.Endpoint, ErrDestinationClosed) } return conn, nil } diff --git a/pkg/e2e/e2e_test.go b/pkg/e2e/e2e_test.go index 99a1ca7..2d1b204 100644 --- a/pkg/e2e/e2e_test.go +++ b/pkg/e2e/e2e_test.go @@ -293,11 +293,11 @@ func TestE2E(t *testing.T) { require.Empty(t, cl.Sources()) acceptConn, acceptErr := dst.Accept() - require.ErrorIs(t, acceptErr, net.ErrClosed) + require.ErrorIs(t, acceptErr, connet.ErrDestinationClosed) require.Nil(t, acceptConn) dialConn, dialErr := src.Dial("", "") - require.ErrorIs(t, dialErr, connet.ErrSourceNoActiveDestinations) + require.ErrorIs(t, dialErr, connet.ErrSourceClosed) require.Nil(t, dialConn) }) t.Run("cancel-client", func(t *testing.T) { @@ -333,11 +333,11 @@ func TestE2E(t *testing.T) { require.Empty(t, cl.Sources()) acceptConn, acceptErr := dst.Accept() - require.ErrorIs(t, acceptErr, net.ErrClosed) + require.ErrorIs(t, acceptErr, connet.ErrDestinationClosed) require.Nil(t, acceptConn) dialConn, dialErr := src.Dial("", "") - require.ErrorIs(t, dialErr, connet.ErrSourceNoActiveDestinations) + require.ErrorIs(t, dialErr, connet.ErrSourceClosed) require.Nil(t, dialConn) }) t.Run("close-dst", func(t *testing.T) { diff --git a/source.go b/source.go index 2b94fde..755a22b 100644 --- a/source.go +++ b/source.go @@ -27,7 +27,7 @@ import ( ) var ( - ErrSourceClosed = fmt.Errorf("source closed: %w", errEndpointClosed) + ErrSourceClosed = fmt.Errorf("source closed: %w: %w", errEndpointClosed, net.ErrClosed) ErrSourceConnect = errors.New("cannot connect destination") ErrSourceConnectDestinations = fmt.Errorf("%w: all peer connections were unreachable", ErrSourceConnect) ErrSourceNoActiveDestinations = fmt.Errorf("%w: no active peer connections", ErrSourceConnect) @@ -88,6 +88,10 @@ func (s *Source) Dial(network, address string) (net.Conn, error) { // DialContext dials into any available destination. Both network and address are ignored. // Blocks until connection can be established. func (s *Source) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := s.ep.ctx.Err(); err != nil { + return nil, fmt.Errorf("dial context: %w", context.Cause(s.ep.ctx)) + } + if s.cfg.DestinationPolicy == NoPolicy { conns, err := s.findActive() if err != nil { @@ -157,6 +161,20 @@ func (s *Source) runActive(ctx context.Context) { } func (s *Source) runActiveErr(ctx context.Context) error { + defer func() { + s.conns.Store(new([]sourceConn)) + }() + + isConnsTracking := s.cfg.DestinationPolicy == LeastConnsPolicy + if isConnsTracking { + defer func() { + s.connsTrackingMu.Lock() + defer s.connsTrackingMu.Unlock() + + s.connsTracking = map[peerID]*atomic.Int32{} + }() + } + return s.ep.peer.activeConnsListen(ctx, func(active map[peerConnKey]*quic.Conn) error { s.logger.Debug("active conns", "len", len(active)) @@ -166,7 +184,7 @@ func (s *Source) runActiveErr(ctx context.Context) error { } s.conns.Store(&conns) - if s.connsTracking != nil { + if isConnsTracking { activePeers := map[peerID]struct{}{} for peer := range active { activePeers[peer.id] = struct{}{}