Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
)

var (
ErrDestinationClosed = fmt.Errorf("destination closed: %w", errEndpointClosed)
ErrDestinationClosed = fmt.Errorf("destination closed: %w: %w", errEndpointClosed, net.ErrClosed)
errDestinationConnUpdated = errors.New("destination connection updated")
errDestinationConnRemoved = errors.New("destination connection removed")
)
Expand Down Expand Up @@ -71,7 +71,7 @@ func (d *Destination) AcceptContext(ctx context.Context) (net.Conn, error) {
return nil, ctx.Err()
case conn, ok := <-d.acceptCh:
if !ok {
return nil, fmt.Errorf("destination %s is closed: %w", d.cfg.Endpoint, net.ErrClosed)
return nil, fmt.Errorf("destination %s is closed: %w", d.cfg.Endpoint, ErrDestinationClosed)
}
return conn, nil
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ func TestE2E(t *testing.T) {
require.Empty(t, cl.Sources())

acceptConn, acceptErr := dst.Accept()
require.ErrorIs(t, acceptErr, net.ErrClosed)
require.ErrorIs(t, acceptErr, connet.ErrDestinationClosed)
require.Nil(t, acceptConn)

dialConn, dialErr := src.Dial("", "")
require.ErrorIs(t, dialErr, connet.ErrSourceNoActiveDestinations)
require.ErrorIs(t, dialErr, connet.ErrSourceClosed)
require.Nil(t, dialConn)
})
t.Run("cancel-client", func(t *testing.T) {
Expand Down Expand Up @@ -333,11 +333,11 @@ func TestE2E(t *testing.T) {
require.Empty(t, cl.Sources())

acceptConn, acceptErr := dst.Accept()
require.ErrorIs(t, acceptErr, net.ErrClosed)
require.ErrorIs(t, acceptErr, connet.ErrDestinationClosed)
require.Nil(t, acceptConn)

dialConn, dialErr := src.Dial("", "")
require.ErrorIs(t, dialErr, connet.ErrSourceNoActiveDestinations)
require.ErrorIs(t, dialErr, connet.ErrSourceClosed)
require.Nil(t, dialConn)
})
t.Run("close-dst", func(t *testing.T) {
Expand Down
22 changes: 20 additions & 2 deletions source.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

var (
ErrSourceClosed = fmt.Errorf("source closed: %w", errEndpointClosed)
ErrSourceClosed = fmt.Errorf("source closed: %w: %w", errEndpointClosed, net.ErrClosed)
ErrSourceConnect = errors.New("cannot connect destination")
ErrSourceConnectDestinations = fmt.Errorf("%w: all peer connections were unreachable", ErrSourceConnect)
ErrSourceNoActiveDestinations = fmt.Errorf("%w: no active peer connections", ErrSourceConnect)
Expand Down Expand Up @@ -88,6 +88,10 @@ func (s *Source) Dial(network, address string) (net.Conn, error) {
// DialContext dials into any available destination. Both network and address are ignored.
// Blocks until connection can be established.
func (s *Source) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if err := s.ep.ctx.Err(); err != nil {
return nil, fmt.Errorf("dial context: %w", context.Cause(s.ep.ctx))
}

if s.cfg.DestinationPolicy == NoPolicy {
conns, err := s.findActive()
if err != nil {
Expand Down Expand Up @@ -157,6 +161,20 @@ func (s *Source) runActive(ctx context.Context) {
}

func (s *Source) runActiveErr(ctx context.Context) error {
defer func() {
s.conns.Store(new([]sourceConn))
}()

isConnsTracking := s.cfg.DestinationPolicy == LeastConnsPolicy
if isConnsTracking {
defer func() {
s.connsTrackingMu.Lock()
defer s.connsTrackingMu.Unlock()

s.connsTracking = map[peerID]*atomic.Int32{}
}()
}

return s.ep.peer.activeConnsListen(ctx, func(active map[peerConnKey]*quic.Conn) error {
s.logger.Debug("active conns", "len", len(active))

Expand All @@ -166,7 +184,7 @@ func (s *Source) runActiveErr(ctx context.Context) error {
}
s.conns.Store(&conns)

if s.connsTracking != nil {
if isConnsTracking {
activePeers := map[peerID]struct{}{}
for peer := range active {
activePeers[peer.id] = struct{}{}
Expand Down
Loading