Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions internal/datastore/proxy/checkingreplicated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,84 @@ import (
"github.com/authzed/spicedb/pkg/tuple"
)

func TestCheckingReplicatedWithNoReplicasReturnsPrimary(t *testing.T) {
Copy link
Copy Markdown
Contributor

@miparnisari miparnisari Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a godoc in the prod code that says

// NewCheckingReplicatedDatastore (...)
// NOTE: Be *very* careful when using this function. It is not safe to use this function without
// knowledge of the layout of the underlying datastore and its replicas.
// the replicas *must* point to a *stable* instance of the datastore (not a load balancer).

is it possible to encode this knowledge in a test?

Similarly, for NewStrictReplicatedDatastore:

// NewStrictReplicatedDatastore (...)
// This is useful when the read pool points to a load balancer that can transparently handle the request.

can we encode that in a test?

(see #2525 (comment) for more context)

primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil}

ds, err := NewCheckingReplicatedDatastore(primary)
require.NoError(t, err)
require.Equal(t, primary, ds)
}

func TestCheckingReplicatedRoundRobinsAcrossReplicas(t *testing.T) {
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil}
replicaA := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil}
replicaB := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil}

replicated, err := NewCheckingReplicatedDatastore(primary, replicaA, replicaB)
require.NoError(t, err)

crd, ok := replicated.(*checkingReplicatedDatastore)
require.True(t, ok)

// Two replicas produce two distinct cached wrappers, selected alternately.
first := selectReplica(crd.replicas, &crd.lastReplica)
second := selectReplica(crd.replicas, &crd.lastReplica)
third := selectReplica(crd.replicas, &crd.lastReplica)
fourth := selectReplica(crd.replicas, &crd.lastReplica)

require.NotSame(t, first, second)
require.Same(t, first, third)
require.Same(t, second, fourth)
}

func TestCheckingReplicatedReaderWrapsAllReadMethods(t *testing.T) {
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil}

replicated, err := NewCheckingReplicatedDatastore(primary, replica)
require.NoError(t, err)

reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1"))

// fakeSnapshotReader returns "not implemented" for caveats and counters;
// the wrapper just needs to invoke them via the chosen reader.
_, _, err = reader.LegacyReadCaveatByName(t.Context(), "is_weekend")
require.ErrorContains(t, err, "not implemented")

_, err = reader.LegacyListAllCaveats(t.Context())
require.ErrorContains(t, err, "not implemented")

_, err = reader.LegacyLookupCaveatsWithNames(t.Context(), []string{"is_weekend"})
require.ErrorContains(t, err, "not implemented")

_, err = reader.CountRelationships(t.Context(), "filter")
require.ErrorContains(t, err, "not implemented")

_, err = reader.LookupCounters(t.Context())
require.ErrorContains(t, err, "not implemented")

// QueryRelationships should succeed (replica has revision 1).
iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{
OptionalResourceType: "resource",
})
require.NoError(t, err)
rels, err := datastore.IteratorToSlice(iter)
require.NoError(t, err)
require.Len(t, rels, 2)

// ReverseQueryRelationships should also succeed.
riter, err := reader.ReverseQueryRelationships(t.Context(), datastore.SubjectsFilter{
SubjectType: "user",
})
require.NoError(t, err)
rrels, err := datastore.IteratorToSlice(riter)
require.NoError(t, err)
require.Len(t, rrels, 2)

// The replica was selected, not the primary.
require.False(t, reader.(*checkingStableReader).chosePrimaryForTest)
}

func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) {
primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil}
replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1"), nil}
Expand Down
295 changes: 295 additions & 0 deletions internal/datastore/proxy/relationshipintegrity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,301 @@ func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) {
}
}

