Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions pkg/tcpip/network/ipv4/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,21 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
localAddressTemporary := pkt.NetworkPacketInfo.LocalAddressTemporary
localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast

// It's possible that a raw socket or custom defaultHandler expects to
// receive this packet.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
// It's possible that a raw socket or per-stack default handler expects
// to receive this packet.
defaultHandlerHandled := false
if dispatcher, ok := e.dispatcher.(stack.TransportDispatcherWithDefaultHandlerResult); ok {
_, defaultHandlerHandled = dispatcher.DeliverTransportPacketWithDefaultHandlerResult(header.ICMPv4ProtocolNumber, pkt)
} else {
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
}
pkt = nil

// Skip direct ICMP echo reply if the packet was received with a temporary
// address, allowing custom handlers to take over.
if localAddressTemporary {
// Skip the built-in ICMP echo reply if the request was consumed by a
// per-stack default handler. Also preserve the IPv4 behavior for
// temporary local addresses: the packet is delivered above, but the
// stack does not synthesize an echo reply for it.
if defaultHandlerHandled || localAddressTemporary {
return
}

Expand Down
313 changes: 313 additions & 0 deletions pkg/tcpip/network/ipv4/ipv4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3886,6 +3886,319 @@ func TestCloseLocking(t *testing.T) {
}()
}

func TestICMPEchoDefaultHandlerControlsReply(t *testing.T) {
var (
localAddr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
remoteAddr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 24,
},
}
)

tests := []struct {
name string
installHandler bool
handled bool
wantHandlerCalled bool
wantReply bool
}{
{
name: "no default handler",
wantReply: true,
},
{
name: "default handler handled",
installHandler: true,
handled: true,
wantHandlerCalled: true,
wantReply: false,
},
{
name: "default handler not handled",
installHandler: true,
handled: false,
wantHandlerCalled: true,
wantReply: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
Clock: clock,
})
defer func() {
s.Close()
s.Wait()
refs.DoRepeatedLeakCheck()
}()

const ident = 1234
handlerCalled := false
if test.installHandler {
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
handlerCalled = true
if got := id.LocalPort; got != ident {
t.Errorf("got id.LocalPort = %d, want = %d", got, ident)
}
return test.handled
})
}

e := channel.New(1, defaultMTU, "")
defer e.Close()
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: localAddr.AddressWithPrefix.Subnet(),
NIC: nicID,
}})

totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := prependable.New(totalLength)
icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
icmpH.SetIdent(ident)
icmpH.SetType(header.ICMPv4Echo)
icmpH.SetCode(header.ICMPv4UnusedCode)
icmpH.SetChecksum(0)
icmpH.SetChecksum(^checksum.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
Protocol: uint8(icmp.ProtocolNumber4),
TTL: ipv4.DefaultTTL,
SrcAddr: remoteAddr.AddressWithPrefix.Address,
DstAddr: localAddr.AddressWithPrefix.Address,
})
ip.SetChecksum(^ip.CalculateChecksum())
echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
e.InjectInbound(header.IPv4ProtocolNumber, echoPkt)
echoPkt.DecRef()
clock.RunImmediatelyScheduledJobs()

if got, want := handlerCalled, test.wantHandlerCalled; got != want {
t.Fatalf("got handlerCalled = %t, want = %t", got, want)
}

p := e.Read()
if !test.wantReply {
if p != nil {
p.DecRef()
t.Fatalf("got unexpected ICMP echo reply")
}
return
}
if p == nil {
t.Fatalf("expected ICMP echo reply")
}
defer p.DecRef()
payload := stack.PayloadSince(p.NetworkHeader())
defer payload.Release()
checker.IPv4(t, payload,
checker.SrcAddr(localAddr.AddressWithPrefix.Address),
checker.DstAddr(remoteAddr.AddressWithPrefix.Address),
checker.ICMPv4(
checker.ICMPv4Type(header.ICMPv4EchoReply),
checker.ICMPv4Code(header.ICMPv4UnusedCode)))
})
}
}

