Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 139 additions & 1 deletion inside.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,75 @@ import (
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/udp"
)

// consumeInsidePacketBatched is a variant of consumeInsidePacket that queues
// outgoing packets into pendingPackets instead of sending them immediately.
// The caller is responsible for flushing pendingPackets with WriteBatch.
func (f *Interface) consumeInsidePacketBatched(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache, pendingPackets *[]udp.BatchPacket) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
}
return
}

// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
return
}
}

if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
if err != nil {
f.l.WithError(err).Error("Failed to forward to tun")
}
}
return
}

// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return
}

hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})

if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
return
}

if !ready {
return
}

dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetricsBatched(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, pendingPackets)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
}
}

func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
Expand Down Expand Up @@ -69,7 +136,6 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)

} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
Expand Down Expand Up @@ -410,3 +476,75 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}

// sendNoMetricsBatched is like sendNoMetrics but queues the packet for batched sending
// instead of sending immediately. The caller must flush pendingPackets with WriteBatch.
func (f *Interface) sendNoMetricsBatched(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, pendingPackets *[]udp.BatchPacket) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out

if useRelay {
if len(out) < header.Len {
out = out[:header.Len]
}
out = out[header.Len:]
}

if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)

out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)

if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}

var err error
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}

// Queue the packet for batched sending
var addr netip.AddrPort
if remote.IsValid() {
addr = remote
} else if hostinfo.remote.IsValid() {
addr = hostinfo.remote
} else {
// Relay path - send immediately, not batched
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetricsBatched failed to find HostInfo")
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
}
return
}

// Copy the payload since the buffer will be reused
payload := make([]byte, len(out))
copy(payload, out)
*pendingPackets = append(*pendingPackets, udp.BatchPacket{Payload: payload, Addr: addr})
}
77 changes: 73 additions & 4 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type InterfaceConfig struct {

ConntrackCacheTimeout time.Duration
l *logrus.Logger

tunBatchSize int // batch size for TUN read/write batching, 0 to disable
}

type Interface struct {
Expand Down Expand Up @@ -86,8 +88,9 @@ type Interface struct {

conntrackCacheTimeout time.Duration

writers []udp.Conn
readers []io.ReadWriteCloser
writers []udp.Conn
readers []io.ReadWriteCloser
tunBatchSize int // batch size for TUN read/write batching

metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
Expand Down Expand Up @@ -187,6 +190,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager,
connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
tunBatchSize: c.tunBatchSize,

metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
Expand Down Expand Up @@ -244,6 +248,15 @@ func (f *Interface) activate() {
f.readers[i] = reader
}

// Enable batch reading on all readers if batch size > 1
if f.tunBatchSize > 1 {
for i := 0; i < f.routines; i++ {
if err := overlay.EnableBatchReading(f.readers[i]); err != nil {
f.l.WithError(err).WithField("routine", i).Warn("Failed to enable batch reading, falling back to single reads")
}
}
}

if err := f.inside.Activate(); err != nil {
f.inside.Close()
f.l.Fatal(err)
Expand Down Expand Up @@ -287,13 +300,21 @@ func (f *Interface) listenOut(i int) {
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread()

conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)

// Check if batch reading is available and enabled
batchReader := overlay.AsBatchReader(reader)
if batchReader != nil && f.tunBatchSize > 1 {
f.listenInBatched(reader, batchReader, i, conntrackCache)
return
}

// Fallback to single-packet reading
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)

conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)

for {
n, err := reader.Read(packet)
if err != nil {
Expand All @@ -310,6 +331,54 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
}
}

func (f *Interface) listenInBatched(reader io.ReadWriteCloser, batchReader overlay.BatchReader, i int, conntrackCache *firewall.ConntrackCacheTicker) {
batchSize := f.tunBatchSize

// Pre-allocate buffers for batch reading
packets := make([][]byte, batchSize)
for j := range packets {
packets[j] = make([]byte, mtu)
}
sizes := make([]int, batchSize)

// Pre-allocate buffers for packet processing
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)

// Pre-allocate buffer for batched UDP writes
pendingPackets := make([]udp.BatchPacket, 0, batchSize)

for {
// Read a batch of packets from TUN
n, err := batchReader.ReadBatch(packets, sizes)
if err != nil {
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return
}

f.l.WithError(err).Error("Error while reading outbound packets")
os.Exit(2)
}

if n == 0 {
continue
}

// Process all packets in the batch
cache := conntrackCache.Get(f.l)
for j := 0; j < n; j++ {
f.consumeInsidePacketBatched(packets[j][:sizes[j]], fwPacket, nb, out, i, cache, &pendingPackets)
}

// Flush all pending UDP writes
if len(pendingPackets) > 0 {
f.writers[i].WriteBatch(pendingPackets)
pendingPackets = pendingPackets[:0]
}
}
}

func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
tunBatchSize: c.GetInt("listen.batch", 64),
}

var ifce *Interface
Expand Down
35 changes: 35 additions & 0 deletions overlay/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,38 @@ type Device interface {
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

// BatchReader is an optional interface that devices can implement
// to support reading multiple packets in a single batch operation.
// This can significantly reduce syscall overhead under high load.
type BatchReader interface {
// ReadBatch reads up to len(packets) packets into the provided buffers.
// Each packet is read into packets[i] and its length is stored in sizes[i].
// Returns the number of packets read, or an error.
// A return of (0, nil) indicates no packets were available (non-blocking).
ReadBatch(packets [][]byte, sizes []int) (int, error)
}

// AsBatchReader returns a BatchReader if the reader supports batch operations,
// otherwise returns nil.
func AsBatchReader(r io.ReadWriteCloser) BatchReader {
if br, ok := r.(BatchReader); ok {
return br
}
return nil
}

// BatchEnabler is an optional interface for devices that need explicit
// enabling of batch read support (e.g., setting non-blocking mode).
type BatchEnabler interface {
EnableBatchReading() error
}

// EnableBatchReading enables batch reading on the device if supported.
// Returns nil if the device doesn't support or need explicit enabling.
func EnableBatchReading(d interface{}) error {
if be, ok := d.(BatchEnabler); ok {
return be.EnableBatchReading()
}
return nil
}
Loading
Loading