Skip to content
585 changes: 585 additions & 0 deletions router-tests/security/subgraph_grpc_mtls_test.go

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/encoding/protojson"

"github.com/wundergraph/cosmo/demo/pkg/subgraphs"
Expand Down Expand Up @@ -422,6 +423,10 @@ type SubgraphConfig struct {
// TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS()
// instead of Start(). This is useful for testing mTLS between the router and subgraphs.
TLSConfig *tls.Config

// GRPCTLSConfig enables TLS on the gRPC subgraph server. When set, the gRPC server
// uses TLS credentials instead of plain connections.
GRPCTLSConfig *tls.Config
}

type LogObservationConfig struct {
Expand Down Expand Up @@ -628,7 +633,7 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1071,7 +1076,7 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1767,7 +1772,7 @@ func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.C
return s
}

func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor) (*grpc.Server, string) {
func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor, tlsConfig *tls.Config) (*grpc.Server, string) {
t.Helper()

// We could use freeport here, but it is easy to use ephemeral port and get the endpoint
Expand All @@ -1782,6 +1787,9 @@ func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interce
if interceptor != nil {
opts = append(opts, grpc.ChainUnaryInterceptor(interceptor))
}
if tlsConfig != nil {
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}

s := grpc.NewServer(opts...)
s.RegisterService(sd, service)
Expand Down
32 changes: 26 additions & 6 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -105,6 +106,8 @@ type (
connector *grpcconnector.Connector
circuitBreakerManager *circuit.Manager
headerPropagation *HeaderPropagation
defaultGRPCClientTLS *tls.Config
perSubgraphGRPCTLS map[string]*tls.Config
}
)

Expand Down Expand Up @@ -147,12 +150,18 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
traceDialer = NewTraceDialer()
}

// Build subgraph client TLS configs (mTLS for outbound subgraph connections)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.subgraphTLSConfiguration)
// Build subgraph client TLS configs (mTLS for outbound HTTP subgraph connections)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, r.subgraphTLSConfiguration.All, r.subgraphTLSConfiguration.Subgraphs)
if err != nil {
return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err)
}

// Build gRPC subgraph client TLS configs
defaultGRPCClientTLS, perSubgraphGRPCTLS, err := buildSubgraphTLSConfigs(r.logger, r.subgraphGRPCTLSConfiguration.All, r.subgraphGRPCTLSConfiguration.Subgraphs)
if err != nil {
return nil, fmt.Errorf("could not build gRPC subgraph client TLS config: %w", err)
}

// Base transport
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS)

Expand Down Expand Up @@ -193,8 +202,10 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
HostName: r.hostName,
ListenAddress: r.listenAddr,
},
storageProviders: &r.storageProviders,
headerPropagation: r.headerPropagation,
storageProviders: &r.storageProviders,
headerPropagation: r.headerPropagation,
defaultGRPCClientTLS: defaultGRPCClientTLS,
perSubgraphGRPCTLS: perSubgraphGRPCTLS,
}