func TestNewRelationshipIntegrityProxyErrors(t *testing.T) {
validKey := DefaultKeyForTesting

expTime, err := time.Parse("2006-01-02", "2021-01-01")
require.NoError(t, err)

cases := []struct {
name string
current KeyConfig
expired []KeyConfig
errContains string
}{
{
name: "empty current key bytes",
current: KeyConfig{ID: "k1", Bytes: nil},
errContains: "current key file cannot be empty",
},
{
name: "empty current key ID",
current: KeyConfig{ID: "", Bytes: validKey.Bytes},
errContains: "current key ID cannot be empty",
},
{
name: "expired key empty bytes",
current: validKey,
expired: []KeyConfig{
{ID: "exp", Bytes: nil, ExpiredAt: &expTime},
},
errContains: "expired key cannot be empty",
},
{
name: "expired key empty ID",
current: validKey,
expired: []KeyConfig{
{ID: "", Bytes: validKey.Bytes, ExpiredAt: &expTime},
},
errContains: "expired key ID cannot be empty",
},
{
name: "expired key missing expiration",
current: validKey,
expired: []KeyConfig{
{ID: "exp", Bytes: validKey.Bytes, ExpiredAt: nil},
},
errContains: "expired key missing expiration time",
},
{
name: "duplicate key ID",
current: validKey,
expired: []KeyConfig{
{ID: validKey.ID, Bytes: validKey.Bytes, ExpiredAt: &expTime},
},
errContains: "found duplicate key ID",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

_, err = NewRelationshipIntegrityProxy(ds, tc.current, tc.expired)
require.Error(t, err)
require.ErrorContains(t, err, tc.errContains)
})
}

// "Current key has an expiration" panics via MustBugf rather than
// returning an error, so test it separately.
t.Run("current key with expiration panics", func(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

require.Panics(t, func() {
_, _ = NewRelationshipIntegrityProxy(ds, KeyConfig{
ID: "k1",
Bytes: validKey.Bytes,
ExpiredAt: &expTime,
}, nil)
})
})
}

func TestRelationshipIntegrityProxyPassThroughs(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil)
require.NoError(t, err)

ctx := t.Context()

metricsID, err := pds.MetricsID()
require.NoError(t, err)
require.NotEmpty(t, metricsID)

uniqueID, err := pds.UniqueID(ctx)
require.NoError(t, err)
require.NotEmpty(t, uniqueID)

features, err := pds.Features(ctx)
require.NoError(t, err)
require.NotNil(t, features)

offline, err := pds.OfflineFeatures()
require.NoError(t, err)
require.NotNil(t, offline)

headRev, err := pds.HeadRevision(ctx)
require.NoError(t, err)
require.NotNil(t, headRev)

require.NoError(t, pds.CheckRevision(ctx, headRev))

optRev, err := pds.OptimizedRevision(ctx)
require.NoError(t, err)
require.NotNil(t, optRev)

readyState, err := pds.ReadyState(ctx)
require.NoError(t, err)
require.True(t, readyState.IsReady)

roundTripped, err := pds.RevisionFromString(headRev.String())
require.NoError(t, err)
require.True(t, roundTripped.Equal(headRev))

_, err = pds.Statistics(ctx)
require.NoError(t, err)

unwrapper, ok := pds.(datastore.UnwrappableDatastore)
require.True(t, ok)
require.Equal(t, ds, unwrapper.Unwrap())

require.NoError(t, pds.Close())
}

func TestRelationshipIntegrityReaderPassThroughs(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil)
require.NoError(t, err)

headRev, err := pds.HeadRevision(t.Context())
require.NoError(t, err)

reader := pds.SnapshotReader(headRev)

// Each of these simply delegates to the wrapped reader; memdb returns
// empty/zero results at head on a fresh database.
_, err = reader.CountRelationships(t.Context(), "nonexistent")
require.Error(t, err) // memdb errors on unknown counter

caveats, err := reader.LegacyListAllCaveats(t.Context())
require.NoError(t, err)
require.Empty(t, caveats)

namespaces, err := reader.LegacyListAllNamespaces(t.Context())
require.NoError(t, err)
require.Empty(t, namespaces)

lookupC, err := reader.LegacyLookupCaveatsWithNames(t.Context(), []string{"missing"})
require.NoError(t, err)
require.Empty(t, lookupC)

counters, err := reader.LookupCounters(t.Context())
require.NoError(t, err)
require.Empty(t, counters)