func TestICMPEchoRegisteredEndpointDoesNotSuppressReply(t *testing.T) {
localAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
remoteAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 24,
},
}

clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
Clock: clock,
})
defer func() {
s.Close()
s.Wait()
refs.DoRepeatedLeakCheck()
}()

e := channel.New(1, defaultMTU, "")
defer e.Close()
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddProtocolAddress(nicID, localAddr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, localAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: localAddr.AddressWithPrefix.Subnet(),
NIC: nicID,
}})

const ident = 1234
var wq waiter.Queue
ep, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Addr: localAddr.AddressWithPrefix.Address, Port: ident}); err != nil {
t.Fatalf("ep.Bind(...) = %s", err)
}

handlerCalled := false
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, func(stack.TransportEndpointID, *stack.PacketBuffer) bool {
handlerCalled = true
return true
})

totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := prependable.New(totalLength)
icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
icmpH.SetIdent(ident)
icmpH.SetType(header.ICMPv4Echo)
icmpH.SetCode(header.ICMPv4UnusedCode)
icmpH.SetChecksum(0)
icmpH.SetChecksum(^checksum.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
Protocol: uint8(icmp.ProtocolNumber4),
TTL: ipv4.DefaultTTL,
SrcAddr: remoteAddr.AddressWithPrefix.Address,
DstAddr: localAddr.AddressWithPrefix.Address,
})
ip.SetChecksum(^ip.CalculateChecksum())
echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
e.InjectInbound(header.IPv4ProtocolNumber, echoPkt)
echoPkt.DecRef()
clock.RunImmediatelyScheduledJobs()

if handlerCalled {
t.Fatalf("default handler was unexpectedly called")
}

p := e.Read()
if p == nil {
t.Fatalf("expected ICMP echo reply")
}
defer p.DecRef()
payload := stack.PayloadSince(p.NetworkHeader())
defer payload.Release()
checker.IPv4(t, payload,
checker.SrcAddr(localAddr.AddressWithPrefix.Address),
checker.DstAddr(remoteAddr.AddressWithPrefix.Address),
checker.ICMPv4(
checker.ICMPv4Type(header.ICMPv4EchoReply),
checker.ICMPv4Code(header.ICMPv4UnusedCode)))
}

func TestICMPEchoTemporaryAddressSuppressesReply(t *testing.T) {
assignedAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
temporaryAddr := tcpip.AddrFromSlice(net.ParseIP("192.168.0.99").To4())
remoteAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 24,
},
}

clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
Clock: clock,
})
defer func() {
s.Close()
s.Wait()
refs.DoRepeatedLeakCheck()
}()

e := channel.New(1, defaultMTU, "")
defer e.Close()
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddProtocolAddress(nicID, assignedAddr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, assignedAddr, err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: assignedAddr.AddressWithPrefix.Subnet(),
NIC: nicID,
}})

const ident = 1234
totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := prependable.New(totalLength)
icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
icmpH.SetIdent(ident)
icmpH.SetType(header.ICMPv4Echo)
icmpH.SetCode(header.ICMPv4UnusedCode)
icmpH.SetChecksum(0)
icmpH.SetChecksum(^checksum.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
Protocol: uint8(icmp.ProtocolNumber4),
TTL: ipv4.DefaultTTL,
SrcAddr: remoteAddr.AddressWithPrefix.Address,
DstAddr: temporaryAddr,
})
ip.SetChecksum(^ip.CalculateChecksum())
echoPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
e.InjectInbound(header.IPv4ProtocolNumber, echoPkt)
echoPkt.DecRef()
clock.RunImmediatelyScheduledJobs()

if p := e.Read(); p != nil {
p.DecRef()
t.Fatalf("got unexpected ICMP echo reply")
}
}

func TestIcmpRateLimit(t *testing.T) {
var (
host1IPv4Addr = tcpip.ProtocolAddress{
Expand Down
Loading