baseOtelAttributes := []attribute.KeyValue{
Expand Down Expand Up @@ -1769,9 +1780,18 @@ func (s *graphServer) setupConnector(

pluginConfig := grpcConfig.GetPlugin()
if pluginConfig == nil {
// Resolve per-subgraph gRPC TLS config, falling back to the default.
var grpcTLS *tls.Config
if sgTLS, ok := s.perSubgraphGRPCTLS[sg.Name]; ok {
grpcTLS = sgTLS
} else {
grpcTLS = s.defaultGRPCClientTLS
}

remoteProvider, err := grpcremote.NewRemoteGRPCProvider(grpcremote.RemoteGRPCProviderConfig{
Logger: s.logger,
Endpoint: sg.RoutingUrl,
Logger: s.logger,
Endpoint: sg.RoutingUrl,
TLSConfig: grpcTLS,
})

if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2307,6 +2307,12 @@ func WithSubgraphTLSConfiguration(cfg config.ClientTLSConfiguration) Option {
}
}

func WithSubgraphGRPCTLSConfiguration(cfg config.GRPCClientTLSConfiguration) Option {
return func(r *Router) {
r.subgraphGRPCTLSConfiguration = cfg
}
}

func WithTelemetryAttributes(attributes []config.CustomAttribute) Option {
return func(r *Router) {
r.telemetryAttributes = attributes
Expand Down
1 change: 1 addition & 0 deletions router/core/router_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ type Config struct {
tlsServerConfig *tls.Config
tlsConfig *TlsConfig
subgraphTLSConfiguration config.ClientTLSConfiguration
subgraphGRPCTLSConfiguration config.GRPCClientTLSConfiguration
telemetryAttributes []config.CustomAttribute
tracePropagators []propagation.TextMapPropagator
compositePropagator propagation.TextMapPropagator
Expand Down
1 change: 1 addition & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config, reloadPersi
WithStreamsHandlerConfiguration(config.Events.Handlers),
WithReloadPersistentState(reloadPersistentState),
WithSubgraphTLSConfiguration(config.TLS.Client),
WithSubgraphGRPCTLSConfiguration(config.TLS.ClientGRPC),
}

return options
Expand Down
12 changes: 6 additions & 6 deletions router/core/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,30 @@ func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Co

// buildSubgraphTLSConfigs builds the default and per-subgraph TLS configs from raw configuration.
// Returns (defaultClientTLS, perSubgraphTLS, error).
func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfiguration) (*tls.Config, map[string]*tls.Config, error) {
hasAll := (cfg.All.CertFile != "" && cfg.All.KeyFile != "") || cfg.All.CaFile != "" || cfg.All.InsecureSkipCaVerification
func buildSubgraphTLSConfigs(logger *zap.Logger, all config.TLSClientCertConfiguration, subgraphs map[string]config.TLSClientCertConfiguration) (*tls.Config, map[string]*tls.Config, error) {
hasAll := (all.CertFile != "" && all.KeyFile != "") || all.CaFile != "" || all.InsecureSkipCaVerification

// If no global TLS config is provided and there are no subgraph specific TLS configs
if !hasAll && len(cfg.Subgraphs) == 0 {
if !hasAll && len(subgraphs) == 0 {
return nil, nil, nil
}

var defaultClientTLS *tls.Config
perSubgraphTLS := make(map[string]*tls.Config)

if hasAll {
if cfg.All.InsecureSkipCaVerification {
if all.InsecureSkipCaVerification {
logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.")
}

defaultTLS, err := buildTLSClientConfig(&cfg.All)
defaultTLS, err := buildTLSClientConfig(&all)
if err != nil {
return nil, nil, fmt.Errorf("failed to build global subgraph TLS config: %w", err)
}
defaultClientTLS = defaultTLS
}

for name, sgCfg := range cfg.Subgraphs {
for name, sgCfg := range subgraphs {
if sgCfg.InsecureSkipCaVerification {
logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.",
zap.String("subgraph", name))
Comment thread
dkorittki marked this conversation as resolved.
Expand Down
87 changes: 36 additions & 51 deletions router/core/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ func TestBuildTLSClientConfig(t *testing.T) {
t.Run("returns nil when no TLS configured", func(t *testing.T) {
t.Parallel()

cfg := &config.ClientTLSConfiguration{}
defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), config.TLSClientCertConfiguration{}, nil)
require.NoError(t, err)
require.Nil(t, defaultTLS)
require.Nil(t, perSubgraphTLS)
Expand All @@ -102,15 +101,13 @@ func TestBuildTLSClientConfig(t *testing.T) {
certPath, keyPath := generateTestCert(t, "client")
caPath, _ := generateTestCert(t, "ca")

cfg := &config.ClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CertFile: certPath,
KeyFile: keyPath,
CaFile: caPath,
},
all := config.TLSClientCertConfiguration{
CertFile: certPath,
KeyFile: keyPath,
CaFile: caPath,
}

defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), all, nil)
require.NoError(t, err)
require.NotNil(t, defaultTLS)
require.Len(t, defaultTLS.Certificates, 1)
Expand All @@ -123,16 +120,14 @@ func TestBuildTLSClientConfig(t *testing.T) {

certPath, keyPath := generateTestCert(t, "products")

cfg := &config.ClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: certPath,
KeyFile: keyPath,
},
subgraphs := map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: certPath,
KeyFile: keyPath,
},
}

defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), config.TLSClientCertConfiguration{}, subgraphs)
require.NoError(t, err)
require.Nil(t, defaultTLS)
require.Contains(t, perSubgraphTLS, "products")
Expand All @@ -145,20 +140,18 @@ func TestBuildTLSClientConfig(t *testing.T) {
globalCert, globalKey := generateTestCert(t, "global")
productsCert, productsKey := generateTestCert(t, "products")

cfg := &config.ClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CertFile: globalCert,
KeyFile: globalKey,
},
Subgraphs: map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: productsCert,
KeyFile: productsKey,
},
all := config.TLSClientCertConfiguration{
CertFile: globalCert,
KeyFile: globalKey,
}
subgraphs := map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: productsCert,
KeyFile: productsKey,
},
}

defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), all, subgraphs)
require.NoError(t, err)
require.NotNil(t, defaultTLS)
require.Contains(t, perSubgraphTLS, "products")
Expand All @@ -167,31 +160,27 @@ func TestBuildTLSClientConfig(t *testing.T) {
t.Run("errors on invalid global cert", func(t *testing.T) {
t.Parallel()

cfg := &config.ClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CertFile: "/nonexistent/cert.pem",
KeyFile: "/nonexistent/key.pem",
},
all := config.TLSClientCertConfiguration{
CertFile: "/nonexistent/cert.pem",
KeyFile: "/nonexistent/key.pem",
}

_, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
_, _, err := buildSubgraphTLSConfigs(zap.NewNop(), all, nil)
require.Error(t, err)
require.EqualError(t, err, "failed to build global subgraph TLS config: failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory")
})