lookupN, err := reader.LegacyLookupNamespacesWithNames(t.Context(), []string{"missing"})
require.NoError(t, err)
require.Empty(t, lookupN)

_, _, err = reader.LegacyReadCaveatByName(t.Context(), "missing")
require.Error(t, err)

_, _, err = reader.LegacyReadNamespaceByName(t.Context(), "missing")
require.Error(t, err)
}

func TestRelationshipIntegrityReverseQueryValidatesHash(t *testing.T) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have a test for this scenario: TestBasicIntegrityFailureDueToInvalidHashSignature. please remove this one

ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil)
require.NoError(t, err)

// Write a valid relationship through the proxy.
_, err = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
return tx.WriteRelationships(t.Context(), []tuple.RelationshipUpdate{
tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")),
})
})
require.NoError(t, err)

// Bypass the proxy to insert a relationship with an invalid hash.
_, err = ds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
invalid := tuple.MustParse("resource:foo#viewer@user:fred")
invalid.OptionalIntegrity = &core.RelationshipIntegrity{
KeyId: "defaultfortest",
Hash: append([]byte{0x01}, []byte("someinvalidhashaasd")[0:hashLength]...),
HashedAt: timestamppb.Now(),
}
return tx.WriteRelationships(t.Context(), []tuple.RelationshipUpdate{
tuple.Create(invalid),
})
})
require.NoError(t, err)

headRev, err := pds.HeadRevision(t.Context())
require.NoError(t, err)

reader := pds.SnapshotReader(headRev)
iter, err := reader.ReverseQueryRelationships(t.Context(), datastore.SubjectsFilter{
SubjectType: "user",
})
require.NoError(t, err)

_, err = datastore.IteratorToSlice(iter)
require.Error(t, err)
require.ErrorContains(t, err, "has invalid integrity hash")
}

// stubBulkSource is a trivial BulkWriteRelationshipSource used to verify
// that the integrity proxy decorates the iterator correctly.
type stubBulkSource struct {
rels []tuple.Relationship
idx int
}

func (s *stubBulkSource) Next(_ context.Context) (*tuple.Relationship, error) {
if s.idx >= len(s.rels) {
return nil, nil
}
rel := s.rels[s.idx]
s.idx++
return &rel, nil
}

func TestRelationshipIntegrityBulkLoad(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil)
require.NoError(t, err)

src := &stubBulkSource{rels: []tuple.Relationship{
tuple.MustParse("resource:foo#viewer@user:tom"),
tuple.MustParse("resource:foo#viewer@user:fred"),
}}

_, err = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
loaded, err := tx.BulkLoad(ctx, src)
require.NoError(t, err)
require.Equal(t, uint64(2), loaded)
return nil
})
require.NoError(t, err)

// Integrity metadata should have been added and verifiable on readback.
headRev, err := pds.HeadRevision(t.Context())
require.NoError(t, err)

iter, err := pds.SnapshotReader(headRev).QueryRelationships(
t.Context(),
datastore.RelationshipsFilter{OptionalResourceType: "resource"},
)
require.NoError(t, err)
rels, err := datastore.IteratorToSlice(iter)
require.NoError(t, err)
require.Len(t, rels, 2)
}

func TestRelationshipIntegrityBulkLoadRejectsPrehashed(t *testing.T) {
ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour)
require.NoError(t, err)

pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil)
require.NoError(t, err)

prehashed := tuple.MustParse("resource:foo#viewer@user:tom")
prehashed.OptionalIntegrity = &core.RelationshipIntegrity{KeyId: "other"}

src := &stubBulkSource{rels: []tuple.Relationship{prehashed}}

// spiceerrors.MustBugf panics; BulkLoad propagates that through the iterator's
// Next call.
require.Panics(t, func() {
_, _ = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
_, _ = tx.BulkLoad(ctx, src)
return nil
})
})
}

func BenchmarkQueryRelsWithIntegrity(b *testing.B) {
for _, withIntegrity := range []bool{true, false} {
b.Run(fmt.Sprintf("withIntegrity=%t", withIntegrity), func(b *testing.B) {
Expand Down
Loading
Loading