Skip to content

Commit 4a0507c

Browse files
committed
Correctness and robustness improvements across multiple subsystems
- cgroup: initialize map eagerly, fix swallowed error in cGroupWalk, add blank line before lock call for readability, declare err in var block - main: defer signal.Stop to clean up signal channel on exit - map: guard duration with 1ns floor to prevent division-by-zero in bitrate calculation; simplify checkBatchMapSupport return expression; add minDuration constant with explanatory comment - output: fix bitrate unit thresholds to use unit boundaries instead of 10x multiples; handle json.Encoder.Encode error instead of discarding - probe: warn when no KProbes were attached so output-is-empty is diagnosable - tui: replace func-pointer sort selector with atomic.Int32 index into a fixed sortFuncs array for safe concurrent access; introduce done channel so updateStatsTable goroutine exits cleanly when the app stops
1 parent c565374 commit 4a0507c

6 files changed

Lines changed: 70 additions & 28 deletions

File tree

cgroup.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const (
4141
)
4242

4343
var (
44-
cGroupCache map[uint64]string
44+
cGroupCache = make(map[uint64]string)
4545
cGroupCacheLock sync.RWMutex
4646
cGroupInitOnce sync.Once
4747
)
@@ -131,6 +131,8 @@ func cGroupWalk(dir string, mapping map[uint64]string) error {
131131
if errors.Is(err, fs.ErrNotExist) {
132132
return nil
133133
}
134+
135+
return err
134136
}
135137

