Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions router/core/cache_warmup_cdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var _ CacheWarmupSource = (*CDNSource)(nil)

type CDNSource struct {
cdnURL *url.URL
cdnFallbackURL *url.URL
authenticationToken string
// federatedGraphID is the ID of the federated graph that was obtained
// from the token, already url-escaped
Expand All @@ -34,12 +35,20 @@ type CDNSource struct {
httpClient *http.Client
}

func NewCDNSource(endpoint, token string, logger *zap.Logger) (*CDNSource, error) {
func NewCDNSource(endpoint, fallbackEndpoint, token string, logger *zap.Logger) (*CDNSource, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

var fu *url.URL = nil
if fallbackEndpoint != "" {
fu, err = url.Parse(fallbackEndpoint)
if err != nil {
return nil, err
}
}

claims, err := jwt.ExtractFederatedGraphTokenClaims(token)
if err != nil {
return nil, err
Expand All @@ -51,6 +60,7 @@ func NewCDNSource(endpoint, token string, logger *zap.Logger) (*CDNSource, error

return &CDNSource{
cdnURL: u,
cdnFallbackURL: fu,
authenticationToken: token,
federatedGraphID: claims.FederatedGraphID,
organizationID: claims.OrganizationID,
Expand All @@ -62,14 +72,45 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O
span := trace.SpanFromContext(ctx)
defer span.End()

operationsPath := fmt.Sprintf("/%s/%s/cache_warmup/operations.json", c.organizationID, c.federatedGraphID)
resp, body, err := c.fetchOperationsJSON(ctx, log, c.cdnURL)

if err != nil && c.cdnFallbackURL != nil && httpclient.IsCDNFallbackEligible(resp, err) {
log.Warn("Primary CDN failed, attempting fallback CDN",
zap.Error(err),
zap.String("fallback_url", c.cdnFallbackURL.String()),
)
span.AddEvent("cdn.fallback", trace.WithAttributes(
semconv.HTTPURL(c.cdnFallbackURL.String()),
))
_, body, err = c.fetchOperationsJSON(ctx, log, c.cdnFallbackURL)
}

if err != nil {
return nil, err
}
if body == nil {
return nil, nil
}

var warmupOperations nodev1.CacheWarmerOperations
unmarshalOpts := protojson.UnmarshalOptions{DiscardUnknown: true}
if err := unmarshalOpts.Unmarshal(body, &warmupOperations); err != nil {
return nil, err
}

operationURL := c.cdnURL.ResolveReference(&url.URL{Path: operationsPath})
return warmupOperations.GetOperations(), nil
}

func (c *CDNSource) fetchOperationsJSON(ctx context.Context, log *zap.Logger, baseURL *url.URL) (*http.Response, []byte, error) {
span := trace.SpanFromContext(ctx)

operationsPath := fmt.Sprintf("/%s/%s/cache_warmup/operations.json", c.organizationID, c.federatedGraphID)
operationURL := baseURL.ResolveReference(&url.URL{Path: operationsPath})
log.Debug("Loading cache warmup config", zap.String("url", operationURL.String()))

req, err := http.NewRequestWithContext(ctx, "GET", operationURL.String(), nil)
if err != nil {
return nil, err
return nil, nil, err
}

span.SetAttributes(
Expand All @@ -84,7 +125,7 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
return nil, nil, err
}
defer func() { _ = resp.Body.Close() }()

Expand All @@ -95,29 +136,23 @@ func (c *CDNSource) LoadItems(ctx context.Context, log *zap.Logger) ([]*nodev1.O

if resp.StatusCode == http.StatusNotFound {
log.Debug("Cache warmup config not found", zap.String("url", operationURL.String()))
return nil, nil
return resp, nil, nil
}
if resp.StatusCode == http.StatusUnauthorized {
return nil, errors.New("could not authenticate against CDN")
return resp, nil, errors.New("could not authenticate against CDN")
}
if resp.StatusCode == http.StatusBadRequest {
return nil, errors.New("bad request")
return resp, nil, errors.New("bad request")
}
return nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode)
return resp, nil, fmt.Errorf("unexpected status code when loading persisted operation, statusCode: %d", resp.StatusCode)
}

body, err := c.readResponse(resp)
if err != nil {
return nil, err
}

var warmupOperations nodev1.CacheWarmerOperations
unmarshalOpts := protojson.UnmarshalOptions{DiscardUnknown: true}
if err := unmarshalOpts.Unmarshal(body, &warmupOperations); err != nil {
return nil, err
return resp, nil, err
}

return warmupOperations.GetOperations(), nil
return resp, body, nil
}

func (c *CDNSource) readResponse(resp *http.Response) ([]byte, error) {
Expand Down
228 changes: 228 additions & 0 deletions router/core/cache_warmup_cdn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package core

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

// validCacheWarmupJSON is a valid CacheWarmerOperations protobuf JSON with one operation.
const validCacheWarmupJSON = `{
"operations": [
{
"request": {
"query": "query { hello }"
}
}
]
}`

func newTestCDNSource(primaryURL string, fallbackURL string) *CDNSource {
u, _ := url.Parse(primaryURL)
var fu *url.URL
if fallbackURL != "" {
fu, _ = url.Parse(fallbackURL)
}
return &CDNSource{
cdnURL: u,
cdnFallbackURL: fu,
authenticationToken: "test-token",
federatedGraphID: "test-graph",
organizationID: "test-org",
httpClient: http.DefaultClient,
}
}

func TestCDNSource_LoadItems(t *testing.T) {
t.Parallel()

t.Run("primary 200 without fallback", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer primary.Close()

source := newTestCDNSource(primary.URL, "")
items, err := source.LoadItems(context.Background(), zap.NewNop())
require.NoError(t, err)
require.Len(t, items, 1)
assert.Equal(t, "query { hello }", items[0].Request.Query)
})

t.Run("primary 200 with fallback configured", func(t *testing.T) {
t.Parallel()
var fallbackCalled atomic.Bool

primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fallbackCalled.Store(true)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
items, err := source.LoadItems(context.Background(), zap.NewNop())
require.NoError(t, err)
require.Len(t, items, 1)
assert.False(t, fallbackCalled.Load(), "fallback should not be called when primary succeeds")
})

t.Run("primary 404 with fallback does not trigger fallback", func(t *testing.T) {
t.Parallel()
var fallbackCalled atomic.Bool

primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fallbackCalled.Store(true)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
items, err := source.LoadItems(context.Background(), zap.NewNop())
assert.NoError(t, err)
assert.Nil(t, items)
assert.False(t, fallbackCalled.Load(), "fallback should not be called on 404")
})

t.Run("primary 401 with fallback does not trigger fallback", func(t *testing.T) {
t.Parallel()
var fallbackCalled atomic.Bool

primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fallbackCalled.Store(true)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
_, err := source.LoadItems(context.Background(), zap.NewNop())
assert.Error(t, err)
assert.Contains(t, err.Error(), "authenticate")
assert.False(t, fallbackCalled.Load(), "fallback should not be called on 401")
})

t.Run("primary 503 without fallback returns error", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer primary.Close()

source := newTestCDNSource(primary.URL, "")
_, err := source.LoadItems(context.Background(), zap.NewNop())
assert.Error(t, err)
assert.Contains(t, err.Error(), "503")
})

t.Run("primary 503 with fallback 200 returns items from fallback", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
items, err := source.LoadItems(context.Background(), zap.NewNop())
require.NoError(t, err)
require.Len(t, items, 1)
assert.Equal(t, "query { hello }", items[0].Request.Query)
})

t.Run("primary 429 with fallback 200 returns items from fallback", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
items, err := source.LoadItems(context.Background(), zap.NewNop())
require.NoError(t, err)
require.Len(t, items, 1)
})

t.Run("primary 503 fallback 503 returns error", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
_, err := source.LoadItems(context.Background(), zap.NewNop())
assert.Error(t, err)
assert.Contains(t, err.Error(), "503")
})

t.Run("primary network error with fallback 200 returns items from fallback", func(t *testing.T) {
t.Parallel()
// Use an immediately-closed server to simulate network error
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
primary.Close()

fallback := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(validCacheWarmupJSON))
}))
defer fallback.Close()

