Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
615 changes: 615 additions & 0 deletions router-tests/security/subgraph_grpc_mtls_test.go

Large diffs are not rendered by default.

28 changes: 14 additions & 14 deletions router-tests/security/subgraph_mtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import (

var (
clientTLSAllInsecureSkipVerify = config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
},
},
}
clientTLSEmployeesInsecureSkipVerifyWithTestdataCert = config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"employees": {
InsecureSkipCaVerification: true,
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: false,
},
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CaFile: certPath,
},
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
CertFile: "../testdata/tls/cert.pem",
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
CertFile: "../testdata/tls/cert-2.pem",
Expand Down Expand Up @@ -262,7 +262,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
CertFile: "../testdata/tls/cert-2.pem",
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
CertFile: "../testdata/tls/cert.pem",
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
CertFile: "../testdata/tls/cert.pem",
Expand Down Expand Up @@ -388,7 +388,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CaFile: certPath,
},
Expand Down Expand Up @@ -424,7 +424,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"employees": {
CaFile: certPath,
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
InsecureSkipCaVerification: true,
},
Expand Down Expand Up @@ -511,7 +511,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
All: config.TLSClientCertConfiguration{
CaFile: certPath,
CertFile: "../testdata/tls/cert.pem",
Expand Down Expand Up @@ -555,7 +555,7 @@ func TestSubgraphMTLS(t *testing.T) {
},
RouterOptions: []core.Option{
core.WithTLSConfig(config.TLSConfiguration{
Client: config.ClientTLSConfiguration{
Client: config.HTTPClientTLSConfiguration{
Subgraphs: map[string]config.TLSClientCertConfiguration{
"employees": {
CaFile: certPath,
Expand Down
14 changes: 11 additions & 3 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/encoding/protojson"

"github.com/wundergraph/cosmo/demo/pkg/subgraphs"
Expand Down Expand Up @@ -422,6 +423,10 @@ type SubgraphConfig struct {
// TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS()
// instead of Start(). This is useful for testing mTLS between the router and subgraphs.
TLSConfig *tls.Config

// GRPCTLSConfig enables TLS on the gRPC subgraph server. When set, the gRPC server
// uses TLS credentials instead of plain connections.
GRPCTLSConfig *tls.Config
}

type LogObservationConfig struct {
Expand Down Expand Up @@ -630,7 +635,7 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1068,7 +1073,7 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) {
)

if cfg.EnableGRPC {
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor)
projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig)
}

replacements := map[string]string{
Expand Down Expand Up @@ -1766,7 +1771,7 @@ func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.C
return s
}

func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor) (*grpc.Server, string) {
func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor, tlsConfig *tls.Config) (*grpc.Server, string) {
t.Helper()

// We could use freeport here, but it is easy to use ephemeral port and get the endpoint
Expand All @@ -1781,6 +1786,9 @@ func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interce
if interceptor != nil {
opts = append(opts, grpc.ChainUnaryInterceptor(interceptor))
}
if tlsConfig != nil {
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}

s := grpc.NewServer(opts...)
s.RegisterService(sd, service)
Expand Down
73 changes: 58 additions & 15 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -120,6 +121,8 @@ type BuildGraphMuxOptions struct {
ConfigSubgraphs []*nodev1.Subgraph
RoutingUrlGroupings map[string]map[string]bool
ReloadPersistentState *ReloadPersistentState
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

func (b BuildGraphMuxOptions) IsBaseGraph() bool {
Expand All @@ -133,6 +136,8 @@ type buildMultiGraphHandlerOptions struct {
reloadPersistentState *ReloadPersistentState
currentGraphMuxes map[string]*graphMux
changes *routerconfig.Changes
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

// reusedGraphMux holds a graph mux from the previous server that the new server
Expand Down Expand Up @@ -163,11 +168,23 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
}

// Build subgraph client TLS configs (mTLS for outbound subgraph connections)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.tls.settings.Client)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(
r.logger,
&r.tls.settings.Client,
)
if err != nil {
return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err)
}

// Build gRPC subgraph client TLS configs
defaultGRPCClientTLS, perSubgraphGRPCTLS, err := buildSubgraphTLSConfigs(
r.logger,
&r.tls.settings.ClientGRPC,
)
if err != nil {
return nil, fmt.Errorf("could not build gRPC subgraph client TLS config: %w", err)
}

// Base transport
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS)

Expand Down Expand Up @@ -337,6 +354,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
ConfigSubgraphs: response.Config.GetSubgraphs(),
RoutingUrlGroupings: routingUrlGroupings,
ReloadPersistentState: r.reloadPersistentState,
defaultClientTLS: defaultGRPCClientTLS,
perSubgraphTLS: perSubgraphGRPCTLS,
Comment thread
dkorittki marked this conversation as resolved.
})
if err != nil {
return nil, fmt.Errorf("failed to build base mux: %w", err)
Expand All @@ -358,6 +377,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig
reloadPersistentState: r.reloadPersistentState,
currentGraphMuxes: currentMuxes,
changes: response.Changes,
defaultClientTLS: defaultGRPCClientTLS,
perSubgraphTLS: perSubgraphGRPCTLS,
})
if err != nil {
return nil, fmt.Errorf("failed to build feature flag handler: %w", err)
Expand Down Expand Up @@ -578,6 +599,8 @@ func (s *graphServer) buildMultiGraphHandler(
EngineConfig: executionConfig.GetEngineConfig(),
ConfigSubgraphs: executionConfig.Subgraphs,
ReloadPersistentState: opts.reloadPersistentState,
defaultClientTLS: opts.defaultClientTLS,
perSubgraphTLS: opts.perSubgraphTLS,
})
if err != nil {
return nil, nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err)
Expand Down Expand Up @@ -1385,7 +1408,15 @@ func (s *graphServer) buildGraphMux(
subgraphTippers[subgraph] = subgraphTransport
}

if err := s.setupConnector(s.graphServerCtx, opts.EngineConfig, opts.ConfigSubgraphs, telemetryAttExpressions, tracingAttExpressions); err != nil {
err = s.setupConnector(s.graphServerCtx, setupConnectorOpts{
config: opts.EngineConfig,
configSubgraphs: opts.ConfigSubgraphs,
telemetryAttributeExpressions: telemetryAttExpressions,
tracingAttributeExpressions: tracingAttExpressions,
defaultClientTLS: opts.defaultClientTLS,
perSubgraphTLS: opts.perSubgraphTLS,
})
if err != nil {
return nil, fmt.Errorf("failed to setup plugin host: %w", err)
Comment thread
dkorittki marked this conversation as resolved.
}

Expand Down Expand Up @@ -1850,24 +1881,27 @@ func (s *graphServer) buildGraphMux(
return gm, nil
}

func (s *graphServer) setupConnector(
ctx context.Context,
config *nodev1.EngineConfiguration,
configSubgraphs []*nodev1.Subgraph,
telemetryAttributeExpressions *attributeExpressions,
tracingAttributeExpressions *attributeExpressions,
) error {
type setupConnectorOpts struct {
config *nodev1.EngineConfiguration
configSubgraphs []*nodev1.Subgraph
telemetryAttributeExpressions *attributeExpressions
tracingAttributeExpressions *attributeExpressions
defaultClientTLS *tls.Config
perSubgraphTLS map[string]*tls.Config
}

func (s *graphServer) setupConnector(ctx context.Context, opts setupConnectorOpts) error {
s.connector = grpcconnector.NewConnector()

for _, dsConfig := range config.DatasourceConfigurations {
for _, dsConfig := range opts.config.DatasourceConfigurations {
grpcConfig := dsConfig.GetCustomGraphql().GetGrpc()
if grpcConfig == nil {
continue
}

var sg *nodev1.Subgraph

for _, subgraph := range configSubgraphs {
for _, subgraph := range opts.configSubgraphs {
if subgraph.Id == dsConfig.Id {
sg = subgraph
break
Expand All @@ -1880,9 +1914,18 @@ func (s *graphServer) setupConnector(

pluginConfig := grpcConfig.GetPlugin()
if pluginConfig == nil {
// Resolve per-subgraph gRPC TLS config, falling back to the default.
var grpcTLS *tls.Config
if sgTLS, ok := opts.perSubgraphTLS[sg.Name]; ok {
grpcTLS = sgTLS
} else {
grpcTLS = opts.defaultClientTLS
}

remoteProvider, err := grpcremote.NewRemoteGRPCProvider(grpcremote.RemoteGRPCProviderConfig{
Logger: s.logger,
Endpoint: sg.RoutingUrl,
Logger: s.logger,
Endpoint: sg.RoutingUrl,
TLSConfig: grpcTLS,
})
if err != nil {
return fmt.Errorf("failed to create standalone plugin for subgraph %s: %w", dsConfig.Id, err)
Expand Down Expand Up @@ -1911,8 +1954,8 @@ func (s *graphServer) setupConnector(
tracer := s.tracerProvider.Tracer("wundergraph/cosmo/router/engine/grpc", oteltrace.WithInstrumentationVersion("0.0.1"))

getTraceAttributes := CreateGRPCTraceGetter(
telemetryAttributeExpressions,
tracingAttributeExpressions,
opts.telemetryAttributeExpressions,
opts.tracingAttributeExpressions,
s.spanNameFormatter,
)

Expand Down
Loading
Loading