Skip to content
585 changes: 585 additions & 0 deletions router-tests/security/subgraph_grpc_mtls_test.go

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/encoding/protojson"

"github.com/wundergraph/cosmo/demo/pkg/subgraphs"
Expand Down Expand Up @@ -422,6 +423,10 @@ type SubgraphConfig struct {
// TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS()
// instead of Start(). This is useful for testing mTLS between the router and subgraphs.
TLSConfig *tls.Config

// GRPCTLSConfig enables TLS on the gRPC subgraph server. When set, the gRPC server
// uses TLS credentials instead of plain connections.
GRPCTLSConfig *tls.Config
}

type LogObservationConfig struct {
Expand Down Expand Up @@ -630,7 +635,7 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1068,7 +1073,7 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1766,7 +1771,7 @@ func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.C
return s
}

func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor) (*grpc.Server, string) {
func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor, tlsConfig *tls.Config) (*grpc.Server, string) {
t.Helper()

// We could use freeport here, but it is easy to use ephemeral port and get the endpoint
Expand All @@ -1781,6 +1786,9 @@ func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interce
if interceptor != nil {
opts = append(opts, grpc.ChainUnaryInterceptor(interceptor))
}
if tlsConfig != nil {
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}

s := grpc.NewServer(opts...)
s.RegisterService(sd, service)
Expand Down
73 changes: 58 additions & 15 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -120,6 +121,8 @@ type BuildGraphMuxOptions struct {
ConfigSubgraphs []*nodev1.Subgraph
RoutingUrlGroupings map[string]map[string]bool
ReloadPersistentState *ReloadPersistentState
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

func (b BuildGraphMuxOptions) IsBaseGraph() bool {
Expand All @@ -133,6 +136,8 @@ type buildMultiGraphHandlerOptions struct {
reloadPersistentState *ReloadPersistentState
currentGraphMuxes map[string]*graphMux
changes *routerconfig.Changes
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

// reusedGraphMux holds a graph mux from the previous server that the new server
Expand Down Expand Up @@ -163,11 +168,23 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
}

// Build subgraph client TLS configs (mTLS for outbound subgraph connections)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.tls.settings.Client)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(
r.logger,
&r.tls.settings.Client,
)
if err != nil {
return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err)
}

// Build gRPC subgraph client TLS configs
defaultGRPCClientTLS, perSubgraphGRPCTLS, err := buildSubgraphTLSConfigs(
r.logger,
&r.tls.settings.ClientGRPC,
)
if err != nil {
return nil, fmt.Errorf("could not build gRPC subgraph client TLS config: %w", err)
}

// Base transport
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS)

Expand Down Expand Up @@ -337,6 +354,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
ConfigSubgraphs: response.Config.GetSubgraphs(),
RoutingUrlGroupings: routingUrlGroupings,
ReloadPersistentState: r.reloadPersistentState,
defaultClientTLS: defaultGRPCClientTLS,
perSubgraphTLS: perSubgraphGRPCTLS,
Comment thread
dkorittki marked this conversation as resolved.
})
if err != nil {
return nil, fmt.Errorf("failed to build base mux: %w", err)
Expand All @@ -358,6 +377,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
reloadPersistentState: r.reloadPersistentState,
currentGraphMuxes: currentMuxes,
changes: response.Changes,
defaultClientTLS: defaultGRPCClientTLS,
perSubgraphTLS: perSubgraphGRPCTLS,
})
if err != nil {
return nil, fmt.Errorf("failed to build feature flag handler: %w", err)
Expand Down Expand Up @@ -578,6 +599,8 @@ func (s *graphServer) buildMultiGraphHandler(
EngineConfig: executionConfig.GetEngineConfig(),
ConfigSubgraphs: executionConfig.Subgraphs,
ReloadPersistentState: opts.reloadPersistentState,
defaultClientTLS: opts.defaultClientTLS,
perSubgraphTLS: opts.perSubgraphTLS,
})
if err != nil {
return nil, nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err)
Expand Down Expand Up @@ -1385,7 +1408,15 @@ func (s *graphServer) buildGraphMux(
subgraphTippers[subgraph] = subgraphTransport
}