136138
if !d.IsDir() {
@@ -181,9 +183,11 @@ func cGroupWatcher(objs cgroupObjects) (*perf.Reader, error) {
181183
}
182184

183185
go func() {
184-
var event cgroupCgroupevent
185-
186-
var r perf.Record
186+
var (
187+
event cgroupCgroupevent
188+
r perf.Record
189+
err error
190+
)
187191

188192
for {
189193
r, err = rd.Read()
@@ -201,6 +205,7 @@ func cGroupWatcher(objs cgroupObjects) (*perf.Reader, error) {
201205
}
202206

203207
path := bsliceToString(event.Path[:])
208+
204209
cGroupCacheLock.Lock()
205210
cGroupCache[event.Cgroupid] = strings.TrimPrefix(path, CGroupRootPath)
206211
cGroupCacheLock.Unlock()

main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ func main() {
144144
drawTUI(objsCounter, startTime)
145145
} else {
146146
signalCh := make(chan os.Signal, 1)
147+
147148
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
149+
defer signal.Stop(signalCh)
148150

149151
go func() {
150152
s := <-signalCh

map.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ import (
2929
"github.com/cilium/ebpf"
3030
)
3131

32-
const batchSize = 4096
32+
const (
33+
batchSize = 4096
34+
minDuration = 1e-9 // 1 nanosecond floor to avoid division by zero in bitrate
35+
)
3336

3437
type batchBuffers struct {
3538
keys []counterStatkey
@@ -67,11 +70,7 @@ func checkBatchMapSupport(m *ebpf.Map) bool {
6770
// BPF_MAP_LOOKUP_BATCH support requires v5.6 kernel
6871
_, err := m.BatchLookup(&cursor, keys, values, nil)
6972

70-
if err != nil && errors.Is(err, ebpf.ErrNotSupported) {
71-
return false
72-
}
73-
74-
return true
73+
return !errors.Is(err, ebpf.ErrNotSupported)
7574
}
7675

7776
// listMap lists all the entries in the given ebpf.Map, converting the counter
@@ -112,6 +111,10 @@ func listMapBatch(m *ebpf.Map, start time.Time) ([]statEntry, error) {
112111
values := buf.values
113112

114113
dur := time.Since(start).Seconds()
114+
if dur < minDuration {
115+
dur = minDuration
116+
}
117+
115118
stats := make([]statEntry, 0, m.MaxEntries())
116119

117120
var cursor ebpf.MapBatchCursor
@@ -156,6 +159,10 @@ func listMapIterate(m *ebpf.Map, start time.Time) ([]statEntry, error) {
156159
)
157160

158161
dur := time.Since(start).Seconds()
162+
if dur < minDuration {
163+
dur = minDuration
164+
}
165+
159166
stats := make([]statEntry, 0, m.MaxEntries())
160167

161168
iter := m.Iterate()

output.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"bytes"
2626
"cmp"
2727
"encoding/json"
28+
"fmt"
2829
"os"
2930
"slices"
3031
"strconv"
@@ -124,14 +125,12 @@ func formatBitrate(b float64) string {
124125
switch {
125126
case b < Kbps:
126127
return strconv.FormatFloat(b, 'f', 2, 64) + " bps"
127-
case b < 10*Kbps:
128+
case b < Mbps:
128129
return strconv.FormatFloat(b/Kbps, 'f', 2, 64) + " Kbps"
129-
case b < 10*Mbps:
130+
case b < Gbps:
130131
return strconv.FormatFloat(b/Mbps, 'f', 2, 64) + " Mbps"
131-
case b < 10*Gbps:
132+
case b < Tbps:
132133
return strconv.FormatFloat(b/Gbps, 'f', 2, 64) + " Gbps"
133-
case b < 10*Tbps:
134-
return strconv.FormatFloat(b/Tbps, 'f', 2, 64) + " Tbps"
135134
}
136135

137136
return strconv.FormatFloat(b/Tbps, 'f', 2, 64) + " Tbps"
@@ -230,7 +229,10 @@ func outputPlain(m []statEntry) string {
230229
func outputJSON(m []statEntry) {
231230
enc := json.NewEncoder(os.Stdout)
232231
enc.SetEscapeHTML(false)
233-
_ = enc.Encode(m)
232+
233+
if err := enc.Encode(m); err != nil {
234+
_, _ = fmt.Fprintf(os.Stderr, "Error encoding JSON output: %v\n", err)
235+
}
234236
}
235237

236238
// bsliceToString converts a slice of int8 values to a string by first

probe.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ func startKProbes(hooks []kprobeHook, links []link.Link) []link.Link {
6767

6868
links = slices.Grow(links, len(hooks))
6969

70+
initialLen := len(links)
71+
7072
for _, kp := range hooks {
7173
l, err = link.Kprobe(kp.kprobe, kp.prog, nil)
7274
if err != nil {
@@ -78,6 +80,10 @@ func startKProbes(hooks []kprobeHook, links []link.Link) []link.Link {
7880
links = append(links, l)
7981
}
8082

83+
if len(links) == initialLen {
84+
log.Printf("Warning: no KProbes were successfully attached; output will be empty")
85+
}
86+
8187
log.Printf("Using KProbes mode w/ PID/comm tracking")
8288

8389
return links

tui.go

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ package main
2525
import (
2626
"fmt"
2727
"strconv"
28+
"sync/atomic"
2829
"time"
2930

3031
"github.com/gdamore/tcell/v2"
@@ -54,9 +55,19 @@ const (
5455
// The TUI is interactive: pressing 'q' or 'x' will exit the program,
5556
// pressing 'r' or 'l' will redraw the TUI, and pressing any other key will
5657
// do nothing.
58+
// sortFuncs maps the atomic sort index to the corresponding sort function.
59+
var sortFuncs = [...]func([]statEntry){
60+
bitrateSort, // 0
61+
packetSort, // 1
62+
bytesSort, // 2
63+
srcIPSort, // 3
64+
dstIPSort, // 4
65+
}
66+
5767
func drawTUI(objs counterObjects, startTime time.Time) {
5868
app := tview.NewApplication()
59-
tableSort := bitrateSort
69+
70+
var tableSortIdx atomic.Int32 // 0 = bitrateSort (default)
6071

6172
// packet statistics
6273
statsTable := tview.NewTable().
@@ -72,23 +83,23 @@ func drawTUI(objs counterObjects, startTime time.Time) {
7283
statsTable.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
7384
switch event.Rune() {
7485
case '0':
75-
tableSort = bitrateSort
86+
tableSortIdx.Store(0)
7687

7788
statsTable.Select(0, 0)
7889
case '1':
79-
tableSort = packetSort
90+
tableSortIdx.Store(1)
8091

8192
statsTable.Select(0, 0)
8293
case '2':
83-
tableSort = bytesSort
94+
tableSortIdx.Store(2)
8495

8596
statsTable.Select(0, 0)
8697
case '3':
87-
tableSort = srcIPSort
98+
tableSortIdx.Store(3)
8899

89100
statsTable.Select(0, 0)
90101
case '4':
91-
tableSort = dstIPSort
102+
tableSortIdx.Store(4)
92103

93104
statsTable.Select(0, 0)
94105
case 'q', 'x', 'Q', 'X':
@@ -131,12 +142,17 @@ func drawTUI(objs counterObjects, startTime time.Time) {
131142
AddItem(statsTable, 1, 0, 1, 1, 0, 0, true).
132143
AddItem(naviView, 2, 0, 1, 1, 0, 0, false)
133144

134-
// start the update loop
135-
go updateStatsTable(app, statsTable, &tableSort, objs, startTime)
145+
// start the update loop; done is closed when app.Run() returns so the
146+
// goroutine can exit instead of blocking on a stopped application.
147+
done := make(chan struct{})
148+
149+
go updateStatsTable(app, statsTable, &tableSortIdx, objs, startTime, done)
136150

137151
_ = app.SetRoot(grid, true).
138152
SetFocus(statsTable).
139153
Run()
154+
155+
close(done)
140156
}
141157

142158
// updateStatsTable starts an infinite loop that updates the given table with
@@ -159,8 +175,8 @@ func drawTUI(objs counterObjects, startTime time.Time) {
159175
//
160176
// Note that the table is cleared and recreated on each iteration, so any cell
161177
// attributes are lost on each iteration.
162-
func updateStatsTable(app *tview.Application, table *tview.Table, tableSort *func(stats []statEntry),
163-
objs counterObjects, startTime time.Time,
178+
func updateStatsTable(app *tview.Application, table *tview.Table, tableSortIdx *atomic.Int32,
179+
objs counterObjects, startTime time.Time, done <-chan struct{},
164180
) {
165181
ticker := time.NewTicker(*refresh)
166182
defer ticker.Stop()
@@ -186,7 +202,7 @@ func updateStatsTable(app *tview.Application, table *tview.Table, tableSort *fun
186202

187203
for {
188204
// read eBPF map outside the draw closure so the UI goroutine is not blocked on the syscall
189-
snapshot, _ := processMap(objs.PktCount, startTime, *tableSort)
205+
snapshot, _ := processMap(objs.PktCount, startTime, sortFuncs[tableSortIdx.Load()])
190206

191207
app.QueueUpdateDraw(func() {
192208
table.Clear()
@@ -273,6 +289,10 @@ func updateStatsTable(app *tview.Application, table *tview.Table, tableSort *fun
273289
}
274290
})
275291

276-
<-ticker.C
292+
select {
293+
case <-ticker.C:
294+
case <-done:
295+
return
296+
}
277297
}
278298
}

0 commit comments

Comments
 (0)