diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index e2de738182..29f71654fe 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -360,7 +360,7 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { return resolve.FetchConfiguration{} } - dataSource, err = grpcdatasource.NewDataSource(p.grpcClient, grpcdatasource.DataSourceConfig{ + dataSource, err = grpcdatasource.NewDataSource(grpcdatasource.NewGRPCTransport(p.grpcClient), grpcdatasource.DataSourceConfig{ Operation: &opDocument, Definition: p.config.schemaConfiguration.upstreamSchemaAst, Mapping: p.config.grpc.Mapping, diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 527b3e2448..276e612092 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -16,7 +16,6 @@ import ( "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" - "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/wundergraph/astjson" @@ -44,7 +43,7 @@ var _ resolve.DataSource = (*DataSource)(nil) // transforms the responses back to GraphQL format. type DataSource struct { plan *RPCExecutionPlan - cc grpc.ClientConnInterface + transport RPCTransport rc *RPCCompiler mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations @@ -68,8 +67,8 @@ type DataSourceConfig struct { Disabled bool } -// NewDataSource creates a new gRPC datasource -func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*DataSource, error) { +// NewDataSource creates a new datasource with the given RPCTransport. +func NewDataSource(transport RPCTransport, config DataSourceConfig) (*DataSource, error) { planner, err := NewPlanner(config.SubgraphName, config.Mapping, config.FederationConfigs) if err != nil { return nil, err @@ -81,7 +80,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D return &DataSource{ plan: plan, - cc: client, + transport: transport, rc: config.Compiler, mapping: config.Mapping, definition: config.Definition, @@ -152,7 +151,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { // Invoke the gRPC method - this will populate serviceCall.Output - err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) + err := d.transport.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) if err != nil { return err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go index 786fe338cf..ffd077d166 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go @@ -317,7 +317,7 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -617,7 +617,7 @@ func Test_DataSource_Load_WithEntity_Calls_WithCompositeTypes(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1406,7 +1406,7 @@ func Test_DataSource_Load_WithEntity_Calls_And_Requires(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1821,7 +1821,7 @@ func Test_DataSource_Load_WithEntity_Calls_And_Requires_And_FieldResolvers(t *te } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1965,7 +1965,7 @@ func Test_DataSource_Load_WithEntity_Calls_And_Requires_AbstractTypes(t *testing } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go index 2ad9172916..1ce2d9f744 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go @@ -150,7 +150,7 @@ func Test_DataSource_Load_NullMetrics_NestedResolversNotInvoked(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -191,7 +191,7 @@ func Test_DataSource_Load_NullCategory_FieldResolversNotInvoked(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -226,7 +226,7 @@ func Test_DataSource_Load_ArgumentLessFieldResolversCalled(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -267,7 +267,7 @@ func Test_DataSource_Load_NullCategory_ArgumentLessFieldResolversNotInvoked(t *t compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 112ce26d62..df252df130 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -44,7 +44,7 @@ func Benchmark_DataSource_Load(b *testing.B) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(b), testMapping()) require.NoError(b, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -81,7 +81,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { const subgraphName = "Products" - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: subgraphName, @@ -192,7 +192,7 @@ func Test_DataSource_Load(t *testing.T) { } mi := mockInterface{} - ds, err := NewDataSource(mi, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(mi), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -247,7 +247,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { } // 2. Create a datasource with the real gRPC client connection - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -337,7 +337,7 @@ func Test_DataSource_Load_WithRecursiveInputType(t *testing.T) { t.Fatalf("failed to compile proto: %v", err) } - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -388,7 +388,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } // 2. Create a datasource with the real gRPC client connection - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -489,7 +489,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } // 3. Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -837,7 +837,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1107,7 +1107,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1243,7 +1243,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1323,7 +1323,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1413,7 +1413,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1882,7 +1882,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -2260,7 +2260,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -3561,7 +3561,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -4795,7 +4795,7 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -5004,7 +5004,7 @@ func Test_Datasource_Load_WithHeaders(t *testing.T) { require.NoError(t, err) // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -5055,7 +5055,7 @@ func Test_Datasource_Load_PreservesExistingContextMetadata(t *testing.T) { require.NoError(t, err) // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/transport.go b/v2/pkg/engine/datasource/grpc_datasource/transport.go new file mode 100644 index 0000000000..6bdd0208d1 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/transport.go @@ -0,0 +1,36 @@ +package grpcdatasource + +import ( + "context" + "errors" + + "google.golang.org/grpc" + protoref "google.golang.org/protobuf/reflect/protoreflect" +) + +// RPCTransport abstracts the transport protocol for RPC calls. +// Both gRPC and Connect protocol implement this interface. +type RPCTransport interface { + Invoke(ctx context.Context, methodFullName string, input, output protoref.Message) error +} + +// grpcTransport wraps grpc.ClientConnInterface to implement RPCTransport. +type grpcTransport struct { + cc grpc.ClientConnInterface +} + +// NewGRPCTransport creates an RPCTransport that delegates to a gRPC ClientConnInterface. +func NewGRPCTransport(cc grpc.ClientConnInterface) RPCTransport { + return &grpcTransport{cc: cc} +} + +func (t *grpcTransport) Invoke(ctx context.Context, method string, input, output protoref.Message) error { + if t.cc == nil { + return errors.New("grpc transport: nil client connection") + } + // grpc.ClientConnInterface.Invoke accepts (ctx, method, args any, reply any, opts ...grpc.CallOption). + // protoref.Message satisfies the any constraint; variadic opts can be omitted. + // This wrapper intentionally does not forward grpc.CallOption, as RPCTransport + // is protocol-agnostic. The existing grpc_datasource code does not use any CallOption at the Invoke site. + return t.cc.Invoke(ctx, method, input, output) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/transport_test.go b/v2/pkg/engine/datasource/grpc_datasource/transport_test.go new file mode 100644 index 0000000000..923d067942 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/transport_test.go @@ -0,0 +1,54 @@ +package grpcdatasource + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + protoref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" +) + +// newTestCompiler builds an RPCCompiler bound to the grpctest fixture. +// It is shared by every transport-level test in this package. +func newTestCompiler(t *testing.T) *RPCCompiler { + t.Helper() + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + require.NoError(t, err) + return compiler +} + +// findMessageDesc resolves a fully-qualified message name from the compiled +// proto document. Used by tests to construct dynamicpb.Message instances +// for transport.Invoke calls without depending on the generated Go types. +func findMessageDesc(t *testing.T, compiler *RPCCompiler, fullName string) protoref.MessageDescriptor { + t.Helper() + for _, m := range compiler.doc.Messages { + if string(m.Desc.FullName()) == fullName { + return m.Desc + } + } + t.Fatalf("message %q not found in proto document", fullName) + return nil +} + +// TestGRPCTransport_Invoke is a smoke test for the gRPC RPCTransport +// implementation; it goes through the data source's mockInterface (defined +// in grpc_datasource_test.go) so the assertion is just that Invoke returns +// no error for a well-formed request. +func TestGRPCTransport_Invoke(t *testing.T) { + mi := mockInterface{} + transport := NewGRPCTransport(mi) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryComplexFilterTypeRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryComplexFilterTypeResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + err := transport.Invoke(context.Background(), "/productv1.ProductService/QueryComplexFilterType", inputMsg, outputMsg) + require.NoError(t, err) +}