Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass HttpRequest to ServerBaseUrlCustomizer #2589

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

package org.springdoc.core.customizers;

import org.springframework.http.HttpRequest;

/**
* The interface Server Base URL customiser.
* @author skylar -stark
Expand All @@ -35,7 +37,8 @@ public interface ServerBaseUrlCustomizer {
* Customise.
*
* @param serverBaseUrl the serverBaseUrl.
* @param request the request.
* @return the customised serverBaseUrl
*/
String customize(String serverBaseUrl);
String customize(String serverBaseUrl, HttpRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.http.HttpRequest;
import org.springframework.stereotype.Controller;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
Expand Down Expand Up @@ -490,12 +491,12 @@ public Schema resolveProperties(Schema schema, Locale locale) {
*
* @param serverBaseUrl the server base url
*/
public void setServerBaseUrl(String serverBaseUrl) {
public void setServerBaseUrl(String serverBaseUrl, HttpRequest httpRequest) {
String customServerBaseUrl = serverBaseUrl;

if (serverBaseUrlCustomizers.isPresent()) {
for (ServerBaseUrlCustomizer customizer : serverBaseUrlCustomizers.get()) {
customServerBaseUrl = customizer.customize(customServerBaseUrl);
customServerBaseUrl = customizer.customize(customServerBaseUrl, httpRequest);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.springframework.context.ApplicationContext;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.mock.http.client.MockClientHttpRequest;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
Expand Down Expand Up @@ -190,7 +191,7 @@ void preLoadingModeShouldNotOverwriteServers() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServersPresent(true);
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

String customUrl = "https://custom.com";
Expand All @@ -212,7 +213,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
Thread.sleep(1_000);

// emulate generating base url
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
Locale locale = Locale.US;
OpenAPI after = resource.getOpenApi(locale);
Expand All @@ -224,7 +225,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
void serverBaseUrlCustomisersTest() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

SpringDocConfigProperties properties = new SpringDocConfigProperties();
Expand All @@ -247,37 +248,37 @@ springDocProviders, new SpringDocCustomizers(Optional.empty(),Optional.empty(),O

// Test that setting generated URL works fine with no customizers present
String generatedUrl = "https://generated-url.com/context-path";
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
OpenAPI after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is(generatedUrl));

// Test that adding a serverBaseUrlCustomizer has the desired effect
ServerBaseUrlCustomizer serverBaseUrlCustomizer = serverBaseUrl -> serverBaseUrl.replace("/context-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomizer = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path", "");
List<ServerBaseUrlCustomizer> serverBaseUrlCustomizerList = new ArrayList<>();
serverBaseUrlCustomizerList.add(serverBaseUrlCustomizer);

ReflectionTestUtils.setField(openAPIService, "serverBaseUrlCustomizers", Optional.of(serverBaseUrlCustomizerList));
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));

// Test that serverBaseUrlCustomisers are performed in order
generatedUrl = "https://generated-url.com/context-path/second-path";
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = serverBaseUrl -> serverBaseUrl.replace("/context-path/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser2);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com/second-path"));

// Test that all serverBaseUrlCustomisers in the List are performed
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = serverBaseUrl -> serverBaseUrl.replace("/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = (serverBaseUrl, request) -> serverBaseUrl.replace("/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser3);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public Mono<byte[]> openapiYaml(ServerHttpRequest serverHttpRequest, Locale loca
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
URI uri = getActuatorURI(serverHttpRequest.getURI().getScheme(), serverHttpRequest.getURI().getHost());
openAPIService.setServerBaseUrl(uri.toString());
openAPIService.setServerBaseUrl(uri.toString(), serverHttpRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ protected void getWebFluxRouterFunctionPaths(Locale locale, OpenAPI openAPI) {
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
initOpenAPIBuilder(locale);
String serverUrl = getServerUrl(serverHttpRequest, apiDocsUrl);
openAPIService.setServerBaseUrl(serverUrl);
openAPIService.setServerBaseUrl(serverUrl, serverHttpRequest);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMethod;
Expand Down Expand Up @@ -244,7 +245,8 @@ private Comparator<RequestMappingInfo> byReversedRequestMappingInfos() {
protected void calculateServerUrl(HttpServletRequest request, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
String calculatedUrl = getServerUrl(request, apiDocsUrl);
openAPIService.setServerBaseUrl(calculatedUrl);
ServletServerHttpRequest serverRequest = request != null ? new ServletServerHttpRequest(request) : null;
openAPIService.setServerBaseUrl(calculatedUrl, serverRequest);
}

/**
Expand Down