Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions solr/core/src/java/org/apache/solr/security/MultiAuthPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.solr.security;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand All @@ -26,6 +27,7 @@
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpRequest;
import org.apache.http.protocol.HttpContext;
import org.apache.lucene.util.ResourceLoader;
Expand Down Expand Up @@ -59,6 +61,8 @@ public class MultiAuthPlugin extends AuthenticationPlugin
private static final String UNKNOWN_SCHEME = "";

private final Map<String, AuthenticationPlugin> pluginMap = new LinkedHashMap<>();
private final Map<String, String> realms = new LinkedHashMap<>();
private final List<String> WWWAuthenticateHeaders = new ArrayList<>();
private final ResourceLoader loader;
// the first of our plugins that allows anonymous requests
private AuthenticationPlugin allowsUnknown = null;
Expand Down Expand Up @@ -141,6 +145,7 @@ public void init(Map<String, Object> pluginConfig) {
}
initPluginForScheme((Map<String, Object>) s);
}
initWWWAuthenticateHeaders();
}

protected void initPluginForScheme(Map<String, Object> schemeMap) {
Expand All @@ -158,6 +163,11 @@ protected void initPluginForScheme(Map<String, Object> schemeMap) {
ErrorCode.SERVER_ERROR, "'class' is a required attribute: " + schemeMap);
}

String realm = (String) schemeConfig.remove("realm");
if (!StrUtils.isNullOrEmpty(realm)) {
realms.put(scheme, realm);
}

AuthenticationPlugin pluginForScheme = loader.newInstance(clazz, AuthenticationPlugin.class);
pluginForScheme.init(schemeConfig);
pluginMap.put(scheme.toLowerCase(Locale.ROOT), pluginForScheme);
Expand All @@ -171,6 +181,20 @@ protected void initPluginForScheme(Map<String, Object> schemeMap) {
}
}

private void initWWWAuthenticateHeaders() {
for (String scheme : pluginMap.keySet()) {
String realm = realms.get(scheme);
String realmStr = realm == null ? "" : " realm=\"" + realm + "\"";
WWWAuthenticateHeaders.add(scheme + realmStr);
}
}

private void addWWWAuthenticateHeaders(HttpServletResponse response) {
for (String wwwAuthHeader : WWWAuthenticateHeaders) {
response.addHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthHeader);
}
}

@Override
public void initializeMetrics(SolrMetricsContext parentContext, String scope) {
for (AuthenticationPlugin plugin : pluginMap.values()) {
Expand Down Expand Up @@ -202,6 +226,7 @@ public boolean doAuthenticate(
pluginInRequest.set(plugin);
result = plugin.doAuthenticate(request, response, filterChain);
} else {
addWWWAuthenticateHeaders(response);
response.sendError(ErrorCode.UNAUTHORIZED.code, "No Authorization header");
}
return result;
Expand All @@ -210,6 +235,7 @@ public boolean doAuthenticate(
final String scheme = getSchemeFromAuthHeader(authHeader);
final AuthenticationPlugin plugin = pluginMap.get(scheme);
if (plugin == null) {
addWWWAuthenticateHeaders(response);
response.sendError(
ErrorCode.UNAUTHORIZED.code, "Authorization scheme '" + scheme + "' not supported!");
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"class": "solr.MultiAuthPlugin",
"schemes": [{
"scheme": "basic",
"realm": "basicRealm",
"blockUnknown": true,
"class": "solr.BasicAuthPlugin",
"credentials": {
Expand All @@ -11,6 +12,11 @@
"forwardCredentials": false
},{
"scheme": "mock",
"realm": "mockRealm",
"class": "org.apache.solr.security.MultiAuthPluginTest$MockAuthPluginForTesting",
"blockUnknown": true
},{
"scheme": "mockNoRealm",
"class": "org.apache.solr.security.MultiAuthPluginTest$MockAuthPluginForTesting",
"blockUnknown": true
}]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.http.Header;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
Expand Down Expand Up @@ -105,6 +110,10 @@ public void testMultiAuthEditAPI() throws Exception {
new SecurityConfHandler.SecurityConfig()
.setData(Utils.fromJSONString(multiAuthPluginSecurityJson)));
securityConfHandler.securityConfEdited();

// verify "WWW-Authenticate" headers are returned
testWWWAuthenticateHeaders(httpClient, baseUrl);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this test actually be starting with verify? Instead of test? like verifySecurityStatus??

Copy link
Contributor Author

@laminelam laminelam Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that make sense as testWWWAuthenticateHeaders is not a UnitTest method. Will change it.


verifySecurityStatus(
httpClient,
baseUrl + authcPrefix,
Expand Down Expand Up @@ -269,6 +278,39 @@ private int doHttpGetAnonymous(HttpClient cl, String url) throws IOException {
return statusCode;
}

private void testWWWAuthenticateHeaders(HttpClient httpClient, String baseUrl) throws Exception {
HttpGet httpGet = new HttpGet(baseUrl + "/admin/info/system");
HttpResponse response = httpClient.execute(httpGet);
Header[] headers = response.getHeaders(HttpHeaders.WWW_AUTHENTICATE);
List<String> actualSchemes =
Arrays.stream(headers).map(Header::getValue).collect(Collectors.toList());

List<String> expectedSchemes = generateExpectedSchemes();
actualSchemes.sort(String.CASE_INSENSITIVE_ORDER);
expectedSchemes.sort(String.CASE_INSENSITIVE_ORDER);

assertEquals(
"The actual schemes and realms should match the expected ones exactly",
expectedSchemes.stream().map(s -> s.toLowerCase(Locale.ROOT)).collect(Collectors.toList()),
actualSchemes.stream().map(s -> s.toLowerCase(Locale.ROOT)).collect(Collectors.toList()));
}

@SuppressWarnings("unchecked")
private List<String> generateExpectedSchemes() {
Map<String, Object> data = securityConfHandler.getSecurityConfig(false).getData();
Map<String, Object> authentication = (Map<String, Object>) data.get("authentication");
List<Map<String, Object>> schemes = (List<Map<String, Object>>) authentication.get("schemes");

return schemes.stream()
.map(
schemeMap -> {
String scheme = (String) schemeMap.get("scheme");
String realm = (String) schemeMap.get("realm");
return realm != null ? scheme + " realm=\"" + realm + "\"" : scheme;
})
.collect(Collectors.toList());
}

private static final class MockPrincipal implements Principal, Serializable {
@Override
public String getName() {
Expand Down