source := newTestCDNSource(primary.URL, fallback.URL)
items, err := source.LoadItems(context.Background(), zap.NewNop())
require.NoError(t, err)
require.Len(t, items, 1)
})

t.Run("primary network error without fallback returns error", func(t *testing.T) {
t.Parallel()
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
primary.Close()

source := newTestCDNSource(primary.URL, "")
_, err := source.LoadItems(context.Background(), zap.NewNop())
assert.Error(t, err)
})
}
2 changes: 1 addition & 1 deletion router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ func (s *graphServer) buildGraphMux(
warmupConfig.FallbackSource = NewPlanSource(opts.ReloadPersistentState.inMemoryPlanCacheFallback.getPlanCacheForFF(opts.FeatureFlagName))
opts.ReloadPersistentState.inMemoryPlanCacheFallback.setPlanCacheForFF(opts.FeatureFlagName, gm.planFallbackCache)
}
cdnSource, err := NewCDNSource(s.cdnConfig.URL, s.graphApiToken, s.logger)
cdnSource, err := NewCDNSource(s.cdnConfig.URL, s.cdnConfig.FallbackURL, s.graphApiToken, s.logger)
if err != nil {
return nil, fmt.Errorf("failed to create cdn source: %w", err)
}
Expand Down
2 changes: 2 additions & 0 deletions router/core/init_config_poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func getConfigClient(r *Router, registry *ProviderRegistry, providerID string, i
Logger: r.logger,
SignatureKey: r.routerConfigPollerConfig.GraphSignKey,
RouterCompatibilityVersion: execution_config.RouterCompatibilityVersionThreshold,
FallbackEndpoint: provider.FallbackURL,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -103,6 +104,7 @@ func getConfigClient(r *Router, registry *ProviderRegistry, providerID string, i
Logger: r.logger,
SignatureKey: r.routerConfigPollerConfig.GraphSignKey,
RouterCompatibilityVersion: execution_config.RouterCompatibilityVersionThreshold,
FallbackEndpoint: r.cdnConfig.FallbackURL,
})
if err != nil {
return nil, err
Expand Down
Loading
Loading