Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ func (s *SQLiteDB) LoadAll() ([]*Link, error) {
if err != nil {
return nil, err
}
defer rows.Close()

for rows.Next() {
link := new(Link)
var created, lastEdit int64
Expand Down Expand Up @@ -178,6 +180,8 @@ func (s *SQLiteDB) LoadStats() (ClickStats, error) {
if err != nil {
return nil, err
}
defer rows.Close()

stats := make(map[string]int)
for rows.Next() {
var id string
Expand All @@ -192,6 +196,36 @@ func (s *SQLiteDB) LoadStats() (ClickStats, error) {
return stats, rows.Err()
}

// LoadStatsSince returns click stats for links since the given time.
func (s *SQLiteDB) LoadStatsSince(since time.Time) (ClickStats, error) {
s.mu.RLock()
defer s.mu.RUnlock()

rows, err := s.db.Query(`
SELECT Links.Short, sum(Stats.Clicks)
FROM Stats
JOIN Links ON Stats.ID = Links.ID
WHERE Stats.Created >= ?
GROUP BY Links.ID, Links.Short
`, since.Unix())
if err != nil {
return nil, err
}
Comment thread
dfcarney marked this conversation as resolved.
defer rows.Close()

stats := make(map[string]int)
for rows.Next() {
var short string
var clicks int
err := rows.Scan(&short, &clicks)
if err != nil {
return nil, err
}
stats[short] = clicks
}
return stats, rows.Err()
}

// SaveStats records click stats for links. The provided map includes
// incremental clicks that have occurred since the last time SaveStats
// was called.
Expand Down Expand Up @@ -236,6 +270,8 @@ func (s *SQLiteDB) GetLinksByOwner(owner string) ([]*Link, error) {
if err != nil {
return nil, err
}
defer rows.Close()

for rows.Next() {
link := new(Link)
var created, lastEdit int64
Expand Down
55 changes: 55 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ package golink
import (
"path"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/tstest"
)

// Test saving, loading, and deleting links for SQLiteDB.
Expand Down Expand Up @@ -126,6 +128,59 @@ func Test_SQLiteDB_SaveLoadDeleteStats(t *testing.T) {
}
}

// Test LoadStatsSince returns only stats since a given time.
func Test_SQLiteDB_LoadStatsSince(t *testing.T) {
clock := tstest.NewClock(tstest.ClockOpts{
Start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
})

db, err := NewSQLiteDB(path.Join(t.TempDir(), "links.db"))
if err != nil {
t.Fatal(err)
}
db.clock = clock

// preload links
for _, link := range []*Link{{Short: "a"}, {Short: "b"}} {
if err := db.Save(link); err != nil {
t.Fatal(err)
}
}

// save stats at initial time (old stats)
if err := db.SaveStats(ClickStats{"a": 5, "b": 3}); err != nil {
t.Fatal(err)
}

// advance 20 days and save more stats (recent stats)
clock.Advance(20 * 24 * time.Hour)
if err := db.SaveStats(ClickStats{"a": 2, "b": 7, "missing": 11}); err != nil {
t.Fatal(err)
}

// LoadStatsSince 10 days ago should only return the recent stats
since := clock.Now().Add(-10 * 24 * time.Hour)
got, err := db.LoadStatsSince(since)
if err != nil {
t.Fatal(err)
}
want := ClickStats{"a": 2, "b": 7}
if !cmp.Equal(got, want) {
t.Errorf("LoadStatsSince got %v, want %v", got, want)
}

// LoadStatsSince 30 days ago should return all stats
since = clock.Now().Add(-30 * 24 * time.Hour)
got, err = db.LoadStatsSince(since)
if err != nil {
t.Fatal(err)
}
want = ClickStats{"a": 7, "b": 10}
if !cmp.Equal(got, want) {
t.Errorf("LoadStatsSince got %v, want %v", got, want)
}
}

// Test GetLinksByOwner functionality
func Test_SQLiteDB_GetLinksByOwner(t *testing.T) {
db, err := NewSQLiteDB(path.Join(t.TempDir(), "links.db"))
Expand Down
52 changes: 45 additions & 7 deletions golink.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"sync"
texttemplate "text/template"
Expand All @@ -45,6 +46,7 @@ import (

const (
defaultHostname = "go"
maxStatsPeriod = 365 * 24 * time.Hour

// Used as a placeholder short name for generating the XSRF defense token,
// when creating new links.
Expand Down Expand Up @@ -344,6 +346,7 @@ type homeData struct {
XSRF string
ReadOnly bool
User string
Period string
}

// deleteData is the data used by deleteTmpl.
Expand Down Expand Up @@ -535,14 +538,32 @@ func serveHandler() http.Handler {
func serveHome(w http.ResponseWriter, r *http.Request, short string) {
var clicks []visitData

stats.mu.Lock()
for short, numClicks := range stats.clicks {
clicks = append(clicks, visitData{
Short: short,
NumClicks: numClicks,
})
period := r.URL.Query().Get("period")
if periodDuration, ok := parseStatsPeriod(period); ok {
since := db.Now().Add(-periodDuration)

clickStats, err := db.LoadStatsSince(since)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
for s, numClicks := range clickStats {
clicks = append(clicks, visitData{
Short: s,
NumClicks: numClicks,
})
}
} else {
period = ""
stats.mu.Lock()
for short, numClicks := range stats.clicks {
clicks = append(clicks, visitData{
Short: short,
NumClicks: numClicks,
})
}
stats.mu.Unlock()
Comment on lines +542 to +565
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: I know it's in the PR description of a 1 minute potential staleness. A comment here describing the delta would help also for future devs.

}
stats.mu.Unlock()

sort.Slice(clicks, func(i, j int) bool {
if clicks[i].NumClicks != clicks[j].NumClicks {
Expand Down Expand Up @@ -580,9 +601,26 @@ func serveHome(w http.ResponseWriter, r *http.Request, short string) {
XSRF: xsrftoken.Generate(xsrfKey, cu.login, newShortName),
ReadOnly: *readonly,
User: cu.login,
Period: period,
})
}

func parseStatsPeriod(period string) (time.Duration, bool) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue(minor, non-blocking): I think there's a subtle UI bug here that may or may not be worth addressing.

If someone hits the url with ?period=7d the UI template will correctly highlight.

But if that URL is hit with some other period that is 7 days ?period=168h, ?period=7d0h0m, the DB query would be correct, but the UI would not highlight anything because the raw Period from line 541 is returned in homeData on line 604, it's not normalized for the template to handle correctly.

if period == "" {
return 0, false
}
if d, err := time.ParseDuration(period); err == nil && d > 0 && d <= maxStatsPeriod {
return d, true
}
if days, ok := strings.CutSuffix(period, "d"); ok {
n, err := strconv.Atoi(days)
if err == nil && n > 0 && n <= int(maxStatsPeriod/(24*time.Hour)) {
return time.Duration(n) * 24 * time.Hour, true
}
}
return 0, false
}

func serveAll(w http.ResponseWriter, _ *http.Request) {
if err := flushStats(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
166 changes: 166 additions & 0 deletions golink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,172 @@ func TestServeSearch(t *testing.T) {
}
}

func TestServeHomePeriodFilter(t *testing.T) {
clock := tstest.NewClock(tstest.ClockOpts{
Start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
})

var err error
db, err = NewSQLiteDB(":memory:")
if err != nil {
t.Fatal(err)
}
db.clock = clock

if err := db.Save(&Link{Short: "old-link", Long: "http://old/"}); err != nil {
t.Fatalf("saving old-link: %v", err)
}
if err := db.Save(&Link{Short: "new-link", Long: "http://new/"}); err != nil {
t.Fatalf("saving new-link: %v", err)
}

// save old stats (40 days ago)
if err := db.SaveStats(ClickStats{"old-link": 10}); err != nil {
t.Fatal(err)
}

// advance 40 days and save recent stats
clock.Advance(40 * 24 * time.Hour)
if err := db.SaveStats(ClickStats{"new-link": 5}); err != nil {
t.Fatal(err)
}

if err := initStats(); err != nil {
t.Fatal(err)
}

tests := []struct {
name string
period string
wantStatus int
wantContains []string
wantNotContains []string
}{
{
name: "all time shows all links",
period: "",
wantStatus: http.StatusOK,
wantContains: []string{"old-link", "new-link"},
},
{
name: "7d shows only recent link",
period: "7d",
wantStatus: http.StatusOK,
wantContains: []string{"new-link"},
wantNotContains: []string{"old-link"},
},
{
name: "duration shows only recent link",
period: "168h",
wantStatus: http.StatusOK,
wantContains: []string{"new-link"},
wantNotContains: []string{"old-link"},
},
{
name: "30d shows only recent link",
period: "30d",
wantStatus: http.StatusOK,
wantContains: []string{"new-link"},
wantNotContains: []string{"old-link"},
},
{
name: "invalid period falls back to all time",
period: "nope",
wantStatus: http.StatusOK,
wantContains: []string{"old-link", "new-link"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := "/"
if tt.period != "" {
url = "/?period=" + tt.period
}
r := httptest.NewRequest("GET", url, nil)
w := httptest.NewRecorder()
serveHandler().ServeHTTP(w, r)

if w.Code != tt.wantStatus {
t.Errorf("serveHome(?period=%s) = %d; want %d", tt.period, w.Code, tt.wantStatus)
}

body := w.Body.String()
for _, s := range tt.wantContains {
if !strings.Contains(body, s) {
t.Errorf("serveHome(?period=%s) body missing %q", tt.period, s)
}
}
for _, s := range tt.wantNotContains {
if strings.Contains(body, s) {
t.Errorf("serveHome(?period=%s) body unexpectedly contains %q", tt.period, s)
}
}
})
}
}

func TestParseStatsPeriod(t *testing.T) {
tests := []struct {
name string
period string
want time.Duration
wantOK bool
}{
{
name: "empty",
period: "",
},
{
name: "days",
period: "7d",
want: 7 * 24 * time.Hour,
wantOK: true,
},
{
name: "duration",
period: "168h",
want: 168 * time.Hour,
wantOK: true,
},
{
name: "max duration",
period: "8760h",
want: maxStatsPeriod,
wantOK: true,
},
{
name: "too long duration",
period: "8761h",
},
{
name: "too many days",
period: "366d",
},
{
name: "negative",
period: "-7d",
},
{
name: "overflowing days",
period: "100000000000000000000d",
},
{
name: "invalid",
period: "nope",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := parseStatsPeriod(tt.period)
if got != tt.want || ok != tt.wantOK {
t.Errorf("parseStatsPeriod(%q) = %v, %v; want %v, %v", tt.period, got, ok, tt.want, tt.wantOK)
}
})
}
}

func TestParseAdvertiseTags(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading