diff --git a/router/core/cache_warmup_cdn.go b/router/core/cache_warmup_cdn.go index 526e618d87..a5fff8dbce 100644 --- a/router/core/cache_warmup_cdn.go +++ b/router/core/cache_warmup_cdn.go @@ -24,6 +24,7 @@ var _ CacheWarmupSource = (*CDNSource)(nil) type CDNSource struct { cdnURL *url.URL + cdnFallbackURL *url.URL authenticationToken string // federatedGraphID is the ID of the federated graph that was obtained // from the token, already url-escaped @@ -34,12 +35,20 @@ type CDNSource struct { httpClient *http.Client } -func NewCDNSource(endpoint, token string, logger *zap.Logger) (*CDNSource, error) { +func NewCDNSource(endpoint, fallbackEndpoint, token string, logger *zap.Logger) (*CDNSource, error) { u, err := url.Parse(endpoint) if err != nil { return nil, err } + var fu *url.URL = nil + if fallbackEndpoint != "" { + fu, err = url.Parse(fallbackEndpoint) + if err != nil { + return nil, err + } + } + claims, err := jwt.ExtractFederatedGraphTokenClaims(token) if err != nil { return nil, err @@ -51,6 +60,7 @@ func NewCDNSource(endpoint, token string, logger *zap.Logger) (*CDNSource, error return &CDNSource{ cdnURL: u, + cdnFallbackURL: fu, authenticationToken: token, federatedGraphID: claims.FederatedGraphID, organizationID: claims.OrganizationID, @@ -62,14 +72,52 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O span := trace.SpanFromContext(ctx) defer span.End() - operationsPath := fmt.Sprintf("/%s/%s/cache_warmup/operations.json", c.organizationID, c.federatedGraphID) + resp, body, err := c.fetchOperationsJSON(ctx, log, c.cdnURL) + + if err != nil && c.cdnFallbackURL != nil && httpclient.IsCDNFallbackEligible(resp, err) { + primaryErr := err + log.Warn("Primary CDN failed, attempting fallback CDN", + zap.Error(err), + zap.String("fallback_url", c.cdnFallbackURL.String()), + ) + span.AddEvent("cdn.fallback", trace.WithAttributes( + semconv.HTTPURL(c.cdnFallbackURL.String()), + )) + _, body, err = c.fetchOperationsJSON(ctx, log, c.cdnFallbackURL) + if err != nil { + return nil, fmt.Errorf("primary CDN failed: %w; fallback CDN also failed: %v", primaryErr, err) + } + if body == nil { + return nil, fmt.Errorf("primary CDN failed: %w; fallback CDN returned no cache warmup config", primaryErr) + } + } + + if err != nil { + return nil, err + } + if body == nil { + return nil, nil + } + + var warmupOperations nodev1.CacheWarmerOperations + unmarshalOpts := protojson.UnmarshalOptions{DiscardUnknown: true} + if err := unmarshalOpts.Unmarshal(body, &warmupOperations); err != nil { + return nil, err + } + + return warmupOperations.GetOperations(), nil +} - operationURL := c.cdnURL.ResolveReference(&url.URL{Path: operationsPath}) +func (c *CDNSource) fetchOperationsJSON(ctx context.Context, log *zap.Logger, baseURL *url.URL) (*http.Response, []byte, error) { + span := trace.SpanFromContext(ctx) + + operationsPath := fmt.Sprintf("/%s/%s/cache_warmup/operations.json", c.organizationID, c.federatedGraphID) + operationURL := baseURL.ResolveReference(&url.URL{Path: operationsPath}) log.Debug("Loading cache warmup config", zap.String("url", operationURL.String())) req, err := http.NewRequestWithContext(ctx, "GET", operationURL.String(), nil) if err != nil { - return nil, err + return nil, nil, err } span.SetAttributes( @@ -84,7 +132,7 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O resp, err := c.httpClient.Do(req) if err != nil { - return nil, err + return nil, nil, err } defer func() { _ = resp.Body.Close() }() @@ -95,29 +143,23 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O if resp.StatusCode == http.StatusNotFound { log.Debug("Cache warmup config not found", zap.String("url", operationURL.String())) - return nil, nil + return resp, nil, nil } if resp.StatusCode == http.StatusUnauthorized { - return nil, errors.New("could not authenticate against CDN") + return resp, nil, errors.New("could not authenticate against CDN") } if resp.StatusCode == http.StatusBadRequest { - return nil, errors.New("bad request") + return resp, nil, errors.New("bad request") } - return nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode) + return resp, nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode) } body, err := c.readResponse(resp) if err != nil { - return nil, err + return resp, nil, err } - var warmupOperations nodev1.CacheWarmerOperations - unmarshalOpts := protojson.UnmarshalOptions{DiscardUnknown: true} - if err := unmarshalOpts.Unmarshal(body, &warmupOperations); err != nil { - return nil, err - } - - return warmupOperations.GetOperations(), nil + return resp, body, nil } func (c *CDNSource) readResponse(resp *http.Response) ([]byte, error) { diff --git a/router/core/cache_warmup_cdn_test.go b/router/core/cache_warmup_cdn_test.go new file mode 100644 index 0000000000..2ba4690ae5 --- /dev/null +++ b/router/core/cache_warmup_cdn_test.go @@ -0,0 +1,267 @@ +package core + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// validCacheWarmupJSON is a valid CacheWarmerOperations protobuf JSON with one operation. +const validCacheWarmupJSON = `{ + "operations": [ + { + "request": { + "query": "query { hello }" + } + } + ] +}` + +func newTestCDNSource(primaryURL string, fallbackURL string) *CDNSource { + u, _ := url.Parse(primaryURL) + var fu *url.URL + if fallbackURL != "" { + fu, _ = url.Parse(fallbackURL) + } + return &CDNSource{ + cdnURL: u, + cdnFallbackURL: fu, + authenticationToken: "test-token", + federatedGraphID: "test-graph", + organizationID: "test-org", + httpClient: http.DefaultClient, + } +} + +func TestCDNSource_LoadItems(t *testing.T) { + t.Parallel() + + t.Run("primary 200 without fallback", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer primary.Close() + + source := newTestCDNSource(primary.URL, "") + items, err := source.LoadItems(context.Background(), zap.NewNop()) + require.NoError(t, err) + require.Len(t, items, 1) + assert.Equal(t, "query { hello }", items[0].Request.Query) + }) + + t.Run("primary 200 with fallback configured", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + items, err := source.LoadItems(context.Background(), zap.NewNop()) + require.NoError(t, err) + require.Len(t, items, 1) + assert.False(t, fallbackCalled.Load(), "fallback should not be called when primary succeeds") + }) + + t.Run("primary 404 with fallback does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + items, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.NoError(t, err) + assert.Nil(t, items) + assert.False(t, fallbackCalled.Load(), "fallback should not be called on 404") + }) + + t.Run("primary 401 with fallback does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authenticate") + assert.False(t, fallbackCalled.Load(), "fallback should not be called on 401") + }) + + t.Run("primary 503 without fallback returns error", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + source := newTestCDNSource(primary.URL, "") + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary 503 with fallback 200 returns items from fallback", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + items, err := source.LoadItems(context.Background(), zap.NewNop()) + require.NoError(t, err) + require.Len(t, items, 1) + assert.Equal(t, "query { hello }", items[0].Request.Query) + }) + + t.Run("primary 429 with fallback 200 returns items from fallback", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + items, err := source.LoadItems(context.Background(), zap.NewNop()) + require.NoError(t, err) + require.Len(t, items, 1) + }) + + t.Run("primary 503 fallback 503 returns primary error", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "fallback CDN also failed") + }) + + t.Run("primary 503 fallback 404 preserves primary error", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary 503 fallback 401 preserves primary error", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary network error with fallback 200 returns items from fallback", func(t *testing.T) { + t.Parallel() + // Use an immediately-closed server to simulate network error + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validCacheWarmupJSON)) + })) + defer fallback.Close() + + source := newTestCDNSource(primary.URL, fallback.URL) + items, err := source.LoadItems(context.Background(), zap.NewNop()) + require.NoError(t, err) + require.Len(t, items, 1) + }) + + t.Run("primary network error without fallback returns error", func(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + source := newTestCDNSource(primary.URL, "") + _, err := source.LoadItems(context.Background(), zap.NewNop()) + assert.Error(t, err) + }) +} diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 23812b8fe4..41deb0da7c 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1460,7 +1460,7 @@ func (s *graphServer) buildGraphMux( warmupConfig.FallbackSource = NewPlanSource(opts.ReloadPersistentState.inMemoryPlanCacheFallback.getPlanCacheForFF(opts.FeatureFlagName)) opts.ReloadPersistentState.inMemoryPlanCacheFallback.setPlanCacheForFF(opts.FeatureFlagName, gm.planFallbackCache) } - cdnSource, err := NewCDNSource(s.cdnConfig.URL, s.graphApiToken, s.logger) + cdnSource, err := NewCDNSource(s.cdnConfig.URL, s.cdnConfig.FallbackURL, s.graphApiToken, s.logger) if err != nil { return nil, fmt.Errorf("failed to create cdn source: %w", err) } diff --git a/router/core/init_config_poller.go b/router/core/init_config_poller.go index b8358ccae0..3f0776328a 100644 --- a/router/core/init_config_poller.go +++ b/router/core/init_config_poller.go @@ -29,6 +29,7 @@ func getConfigClient(r *Router, registry *ProviderRegistry, providerID string, i Logger: r.logger, SignatureKey: r.routerConfigPollerConfig.GraphSignKey, RouterCompatibilityVersion: execution_config.RouterCompatibilityVersionThreshold, + FallbackEndpoint: provider.FallbackURL, }) if err != nil { return nil, err @@ -103,6 +104,7 @@ func getConfigClient(r *Router, registry *ProviderRegistry, providerID string, i Logger: r.logger, SignatureKey: r.routerConfigPollerConfig.GraphSignKey, RouterCompatibilityVersion: execution_config.RouterCompatibilityVersionThreshold, + FallbackEndpoint: r.cdnConfig.FallbackURL, }) if err != nil { return nil, err diff --git a/router/core/router.go b/router/core/router.go index cb173417d3..9939fac122 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -1214,7 +1214,8 @@ func (r *Router) buildPersistedOpsClient(registry *ProviderRegistry) (persistedo } c, err := cdn.NewClient(provider.URL, r.graphApiToken, cdn.Options{ - Logger: r.logger, + Logger: r.logger, + FallbackEndpoint: provider.FallbackURL, }) if err != nil { return nil, nil, fmt.Errorf("failed to create CDN client: %w", err) @@ -1267,7 +1268,8 @@ func (r *Router) buildPersistedOpsClient(registry *ProviderRegistry) (persistedo } c, err := cdn.NewClient(r.cdnConfig.URL, r.graphApiToken, cdn.Options{ - Logger: r.logger, + Logger: r.logger, + FallbackEndpoint: r.cdnConfig.FallbackURL, }) if err != nil { return nil, nil, fmt.Errorf("failed to create CDN client: %w", err) @@ -1373,7 +1375,7 @@ func (r *Router) buildManifestStore(ctx context.Context, registry *ProviderRegis return nil, errors.New("graph token is required for PQL manifest") } - fetcher, err := pqlmanifest.NewFetcher(r.cdnConfig.URL, r.graphApiToken, r.logger) + fetcher, err := pqlmanifest.NewFetcher(r.cdnConfig.URL, r.cdnConfig.FallbackURL, r.graphApiToken, r.logger) if err != nil { return nil, fmt.Errorf("failed to create PQL manifest fetcher: %w", err) } diff --git a/router/internal/httpclient/fallback.go b/router/internal/httpclient/fallback.go new file mode 100644 index 0000000000..123d2b7e45 --- /dev/null +++ b/router/internal/httpclient/fallback.go @@ -0,0 +1,34 @@ +package httpclient + +import ( + "context" + "errors" + "net/http" +) + +// IsCDNFallbackEligible returns true if the error or response indicates a +// server-side failure that warrants retrying against a fallback CDN URL. +// It returns true for HTTP 5xx, 429, and network errors. +// It returns false for client errors (401, 400, 404), context cancellation, +// and context deadline exceeded. +func IsCDNFallbackEligible(resp *http.Response, err error) bool { + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + // If we have a response, use the status code to decide + if resp != nil { + return isServerErrorStatus(resp.StatusCode) + } + // No response means network error → fallback + return true + } + if resp != nil { + return isServerErrorStatus(resp.StatusCode) + } + return false +} + +func isServerErrorStatus(statusCode int) bool { + return statusCode >= 500 || statusCode == http.StatusTooManyRequests +} diff --git a/router/internal/httpclient/fallback_test.go b/router/internal/httpclient/fallback_test.go new file mode 100644 index 0000000000..71848195d4 --- /dev/null +++ b/router/internal/httpclient/fallback_test.go @@ -0,0 +1,121 @@ +package httpclient + +import ( + "context" + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsCDNFallbackEligible(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resp *http.Response + err error + expected bool + }{ + { + name: "500 status code", + resp: &http.Response{StatusCode: http.StatusInternalServerError}, + err: errors.New("unexpected status code"), + expected: true, + }, + { + name: "502 status code", + resp: &http.Response{StatusCode: http.StatusBadGateway}, + err: errors.New("unexpected status code"), + expected: true, + }, + { + name: "503 status code", + resp: &http.Response{StatusCode: http.StatusServiceUnavailable}, + err: errors.New("unexpected status code"), + expected: true, + }, + { + name: "429 status code", + resp: &http.Response{StatusCode: http.StatusTooManyRequests}, + err: errors.New("unexpected status code"), + expected: true, + }, + { + name: "5xx response without error", + resp: &http.Response{StatusCode: http.StatusInternalServerError}, + err: nil, + expected: true, + }, + { + name: "200 status code", + resp: &http.Response{StatusCode: http.StatusOK}, + err: nil, + expected: false, + }, + { + name: "401 status code", + resp: &http.Response{StatusCode: http.StatusUnauthorized}, + err: errors.New("unauthorized"), + expected: false, + }, + { + name: "400 status code", + resp: &http.Response{StatusCode: http.StatusBadRequest}, + err: errors.New("bad request"), + expected: false, + }, + { + name: "404 status code", + resp: &http.Response{StatusCode: http.StatusNotFound}, + err: errors.New("not found"), + expected: false, + }, + { + name: "304 status code", + resp: &http.Response{StatusCode: http.StatusNotModified}, + err: errors.New("not modified"), + expected: false, + }, + { + name: "network error no response", + resp: nil, + err: errors.New("connection refused"), + expected: true, + }, + { + name: "context canceled", + resp: nil, + err: context.Canceled, + expected: false, + }, + { + name: "context deadline exceeded", + resp: nil, + err: context.DeadlineExceeded, + expected: false, + }, + { + name: "wrapped context canceled", + resp: nil, + err: fmt.Errorf("request failed: %w", context.Canceled), + expected: false, + }, + { + name: "nil response nil error", + resp: nil, + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := IsCDNFallbackEligible(tt.resp, tt.err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/router/internal/persistedoperation/operationstorage/cdn/client.go b/router/internal/persistedoperation/operationstorage/cdn/client.go index 34f69df234..32655fce72 100644 --- a/router/internal/persistedoperation/operationstorage/cdn/client.go +++ b/router/internal/persistedoperation/operationstorage/cdn/client.go @@ -23,7 +23,8 @@ import ( ) type Options struct { - Logger *zap.Logger + Logger *zap.Logger + FallbackEndpoint string } // Deprecated: The CDN-based persisted operation Client is deprecated. @@ -33,6 +34,7 @@ var _ persistedoperation.StorageClient = (*Client)(nil) type Client struct { cdnURL *url.URL + cdnFallbackURL *url.URL authenticationToken string // federatedGraphID is the ID of the federated graph that was obtained // from the token, already url-escaped @@ -53,6 +55,14 @@ func NewClient(endpoint string, token string, opts Options) (*Client, error) { return nil, fmt.Errorf("invalid CDN URL %q: %w", endpoint, err) } + var fu *url.URL + if opts.FallbackEndpoint != "" { + fu, err = url.Parse(opts.FallbackEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid CDN fallback URL %q: %w", opts.FallbackEndpoint, err) + } + } + if opts.Logger == nil { opts.Logger = zap.NewNop() } @@ -67,13 +77,14 @@ func NewClient(endpoint string, token string, opts Options) (*Client, error) { zap.String("url", endpoint), ) - fetcher, err := pqlmanifest.NewFetcher(endpoint, token, logger) + fetcher, err := pqlmanifest.NewFetcher(endpoint, opts.FallbackEndpoint, token, logger) if err != nil { return nil, fmt.Errorf("failed to create manifest fetcher: %w", err) } return &Client{ cdnURL: u, + cdnFallbackURL: fu, authenticationToken: token, federatedGraphID: url.PathEscape(claims.FederatedGraphID), organizationID: url.PathEscape(claims.OrganizationID), @@ -93,7 +104,40 @@ func (cdn *Client) PersistedOperation(ctx context.Context, clientName string, sh } func (cdn *Client) persistedOperation(ctx context.Context, clientName string, sha256Hash string) ([]byte, error) { + span := trace.SpanFromContext(ctx) + + resp, body, err := cdn.doPersistedOperation(ctx, clientName, sha256Hash, cdn.cdnURL) + + if err != nil && cdn.cdnFallbackURL != nil && httpclient.IsCDNFallbackEligible(resp, err) { + primaryErr := err + cdn.logger.Warn("Primary CDN failed, attempting fallback CDN for persisted operation", + zap.String("fallback_url", cdn.cdnFallbackURL.String()), + ) + span.AddEvent("cdn.fallback", trace.WithAttributes( + semconv.HTTPURL(cdn.cdnFallbackURL.String()), + )) + var fallbackErr error + _, body, fallbackErr = cdn.doPersistedOperation(ctx, clientName, sha256Hash, cdn.cdnFallbackURL) + if fallbackErr != nil { + return nil, fmt.Errorf("primary CDN failed: %w; fallback CDN also failed: %v", primaryErr, fallbackErr) + } + span.SetStatus(codes.Ok, "") + err = nil + } + + if err != nil { + return nil, err + } + + var po persistedoperation.PersistedOperation + if err := json.Unmarshal(body, &po); err != nil { + return nil, err + } + return []byte(po.Body), nil +} + +func (cdn *Client) doPersistedOperation(ctx context.Context, clientName string, sha256Hash string, baseURL *url.URL) (*http.Response, []byte, error) { span := trace.SpanFromContext(ctx) operationPath := fmt.Sprintf("/%s/%s/operations/%s/%s.json", @@ -101,11 +145,11 @@ func (cdn *Client) persistedOperation(ctx context.Context, clientName string, sh cdn.federatedGraphID, url.PathEscape(clientName), url.PathEscape(sha256Hash)) - operationURL := cdn.cdnURL.ResolveReference(&url.URL{Path: operationPath}) + operationURL := baseURL.ResolveReference(&url.URL{Path: operationPath}) req, err := http.NewRequestWithContext(ctx, "GET", operationURL.String(), nil) if err != nil { - return nil, err + return nil, nil, err } span.SetAttributes( @@ -118,7 +162,7 @@ func (cdn *Client) persistedOperation(ctx context.Context, clientName string, sh resp, err := cdn.httpClient.Do(req) if err != nil { - return nil, err + return nil, nil, err } defer func() { _ = resp.Body.Close() @@ -130,38 +174,32 @@ func (cdn *Client) persistedOperation(ctx context.Context, clientName string, sh span.SetStatus(codes.Error, fmt.Sprintf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode)) if resp.StatusCode == http.StatusNotFound { - return nil, &persistedoperation.PersistentOperationNotFoundError{ + return resp, nil, &persistedoperation.PersistentOperationNotFoundError{ ClientName: clientName, Sha256Hash: sha256Hash, } } if resp.StatusCode == http.StatusUnauthorized { - return nil, errors.New("could not authenticate against CDN") + return resp, nil, errors.New("could not authenticate against CDN") } if resp.StatusCode == http.StatusBadRequest { - return nil, errors.New("bad request") + return resp, nil, errors.New("bad request") } - return nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode) + return resp, nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode) } reader, cleanup, err := gzipAwareReader(resp) if err != nil { - return nil, err + return resp, nil, err } defer cleanup() body, err := io.ReadAll(reader) if err != nil { - return nil, errors.New("could not read the response body. " + err.Error()) + return resp, nil, errors.New("could not read the response body. " + err.Error()) } - var po persistedoperation.PersistedOperation - err = json.Unmarshal(body, &po) - if err != nil { - return nil, err - } - - return []byte(po.Body), nil + return resp, body, nil } // setCDNHeaders sets the common headers for CDN requests. diff --git a/router/internal/persistedoperation/operationstorage/cdn/client_test.go b/router/internal/persistedoperation/operationstorage/cdn/client_test.go new file mode 100644 index 0000000000..a80f771594 --- /dev/null +++ b/router/internal/persistedoperation/operationstorage/cdn/client_test.go @@ -0,0 +1,284 @@ +package cdn + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/persistedoperation" + "go.uber.org/zap" +) + +// validPersistedOpJSON is a valid PersistedOperation JSON response. +var validPersistedOpJSON = mustMarshalPersistedOp(persistedoperation.PersistedOperation{ + Version: 1, + Body: "query { hello }", +}) + +func mustMarshalPersistedOp(po persistedoperation.PersistedOperation) []byte { + data, err := json.Marshal(po) + if err != nil { + panic(err) + } + return data +} + +func newTestPersistedOpsClient(primaryURL, fallbackURL string) *Client { + u, _ := url.Parse(primaryURL) + var fu *url.URL + if fallbackURL != "" { + fu, _ = url.Parse(fallbackURL) + } + return &Client{ + cdnURL: u, + cdnFallbackURL: fu, + authenticationToken: "test-token", + federatedGraphID: "test-graph", + organizationID: "test-org", + httpClient: http.DefaultClient, + logger: zap.NewNop(), + } +} + +func TestPersistedOperation_Fallback(t *testing.T) { + t.Parallel() + + t.Run("primary 200 without fallback succeeds normally", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer primary.Close() + + c := newTestPersistedOpsClient(primary.URL, "") + body, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.NoError(t, err) + assert.Equal(t, "query { hello }", string(body)) + }) + + t.Run("primary 200 with fallback configured does not call fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + body, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.NoError(t, err) + assert.Equal(t, "query { hello }", string(body)) + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 404 with fallback does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + var notFoundErr *persistedoperation.PersistentOperationNotFoundError + assert.ErrorAs(t, err, ¬FoundErr) + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 401 with fallback does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + assert.Contains(t, err.Error(), "authenticate") + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 503 without fallback returns error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + c := newTestPersistedOpsClient(primary.URL, "") + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary 503 with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + body, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.NoError(t, err) + assert.Equal(t, "query { hello }", string(body)) + }) + + t.Run("primary 429 with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + body, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.NoError(t, err) + assert.Equal(t, "query { hello }", string(body)) + }) + + t.Run("primary 503 fallback 503 returns primary error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "fallback CDN also failed") + }) + + t.Run("primary 503 fallback 404 preserves primary error not NotFound", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + // Must NOT be a PersistentOperationNotFoundError + var notFoundErr *persistedoperation.PersistentOperationNotFoundError + assert.False(t, errors.As(err, ¬FoundErr), "should not return PersistentOperationNotFoundError when primary was 503") + }) + + t.Run("primary 503 fallback 401 preserves primary error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary network error without fallback returns error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + c := newTestPersistedOpsClient(primary.URL, "") + _, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.Error(t, err) + }) + + t.Run("primary network error with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(validPersistedOpJSON) + })) + defer fallback.Close() + + c := newTestPersistedOpsClient(primary.URL, fallback.URL) + body, err := c.PersistedOperation(context.Background(), "client1", "abc123") + require.NoError(t, err) + assert.Equal(t, "query { hello }", string(body)) + }) +} diff --git a/router/internal/persistedoperation/pqlmanifest/fetcher.go b/router/internal/persistedoperation/pqlmanifest/fetcher.go index 956d08ca80..51e1dab156 100644 --- a/router/internal/persistedoperation/pqlmanifest/fetcher.go +++ b/router/internal/persistedoperation/pqlmanifest/fetcher.go @@ -17,6 +17,7 @@ import ( type Fetcher struct { cdnURL *url.URL + cdnFallbackURL *url.URL authenticationToken string // federatedGraphID is the ID of the federated graph that was obtained // from the token, already url-escaped @@ -30,12 +31,20 @@ type Fetcher struct { // NewFetcher creates a new manifest fetcher. It reuses JWT extraction and HTTP client // setup patterns from the CDN persisted operations client. -func NewFetcher(endpoint, token string, logger *zap.Logger) (*Fetcher, error) { +func NewFetcher(endpoint, fallbackEndpoint, token string, logger *zap.Logger) (*Fetcher, error) { u, err := url.Parse(endpoint) if err != nil { return nil, fmt.Errorf("invalid CDN URL %q: %w", endpoint, err) } + var fu *url.URL + if fallbackEndpoint != "" { + fu, err = url.Parse(fallbackEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid CDN fallback URL %q: %w", fallbackEndpoint, err) + } + } + claims, err := jwt.ExtractFederatedGraphTokenClaims(token) if err != nil { return nil, err @@ -52,6 +61,7 @@ func NewFetcher(endpoint, token string, logger *zap.Logger) (*Fetcher, error) { return &Fetcher{ cdnURL: u, + cdnFallbackURL: fu, authenticationToken: token, federatedGraphID: url.PathEscape(claims.FederatedGraphID), organizationID: url.PathEscape(claims.OrganizationID), @@ -64,12 +74,53 @@ func NewFetcher(endpoint, token string, logger *zap.Logger) (*Fetcher, error) { // with Bearer auth, using If-None-Match for conditional requests. The CDN returns 304 Not Modified // when the ETag matches, avoiding a full download. Returns (manifest, changed, err). func (f *Fetcher) Fetch(ctx context.Context, currentRevision string) (*Manifest, bool, error) { + resp, body, err := f.doFetch(ctx, currentRevision, f.cdnURL) + + if err != nil && f.cdnFallbackURL != nil && httpclient.IsCDNFallbackEligible(resp, err) { + primaryErr := err + f.logger.Warn("Primary CDN failed, attempting fallback CDN for PQL manifest", + zap.Error(err), + zap.String("fallback_url", f.cdnFallbackURL.String()), + ) + var fallbackErr error + _, body, fallbackErr = f.doFetch(ctx, currentRevision, f.cdnFallbackURL) + if fallbackErr != nil { + return nil, false, fmt.Errorf("primary CDN failed: %w; fallback CDN also failed: %v", primaryErr, fallbackErr) + } + err = nil + if body == nil { + // Fallback returned 304 Not Modified + return nil, false, nil + } + } + + if err != nil { + return nil, false, err + } + if body == nil { + // 304 Not Modified + return nil, false, nil + } + + var manifest Manifest + if err := json.Unmarshal(body, &manifest); err != nil { + return nil, false, fmt.Errorf("could not unmarshal PQL manifest: %w", err) + } + + if err := validateManifest(&manifest); err != nil { + return nil, false, fmt.Errorf("invalid PQL manifest: %w", err) + } + + return &manifest, true, nil +} + +func (f *Fetcher) doFetch(ctx context.Context, currentRevision string, baseURL *url.URL) (*http.Response, []byte, error) { manifestPath := fmt.Sprintf("/%s/%s/operations/manifest.json", f.organizationID, f.federatedGraphID) - manifestURL := f.cdnURL.ResolveReference(&url.URL{Path: manifestPath}) + manifestURL := baseURL.ResolveReference(&url.URL{Path: manifestPath}) req, err := http.NewRequestWithContext(ctx, "GET", manifestURL.String(), nil) if err != nil { - return nil, false, err + return nil, nil, err } req.Header.Set("Authorization", "Bearer "+f.authenticationToken) @@ -80,27 +131,27 @@ func (f *Fetcher) Fetch(ctx context.Context, currentRevision string) (*Manifest, resp, err := f.httpClient.Do(req) if err != nil { - return nil, false, err + return nil, nil, err } defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusNotModified { - return nil, false, nil + return resp, nil, nil } if resp.StatusCode != http.StatusOK { if resp.StatusCode == http.StatusNotFound { - return nil, false, errors.New("PQL manifest not found on CDN") + return resp, nil, errors.New("PQL manifest not found on CDN") } if resp.StatusCode == http.StatusUnauthorized { - return nil, false, errors.New("could not authenticate against CDN") + return resp, nil, errors.New("could not authenticate against CDN") } if resp.StatusCode == http.StatusBadRequest { - return nil, false, errors.New("bad request") + return resp, nil, errors.New("bad request") } - return nil, false, fmt.Errorf("unexpected status code when loading PQL manifest, statusCode: %d", resp.StatusCode) + return resp, nil, fmt.Errorf("unexpected status code when loading PQL manifest, statusCode: %d", resp.StatusCode) } var reader io.Reader = resp.Body @@ -108,7 +159,7 @@ func (f *Fetcher) Fetch(ctx context.Context, currentRevision string) (*Manifest, if resp.Header.Get("Content-Encoding") == "gzip" { r, err := gzip.NewReader(resp.Body) if err != nil { - return nil, false, fmt.Errorf("could not create gzip reader: %w", err) + return resp, nil, fmt.Errorf("could not create gzip reader: %w", err) } defer func() { _ = r.Close() @@ -118,21 +169,12 @@ func (f *Fetcher) Fetch(ctx context.Context, currentRevision string) (*Manifest, body, err := io.ReadAll(reader) if err != nil { - return nil, false, fmt.Errorf("could not read response body: %w", err) + return resp, nil, fmt.Errorf("could not read response body: %w", err) } if len(body) == 0 { - return nil, false, errors.New("empty response body") + return resp, nil, errors.New("empty response body") } - var manifest Manifest - if err := json.Unmarshal(body, &manifest); err != nil { - return nil, false, fmt.Errorf("could not unmarshal PQL manifest: %w", err) - } - - if err := validateManifest(&manifest); err != nil { - return nil, false, fmt.Errorf("invalid PQL manifest: %w", err) - } - - return &manifest, true, nil + return resp, body, nil } diff --git a/router/internal/persistedoperation/pqlmanifest/fetcher_test.go b/router/internal/persistedoperation/pqlmanifest/fetcher_test.go index b3555a678d..7fd5b306aa 100644 --- a/router/internal/persistedoperation/pqlmanifest/fetcher_test.go +++ b/router/internal/persistedoperation/pqlmanifest/fetcher_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "sync/atomic" "testing" "github.com/stretchr/testify/require" @@ -25,6 +26,15 @@ func newTestFetcher(serverURL string) *Fetcher { } } +func newTestFetcherWithFallback(primaryURL, fallbackURL string) *Fetcher { + f := newTestFetcher(primaryURL) + if fallbackURL != "" { + fu, _ := url.Parse(fallbackURL) + f.cdnFallbackURL = fu + } + return f +} + // mustMarshalManifest marshals a Manifest to JSON, panicking on error. func mustMarshalManifest(m *Manifest) []byte { data, err := json.Marshal(m) @@ -207,3 +217,189 @@ func TestFetch_UsesGETMethod(t *testing.T) { require.NoError(t, err) require.Equal(t, "GET", receivedMethod) } + +func TestFetch_Fallback_503PrimaryFallsBackToSecondary(t *testing.T) { + t.Parallel() + m := &Manifest{ + Version: 1, + Revision: "rev-fallback", + GeneratedAt: "2025-01-01T00:00:00Z", + Operations: map[string]string{"hash1": "query { fallback }"}, + } + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(newETagCDNHandler(m)) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + result, changed, err := f.Fetch(context.Background(), "") + + require.NoError(t, err) + require.True(t, changed) + require.NotNil(t, result) + require.Equal(t, "rev-fallback", result.Revision) + require.Equal(t, "query { fallback }", result.Operations["hash1"]) +} + +func TestFetch_Fallback_429PrimaryFallsBackToSecondary(t *testing.T) { + t.Parallel() + m := &Manifest{ + Version: 1, + Revision: "rev-429", + GeneratedAt: "2025-01-01T00:00:00Z", + Operations: map[string]string{"hash1": "query { rate_limited }"}, + } + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer primary.Close() + + fallback := httptest.NewServer(newETagCDNHandler(m)) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + result, changed, err := f.Fetch(context.Background(), "") + + require.NoError(t, err) + require.True(t, changed) + require.NotNil(t, result) +} + +func TestFetch_Fallback_NotTriggeredOn404(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + require.False(t, fallbackCalled.Load()) +} + +func TestFetch_Fallback_NotTriggeredOn401(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) + require.Contains(t, err.Error(), "authenticate") + require.False(t, fallbackCalled.Load()) +} + +func TestFetch_Fallback_NetworkErrorFallsBack(t *testing.T) { + t.Parallel() + m := &Manifest{ + Version: 1, + Revision: "rev-net", + GeneratedAt: "2025-01-01T00:00:00Z", + Operations: map[string]string{"hash1": "query { net }"}, + } + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() // close immediately to cause network error + + fallback := httptest.NewServer(newETagCDNHandler(m)) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + result, changed, err := f.Fetch(context.Background(), "") + + require.NoError(t, err) + require.True(t, changed) + require.NotNil(t, result) +} + +func TestFetch_Fallback_NetworkErrorWithoutFallbackReturnsError(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + f := newTestFetcher(primary.URL) // no fallback + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) +} + +func TestFetch_Fallback_503WithoutFallbackReturnsError(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + f := newTestFetcher(primary.URL) // no fallback + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) + require.Contains(t, err.Error(), "503") +} + +func TestFetch_Fallback_503Then404PreservesPrimaryError(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) + require.Contains(t, err.Error(), "primary CDN failed") + require.Contains(t, err.Error(), "503") +} + +func TestFetch_Fallback_503Then401PreservesPrimaryError(t *testing.T) { + t.Parallel() + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer fallback.Close() + + f := newTestFetcherWithFallback(primary.URL, fallback.URL) + _, _, err := f.Fetch(context.Background(), "") + + require.Error(t, err) + require.Contains(t, err.Error(), "primary CDN failed") + require.Contains(t, err.Error(), "503") +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 225f68c6b4..7764c46983 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -666,8 +666,9 @@ type RateLimitOverride struct { } type CDNConfiguration struct { - URL string `yaml:"url" env:"CDN_URL" envDefault:"https://cosmo-cdn.wundergraph.com"` - CacheSize BytesString `yaml:"cache_size,omitempty" env:"CDN_CACHE_SIZE" envDefault:"100MB"` + URL string `yaml:"url" env:"CDN_URL" envDefault:"https://cosmo-cdn.wundergraph.com"` + FallbackURL string `yaml:"fallback_url" env:"CDN_FALLBACK_URL" envDefault:""` + CacheSize BytesString `yaml:"cache_size,omitempty" env:"CDN_CACHE_SIZE" envDefault:"100MB"` } type NatsTokenBasedAuthentication struct { @@ -952,8 +953,9 @@ type S3StorageProvider struct { } type CDNStorageProvider struct { - ID string `yaml:"id,omitempty" env:"ID"` - URL string `yaml:"url,omitempty" env:"URL" envDefault:"https://cosmo-cdn.wundergraph.com"` + ID string `yaml:"id,omitempty" env:"ID"` + URL string `yaml:"url,omitempty" env:"URL" envDefault:"https://cosmo-cdn.wundergraph.com"` + FallbackURL string `yaml:"fallback_url,omitempty" env:"FALLBACK_URL" envDefault:""` } type FileSystemStorageProvider struct { @@ -968,7 +970,8 @@ type RedisStorageProvider struct { } type PersistedOperationsCDNProvider struct { - URL string `yaml:"url,omitempty" envDefault:"https://cosmo-cdn.wundergraph.com"` + URL string `yaml:"url,omitempty" envDefault:"https://cosmo-cdn.wundergraph.com"` + FallbackURL string `yaml:"fallback_url,omitempty" envDefault:""` } type ExecutionConfigStorage struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 54fa556e93..ac8f08ffdd 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -48,6 +48,11 @@ "type": "string", "description": "The provider URL. The URL is used to connect to the provider.", "format": "url" + }, + "fallback_url": { + "type": "string", + "description": "The fallback provider URL. Used when the primary URL fails with a server error (5xx), rate limiting (429), or a network/connection failure.", + "format": "url" } } } @@ -2326,6 +2331,11 @@ "format": "http-url", "description": "The URL of the CDN. The URL is used to register the router on the control-plane. The URL is specified as a string with the format 'scheme://host:port'." }, + "fallback_url": { + "type": "string", + "format": "http-url", + "description": "The fallback URL of the CDN. Used when the primary URL fails with a server error (5xx), rate limiting (429), or a network/connection failure." + }, "cache_size": { "type": "string", "default": "100MB", diff --git a/router/pkg/routerconfig/cdn/client.go b/router/pkg/routerconfig/cdn/client.go index 3bd8289aba..f8ae4e751d 100644 --- a/router/pkg/routerconfig/cdn/client.go +++ b/router/pkg/routerconfig/cdn/client.go @@ -40,10 +40,12 @@ type Options struct { Logger *zap.Logger SignatureKey string RouterCompatibilityVersion int + FallbackEndpoint string } type Client struct { cdnURL *url.URL + cdnFallbackURL *url.URL authenticationToken string // federatedGraphID is the ID of the federated graph that was obtained // from the token, already url-escaped @@ -85,6 +87,14 @@ func NewClient(endpoint string, token string, opts *Options) (routerconfig.Clien return nil, fmt.Errorf("invalid CDN URL %q: %w", endpoint, err) } + var fu *url.URL + if opts.FallbackEndpoint != "" { + fu, err = url.Parse(opts.FallbackEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid CDN fallback URL %q: %w", opts.FallbackEndpoint, err) + } + } + if opts.Logger == nil { opts.Logger = zap.NewNop() } @@ -98,6 +108,7 @@ func NewClient(endpoint string, token string, opts *Options) (routerconfig.Clien c := &Client{ cdnURL: u, + cdnFallbackURL: fu, authenticationToken: token, federatedGraphID: url.PathEscape(claims.FederatedGraphID), organizationID: url.PathEscape(claims.OrganizationID), @@ -114,23 +125,82 @@ func NewClient(endpoint string, token string, opts *Options) (routerconfig.Clien } func (cdn *Client) getRouterConfig(ctx context.Context, version string, _ time.Time) ([]byte, error) { + resp, body, err := cdn.doGetRouterConfig(ctx, version, cdn.cdnURL) + + if err != nil && cdn.cdnFallbackURL != nil && httpclient.IsCDNFallbackEligible(resp, err) { + primaryErr := err + cdn.logger.Warn("Primary CDN failed, attempting fallback CDN for router config", + zap.Error(err), + zap.String("fallback_url", cdn.cdnFallbackURL.String()), + ) + fallbackResp, fallbackBody, fallbackErr := cdn.doGetRouterConfig(ctx, version, cdn.cdnFallbackURL) + if fallbackErr == nil || errors.Is(fallbackErr, configpoller.ErrConfigNotModified) { + resp, body, err = fallbackResp, fallbackBody, fallbackErr + } else { + return nil, fmt.Errorf("primary CDN failed: %w; fallback CDN also failed: %v", primaryErr, fallbackErr) + } + } + + if err != nil { + return nil, err + } + + if cdn.hash != nil { + configSignature := resp.Header.Get(sigResponseHeaderName) + if configSignature == "" { + cdn.logger.Error( + "Signature header not found in CDN response. Ensure that your Admission Controller was able to sign the config. Open the compositions page in the Studio to check the status of the last deployment", + zap.Error(ErrMissingSignatureHeader), + ) + return nil, ErrMissingSignatureHeader + } + + if _, err := cdn.hash.Write(body); err != nil { + return nil, fmt.Errorf("could not write config body to hmac: %w", err) + } + dataHmac := cdn.hash.Sum(nil) + cdn.hash.Reset() + + rawSignature, err := base64.StdEncoding.DecodeString(configSignature) + if err != nil { + return nil, fmt.Errorf("could not hex decode signature key: %w", err) + } + + if subtle.ConstantTimeCompare(rawSignature, dataHmac) != 1 { + cdn.logger.Error( + "Invalid config signature, potential tampering detected. Ensure that your Admission Controller has signed the config correctly. Open the compositions page in the Studio to check the status of the last deployment", + zap.Error(ErrInvalidSignature), + ) + return nil, ErrInvalidSignature + } + + cdn.logger.Info("Config signature validation successful", + zap.String("federatedGraphID", cdn.federatedGraphID), + zap.String("signature", configSignature), + ) + } + + return body, nil +} + +func (cdn *Client) doGetRouterConfig(ctx context.Context, version string, baseURL *url.URL) (*http.Response, []byte, error) { routerConfigPath := fmt.Sprintf("/%s/%s/routerconfigs/%slatest.json", cdn.organizationID, cdn.federatedGraphID, routerconfig.VersionPath(cdn.routerCompatibilityVersion), ) - routerConfigURL := cdn.cdnURL.ResolveReference(&url.URL{Path: routerConfigPath}) + routerConfigURL := baseURL.ResolveReference(&url.URL{Path: routerConfigPath}) body, err := json.Marshal(getRouterConfigRequestBody{ Version: version, }) if err != nil { - return nil, err + return nil, nil, err } req, err := http.NewRequestWithContext(ctx, "POST", routerConfigURL.String(), bytes.NewBuffer(body)) if err != nil { - return nil, err + return nil, nil, err } req.Header.Set("Content-Type", "application/json; charset=UTF-8") @@ -139,7 +209,7 @@ func (cdn *Client) getRouterConfig(ctx context.Context, version string, _ time.T resp, err := cdn.httpClient.Do(req) if err != nil { - return nil, err + return nil, nil, err } defer func() { _ = resp.Body.Close() @@ -147,19 +217,19 @@ func (cdn *Client) getRouterConfig(ctx context.Context, version string, _ time.T if resp.StatusCode != http.StatusOK { if resp.StatusCode == http.StatusNotFound { - return nil, ErrConfigNotFound + return resp, nil, ErrConfigNotFound } if resp.StatusCode == http.StatusUnauthorized { - return nil, errors.New("could not authenticate against CDN") + return resp, nil, errors.New("could not authenticate against CDN") } if resp.StatusCode == http.StatusBadRequest { - return nil, errors.New("bad request") + return resp, nil, errors.New("bad request") } if resp.StatusCode == http.StatusNotModified { - return nil, configpoller.ErrConfigNotModified + return resp, nil, configpoller.ErrConfigNotModified } - return nil, fmt.Errorf("unexpected status code when loading router config, statusCode: %d", resp.StatusCode) + return resp, nil, fmt.Errorf("unexpected status code when loading router config, statusCode: %d", resp.StatusCode) } var reader io.Reader = resp.Body @@ -167,7 +237,7 @@ func (cdn *Client) getRouterConfig(ctx context.Context, version string, _ time.T if resp.Header.Get("Content-Encoding") == "gzip" { r, err := gzip.NewReader(resp.Body) if err != nil { - return nil, fmt.Errorf("could not create gzip reader: %w", err) + return resp, nil, fmt.Errorf("could not create gzip reader: %w", err) } defer func() { _ = r.Close() @@ -177,55 +247,14 @@ func (cdn *Client) getRouterConfig(ctx context.Context, version string, _ time.T body, err = io.ReadAll(reader) if err != nil { - return nil, fmt.Errorf("could not read the response body: %w", err) + return resp, nil, fmt.Errorf("could not read the response body: %w", err) } if len(body) == 0 { - return nil, errors.New("empty response body") - } - - /** - * If a signature key is set, we need to validate the signature of the received config - */ - - if cdn.hash != nil { - configSignature := resp.Header.Get(sigResponseHeaderName) - if configSignature == "" { - cdn.logger.Error( - "Signature header not found in CDN response. Ensure that your Admission Controller was able to sign the config. Open the compositions page in the Studio to check the status of the last deployment", - zap.Error(ErrMissingSignatureHeader), - ) - return nil, ErrMissingSignatureHeader - } - - // create a signature of the received config body - if _, err := cdn.hash.Write(body); err != nil { - return nil, fmt.Errorf("could not write config body to hmac: %w", err) - } - dataHmac := cdn.hash.Sum(nil) - cdn.hash.Reset() - - // compare received signature with the one we calculated with the private signature key - rawSignature, err := base64.StdEncoding.DecodeString(configSignature) - if err != nil { - return nil, fmt.Errorf("could not hex decode signature key: %w", err) - } - - if subtle.ConstantTimeCompare(rawSignature, dataHmac) != 1 { - cdn.logger.Error( - "Invalid config signature, potential tampering detected. Ensure that your Admission Controller has signed the config correctly. Open the compositions page in the Studio to check the status of the last deployment", - zap.Error(ErrInvalidSignature), - ) - return nil, ErrInvalidSignature - } - - cdn.logger.Info("Config signature validation successful", - zap.String("federatedGraphID", cdn.federatedGraphID), - zap.String("signature", configSignature), - ) + return resp, nil, errors.New("empty response body") } - return body, nil + return resp, body, nil } func (cdn *Client) RouterConfig(ctx context.Context, version string, modifiedSince time.Time) (*routerconfig.Response, error) { diff --git a/router/pkg/routerconfig/cdn/client_test.go b/router/pkg/routerconfig/cdn/client_test.go new file mode 100644 index 0000000000..cdee50bad1 --- /dev/null +++ b/router/pkg/routerconfig/cdn/client_test.go @@ -0,0 +1,267 @@ +package cdn + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// validRouterConfigJSON is a minimal valid execution config JSON. +const validRouterConfigJSON = `{"version":"1","engineConfig":{},"subgraphs":[]}` + +func newTestConfigClient(primaryURL, fallbackURL string) *Client { + u, _ := url.Parse(primaryURL) + var fu *url.URL + if fallbackURL != "" { + fu, _ = url.Parse(fallbackURL) + } + return &Client{ + cdnURL: u, + cdnFallbackURL: fu, + authenticationToken: "test-token", + federatedGraphID: "test-graph", + organizationID: "test-org", + httpClient: http.DefaultClient, + logger: zap.NewNop(), + routerCompatibilityVersion: 0, + } +} + +func TestGetRouterConfig_Fallback(t *testing.T) { + t.Parallel() + + t.Run("primary 200 does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + body, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.NoError(t, err) + require.NotEmpty(t, body) + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 503 with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + body, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.NoError(t, err) + require.NotEmpty(t, body) + }) + + t.Run("primary 429 with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + body, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.NoError(t, err) + require.NotEmpty(t, body) + }) + + t.Run("primary 404 does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 401 does not trigger fallback", func(t *testing.T) { + t.Parallel() + var fallbackCalled atomic.Bool + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fallbackCalled.Store(true) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "authenticate") + assert.False(t, fallbackCalled.Load()) + }) + + t.Run("primary 200 without fallback succeeds normally", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer primary.Close() + + c := newTestConfigClient(primary.URL, "") + body, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.NoError(t, err) + require.NotEmpty(t, body) + }) + + t.Run("primary 503 without fallback returns error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + c := newTestConfigClient(primary.URL, "") + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary network error without fallback returns error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + c := newTestConfigClient(primary.URL, "") + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + }) + + t.Run("primary network error with fallback 200 returns from fallback", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(validRouterConfigJSON)) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + body, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.NoError(t, err) + require.NotEmpty(t, body) + }) + + t.Run("primary 503 fallback 503 returns primary error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "fallback CDN also failed") + }) + + t.Run("primary 503 fallback 404 preserves primary error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + }) + + t.Run("primary 503 fallback 401 preserves primary error", func(t *testing.T) { + t.Parallel() + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer primary.Close() + + fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer fallback.Close() + + c := newTestConfigClient(primary.URL, fallback.URL) + _, err := c.getRouterConfig(context.Background(), "", time.Time{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "primary CDN failed") + assert.Contains(t, err.Error(), "503") + }) +}