diff --git a/internal/datastore/proxy/checkingreplicated_test.go b/internal/datastore/proxy/checkingreplicated_test.go index 3cd8259a5f..a6f5de5b41 100644 --- a/internal/datastore/proxy/checkingreplicated_test.go +++ b/internal/datastore/proxy/checkingreplicated_test.go @@ -16,6 +16,84 @@ import ( "github.com/authzed/spicedb/pkg/tuple" ) +func TestCheckingReplicatedWithNoReplicasReturnsPrimary(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + + ds, err := NewCheckingReplicatedDatastore(primary) + require.NoError(t, err) + require.Equal(t, primary, ds) +} + +func TestCheckingReplicatedRoundRobinsAcrossReplicas(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + replicaA := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil} + replicaB := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil} + + replicated, err := NewCheckingReplicatedDatastore(primary, replicaA, replicaB) + require.NoError(t, err) + + crd, ok := replicated.(*checkingReplicatedDatastore) + require.True(t, ok) + + // Two replicas produce two distinct cached wrappers, selected alternately. + first := selectReplica(crd.replicas, &crd.lastReplica) + second := selectReplica(crd.replicas, &crd.lastReplica) + third := selectReplica(crd.replicas, &crd.lastReplica) + fourth := selectReplica(crd.replicas, &crd.lastReplica) + + require.NotSame(t, first, second) + require.Same(t, first, third) + require.Same(t, second, fourth) +} + +func TestCheckingReplicatedReaderWrapsAllReadMethods(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil} + + replicated, err := NewCheckingReplicatedDatastore(primary, replica) + require.NoError(t, err) + + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + + // fakeSnapshotReader returns "not implemented" for caveats and counters; + // the wrapper just needs to invoke them via the chosen reader. + _, _, err = reader.LegacyReadCaveatByName(t.Context(), "is_weekend") + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LegacyListAllCaveats(t.Context()) + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LegacyLookupCaveatsWithNames(t.Context(), []string{"is_weekend"}) + require.ErrorContains(t, err, "not implemented") + + _, err = reader.CountRelationships(t.Context(), "filter") + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LookupCounters(t.Context()) + require.ErrorContains(t, err, "not implemented") + + // QueryRelationships should succeed (replica has revision 1). + iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(t, err) + rels, err := datastore.IteratorToSlice(iter) + require.NoError(t, err) + require.Len(t, rels, 2) + + // ReverseQueryRelationships should also succeed. + riter, err := reader.ReverseQueryRelationships(t.Context(), datastore.SubjectsFilter{ + SubjectType: "user", + }) + require.NoError(t, err) + rrels, err := datastore.IteratorToSlice(riter) + require.NoError(t, err) + require.Len(t, rrels, 2) + + // The replica was selected, not the primary. + require.False(t, reader.(*checkingStableReader).chosePrimaryForTest) +} + func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *testing.T) { primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1"), nil} diff --git a/internal/datastore/proxy/relationshipintegrity_test.go b/internal/datastore/proxy/relationshipintegrity_test.go index 4c8f09abd7..7a70de820e 100644 --- a/internal/datastore/proxy/relationshipintegrity_test.go +++ b/internal/datastore/proxy/relationshipintegrity_test.go @@ -285,6 +285,301 @@ func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { } } +func TestNewRelationshipIntegrityProxyErrors(t *testing.T) { + validKey := DefaultKeyForTesting + + expTime, err := time.Parse("2006-01-02", "2021-01-01") + require.NoError(t, err) + + cases := []struct { + name string + current KeyConfig + expired []KeyConfig + errContains string + }{ + { + name: "empty current key bytes", + current: KeyConfig{ID: "k1", Bytes: nil}, + errContains: "current key file cannot be empty", + }, + { + name: "empty current key ID", + current: KeyConfig{ID: "", Bytes: validKey.Bytes}, + errContains: "current key ID cannot be empty", + }, + { + name: "expired key empty bytes", + current: validKey, + expired: []KeyConfig{ + {ID: "exp", Bytes: nil, ExpiredAt: &expTime}, + }, + errContains: "expired key cannot be empty", + }, + { + name: "expired key empty ID", + current: validKey, + expired: []KeyConfig{ + {ID: "", Bytes: validKey.Bytes, ExpiredAt: &expTime}, + }, + errContains: "expired key ID cannot be empty", + }, + { + name: "expired key missing expiration", + current: validKey, + expired: []KeyConfig{ + {ID: "exp", Bytes: validKey.Bytes, ExpiredAt: nil}, + }, + errContains: "expired key missing expiration time", + }, + { + name: "duplicate key ID", + current: validKey, + expired: []KeyConfig{ + {ID: validKey.ID, Bytes: validKey.Bytes, ExpiredAt: &expTime}, + }, + errContains: "found duplicate key ID", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + _, err = NewRelationshipIntegrityProxy(ds, tc.current, tc.expired) + require.Error(t, err) + require.ErrorContains(t, err, tc.errContains) + }) + } + + // "Current key has an expiration" panics via MustBugf rather than + // returning an error, so test it separately. + t.Run("current key with expiration panics", func(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + require.Panics(t, func() { + _, _ = NewRelationshipIntegrityProxy(ds, KeyConfig{ + ID: "k1", + Bytes: validKey.Bytes, + ExpiredAt: &expTime, + }, nil) + }) + }) +} + +func TestRelationshipIntegrityProxyPassThroughs(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) + require.NoError(t, err) + + ctx := t.Context() + + metricsID, err := pds.MetricsID() + require.NoError(t, err) + require.NotEmpty(t, metricsID) + + uniqueID, err := pds.UniqueID(ctx) + require.NoError(t, err) + require.NotEmpty(t, uniqueID) + + features, err := pds.Features(ctx) + require.NoError(t, err) + require.NotNil(t, features) + + offline, err := pds.OfflineFeatures() + require.NoError(t, err) + require.NotNil(t, offline) + + headRev, err := pds.HeadRevision(ctx) + require.NoError(t, err) + require.NotNil(t, headRev) + + require.NoError(t, pds.CheckRevision(ctx, headRev)) + + optRev, err := pds.OptimizedRevision(ctx) + require.NoError(t, err) + require.NotNil(t, optRev) + + readyState, err := pds.ReadyState(ctx) + require.NoError(t, err) + require.True(t, readyState.IsReady) + + roundTripped, err := pds.RevisionFromString(headRev.String()) + require.NoError(t, err) + require.True(t, roundTripped.Equal(headRev)) + + _, err = pds.Statistics(ctx) + require.NoError(t, err) + + unwrapper, ok := pds.(datastore.UnwrappableDatastore) + require.True(t, ok) + require.Equal(t, ds, unwrapper.Unwrap()) + + require.NoError(t, pds.Close()) +} + +func TestRelationshipIntegrityReaderPassThroughs(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) + require.NoError(t, err) + + headRev, err := pds.HeadRevision(t.Context()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + + // Each of these simply delegates to the wrapped reader; memdb returns + // empty/zero results at head on a fresh database. + _, err = reader.CountRelationships(t.Context(), "nonexistent") + require.Error(t, err) // memdb errors on unknown counter + + caveats, err := reader.LegacyListAllCaveats(t.Context()) + require.NoError(t, err) + require.Empty(t, caveats) + + namespaces, err := reader.LegacyListAllNamespaces(t.Context()) + require.NoError(t, err) + require.Empty(t, namespaces) + + lookupC, err := reader.LegacyLookupCaveatsWithNames(t.Context(), []string{"missing"}) + require.NoError(t, err) + require.Empty(t, lookupC) + + counters, err := reader.LookupCounters(t.Context()) + require.NoError(t, err) + require.Empty(t, counters) + + lookupN, err := reader.LegacyLookupNamespacesWithNames(t.Context(), []string{"missing"}) + require.NoError(t, err) + require.Empty(t, lookupN) + + _, _, err = reader.LegacyReadCaveatByName(t.Context(), "missing") + require.Error(t, err) + + _, _, err = reader.LegacyReadNamespaceByName(t.Context(), "missing") + require.Error(t, err) +} + +func TestRelationshipIntegrityReverseQueryValidatesHash(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) + require.NoError(t, err) + + // Write a valid relationship through the proxy. + _, err = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + return tx.WriteRelationships(t.Context(), []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + }) + }) + require.NoError(t, err) + + // Bypass the proxy to insert a relationship with an invalid hash. + _, err = ds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + invalid := tuple.MustParse("resource:foo#viewer@user:fred") + invalid.OptionalIntegrity = &core.RelationshipIntegrity{ + KeyId: "defaultfortest", + Hash: append([]byte{0x01}, []byte("someinvalidhashaasd")[0:hashLength]...), + HashedAt: timestamppb.Now(), + } + return tx.WriteRelationships(t.Context(), []tuple.RelationshipUpdate{ + tuple.Create(invalid), + }) + }) + require.NoError(t, err) + + headRev, err := pds.HeadRevision(t.Context()) + require.NoError(t, err) + + reader := pds.SnapshotReader(headRev) + iter, err := reader.ReverseQueryRelationships(t.Context(), datastore.SubjectsFilter{ + SubjectType: "user", + }) + require.NoError(t, err) + + _, err = datastore.IteratorToSlice(iter) + require.Error(t, err) + require.ErrorContains(t, err, "has invalid integrity hash") +} + +// stubBulkSource is a trivial BulkWriteRelationshipSource used to verify +// that the integrity proxy decorates the iterator correctly. +type stubBulkSource struct { + rels []tuple.Relationship + idx int +} + +func (s *stubBulkSource) Next(_ context.Context) (*tuple.Relationship, error) { + if s.idx >= len(s.rels) { + return nil, nil + } + rel := s.rels[s.idx] + s.idx++ + return &rel, nil +} + +func TestRelationshipIntegrityBulkLoad(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) + require.NoError(t, err) + + src := &stubBulkSource{rels: []tuple.Relationship{ + tuple.MustParse("resource:foo#viewer@user:tom"), + tuple.MustParse("resource:foo#viewer@user:fred"), + }} + + _, err = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + loaded, err := tx.BulkLoad(ctx, src) + require.NoError(t, err) + require.Equal(t, uint64(2), loaded) + return nil + }) + require.NoError(t, err) + + // Integrity metadata should have been added and verifiable on readback. + headRev, err := pds.HeadRevision(t.Context()) + require.NoError(t, err) + + iter, err := pds.SnapshotReader(headRev).QueryRelationships( + t.Context(), + datastore.RelationshipsFilter{OptionalResourceType: "resource"}, + ) + require.NoError(t, err) + rels, err := datastore.IteratorToSlice(iter) + require.NoError(t, err) + require.Len(t, rels, 2) +} + +func TestRelationshipIntegrityBulkLoadRejectsPrehashed(t *testing.T) { + ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + + pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) + require.NoError(t, err) + + prehashed := tuple.MustParse("resource:foo#viewer@user:tom") + prehashed.OptionalIntegrity = &core.RelationshipIntegrity{KeyId: "other"} + + src := &stubBulkSource{rels: []tuple.Relationship{prehashed}} + + // spiceerrors.MustBugf panics; BulkLoad propagates that through the iterator's + // Next call. + require.Panics(t, func() { + _, _ = pds.ReadWriteTx(t.Context(), func(ctx context.Context, tx datastore.ReadWriteTransaction) error { + _, _ = tx.BulkLoad(ctx, src) + return nil + }) + }) +} + func BenchmarkQueryRelsWithIntegrity(b *testing.B) { for _, withIntegrity := range []bool{true, false} { b.Run(fmt.Sprintf("withIntegrity=%t", withIntegrity), func(b *testing.B) { diff --git a/internal/datastore/proxy/strictreplicated_test.go b/internal/datastore/proxy/strictreplicated_test.go index 4ec25b60bd..3891315e5f 100644 --- a/internal/datastore/proxy/strictreplicated_test.go +++ b/internal/datastore/proxy/strictreplicated_test.go @@ -9,6 +9,16 @@ import ( "github.com/authzed/spicedb/pkg/datastore/revisionparsing" ) +// nonStrictDatastore wraps a fakeDatastore but reports strict read mode disabled, +// used to exercise the IsStrictReadModeEnabled==false rejection path. +type nonStrictDatastore struct { + fakeDatastore +} + +func (nonStrictDatastore) IsStrictReadModeEnabled() bool { + return false +} + func TestStrictReplicatedReaderWithOnlyPrimary(t *testing.T) { primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} @@ -101,3 +111,64 @@ func TestStrictReplicatedQueryNonFallbackError(t *testing.T) { }) require.ErrorContains(t, err, "raising an expected error") } + +func TestStrictReplicatedRejectsReplicaWithoutStrictMode(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + replica := nonStrictDatastore{fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1"), nil}} + + _, err := NewStrictReplicatedDatastore(primary, replica) + require.Error(t, err) + require.ErrorContains(t, err, "does not have strict read mode enabled") +} + +// TestStrictReplicatedReaderWrapperMethods exercises the legacy caveat/namespace +// wrappers, CountRelationships, and LookupCounters on a strict replicated reader. +// The fake replica returns "not implemented" (not a RevisionUnavailableError), so +// these calls do not trigger a primary fallback; they simply confirm the wrappers +// invoke the replica's reader. +func TestStrictReplicatedReaderWrapperMethods(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("2"), nil} + + replicated, err := NewStrictReplicatedDatastore(primary, replica) + require.NoError(t, err) + + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + + _, _, err = reader.LegacyReadCaveatByName(t.Context(), "is_weekend") + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LegacyListAllCaveats(t.Context()) + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LegacyLookupCaveatsWithNames(t.Context(), []string{"is_weekend"}) + require.ErrorContains(t, err, "not implemented") + + // LegacyListAllNamespaces returns nil on the fake replica (no error), so happy-path. + ns, err := reader.LegacyListAllNamespaces(t.Context()) + require.NoError(t, err) + require.Empty(t, ns) + + _, err = reader.CountRelationships(t.Context(), "filter") + require.ErrorContains(t, err, "not implemented") + + _, err = reader.LookupCounters(t.Context()) + require.ErrorContains(t, err, "not implemented") +} + +// TestStrictReplicatedReaderFallsbackForNamespaceLookups ensures the fallback +// path in LegacyLookupNamespacesWithNames kicks in when the replica returns a +// RevisionUnavailableError. The fake returns that error when queried beyond +// revision 2. +func TestStrictReplicatedReaderFallsbackForNamespaceLookups(t *testing.T) { + primary := fakeDatastore{"primary", revisionparsing.MustParseRevisionForTest("2"), nil} + replica := fakeDatastore{"replica", revisionparsing.MustParseRevisionForTest("1"), nil} + + replicated, err := NewStrictReplicatedDatastore(primary, replica) + require.NoError(t, err) + + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + ns, err := reader.LegacyLookupNamespacesWithNames(t.Context(), []string{"ns1"}) + require.NoError(t, err) + require.Len(t, ns, 1) +} diff --git a/internal/dispatch/combined/combined_test.go b/internal/dispatch/combined/combined_test.go index 2629c55815..e9c04f0811 100644 --- a/internal/dispatch/combined/combined_test.go +++ b/internal/dispatch/combined/combined_test.go @@ -2,12 +2,17 @@ package combined import ( "testing" + "time" "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/dispatch/graph" + "github.com/authzed/spicedb/internal/dispatch/keys" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/cache" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datalayer" core "github.com/authzed/spicedb/pkg/proto/core/v1" dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -58,3 +63,117 @@ func TestCombinedRecursiveCall(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "max depth exceeded") } + +// TestNewDispatcher_AppliesAllOptions_NoUpstream exercises the previously +// uncovered option setters on the no-upstream branch and verifies the +// dispatcher constructs successfully. +func TestNewDispatcher_AppliesAllOptions_NoUpstream(t *testing.T) { + cfg := &cache.Config{ + NumCounters: 100, + MaxCost: 1024, + DefaultTTL: 1 * time.Second, + } + + dispatchCache, err := cache.NewStandardCache[keys.DispatchCacheKey, any](&cache.Config{ + NumCounters: 100, + MaxCost: 1024, + DefaultTTL: 1 * time.Second, + }) + require.NoError(t, err) + + dispatcher, err := NewDispatcher( + MetricsEnabled(true), + PrometheusSubsystem("test_subsystem"), + DispatchChunkSize(50), + RelationshipChunkCacheConfig(cfg), + CaveatTypeSet(caveattypes.Default.TypeSet), + Cache(dispatchCache), + ConcurrencyLimits(graph.ConcurrencyLimits{Check: 10, LookupResources: 5}), + ) + require.NoError(t, err) + require.NotNil(t, dispatcher) + t.Cleanup(func() { dispatcher.Close() }) +} + +// TestNewDispatcher_WithProvidedRelationshipChunkCache hits the branch where +// the caller supplies a pre-built relationship chunk cache instead of a config. +func TestNewDispatcher_WithProvidedRelationshipChunkCache(t *testing.T) { + c, err := cache.NewStandardCache[cache.StringKey, any](&cache.Config{ + NumCounters: 100, + MaxCost: 1024, + DefaultTTL: 1 * time.Second, + }) + require.NoError(t, err) + + dispatcher, err := NewDispatcher( + RelationshipChunkCache(c), + ) + require.NoError(t, err) + require.NotNil(t, dispatcher) + t.Cleanup(func() { dispatcher.Close() }) +} + +// TestNewDispatcher_WithUpstream exercises the upstream branch using an +// insecure preshared-key setup. grpc.DialContext is lazy, so a bogus but +// syntactically valid address is acceptable. +func TestNewDispatcher_WithUpstream(t *testing.T) { + dispatcher, err := NewDispatcher( + UpstreamAddr("localhost:0"), + GrpcPresharedKey("test-key"), + RemoteDispatchTimeout(5*time.Second), + StartingPrimaryHedgingDelay(10*time.Millisecond), + GrpcDialOpts(), + ) + require.NoError(t, err) + require.NotNil(t, dispatcher) + t.Cleanup(func() { dispatcher.Close() }) +} + +// TestNewDispatcher_UpstreamCAPathMissing exercises the TLS branch by pointing +// UpstreamCAPath at a non-existent file; grpcutil.WithCustomCerts will error. +func TestNewDispatcher_UpstreamCAPathMissing(t *testing.T) { + _, err := NewDispatcher( + UpstreamAddr("localhost:0"), + UpstreamCAPath("/nonexistent/path/to/ca.pem"), + GrpcPresharedKey("test-key"), + ) + require.Error(t, err) +} + +// TestNewDispatcher_SecondaryInvalidHedgingDelay covers the parse-error branch +// for secondary upstream maximum primary hedging delays. +func TestNewDispatcher_SecondaryInvalidHedgingDelay(t *testing.T) { + _, err := NewDispatcher( + UpstreamAddr("localhost:0"), + GrpcPresharedKey("test-key"), + SecondaryUpstreamAddrs(map[string]string{"sec": "localhost:1"}), + SecondaryMaximumPrimaryHedgingDelays(map[string]string{"sec": "not-a-duration"}), + ) + require.Error(t, err) + require.ErrorContains(t, err, "error parsing maximum primary hedging delay") +} + +// TestNewDispatcher_SecondaryNegativeHedgingDelay covers the zero/negative +// hedging delay branch. +func TestNewDispatcher_SecondaryNegativeHedgingDelay(t *testing.T) { + _, err := NewDispatcher( + UpstreamAddr("localhost:0"), + GrpcPresharedKey("test-key"), + SecondaryUpstreamAddrs(map[string]string{"sec": "localhost:1"}), + SecondaryMaximumPrimaryHedgingDelays(map[string]string{"sec": "0s"}), + ) + require.Error(t, err) + require.ErrorContains(t, err, "must be greater than 0") +} + +// TestNewDispatcher_InvalidSecondaryDispatchExpr covers the parse-error branch +// for secondary dispatch expressions. +func TestNewDispatcher_InvalidSecondaryDispatchExpr(t *testing.T) { + _, err := NewDispatcher( + UpstreamAddr("localhost:0"), + GrpcPresharedKey("test-key"), + SecondaryUpstreamExprs(map[string]string{"check": "not a valid CEL expression @#$"}), + ) + require.Error(t, err) + require.ErrorContains(t, err, "error parsing secondary dispatch expr") +} diff --git a/internal/graph/errors_test.go b/internal/graph/errors_test.go new file mode 100644 index 0000000000..318347eb13 --- /dev/null +++ b/internal/graph/errors_test.go @@ -0,0 +1,144 @@ +package graph + +import ( + "errors" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +func TestCheckFailureError(t *testing.T) { + base := errors.New("boom") + err := NewCheckFailureErr(base) + + var cfe CheckFailureError + require.ErrorAs(t, err, &cfe) + require.ErrorIs(t, err, base) + require.Contains(t, err.Error(), "error performing check") + require.Contains(t, err.Error(), "boom") +} + +func TestExpansionFailureError(t *testing.T) { + base := errors.New("boom") + err := NewExpansionFailureErr(base) + + var efe ExpansionFailureError + require.ErrorAs(t, err, &efe) + require.ErrorIs(t, err, base) + require.Contains(t, err.Error(), "error performing expand") + require.Contains(t, err.Error(), "boom") +} + +func TestAlwaysFailError(t *testing.T) { + err := NewAlwaysFailErr() + + var afe AlwaysFailError + require.ErrorAs(t, err, &afe) + require.Equal(t, "always fail", err.Error()) +} + +func TestRelationNotFoundError(t *testing.T) { + err := NewRelationNotFoundErr("document", "viewer") + + var rnfe RelationNotFoundError + require.ErrorAs(t, err, &rnfe) + require.Equal(t, "document", rnfe.NamespaceName()) + require.Equal(t, "viewer", rnfe.NotFoundRelationName()) + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_or_permission_name": "viewer", + }, rnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + rnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestRelationMissingTypeInfoError(t *testing.T) { + err := NewRelationMissingTypeInfoErr("document", "viewer") + + var rmte RelationMissingTypeInfoError + require.ErrorAs(t, err, &rmte) + require.Equal(t, "document", rmte.NamespaceName()) + require.Equal(t, "viewer", rmte.RelationName()) + require.Contains(t, err.Error(), "missing type information") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_name": "viewer", + }, rmte.DetailsMetadata()) + + require.NotPanics(t, func() { + rmte.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestWildcardNotAllowedError(t *testing.T) { + err := NewWildcardNotAllowedErr("wildcard not allowed here", "resource.subject") + + var wne WildcardNotAllowedError + require.ErrorAs(t, err, &wne) + require.Contains(t, err.Error(), "wildcard not allowed here") + + status := wne.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestUnimplementedError(t *testing.T) { + base := errors.New("not yet") + err := NewUnimplementedErr(base) + + var ue UnimplementedError + require.ErrorAs(t, err, &ue) + require.ErrorIs(t, err, base) + require.Equal(t, "not yet", err.Error()) +} + +func TestInvalidCursorError(t *testing.T) { + err := NewInvalidCursorErr(2, &dispatch.Cursor{DispatchVersion: 1}) + + var ice InvalidCursorError + require.ErrorAs(t, err, &ice) + require.Contains(t, err.Error(), "cursor is no longer valid") + + status := ice.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestCheckFailureErrorUnwrap(t *testing.T) { + base := errors.New("underlying") + err := NewCheckFailureErr(base) + + var cfe CheckFailureError + require.ErrorAs(t, err, &cfe) + require.Error(t, cfe.Unwrap()) + require.ErrorIs(t, cfe.Unwrap(), base) +} + +func TestExpansionFailureErrorUnwrap(t *testing.T) { + base := errors.New("underlying") + err := NewExpansionFailureErr(base) + + var efe ExpansionFailureError + require.ErrorAs(t, err, &efe) + require.Error(t, efe.Unwrap()) + require.ErrorIs(t, efe.Unwrap(), base) +} + +func TestUnimplementedErrorUnwrap(t *testing.T) { + base := errors.New("underlying") + err := NewUnimplementedErr(base) + + var ue UnimplementedError + require.ErrorAs(t, err, &ue) + require.Equal(t, base, ue.Unwrap()) +} diff --git a/internal/namespace/errors_test.go b/internal/namespace/errors_test.go new file mode 100644 index 0000000000..6573e0508a --- /dev/null +++ b/internal/namespace/errors_test.go @@ -0,0 +1,100 @@ +package namespace + +import ( + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func TestNamespaceNotFoundError(t *testing.T) { + err := NewNamespaceNotFoundErr("document") + + var nnfe NamespaceNotFoundError + require.ErrorAs(t, err, &nnfe) + require.Equal(t, "document", nnfe.NotFoundNamespaceName()) + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{"definition_name": "document"}, nnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + nnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestRelationNotFoundError(t *testing.T) { + err := NewRelationNotFoundErr("document", "viewer") + + var rnfe RelationNotFoundError + require.ErrorAs(t, err, &rnfe) + require.Equal(t, "document", rnfe.NamespaceName()) + require.Equal(t, "viewer", rnfe.NotFoundRelationName()) + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_or_permission_name": "viewer", + }, rnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + rnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestDuplicateRelationError(t *testing.T) { + err := NewDuplicateRelationError("document", "viewer") + + var dre DuplicateRelationError + require.ErrorAs(t, err, &dre) + require.Contains(t, err.Error(), "duplicate") + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_or_permission_name": "viewer", + }, dre.DetailsMetadata()) + + require.NotPanics(t, func() { + dre.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestPermissionsCycleError(t *testing.T) { + err := NewPermissionsCycleErr("document", []string{"view", "edit", "admin"}) + + var pce PermissionsCycleError + require.ErrorAs(t, err, &pce) + require.Contains(t, err.Error(), "cycle") + require.Contains(t, err.Error(), "view") + require.Contains(t, err.Error(), "edit") + require.Contains(t, err.Error(), "admin") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "permission_names": "view,edit,admin", + }, pce.DetailsMetadata()) + + require.NotPanics(t, func() { + pce.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestUnusedCaveatParameterError(t *testing.T) { + err := NewUnusedCaveatParameterErr("is_weekend", "day") + + var ucpe UnusedCaveatParameterError + require.ErrorAs(t, err, &ucpe) + require.Contains(t, err.Error(), "day") + require.Contains(t, err.Error(), "is_weekend") + + require.Equal(t, map[string]string{ + "caveat_name": "is_weekend", + "parameter_name": "day", + }, ucpe.DetailsMetadata()) + + require.NotPanics(t, func() { + ucpe.MarshalZerologObject(zerolog.Dict()) + }) +} diff --git a/internal/services/v1/errors_test.go b/internal/services/v1/errors_test.go new file mode 100644 index 0000000000..e7344529da --- /dev/null +++ b/internal/services/v1/errors_test.go @@ -0,0 +1,265 @@ +package v1 + +import ( + "errors" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" +) + +func TestExceedsMaximumLimitError(t *testing.T) { + err := NewExceedsMaximumLimitErr(100, 50) + + require.Contains(t, err.Error(), "100") + require.Contains(t, err.Error(), "50") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestExceedsMaximumChecksError(t *testing.T) { + err := NewExceedsMaximumChecksErr(100, 50) + + require.Contains(t, err.Error(), "100") + require.Contains(t, err.Error(), "50") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestExceedsMaximumUpdatesError(t *testing.T) { + err := NewExceedsMaximumUpdatesErr(1000, 500) + + require.Contains(t, err.Error(), "1000") + require.Contains(t, err.Error(), "500") + require.Contains(t, err.Error(), "ImportBulkRelationships") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestExceedsMaximumPreconditionsError(t *testing.T) { + err := NewExceedsMaximumPreconditionsErr(50, 10) + + require.Contains(t, err.Error(), "50") + require.Contains(t, err.Error(), "10") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestPreconditionFailedError_Minimal(t *testing.T) { + precondition := &v1.Precondition{ + Operation: v1.Precondition_OPERATION_MUST_MATCH, + Filter: &v1.RelationshipFilter{ + ResourceType: "document", + }, + } + err := func() PreconditionFailedError { + var target PreconditionFailedError + _ = errors.As(NewPreconditionFailedErr(precondition), &target) + return target + }() + + require.Contains(t, err.Error(), "precondition") + + status := err.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestPreconditionFailedError_FullFilter(t *testing.T) { + precondition := &v1.Precondition{ + Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH, + Filter: &v1.RelationshipFilter{ + ResourceType: "document", + OptionalResourceId: "first", + OptionalResourceIdPrefix: "doc-", + OptionalRelation: "viewer", + OptionalSubjectFilter: &v1.SubjectFilter{ + SubjectType: "user", + OptionalSubjectId: "alice", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "member", + }, + }, + }, + } + err := func() PreconditionFailedError { + var target PreconditionFailedError + _ = errors.As(NewPreconditionFailedErr(precondition), &target) + return target + }() + + status := err.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestDuplicateRelationshipError(t *testing.T) { + update := &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_CREATE, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{ + ObjectType: "document", + ObjectId: "first", + }, + Relation: "viewer", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "alice", + }, + }, + }, + } + err := NewDuplicateRelationshipErr(update) + + require.Contains(t, err.Error(), "document") + require.Contains(t, err.Error(), "alice") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestErrMaxRelationshipContextError(t *testing.T) { + update := &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_CREATE, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{ObjectType: "document", ObjectId: "first"}, + Relation: "viewer", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ObjectType: "user", ObjectId: "alice"}, + }, + }, + } + err := NewMaxRelationshipContextError(update, 1024) + + require.Contains(t, err.Error(), "1024") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestCouldNotTransactionallyDeleteError_Minimal(t *testing.T) { + filter := &v1.RelationshipFilter{ResourceType: "document"} + err := NewCouldNotTransactionallyDeleteErr(filter, 100) + + require.Contains(t, err.Error(), "100") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestCouldNotTransactionallyDeleteError_FullFilter(t *testing.T) { + filter := &v1.RelationshipFilter{ + ResourceType: "document", + OptionalResourceId: "first", + OptionalRelation: "viewer", + OptionalSubjectFilter: &v1.SubjectFilter{ + SubjectType: "user", + OptionalSubjectId: "alice", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "member", + }, + }, + } + err := NewCouldNotTransactionallyDeleteErr(filter, 100) + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestInvalidCursorError(t *testing.T) { + err := NewInvalidCursorErr("malformed") + + require.Contains(t, err.Error(), "malformed") + + status := err.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestInvalidFilterError(t *testing.T) { + err := NewInvalidFilterErr("bad filter", "filter-repr") + + require.Contains(t, err.Error(), "bad filter") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestEmptyPreconditionError(t *testing.T) { + err := NewEmptyPreconditionErr() + + require.Equal(t, "one of the specified preconditions is empty", err.Error()) + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) +} + +func TestNotAPermissionError(t *testing.T) { + err := NewNotAPermissionError("viewer") + + require.Contains(t, err.Error(), "viewer") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) +} + +func TestTransactionMetadataTooLargeError(t *testing.T) { + err := NewTransactionMetadataTooLargeErr(1024, 512) + + require.Contains(t, err.Error(), "1024") + require.Contains(t, err.Error(), "512") + + status := err.GRPCStatus() + require.Equal(t, codes.InvalidArgument, status.Code()) + require.NotEmpty(t, status.Details()) + + require.NotPanics(t, func() { + err.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestDefaultIfZero(t *testing.T) { + require.Equal(t, "fallback", defaultIfZero("", "fallback")) + require.Equal(t, "explicit", defaultIfZero("explicit", "fallback")) + + require.Equal(t, 5, defaultIfZero(0, 5)) + require.Equal(t, 7, defaultIfZero(7, 5)) +} diff --git a/pkg/datalayer/datalayer_test.go b/pkg/datalayer/datalayer_test.go new file mode 100644 index 0000000000..34fcf204a1 --- /dev/null +++ b/pkg/datalayer/datalayer_test.go @@ -0,0 +1,459 @@ +package datalayer + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +func newTestDataLayer(t testing.TB) (DataLayer, datastore.Datastore) { + t.Helper() + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + t.Cleanup(func() { + _ = ds.Close() + }) + return NewDataLayer(ds), ds +} + +// TestDefaultDataLayer_PassThroughs exercises the trivial pass-through methods on +// defaultDataLayer in impl.go. +func TestDefaultDataLayer_PassThroughs(t *testing.T) { + dl, underlying := newTestDataLayer(t) + ctx := t.Context() + + readyState, err := dl.ReadyState(ctx) + require.NoError(t, err) + require.True(t, readyState.IsReady) + + features, err := dl.Features(ctx) + require.NoError(t, err) + require.NotNil(t, features) + + offline, err := dl.OfflineFeatures() + require.NoError(t, err) + require.NotNil(t, offline) + + _, err = dl.Statistics(ctx) + require.NoError(t, err) + + metricsID, err := dl.MetricsID() + require.NoError(t, err) + require.NotEmpty(t, metricsID) + + uniqueID, err := dl.UniqueID(ctx) + require.NoError(t, err) + require.NotEmpty(t, uniqueID) + + rev, err := dl.HeadRevision(ctx) + require.NoError(t, err) + require.NoError(t, dl.CheckRevision(ctx, rev)) + + optRev, err := dl.OptimizedRevision(ctx) + require.NoError(t, err) + require.NotNil(t, optRev) + + roundTripped, err := dl.RevisionFromString(rev.String()) + require.NoError(t, err) + require.True(t, roundTripped.Equal(rev)) + + watchCh, errCh := dl.Watch(ctx, rev, datastore.WatchOptions{Content: datastore.WatchRelationships}) + require.NotNil(t, watchCh) + require.NotNil(t, errCh) + + require.Equal(t, underlying, UnwrapDatastore(dl)) + require.NoError(t, dl.Close()) +} + +// TestCountingDataLayer_CountsQueryAndReverseQuery ensures QueryRelationships +// and ReverseQueryRelationships increment their counters on both the snapshot +// reader and the RW transaction. +func TestCountingDataLayer_CountsQueryAndReverseQuery(t *testing.T) { + baseDL, _ := newTestDataLayer(t) + counting, counts := NewCountingDataLayer(baseDL) + ctx := t.Context() + + _, err := counting.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + }) + }) + require.NoError(t, err) + + rev, err := counting.HeadRevision(ctx) + require.NoError(t, err) + + reader := counting.SnapshotReader(rev) + it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(t, err) + _, err = datastore.IteratorToSlice(it) + require.NoError(t, err) + require.Equal(t, uint64(1), counts.QueryRelationships()) + + rit, err := reader.ReverseQueryRelationships(ctx, datastore.SubjectsFilter{SubjectType: "user"}) + require.NoError(t, err) + _, err = datastore.IteratorToSlice(rit) + require.NoError(t, err) + require.Equal(t, uint64(1), counts.ReverseQueryRelationships()) + + // RWT also counts queries. + _, err = counting.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + it, err := rwt.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: "resource", + }) + require.NoError(t, err) + _, err = datastore.IteratorToSlice(it) + require.NoError(t, err) + + rit, err := rwt.ReverseQueryRelationships(ctx, datastore.SubjectsFilter{SubjectType: "user"}) + require.NoError(t, err) + _, err = datastore.IteratorToSlice(rit) + require.NoError(t, err) + + // CountRelationships and LookupCounters delegate without counting. + _, _ = rwt.CountRelationships(ctx, "nonexistent") + _, _ = rwt.LookupCounters(ctx) + return nil + }) + require.NoError(t, err) + + require.Equal(t, uint64(2), counts.QueryRelationships()) + require.Equal(t, uint64(2), counts.ReverseQueryRelationships()) +} + +// TestCountingDataLayer_PassThroughs exercises the trivial pass-through methods +// on countingDataLayer. +func TestCountingDataLayer_PassThroughs(t *testing.T) { + base, _ := newTestDataLayer(t) + counting, _ := NewCountingDataLayer(base) + ctx := t.Context() + + readyState, err := counting.ReadyState(ctx) + require.NoError(t, err) + require.True(t, readyState.IsReady) + + features, err := counting.Features(ctx) + require.NoError(t, err) + require.NotNil(t, features) + + offline, err := counting.OfflineFeatures() + require.NoError(t, err) + require.NotNil(t, offline) + + _, err = counting.Statistics(ctx) + require.NoError(t, err) + + metricsID, err := counting.MetricsID() + require.NoError(t, err) + require.NotEmpty(t, metricsID) + + uniqueID, err := counting.UniqueID(ctx) + require.NoError(t, err) + require.NotEmpty(t, uniqueID) + + rev, err := counting.HeadRevision(ctx) + require.NoError(t, err) + + require.NoError(t, counting.CheckRevision(ctx, rev)) + + optRev, err := counting.OptimizedRevision(ctx) + require.NoError(t, err) + require.NotNil(t, optRev) + + roundTripped, err := counting.RevisionFromString(rev.String()) + require.NoError(t, err) + require.True(t, roundTripped.Equal(rev)) + + watchCh, errCh := counting.Watch(ctx, rev, datastore.WatchOptions{Content: datastore.WatchRelationships}) + require.NotNil(t, watchCh) + require.NotNil(t, errCh) + + // UnwrapDatastore transitively returns the underlying datastore. + require.NotNil(t, UnwrapDatastore(counting)) + + reader := counting.SnapshotReader(rev) + _, err = reader.ReadSchema(ctx) + require.NoError(t, err) + + _, _ = reader.CountRelationships(ctx, "nonexistent") + _, err = reader.LookupCounters(ctx) + require.NoError(t, err) + + // Exercise DeleteRelationships on countingReadWriteTransaction. + _, err = counting.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + require.NoError(t, rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + })) + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ResourceType: "resource"}) + return err + }) + require.NoError(t, err) + + require.NoError(t, counting.Close()) +} + +// TestUnaryCountingInterceptor verifies the interceptor swaps in a counting +// datalayer, invokes the handler, and calls the exporter exactly once. +func TestUnaryCountingInterceptor(t *testing.T) { + base, _ := newTestDataLayer(t) + parentCtx := ContextWithDataLayer(t.Context(), base) + + var exportedCounts *MethodCounts + interceptor := UnaryCountingInterceptor(func(counts *MethodCounts) { + exportedCounts = counts + }) + + resp, err := interceptor(parentCtx, "req", &grpc.UnaryServerInfo{FullMethod: "/test.Service/Method"}, + func(ctx context.Context, _ any) (any, error) { + dl := MustFromContext(ctx) + _, err := dl.HeadRevision(ctx) + require.NoError(t, err) + return "ok", nil + }) + require.NoError(t, err) + require.Equal(t, "ok", resp) + + require.NotNil(t, exportedCounts, "exporter should have been called") + require.Equal(t, uint64(0), exportedCounts.QueryRelationships()) +} + +func TestUnaryCountingInterceptor_NilExporter(t *testing.T) { + base, _ := newTestDataLayer(t) + parentCtx := ContextWithDataLayer(t.Context(), base) + + interceptor := UnaryCountingInterceptor(nil) + + _, err := interceptor(parentCtx, "req", &grpc.UnaryServerInfo{FullMethod: "/x/y"}, + func(ctx context.Context, _ any) (any, error) { + return "ok", nil + }) + require.NoError(t, err) +} + +// streamStub implements the minimum of grpc.ServerStream needed by the +// interceptor — just Context(). +type streamStub struct { + grpc.ServerStream + ctx context.Context +} + +func (s *streamStub) Context() context.Context { return s.ctx } + +func TestStreamCountingInterceptor(t *testing.T) { + base, _ := newTestDataLayer(t) + streamCtx := ContextWithDataLayer(t.Context(), base) + + var exportedCounts *MethodCounts + interceptor := StreamCountingInterceptor(func(counts *MethodCounts) { + exportedCounts = counts + }) + + stub := &streamStub{ctx: streamCtx} + err := interceptor(nil, stub, &grpc.StreamServerInfo{FullMethod: "/test/Stream"}, + func(_ any, ss grpc.ServerStream) error { + dl := MustFromContext(ss.Context()) + _, err := dl.HeadRevision(ss.Context()) + return err + }) + require.NoError(t, err) + require.NotNil(t, exportedCounts) +} + +func TestStreamCountingInterceptor_NilExporter(t *testing.T) { + base, _ := newTestDataLayer(t) + streamCtx := ContextWithDataLayer(t.Context(), base) + + interceptor := StreamCountingInterceptor(nil) + err := interceptor(nil, &streamStub{ctx: streamCtx}, &grpc.StreamServerInfo{FullMethod: "/x"}, + func(_ any, _ grpc.ServerStream) error { return nil }) + require.NoError(t, err) +} + +// TestNewReadOnlyDataLayer_RejectsWrite ensures the read-only adapter returns a +// readonly error from ReadWriteTx and pass-throughs everything else. +func TestNewReadOnlyDataLayer_RejectsWrite(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + t.Cleanup(func() { _ = ds.Close() }) + + ro := NewReadOnlyDataLayer(ds) + ctx := t.Context() + + _, err = ro.ReadWriteTx(ctx, func(context.Context, ReadWriteTransaction) error { + return nil + }) + require.Error(t, err) + + rev, err := ro.HeadRevision(ctx) + require.NoError(t, err) + require.NoError(t, ro.CheckRevision(ctx, rev)) + + readyState, err := ro.ReadyState(ctx) + require.NoError(t, err) + require.True(t, readyState.IsReady) + + _, err = ro.Features(ctx) + require.NoError(t, err) + + _, err = ro.OfflineFeatures() + require.NoError(t, err) + + _, err = ro.Statistics(ctx) + require.NoError(t, err) + + _, err = ro.UniqueID(ctx) + require.NoError(t, err) + + _, err = ro.MetricsID() + require.NoError(t, err) + + optRev, err := ro.OptimizedRevision(ctx) + require.NoError(t, err) + require.NotNil(t, optRev) + + rtr, err := ro.RevisionFromString(rev.String()) + require.NoError(t, err) + require.True(t, rtr.Equal(rev)) + + watchCh, errCh := ro.Watch(ctx, rev, datastore.WatchOptions{Content: datastore.WatchRelationships}) + require.NotNil(t, watchCh) + require.NotNil(t, errCh) + + // UnwrapDatastore returns nil for the readonly adapter. + require.Nil(t, UnwrapDatastore(ro)) + + require.NoError(t, ro.Close()) +} + +// TestReadOnlyDataLayer_ReaderPassThroughs exercises the snapshot reader on the +// readonly adapter. +func TestReadOnlyDataLayer_ReaderPassThroughs(t *testing.T) { + ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + require.NoError(t, err) + t.Cleanup(func() { _ = ds.Close() }) + + ro := NewReadOnlyDataLayer(ds) + ctx := t.Context() + + rev, err := ro.HeadRevision(ctx) + require.NoError(t, err) + + reader := ro.SnapshotReader(rev) + + // ReadSchema on revisionedReader (shared between impl.go and readonly adapter). + sr, err := reader.ReadSchema(ctx) + require.NoError(t, err) + require.NotNil(t, sr) + + iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{OptionalResourceType: "resource"}) + require.NoError(t, err) + rels, err := datastore.IteratorToSlice(iter) + require.NoError(t, err) + require.Empty(t, rels) + + rit, err := reader.ReverseQueryRelationships(ctx, datastore.SubjectsFilter{SubjectType: "user"}) + require.NoError(t, err) + rrels, err := datastore.IteratorToSlice(rit) + require.NoError(t, err) + require.Empty(t, rrels) + + _, err = reader.CountRelationships(ctx, "missing") + require.Error(t, err) + + _, err = reader.LookupCounters(ctx) + require.NoError(t, err) +} + +// TestReadWriteTransaction_AllMethods exercises the RW transaction wrappers in +// impl.go: WriteRelationships, DeleteRelationships, ReadSchema, +// CountRelationships, LookupCounters, BulkLoad, counter registration, and the +// legacy schema writer. +func TestReadWriteTransaction_AllMethods(t *testing.T) { + dl, _ := newTestDataLayer(t) + ctx := t.Context() + + _, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + require.NoError(t, rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("resource:foo#viewer@user:tom")), + tuple.Create(tuple.MustParse("resource:bar#viewer@user:fred")), + })) + + _, _, err := rwt.DeleteRelationships(ctx, &v1.RelationshipFilter{ + ResourceType: "resource", + OptionalResourceId: "bar", + }) + require.NoError(t, err) + + sr, err := rwt.ReadSchema(ctx) + require.NoError(t, err) + require.NotNil(t, sr) + + _, err = rwt.CountRelationships(ctx, "missing") + require.Error(t, err) + + _, err = rwt.LookupCounters(ctx) + require.NoError(t, err) + + // BulkLoad via the passthrough iterator. + src := &stubBulkSourceDL{rels: []tuple.Relationship{ + tuple.MustParse("resource:bulk1#viewer@user:alice"), + tuple.MustParse("resource:bulk2#viewer@user:bob"), + }} + loaded, err := rwt.BulkLoad(ctx, src) + require.NoError(t, err) + require.Equal(t, uint64(2), loaded) + + // Counter registration + store + unregister. + counterName := "my_counter" + require.NoError(t, rwt.RegisterCounter(ctx, counterName, &core.RelationshipFilter{ + ResourceType: "resource", + })) + + // StoreCounterValue requires a revision; use a zero/no revision. + require.NoError(t, rwt.StoreCounterValue(ctx, counterName, 42, datastore.NoRevision)) + + require.NoError(t, rwt.UnregisterCounter(ctx, counterName)) + + // LegacySchemaWriter surface: exercise each method. memdb's legacy writer + // accepts empty inputs. + lw := rwt.LegacySchemaWriter() + require.NotNil(t, lw) + require.NoError(t, lw.LegacyWriteCaveats(ctx, nil)) + require.NoError(t, lw.LegacyWriteNamespaces(ctx)) + require.NoError(t, lw.LegacyDeleteCaveats(ctx, nil)) + require.NoError(t, lw.LegacyDeleteNamespaces(ctx, nil, datastore.DeleteNamespacesOnly)) + + // WriteSchema with no definitions should succeed via the legacy path. + require.NoError(t, rwt.WriteSchema(ctx, nil, "", nil)) + + return nil + }) + require.NoError(t, err) +} + +// stubBulkSourceDL is a minimal BulkWriteRelationshipSource for tests. +type stubBulkSourceDL struct { + rels []tuple.Relationship + idx int +} + +func (s *stubBulkSourceDL) Next(_ context.Context) (*tuple.Relationship, error) { + if s.idx >= len(s.rels) { + return nil, nil + } + rel := s.rels[s.idx] + s.idx++ + return &rel, nil +} diff --git a/pkg/query/arrow_test.go b/pkg/query/arrow_test.go index 388a5de09d..c888f202e6 100644 --- a/pkg/query/arrow_test.go +++ b/pkg/query/arrow_test.go @@ -8,6 +8,22 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" ) +func TestNewSchemaArrow(t *testing.T) { + left := NewEmptyFixedIterator() + right := NewEmptyFixedIterator() + + arrow := NewSchemaArrow(left, right) + require.NotNil(t, arrow) + require.Same(t, left, arrow.left) + require.Same(t, right, arrow.right) + require.Equal(t, leftToRight, arrow.direction) + require.True(t, arrow.isSchemaArrow, "NewSchemaArrow should mark isSchemaArrow=true") + + // NewArrowIterator should, by contrast, mark isSchemaArrow=false. + subrelArrow := NewArrowIterator(left, right) + require.False(t, subrelArrow.isSchemaArrow) +} + // testArrowBothDirections runs the same test with both arrow directions func testArrowBothDirections(t *testing.T, name string, testFn func(t *testing.T, direction arrowDirection)) { t.Run(name+"_LTR", func(t *testing.T) { diff --git a/pkg/query/datastore_test.go b/pkg/query/datastore_test.go index 7f8d3d718a..aaed0c5b54 100644 --- a/pkg/query/datastore_test.go +++ b/pkg/query/datastore_test.go @@ -345,3 +345,14 @@ func TestDatastoreIterator_Types(t *testing.T) { require.Equal(tuple.Ellipsis, subjectTypes[0].Subrelation) // Ellipsis is preserved as-is }) } + +func TestDatastoreIterator_ReplaceSubiteratorsPanics(t *testing.T) { + baseRel := createTestBaseRelation("document", "viewer", "user", "") + iter := NewDatastoreIterator(baseRel) + + require.Empty(t, iter.Subiterators(), "DatastoreIterator is a leaf and has no subiterators") + + require.Panics(t, func() { + _, _ = iter.ReplaceSubiterators([]Iterator{NewEmptyFixedIterator()}) + }) +} diff --git a/pkg/schema/errors_test.go b/pkg/schema/errors_test.go new file mode 100644 index 0000000000..c6d62f1751 --- /dev/null +++ b/pkg/schema/errors_test.go @@ -0,0 +1,288 @@ +package schema + +import ( + "errors" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func TestDefinitionNotFoundError(t *testing.T) { + err := NewDefinitionNotFoundErr("document") + + var dnfe DefinitionNotFoundError + require.ErrorAs(t, err, &dnfe) + require.Equal(t, "document", dnfe.NotFoundNamespaceName()) + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{"definition_name": "document"}, dnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + dnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestRelationNotFoundError(t *testing.T) { + err := NewRelationNotFoundErr("document", "viewer") + + var rnfe RelationNotFoundError + require.ErrorAs(t, err, &rnfe) + require.Equal(t, "document", rnfe.NamespaceName()) + require.Equal(t, "viewer", rnfe.NotFoundRelationName()) + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_or_permission_name": "viewer", + }, rnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + rnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestCaveatNotFoundError(t *testing.T) { + err := NewCaveatNotFoundErr("is_weekend") + + var cnfe CaveatNotFoundError + require.ErrorAs(t, err, &cnfe) + require.Equal(t, "is_weekend", cnfe.CaveatName()) + require.Contains(t, err.Error(), "is_weekend") + + require.Equal(t, map[string]string{"caveat_name": "is_weekend"}, cnfe.DetailsMetadata()) + + require.NotPanics(t, func() { + cnfe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestDuplicateRelationError(t *testing.T) { + err := NewDuplicateRelationError("document", "viewer") + + var dre DuplicateRelationError + require.ErrorAs(t, err, &dre) + require.Contains(t, err.Error(), "duplicate") + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_or_permission_name": "viewer", + }, dre.DetailsMetadata()) + + require.NotPanics(t, func() { + dre.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestDuplicateAllowedRelationError(t *testing.T) { + err := NewDuplicateAllowedRelationErr("document", "viewer", "user") + + var dare DuplicateAllowedRelationError + require.ErrorAs(t, err, &dare) + require.Contains(t, err.Error(), "duplicate") + require.Contains(t, err.Error(), "user") + require.Contains(t, err.Error(), "viewer") + require.Contains(t, err.Error(), "document") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_name": "viewer", + "allowed_relation": "user", + }, dare.DetailsMetadata()) + + require.NotPanics(t, func() { + dare.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestPermissionUsedOnLeftOfArrowError(t *testing.T) { + err := NewPermissionUsedOnLeftOfArrowErr("document", "view", "edit") + + var pulae PermissionUsedOnLeftOfArrowError + require.ErrorAs(t, err, &pulae) + require.Contains(t, err.Error(), "view") + require.Contains(t, err.Error(), "edit") + require.Contains(t, err.Error(), "document") + require.Contains(t, err.Error(), "left hand side of an arrow") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "permission_name": "view", + "used_permission_name": "edit", + }, pulae.DetailsMetadata()) + + require.NotPanics(t, func() { + pulae.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestWildcardUsedInArrowError(t *testing.T) { + err := NewWildcardUsedInArrowErr("document", "view", "parent", "user", "member") + + var wuiae WildcardUsedInArrowError + require.ErrorAs(t, err, &wuiae) + require.Contains(t, err.Error(), "wildcard") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "permission_name": "view", + "accessed_relation_name": "parent", + }, wuiae.DetailsMetadata()) + + require.NotPanics(t, func() { + wuiae.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestMissingAllowedRelationsError(t *testing.T) { + err := NewMissingAllowedRelationsErr("document", "viewer") + + var mare MissingAllowedRelationsError + require.ErrorAs(t, err, &mare) + require.Contains(t, err.Error(), "viewer") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_name": "viewer", + }, mare.DetailsMetadata()) + + require.NotPanics(t, func() { + mare.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestTransitiveWildcardError(t *testing.T) { + err := NewTransitiveWildcardErr("document", "viewer", "group", "member", "user", "owner") + + var twe TransitiveWildcardError + require.ErrorAs(t, err, &twe) + require.Contains(t, err.Error(), "wildcard") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "relation_name": "viewer", + }, twe.DetailsMetadata()) + + require.NotPanics(t, func() { + twe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestPermissionsCycleError(t *testing.T) { + err := NewPermissionsCycleErr("document", []string{"view", "edit", "admin"}) + + var pce PermissionsCycleError + require.ErrorAs(t, err, &pce) + require.Contains(t, err.Error(), "cycle") + require.Contains(t, err.Error(), "view") + require.Contains(t, err.Error(), "edit") + require.Contains(t, err.Error(), "admin") + + require.Equal(t, map[string]string{ + "definition_name": "document", + "permission_names": "view,edit,admin", + }, pce.DetailsMetadata()) + + require.NotPanics(t, func() { + pce.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestUnusedCaveatParameterError(t *testing.T) { + err := NewUnusedCaveatParameterErr("is_weekend", "day") + + var ucpe UnusedCaveatParameterError + require.ErrorAs(t, err, &ucpe) + require.Contains(t, err.Error(), "day") + require.Contains(t, err.Error(), "is_weekend") + + require.Equal(t, map[string]string{ + "caveat_name": "is_weekend", + "parameter_name": "day", + }, ucpe.DetailsMetadata()) + + require.NotPanics(t, func() { + ucpe.MarshalZerologObject(zerolog.Dict()) + }) +} + +func TestKnownEmptyFilterError(t *testing.T) { + err := NewKnownEmptyFilterErr("no matching relations") + + var kefe KnownEmptyFilterError + require.ErrorAs(t, err, &kefe) + require.Equal(t, "no matching relations", kefe.WarningMessage()) + require.Contains(t, err.Error(), "no matching relations") + + require.Equal(t, map[string]string{ + "warning_message": "no matching relations", + }, kefe.DetailsMetadata()) + + require.NotPanics(t, func() { + kefe.MarshalZerologObject(zerolog.Dict()) + }) + + // KnownEmptyFilterError is wrapped in a TypeError. + var te TypeError + require.ErrorAs(t, err, &te) +} + +func TestTypeErrorUnwrap(t *testing.T) { + base := errors.New("underlying") + te := TypeError{error: base} + + require.Equal(t, base, te.Unwrap()) + require.ErrorIs(t, te, base) +} + +func TestAsTypeErrorNil(t *testing.T) { + require.NoError(t, asTypeError(nil)) +} + +func TestAsTypeErrorDoubleWrap(t *testing.T) { + inner := errors.New("inner") + once := asTypeError(inner) + twice := asTypeError(once) + + // Wrapping an already-TypeError should return the same error, not double-wrap. + require.Equal(t, once, twice) +} + +func TestNewTypeWithSourceErrorNilPosition(t *testing.T) { + // A core.Relation has no source position set when zero-valued. + withSource := &core.Relation{} + require.Nil(t, withSource.GetSourcePosition()) + + err := NewTypeWithSourceError(errors.New("bad"), withSource, "relation foo") + require.Error(t, err) + + // The result should be wrapped as a TypeError. + var te TypeError + require.ErrorAs(t, err, &te) +} + +func TestNewTypeWithSourceErrorWithPosition(t *testing.T) { + withSource := &core.Relation{ + SourcePosition: &core.SourcePosition{ + ZeroIndexedLineNumber: 4, + ZeroIndexedColumnPosition: 7, + }, + } + + err := NewTypeWithSourceError(errors.New("bad"), withSource, "relation foo") + require.Error(t, err) + + var te TypeError + require.ErrorAs(t, err, &te) +} + +func TestBacktickNames(t *testing.T) { + require.Empty(t, backtickNames(nil)) + require.Equal(t, []string{"`a`"}, backtickNames([]string{"a"})) + require.Equal(t, []string{"`a`", "`b`", "`c`"}, backtickNames([]string{"a", "b", "c"})) +} diff --git a/pkg/schema/v2/flatten_test.go b/pkg/schema/v2/flatten_test.go index 091f993af2..942524a778 100644 --- a/pkg/schema/v2/flatten_test.go +++ b/pkg/schema/v2/flatten_test.go @@ -11,6 +11,49 @@ import ( "github.com/authzed/spicedb/pkg/schemadsl/input" ) +func TestWalkFlattenedSchema(t *testing.T) { + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("test"), + SchemaString: `definition user {} +definition document { + relation viewer: user + permission view = viewer +}`, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(t, err) + + schema, err := BuildSchemaFromCompiledSchema(*compiled) + require.NoError(t, err) + + resolved, err := ResolveSchema(schema) + require.NoError(t, err) + + flattened, err := FlattenSchema(resolved, FlattenSeparatorDoubleUnderscore) + require.NoError(t, err) + require.NotNil(t, flattened) + require.Same(t, flattened.resolvedSchema, flattened.ResolvedSchema()) + + visitor := &testVisitor{} + _, err = WalkFlattenedSchema(flattened, visitor, struct{}{}) + require.NoError(t, err) + + require.NotEmpty(t, visitor.schemas) + require.GreaterOrEqual(t, len(visitor.definitions), 2) +} + +func TestWalkFlattenedSchema_Nil(t *testing.T) { + visitor := &testVisitor{} + _, err := WalkFlattenedSchema[struct{}](nil, visitor, struct{}{}) + require.NoError(t, err) + require.Empty(t, visitor.schemas) +} + +func TestFlattenSchemaWithOptions_NilInput(t *testing.T) { + _, err := FlattenSchemaWithOptions(nil, FlattenOptions{Separator: FlattenSeparatorDollar}) + require.Error(t, err) + require.ErrorContains(t, err, "cannot flatten nil resolved schema") +} + func TestFlattenSchema(t *testing.T) { tests := []struct { name string diff --git a/pkg/schema/v2/walk_public_wrappers_test.go b/pkg/schema/v2/walk_public_wrappers_test.go new file mode 100644 index 0000000000..02b53202f4 --- /dev/null +++ b/pkg/schema/v2/walk_public_wrappers_test.go @@ -0,0 +1,106 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +const walkWrapperSchema = `caveat is_weekend(day int) { day == 6 || day == 7 } + +definition user {} + +definition document { + relation viewer: user + relation editor: user with is_weekend + permission view = viewer + editor +}` + +func compileWalkWrapperSchema(t *testing.T) *Schema { + t.Helper() + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("test"), + SchemaString: walkWrapperSchema, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(t, err) + + schema, err := BuildSchemaFromCompiledSchema(*compiled) + require.NoError(t, err) + return schema +} + +func TestWalkCaveat_PublicWrapper(t *testing.T) { + schema := compileWalkWrapperSchema(t) + caveat, ok := schema.caveats["is_weekend"] + require.True(t, ok) + + visitor := &testVisitor{} + _, err := WalkCaveat(caveat, visitor, struct{}{}) + require.NoError(t, err) + require.Len(t, visitor.caveats, 1) + require.Same(t, caveat, visitor.caveats[0]) +} + +func TestWalkCaveatWithOptions_PublicWrapper(t *testing.T) { + schema := compileWalkWrapperSchema(t) + caveat := schema.caveats["is_weekend"] + + visitor := &testVisitor{} + opts, err := NewWalkOptions().WithStrategy(WalkPostOrder).Build() + require.NoError(t, err) + + _, err = WalkCaveatWithOptions(caveat, visitor, struct{}{}, opts) + require.NoError(t, err) + require.Len(t, visitor.caveats, 1) +} + +func TestWalkBaseRelation_PublicWrapper(t *testing.T) { + schema := compileWalkWrapperSchema(t) + def := schema.definitions["document"] + rel := def.relations["viewer"] + require.NotEmpty(t, rel.baseRelations) + + visitor := &testVisitor{} + _, err := WalkBaseRelation(rel.baseRelations[0], visitor, struct{}{}) + require.NoError(t, err) + require.Len(t, visitor.baseRelations, 1) +} + +func TestWalkBaseRelationWithOptions_PublicWrapper(t *testing.T) { + schema := compileWalkWrapperSchema(t) + rel := schema.definitions["document"].relations["viewer"] + + visitor := &testVisitor{} + opts, err := NewWalkOptions().Build() + require.NoError(t, err) + + _, err = WalkBaseRelationWithOptions(rel.baseRelations[0], visitor, struct{}{}, opts) + require.NoError(t, err) + require.Len(t, visitor.baseRelations, 1) +} + +// TestWalkOptions_WithStrategyValueReceiver covers the value-receiver form of +// WithStrategy on WalkOptions (distinct from the *WalkOptionsBuilder method). +func TestWalkOptions_WithStrategyValueReceiver(t *testing.T) { + original := defaultWalkOptions() + require.Equal(t, WalkPreOrder, original.strategy) + + updated := original.WithStrategy(WalkPostOrder) + require.Equal(t, WalkPostOrder, updated.strategy) + // Original should remain unchanged (value receiver). + require.Equal(t, WalkPreOrder, original.strategy) +} + +func TestWalkOptionsBuilder_MustBuildPanicsOnError(t *testing.T) { + // Passing nil to WithTraverseArrowTargets triggers an error inside ResolveSchema. + builder := NewWalkOptions(). + WithStrategy(WalkPostOrder). + WithTraverseArrowTargets(nil) + + require.Panics(t, func() { + _ = builder.MustBuild() + }) +}