t.Run("errors on invalid per-subgraph cert", func(t *testing.T) {
t.Parallel()

cfg := &config.ClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: "/nonexistent/cert.pem",
KeyFile: "/nonexistent/key.pem",
},
subgraphs := map[string]config.TLSClientCertConfiguration{
"products": {
CertFile: "/nonexistent/cert.pem",
KeyFile: "/nonexistent/key.pem",
},
}

_, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg)
_, _, err := buildSubgraphTLSConfigs(zap.NewNop(), config.TLSClientCertConfiguration{}, subgraphs)
require.Error(t, err)
require.EqualError(t, err, `failed to build TLS config for subgraph "products": failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory`)
})
Expand All @@ -202,13 +191,11 @@ func TestBuildTLSClientConfig(t *testing.T) {
core, logs := observer.New(zapcore.WarnLevel)
logger := zap.New(core)

cfg := &config.ClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
},
all := config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
}

defaultTLS, _, err := buildSubgraphTLSConfigs(logger, cfg)
defaultTLS, _, err := buildSubgraphTLSConfigs(logger, all, nil)
require.NoError(t, err)
require.NotNil(t, defaultTLS)
require.True(t, defaultTLS.InsecureSkipVerify)
Expand All @@ -223,15 +210,13 @@ func TestBuildTLSClientConfig(t *testing.T) {
core, logs := observer.New(zapcore.WarnLevel)
logger := zap.New(core)

cfg := &config.ClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"products": {
InsecureSkipCaVerification: true,
},
subgraphs := map[string]config.TLSClientCertConfiguration{
"products": {
InsecureSkipCaVerification: true,
},
}

_, perSubgraphTLS, err := buildSubgraphTLSConfigs(logger, cfg)
_, perSubgraphTLS, err := buildSubgraphTLSConfigs(logger, config.TLSClientCertConfiguration{}, subgraphs)
require.NoError(t, err)
require.Contains(t, perSubgraphTLS, "products")
require.True(t, perSubgraphTLS["products"].InsecureSkipVerify)
Expand Down
12 changes: 10 additions & 2 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -890,9 +890,17 @@ type ClientTLSConfiguration struct {
Subgraphs map[string]TLSClientCertConfiguration `yaml:"subgraphs,omitempty"`
}

type GRPCClientTLSConfiguration struct {
// All applies to all gRPC subgraph connections.
All TLSClientCertConfiguration `yaml:"all" envPrefix:"TLS_CLIENT_GRPC_ALL_"`
// Subgraphs overrides per-subgraph gRPC TLS config. Key is the subgraph name.
Subgraphs map[string]TLSClientCertConfiguration `yaml:"subgraphs,omitempty"`
}

type TLSConfiguration struct {
Server TLSServerConfiguration `yaml:"server"`
Client ClientTLSConfiguration `yaml:"client"`
Server TLSServerConfiguration `yaml:"server"`
Client ClientTLSConfiguration `yaml:"client"`
ClientGRPC GRPCClientTLSConfiguration `yaml:"client_grpc"`
}

type SubgraphErrorPropagationMode string
Expand Down
Loading
Loading