|
10 | 10 | import java.util.List;
|
11 | 11 | import java.util.Map;
|
12 | 12 | import java.util.Objects;
|
13 |
| -import java.util.Set; |
14 | 13 | import java.util.UUID;
|
15 | 14 | import java.util.concurrent.ConcurrentHashMap;
|
16 | 15 | import java.util.concurrent.ConcurrentMap;
|
|
40 | 39 | import org.springframework.messaging.handler.annotation.MessageMapping;
|
41 | 40 | import org.springframework.messaging.rsocket.RSocketRequester;
|
42 | 41 | import org.springframework.messaging.rsocket.annotation.ConnectMapping;
|
| 42 | +import org.springframework.util.Base64Utils; |
| 43 | +import org.springframework.util.StringUtils; |
| 44 | +import org.springframework.web.bind.annotation.ExceptionHandler; |
43 | 45 | import org.springframework.web.bind.annotation.GetMapping;
|
44 | 46 | import org.springframework.web.bind.annotation.PathVariable;
|
45 | 47 | import org.springframework.web.bind.annotation.RequestMapping;
|
@@ -96,6 +98,7 @@ public Collection<UUID> requesters() {
|
96 | 98 |
|
97 | 99 | @RequestMapping(path = "**")
|
98 | 100 | public Mono<Void> proxy(ServerHttpRequest request, ServerHttpResponse response) throws Exception {
|
| 101 | + this.checkAuthorization(request); |
99 | 102 | final HttpHeaders httpHeaders = setForwardHeaders(request);
|
100 | 103 | final HttpRequestMetadata httpRequestMetadata = new HttpRequestMetadata(request.getMethod(), request.getURI(), httpHeaders);
|
101 | 104 | final Flux<DataBuffer> responseStream;
|
@@ -230,8 +233,48 @@ else if ("https".equals(scheme) || "wss".equals(scheme)) {
|
230 | 233 | httpHeaders.addAll(source);
|
231 | 234 | final String remoteAddress = request.getRemoteAddress().getAddress().getHostAddress();
|
232 | 235 | final String forwarded = String.format("for=%s;host=%s:%d;proto=%s", remoteAddress, uri.getHost(), port, scheme);
|
233 |
| - httpHeaders.set("Forwarded", forwarded); |
234 |
| - httpHeaders.set("X-Real-IP", remoteAddress); |
235 |
| - return httpHeaders; |
| 236 | + httpHeaders.set("Forwarded", forwarded); httpHeaders.set("X-Real-IP", remoteAddress); return httpHeaders; |
| 237 | + } |
| 238 | + |
| 239 | + |
| 240 | + @ExceptionHandler(AuthorizationException.class) |
| 241 | + public Mono<ResponseEntity<Map<String, ?>>> handleAuthorizationException(AuthorizationException e) { |
| 242 | + return Mono.just(ResponseEntity.status(HttpStatus.UNAUTHORIZED) |
| 243 | + .header(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"Tsunagu API\"") |
| 244 | + .body(Map.of("error", Map.of("message", e.getMessage(), "type", "invalid_request_error", "code", e.code)))); |
| 245 | + } |
| 246 | + |
| 247 | + void checkAuthorization(ServerHttpRequest request) { |
| 248 | + if (StringUtils.hasText(this.props.getAuthorizationToken())) { |
| 249 | + final String authorization = request.getHeaders().getFirst(HttpHeaders.AUTHORIZATION); |
| 250 | + if (StringUtils.hasText(authorization)) { |
| 251 | + if (authorization.startsWith("Bearer") || authorization.startsWith("bearer")) { |
| 252 | + final String token = authorization.replace("Bearer ", "").replace("bearer ", ""); |
| 253 | + if (!Objects.equals(this.props.getAuthorizationToken(), token)) { |
| 254 | + throw new AuthorizationException("Incorrect API key provided: " + token, "invalid_api_key"); |
| 255 | + } |
| 256 | + } |
| 257 | + else if (authorization.startsWith("Basic") || authorization.startsWith("basic")) { |
| 258 | + final String basic = authorization.replace("Basic ", "").replace("basic ", ""); |
| 259 | + final String token = new String(Base64Utils.decodeFromString(basic)).split(":", 2)[1]; |
| 260 | + if (!Objects.equals(this.props.getAuthorizationToken(), token)) { |
| 261 | + throw new AuthorizationException("Incorrect API key provided: " + token, "invalid_api_key"); |
| 262 | + } |
| 263 | + } |
| 264 | + } |
| 265 | + else { |
| 266 | + throw new AuthorizationException("You didn't provide an API key. You need to provide your API key in an Authorization header using Bearer auth (i.e. Authorization: Bearer YOUR_KEY), or as the password field (with blank username) if you're accesing the API from your browser and are prompted for a username and password.", ""); |
| 267 | + } |
| 268 | + } |
| 269 | + } |
| 270 | + |
| 271 | + public static class AuthorizationException extends RuntimeException { |
| 272 | + |
| 273 | + |
| 274 | + private final String code; |
| 275 | + |
| 276 | + public AuthorizationException(String message, String code) { |
| 277 | + super(message); this.code = code; |
| 278 | + } |
236 | 279 | }
|
237 | 280 | }
|
0 commit comments