Skip to content

Commit b2299e2

Browse files
authored
Merge branch 'main' into poc/istio-step4
2 parents 40a118a + 233b5d8 commit b2299e2

20 files changed

+461
-481
lines changed

grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/GrpcMessageMarshaller.java

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,19 @@ public ByteBuf serializeRequest(I message) throws IOException {
112112
}
113113

114114
public I deserializeRequest(DeframedMessage message, boolean grpcWebText) throws IOException {
115-
InputStream messageStream = message.stream();
115+
final InputStream messageStream = message.stream();
116116
final ByteBuf buf = message.buf();
117117
if (buf != null) {
118-
try {
119-
switch (requestType) {
120-
case PROTOBUF:
121-
final PrototypeMarshaller<I> marshaller = (PrototypeMarshaller<I>) requestMarshaller;
122-
// PrototypeMarshaller<I>.getMessagePrototype will always parse to I
123-
@SuppressWarnings("unchecked")
124-
final I msg = (I) deserializeProto(marshaller, buf);
125-
return msg;
126-
default:
127-
// Fallback to using the method's stream marshaller.
128-
messageStream = new ByteBufInputStream(buf.retain(), true);
129-
break;
130-
}
131-
} finally {
132-
if (!unsafeWrapDeserializedBuffer || grpcWebText) {
133-
buf.release();
134-
}
135-
}
118+
return deserializeRequest(buf, grpcWebText);
136119
}
137120

138121
assert messageStream != null;
139-
try (InputStream msg = messageStream) {
122+
return deserializeRequest(messageStream);
123+
}
124+
125+
public I deserializeRequest(InputStream message) throws IOException {
126+
requireNonNull(message, "message");
127+
try (InputStream msg = message) {
140128
if (isProto) {
141129
return method.parseRequest(msg);
142130
} else {
@@ -146,6 +134,29 @@ public I deserializeRequest(DeframedMessage message, boolean grpcWebText) throws
146134
}
147135
}
148136

137+
public I deserializeRequest(ByteBuf buf, boolean grpcWebText) throws IOException {
138+
requireNonNull(buf, "buf");
139+
try {
140+
switch (requestType) {
141+
case PROTOBUF:
142+
final PrototypeMarshaller<I> marshaller = (PrototypeMarshaller<I>) requestMarshaller;
143+
// PrototypeMarshaller<I>.getMessagePrototype will always parse to I
144+
@SuppressWarnings("unchecked")
145+
final I msg = (I) deserializeProto(marshaller, buf);
146+
return msg;
147+
default:
148+
// Fallback to using the method's stream marshaller.
149+
try (InputStream messageStream = new ByteBufInputStream(buf.retain(), true)) {
150+
return deserializeRequest(messageStream);
151+
}
152+
}
153+
} finally {
154+
if (!unsafeWrapDeserializedBuffer || grpcWebText) {
155+
buf.release();
156+
}
157+
}
158+
}
159+
149160
public ByteBuf serializeResponse(O message) throws IOException {
150161
switch (responseType) {
151162
case PROTOBUF:
@@ -170,32 +181,19 @@ public ByteBuf serializeResponse(O message) throws IOException {
170181
}
171182

172183
public O deserializeResponse(DeframedMessage message, boolean grpcWebText) throws IOException {
173-
InputStream messageStream = message.stream();
184+
final InputStream messageStream = message.stream();
174185
final ByteBuf buf = message.buf();
175186
if (buf != null) {
176-
try {
177-
switch (responseType) {
178-
case PROTOBUF:
179-
final PrototypeMarshaller<O> marshaller =
180-
(PrototypeMarshaller<O>) method.getResponseMarshaller();
181-
// PrototypeMarshaller<I>.getMessagePrototype will always parse to I
182-
@SuppressWarnings("unchecked")
183-
final O msg = (O) deserializeProto(marshaller, buf);
184-
return msg;
185-
default:
186-
// Fallback to using the method's stream marshaller.
187-
messageStream = new ByteBufInputStream(buf.retain(), true);
188-
break;
189-
}
190-
} finally {
191-
if (!unsafeWrapDeserializedBuffer || grpcWebText) {
192-
buf.release();
193-
}
194-
}
187+
return deserializeResponse(buf, grpcWebText);
195188
}
196189

197190
assert messageStream != null;
198-
try (InputStream msg = messageStream) {
191+
return deserializeResponse(messageStream);
192+
}
193+
194+
public O deserializeResponse(InputStream message) throws IOException {
195+
requireNonNull(message, "message");
196+
try (InputStream msg = message) {
199197
if (isProto) {
200198
return method.parseResponse(msg);
201199
} else {
@@ -205,6 +203,30 @@ public O deserializeResponse(DeframedMessage message, boolean grpcWebText) throw
205203
}
206204
}
207205

206+
public O deserializeResponse(ByteBuf buf, boolean grpcWebText) throws IOException {
207+
requireNonNull(buf, "buf");
208+
try {
209+
switch (responseType) {
210+
case PROTOBUF:
211+
final PrototypeMarshaller<O> marshaller =
212+
(PrototypeMarshaller<O>) method.getResponseMarshaller();
213+
// PrototypeMarshaller<I>.getMessagePrototype will always parse to I
214+
@SuppressWarnings("unchecked")
215+
final O msg = (O) deserializeProto(marshaller, buf);
216+
return msg;
217+
default:
218+
// Fallback to using the method's stream marshaller.
219+
try (InputStream messageStream = new ByteBufInputStream(buf.retain(), true)) {
220+
return deserializeResponse(messageStream);
221+
}
222+
}
223+
} finally {
224+
if (!unsafeWrapDeserializedBuffer || grpcWebText) {
225+
buf.release();
226+
}
227+
}
228+
}
229+
208230
private <T> ByteBuf serializeProto(PrototypeMarshaller<T> marshaller, Message message) throws IOException {
209231
if (isProto) {
210232
final int serializedSize = message.getSerializedSize();

grpc/src/main/java/com/linecorp/armeria/server/grpc/DelegatingHttpJsonTranscodingServiceBuilder.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public final class DelegatingHttpJsonTranscodingServiceBuilder {
5151

5252
private HttpJsonTranscodingOptions options = HttpJsonTranscodingOptions.of();
5353
private HttpService fallback = DEFAULT_FALLBACK;
54+
private boolean protoSerialization = true;
5455

5556
/**
5657
* Creates a new builder for the specified delegate.
@@ -77,6 +78,15 @@ public DelegatingHttpJsonTranscodingServiceBuilder options(HttpJsonTranscodingOp
7778
return this;
7879
}
7980

81+
/**
82+
* Sets whether to use Protocol Buffers serialization for transcoded gRPC requests sent to the delegate.
83+
* If {@code false}, JSON serialization is used instead. Defaults to {@code true}.
84+
*/
85+
public DelegatingHttpJsonTranscodingServiceBuilder protoSerialization(boolean protoSerialization) {
86+
this.protoSerialization = protoSerialization;
87+
return this;
88+
}
89+
8090
/**
8191
* Adds the {@link ServiceDescriptor}s that define HTTP/JSON mappings.
8292
*/
@@ -110,6 +120,7 @@ public DelegatingHttpJsonTranscodingService build() {
110120
final HttpJsonTranscoder transcoder =
111121
new HttpJsonTranscoderBuilder()
112122
.options(options)
123+
.protoSerialization(protoSerialization)
113124
.serviceDescriptors(serviceDescriptors)
114125
.build();
115126
if (transcoder == null) {

grpc/src/main/java/com/linecorp/armeria/server/grpc/GrpcServiceBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ public GrpcService build() {
10991099
final HttpJsonTranscoder transcoder =
11001100
new HttpJsonTranscoderBuilder()
11011101
.options(httpJsonTranscodingOptions)
1102+
.protoSerialization(false)
11021103
.serviceDefinitions(grpcService.services())
11031104
.build();
11041105
if (transcoder != null) {

grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscoder.java

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static com.google.common.collect.ImmutableSet.toImmutableSet;
2323
import static java.util.Objects.requireNonNull;
2424

25+
import java.io.ByteArrayInputStream;
2526
import java.io.IOException;
2627
import java.util.AbstractMap.SimpleImmutableEntry;
2728
import java.util.Base64;
@@ -69,12 +70,14 @@
6970
import com.linecorp.armeria.common.RequestHeaders;
7071
import com.linecorp.armeria.common.RequestHeadersBuilder;
7172
import com.linecorp.armeria.common.ResponseHeaders;
73+
import com.linecorp.armeria.common.SerializationFormat;
7274
import com.linecorp.armeria.common.annotation.Nullable;
7375
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
7476
import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames;
7577
import com.linecorp.armeria.common.logging.RequestLogProperty;
7678
import com.linecorp.armeria.common.util.SafeCloseable;
7779
import com.linecorp.armeria.internal.common.JacksonUtil;
80+
import com.linecorp.armeria.internal.common.grpc.GrpcMessageMarshaller;
7881
import com.linecorp.armeria.internal.server.grpc.HttpEndpointSpecification;
7982
import com.linecorp.armeria.internal.server.grpc.HttpEndpointSpecification.Parameter;
8083
import com.linecorp.armeria.internal.server.grpc.HttpEndpointSupport;
@@ -90,6 +93,7 @@
9093
import com.linecorp.armeria.unsafe.PooledObjects;
9194

9295
import io.grpc.Status;
96+
import io.netty.buffer.ByteBuf;
9397
import io.netty.util.AttributeKey;
9498

9599
final class HttpJsonTranscoder implements HttpEndpointSupport {
@@ -146,12 +150,15 @@ private static String grpcPath(MethodDescriptor methodDescriptor) {
146150
AttributeKey.valueOf(FramedGrpcService.class, "HTTP_JSON_GRPC_METHOD_INFO");
147151

148152
private final HttpJsonTranscodingOptions options;
153+
private final boolean protoSerialization;
149154
private final Map<Route, TranscodingSpec> routeAndSpecs;
150155
private final Set<Route> routes;
151156

152157
HttpJsonTranscoder(Map<Route, TranscodingSpec> routeAndSpecs,
153-
HttpJsonTranscodingOptions httpJsonTranscodingOptions) {
158+
HttpJsonTranscodingOptions httpJsonTranscodingOptions,
159+
boolean protoSerialization) {
154160
options = requireNonNull(httpJsonTranscodingOptions, "httpJsonTranscodingOptions");
161+
this.protoSerialization = protoSerialization;
155162
this.routeAndSpecs = routeAndSpecs;
156163

157164
final LinkedHashSet<Route> linkedHashSet = new LinkedHashSet<>(routeAndSpecs.size());
@@ -214,8 +221,10 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
214221
"gRPC encoding is not supported for non-framed requests.");
215222
}
216223

224+
final boolean useProto = protoSerialization;
217225
grpcHeaders.method(HttpMethod.POST)
218-
.contentType(GrpcSerializationFormats.JSON.mediaType());
226+
.contentType(useProto ? GrpcSerializationFormats.PROTO.mediaType()
227+
: GrpcSerializationFormats.JSON.mediaType());
219228
grpcHeaders.path(grpcPath(spec.methodDescriptor));
220229
// All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
221230
// clear the header if it's present.
@@ -244,10 +253,20 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
244253
requestContent = convertToJson(ctx, clientRequest, spec);
245254
}
246255

256+
final JsonProtoMarshaller jsonProtoMarshaller =
257+
useProto ? new JsonProtoMarshaller(ctx, method) : null;
258+
final HttpData transcodedRequestContent;
259+
if (useProto) {
260+
transcodedRequestContent = convertToProto(requestContent, jsonProtoMarshaller);
261+
} else {
262+
transcodedRequestContent = requestContent;
263+
}
247264
final ResponseHandler responseHandler =
248-
new HttpJsonResponseHandler(ctx, responseFuture, spec);
265+
new HttpJsonResponseHandler(ctx, responseFuture, spec,
266+
useProto, jsonProtoMarshaller);
249267
UnframedGrpcSupport.frameAndServe(
250-
delegate, ctx, grpcHeaders.build(), requestContent, responseHandler);
268+
delegate, ctx, grpcHeaders.build(), transcodedRequestContent,
269+
responseHandler);
251270
} catch (IllegalArgumentException iae) {
252271
responseFuture.completeExceptionally(
253272
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
@@ -261,7 +280,17 @@ HttpResponse serve(ServiceRequestContext ctx, HttpRequest req,
261280
return HttpResponse.of(responseFuture);
262281
}
263282

264-
private AggregatedHttpResponse convertResponse(TranscodingSpec spec, AggregatedHttpResponse httpResponse) {
283+
private static AggregatedHttpResponse convertResponse(
284+
TranscodingSpec spec, boolean useProto, AggregatedHttpResponse httpResponse,
285+
@Nullable JsonProtoMarshaller jsonProtoMarshaller) throws Exception {
286+
final AggregatedHttpResponse jsonResponse =
287+
useProto ? convertProtoResponseToJson(httpResponse, jsonProtoMarshaller)
288+
: httpResponse;
289+
return convertJsonResponse(spec, jsonResponse);
290+
}
291+
292+
private static AggregatedHttpResponse convertJsonResponse(
293+
TranscodingSpec spec, AggregatedHttpResponse httpResponse) {
265294
// Ignore the spec if the method is HttpBody. The response body is already in the correct format.
266295
if (HttpBody.getDescriptor().equals(spec.methodDescriptor.getOutputType())) {
267296
final HttpData data = httpResponse.content();
@@ -304,6 +333,60 @@ private AggregatedHttpResponse convertResponse(TranscodingSpec spec, AggregatedH
304333
}
305334
}
306335

336+
private static HttpData convertToProto(HttpData requestContent,
337+
@Nullable JsonProtoMarshaller jsonProtoMarshaller)
338+
throws IOException {
339+
final byte[] jsonBuf = requestContent.array();
340+
assert jsonProtoMarshaller != null;
341+
final ByteBuf protoBuf = jsonProtoMarshaller.jsonToProto(jsonBuf);
342+
return HttpData.wrap(protoBuf);
343+
}
344+
345+
private static AggregatedHttpResponse convertProtoResponseToJson(
346+
AggregatedHttpResponse httpResponse,
347+
@Nullable JsonProtoMarshaller jsonProtoMarshaller) throws Exception {
348+
final HttpData data = httpResponse.content();
349+
final ByteBuf protoBuf = data.byteBuf();
350+
assert jsonProtoMarshaller != null;
351+
final ByteBuf jsonBuf = jsonProtoMarshaller.protoToJson(protoBuf, true);
352+
return AggregatedHttpResponse.of(httpResponse.headers(), HttpData.wrap(jsonBuf));
353+
}
354+
355+
@SuppressWarnings("unchecked")
356+
private static GrpcMessageMarshaller<Object, Object> newMessageMarshaller(
357+
ServiceRequestContext ctx, SerializationFormat serializationFormat,
358+
HttpJsonGrpcMethod method) {
359+
return new GrpcMessageMarshaller<>(
360+
ctx.alloc(),
361+
serializationFormat,
362+
(io.grpc.MethodDescriptor<Object, Object>) method.grpcMethodDescriptor,
363+
method.jsonMarshaller,
364+
false,
365+
true);
366+
}
367+
368+
private static final class JsonProtoMarshaller {
369+
private final GrpcMessageMarshaller<Object, Object> json;
370+
private final GrpcMessageMarshaller<Object, Object> proto;
371+
372+
JsonProtoMarshaller(ServiceRequestContext ctx, HttpJsonGrpcMethod method) {
373+
json = newMessageMarshaller(ctx, GrpcSerializationFormats.JSON, method);
374+
proto = newMessageMarshaller(ctx, GrpcSerializationFormats.PROTO, method);
375+
}
376+
377+
ByteBuf jsonToProto(byte[] jsonBytes) throws IOException {
378+
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(jsonBytes)) {
379+
final Object message = json.deserializeRequest(inputStream);
380+
return proto.serializeRequest(message);
381+
}
382+
}
383+
384+
ByteBuf protoToJson(ByteBuf protoBuf, boolean grpcWebText) throws IOException {
385+
final Object message = proto.deserializeResponse(protoBuf, grpcWebText);
386+
return json.serializeResponse(message);
387+
}
388+
}
389+
307390
private static HttpData convertToHttpBody(AggregatedHttpRequest request) throws IOException {
308391
final ObjectNode body = mapper.createObjectNode();
309392

@@ -464,7 +547,7 @@ private static void setParametersToNode(ObjectNode root,
464547
continue;
465548
}
466549

467-
if (field.javaType == JavaType.MESSAGE) {
550+
if (field.type() == JavaType.MESSAGE) {
468551
throw new IllegalArgumentException(
469552
"Unsupported message type: " + field.descriptor.getFullName());
470553
}
@@ -759,13 +842,19 @@ private class HttpJsonResponseHandler implements ResponseHandler {
759842
private final ServiceRequestContext ctx;
760843
private final CompletableFuture<HttpResponse> responseFuture;
761844
private final TranscodingSpec spec;
845+
private final boolean useProto;
846+
@Nullable
847+
private final JsonProtoMarshaller jsonProtoMarshaller;
762848

763849
HttpJsonResponseHandler(ServiceRequestContext ctx,
764850
CompletableFuture<HttpResponse> responseFuture,
765-
TranscodingSpec spec) {
851+
TranscodingSpec spec, boolean useProto,
852+
@Nullable JsonProtoMarshaller jsonProtoMarshaller) {
766853
this.ctx = ctx;
767854
this.responseFuture = responseFuture;
768855
this.spec = spec;
856+
this.useProto = useProto;
857+
this.jsonProtoMarshaller = jsonProtoMarshaller;
769858
}
770859

771860
@Override
@@ -786,7 +875,8 @@ public void handle(@Nullable AggregatedHttpResponse aggregatedResponse,
786875
return;
787876
}
788877
assert aggregatedResponse != null;
789-
final AggregatedHttpResponse convertedResponse = convertResponse(spec, aggregatedResponse);
878+
final AggregatedHttpResponse convertedResponse =
879+
convertResponse(spec, useProto, aggregatedResponse, jsonProtoMarshaller);
790880
responseFuture.complete(convertedResponse.toHttpResponse());
791881
} catch (Exception e) {
792882
responseFuture.completeExceptionally(e);

0 commit comments

Comments
 (0)