diff --git a/router-tests/security/subgraph_grpc_mtls_test.go b/router-tests/security/subgraph_grpc_mtls_test.go new file mode 100644 index 0000000000..d4fabcc629 --- /dev/null +++ b/router-tests/security/subgraph_grpc_mtls_test.go @@ -0,0 +1,615 @@ +package integration + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +const projectsExpectedData = `{"data":{"projects":[{"id":"1","name":"Cloud Migration Overhaul"},{"id":"2","name":"Microservices Revolution"},{"id":"3","name":"AI-Powered Analytics"},{"id":"4","name":"DevOps Transformation"},{"id":"5","name":"Security Overhaul"},{"id":"6","name":"Mobile App Development"},{"id":"7","name":"Data Lake Implementation"}]}}` + +func TestSubgraphGRPCmTLS(t *testing.T) { + t.Parallel() + + t.Run("InsecureSkipVerify", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("connects when enabled", func(t *testing.T) { + t.Parallel() + // Router skips cert verification (InsecureSkipCaVerification: true) and successfully queries a TLS-only gRPC subgraph. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + + t.Run("fails when disabled", func(t *testing.T) { + t.Parallel() + // Router has no TLS config at all for a TLS-only gRPC subgraph, connection fails. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: false, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("overrides global CaFile", func(t *testing.T) { + t.Parallel() + // Global config points to a wrong CA (would fail); per-subgraph InsecureSkipCaVerification overrides it and succeeds. + + // Global config uses CaFile that does NOT match the gRPC server cert, + // so it would fail. Per-subgraph overrides with InsecureSkipVerify. + wrongCertPath, _ := generateServerCert(t) + serverTLS, _ := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: wrongCertPath, + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + InsecureSkipCaVerification: true, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + }) + }) + + t.Run("Client certificate", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("presents correct cert to mTLS subgraph", func(t *testing.T) { + t.Parallel() + // Router presents the right client cert, mTLS subgraph accepts it. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + + t.Run("fails without client cert", func(t *testing.T) { + t.Parallel() + // Subgraph requires a client cert, router sends none, connection fails. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + }) + }) + + t.Run("fails with wrong client cert", func(t *testing.T) { + t.Parallel() + // Router presents a cert signed by the wrong CA, subgraph rejects it. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert-2.pem", + KeyFile: "../testdata/tls/key-2.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("correct config without global", func(t *testing.T) { + t.Parallel() + // No global config, per-subgraph config with correct client cert succeeds. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + + t.Run("correct config overrides incorrect global", func(t *testing.T) { + t.Parallel() + // Global has wrong cert, per-subgraph has correct cert, per-subgraph wins and succeeds. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert-2.pem", + KeyFile: "../testdata/tls/key-2.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + + t.Run("incorrect config overrides correct global", func(t *testing.T) { + t.Parallel() + // Global has correct cert, per-subgraph has wrong cert, per-subgraph wins and fails. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert-2.pem", + KeyFile: "../testdata/tls/key-2.pem", + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + }) + }) + + t.Run("override without cert fails even when global has cert", func(t *testing.T) { + t.Parallel() + // Per-subgraph config with no cert fully replaces global (no field inheritance), mTLS fails. + + serverTLS, _ := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + InsecureSkipCaVerification: true, + // NO CertFile/KeyFile — proves fields are NOT inherited from All + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, "Failed to fetch from Subgraph") + }) + }) + }) + }) + + t.Run("CaFile", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("trusts gRPC subgraph server", func(t *testing.T) { + t.Parallel() + // Router's CaFile matches the server's self-signed cert, connection is verified and succeeds. + + serverTLS, certPath := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: certPath, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("trusts gRPC subgraph server without global config", func(t *testing.T) { + t.Parallel() + // No global config, per-subgraph CaFile trusts the server cert, succeeds. + + serverTLS, certPath := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + CaFile: certPath, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + + t.Run("overrides global InsecureSkipVerify with proper verification", func(t *testing.T) { + t.Parallel() + // Global uses InsecureSkipCaVerification, per-subgraph replaces it with a proper CaFile check — proves per-subgraph can be more secure than global. + + serverTLS, certPath := grpcSubgraphTLSServerConfig(t, false) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + CaFile: certPath, + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + }) + }) + + t.Run("Full mTLS", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("with CaFile and client certificate", func(t *testing.T) { + t.Parallel() + // Production-like: router verifies the server cert via CaFile and presents a client cert for mutual authentication, both sides verified. + + serverTLS, certPath := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: certPath, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("with CaFile and client certificate without global config", func(t *testing.T) { + t.Parallel() + // Same full mTLS scenario but configured only at the per-subgraph level with no global config. + + serverTLS, certPath := grpcSubgraphTLSServerConfig(t, true) + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithGRPCJSONTemplate, + EnableGRPC: true, + Subgraphs: testenv.SubgraphsConfig{ + Projects: testenv.SubgraphConfig{ + GRPCTLSConfig: serverTLS, + }, + }, + RouterOptions: []core.Option{ + core.WithTLSConfig(config.TLSConfiguration{ + ClientGRPC: config.GRPCClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "projects": { + CaFile: certPath, + CertFile: "../testdata/tls/cert.pem", + KeyFile: "../testdata/tls/key.pem", + }, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { projects { id name } }`, + }) + require.JSONEq(t, projectsExpectedData, res.Body) + }) + }) + }) + }) +} + +// grpcSubgraphTLSServerConfig creates a tls.Config for a gRPC subgraph test server. +// It generates a self-signed certificate valid for 127.0.0.1 and returns both the +// TLS config and the path to the cert PEM file (for use as CaFile on the router). +// If requireClientCert is true, the server requires the router to present a valid +// client certificate signed by the CA in testdata/tls/cert.pem. +func grpcSubgraphTLSServerConfig(t *testing.T, requireClientCert bool) (serverTLSConfig *tls.Config, certPath string) { + t.Helper() + + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + cfg := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + + if requireClientCert { + caPool := loadSubgraphMTLSCACertPool(t, "../testdata/tls/cert.pem") + cfg.ClientCAs = caPool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, certPath +} diff --git a/router-tests/security/subgraph_mtls_test.go b/router-tests/security/subgraph_mtls_test.go index d82d6a7b52..a66931a3de 100644 --- a/router-tests/security/subgraph_mtls_test.go +++ b/router-tests/security/subgraph_mtls_test.go @@ -25,14 +25,14 @@ import ( var ( clientTLSAllInsecureSkipVerify = config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, }, }, } clientTLSEmployeesInsecureSkipVerifyWithTestdataCert = config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ Subgraphs: map[string]config.TLSClientCertConfiguration{ "employees": { InsecureSkipCaVerification: true, @@ -84,7 +84,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: false, }, @@ -119,7 +119,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ CaFile: certPath, }, @@ -158,7 +158,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, CertFile: "../testdata/tls/cert.pem", @@ -207,7 +207,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, CertFile: "../testdata/tls/cert-2.pem", @@ -262,7 +262,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, CertFile: "../testdata/tls/cert-2.pem", @@ -299,7 +299,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, CertFile: "../testdata/tls/cert.pem", @@ -339,7 +339,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, CertFile: "../testdata/tls/cert.pem", @@ -388,7 +388,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ CaFile: certPath, }, @@ -424,7 +424,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ Subgraphs: map[string]config.TLSClientCertConfiguration{ "employees": { CaFile: certPath, @@ -460,7 +460,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, }, @@ -511,7 +511,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ All: config.TLSClientCertConfiguration{ CaFile: certPath, CertFile: "../testdata/tls/cert.pem", @@ -555,7 +555,7 @@ func TestSubgraphMTLS(t *testing.T) { }, RouterOptions: []core.Option{ core.WithTLSConfig(config.TLSConfiguration{ - Client: config.ClientTLSConfiguration{ + Client: config.HTTPClientTLSConfiguration{ Subgraphs: map[string]config.TLSClientCertConfiguration{ "employees": { CaFile: certPath, diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 1c34137589..71e11e48fe 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -52,6 +52,7 @@ import ( "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/protobuf/encoding/protojson" "github.com/wundergraph/cosmo/demo/pkg/subgraphs" @@ -422,6 +423,10 @@ type SubgraphConfig struct { // TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS() // instead of Start(). This is useful for testing mTLS between the router and subgraphs. TLSConfig *tls.Config + + // GRPCTLSConfig enables TLS on the gRPC subgraph server. When set, the gRPC server + // uses TLS credentials instead of plain connections. + GRPCTLSConfig *tls.Config } type LogObservationConfig struct { @@ -630,7 +635,7 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) { ) if cfg.EnableGRPC { - projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor) + projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig) } replacements := map[string]string{ @@ -1068,7 +1073,7 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) { ) if cfg.EnableGRPC { - projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor) + projectServer, endpoint = makeSafeGRPCServer(t, &projects.ProjectsService_ServiceDesc, &service.ProjectsService{}, cfg.Subgraphs.Projects.GRPCInterceptor, cfg.Subgraphs.Projects.GRPCTLSConfig) } replacements := map[string]string{ @@ -1766,7 +1771,7 @@ func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.C return s } -func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor) (*grpc.Server, string) { +func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interceptor grpc.UnaryServerInterceptor, tlsConfig *tls.Config) (*grpc.Server, string) { t.Helper() // We could use freeport here, but it is easy to use ephemeral port and get the endpoint @@ -1781,6 +1786,9 @@ func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any, interce if interceptor != nil { opts = append(opts, grpc.ChainUnaryInterceptor(interceptor)) } + if tlsConfig != nil { + opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig))) + } s := grpc.NewServer(opts...) s.RegisterService(sd, service) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 45c7179c9e..b4ceecf7ca 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/ecdsa" + "crypto/tls" "errors" "fmt" "net/http" @@ -120,6 +121,8 @@ type BuildGraphMuxOptions struct { ConfigSubgraphs []*nodev1.Subgraph RoutingUrlGroupings map[string]map[string]bool ReloadPersistentState *ReloadPersistentState + defaultClientTLS *tls.Config + perSubgraphTLS map[string]*tls.Config } func (b BuildGraphMuxOptions) IsBaseGraph() bool { @@ -133,6 +136,8 @@ type buildMultiGraphHandlerOptions struct { reloadPersistentState *ReloadPersistentState currentGraphMuxes map[string]*graphMux changes *routerconfig.Changes + defaultClientTLS *tls.Config + perSubgraphTLS map[string]*tls.Config } // reusedGraphMux holds a graph mux from the previous server that the new server @@ -163,11 +168,23 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig } // Build subgraph client TLS configs (mTLS for outbound subgraph connections) - defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.tls.settings.Client) + defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs( + r.logger, + &r.tls.settings.Client, + ) if err != nil { return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err) } + // Build gRPC subgraph client TLS configs + defaultGRPCClientTLS, perSubgraphGRPCTLS, err := buildSubgraphTLSConfigs( + r.logger, + &r.tls.settings.ClientGRPC, + ) + if err != nil { + return nil, fmt.Errorf("could not build gRPC subgraph client TLS config: %w", err) + } + // Base transport baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS) @@ -337,6 +354,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig ConfigSubgraphs: response.Config.GetSubgraphs(), RoutingUrlGroupings: routingUrlGroupings, ReloadPersistentState: r.reloadPersistentState, + defaultClientTLS: defaultGRPCClientTLS, + perSubgraphTLS: perSubgraphGRPCTLS, }) if err != nil { return nil, fmt.Errorf("failed to build base mux: %w", err) @@ -358,6 +377,8 @@ func newGraphServer(routerCtx context.Context, r *Router, response *routerconfig reloadPersistentState: r.reloadPersistentState, currentGraphMuxes: currentMuxes, changes: response.Changes, + defaultClientTLS: defaultGRPCClientTLS, + perSubgraphTLS: perSubgraphGRPCTLS, }) if err != nil { return nil, fmt.Errorf("failed to build feature flag handler: %w", err) @@ -578,6 +599,8 @@ func (s *graphServer) buildMultiGraphHandler( EngineConfig: executionConfig.GetEngineConfig(), ConfigSubgraphs: executionConfig.Subgraphs, ReloadPersistentState: opts.reloadPersistentState, + defaultClientTLS: opts.defaultClientTLS, + perSubgraphTLS: opts.perSubgraphTLS, }) if err != nil { return nil, nil, fmt.Errorf("failed to build mux for feature flag '%s': %w", featureFlagName, err) @@ -1385,7 +1408,15 @@ func (s *graphServer) buildGraphMux( subgraphTippers[subgraph] = subgraphTransport } - if err := s.setupConnector(s.graphServerCtx, opts.EngineConfig, opts.ConfigSubgraphs, telemetryAttExpressions, tracingAttExpressions); err != nil { + err = s.setupConnector(s.graphServerCtx, setupConnectorOpts{ + config: opts.EngineConfig, + configSubgraphs: opts.ConfigSubgraphs, + telemetryAttributeExpressions: telemetryAttExpressions, + tracingAttributeExpressions: tracingAttExpressions, + defaultClientTLS: opts.defaultClientTLS, + perSubgraphTLS: opts.perSubgraphTLS, + }) + if err != nil { return nil, fmt.Errorf("failed to setup plugin host: %w", err) } @@ -1850,16 +1881,19 @@ func (s *graphServer) buildGraphMux( return gm, nil } -func (s *graphServer) setupConnector( - ctx context.Context, - config *nodev1.EngineConfiguration, - configSubgraphs []*nodev1.Subgraph, - telemetryAttributeExpressions *attributeExpressions, - tracingAttributeExpressions *attributeExpressions, -) error { +type setupConnectorOpts struct { + config *nodev1.EngineConfiguration + configSubgraphs []*nodev1.Subgraph + telemetryAttributeExpressions *attributeExpressions + tracingAttributeExpressions *attributeExpressions + defaultClientTLS *tls.Config + perSubgraphTLS map[string]*tls.Config +} + +func (s *graphServer) setupConnector(ctx context.Context, opts setupConnectorOpts) error { s.connector = grpcconnector.NewConnector() - for _, dsConfig := range config.DatasourceConfigurations { + for _, dsConfig := range opts.config.DatasourceConfigurations { grpcConfig := dsConfig.GetCustomGraphql().GetGrpc() if grpcConfig == nil { continue @@ -1867,7 +1901,7 @@ func (s *graphServer) setupConnector( var sg *nodev1.Subgraph - for _, subgraph := range configSubgraphs { + for _, subgraph := range opts.configSubgraphs { if subgraph.Id == dsConfig.Id { sg = subgraph break @@ -1880,9 +1914,18 @@ func (s *graphServer) setupConnector( pluginConfig := grpcConfig.GetPlugin() if pluginConfig == nil { + // Resolve per-subgraph gRPC TLS config, falling back to the default. + var grpcTLS *tls.Config + if sgTLS, ok := opts.perSubgraphTLS[sg.Name]; ok { + grpcTLS = sgTLS + } else { + grpcTLS = opts.defaultClientTLS + } + remoteProvider, err := grpcremote.NewRemoteGRPCProvider(grpcremote.RemoteGRPCProviderConfig{ - Logger: s.logger, - Endpoint: sg.RoutingUrl, + Logger: s.logger, + Endpoint: sg.RoutingUrl, + TLSConfig: grpcTLS, }) if err != nil { return fmt.Errorf("failed to create standalone plugin for subgraph %s: %w", dsConfig.Id, err) @@ -1911,8 +1954,8 @@ func (s *graphServer) setupConnector( tracer := s.tracerProvider.Tracer("wundergraph/cosmo/router/engine/grpc", oteltrace.WithInstrumentationVersion("0.0.1")) getTraceAttributes := CreateGRPCTraceGetter( - telemetryAttributeExpressions, - tracingAttributeExpressions, + opts.telemetryAttributeExpressions, + opts.tracingAttributeExpressions, s.spanNameFormatter, ) diff --git a/router/core/tls.go b/router/core/tls.go index 7f4d8a97e8..58a4794b32 100644 --- a/router/core/tls.go +++ b/router/core/tls.go @@ -13,7 +13,7 @@ import ( ) // buildTLSClientConfig creates a *tls.Config from a TLSClientCertConfiguration. -func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Config, error) { +func buildTLSClientConfig(clientCfg config.TLSClientCertConfiguration) (*tls.Config, error) { tlsConfig := &tls.Config{ InsecureSkipVerify: clientCfg.InsecureSkipCaVerification, } @@ -43,10 +43,18 @@ func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Co return tlsConfig, nil } +type clientTLSConfiguration interface { + GetAll() config.TLSClientCertConfiguration + GetSubgraphs() map[string]config.TLSClientCertConfiguration + Enabled() bool +} + // buildSubgraphTLSConfigs builds the default and per-subgraph TLS configs from raw configuration. // Returns (defaultClientTLS, perSubgraphTLS, error). -func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfiguration) (*tls.Config, map[string]*tls.Config, error) { - hasAll := (cfg.All.CertFile != "" && cfg.All.KeyFile != "") || cfg.All.CaFile != "" || cfg.All.InsecureSkipCaVerification +func buildSubgraphTLSConfigs[K clientTLSConfiguration](logger *zap.Logger, cfg K) ( + *tls.Config, map[string]*tls.Config, error) { + hasAll := (cfg.GetAll().CertFile != "" && cfg.GetAll().KeyFile != "") || + cfg.GetAll().CaFile != "" || cfg.GetAll().InsecureSkipCaVerification // If no global TLS config is provided and there are no subgraph specific TLS configs if !cfg.Enabled() { @@ -57,24 +65,26 @@ func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfigurat perSubgraphTLS := make(map[string]*tls.Config) if hasAll { - if cfg.All.InsecureSkipCaVerification { - logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.") + if cfg.GetAll().InsecureSkipCaVerification { + logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. " + + "This is not recommended for production environments.") } - defaultTLS, err := buildTLSClientConfig(&cfg.All) + defaultTLS, err := buildTLSClientConfig(cfg.GetAll()) if err != nil { return nil, nil, fmt.Errorf("failed to build global subgraph TLS config: %w", err) } defaultClientTLS = defaultTLS } - for name, sgCfg := range cfg.Subgraphs { + for name, sgCfg := range cfg.GetSubgraphs() { if sgCfg.InsecureSkipCaVerification { - logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.", + logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from "+ + "global config. This is not recommended for production environments.", zap.String("subgraph", name)) } - subgraphTLS, err := buildTLSClientConfig(&sgCfg) + subgraphTLS, err := buildTLSClientConfig(sgCfg) if err != nil { return nil, nil, fmt.Errorf("failed to build TLS config for subgraph %q: %w", name, err) } diff --git a/router/core/tls_test.go b/router/core/tls_test.go index d478e58b3f..8ca7d00395 100644 --- a/router/core/tls_test.go +++ b/router/core/tls_test.go @@ -7,15 +7,16 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "go.uber.org/zap/zaptest/observer" "math/big" "os" "path/filepath" "testing" "time" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -26,7 +27,7 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("returns config with insecure_skip_ca_verification only", func(t *testing.T) { t.Parallel() - tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + tlsCfg, err := buildTLSClientConfig(config.TLSClientCertConfiguration{ InsecureSkipCaVerification: true, }) @@ -42,7 +43,7 @@ func TestBuildTLSClientConfig(t *testing.T) { certPath, keyPath := generateTestCert(t, "client") - tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + tlsCfg, err := buildTLSClientConfig(config.TLSClientCertConfiguration{ CertFile: certPath, KeyFile: keyPath, }) @@ -57,7 +58,7 @@ func TestBuildTLSClientConfig(t *testing.T) { certPath, _ := generateTestCert(t, "ca") - tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + tlsCfg, err := buildTLSClientConfig(config.TLSClientCertConfiguration{ CaFile: certPath, }) require.NoError(t, err) @@ -68,7 +69,7 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("errors on invalid cert path", func(t *testing.T) { t.Parallel() - _, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + _, err := buildTLSClientConfig(config.TLSClientCertConfiguration{ CertFile: "/nonexistent/cert.pem", KeyFile: "/nonexistent/key.pem", }) @@ -79,7 +80,7 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("errors on invalid CA path", func(t *testing.T) { t.Parallel() - _, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + _, err := buildTLSClientConfig(config.TLSClientCertConfiguration{ CaFile: "/nonexistent/ca.pem", }) require.Error(t, err) @@ -89,8 +90,7 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("returns nil when no TLS configured", func(t *testing.T) { t.Parallel() - cfg := &config.ClientTLSConfiguration{} - defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{}) require.NoError(t, err) require.Nil(t, defaultTLS) require.Nil(t, perSubgraphTLS) @@ -102,15 +102,13 @@ func TestBuildTLSClientConfig(t *testing.T) { certPath, keyPath := generateTestCert(t, "client") caPath, _ := generateTestCert(t, "ca") - cfg := &config.ClientTLSConfiguration{ - All: config.TLSClientCertConfiguration{ - CertFile: certPath, - KeyFile: keyPath, - CaFile: caPath, - }, + all := config.TLSClientCertConfiguration{ + CertFile: certPath, + KeyFile: keyPath, + CaFile: caPath, } - defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{All: all}) require.NoError(t, err) require.NotNil(t, defaultTLS) require.Len(t, defaultTLS.Certificates, 1) @@ -123,16 +121,14 @@ func TestBuildTLSClientConfig(t *testing.T) { certPath, keyPath := generateTestCert(t, "products") - cfg := &config.ClientTLSConfiguration{ - Subgraphs: map[string]config.TLSClientCertConfiguration{ - "products": { - CertFile: certPath, - KeyFile: keyPath, - }, + subgraphs := map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: certPath, + KeyFile: keyPath, }, } - defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{Subgraphs: subgraphs}) require.NoError(t, err) require.Nil(t, defaultTLS) require.Contains(t, perSubgraphTLS, "products") @@ -145,20 +141,18 @@ func TestBuildTLSClientConfig(t *testing.T) { globalCert, globalKey := generateTestCert(t, "global") productsCert, productsKey := generateTestCert(t, "products") - cfg := &config.ClientTLSConfiguration{ - All: config.TLSClientCertConfiguration{ - CertFile: globalCert, - KeyFile: globalKey, - }, - Subgraphs: map[string]config.TLSClientCertConfiguration{ - "products": { - CertFile: productsCert, - KeyFile: productsKey, - }, + all := config.TLSClientCertConfiguration{ + CertFile: globalCert, + KeyFile: globalKey, + } + subgraphs := map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: productsCert, + KeyFile: productsKey, }, } - defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{All: all, Subgraphs: subgraphs}) require.NoError(t, err) require.NotNil(t, defaultTLS) require.Contains(t, perSubgraphTLS, "products") @@ -167,14 +161,12 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("errors on invalid global cert", func(t *testing.T) { t.Parallel() - cfg := &config.ClientTLSConfiguration{ - All: config.TLSClientCertConfiguration{ - CertFile: "/nonexistent/cert.pem", - KeyFile: "/nonexistent/key.pem", - }, + all := config.TLSClientCertConfiguration{ + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", } - _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{All: all}) require.Error(t, err) require.EqualError(t, err, "failed to build global subgraph TLS config: failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory") }) @@ -182,16 +174,14 @@ func TestBuildTLSClientConfig(t *testing.T) { t.Run("errors on invalid per-subgraph cert", func(t *testing.T) { t.Parallel() - cfg := &config.ClientTLSConfiguration{ - Subgraphs: map[string]config.TLSClientCertConfiguration{ - "products": { - CertFile: "/nonexistent/cert.pem", - KeyFile: "/nonexistent/key.pem", - }, + subgraphs := map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", }, } - _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), &config.HTTPClientTLSConfiguration{Subgraphs: subgraphs}) require.Error(t, err) require.EqualError(t, err, `failed to build TLS config for subgraph "products": failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory`) }) @@ -202,13 +192,11 @@ func TestBuildTLSClientConfig(t *testing.T) { core, logs := observer.New(zapcore.WarnLevel) logger := zap.New(core) - cfg := &config.ClientTLSConfiguration{ - All: config.TLSClientCertConfiguration{ - InsecureSkipCaVerification: true, - }, + all := config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, } - defaultTLS, _, err := buildSubgraphTLSConfigs(logger, cfg) + defaultTLS, _, err := buildSubgraphTLSConfigs(logger, &config.HTTPClientTLSConfiguration{All: all}) require.NoError(t, err) require.NotNil(t, defaultTLS) require.True(t, defaultTLS.InsecureSkipVerify) @@ -223,15 +211,13 @@ func TestBuildTLSClientConfig(t *testing.T) { core, logs := observer.New(zapcore.WarnLevel) logger := zap.New(core) - cfg := &config.ClientTLSConfiguration{ - Subgraphs: map[string]config.TLSClientCertConfiguration{ - "products": { - InsecureSkipCaVerification: true, - }, + subgraphs := map[string]config.TLSClientCertConfiguration{ + "products": { + InsecureSkipCaVerification: true, }, } - _, perSubgraphTLS, err := buildSubgraphTLSConfigs(logger, cfg) + _, perSubgraphTLS, err := buildSubgraphTLSConfigs(logger, &config.HTTPClientTLSConfiguration{Subgraphs: subgraphs}) require.NoError(t, err) require.Contains(t, perSubgraphTLS, "products") require.True(t, perSubgraphTLS["products"].InsecureSkipVerify) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 15dbb93f3c..131aadb1d2 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -902,26 +902,60 @@ type TLSClientCertConfiguration struct { InsecureSkipCaVerification bool `yaml:"insecure_skip_ca_verification" envDefault:"false" env:"INSECURE_SKIP_CA_VERIFICATION"` } -type ClientTLSConfiguration struct { +type HTTPClientTLSConfiguration struct { // All applies to all subgraph connections. All TLSClientCertConfiguration `yaml:"all" envPrefix:"TLS_CLIENT_ALL_"` // Subgraphs overrides per-subgraph TLS config. Key is the subgraph name. Subgraphs map[string]TLSClientCertConfiguration `yaml:"subgraphs,omitempty"` } -// Enabled returns true if anything in s has been configured.© -func (s ClientTLSConfiguration) Enabled() bool { - allConfigured := s.All.InsecureSkipCaVerification || - s.All.CaFile != "" || - s.All.KeyFile != "" || - s.All.CertFile != "" +func (c *HTTPClientTLSConfiguration) GetAll() TLSClientCertConfiguration { + return c.All +} + +func (c *HTTPClientTLSConfiguration) GetSubgraphs() map[string]TLSClientCertConfiguration { + return c.Subgraphs +} + +// Enabled returns true if anything in c has been configured. +func (c *HTTPClientTLSConfiguration) Enabled() bool { + allConfigured := c.All.InsecureSkipCaVerification || + c.All.CaFile != "" || + c.All.KeyFile != "" || + c.All.CertFile != "" + + return allConfigured || len(c.Subgraphs) > 0 +} + +type GRPCClientTLSConfiguration struct { + // All applies to all gRPC subgraph connections. + All TLSClientCertConfiguration `yaml:"all" envPrefix:"TLS_CLIENT_GRPC_ALL_"` + // Subgraphs overrides per-subgraph gRPC TLS config. Key is the subgraph name. + Subgraphs map[string]TLSClientCertConfiguration `yaml:"subgraphs,omitempty"` +} + +func (c *GRPCClientTLSConfiguration) GetAll() TLSClientCertConfiguration { + return c.All +} + +func (c *GRPCClientTLSConfiguration) GetSubgraphs() map[string]TLSClientCertConfiguration { + return c.Subgraphs +} + +// Enabled returns true if anything in c has been configured. +func (c *GRPCClientTLSConfiguration) Enabled() bool { + allConfigured := c.All.InsecureSkipCaVerification || + c.All.CaFile != "" || + c.All.KeyFile != "" || + c.All.CertFile != "" - return allConfigured || len(s.Subgraphs) > 0 + return allConfigured || len(c.Subgraphs) > 0 } type TLSConfiguration struct { - Server TLSServerConfiguration `yaml:"server"` - Client ClientTLSConfiguration `yaml:"client"` + Server TLSServerConfiguration `yaml:"server"` + Client HTTPClientTLSConfiguration `yaml:"client"` + ClientGRPC GRPCClientTLSConfiguration `yaml:"client_grpc"` } type SubgraphErrorPropagationMode string diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 28b9f91ced..7e79a29b92 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -564,6 +564,78 @@ } } } + }, + "client_grpc": { + "type": "object", + "description": "The TLS configuration for outbound gRPC connections from the router to gRPC subgraphs. Enables TLS/mTLS by presenting a client certificate when connecting to gRPC subgraphs.", + "additionalProperties": false, + "properties": { + "all": { + "type": "object", + "description": "TLS configuration applied to all gRPC subgraph connections.", + "additionalProperties": false, + "properties": { + "cert_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client certificate chain file. Used to authenticate the router to gRPC subgraphs. May include intermediate certificates." + }, + "key_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client private key file." + }, + "ca_file": { + "type": "string", + "format": "file-path", + "description": "The path to the CA certificate file. Used to verify gRPC subgraph server certificates. If not set, the system's root CAs are used." + }, + "insecure_skip_ca_verification": { + "type": "boolean", + "default": false, + "description": "Skip verification of the gRPC subgraph server certificate. Only use for development or testing." + } + }, + "dependencies": { + "cert_file": ["key_file"], + "key_file": ["cert_file"] + } + }, + "subgraphs": { + "type": "object", + "description": "Per-subgraph gRPC TLS configuration overrides. Each key is a subgraph name. Fully overrides the 'all' config for that subgraph.", + "additionalProperties": { + "type": "object", + "additionalProperties": false, + "properties": { + "cert_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client certificate chain file for this gRPC subgraph. May include intermediate certificates." + }, + "key_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client private key file for this gRPC subgraph." + }, + "ca_file": { + "type": "string", + "format": "file-path", + "description": "The path to the CA certificate file for verifying this gRPC subgraph's server certificate." + }, + "insecure_skip_ca_verification": { + "type": "boolean", + "default": false, + "description": "Skip verification of this gRPC subgraph's server certificate. Only use for development or testing." + } + }, + "dependencies": { + "cert_file": ["key_file"], + "key_file": ["cert_file"] + } + } + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index d1ff08a786..5acec9ec49 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -49,6 +49,18 @@ tls: key_file: '/path/to/products-client.key' ca_file: '/path/to/products-ca.crt' insecure_skip_ca_verification: false + client_grpc: + all: + cert_file: '/path/to/grpc-client.crt' + key_file: '/path/to/grpc-client.key' + ca_file: '/path/to/grpc-ca.crt' + insecure_skip_ca_verification: false + subgraphs: + my-grpc-subgraph: + cert_file: '/path/to/my-grpc-subgraph-client.crt' + key_file: '/path/to/my-grpc-subgraph-client.key' + ca_file: '/path/to/my-grpc-subgraph-ca.crt' + insecure_skip_ca_verification: false instance_id: '' graphql_metrics: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 3edcd11623..d444f2bb4c 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -140,6 +140,15 @@ "InsecureSkipCaVerification": false }, "Subgraphs": null + }, + "ClientGRPC": { + "All": { + "CertFile": "", + "KeyFile": "", + "CaFile": "", + "InsecureSkipCaVerification": false + }, + "Subgraphs": null } }, "CacheControl": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 14ae1e3529..303971ee25 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -180,6 +180,22 @@ "InsecureSkipCaVerification": false } } + }, + "ClientGRPC": { + "All": { + "CertFile": "/path/to/grpc-client.crt", + "KeyFile": "/path/to/grpc-client.key", + "CaFile": "/path/to/grpc-ca.crt", + "InsecureSkipCaVerification": false + }, + "Subgraphs": { + "my-grpc-subgraph": { + "CertFile": "/path/to/my-grpc-subgraph-client.crt", + "KeyFile": "/path/to/my-grpc-subgraph-client.key", + "CaFile": "/path/to/my-grpc-subgraph-ca.crt", + "InsecureSkipCaVerification": false + } + } } }, "CacheControl": { diff --git a/router/pkg/grpcconnector/grpcremote/grpc_remote.go b/router/pkg/grpcconnector/grpcremote/grpc_remote.go index c6d64bf12a..05a5dfa589 100644 --- a/router/pkg/grpcconnector/grpcremote/grpc_remote.go +++ b/router/pkg/grpcconnector/grpcremote/grpc_remote.go @@ -2,6 +2,7 @@ package grpcremote import ( "context" + "crypto/tls" "fmt" "io" "sync" @@ -9,6 +10,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/grpcconnector" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) @@ -18,14 +20,17 @@ type RemoteGRPCProviderConfig struct { Logger *zap.Logger // Endpoint is the URL of the gRPC server to connect to. Endpoint string + // TLSConfig is the TLS configuration for the gRPC connection. When nil, an insecure connection is used. + TLSConfig *tls.Config } // RemoteGRPCProvider is a client provider that manages a gRPC client connection to a standalone gRPC server. // It is used to connect to a standalone gRPC server that is not part of the cosmo cluster. // The provider maintains a single client connection and provides thread-safe access to it. type RemoteGRPCProvider struct { - logger *zap.Logger - endpoint string + logger *zap.Logger + endpoint string + tlsConfig *tls.Config cc grpc.ClientConnInterface mu sync.RWMutex @@ -46,8 +51,9 @@ func NewRemoteGRPCProvider(config RemoteGRPCProviderConfig) (*RemoteGRPCProvider } return &RemoteGRPCProvider{ - logger: config.Logger, - endpoint: config.Endpoint, + logger: config.Logger, + endpoint: config.Endpoint, + tlsConfig: config.TLSConfig, }, nil } @@ -61,10 +67,18 @@ func (g *RemoteGRPCProvider) GetClient() grpc.ClientConnInterface { } // Start initializes the gRPC client connection if it hasn't been created yet. -// It parses the endpoint URL and creates a new insecure gRPC connection. +// It creates a new gRPC connection using TLS when a TLSConfig is provided, +// otherwise uses an insecure connection. func (g *RemoteGRPCProvider) Start(ctx context.Context) error { if g.cc == nil { - clientConn, err := grpc.NewClient(g.endpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) + var transportCreds grpc.DialOption + if g.tlsConfig != nil { + transportCreds = grpc.WithTransportCredentials(credentials.NewTLS(g.tlsConfig)) + } else { + transportCreds = grpc.WithTransportCredentials(insecure.NewCredentials()) + } + + clientConn, err := grpc.NewClient(g.endpoint, transportCreds) if err != nil { return fmt.Errorf("failed to create client connection: %w", err) }