if err := s.setupConnector(s.graphServerCtx, opts.EngineConfig, opts.ConfigSubgraphs, telemetryAttExpressions, tracingAttExpressions); err != nil {
err = s.setupConnector(s.graphServerCtx, setupConnectorOpts{
config: opts.EngineConfig,
configSubgraphs: opts.ConfigSubgraphs,
telemetryAttributeExpressions: telemetryAttExpressions,
tracingAttributeExpressions: tracingAttExpressions,
defaultClientTLS: opts.defaultClientTLS,
perSubgraphTLS: opts.perSubgraphTLS,
})
if err != nil {
return nil, fmt.Errorf("failed to setup plugin host: %w", err)
Comment thread
dkorittki marked this conversation as resolved.
}

Expand Down Expand Up @@ -1850,24 +1881,27 @@ func (s *graphServer) buildGraphMux(
return gm, nil
}

func (s *graphServer) setupConnector(
ctx context.Context,
config *nodev1.EngineConfiguration,
configSubgraphs []*nodev1.Subgraph,
telemetryAttributeExpressions *attributeExpressions,
tracingAttributeExpressions *attributeExpressions,
) error {
type setupConnectorOpts struct {
config *nodev1.EngineConfiguration
configSubgraphs []*nodev1.Subgraph
telemetryAttributeExpressions *attributeExpressions
tracingAttributeExpressions *attributeExpressions
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

func (s *graphServer) setupConnector(ctx context.Context, opts setupConnectorOpts) error {
s.connector = grpcconnector.NewConnector()

for _, dsConfig := range config.DatasourceConfigurations {
for _, dsConfig := range opts.config.DatasourceConfigurations {
grpcConfig := dsConfig.GetCustomGraphql().GetGrpc()
if grpcConfig == nil {
continue
}

var sg *nodev1.Subgraph

for _, subgraph := range configSubgraphs {
for _, subgraph := range opts.configSubgraphs {
if subgraph.Id == dsConfig.Id {
sg = subgraph
break
Expand All @@ -1880,9 +1914,18 @@ func (s *graphServer) setupConnector(

pluginConfig := grpcConfig.GetPlugin()
if pluginConfig == nil {
// Resolve per-subgraph gRPC TLS config, falling back to the default.
var grpcTLS *tls.Config
if sgTLS, ok := opts.perSubgraphTLS[sg.Name]; ok {
grpcTLS = sgTLS
} else {
grpcTLS = opts.defaultClientTLS
}

remoteProvider, err := grpcremote.NewRemoteGRPCProvider(grpcremote.RemoteGRPCProviderConfig{
Logger: s.logger,
Endpoint: sg.RoutingUrl,
Logger: s.logger,
Endpoint: sg.RoutingUrl,
TLSConfig: grpcTLS,
})
if err != nil {
return fmt.Errorf("failed to create standalone plugin for subgraph %s: %w", dsConfig.Id, err)
Expand Down Expand Up @@ -1911,8 +1954,8 @@ func (s *graphServer) setupConnector(
tracer := s.tracerProvider.Tracer("wundergraph/cosmo/router/engine/grpc", oteltrace.WithInstrumentationVersion("0.0.1"))

getTraceAttributes := CreateGRPCTraceGetter(
telemetryAttributeExpressions,
tracingAttributeExpressions,
opts.telemetryAttributeExpressions,
opts.tracingAttributeExpressions,
s.spanNameFormatter,
)

Expand Down
28 changes: 19 additions & 9 deletions router/core/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

// buildTLSClientConfig creates a *tls.Config from a TLSClientCertConfiguration.
func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Config, error) {
func buildTLSClientConfig(clientCfg config.TLSClientCertConfiguration) (*tls.Config, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: clientCfg.InsecureSkipCaVerification,
}
Expand Down Expand Up @@ -43,10 +43,18 @@ func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Co
return tlsConfig, nil
}

type clientTLSConfiguration interface {
GetAll() config.TLSClientCertConfiguration
GetSubgraphs() map[string]config.TLSClientCertConfiguration
Enabled() bool
}

// buildSubgraphTLSConfigs builds the default and per-subgraph TLS configs from raw configuration.
// Returns (defaultClientTLS, perSubgraphTLS, error).
func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfiguration) (*tls.Config, map[string]*tls.Config, error) {
hasAll := (cfg.All.CertFile != "" && cfg.All.KeyFile != "") || cfg.All.CaFile != "" || cfg.All.InsecureSkipCaVerification
func buildSubgraphTLSConfigs[K clientTLSConfiguration](logger *zap.Logger, cfg K) (
*tls.Config, map[string]*tls.Config, error) {
hasAll := (cfg.GetAll().CertFile != "" && cfg.GetAll().KeyFile != "") ||
cfg.GetAll().CaFile != "" || cfg.GetAll().InsecureSkipCaVerification

// If no global TLS config is provided and there are no subgraph specific TLS configs
if !cfg.Enabled() {
Expand All @@ -57,24 +65,26 @@ func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfigurat
perSubgraphTLS := make(map[string]*tls.Config)

if hasAll {
if cfg.All.InsecureSkipCaVerification {
logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.")
if cfg.GetAll().InsecureSkipCaVerification {
logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. " +
"This is not recommended for production environments.")
}

defaultTLS, err := buildTLSClientConfig(&cfg.All)
defaultTLS, err := buildTLSClientConfig(cfg.GetAll())
if err != nil {
return nil, nil, fmt.Errorf("failed to build global subgraph TLS config: %w", err)
}
defaultClientTLS = defaultTLS
}

for name, sgCfg := range cfg.Subgraphs {
for name, sgCfg := range cfg.GetSubgraphs() {
if sgCfg.InsecureSkipCaVerification {
logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.",
logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from "+
"global config. This is not recommended for production environments.",
zap.String("subgraph", name))
Comment thread
dkorittki marked this conversation as resolved.
}

subgraphTLS, err := buildTLSClientConfig(&sgCfg)
subgraphTLS, err := buildTLSClientConfig(sgCfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to build TLS config for subgraph %q: %w", name, err)
}
Expand Down
Loading
Loading