5
5
import java .util .Collections ;
6
6
import java .util .Set ;
7
7
import java .util .concurrent .ConcurrentHashMap ;
8
+ import java .util .function .BiFunction ;
8
9
9
10
import com .fasterxml .jackson .databind .ObjectMapper ;
10
11
import io .netty .buffer .ByteBufUtil ;
12
+ import io .netty .buffer .PooledByteBufAllocator ;
13
+ import io .netty .buffer .Unpooled ;
11
14
import io .rsocket .Payload ;
15
+ import io .rsocket .core .RSocketClient ;
12
16
import io .rsocket .util .DefaultPayload ;
17
+ import org .reactivestreams .Publisher ;
13
18
import org .slf4j .Logger ;
14
19
import org .slf4j .LoggerFactory ;
15
20
import reactor .core .publisher .Flux ;
16
21
import reactor .core .publisher .Mono ;
22
+ import reactor .core .publisher .Signal ;
17
23
18
24
import org .springframework .core .io .buffer .DataBuffer ;
19
- import org .springframework .core .io .buffer .DataBufferFactory ;
20
25
import org .springframework .core .io .buffer .DataBufferUtils ;
21
- import org .springframework .core .io .buffer .DefaultDataBufferFactory ;
26
+ import org .springframework .core .io .buffer .NettyDataBufferFactory ;
22
27
import org .springframework .http .HttpStatus ;
23
28
import org .springframework .http .ResponseEntity ;
24
29
import org .springframework .http .server .reactive .ServerHttpRequest ;
@@ -37,22 +42,36 @@ public class TsunaguController {
37
42
38
43
private final ObjectMapper objectMapper ;
39
44
40
- private final DataBufferFactory dataBufferFactory = DefaultDataBufferFactory . sharedInstance ;
45
+ private final NettyDataBufferFactory dataBufferFactory = new NettyDataBufferFactory ( PooledByteBufAllocator . DEFAULT ) ;
41
46
42
47
public TsunaguController (ObjectMapper objectMapper ) {
43
48
this .objectMapper = objectMapper ;
44
49
}
45
50
46
51
@ RequestMapping (path = "**" )
47
- public Mono <ResponseEntity <?>> proxy (ServerHttpRequest request ) {
48
- final RSocketRequester requester = this .getRequester ();
52
+ public Mono <ResponseEntity <?>> proxy (ServerHttpRequest request ) throws Exception {
53
+ final RSocketClient rsocketClient = this .getRequester (). rsocketClient ();
49
54
final HttpRequestMetadata httpRequestMetadata = new HttpRequestMetadata (request .getMethod (), request .getURI (), request .getHeaders ());
50
- final Mono <Payload > requestPayload = Mono .fromCallable (() -> this .objectMapper .writeValueAsBytes (httpRequestMetadata ))
51
- .map (metadata -> DefaultPayload .create (new byte [] {}, metadata ));
52
- final Flux <Payload > responseStream = requester .rsocketClient ().requestStream (requestPayload );
53
- return responseStream .<ResponseEntity <?>>switchOnFirst ((signal , flux ) -> {
55
+ final byte [] metadata = this .objectMapper .writeValueAsBytes (httpRequestMetadata );
56
+ final Flux <Payload > responseStream ;
57
+ if (httpRequestMetadata .hasBody ()) {
58
+ final Flux <Payload > requestPayload = request .getBody ()
59
+ .map (NettyDataBufferFactory ::toByteBuf )
60
+ .map (data -> DefaultPayload .create (data , Unpooled .copiedBuffer (metadata )))
61
+ .switchIfEmpty (Mono .fromCallable (() -> DefaultPayload .create (new byte [] {}, metadata )));
62
+ responseStream = rsocketClient .requestChannel (requestPayload );
63
+ }
64
+ else {
65
+ final Mono <Payload > requestPayload = Mono .just (DefaultPayload .create (new byte [] {}, metadata ));
66
+ responseStream = rsocketClient .requestStream (requestPayload );
67
+ }
68
+ return responseStream .switchOnFirst (this .handleResponse (httpRequestMetadata )).single ();
69
+ }
70
+
71
+ BiFunction <Signal <? extends Payload >, Flux <Payload >, Publisher <? extends ResponseEntity <?>>> handleResponse (HttpRequestMetadata httpRequestMetadata ) {
72
+ return (signal , flux ) -> {
54
73
final byte [] httpResponseMetadataBytes = ByteBufUtil .getBytes (signal .get ().metadata ());
55
- final Mono <DataBuffer > bodyMono = DataBufferUtils .join (flux .map (payload -> dataBufferFactory .wrap (payload .getData ())));
74
+ final Mono <DataBuffer > bodyMono = DataBufferUtils .join (flux .map (payload -> dataBufferFactory .wrap (payload .data ())));
56
75
try {
57
76
final HttpResponseMetadata httpResponseMetadata = this .objectMapper .readValue (httpResponseMetadataBytes , HttpResponseMetadata .class );
58
77
log .info ("\n request:\t {}\n response:\t {}" , httpRequestMetadata , httpResponseMetadata );
@@ -63,7 +82,7 @@ public Mono<ResponseEntity<?>> proxy(ServerHttpRequest request) {
63
82
catch (IOException e ) {
64
83
throw new UncheckedIOException (e );
65
84
}
66
- }). single () ;
85
+ };
67
86
}
68
87
69
88
private RSocketRequester getRequester () {
0 commit comments