Skip to content

Commit

Permalink
pass HttpHeaders to ServerBaseUrlCustomizer
Browse files Browse the repository at this point in the history
  • Loading branch information
mshima committed Apr 22, 2024
1 parent e14f975 commit 344ac78
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

package org.springdoc.core.customizers;

import org.springframework.http.HttpHeaders;

/**
* The interface Server Base URL customiser.
* @author skylar -stark
Expand All @@ -35,7 +37,8 @@ public interface ServerBaseUrlCustomizer {
* Customise.
*
* @param serverBaseUrl the serverBaseUrl.
* @param httpHeaders request headers.
* @return the customised serverBaseUrl
*/
String customize(String serverBaseUrl);
String customize(String serverBaseUrl, HttpHeaders httpHeaders);
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Controller;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
Expand Down Expand Up @@ -490,12 +491,12 @@ public Schema resolveProperties(Schema schema, Locale locale) {
*
* @param serverBaseUrl the server base url
*/
public void setServerBaseUrl(String serverBaseUrl) {
public void setServerBaseUrl(String serverBaseUrl, HttpHeaders httpHeaders) {
String customServerBaseUrl = serverBaseUrl;

if (serverBaseUrlCustomizers.isPresent()) {
for (ServerBaseUrlCustomizer customizer : serverBaseUrlCustomizers.get()) {
customServerBaseUrl = customizer.customize(customServerBaseUrl);
customServerBaseUrl = customizer.customize(customServerBaseUrl, httpHeaders);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

import org.springframework.beans.factory.ObjectFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpHeaders;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.bind.annotation.RequestMethod;

Expand Down Expand Up @@ -190,7 +191,7 @@ void preLoadingModeShouldNotOverwriteServers() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServersPresent(true);
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

String customUrl = "https://custom.com";
Expand All @@ -212,7 +213,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
Thread.sleep(1_000);

// emulate generating base url
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new HttpHeaders());
openAPIService.updateServers(openAPI);
Locale locale = Locale.US;
OpenAPI after = resource.getOpenApi(locale);
Expand All @@ -224,7 +225,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
void serverBaseUrlCustomisersTest() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

SpringDocConfigProperties properties = new SpringDocConfigProperties();
Expand All @@ -247,37 +248,37 @@ springDocProviders, new SpringDocCustomizers(Optional.empty(),Optional.empty(),O

// Test that setting generated URL works fine with no customizers present
String generatedUrl = "https://generated-url.com/context-path";
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new HttpHeaders());
openAPIService.updateServers(openAPI);
OpenAPI after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is(generatedUrl));

// Test that adding a serverBaseUrlCustomizer has the desired effect
ServerBaseUrlCustomizer serverBaseUrlCustomizer = serverBaseUrl -> serverBaseUrl.replace("/context-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomizer = (serverBaseUrl, headers) -> serverBaseUrl.replace("/context-path", "");
List<ServerBaseUrlCustomizer> serverBaseUrlCustomizerList = new ArrayList<>();
serverBaseUrlCustomizerList.add(serverBaseUrlCustomizer);

ReflectionTestUtils.setField(openAPIService, "serverBaseUrlCustomizers", Optional.of(serverBaseUrlCustomizerList));
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new HttpHeaders());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));

// Test that serverBaseUrlCustomisers are performed in order
generatedUrl = "https://generated-url.com/context-path/second-path";
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = serverBaseUrl -> serverBaseUrl.replace("/context-path/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = (serverBaseUrl, headers) -> serverBaseUrl.replace("/context-path/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser2);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new HttpHeaders());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com/second-path"));

// Test that all serverBaseUrlCustomisers in the List are performed
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = serverBaseUrl -> serverBaseUrl.replace("/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = (serverBaseUrl, headers) -> serverBaseUrl.replace("/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser3);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new HttpHeaders());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.boot.actuate.endpoint.web.annotation.RestControllerEndpoint;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.bind.annotation.GetMapping;
Expand Down Expand Up @@ -131,7 +132,7 @@ public Mono<byte[]> openapiYaml(ServerHttpRequest serverHttpRequest, Locale loca
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
URI uri = getActuatorURI(serverHttpRequest.getURI().getScheme(), serverHttpRequest.getURI().getHost());
openAPIService.setServerBaseUrl(uri.toString());
openAPIService.setServerBaseUrl(uri.toString(), HttpHeaders.readOnlyHttpHeaders(serverHttpRequest.getHeaders()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import reactor.core.publisher.Mono;

import org.springframework.beans.factory.ObjectFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMethod;
Expand Down Expand Up @@ -229,7 +230,7 @@ protected void getWebFluxRouterFunctionPaths(Locale locale, OpenAPI openAPI) {
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
initOpenAPIBuilder(locale);
String serverUrl = getServerUrl(serverHttpRequest, apiDocsUrl);
openAPIService.setServerBaseUrl(serverUrl);
openAPIService.setServerBaseUrl(serverUrl, HttpHeaders.readOnlyHttpHeaders(serverHttpRequest.getHeaders()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.autoconfigure.web.reactive.WebFluxTest;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.result.method.RequestMappingInfo;
Expand Down Expand Up @@ -67,6 +68,7 @@ public void shouldGenerateOperationIdsDeterministically() throws Exception {
shuffleSpringHandlerMethods();

ServerHttpRequest request = mock(ServerHttpRequest.class);
when(request.getHeaders()).thenReturn(new HttpHeaders());
when(request.getURI()).thenReturn(URI.create("http://localhost"));

String expected = getContent("results/app81.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@

import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMethod;
Expand Down Expand Up @@ -244,7 +246,8 @@ private Comparator<RequestMappingInfo> byReversedRequestMappingInfos() {
protected void calculateServerUrl(HttpServletRequest request, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
String calculatedUrl = getServerUrl(request, apiDocsUrl);
openAPIService.setServerBaseUrl(calculatedUrl);
ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request);
openAPIService.setServerBaseUrl(calculatedUrl, HttpHeaders.readOnlyHttpHeaders(serverRequest.getHeaders()));
}

/**
Expand Down

0 comments on commit 344ac78

Please sign in to comment.