2222import static com .google .common .collect .ImmutableSet .toImmutableSet ;
2323import static java .util .Objects .requireNonNull ;
2424
25+ import java .io .ByteArrayInputStream ;
2526import java .io .IOException ;
2627import java .util .AbstractMap .SimpleImmutableEntry ;
2728import java .util .Base64 ;
6970import com .linecorp .armeria .common .RequestHeaders ;
7071import com .linecorp .armeria .common .RequestHeadersBuilder ;
7172import com .linecorp .armeria .common .ResponseHeaders ;
73+ import com .linecorp .armeria .common .SerializationFormat ;
7274import com .linecorp .armeria .common .annotation .Nullable ;
7375import com .linecorp .armeria .common .grpc .GrpcSerializationFormats ;
7476import com .linecorp .armeria .common .grpc .protocol .GrpcHeaderNames ;
7577import com .linecorp .armeria .common .logging .RequestLogProperty ;
7678import com .linecorp .armeria .common .util .SafeCloseable ;
7779import com .linecorp .armeria .internal .common .JacksonUtil ;
80+ import com .linecorp .armeria .internal .common .grpc .GrpcMessageMarshaller ;
7881import com .linecorp .armeria .internal .server .grpc .HttpEndpointSpecification ;
7982import com .linecorp .armeria .internal .server .grpc .HttpEndpointSpecification .Parameter ;
8083import com .linecorp .armeria .internal .server .grpc .HttpEndpointSupport ;
9093import com .linecorp .armeria .unsafe .PooledObjects ;
9194
9295import io .grpc .Status ;
96+ import io .netty .buffer .ByteBuf ;
9397import io .netty .util .AttributeKey ;
9498
9599final class HttpJsonTranscoder implements HttpEndpointSupport {
@@ -146,12 +150,15 @@ private static String grpcPath(MethodDescriptor methodDescriptor) {
146150 AttributeKey .valueOf (FramedGrpcService .class , "HTTP_JSON_GRPC_METHOD_INFO" );
147151
148152 private final HttpJsonTranscodingOptions options ;
153+ private final boolean protoSerialization ;
149154 private final Map <Route , TranscodingSpec > routeAndSpecs ;
150155 private final Set <Route > routes ;
151156
152157 HttpJsonTranscoder (Map <Route , TranscodingSpec > routeAndSpecs ,
153- HttpJsonTranscodingOptions httpJsonTranscodingOptions ) {
158+ HttpJsonTranscodingOptions httpJsonTranscodingOptions ,
159+ boolean protoSerialization ) {
154160 options = requireNonNull (httpJsonTranscodingOptions , "httpJsonTranscodingOptions" );
161+ this .protoSerialization = protoSerialization ;
155162 this .routeAndSpecs = routeAndSpecs ;
156163
157164 final LinkedHashSet <Route > linkedHashSet = new LinkedHashSet <>(routeAndSpecs .size ());
@@ -214,8 +221,10 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
214221 "gRPC encoding is not supported for non-framed requests." );
215222 }
216223
224+ final boolean useProto = protoSerialization ;
217225 grpcHeaders .method (HttpMethod .POST )
218- .contentType (GrpcSerializationFormats .JSON .mediaType ());
226+ .contentType (useProto ? GrpcSerializationFormats .PROTO .mediaType ()
227+ : GrpcSerializationFormats .JSON .mediaType ());
219228 grpcHeaders .path (grpcPath (spec .methodDescriptor ));
220229 // All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
221230 // clear the header if it's present.
@@ -244,10 +253,20 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
244253 requestContent = convertToJson (ctx , clientRequest , spec );
245254 }
246255
256+ final JsonProtoMarshaller jsonProtoMarshaller =
257+ useProto ? new JsonProtoMarshaller (ctx , method ) : null ;
258+ final HttpData transcodedRequestContent ;
259+ if (useProto ) {
260+ transcodedRequestContent = convertToProto (requestContent , jsonProtoMarshaller );
261+ } else {
262+ transcodedRequestContent = requestContent ;
263+ }
247264 final ResponseHandler responseHandler =
248- new HttpJsonResponseHandler (ctx , responseFuture , spec );
265+ new HttpJsonResponseHandler (ctx , responseFuture , spec ,
266+ useProto , jsonProtoMarshaller );
249267 UnframedGrpcSupport .frameAndServe (
250- delegate , ctx , grpcHeaders .build (), requestContent , responseHandler );
268+ delegate , ctx , grpcHeaders .build (), transcodedRequestContent ,
269+ responseHandler );
251270 } catch (IllegalArgumentException iae ) {
252271 responseFuture .completeExceptionally (
253272 HttpStatusException .of (HttpStatus .BAD_REQUEST , iae ));
@@ -261,7 +280,17 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
261280 return HttpResponse .of (responseFuture );
262281 }
263282
264- private AggregatedHttpResponse convertResponse (TranscodingSpec spec , AggregatedHttpResponse httpResponse ) {
283+ private static AggregatedHttpResponse convertResponse (
284+ TranscodingSpec spec , boolean useProto , AggregatedHttpResponse httpResponse ,
285+ @ Nullable JsonProtoMarshaller jsonProtoMarshaller ) throws Exception {
286+ final AggregatedHttpResponse jsonResponse =
287+ useProto ? convertProtoResponseToJson (httpResponse , jsonProtoMarshaller )
288+ : httpResponse ;
289+ return convertJsonResponse (spec , jsonResponse );
290+ }
291+
292+ private static AggregatedHttpResponse convertJsonResponse (
293+ TranscodingSpec spec , AggregatedHttpResponse httpResponse ) {
265294 // Ignore the spec if the method is HttpBody. The response body is already in the correct format.
266295 if (HttpBody .getDescriptor ().equals (spec .methodDescriptor .getOutputType ())) {
267296 final HttpData data = httpResponse .content ();
@@ -304,6 +333,60 @@ private AggregatedHttpResponse convertResponse(TranscodingSpec spec, AggregatedH
304333 }
305334 }
306335
336+ private static HttpData convertToProto (HttpData requestContent ,
337+ @ Nullable JsonProtoMarshaller jsonProtoMarshaller )
338+ throws IOException {
339+ final byte [] jsonBuf = requestContent .array ();
340+ assert jsonProtoMarshaller != null ;
341+ final ByteBuf protoBuf = jsonProtoMarshaller .jsonToProto (jsonBuf );
342+ return HttpData .wrap (protoBuf );
343+ }
344+
345+ private static AggregatedHttpResponse convertProtoResponseToJson (
346+ AggregatedHttpResponse httpResponse ,
347+ @ Nullable JsonProtoMarshaller jsonProtoMarshaller ) throws Exception {
348+ final HttpData data = httpResponse .content ();
349+ final ByteBuf protoBuf = data .byteBuf ();
350+ assert jsonProtoMarshaller != null ;
351+ final ByteBuf jsonBuf = jsonProtoMarshaller .protoToJson (protoBuf , true );
352+ return AggregatedHttpResponse .of (httpResponse .headers (), HttpData .wrap (jsonBuf ));
353+ }
354+
355+ @ SuppressWarnings ("unchecked" )
356+ private static GrpcMessageMarshaller <Object , Object > newMessageMarshaller (
357+ ServiceRequestContext ctx , SerializationFormat serializationFormat ,
358+ HttpJsonGrpcMethod method ) {
359+ return new GrpcMessageMarshaller <>(
360+ ctx .alloc (),
361+ serializationFormat ,
362+ (io .grpc .MethodDescriptor <Object , Object >) method .grpcMethodDescriptor ,
363+ method .jsonMarshaller ,
364+ false ,
365+ true );
366+ }
367+
368+ private static final class JsonProtoMarshaller {
369+ private final GrpcMessageMarshaller <Object , Object > json ;
370+ private final GrpcMessageMarshaller <Object , Object > proto ;
371+
372+ JsonProtoMarshaller (ServiceRequestContext ctx , HttpJsonGrpcMethod method ) {
373+ json = newMessageMarshaller (ctx , GrpcSerializationFormats .JSON , method );
374+ proto = newMessageMarshaller (ctx , GrpcSerializationFormats .PROTO , method );
375+ }
376+
377+ ByteBuf jsonToProto (byte [] jsonBytes ) throws IOException {
378+ try (ByteArrayInputStream inputStream = new ByteArrayInputStream (jsonBytes )) {
379+ final Object message = json .deserializeRequest (inputStream );
380+ return proto .serializeRequest (message );
381+ }
382+ }
383+
384+ ByteBuf protoToJson (ByteBuf protoBuf , boolean grpcWebText ) throws IOException {
385+ final Object message = proto .deserializeResponse (protoBuf , grpcWebText );
386+ return json .serializeResponse (message );
387+ }
388+ }
389+
307390 private static HttpData convertToHttpBody (AggregatedHttpRequest request ) throws IOException {
308391 final ObjectNode body = mapper .createObjectNode ();
309392
@@ -464,7 +547,7 @@ private static void setParametersToNode(ObjectNode root,
464547 continue ;
465548 }
466549
467- if (field .javaType == JavaType .MESSAGE ) {
550+ if (field .type () == JavaType .MESSAGE ) {
468551 throw new IllegalArgumentException (
469552 "Unsupported message type: " + field .descriptor .getFullName ());
470553 }
@@ -759,13 +842,19 @@ private class HttpJsonResponseHandler implements ResponseHandler {
759842 private final ServiceRequestContext ctx ;
760843 private final CompletableFuture <HttpResponse > responseFuture ;
761844 private final TranscodingSpec spec ;
845+ private final boolean useProto ;
846+ @ Nullable
847+ private final JsonProtoMarshaller jsonProtoMarshaller ;
762848
763849 HttpJsonResponseHandler (ServiceRequestContext ctx ,
764850 CompletableFuture <HttpResponse > responseFuture ,
765- TranscodingSpec spec ) {
851+ TranscodingSpec spec , boolean useProto ,
852+ @ Nullable JsonProtoMarshaller jsonProtoMarshaller ) {
766853 this .ctx = ctx ;
767854 this .responseFuture = responseFuture ;
768855 this .spec = spec ;
856+ this .useProto = useProto ;
857+ this .jsonProtoMarshaller = jsonProtoMarshaller ;
769858 }
770859
771860 @ Override
@@ -786,7 +875,8 @@ public void handle(@Nullable AggregatedHttpResponse aggregatedResponse,
786875 return ;
787876 }
788877 assert aggregatedResponse != null ;
789- final AggregatedHttpResponse convertedResponse = convertResponse (spec , aggregatedResponse );
878+ final AggregatedHttpResponse convertedResponse =
879+ convertResponse (spec , useProto , aggregatedResponse , jsonProtoMarshaller );
790880 responseFuture .complete (convertedResponse .toHttpResponse ());
791881 } catch (Exception e ) {
792882 responseFuture .completeExceptionally (e );
0 commit comments